Implement chat completions in go (#16491)

### Summary

POST   /api/v1/chat/completions
This commit is contained in:
qinling0210
2026-07-01 15:52:52 +08:00
committed by GitHub
parent b8e960e6c8
commit 7862f69f39
19 changed files with 1917 additions and 2362 deletions

View File

@@ -45,6 +45,16 @@ server {
include proxy.conf;
}
location ~ ^/api/v1/datasets/search {
proxy_pass http://127.0.0.1:9384;
include proxy.conf;
}
location ~ ^/api/v1/chat/completions {
proxy_pass http://127.0.0.1:9384;
include proxy.conf;
}
location ~ ^/v1/system/config {
proxy_pass http://127.0.0.1:9384;
include proxy.conf;

View File

@@ -834,6 +834,8 @@ Commands (User Mode):
CHAT 'provider/instance/model' 'message'; - Chat with specified model
OPENAI_CHAT 'chat_id' 'message' [options] ; - OpenAI-compatible chat
(run openai_chat -h for detailed options)
CHAT COMPLETIONS 'question' [options] ; - Chat completions via /api/v1/chat/completions
(run chat completions -h for detailed options)
Filesystem Commands (no quotes):
ls [path] - List resources
@@ -1041,3 +1043,50 @@ Examples:
`
fmt.Println(help)
}
// printChatCompletionsHelp prints help for the CHAT COMPLETIONS command.
func printChatCompletionsHelp() {
help := `CHAT COMPLETIONS — hit POST /api/v1/chat/completions
Syntax:
CHAT COMPLETIONS 'question'
chat_id '...'
[session "..."] [llm "..."]
[system "..."] [history "..."] [history_delimiter "<char>"]
[temperature <float>] [max_tokens <int>] [stream <bool>]
[top_p <float>] [frequency_penalty <float>] [presence_penalty <float>]
[pass_all_history <bool>] [legacy <bool>] ;
Required positional:
'question' the user question
Named options (any order; all optional with defaults):
chat_id '...' the dialog id (optional)
session '...' existing session/conversation id
llm '...' override the dialog's LLM
system '...' override the system prompt
history '...' prior turns: user:...;assistant:...;user:...
history_delimiter '...' turn separator for history (default ';')
temperature <float> 0..2 (default 0)
max_tokens <int> (default 0 = server/model default)
stream <bool> true|false (default false)
top_p <float> 0..1
frequency_penalty <float> -2..2
presence_penalty <float> -2..2
pass_all_history <bool> pass all history messages
legacy <bool> use legacy SSE format
Defaults:
stream false
temperature 0
history_delimiter ';'
Examples:
CHAT COMPLETIONS 'Hello, how are you?' chat_id 'cid';
CHAT COMPLETIONS 'Explain quantum computing' chat_id 'cid' stream true;
CHAT COMPLETIONS 'Next question' chat_id 'cid' session 'sess-abc123';
CHAT COMPLETIONS 'What about X?' chat_id 'cid' system 'You are a helpful assistant.' history 'user:Tell me about Y;assistant:Y is...';
CHAT COMPLETIONS 'Summarize' chat_id 'cid' llm 'Qwen/Qwen3-8B@ling@SILICONFLOW' temperature 0.7 max_tokens 512;
`
fmt.Println(help)
}

View File

@@ -419,6 +419,11 @@ func (c *CLI) ExecuteUserCommand(cmd *Command) (ResponseIf, error) {
return c.EmbedUserTextCommand(cmd)
case "api_rarank_user_document":
return c.APIRerankUserDocumentCommand(cmd)
case "chat completions":
return c.ChatCompletions(cmd)
case "chat completions help":
printChatCompletionsHelp()
return nil, nil
case "tts_user_command":
return c.APITTSUserCommand(cmd)
case "asr_user_command":

View File

@@ -980,9 +980,11 @@ func getChunkID(c map[string]interface{}) string {
}
func chunkContent(c map[string]interface{}) string {
if v, ok := c["content"]; ok {
s := fmt.Sprint(v)
return strings.TrimSpace(s)
for _, key := range []string{"content_with_weight", "content"} {
if v, ok := c[key]; ok {
s := fmt.Sprint(v)
return strings.TrimSpace(s)
}
}
return ""
}
@@ -1283,3 +1285,59 @@ func (r *QuotaSummaryResponse) PrintOut() {
PrintTableSimpleByFormatWithOrder(table, section.columns, r.OutputFormat)
}
}
// ChatCompletionsResponse represents the RAGFlow-internal response from
// POST /api/v1/chat/completions (non-OpenAI format).
//
// JSON shape:
//
// {"code":0,"data":{"answer":"...","reference":...},"message":""}
type ChatCompletionsResponse struct {
Code int `json:"code"`
Data *chatCompletionData `json:"data"`
Message string `json:"message"`
Duration float64 `json:"-"`
OutputFormat OutputFormat `json:"-"`
// raw HTTP body for "raw" output.
raw []byte
// streamed skips the "Answer:" line in PrintOut to avoid duplication
// (used by the streaming path which prints chunk-by-chunk).
streamed bool
}
type chatCompletionData struct {
Answer string `json:"answer"`
Reference json.RawMessage `json:"reference,omitempty"`
ID string `json:"id,omitempty"`
SessionID string `json:"session_id,omitempty"`
ChatID string `json:"chat_id,omitempty"`
}
func (r *ChatCompletionsResponse) Type() string { return "chat_completions" }
func (r *ChatCompletionsResponse) TimeCost() float64 { return r.Duration }
func (r *ChatCompletionsResponse) SetOutputFormat(f OutputFormat) { r.OutputFormat = f }
func (r *ChatCompletionsResponse) PrintOut() {
if r.OutputFormat == "raw" && r.raw != nil {
fmt.Println(string(r.raw))
return
}
if r.Code != 0 {
fmt.Println("ERROR")
fmt.Printf("%d, %s\n", r.Code, r.Message)
return
}
if r.Data == nil {
fmt.Println("(no data)")
return
}
if !r.streamed {
if r.Data.Answer != "" {
fmt.Printf("Answer: %s\n", r.Data.Answer)
}
}
if r.Data != nil && len(r.Data.Reference) > 0 {
printReferenceChunks(r.Data.Reference)
}
fmt.Printf("Time: %f\n", r.Duration)
}

View File

@@ -32,7 +32,6 @@ import (
"ragflow/internal/ingestion"
"ragflow/internal/ingestion/parser"
"ragflow/internal/utility"
"regexp"
"strings"
"time"
)
@@ -2128,7 +2127,7 @@ func (c *CLI) APIChatToModelCommand(cmd *Command) (ResponseIf, error) {
effort := cmd.Params["effort"].(string)
verbosity := cmd.Params["verbosity"].(string)
url := "/chat/completions"
url := "/chat/to_model"
payload := map[string]interface{}{
"messages": formattedMessages,
@@ -4133,8 +4132,6 @@ func (c *CLI) streamOpenaiChat(url string, body map[string]interface{}) (Respons
fullContent = strings.TrimLeft(fullContent, "\n\r")
fullReason = strings.TrimLeft(fullReason, "\n\r")
fullContent = stripThinkTags(fullContent)
fullReason = stripThinkTags(fullReason)
return &OpenAIChatResponse{
Duration: resp.Duration,
Reasoning: fullReason,
@@ -4147,8 +4144,187 @@ func (c *CLI) streamOpenaiChat(url string, body map[string]interface{}) (Respons
}, nil
}
// stripThinkTags removes <think>…</think> wrappers from a streamed answer
func stripThinkTags(s string) string {
var thinkTagRE = regexp.MustCompile(`(?s)<think>.*?</think>`)
return thinkTagRE.ReplaceAllString(s, "")
// ChatCompletions dispatches the parsed CHAT COMPLETIONS command to
// POST /api/v1/chat/completions.
func (c *CLI) ChatCompletions(cmd *Command) (ResponseIf, error) {
if c.Config.CLIMode != APIMode {
return nil, fmt.Errorf("CHAT COMPLETIONS is only allowed in USER mode")
}
httpClient := c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer]
if httpClient.APIKey == nil && httpClient.LoginToken == nil {
return nil, fmt.Errorf("API token not set. Please login first")
}
body, err := buildChatCompletionsRequestBody(cmd)
if err != nil {
return nil, err
}
url := "/chat/completions"
stream, _ := cmd.Params["stream"].(bool)
if stream {
return c.streamChatCompletions(url, body)
}
return c.oneshotChatCompletions(url, body)
}
// buildChatCompletionsRequestBody assembles the JSON payload for
// POST /api/v1/chat/completions.
//
// When system or history is provided, a `messages` array is built;
// otherwise just `question` is sent and the server normalizes it.
func buildChatCompletionsRequestBody(cmd *Command) (map[string]interface{}, error) {
chatID, _ := cmd.Params["chat_id"].(string)
question, _ := cmd.Params["question"].(string)
stream, _ := cmd.Params["stream"].(bool)
body := map[string]interface{}{
"chat_id": chatID,
"stream": stream,
}
// Optional session_id
if v, ok := cmd.Params["session"].(string); ok && v != "" {
body["session_id"] = v
}
// Optional llm_id
if v, ok := cmd.Params["llm"].(string); ok && v != "" {
body["llm_id"] = v
}
// Build messages from system + history when provided; otherwise send question.
system, hasSystem := cmd.Params["system"].(string)
historyRaw, hasHistory := cmd.Params["history_raw"].(string)
if hasSystem || hasHistory {
messages := make([]map[string]interface{}, 0, 4)
if hasSystem && system != "" {
messages = append(messages, map[string]interface{}{"role": "system", "content": system})
}
if hasHistory && historyRaw != "" {
delimiter, _ := cmd.Params["history_delimiter"].(string)
turns, err := parseHistory(historyRaw, delimiter)
if err != nil {
return nil, fmt.Errorf("CHAT COMPLETIONS history: %w", err)
}
for _, t := range turns {
messages = append(messages, map[string]interface{}{
"role": t["role"],
"content": t["content"],
})
}
}
messages = append(messages, map[string]interface{}{"role": "user", "content": question})
body["messages"] = messages
} else {
body["question"] = question
}
// Optional flags — only emit when explicitly set
if isSet(cmd, "pass_all_history") && cmd.Params["pass_all_history"].(bool) {
body["pass_all_history_messages"] = true
}
if isSet(cmd, "legacy") && cmd.Params["legacy"].(bool) {
body["legacy"] = true
}
// Generation params — only emit when explicitly set
if isSet(cmd, "temperature") {
body["temperature"] = cmd.Params["temperature"]
}
if isSet(cmd, "max_tokens") {
body["max_tokens"] = cmd.Params["max_tokens"]
}
if isSet(cmd, "top_p") {
body["top_p"] = cmd.Params["top_p"]
}
if isSet(cmd, "frequency_penalty") {
body["frequency_penalty"] = cmd.Params["frequency_penalty"]
}
if isSet(cmd, "presence_penalty") {
body["presence_penalty"] = cmd.Params["presence_penalty"]
}
return body, nil
}
// oneshotChatCompletions performs a non-streaming POST and returns a
// ChatCompletionsResponse parsed from the RAGFlow-internal JSON envelope.
func (c *CLI) oneshotChatCompletions(url string, body map[string]interface{}) (ResponseIf, error) {
httpClient := c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer]
resp, err := httpClient.Request("POST", url, "web", nil, body)
if err != nil {
return nil, fmt.Errorf("chat completions request: %w", err)
}
if resp.StatusCode != 200 {
return &ChatCompletionsResponse{
Code: resp.StatusCode,
Message: string(resp.Body),
raw: resp.Body,
}, nil
}
out := &ChatCompletionsResponse{
Duration: resp.Duration,
raw: resp.Body,
}
// RAGFlow returns {code, data: {answer, reference, ...}, message}.
var envelope struct {
Code int `json:"code"`
Message string `json:"message"`
Data *chatCompletionData `json:"data"`
}
if err := json.Unmarshal(resp.Body, &envelope); err != nil {
return nil, fmt.Errorf("chat completions: invalid response JSON: %w", err)
}
out.Code = envelope.Code
out.Message = envelope.Message
out.Data = envelope.Data
return out, nil
}
// streamChatCompletions performs a streaming POST and collects SSE chunks.
func (c *CLI) streamChatCompletions(url string, body map[string]interface{}) (ResponseIf, error) {
httpClient := c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer]
reader, err := httpClient.RequestStream("POST", url, "web", nil, body)
if err != nil {
return nil, fmt.Errorf("chat completions stream: %w", err)
}
defer reader.Close()
start := time.Now()
scanner := bufio.NewScanner(reader)
var fullContent string
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || !strings.HasPrefix(line, "data:") {
continue
}
payload := strings.TrimPrefix(line, "data:")
payload = strings.TrimSpace(payload)
if payload == "[DONE]" {
continue
}
var chunk struct {
Code int `json:"code"`
Message string `json:"message"`
Data chatCompletionData `json:"data"`
}
if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
continue
}
if chunk.Data.Answer != "" {
fullContent += chunk.Data.Answer
}
}
fullContent = strings.TrimLeft(fullContent, "\n\r")
return &ChatCompletionsResponse{
Duration: time.Since(start).Seconds(),
Data: &chatCompletionData{
Answer: fullContent,
},
streamed: true,
}, nil
}

View File

@@ -2702,6 +2702,13 @@ func (p *Parser) parseAPIDisable() (*Command, error) {
func (p *Parser) parseAPIChat() (*Command, error) {
p.nextToken() // consume CHAT
// Redirect "chat completion[s]" to the standalone chat completions parser.
if p.curToken.Type == TokenIdentifier &&
(strings.EqualFold(p.curToken.Value, "completion") || strings.EqualFold(p.curToken.Value, "completions")) {
p.nextToken() // consume completion/completions
return p.parseChatCompletionsBody()
}
var err error
var modelNameOrID string = ""
var messages []string
@@ -3759,6 +3766,150 @@ optionsLoop:
return cmd, nil
}
// CHAT COMPLETIONS <question>
// [chat_id <string>] [session <string>] [llm <string>]
// parseChatCompletionsBody parses the question and options of a CHAT COMPLETIONS
// command. The leading keyword(s) must already have been consumed by the caller.
func (p *Parser) parseChatCompletionsBody() (*Command, error) {
if p.curToken.Type == TokenDash {
dashCount := 0
for p.curToken.Type == TokenDash {
dashCount++
p.nextToken()
}
if dashCount > 0 && p.curToken.Type == TokenIdentifier {
switch strings.ToLower(p.curToken.Value) {
case "h", "help":
return NewCommand("chat completions help"), nil
}
}
return nil, fmt.Errorf("CHAT COMPLETIONS: only -h/--help takes no args; otherwise expected question")
}
cmd := NewCommand("chat completions")
// Defaults
cmd.Params["chat_id"] = ""
cmd.Params["temperature"] = 0.0
cmd.Params["max_tokens"] = 0
cmd.Params["stream"] = false
// Track which options were explicitly set (distinguishes from defaults).
cmd.Params["_set"] = map[string]bool{}
// Required positional: <question>
question, err := p.parseQuotedString()
if err != nil {
return nil, fmt.Errorf("CHAT COMPLETIONS: expected question: %w", err)
}
cmd.Params["question"] = question
p.nextToken()
// Optional named options
handleOption := func(name string) error {
switch name {
case "chat_id", "session", "llm", "system":
v, err := p.parseQuotedString()
if err != nil {
return fmt.Errorf("CHAT COMPLETIONS %s: expected quoted string, got %s", name, p.curToken.Value)
}
cmd.Params[name] = v
p.nextToken()
markSet(cmd, name)
case "temperature", "top_p", "frequency_penalty", "presence_penalty":
v, err := p.parseFloat()
if err != nil {
return fmt.Errorf("CHAT COMPLETIONS %s: expected number, got %s", name, p.curToken.Value)
}
cmd.Params[name] = v
p.nextToken()
markSet(cmd, name)
case "max_tokens":
v, err := p.parseNumber()
if err != nil {
return fmt.Errorf("CHAT COMPLETIONS max_tokens: expected integer, got %s", p.curToken.Value)
}
cmd.Params["max_tokens"] = v
p.nextToken()
markSet(cmd, "max_tokens")
case "stream", "pass_all_history", "legacy":
v, err := p.parseBool()
if err != nil {
return fmt.Errorf("CHAT COMPLETIONS %s: expected true|false, got %s", name, p.curToken.Value)
}
cmd.Params[name] = v
markSet(cmd, name)
case "history":
raw, err := p.parseQuotedString()
if err != nil {
return fmt.Errorf("CHAT COMPLETIONS history: expected quoted string, got %s", p.curToken.Value)
}
cmd.Params["history_raw"] = raw
p.nextToken()
case "history_delimiter":
v, err := p.parseQuotedString()
if err != nil {
return fmt.Errorf("CHAT COMPLETIONS history_delimiter: expected quoted string, got %s", p.curToken.Value)
}
cmd.Params["history_delimiter"] = v
p.nextToken()
default:
return fmt.Errorf("CHAT COMPLETIONS: unknown option %q (valid: chat_id, session, llm, system, history, history_delimiter, temperature, max_tokens, stream, top_p, frequency_penalty, presence_penalty, pass_all_history, legacy)", name)
}
return nil
}
// Named options, any order, until ';'.
optionsLoop:
for {
switch p.curToken.Type {
case TokenSemicolon:
p.nextToken()
break optionsLoop
case TokenEOF:
break optionsLoop
case TokenIdentifier, TokenQuotedString:
name := p.curToken.Value
if p.curToken.Type == TokenQuotedString {
name = strings.Trim(name, "'\"")
}
p.nextToken()
if err := handleOption(name); err != nil {
return nil, err
}
default:
if !isKeyword(p.curToken.Type) {
return nil, fmt.Errorf("CHAT COMPLETIONS: unexpected token %q in option list (valid options: chat_id, session, llm, system, history, history_delimiter, temperature, max_tokens, stream, top_p, frequency_penalty, presence_penalty, pass_all_history, legacy)", p.curToken.Value)
}
name := p.curToken.Value
p.nextToken()
if err := handleOption(name); err != nil {
return nil, err
}
}
}
return cmd, nil
}
func markSet(cmd *Command, name string) {
if s, ok := cmd.Params["_set"].(map[string]bool); ok {
s[name] = true
}
}
func isSet(cmd *Command, name string) bool {
if s, ok := cmd.Params["_set"].(map[string]bool); ok {
return s[name]
}
return false
}
// parseJSONLiteral consumes a TokenQuotedString whose payload is a JSON
// value (object, array, string, number, or boolean) and returns it as
// the original raw string (NOT decoded — the caller decides whether to

View File

@@ -167,7 +167,33 @@ Time: 76.582520
```
Note: Both image and video understanding support streaming and thinking modes as well.
### 6.8. Chat with OpenAI compatible API
### 6.8. Chat completions
```
RAGFlow(api/default)> chat completion 'hello'
Answer: Hello! How can I assist you today? 😊
Time: 1.591929
```
```
RAGFlow(api/default)> CHAT COMPLETIONS '<question>' chat_id '<chat_id>';
```
```
RAGFlow(api/default)> CHAT COMPLETIONS 'Explain the theory' \
chat_id '<chat_id>' \
session '<session_id>' llm 'glm-4.5-flash@test@zhipu-ai' stream true;
```
```
RAGFlow(api/default)> CHAT COMPLETIONS 'Continue' \
system 'You are a helpful assistant.' \
history 'user:What is RAG?;assistant:RAG stands for Retrieval-Augmented Generation...' \
history_delimiter ';';
```
### 6.9. Chat with OpenAI compatible API
```
RAGFlow(api/default)> openai_chat '<chat_id>' 'Hello, how are you?';
Answer: Hello! I'm just a virtual assistant, so I don't have feelings, but I'm here and ready to help you with anything you need. How can I assist you today? 😊
@@ -204,17 +230,17 @@ RAGFlow(api/default)> openai_chat '<chat_id>' 'Hello, how are you?' extra_body '
CLI error: OPENAI_CHAT extra_body: unknown field "ref" (valid: reference, reference_metadata, metadata_condition)
```
### 6.9. Generate Embeddings
### 6.10. Generate Embeddings
```
RAGFlow(api/default)> embed text 'what is rag' 'who are you' with 'embedding-3@test@zhipu-ai' dimension 16;
```
### 6.10. Document Reranking
### 6.11. Document Reranking
```
RAGFlow(api/default)> rerank query 'what is rag' document 'rag is retrieval augment generation' 'rag need llm' 'famous rag project includes ragflow' with 'rerank@test@zhipu-ai' top 2;
```
### 6.11. Get supported models from provider API
### 6.12. Get supported models from provider API
```
RAGFlow(api/default)> list supported models from 'gitee' 'test';
@@ -236,7 +262,7 @@ RAGFlow(api/default)> list supported models from 'gitee' 'test';
+-----------+---------------------------+---------------+------------+-----------------------------------------------------------------+----------------------------------------------------------+---------------------------------------------+
```
### 6.12. Get preset models of a provider
### 6.13. Get preset models of a provider
```
RAGFlow(api/default)> list models from 'minimax';
@@ -254,7 +280,7 @@ RAGFlow(api/default)> list models from 'minimax';
+------------+-------------+------------------------+
```
### 6.13. List instances of a provider
### 6.14. List instances of a provider
```
RAGFlow(api/default)> list instances from 'zhipu-ai';
@@ -265,7 +291,7 @@ RAGFlow(api/default)> list instances from 'zhipu-ai';
+---------+----------------------+----------------------------------+--------------+----------------------------------+--------+
```
### 6.14. Show instance of a provider
### 6.15. Show instance of a provider
```
RAGFlow(api/default)> show instance 'test' from 'zhipu-ai';
+----------------------------------+--------------+----------------------------------+---------+--------+
@@ -275,7 +301,7 @@ RAGFlow(api/default)> show instance 'test' from 'zhipu-ai';
+----------------------------------+--------------+----------------------------------+---------+--------+
```
### 6.15. List models of a specific instance
### 6.16. List models of a specific instance
```
RAGFlow(api/default)> list models from 'minimax' 'test';
@@ -293,7 +319,7 @@ RAGFlow(api/default)> list models from 'minimax' 'test';
+------------+-------------+------------------------+--------+
```
### 6.16. List added providers
### 6.17. List added providers
```
RAGFlow(api/default)> list providers;
+--------------------------------------------------------------------------+-------------+--------------+
@@ -305,7 +331,7 @@ RAGFlow(api/default)> list providers;
+--------------------------------------------------------------------------+-------------+--------------+
```
### 6.17. Deactivate / activate a model
### 6.18. Deactivate / activate a model
```
RAGFlow(api/default)> disable model 'deepseek-v4-pro' from 'deepseek' 'test';
@@ -321,7 +347,7 @@ RAGFlow(api/default)> enable model 'deepseek-v4-pro' from 'deepseek' 'test';
SUCCESS
```
### 6.18. Set current model
### 6.19. Set current model
```
RAGFlow(api/default)> use model 'glm-4.5-flash@test@zhipu-ai';
SUCCESS
@@ -330,7 +356,7 @@ Answer: Large language models are advanced AI systems. They process text to unde
Time: 1.680416
```
### 6.19. Set, reset, and list default models
### 6.20. Set, reset, and list default models
```
RAGFlow(api/default)> set default chat model 'glm-4.5-flash@test@zhipu-ai';
SUCCESS
@@ -374,7 +400,7 @@ RAGFlow(api/default)> list default models;
+--------+----------------+--------------+----------------+------------+
```
### 6.20. Show current balance of a provider instance
### 6.21. Show current balance of a provider instance
```
RAGFlow(api/default)> show balance from 'gitee' 'test';
+-------------+----------+
@@ -384,13 +410,13 @@ RAGFlow(api/default)> show balance from 'gitee' 'test';
+-------------+----------+
```
### 6.21. Check provider instance availability
### 6.22. Check provider instance availability
```
RAGFlow(api/default)> check instance 'test' from 'zhipu-ai';
SUCCESS
```
### 6.22. Add local model to RAGFlow, only for local deployed inference server, such as ollama
### 6.23. Add local model to RAGFlow, only for local deployed inference server, such as ollama
```
RAGFlow(api/default)> add model 'Qwen/Qwen2.5-0.5B' to provider 'vllm' instance 'test' with tokens 131072 chat;
SUCCESS
@@ -404,7 +430,7 @@ RAGFlow(api/default)> drop model 'Qwen/Qwen2.5-0.5B' from 'vllm' 'test';
SUCCESS
```
### 6.23. List datasets
### 6.24. List datasets
```
RAGFlow(api/default)> list datasets;
+-------------+--------------+----------------+----------------------+----------------------------------+----------+------+----------+------------+----------------------------------+-----------+---------------+
@@ -415,14 +441,14 @@ RAGFlow(api/default)> list datasets;
+-------------+--------------+----------------+----------------------+----------------------------------+----------+------+----------+------------+----------------------------------+-----------+---------------+
```
### 6.24. Text to Speech
### 6.25. Text to Speech
```
RAGFlow(api/default)> tts with 'speech-2.8-hd@test@minimax' text 'He who desires but acts not, breeds pestilence.' play format 'wav' save './internal' param '{"voice_setting": {"voice_id": "English_radiant_girl", "speed": 1, "vol": 1, "pitch": 0}, "audio_setting": {"sample_rate": 32000, "bitrate": 128000, "format": "wav", "channel": 1}, "output_format": "hex"}'
Saved to directory: /home/infiniflow/Documents/development/ragflow/internal/speech-2.8-hd_output.wav
SUCCESS
```
### 6.25. Audio to Speech
### 6.26. Audio to Speech
```
RAGFlow(api/default)> asr with 'FunAudioLLM/SenseVoiceSmall@test@siliconflow' audio './internal/test.wav' param ''
+----------------------------------------------------------------------------------------------------------------------+
@@ -432,7 +458,7 @@ RAGFlow(api/default)> asr with 'FunAudioLLM/SenseVoiceSmall@test@siliconflow' au
+----------------------------------------------------------------------------------------------------------------------+
```
### 6.26. Optical Character Recognition
### 6.27. Optical Character Recognition
```
RAGFlow(api/default)> ocr with 'paddleocr-vl-0.9b@test@baidu' file './internal/text.jpg'
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
@@ -442,7 +468,7 @@ RAGFlow(api/default)> ocr with 'paddleocr-vl-0.9b@test@baidu' file './internal/t
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
```
### 6.27. Chunk Management Commands
### 6.28. Chunk Management Commands
- Create a chunk store with vector size
```
@@ -489,7 +515,7 @@ RAGFlow(api/default)> RETRIEVE 'AI' ON DATASETS 'test'
RAGFlow(api/default)> GET CHUNK '29cc4f6d7a5c6e7c' OF DATASET 'test' DOCUMENT 'bbe55942535e11f1bc5184ba59049aa3' IN DATASET 'test'
```
### 6.28. Metadata Management Commands
### 6.29. Metadata Management Commands
- Create metadata store
```
@@ -525,7 +551,7 @@ RAGFlow(api/default)> DROP METADATA STORE
RAGFlow(api/default)> GET METADATA OF DATASET 'test' 'test2'
```
### 6.29. Search datasets
### 6.30. Search datasets
- Search datasets using SQL-like dataset search syntax:
```

View File

@@ -227,192 +227,6 @@ func (h *ChatHandler) MindMap(c *gin.Context) {
jsonResponse(c, common.CodeSuccess, mindMap, "success")
}
// ListChatsNext list chats with advanced filtering and pagination
// @Summary List Chats Next
// @Description Get list of chats with filtering, pagination and sorting (equivalent to list_dialogs_next)
// @Tags chat
// @Accept json
// @Produce json
// @Param keywords query string false "search keywords"
// @Param page query int false "page number"
// @Param page_size query int false "items per page"
// @Param orderby query string false "order by field (default: create_time)"
// @Param desc query bool false "descending order (default: true)"
// @Param request body service.ListChatsNextRequest true "filter options including owner_ids"
// @Success 200 {object} service.ListChatsNextResponse
// @Router /v1/dialog/next [post]
func (h *ChatHandler) ListChatsNext(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := user.ID
// Parse query parameters
keywords := c.Query("keywords")
page := 0
if pageStr := c.Query("page"); pageStr != "" {
if p, err := strconv.Atoi(pageStr); err == nil && p > 0 {
page = p
}
}
pageSize := 0
if pageSizeStr := c.Query("page_size"); pageSizeStr != "" {
if ps, err := strconv.Atoi(pageSizeStr); err == nil && ps > 0 {
pageSize = ps
}
}
orderby := c.DefaultQuery("orderby", "create_time")
desc := true
if descStr := c.Query("desc"); descStr != "" {
desc = descStr != "false"
}
// Parse request body for owner_ids
var req service.ListChatsNextRequest
if c.Request.ContentLength > 0 {
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": err.Error(),
})
return
}
}
// List chats with advanced filtering
result, err := h.chatService.ListChatsNext(userID, keywords, page, pageSize, orderby, desc, req.OwnerIDs)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": result,
"message": "success",
})
}
// SetDialog create or update a dialog
// @Summary Set Dialog
// @Description Create or update a dialog (chat). If dialog_id is provided, updates existing dialog; otherwise creates new one.
// @Tags chat
// @Accept json
// @Produce json
// @Param request body service.SetDialogRequest true "dialog configuration"
// @Success 200 {object} service.SetDialogResponse
// @Router /v1/dialog/set [post]
func (h *ChatHandler) SetDialog(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := user.ID
// Parse request body
var req service.SetDialogRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": err.Error(),
})
return
}
// Validate required field: prompt_config
if req.PromptConfig == nil {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": "prompt_config is required",
})
return
}
// Call service to set dialog
result, err := h.chatService.SetDialog(userID, &req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": result,
"message": "success",
})
}
// RemoveDialogsRequest remove dialogs request
type RemoveDialogsRequest struct {
DialogIDs []string `json:"dialog_ids" binding:"required"`
}
// RemoveChats remove/delete dialogs (soft delete by setting status to invalid)
// @Summary Remove Dialogs
// @Description Remove dialogs by setting their status to invalid. Only the owner of the dialog can perform this operation.
// @Tags chat
// @Accept json
// @Produce json
// @Param request body RemoveDialogsRequest true "dialog IDs to remove"
// @Success 200 {object} map[string]interface{}
// @Router /v1/dialog/rm [post]
func (h *ChatHandler) RemoveChats(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := user.ID
// Parse request body
var req RemoveDialogsRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": err.Error(),
})
return
}
// Call service to remove dialogs
if err := h.chatService.RemoveChats(userID, req.DialogIDs); err != nil {
// Check if it's an authorization error
if err.Error() == "only owner of chat authorized for this operation" {
c.JSON(http.StatusForbidden, gin.H{
"code": 403,
"data": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": true,
"message": "success",
})
}
// DeleteChat soft deletes a chat by ID.
func (h *ChatHandler) DeleteChat(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {

View File

@@ -19,7 +19,6 @@ package handler
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"ragflow/internal/common"
@@ -44,107 +43,6 @@ func NewChatSessionHandler(chatSessionService *service.ChatSessionService, userS
}
}
// SetChatSession create or update a chat session
// @Summary Set chat session
// @Description Create or update a chat session. If is_new is true, creates new chat session; otherwise updates existing one.
// @Tags chat_session
// @Accept json
// @Produce json
// @Param request body service.SetChatSessionRequest true "chat session configuration"
// @Success 200 {object} service.SetChatSessionResponse
// @Router /v1/conversation/set [post]
func (h *ChatSessionHandler) SetChatSession(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := user.ID
// Parse request body
var req service.SetChatSessionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": err.Error(),
})
return
}
// Call service to set chat session
result, err := h.chatSessionService.SetChatSession(userID, &req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": result,
"message": "success",
})
}
// RemoveChatSessionsRequest remove chat sessions request
type RemoveChatSessionsRequest struct {
ConversationIDs []string `json:"conversation_ids" binding:"required"`
}
// RemoveChatSessions remove/delete chat sessions
// @Summary Remove Chat Sessions
// @Description Remove chat sessions by their IDs. Only the owner of the chat session can perform this operation.
// @Tags chat_session
// @Accept json
// @Produce json
// @Param request body RemoveChatSessionsRequest true "chat session IDs to remove"
// @Success 200 {object} map[string]interface{}
// @Router /v1/conversation/rm [post]
func (h *ChatSessionHandler) RemoveChatSessions(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := user.ID
// Parse request body
var req RemoveChatSessionsRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": err.Error(),
})
return
}
// Call service to remove chat sessions
if err := h.chatSessionService.RemoveChatSessions(userID, req.ConversationIDs); err != nil {
// Check if it's an authorization error
if err.Error() == "Only owner of chat session authorized for this operation" {
c.JSON(http.StatusForbidden, gin.H{
"code": 403,
"data": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": true,
"message": "success",
})
}
// ListChatSessions list chat sessions for a dialog
// @Summary List Chat Sessions
// @Description Get list of chat sessions for a specific dialog
@@ -198,30 +96,37 @@ func (h *ChatSessionHandler) ListChatSessions(c *gin.Context) {
})
}
// CompletionRequest completion request
type CompletionRequest struct {
ConversationID string `json:"conversation_id" binding:"required"`
Messages []map[string]interface{} `json:"messages" binding:"required"`
LLMID string `json:"llm_id,omitempty"`
Stream *bool `json:"stream,omitempty"`
Thinking *bool `json:"thinking,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
type ChatCompletionsRequest struct {
ChatID string `json:"chat_id,omitempty"`
SessionID string `json:"session_id,omitempty"`
ConversationID string `json:"conversation_id,omitempty"`
Messages []map[string]interface{} `json:"messages,omitempty"`
Question string `json:"question,omitempty"`
Files []interface{} `json:"files,omitempty"`
LLMID string `json:"llm_id,omitempty"`
PassAllHistoryMessages *bool `json:"pass_all_history_messages,omitempty"`
PassAllHistory *bool `json:"pass_all_history,omitempty"`
Legacy bool `json:"legacy,omitempty"`
Stream *bool `json:"stream"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
}
// Completion chat completion
// ChatCompletions chat completion
// @Summary Chat Completion
// @Description Send messages to the chat model and get a response. Supports streaming and non-streaming modes.
// @Description Send messages to the chat model and get a response.
// @Description Default is streaming (text/event-stream); set stream:false for JSON.
// @Tags chat_session
// @Accept json
// @Produce json
// @Param request body CompletionRequest true "completion request"
// @Success 200 {object} map[string]interface{}
// @Router /v1/conversation/completion [post]
func (h *ChatSessionHandler) Completion(c *gin.Context) {
// @Produce json, text/event-stream
// @Param request body ChatCompletionsRequest true "chat completion request"
// @Success 200 {object} map[string]interface{} "Non-streaming JSON response"
// @Success 200 {string} text/event-stream "Streaming SSE response"
// @Router /api/v1/chat/completions [post]
func (h *ChatSessionHandler) ChatCompletions(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
@@ -229,83 +134,97 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) {
}
userID := user.ID
// Parse request body
var req CompletionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": err.Error(),
})
var rawBody map[string]interface{}
if err := c.ShouldBindJSON(&rawBody); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
return
}
// Build chat model config
chatModelConfig := make(map[string]interface{})
if req.Temperature != 0 {
chatModelConfig["temperature"] = req.Temperature
var req ChatCompletionsRequest
b, err := json.Marshal(rawBody)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
return
}
if req.TopP != 0 {
chatModelConfig["top_p"] = req.TopP
if err := json.Unmarshal(b, &req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
return
}
if req.FrequencyPenalty != 0 {
chatModelConfig["frequency_penalty"] = req.FrequencyPenalty
// Normalize session_id / conversation_id
sessionID := req.SessionID
if sessionID == "" {
sessionID = req.ConversationID
}
if req.PresencePenalty != 0 {
chatModelConfig["presence_penalty"] = req.PresencePenalty
// Build generation config
genConfig := make(map[string]interface{})
if req.Temperature != nil {
genConfig["temperature"] = *req.Temperature
}
if req.MaxTokens != 0 {
chatModelConfig["max_tokens"] = req.MaxTokens
if req.TopP != nil {
genConfig["top_p"] = *req.TopP
}
if req.FrequencyPenalty != nil {
genConfig["frequency_penalty"] = *req.FrequencyPenalty
}
if req.PresencePenalty != nil {
genConfig["presence_penalty"] = *req.PresencePenalty
}
if req.MaxTokens != nil {
genConfig["max_tokens"] = *req.MaxTokens
}
// Resolve pass_all_history from either alias
passAllHistory := false
if req.PassAllHistory != nil {
passAllHistory = *req.PassAllHistory
}
if req.PassAllHistoryMessages != nil {
passAllHistory = *req.PassAllHistoryMessages
}
// Remove known keys from rawBody; what remains is passthrough kwargs
knownKeys := []string{
"chat_id", "session_id", "conversation_id",
"messages", "question", "files",
"llm_id",
"pass_all_history_messages", "pass_all_history",
"legacy", "stream",
"temperature", "top_p", "frequency_penalty", "presence_penalty", "max_tokens",
}
for _, key := range knownKeys {
delete(rawBody, key)
}
kwargs := rawBody
// Determine stream mode
streamMode := true
if req.Stream != nil {
chatModelConfig["stream"] = *req.Stream
}
if req.Thinking != nil {
chatModelConfig["thinking"] = *req.Thinking
streamMode = *req.Stream
}
// Process messages - filter out system messages and initial assistant messages
var processedMessages []map[string]interface{}
for i, m := range req.Messages {
role, _ := m["role"].(string)
if role == "system" {
continue
}
if role == "assistant" && len(processedMessages) == 0 {
continue
}
processedMessages = append(processedMessages, m)
_ = i
}
// Get last message ID if present
var messageID string
if len(processedMessages) > 0 {
if id, ok := processedMessages[len(processedMessages)-1]["id"].(string); ok {
messageID = id
}
}
// Call service
if req.Stream != nil && *req.Stream {
// Streaming response
if streamMode {
disableWriteDeadlineForSSE(c)
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
// Create a channel for streaming data
streamChan := make(chan string)
streamChan := make(chan string, 32)
reqCtx := c.Request.Context()
go func() {
defer close(streamChan)
err := h.chatSessionService.CompletionStream(reqCtx, userID, req.ConversationID, processedMessages, req.LLMID, chatModelConfig, messageID, streamChan)
if err != nil {
streamChan <- fmt.Sprintf("data: %s\n\n", err.Error())
}
_, _ = h.chatSessionService.ChatCompletions(
reqCtx, userID,
req.ChatID, sessionID,
req.Messages, req.Question, req.Files,
req.LLMID, genConfig, kwargs,
passAllHistory, req.Legacy,
true, streamChan,
)
}()
// Stream data to client
c.Stream(func(w io.Writer) bool {
data, ok := <-streamChan
if !ok {
@@ -315,8 +234,14 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) {
return true
})
} else {
// Non-streaming response
result, err := h.chatSessionService.Completion(userID, req.ConversationID, processedMessages, req.LLMID, chatModelConfig, messageID)
result, err := h.chatSessionService.ChatCompletions(
c.Request.Context(), userID,
req.ChatID, sessionID,
req.Messages, req.Question, req.Files,
req.LLMID, genConfig, kwargs,
passAllHistory, req.Legacy,
false, nil,
)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
@@ -324,7 +249,6 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) {
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": result,

View File

@@ -292,6 +292,12 @@ func (r *Router) Setup(engine *gin.Engine) {
chats.PATCH("/:chat_id/sessions/:session_id", r.chatSessionHandler.UpdateSession)
}
chat := v1.Group("/chat")
{
// Chat completions route
chat.POST("/completions", r.chatSessionHandler.ChatCompletions)
}
// OpenAI-compatible chat completions route
openai := v1.Group("/openai")
{
@@ -498,7 +504,7 @@ func (r *Router) Setup(engine *gin.Engine) {
provider.PATCH("/:provider_name/instances/:instance_name/models/*model_name", r.providerHandler.EnableOrDisableModel)
provider.POST("/:provider_name/instances/:instance_name/models", r.providerHandler.AddModel)
provider.DELETE("/:provider_name/instances/:instance_name/models", r.providerHandler.DropInstanceModels)
v1.POST("/chat/completions", r.providerHandler.ChatToModel)
v1.POST("/chat/to_model", r.providerHandler.ChatToModel)
v1.POST("/embeddings", r.providerHandler.EmbedText)
v1.POST("/rerank", r.providerHandler.RerankDocument)
v1.POST("/audio/transcriptions", r.providerHandler.TranscribeAudio)
@@ -678,14 +684,6 @@ func (r *Router) Setup(engine *gin.Engine) {
chunk.POST("/update", r.chunkHandler.UpdateChunk) // Internal API only for GO
}
// Chat routes
chat := authorized.Group("/v1/dialog")
{
chat.POST("/next", r.chatHandler.ListChatsNext)
chat.POST("/set", r.chatHandler.SetDialog)
chat.POST("/rm", r.chatHandler.RemoveChats)
}
// Chat Channel
chanChannel := v1.Group("/chat-channels")
{
@@ -705,15 +703,6 @@ func (r *Router) Setup(engine *gin.Engine) {
langfuse.DELETE("/api-key", r.langfuseHandler.DeleteAPIKey)
}
// Chat session (conversation) routes
session := authorized.Group("/v1/conversation")
{
session.POST("/set", r.chatSessionHandler.SetChatSession)
session.POST("/rm", r.chatSessionHandler.RemoveChatSessions)
session.GET("/list", r.chatSessionHandler.ListChatSessions)
session.POST("/completion", r.chatSessionHandler.Completion)
}
// Connector routes
connector := authorized.Group("/v1/connector")
{

View File

@@ -28,28 +28,6 @@ import (
"ragflow/internal/dao"
)
var DefaultPromptConfig = PromptConfig{
System: strPtr(pyDefaultSystemPrompt),
Prologue: strPtr(pyDefaultPrologue),
Parameters: []ParameterConfig{
{Key: "knowledge", Optional: false},
},
EmptyResponse: strPtr(pyDefaultEmptyResponse),
Quote: boolPtr(true),
TTS: boolPtr(false),
RefineMultiturn: boolPtr(true),
}
var DefaultDirectChatPromptConfig = PromptConfig{
System: strPtr(""),
Prologue: strPtr(""),
Parameters: []ParameterConfig{},
EmptyResponse: strPtr(""),
Quote: boolPtr(false),
TTS: boolPtr(false),
RefineMultiturn: boolPtr(true),
}
var DefaultRerankModels = map[string]struct{}{
"BAAI/bge-reranker-v2-m3": {},
"maidalun1020/bce-reranker-base_v1": {},
@@ -653,74 +631,6 @@ func isTruthy(value interface{}) bool {
}
}
// ListChatsNextRequest list chats next request
type ListChatsNextRequest struct {
OwnerIDs []string `json:"owner_ids,omitempty"`
}
// ListChatsNextResponse list chats next response
type ListChatsNextResponse struct {
Chats []*ChatWithKBNames `json:"dialogs"`
Total int64 `json:"total"`
}
// ListChatsNext list chats with advanced filtering (equivalent to list_dialogs_next)
func (s *ChatService) ListChatsNext(userID string, keywords string, page, pageSize int, orderby string, desc bool, ownerIDs []string) (*ListChatsNextResponse, error) {
var chats []*entity.Chat
var total int64
var err error
if len(ownerIDs) == 0 {
// Get tenant IDs by user ID (joined tenants)
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
if err != nil {
return nil, err
}
// Use database pagination
chats, total, err = s.chatDAO.ListByTenantIDs(tenantIDs, userID, page, pageSize, orderby, desc, keywords)
if err != nil {
return nil, err
}
} else {
// Filter by owner IDs, manual pagination
chats, total, err = s.chatDAO.ListByOwnerIDs(ownerIDs, userID, orderby, desc, keywords)
if err != nil {
return nil, err
}
// Manual pagination
if page > 0 && pageSize > 0 {
start := (page - 1) * pageSize
end := start + pageSize
if start < int(total) {
if end > int(total) {
end = int(total)
}
chats = chats[start:end]
} else {
chats = []*entity.Chat{}
}
}
}
// Enrich with knowledge base names
chatsWithKBNames := make([]*ChatWithKBNames, 0, len(chats))
for _, chat := range chats {
kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs)
chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{
Chat: chat,
KBNames: kbNames,
DatasetIDs: datasetIDs,
})
}
return &ListChatsNextResponse{
Chats: chatsWithKBNames,
Total: total,
}, nil
}
// getDatasetNamesAndIDs gets knowledge base names by IDs
func (s *ChatService) getDatasetNamesAndIDs(kbIDs entity.JSONSlice) ([]string, []string) {
var names = make([]string, 0, 0)
@@ -743,30 +653,6 @@ func (s *ChatService) getDatasetNamesAndIDs(kbIDs entity.JSONSlice) ([]string, [
return names, ids
}
// ParameterConfig parameter configuration in prompt_config
type ParameterConfig struct {
Key string `json:"key"`
Optional bool `json:"optional"`
}
// PromptConfig prompt configuration
type PromptConfig struct {
System *string `json:"system"`
Prologue *string `json:"prologue"`
Parameters []ParameterConfig `json:"parameters"`
EmptyResponse *string `json:"empty_response"`
TavilyAPIKey string `json:"tavily_api_key,omitempty"`
Keyword *bool `json:"keyword,omitempty"`
Quote *bool `json:"quote,omitempty"`
Reasoning *bool `json:"reasoning,omitempty"`
RefineMultiturn *bool `json:"refine_multiturn,omitempty"`
TocEnhance *bool `json:"toc_enhance,omitempty"`
TTS *bool `json:"tts,omitempty"`
UseKG *bool `json:"use_kg,omitempty"`
CrossLanguages []string `json:"cross_languages,omitempty"`
ReferenceMetadata map[string]interface{} `json:"reference_metadata,omitempty"`
}
const (
pyDefaultSystemPrompt = "You are an intelligent assistant. Please summarize the content of the dataset to answer the question. " +
"Please list the data in the dataset and answer in detail. " +
@@ -781,393 +667,6 @@ const (
pyDefaultEmptyResponse = "Sorry! No relevant content was found in the knowledge base!"
)
// applyPromptDefaults replaces missing keys with default values
func applyPromptDefaults(p *PromptConfig) {
if p.System == nil || *p.System == "" {
s := pyDefaultSystemPrompt
p.System = &s
}
if p.Prologue == nil {
s := pyDefaultPrologue
p.Prologue = &s
}
if p.Parameters == nil {
p.Parameters = []ParameterConfig{{Key: "knowledge", Optional: false}}
}
if p.EmptyResponse == nil {
s := pyDefaultEmptyResponse
p.EmptyResponse = &s
}
if p.Quote == nil {
t := true
p.Quote = &t
}
if p.RefineMultiturn == nil {
t := true
p.RefineMultiturn = &t
}
if p.TTS == nil {
f := false
p.TTS = &f
}
}
// SetDialogRequest set chat request
type SetDialogRequest struct {
DialogID string `json:"dialog_id,omitempty"`
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Icon string `json:"icon,omitempty"`
TopN int64 `json:"top_n,omitempty"`
TopK int64 `json:"top_k,omitempty"`
RerankID string `json:"rerank_id,omitempty"`
SimilarityThreshold float64 `json:"similarity_threshold,omitempty"`
VectorSimilarityWeight float64 `json:"vector_similarity_weight,omitempty"`
LLMSetting map[string]interface{} `json:"llm_setting,omitempty"`
MetaDataFilter map[string]interface{} `json:"meta_data_filter,omitempty"`
PromptConfig *PromptConfig `json:"prompt_config" binding:"required"`
KBIDs []string `json:"kb_ids,omitempty"`
LLMID string `json:"llm_id,omitempty"`
}
// SetDialogResponse set chat response
type SetDialogResponse struct {
*entity.Chat
KBNames []string `json:"kb_names"`
}
// SetDialog create or update a chat
func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialogResponse, error) {
// Determine if this is a create or update operation
isCreate := req.DialogID == ""
// Validate and process name
name := req.Name
if name == "" {
name = "New Chat"
}
// Validate name type and content
if strings.TrimSpace(name) == "" {
return nil, errors.New("Chat name can't be empty")
}
// Check name length (UTF-8 byte length)
if len(name) > 255 {
return nil, fmt.Errorf("Chat name length is %d which is larger than 255", len(name))
}
name = strings.TrimSpace(name)
// Get tenant ID (use userID as default tenant)
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
if err != nil {
return nil, err
}
var tenantID string
if len(tenantIDs) > 0 {
tenantID = tenantIDs[0]
} else {
tenantID = userID
}
// For create: check for duplicate names and generate unique name
if isCreate {
existingNames, err := s.chatDAO.GetExistingNames(tenantID, "1")
if err != nil {
return nil, err
}
// Check if name exists (case-insensitive)
nameLower := strings.ToLower(name)
for _, existing := range existingNames {
if strings.ToLower(existing) == nameLower {
// Generate unique name
name = s.generateUniqueName(name, existingNames)
break
}
}
}
// Set default values
description := req.Description
if description == "" {
description = "A helpful chat"
}
topN := req.TopN
if topN == 0 {
topN = 6
}
topK := req.TopK
if topK == 0 {
topK = 1024
}
rerankID := req.RerankID
similarityThreshold := req.SimilarityThreshold
if similarityThreshold == 0 {
similarityThreshold = 0.1
}
vectorSimilarityWeight := req.VectorSimilarityWeight
if vectorSimilarityWeight == 0 {
vectorSimilarityWeight = 0.3
}
llmSetting := req.LLMSetting
if llmSetting == nil {
llmSetting = make(map[string]interface{})
}
metaDataFilter := req.MetaDataFilter
if metaDataFilter == nil {
metaDataFilter = make(map[string]interface{})
}
promptConfig := req.PromptConfig
// Process kb_ids
kbIDs := req.KBIDs
if kbIDs == nil {
kbIDs = []string{}
}
// Apply default prompt config on create only
if isCreate {
applyPromptDefaults(promptConfig)
}
// Set default parameters for datasets with knowledge retrieval
// Check if parameters is missing or empty and kb_ids is provided
if len(kbIDs) > 0 && (promptConfig.Parameters == nil || len(promptConfig.Parameters) == 0) {
// Check if system prompt uses {knowledge} placeholder
if promptConfig.System != nil && strings.Contains(*promptConfig.System, "{knowledge}") {
promptConfig.Parameters = []ParameterConfig{
{Key: "knowledge", Optional: false},
}
}
}
// For update: validate that {knowledge} is not used when no KBs or Tavily
if !isCreate {
if len(kbIDs) == 0 && promptConfig.TavilyAPIKey == "" &&
promptConfig.System != nil && strings.Contains(*promptConfig.System, "{knowledge}") {
return nil, errors.New("Please remove `{knowledge}` in system prompt since no dataset / Tavily used here")
}
}
// Validate parameters
for _, p := range promptConfig.Parameters {
if p.Optional {
continue
}
placeholder := fmt.Sprintf("{%s}", p.Key)
if promptConfig.System == nil || !strings.Contains(*promptConfig.System, placeholder) {
return nil, fmt.Errorf("Parameter '%s' is not used", p.Key)
}
}
// Check knowledge bases and their embedding models
if len(kbIDs) > 0 {
kbs, err := s.kbDAO.GetByIDs(kbIDs)
if err != nil {
return nil, err
}
// Check if all KBs use the same embedding model
var embdID string
for i, kb := range kbs {
if i == 0 {
embdID = kb.EmbdID
} else {
// Extract base model name (remove vendor suffix)
embdBase := s.splitModelNameAndFactory(embdID)
kbEmbdBase := s.splitModelNameAndFactory(kb.EmbdID)
if embdBase != kbEmbdBase {
return nil, fmt.Errorf("Datasets use different embedding models: %v", getEmbdIDs(kbs))
}
}
}
}
// Get LLM ID (use tenant's default if not provided)
llmID := req.LLMID
if llmID == "" {
tenant, err := s.tenantDAO.GetByID(tenantID)
if err != nil {
return nil, errors.New("Tenant not found")
}
llmID = tenant.LLMID
}
// Convert prompt config to JSONMap
promptConfigMap := entity.JSONMap{}
if promptConfig.System != nil && *promptConfig.System != "" {
promptConfigMap["system"] = *promptConfig.System
}
if promptConfig.Prologue != nil {
promptConfigMap["prologue"] = *promptConfig.Prologue
}
if promptConfig.EmptyResponse != nil {
promptConfigMap["empty_response"] = *promptConfig.EmptyResponse
}
if promptConfig.Quote != nil {
promptConfigMap["quote"] = *promptConfig.Quote
}
if promptConfig.RefineMultiturn != nil {
promptConfigMap["refine_multiturn"] = *promptConfig.RefineMultiturn
}
if promptConfig.TTS != nil {
promptConfigMap["tts"] = *promptConfig.TTS
}
if promptConfig.Keyword != nil {
promptConfigMap["keyword"] = *promptConfig.Keyword
}
if promptConfig.Reasoning != nil {
promptConfigMap["reasoning"] = *promptConfig.Reasoning
}
if promptConfig.TocEnhance != nil {
promptConfigMap["toc_enhance"] = *promptConfig.TocEnhance
}
if promptConfig.UseKG != nil {
promptConfigMap["use_kg"] = *promptConfig.UseKG
}
if promptConfig.TavilyAPIKey != "" {
promptConfigMap["tavily_api_key"] = promptConfig.TavilyAPIKey
}
if len(promptConfig.CrossLanguages) > 0 {
promptConfigMap["cross_languages"] = promptConfig.CrossLanguages
}
if len(promptConfig.ReferenceMetadata) > 0 {
promptConfigMap["reference_metadata"] = promptConfig.ReferenceMetadata
}
if len(promptConfig.Parameters) > 0 {
params := make([]map[string]interface{}, len(promptConfig.Parameters))
for i, p := range promptConfig.Parameters {
params[i] = map[string]interface{}{
"key": p.Key,
"optional": p.Optional,
}
}
promptConfigMap["parameters"] = params
}
// Convert kbIDs to JSONSlice
kbIDsJSON := make(entity.JSONSlice, len(kbIDs))
for i, id := range kbIDs {
kbIDsJSON[i] = id
}
if isCreate {
// Generate UUID for new chat
newID := common.GenerateUUID()
// Set default language
language := "English"
// Create new chat
chat := &entity.Chat{
ID: newID,
TenantID: tenantID,
Name: &name,
Description: &description,
Icon: &req.Icon,
Language: &language,
LLMID: llmID,
LLMSetting: llmSetting,
PromptConfig: promptConfigMap,
MetaDataFilter: (*entity.JSONMap)(&metaDataFilter),
TopN: topN,
TopK: topK,
RerankID: rerankID,
SimilarityThreshold: similarityThreshold,
VectorSimilarityWeight: vectorSimilarityWeight,
KBIDs: kbIDsJSON,
Status: strPtr("1"),
}
if err := s.chatDAO.Create(chat); err != nil {
return nil, errors.New("Fail to new a chat")
}
// Get KB names
kbNames, _ := s.getDatasetNamesAndIDs(chat.KBIDs)
return &SetDialogResponse{
Chat: chat,
KBNames: kbNames,
}, nil
}
updateData := map[string]interface{}{
"name": name,
"description": description,
"icon": req.Icon,
"llm_id": llmID,
"llm_setting": llmSetting,
"prompt_config": promptConfigMap,
"meta_data_filter": metaDataFilter,
"top_n": topN,
"top_k": topK,
"rerank_id": rerankID,
"similarity_threshold": similarityThreshold,
"vector_similarity_weight": vectorSimilarityWeight,
"kb_ids": kbIDsJSON,
}
if err := s.chatDAO.UpdateByID(req.DialogID, updateData); err != nil {
return nil, errors.New("Dialog not found")
}
// Get updated chat
chat, err := s.chatDAO.GetByID(req.DialogID)
if err != nil {
return nil, errors.New("Fail to update a chat")
}
// Get KB names
kbNames, _ := s.getDatasetNamesAndIDs(chat.KBIDs)
return &SetDialogResponse{
Chat: chat,
KBNames: kbNames,
}, nil
}
// generateUniqueName generates a unique name by appending a number
func (s *ChatService) generateUniqueName(name string, existingNames []string) string {
baseName := name
counter := 1
// Check if name already has a suffix like "(1)"
if idx := strings.LastIndex(name, "("); idx > 0 {
if idx2 := strings.LastIndex(name, ")"); idx2 > idx {
if num, err := fmt.Sscanf(name[idx+1:idx2], "%d", &counter); err == nil && num == 1 {
baseName = strings.TrimSpace(name[:idx])
counter++
}
}
}
existingMap := make(map[string]bool)
for _, n := range existingNames {
existingMap[strings.ToLower(n)] = true
}
newName := name
for {
if !existingMap[strings.ToLower(newName)] {
return newName
}
newName = fmt.Sprintf("%s(%d)", baseName, counter)
counter++
}
}
// splitModelNameAndFactory extracts the base model name (removes vendor suffix)
func (s *ChatService) splitModelNameAndFactory(embdID string) string {
// Remove vendor suffix (e.g., "model@openai" -> "model")
@@ -1629,52 +1128,6 @@ func (s *ChatService) BulkDeleteChats(userID string, req *BulkDeleteChatsRequest
return nil, errors.New(strings.Join(errorsList, "; "))
}
// RemoveChats removes dialogs by setting their status to invalid (soft delete)
// Only the owner of the chat can perform this operation
func (s *ChatService) RemoveChats(userID string, chatIDs []string) error {
// Get user's tenants
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
if err != nil {
return err
}
// Build a set of user's tenant IDs for quick lookup
tenantIDSet := make(map[string]bool)
for _, tid := range tenantIDs {
tenantIDSet[tid] = true
}
// Also add userID itself as a tenant (for cases where tenant_id = user_id)
tenantIDSet[userID] = true
// Check each chat and build update list
var updates []map[string]interface{}
for _, chatID := range chatIDs {
// Get the chat to check ownership
chat, err := s.chatDAO.GetByID(chatID)
if err != nil {
return fmt.Errorf("chat not found: %s", chatID)
}
// Check if user is the owner (chat's tenant_id must be in user's tenants)
if !tenantIDSet[chat.TenantID] {
return errors.New("only owner of chat authorized for this operation")
}
// Add to update list (soft delete by setting status to "0")
updates = append(updates, map[string]interface{}{
"id": chatID,
"status": "0",
})
}
// Batch update all dialogs
if err := s.chatDAO.UpdateManyByID(updates); err != nil {
return err
}
return nil
}
// strPtr returns a pointer to a string
func strPtr(s string) *string {
return &s

View File

@@ -171,16 +171,23 @@ func (s *ChatPipelineService) AsyncChat(
}
// No KBs & no web search → fast-path to LLM-only chat.
hasKBs := false
for _, raw := range chat.KBIDs {
if id, ok := raw.(string); ok && id != "" {
hasKBs = true
break
}
}
useWebSearch := s.shouldUseWebSearch(chat, kwargs["internet"])
if useWebSearch {
common.Debug("web_search",
zap.Bool("kb", len(chat.KBIDs) > 0),
zap.Bool("kb", hasKBs),
zap.Bool("tavily", chat.PromptConfig != nil && chat.PromptConfig["tavily_api_key"] != "" && chat.PromptConfig["tavily_api_key"] != nil),
zap.Any("internet", kwargs["internet"]),
zap.Bool("enabled", useWebSearch))
}
if len(chat.KBIDs) == 0 && !useWebSearch {
if !hasKBs && !useWebSearch {
return s.AsyncChatSolo(ctx, chat, messages, stream)
}
@@ -1022,8 +1029,7 @@ func (s *ChatPipelineService) AsyncChat(
if stream {
// Streaming path: accumulate answer, emit deltas.
var fullAnswer string
var fullReasoning string
thinkState := &thinkStreamState{}
thinkState := &ThinkStreamState{}
chatCfg := BuildChatConfig(chat, nil)
@@ -1094,50 +1100,60 @@ func (s *ChatPipelineService) AsyncChat(
*chatDriver.ModelName, chatMessages, chatDriver.APIConfig, chatCfg,
func(answer *string, reason *string) error {
if reason != nil && *reason != "" {
fullReasoning += *reason
kind, output := processThinkDelta(thinkState, *reason, 16)
if kind == "marker" && output == "<think>" {
// <think> marker — emit StartToThink
out <- AsyncChatResult{
Answer: "",
Reference: map[string]interface{}{},
AudioBinary: nil,
CreatedAt: float64(time.Now().Unix()),
Final: false,
StartToThink: true,
}
} else if kind == "marker" && output == "</think>" {
// </think> marker — emit EndToThink
out <- AsyncChatResult{
Answer: "",
Reference: map[string]interface{}{},
AudioBinary: nil,
CreatedAt: float64(time.Now().Unix()),
Final: false,
EndToThink: true,
}
} else if kind == "text" && output != "" {
// Route reasoning text to Reasoning field.
// TTS is nil — chain-of-thought is not narrated.
out <- AsyncChatResult{
Reasoning: output,
Reference: map[string]interface{}{},
// TTS only narrates user-visible answer text.
AudioBinary: nil,
CreatedAt: float64(time.Now().Unix()),
Final: false,
deltas := NextThinkDelta(thinkState, *reason, 16)
for _, d := range deltas {
if d.Kind == ThinkDeltaMarker && d.Value == "<think>" {
out <- AsyncChatResult{
Answer: "",
Reference: map[string]interface{}{},
AudioBinary: nil,
CreatedAt: float64(time.Now().Unix()),
Final: false,
StartToThink: true,
}
} else if d.Kind == ThinkDeltaMarker && d.Value == "</think>" {
out <- AsyncChatResult{
Answer: "",
Reference: map[string]interface{}{},
AudioBinary: nil,
CreatedAt: float64(time.Now().Unix()),
Final: false,
EndToThink: true,
}
} else if d.Kind == ThinkDeltaText && d.Value != "" {
fullAnswer += d.Value
out <- AsyncChatResult{
Answer: d.Value,
Reference: map[string]interface{}{},
AudioBinary: s.synthesizeTTS(ttsModel, d.Value),
CreatedAt: float64(time.Now().Unix()),
Final: false,
}
}
}
}
if isContentDelta(answer) {
fullAnswer += *answer
out <- AsyncChatResult{
Answer: *answer,
Reference: map[string]interface{}{},
// Per-delta TTS for incremental audio playback.
AudioBinary: s.synthesizeTTS(ttsModel, *answer),
CreatedAt: float64(time.Now().Unix()),
Final: false,
deltas := BufferAnswerDelta(thinkState, *answer, 16)
for _, d := range deltas {
if d.Kind == ThinkDeltaMarker && d.Value == "</think>" {
out <- AsyncChatResult{
Answer: "",
Reference: map[string]interface{}{},
AudioBinary: nil,
CreatedAt: float64(time.Now().Unix()),
Final: false,
EndToThink: true,
}
} else if d.Kind == ThinkDeltaText && d.Value != "" {
out <- AsyncChatResult{
Answer: d.Value,
Reference: map[string]interface{}{},
AudioBinary: s.synthesizeTTS(ttsModel, d.Value),
CreatedAt: float64(time.Now().Unix()),
Final: false,
}
}
}
}
return nil
@@ -1152,33 +1168,34 @@ func (s *ChatPipelineService) AsyncChat(
return
}
// Flush remaining think stream buffer.
// If the LLM ended mid-think, emit remaining reasoning +
// implicit close marker.
remainingText, remainingMarker := flushThinkStream(thinkState)
if remainingText != "" {
// Flushed text belongs in Reasoning, not Answer.
out <- AsyncChatResult{
Reasoning: remainingText,
Reference: map[string]interface{}{},
AudioBinary: nil,
CreatedAt: float64(time.Now().Unix()),
Final: false,
}
}
if remainingMarker == "</think>" {
out <- AsyncChatResult{
Answer: "",
Reference: map[string]interface{}{},
AudioBinary: nil,
CreatedAt: float64(time.Now().Unix()),
Final: false,
EndToThink: true,
// Flush remaining state matching Python's final flush order
// (dialog_service.py:1601-1612): think_buffer → marker → answer_buffer → pending_after_close
// Python has no Reasoning field — all text is Answer.
for _, d := range FlushRemaining(thinkState) {
if d.Kind == ThinkDeltaMarker && d.Value == "</think>" {
out <- AsyncChatResult{
Answer: "",
Reference: map[string]interface{}{},
AudioBinary: nil,
CreatedAt: float64(time.Now().Unix()),
Final: false,
EndToThink: true,
}
} else if d.Kind == ThinkDeltaText && d.Value != "" {
out <- AsyncChatResult{
Answer: d.Value,
Reference: map[string]interface{}{},
AudioBinary: s.synthesizeTTS(ttsModel, d.Value),
CreatedAt: float64(time.Now().Unix()),
Final: false,
}
}
}
// Decorate and yield the final answer.
visibleAnswer := s.extractVisibleAnswer(fullReasoning + fullAnswer)
// Python uses state.full_text (raw text with <think> tags) as input
// to _extract_visible_answer → decorate_answer (dialog_service.py:914-920).
visibleAnswer := s.extractVisibleAnswer(thinkState.fullText)
// Pass nil for ttsModel — audio was already produced per-delta.
final := s.decorateAnswer(ctx, visibleAnswer, kbinfos, prompt, questions, usedTokenCount, timer, embModel, chat.VectorSimilarityWeight, quote, nil, langfuseTraceID, llmModelConfig, chat.TenantID, kbTenantIDStrings(kbs), len(knowledges) > 0)
@@ -1372,28 +1389,51 @@ func (s *ChatPipelineService) AsyncChatSolo(
// 7. Drive the LLM: stream (per-delta with think markers) or non-stream (one-shot).
if stream {
var fullAnswer string
var fullReasoning string
thinkState := &thinkStreamState{}
thinkState := &ThinkStreamState{}
chatCfg := BuildChatConfig(chat, nil)
timer.Enter(common.PhaseGenerateAnswer)
driverErr := chatModel.ModelDriver.ChatStreamlyWithSender(
*chatModel.ModelName, chatMessages, chatModel.APIConfig, chatCfg,
func(answer *string, reason *string) error {
if reason != nil && *reason != "" {
fullReasoning += *reason
kind, output := processThinkDelta(thinkState, *reason, 16)
if kind == "marker" && output == "<think>" {
// Start thinking.
out <- AsyncChatResult{
Answer: "",
Reference: map[string]interface{}{},
AudioBinary: nil,
CreatedAt: float64(time.Now().Unix()),
Final: false,
StartToThink: true,
deltas := NextThinkDelta(thinkState, *reason, 16)
for _, d := range deltas {
if d.Kind == ThinkDeltaMarker && d.Value == "<think>" {
out <- AsyncChatResult{
Answer: "",
Reference: map[string]interface{}{},
AudioBinary: nil,
CreatedAt: float64(time.Now().Unix()),
Final: false,
StartToThink: true,
}
} else if d.Kind == ThinkDeltaMarker && d.Value == "</think>" {
out <- AsyncChatResult{
Answer: "",
Reference: map[string]interface{}{},
AudioBinary: nil,
CreatedAt: float64(time.Now().Unix()),
Final: false,
EndToThink: true,
}
} else if d.Kind == ThinkDeltaText && d.Value != "" {
fullAnswer += d.Value
out <- AsyncChatResult{
Answer: d.Value,
Reference: map[string]interface{}{},
AudioBinary: s.synthesizeTTS(ttsModel, d.Value),
CreatedAt: float64(time.Now().Unix()),
Final: false,
}
}
} else if kind == "marker" && output == "</think>" {
// End thinking.
}
}
if isContentDelta(answer) {
fullAnswer += *answer
deltas := BufferAnswerDelta(thinkState, *answer, 16)
for _, d := range deltas {
if d.Kind == ThinkDeltaMarker && d.Value == "</think>" {
out <- AsyncChatResult{
Answer: "",
Reference: map[string]interface{}{},
@@ -1402,27 +1442,17 @@ func (s *ChatPipelineService) AsyncChatSolo(
Final: false,
EndToThink: true,
}
} else if kind == "text" && output != "" {
// Reasoning text with per-delta TTS.
} else if d.Kind == ThinkDeltaText && d.Value != "" {
out <- AsyncChatResult{
Reasoning: output,
Answer: d.Value,
Reference: map[string]interface{}{},
AudioBinary: s.synthesizeTTS(ttsModel, output),
AudioBinary: s.synthesizeTTS(ttsModel, d.Value),
CreatedAt: float64(time.Now().Unix()),
Final: false,
}
}
}
if isContentDelta(answer) {
fullAnswer += *answer
out <- AsyncChatResult{
Answer: *answer,
Reference: map[string]interface{}{},
AudioBinary: s.synthesizeTTS(ttsModel, *answer),
CreatedAt: float64(time.Now().Unix()),
Final: false,
}
}
}
return nil
},
)
@@ -1434,33 +1464,30 @@ func (s *ChatPipelineService) AsyncChatSolo(
return
}
timer.Exit(common.PhaseGenerateAnswer)
// Flush any remaining think buffer.
remainingText, remainingMarker := flushThinkStream(thinkState)
if remainingText != "" {
out <- AsyncChatResult{
Reasoning: remainingText,
Reference: map[string]interface{}{},
AudioBinary: s.synthesizeTTS(ttsModel, remainingText),
CreatedAt: float64(time.Now().Unix()),
Final: false,
for _, d := range FlushRemaining(thinkState) {
if d.Kind == ThinkDeltaMarker && d.Value == "</think>" {
out <- AsyncChatResult{
Answer: "",
Reference: map[string]interface{}{},
AudioBinary: nil,
CreatedAt: float64(time.Now().Unix()),
Final: false,
EndToThink: true,
}
} else if d.Kind == ThinkDeltaText && d.Value != "" {
out <- AsyncChatResult{
Answer: d.Value,
Reference: map[string]interface{}{},
AudioBinary: s.synthesizeTTS(ttsModel, d.Value),
CreatedAt: float64(time.Now().Unix()),
Final: false,
}
}
}
if remainingMarker == "</think>" {
out <- AsyncChatResult{
Answer: "",
Reference: map[string]interface{}{},
AudioBinary: nil,
CreatedAt: float64(time.Now().Unix()),
Final: false,
EndToThink: true,
}
finalAnswer := ExtractVisibleAnswer(thinkState.fullText)
if finalAnswer == "" {
finalAnswer = fullAnswer
}
// Final aggregate: re-attach reasoning wrapper for non-streaming consumers.
finalAnswer := fullAnswer
if fullReasoning != "" {
finalAnswer = "<think>" + fullReasoning + "</think>" + fullAnswer
}
// Raw answer, no decorate_answer. AudioBinary=nil (per-delta TTS already emitted).
out <- AsyncChatResult{
Answer: finalAnswer,
Reference: map[string]interface{}{},
@@ -2288,7 +2315,7 @@ func (s *ChatPipelineService) kbPrompt(kbinfos map[string]interface{}, maxTokens
}
contents := make([]chunkContent, 0, len(chunksRaw))
for _, ck := range chunksRaw {
c := getMapString(ck, "content", "content_with_weight")
c := getMapString(ck, "content_with_weight", "content")
if c == "" {
continue
}
@@ -2315,7 +2342,7 @@ func (s *ChatPipelineService) kbPrompt(kbinfos map[string]interface{}, maxTokens
var result []string
for i := 0; i < chunksNum; i++ {
ck := chunksRaw[i]
c := getMapString(ck, "content", "content_with_weight")
c := getMapString(ck, "content_with_weight", "content")
if c == "" {
continue
}
@@ -2781,25 +2808,8 @@ func langfuseExtractTimeElapsed(prompt string) string {
}
// extractVisibleAnswer mirrors Python's _extract_visible_answer.
// It preserves <think> wrappers and strips stray think tags.
func (s *ChatPipelineService) extractVisibleAnswer(text string) string {
if !strings.Contains(text, "</think>") {
text = strings.ReplaceAll(text, "<think>", "")
text = strings.ReplaceAll(text, "</think>", "")
return text
}
idx := strings.LastIndex(text, "</think>")
thought := text[:idx]
answer := text[idx+len("</think>"):]
thought = strings.ReplaceAll(thought, "<think>", "")
thought = strings.ReplaceAll(thought, "</think>", "")
thought = strings.TrimSpace(thought)
answer = strings.ReplaceAll(answer, "<think>", "")
answer = strings.ReplaceAll(answer, "</think>", "")
if thought == "" {
return answer
}
return "<think>" + thought + "</think>" + answer
return ExtractVisibleAnswer(text)
}
// citationPrompt returns the citation instruction prompt.
@@ -2809,124 +2819,6 @@ func citationPrompt() string {
"(where N is the chunk number) after each sentence where the information from that chunk is used."
}
// ---------------------------------------------------------------------------
// Think-marker streaming — mirrors Python's _stream_with_think_delta.
// ---------------------------------------------------------------------------
// thinkStreamState tracks accumulated reasoning text and emits deltas.
type thinkStreamState struct {
fullText string
lastIdx int
endsWithThink bool
inThink bool
buffer string
postThinkText string
}
// nextThinkDelta computes the next delta to emit from the accumulated text.
// Mirrors _next_think_delta in dialog_service.py:1460-1487.
func nextThinkDelta(state *thinkStreamState) string {
full := state.fullText
if full == "" || len(full) <= state.lastIdx {
return ""
}
delta := full[state.lastIdx:]
if strings.HasPrefix(delta, "<think>") {
state.lastIdx += len("<think>")
return "<think>"
}
if idx := strings.Index(delta, "<think>"); idx > 0 {
state.lastIdx += idx
return delta[:idx]
}
if strings.HasSuffix(delta, "</think>") {
state.endsWithThink = true
} else if state.endsWithThink {
state.endsWithThink = false
remainder := delta
if idx := strings.Index(delta, "</think>"); idx >= 0 {
remainder = delta[idx+len("</think>"):]
}
if remainder != "" {
state.postThinkText = remainder
}
state.lastIdx = len(full)
return "</think>"
}
state.lastIdx = len(full)
if strings.HasSuffix(full, "</think>") {
state.lastIdx -= len("</think>")
}
return strings.ReplaceAll(strings.ReplaceAll(delta, "<think>", ""), "</think>", "")
}
// processThinkDelta updates the state with a new delta and returns what to emit.
// Returns the kind of emission: "marker" for think tags, "text" for content, "" for nothing.
func processThinkDelta(state *thinkStreamState, delta string, minTokens int) (kind string, output string) {
if delta == "" {
return "", ""
}
state.fullText += delta
d := nextThinkDelta(state)
if d == "" {
return "", ""
}
if d == "<think>" {
if state.inThink {
return "", ""
}
if state.buffer != "" {
kind, out := "text", state.buffer
state.buffer = ""
state.inThink = true
return kind, out
}
state.inThink = true
return "marker", "<think>"
}
if d == "</think>" {
if !state.inThink {
return "", ""
}
state.inThink = false
if state.postThinkText != "" {
state.buffer += state.postThinkText
state.postThinkText = ""
}
return "marker", "</think>"
}
state.buffer += d
if kg.NumTokensFromString(state.buffer) < minTokens {
return "", ""
}
out := state.buffer
state.buffer = ""
return "text", out
}
// flushThinkStream flushes any remaining buffered text from the think stream.
func flushThinkStream(state *thinkStreamState) (text string, marker string) {
if state.buffer != "" {
text = state.buffer
state.buffer = ""
}
if state.postThinkText != "" {
if text != "" {
text += state.postThinkText
} else {
text = state.postThinkText
}
state.postThinkText = ""
}
if state.endsWithThink {
marker = "</think>"
state.endsWithThink = false
}
return text, marker
}
// -----------------------------------------------------------------------
// Moved from sql_fallback.go (2026-06-12). SQL retrieval system, repair
// helpers, and Python parity helpers. Kept in async_chat.go because the

File diff suppressed because it is too large Load Diff

View File

@@ -26,7 +26,6 @@ type fakeSessionStore struct {
getByIDErr error
createErr error
updateByIDErr error
deleteByIDErr error
getDialogErr error
// record calls
createCalled []*entity.ChatSession
@@ -34,7 +33,6 @@ type fakeSessionStore struct {
id string
updates map[string]interface{}
}
deleteByIDIDs []string
}
func newFakeSessionStore() *fakeSessionStore {
@@ -112,12 +110,6 @@ func (f *fakeSessionStore) UpdateByID(id string, updates map[string]interface{})
}
func (f *fakeSessionStore) DeleteByID(id string) error {
f.mu.Lock()
defer f.mu.Unlock()
if f.deleteByIDErr != nil {
return f.deleteByIDErr
}
f.deleteByIDIDs = append(f.deleteByIDIDs, id)
delete(f.sessions, id)
return nil
}
@@ -179,226 +171,6 @@ func makeResultChan(results ...AsyncChatResult) <-chan AsyncChatResult {
return ch
}
// ===================================================================
// SetChatSession tests
// ===================================================================
func TestSetChatSession_CreateNew(t *testing.T) {
store := newFakeSessionStore()
dialog := &entity.Chat{ID: "dialog-1", PromptConfig: entity.JSONMap{"prologue": "Welcome!"}}
store.dialogs["dialog-1"] = dialog
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
resp, err := svc.SetChatSession("user-1", &SetChatSessionRequest{
DialogID: "dialog-1",
IsNew: true,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.ID == "" {
t.Fatal("expected session ID to be generated")
}
if resp.DialogID != "dialog-1" {
t.Fatalf("expected dialog_id=dialog-1, got %s", resp.DialogID)
}
if len(store.createCalled) != 1 {
t.Fatalf("expected 1 Create call, got %d", len(store.createCalled))
}
// Verify prologue is in the message list.
var msgs []map[string]interface{}
if err := json.Unmarshal(store.createCalled[0].Message, &msgs); err != nil {
t.Fatalf("failed to unmarshal message: %v", err)
}
if len(msgs) != 1 {
t.Fatalf("expected 1 initial message, got %d", len(msgs))
}
firstMsg := msgs[0]
if firstMsg["role"] != "assistant" || firstMsg["content"] != "Welcome!" {
t.Fatalf("unexpected prologue message: %#v", firstMsg)
}
}
func TestSetChatSession_CreateNewDefaultPrologue(t *testing.T) {
store := newFakeSessionStore()
store.dialogs["dialog-1"] = &entity.Chat{ID: "dialog-1"}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
resp, err := svc.SetChatSession("user-1", &SetChatSessionRequest{
DialogID: "dialog-1",
IsNew: true,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.ID == "" {
t.Fatal("expected session ID")
}
// Default prologue
var msgs []map[string]interface{}
json.Unmarshal(store.createCalled[0].Message, &msgs)
firstMsg := msgs[0]
if !strings.Contains(firstMsg["content"].(string), "Hi! I'm your assistant") {
t.Fatalf("expected default prologue, got %q", firstMsg["content"])
}
}
func TestSetChatSession_CreateNewDialogNotFound(t *testing.T) {
store := newFakeSessionStore()
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, err := svc.SetChatSession("user-1", &SetChatSessionRequest{
DialogID: "nonexistent",
IsNew: true,
})
if err == nil || err.Error() != "Dialog not found" {
t.Fatalf("expected 'Dialog not found' error, got %v", err)
}
}
func TestSetChatSession_UpdateExisting(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1", DialogID: "dialog-1", Name: strPtr("old name"),
}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
resp, err := svc.SetChatSession("user-1", &SetChatSessionRequest{
SessionID: "session-1",
Name: "new name",
IsNew: false,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.ID != "session-1" {
t.Fatalf("expected session-1, got %s", resp.ID)
}
if len(store.updateCalled) != 1 {
t.Fatalf("expected UpdateByID call, got %d", len(store.updateCalled))
}
}
func TestSetChatSession_UpdateNotFound(t *testing.T) {
store := newFakeSessionStore()
store.updateByIDErr = errors.New("Chat session not found")
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, err := svc.SetChatSession("user-1", &SetChatSessionRequest{
SessionID: "missing",
IsNew: false,
})
if err == nil || err.Error() != "Chat session not found" {
t.Fatalf("expected 'Chat session not found' error, got %v", err)
}
}
func TestSetChatSession_NameTruncation(t *testing.T) {
store := newFakeSessionStore()
store.dialogs["dialog-1"] = &entity.Chat{ID: "dialog-1"}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
longName := strings.Repeat("x", 300)
resp, err := svc.SetChatSession("user-1", &SetChatSessionRequest{
DialogID: "dialog-1",
Name: longName,
IsNew: true,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Name == nil || len(*resp.Name) > 255 {
t.Fatalf("expected name truncated to <=255, got len=%d", len(*resp.Name))
}
}
// ===================================================================
// RemoveChatSessions tests
// ===================================================================
func TestRemoveChatSessions_Success(t *testing.T) {
store := newFakeSessionStore()
store.sessions["conv-1"] = &entity.ChatSession{ID: "conv-1", DialogID: "dialog-1"}
store.sessions["conv-2"] = &entity.ChatSession{ID: "conv-2", DialogID: "dialog-1"}
store.dialogExists["user-1|dialog-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-1"}},
pipeline: &fakePipeline{},
}
err := svc.RemoveChatSessions("user-1", []string{"conv-1", "conv-2"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(store.deleteByIDIDs) != 2 {
t.Fatalf("expected 2 deletes, got %d", len(store.deleteByIDIDs))
}
}
func TestRemoveChatSessions_SessionNotFound(t *testing.T) {
store := newFakeSessionStore()
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-1"}},
pipeline: &fakePipeline{},
}
err := svc.RemoveChatSessions("user-1", []string{"missing"})
if err == nil || !strings.Contains(err.Error(), "not found") {
t.Fatalf("expected 'not found' error, got %v", err)
}
}
func TestRemoveChatSessions_NotOwner(t *testing.T) {
store := newFakeSessionStore()
store.sessions["conv-1"] = &entity.ChatSession{ID: "conv-1", DialogID: "dialog-1"}
// No tenant matches — dialogExists stays false for all combinations
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-other"}},
pipeline: &fakePipeline{},
}
err := svc.RemoveChatSessions("user-1", []string{"conv-1"})
if err == nil || !strings.Contains(err.Error(), "Only owner") {
t.Fatalf("expected 'Only owner' error, got %v", err)
}
}
// ===================================================================
// ListChatSessions tests
// ===================================================================
@@ -638,301 +410,6 @@ func TestUpdateSession_NotFound(t *testing.T) {
}
}
// ===================================================================
// Completion tests
// ===================================================================
func TestCompletion_Success(t *testing.T) {
store := newFakeSessionStore()
session := &entity.ChatSession{
ID: "session-1", DialogID: "dialog-1",
Message: json.RawMessage(`[{"role":"assistant","content":"Welcome!"}]`),
Reference: json.RawMessage(`[]`),
}
store.sessions["session-1"] = session
store.dialogs["dialog-1"] = &entity.Chat{
ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory",
LLMSetting: entity.JSONMap{},
}
pipeline := &fakePipeline{
resultChan: makeResultChan(
AsyncChatResult{Answer: "Hello", Reference: map[string]interface{}{"chunks": []interface{}{}}},
AsyncChatResult{Answer: " world", Final: true, Reference: map[string]interface{}{"chunks": []interface{}{}}},
),
}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: pipeline,
}
result, err := svc.Completion("user-1", "session-1", []map[string]interface{}{
{"role": "user", "content": "hi"},
}, "", nil, "msg-1")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
ans, _ := result["answer"].(string)
if ans != "Hello world" {
t.Fatalf("expected answer 'Hello world', got %q", ans)
}
got := parseMessages(store.sessions["session-1"].Message)
if len(got) != 3 {
t.Fatalf("stored messages=%#v", got)
}
if got[0]["role"] != "assistant" || got[0]["content"] != "Welcome!" {
t.Fatalf("stored prologue=%#v", got[0])
}
if got[1]["role"] != "user" || got[1]["content"] != "hi" {
t.Fatalf("stored user message=%#v", got[1])
}
if got[2]["role"] != "assistant" || got[2]["content"] != "Hello world" || got[2]["id"] != "msg-1" {
t.Fatalf("stored assistant message=%#v", got[2])
}
}
func TestCompletion_EmptyMessages(t *testing.T) {
svc := &ChatSessionService{
chatSessionDAO: &fakeSessionStore{},
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, err := svc.Completion("user-1", "session-1", nil, "", nil, "msg-1")
if err == nil || err.Error() != "messages cannot be empty" {
t.Fatalf("expected 'messages cannot be empty', got %v", err)
}
}
func TestCompletion_LastMessageNotFromUser(t *testing.T) {
svc := &ChatSessionService{
chatSessionDAO: &fakeSessionStore{},
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, err := svc.Completion("user-1", "session-1", []map[string]interface{}{
{"role": "assistant", "content": "hello"},
}, "", nil, "msg-1")
if err == nil || !strings.Contains(err.Error(), "not from user") {
t.Fatalf("expected 'not from user' error, got %v", err)
}
}
func TestCompletion_ConversationNotFound(t *testing.T) {
store := newFakeSessionStore()
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, err := svc.Completion("user-1", "missing", []map[string]interface{}{
{"role": "user", "content": "hi"},
}, "", nil, "msg-1")
if err == nil || err.Error() != "Conversation not found" {
t.Fatalf("expected 'Conversation not found', got %v", err)
}
}
func TestCompletion_DialogNotFound(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1", DialogID: "dialog-1",
Message: json.RawMessage(`[]`),
Reference: json.RawMessage(`[]`),
}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, err := svc.Completion("user-1", "session-1", []map[string]interface{}{
{"role": "user", "content": "hi"},
}, "", nil, "msg-1")
if err == nil || err.Error() != "Dialog not found" {
t.Fatalf("expected 'Dialog not found', got %v", err)
}
}
func TestCompletion_PipelineError(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1", DialogID: "dialog-1",
Message: json.RawMessage(`[]`),
Reference: json.RawMessage(`[]`),
}
store.dialogs["dialog-1"] = &entity.Chat{
ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory",
LLMSetting: entity.JSONMap{},
}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{err: errors.New("model unavailable")},
}
_, err := svc.Completion("user-1", "session-1", []map[string]interface{}{
{"role": "user", "content": "hi"},
}, "", nil, "msg-1")
if err == nil || err.Error() != "model unavailable" {
t.Fatalf("expected 'model unavailable' error, got %v", err)
}
}
// ===================================================================
// CompletionStream tests
// ===================================================================
func readStreamChan(ch <-chan string, n int) []string {
var msgs []string
for i := 0; i < n; i++ {
select {
case msg, ok := <-ch:
if !ok {
return msgs
}
msgs = append(msgs, msg)
default:
return msgs
}
}
return msgs
}
func TestCompletionStream_Success(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1", DialogID: "dialog-1",
Message: json.RawMessage(`{"messages":[{"role":"assistant","content":"Welcome!"}]}`),
Reference: json.RawMessage(`[]`),
}
store.dialogs["dialog-1"] = &entity.Chat{
ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory",
LLMSetting: entity.JSONMap{},
}
pipeline := &fakePipeline{
resultChan: makeResultChan(
AsyncChatResult{Answer: "stream", Reference: map[string]interface{}{"chunks": []interface{}{}}},
AsyncChatResult{Answer: " answer", Reference: map[string]interface{}{"chunks": []interface{}{}}},
),
}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: pipeline,
}
streamChan := make(chan string, 10)
err := svc.CompletionStream(context.Background(), "user-1", "session-1", []map[string]interface{}{
{"role": "user", "content": "hi"},
}, "", nil, "msg-1", streamChan)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Should receive data events and final signal
msgs := readStreamChan(streamChan, 5)
if len(msgs) < 3 {
t.Fatalf("expected at least 3 stream messages, got %d: %v", len(msgs), msgs)
}
// Check final signal
finalFound := false
for _, m := range msgs {
if strings.Contains(m, `"data":true`) {
finalFound = true
break
}
}
if !finalFound {
t.Fatal("expected final=true signal in stream")
}
got := parseMessages(store.sessions["session-1"].Message)
if len(got) != 3 {
t.Fatalf("stored messages=%#v", got)
}
if got[0]["role"] != "assistant" || got[0]["content"] != "Welcome!" {
t.Fatalf("stored prologue=%#v", got[0])
}
if got[1]["role"] != "user" || got[1]["content"] != "hi" {
t.Fatalf("stored user message=%#v", got[1])
}
if got[2]["role"] != "assistant" || got[2]["content"] != "stream answer" || got[2]["id"] != "msg-1" {
t.Fatalf("stored assistant message=%#v", got[2])
}
}
func TestStructureAnswerWithConv_ParsesArrayMessages(t *testing.T) {
session := &entity.ChatSession{
ID: "session-1",
Message: json.RawMessage(`[{"role":"assistant","content":"Welcome!"}]`),
}
svc := &ChatSessionService{}
ans := svc.structureAnswerWithConv(session, map[string]interface{}{
"answer": "Final answer",
"reference": map[string]interface{}{"chunks": []interface{}{}},
"final": true,
}, "msg-1", "session-1", []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
if ans["id"] != "msg-1" || ans["session_id"] != "session-1" {
t.Fatalf("ans=%#v", ans)
}
got := parseMessages(session.Message)
if len(got) != 1 {
t.Fatalf("stored messages=%#v", got)
}
if got[0]["role"] != "assistant" || got[0]["content"] != "Final answer" || got[0]["id"] != "msg-1" {
t.Fatalf("stored assistant message=%#v", got[0])
}
}
func TestParseMessages_LegacyWrappedObject(t *testing.T) {
got := parseMessages(json.RawMessage(`{"messages":[{"role":"assistant","content":"legacy"}]}`))
if !reflect.DeepEqual(got, []map[string]interface{}{{"role": "assistant", "content": "legacy"}}) {
t.Fatalf("messages=%#v", got)
}
}
func TestBuildSessionPayload_EmptyCollectionsEncodeAsEmptyArrays(t *testing.T) {
svc := &ChatSessionService{}
payload := svc.buildSessionPayload(&entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Message: nil,
Reference: json.RawMessage(`null`),
}, nil, false)
if payload.Messages == nil {
t.Fatal("messages is nil")
}
if payload.Reference == nil {
t.Fatal("reference is nil")
}
body, err := json.Marshal(payload)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
if !strings.Contains(string(body), `"messages":[]`) {
t.Fatalf("messages did not encode as empty array: %s", string(body))
}
if !strings.Contains(string(body), `"reference":[]`) {
t.Fatalf("reference did not encode as empty array: %s", string(body))
}
}
func TestParseCollections_ReturnEmptySlicesForMissingOrNull(t *testing.T) {
messageInputs := []json.RawMessage{
nil,
@@ -980,99 +457,174 @@ func TestParseCollections_ReturnNilForMalformedData(t *testing.T) {
}
}
func TestCompletionStream_EmptyMessages(t *testing.T) {
svc := &ChatSessionService{
chatSessionDAO: &fakeSessionStore{},
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
// ===================================================================
// chunksFormat tests — verifies field normalization after the rewrite.
// ===================================================================
func TestChunksFormat_NormalizesRawFieldNames(t *testing.T) {
svc := &ChatSessionService{}
ref := map[string]interface{}{
"chunks": []map[string]interface{}{
{
"chunk_id": "c1",
"content_with_weight": "hello world",
"content_ltks": "hello world ltks",
"doc_id": "d1",
"docnm_kwd": "Document 1",
"kb_id": "kb1",
"image_id": "img1",
"img_id": "img2",
"positions": []int{0, 10},
"position_int": []int{1, 11},
"doc_type_kwd": "pdf",
"similarity": 0.95,
"vector_similarity": 0.9,
"term_similarity": 0.85,
"row_id": "r1",
"url": "http://example.com",
"document_metadata": map[string]interface{}{"author": "Alice"},
},
},
}
streamChan := make(chan string, 10)
err := svc.CompletionStream(context.Background(), "user-1", "session-1", nil, "", nil, "msg-1", streamChan)
if err == nil || err.Error() != "messages cannot be empty" {
t.Fatalf("expected 'messages cannot be empty', got %v", err)
result := svc.chunksFormat(ref)
if len(result) != 1 {
t.Fatalf("expected 1 chunk, got %d", len(result))
}
c := result[0]
if c["id"] != "c1" {
t.Fatalf("id=%v", c["id"])
}
if c["content"] != "hello world" {
t.Fatalf("content=%v", c["content"])
}
if c["document_id"] != "d1" {
t.Fatalf("document_id=%v", c["document_id"])
}
if c["document_name"] != "Document 1" {
t.Fatalf("document_name=%v", c["document_name"])
}
if c["dataset_id"] != "kb1" {
t.Fatalf("dataset_id=%v", c["dataset_id"])
}
if c["image_id"] != "img1" {
t.Fatalf("image_id=%v", c["image_id"])
}
if c["doc_type"] != "pdf" {
t.Fatalf("doc_type=%v", c["doc_type"])
}
if c["similarity"] != 0.95 {
t.Fatalf("similarity=%v", c["similarity"])
}
if c["url"] != "http://example.com" {
t.Fatalf("url=%v", c["url"])
}
pos, ok := c["positions"].([]int)
if !ok || len(pos) != 2 || pos[0] != 0 {
t.Fatalf("positions=%v (%T)", c["positions"], c["positions"])
}
// Raw keys must be normalized away.
if _, exists := c["content_with_weight"]; exists {
t.Fatal("content_with_weight should not be present after normalization")
}
if _, exists := c["content_ltks"]; exists {
t.Fatal("content_ltks should not be present after normalization")
}
}
func TestCompletionStream_LastMessageNotFromUser(t *testing.T) {
svc := &ChatSessionService{
chatSessionDAO: &fakeSessionStore{},
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
func TestChunksFormat_PreservesAlreadyNormalizedFields(t *testing.T) {
svc := &ChatSessionService{}
ref := map[string]interface{}{
"chunks": []map[string]interface{}{
{
"id": "c2",
"content": "already normalized",
"document_id": "d2",
"document_name": "Doc 2",
},
},
}
streamChan := make(chan string, 10)
err := svc.CompletionStream(context.Background(), "user-1", "session-1", []map[string]interface{}{
{"role": "assistant", "content": "hello"},
}, "", nil, "msg-1", streamChan)
if err == nil || !strings.Contains(err.Error(), "not from user") {
t.Fatalf("expected 'not from user' error, got %v", err)
result := svc.chunksFormat(ref)
if len(result) != 1 {
t.Fatalf("expected 1 chunk, got %d", len(result))
}
c := result[0]
if c["id"] != "c2" {
t.Fatalf("id=%v", c["id"])
}
if c["content"] != "already normalized" {
t.Fatalf("content=%v", c["content"])
}
}
func TestCompletionStream_ConversationNotFound(t *testing.T) {
store := newFakeSessionStore()
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
func TestChunksFormat_EmptyReference(t *testing.T) {
svc := &ChatSessionService{}
streamChan := make(chan string, 10)
err := svc.CompletionStream(context.Background(), "user-1", "missing", []map[string]interface{}{
{"role": "user", "content": "hi"},
}, "", nil, "msg-1", streamChan)
if err == nil || err.Error() != "Conversation not found" {
t.Fatalf("expected 'Conversation not found', got %v", err)
if n := len(svc.chunksFormat(nil)); n != 0 {
t.Fatalf("nil ref: expected 0, got %d", n)
}
if n := len(svc.chunksFormat(map[string]interface{}{})); n != 0 {
t.Fatalf("empty ref: expected 0, got %d", n)
}
if n := len(svc.chunksFormat(map[string]interface{}{"chunks": nil})); n != 0 {
t.Fatalf("nil chunks: expected 0, got %d", n)
}
if n := len(svc.chunksFormat(map[string]interface{}{"chunks": []map[string]interface{}{}})); n != 0 {
t.Fatalf("empty chunks: expected 0, got %d", n)
}
}
func TestCompletionStream_DialogNotFound(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1", DialogID: "dialog-1",
Message: json.RawMessage(`[]`),
Reference: json.RawMessage(`[]`),
func TestChunksFormat_ChunksAsInterfaceSlice(t *testing.T) {
svc := &ChatSessionService{}
ref := map[string]interface{}{
"chunks": []interface{}{
map[string]interface{}{
"chunk_id": "c3",
"content_with_weight": "from interface slice",
},
},
}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
result := svc.chunksFormat(ref)
if len(result) != 1 {
t.Fatalf("expected 1 chunk, got %d", len(result))
}
streamChan := make(chan string, 10)
err := svc.CompletionStream(context.Background(), "user-1", "session-1", []map[string]interface{}{
{"role": "user", "content": "hi"},
}, "", nil, "msg-1", streamChan)
if err == nil || err.Error() != "Dialog not found" {
t.Fatalf("expected 'Dialog not found', got %v", err)
if result[0]["id"] != "c3" {
t.Fatalf("id=%v", result[0]["id"])
}
if result[0]["content"] != "from interface slice" {
t.Fatalf("content=%v", result[0]["content"])
}
}
func TestCompletionStream_PipelineError(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1", DialogID: "dialog-1",
Message: json.RawMessage(`[]`),
Reference: json.RawMessage(`[]`),
func TestChunksFormat_IgnoresNonMapItems(t *testing.T) {
svc := &ChatSessionService{}
ref := map[string]interface{}{
"chunks": []interface{}{
"not a map",
map[string]interface{}{
"chunk_id": "c4",
"content_with_weight": "valid chunk",
},
},
}
store.dialogs["dialog-1"] = &entity.Chat{
ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory",
LLMSetting: entity.JSONMap{},
result := svc.chunksFormat(ref)
if len(result) != 1 {
t.Fatalf("expected 1 chunk (non-maps skipped), got %d", len(result))
}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{err: errors.New("model unavailable")},
}
streamChan := make(chan string, 10)
err := svc.CompletionStream(context.Background(), "user-1", "session-1", []map[string]interface{}{
{"role": "user", "content": "hi"},
}, "", nil, "msg-1", streamChan)
if err == nil || err.Error() != "model unavailable" {
t.Fatalf("expected 'model unavailable' error, got %v", err)
if result[0]["id"] != "c4" {
t.Fatalf("id=%v", result[0]["id"])
}
}
func TestChunksFormat_UnsupportedTypeReturnsEmpty(t *testing.T) {
svc := &ChatSessionService{}
ref := map[string]interface{}{"chunks": "not a slice"}
result := svc.chunksFormat(ref)
if len(result) != 0 {
t.Fatalf("expected empty for string type, got %d", len(result))
}
}

View File

@@ -1980,7 +1980,7 @@ func (s *DatasetService) SearchDatasets(req *SearchDatasetsRequest, userID strin
// Apply meta_data_filter to get filtered doc_ids
docIDs := make([]string, len(req.DocIDs))
copy(docIDs, req.DocIDs)
if metadataFilter != nil {
if len(metadataFilter) > 0 {
metadataSvc := NewMetadataService()
flattedMeta, err := metadataSvc.GetFlattedMetaByKBs(datasetIDs)
if err != nil {

View File

@@ -644,16 +644,16 @@ func formatChunks(chunks []map[string]interface{}) []FormattedChunk {
for _, chunk := range chunks {
out = append(out, FormattedChunk{
ID: strVal(getValue(chunk, "chunk_id", "id")),
Content: strVal(getValue(chunk, "content", "content_with_weight")),
Content: strVal(getValue(chunk, "content_with_weight", "content")),
DocumentID: strVal(getValue(chunk, "doc_id", "document_id")),
DocumentName: strVal(getValue(chunk, "docnm_kwd", "document_name")),
DatasetID: strVal(getValue(chunk, "kb_id", "dataset_id")),
ImageID: strVal(getValue(chunk, "image_id", "img_id")),
Positions: getValue(chunk, "positions", "position_int"),
URL: chunk["url"],
Similarity: chunk["similarity"],
VectorSimilarity: chunk["vector_similarity"],
TermSimilarity: chunk["term_similarity"],
Similarity: sanitizeJSONFloats(chunk["similarity"]),
VectorSimilarity: sanitizeJSONFloats(chunk["vector_similarity"]),
TermSimilarity: sanitizeJSONFloats(chunk["term_similarity"]),
RowID: chunk["row_id"],
DocType: getValue(chunk, "doc_type_kwd", "doc_type"),
DocumentMetadata: chunk["document_metadata"],

View File

@@ -18,40 +18,41 @@ package service
import (
"context"
"ragflow/internal/tokenizer"
"strings"
)
const thinkOpen = "<think>"
const thinkClose = "</think>"
var stripThinkReplacer = strings.NewReplacer("<think>", "", "</think>", "")
// ThinkStreamState holds accumulated state across streaming LLM chunks
// so that <think>...</think> tags can be surfaced as structured markers.
//
// Corresponds to _ThinkStreamState in api/db/services/dialog_service.py.
// so that <think>...</think> tags can be surfaced as structured markers
type ThinkStreamState struct {
// fullText accumulates all text received so far.
fullText string
// lastIdx is the last consumed position in fullText.
lastIdx int
// lastFull is the previous fullText snapshot.
lastFull string
// lastModelFull is the previous model chunk for diffing.
// lastModelFull is the previous model-full snapshot for diffing
lastModelFull string
// inThink is true when we are currently inside a <think> block.
inThink bool
// buffer accumulates visible text before flushing (for batching).
buffer string
// postThinkText holds text between </think> and the next <think> or end
// of delta. Kept for API alignment with Python; may be used by future
// callers that need per-delta visibility into think boundaries.
postThinkText string
// closePending defers emission of </think> when no visible text follows the tag
closePending bool
// pendingAfterClose collects text received after a deferred </think>
pendingAfterClose string
// thinkBuffer is the think-buffer
thinkBuffer string
// answerBuffer accumulates answer-side text before token-batch flushing
answerBuffer string
// carry holds text at the end of a chunk that may be a partial <think> or </think> prefix.
carry string
}
// ThinkDeltaKind describes the type of a think-tag delta event.
type ThinkDeltaKind int
const (
ThinkDeltaText ThinkDeltaKind = iota // visible answer text
ThinkDeltaText ThinkDeltaKind = iota // think-side or answer-side text
ThinkDeltaMarker // <think> or </think> tag boundary
)
@@ -61,93 +62,239 @@ type ThinkDelta struct {
Value string
}
// emitText returns the batched text and its kind.
func emitText(state *ThinkStreamState, section string, text string, minTokens int) (string, ThinkDeltaKind) {
if text == "" {
return "", 0
}
if section == "think" {
state.thinkBuffer += text
if tokenizer.NumTokensFromString(state.thinkBuffer) >= minTokens {
out := state.thinkBuffer
state.thinkBuffer = ""
return out, ThinkDeltaText
}
return "", 0
}
state.answerBuffer += text
if tokenizer.NumTokensFromString(state.answerBuffer) >= minTokens {
out := state.answerBuffer
state.answerBuffer = ""
return out, ThinkDeltaText
}
return "", 0
}
func flushThinkBufferInternal(state *ThinkStreamState) ThinkDelta {
if state.thinkBuffer == "" {
return ThinkDelta{}
}
out := state.thinkBuffer
state.thinkBuffer = ""
return ThinkDelta{Kind: ThinkDeltaText, Value: out}
}
func flushAnswerBufferInternal(state *ThinkStreamState) ThinkDelta {
if state.answerBuffer == "" {
return ThinkDelta{}
}
out := state.answerBuffer
state.answerBuffer = ""
return ThinkDelta{Kind: ThinkDeltaText, Value: out}
}
func stripThinkTags(s string) string {
if s == "" {
return ""
}
return stripThinkReplacer.Replace(s)
}
// tagPrefixLen returns the length of the longest suffix of s that could be a
// PARTIAL start of "<think>" or "</think>". Returns 0 if the suffix is a complete
// tag or no prefix match exists.
func tagPrefixLen(s string) int {
for i := 0; i < len(s); i++ {
sub := s[i:]
if sub == thinkOpen || sub == thinkClose {
return 0 // complete tag, not a partial prefix
}
if strings.HasPrefix(thinkOpen, sub) || strings.HasPrefix(thinkClose, sub) {
return len(sub)
}
}
return 0
}
// NextThinkDelta processes the next chunk of LLM output and returns any
// visible text or tag boundary markers that should be emitted.
//
// Pure function — no side effects beyond updating state.
func NextThinkDelta(state *ThinkStreamState, chunk string) []ThinkDelta {
func NextThinkDelta(state *ThinkStreamState, chunk string, minTokens int) []ThinkDelta {
if state == nil {
return nil
}
if state.lastFull != "" {
// Compute the delta: what's new since lastFull.
delta := strings.TrimPrefix(chunk, state.lastFull)
state.lastModelFull = delta
} else {
var newPart string
if strings.HasPrefix(chunk, state.lastModelFull) {
newPart = chunk[len(state.lastModelFull):]
state.lastModelFull = chunk
} else {
newPart = chunk
state.lastModelFull += chunk
}
state.lastFull = chunk
// Accumulate fullText from the delta.
state.fullText += state.lastModelFull
// Extract new content since lastIdx.
newPart := state.fullText[state.lastIdx:]
if len(newPart) == 0 {
if newPart == "" {
return nil
}
state.fullText += newPart
// Prepend carry from previous chunk that may complete a partial tag.
pending := state.carry + newPart
state.carry = ""
// Check if pending ends with a partial <think> or </think> prefix.
// Save it as carry so it isn't emitted as visible text.
if n := tagPrefixLen(pending); n > 0 {
state.carry = pending[len(pending)-n:]
pending = pending[:len(pending)-n]
}
var deltas []ThinkDelta
// Process character by character to detect tag boundaries.
for len(newPart) > 0 {
if !state.inThink {
idx := strings.Index(newPart, thinkOpen)
if idx < 0 {
// No more think open — buffer everything as visible text.
state.buffer += newPart
state.lastIdx += len(newPart)
break
}
// Text before <think> is visible answer.
if idx > 0 {
state.buffer += newPart[:idx]
}
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkOpen})
newPart = newPart[idx+len(thinkOpen):]
state.lastIdx += idx + len(thinkOpen)
state.inThink = true
} else {
idx := strings.Index(newPart, thinkClose)
if idx < 0 {
// Still inside think, consume all silently.
state.lastIdx += len(newPart)
break
}
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkClose})
state.postThinkText = newPart[:idx]
newPart = newPart[idx+len(thinkClose):]
state.lastIdx += idx + len(thinkClose)
state.inThink = false
// Phase 1: handle deferred </think> from a previous chunk.
if state.closePending {
state.closePending = false
if piece := flushThinkBufferInternal(state); piece.Value != "" {
deltas = append(deltas, piece)
}
state.inThink = false
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkClose})
if state.pendingAfterClose != "" {
pending = state.pendingAfterClose + pending
state.pendingAfterClose = ""
}
}
// Phase 2: process pending text for think tags.
for pending != "" {
openIdx := strings.Index(pending, thinkOpen)
closeIdx := strings.Index(pending, thinkClose)
// No tags remaining — emit to the appropriate section.
if openIdx == -1 && closeIdx == -1 {
if piece := stripThinkTags(pending); piece != "" {
section := "answer"
if state.inThink {
section = "think"
}
if out, kind := emitText(state, section, piece, minTokens); out != "" {
deltas = append(deltas, ThinkDelta{Kind: kind, Value: out})
}
}
break
}
// <think> appears first (or no </think> found).
if openIdx != -1 && (closeIdx == -1 || openIdx < closeIdx) {
before := pending[:openIdx]
if before != "" {
piece := stripThinkTags(before)
section := "answer"
if state.inThink {
section = "think"
}
if out, kind := emitText(state, section, piece, minTokens); out != "" {
deltas = append(deltas, ThinkDelta{Kind: kind, Value: out})
}
}
pending = pending[openIdx+len(thinkOpen):]
if !state.inThink {
if answerPiece := flushAnswerBufferInternal(state); answerPiece.Value != "" {
deltas = append(deltas, answerPiece)
}
if thinkPiece := flushThinkBufferInternal(state); thinkPiece.Value != "" {
deltas = append(deltas, thinkPiece)
}
state.inThink = true
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkOpen})
}
continue
}
// </think> appears first.
before := pending[:closeIdx]
after := pending[closeIdx+len(thinkClose):]
if before != "" {
piece := stripThinkTags(before)
section := "answer"
if state.inThink {
section = "think"
}
if out, kind := emitText(state, section, piece, minTokens); out != "" {
deltas = append(deltas, ThinkDelta{Kind: kind, Value: out})
}
}
if strings.TrimSpace(after) != "" {
if thinkPiece := flushThinkBufferInternal(state); thinkPiece.Value != "" {
deltas = append(deltas, thinkPiece)
}
state.inThink = false
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkClose})
pending = after
continue
}
// No visible text after close — defer the marker.
state.closePending = true
if after != "" {
state.pendingAfterClose += after
}
pending = ""
break
}
return deltas
}
// FlushThinkBuffer drains the buffered visible text, if any, as a single delta.
// Call this after all LLM chunks have been processed.
func FlushThinkBuffer(state *ThinkStreamState) []ThinkDelta {
if state == nil || state.buffer == "" {
// FlushRemaining drains all remaining buffered text and handles deferred
// markers. Call this after all LLM chunks have been processed.
func FlushRemaining(state *ThinkStreamState) []ThinkDelta {
if state == nil {
return nil
}
text := state.buffer
state.buffer = ""
return []ThinkDelta{{Kind: ThinkDeltaText, Value: text}}
var deltas []ThinkDelta
if state.thinkBuffer != "" {
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaText, Value: state.thinkBuffer})
state.thinkBuffer = ""
}
if state.closePending {
state.inThink = false
state.closePending = false
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkClose})
}
if state.answerBuffer != "" {
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaText, Value: state.answerBuffer})
state.answerBuffer = ""
}
if state.pendingAfterClose != "" {
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaText, Value: state.pendingAfterClose})
state.pendingAfterClose = ""
}
if state.carry != "" {
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaText, Value: state.carry})
state.carry = ""
}
return deltas
}
// StreamThinkTagDelta — channel-based pipeline.
// ---------------------------------------------------------------------------
// StreamThinkTagDelta takes a channel of raw LLM text chunks and produces a
// channel of (kind, value) pairs. When ctx is cancelled (e.g. client
// disconnect), the goroutine drains the input channel silently and exits,
// preventing the producer goroutine from blocking forever on send.
//
// Markers (<think>, </think>) are emitted immediately without buffering.
// channel of structured deltas. When ctx is cancelled (e.g. client
// disconnect), the goroutine drains the input channel silently and exits.
func StreamThinkTagDelta(ctx context.Context, chunks <-chan string, minTokens int) <-chan ThinkDelta {
out := make(chan ThinkDelta, 32)
go func() {
defer close(out)
state := &ThinkStreamState{}
flushSize := minTokens * 4 // approximate: ~4 bytes per token
for {
select {
case <-ctx.Done():
@@ -158,7 +305,7 @@ func StreamThinkTagDelta(ctx context.Context, chunks <-chan string, minTokens in
return
case chunk, ok := <-chunks:
if !ok {
for _, d := range FlushThinkBuffer(state) {
for _, d := range FlushRemaining(state) {
select {
case out <- d:
case <-ctx.Done():
@@ -167,34 +314,16 @@ func StreamThinkTagDelta(ctx context.Context, chunks <-chan string, minTokens in
}
return
}
deltas := NextThinkDelta(state, chunk)
deltas := NextThinkDelta(state, chunk, minTokens)
for _, d := range deltas {
if d.Kind == ThinkDeltaMarker {
select {
case out <- d:
case <-ctx.Done():
go func() {
for range chunks {
}
}()
return
}
}
}
// Flush buffered visible text when it reaches the token threshold,
// matching Python _stream_with_think_delta which yields ("text", ...)
// per chunk. Markers are emitted immediately above.
if len(state.buffer) >= flushSize {
for _, d := range FlushThinkBuffer(state) {
select {
case out <- d:
case <-ctx.Done():
go func() {
for range chunks {
}
}()
return
}
select {
case out <- d:
case <-ctx.Done():
go func() {
for range chunks {
}
}()
return
}
}
}
@@ -203,47 +332,46 @@ func StreamThinkTagDelta(ctx context.Context, chunks <-chan string, minTokens in
return out
}
// ExtractVisibleAnswer strips <think> blocks from the raw LLM response,
// returning only the visible answer text. If the response consists
// entirely of think content, returns an empty string.
//
// Corresponds to _extract_visible_answer in dialog_service.py.
// ExtractVisibleAnswer returns the visible answer text after the last </think>.
// Stray <think>/</think> tags are stripped. If there is no </think>, all tags are stripped.
func ExtractVisibleAnswer(raw string) string {
if raw == "" {
return ""
}
// Collect all non-think text.
var visible []string
remaining := raw
hasThink := false
for {
openIdx := strings.Index(remaining, thinkOpen)
if openIdx < 0 {
// No more think open tags — strip any stray </think> and keep the rest.
remaining = strings.ReplaceAll(remaining, thinkClose, "")
visible = append(visible, remaining)
break
}
hasThink = true
if openIdx > 0 {
visible = append(visible, remaining[:openIdx])
}
remaining = remaining[openIdx+len(thinkOpen):]
closeIdx := strings.Index(remaining, thinkClose)
if closeIdx < 0 {
// Unclosed think — treat rest as visible.
visible = append(visible, remaining)
break
}
remaining = remaining[closeIdx+len(thinkClose):]
if !strings.Contains(raw, thinkClose) {
return stripThinkTags(raw)
}
result := strings.TrimSpace(strings.Join(visible, ""))
if hasThink && result == "" {
// Only think content — return empty.
return ""
}
return result
lastClose := strings.LastIndex(raw, thinkClose)
answer := raw[lastClose+len(thinkClose):]
return stripThinkTags(answer)
}
// BufferAnswerDelta processes an answer delta through the think-state lifecycle.
// When closePending is true, it first flushes the deferred think buffer and </think>
// marker, then processes pendingAfterClose + the new answer text.
func BufferAnswerDelta(state *ThinkStreamState, text string, minTokens int) []ThinkDelta {
if state == nil || text == "" {
return nil
}
state.fullText += text
var deltas []ThinkDelta
if state.closePending {
state.closePending = false
if piece := flushThinkBufferInternal(state); piece.Value != "" {
deltas = append(deltas, piece)
}
state.inThink = false
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkClose})
if state.pendingAfterClose != "" {
text = state.pendingAfterClose + text
state.pendingAfterClose = ""
}
}
if out, kind := emitText(state, "answer", text, minTokens); out != "" {
deltas = append(deltas, ThinkDelta{Kind: kind, Value: out})
}
return deltas
}

View File

@@ -24,117 +24,130 @@ import (
func TestNextThinkDelta_NoThinkTag(t *testing.T) {
state := &ThinkStreamState{}
deltas := NextThinkDelta(state, "hello world")
if len(deltas) != 0 {
t.Fatalf("expected 0 deltas, got %d", len(deltas))
deltas := NextThinkDelta(state, "hello world", 0)
if len(deltas) != 1 {
t.Fatalf("expected 1 delta, got %d: %+v", len(deltas), deltas)
}
if state.buffer != "hello world" {
t.Errorf("buffer = %q", state.buffer)
if deltas[0].Kind != ThinkDeltaText || deltas[0].Value != "hello world" {
t.Errorf("expected text delta, got %+v", deltas[0])
}
}
func TestNextThinkDelta_OnlyThinkTag(t *testing.T) {
state := &ThinkStreamState{}
deltas := NextThinkDelta(state, "<think>reasoning</think>visible")
if len(deltas) != 2 {
t.Fatalf("expected 2 deltas, got %d: %+v", len(deltas), deltas)
deltas := NextThinkDelta(state, "<think>reasoning</think>visible", 0)
if len(deltas) < 2 {
t.Fatalf("expected at least 2 deltas, got %d: %+v", len(deltas), deltas)
}
if deltas[0].Kind != ThinkDeltaMarker || deltas[0].Value != "<think>" {
t.Errorf("first delta should be <think> marker: %+v", deltas[0])
foundOpen, foundClose := false, false
for _, d := range deltas {
if d.Kind == ThinkDeltaMarker && d.Value == "<think>" {
foundOpen = true
}
if d.Kind == ThinkDeltaMarker && d.Value == "</think>" {
foundClose = true
}
}
if deltas[1].Kind != ThinkDeltaMarker || deltas[1].Value != "</think>" {
t.Errorf("second delta should be </think> marker: %+v", deltas[1])
}
if state.buffer != "visible" {
t.Errorf("buffer = %q, want visible", state.buffer)
if !foundOpen || !foundClose {
t.Errorf("missing markers: open=%v close=%v, deltas=%+v", foundOpen, foundClose, deltas)
}
}
func TestNextThinkDelta_TextThenThink(t *testing.T) {
state := &ThinkStreamState{}
deltas := NextThinkDelta(state, "before <think>inside</think> after")
// "before " -> buffer (no flush yet)
// "<think>" -> marker
// "inside" -> inside think, consumed silently
// "</think>" -> marker
// " after" -> buffer
if len(deltas) != 2 {
t.Fatalf("expected 2 markers, got %d: %+v", len(deltas), deltas)
deltas := NextThinkDelta(state, "before <think>inside</think> after", 0)
foundOpen, foundClose := false, false
for _, d := range deltas {
if d.Kind == ThinkDeltaMarker && d.Value == "<think>" {
foundOpen = true
}
if d.Kind == ThinkDeltaMarker && d.Value == "</think>" {
foundClose = true
}
}
if state.buffer != "before after" {
t.Errorf("buffer = %q", state.buffer)
if !foundOpen || !foundClose {
t.Errorf("missing markers: open=%v close=%v, deltas=%+v", foundOpen, foundClose, deltas)
}
}
func TestNextThinkDelta_MultipleChunks(t *testing.T) {
state := &ThinkStreamState{}
NextThinkDelta(state, "hello ")
NextThinkDelta(state, "<think>")
NextThinkDelta(state, "reasoning")
NextThinkDelta(state, "</think>")
deltas := NextThinkDelta(state, " world")
if len(deltas) != 0 {
t.Fatalf("expected 0 deltas from final chunk, got %d", len(deltas))
deltas := NextThinkDelta(state, "hello ", 0)
_ = deltas
deltas = NextThinkDelta(state, "<think>", 0)
_ = deltas
deltas = NextThinkDelta(state, "reasoning", 0)
_ = deltas
deltas = NextThinkDelta(state, "</think>", 0)
_ = deltas
deltas = NextThinkDelta(state, " world", 0)
// After deferral, closing marker and answer text arrive in the last chunk.
foundClose := false
for _, d := range deltas {
if d.Kind == ThinkDeltaMarker && d.Value == "</think>" {
foundClose = true
}
}
if state.buffer != "hello world" {
t.Errorf("buffer = %q", state.buffer)
if !foundClose {
t.Errorf("expected </think> in final chunk, got %+v", deltas)
}
}
func TestNextThinkDelta_UnclosedThink(t *testing.T) {
state := &ThinkStreamState{}
deltas := NextThinkDelta(state, "text <think>unclosed")
if len(deltas) != 1 {
t.Fatalf("expected 1 marker (think open), got %d", len(deltas))
deltas := NextThinkDelta(state, "text <think>unclosed", 0)
foundOpen := false
for _, d := range deltas {
if d.Kind == ThinkDeltaMarker && d.Value == "<think>" {
foundOpen = true
}
}
if deltas[0].Value != "<think>" {
t.Errorf("expected <think> marker")
if !foundOpen {
t.Errorf("expected <think> marker in %+v", deltas)
}
if state.buffer != "text " {
t.Errorf("buffer = %q", state.buffer)
}
// "unclosed" should be consumed silently inside think
}
func TestNextThinkDelta_EmptyInput(t *testing.T) {
state := &ThinkStreamState{}
deltas := NextThinkDelta(state, "")
deltas := NextThinkDelta(state, "", 0)
if len(deltas) != 0 {
t.Errorf("expected 0 deltas for empty input")
t.Errorf("expected 0 deltas for empty input, got %d", len(deltas))
}
}
func TestNextThinkDelta_NilState(t *testing.T) {
deltas := NextThinkDelta(nil, "test")
deltas := NextThinkDelta(nil, "test", 0)
if deltas != nil {
t.Error("expected nil for nil state")
}
}
func TestFlushThinkBuffer_Empty(t *testing.T) {
if deltas := FlushThinkBuffer(nil); len(deltas) != 0 {
t.Error("expected empty for nil state")
func TestFlushRemaining_FlushesAll(t *testing.T) {
state := &ThinkStreamState{
thinkBuffer: "think-tail",
closePending: true,
answerBuffer: "answer-tail",
pendingAfterClose: "pending-tail",
}
state := &ThinkStreamState{}
if deltas := FlushThinkBuffer(state); len(deltas) != 0 {
t.Error("expected empty for zero state")
deltas := FlushRemaining(state)
if len(deltas) != 4 {
t.Fatalf("expected 4 deltas, got %d: %+v", len(deltas), deltas)
}
}
func TestFlushThinkBuffer_WithContent(t *testing.T) {
state := &ThinkStreamState{buffer: "flushed text"}
deltas := FlushThinkBuffer(state)
if len(deltas) != 1 {
t.Fatalf("expected 1 delta, got %d", len(deltas))
if deltas[0].Kind != ThinkDeltaText || deltas[0].Value != "think-tail" {
t.Errorf("delta[0] = %+v", deltas[0])
}
if deltas[0].Kind != ThinkDeltaText {
t.Error("expected text delta")
if deltas[1].Kind != ThinkDeltaMarker || deltas[1].Value != "</think>" {
t.Errorf("delta[1] = %+v", deltas[1])
}
if deltas[0].Value != "flushed text" {
t.Errorf("value = %q", deltas[0].Value)
if deltas[2].Kind != ThinkDeltaText || deltas[2].Value != "answer-tail" {
t.Errorf("delta[2] = %+v", deltas[2])
}
if state.buffer != "" {
t.Error("buffer should be cleared after flush")
if deltas[3].Kind != ThinkDeltaText || deltas[3].Value != "pending-tail" {
t.Errorf("delta[3] = %+v", deltas[3])
}
// State should be cleared.
if state.closePending || state.inThink {
t.Error("state not fully cleared")
}
}
@@ -160,13 +173,13 @@ func TestExtractVisibleAnswer_WithThink(t *testing.T) {
func TestExtractVisibleAnswer_ThinkOnly(t *testing.T) {
raw := "<think>only reasoning here</think>"
if got := ExtractVisibleAnswer(raw); got != "" {
t.Errorf("expected empty for think-only, got %q", got)
t.Errorf("got %q", got)
}
}
func TestExtractVisibleAnswer_MultipleThinks(t *testing.T) {
raw := "<think>first</think>visible1<think>second</think>visible2"
if got := ExtractVisibleAnswer(raw); got != "visible1visible2" {
if got := ExtractVisibleAnswer(raw); got != "visible2" {
t.Errorf("got %q", got)
}
}
@@ -178,6 +191,18 @@ func TestExtractVisibleAnswer_NestedTags(t *testing.T) {
}
}
func TestExtractVisibleAnswer_NoTags(t *testing.T) {
if got := ExtractVisibleAnswer("plain text"); got != "plain text" {
t.Errorf("got %q", got)
}
}
func TestExtractVisibleAnswer_StrayTag(t *testing.T) {
if got := ExtractVisibleAnswer("<think>text"); got != "text" {
t.Errorf("got %q", got)
}
}
func TestStreamThinkTagDelta(t *testing.T) {
chunks := []string{"hello ", "wor", "<think>", "think text", "</think>", "ld", " final"}
ch := make(chan string, len(chunks))
@@ -208,14 +233,12 @@ func TestStreamThinkTagDelta(t *testing.T) {
}
joined := strings.Join(texts, "")
if !strings.Contains(joined, "hello world") || !strings.Contains(joined, "final") {
if !strings.Contains(joined, "hello wor") || !strings.Contains(joined, "final") {
t.Errorf("texts = %q", texts)
}
}
func TestStreamThinkTagDelta_IncrementalFlush(t *testing.T) {
// Verify that visible text is streamed incrementally, not all at the end.
// minTokens=1 → flushSize=4 bytes. Each chunk triggers a flush.
chunks := []string{"1234", "5678", "90ab"}
ch := make(chan string, len(chunks))
for _, c := range chunks {
@@ -229,8 +252,6 @@ func TestStreamThinkTagDelta_IncrementalFlush(t *testing.T) {
texts = append(texts, d.Value)
}
}
// With minTokens=1 (flushSize=4), each chunk triggers a flush.
// We should get incremental text deltas, not just one final burst.
if len(texts) < 2 {
t.Errorf("expected >=2 incremental text deltas, got %d: %q", len(texts), texts)
}
@@ -254,3 +275,29 @@ func TestStreamThinkTagDelta_NoThinkTags(t *testing.T) {
t.Errorf("got %q, want 'just plain text'", joined)
}
}
func TestStreamThinkTagDelta_DeferredClose(t *testing.T) {
// When </think> has no visible text after it, the marker is deferred.
chunks := []string{"<think>", "hello", "</think>", "world"}
ch := make(chan string, len(chunks))
for _, c := range chunks {
ch <- c
}
close(ch)
var markers []string
for d := range StreamThinkTagDelta(context.Background(), ch, 1) {
if d.Kind == ThinkDeltaMarker {
markers = append(markers, d.Value)
}
}
if len(markers) != 2 {
t.Fatalf("expected 2 markers, got %d: %v", len(markers), markers)
}
if markers[0] != "<think>" {
t.Errorf("first marker = %q", markers[0])
}
if markers[1] != "</think>" {
t.Errorf("second marker = %q", markers[1])
}
}

View File

@@ -103,6 +103,11 @@ export default defineConfig(({ mode }) => {
changeOrigin: true,
ws: true,
},
'^(/api/v1/datasets/search)|^(/api/v1/chat/completions)': {
target: 'http://127.0.0.1:9384/',
changeOrigin: true,
ws: true,
},
'/api': {
target: 'http://127.0.0.1:9380/',
changeOrigin: true,