From 7862f69f3914b8b03dbed65ae9ff79dd7bc28ff2 Mon Sep 17 00:00:00 2001 From: qinling0210 <88864212+qinling0210@users.noreply.github.com> Date: Wed, 1 Jul 2026 15:52:52 +0800 Subject: [PATCH] Implement chat completions in go (#16491) ### Summary POST /api/v1/chat/completions --- docker/nginx/ragflow.conf.hybrid | 10 + internal/cli/cli.go | 49 ++ internal/cli/cli_http.go | 5 + internal/cli/response.go | 64 +- internal/cli/user_command.go | 192 +++++- internal/cli/user_parser.go | 151 +++++ internal/development.md | 70 +- internal/handler/chat.go | 186 ----- internal/handler/chat_session.go | 282 +++----- internal/router/router.go | 25 +- internal/service/chat.go | 547 --------------- internal/service/chat_pipeline.go | 406 ++++------- internal/service/chat_session.go | 936 +++++++++++++++++--------- internal/service/chat_session_test.go | 744 ++++---------------- internal/service/dataset.go | 2 +- internal/service/openai_chat.go | 8 +- internal/service/think_tag.go | 406 +++++++---- internal/service/think_tag_test.go | 191 ++++-- web/vite.config.ts | 5 + 19 files changed, 1917 insertions(+), 2362 deletions(-) diff --git a/docker/nginx/ragflow.conf.hybrid b/docker/nginx/ragflow.conf.hybrid index 1f68187063..3745c788d3 100644 --- a/docker/nginx/ragflow.conf.hybrid +++ b/docker/nginx/ragflow.conf.hybrid @@ -45,6 +45,16 @@ server { include proxy.conf; } + location ~ ^/api/v1/datasets/search { + proxy_pass http://127.0.0.1:9384; + include proxy.conf; + } + + location ~ ^/api/v1/chat/completions { + proxy_pass http://127.0.0.1:9384; + include proxy.conf; + } + location ~ ^/v1/system/config { proxy_pass http://127.0.0.1:9384; include proxy.conf; diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 185c3ea3ce..0738d3521f 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -834,6 +834,8 @@ Commands (User Mode): CHAT 'provider/instance/model' 'message'; - Chat with specified model OPENAI_CHAT 'chat_id' 'message' [options] ; - OpenAI-compatible chat (run openai_chat -h for detailed options) + CHAT COMPLETIONS 'question' [options] ; - Chat completions via /api/v1/chat/completions + (run chat completions -h for detailed options) Filesystem Commands (no quotes): ls [path] - List resources @@ -1041,3 +1043,50 @@ Examples: ` fmt.Println(help) } + +// printChatCompletionsHelp prints help for the CHAT COMPLETIONS command. +func printChatCompletionsHelp() { + help := `CHAT COMPLETIONS — hit POST /api/v1/chat/completions + +Syntax: + CHAT COMPLETIONS 'question' + chat_id '...' + [session "..."] [llm "..."] + [system "..."] [history "..."] [history_delimiter ""] + [temperature ] [max_tokens ] [stream ] + [top_p ] [frequency_penalty ] [presence_penalty ] + [pass_all_history ] [legacy ] ; + +Required positional: + 'question' the user question + +Named options (any order; all optional with defaults): + chat_id '...' the dialog id (optional) + session '...' existing session/conversation id + llm '...' override the dialog's LLM + system '...' override the system prompt + history '...' prior turns: user:...;assistant:...;user:... + history_delimiter '...' turn separator for history (default ';') + temperature 0..2 (default 0) + max_tokens (default 0 = server/model default) + stream true|false (default false) + top_p 0..1 + frequency_penalty -2..2 + presence_penalty -2..2 + pass_all_history pass all history messages + legacy use legacy SSE format + +Defaults: + stream false + temperature 0 + history_delimiter ';' + +Examples: + CHAT COMPLETIONS 'Hello, how are you?' chat_id 'cid'; + CHAT COMPLETIONS 'Explain quantum computing' chat_id 'cid' stream true; + CHAT COMPLETIONS 'Next question' chat_id 'cid' session 'sess-abc123'; + CHAT COMPLETIONS 'What about X?' chat_id 'cid' system 'You are a helpful assistant.' history 'user:Tell me about Y;assistant:Y is...'; + CHAT COMPLETIONS 'Summarize' chat_id 'cid' llm 'Qwen/Qwen3-8B@ling@SILICONFLOW' temperature 0.7 max_tokens 512; +` + fmt.Println(help) +} diff --git a/internal/cli/cli_http.go b/internal/cli/cli_http.go index af53086d2e..4ef04b39cc 100644 --- a/internal/cli/cli_http.go +++ b/internal/cli/cli_http.go @@ -419,6 +419,11 @@ func (c *CLI) ExecuteUserCommand(cmd *Command) (ResponseIf, error) { return c.EmbedUserTextCommand(cmd) case "api_rarank_user_document": return c.APIRerankUserDocumentCommand(cmd) + case "chat completions": + return c.ChatCompletions(cmd) + case "chat completions help": + printChatCompletionsHelp() + return nil, nil case "tts_user_command": return c.APITTSUserCommand(cmd) case "asr_user_command": diff --git a/internal/cli/response.go b/internal/cli/response.go index 54fb9974cb..36a143186d 100644 --- a/internal/cli/response.go +++ b/internal/cli/response.go @@ -980,9 +980,11 @@ func getChunkID(c map[string]interface{}) string { } func chunkContent(c map[string]interface{}) string { - if v, ok := c["content"]; ok { - s := fmt.Sprint(v) - return strings.TrimSpace(s) + for _, key := range []string{"content_with_weight", "content"} { + if v, ok := c[key]; ok { + s := fmt.Sprint(v) + return strings.TrimSpace(s) + } } return "" } @@ -1283,3 +1285,59 @@ func (r *QuotaSummaryResponse) PrintOut() { PrintTableSimpleByFormatWithOrder(table, section.columns, r.OutputFormat) } } + +// ChatCompletionsResponse represents the RAGFlow-internal response from +// POST /api/v1/chat/completions (non-OpenAI format). +// +// JSON shape: +// +// {"code":0,"data":{"answer":"...","reference":...},"message":""} +type ChatCompletionsResponse struct { + Code int `json:"code"` + Data *chatCompletionData `json:"data"` + Message string `json:"message"` + Duration float64 `json:"-"` + OutputFormat OutputFormat `json:"-"` + // raw HTTP body for "raw" output. + raw []byte + // streamed skips the "Answer:" line in PrintOut to avoid duplication + // (used by the streaming path which prints chunk-by-chunk). + streamed bool +} + +type chatCompletionData struct { + Answer string `json:"answer"` + Reference json.RawMessage `json:"reference,omitempty"` + ID string `json:"id,omitempty"` + SessionID string `json:"session_id,omitempty"` + ChatID string `json:"chat_id,omitempty"` +} + +func (r *ChatCompletionsResponse) Type() string { return "chat_completions" } +func (r *ChatCompletionsResponse) TimeCost() float64 { return r.Duration } +func (r *ChatCompletionsResponse) SetOutputFormat(f OutputFormat) { r.OutputFormat = f } + +func (r *ChatCompletionsResponse) PrintOut() { + if r.OutputFormat == "raw" && r.raw != nil { + fmt.Println(string(r.raw)) + return + } + if r.Code != 0 { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + return + } + if r.Data == nil { + fmt.Println("(no data)") + return + } + if !r.streamed { + if r.Data.Answer != "" { + fmt.Printf("Answer: %s\n", r.Data.Answer) + } + } + if r.Data != nil && len(r.Data.Reference) > 0 { + printReferenceChunks(r.Data.Reference) + } + fmt.Printf("Time: %f\n", r.Duration) +} diff --git a/internal/cli/user_command.go b/internal/cli/user_command.go index 0d9e81a0eb..6781166d9c 100644 --- a/internal/cli/user_command.go +++ b/internal/cli/user_command.go @@ -32,7 +32,6 @@ import ( "ragflow/internal/ingestion" "ragflow/internal/ingestion/parser" "ragflow/internal/utility" - "regexp" "strings" "time" ) @@ -2128,7 +2127,7 @@ func (c *CLI) APIChatToModelCommand(cmd *Command) (ResponseIf, error) { effort := cmd.Params["effort"].(string) verbosity := cmd.Params["verbosity"].(string) - url := "/chat/completions" + url := "/chat/to_model" payload := map[string]interface{}{ "messages": formattedMessages, @@ -4133,8 +4132,6 @@ func (c *CLI) streamOpenaiChat(url string, body map[string]interface{}) (Respons fullContent = strings.TrimLeft(fullContent, "\n\r") fullReason = strings.TrimLeft(fullReason, "\n\r") - fullContent = stripThinkTags(fullContent) - fullReason = stripThinkTags(fullReason) return &OpenAIChatResponse{ Duration: resp.Duration, Reasoning: fullReason, @@ -4147,8 +4144,187 @@ func (c *CLI) streamOpenaiChat(url string, body map[string]interface{}) (Respons }, nil } -// stripThinkTags removes wrappers from a streamed answer -func stripThinkTags(s string) string { - var thinkTagRE = regexp.MustCompile(`(?s).*?`) - return thinkTagRE.ReplaceAllString(s, "") +// ChatCompletions dispatches the parsed CHAT COMPLETIONS command to +// POST /api/v1/chat/completions. +func (c *CLI) ChatCompletions(cmd *Command) (ResponseIf, error) { + if c.Config.CLIMode != APIMode { + return nil, fmt.Errorf("CHAT COMPLETIONS is only allowed in USER mode") + } + httpClient := c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer] + if httpClient.APIKey == nil && httpClient.LoginToken == nil { + return nil, fmt.Errorf("API token not set. Please login first") + } + + body, err := buildChatCompletionsRequestBody(cmd) + if err != nil { + return nil, err + } + + url := "/chat/completions" + + stream, _ := cmd.Params["stream"].(bool) + if stream { + return c.streamChatCompletions(url, body) + } + return c.oneshotChatCompletions(url, body) +} + +// buildChatCompletionsRequestBody assembles the JSON payload for +// POST /api/v1/chat/completions. +// +// When system or history is provided, a `messages` array is built; +// otherwise just `question` is sent and the server normalizes it. +func buildChatCompletionsRequestBody(cmd *Command) (map[string]interface{}, error) { + chatID, _ := cmd.Params["chat_id"].(string) + question, _ := cmd.Params["question"].(string) + stream, _ := cmd.Params["stream"].(bool) + + body := map[string]interface{}{ + "chat_id": chatID, + "stream": stream, + } + + // Optional session_id + if v, ok := cmd.Params["session"].(string); ok && v != "" { + body["session_id"] = v + } + + // Optional llm_id + if v, ok := cmd.Params["llm"].(string); ok && v != "" { + body["llm_id"] = v + } + + // Build messages from system + history when provided; otherwise send question. + system, hasSystem := cmd.Params["system"].(string) + historyRaw, hasHistory := cmd.Params["history_raw"].(string) + + if hasSystem || hasHistory { + messages := make([]map[string]interface{}, 0, 4) + if hasSystem && system != "" { + messages = append(messages, map[string]interface{}{"role": "system", "content": system}) + } + if hasHistory && historyRaw != "" { + delimiter, _ := cmd.Params["history_delimiter"].(string) + turns, err := parseHistory(historyRaw, delimiter) + if err != nil { + return nil, fmt.Errorf("CHAT COMPLETIONS history: %w", err) + } + for _, t := range turns { + messages = append(messages, map[string]interface{}{ + "role": t["role"], + "content": t["content"], + }) + } + } + messages = append(messages, map[string]interface{}{"role": "user", "content": question}) + body["messages"] = messages + } else { + body["question"] = question + } + + // Optional flags — only emit when explicitly set + if isSet(cmd, "pass_all_history") && cmd.Params["pass_all_history"].(bool) { + body["pass_all_history_messages"] = true + } + if isSet(cmd, "legacy") && cmd.Params["legacy"].(bool) { + body["legacy"] = true + } + + // Generation params — only emit when explicitly set + if isSet(cmd, "temperature") { + body["temperature"] = cmd.Params["temperature"] + } + if isSet(cmd, "max_tokens") { + body["max_tokens"] = cmd.Params["max_tokens"] + } + if isSet(cmd, "top_p") { + body["top_p"] = cmd.Params["top_p"] + } + if isSet(cmd, "frequency_penalty") { + body["frequency_penalty"] = cmd.Params["frequency_penalty"] + } + if isSet(cmd, "presence_penalty") { + body["presence_penalty"] = cmd.Params["presence_penalty"] + } + + return body, nil +} + +// oneshotChatCompletions performs a non-streaming POST and returns a +// ChatCompletionsResponse parsed from the RAGFlow-internal JSON envelope. +func (c *CLI) oneshotChatCompletions(url string, body map[string]interface{}) (ResponseIf, error) { + httpClient := c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer] + resp, err := httpClient.Request("POST", url, "web", nil, body) + if err != nil { + return nil, fmt.Errorf("chat completions request: %w", err) + } + if resp.StatusCode != 200 { + return &ChatCompletionsResponse{ + Code: resp.StatusCode, + Message: string(resp.Body), + raw: resp.Body, + }, nil + } + out := &ChatCompletionsResponse{ + Duration: resp.Duration, + raw: resp.Body, + } + // RAGFlow returns {code, data: {answer, reference, ...}, message}. + var envelope struct { + Code int `json:"code"` + Message string `json:"message"` + Data *chatCompletionData `json:"data"` + } + if err := json.Unmarshal(resp.Body, &envelope); err != nil { + return nil, fmt.Errorf("chat completions: invalid response JSON: %w", err) + } + out.Code = envelope.Code + out.Message = envelope.Message + out.Data = envelope.Data + return out, nil +} + +// streamChatCompletions performs a streaming POST and collects SSE chunks. +func (c *CLI) streamChatCompletions(url string, body map[string]interface{}) (ResponseIf, error) { + httpClient := c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer] + reader, err := httpClient.RequestStream("POST", url, "web", nil, body) + if err != nil { + return nil, fmt.Errorf("chat completions stream: %w", err) + } + defer reader.Close() + + start := time.Now() + scanner := bufio.NewScanner(reader) + var fullContent string + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || !strings.HasPrefix(line, "data:") { + continue + } + payload := strings.TrimPrefix(line, "data:") + payload = strings.TrimSpace(payload) + if payload == "[DONE]" { + continue + } + var chunk struct { + Code int `json:"code"` + Message string `json:"message"` + Data chatCompletionData `json:"data"` + } + if err := json.Unmarshal([]byte(payload), &chunk); err != nil { + continue + } + if chunk.Data.Answer != "" { + fullContent += chunk.Data.Answer + } + } + + fullContent = strings.TrimLeft(fullContent, "\n\r") + return &ChatCompletionsResponse{ + Duration: time.Since(start).Seconds(), + Data: &chatCompletionData{ + Answer: fullContent, + }, + streamed: true, + }, nil } diff --git a/internal/cli/user_parser.go b/internal/cli/user_parser.go index 175f00522d..8a9791d2ff 100644 --- a/internal/cli/user_parser.go +++ b/internal/cli/user_parser.go @@ -2702,6 +2702,13 @@ func (p *Parser) parseAPIDisable() (*Command, error) { func (p *Parser) parseAPIChat() (*Command, error) { p.nextToken() // consume CHAT + // Redirect "chat completion[s]" to the standalone chat completions parser. + if p.curToken.Type == TokenIdentifier && + (strings.EqualFold(p.curToken.Value, "completion") || strings.EqualFold(p.curToken.Value, "completions")) { + p.nextToken() // consume completion/completions + return p.parseChatCompletionsBody() + } + var err error var modelNameOrID string = "" var messages []string @@ -3759,6 +3766,150 @@ optionsLoop: return cmd, nil } +// CHAT COMPLETIONS +// [chat_id ] [session ] [llm ] + + +// parseChatCompletionsBody parses the question and options of a CHAT COMPLETIONS +// command. The leading keyword(s) must already have been consumed by the caller. +func (p *Parser) parseChatCompletionsBody() (*Command, error) { + + if p.curToken.Type == TokenDash { + dashCount := 0 + for p.curToken.Type == TokenDash { + dashCount++ + p.nextToken() + } + if dashCount > 0 && p.curToken.Type == TokenIdentifier { + switch strings.ToLower(p.curToken.Value) { + case "h", "help": + return NewCommand("chat completions help"), nil + } + } + return nil, fmt.Errorf("CHAT COMPLETIONS: only -h/--help takes no args; otherwise expected question") + } + + cmd := NewCommand("chat completions") + + // Defaults + cmd.Params["chat_id"] = "" + cmd.Params["temperature"] = 0.0 + cmd.Params["max_tokens"] = 0 + cmd.Params["stream"] = false + + // Track which options were explicitly set (distinguishes from defaults). + cmd.Params["_set"] = map[string]bool{} + + // Required positional: + question, err := p.parseQuotedString() + if err != nil { + return nil, fmt.Errorf("CHAT COMPLETIONS: expected question: %w", err) + } + cmd.Params["question"] = question + p.nextToken() + + // Optional named options + handleOption := func(name string) error { + switch name { + case "chat_id", "session", "llm", "system": + v, err := p.parseQuotedString() + if err != nil { + return fmt.Errorf("CHAT COMPLETIONS %s: expected quoted string, got %s", name, p.curToken.Value) + } + cmd.Params[name] = v + p.nextToken() + markSet(cmd, name) + case "temperature", "top_p", "frequency_penalty", "presence_penalty": + v, err := p.parseFloat() + if err != nil { + return fmt.Errorf("CHAT COMPLETIONS %s: expected number, got %s", name, p.curToken.Value) + } + cmd.Params[name] = v + p.nextToken() + markSet(cmd, name) + case "max_tokens": + v, err := p.parseNumber() + if err != nil { + return fmt.Errorf("CHAT COMPLETIONS max_tokens: expected integer, got %s", p.curToken.Value) + } + cmd.Params["max_tokens"] = v + p.nextToken() + markSet(cmd, "max_tokens") + case "stream", "pass_all_history", "legacy": + v, err := p.parseBool() + if err != nil { + return fmt.Errorf("CHAT COMPLETIONS %s: expected true|false, got %s", name, p.curToken.Value) + } + cmd.Params[name] = v + markSet(cmd, name) + case "history": + raw, err := p.parseQuotedString() + if err != nil { + return fmt.Errorf("CHAT COMPLETIONS history: expected quoted string, got %s", p.curToken.Value) + } + cmd.Params["history_raw"] = raw + p.nextToken() + case "history_delimiter": + v, err := p.parseQuotedString() + if err != nil { + return fmt.Errorf("CHAT COMPLETIONS history_delimiter: expected quoted string, got %s", p.curToken.Value) + } + cmd.Params["history_delimiter"] = v + p.nextToken() + default: + return fmt.Errorf("CHAT COMPLETIONS: unknown option %q (valid: chat_id, session, llm, system, history, history_delimiter, temperature, max_tokens, stream, top_p, frequency_penalty, presence_penalty, pass_all_history, legacy)", name) + } + return nil + } + + // Named options, any order, until ';'. +optionsLoop: + for { + switch p.curToken.Type { + case TokenSemicolon: + p.nextToken() + break optionsLoop + case TokenEOF: + break optionsLoop + + case TokenIdentifier, TokenQuotedString: + name := p.curToken.Value + if p.curToken.Type == TokenQuotedString { + name = strings.Trim(name, "'\"") + } + p.nextToken() + if err := handleOption(name); err != nil { + return nil, err + } + + default: + if !isKeyword(p.curToken.Type) { + return nil, fmt.Errorf("CHAT COMPLETIONS: unexpected token %q in option list (valid options: chat_id, session, llm, system, history, history_delimiter, temperature, max_tokens, stream, top_p, frequency_penalty, presence_penalty, pass_all_history, legacy)", p.curToken.Value) + } + name := p.curToken.Value + p.nextToken() + if err := handleOption(name); err != nil { + return nil, err + } + } + } + + return cmd, nil +} + +func markSet(cmd *Command, name string) { + if s, ok := cmd.Params["_set"].(map[string]bool); ok { + s[name] = true + } +} + +func isSet(cmd *Command, name string) bool { + if s, ok := cmd.Params["_set"].(map[string]bool); ok { + return s[name] + } + return false +} + // parseJSONLiteral consumes a TokenQuotedString whose payload is a JSON // value (object, array, string, number, or boolean) and returns it as // the original raw string (NOT decoded — the caller decides whether to diff --git a/internal/development.md b/internal/development.md index 85c50eb190..ea1e11fe63 100644 --- a/internal/development.md +++ b/internal/development.md @@ -167,7 +167,33 @@ Time: 76.582520 ``` Note: Both image and video understanding support streaming and thinking modes as well. -### 6.8. Chat with OpenAI compatible API +### 6.8. Chat completions + +``` +RAGFlow(api/default)> chat completion 'hello' +Answer: Hello! How can I assist you today? 😊 +Time: 1.591929 +``` + +``` +RAGFlow(api/default)> CHAT COMPLETIONS '' chat_id ''; +``` + +``` +RAGFlow(api/default)> CHAT COMPLETIONS 'Explain the theory' \ + chat_id '' \ + session '' llm 'glm-4.5-flash@test@zhipu-ai' stream true; +``` + +``` +RAGFlow(api/default)> CHAT COMPLETIONS 'Continue' \ + system 'You are a helpful assistant.' \ + history 'user:What is RAG?;assistant:RAG stands for Retrieval-Augmented Generation...' \ + history_delimiter ';'; +``` + +### 6.9. Chat with OpenAI compatible API + ``` RAGFlow(api/default)> openai_chat '' 'Hello, how are you?'; Answer: Hello! I'm just a virtual assistant, so I don't have feelings, but I'm here and ready to help you with anything you need. How can I assist you today? 😊 @@ -204,17 +230,17 @@ RAGFlow(api/default)> openai_chat '' 'Hello, how are you?' extra_body ' CLI error: OPENAI_CHAT extra_body: unknown field "ref" (valid: reference, reference_metadata, metadata_condition) ``` -### 6.9. Generate Embeddings +### 6.10. Generate Embeddings ``` RAGFlow(api/default)> embed text 'what is rag' 'who are you' with 'embedding-3@test@zhipu-ai' dimension 16; ``` -### 6.10. Document Reranking +### 6.11. Document Reranking ``` RAGFlow(api/default)> rerank query 'what is rag' document 'rag is retrieval augment generation' 'rag need llm' 'famous rag project includes ragflow' with 'rerank@test@zhipu-ai' top 2; ``` -### 6.11. Get supported models from provider API +### 6.12. Get supported models from provider API ``` RAGFlow(api/default)> list supported models from 'gitee' 'test'; @@ -236,7 +262,7 @@ RAGFlow(api/default)> list supported models from 'gitee' 'test'; +-----------+---------------------------+---------------+------------+-----------------------------------------------------------------+----------------------------------------------------------+---------------------------------------------+ ``` -### 6.12. Get preset models of a provider +### 6.13. Get preset models of a provider ``` RAGFlow(api/default)> list models from 'minimax'; @@ -254,7 +280,7 @@ RAGFlow(api/default)> list models from 'minimax'; +------------+-------------+------------------------+ ``` -### 6.13. List instances of a provider +### 6.14. List instances of a provider ``` RAGFlow(api/default)> list instances from 'zhipu-ai'; @@ -265,7 +291,7 @@ RAGFlow(api/default)> list instances from 'zhipu-ai'; +---------+----------------------+----------------------------------+--------------+----------------------------------+--------+ ``` -### 6.14. Show instance of a provider +### 6.15. Show instance of a provider ``` RAGFlow(api/default)> show instance 'test' from 'zhipu-ai'; +----------------------------------+--------------+----------------------------------+---------+--------+ @@ -275,7 +301,7 @@ RAGFlow(api/default)> show instance 'test' from 'zhipu-ai'; +----------------------------------+--------------+----------------------------------+---------+--------+ ``` -### 6.15. List models of a specific instance +### 6.16. List models of a specific instance ``` RAGFlow(api/default)> list models from 'minimax' 'test'; @@ -293,7 +319,7 @@ RAGFlow(api/default)> list models from 'minimax' 'test'; +------------+-------------+------------------------+--------+ ``` -### 6.16. List added providers +### 6.17. List added providers ``` RAGFlow(api/default)> list providers; +--------------------------------------------------------------------------+-------------+--------------+ @@ -305,7 +331,7 @@ RAGFlow(api/default)> list providers; +--------------------------------------------------------------------------+-------------+--------------+ ``` -### 6.17. Deactivate / activate a model +### 6.18. Deactivate / activate a model ``` RAGFlow(api/default)> disable model 'deepseek-v4-pro' from 'deepseek' 'test'; @@ -321,7 +347,7 @@ RAGFlow(api/default)> enable model 'deepseek-v4-pro' from 'deepseek' 'test'; SUCCESS ``` -### 6.18. Set current model +### 6.19. Set current model ``` RAGFlow(api/default)> use model 'glm-4.5-flash@test@zhipu-ai'; SUCCESS @@ -330,7 +356,7 @@ Answer: Large language models are advanced AI systems. They process text to unde Time: 1.680416 ``` -### 6.19. Set, reset, and list default models +### 6.20. Set, reset, and list default models ``` RAGFlow(api/default)> set default chat model 'glm-4.5-flash@test@zhipu-ai'; SUCCESS @@ -374,7 +400,7 @@ RAGFlow(api/default)> list default models; +--------+----------------+--------------+----------------+------------+ ``` -### 6.20. Show current balance of a provider instance +### 6.21. Show current balance of a provider instance ``` RAGFlow(api/default)> show balance from 'gitee' 'test'; +-------------+----------+ @@ -384,13 +410,13 @@ RAGFlow(api/default)> show balance from 'gitee' 'test'; +-------------+----------+ ``` -### 6.21. Check provider instance availability +### 6.22. Check provider instance availability ``` RAGFlow(api/default)> check instance 'test' from 'zhipu-ai'; SUCCESS ``` -### 6.22. Add local model to RAGFlow, only for local deployed inference server, such as ollama +### 6.23. Add local model to RAGFlow, only for local deployed inference server, such as ollama ``` RAGFlow(api/default)> add model 'Qwen/Qwen2.5-0.5B' to provider 'vllm' instance 'test' with tokens 131072 chat; SUCCESS @@ -404,7 +430,7 @@ RAGFlow(api/default)> drop model 'Qwen/Qwen2.5-0.5B' from 'vllm' 'test'; SUCCESS ``` -### 6.23. List datasets +### 6.24. List datasets ``` RAGFlow(api/default)> list datasets; +-------------+--------------+----------------+----------------------+----------------------------------+----------+------+----------+------------+----------------------------------+-----------+---------------+ @@ -415,14 +441,14 @@ RAGFlow(api/default)> list datasets; +-------------+--------------+----------------+----------------------+----------------------------------+----------+------+----------+------------+----------------------------------+-----------+---------------+ ``` -### 6.24. Text to Speech +### 6.25. Text to Speech ``` RAGFlow(api/default)> tts with 'speech-2.8-hd@test@minimax' text 'He who desires but acts not, breeds pestilence.' play format 'wav' save './internal' param '{"voice_setting": {"voice_id": "English_radiant_girl", "speed": 1, "vol": 1, "pitch": 0}, "audio_setting": {"sample_rate": 32000, "bitrate": 128000, "format": "wav", "channel": 1}, "output_format": "hex"}' Saved to directory: /home/infiniflow/Documents/development/ragflow/internal/speech-2.8-hd_output.wav SUCCESS ``` -### 6.25. Audio to Speech +### 6.26. Audio to Speech ``` RAGFlow(api/default)> asr with 'FunAudioLLM/SenseVoiceSmall@test@siliconflow' audio './internal/test.wav' param '' +----------------------------------------------------------------------------------------------------------------------+ @@ -432,7 +458,7 @@ RAGFlow(api/default)> asr with 'FunAudioLLM/SenseVoiceSmall@test@siliconflow' au +----------------------------------------------------------------------------------------------------------------------+ ``` -### 6.26. Optical Character Recognition +### 6.27. Optical Character Recognition ``` RAGFlow(api/default)> ocr with 'paddleocr-vl-0.9b@test@baidu' file './internal/text.jpg' +------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -442,7 +468,7 @@ RAGFlow(api/default)> ocr with 'paddleocr-vl-0.9b@test@baidu' file './internal/t +------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ ``` -### 6.27. Chunk Management Commands +### 6.28. Chunk Management Commands - Create a chunk store with vector size ``` @@ -489,7 +515,7 @@ RAGFlow(api/default)> RETRIEVE 'AI' ON DATASETS 'test' RAGFlow(api/default)> GET CHUNK '29cc4f6d7a5c6e7c' OF DATASET 'test' DOCUMENT 'bbe55942535e11f1bc5184ba59049aa3' IN DATASET 'test' ``` -### 6.28. Metadata Management Commands +### 6.29. Metadata Management Commands - Create metadata store ``` @@ -525,7 +551,7 @@ RAGFlow(api/default)> DROP METADATA STORE RAGFlow(api/default)> GET METADATA OF DATASET 'test' 'test2' ``` -### 6.29. Search datasets +### 6.30. Search datasets - Search datasets using SQL-like dataset search syntax: ``` diff --git a/internal/handler/chat.go b/internal/handler/chat.go index eba5f41266..c4f998eac0 100644 --- a/internal/handler/chat.go +++ b/internal/handler/chat.go @@ -227,192 +227,6 @@ func (h *ChatHandler) MindMap(c *gin.Context) { jsonResponse(c, common.CodeSuccess, mindMap, "success") } -// ListChatsNext list chats with advanced filtering and pagination -// @Summary List Chats Next -// @Description Get list of chats with filtering, pagination and sorting (equivalent to list_dialogs_next) -// @Tags chat -// @Accept json -// @Produce json -// @Param keywords query string false "search keywords" -// @Param page query int false "page number" -// @Param page_size query int false "items per page" -// @Param orderby query string false "order by field (default: create_time)" -// @Param desc query bool false "descending order (default: true)" -// @Param request body service.ListChatsNextRequest true "filter options including owner_ids" -// @Success 200 {object} service.ListChatsNextResponse -// @Router /v1/dialog/next [post] -func (h *ChatHandler) ListChatsNext(c *gin.Context) { - user, errorCode, errorMessage := GetUser(c) - if errorCode != common.CodeSuccess { - jsonError(c, errorCode, errorMessage) - return - } - userID := user.ID - - // Parse query parameters - keywords := c.Query("keywords") - - page := 0 - if pageStr := c.Query("page"); pageStr != "" { - if p, err := strconv.Atoi(pageStr); err == nil && p > 0 { - page = p - } - } - - pageSize := 0 - if pageSizeStr := c.Query("page_size"); pageSizeStr != "" { - if ps, err := strconv.Atoi(pageSizeStr); err == nil && ps > 0 { - pageSize = ps - } - } - - orderby := c.DefaultQuery("orderby", "create_time") - - desc := true - if descStr := c.Query("desc"); descStr != "" { - desc = descStr != "false" - } - - // Parse request body for owner_ids - var req service.ListChatsNextRequest - if c.Request.ContentLength > 0 { - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": err.Error(), - }) - return - } - } - - // List chats with advanced filtering - result, err := h.chatService.ListChatsNext(userID, keywords, page, pageSize, orderby, desc, req.OwnerIDs) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "code": 500, - "message": err.Error(), - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "code": 0, - "data": result, - "message": "success", - }) -} - -// SetDialog create or update a dialog -// @Summary Set Dialog -// @Description Create or update a dialog (chat). If dialog_id is provided, updates existing dialog; otherwise creates new one. -// @Tags chat -// @Accept json -// @Produce json -// @Param request body service.SetDialogRequest true "dialog configuration" -// @Success 200 {object} service.SetDialogResponse -// @Router /v1/dialog/set [post] -func (h *ChatHandler) SetDialog(c *gin.Context) { - user, errorCode, errorMessage := GetUser(c) - if errorCode != common.CodeSuccess { - jsonError(c, errorCode, errorMessage) - return - } - userID := user.ID - - // Parse request body - var req service.SetDialogRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": err.Error(), - }) - return - } - - // Validate required field: prompt_config - if req.PromptConfig == nil { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "prompt_config is required", - }) - return - } - - // Call service to set dialog - result, err := h.chatService.SetDialog(userID, &req) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "code": 500, - "message": err.Error(), - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "code": 0, - "data": result, - "message": "success", - }) -} - -// RemoveDialogsRequest remove dialogs request -type RemoveDialogsRequest struct { - DialogIDs []string `json:"dialog_ids" binding:"required"` -} - -// RemoveChats remove/delete dialogs (soft delete by setting status to invalid) -// @Summary Remove Dialogs -// @Description Remove dialogs by setting their status to invalid. Only the owner of the dialog can perform this operation. -// @Tags chat -// @Accept json -// @Produce json -// @Param request body RemoveDialogsRequest true "dialog IDs to remove" -// @Success 200 {object} map[string]interface{} -// @Router /v1/dialog/rm [post] -func (h *ChatHandler) RemoveChats(c *gin.Context) { - user, errorCode, errorMessage := GetUser(c) - if errorCode != common.CodeSuccess { - jsonError(c, errorCode, errorMessage) - return - } - userID := user.ID - - // Parse request body - var req RemoveDialogsRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": err.Error(), - }) - return - } - - // Call service to remove dialogs - if err := h.chatService.RemoveChats(userID, req.DialogIDs); err != nil { - // Check if it's an authorization error - if err.Error() == "only owner of chat authorized for this operation" { - c.JSON(http.StatusForbidden, gin.H{ - "code": 403, - "data": false, - "message": err.Error(), - }) - return - } - c.JSON(http.StatusInternalServerError, gin.H{ - "code": 500, - "message": err.Error(), - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "code": 0, - "data": true, - "message": "success", - }) -} - -// DeleteChat soft deletes a chat by ID. func (h *ChatHandler) DeleteChat(c *gin.Context) { user, errorCode, errorMessage := GetUser(c) if errorCode != common.CodeSuccess { diff --git a/internal/handler/chat_session.go b/internal/handler/chat_session.go index bbc34c95c9..ce3adf8808 100644 --- a/internal/handler/chat_session.go +++ b/internal/handler/chat_session.go @@ -19,7 +19,6 @@ package handler import ( "encoding/json" "errors" - "fmt" "io" "net/http" "ragflow/internal/common" @@ -44,107 +43,6 @@ func NewChatSessionHandler(chatSessionService *service.ChatSessionService, userS } } -// SetChatSession create or update a chat session -// @Summary Set chat session -// @Description Create or update a chat session. If is_new is true, creates new chat session; otherwise updates existing one. -// @Tags chat_session -// @Accept json -// @Produce json -// @Param request body service.SetChatSessionRequest true "chat session configuration" -// @Success 200 {object} service.SetChatSessionResponse -// @Router /v1/conversation/set [post] -func (h *ChatSessionHandler) SetChatSession(c *gin.Context) { - user, errorCode, errorMessage := GetUser(c) - if errorCode != common.CodeSuccess { - jsonError(c, errorCode, errorMessage) - return - } - userID := user.ID - - // Parse request body - var req service.SetChatSessionRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": err.Error(), - }) - return - } - - // Call service to set chat session - result, err := h.chatSessionService.SetChatSession(userID, &req) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "code": 500, - "message": err.Error(), - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "code": 0, - "data": result, - "message": "success", - }) -} - -// RemoveChatSessionsRequest remove chat sessions request -type RemoveChatSessionsRequest struct { - ConversationIDs []string `json:"conversation_ids" binding:"required"` -} - -// RemoveChatSessions remove/delete chat sessions -// @Summary Remove Chat Sessions -// @Description Remove chat sessions by their IDs. Only the owner of the chat session can perform this operation. -// @Tags chat_session -// @Accept json -// @Produce json -// @Param request body RemoveChatSessionsRequest true "chat session IDs to remove" -// @Success 200 {object} map[string]interface{} -// @Router /v1/conversation/rm [post] -func (h *ChatSessionHandler) RemoveChatSessions(c *gin.Context) { - user, errorCode, errorMessage := GetUser(c) - if errorCode != common.CodeSuccess { - jsonError(c, errorCode, errorMessage) - return - } - userID := user.ID - - // Parse request body - var req RemoveChatSessionsRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": err.Error(), - }) - return - } - - // Call service to remove chat sessions - if err := h.chatSessionService.RemoveChatSessions(userID, req.ConversationIDs); err != nil { - // Check if it's an authorization error - if err.Error() == "Only owner of chat session authorized for this operation" { - c.JSON(http.StatusForbidden, gin.H{ - "code": 403, - "data": false, - "message": err.Error(), - }) - return - } - c.JSON(http.StatusInternalServerError, gin.H{ - "code": 500, - "message": err.Error(), - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "code": 0, - "data": true, - "message": "success", - }) -} - // ListChatSessions list chat sessions for a dialog // @Summary List Chat Sessions // @Description Get list of chat sessions for a specific dialog @@ -198,30 +96,37 @@ func (h *ChatSessionHandler) ListChatSessions(c *gin.Context) { }) } -// CompletionRequest completion request -type CompletionRequest struct { - ConversationID string `json:"conversation_id" binding:"required"` - Messages []map[string]interface{} `json:"messages" binding:"required"` - LLMID string `json:"llm_id,omitempty"` - Stream *bool `json:"stream,omitempty"` - Thinking *bool `json:"thinking,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` +type ChatCompletionsRequest struct { + ChatID string `json:"chat_id,omitempty"` + SessionID string `json:"session_id,omitempty"` + ConversationID string `json:"conversation_id,omitempty"` + Messages []map[string]interface{} `json:"messages,omitempty"` + Question string `json:"question,omitempty"` + Files []interface{} `json:"files,omitempty"` + LLMID string `json:"llm_id,omitempty"` + PassAllHistoryMessages *bool `json:"pass_all_history_messages,omitempty"` + PassAllHistory *bool `json:"pass_all_history,omitempty"` + Legacy bool `json:"legacy,omitempty"` + Stream *bool `json:"stream"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` } -// Completion chat completion +// ChatCompletions chat completion // @Summary Chat Completion -// @Description Send messages to the chat model and get a response. Supports streaming and non-streaming modes. +// @Description Send messages to the chat model and get a response. +// @Description Default is streaming (text/event-stream); set stream:false for JSON. // @Tags chat_session // @Accept json -// @Produce json -// @Param request body CompletionRequest true "completion request" -// @Success 200 {object} map[string]interface{} -// @Router /v1/conversation/completion [post] -func (h *ChatSessionHandler) Completion(c *gin.Context) { +// @Produce json, text/event-stream +// @Param request body ChatCompletionsRequest true "chat completion request" +// @Success 200 {object} map[string]interface{} "Non-streaming JSON response" +// @Success 200 {string} text/event-stream "Streaming SSE response" +// @Router /api/v1/chat/completions [post] +func (h *ChatSessionHandler) ChatCompletions(c *gin.Context) { user, errorCode, errorMessage := GetUser(c) if errorCode != common.CodeSuccess { jsonError(c, errorCode, errorMessage) @@ -229,83 +134,97 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) { } userID := user.ID - // Parse request body - var req CompletionRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": err.Error(), - }) + var rawBody map[string]interface{} + if err := c.ShouldBindJSON(&rawBody); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()}) return } - // Build chat model config - chatModelConfig := make(map[string]interface{}) - if req.Temperature != 0 { - chatModelConfig["temperature"] = req.Temperature + var req ChatCompletionsRequest + b, err := json.Marshal(rawBody) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()}) + return } - if req.TopP != 0 { - chatModelConfig["top_p"] = req.TopP + if err := json.Unmarshal(b, &req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()}) + return } - if req.FrequencyPenalty != 0 { - chatModelConfig["frequency_penalty"] = req.FrequencyPenalty + + // Normalize session_id / conversation_id + sessionID := req.SessionID + if sessionID == "" { + sessionID = req.ConversationID } - if req.PresencePenalty != 0 { - chatModelConfig["presence_penalty"] = req.PresencePenalty + + // Build generation config + genConfig := make(map[string]interface{}) + if req.Temperature != nil { + genConfig["temperature"] = *req.Temperature } - if req.MaxTokens != 0 { - chatModelConfig["max_tokens"] = req.MaxTokens + if req.TopP != nil { + genConfig["top_p"] = *req.TopP } + if req.FrequencyPenalty != nil { + genConfig["frequency_penalty"] = *req.FrequencyPenalty + } + if req.PresencePenalty != nil { + genConfig["presence_penalty"] = *req.PresencePenalty + } + if req.MaxTokens != nil { + genConfig["max_tokens"] = *req.MaxTokens + } + + // Resolve pass_all_history from either alias + passAllHistory := false + if req.PassAllHistory != nil { + passAllHistory = *req.PassAllHistory + } + if req.PassAllHistoryMessages != nil { + passAllHistory = *req.PassAllHistoryMessages + } + + // Remove known keys from rawBody; what remains is passthrough kwargs + knownKeys := []string{ + "chat_id", "session_id", "conversation_id", + "messages", "question", "files", + "llm_id", + "pass_all_history_messages", "pass_all_history", + "legacy", "stream", + "temperature", "top_p", "frequency_penalty", "presence_penalty", "max_tokens", + } + for _, key := range knownKeys { + delete(rawBody, key) + } + kwargs := rawBody + + // Determine stream mode + streamMode := true if req.Stream != nil { - chatModelConfig["stream"] = *req.Stream - } - if req.Thinking != nil { - chatModelConfig["thinking"] = *req.Thinking + streamMode = *req.Stream } - // Process messages - filter out system messages and initial assistant messages - var processedMessages []map[string]interface{} - for i, m := range req.Messages { - role, _ := m["role"].(string) - if role == "system" { - continue - } - if role == "assistant" && len(processedMessages) == 0 { - continue - } - processedMessages = append(processedMessages, m) - _ = i - } - - // Get last message ID if present - var messageID string - if len(processedMessages) > 0 { - if id, ok := processedMessages[len(processedMessages)-1]["id"].(string); ok { - messageID = id - } - } - - // Call service - if req.Stream != nil && *req.Stream { - // Streaming response + if streamMode { disableWriteDeadlineForSSE(c) c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Header("X-Accel-Buffering", "no") - // Create a channel for streaming data - streamChan := make(chan string) + streamChan := make(chan string, 32) reqCtx := c.Request.Context() go func() { defer close(streamChan) - err := h.chatSessionService.CompletionStream(reqCtx, userID, req.ConversationID, processedMessages, req.LLMID, chatModelConfig, messageID, streamChan) - if err != nil { - streamChan <- fmt.Sprintf("data: %s\n\n", err.Error()) - } + _, _ = h.chatSessionService.ChatCompletions( + reqCtx, userID, + req.ChatID, sessionID, + req.Messages, req.Question, req.Files, + req.LLMID, genConfig, kwargs, + passAllHistory, req.Legacy, + true, streamChan, + ) }() - // Stream data to client c.Stream(func(w io.Writer) bool { data, ok := <-streamChan if !ok { @@ -315,8 +234,14 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) { return true }) } else { - // Non-streaming response - result, err := h.chatSessionService.Completion(userID, req.ConversationID, processedMessages, req.LLMID, chatModelConfig, messageID) + result, err := h.chatSessionService.ChatCompletions( + c.Request.Context(), userID, + req.ChatID, sessionID, + req.Messages, req.Question, req.Files, + req.LLMID, genConfig, kwargs, + passAllHistory, req.Legacy, + false, nil, + ) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, @@ -324,7 +249,6 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) { }) return } - c.JSON(http.StatusOK, gin.H{ "code": 0, "data": result, diff --git a/internal/router/router.go b/internal/router/router.go index a6700de896..826ed379dc 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -292,6 +292,12 @@ func (r *Router) Setup(engine *gin.Engine) { chats.PATCH("/:chat_id/sessions/:session_id", r.chatSessionHandler.UpdateSession) } + chat := v1.Group("/chat") + { + // Chat completions route + chat.POST("/completions", r.chatSessionHandler.ChatCompletions) + } + // OpenAI-compatible chat completions route openai := v1.Group("/openai") { @@ -498,7 +504,7 @@ func (r *Router) Setup(engine *gin.Engine) { provider.PATCH("/:provider_name/instances/:instance_name/models/*model_name", r.providerHandler.EnableOrDisableModel) provider.POST("/:provider_name/instances/:instance_name/models", r.providerHandler.AddModel) provider.DELETE("/:provider_name/instances/:instance_name/models", r.providerHandler.DropInstanceModels) - v1.POST("/chat/completions", r.providerHandler.ChatToModel) + v1.POST("/chat/to_model", r.providerHandler.ChatToModel) v1.POST("/embeddings", r.providerHandler.EmbedText) v1.POST("/rerank", r.providerHandler.RerankDocument) v1.POST("/audio/transcriptions", r.providerHandler.TranscribeAudio) @@ -678,14 +684,6 @@ func (r *Router) Setup(engine *gin.Engine) { chunk.POST("/update", r.chunkHandler.UpdateChunk) // Internal API only for GO } - // Chat routes - chat := authorized.Group("/v1/dialog") - { - chat.POST("/next", r.chatHandler.ListChatsNext) - chat.POST("/set", r.chatHandler.SetDialog) - chat.POST("/rm", r.chatHandler.RemoveChats) - } - // Chat Channel chanChannel := v1.Group("/chat-channels") { @@ -705,15 +703,6 @@ func (r *Router) Setup(engine *gin.Engine) { langfuse.DELETE("/api-key", r.langfuseHandler.DeleteAPIKey) } - // Chat session (conversation) routes - session := authorized.Group("/v1/conversation") - { - session.POST("/set", r.chatSessionHandler.SetChatSession) - session.POST("/rm", r.chatSessionHandler.RemoveChatSessions) - session.GET("/list", r.chatSessionHandler.ListChatSessions) - session.POST("/completion", r.chatSessionHandler.Completion) - } - // Connector routes connector := authorized.Group("/v1/connector") { diff --git a/internal/service/chat.go b/internal/service/chat.go index 4fa939d75c..6106b4b73d 100644 --- a/internal/service/chat.go +++ b/internal/service/chat.go @@ -28,28 +28,6 @@ import ( "ragflow/internal/dao" ) -var DefaultPromptConfig = PromptConfig{ - System: strPtr(pyDefaultSystemPrompt), - Prologue: strPtr(pyDefaultPrologue), - Parameters: []ParameterConfig{ - {Key: "knowledge", Optional: false}, - }, - EmptyResponse: strPtr(pyDefaultEmptyResponse), - Quote: boolPtr(true), - TTS: boolPtr(false), - RefineMultiturn: boolPtr(true), -} - -var DefaultDirectChatPromptConfig = PromptConfig{ - System: strPtr(""), - Prologue: strPtr(""), - Parameters: []ParameterConfig{}, - EmptyResponse: strPtr(""), - Quote: boolPtr(false), - TTS: boolPtr(false), - RefineMultiturn: boolPtr(true), -} - var DefaultRerankModels = map[string]struct{}{ "BAAI/bge-reranker-v2-m3": {}, "maidalun1020/bce-reranker-base_v1": {}, @@ -653,74 +631,6 @@ func isTruthy(value interface{}) bool { } } -// ListChatsNextRequest list chats next request -type ListChatsNextRequest struct { - OwnerIDs []string `json:"owner_ids,omitempty"` -} - -// ListChatsNextResponse list chats next response -type ListChatsNextResponse struct { - Chats []*ChatWithKBNames `json:"dialogs"` - Total int64 `json:"total"` -} - -// ListChatsNext list chats with advanced filtering (equivalent to list_dialogs_next) -func (s *ChatService) ListChatsNext(userID string, keywords string, page, pageSize int, orderby string, desc bool, ownerIDs []string) (*ListChatsNextResponse, error) { - var chats []*entity.Chat - var total int64 - var err error - - if len(ownerIDs) == 0 { - // Get tenant IDs by user ID (joined tenants) - tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) - if err != nil { - return nil, err - } - - // Use database pagination - chats, total, err = s.chatDAO.ListByTenantIDs(tenantIDs, userID, page, pageSize, orderby, desc, keywords) - if err != nil { - return nil, err - } - } else { - // Filter by owner IDs, manual pagination - chats, total, err = s.chatDAO.ListByOwnerIDs(ownerIDs, userID, orderby, desc, keywords) - if err != nil { - return nil, err - } - - // Manual pagination - if page > 0 && pageSize > 0 { - start := (page - 1) * pageSize - end := start + pageSize - if start < int(total) { - if end > int(total) { - end = int(total) - } - chats = chats[start:end] - } else { - chats = []*entity.Chat{} - } - } - } - - // Enrich with knowledge base names - chatsWithKBNames := make([]*ChatWithKBNames, 0, len(chats)) - for _, chat := range chats { - kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs) - chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{ - Chat: chat, - KBNames: kbNames, - DatasetIDs: datasetIDs, - }) - } - - return &ListChatsNextResponse{ - Chats: chatsWithKBNames, - Total: total, - }, nil -} - // getDatasetNamesAndIDs gets knowledge base names by IDs func (s *ChatService) getDatasetNamesAndIDs(kbIDs entity.JSONSlice) ([]string, []string) { var names = make([]string, 0, 0) @@ -743,30 +653,6 @@ func (s *ChatService) getDatasetNamesAndIDs(kbIDs entity.JSONSlice) ([]string, [ return names, ids } -// ParameterConfig parameter configuration in prompt_config -type ParameterConfig struct { - Key string `json:"key"` - Optional bool `json:"optional"` -} - -// PromptConfig prompt configuration -type PromptConfig struct { - System *string `json:"system"` - Prologue *string `json:"prologue"` - Parameters []ParameterConfig `json:"parameters"` - EmptyResponse *string `json:"empty_response"` - TavilyAPIKey string `json:"tavily_api_key,omitempty"` - Keyword *bool `json:"keyword,omitempty"` - Quote *bool `json:"quote,omitempty"` - Reasoning *bool `json:"reasoning,omitempty"` - RefineMultiturn *bool `json:"refine_multiturn,omitempty"` - TocEnhance *bool `json:"toc_enhance,omitempty"` - TTS *bool `json:"tts,omitempty"` - UseKG *bool `json:"use_kg,omitempty"` - CrossLanguages []string `json:"cross_languages,omitempty"` - ReferenceMetadata map[string]interface{} `json:"reference_metadata,omitempty"` -} - const ( pyDefaultSystemPrompt = "You are an intelligent assistant. Please summarize the content of the dataset to answer the question. " + "Please list the data in the dataset and answer in detail. " + @@ -781,393 +667,6 @@ const ( pyDefaultEmptyResponse = "Sorry! No relevant content was found in the knowledge base!" ) -// applyPromptDefaults replaces missing keys with default values -func applyPromptDefaults(p *PromptConfig) { - if p.System == nil || *p.System == "" { - s := pyDefaultSystemPrompt - p.System = &s - } - if p.Prologue == nil { - s := pyDefaultPrologue - p.Prologue = &s - } - if p.Parameters == nil { - p.Parameters = []ParameterConfig{{Key: "knowledge", Optional: false}} - } - if p.EmptyResponse == nil { - s := pyDefaultEmptyResponse - p.EmptyResponse = &s - } - if p.Quote == nil { - t := true - p.Quote = &t - } - if p.RefineMultiturn == nil { - t := true - p.RefineMultiturn = &t - } - if p.TTS == nil { - f := false - p.TTS = &f - } -} - -// SetDialogRequest set chat request -type SetDialogRequest struct { - DialogID string `json:"dialog_id,omitempty"` - Name string `json:"name,omitempty"` - Description string `json:"description,omitempty"` - Icon string `json:"icon,omitempty"` - TopN int64 `json:"top_n,omitempty"` - TopK int64 `json:"top_k,omitempty"` - RerankID string `json:"rerank_id,omitempty"` - SimilarityThreshold float64 `json:"similarity_threshold,omitempty"` - VectorSimilarityWeight float64 `json:"vector_similarity_weight,omitempty"` - LLMSetting map[string]interface{} `json:"llm_setting,omitempty"` - MetaDataFilter map[string]interface{} `json:"meta_data_filter,omitempty"` - PromptConfig *PromptConfig `json:"prompt_config" binding:"required"` - KBIDs []string `json:"kb_ids,omitempty"` - LLMID string `json:"llm_id,omitempty"` -} - -// SetDialogResponse set chat response -type SetDialogResponse struct { - *entity.Chat - KBNames []string `json:"kb_names"` -} - -// SetDialog create or update a chat -func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialogResponse, error) { - // Determine if this is a create or update operation - isCreate := req.DialogID == "" - - // Validate and process name - name := req.Name - if name == "" { - name = "New Chat" - } - - // Validate name type and content - if strings.TrimSpace(name) == "" { - return nil, errors.New("Chat name can't be empty") - } - - // Check name length (UTF-8 byte length) - if len(name) > 255 { - return nil, fmt.Errorf("Chat name length is %d which is larger than 255", len(name)) - } - - name = strings.TrimSpace(name) - - // Get tenant ID (use userID as default tenant) - tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) - if err != nil { - return nil, err - } - - var tenantID string - if len(tenantIDs) > 0 { - tenantID = tenantIDs[0] - } else { - tenantID = userID - } - - // For create: check for duplicate names and generate unique name - if isCreate { - existingNames, err := s.chatDAO.GetExistingNames(tenantID, "1") - if err != nil { - return nil, err - } - - // Check if name exists (case-insensitive) - nameLower := strings.ToLower(name) - for _, existing := range existingNames { - if strings.ToLower(existing) == nameLower { - // Generate unique name - name = s.generateUniqueName(name, existingNames) - break - } - } - } - - // Set default values - description := req.Description - if description == "" { - description = "A helpful chat" - } - - topN := req.TopN - if topN == 0 { - topN = 6 - } - - topK := req.TopK - if topK == 0 { - topK = 1024 - } - - rerankID := req.RerankID - - similarityThreshold := req.SimilarityThreshold - if similarityThreshold == 0 { - similarityThreshold = 0.1 - } - - vectorSimilarityWeight := req.VectorSimilarityWeight - if vectorSimilarityWeight == 0 { - vectorSimilarityWeight = 0.3 - } - - llmSetting := req.LLMSetting - if llmSetting == nil { - llmSetting = make(map[string]interface{}) - } - - metaDataFilter := req.MetaDataFilter - if metaDataFilter == nil { - metaDataFilter = make(map[string]interface{}) - } - - promptConfig := req.PromptConfig - - // Process kb_ids - kbIDs := req.KBIDs - if kbIDs == nil { - kbIDs = []string{} - } - - // Apply default prompt config on create only - if isCreate { - applyPromptDefaults(promptConfig) - } - - // Set default parameters for datasets with knowledge retrieval - // Check if parameters is missing or empty and kb_ids is provided - if len(kbIDs) > 0 && (promptConfig.Parameters == nil || len(promptConfig.Parameters) == 0) { - // Check if system prompt uses {knowledge} placeholder - if promptConfig.System != nil && strings.Contains(*promptConfig.System, "{knowledge}") { - promptConfig.Parameters = []ParameterConfig{ - {Key: "knowledge", Optional: false}, - } - } - } - - // For update: validate that {knowledge} is not used when no KBs or Tavily - if !isCreate { - if len(kbIDs) == 0 && promptConfig.TavilyAPIKey == "" && - promptConfig.System != nil && strings.Contains(*promptConfig.System, "{knowledge}") { - return nil, errors.New("Please remove `{knowledge}` in system prompt since no dataset / Tavily used here") - } - } - - // Validate parameters - for _, p := range promptConfig.Parameters { - if p.Optional { - continue - } - placeholder := fmt.Sprintf("{%s}", p.Key) - if promptConfig.System == nil || !strings.Contains(*promptConfig.System, placeholder) { - return nil, fmt.Errorf("Parameter '%s' is not used", p.Key) - } - } - - // Check knowledge bases and their embedding models - if len(kbIDs) > 0 { - kbs, err := s.kbDAO.GetByIDs(kbIDs) - if err != nil { - return nil, err - } - - // Check if all KBs use the same embedding model - var embdID string - for i, kb := range kbs { - if i == 0 { - embdID = kb.EmbdID - } else { - // Extract base model name (remove vendor suffix) - embdBase := s.splitModelNameAndFactory(embdID) - kbEmbdBase := s.splitModelNameAndFactory(kb.EmbdID) - if embdBase != kbEmbdBase { - return nil, fmt.Errorf("Datasets use different embedding models: %v", getEmbdIDs(kbs)) - } - } - } - } - - // Get LLM ID (use tenant's default if not provided) - llmID := req.LLMID - if llmID == "" { - tenant, err := s.tenantDAO.GetByID(tenantID) - if err != nil { - return nil, errors.New("Tenant not found") - } - llmID = tenant.LLMID - } - - // Convert prompt config to JSONMap - promptConfigMap := entity.JSONMap{} - if promptConfig.System != nil && *promptConfig.System != "" { - promptConfigMap["system"] = *promptConfig.System - } - if promptConfig.Prologue != nil { - promptConfigMap["prologue"] = *promptConfig.Prologue - } - if promptConfig.EmptyResponse != nil { - promptConfigMap["empty_response"] = *promptConfig.EmptyResponse - } - if promptConfig.Quote != nil { - promptConfigMap["quote"] = *promptConfig.Quote - } - if promptConfig.RefineMultiturn != nil { - promptConfigMap["refine_multiturn"] = *promptConfig.RefineMultiturn - } - if promptConfig.TTS != nil { - promptConfigMap["tts"] = *promptConfig.TTS - } - if promptConfig.Keyword != nil { - promptConfigMap["keyword"] = *promptConfig.Keyword - } - if promptConfig.Reasoning != nil { - promptConfigMap["reasoning"] = *promptConfig.Reasoning - } - if promptConfig.TocEnhance != nil { - promptConfigMap["toc_enhance"] = *promptConfig.TocEnhance - } - if promptConfig.UseKG != nil { - promptConfigMap["use_kg"] = *promptConfig.UseKG - } - if promptConfig.TavilyAPIKey != "" { - promptConfigMap["tavily_api_key"] = promptConfig.TavilyAPIKey - } - if len(promptConfig.CrossLanguages) > 0 { - promptConfigMap["cross_languages"] = promptConfig.CrossLanguages - } - if len(promptConfig.ReferenceMetadata) > 0 { - promptConfigMap["reference_metadata"] = promptConfig.ReferenceMetadata - } - if len(promptConfig.Parameters) > 0 { - params := make([]map[string]interface{}, len(promptConfig.Parameters)) - for i, p := range promptConfig.Parameters { - params[i] = map[string]interface{}{ - "key": p.Key, - "optional": p.Optional, - } - } - promptConfigMap["parameters"] = params - } - - // Convert kbIDs to JSONSlice - kbIDsJSON := make(entity.JSONSlice, len(kbIDs)) - for i, id := range kbIDs { - kbIDsJSON[i] = id - } - - if isCreate { - // Generate UUID for new chat - newID := common.GenerateUUID() - - // Set default language - language := "English" - - // Create new chat - chat := &entity.Chat{ - ID: newID, - TenantID: tenantID, - Name: &name, - Description: &description, - Icon: &req.Icon, - Language: &language, - LLMID: llmID, - LLMSetting: llmSetting, - PromptConfig: promptConfigMap, - MetaDataFilter: (*entity.JSONMap)(&metaDataFilter), - TopN: topN, - TopK: topK, - RerankID: rerankID, - SimilarityThreshold: similarityThreshold, - VectorSimilarityWeight: vectorSimilarityWeight, - KBIDs: kbIDsJSON, - Status: strPtr("1"), - } - - if err := s.chatDAO.Create(chat); err != nil { - return nil, errors.New("Fail to new a chat") - } - - // Get KB names - kbNames, _ := s.getDatasetNamesAndIDs(chat.KBIDs) - - return &SetDialogResponse{ - Chat: chat, - KBNames: kbNames, - }, nil - } - - updateData := map[string]interface{}{ - "name": name, - "description": description, - "icon": req.Icon, - "llm_id": llmID, - "llm_setting": llmSetting, - "prompt_config": promptConfigMap, - "meta_data_filter": metaDataFilter, - "top_n": topN, - "top_k": topK, - "rerank_id": rerankID, - "similarity_threshold": similarityThreshold, - "vector_similarity_weight": vectorSimilarityWeight, - "kb_ids": kbIDsJSON, - } - - if err := s.chatDAO.UpdateByID(req.DialogID, updateData); err != nil { - return nil, errors.New("Dialog not found") - } - - // Get updated chat - chat, err := s.chatDAO.GetByID(req.DialogID) - if err != nil { - return nil, errors.New("Fail to update a chat") - } - - // Get KB names - kbNames, _ := s.getDatasetNamesAndIDs(chat.KBIDs) - - return &SetDialogResponse{ - Chat: chat, - KBNames: kbNames, - }, nil -} - -// generateUniqueName generates a unique name by appending a number -func (s *ChatService) generateUniqueName(name string, existingNames []string) string { - baseName := name - counter := 1 - - // Check if name already has a suffix like "(1)" - if idx := strings.LastIndex(name, "("); idx > 0 { - if idx2 := strings.LastIndex(name, ")"); idx2 > idx { - if num, err := fmt.Sscanf(name[idx+1:idx2], "%d", &counter); err == nil && num == 1 { - baseName = strings.TrimSpace(name[:idx]) - counter++ - } - } - } - - existingMap := make(map[string]bool) - for _, n := range existingNames { - existingMap[strings.ToLower(n)] = true - } - - newName := name - for { - if !existingMap[strings.ToLower(newName)] { - return newName - } - newName = fmt.Sprintf("%s(%d)", baseName, counter) - counter++ - } -} - // splitModelNameAndFactory extracts the base model name (removes vendor suffix) func (s *ChatService) splitModelNameAndFactory(embdID string) string { // Remove vendor suffix (e.g., "model@openai" -> "model") @@ -1629,52 +1128,6 @@ func (s *ChatService) BulkDeleteChats(userID string, req *BulkDeleteChatsRequest return nil, errors.New(strings.Join(errorsList, "; ")) } -// RemoveChats removes dialogs by setting their status to invalid (soft delete) -// Only the owner of the chat can perform this operation -func (s *ChatService) RemoveChats(userID string, chatIDs []string) error { - // Get user's tenants - tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) - if err != nil { - return err - } - - // Build a set of user's tenant IDs for quick lookup - tenantIDSet := make(map[string]bool) - for _, tid := range tenantIDs { - tenantIDSet[tid] = true - } - // Also add userID itself as a tenant (for cases where tenant_id = user_id) - tenantIDSet[userID] = true - - // Check each chat and build update list - var updates []map[string]interface{} - for _, chatID := range chatIDs { - // Get the chat to check ownership - chat, err := s.chatDAO.GetByID(chatID) - if err != nil { - return fmt.Errorf("chat not found: %s", chatID) - } - - // Check if user is the owner (chat's tenant_id must be in user's tenants) - if !tenantIDSet[chat.TenantID] { - return errors.New("only owner of chat authorized for this operation") - } - - // Add to update list (soft delete by setting status to "0") - updates = append(updates, map[string]interface{}{ - "id": chatID, - "status": "0", - }) - } - - // Batch update all dialogs - if err := s.chatDAO.UpdateManyByID(updates); err != nil { - return err - } - - return nil -} - // strPtr returns a pointer to a string func strPtr(s string) *string { return &s diff --git a/internal/service/chat_pipeline.go b/internal/service/chat_pipeline.go index c649fca3c2..97c120d65d 100644 --- a/internal/service/chat_pipeline.go +++ b/internal/service/chat_pipeline.go @@ -171,16 +171,23 @@ func (s *ChatPipelineService) AsyncChat( } // No KBs & no web search → fast-path to LLM-only chat. + hasKBs := false + for _, raw := range chat.KBIDs { + if id, ok := raw.(string); ok && id != "" { + hasKBs = true + break + } + } useWebSearch := s.shouldUseWebSearch(chat, kwargs["internet"]) if useWebSearch { common.Debug("web_search", - zap.Bool("kb", len(chat.KBIDs) > 0), + zap.Bool("kb", hasKBs), zap.Bool("tavily", chat.PromptConfig != nil && chat.PromptConfig["tavily_api_key"] != "" && chat.PromptConfig["tavily_api_key"] != nil), zap.Any("internet", kwargs["internet"]), zap.Bool("enabled", useWebSearch)) } - if len(chat.KBIDs) == 0 && !useWebSearch { + if !hasKBs && !useWebSearch { return s.AsyncChatSolo(ctx, chat, messages, stream) } @@ -1022,8 +1029,7 @@ func (s *ChatPipelineService) AsyncChat( if stream { // Streaming path: accumulate answer, emit deltas. var fullAnswer string - var fullReasoning string - thinkState := &thinkStreamState{} + thinkState := &ThinkStreamState{} chatCfg := BuildChatConfig(chat, nil) @@ -1094,50 +1100,60 @@ func (s *ChatPipelineService) AsyncChat( *chatDriver.ModelName, chatMessages, chatDriver.APIConfig, chatCfg, func(answer *string, reason *string) error { if reason != nil && *reason != "" { - fullReasoning += *reason - kind, output := processThinkDelta(thinkState, *reason, 16) - if kind == "marker" && output == "" { - // marker — emit StartToThink - out <- AsyncChatResult{ - Answer: "", - Reference: map[string]interface{}{}, - AudioBinary: nil, - CreatedAt: float64(time.Now().Unix()), - Final: false, - StartToThink: true, - } - } else if kind == "marker" && output == "" { - // marker — emit EndToThink - out <- AsyncChatResult{ - Answer: "", - Reference: map[string]interface{}{}, - AudioBinary: nil, - CreatedAt: float64(time.Now().Unix()), - Final: false, - EndToThink: true, - } - } else if kind == "text" && output != "" { - // Route reasoning text to Reasoning field. - // TTS is nil — chain-of-thought is not narrated. - out <- AsyncChatResult{ - Reasoning: output, - Reference: map[string]interface{}{}, - // TTS only narrates user-visible answer text. - AudioBinary: nil, - CreatedAt: float64(time.Now().Unix()), - Final: false, + deltas := NextThinkDelta(thinkState, *reason, 16) + for _, d := range deltas { + if d.Kind == ThinkDeltaMarker && d.Value == "" { + out <- AsyncChatResult{ + Answer: "", + Reference: map[string]interface{}{}, + AudioBinary: nil, + CreatedAt: float64(time.Now().Unix()), + Final: false, + StartToThink: true, + } + } else if d.Kind == ThinkDeltaMarker && d.Value == "" { + out <- AsyncChatResult{ + Answer: "", + Reference: map[string]interface{}{}, + AudioBinary: nil, + CreatedAt: float64(time.Now().Unix()), + Final: false, + EndToThink: true, + } + } else if d.Kind == ThinkDeltaText && d.Value != "" { + fullAnswer += d.Value + out <- AsyncChatResult{ + Answer: d.Value, + Reference: map[string]interface{}{}, + AudioBinary: s.synthesizeTTS(ttsModel, d.Value), + CreatedAt: float64(time.Now().Unix()), + Final: false, + } } } } if isContentDelta(answer) { fullAnswer += *answer - out <- AsyncChatResult{ - Answer: *answer, - Reference: map[string]interface{}{}, - // Per-delta TTS for incremental audio playback. - AudioBinary: s.synthesizeTTS(ttsModel, *answer), - CreatedAt: float64(time.Now().Unix()), - Final: false, + deltas := BufferAnswerDelta(thinkState, *answer, 16) + for _, d := range deltas { + if d.Kind == ThinkDeltaMarker && d.Value == "" { + out <- AsyncChatResult{ + Answer: "", + Reference: map[string]interface{}{}, + AudioBinary: nil, + CreatedAt: float64(time.Now().Unix()), + Final: false, + EndToThink: true, + } + } else if d.Kind == ThinkDeltaText && d.Value != "" { + out <- AsyncChatResult{ + Answer: d.Value, + Reference: map[string]interface{}{}, + AudioBinary: s.synthesizeTTS(ttsModel, d.Value), + CreatedAt: float64(time.Now().Unix()), + Final: false, + } + } } } return nil @@ -1152,33 +1168,34 @@ func (s *ChatPipelineService) AsyncChat( return } - // Flush remaining think stream buffer. - // If the LLM ended mid-think, emit remaining reasoning + - // implicit close marker. - remainingText, remainingMarker := flushThinkStream(thinkState) - if remainingText != "" { - // Flushed text belongs in Reasoning, not Answer. - out <- AsyncChatResult{ - Reasoning: remainingText, - Reference: map[string]interface{}{}, - AudioBinary: nil, - CreatedAt: float64(time.Now().Unix()), - Final: false, - } - } - if remainingMarker == "" { - out <- AsyncChatResult{ - Answer: "", - Reference: map[string]interface{}{}, - AudioBinary: nil, - CreatedAt: float64(time.Now().Unix()), - Final: false, - EndToThink: true, + // Flush remaining state matching Python's final flush order + // (dialog_service.py:1601-1612): think_buffer → marker → answer_buffer → pending_after_close + // Python has no Reasoning field — all text is Answer. + for _, d := range FlushRemaining(thinkState) { + if d.Kind == ThinkDeltaMarker && d.Value == "" { + out <- AsyncChatResult{ + Answer: "", + Reference: map[string]interface{}{}, + AudioBinary: nil, + CreatedAt: float64(time.Now().Unix()), + Final: false, + EndToThink: true, + } + } else if d.Kind == ThinkDeltaText && d.Value != "" { + out <- AsyncChatResult{ + Answer: d.Value, + Reference: map[string]interface{}{}, + AudioBinary: s.synthesizeTTS(ttsModel, d.Value), + CreatedAt: float64(time.Now().Unix()), + Final: false, + } } } // Decorate and yield the final answer. - visibleAnswer := s.extractVisibleAnswer(fullReasoning + fullAnswer) + // Python uses state.full_text (raw text with tags) as input + // to _extract_visible_answer → decorate_answer (dialog_service.py:914-920). + visibleAnswer := s.extractVisibleAnswer(thinkState.fullText) // Pass nil for ttsModel — audio was already produced per-delta. final := s.decorateAnswer(ctx, visibleAnswer, kbinfos, prompt, questions, usedTokenCount, timer, embModel, chat.VectorSimilarityWeight, quote, nil, langfuseTraceID, llmModelConfig, chat.TenantID, kbTenantIDStrings(kbs), len(knowledges) > 0) @@ -1372,28 +1389,51 @@ func (s *ChatPipelineService) AsyncChatSolo( // 7. Drive the LLM: stream (per-delta with think markers) or non-stream (one-shot). if stream { var fullAnswer string - var fullReasoning string - thinkState := &thinkStreamState{} + thinkState := &ThinkStreamState{} chatCfg := BuildChatConfig(chat, nil) timer.Enter(common.PhaseGenerateAnswer) + driverErr := chatModel.ModelDriver.ChatStreamlyWithSender( *chatModel.ModelName, chatMessages, chatModel.APIConfig, chatCfg, func(answer *string, reason *string) error { if reason != nil && *reason != "" { - fullReasoning += *reason - kind, output := processThinkDelta(thinkState, *reason, 16) - if kind == "marker" && output == "" { - // Start thinking. - out <- AsyncChatResult{ - Answer: "", - Reference: map[string]interface{}{}, - AudioBinary: nil, - CreatedAt: float64(time.Now().Unix()), - Final: false, - StartToThink: true, + deltas := NextThinkDelta(thinkState, *reason, 16) + for _, d := range deltas { + if d.Kind == ThinkDeltaMarker && d.Value == "" { + out <- AsyncChatResult{ + Answer: "", + Reference: map[string]interface{}{}, + AudioBinary: nil, + CreatedAt: float64(time.Now().Unix()), + Final: false, + StartToThink: true, + } + } else if d.Kind == ThinkDeltaMarker && d.Value == "" { + out <- AsyncChatResult{ + Answer: "", + Reference: map[string]interface{}{}, + AudioBinary: nil, + CreatedAt: float64(time.Now().Unix()), + Final: false, + EndToThink: true, + } + } else if d.Kind == ThinkDeltaText && d.Value != "" { + fullAnswer += d.Value + out <- AsyncChatResult{ + Answer: d.Value, + Reference: map[string]interface{}{}, + AudioBinary: s.synthesizeTTS(ttsModel, d.Value), + CreatedAt: float64(time.Now().Unix()), + Final: false, + } } - } else if kind == "marker" && output == "" { - // End thinking. + } + } + if isContentDelta(answer) { + fullAnswer += *answer + deltas := BufferAnswerDelta(thinkState, *answer, 16) + for _, d := range deltas { + if d.Kind == ThinkDeltaMarker && d.Value == "" { out <- AsyncChatResult{ Answer: "", Reference: map[string]interface{}{}, @@ -1402,27 +1442,17 @@ func (s *ChatPipelineService) AsyncChatSolo( Final: false, EndToThink: true, } - } else if kind == "text" && output != "" { - // Reasoning text with per-delta TTS. + } else if d.Kind == ThinkDeltaText && d.Value != "" { out <- AsyncChatResult{ - Reasoning: output, + Answer: d.Value, Reference: map[string]interface{}{}, - AudioBinary: s.synthesizeTTS(ttsModel, output), + AudioBinary: s.synthesizeTTS(ttsModel, d.Value), CreatedAt: float64(time.Now().Unix()), Final: false, } } } - if isContentDelta(answer) { - fullAnswer += *answer - out <- AsyncChatResult{ - Answer: *answer, - Reference: map[string]interface{}{}, - AudioBinary: s.synthesizeTTS(ttsModel, *answer), - CreatedAt: float64(time.Now().Unix()), - Final: false, - } - } + } return nil }, ) @@ -1434,33 +1464,30 @@ func (s *ChatPipelineService) AsyncChatSolo( return } timer.Exit(common.PhaseGenerateAnswer) - // Flush any remaining think buffer. - remainingText, remainingMarker := flushThinkStream(thinkState) - if remainingText != "" { - out <- AsyncChatResult{ - Reasoning: remainingText, - Reference: map[string]interface{}{}, - AudioBinary: s.synthesizeTTS(ttsModel, remainingText), - CreatedAt: float64(time.Now().Unix()), - Final: false, + for _, d := range FlushRemaining(thinkState) { + if d.Kind == ThinkDeltaMarker && d.Value == "" { + out <- AsyncChatResult{ + Answer: "", + Reference: map[string]interface{}{}, + AudioBinary: nil, + CreatedAt: float64(time.Now().Unix()), + Final: false, + EndToThink: true, + } + } else if d.Kind == ThinkDeltaText && d.Value != "" { + out <- AsyncChatResult{ + Answer: d.Value, + Reference: map[string]interface{}{}, + AudioBinary: s.synthesizeTTS(ttsModel, d.Value), + CreatedAt: float64(time.Now().Unix()), + Final: false, + } } } - if remainingMarker == "" { - out <- AsyncChatResult{ - Answer: "", - Reference: map[string]interface{}{}, - AudioBinary: nil, - CreatedAt: float64(time.Now().Unix()), - Final: false, - EndToThink: true, - } + finalAnswer := ExtractVisibleAnswer(thinkState.fullText) + if finalAnswer == "" { + finalAnswer = fullAnswer } - // Final aggregate: re-attach reasoning wrapper for non-streaming consumers. - finalAnswer := fullAnswer - if fullReasoning != "" { - finalAnswer = "" + fullReasoning + "" + fullAnswer - } - // Raw answer, no decorate_answer. AudioBinary=nil (per-delta TTS already emitted). out <- AsyncChatResult{ Answer: finalAnswer, Reference: map[string]interface{}{}, @@ -2288,7 +2315,7 @@ func (s *ChatPipelineService) kbPrompt(kbinfos map[string]interface{}, maxTokens } contents := make([]chunkContent, 0, len(chunksRaw)) for _, ck := range chunksRaw { - c := getMapString(ck, "content", "content_with_weight") + c := getMapString(ck, "content_with_weight", "content") if c == "" { continue } @@ -2315,7 +2342,7 @@ func (s *ChatPipelineService) kbPrompt(kbinfos map[string]interface{}, maxTokens var result []string for i := 0; i < chunksNum; i++ { ck := chunksRaw[i] - c := getMapString(ck, "content", "content_with_weight") + c := getMapString(ck, "content_with_weight", "content") if c == "" { continue } @@ -2781,25 +2808,8 @@ func langfuseExtractTimeElapsed(prompt string) string { } // extractVisibleAnswer mirrors Python's _extract_visible_answer. -// It preserves wrappers and strips stray think tags. func (s *ChatPipelineService) extractVisibleAnswer(text string) string { - if !strings.Contains(text, "") { - text = strings.ReplaceAll(text, "", "") - text = strings.ReplaceAll(text, "", "") - return text - } - idx := strings.LastIndex(text, "") - thought := text[:idx] - answer := text[idx+len(""):] - thought = strings.ReplaceAll(thought, "", "") - thought = strings.ReplaceAll(thought, "", "") - thought = strings.TrimSpace(thought) - answer = strings.ReplaceAll(answer, "", "") - answer = strings.ReplaceAll(answer, "", "") - if thought == "" { - return answer - } - return "" + thought + "" + answer + return ExtractVisibleAnswer(text) } // citationPrompt returns the citation instruction prompt. @@ -2809,124 +2819,6 @@ func citationPrompt() string { "(where N is the chunk number) after each sentence where the information from that chunk is used." } -// --------------------------------------------------------------------------- -// Think-marker streaming — mirrors Python's _stream_with_think_delta. -// --------------------------------------------------------------------------- - -// thinkStreamState tracks accumulated reasoning text and emits deltas. -type thinkStreamState struct { - fullText string - lastIdx int - endsWithThink bool - inThink bool - buffer string - postThinkText string -} - -// nextThinkDelta computes the next delta to emit from the accumulated text. -// Mirrors _next_think_delta in dialog_service.py:1460-1487. -func nextThinkDelta(state *thinkStreamState) string { - full := state.fullText - if full == "" || len(full) <= state.lastIdx { - return "" - } - delta := full[state.lastIdx:] - - if strings.HasPrefix(delta, "") { - state.lastIdx += len("") - return "" - } - if idx := strings.Index(delta, ""); idx > 0 { - state.lastIdx += idx - return delta[:idx] - } - if strings.HasSuffix(delta, "") { - state.endsWithThink = true - } else if state.endsWithThink { - state.endsWithThink = false - remainder := delta - if idx := strings.Index(delta, ""); idx >= 0 { - remainder = delta[idx+len(""):] - } - if remainder != "" { - state.postThinkText = remainder - } - state.lastIdx = len(full) - return "" - } - - state.lastIdx = len(full) - if strings.HasSuffix(full, "") { - state.lastIdx -= len("") - } - return strings.ReplaceAll(strings.ReplaceAll(delta, "", ""), "", "") -} - -// processThinkDelta updates the state with a new delta and returns what to emit. -// Returns the kind of emission: "marker" for think tags, "text" for content, "" for nothing. -func processThinkDelta(state *thinkStreamState, delta string, minTokens int) (kind string, output string) { - if delta == "" { - return "", "" - } - state.fullText += delta - d := nextThinkDelta(state) - if d == "" { - return "", "" - } - if d == "" { - if state.inThink { - return "", "" - } - if state.buffer != "" { - kind, out := "text", state.buffer - state.buffer = "" - state.inThink = true - return kind, out - } - state.inThink = true - return "marker", "" - } - if d == "" { - if !state.inThink { - return "", "" - } - state.inThink = false - if state.postThinkText != "" { - state.buffer += state.postThinkText - state.postThinkText = "" - } - return "marker", "" - } - state.buffer += d - if kg.NumTokensFromString(state.buffer) < minTokens { - return "", "" - } - out := state.buffer - state.buffer = "" - return "text", out -} - -// flushThinkStream flushes any remaining buffered text from the think stream. -func flushThinkStream(state *thinkStreamState) (text string, marker string) { - if state.buffer != "" { - text = state.buffer - state.buffer = "" - } - if state.postThinkText != "" { - if text != "" { - text += state.postThinkText - } else { - text = state.postThinkText - } - state.postThinkText = "" - } - if state.endsWithThink { - marker = "" - state.endsWithThink = false - } - return text, marker -} - // ----------------------------------------------------------------------- // Moved from sql_fallback.go (2026-06-12). SQL retrieval system, repair // helpers, and Python parity helpers. Kept in async_chat.go because the diff --git a/internal/service/chat_session.go b/internal/service/chat_session.go index 7a48f708fd..b700658a87 100644 --- a/internal/service/chat_session.go +++ b/internal/service/chat_session.go @@ -21,6 +21,7 @@ import ( "encoding/json" "errors" "fmt" + "math" "ragflow/internal/common" "ragflow/internal/storage" "strings" @@ -29,6 +30,7 @@ import ( "go.uber.org/zap" "gorm.io/gorm" + "github.com/google/uuid" "ragflow/internal/dao" "ragflow/internal/entity" ) @@ -87,150 +89,6 @@ func NewChatSessionService() *ChatSessionService { } } -// SetChatSessionRequest set chat session request -type SetChatSessionRequest struct { - SessionID string `json:"conversation_id,omitempty"` - DialogID string `json:"dialog_id,omitempty"` - Name string `json:"name,omitempty"` - IsNew bool `json:"is_new"` -} - -// SetChatSessionResponse set chat session response -type SetChatSessionResponse struct { - *entity.ChatSession -} - -// SetChatSession create or update a chat session -func (s *ChatSessionService) SetChatSession(userID string, req *SetChatSessionRequest) (*SetChatSessionResponse, error) { - name := req.Name - if name == "" { - name = "New chat session" - } - // Limit name length to 255 characters - if len(name) > 255 { - name = name[:255] - } - - if !req.IsNew { - // Update existing chat session - updates := map[string]interface{}{ - "name": name, - "user_id": userID, - } - - if err := s.chatSessionDAO.UpdateByID(req.SessionID, updates); err != nil { - return nil, errors.New("Chat session not found") - } - - // Get updated chat session - session, err := s.chatSessionDAO.GetByID(req.SessionID) - if err != nil { - return nil, errors.New("Fail to update a chat session") - } - - return &SetChatSessionResponse{ChatSession: session}, nil - } - - // Create new chat session - // Check if dialog exists - dialog, err := s.chatSessionDAO.GetDialogByID(req.DialogID) - if err != nil { - return nil, errors.New("Dialog not found") - } - - // Generate UUID for new chat session - newID := common.GenerateUUID() - - // Get prologue from dialog's prompt_config - prologue := "Hi! I'm your assistant. What can I do for you?" - if dialog.PromptConfig != nil { - if p, ok := dialog.PromptConfig["prologue"].(string); ok && p != "" { - prologue = p - } - } - - // Store messages in the same list shape as Python Conversation.message. - messagesJSON, _ := json.Marshal([]map[string]interface{}{ - { - "role": "assistant", - "content": prologue, - }, - }) - - // Create reference - store as JSON array - referenceJSON, _ := json.Marshal([]interface{}{}) - - // Create chat session - session := &entity.ChatSession{ - ID: newID, - DialogID: req.DialogID, - Name: &name, - Message: messagesJSON, - UserID: &userID, - Reference: referenceJSON, - } - - if err := s.chatSessionDAO.Create(session); err != nil { - return nil, errors.New("Fail to create a chat session") - } - - return &SetChatSessionResponse{ChatSession: session}, nil -} - -// RemoveChatSessions removes chat sessions (hard delete) -func (s *ChatSessionService) RemoveChatSessions(userID string, chatSessions []string) error { - // Get user's tenants - tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) - if err != nil { - return err - } - - // Build a set of user's tenant IDs for quick lookup - tenantIDSet := make(map[string]bool) - for _, tid := range tenantIDs { - tenantIDSet[tid] = true - } - tenantIDSet[userID] = true - - // Check each chat session - for _, convID := range chatSessions { - // Get the chat session - session, err := s.chatSessionDAO.GetByID(convID) - if err != nil { - return fmt.Errorf("Chat session not found: %s", convID) - } - - // Check if user is the owner by checking dialog ownership - isOwner := false - for tenantID := range tenantIDSet { - exists, err := s.chatSessionDAO.CheckDialogExists(tenantID, session.DialogID) - if err != nil { - return err - } - if exists { - isOwner = true - break - } - } - - if !isOwner { - return errors.New("Only owner of chat session authorized for this operation") - } - - // Delete the chat session - if err := s.chatSessionDAO.DeleteByID(convID); err != nil { - return err - } - } - - return nil -} - -// ListChatSessionsRequest list chat sessions request -type ListChatSessionsRequest struct { - DialogID string `json:"dialog_id" binding:"required"` -} - // ListChatSessionsResponse list chat sessions response type ListChatSessionsResponse struct { Sessions []*entity.ChatSession @@ -993,201 +851,441 @@ func isChatSessionNotFound(err error) bool { return errors.Is(err, gorm.ErrRecordNotFound) } -// Completion performs chat completion with full RAG support via ChatPipelineService. -func (s *ChatSessionService) Completion(userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string) (map[string]interface{}, error) { - // Validate the last message is from user - if len(messages) == 0 { - return nil, errors.New("messages cannot be empty") - } - lastRole, _ := messages[len(messages)-1]["role"].(string) - if lastRole != "user" { - return nil, errors.New("the last content of this conversation is not from user") - } - - // Get conversation - session, err := s.chatSessionDAO.GetByID(conversationID) - if err != nil { - return nil, errors.New("Conversation not found") - } - - // Get dialog - dialog, err := s.chatSessionDAO.GetDialogByID(session.DialogID) - if err != nil { - return nil, errors.New("Dialog not found") - } - - // Deep copy messages to session, preserving the stored prologue that handler strips from requests. - sessionMessages := s.buildSessionMessages(session, messages) - - // Initialize reference if empty - reference := s.initializeReference(session) - - // Check if custom LLM is specified and validate API key - isEmbedded := llmID != "" - if llmID != "" { - hasKey, err := s.checkTenantLLMAPIKey(dialog.TenantID, llmID) - if err != nil || !hasKey { - return nil, fmt.Errorf("Cannot use specified model %s", llmID) - } - dialog.LLMID = llmID - if chatModelConfig != nil { - dialog.LLMSetting = chatModelConfig - } - } - - // Perform chat completion via shared RAG pipeline - ctx := context.Background() - kwargs := chatModelConfig - if kwargs == nil { - kwargs = map[string]interface{}{} - } - resultChan, err := s.pipeline.AsyncChat(ctx, dialog, messages, false, kwargs) - if err != nil { - return nil, err - } - - // Collect results from the pipeline - var answer strings.Builder - var finalRef map[string]interface{} - for result := range resultChan { - if result.Answer != "" { - answer.WriteString(result.Answer) - } - if result.Reference != nil { - finalRef = result.Reference - } - } - - // Structure the answer - ans := map[string]interface{}{ - "answer": answer.String(), - "reference": finalRef, - "final": true, - } - result := s.structureAnswerWithConv(session, ans, messageID, session.ID, reference) - - // Update conversation if not embedded - if !isEmbedded { - sessionMessages = append(sessionMessages, map[string]interface{}{ - "role": "assistant", - "content": answer.String(), - "id": messageID, - "created_at": float64(time.Now().Unix()), - }) - s.updateSessionMessages(session, sessionMessages, reference) - } - - return result, nil -} - -// CompletionStream performs streaming chat completion with full RAG support via ChatPipelineService. -func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string, streamChan chan<- string) error { +// ChatCompletions handles chat completion matching Python's session_completion. +// When stream=true, returns nil result and streams SSE via streamChan. +// When stream=false, returns the structured answer map. +func (s *ChatSessionService) ChatCompletions( + ctx context.Context, + userID string, + chatID string, sessionID string, + messages []map[string]interface{}, question string, files []interface{}, + llmID string, genConfig map[string]interface{}, kwargs map[string]interface{}, + passAllHistory bool, legacy bool, + stream bool, streamChan chan<- string, +) (map[string]interface{}, error) { if ctx == nil { ctx = context.Background() } - // Validate the last message is from user - if len(messages) == 0 { - streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "messages cannot be empty", "data": {"answer": "**ERROR**: messages cannot be empty", "reference": []}}`) - return errors.New("messages cannot be empty") - } - lastRole, _ := messages[len(messages)-1]["role"].(string) - if lastRole != "user" { - streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "the last content of this conversation is not from user", "data": {"answer": "**ERROR**: the last content of this conversation is not from user", "reference": []}}`) - return errors.New("the last content of this conversation is not from user") + fail := func(err error) (map[string]interface{}, error) { + if stream && streamChan != nil { + s.sendSSEError(streamChan, err.Error()) + } + return nil, err } - // Get conversation - session, err := s.chatSessionDAO.GetByID(conversationID) + sendOrCancel := func(data string) bool { + select { + case streamChan <- data: + return true + case <-ctx.Done(): + return false + } + } + + common.Info("ChatCompletions started") + + // --- 1. Normalize messages --- + requestMessages, requestMsg, messageID, err := s.normalizeCompletionMessages(messages, question, files) if err != nil { - streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "Conversation not found", "data": {"answer": "**ERROR**: Conversation not found", "reference": []}}`) - return errors.New("Conversation not found") + return fail(err) } - // Get dialog - dialog, err := s.chatSessionDAO.GetDialogByID(session.DialogID) - if err != nil { - streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "Dialog not found", "data": {"answer": "**ERROR**: Dialog not found", "reference": []}}`) - return errors.New("Dialog not found") + // --- 2. Validate --- + if sessionID != "" && chatID == "" { + return fail(errors.New("`chat_id` is required when `session_id` is provided.")) } - // Deep copy messages to session, preserving the stored prologue that handler strips from requests. - sessionMessages := s.buildSessionMessages(session, messages) + // --- 3. Resolve dialog and session --- + var dialog *entity.Chat + var session *entity.ChatSession + if chatID != "" { + if err := s.checkDialogOwnership(userID, chatID); err != nil { + return fail(err) + } + dialog, err = s.chatSessionDAO.GetDialogByID(chatID) + if err != nil { + return fail(errors.New("Chat not found!")) + } + if sessionID != "" { + session, err = s.chatSessionDAO.GetByID(sessionID) + if err != nil { + return fail(errors.New("Session not found!")) + } + if session.DialogID != chatID { + return fail(errors.New("Session does not belong to this chat!")) + } + } else { + session, err = s.createSessionForCompletion(chatID, dialog, userID) + if err != nil { + return fail(err) + } + sessionID = session.ID + } - // Initialize reference if empty - reference := s.initializeReference(session) + if passAllHistory { + session.Message, _ = json.Marshal(requestMessages) + } else { + session = s.appendSessionMessage(session, requestMsg) + } + requestMsg = s.filterSystemAndLeadingAssistant(session) + _ = messageID + } else { + dialog = s.buildDefaultCompletionDialog(userID) + if !stream { + genConfig["stream"] = false + } + } - // Check if custom LLM is specified and validate API key - isEmbedded := llmID != "" + // --- 4. Initialize reference --- + var reference []interface{} + if session != nil { + reference = s.initializeReference(session) + } + + // --- 5. LLM override --- + if genConfig == nil { + genConfig = map[string]interface{}{} + } if llmID != "" { hasKey, err := s.checkTenantLLMAPIKey(dialog.TenantID, llmID) if err != nil || !hasKey { - errMsg := fmt.Sprintf(`{"code": 500, "message": "Cannot use specified model %s", "data": {"answer": "**ERROR**: Cannot use specified model", "reference": []}}`, llmID) - streamChan <- fmt.Sprintf("data: %s\n\n", errMsg) - return fmt.Errorf("Cannot use specified model %s", llmID) + return fail(fmt.Errorf("Cannot use specified model %s", llmID)) } dialog.LLMID = llmID - if chatModelConfig != nil { - dialog.LLMSetting = chatModelConfig + dialog.LLMSetting = genConfig + } else if dialog.LLMID == "" { + tenant, err := dao.NewTenantDAO().GetByID(dialog.TenantID) + if err != nil || tenant.LLMID == "" { + return fail(errors.New("No default chat model for tenant.")) + } + dialog.LLMID = tenant.LLMID + if dialog.LLMSetting == nil { + dialog.LLMSetting = entity.JSONMap{} + } + for k, v := range genConfig { + dialog.LLMSetting[k] = v } } - // Perform streaming chat via shared RAG pipeline - kwargs := chatModelConfig if kwargs == nil { kwargs = map[string]interface{}{} } - resultChan, err := s.pipeline.AsyncChat(ctx, dialog, messages, true, kwargs) + for k, v := range genConfig { + kwargs[k] = v + } + + // --- 6. Run pipeline --- + resultChan, err := s.pipeline.AsyncChat(ctx, dialog, requestMsg, stream, kwargs) + if err != nil { + return fail(err) + } + + if stream && streamChan != nil { + var fullAnswer strings.Builder + var finalLegacyAnswer map[string]interface{} + + for result := range resultChan { + if result.Reference != nil && len(reference) > 0 { + reference[len(reference)-1] = result.Reference + } + + if legacy { + if result.Final { + if strings.Contains(result.Answer, "**ERROR**") { + ans := s.structureAnswer(session, result.Answer, messageID, sessionID, reference) + if chatID != "" { + ans["chat_id"] = chatID + } + sendOrCancel(fmt.Sprintf("data:%s\n\n", sseMarshalChunk(sanitizeJSONFloats(ans).(map[string]interface{}), chatID))) + } + finalLegacyAnswer = s.structureAnswer(session, result.Answer, messageID, sessionID, reference) + continue + } + if result.StartToThink { + fullAnswer.WriteString("") + } else if result.EndToThink { + fullAnswer.WriteString("") + } else if result.Answer != "" { + fullAnswer.WriteString(result.Answer) + } + if session != nil { + s.appendAssistantToSession(session, fullAnswer.String(), messageID) + } + ans := s.structureAnswer(session, fullAnswer.String(), messageID, sessionID, reference) + ans["start_to_think"] = nil + ans["end_to_think"] = nil + delete(ans, "start_to_think") + delete(ans, "end_to_think") + if chatID != "" { + ans["chat_id"] = chatID + } + sendOrCancel(fmt.Sprintf("data:%s\n\n", sseMarshalChunk(sanitizeJSONFloats(ans).(map[string]interface{}), chatID))) + } else { + if result.Final { + if strings.Contains(result.Answer, "**ERROR**") { + ans := s.structureAnswer(session, result.Answer, messageID, sessionID, reference) + if chatID != "" { + ans["chat_id"] = chatID + } + sendOrCancel(fmt.Sprintf("data:%s\n\n", sseMarshalChunk(sanitizeJSONFloats(ans).(map[string]interface{}), chatID))) + } + continue + } + if result.StartToThink { + fullAnswer.WriteString("") + } else if result.EndToThink { + fullAnswer.WriteString("") + } else if result.Answer != "" { + fullAnswer.WriteString(result.Answer) + } + if session != nil { + s.appendAssistantToSession(session, fullAnswer.String(), messageID) + } + ans := s.structureAnswer(session, result.Answer, messageID, sessionID, reference) + if chatID != "" { + ans["chat_id"] = chatID + } + sendOrCancel(fmt.Sprintf("data:%s\n\n", sseMarshalChunk(sanitizeJSONFloats(ans).(map[string]interface{}), chatID))) + } + } + if legacy && finalLegacyAnswer != nil { + finalLegacyAnswer["answer"] = fullAnswer.String() + delete(finalLegacyAnswer, "start_to_think") + delete(finalLegacyAnswer, "end_to_think") + finalChunk := sseWrapper{Code: 0, Message: "", Data: sanitizeJSONFloats(finalLegacyAnswer)} + sendOrCancel(fmt.Sprintf("data:%s\n\n", marshalJSONWithSpaces(finalChunk))) + } + + wrapper := sseWrapper{Code: 0, Message: "", Data: true} + sendOrCancel(fmt.Sprintf("data:%s\n\n", marshalJSONWithSpaces(wrapper))) + + // Persist session state (matches Python's update_by_id after loop) + if session != nil { + s.updateSessionMessages(session, s.getSessionMessagesAsSlice(session), reference) + } + } else { + var answer strings.Builder + var finalRef map[string]interface{} + for result := range resultChan { + if result.Answer != "" { + answer.WriteString(result.Answer) + } + if result.Reference != nil { + finalRef = result.Reference + } + } + ans := map[string]interface{}{ + "answer": answer.String(), + "reference": finalRef, + "final": true, + } + if session != nil { + result := s.structureAnswerWithConv(session, ans, messageID, sessionID, reference) + if chatID != "" { + result["chat_id"] = chatID + } + s.updateSessionMessages(session, s.getSessionMessagesAsSlice(session), reference) + return sanitizeJSONFloats(result).(map[string]interface{}), nil + } + ans["id"] = messageID + ans["session_id"] = sessionID + if chatID != "" { + ans["chat_id"] = chatID + } + return sanitizeJSONFloats(ans).(map[string]interface{}), nil + } + + return nil, nil +} + +// --- Helpers for ChatCompletions --- + +// normalizeCompletionMessages mirrors Python _normalize_completion_messages. +func (s *ChatSessionService) normalizeCompletionMessages( + messages []map[string]interface{}, question string, files []interface{}, +) (requestMessages []map[string]interface{}, requestMsg []map[string]interface{}, messageID string, err error) { + if len(messages) == 0 { + if question == "" { + return nil, nil, "", errors.New("required argument are missing: messages") + } + messages = []map[string]interface{}{{"role": "user", "content": question}} + if len(files) > 0 { + messages[0]["files"] = files + } + } + + requestMessages = make([]map[string]interface{}, len(messages)) + for i, m := range messages { + requestMessages[i] = make(map[string]interface{}) + for k, v := range m { + requestMessages[i][k] = v + } + } + + // Filter system and leading assistant messages + requestMsg = make([]map[string]interface{}, 0, len(messages)) + for _, m := range messages { + role, _ := m["role"].(string) + if role == "system" { + continue + } + if role == "assistant" && len(requestMsg) == 0 { + continue + } + requestMsg = append(requestMsg, m) + } + + if len(requestMsg) == 0 { + return nil, nil, "", errors.New("`messages` must contain a user message.") + } + lastRole, _ := requestMsg[len(requestMsg)-1]["role"].(string) + if lastRole != "user" { + return nil, nil, "", errors.New("The last content of this conversation is not from user.") + } + + // Generate message ID if missing — matches Python's get_uuid() in _normalize_completion_messages. + lastUserMsg := requestMsg[len(requestMsg)-1] + if id, ok := lastUserMsg["id"].(string); ok && id != "" { + messageID = id + } else { + messageID = strings.ReplaceAll(uuid.New().String(), "-", "") + lastUserMsg["id"] = messageID + for i := len(requestMessages) - 1; i >= 0; i-- { + if role, _ := requestMessages[i]["role"].(string); role == "user" { + requestMessages[i]["id"] = messageID + break + } + } + } + return requestMessages, requestMsg, messageID, nil +} + +// checkDialogOwnership checks if the user owns the dialog. +func (s *ChatSessionService) checkDialogOwnership(userID, chatID string) error { + ok, err := s.ensureOwnedChat(userID, chatID) if err != nil { - streamChan <- fmt.Sprintf("data: %s\n\n", fmt.Sprintf(`{"code": 500, "message": "%s", "data": {"answer": "**ERROR**: %s", "reference": []}}`, err.Error(), err.Error())) return err } - - // Stream results, accumulating the answer - var fullAnswer strings.Builder - for result := range resultChan { - if result.Reference != nil && len(reference) > 0 { - reference[len(reference)-1] = result.Reference - } - if result.Final { - if result.Answer != "" { - fullAnswer.Reset() - fullAnswer.WriteString(result.Answer) - } - } else if result.Answer != "" { - fullAnswer.WriteString(result.Answer) - } - ans := s.structureAnswer(session, fullAnswer.String(), messageID, session.ID, reference) - data, _ := json.Marshal(map[string]interface{}{ - "code": 0, - "message": "", - "data": ans, - }) - streamChan <- fmt.Sprintf("data: %s\n\n", string(data)) + if !ok { + return errors.New("No authorization.") } - - // Send final completion signal - finalData, _ := json.Marshal(map[string]interface{}{ - "code": 0, - "message": "", - "data": true, - }) - streamChan <- fmt.Sprintf("data: %s\n\n", string(finalData)) - - // Update conversation if not embedded - if !isEmbedded { - sessionMessages = append(sessionMessages, map[string]interface{}{ - "role": "assistant", - "content": fullAnswer.String(), - "id": messageID, - "created_at": float64(time.Now().Unix()), - }) - s.updateSessionMessages(session, sessionMessages, reference) - } - return nil } +// buildDefaultCompletionDialog mirrors Python _build_default_completion_dialog. +func (s *ChatSessionService) buildDefaultCompletionDialog(tenantID string) *entity.Chat { + return &entity.Chat{ + TenantID: tenantID, + LLMID: "", + LLMSetting: entity.JSONMap{}, + PromptConfig: entity.JSONMap{}, + KBIDs: entity.JSONSlice{}, + TopN: 6, + TopK: 1024, + RerankID: "", + SimilarityThreshold: 0.1, + VectorSimilarityWeight: 0.3, + } +} + +// createSessionForCompletion mirrors Python _create_session_for_completion. +func (s *ChatSessionService) createSessionForCompletion(chatID string, dialog *entity.Chat, userID string) (*entity.ChatSession, error) { + newID := common.GenerateUUID() + name := "New session" + + prologue := "Hi! I'm your assistant. What can I do for you?" + if dialog.PromptConfig != nil { + if p, ok := dialog.PromptConfig["prologue"].(string); ok && p != "" { + prologue = p + } + } + + msg := []map[string]interface{}{ + {"role": "assistant", "content": prologue}, + } + msgJSON, _ := json.Marshal(msg) + refJSON, _ := json.Marshal([]interface{}{}) + + session := &entity.ChatSession{ + ID: newID, + DialogID: chatID, + Name: &name, + Message: msgJSON, + UserID: &userID, + Reference: refJSON, + } + if err := s.chatSessionDAO.Create(session); err != nil { + return nil, err + } + return session, nil +} + +// appendSessionMessage appends the last user message to the session's message history. +func (s *ChatSessionService) appendSessionMessage(session *entity.ChatSession, requestMsg []map[string]interface{}) *entity.ChatSession { + msgs := parseMessages(session.Message) + msgs = append(msgs, requestMsg[len(requestMsg)-1]) + session.Message, _ = json.Marshal(msgs) + return session +} + +// filterSystemAndLeadingAssistant filters system messages and leading assistant messages from session history. +func (s *ChatSessionService) filterSystemAndLeadingAssistant(session *entity.ChatSession) []map[string]interface{} { + messages := parseMessages(session.Message) + var result []map[string]interface{} + for _, msg := range messages { + role, _ := msg["role"].(string) + if role == "system" { + continue + } + if role == "assistant" && len(result) == 0 { + continue + } + result = append(result, msg) + } + return result +} + +// appendAssistantToSession appends or updates the assistant message in session.Message. +func (s *ChatSessionService) appendAssistantToSession(session *entity.ChatSession, content string, messageID string) { + messages := parseMessages(session.Message) + if len(messages) == 0 || s.getLastRole(messages) != "assistant" { + messages = append(messages, map[string]interface{}{ + "role": "assistant", + "content": content, + "created_at": float64(time.Now().Unix()), + "id": messageID, + }) + } else { + lastIdx := len(messages) - 1 + messages[lastIdx]["content"] = content + messages[lastIdx]["created_at"] = float64(time.Now().Unix()) + messages[lastIdx]["id"] = messageID + } + session.Message, _ = json.Marshal(messages) +} + +// getSessionMessagesAsSlice returns the session's messages as a slice of maps. +func (s *ChatSessionService) getSessionMessagesAsSlice(session *entity.ChatSession) []map[string]interface{} { + if session == nil { + return nil + } + return parseMessages(session.Message) +} + +// sendSSEError sends an error in SSE format through the stream channel. +func (s *ChatSessionService) sendSSEError(streamChan chan<- string, errMsg string) { + wrapper := sseWrapper{ + Code: 500, + Message: errMsg, + Data: map[string]interface{}{ + "answer": "**ERROR**: " + errMsg, + "reference": []interface{}{}, + }, + } + streamChan <- fmt.Sprintf("data:%s\n\n", marshalJSONWithSpaces(wrapper)) +} + // Helper methods func (s *ChatSessionService) buildSessionMessages(session *entity.ChatSession, messages []map[string]interface{}) []map[string]interface{} { @@ -1241,29 +1339,186 @@ func (s *ChatSessionService) initializeReference(session *entity.ChatSession) [] } func (s *ChatSessionService) checkTenantLLMAPIKey(tenantID, modelName string) (bool, error) { - // Simplified check - in real implementation, check if tenant has API key for this model + _, err := NewTenantLLMService().GetAPIKeyFromInstance(tenantID, modelName) + if err != nil { + return false, err + } return true, nil } +// sseAnswerChunk has deterministic JSON field order matching Python's structure_answer output. +type sseAnswerChunk struct { + Answer string `json:"answer"` + Reference map[string]interface{} `json:"reference"` + AudioBinary interface{} `json:"audio_binary"` + Prompt string `json:"prompt"` + CreatedAt float64 `json:"created_at"` + Final bool `json:"final"` + ID string `json:"id"` + SessionID string `json:"session_id"` + ChatID string `json:"chat_id,omitempty"` +} + +// sseWrapper wraps the SSE response with deterministic field order matching Python: +// +// {"code": 0, "message": "", "data": ...} +type sseWrapper struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data"` +} + +// marshalJSONWithSpaces marshals v to JSON and adds spaces after ':' and ',' +// to match Python's json.dumps format. +func marshalJSONWithSpaces(v interface{}) string { + data, err := json.Marshal(v) + if err != nil { + return "{}" + } + return addJSONSpacesOutsideStrings(data) +} + +func addJSONSpacesOutsideStrings(data []byte) string { + var b strings.Builder + b.Grow(len(data) + 16) + inString := false + escaped := false + for _, c := range data { + b.WriteByte(c) + if escaped { + escaped = false + continue + } + if inString && c == '\\' { + escaped = true + continue + } + if c == '"' { + inString = !inString + continue + } + if !inString && (c == ':' || c == ',') { + b.WriteByte(' ') + } + } + return b.String() +} + +// sanitizeJSONFloats recursively replaces NaN/Infinity with nil. +// Matches Python's _sanitize_json_floats in chat_api.py. +func sanitizeJSONFloats(v interface{}) interface{} { + switch val := v.(type) { + case float64: + if math.IsNaN(val) || math.IsInf(val, 0) { + return nil + } + return val + case float32: + if math.IsNaN(float64(val)) || math.IsInf(float64(val), 0) { + return nil + } + return val + case map[string]interface{}: + out := make(map[string]interface{}, len(val)) + for k, vv := range val { + out[k] = sanitizeJSONFloats(vv) + } + return out + case []interface{}: + out := make([]interface{}, len(val)) + for i, vv := range val { + out[i] = sanitizeJSONFloats(vv) + } + return out + case []map[string]interface{}: + out := make([]map[string]interface{}, len(val)) + for i, item := range val { + sanitized, _ := sanitizeJSONFloats(item).(map[string]interface{}) + out[i] = sanitized + } + return out + default: + return v + } +} + +// sseMarshalChunk converts an answer map to the ordered sseAnswerChunk struct +// and marshals it with Python-compatible JSON formatting (spaces, field order). +func sseMarshalChunk(ans map[string]interface{}, chatID string) string { + ref, _ := ans["reference"].(map[string]interface{}) + if ref == nil { + ref = map[string]interface{}{"chunks": []interface{}{}} + } + answer, _ := ans["answer"].(string) + prompt, _ := ans["prompt"].(string) + id, _ := ans["id"].(string) + sessionID, _ := ans["session_id"].(string) + createdAt, _ := ans["created_at"].(float64) + final, _ := ans["final"].(bool) + + chunk := sseAnswerChunk{ + Answer: answer, + Reference: ref, + AudioBinary: ans["audio_binary"], + Prompt: prompt, + CreatedAt: createdAt, + Final: final, + ID: id, + SessionID: sessionID, + ChatID: chatID, + } + wrapper := sseWrapper{Code: 0, Message: "", Data: chunk} + return marshalJSONWithSpaces(wrapper) +} + func (s *ChatSessionService) structureAnswer(session *entity.ChatSession, answer string, messageID, conversationID string, reference []interface{}) map[string]interface{} { + // Match Python's structure_answer output: + // {"answer", "reference": {"chunks": [...]}, "audio_binary": null, "prompt": "", + // "created_at": ..., "final": false, "id": "...", "session_id": "..."} + refMap := map[string]interface{}{ + "chunks": []interface{}{}, + "doc_aggs": []interface{}{}, + } + if len(reference) > 0 { + if latest, ok := reference[len(reference)-1].(map[string]interface{}); ok && latest != nil { + refMap = latest + if _, ok := refMap["chunks"]; !ok { + refMap["chunks"] = []interface{}{} + } + } + } return map[string]interface{}{ - "answer": answer, - "reference": reference, - "conversation_id": conversationID, - "message_id": messageID, + "answer": answer, + "reference": refMap, + "audio_binary": nil, + "prompt": "", + "created_at": float64(time.Now().UnixNano()) / 1e9, + "final": false, + "id": messageID, + "session_id": conversationID, } } func (s *ChatSessionService) updateSessionMessages(session *entity.ChatSession, messages []map[string]interface{}, reference []interface{}) { - // Update session with new messages and reference - messagesJSON, _ := json.Marshal(messages) - referenceJSON, _ := json.Marshal(reference) + messagesJSON, err := json.Marshal(messages) + if err != nil { + common.Warn("updateSessionMessages: failed to marshal messages", zap.Error(err)) + return + } + referenceJSON, err := json.Marshal(reference) + if err != nil { + common.Warn("updateSessionMessages: failed to marshal reference", zap.Error(err)) + return + } updates := map[string]interface{}{ "message": messagesJSON, "reference": referenceJSON, } - s.chatSessionDAO.UpdateByID(session.ID, updates) + if err := s.chatSessionDAO.UpdateByID(session.ID, updates); err != nil { + common.Warn("updateSessionMessages: DAO update failed", zap.Error(err)) + return + } session.Message = messagesJSON session.Reference = referenceJSON } @@ -1340,22 +1595,43 @@ func (s *ChatSessionService) getLastRole(messages []map[string]interface{}) stri return role } -// chunksFormat formats chunks for reference (simplified version) +// chunksFormat normalizes chunk fields to a canonical schema (matching +// formatChunks in openai_chat.go and Python's chunks_format), returning +// []map[string]interface{} for JSON serialization. func (s *ChatSessionService) chunksFormat(reference map[string]interface{}) []map[string]interface{} { - switch c := reference["chunks"].(type) { - case []map[string]interface{}: - formatted := make([]map[string]interface{}, len(c)) - copy(formatted, c) - return formatted - case []interface{}: - formatted := make([]map[string]interface{}, 0, len(c)) - for _, item := range c { - if m, ok := item.(map[string]interface{}); ok { - formatted = append(formatted, m) + raw, ok := reference["chunks"].([]map[string]interface{}) + if !ok { + // Coerce []interface{} → []map[string]interface{} + if ifaces, ok2 := reference["chunks"].([]interface{}); ok2 { + raw = make([]map[string]interface{}, 0, len(ifaces)) + for _, item := range ifaces { + if m, ok3 := item.(map[string]interface{}); ok3 { + raw = append(raw, m) + } } } - return formatted - default: + } + if len(raw) == 0 { return []map[string]interface{}{} } + out := make([]map[string]interface{}, 0, len(raw)) + for _, chunk := range raw { + out = append(out, map[string]interface{}{ + "id": getValue(chunk, "chunk_id", "id"), + "content": getValue(chunk, "content_with_weight", "content"), + "document_id": getValue(chunk, "doc_id", "document_id"), + "document_name": getValue(chunk, "docnm_kwd", "document_name"), + "dataset_id": getValue(chunk, "kb_id", "dataset_id"), + "image_id": getValue(chunk, "image_id", "img_id"), + "positions": getValue(chunk, "positions", "position_int"), + "url": chunk["url"], + "similarity": sanitizeJSONFloats(chunk["similarity"]), + "vector_similarity": sanitizeJSONFloats(chunk["vector_similarity"]), + "term_similarity": sanitizeJSONFloats(chunk["term_similarity"]), + "row_id": chunk["row_id"], + "doc_type": getValue(chunk, "doc_type_kwd", "doc_type"), + "document_metadata": chunk["document_metadata"], + }) + } + return out } diff --git a/internal/service/chat_session_test.go b/internal/service/chat_session_test.go index 54d7525fbf..24cf4d91bf 100644 --- a/internal/service/chat_session_test.go +++ b/internal/service/chat_session_test.go @@ -26,7 +26,6 @@ type fakeSessionStore struct { getByIDErr error createErr error updateByIDErr error - deleteByIDErr error getDialogErr error // record calls createCalled []*entity.ChatSession @@ -34,7 +33,6 @@ type fakeSessionStore struct { id string updates map[string]interface{} } - deleteByIDIDs []string } func newFakeSessionStore() *fakeSessionStore { @@ -112,12 +110,6 @@ func (f *fakeSessionStore) UpdateByID(id string, updates map[string]interface{}) } func (f *fakeSessionStore) DeleteByID(id string) error { - f.mu.Lock() - defer f.mu.Unlock() - if f.deleteByIDErr != nil { - return f.deleteByIDErr - } - f.deleteByIDIDs = append(f.deleteByIDIDs, id) delete(f.sessions, id) return nil } @@ -179,226 +171,6 @@ func makeResultChan(results ...AsyncChatResult) <-chan AsyncChatResult { return ch } -// =================================================================== -// SetChatSession tests -// =================================================================== - -func TestSetChatSession_CreateNew(t *testing.T) { - store := newFakeSessionStore() - dialog := &entity.Chat{ID: "dialog-1", PromptConfig: entity.JSONMap{"prologue": "Welcome!"}} - store.dialogs["dialog-1"] = dialog - - svc := &ChatSessionService{ - chatSessionDAO: store, - userTenantDAO: &fakeTenantStore{}, - pipeline: &fakePipeline{}, - } - - resp, err := svc.SetChatSession("user-1", &SetChatSessionRequest{ - DialogID: "dialog-1", - IsNew: true, - }) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if resp.ID == "" { - t.Fatal("expected session ID to be generated") - } - if resp.DialogID != "dialog-1" { - t.Fatalf("expected dialog_id=dialog-1, got %s", resp.DialogID) - } - if len(store.createCalled) != 1 { - t.Fatalf("expected 1 Create call, got %d", len(store.createCalled)) - } - - // Verify prologue is in the message list. - var msgs []map[string]interface{} - if err := json.Unmarshal(store.createCalled[0].Message, &msgs); err != nil { - t.Fatalf("failed to unmarshal message: %v", err) - } - if len(msgs) != 1 { - t.Fatalf("expected 1 initial message, got %d", len(msgs)) - } - firstMsg := msgs[0] - if firstMsg["role"] != "assistant" || firstMsg["content"] != "Welcome!" { - t.Fatalf("unexpected prologue message: %#v", firstMsg) - } -} - -func TestSetChatSession_CreateNewDefaultPrologue(t *testing.T) { - store := newFakeSessionStore() - store.dialogs["dialog-1"] = &entity.Chat{ID: "dialog-1"} - - svc := &ChatSessionService{ - chatSessionDAO: store, - userTenantDAO: &fakeTenantStore{}, - pipeline: &fakePipeline{}, - } - - resp, err := svc.SetChatSession("user-1", &SetChatSessionRequest{ - DialogID: "dialog-1", - IsNew: true, - }) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if resp.ID == "" { - t.Fatal("expected session ID") - } - // Default prologue - var msgs []map[string]interface{} - json.Unmarshal(store.createCalled[0].Message, &msgs) - firstMsg := msgs[0] - if !strings.Contains(firstMsg["content"].(string), "Hi! I'm your assistant") { - t.Fatalf("expected default prologue, got %q", firstMsg["content"]) - } -} - -func TestSetChatSession_CreateNewDialogNotFound(t *testing.T) { - store := newFakeSessionStore() - - svc := &ChatSessionService{ - chatSessionDAO: store, - userTenantDAO: &fakeTenantStore{}, - pipeline: &fakePipeline{}, - } - - _, err := svc.SetChatSession("user-1", &SetChatSessionRequest{ - DialogID: "nonexistent", - IsNew: true, - }) - if err == nil || err.Error() != "Dialog not found" { - t.Fatalf("expected 'Dialog not found' error, got %v", err) - } -} - -func TestSetChatSession_UpdateExisting(t *testing.T) { - store := newFakeSessionStore() - store.sessions["session-1"] = &entity.ChatSession{ - ID: "session-1", DialogID: "dialog-1", Name: strPtr("old name"), - } - - svc := &ChatSessionService{ - chatSessionDAO: store, - userTenantDAO: &fakeTenantStore{}, - pipeline: &fakePipeline{}, - } - - resp, err := svc.SetChatSession("user-1", &SetChatSessionRequest{ - SessionID: "session-1", - Name: "new name", - IsNew: false, - }) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if resp.ID != "session-1" { - t.Fatalf("expected session-1, got %s", resp.ID) - } - if len(store.updateCalled) != 1 { - t.Fatalf("expected UpdateByID call, got %d", len(store.updateCalled)) - } -} - -func TestSetChatSession_UpdateNotFound(t *testing.T) { - store := newFakeSessionStore() - store.updateByIDErr = errors.New("Chat session not found") - - svc := &ChatSessionService{ - chatSessionDAO: store, - userTenantDAO: &fakeTenantStore{}, - pipeline: &fakePipeline{}, - } - - _, err := svc.SetChatSession("user-1", &SetChatSessionRequest{ - SessionID: "missing", - IsNew: false, - }) - if err == nil || err.Error() != "Chat session not found" { - t.Fatalf("expected 'Chat session not found' error, got %v", err) - } -} - -func TestSetChatSession_NameTruncation(t *testing.T) { - store := newFakeSessionStore() - store.dialogs["dialog-1"] = &entity.Chat{ID: "dialog-1"} - - svc := &ChatSessionService{ - chatSessionDAO: store, - userTenantDAO: &fakeTenantStore{}, - pipeline: &fakePipeline{}, - } - - longName := strings.Repeat("x", 300) - resp, err := svc.SetChatSession("user-1", &SetChatSessionRequest{ - DialogID: "dialog-1", - Name: longName, - IsNew: true, - }) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if resp.Name == nil || len(*resp.Name) > 255 { - t.Fatalf("expected name truncated to <=255, got len=%d", len(*resp.Name)) - } -} - -// =================================================================== -// RemoveChatSessions tests -// =================================================================== - -func TestRemoveChatSessions_Success(t *testing.T) { - store := newFakeSessionStore() - store.sessions["conv-1"] = &entity.ChatSession{ID: "conv-1", DialogID: "dialog-1"} - store.sessions["conv-2"] = &entity.ChatSession{ID: "conv-2", DialogID: "dialog-1"} - store.dialogExists["user-1|dialog-1"] = true - - svc := &ChatSessionService{ - chatSessionDAO: store, - userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-1"}}, - pipeline: &fakePipeline{}, - } - - err := svc.RemoveChatSessions("user-1", []string{"conv-1", "conv-2"}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(store.deleteByIDIDs) != 2 { - t.Fatalf("expected 2 deletes, got %d", len(store.deleteByIDIDs)) - } -} - -func TestRemoveChatSessions_SessionNotFound(t *testing.T) { - store := newFakeSessionStore() - svc := &ChatSessionService{ - chatSessionDAO: store, - userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-1"}}, - pipeline: &fakePipeline{}, - } - - err := svc.RemoveChatSessions("user-1", []string{"missing"}) - if err == nil || !strings.Contains(err.Error(), "not found") { - t.Fatalf("expected 'not found' error, got %v", err) - } -} - -func TestRemoveChatSessions_NotOwner(t *testing.T) { - store := newFakeSessionStore() - store.sessions["conv-1"] = &entity.ChatSession{ID: "conv-1", DialogID: "dialog-1"} - // No tenant matches — dialogExists stays false for all combinations - - svc := &ChatSessionService{ - chatSessionDAO: store, - userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-other"}}, - pipeline: &fakePipeline{}, - } - - err := svc.RemoveChatSessions("user-1", []string{"conv-1"}) - if err == nil || !strings.Contains(err.Error(), "Only owner") { - t.Fatalf("expected 'Only owner' error, got %v", err) - } -} - // =================================================================== // ListChatSessions tests // =================================================================== @@ -638,301 +410,6 @@ func TestUpdateSession_NotFound(t *testing.T) { } } -// =================================================================== -// Completion tests -// =================================================================== - -func TestCompletion_Success(t *testing.T) { - store := newFakeSessionStore() - session := &entity.ChatSession{ - ID: "session-1", DialogID: "dialog-1", - Message: json.RawMessage(`[{"role":"assistant","content":"Welcome!"}]`), - Reference: json.RawMessage(`[]`), - } - store.sessions["session-1"] = session - store.dialogs["dialog-1"] = &entity.Chat{ - ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory", - LLMSetting: entity.JSONMap{}, - } - - pipeline := &fakePipeline{ - resultChan: makeResultChan( - AsyncChatResult{Answer: "Hello", Reference: map[string]interface{}{"chunks": []interface{}{}}}, - AsyncChatResult{Answer: " world", Final: true, Reference: map[string]interface{}{"chunks": []interface{}{}}}, - ), - } - - svc := &ChatSessionService{ - chatSessionDAO: store, - userTenantDAO: &fakeTenantStore{}, - pipeline: pipeline, - } - - result, err := svc.Completion("user-1", "session-1", []map[string]interface{}{ - {"role": "user", "content": "hi"}, - }, "", nil, "msg-1") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - ans, _ := result["answer"].(string) - if ans != "Hello world" { - t.Fatalf("expected answer 'Hello world', got %q", ans) - } - - got := parseMessages(store.sessions["session-1"].Message) - if len(got) != 3 { - t.Fatalf("stored messages=%#v", got) - } - if got[0]["role"] != "assistant" || got[0]["content"] != "Welcome!" { - t.Fatalf("stored prologue=%#v", got[0]) - } - if got[1]["role"] != "user" || got[1]["content"] != "hi" { - t.Fatalf("stored user message=%#v", got[1]) - } - if got[2]["role"] != "assistant" || got[2]["content"] != "Hello world" || got[2]["id"] != "msg-1" { - t.Fatalf("stored assistant message=%#v", got[2]) - } -} - -func TestCompletion_EmptyMessages(t *testing.T) { - svc := &ChatSessionService{ - chatSessionDAO: &fakeSessionStore{}, - userTenantDAO: &fakeTenantStore{}, - pipeline: &fakePipeline{}, - } - - _, err := svc.Completion("user-1", "session-1", nil, "", nil, "msg-1") - if err == nil || err.Error() != "messages cannot be empty" { - t.Fatalf("expected 'messages cannot be empty', got %v", err) - } -} - -func TestCompletion_LastMessageNotFromUser(t *testing.T) { - svc := &ChatSessionService{ - chatSessionDAO: &fakeSessionStore{}, - userTenantDAO: &fakeTenantStore{}, - pipeline: &fakePipeline{}, - } - - _, err := svc.Completion("user-1", "session-1", []map[string]interface{}{ - {"role": "assistant", "content": "hello"}, - }, "", nil, "msg-1") - if err == nil || !strings.Contains(err.Error(), "not from user") { - t.Fatalf("expected 'not from user' error, got %v", err) - } -} - -func TestCompletion_ConversationNotFound(t *testing.T) { - store := newFakeSessionStore() - - svc := &ChatSessionService{ - chatSessionDAO: store, - userTenantDAO: &fakeTenantStore{}, - pipeline: &fakePipeline{}, - } - - _, err := svc.Completion("user-1", "missing", []map[string]interface{}{ - {"role": "user", "content": "hi"}, - }, "", nil, "msg-1") - if err == nil || err.Error() != "Conversation not found" { - t.Fatalf("expected 'Conversation not found', got %v", err) - } -} - -func TestCompletion_DialogNotFound(t *testing.T) { - store := newFakeSessionStore() - store.sessions["session-1"] = &entity.ChatSession{ - ID: "session-1", DialogID: "dialog-1", - Message: json.RawMessage(`[]`), - Reference: json.RawMessage(`[]`), - } - - svc := &ChatSessionService{ - chatSessionDAO: store, - userTenantDAO: &fakeTenantStore{}, - pipeline: &fakePipeline{}, - } - - _, err := svc.Completion("user-1", "session-1", []map[string]interface{}{ - {"role": "user", "content": "hi"}, - }, "", nil, "msg-1") - if err == nil || err.Error() != "Dialog not found" { - t.Fatalf("expected 'Dialog not found', got %v", err) - } -} - -func TestCompletion_PipelineError(t *testing.T) { - store := newFakeSessionStore() - store.sessions["session-1"] = &entity.ChatSession{ - ID: "session-1", DialogID: "dialog-1", - Message: json.RawMessage(`[]`), - Reference: json.RawMessage(`[]`), - } - store.dialogs["dialog-1"] = &entity.Chat{ - ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory", - LLMSetting: entity.JSONMap{}, - } - - svc := &ChatSessionService{ - chatSessionDAO: store, - userTenantDAO: &fakeTenantStore{}, - pipeline: &fakePipeline{err: errors.New("model unavailable")}, - } - - _, err := svc.Completion("user-1", "session-1", []map[string]interface{}{ - {"role": "user", "content": "hi"}, - }, "", nil, "msg-1") - if err == nil || err.Error() != "model unavailable" { - t.Fatalf("expected 'model unavailable' error, got %v", err) - } -} - -// =================================================================== -// CompletionStream tests -// =================================================================== - -func readStreamChan(ch <-chan string, n int) []string { - var msgs []string - for i := 0; i < n; i++ { - select { - case msg, ok := <-ch: - if !ok { - return msgs - } - msgs = append(msgs, msg) - default: - return msgs - } - } - return msgs -} - -func TestCompletionStream_Success(t *testing.T) { - store := newFakeSessionStore() - store.sessions["session-1"] = &entity.ChatSession{ - ID: "session-1", DialogID: "dialog-1", - Message: json.RawMessage(`{"messages":[{"role":"assistant","content":"Welcome!"}]}`), - Reference: json.RawMessage(`[]`), - } - store.dialogs["dialog-1"] = &entity.Chat{ - ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory", - LLMSetting: entity.JSONMap{}, - } - - pipeline := &fakePipeline{ - resultChan: makeResultChan( - AsyncChatResult{Answer: "stream", Reference: map[string]interface{}{"chunks": []interface{}{}}}, - AsyncChatResult{Answer: " answer", Reference: map[string]interface{}{"chunks": []interface{}{}}}, - ), - } - - svc := &ChatSessionService{ - chatSessionDAO: store, - userTenantDAO: &fakeTenantStore{}, - pipeline: pipeline, - } - - streamChan := make(chan string, 10) - err := svc.CompletionStream(context.Background(), "user-1", "session-1", []map[string]interface{}{ - {"role": "user", "content": "hi"}, - }, "", nil, "msg-1", streamChan) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - // Should receive data events and final signal - msgs := readStreamChan(streamChan, 5) - if len(msgs) < 3 { - t.Fatalf("expected at least 3 stream messages, got %d: %v", len(msgs), msgs) - } - // Check final signal - finalFound := false - for _, m := range msgs { - if strings.Contains(m, `"data":true`) { - finalFound = true - break - } - } - if !finalFound { - t.Fatal("expected final=true signal in stream") - } - - got := parseMessages(store.sessions["session-1"].Message) - if len(got) != 3 { - t.Fatalf("stored messages=%#v", got) - } - if got[0]["role"] != "assistant" || got[0]["content"] != "Welcome!" { - t.Fatalf("stored prologue=%#v", got[0]) - } - if got[1]["role"] != "user" || got[1]["content"] != "hi" { - t.Fatalf("stored user message=%#v", got[1]) - } - if got[2]["role"] != "assistant" || got[2]["content"] != "stream answer" || got[2]["id"] != "msg-1" { - t.Fatalf("stored assistant message=%#v", got[2]) - } -} - -func TestStructureAnswerWithConv_ParsesArrayMessages(t *testing.T) { - session := &entity.ChatSession{ - ID: "session-1", - Message: json.RawMessage(`[{"role":"assistant","content":"Welcome!"}]`), - } - svc := &ChatSessionService{} - - ans := svc.structureAnswerWithConv(session, map[string]interface{}{ - "answer": "Final answer", - "reference": map[string]interface{}{"chunks": []interface{}{}}, - "final": true, - }, "msg-1", "session-1", []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) - - if ans["id"] != "msg-1" || ans["session_id"] != "session-1" { - t.Fatalf("ans=%#v", ans) - } - - got := parseMessages(session.Message) - if len(got) != 1 { - t.Fatalf("stored messages=%#v", got) - } - if got[0]["role"] != "assistant" || got[0]["content"] != "Final answer" || got[0]["id"] != "msg-1" { - t.Fatalf("stored assistant message=%#v", got[0]) - } -} - -func TestParseMessages_LegacyWrappedObject(t *testing.T) { - got := parseMessages(json.RawMessage(`{"messages":[{"role":"assistant","content":"legacy"}]}`)) - if !reflect.DeepEqual(got, []map[string]interface{}{{"role": "assistant", "content": "legacy"}}) { - t.Fatalf("messages=%#v", got) - } -} - -func TestBuildSessionPayload_EmptyCollectionsEncodeAsEmptyArrays(t *testing.T) { - svc := &ChatSessionService{} - payload := svc.buildSessionPayload(&entity.ChatSession{ - ID: "session-1", - DialogID: "chat-1", - Message: nil, - Reference: json.RawMessage(`null`), - }, nil, false) - - if payload.Messages == nil { - t.Fatal("messages is nil") - } - if payload.Reference == nil { - t.Fatal("reference is nil") - } - - body, err := json.Marshal(payload) - if err != nil { - t.Fatalf("Marshal failed: %v", err) - } - if !strings.Contains(string(body), `"messages":[]`) { - t.Fatalf("messages did not encode as empty array: %s", string(body)) - } - if !strings.Contains(string(body), `"reference":[]`) { - t.Fatalf("reference did not encode as empty array: %s", string(body)) - } -} - func TestParseCollections_ReturnEmptySlicesForMissingOrNull(t *testing.T) { messageInputs := []json.RawMessage{ nil, @@ -980,99 +457,174 @@ func TestParseCollections_ReturnNilForMalformedData(t *testing.T) { } } -func TestCompletionStream_EmptyMessages(t *testing.T) { - svc := &ChatSessionService{ - chatSessionDAO: &fakeSessionStore{}, - userTenantDAO: &fakeTenantStore{}, - pipeline: &fakePipeline{}, +// =================================================================== +// chunksFormat tests — verifies field normalization after the rewrite. +// =================================================================== + +func TestChunksFormat_NormalizesRawFieldNames(t *testing.T) { + svc := &ChatSessionService{} + ref := map[string]interface{}{ + "chunks": []map[string]interface{}{ + { + "chunk_id": "c1", + "content_with_weight": "hello world", + "content_ltks": "hello world ltks", + "doc_id": "d1", + "docnm_kwd": "Document 1", + "kb_id": "kb1", + "image_id": "img1", + "img_id": "img2", + "positions": []int{0, 10}, + "position_int": []int{1, 11}, + "doc_type_kwd": "pdf", + "similarity": 0.95, + "vector_similarity": 0.9, + "term_similarity": 0.85, + "row_id": "r1", + "url": "http://example.com", + "document_metadata": map[string]interface{}{"author": "Alice"}, + }, + }, } - streamChan := make(chan string, 10) - err := svc.CompletionStream(context.Background(), "user-1", "session-1", nil, "", nil, "msg-1", streamChan) - if err == nil || err.Error() != "messages cannot be empty" { - t.Fatalf("expected 'messages cannot be empty', got %v", err) + result := svc.chunksFormat(ref) + if len(result) != 1 { + t.Fatalf("expected 1 chunk, got %d", len(result)) + } + c := result[0] + + if c["id"] != "c1" { + t.Fatalf("id=%v", c["id"]) + } + if c["content"] != "hello world" { + t.Fatalf("content=%v", c["content"]) + } + if c["document_id"] != "d1" { + t.Fatalf("document_id=%v", c["document_id"]) + } + if c["document_name"] != "Document 1" { + t.Fatalf("document_name=%v", c["document_name"]) + } + if c["dataset_id"] != "kb1" { + t.Fatalf("dataset_id=%v", c["dataset_id"]) + } + if c["image_id"] != "img1" { + t.Fatalf("image_id=%v", c["image_id"]) + } + if c["doc_type"] != "pdf" { + t.Fatalf("doc_type=%v", c["doc_type"]) + } + if c["similarity"] != 0.95 { + t.Fatalf("similarity=%v", c["similarity"]) + } + if c["url"] != "http://example.com" { + t.Fatalf("url=%v", c["url"]) + } + + pos, ok := c["positions"].([]int) + if !ok || len(pos) != 2 || pos[0] != 0 { + t.Fatalf("positions=%v (%T)", c["positions"], c["positions"]) + } + + // Raw keys must be normalized away. + if _, exists := c["content_with_weight"]; exists { + t.Fatal("content_with_weight should not be present after normalization") + } + if _, exists := c["content_ltks"]; exists { + t.Fatal("content_ltks should not be present after normalization") } } -func TestCompletionStream_LastMessageNotFromUser(t *testing.T) { - svc := &ChatSessionService{ - chatSessionDAO: &fakeSessionStore{}, - userTenantDAO: &fakeTenantStore{}, - pipeline: &fakePipeline{}, +func TestChunksFormat_PreservesAlreadyNormalizedFields(t *testing.T) { + svc := &ChatSessionService{} + ref := map[string]interface{}{ + "chunks": []map[string]interface{}{ + { + "id": "c2", + "content": "already normalized", + "document_id": "d2", + "document_name": "Doc 2", + }, + }, } - streamChan := make(chan string, 10) - err := svc.CompletionStream(context.Background(), "user-1", "session-1", []map[string]interface{}{ - {"role": "assistant", "content": "hello"}, - }, "", nil, "msg-1", streamChan) - if err == nil || !strings.Contains(err.Error(), "not from user") { - t.Fatalf("expected 'not from user' error, got %v", err) + result := svc.chunksFormat(ref) + if len(result) != 1 { + t.Fatalf("expected 1 chunk, got %d", len(result)) + } + c := result[0] + if c["id"] != "c2" { + t.Fatalf("id=%v", c["id"]) + } + if c["content"] != "already normalized" { + t.Fatalf("content=%v", c["content"]) } } -func TestCompletionStream_ConversationNotFound(t *testing.T) { - store := newFakeSessionStore() - svc := &ChatSessionService{ - chatSessionDAO: store, - userTenantDAO: &fakeTenantStore{}, - pipeline: &fakePipeline{}, - } +func TestChunksFormat_EmptyReference(t *testing.T) { + svc := &ChatSessionService{} - streamChan := make(chan string, 10) - err := svc.CompletionStream(context.Background(), "user-1", "missing", []map[string]interface{}{ - {"role": "user", "content": "hi"}, - }, "", nil, "msg-1", streamChan) - if err == nil || err.Error() != "Conversation not found" { - t.Fatalf("expected 'Conversation not found', got %v", err) + if n := len(svc.chunksFormat(nil)); n != 0 { + t.Fatalf("nil ref: expected 0, got %d", n) + } + if n := len(svc.chunksFormat(map[string]interface{}{})); n != 0 { + t.Fatalf("empty ref: expected 0, got %d", n) + } + if n := len(svc.chunksFormat(map[string]interface{}{"chunks": nil})); n != 0 { + t.Fatalf("nil chunks: expected 0, got %d", n) + } + if n := len(svc.chunksFormat(map[string]interface{}{"chunks": []map[string]interface{}{}})); n != 0 { + t.Fatalf("empty chunks: expected 0, got %d", n) } } -func TestCompletionStream_DialogNotFound(t *testing.T) { - store := newFakeSessionStore() - store.sessions["session-1"] = &entity.ChatSession{ - ID: "session-1", DialogID: "dialog-1", - Message: json.RawMessage(`[]`), - Reference: json.RawMessage(`[]`), +func TestChunksFormat_ChunksAsInterfaceSlice(t *testing.T) { + svc := &ChatSessionService{} + ref := map[string]interface{}{ + "chunks": []interface{}{ + map[string]interface{}{ + "chunk_id": "c3", + "content_with_weight": "from interface slice", + }, + }, } - - svc := &ChatSessionService{ - chatSessionDAO: store, - userTenantDAO: &fakeTenantStore{}, - pipeline: &fakePipeline{}, + result := svc.chunksFormat(ref) + if len(result) != 1 { + t.Fatalf("expected 1 chunk, got %d", len(result)) } - - streamChan := make(chan string, 10) - err := svc.CompletionStream(context.Background(), "user-1", "session-1", []map[string]interface{}{ - {"role": "user", "content": "hi"}, - }, "", nil, "msg-1", streamChan) - if err == nil || err.Error() != "Dialog not found" { - t.Fatalf("expected 'Dialog not found', got %v", err) + if result[0]["id"] != "c3" { + t.Fatalf("id=%v", result[0]["id"]) + } + if result[0]["content"] != "from interface slice" { + t.Fatalf("content=%v", result[0]["content"]) } } -func TestCompletionStream_PipelineError(t *testing.T) { - store := newFakeSessionStore() - store.sessions["session-1"] = &entity.ChatSession{ - ID: "session-1", DialogID: "dialog-1", - Message: json.RawMessage(`[]`), - Reference: json.RawMessage(`[]`), +func TestChunksFormat_IgnoresNonMapItems(t *testing.T) { + svc := &ChatSessionService{} + ref := map[string]interface{}{ + "chunks": []interface{}{ + "not a map", + map[string]interface{}{ + "chunk_id": "c4", + "content_with_weight": "valid chunk", + }, + }, } - store.dialogs["dialog-1"] = &entity.Chat{ - ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory", - LLMSetting: entity.JSONMap{}, + result := svc.chunksFormat(ref) + if len(result) != 1 { + t.Fatalf("expected 1 chunk (non-maps skipped), got %d", len(result)) } - - svc := &ChatSessionService{ - chatSessionDAO: store, - userTenantDAO: &fakeTenantStore{}, - pipeline: &fakePipeline{err: errors.New("model unavailable")}, - } - - streamChan := make(chan string, 10) - err := svc.CompletionStream(context.Background(), "user-1", "session-1", []map[string]interface{}{ - {"role": "user", "content": "hi"}, - }, "", nil, "msg-1", streamChan) - if err == nil || err.Error() != "model unavailable" { - t.Fatalf("expected 'model unavailable' error, got %v", err) + if result[0]["id"] != "c4" { + t.Fatalf("id=%v", result[0]["id"]) + } +} + +func TestChunksFormat_UnsupportedTypeReturnsEmpty(t *testing.T) { + svc := &ChatSessionService{} + ref := map[string]interface{}{"chunks": "not a slice"} + result := svc.chunksFormat(ref) + if len(result) != 0 { + t.Fatalf("expected empty for string type, got %d", len(result)) } } diff --git a/internal/service/dataset.go b/internal/service/dataset.go index c680c5fd07..041c092f2e 100644 --- a/internal/service/dataset.go +++ b/internal/service/dataset.go @@ -1980,7 +1980,7 @@ func (s *DatasetService) SearchDatasets(req *SearchDatasetsRequest, userID strin // Apply meta_data_filter to get filtered doc_ids docIDs := make([]string, len(req.DocIDs)) copy(docIDs, req.DocIDs) - if metadataFilter != nil { + if len(metadataFilter) > 0 { metadataSvc := NewMetadataService() flattedMeta, err := metadataSvc.GetFlattedMetaByKBs(datasetIDs) if err != nil { diff --git a/internal/service/openai_chat.go b/internal/service/openai_chat.go index c076e770a0..f298277924 100644 --- a/internal/service/openai_chat.go +++ b/internal/service/openai_chat.go @@ -644,16 +644,16 @@ func formatChunks(chunks []map[string]interface{}) []FormattedChunk { for _, chunk := range chunks { out = append(out, FormattedChunk{ ID: strVal(getValue(chunk, "chunk_id", "id")), - Content: strVal(getValue(chunk, "content", "content_with_weight")), + Content: strVal(getValue(chunk, "content_with_weight", "content")), DocumentID: strVal(getValue(chunk, "doc_id", "document_id")), DocumentName: strVal(getValue(chunk, "docnm_kwd", "document_name")), DatasetID: strVal(getValue(chunk, "kb_id", "dataset_id")), ImageID: strVal(getValue(chunk, "image_id", "img_id")), Positions: getValue(chunk, "positions", "position_int"), URL: chunk["url"], - Similarity: chunk["similarity"], - VectorSimilarity: chunk["vector_similarity"], - TermSimilarity: chunk["term_similarity"], + Similarity: sanitizeJSONFloats(chunk["similarity"]), + VectorSimilarity: sanitizeJSONFloats(chunk["vector_similarity"]), + TermSimilarity: sanitizeJSONFloats(chunk["term_similarity"]), RowID: chunk["row_id"], DocType: getValue(chunk, "doc_type_kwd", "doc_type"), DocumentMetadata: chunk["document_metadata"], diff --git a/internal/service/think_tag.go b/internal/service/think_tag.go index 93e04271d6..582238f66d 100644 --- a/internal/service/think_tag.go +++ b/internal/service/think_tag.go @@ -18,40 +18,41 @@ package service import ( "context" + "ragflow/internal/tokenizer" "strings" ) const thinkOpen = "" const thinkClose = "" +var stripThinkReplacer = strings.NewReplacer("", "", "", "") + // ThinkStreamState holds accumulated state across streaming LLM chunks -// so that ... tags can be surfaced as structured markers. -// -// Corresponds to _ThinkStreamState in api/db/services/dialog_service.py. +// so that ... tags can be surfaced as structured markers type ThinkStreamState struct { // fullText accumulates all text received so far. fullText string - // lastIdx is the last consumed position in fullText. - lastIdx int - // lastFull is the previous fullText snapshot. - lastFull string - // lastModelFull is the previous model chunk for diffing. + // lastModelFull is the previous model-full snapshot for diffing lastModelFull string // inThink is true when we are currently inside a block. inThink bool - // buffer accumulates visible text before flushing (for batching). - buffer string - // postThinkText holds text between and the next or end - // of delta. Kept for API alignment with Python; may be used by future - // callers that need per-delta visibility into think boundaries. - postThinkText string + // closePending defers emission of when no visible text follows the tag + closePending bool + // pendingAfterClose collects text received after a deferred + pendingAfterClose string + // thinkBuffer is the think-buffer + thinkBuffer string + // answerBuffer accumulates answer-side text before token-batch flushing + answerBuffer string + // carry holds text at the end of a chunk that may be a partial or prefix. + carry string } // ThinkDeltaKind describes the type of a think-tag delta event. type ThinkDeltaKind int const ( - ThinkDeltaText ThinkDeltaKind = iota // visible answer text + ThinkDeltaText ThinkDeltaKind = iota // think-side or answer-side text ThinkDeltaMarker // or tag boundary ) @@ -61,93 +62,239 @@ type ThinkDelta struct { Value string } +// emitText returns the batched text and its kind. +func emitText(state *ThinkStreamState, section string, text string, minTokens int) (string, ThinkDeltaKind) { + if text == "" { + return "", 0 + } + if section == "think" { + state.thinkBuffer += text + if tokenizer.NumTokensFromString(state.thinkBuffer) >= minTokens { + out := state.thinkBuffer + state.thinkBuffer = "" + return out, ThinkDeltaText + } + return "", 0 + } + state.answerBuffer += text + if tokenizer.NumTokensFromString(state.answerBuffer) >= minTokens { + out := state.answerBuffer + state.answerBuffer = "" + return out, ThinkDeltaText + } + return "", 0 +} + +func flushThinkBufferInternal(state *ThinkStreamState) ThinkDelta { + if state.thinkBuffer == "" { + return ThinkDelta{} + } + out := state.thinkBuffer + state.thinkBuffer = "" + return ThinkDelta{Kind: ThinkDeltaText, Value: out} +} + +func flushAnswerBufferInternal(state *ThinkStreamState) ThinkDelta { + if state.answerBuffer == "" { + return ThinkDelta{} + } + out := state.answerBuffer + state.answerBuffer = "" + return ThinkDelta{Kind: ThinkDeltaText, Value: out} +} + +func stripThinkTags(s string) string { + if s == "" { + return "" + } + return stripThinkReplacer.Replace(s) +} + +// tagPrefixLen returns the length of the longest suffix of s that could be a +// PARTIAL start of "" or "". Returns 0 if the suffix is a complete +// tag or no prefix match exists. +func tagPrefixLen(s string) int { + for i := 0; i < len(s); i++ { + sub := s[i:] + if sub == thinkOpen || sub == thinkClose { + return 0 // complete tag, not a partial prefix + } + if strings.HasPrefix(thinkOpen, sub) || strings.HasPrefix(thinkClose, sub) { + return len(sub) + } + } + return 0 +} + // NextThinkDelta processes the next chunk of LLM output and returns any // visible text or tag boundary markers that should be emitted. -// -// Pure function — no side effects beyond updating state. -func NextThinkDelta(state *ThinkStreamState, chunk string) []ThinkDelta { +func NextThinkDelta(state *ThinkStreamState, chunk string, minTokens int) []ThinkDelta { if state == nil { return nil } - if state.lastFull != "" { - // Compute the delta: what's new since lastFull. - delta := strings.TrimPrefix(chunk, state.lastFull) - state.lastModelFull = delta - } else { + var newPart string + if strings.HasPrefix(chunk, state.lastModelFull) { + newPart = chunk[len(state.lastModelFull):] state.lastModelFull = chunk + } else { + newPart = chunk + state.lastModelFull += chunk } - state.lastFull = chunk - - // Accumulate fullText from the delta. - state.fullText += state.lastModelFull - - // Extract new content since lastIdx. - newPart := state.fullText[state.lastIdx:] - if len(newPart) == 0 { + if newPart == "" { return nil } + state.fullText += newPart + + // Prepend carry from previous chunk that may complete a partial tag. + pending := state.carry + newPart + state.carry = "" + + // Check if pending ends with a partial or prefix. + // Save it as carry so it isn't emitted as visible text. + if n := tagPrefixLen(pending); n > 0 { + state.carry = pending[len(pending)-n:] + pending = pending[:len(pending)-n] + } var deltas []ThinkDelta - // Process character by character to detect tag boundaries. - for len(newPart) > 0 { - if !state.inThink { - idx := strings.Index(newPart, thinkOpen) - if idx < 0 { - // No more think open — buffer everything as visible text. - state.buffer += newPart - state.lastIdx += len(newPart) - break - } - // Text before is visible answer. - if idx > 0 { - state.buffer += newPart[:idx] - } - deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkOpen}) - newPart = newPart[idx+len(thinkOpen):] - state.lastIdx += idx + len(thinkOpen) - state.inThink = true - } else { - idx := strings.Index(newPart, thinkClose) - if idx < 0 { - // Still inside think, consume all silently. - state.lastIdx += len(newPart) - break - } - deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkClose}) - state.postThinkText = newPart[:idx] - newPart = newPart[idx+len(thinkClose):] - state.lastIdx += idx + len(thinkClose) - state.inThink = false + + // Phase 1: handle deferred from a previous chunk. + if state.closePending { + state.closePending = false + if piece := flushThinkBufferInternal(state); piece.Value != "" { + deltas = append(deltas, piece) } + state.inThink = false + deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkClose}) + if state.pendingAfterClose != "" { + pending = state.pendingAfterClose + pending + state.pendingAfterClose = "" + } + } + + // Phase 2: process pending text for think tags. + for pending != "" { + openIdx := strings.Index(pending, thinkOpen) + closeIdx := strings.Index(pending, thinkClose) + + // No tags remaining — emit to the appropriate section. + if openIdx == -1 && closeIdx == -1 { + if piece := stripThinkTags(pending); piece != "" { + section := "answer" + if state.inThink { + section = "think" + } + if out, kind := emitText(state, section, piece, minTokens); out != "" { + deltas = append(deltas, ThinkDelta{Kind: kind, Value: out}) + } + } + break + } + + // appears first (or no found). + if openIdx != -1 && (closeIdx == -1 || openIdx < closeIdx) { + before := pending[:openIdx] + if before != "" { + piece := stripThinkTags(before) + section := "answer" + if state.inThink { + section = "think" + } + if out, kind := emitText(state, section, piece, minTokens); out != "" { + deltas = append(deltas, ThinkDelta{Kind: kind, Value: out}) + } + } + pending = pending[openIdx+len(thinkOpen):] + if !state.inThink { + if answerPiece := flushAnswerBufferInternal(state); answerPiece.Value != "" { + deltas = append(deltas, answerPiece) + } + if thinkPiece := flushThinkBufferInternal(state); thinkPiece.Value != "" { + deltas = append(deltas, thinkPiece) + } + state.inThink = true + deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkOpen}) + } + continue + } + + // appears first. + before := pending[:closeIdx] + after := pending[closeIdx+len(thinkClose):] + if before != "" { + piece := stripThinkTags(before) + section := "answer" + if state.inThink { + section = "think" + } + if out, kind := emitText(state, section, piece, minTokens); out != "" { + deltas = append(deltas, ThinkDelta{Kind: kind, Value: out}) + } + } + if strings.TrimSpace(after) != "" { + if thinkPiece := flushThinkBufferInternal(state); thinkPiece.Value != "" { + deltas = append(deltas, thinkPiece) + } + state.inThink = false + deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkClose}) + pending = after + continue + } + // No visible text after close — defer the marker. + state.closePending = true + if after != "" { + state.pendingAfterClose += after + } + pending = "" + break } return deltas } -// FlushThinkBuffer drains the buffered visible text, if any, as a single delta. -// Call this after all LLM chunks have been processed. -func FlushThinkBuffer(state *ThinkStreamState) []ThinkDelta { - if state == nil || state.buffer == "" { +// FlushRemaining drains all remaining buffered text and handles deferred +// markers. Call this after all LLM chunks have been processed. +func FlushRemaining(state *ThinkStreamState) []ThinkDelta { + if state == nil { return nil } - text := state.buffer - state.buffer = "" - return []ThinkDelta{{Kind: ThinkDeltaText, Value: text}} + var deltas []ThinkDelta + if state.thinkBuffer != "" { + deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaText, Value: state.thinkBuffer}) + state.thinkBuffer = "" + } + if state.closePending { + state.inThink = false + state.closePending = false + deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkClose}) + } + if state.answerBuffer != "" { + deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaText, Value: state.answerBuffer}) + state.answerBuffer = "" + } + if state.pendingAfterClose != "" { + deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaText, Value: state.pendingAfterClose}) + state.pendingAfterClose = "" + } + if state.carry != "" { + deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaText, Value: state.carry}) + state.carry = "" + } + return deltas } +// StreamThinkTagDelta — channel-based pipeline. +// --------------------------------------------------------------------------- + // StreamThinkTagDelta takes a channel of raw LLM text chunks and produces a -// channel of (kind, value) pairs. When ctx is cancelled (e.g. client -// disconnect), the goroutine drains the input channel silently and exits, -// preventing the producer goroutine from blocking forever on send. -// -// Markers (, ) are emitted immediately without buffering. +// channel of structured deltas. When ctx is cancelled (e.g. client +// disconnect), the goroutine drains the input channel silently and exits. func StreamThinkTagDelta(ctx context.Context, chunks <-chan string, minTokens int) <-chan ThinkDelta { out := make(chan ThinkDelta, 32) go func() { defer close(out) state := &ThinkStreamState{} - flushSize := minTokens * 4 // approximate: ~4 bytes per token for { select { case <-ctx.Done(): @@ -158,7 +305,7 @@ func StreamThinkTagDelta(ctx context.Context, chunks <-chan string, minTokens in return case chunk, ok := <-chunks: if !ok { - for _, d := range FlushThinkBuffer(state) { + for _, d := range FlushRemaining(state) { select { case out <- d: case <-ctx.Done(): @@ -167,34 +314,16 @@ func StreamThinkTagDelta(ctx context.Context, chunks <-chan string, minTokens in } return } - deltas := NextThinkDelta(state, chunk) + deltas := NextThinkDelta(state, chunk, minTokens) for _, d := range deltas { - if d.Kind == ThinkDeltaMarker { - select { - case out <- d: - case <-ctx.Done(): - go func() { - for range chunks { - } - }() - return - } - } - } - // Flush buffered visible text when it reaches the token threshold, - // matching Python _stream_with_think_delta which yields ("text", ...) - // per chunk. Markers are emitted immediately above. - if len(state.buffer) >= flushSize { - for _, d := range FlushThinkBuffer(state) { - select { - case out <- d: - case <-ctx.Done(): - go func() { - for range chunks { - } - }() - return - } + select { + case out <- d: + case <-ctx.Done(): + go func() { + for range chunks { + } + }() + return } } } @@ -203,47 +332,46 @@ func StreamThinkTagDelta(ctx context.Context, chunks <-chan string, minTokens in return out } -// ExtractVisibleAnswer strips blocks from the raw LLM response, -// returning only the visible answer text. If the response consists -// entirely of think content, returns an empty string. -// -// Corresponds to _extract_visible_answer in dialog_service.py. +// ExtractVisibleAnswer returns the visible answer text after the last . +// Stray / tags are stripped. If there is no , all tags are stripped. func ExtractVisibleAnswer(raw string) string { if raw == "" { return "" } - // Collect all non-think text. - var visible []string - remaining := raw - hasThink := false - - for { - openIdx := strings.Index(remaining, thinkOpen) - if openIdx < 0 { - // No more think open tags — strip any stray and keep the rest. - remaining = strings.ReplaceAll(remaining, thinkClose, "") - visible = append(visible, remaining) - break - } - hasThink = true - if openIdx > 0 { - visible = append(visible, remaining[:openIdx]) - } - remaining = remaining[openIdx+len(thinkOpen):] - - closeIdx := strings.Index(remaining, thinkClose) - if closeIdx < 0 { - // Unclosed think — treat rest as visible. - visible = append(visible, remaining) - break - } - remaining = remaining[closeIdx+len(thinkClose):] + if !strings.Contains(raw, thinkClose) { + return stripThinkTags(raw) } - result := strings.TrimSpace(strings.Join(visible, "")) - if hasThink && result == "" { - // Only think content — return empty. - return "" - } - return result + lastClose := strings.LastIndex(raw, thinkClose) + answer := raw[lastClose+len(thinkClose):] + return stripThinkTags(answer) +} + +// BufferAnswerDelta processes an answer delta through the think-state lifecycle. +// When closePending is true, it first flushes the deferred think buffer and +// marker, then processes pendingAfterClose + the new answer text. +func BufferAnswerDelta(state *ThinkStreamState, text string, minTokens int) []ThinkDelta { + if state == nil || text == "" { + return nil + } + state.fullText += text + + var deltas []ThinkDelta + if state.closePending { + state.closePending = false + if piece := flushThinkBufferInternal(state); piece.Value != "" { + deltas = append(deltas, piece) + } + state.inThink = false + deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkClose}) + if state.pendingAfterClose != "" { + text = state.pendingAfterClose + text + state.pendingAfterClose = "" + } + } + + if out, kind := emitText(state, "answer", text, minTokens); out != "" { + deltas = append(deltas, ThinkDelta{Kind: kind, Value: out}) + } + return deltas } diff --git a/internal/service/think_tag_test.go b/internal/service/think_tag_test.go index e19ac0d7a6..e87008b58a 100644 --- a/internal/service/think_tag_test.go +++ b/internal/service/think_tag_test.go @@ -24,117 +24,130 @@ import ( func TestNextThinkDelta_NoThinkTag(t *testing.T) { state := &ThinkStreamState{} - deltas := NextThinkDelta(state, "hello world") - if len(deltas) != 0 { - t.Fatalf("expected 0 deltas, got %d", len(deltas)) + deltas := NextThinkDelta(state, "hello world", 0) + if len(deltas) != 1 { + t.Fatalf("expected 1 delta, got %d: %+v", len(deltas), deltas) } - if state.buffer != "hello world" { - t.Errorf("buffer = %q", state.buffer) + if deltas[0].Kind != ThinkDeltaText || deltas[0].Value != "hello world" { + t.Errorf("expected text delta, got %+v", deltas[0]) } } func TestNextThinkDelta_OnlyThinkTag(t *testing.T) { state := &ThinkStreamState{} - deltas := NextThinkDelta(state, "reasoningvisible") - if len(deltas) != 2 { - t.Fatalf("expected 2 deltas, got %d: %+v", len(deltas), deltas) + deltas := NextThinkDelta(state, "reasoningvisible", 0) + if len(deltas) < 2 { + t.Fatalf("expected at least 2 deltas, got %d: %+v", len(deltas), deltas) } - if deltas[0].Kind != ThinkDeltaMarker || deltas[0].Value != "" { - t.Errorf("first delta should be marker: %+v", deltas[0]) + foundOpen, foundClose := false, false + for _, d := range deltas { + if d.Kind == ThinkDeltaMarker && d.Value == "" { + foundOpen = true + } + if d.Kind == ThinkDeltaMarker && d.Value == "" { + foundClose = true + } } - if deltas[1].Kind != ThinkDeltaMarker || deltas[1].Value != "" { - t.Errorf("second delta should be marker: %+v", deltas[1]) - } - if state.buffer != "visible" { - t.Errorf("buffer = %q, want visible", state.buffer) + if !foundOpen || !foundClose { + t.Errorf("missing markers: open=%v close=%v, deltas=%+v", foundOpen, foundClose, deltas) } } func TestNextThinkDelta_TextThenThink(t *testing.T) { state := &ThinkStreamState{} - deltas := NextThinkDelta(state, "before inside after") - // "before " -> buffer (no flush yet) - // "" -> marker - // "inside" -> inside think, consumed silently - // "" -> marker - // " after" -> buffer - if len(deltas) != 2 { - t.Fatalf("expected 2 markers, got %d: %+v", len(deltas), deltas) + deltas := NextThinkDelta(state, "before inside after", 0) + foundOpen, foundClose := false, false + for _, d := range deltas { + if d.Kind == ThinkDeltaMarker && d.Value == "" { + foundOpen = true + } + if d.Kind == ThinkDeltaMarker && d.Value == "" { + foundClose = true + } } - if state.buffer != "before after" { - t.Errorf("buffer = %q", state.buffer) + if !foundOpen || !foundClose { + t.Errorf("missing markers: open=%v close=%v, deltas=%+v", foundOpen, foundClose, deltas) } } func TestNextThinkDelta_MultipleChunks(t *testing.T) { state := &ThinkStreamState{} - NextThinkDelta(state, "hello ") - NextThinkDelta(state, "") - NextThinkDelta(state, "reasoning") - NextThinkDelta(state, "") - deltas := NextThinkDelta(state, " world") - if len(deltas) != 0 { - t.Fatalf("expected 0 deltas from final chunk, got %d", len(deltas)) + deltas := NextThinkDelta(state, "hello ", 0) + _ = deltas + deltas = NextThinkDelta(state, "", 0) + _ = deltas + deltas = NextThinkDelta(state, "reasoning", 0) + _ = deltas + deltas = NextThinkDelta(state, "", 0) + _ = deltas + deltas = NextThinkDelta(state, " world", 0) + // After deferral, closing marker and answer text arrive in the last chunk. + foundClose := false + for _, d := range deltas { + if d.Kind == ThinkDeltaMarker && d.Value == "" { + foundClose = true + } } - if state.buffer != "hello world" { - t.Errorf("buffer = %q", state.buffer) + if !foundClose { + t.Errorf("expected in final chunk, got %+v", deltas) } } func TestNextThinkDelta_UnclosedThink(t *testing.T) { state := &ThinkStreamState{} - deltas := NextThinkDelta(state, "text unclosed") - if len(deltas) != 1 { - t.Fatalf("expected 1 marker (think open), got %d", len(deltas)) + deltas := NextThinkDelta(state, "text unclosed", 0) + foundOpen := false + for _, d := range deltas { + if d.Kind == ThinkDeltaMarker && d.Value == "" { + foundOpen = true + } } - if deltas[0].Value != "" { - t.Errorf("expected marker") + if !foundOpen { + t.Errorf("expected marker in %+v", deltas) } - if state.buffer != "text " { - t.Errorf("buffer = %q", state.buffer) - } - // "unclosed" should be consumed silently inside think } func TestNextThinkDelta_EmptyInput(t *testing.T) { state := &ThinkStreamState{} - deltas := NextThinkDelta(state, "") + deltas := NextThinkDelta(state, "", 0) if len(deltas) != 0 { - t.Errorf("expected 0 deltas for empty input") + t.Errorf("expected 0 deltas for empty input, got %d", len(deltas)) } } func TestNextThinkDelta_NilState(t *testing.T) { - deltas := NextThinkDelta(nil, "test") + deltas := NextThinkDelta(nil, "test", 0) if deltas != nil { t.Error("expected nil for nil state") } } -func TestFlushThinkBuffer_Empty(t *testing.T) { - if deltas := FlushThinkBuffer(nil); len(deltas) != 0 { - t.Error("expected empty for nil state") +func TestFlushRemaining_FlushesAll(t *testing.T) { + state := &ThinkStreamState{ + thinkBuffer: "think-tail", + closePending: true, + answerBuffer: "answer-tail", + pendingAfterClose: "pending-tail", } - state := &ThinkStreamState{} - if deltas := FlushThinkBuffer(state); len(deltas) != 0 { - t.Error("expected empty for zero state") + deltas := FlushRemaining(state) + if len(deltas) != 4 { + t.Fatalf("expected 4 deltas, got %d: %+v", len(deltas), deltas) } -} - -func TestFlushThinkBuffer_WithContent(t *testing.T) { - state := &ThinkStreamState{buffer: "flushed text"} - deltas := FlushThinkBuffer(state) - if len(deltas) != 1 { - t.Fatalf("expected 1 delta, got %d", len(deltas)) + if deltas[0].Kind != ThinkDeltaText || deltas[0].Value != "think-tail" { + t.Errorf("delta[0] = %+v", deltas[0]) } - if deltas[0].Kind != ThinkDeltaText { - t.Error("expected text delta") + if deltas[1].Kind != ThinkDeltaMarker || deltas[1].Value != "" { + t.Errorf("delta[1] = %+v", deltas[1]) } - if deltas[0].Value != "flushed text" { - t.Errorf("value = %q", deltas[0].Value) + if deltas[2].Kind != ThinkDeltaText || deltas[2].Value != "answer-tail" { + t.Errorf("delta[2] = %+v", deltas[2]) } - if state.buffer != "" { - t.Error("buffer should be cleared after flush") + if deltas[3].Kind != ThinkDeltaText || deltas[3].Value != "pending-tail" { + t.Errorf("delta[3] = %+v", deltas[3]) + } + // State should be cleared. + if state.closePending || state.inThink { + t.Error("state not fully cleared") } } @@ -160,13 +173,13 @@ func TestExtractVisibleAnswer_WithThink(t *testing.T) { func TestExtractVisibleAnswer_ThinkOnly(t *testing.T) { raw := "only reasoning here" if got := ExtractVisibleAnswer(raw); got != "" { - t.Errorf("expected empty for think-only, got %q", got) + t.Errorf("got %q", got) } } func TestExtractVisibleAnswer_MultipleThinks(t *testing.T) { raw := "firstvisible1secondvisible2" - if got := ExtractVisibleAnswer(raw); got != "visible1visible2" { + if got := ExtractVisibleAnswer(raw); got != "visible2" { t.Errorf("got %q", got) } } @@ -178,6 +191,18 @@ func TestExtractVisibleAnswer_NestedTags(t *testing.T) { } } +func TestExtractVisibleAnswer_NoTags(t *testing.T) { + if got := ExtractVisibleAnswer("plain text"); got != "plain text" { + t.Errorf("got %q", got) + } +} + +func TestExtractVisibleAnswer_StrayTag(t *testing.T) { + if got := ExtractVisibleAnswer("text"); got != "text" { + t.Errorf("got %q", got) + } +} + func TestStreamThinkTagDelta(t *testing.T) { chunks := []string{"hello ", "wor", "", "think text", "", "ld", " final"} ch := make(chan string, len(chunks)) @@ -208,14 +233,12 @@ func TestStreamThinkTagDelta(t *testing.T) { } joined := strings.Join(texts, "") - if !strings.Contains(joined, "hello world") || !strings.Contains(joined, "final") { + if !strings.Contains(joined, "hello wor") || !strings.Contains(joined, "final") { t.Errorf("texts = %q", texts) } } func TestStreamThinkTagDelta_IncrementalFlush(t *testing.T) { - // Verify that visible text is streamed incrementally, not all at the end. - // minTokens=1 → flushSize=4 bytes. Each chunk triggers a flush. chunks := []string{"1234", "5678", "90ab"} ch := make(chan string, len(chunks)) for _, c := range chunks { @@ -229,8 +252,6 @@ func TestStreamThinkTagDelta_IncrementalFlush(t *testing.T) { texts = append(texts, d.Value) } } - // With minTokens=1 (flushSize=4), each chunk triggers a flush. - // We should get incremental text deltas, not just one final burst. if len(texts) < 2 { t.Errorf("expected >=2 incremental text deltas, got %d: %q", len(texts), texts) } @@ -254,3 +275,29 @@ func TestStreamThinkTagDelta_NoThinkTags(t *testing.T) { t.Errorf("got %q, want 'just plain text'", joined) } } + +func TestStreamThinkTagDelta_DeferredClose(t *testing.T) { + // When has no visible text after it, the marker is deferred. + chunks := []string{"", "hello", "", "world"} + ch := make(chan string, len(chunks)) + for _, c := range chunks { + ch <- c + } + close(ch) + + var markers []string + for d := range StreamThinkTagDelta(context.Background(), ch, 1) { + if d.Kind == ThinkDeltaMarker { + markers = append(markers, d.Value) + } + } + if len(markers) != 2 { + t.Fatalf("expected 2 markers, got %d: %v", len(markers), markers) + } + if markers[0] != "" { + t.Errorf("first marker = %q", markers[0]) + } + if markers[1] != "" { + t.Errorf("second marker = %q", markers[1]) + } +} diff --git a/web/vite.config.ts b/web/vite.config.ts index 9357640ec2..4bc3f1fcad 100644 --- a/web/vite.config.ts +++ b/web/vite.config.ts @@ -103,6 +103,11 @@ export default defineConfig(({ mode }) => { changeOrigin: true, ws: true, }, + '^(/api/v1/datasets/search)|^(/api/v1/chat/completions)': { + target: 'http://127.0.0.1:9384/', + changeOrigin: true, + ws: true, + }, '/api': { target: 'http://127.0.0.1:9380/', changeOrigin: true,