mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-04 01:29:35 +08:00
Implement chat completions in go (#16491)
### Summary POST /api/v1/chat/completions
This commit is contained in:
@@ -45,6 +45,16 @@ server {
|
||||
include proxy.conf;
|
||||
}
|
||||
|
||||
location ~ ^/api/v1/datasets/search {
|
||||
proxy_pass http://127.0.0.1:9384;
|
||||
include proxy.conf;
|
||||
}
|
||||
|
||||
location ~ ^/api/v1/chat/completions {
|
||||
proxy_pass http://127.0.0.1:9384;
|
||||
include proxy.conf;
|
||||
}
|
||||
|
||||
location ~ ^/v1/system/config {
|
||||
proxy_pass http://127.0.0.1:9384;
|
||||
include proxy.conf;
|
||||
|
||||
@@ -834,6 +834,8 @@ Commands (User Mode):
|
||||
CHAT 'provider/instance/model' 'message'; - Chat with specified model
|
||||
OPENAI_CHAT 'chat_id' 'message' [options] ; - OpenAI-compatible chat
|
||||
(run openai_chat -h for detailed options)
|
||||
CHAT COMPLETIONS 'question' [options] ; - Chat completions via /api/v1/chat/completions
|
||||
(run chat completions -h for detailed options)
|
||||
|
||||
Filesystem Commands (no quotes):
|
||||
ls [path] - List resources
|
||||
@@ -1041,3 +1043,50 @@ Examples:
|
||||
`
|
||||
fmt.Println(help)
|
||||
}
|
||||
|
||||
// printChatCompletionsHelp prints help for the CHAT COMPLETIONS command.
|
||||
func printChatCompletionsHelp() {
|
||||
help := `CHAT COMPLETIONS — hit POST /api/v1/chat/completions
|
||||
|
||||
Syntax:
|
||||
CHAT COMPLETIONS 'question'
|
||||
chat_id '...'
|
||||
[session "..."] [llm "..."]
|
||||
[system "..."] [history "..."] [history_delimiter "<char>"]
|
||||
[temperature <float>] [max_tokens <int>] [stream <bool>]
|
||||
[top_p <float>] [frequency_penalty <float>] [presence_penalty <float>]
|
||||
[pass_all_history <bool>] [legacy <bool>] ;
|
||||
|
||||
Required positional:
|
||||
'question' the user question
|
||||
|
||||
Named options (any order; all optional with defaults):
|
||||
chat_id '...' the dialog id (optional)
|
||||
session '...' existing session/conversation id
|
||||
llm '...' override the dialog's LLM
|
||||
system '...' override the system prompt
|
||||
history '...' prior turns: user:...;assistant:...;user:...
|
||||
history_delimiter '...' turn separator for history (default ';')
|
||||
temperature <float> 0..2 (default 0)
|
||||
max_tokens <int> (default 0 = server/model default)
|
||||
stream <bool> true|false (default false)
|
||||
top_p <float> 0..1
|
||||
frequency_penalty <float> -2..2
|
||||
presence_penalty <float> -2..2
|
||||
pass_all_history <bool> pass all history messages
|
||||
legacy <bool> use legacy SSE format
|
||||
|
||||
Defaults:
|
||||
stream false
|
||||
temperature 0
|
||||
history_delimiter ';'
|
||||
|
||||
Examples:
|
||||
CHAT COMPLETIONS 'Hello, how are you?' chat_id 'cid';
|
||||
CHAT COMPLETIONS 'Explain quantum computing' chat_id 'cid' stream true;
|
||||
CHAT COMPLETIONS 'Next question' chat_id 'cid' session 'sess-abc123';
|
||||
CHAT COMPLETIONS 'What about X?' chat_id 'cid' system 'You are a helpful assistant.' history 'user:Tell me about Y;assistant:Y is...';
|
||||
CHAT COMPLETIONS 'Summarize' chat_id 'cid' llm 'Qwen/Qwen3-8B@ling@SILICONFLOW' temperature 0.7 max_tokens 512;
|
||||
`
|
||||
fmt.Println(help)
|
||||
}
|
||||
|
||||
@@ -419,6 +419,11 @@ func (c *CLI) ExecuteUserCommand(cmd *Command) (ResponseIf, error) {
|
||||
return c.EmbedUserTextCommand(cmd)
|
||||
case "api_rarank_user_document":
|
||||
return c.APIRerankUserDocumentCommand(cmd)
|
||||
case "chat completions":
|
||||
return c.ChatCompletions(cmd)
|
||||
case "chat completions help":
|
||||
printChatCompletionsHelp()
|
||||
return nil, nil
|
||||
case "tts_user_command":
|
||||
return c.APITTSUserCommand(cmd)
|
||||
case "asr_user_command":
|
||||
|
||||
@@ -980,9 +980,11 @@ func getChunkID(c map[string]interface{}) string {
|
||||
}
|
||||
|
||||
func chunkContent(c map[string]interface{}) string {
|
||||
if v, ok := c["content"]; ok {
|
||||
s := fmt.Sprint(v)
|
||||
return strings.TrimSpace(s)
|
||||
for _, key := range []string{"content_with_weight", "content"} {
|
||||
if v, ok := c[key]; ok {
|
||||
s := fmt.Sprint(v)
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -1283,3 +1285,59 @@ func (r *QuotaSummaryResponse) PrintOut() {
|
||||
PrintTableSimpleByFormatWithOrder(table, section.columns, r.OutputFormat)
|
||||
}
|
||||
}
|
||||
|
||||
// ChatCompletionsResponse represents the RAGFlow-internal response from
|
||||
// POST /api/v1/chat/completions (non-OpenAI format).
|
||||
//
|
||||
// JSON shape:
|
||||
//
|
||||
// {"code":0,"data":{"answer":"...","reference":...},"message":""}
|
||||
type ChatCompletionsResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data *chatCompletionData `json:"data"`
|
||||
Message string `json:"message"`
|
||||
Duration float64 `json:"-"`
|
||||
OutputFormat OutputFormat `json:"-"`
|
||||
// raw HTTP body for "raw" output.
|
||||
raw []byte
|
||||
// streamed skips the "Answer:" line in PrintOut to avoid duplication
|
||||
// (used by the streaming path which prints chunk-by-chunk).
|
||||
streamed bool
|
||||
}
|
||||
|
||||
type chatCompletionData struct {
|
||||
Answer string `json:"answer"`
|
||||
Reference json.RawMessage `json:"reference,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *ChatCompletionsResponse) Type() string { return "chat_completions" }
|
||||
func (r *ChatCompletionsResponse) TimeCost() float64 { return r.Duration }
|
||||
func (r *ChatCompletionsResponse) SetOutputFormat(f OutputFormat) { r.OutputFormat = f }
|
||||
|
||||
func (r *ChatCompletionsResponse) PrintOut() {
|
||||
if r.OutputFormat == "raw" && r.raw != nil {
|
||||
fmt.Println(string(r.raw))
|
||||
return
|
||||
}
|
||||
if r.Code != 0 {
|
||||
fmt.Println("ERROR")
|
||||
fmt.Printf("%d, %s\n", r.Code, r.Message)
|
||||
return
|
||||
}
|
||||
if r.Data == nil {
|
||||
fmt.Println("(no data)")
|
||||
return
|
||||
}
|
||||
if !r.streamed {
|
||||
if r.Data.Answer != "" {
|
||||
fmt.Printf("Answer: %s\n", r.Data.Answer)
|
||||
}
|
||||
}
|
||||
if r.Data != nil && len(r.Data.Reference) > 0 {
|
||||
printReferenceChunks(r.Data.Reference)
|
||||
}
|
||||
fmt.Printf("Time: %f\n", r.Duration)
|
||||
}
|
||||
|
||||
@@ -32,7 +32,6 @@ import (
|
||||
"ragflow/internal/ingestion"
|
||||
"ragflow/internal/ingestion/parser"
|
||||
"ragflow/internal/utility"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -2128,7 +2127,7 @@ func (c *CLI) APIChatToModelCommand(cmd *Command) (ResponseIf, error) {
|
||||
effort := cmd.Params["effort"].(string)
|
||||
verbosity := cmd.Params["verbosity"].(string)
|
||||
|
||||
url := "/chat/completions"
|
||||
url := "/chat/to_model"
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"messages": formattedMessages,
|
||||
@@ -4133,8 +4132,6 @@ func (c *CLI) streamOpenaiChat(url string, body map[string]interface{}) (Respons
|
||||
|
||||
fullContent = strings.TrimLeft(fullContent, "\n\r")
|
||||
fullReason = strings.TrimLeft(fullReason, "\n\r")
|
||||
fullContent = stripThinkTags(fullContent)
|
||||
fullReason = stripThinkTags(fullReason)
|
||||
return &OpenAIChatResponse{
|
||||
Duration: resp.Duration,
|
||||
Reasoning: fullReason,
|
||||
@@ -4147,8 +4144,187 @@ func (c *CLI) streamOpenaiChat(url string, body map[string]interface{}) (Respons
|
||||
}, nil
|
||||
}
|
||||
|
||||
// stripThinkTags removes <think>…</think> wrappers from a streamed answer
|
||||
func stripThinkTags(s string) string {
|
||||
var thinkTagRE = regexp.MustCompile(`(?s)<think>.*?</think>`)
|
||||
return thinkTagRE.ReplaceAllString(s, "")
|
||||
// ChatCompletions dispatches the parsed CHAT COMPLETIONS command to
|
||||
// POST /api/v1/chat/completions.
|
||||
func (c *CLI) ChatCompletions(cmd *Command) (ResponseIf, error) {
|
||||
if c.Config.CLIMode != APIMode {
|
||||
return nil, fmt.Errorf("CHAT COMPLETIONS is only allowed in USER mode")
|
||||
}
|
||||
httpClient := c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer]
|
||||
if httpClient.APIKey == nil && httpClient.LoginToken == nil {
|
||||
return nil, fmt.Errorf("API token not set. Please login first")
|
||||
}
|
||||
|
||||
body, err := buildChatCompletionsRequestBody(cmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := "/chat/completions"
|
||||
|
||||
stream, _ := cmd.Params["stream"].(bool)
|
||||
if stream {
|
||||
return c.streamChatCompletions(url, body)
|
||||
}
|
||||
return c.oneshotChatCompletions(url, body)
|
||||
}
|
||||
|
||||
// buildChatCompletionsRequestBody assembles the JSON payload for
|
||||
// POST /api/v1/chat/completions.
|
||||
//
|
||||
// When system or history is provided, a `messages` array is built;
|
||||
// otherwise just `question` is sent and the server normalizes it.
|
||||
func buildChatCompletionsRequestBody(cmd *Command) (map[string]interface{}, error) {
|
||||
chatID, _ := cmd.Params["chat_id"].(string)
|
||||
question, _ := cmd.Params["question"].(string)
|
||||
stream, _ := cmd.Params["stream"].(bool)
|
||||
|
||||
body := map[string]interface{}{
|
||||
"chat_id": chatID,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
// Optional session_id
|
||||
if v, ok := cmd.Params["session"].(string); ok && v != "" {
|
||||
body["session_id"] = v
|
||||
}
|
||||
|
||||
// Optional llm_id
|
||||
if v, ok := cmd.Params["llm"].(string); ok && v != "" {
|
||||
body["llm_id"] = v
|
||||
}
|
||||
|
||||
// Build messages from system + history when provided; otherwise send question.
|
||||
system, hasSystem := cmd.Params["system"].(string)
|
||||
historyRaw, hasHistory := cmd.Params["history_raw"].(string)
|
||||
|
||||
if hasSystem || hasHistory {
|
||||
messages := make([]map[string]interface{}, 0, 4)
|
||||
if hasSystem && system != "" {
|
||||
messages = append(messages, map[string]interface{}{"role": "system", "content": system})
|
||||
}
|
||||
if hasHistory && historyRaw != "" {
|
||||
delimiter, _ := cmd.Params["history_delimiter"].(string)
|
||||
turns, err := parseHistory(historyRaw, delimiter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("CHAT COMPLETIONS history: %w", err)
|
||||
}
|
||||
for _, t := range turns {
|
||||
messages = append(messages, map[string]interface{}{
|
||||
"role": t["role"],
|
||||
"content": t["content"],
|
||||
})
|
||||
}
|
||||
}
|
||||
messages = append(messages, map[string]interface{}{"role": "user", "content": question})
|
||||
body["messages"] = messages
|
||||
} else {
|
||||
body["question"] = question
|
||||
}
|
||||
|
||||
// Optional flags — only emit when explicitly set
|
||||
if isSet(cmd, "pass_all_history") && cmd.Params["pass_all_history"].(bool) {
|
||||
body["pass_all_history_messages"] = true
|
||||
}
|
||||
if isSet(cmd, "legacy") && cmd.Params["legacy"].(bool) {
|
||||
body["legacy"] = true
|
||||
}
|
||||
|
||||
// Generation params — only emit when explicitly set
|
||||
if isSet(cmd, "temperature") {
|
||||
body["temperature"] = cmd.Params["temperature"]
|
||||
}
|
||||
if isSet(cmd, "max_tokens") {
|
||||
body["max_tokens"] = cmd.Params["max_tokens"]
|
||||
}
|
||||
if isSet(cmd, "top_p") {
|
||||
body["top_p"] = cmd.Params["top_p"]
|
||||
}
|
||||
if isSet(cmd, "frequency_penalty") {
|
||||
body["frequency_penalty"] = cmd.Params["frequency_penalty"]
|
||||
}
|
||||
if isSet(cmd, "presence_penalty") {
|
||||
body["presence_penalty"] = cmd.Params["presence_penalty"]
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// oneshotChatCompletions performs a non-streaming POST and returns a
|
||||
// ChatCompletionsResponse parsed from the RAGFlow-internal JSON envelope.
|
||||
func (c *CLI) oneshotChatCompletions(url string, body map[string]interface{}) (ResponseIf, error) {
|
||||
httpClient := c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer]
|
||||
resp, err := httpClient.Request("POST", url, "web", nil, body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("chat completions request: %w", err)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
return &ChatCompletionsResponse{
|
||||
Code: resp.StatusCode,
|
||||
Message: string(resp.Body),
|
||||
raw: resp.Body,
|
||||
}, nil
|
||||
}
|
||||
out := &ChatCompletionsResponse{
|
||||
Duration: resp.Duration,
|
||||
raw: resp.Body,
|
||||
}
|
||||
// RAGFlow returns {code, data: {answer, reference, ...}, message}.
|
||||
var envelope struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data *chatCompletionData `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(resp.Body, &envelope); err != nil {
|
||||
return nil, fmt.Errorf("chat completions: invalid response JSON: %w", err)
|
||||
}
|
||||
out.Code = envelope.Code
|
||||
out.Message = envelope.Message
|
||||
out.Data = envelope.Data
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// streamChatCompletions performs a streaming POST and collects SSE chunks.
|
||||
func (c *CLI) streamChatCompletions(url string, body map[string]interface{}) (ResponseIf, error) {
|
||||
httpClient := c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer]
|
||||
reader, err := httpClient.RequestStream("POST", url, "web", nil, body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("chat completions stream: %w", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
start := time.Now()
|
||||
scanner := bufio.NewScanner(reader)
|
||||
var fullContent string
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimPrefix(line, "data:")
|
||||
payload = strings.TrimSpace(payload)
|
||||
if payload == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
var chunk struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data chatCompletionData `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
if chunk.Data.Answer != "" {
|
||||
fullContent += chunk.Data.Answer
|
||||
}
|
||||
}
|
||||
|
||||
fullContent = strings.TrimLeft(fullContent, "\n\r")
|
||||
return &ChatCompletionsResponse{
|
||||
Duration: time.Since(start).Seconds(),
|
||||
Data: &chatCompletionData{
|
||||
Answer: fullContent,
|
||||
},
|
||||
streamed: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -2702,6 +2702,13 @@ func (p *Parser) parseAPIDisable() (*Command, error) {
|
||||
func (p *Parser) parseAPIChat() (*Command, error) {
|
||||
p.nextToken() // consume CHAT
|
||||
|
||||
// Redirect "chat completion[s]" to the standalone chat completions parser.
|
||||
if p.curToken.Type == TokenIdentifier &&
|
||||
(strings.EqualFold(p.curToken.Value, "completion") || strings.EqualFold(p.curToken.Value, "completions")) {
|
||||
p.nextToken() // consume completion/completions
|
||||
return p.parseChatCompletionsBody()
|
||||
}
|
||||
|
||||
var err error
|
||||
var modelNameOrID string = ""
|
||||
var messages []string
|
||||
@@ -3759,6 +3766,150 @@ optionsLoop:
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// CHAT COMPLETIONS <question>
|
||||
// [chat_id <string>] [session <string>] [llm <string>]
|
||||
|
||||
|
||||
// parseChatCompletionsBody parses the question and options of a CHAT COMPLETIONS
|
||||
// command. The leading keyword(s) must already have been consumed by the caller.
|
||||
func (p *Parser) parseChatCompletionsBody() (*Command, error) {
|
||||
|
||||
if p.curToken.Type == TokenDash {
|
||||
dashCount := 0
|
||||
for p.curToken.Type == TokenDash {
|
||||
dashCount++
|
||||
p.nextToken()
|
||||
}
|
||||
if dashCount > 0 && p.curToken.Type == TokenIdentifier {
|
||||
switch strings.ToLower(p.curToken.Value) {
|
||||
case "h", "help":
|
||||
return NewCommand("chat completions help"), nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("CHAT COMPLETIONS: only -h/--help takes no args; otherwise expected question")
|
||||
}
|
||||
|
||||
cmd := NewCommand("chat completions")
|
||||
|
||||
// Defaults
|
||||
cmd.Params["chat_id"] = ""
|
||||
cmd.Params["temperature"] = 0.0
|
||||
cmd.Params["max_tokens"] = 0
|
||||
cmd.Params["stream"] = false
|
||||
|
||||
// Track which options were explicitly set (distinguishes from defaults).
|
||||
cmd.Params["_set"] = map[string]bool{}
|
||||
|
||||
// Required positional: <question>
|
||||
question, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("CHAT COMPLETIONS: expected question: %w", err)
|
||||
}
|
||||
cmd.Params["question"] = question
|
||||
p.nextToken()
|
||||
|
||||
// Optional named options
|
||||
handleOption := func(name string) error {
|
||||
switch name {
|
||||
case "chat_id", "session", "llm", "system":
|
||||
v, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return fmt.Errorf("CHAT COMPLETIONS %s: expected quoted string, got %s", name, p.curToken.Value)
|
||||
}
|
||||
cmd.Params[name] = v
|
||||
p.nextToken()
|
||||
markSet(cmd, name)
|
||||
case "temperature", "top_p", "frequency_penalty", "presence_penalty":
|
||||
v, err := p.parseFloat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("CHAT COMPLETIONS %s: expected number, got %s", name, p.curToken.Value)
|
||||
}
|
||||
cmd.Params[name] = v
|
||||
p.nextToken()
|
||||
markSet(cmd, name)
|
||||
case "max_tokens":
|
||||
v, err := p.parseNumber()
|
||||
if err != nil {
|
||||
return fmt.Errorf("CHAT COMPLETIONS max_tokens: expected integer, got %s", p.curToken.Value)
|
||||
}
|
||||
cmd.Params["max_tokens"] = v
|
||||
p.nextToken()
|
||||
markSet(cmd, "max_tokens")
|
||||
case "stream", "pass_all_history", "legacy":
|
||||
v, err := p.parseBool()
|
||||
if err != nil {
|
||||
return fmt.Errorf("CHAT COMPLETIONS %s: expected true|false, got %s", name, p.curToken.Value)
|
||||
}
|
||||
cmd.Params[name] = v
|
||||
markSet(cmd, name)
|
||||
case "history":
|
||||
raw, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return fmt.Errorf("CHAT COMPLETIONS history: expected quoted string, got %s", p.curToken.Value)
|
||||
}
|
||||
cmd.Params["history_raw"] = raw
|
||||
p.nextToken()
|
||||
case "history_delimiter":
|
||||
v, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return fmt.Errorf("CHAT COMPLETIONS history_delimiter: expected quoted string, got %s", p.curToken.Value)
|
||||
}
|
||||
cmd.Params["history_delimiter"] = v
|
||||
p.nextToken()
|
||||
default:
|
||||
return fmt.Errorf("CHAT COMPLETIONS: unknown option %q (valid: chat_id, session, llm, system, history, history_delimiter, temperature, max_tokens, stream, top_p, frequency_penalty, presence_penalty, pass_all_history, legacy)", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Named options, any order, until ';'.
|
||||
optionsLoop:
|
||||
for {
|
||||
switch p.curToken.Type {
|
||||
case TokenSemicolon:
|
||||
p.nextToken()
|
||||
break optionsLoop
|
||||
case TokenEOF:
|
||||
break optionsLoop
|
||||
|
||||
case TokenIdentifier, TokenQuotedString:
|
||||
name := p.curToken.Value
|
||||
if p.curToken.Type == TokenQuotedString {
|
||||
name = strings.Trim(name, "'\"")
|
||||
}
|
||||
p.nextToken()
|
||||
if err := handleOption(name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
default:
|
||||
if !isKeyword(p.curToken.Type) {
|
||||
return nil, fmt.Errorf("CHAT COMPLETIONS: unexpected token %q in option list (valid options: chat_id, session, llm, system, history, history_delimiter, temperature, max_tokens, stream, top_p, frequency_penalty, presence_penalty, pass_all_history, legacy)", p.curToken.Value)
|
||||
}
|
||||
name := p.curToken.Value
|
||||
p.nextToken()
|
||||
if err := handleOption(name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func markSet(cmd *Command, name string) {
|
||||
if s, ok := cmd.Params["_set"].(map[string]bool); ok {
|
||||
s[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
func isSet(cmd *Command, name string) bool {
|
||||
if s, ok := cmd.Params["_set"].(map[string]bool); ok {
|
||||
return s[name]
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// parseJSONLiteral consumes a TokenQuotedString whose payload is a JSON
|
||||
// value (object, array, string, number, or boolean) and returns it as
|
||||
// the original raw string (NOT decoded — the caller decides whether to
|
||||
|
||||
@@ -167,7 +167,33 @@ Time: 76.582520
|
||||
```
|
||||
Note: Both image and video understanding support streaming and thinking modes as well.
|
||||
|
||||
### 6.8. Chat with OpenAI compatible API
|
||||
### 6.8. Chat completions
|
||||
|
||||
```
|
||||
RAGFlow(api/default)> chat completion 'hello'
|
||||
Answer: Hello! How can I assist you today? 😊
|
||||
Time: 1.591929
|
||||
```
|
||||
|
||||
```
|
||||
RAGFlow(api/default)> CHAT COMPLETIONS '<question>' chat_id '<chat_id>';
|
||||
```
|
||||
|
||||
```
|
||||
RAGFlow(api/default)> CHAT COMPLETIONS 'Explain the theory' \
|
||||
chat_id '<chat_id>' \
|
||||
session '<session_id>' llm 'glm-4.5-flash@test@zhipu-ai' stream true;
|
||||
```
|
||||
|
||||
```
|
||||
RAGFlow(api/default)> CHAT COMPLETIONS 'Continue' \
|
||||
system 'You are a helpful assistant.' \
|
||||
history 'user:What is RAG?;assistant:RAG stands for Retrieval-Augmented Generation...' \
|
||||
history_delimiter ';';
|
||||
```
|
||||
|
||||
### 6.9. Chat with OpenAI compatible API
|
||||
|
||||
```
|
||||
RAGFlow(api/default)> openai_chat '<chat_id>' 'Hello, how are you?';
|
||||
Answer: Hello! I'm just a virtual assistant, so I don't have feelings, but I'm here and ready to help you with anything you need. How can I assist you today? 😊
|
||||
@@ -204,17 +230,17 @@ RAGFlow(api/default)> openai_chat '<chat_id>' 'Hello, how are you?' extra_body '
|
||||
CLI error: OPENAI_CHAT extra_body: unknown field "ref" (valid: reference, reference_metadata, metadata_condition)
|
||||
```
|
||||
|
||||
### 6.9. Generate Embeddings
|
||||
### 6.10. Generate Embeddings
|
||||
```
|
||||
RAGFlow(api/default)> embed text 'what is rag' 'who are you' with 'embedding-3@test@zhipu-ai' dimension 16;
|
||||
```
|
||||
|
||||
### 6.10. Document Reranking
|
||||
### 6.11. Document Reranking
|
||||
```
|
||||
RAGFlow(api/default)> rerank query 'what is rag' document 'rag is retrieval augment generation' 'rag need llm' 'famous rag project includes ragflow' with 'rerank@test@zhipu-ai' top 2;
|
||||
```
|
||||
|
||||
### 6.11. Get supported models from provider API
|
||||
### 6.12. Get supported models from provider API
|
||||
|
||||
```
|
||||
RAGFlow(api/default)> list supported models from 'gitee' 'test';
|
||||
@@ -236,7 +262,7 @@ RAGFlow(api/default)> list supported models from 'gitee' 'test';
|
||||
+-----------+---------------------------+---------------+------------+-----------------------------------------------------------------+----------------------------------------------------------+---------------------------------------------+
|
||||
```
|
||||
|
||||
### 6.12. Get preset models of a provider
|
||||
### 6.13. Get preset models of a provider
|
||||
|
||||
```
|
||||
RAGFlow(api/default)> list models from 'minimax';
|
||||
@@ -254,7 +280,7 @@ RAGFlow(api/default)> list models from 'minimax';
|
||||
+------------+-------------+------------------------+
|
||||
```
|
||||
|
||||
### 6.13. List instances of a provider
|
||||
### 6.14. List instances of a provider
|
||||
|
||||
```
|
||||
RAGFlow(api/default)> list instances from 'zhipu-ai';
|
||||
@@ -265,7 +291,7 @@ RAGFlow(api/default)> list instances from 'zhipu-ai';
|
||||
+---------+----------------------+----------------------------------+--------------+----------------------------------+--------+
|
||||
```
|
||||
|
||||
### 6.14. Show instance of a provider
|
||||
### 6.15. Show instance of a provider
|
||||
```
|
||||
RAGFlow(api/default)> show instance 'test' from 'zhipu-ai';
|
||||
+----------------------------------+--------------+----------------------------------+---------+--------+
|
||||
@@ -275,7 +301,7 @@ RAGFlow(api/default)> show instance 'test' from 'zhipu-ai';
|
||||
+----------------------------------+--------------+----------------------------------+---------+--------+
|
||||
```
|
||||
|
||||
### 6.15. List models of a specific instance
|
||||
### 6.16. List models of a specific instance
|
||||
|
||||
```
|
||||
RAGFlow(api/default)> list models from 'minimax' 'test';
|
||||
@@ -293,7 +319,7 @@ RAGFlow(api/default)> list models from 'minimax' 'test';
|
||||
+------------+-------------+------------------------+--------+
|
||||
```
|
||||
|
||||
### 6.16. List added providers
|
||||
### 6.17. List added providers
|
||||
```
|
||||
RAGFlow(api/default)> list providers;
|
||||
+--------------------------------------------------------------------------+-------------+--------------+
|
||||
@@ -305,7 +331,7 @@ RAGFlow(api/default)> list providers;
|
||||
+--------------------------------------------------------------------------+-------------+--------------+
|
||||
```
|
||||
|
||||
### 6.17. Deactivate / activate a model
|
||||
### 6.18. Deactivate / activate a model
|
||||
|
||||
```
|
||||
RAGFlow(api/default)> disable model 'deepseek-v4-pro' from 'deepseek' 'test';
|
||||
@@ -321,7 +347,7 @@ RAGFlow(api/default)> enable model 'deepseek-v4-pro' from 'deepseek' 'test';
|
||||
SUCCESS
|
||||
```
|
||||
|
||||
### 6.18. Set current model
|
||||
### 6.19. Set current model
|
||||
```
|
||||
RAGFlow(api/default)> use model 'glm-4.5-flash@test@zhipu-ai';
|
||||
SUCCESS
|
||||
@@ -330,7 +356,7 @@ Answer: Large language models are advanced AI systems. They process text to unde
|
||||
Time: 1.680416
|
||||
```
|
||||
|
||||
### 6.19. Set, reset, and list default models
|
||||
### 6.20. Set, reset, and list default models
|
||||
```
|
||||
RAGFlow(api/default)> set default chat model 'glm-4.5-flash@test@zhipu-ai';
|
||||
SUCCESS
|
||||
@@ -374,7 +400,7 @@ RAGFlow(api/default)> list default models;
|
||||
+--------+----------------+--------------+----------------+------------+
|
||||
```
|
||||
|
||||
### 6.20. Show current balance of a provider instance
|
||||
### 6.21. Show current balance of a provider instance
|
||||
```
|
||||
RAGFlow(api/default)> show balance from 'gitee' 'test';
|
||||
+-------------+----------+
|
||||
@@ -384,13 +410,13 @@ RAGFlow(api/default)> show balance from 'gitee' 'test';
|
||||
+-------------+----------+
|
||||
```
|
||||
|
||||
### 6.21. Check provider instance availability
|
||||
### 6.22. Check provider instance availability
|
||||
```
|
||||
RAGFlow(api/default)> check instance 'test' from 'zhipu-ai';
|
||||
SUCCESS
|
||||
```
|
||||
|
||||
### 6.22. Add local model to RAGFlow, only for local deployed inference server, such as ollama
|
||||
### 6.23. Add local model to RAGFlow, only for local deployed inference server, such as ollama
|
||||
```
|
||||
RAGFlow(api/default)> add model 'Qwen/Qwen2.5-0.5B' to provider 'vllm' instance 'test' with tokens 131072 chat;
|
||||
SUCCESS
|
||||
@@ -404,7 +430,7 @@ RAGFlow(api/default)> drop model 'Qwen/Qwen2.5-0.5B' from 'vllm' 'test';
|
||||
SUCCESS
|
||||
```
|
||||
|
||||
### 6.23. List datasets
|
||||
### 6.24. List datasets
|
||||
```
|
||||
RAGFlow(api/default)> list datasets;
|
||||
+-------------+--------------+----------------+----------------------+----------------------------------+----------+------+----------+------------+----------------------------------+-----------+---------------+
|
||||
@@ -415,14 +441,14 @@ RAGFlow(api/default)> list datasets;
|
||||
+-------------+--------------+----------------+----------------------+----------------------------------+----------+------+----------+------------+----------------------------------+-----------+---------------+
|
||||
```
|
||||
|
||||
### 6.24. Text to Speech
|
||||
### 6.25. Text to Speech
|
||||
```
|
||||
RAGFlow(api/default)> tts with 'speech-2.8-hd@test@minimax' text 'He who desires but acts not, breeds pestilence.' play format 'wav' save './internal' param '{"voice_setting": {"voice_id": "English_radiant_girl", "speed": 1, "vol": 1, "pitch": 0}, "audio_setting": {"sample_rate": 32000, "bitrate": 128000, "format": "wav", "channel": 1}, "output_format": "hex"}'
|
||||
Saved to directory: /home/infiniflow/Documents/development/ragflow/internal/speech-2.8-hd_output.wav
|
||||
SUCCESS
|
||||
```
|
||||
|
||||
### 6.25. Audio to Speech
|
||||
### 6.26. Audio to Speech
|
||||
```
|
||||
RAGFlow(api/default)> asr with 'FunAudioLLM/SenseVoiceSmall@test@siliconflow' audio './internal/test.wav' param ''
|
||||
+----------------------------------------------------------------------------------------------------------------------+
|
||||
@@ -432,7 +458,7 @@ RAGFlow(api/default)> asr with 'FunAudioLLM/SenseVoiceSmall@test@siliconflow' au
|
||||
+----------------------------------------------------------------------------------------------------------------------+
|
||||
```
|
||||
|
||||
### 6.26. Optical Character Recognition
|
||||
### 6.27. Optical Character Recognition
|
||||
```
|
||||
RAGFlow(api/default)> ocr with 'paddleocr-vl-0.9b@test@baidu' file './internal/text.jpg'
|
||||
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
@@ -442,7 +468,7 @@ RAGFlow(api/default)> ocr with 'paddleocr-vl-0.9b@test@baidu' file './internal/t
|
||||
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
```
|
||||
|
||||
### 6.27. Chunk Management Commands
|
||||
### 6.28. Chunk Management Commands
|
||||
|
||||
- Create a chunk store with vector size
|
||||
```
|
||||
@@ -489,7 +515,7 @@ RAGFlow(api/default)> RETRIEVE 'AI' ON DATASETS 'test'
|
||||
RAGFlow(api/default)> GET CHUNK '29cc4f6d7a5c6e7c' OF DATASET 'test' DOCUMENT 'bbe55942535e11f1bc5184ba59049aa3' IN DATASET 'test'
|
||||
```
|
||||
|
||||
### 6.28. Metadata Management Commands
|
||||
### 6.29. Metadata Management Commands
|
||||
|
||||
- Create metadata store
|
||||
```
|
||||
@@ -525,7 +551,7 @@ RAGFlow(api/default)> DROP METADATA STORE
|
||||
RAGFlow(api/default)> GET METADATA OF DATASET 'test' 'test2'
|
||||
```
|
||||
|
||||
### 6.29. Search datasets
|
||||
### 6.30. Search datasets
|
||||
|
||||
- Search datasets using SQL-like dataset search syntax:
|
||||
```
|
||||
|
||||
@@ -227,192 +227,6 @@ func (h *ChatHandler) MindMap(c *gin.Context) {
|
||||
jsonResponse(c, common.CodeSuccess, mindMap, "success")
|
||||
}
|
||||
|
||||
// ListChatsNext list chats with advanced filtering and pagination
|
||||
// @Summary List Chats Next
|
||||
// @Description Get list of chats with filtering, pagination and sorting (equivalent to list_dialogs_next)
|
||||
// @Tags chat
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param keywords query string false "search keywords"
|
||||
// @Param page query int false "page number"
|
||||
// @Param page_size query int false "items per page"
|
||||
// @Param orderby query string false "order by field (default: create_time)"
|
||||
// @Param desc query bool false "descending order (default: true)"
|
||||
// @Param request body service.ListChatsNextRequest true "filter options including owner_ids"
|
||||
// @Success 200 {object} service.ListChatsNextResponse
|
||||
// @Router /v1/dialog/next [post]
|
||||
func (h *ChatHandler) ListChatsNext(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
// Parse query parameters
|
||||
keywords := c.Query("keywords")
|
||||
|
||||
page := 0
|
||||
if pageStr := c.Query("page"); pageStr != "" {
|
||||
if p, err := strconv.Atoi(pageStr); err == nil && p > 0 {
|
||||
page = p
|
||||
}
|
||||
}
|
||||
|
||||
pageSize := 0
|
||||
if pageSizeStr := c.Query("page_size"); pageSizeStr != "" {
|
||||
if ps, err := strconv.Atoi(pageSizeStr); err == nil && ps > 0 {
|
||||
pageSize = ps
|
||||
}
|
||||
}
|
||||
|
||||
orderby := c.DefaultQuery("orderby", "create_time")
|
||||
|
||||
desc := true
|
||||
if descStr := c.Query("desc"); descStr != "" {
|
||||
desc = descStr != "false"
|
||||
}
|
||||
|
||||
// Parse request body for owner_ids
|
||||
var req service.ListChatsNextRequest
|
||||
if c.Request.ContentLength > 0 {
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// List chats with advanced filtering
|
||||
result, err := h.chatService.ListChatsNext(userID, keywords, page, pageSize, orderby, desc, req.OwnerIDs)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"data": result,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// SetDialog create or update a dialog
|
||||
// @Summary Set Dialog
|
||||
// @Description Create or update a dialog (chat). If dialog_id is provided, updates existing dialog; otherwise creates new one.
|
||||
// @Tags chat
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body service.SetDialogRequest true "dialog configuration"
|
||||
// @Success 200 {object} service.SetDialogResponse
|
||||
// @Router /v1/dialog/set [post]
|
||||
func (h *ChatHandler) SetDialog(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
// Parse request body
|
||||
var req service.SetDialogRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required field: prompt_config
|
||||
if req.PromptConfig == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "prompt_config is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Call service to set dialog
|
||||
result, err := h.chatService.SetDialog(userID, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"data": result,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveDialogsRequest remove dialogs request
|
||||
type RemoveDialogsRequest struct {
|
||||
DialogIDs []string `json:"dialog_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// RemoveChats remove/delete dialogs (soft delete by setting status to invalid)
|
||||
// @Summary Remove Dialogs
|
||||
// @Description Remove dialogs by setting their status to invalid. Only the owner of the dialog can perform this operation.
|
||||
// @Tags chat
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body RemoveDialogsRequest true "dialog IDs to remove"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/dialog/rm [post]
|
||||
func (h *ChatHandler) RemoveChats(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
// Parse request body
|
||||
var req RemoveDialogsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Call service to remove dialogs
|
||||
if err := h.chatService.RemoveChats(userID, req.DialogIDs); err != nil {
|
||||
// Check if it's an authorization error
|
||||
if err.Error() == "only owner of chat authorized for this operation" {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"code": 403,
|
||||
"data": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"data": true,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteChat soft deletes a chat by ID.
|
||||
func (h *ChatHandler) DeleteChat(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
|
||||
@@ -19,7 +19,6 @@ package handler
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"ragflow/internal/common"
|
||||
@@ -44,107 +43,6 @@ func NewChatSessionHandler(chatSessionService *service.ChatSessionService, userS
|
||||
}
|
||||
}
|
||||
|
||||
// SetChatSession create or update a chat session
|
||||
// @Summary Set chat session
|
||||
// @Description Create or update a chat session. If is_new is true, creates new chat session; otherwise updates existing one.
|
||||
// @Tags chat_session
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body service.SetChatSessionRequest true "chat session configuration"
|
||||
// @Success 200 {object} service.SetChatSessionResponse
|
||||
// @Router /v1/conversation/set [post]
|
||||
func (h *ChatSessionHandler) SetChatSession(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
// Parse request body
|
||||
var req service.SetChatSessionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Call service to set chat session
|
||||
result, err := h.chatSessionService.SetChatSession(userID, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"data": result,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveChatSessionsRequest remove chat sessions request
|
||||
type RemoveChatSessionsRequest struct {
|
||||
ConversationIDs []string `json:"conversation_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// RemoveChatSessions remove/delete chat sessions
|
||||
// @Summary Remove Chat Sessions
|
||||
// @Description Remove chat sessions by their IDs. Only the owner of the chat session can perform this operation.
|
||||
// @Tags chat_session
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body RemoveChatSessionsRequest true "chat session IDs to remove"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/conversation/rm [post]
|
||||
func (h *ChatSessionHandler) RemoveChatSessions(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
// Parse request body
|
||||
var req RemoveChatSessionsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Call service to remove chat sessions
|
||||
if err := h.chatSessionService.RemoveChatSessions(userID, req.ConversationIDs); err != nil {
|
||||
// Check if it's an authorization error
|
||||
if err.Error() == "Only owner of chat session authorized for this operation" {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"code": 403,
|
||||
"data": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"data": true,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// ListChatSessions list chat sessions for a dialog
|
||||
// @Summary List Chat Sessions
|
||||
// @Description Get list of chat sessions for a specific dialog
|
||||
@@ -198,30 +96,37 @@ func (h *ChatSessionHandler) ListChatSessions(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// CompletionRequest completion request
|
||||
type CompletionRequest struct {
|
||||
ConversationID string `json:"conversation_id" binding:"required"`
|
||||
Messages []map[string]interface{} `json:"messages" binding:"required"`
|
||||
LLMID string `json:"llm_id,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Thinking *bool `json:"thinking,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
type ChatCompletionsRequest struct {
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
ConversationID string `json:"conversation_id,omitempty"`
|
||||
Messages []map[string]interface{} `json:"messages,omitempty"`
|
||||
Question string `json:"question,omitempty"`
|
||||
Files []interface{} `json:"files,omitempty"`
|
||||
LLMID string `json:"llm_id,omitempty"`
|
||||
PassAllHistoryMessages *bool `json:"pass_all_history_messages,omitempty"`
|
||||
PassAllHistory *bool `json:"pass_all_history,omitempty"`
|
||||
Legacy bool `json:"legacy,omitempty"`
|
||||
Stream *bool `json:"stream"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// Completion chat completion
|
||||
// ChatCompletions chat completion
|
||||
// @Summary Chat Completion
|
||||
// @Description Send messages to the chat model and get a response. Supports streaming and non-streaming modes.
|
||||
// @Description Send messages to the chat model and get a response.
|
||||
// @Description Default is streaming (text/event-stream); set stream:false for JSON.
|
||||
// @Tags chat_session
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body CompletionRequest true "completion request"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /v1/conversation/completion [post]
|
||||
func (h *ChatSessionHandler) Completion(c *gin.Context) {
|
||||
// @Produce json, text/event-stream
|
||||
// @Param request body ChatCompletionsRequest true "chat completion request"
|
||||
// @Success 200 {object} map[string]interface{} "Non-streaming JSON response"
|
||||
// @Success 200 {string} text/event-stream "Streaming SSE response"
|
||||
// @Router /api/v1/chat/completions [post]
|
||||
func (h *ChatSessionHandler) ChatCompletions(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
@@ -229,83 +134,97 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) {
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
// Parse request body
|
||||
var req CompletionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": err.Error(),
|
||||
})
|
||||
var rawBody map[string]interface{}
|
||||
if err := c.ShouldBindJSON(&rawBody); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Build chat model config
|
||||
chatModelConfig := make(map[string]interface{})
|
||||
if req.Temperature != 0 {
|
||||
chatModelConfig["temperature"] = req.Temperature
|
||||
var req ChatCompletionsRequest
|
||||
b, err := json.Marshal(rawBody)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
if req.TopP != 0 {
|
||||
chatModelConfig["top_p"] = req.TopP
|
||||
if err := json.Unmarshal(b, &req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
if req.FrequencyPenalty != 0 {
|
||||
chatModelConfig["frequency_penalty"] = req.FrequencyPenalty
|
||||
|
||||
// Normalize session_id / conversation_id
|
||||
sessionID := req.SessionID
|
||||
if sessionID == "" {
|
||||
sessionID = req.ConversationID
|
||||
}
|
||||
if req.PresencePenalty != 0 {
|
||||
chatModelConfig["presence_penalty"] = req.PresencePenalty
|
||||
|
||||
// Build generation config
|
||||
genConfig := make(map[string]interface{})
|
||||
if req.Temperature != nil {
|
||||
genConfig["temperature"] = *req.Temperature
|
||||
}
|
||||
if req.MaxTokens != 0 {
|
||||
chatModelConfig["max_tokens"] = req.MaxTokens
|
||||
if req.TopP != nil {
|
||||
genConfig["top_p"] = *req.TopP
|
||||
}
|
||||
if req.FrequencyPenalty != nil {
|
||||
genConfig["frequency_penalty"] = *req.FrequencyPenalty
|
||||
}
|
||||
if req.PresencePenalty != nil {
|
||||
genConfig["presence_penalty"] = *req.PresencePenalty
|
||||
}
|
||||
if req.MaxTokens != nil {
|
||||
genConfig["max_tokens"] = *req.MaxTokens
|
||||
}
|
||||
|
||||
// Resolve pass_all_history from either alias
|
||||
passAllHistory := false
|
||||
if req.PassAllHistory != nil {
|
||||
passAllHistory = *req.PassAllHistory
|
||||
}
|
||||
if req.PassAllHistoryMessages != nil {
|
||||
passAllHistory = *req.PassAllHistoryMessages
|
||||
}
|
||||
|
||||
// Remove known keys from rawBody; what remains is passthrough kwargs
|
||||
knownKeys := []string{
|
||||
"chat_id", "session_id", "conversation_id",
|
||||
"messages", "question", "files",
|
||||
"llm_id",
|
||||
"pass_all_history_messages", "pass_all_history",
|
||||
"legacy", "stream",
|
||||
"temperature", "top_p", "frequency_penalty", "presence_penalty", "max_tokens",
|
||||
}
|
||||
for _, key := range knownKeys {
|
||||
delete(rawBody, key)
|
||||
}
|
||||
kwargs := rawBody
|
||||
|
||||
// Determine stream mode
|
||||
streamMode := true
|
||||
if req.Stream != nil {
|
||||
chatModelConfig["stream"] = *req.Stream
|
||||
}
|
||||
if req.Thinking != nil {
|
||||
chatModelConfig["thinking"] = *req.Thinking
|
||||
streamMode = *req.Stream
|
||||
}
|
||||
|
||||
// Process messages - filter out system messages and initial assistant messages
|
||||
var processedMessages []map[string]interface{}
|
||||
for i, m := range req.Messages {
|
||||
role, _ := m["role"].(string)
|
||||
if role == "system" {
|
||||
continue
|
||||
}
|
||||
if role == "assistant" && len(processedMessages) == 0 {
|
||||
continue
|
||||
}
|
||||
processedMessages = append(processedMessages, m)
|
||||
_ = i
|
||||
}
|
||||
|
||||
// Get last message ID if present
|
||||
var messageID string
|
||||
if len(processedMessages) > 0 {
|
||||
if id, ok := processedMessages[len(processedMessages)-1]["id"].(string); ok {
|
||||
messageID = id
|
||||
}
|
||||
}
|
||||
|
||||
// Call service
|
||||
if req.Stream != nil && *req.Stream {
|
||||
// Streaming response
|
||||
if streamMode {
|
||||
disableWriteDeadlineForSSE(c)
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
// Create a channel for streaming data
|
||||
streamChan := make(chan string)
|
||||
streamChan := make(chan string, 32)
|
||||
reqCtx := c.Request.Context()
|
||||
go func() {
|
||||
defer close(streamChan)
|
||||
err := h.chatSessionService.CompletionStream(reqCtx, userID, req.ConversationID, processedMessages, req.LLMID, chatModelConfig, messageID, streamChan)
|
||||
if err != nil {
|
||||
streamChan <- fmt.Sprintf("data: %s\n\n", err.Error())
|
||||
}
|
||||
_, _ = h.chatSessionService.ChatCompletions(
|
||||
reqCtx, userID,
|
||||
req.ChatID, sessionID,
|
||||
req.Messages, req.Question, req.Files,
|
||||
req.LLMID, genConfig, kwargs,
|
||||
passAllHistory, req.Legacy,
|
||||
true, streamChan,
|
||||
)
|
||||
}()
|
||||
|
||||
// Stream data to client
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
data, ok := <-streamChan
|
||||
if !ok {
|
||||
@@ -315,8 +234,14 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) {
|
||||
return true
|
||||
})
|
||||
} else {
|
||||
// Non-streaming response
|
||||
result, err := h.chatSessionService.Completion(userID, req.ConversationID, processedMessages, req.LLMID, chatModelConfig, messageID)
|
||||
result, err := h.chatSessionService.ChatCompletions(
|
||||
c.Request.Context(), userID,
|
||||
req.ChatID, sessionID,
|
||||
req.Messages, req.Question, req.Files,
|
||||
req.LLMID, genConfig, kwargs,
|
||||
passAllHistory, req.Legacy,
|
||||
false, nil,
|
||||
)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
@@ -324,7 +249,6 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"data": result,
|
||||
|
||||
@@ -292,6 +292,12 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
chats.PATCH("/:chat_id/sessions/:session_id", r.chatSessionHandler.UpdateSession)
|
||||
}
|
||||
|
||||
chat := v1.Group("/chat")
|
||||
{
|
||||
// Chat completions route
|
||||
chat.POST("/completions", r.chatSessionHandler.ChatCompletions)
|
||||
}
|
||||
|
||||
// OpenAI-compatible chat completions route
|
||||
openai := v1.Group("/openai")
|
||||
{
|
||||
@@ -498,7 +504,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
provider.PATCH("/:provider_name/instances/:instance_name/models/*model_name", r.providerHandler.EnableOrDisableModel)
|
||||
provider.POST("/:provider_name/instances/:instance_name/models", r.providerHandler.AddModel)
|
||||
provider.DELETE("/:provider_name/instances/:instance_name/models", r.providerHandler.DropInstanceModels)
|
||||
v1.POST("/chat/completions", r.providerHandler.ChatToModel)
|
||||
v1.POST("/chat/to_model", r.providerHandler.ChatToModel)
|
||||
v1.POST("/embeddings", r.providerHandler.EmbedText)
|
||||
v1.POST("/rerank", r.providerHandler.RerankDocument)
|
||||
v1.POST("/audio/transcriptions", r.providerHandler.TranscribeAudio)
|
||||
@@ -678,14 +684,6 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
chunk.POST("/update", r.chunkHandler.UpdateChunk) // Internal API only for GO
|
||||
}
|
||||
|
||||
// Chat routes
|
||||
chat := authorized.Group("/v1/dialog")
|
||||
{
|
||||
chat.POST("/next", r.chatHandler.ListChatsNext)
|
||||
chat.POST("/set", r.chatHandler.SetDialog)
|
||||
chat.POST("/rm", r.chatHandler.RemoveChats)
|
||||
}
|
||||
|
||||
// Chat Channel
|
||||
chanChannel := v1.Group("/chat-channels")
|
||||
{
|
||||
@@ -705,15 +703,6 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
langfuse.DELETE("/api-key", r.langfuseHandler.DeleteAPIKey)
|
||||
}
|
||||
|
||||
// Chat session (conversation) routes
|
||||
session := authorized.Group("/v1/conversation")
|
||||
{
|
||||
session.POST("/set", r.chatSessionHandler.SetChatSession)
|
||||
session.POST("/rm", r.chatSessionHandler.RemoveChatSessions)
|
||||
session.GET("/list", r.chatSessionHandler.ListChatSessions)
|
||||
session.POST("/completion", r.chatSessionHandler.Completion)
|
||||
}
|
||||
|
||||
// Connector routes
|
||||
connector := authorized.Group("/v1/connector")
|
||||
{
|
||||
|
||||
@@ -28,28 +28,6 @@ import (
|
||||
"ragflow/internal/dao"
|
||||
)
|
||||
|
||||
var DefaultPromptConfig = PromptConfig{
|
||||
System: strPtr(pyDefaultSystemPrompt),
|
||||
Prologue: strPtr(pyDefaultPrologue),
|
||||
Parameters: []ParameterConfig{
|
||||
{Key: "knowledge", Optional: false},
|
||||
},
|
||||
EmptyResponse: strPtr(pyDefaultEmptyResponse),
|
||||
Quote: boolPtr(true),
|
||||
TTS: boolPtr(false),
|
||||
RefineMultiturn: boolPtr(true),
|
||||
}
|
||||
|
||||
var DefaultDirectChatPromptConfig = PromptConfig{
|
||||
System: strPtr(""),
|
||||
Prologue: strPtr(""),
|
||||
Parameters: []ParameterConfig{},
|
||||
EmptyResponse: strPtr(""),
|
||||
Quote: boolPtr(false),
|
||||
TTS: boolPtr(false),
|
||||
RefineMultiturn: boolPtr(true),
|
||||
}
|
||||
|
||||
var DefaultRerankModels = map[string]struct{}{
|
||||
"BAAI/bge-reranker-v2-m3": {},
|
||||
"maidalun1020/bce-reranker-base_v1": {},
|
||||
@@ -653,74 +631,6 @@ func isTruthy(value interface{}) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// ListChatsNextRequest list chats next request
|
||||
type ListChatsNextRequest struct {
|
||||
OwnerIDs []string `json:"owner_ids,omitempty"`
|
||||
}
|
||||
|
||||
// ListChatsNextResponse list chats next response
|
||||
type ListChatsNextResponse struct {
|
||||
Chats []*ChatWithKBNames `json:"dialogs"`
|
||||
Total int64 `json:"total"`
|
||||
}
|
||||
|
||||
// ListChatsNext list chats with advanced filtering (equivalent to list_dialogs_next)
|
||||
func (s *ChatService) ListChatsNext(userID string, keywords string, page, pageSize int, orderby string, desc bool, ownerIDs []string) (*ListChatsNextResponse, error) {
|
||||
var chats []*entity.Chat
|
||||
var total int64
|
||||
var err error
|
||||
|
||||
if len(ownerIDs) == 0 {
|
||||
// Get tenant IDs by user ID (joined tenants)
|
||||
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use database pagination
|
||||
chats, total, err = s.chatDAO.ListByTenantIDs(tenantIDs, userID, page, pageSize, orderby, desc, keywords)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// Filter by owner IDs, manual pagination
|
||||
chats, total, err = s.chatDAO.ListByOwnerIDs(ownerIDs, userID, orderby, desc, keywords)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Manual pagination
|
||||
if page > 0 && pageSize > 0 {
|
||||
start := (page - 1) * pageSize
|
||||
end := start + pageSize
|
||||
if start < int(total) {
|
||||
if end > int(total) {
|
||||
end = int(total)
|
||||
}
|
||||
chats = chats[start:end]
|
||||
} else {
|
||||
chats = []*entity.Chat{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Enrich with knowledge base names
|
||||
chatsWithKBNames := make([]*ChatWithKBNames, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs)
|
||||
chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{
|
||||
Chat: chat,
|
||||
KBNames: kbNames,
|
||||
DatasetIDs: datasetIDs,
|
||||
})
|
||||
}
|
||||
|
||||
return &ListChatsNextResponse{
|
||||
Chats: chatsWithKBNames,
|
||||
Total: total,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getDatasetNamesAndIDs gets knowledge base names by IDs
|
||||
func (s *ChatService) getDatasetNamesAndIDs(kbIDs entity.JSONSlice) ([]string, []string) {
|
||||
var names = make([]string, 0, 0)
|
||||
@@ -743,30 +653,6 @@ func (s *ChatService) getDatasetNamesAndIDs(kbIDs entity.JSONSlice) ([]string, [
|
||||
return names, ids
|
||||
}
|
||||
|
||||
// ParameterConfig parameter configuration in prompt_config
|
||||
type ParameterConfig struct {
|
||||
Key string `json:"key"`
|
||||
Optional bool `json:"optional"`
|
||||
}
|
||||
|
||||
// PromptConfig prompt configuration
|
||||
type PromptConfig struct {
|
||||
System *string `json:"system"`
|
||||
Prologue *string `json:"prologue"`
|
||||
Parameters []ParameterConfig `json:"parameters"`
|
||||
EmptyResponse *string `json:"empty_response"`
|
||||
TavilyAPIKey string `json:"tavily_api_key,omitempty"`
|
||||
Keyword *bool `json:"keyword,omitempty"`
|
||||
Quote *bool `json:"quote,omitempty"`
|
||||
Reasoning *bool `json:"reasoning,omitempty"`
|
||||
RefineMultiturn *bool `json:"refine_multiturn,omitempty"`
|
||||
TocEnhance *bool `json:"toc_enhance,omitempty"`
|
||||
TTS *bool `json:"tts,omitempty"`
|
||||
UseKG *bool `json:"use_kg,omitempty"`
|
||||
CrossLanguages []string `json:"cross_languages,omitempty"`
|
||||
ReferenceMetadata map[string]interface{} `json:"reference_metadata,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
pyDefaultSystemPrompt = "You are an intelligent assistant. Please summarize the content of the dataset to answer the question. " +
|
||||
"Please list the data in the dataset and answer in detail. " +
|
||||
@@ -781,393 +667,6 @@ const (
|
||||
pyDefaultEmptyResponse = "Sorry! No relevant content was found in the knowledge base!"
|
||||
)
|
||||
|
||||
// applyPromptDefaults replaces missing keys with default values
|
||||
func applyPromptDefaults(p *PromptConfig) {
|
||||
if p.System == nil || *p.System == "" {
|
||||
s := pyDefaultSystemPrompt
|
||||
p.System = &s
|
||||
}
|
||||
if p.Prologue == nil {
|
||||
s := pyDefaultPrologue
|
||||
p.Prologue = &s
|
||||
}
|
||||
if p.Parameters == nil {
|
||||
p.Parameters = []ParameterConfig{{Key: "knowledge", Optional: false}}
|
||||
}
|
||||
if p.EmptyResponse == nil {
|
||||
s := pyDefaultEmptyResponse
|
||||
p.EmptyResponse = &s
|
||||
}
|
||||
if p.Quote == nil {
|
||||
t := true
|
||||
p.Quote = &t
|
||||
}
|
||||
if p.RefineMultiturn == nil {
|
||||
t := true
|
||||
p.RefineMultiturn = &t
|
||||
}
|
||||
if p.TTS == nil {
|
||||
f := false
|
||||
p.TTS = &f
|
||||
}
|
||||
}
|
||||
|
||||
// SetDialogRequest set chat request
|
||||
type SetDialogRequest struct {
|
||||
DialogID string `json:"dialog_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
TopN int64 `json:"top_n,omitempty"`
|
||||
TopK int64 `json:"top_k,omitempty"`
|
||||
RerankID string `json:"rerank_id,omitempty"`
|
||||
SimilarityThreshold float64 `json:"similarity_threshold,omitempty"`
|
||||
VectorSimilarityWeight float64 `json:"vector_similarity_weight,omitempty"`
|
||||
LLMSetting map[string]interface{} `json:"llm_setting,omitempty"`
|
||||
MetaDataFilter map[string]interface{} `json:"meta_data_filter,omitempty"`
|
||||
PromptConfig *PromptConfig `json:"prompt_config" binding:"required"`
|
||||
KBIDs []string `json:"kb_ids,omitempty"`
|
||||
LLMID string `json:"llm_id,omitempty"`
|
||||
}
|
||||
|
||||
// SetDialogResponse set chat response
|
||||
type SetDialogResponse struct {
|
||||
*entity.Chat
|
||||
KBNames []string `json:"kb_names"`
|
||||
}
|
||||
|
||||
// SetDialog create or update a chat
|
||||
func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialogResponse, error) {
|
||||
// Determine if this is a create or update operation
|
||||
isCreate := req.DialogID == ""
|
||||
|
||||
// Validate and process name
|
||||
name := req.Name
|
||||
if name == "" {
|
||||
name = "New Chat"
|
||||
}
|
||||
|
||||
// Validate name type and content
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return nil, errors.New("Chat name can't be empty")
|
||||
}
|
||||
|
||||
// Check name length (UTF-8 byte length)
|
||||
if len(name) > 255 {
|
||||
return nil, fmt.Errorf("Chat name length is %d which is larger than 255", len(name))
|
||||
}
|
||||
|
||||
name = strings.TrimSpace(name)
|
||||
|
||||
// Get tenant ID (use userID as default tenant)
|
||||
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tenantID string
|
||||
if len(tenantIDs) > 0 {
|
||||
tenantID = tenantIDs[0]
|
||||
} else {
|
||||
tenantID = userID
|
||||
}
|
||||
|
||||
// For create: check for duplicate names and generate unique name
|
||||
if isCreate {
|
||||
existingNames, err := s.chatDAO.GetExistingNames(tenantID, "1")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if name exists (case-insensitive)
|
||||
nameLower := strings.ToLower(name)
|
||||
for _, existing := range existingNames {
|
||||
if strings.ToLower(existing) == nameLower {
|
||||
// Generate unique name
|
||||
name = s.generateUniqueName(name, existingNames)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set default values
|
||||
description := req.Description
|
||||
if description == "" {
|
||||
description = "A helpful chat"
|
||||
}
|
||||
|
||||
topN := req.TopN
|
||||
if topN == 0 {
|
||||
topN = 6
|
||||
}
|
||||
|
||||
topK := req.TopK
|
||||
if topK == 0 {
|
||||
topK = 1024
|
||||
}
|
||||
|
||||
rerankID := req.RerankID
|
||||
|
||||
similarityThreshold := req.SimilarityThreshold
|
||||
if similarityThreshold == 0 {
|
||||
similarityThreshold = 0.1
|
||||
}
|
||||
|
||||
vectorSimilarityWeight := req.VectorSimilarityWeight
|
||||
if vectorSimilarityWeight == 0 {
|
||||
vectorSimilarityWeight = 0.3
|
||||
}
|
||||
|
||||
llmSetting := req.LLMSetting
|
||||
if llmSetting == nil {
|
||||
llmSetting = make(map[string]interface{})
|
||||
}
|
||||
|
||||
metaDataFilter := req.MetaDataFilter
|
||||
if metaDataFilter == nil {
|
||||
metaDataFilter = make(map[string]interface{})
|
||||
}
|
||||
|
||||
promptConfig := req.PromptConfig
|
||||
|
||||
// Process kb_ids
|
||||
kbIDs := req.KBIDs
|
||||
if kbIDs == nil {
|
||||
kbIDs = []string{}
|
||||
}
|
||||
|
||||
// Apply default prompt config on create only
|
||||
if isCreate {
|
||||
applyPromptDefaults(promptConfig)
|
||||
}
|
||||
|
||||
// Set default parameters for datasets with knowledge retrieval
|
||||
// Check if parameters is missing or empty and kb_ids is provided
|
||||
if len(kbIDs) > 0 && (promptConfig.Parameters == nil || len(promptConfig.Parameters) == 0) {
|
||||
// Check if system prompt uses {knowledge} placeholder
|
||||
if promptConfig.System != nil && strings.Contains(*promptConfig.System, "{knowledge}") {
|
||||
promptConfig.Parameters = []ParameterConfig{
|
||||
{Key: "knowledge", Optional: false},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For update: validate that {knowledge} is not used when no KBs or Tavily
|
||||
if !isCreate {
|
||||
if len(kbIDs) == 0 && promptConfig.TavilyAPIKey == "" &&
|
||||
promptConfig.System != nil && strings.Contains(*promptConfig.System, "{knowledge}") {
|
||||
return nil, errors.New("Please remove `{knowledge}` in system prompt since no dataset / Tavily used here")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate parameters
|
||||
for _, p := range promptConfig.Parameters {
|
||||
if p.Optional {
|
||||
continue
|
||||
}
|
||||
placeholder := fmt.Sprintf("{%s}", p.Key)
|
||||
if promptConfig.System == nil || !strings.Contains(*promptConfig.System, placeholder) {
|
||||
return nil, fmt.Errorf("Parameter '%s' is not used", p.Key)
|
||||
}
|
||||
}
|
||||
|
||||
// Check knowledge bases and their embedding models
|
||||
if len(kbIDs) > 0 {
|
||||
kbs, err := s.kbDAO.GetByIDs(kbIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if all KBs use the same embedding model
|
||||
var embdID string
|
||||
for i, kb := range kbs {
|
||||
if i == 0 {
|
||||
embdID = kb.EmbdID
|
||||
} else {
|
||||
// Extract base model name (remove vendor suffix)
|
||||
embdBase := s.splitModelNameAndFactory(embdID)
|
||||
kbEmbdBase := s.splitModelNameAndFactory(kb.EmbdID)
|
||||
if embdBase != kbEmbdBase {
|
||||
return nil, fmt.Errorf("Datasets use different embedding models: %v", getEmbdIDs(kbs))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get LLM ID (use tenant's default if not provided)
|
||||
llmID := req.LLMID
|
||||
if llmID == "" {
|
||||
tenant, err := s.tenantDAO.GetByID(tenantID)
|
||||
if err != nil {
|
||||
return nil, errors.New("Tenant not found")
|
||||
}
|
||||
llmID = tenant.LLMID
|
||||
}
|
||||
|
||||
// Convert prompt config to JSONMap
|
||||
promptConfigMap := entity.JSONMap{}
|
||||
if promptConfig.System != nil && *promptConfig.System != "" {
|
||||
promptConfigMap["system"] = *promptConfig.System
|
||||
}
|
||||
if promptConfig.Prologue != nil {
|
||||
promptConfigMap["prologue"] = *promptConfig.Prologue
|
||||
}
|
||||
if promptConfig.EmptyResponse != nil {
|
||||
promptConfigMap["empty_response"] = *promptConfig.EmptyResponse
|
||||
}
|
||||
if promptConfig.Quote != nil {
|
||||
promptConfigMap["quote"] = *promptConfig.Quote
|
||||
}
|
||||
if promptConfig.RefineMultiturn != nil {
|
||||
promptConfigMap["refine_multiturn"] = *promptConfig.RefineMultiturn
|
||||
}
|
||||
if promptConfig.TTS != nil {
|
||||
promptConfigMap["tts"] = *promptConfig.TTS
|
||||
}
|
||||
if promptConfig.Keyword != nil {
|
||||
promptConfigMap["keyword"] = *promptConfig.Keyword
|
||||
}
|
||||
if promptConfig.Reasoning != nil {
|
||||
promptConfigMap["reasoning"] = *promptConfig.Reasoning
|
||||
}
|
||||
if promptConfig.TocEnhance != nil {
|
||||
promptConfigMap["toc_enhance"] = *promptConfig.TocEnhance
|
||||
}
|
||||
if promptConfig.UseKG != nil {
|
||||
promptConfigMap["use_kg"] = *promptConfig.UseKG
|
||||
}
|
||||
if promptConfig.TavilyAPIKey != "" {
|
||||
promptConfigMap["tavily_api_key"] = promptConfig.TavilyAPIKey
|
||||
}
|
||||
if len(promptConfig.CrossLanguages) > 0 {
|
||||
promptConfigMap["cross_languages"] = promptConfig.CrossLanguages
|
||||
}
|
||||
if len(promptConfig.ReferenceMetadata) > 0 {
|
||||
promptConfigMap["reference_metadata"] = promptConfig.ReferenceMetadata
|
||||
}
|
||||
if len(promptConfig.Parameters) > 0 {
|
||||
params := make([]map[string]interface{}, len(promptConfig.Parameters))
|
||||
for i, p := range promptConfig.Parameters {
|
||||
params[i] = map[string]interface{}{
|
||||
"key": p.Key,
|
||||
"optional": p.Optional,
|
||||
}
|
||||
}
|
||||
promptConfigMap["parameters"] = params
|
||||
}
|
||||
|
||||
// Convert kbIDs to JSONSlice
|
||||
kbIDsJSON := make(entity.JSONSlice, len(kbIDs))
|
||||
for i, id := range kbIDs {
|
||||
kbIDsJSON[i] = id
|
||||
}
|
||||
|
||||
if isCreate {
|
||||
// Generate UUID for new chat
|
||||
newID := common.GenerateUUID()
|
||||
|
||||
// Set default language
|
||||
language := "English"
|
||||
|
||||
// Create new chat
|
||||
chat := &entity.Chat{
|
||||
ID: newID,
|
||||
TenantID: tenantID,
|
||||
Name: &name,
|
||||
Description: &description,
|
||||
Icon: &req.Icon,
|
||||
Language: &language,
|
||||
LLMID: llmID,
|
||||
LLMSetting: llmSetting,
|
||||
PromptConfig: promptConfigMap,
|
||||
MetaDataFilter: (*entity.JSONMap)(&metaDataFilter),
|
||||
TopN: topN,
|
||||
TopK: topK,
|
||||
RerankID: rerankID,
|
||||
SimilarityThreshold: similarityThreshold,
|
||||
VectorSimilarityWeight: vectorSimilarityWeight,
|
||||
KBIDs: kbIDsJSON,
|
||||
Status: strPtr("1"),
|
||||
}
|
||||
|
||||
if err := s.chatDAO.Create(chat); err != nil {
|
||||
return nil, errors.New("Fail to new a chat")
|
||||
}
|
||||
|
||||
// Get KB names
|
||||
kbNames, _ := s.getDatasetNamesAndIDs(chat.KBIDs)
|
||||
|
||||
return &SetDialogResponse{
|
||||
Chat: chat,
|
||||
KBNames: kbNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
updateData := map[string]interface{}{
|
||||
"name": name,
|
||||
"description": description,
|
||||
"icon": req.Icon,
|
||||
"llm_id": llmID,
|
||||
"llm_setting": llmSetting,
|
||||
"prompt_config": promptConfigMap,
|
||||
"meta_data_filter": metaDataFilter,
|
||||
"top_n": topN,
|
||||
"top_k": topK,
|
||||
"rerank_id": rerankID,
|
||||
"similarity_threshold": similarityThreshold,
|
||||
"vector_similarity_weight": vectorSimilarityWeight,
|
||||
"kb_ids": kbIDsJSON,
|
||||
}
|
||||
|
||||
if err := s.chatDAO.UpdateByID(req.DialogID, updateData); err != nil {
|
||||
return nil, errors.New("Dialog not found")
|
||||
}
|
||||
|
||||
// Get updated chat
|
||||
chat, err := s.chatDAO.GetByID(req.DialogID)
|
||||
if err != nil {
|
||||
return nil, errors.New("Fail to update a chat")
|
||||
}
|
||||
|
||||
// Get KB names
|
||||
kbNames, _ := s.getDatasetNamesAndIDs(chat.KBIDs)
|
||||
|
||||
return &SetDialogResponse{
|
||||
Chat: chat,
|
||||
KBNames: kbNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// generateUniqueName generates a unique name by appending a number
|
||||
func (s *ChatService) generateUniqueName(name string, existingNames []string) string {
|
||||
baseName := name
|
||||
counter := 1
|
||||
|
||||
// Check if name already has a suffix like "(1)"
|
||||
if idx := strings.LastIndex(name, "("); idx > 0 {
|
||||
if idx2 := strings.LastIndex(name, ")"); idx2 > idx {
|
||||
if num, err := fmt.Sscanf(name[idx+1:idx2], "%d", &counter); err == nil && num == 1 {
|
||||
baseName = strings.TrimSpace(name[:idx])
|
||||
counter++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
existingMap := make(map[string]bool)
|
||||
for _, n := range existingNames {
|
||||
existingMap[strings.ToLower(n)] = true
|
||||
}
|
||||
|
||||
newName := name
|
||||
for {
|
||||
if !existingMap[strings.ToLower(newName)] {
|
||||
return newName
|
||||
}
|
||||
newName = fmt.Sprintf("%s(%d)", baseName, counter)
|
||||
counter++
|
||||
}
|
||||
}
|
||||
|
||||
// splitModelNameAndFactory extracts the base model name (removes vendor suffix)
|
||||
func (s *ChatService) splitModelNameAndFactory(embdID string) string {
|
||||
// Remove vendor suffix (e.g., "model@openai" -> "model")
|
||||
@@ -1629,52 +1128,6 @@ func (s *ChatService) BulkDeleteChats(userID string, req *BulkDeleteChatsRequest
|
||||
return nil, errors.New(strings.Join(errorsList, "; "))
|
||||
}
|
||||
|
||||
// RemoveChats removes dialogs by setting their status to invalid (soft delete)
|
||||
// Only the owner of the chat can perform this operation
|
||||
func (s *ChatService) RemoveChats(userID string, chatIDs []string) error {
|
||||
// Get user's tenants
|
||||
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build a set of user's tenant IDs for quick lookup
|
||||
tenantIDSet := make(map[string]bool)
|
||||
for _, tid := range tenantIDs {
|
||||
tenantIDSet[tid] = true
|
||||
}
|
||||
// Also add userID itself as a tenant (for cases where tenant_id = user_id)
|
||||
tenantIDSet[userID] = true
|
||||
|
||||
// Check each chat and build update list
|
||||
var updates []map[string]interface{}
|
||||
for _, chatID := range chatIDs {
|
||||
// Get the chat to check ownership
|
||||
chat, err := s.chatDAO.GetByID(chatID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("chat not found: %s", chatID)
|
||||
}
|
||||
|
||||
// Check if user is the owner (chat's tenant_id must be in user's tenants)
|
||||
if !tenantIDSet[chat.TenantID] {
|
||||
return errors.New("only owner of chat authorized for this operation")
|
||||
}
|
||||
|
||||
// Add to update list (soft delete by setting status to "0")
|
||||
updates = append(updates, map[string]interface{}{
|
||||
"id": chatID,
|
||||
"status": "0",
|
||||
})
|
||||
}
|
||||
|
||||
// Batch update all dialogs
|
||||
if err := s.chatDAO.UpdateManyByID(updates); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// strPtr returns a pointer to a string
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
|
||||
@@ -171,16 +171,23 @@ func (s *ChatPipelineService) AsyncChat(
|
||||
}
|
||||
|
||||
// No KBs & no web search → fast-path to LLM-only chat.
|
||||
hasKBs := false
|
||||
for _, raw := range chat.KBIDs {
|
||||
if id, ok := raw.(string); ok && id != "" {
|
||||
hasKBs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
useWebSearch := s.shouldUseWebSearch(chat, kwargs["internet"])
|
||||
if useWebSearch {
|
||||
common.Debug("web_search",
|
||||
zap.Bool("kb", len(chat.KBIDs) > 0),
|
||||
zap.Bool("kb", hasKBs),
|
||||
zap.Bool("tavily", chat.PromptConfig != nil && chat.PromptConfig["tavily_api_key"] != "" && chat.PromptConfig["tavily_api_key"] != nil),
|
||||
zap.Any("internet", kwargs["internet"]),
|
||||
zap.Bool("enabled", useWebSearch))
|
||||
}
|
||||
|
||||
if len(chat.KBIDs) == 0 && !useWebSearch {
|
||||
if !hasKBs && !useWebSearch {
|
||||
return s.AsyncChatSolo(ctx, chat, messages, stream)
|
||||
}
|
||||
|
||||
@@ -1022,8 +1029,7 @@ func (s *ChatPipelineService) AsyncChat(
|
||||
if stream {
|
||||
// Streaming path: accumulate answer, emit deltas.
|
||||
var fullAnswer string
|
||||
var fullReasoning string
|
||||
thinkState := &thinkStreamState{}
|
||||
thinkState := &ThinkStreamState{}
|
||||
|
||||
chatCfg := BuildChatConfig(chat, nil)
|
||||
|
||||
@@ -1094,50 +1100,60 @@ func (s *ChatPipelineService) AsyncChat(
|
||||
*chatDriver.ModelName, chatMessages, chatDriver.APIConfig, chatCfg,
|
||||
func(answer *string, reason *string) error {
|
||||
if reason != nil && *reason != "" {
|
||||
fullReasoning += *reason
|
||||
kind, output := processThinkDelta(thinkState, *reason, 16)
|
||||
if kind == "marker" && output == "<think>" {
|
||||
// <think> marker — emit StartToThink
|
||||
out <- AsyncChatResult{
|
||||
Answer: "",
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: nil,
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
StartToThink: true,
|
||||
}
|
||||
} else if kind == "marker" && output == "</think>" {
|
||||
// </think> marker — emit EndToThink
|
||||
out <- AsyncChatResult{
|
||||
Answer: "",
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: nil,
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
EndToThink: true,
|
||||
}
|
||||
} else if kind == "text" && output != "" {
|
||||
// Route reasoning text to Reasoning field.
|
||||
// TTS is nil — chain-of-thought is not narrated.
|
||||
out <- AsyncChatResult{
|
||||
Reasoning: output,
|
||||
Reference: map[string]interface{}{},
|
||||
// TTS only narrates user-visible answer text.
|
||||
AudioBinary: nil,
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
deltas := NextThinkDelta(thinkState, *reason, 16)
|
||||
for _, d := range deltas {
|
||||
if d.Kind == ThinkDeltaMarker && d.Value == "<think>" {
|
||||
out <- AsyncChatResult{
|
||||
Answer: "",
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: nil,
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
StartToThink: true,
|
||||
}
|
||||
} else if d.Kind == ThinkDeltaMarker && d.Value == "</think>" {
|
||||
out <- AsyncChatResult{
|
||||
Answer: "",
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: nil,
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
EndToThink: true,
|
||||
}
|
||||
} else if d.Kind == ThinkDeltaText && d.Value != "" {
|
||||
fullAnswer += d.Value
|
||||
out <- AsyncChatResult{
|
||||
Answer: d.Value,
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: s.synthesizeTTS(ttsModel, d.Value),
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if isContentDelta(answer) {
|
||||
fullAnswer += *answer
|
||||
out <- AsyncChatResult{
|
||||
Answer: *answer,
|
||||
Reference: map[string]interface{}{},
|
||||
// Per-delta TTS for incremental audio playback.
|
||||
AudioBinary: s.synthesizeTTS(ttsModel, *answer),
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
deltas := BufferAnswerDelta(thinkState, *answer, 16)
|
||||
for _, d := range deltas {
|
||||
if d.Kind == ThinkDeltaMarker && d.Value == "</think>" {
|
||||
out <- AsyncChatResult{
|
||||
Answer: "",
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: nil,
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
EndToThink: true,
|
||||
}
|
||||
} else if d.Kind == ThinkDeltaText && d.Value != "" {
|
||||
out <- AsyncChatResult{
|
||||
Answer: d.Value,
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: s.synthesizeTTS(ttsModel, d.Value),
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -1152,33 +1168,34 @@ func (s *ChatPipelineService) AsyncChat(
|
||||
return
|
||||
}
|
||||
|
||||
// Flush remaining think stream buffer.
|
||||
// If the LLM ended mid-think, emit remaining reasoning +
|
||||
// implicit close marker.
|
||||
remainingText, remainingMarker := flushThinkStream(thinkState)
|
||||
if remainingText != "" {
|
||||
// Flushed text belongs in Reasoning, not Answer.
|
||||
out <- AsyncChatResult{
|
||||
Reasoning: remainingText,
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: nil,
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
}
|
||||
}
|
||||
if remainingMarker == "</think>" {
|
||||
out <- AsyncChatResult{
|
||||
Answer: "",
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: nil,
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
EndToThink: true,
|
||||
// Flush remaining state matching Python's final flush order
|
||||
// (dialog_service.py:1601-1612): think_buffer → marker → answer_buffer → pending_after_close
|
||||
// Python has no Reasoning field — all text is Answer.
|
||||
for _, d := range FlushRemaining(thinkState) {
|
||||
if d.Kind == ThinkDeltaMarker && d.Value == "</think>" {
|
||||
out <- AsyncChatResult{
|
||||
Answer: "",
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: nil,
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
EndToThink: true,
|
||||
}
|
||||
} else if d.Kind == ThinkDeltaText && d.Value != "" {
|
||||
out <- AsyncChatResult{
|
||||
Answer: d.Value,
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: s.synthesizeTTS(ttsModel, d.Value),
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decorate and yield the final answer.
|
||||
visibleAnswer := s.extractVisibleAnswer(fullReasoning + fullAnswer)
|
||||
// Python uses state.full_text (raw text with <think> tags) as input
|
||||
// to _extract_visible_answer → decorate_answer (dialog_service.py:914-920).
|
||||
visibleAnswer := s.extractVisibleAnswer(thinkState.fullText)
|
||||
|
||||
// Pass nil for ttsModel — audio was already produced per-delta.
|
||||
final := s.decorateAnswer(ctx, visibleAnswer, kbinfos, prompt, questions, usedTokenCount, timer, embModel, chat.VectorSimilarityWeight, quote, nil, langfuseTraceID, llmModelConfig, chat.TenantID, kbTenantIDStrings(kbs), len(knowledges) > 0)
|
||||
@@ -1372,28 +1389,51 @@ func (s *ChatPipelineService) AsyncChatSolo(
|
||||
// 7. Drive the LLM: stream (per-delta with think markers) or non-stream (one-shot).
|
||||
if stream {
|
||||
var fullAnswer string
|
||||
var fullReasoning string
|
||||
thinkState := &thinkStreamState{}
|
||||
thinkState := &ThinkStreamState{}
|
||||
chatCfg := BuildChatConfig(chat, nil)
|
||||
timer.Enter(common.PhaseGenerateAnswer)
|
||||
|
||||
driverErr := chatModel.ModelDriver.ChatStreamlyWithSender(
|
||||
*chatModel.ModelName, chatMessages, chatModel.APIConfig, chatCfg,
|
||||
func(answer *string, reason *string) error {
|
||||
if reason != nil && *reason != "" {
|
||||
fullReasoning += *reason
|
||||
kind, output := processThinkDelta(thinkState, *reason, 16)
|
||||
if kind == "marker" && output == "<think>" {
|
||||
// Start thinking.
|
||||
out <- AsyncChatResult{
|
||||
Answer: "",
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: nil,
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
StartToThink: true,
|
||||
deltas := NextThinkDelta(thinkState, *reason, 16)
|
||||
for _, d := range deltas {
|
||||
if d.Kind == ThinkDeltaMarker && d.Value == "<think>" {
|
||||
out <- AsyncChatResult{
|
||||
Answer: "",
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: nil,
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
StartToThink: true,
|
||||
}
|
||||
} else if d.Kind == ThinkDeltaMarker && d.Value == "</think>" {
|
||||
out <- AsyncChatResult{
|
||||
Answer: "",
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: nil,
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
EndToThink: true,
|
||||
}
|
||||
} else if d.Kind == ThinkDeltaText && d.Value != "" {
|
||||
fullAnswer += d.Value
|
||||
out <- AsyncChatResult{
|
||||
Answer: d.Value,
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: s.synthesizeTTS(ttsModel, d.Value),
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
}
|
||||
}
|
||||
} else if kind == "marker" && output == "</think>" {
|
||||
// End thinking.
|
||||
}
|
||||
}
|
||||
if isContentDelta(answer) {
|
||||
fullAnswer += *answer
|
||||
deltas := BufferAnswerDelta(thinkState, *answer, 16)
|
||||
for _, d := range deltas {
|
||||
if d.Kind == ThinkDeltaMarker && d.Value == "</think>" {
|
||||
out <- AsyncChatResult{
|
||||
Answer: "",
|
||||
Reference: map[string]interface{}{},
|
||||
@@ -1402,27 +1442,17 @@ func (s *ChatPipelineService) AsyncChatSolo(
|
||||
Final: false,
|
||||
EndToThink: true,
|
||||
}
|
||||
} else if kind == "text" && output != "" {
|
||||
// Reasoning text with per-delta TTS.
|
||||
} else if d.Kind == ThinkDeltaText && d.Value != "" {
|
||||
out <- AsyncChatResult{
|
||||
Reasoning: output,
|
||||
Answer: d.Value,
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: s.synthesizeTTS(ttsModel, output),
|
||||
AudioBinary: s.synthesizeTTS(ttsModel, d.Value),
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
if isContentDelta(answer) {
|
||||
fullAnswer += *answer
|
||||
out <- AsyncChatResult{
|
||||
Answer: *answer,
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: s.synthesizeTTS(ttsModel, *answer),
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
@@ -1434,33 +1464,30 @@ func (s *ChatPipelineService) AsyncChatSolo(
|
||||
return
|
||||
}
|
||||
timer.Exit(common.PhaseGenerateAnswer)
|
||||
// Flush any remaining think buffer.
|
||||
remainingText, remainingMarker := flushThinkStream(thinkState)
|
||||
if remainingText != "" {
|
||||
out <- AsyncChatResult{
|
||||
Reasoning: remainingText,
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: s.synthesizeTTS(ttsModel, remainingText),
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
for _, d := range FlushRemaining(thinkState) {
|
||||
if d.Kind == ThinkDeltaMarker && d.Value == "</think>" {
|
||||
out <- AsyncChatResult{
|
||||
Answer: "",
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: nil,
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
EndToThink: true,
|
||||
}
|
||||
} else if d.Kind == ThinkDeltaText && d.Value != "" {
|
||||
out <- AsyncChatResult{
|
||||
Answer: d.Value,
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: s.synthesizeTTS(ttsModel, d.Value),
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
if remainingMarker == "</think>" {
|
||||
out <- AsyncChatResult{
|
||||
Answer: "",
|
||||
Reference: map[string]interface{}{},
|
||||
AudioBinary: nil,
|
||||
CreatedAt: float64(time.Now().Unix()),
|
||||
Final: false,
|
||||
EndToThink: true,
|
||||
}
|
||||
finalAnswer := ExtractVisibleAnswer(thinkState.fullText)
|
||||
if finalAnswer == "" {
|
||||
finalAnswer = fullAnswer
|
||||
}
|
||||
// Final aggregate: re-attach reasoning wrapper for non-streaming consumers.
|
||||
finalAnswer := fullAnswer
|
||||
if fullReasoning != "" {
|
||||
finalAnswer = "<think>" + fullReasoning + "</think>" + fullAnswer
|
||||
}
|
||||
// Raw answer, no decorate_answer. AudioBinary=nil (per-delta TTS already emitted).
|
||||
out <- AsyncChatResult{
|
||||
Answer: finalAnswer,
|
||||
Reference: map[string]interface{}{},
|
||||
@@ -2288,7 +2315,7 @@ func (s *ChatPipelineService) kbPrompt(kbinfos map[string]interface{}, maxTokens
|
||||
}
|
||||
contents := make([]chunkContent, 0, len(chunksRaw))
|
||||
for _, ck := range chunksRaw {
|
||||
c := getMapString(ck, "content", "content_with_weight")
|
||||
c := getMapString(ck, "content_with_weight", "content")
|
||||
if c == "" {
|
||||
continue
|
||||
}
|
||||
@@ -2315,7 +2342,7 @@ func (s *ChatPipelineService) kbPrompt(kbinfos map[string]interface{}, maxTokens
|
||||
var result []string
|
||||
for i := 0; i < chunksNum; i++ {
|
||||
ck := chunksRaw[i]
|
||||
c := getMapString(ck, "content", "content_with_weight")
|
||||
c := getMapString(ck, "content_with_weight", "content")
|
||||
if c == "" {
|
||||
continue
|
||||
}
|
||||
@@ -2781,25 +2808,8 @@ func langfuseExtractTimeElapsed(prompt string) string {
|
||||
}
|
||||
|
||||
// extractVisibleAnswer mirrors Python's _extract_visible_answer.
|
||||
// It preserves <think> wrappers and strips stray think tags.
|
||||
func (s *ChatPipelineService) extractVisibleAnswer(text string) string {
|
||||
if !strings.Contains(text, "</think>") {
|
||||
text = strings.ReplaceAll(text, "<think>", "")
|
||||
text = strings.ReplaceAll(text, "</think>", "")
|
||||
return text
|
||||
}
|
||||
idx := strings.LastIndex(text, "</think>")
|
||||
thought := text[:idx]
|
||||
answer := text[idx+len("</think>"):]
|
||||
thought = strings.ReplaceAll(thought, "<think>", "")
|
||||
thought = strings.ReplaceAll(thought, "</think>", "")
|
||||
thought = strings.TrimSpace(thought)
|
||||
answer = strings.ReplaceAll(answer, "<think>", "")
|
||||
answer = strings.ReplaceAll(answer, "</think>", "")
|
||||
if thought == "" {
|
||||
return answer
|
||||
}
|
||||
return "<think>" + thought + "</think>" + answer
|
||||
return ExtractVisibleAnswer(text)
|
||||
}
|
||||
|
||||
// citationPrompt returns the citation instruction prompt.
|
||||
@@ -2809,124 +2819,6 @@ func citationPrompt() string {
|
||||
"(where N is the chunk number) after each sentence where the information from that chunk is used."
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Think-marker streaming — mirrors Python's _stream_with_think_delta.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// thinkStreamState tracks accumulated reasoning text and emits deltas.
|
||||
type thinkStreamState struct {
|
||||
fullText string
|
||||
lastIdx int
|
||||
endsWithThink bool
|
||||
inThink bool
|
||||
buffer string
|
||||
postThinkText string
|
||||
}
|
||||
|
||||
// nextThinkDelta computes the next delta to emit from the accumulated text.
|
||||
// Mirrors _next_think_delta in dialog_service.py:1460-1487.
|
||||
func nextThinkDelta(state *thinkStreamState) string {
|
||||
full := state.fullText
|
||||
if full == "" || len(full) <= state.lastIdx {
|
||||
return ""
|
||||
}
|
||||
delta := full[state.lastIdx:]
|
||||
|
||||
if strings.HasPrefix(delta, "<think>") {
|
||||
state.lastIdx += len("<think>")
|
||||
return "<think>"
|
||||
}
|
||||
if idx := strings.Index(delta, "<think>"); idx > 0 {
|
||||
state.lastIdx += idx
|
||||
return delta[:idx]
|
||||
}
|
||||
if strings.HasSuffix(delta, "</think>") {
|
||||
state.endsWithThink = true
|
||||
} else if state.endsWithThink {
|
||||
state.endsWithThink = false
|
||||
remainder := delta
|
||||
if idx := strings.Index(delta, "</think>"); idx >= 0 {
|
||||
remainder = delta[idx+len("</think>"):]
|
||||
}
|
||||
if remainder != "" {
|
||||
state.postThinkText = remainder
|
||||
}
|
||||
state.lastIdx = len(full)
|
||||
return "</think>"
|
||||
}
|
||||
|
||||
state.lastIdx = len(full)
|
||||
if strings.HasSuffix(full, "</think>") {
|
||||
state.lastIdx -= len("</think>")
|
||||
}
|
||||
return strings.ReplaceAll(strings.ReplaceAll(delta, "<think>", ""), "</think>", "")
|
||||
}
|
||||
|
||||
// processThinkDelta updates the state with a new delta and returns what to emit.
|
||||
// Returns the kind of emission: "marker" for think tags, "text" for content, "" for nothing.
|
||||
func processThinkDelta(state *thinkStreamState, delta string, minTokens int) (kind string, output string) {
|
||||
if delta == "" {
|
||||
return "", ""
|
||||
}
|
||||
state.fullText += delta
|
||||
d := nextThinkDelta(state)
|
||||
if d == "" {
|
||||
return "", ""
|
||||
}
|
||||
if d == "<think>" {
|
||||
if state.inThink {
|
||||
return "", ""
|
||||
}
|
||||
if state.buffer != "" {
|
||||
kind, out := "text", state.buffer
|
||||
state.buffer = ""
|
||||
state.inThink = true
|
||||
return kind, out
|
||||
}
|
||||
state.inThink = true
|
||||
return "marker", "<think>"
|
||||
}
|
||||
if d == "</think>" {
|
||||
if !state.inThink {
|
||||
return "", ""
|
||||
}
|
||||
state.inThink = false
|
||||
if state.postThinkText != "" {
|
||||
state.buffer += state.postThinkText
|
||||
state.postThinkText = ""
|
||||
}
|
||||
return "marker", "</think>"
|
||||
}
|
||||
state.buffer += d
|
||||
if kg.NumTokensFromString(state.buffer) < minTokens {
|
||||
return "", ""
|
||||
}
|
||||
out := state.buffer
|
||||
state.buffer = ""
|
||||
return "text", out
|
||||
}
|
||||
|
||||
// flushThinkStream flushes any remaining buffered text from the think stream.
|
||||
func flushThinkStream(state *thinkStreamState) (text string, marker string) {
|
||||
if state.buffer != "" {
|
||||
text = state.buffer
|
||||
state.buffer = ""
|
||||
}
|
||||
if state.postThinkText != "" {
|
||||
if text != "" {
|
||||
text += state.postThinkText
|
||||
} else {
|
||||
text = state.postThinkText
|
||||
}
|
||||
state.postThinkText = ""
|
||||
}
|
||||
if state.endsWithThink {
|
||||
marker = "</think>"
|
||||
state.endsWithThink = false
|
||||
}
|
||||
return text, marker
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Moved from sql_fallback.go (2026-06-12). SQL retrieval system, repair
|
||||
// helpers, and Python parity helpers. Kept in async_chat.go because the
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,7 +26,6 @@ type fakeSessionStore struct {
|
||||
getByIDErr error
|
||||
createErr error
|
||||
updateByIDErr error
|
||||
deleteByIDErr error
|
||||
getDialogErr error
|
||||
// record calls
|
||||
createCalled []*entity.ChatSession
|
||||
@@ -34,7 +33,6 @@ type fakeSessionStore struct {
|
||||
id string
|
||||
updates map[string]interface{}
|
||||
}
|
||||
deleteByIDIDs []string
|
||||
}
|
||||
|
||||
func newFakeSessionStore() *fakeSessionStore {
|
||||
@@ -112,12 +110,6 @@ func (f *fakeSessionStore) UpdateByID(id string, updates map[string]interface{})
|
||||
}
|
||||
|
||||
func (f *fakeSessionStore) DeleteByID(id string) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if f.deleteByIDErr != nil {
|
||||
return f.deleteByIDErr
|
||||
}
|
||||
f.deleteByIDIDs = append(f.deleteByIDIDs, id)
|
||||
delete(f.sessions, id)
|
||||
return nil
|
||||
}
|
||||
@@ -179,226 +171,6 @@ func makeResultChan(results ...AsyncChatResult) <-chan AsyncChatResult {
|
||||
return ch
|
||||
}
|
||||
|
||||
// ===================================================================
|
||||
// SetChatSession tests
|
||||
// ===================================================================
|
||||
|
||||
func TestSetChatSession_CreateNew(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
dialog := &entity.Chat{ID: "dialog-1", PromptConfig: entity.JSONMap{"prologue": "Welcome!"}}
|
||||
store.dialogs["dialog-1"] = dialog
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
resp, err := svc.SetChatSession("user-1", &SetChatSessionRequest{
|
||||
DialogID: "dialog-1",
|
||||
IsNew: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.ID == "" {
|
||||
t.Fatal("expected session ID to be generated")
|
||||
}
|
||||
if resp.DialogID != "dialog-1" {
|
||||
t.Fatalf("expected dialog_id=dialog-1, got %s", resp.DialogID)
|
||||
}
|
||||
if len(store.createCalled) != 1 {
|
||||
t.Fatalf("expected 1 Create call, got %d", len(store.createCalled))
|
||||
}
|
||||
|
||||
// Verify prologue is in the message list.
|
||||
var msgs []map[string]interface{}
|
||||
if err := json.Unmarshal(store.createCalled[0].Message, &msgs); err != nil {
|
||||
t.Fatalf("failed to unmarshal message: %v", err)
|
||||
}
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected 1 initial message, got %d", len(msgs))
|
||||
}
|
||||
firstMsg := msgs[0]
|
||||
if firstMsg["role"] != "assistant" || firstMsg["content"] != "Welcome!" {
|
||||
t.Fatalf("unexpected prologue message: %#v", firstMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetChatSession_CreateNewDefaultPrologue(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.dialogs["dialog-1"] = &entity.Chat{ID: "dialog-1"}
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
resp, err := svc.SetChatSession("user-1", &SetChatSessionRequest{
|
||||
DialogID: "dialog-1",
|
||||
IsNew: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.ID == "" {
|
||||
t.Fatal("expected session ID")
|
||||
}
|
||||
// Default prologue
|
||||
var msgs []map[string]interface{}
|
||||
json.Unmarshal(store.createCalled[0].Message, &msgs)
|
||||
firstMsg := msgs[0]
|
||||
if !strings.Contains(firstMsg["content"].(string), "Hi! I'm your assistant") {
|
||||
t.Fatalf("expected default prologue, got %q", firstMsg["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetChatSession_CreateNewDialogNotFound(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, err := svc.SetChatSession("user-1", &SetChatSessionRequest{
|
||||
DialogID: "nonexistent",
|
||||
IsNew: true,
|
||||
})
|
||||
if err == nil || err.Error() != "Dialog not found" {
|
||||
t.Fatalf("expected 'Dialog not found' error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetChatSession_UpdateExisting(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1", DialogID: "dialog-1", Name: strPtr("old name"),
|
||||
}
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
resp, err := svc.SetChatSession("user-1", &SetChatSessionRequest{
|
||||
SessionID: "session-1",
|
||||
Name: "new name",
|
||||
IsNew: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.ID != "session-1" {
|
||||
t.Fatalf("expected session-1, got %s", resp.ID)
|
||||
}
|
||||
if len(store.updateCalled) != 1 {
|
||||
t.Fatalf("expected UpdateByID call, got %d", len(store.updateCalled))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetChatSession_UpdateNotFound(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.updateByIDErr = errors.New("Chat session not found")
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, err := svc.SetChatSession("user-1", &SetChatSessionRequest{
|
||||
SessionID: "missing",
|
||||
IsNew: false,
|
||||
})
|
||||
if err == nil || err.Error() != "Chat session not found" {
|
||||
t.Fatalf("expected 'Chat session not found' error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetChatSession_NameTruncation(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.dialogs["dialog-1"] = &entity.Chat{ID: "dialog-1"}
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
longName := strings.Repeat("x", 300)
|
||||
resp, err := svc.SetChatSession("user-1", &SetChatSessionRequest{
|
||||
DialogID: "dialog-1",
|
||||
Name: longName,
|
||||
IsNew: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Name == nil || len(*resp.Name) > 255 {
|
||||
t.Fatalf("expected name truncated to <=255, got len=%d", len(*resp.Name))
|
||||
}
|
||||
}
|
||||
|
||||
// ===================================================================
|
||||
// RemoveChatSessions tests
|
||||
// ===================================================================
|
||||
|
||||
func TestRemoveChatSessions_Success(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["conv-1"] = &entity.ChatSession{ID: "conv-1", DialogID: "dialog-1"}
|
||||
store.sessions["conv-2"] = &entity.ChatSession{ID: "conv-2", DialogID: "dialog-1"}
|
||||
store.dialogExists["user-1|dialog-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-1"}},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
err := svc.RemoveChatSessions("user-1", []string{"conv-1", "conv-2"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(store.deleteByIDIDs) != 2 {
|
||||
t.Fatalf("expected 2 deletes, got %d", len(store.deleteByIDIDs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveChatSessions_SessionNotFound(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-1"}},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
err := svc.RemoveChatSessions("user-1", []string{"missing"})
|
||||
if err == nil || !strings.Contains(err.Error(), "not found") {
|
||||
t.Fatalf("expected 'not found' error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveChatSessions_NotOwner(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["conv-1"] = &entity.ChatSession{ID: "conv-1", DialogID: "dialog-1"}
|
||||
// No tenant matches — dialogExists stays false for all combinations
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-other"}},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
err := svc.RemoveChatSessions("user-1", []string{"conv-1"})
|
||||
if err == nil || !strings.Contains(err.Error(), "Only owner") {
|
||||
t.Fatalf("expected 'Only owner' error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ===================================================================
|
||||
// ListChatSessions tests
|
||||
// ===================================================================
|
||||
@@ -638,301 +410,6 @@ func TestUpdateSession_NotFound(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// ===================================================================
|
||||
// Completion tests
|
||||
// ===================================================================
|
||||
|
||||
func TestCompletion_Success(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
session := &entity.ChatSession{
|
||||
ID: "session-1", DialogID: "dialog-1",
|
||||
Message: json.RawMessage(`[{"role":"assistant","content":"Welcome!"}]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
}
|
||||
store.sessions["session-1"] = session
|
||||
store.dialogs["dialog-1"] = &entity.Chat{
|
||||
ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory",
|
||||
LLMSetting: entity.JSONMap{},
|
||||
}
|
||||
|
||||
pipeline := &fakePipeline{
|
||||
resultChan: makeResultChan(
|
||||
AsyncChatResult{Answer: "Hello", Reference: map[string]interface{}{"chunks": []interface{}{}}},
|
||||
AsyncChatResult{Answer: " world", Final: true, Reference: map[string]interface{}{"chunks": []interface{}{}}},
|
||||
),
|
||||
}
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: pipeline,
|
||||
}
|
||||
|
||||
result, err := svc.Completion("user-1", "session-1", []map[string]interface{}{
|
||||
{"role": "user", "content": "hi"},
|
||||
}, "", nil, "msg-1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
ans, _ := result["answer"].(string)
|
||||
if ans != "Hello world" {
|
||||
t.Fatalf("expected answer 'Hello world', got %q", ans)
|
||||
}
|
||||
|
||||
got := parseMessages(store.sessions["session-1"].Message)
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("stored messages=%#v", got)
|
||||
}
|
||||
if got[0]["role"] != "assistant" || got[0]["content"] != "Welcome!" {
|
||||
t.Fatalf("stored prologue=%#v", got[0])
|
||||
}
|
||||
if got[1]["role"] != "user" || got[1]["content"] != "hi" {
|
||||
t.Fatalf("stored user message=%#v", got[1])
|
||||
}
|
||||
if got[2]["role"] != "assistant" || got[2]["content"] != "Hello world" || got[2]["id"] != "msg-1" {
|
||||
t.Fatalf("stored assistant message=%#v", got[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletion_EmptyMessages(t *testing.T) {
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: &fakeSessionStore{},
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, err := svc.Completion("user-1", "session-1", nil, "", nil, "msg-1")
|
||||
if err == nil || err.Error() != "messages cannot be empty" {
|
||||
t.Fatalf("expected 'messages cannot be empty', got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletion_LastMessageNotFromUser(t *testing.T) {
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: &fakeSessionStore{},
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, err := svc.Completion("user-1", "session-1", []map[string]interface{}{
|
||||
{"role": "assistant", "content": "hello"},
|
||||
}, "", nil, "msg-1")
|
||||
if err == nil || !strings.Contains(err.Error(), "not from user") {
|
||||
t.Fatalf("expected 'not from user' error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletion_ConversationNotFound(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, err := svc.Completion("user-1", "missing", []map[string]interface{}{
|
||||
{"role": "user", "content": "hi"},
|
||||
}, "", nil, "msg-1")
|
||||
if err == nil || err.Error() != "Conversation not found" {
|
||||
t.Fatalf("expected 'Conversation not found', got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletion_DialogNotFound(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1", DialogID: "dialog-1",
|
||||
Message: json.RawMessage(`[]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
}
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, err := svc.Completion("user-1", "session-1", []map[string]interface{}{
|
||||
{"role": "user", "content": "hi"},
|
||||
}, "", nil, "msg-1")
|
||||
if err == nil || err.Error() != "Dialog not found" {
|
||||
t.Fatalf("expected 'Dialog not found', got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletion_PipelineError(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1", DialogID: "dialog-1",
|
||||
Message: json.RawMessage(`[]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
}
|
||||
store.dialogs["dialog-1"] = &entity.Chat{
|
||||
ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory",
|
||||
LLMSetting: entity.JSONMap{},
|
||||
}
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{err: errors.New("model unavailable")},
|
||||
}
|
||||
|
||||
_, err := svc.Completion("user-1", "session-1", []map[string]interface{}{
|
||||
{"role": "user", "content": "hi"},
|
||||
}, "", nil, "msg-1")
|
||||
if err == nil || err.Error() != "model unavailable" {
|
||||
t.Fatalf("expected 'model unavailable' error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ===================================================================
|
||||
// CompletionStream tests
|
||||
// ===================================================================
|
||||
|
||||
func readStreamChan(ch <-chan string, n int) []string {
|
||||
var msgs []string
|
||||
for i := 0; i < n; i++ {
|
||||
select {
|
||||
case msg, ok := <-ch:
|
||||
if !ok {
|
||||
return msgs
|
||||
}
|
||||
msgs = append(msgs, msg)
|
||||
default:
|
||||
return msgs
|
||||
}
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
|
||||
func TestCompletionStream_Success(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1", DialogID: "dialog-1",
|
||||
Message: json.RawMessage(`{"messages":[{"role":"assistant","content":"Welcome!"}]}`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
}
|
||||
store.dialogs["dialog-1"] = &entity.Chat{
|
||||
ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory",
|
||||
LLMSetting: entity.JSONMap{},
|
||||
}
|
||||
|
||||
pipeline := &fakePipeline{
|
||||
resultChan: makeResultChan(
|
||||
AsyncChatResult{Answer: "stream", Reference: map[string]interface{}{"chunks": []interface{}{}}},
|
||||
AsyncChatResult{Answer: " answer", Reference: map[string]interface{}{"chunks": []interface{}{}}},
|
||||
),
|
||||
}
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: pipeline,
|
||||
}
|
||||
|
||||
streamChan := make(chan string, 10)
|
||||
err := svc.CompletionStream(context.Background(), "user-1", "session-1", []map[string]interface{}{
|
||||
{"role": "user", "content": "hi"},
|
||||
}, "", nil, "msg-1", streamChan)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should receive data events and final signal
|
||||
msgs := readStreamChan(streamChan, 5)
|
||||
if len(msgs) < 3 {
|
||||
t.Fatalf("expected at least 3 stream messages, got %d: %v", len(msgs), msgs)
|
||||
}
|
||||
// Check final signal
|
||||
finalFound := false
|
||||
for _, m := range msgs {
|
||||
if strings.Contains(m, `"data":true`) {
|
||||
finalFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !finalFound {
|
||||
t.Fatal("expected final=true signal in stream")
|
||||
}
|
||||
|
||||
got := parseMessages(store.sessions["session-1"].Message)
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("stored messages=%#v", got)
|
||||
}
|
||||
if got[0]["role"] != "assistant" || got[0]["content"] != "Welcome!" {
|
||||
t.Fatalf("stored prologue=%#v", got[0])
|
||||
}
|
||||
if got[1]["role"] != "user" || got[1]["content"] != "hi" {
|
||||
t.Fatalf("stored user message=%#v", got[1])
|
||||
}
|
||||
if got[2]["role"] != "assistant" || got[2]["content"] != "stream answer" || got[2]["id"] != "msg-1" {
|
||||
t.Fatalf("stored assistant message=%#v", got[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructureAnswerWithConv_ParsesArrayMessages(t *testing.T) {
|
||||
session := &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
Message: json.RawMessage(`[{"role":"assistant","content":"Welcome!"}]`),
|
||||
}
|
||||
svc := &ChatSessionService{}
|
||||
|
||||
ans := svc.structureAnswerWithConv(session, map[string]interface{}{
|
||||
"answer": "Final answer",
|
||||
"reference": map[string]interface{}{"chunks": []interface{}{}},
|
||||
"final": true,
|
||||
}, "msg-1", "session-1", []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
|
||||
|
||||
if ans["id"] != "msg-1" || ans["session_id"] != "session-1" {
|
||||
t.Fatalf("ans=%#v", ans)
|
||||
}
|
||||
|
||||
got := parseMessages(session.Message)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("stored messages=%#v", got)
|
||||
}
|
||||
if got[0]["role"] != "assistant" || got[0]["content"] != "Final answer" || got[0]["id"] != "msg-1" {
|
||||
t.Fatalf("stored assistant message=%#v", got[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMessages_LegacyWrappedObject(t *testing.T) {
|
||||
got := parseMessages(json.RawMessage(`{"messages":[{"role":"assistant","content":"legacy"}]}`))
|
||||
if !reflect.DeepEqual(got, []map[string]interface{}{{"role": "assistant", "content": "legacy"}}) {
|
||||
t.Fatalf("messages=%#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSessionPayload_EmptyCollectionsEncodeAsEmptyArrays(t *testing.T) {
|
||||
svc := &ChatSessionService{}
|
||||
payload := svc.buildSessionPayload(&entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Message: nil,
|
||||
Reference: json.RawMessage(`null`),
|
||||
}, nil, false)
|
||||
|
||||
if payload.Messages == nil {
|
||||
t.Fatal("messages is nil")
|
||||
}
|
||||
if payload.Reference == nil {
|
||||
t.Fatal("reference is nil")
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
if !strings.Contains(string(body), `"messages":[]`) {
|
||||
t.Fatalf("messages did not encode as empty array: %s", string(body))
|
||||
}
|
||||
if !strings.Contains(string(body), `"reference":[]`) {
|
||||
t.Fatalf("reference did not encode as empty array: %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCollections_ReturnEmptySlicesForMissingOrNull(t *testing.T) {
|
||||
messageInputs := []json.RawMessage{
|
||||
nil,
|
||||
@@ -980,99 +457,174 @@ func TestParseCollections_ReturnNilForMalformedData(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionStream_EmptyMessages(t *testing.T) {
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: &fakeSessionStore{},
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
// ===================================================================
|
||||
// chunksFormat tests — verifies field normalization after the rewrite.
|
||||
// ===================================================================
|
||||
|
||||
func TestChunksFormat_NormalizesRawFieldNames(t *testing.T) {
|
||||
svc := &ChatSessionService{}
|
||||
ref := map[string]interface{}{
|
||||
"chunks": []map[string]interface{}{
|
||||
{
|
||||
"chunk_id": "c1",
|
||||
"content_with_weight": "hello world",
|
||||
"content_ltks": "hello world ltks",
|
||||
"doc_id": "d1",
|
||||
"docnm_kwd": "Document 1",
|
||||
"kb_id": "kb1",
|
||||
"image_id": "img1",
|
||||
"img_id": "img2",
|
||||
"positions": []int{0, 10},
|
||||
"position_int": []int{1, 11},
|
||||
"doc_type_kwd": "pdf",
|
||||
"similarity": 0.95,
|
||||
"vector_similarity": 0.9,
|
||||
"term_similarity": 0.85,
|
||||
"row_id": "r1",
|
||||
"url": "http://example.com",
|
||||
"document_metadata": map[string]interface{}{"author": "Alice"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
streamChan := make(chan string, 10)
|
||||
err := svc.CompletionStream(context.Background(), "user-1", "session-1", nil, "", nil, "msg-1", streamChan)
|
||||
if err == nil || err.Error() != "messages cannot be empty" {
|
||||
t.Fatalf("expected 'messages cannot be empty', got %v", err)
|
||||
result := svc.chunksFormat(ref)
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("expected 1 chunk, got %d", len(result))
|
||||
}
|
||||
c := result[0]
|
||||
|
||||
if c["id"] != "c1" {
|
||||
t.Fatalf("id=%v", c["id"])
|
||||
}
|
||||
if c["content"] != "hello world" {
|
||||
t.Fatalf("content=%v", c["content"])
|
||||
}
|
||||
if c["document_id"] != "d1" {
|
||||
t.Fatalf("document_id=%v", c["document_id"])
|
||||
}
|
||||
if c["document_name"] != "Document 1" {
|
||||
t.Fatalf("document_name=%v", c["document_name"])
|
||||
}
|
||||
if c["dataset_id"] != "kb1" {
|
||||
t.Fatalf("dataset_id=%v", c["dataset_id"])
|
||||
}
|
||||
if c["image_id"] != "img1" {
|
||||
t.Fatalf("image_id=%v", c["image_id"])
|
||||
}
|
||||
if c["doc_type"] != "pdf" {
|
||||
t.Fatalf("doc_type=%v", c["doc_type"])
|
||||
}
|
||||
if c["similarity"] != 0.95 {
|
||||
t.Fatalf("similarity=%v", c["similarity"])
|
||||
}
|
||||
if c["url"] != "http://example.com" {
|
||||
t.Fatalf("url=%v", c["url"])
|
||||
}
|
||||
|
||||
pos, ok := c["positions"].([]int)
|
||||
if !ok || len(pos) != 2 || pos[0] != 0 {
|
||||
t.Fatalf("positions=%v (%T)", c["positions"], c["positions"])
|
||||
}
|
||||
|
||||
// Raw keys must be normalized away.
|
||||
if _, exists := c["content_with_weight"]; exists {
|
||||
t.Fatal("content_with_weight should not be present after normalization")
|
||||
}
|
||||
if _, exists := c["content_ltks"]; exists {
|
||||
t.Fatal("content_ltks should not be present after normalization")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionStream_LastMessageNotFromUser(t *testing.T) {
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: &fakeSessionStore{},
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
func TestChunksFormat_PreservesAlreadyNormalizedFields(t *testing.T) {
|
||||
svc := &ChatSessionService{}
|
||||
ref := map[string]interface{}{
|
||||
"chunks": []map[string]interface{}{
|
||||
{
|
||||
"id": "c2",
|
||||
"content": "already normalized",
|
||||
"document_id": "d2",
|
||||
"document_name": "Doc 2",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
streamChan := make(chan string, 10)
|
||||
err := svc.CompletionStream(context.Background(), "user-1", "session-1", []map[string]interface{}{
|
||||
{"role": "assistant", "content": "hello"},
|
||||
}, "", nil, "msg-1", streamChan)
|
||||
if err == nil || !strings.Contains(err.Error(), "not from user") {
|
||||
t.Fatalf("expected 'not from user' error, got %v", err)
|
||||
result := svc.chunksFormat(ref)
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("expected 1 chunk, got %d", len(result))
|
||||
}
|
||||
c := result[0]
|
||||
if c["id"] != "c2" {
|
||||
t.Fatalf("id=%v", c["id"])
|
||||
}
|
||||
if c["content"] != "already normalized" {
|
||||
t.Fatalf("content=%v", c["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionStream_ConversationNotFound(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
func TestChunksFormat_EmptyReference(t *testing.T) {
|
||||
svc := &ChatSessionService{}
|
||||
|
||||
streamChan := make(chan string, 10)
|
||||
err := svc.CompletionStream(context.Background(), "user-1", "missing", []map[string]interface{}{
|
||||
{"role": "user", "content": "hi"},
|
||||
}, "", nil, "msg-1", streamChan)
|
||||
if err == nil || err.Error() != "Conversation not found" {
|
||||
t.Fatalf("expected 'Conversation not found', got %v", err)
|
||||
if n := len(svc.chunksFormat(nil)); n != 0 {
|
||||
t.Fatalf("nil ref: expected 0, got %d", n)
|
||||
}
|
||||
if n := len(svc.chunksFormat(map[string]interface{}{})); n != 0 {
|
||||
t.Fatalf("empty ref: expected 0, got %d", n)
|
||||
}
|
||||
if n := len(svc.chunksFormat(map[string]interface{}{"chunks": nil})); n != 0 {
|
||||
t.Fatalf("nil chunks: expected 0, got %d", n)
|
||||
}
|
||||
if n := len(svc.chunksFormat(map[string]interface{}{"chunks": []map[string]interface{}{}})); n != 0 {
|
||||
t.Fatalf("empty chunks: expected 0, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionStream_DialogNotFound(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1", DialogID: "dialog-1",
|
||||
Message: json.RawMessage(`[]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
func TestChunksFormat_ChunksAsInterfaceSlice(t *testing.T) {
|
||||
svc := &ChatSessionService{}
|
||||
ref := map[string]interface{}{
|
||||
"chunks": []interface{}{
|
||||
map[string]interface{}{
|
||||
"chunk_id": "c3",
|
||||
"content_with_weight": "from interface slice",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
result := svc.chunksFormat(ref)
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("expected 1 chunk, got %d", len(result))
|
||||
}
|
||||
|
||||
streamChan := make(chan string, 10)
|
||||
err := svc.CompletionStream(context.Background(), "user-1", "session-1", []map[string]interface{}{
|
||||
{"role": "user", "content": "hi"},
|
||||
}, "", nil, "msg-1", streamChan)
|
||||
if err == nil || err.Error() != "Dialog not found" {
|
||||
t.Fatalf("expected 'Dialog not found', got %v", err)
|
||||
if result[0]["id"] != "c3" {
|
||||
t.Fatalf("id=%v", result[0]["id"])
|
||||
}
|
||||
if result[0]["content"] != "from interface slice" {
|
||||
t.Fatalf("content=%v", result[0]["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionStream_PipelineError(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1", DialogID: "dialog-1",
|
||||
Message: json.RawMessage(`[]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
func TestChunksFormat_IgnoresNonMapItems(t *testing.T) {
|
||||
svc := &ChatSessionService{}
|
||||
ref := map[string]interface{}{
|
||||
"chunks": []interface{}{
|
||||
"not a map",
|
||||
map[string]interface{}{
|
||||
"chunk_id": "c4",
|
||||
"content_with_weight": "valid chunk",
|
||||
},
|
||||
},
|
||||
}
|
||||
store.dialogs["dialog-1"] = &entity.Chat{
|
||||
ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory",
|
||||
LLMSetting: entity.JSONMap{},
|
||||
result := svc.chunksFormat(ref)
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("expected 1 chunk (non-maps skipped), got %d", len(result))
|
||||
}
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{err: errors.New("model unavailable")},
|
||||
}
|
||||
|
||||
streamChan := make(chan string, 10)
|
||||
err := svc.CompletionStream(context.Background(), "user-1", "session-1", []map[string]interface{}{
|
||||
{"role": "user", "content": "hi"},
|
||||
}, "", nil, "msg-1", streamChan)
|
||||
if err == nil || err.Error() != "model unavailable" {
|
||||
t.Fatalf("expected 'model unavailable' error, got %v", err)
|
||||
if result[0]["id"] != "c4" {
|
||||
t.Fatalf("id=%v", result[0]["id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunksFormat_UnsupportedTypeReturnsEmpty(t *testing.T) {
|
||||
svc := &ChatSessionService{}
|
||||
ref := map[string]interface{}{"chunks": "not a slice"}
|
||||
result := svc.chunksFormat(ref)
|
||||
if len(result) != 0 {
|
||||
t.Fatalf("expected empty for string type, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1980,7 +1980,7 @@ func (s *DatasetService) SearchDatasets(req *SearchDatasetsRequest, userID strin
|
||||
// Apply meta_data_filter to get filtered doc_ids
|
||||
docIDs := make([]string, len(req.DocIDs))
|
||||
copy(docIDs, req.DocIDs)
|
||||
if metadataFilter != nil {
|
||||
if len(metadataFilter) > 0 {
|
||||
metadataSvc := NewMetadataService()
|
||||
flattedMeta, err := metadataSvc.GetFlattedMetaByKBs(datasetIDs)
|
||||
if err != nil {
|
||||
|
||||
@@ -644,16 +644,16 @@ func formatChunks(chunks []map[string]interface{}) []FormattedChunk {
|
||||
for _, chunk := range chunks {
|
||||
out = append(out, FormattedChunk{
|
||||
ID: strVal(getValue(chunk, "chunk_id", "id")),
|
||||
Content: strVal(getValue(chunk, "content", "content_with_weight")),
|
||||
Content: strVal(getValue(chunk, "content_with_weight", "content")),
|
||||
DocumentID: strVal(getValue(chunk, "doc_id", "document_id")),
|
||||
DocumentName: strVal(getValue(chunk, "docnm_kwd", "document_name")),
|
||||
DatasetID: strVal(getValue(chunk, "kb_id", "dataset_id")),
|
||||
ImageID: strVal(getValue(chunk, "image_id", "img_id")),
|
||||
Positions: getValue(chunk, "positions", "position_int"),
|
||||
URL: chunk["url"],
|
||||
Similarity: chunk["similarity"],
|
||||
VectorSimilarity: chunk["vector_similarity"],
|
||||
TermSimilarity: chunk["term_similarity"],
|
||||
Similarity: sanitizeJSONFloats(chunk["similarity"]),
|
||||
VectorSimilarity: sanitizeJSONFloats(chunk["vector_similarity"]),
|
||||
TermSimilarity: sanitizeJSONFloats(chunk["term_similarity"]),
|
||||
RowID: chunk["row_id"],
|
||||
DocType: getValue(chunk, "doc_type_kwd", "doc_type"),
|
||||
DocumentMetadata: chunk["document_metadata"],
|
||||
|
||||
@@ -18,40 +18,41 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"ragflow/internal/tokenizer"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const thinkOpen = "<think>"
|
||||
const thinkClose = "</think>"
|
||||
|
||||
var stripThinkReplacer = strings.NewReplacer("<think>", "", "</think>", "")
|
||||
|
||||
// ThinkStreamState holds accumulated state across streaming LLM chunks
|
||||
// so that <think>...</think> tags can be surfaced as structured markers.
|
||||
//
|
||||
// Corresponds to _ThinkStreamState in api/db/services/dialog_service.py.
|
||||
// so that <think>...</think> tags can be surfaced as structured markers
|
||||
type ThinkStreamState struct {
|
||||
// fullText accumulates all text received so far.
|
||||
fullText string
|
||||
// lastIdx is the last consumed position in fullText.
|
||||
lastIdx int
|
||||
// lastFull is the previous fullText snapshot.
|
||||
lastFull string
|
||||
// lastModelFull is the previous model chunk for diffing.
|
||||
// lastModelFull is the previous model-full snapshot for diffing
|
||||
lastModelFull string
|
||||
// inThink is true when we are currently inside a <think> block.
|
||||
inThink bool
|
||||
// buffer accumulates visible text before flushing (for batching).
|
||||
buffer string
|
||||
// postThinkText holds text between </think> and the next <think> or end
|
||||
// of delta. Kept for API alignment with Python; may be used by future
|
||||
// callers that need per-delta visibility into think boundaries.
|
||||
postThinkText string
|
||||
// closePending defers emission of </think> when no visible text follows the tag
|
||||
closePending bool
|
||||
// pendingAfterClose collects text received after a deferred </think>
|
||||
pendingAfterClose string
|
||||
// thinkBuffer is the think-buffer
|
||||
thinkBuffer string
|
||||
// answerBuffer accumulates answer-side text before token-batch flushing
|
||||
answerBuffer string
|
||||
// carry holds text at the end of a chunk that may be a partial <think> or </think> prefix.
|
||||
carry string
|
||||
}
|
||||
|
||||
// ThinkDeltaKind describes the type of a think-tag delta event.
|
||||
type ThinkDeltaKind int
|
||||
|
||||
const (
|
||||
ThinkDeltaText ThinkDeltaKind = iota // visible answer text
|
||||
ThinkDeltaText ThinkDeltaKind = iota // think-side or answer-side text
|
||||
ThinkDeltaMarker // <think> or </think> tag boundary
|
||||
)
|
||||
|
||||
@@ -61,93 +62,239 @@ type ThinkDelta struct {
|
||||
Value string
|
||||
}
|
||||
|
||||
// emitText returns the batched text and its kind.
|
||||
func emitText(state *ThinkStreamState, section string, text string, minTokens int) (string, ThinkDeltaKind) {
|
||||
if text == "" {
|
||||
return "", 0
|
||||
}
|
||||
if section == "think" {
|
||||
state.thinkBuffer += text
|
||||
if tokenizer.NumTokensFromString(state.thinkBuffer) >= minTokens {
|
||||
out := state.thinkBuffer
|
||||
state.thinkBuffer = ""
|
||||
return out, ThinkDeltaText
|
||||
}
|
||||
return "", 0
|
||||
}
|
||||
state.answerBuffer += text
|
||||
if tokenizer.NumTokensFromString(state.answerBuffer) >= minTokens {
|
||||
out := state.answerBuffer
|
||||
state.answerBuffer = ""
|
||||
return out, ThinkDeltaText
|
||||
}
|
||||
return "", 0
|
||||
}
|
||||
|
||||
func flushThinkBufferInternal(state *ThinkStreamState) ThinkDelta {
|
||||
if state.thinkBuffer == "" {
|
||||
return ThinkDelta{}
|
||||
}
|
||||
out := state.thinkBuffer
|
||||
state.thinkBuffer = ""
|
||||
return ThinkDelta{Kind: ThinkDeltaText, Value: out}
|
||||
}
|
||||
|
||||
func flushAnswerBufferInternal(state *ThinkStreamState) ThinkDelta {
|
||||
if state.answerBuffer == "" {
|
||||
return ThinkDelta{}
|
||||
}
|
||||
out := state.answerBuffer
|
||||
state.answerBuffer = ""
|
||||
return ThinkDelta{Kind: ThinkDeltaText, Value: out}
|
||||
}
|
||||
|
||||
func stripThinkTags(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
return stripThinkReplacer.Replace(s)
|
||||
}
|
||||
|
||||
// tagPrefixLen returns the length of the longest suffix of s that could be a
|
||||
// PARTIAL start of "<think>" or "</think>". Returns 0 if the suffix is a complete
|
||||
// tag or no prefix match exists.
|
||||
func tagPrefixLen(s string) int {
|
||||
for i := 0; i < len(s); i++ {
|
||||
sub := s[i:]
|
||||
if sub == thinkOpen || sub == thinkClose {
|
||||
return 0 // complete tag, not a partial prefix
|
||||
}
|
||||
if strings.HasPrefix(thinkOpen, sub) || strings.HasPrefix(thinkClose, sub) {
|
||||
return len(sub)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// NextThinkDelta processes the next chunk of LLM output and returns any
|
||||
// visible text or tag boundary markers that should be emitted.
|
||||
//
|
||||
// Pure function — no side effects beyond updating state.
|
||||
func NextThinkDelta(state *ThinkStreamState, chunk string) []ThinkDelta {
|
||||
func NextThinkDelta(state *ThinkStreamState, chunk string, minTokens int) []ThinkDelta {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if state.lastFull != "" {
|
||||
// Compute the delta: what's new since lastFull.
|
||||
delta := strings.TrimPrefix(chunk, state.lastFull)
|
||||
state.lastModelFull = delta
|
||||
} else {
|
||||
var newPart string
|
||||
if strings.HasPrefix(chunk, state.lastModelFull) {
|
||||
newPart = chunk[len(state.lastModelFull):]
|
||||
state.lastModelFull = chunk
|
||||
} else {
|
||||
newPart = chunk
|
||||
state.lastModelFull += chunk
|
||||
}
|
||||
state.lastFull = chunk
|
||||
|
||||
// Accumulate fullText from the delta.
|
||||
state.fullText += state.lastModelFull
|
||||
|
||||
// Extract new content since lastIdx.
|
||||
newPart := state.fullText[state.lastIdx:]
|
||||
if len(newPart) == 0 {
|
||||
if newPart == "" {
|
||||
return nil
|
||||
}
|
||||
state.fullText += newPart
|
||||
|
||||
// Prepend carry from previous chunk that may complete a partial tag.
|
||||
pending := state.carry + newPart
|
||||
state.carry = ""
|
||||
|
||||
// Check if pending ends with a partial <think> or </think> prefix.
|
||||
// Save it as carry so it isn't emitted as visible text.
|
||||
if n := tagPrefixLen(pending); n > 0 {
|
||||
state.carry = pending[len(pending)-n:]
|
||||
pending = pending[:len(pending)-n]
|
||||
}
|
||||
|
||||
var deltas []ThinkDelta
|
||||
// Process character by character to detect tag boundaries.
|
||||
for len(newPart) > 0 {
|
||||
if !state.inThink {
|
||||
idx := strings.Index(newPart, thinkOpen)
|
||||
if idx < 0 {
|
||||
// No more think open — buffer everything as visible text.
|
||||
state.buffer += newPart
|
||||
state.lastIdx += len(newPart)
|
||||
break
|
||||
}
|
||||
// Text before <think> is visible answer.
|
||||
if idx > 0 {
|
||||
state.buffer += newPart[:idx]
|
||||
}
|
||||
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkOpen})
|
||||
newPart = newPart[idx+len(thinkOpen):]
|
||||
state.lastIdx += idx + len(thinkOpen)
|
||||
state.inThink = true
|
||||
} else {
|
||||
idx := strings.Index(newPart, thinkClose)
|
||||
if idx < 0 {
|
||||
// Still inside think, consume all silently.
|
||||
state.lastIdx += len(newPart)
|
||||
break
|
||||
}
|
||||
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkClose})
|
||||
state.postThinkText = newPart[:idx]
|
||||
newPart = newPart[idx+len(thinkClose):]
|
||||
state.lastIdx += idx + len(thinkClose)
|
||||
state.inThink = false
|
||||
|
||||
// Phase 1: handle deferred </think> from a previous chunk.
|
||||
if state.closePending {
|
||||
state.closePending = false
|
||||
if piece := flushThinkBufferInternal(state); piece.Value != "" {
|
||||
deltas = append(deltas, piece)
|
||||
}
|
||||
state.inThink = false
|
||||
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkClose})
|
||||
if state.pendingAfterClose != "" {
|
||||
pending = state.pendingAfterClose + pending
|
||||
state.pendingAfterClose = ""
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: process pending text for think tags.
|
||||
for pending != "" {
|
||||
openIdx := strings.Index(pending, thinkOpen)
|
||||
closeIdx := strings.Index(pending, thinkClose)
|
||||
|
||||
// No tags remaining — emit to the appropriate section.
|
||||
if openIdx == -1 && closeIdx == -1 {
|
||||
if piece := stripThinkTags(pending); piece != "" {
|
||||
section := "answer"
|
||||
if state.inThink {
|
||||
section = "think"
|
||||
}
|
||||
if out, kind := emitText(state, section, piece, minTokens); out != "" {
|
||||
deltas = append(deltas, ThinkDelta{Kind: kind, Value: out})
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// <think> appears first (or no </think> found).
|
||||
if openIdx != -1 && (closeIdx == -1 || openIdx < closeIdx) {
|
||||
before := pending[:openIdx]
|
||||
if before != "" {
|
||||
piece := stripThinkTags(before)
|
||||
section := "answer"
|
||||
if state.inThink {
|
||||
section = "think"
|
||||
}
|
||||
if out, kind := emitText(state, section, piece, minTokens); out != "" {
|
||||
deltas = append(deltas, ThinkDelta{Kind: kind, Value: out})
|
||||
}
|
||||
}
|
||||
pending = pending[openIdx+len(thinkOpen):]
|
||||
if !state.inThink {
|
||||
if answerPiece := flushAnswerBufferInternal(state); answerPiece.Value != "" {
|
||||
deltas = append(deltas, answerPiece)
|
||||
}
|
||||
if thinkPiece := flushThinkBufferInternal(state); thinkPiece.Value != "" {
|
||||
deltas = append(deltas, thinkPiece)
|
||||
}
|
||||
state.inThink = true
|
||||
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkOpen})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// </think> appears first.
|
||||
before := pending[:closeIdx]
|
||||
after := pending[closeIdx+len(thinkClose):]
|
||||
if before != "" {
|
||||
piece := stripThinkTags(before)
|
||||
section := "answer"
|
||||
if state.inThink {
|
||||
section = "think"
|
||||
}
|
||||
if out, kind := emitText(state, section, piece, minTokens); out != "" {
|
||||
deltas = append(deltas, ThinkDelta{Kind: kind, Value: out})
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(after) != "" {
|
||||
if thinkPiece := flushThinkBufferInternal(state); thinkPiece.Value != "" {
|
||||
deltas = append(deltas, thinkPiece)
|
||||
}
|
||||
state.inThink = false
|
||||
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkClose})
|
||||
pending = after
|
||||
continue
|
||||
}
|
||||
// No visible text after close — defer the marker.
|
||||
state.closePending = true
|
||||
if after != "" {
|
||||
state.pendingAfterClose += after
|
||||
}
|
||||
pending = ""
|
||||
break
|
||||
}
|
||||
|
||||
return deltas
|
||||
}
|
||||
|
||||
// FlushThinkBuffer drains the buffered visible text, if any, as a single delta.
|
||||
// Call this after all LLM chunks have been processed.
|
||||
func FlushThinkBuffer(state *ThinkStreamState) []ThinkDelta {
|
||||
if state == nil || state.buffer == "" {
|
||||
// FlushRemaining drains all remaining buffered text and handles deferred
|
||||
// markers. Call this after all LLM chunks have been processed.
|
||||
func FlushRemaining(state *ThinkStreamState) []ThinkDelta {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
text := state.buffer
|
||||
state.buffer = ""
|
||||
return []ThinkDelta{{Kind: ThinkDeltaText, Value: text}}
|
||||
var deltas []ThinkDelta
|
||||
if state.thinkBuffer != "" {
|
||||
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaText, Value: state.thinkBuffer})
|
||||
state.thinkBuffer = ""
|
||||
}
|
||||
if state.closePending {
|
||||
state.inThink = false
|
||||
state.closePending = false
|
||||
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkClose})
|
||||
}
|
||||
if state.answerBuffer != "" {
|
||||
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaText, Value: state.answerBuffer})
|
||||
state.answerBuffer = ""
|
||||
}
|
||||
if state.pendingAfterClose != "" {
|
||||
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaText, Value: state.pendingAfterClose})
|
||||
state.pendingAfterClose = ""
|
||||
}
|
||||
if state.carry != "" {
|
||||
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaText, Value: state.carry})
|
||||
state.carry = ""
|
||||
}
|
||||
return deltas
|
||||
}
|
||||
|
||||
// StreamThinkTagDelta — channel-based pipeline.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// StreamThinkTagDelta takes a channel of raw LLM text chunks and produces a
|
||||
// channel of (kind, value) pairs. When ctx is cancelled (e.g. client
|
||||
// disconnect), the goroutine drains the input channel silently and exits,
|
||||
// preventing the producer goroutine from blocking forever on send.
|
||||
//
|
||||
// Markers (<think>, </think>) are emitted immediately without buffering.
|
||||
// channel of structured deltas. When ctx is cancelled (e.g. client
|
||||
// disconnect), the goroutine drains the input channel silently and exits.
|
||||
func StreamThinkTagDelta(ctx context.Context, chunks <-chan string, minTokens int) <-chan ThinkDelta {
|
||||
out := make(chan ThinkDelta, 32)
|
||||
go func() {
|
||||
defer close(out)
|
||||
state := &ThinkStreamState{}
|
||||
flushSize := minTokens * 4 // approximate: ~4 bytes per token
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -158,7 +305,7 @@ func StreamThinkTagDelta(ctx context.Context, chunks <-chan string, minTokens in
|
||||
return
|
||||
case chunk, ok := <-chunks:
|
||||
if !ok {
|
||||
for _, d := range FlushThinkBuffer(state) {
|
||||
for _, d := range FlushRemaining(state) {
|
||||
select {
|
||||
case out <- d:
|
||||
case <-ctx.Done():
|
||||
@@ -167,34 +314,16 @@ func StreamThinkTagDelta(ctx context.Context, chunks <-chan string, minTokens in
|
||||
}
|
||||
return
|
||||
}
|
||||
deltas := NextThinkDelta(state, chunk)
|
||||
deltas := NextThinkDelta(state, chunk, minTokens)
|
||||
for _, d := range deltas {
|
||||
if d.Kind == ThinkDeltaMarker {
|
||||
select {
|
||||
case out <- d:
|
||||
case <-ctx.Done():
|
||||
go func() {
|
||||
for range chunks {
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
// Flush buffered visible text when it reaches the token threshold,
|
||||
// matching Python _stream_with_think_delta which yields ("text", ...)
|
||||
// per chunk. Markers are emitted immediately above.
|
||||
if len(state.buffer) >= flushSize {
|
||||
for _, d := range FlushThinkBuffer(state) {
|
||||
select {
|
||||
case out <- d:
|
||||
case <-ctx.Done():
|
||||
go func() {
|
||||
for range chunks {
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
select {
|
||||
case out <- d:
|
||||
case <-ctx.Done():
|
||||
go func() {
|
||||
for range chunks {
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -203,47 +332,46 @@ func StreamThinkTagDelta(ctx context.Context, chunks <-chan string, minTokens in
|
||||
return out
|
||||
}
|
||||
|
||||
// ExtractVisibleAnswer strips <think> blocks from the raw LLM response,
|
||||
// returning only the visible answer text. If the response consists
|
||||
// entirely of think content, returns an empty string.
|
||||
//
|
||||
// Corresponds to _extract_visible_answer in dialog_service.py.
|
||||
// ExtractVisibleAnswer returns the visible answer text after the last </think>.
|
||||
// Stray <think>/</think> tags are stripped. If there is no </think>, all tags are stripped.
|
||||
func ExtractVisibleAnswer(raw string) string {
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
// Collect all non-think text.
|
||||
var visible []string
|
||||
remaining := raw
|
||||
hasThink := false
|
||||
|
||||
for {
|
||||
openIdx := strings.Index(remaining, thinkOpen)
|
||||
if openIdx < 0 {
|
||||
// No more think open tags — strip any stray </think> and keep the rest.
|
||||
remaining = strings.ReplaceAll(remaining, thinkClose, "")
|
||||
visible = append(visible, remaining)
|
||||
break
|
||||
}
|
||||
hasThink = true
|
||||
if openIdx > 0 {
|
||||
visible = append(visible, remaining[:openIdx])
|
||||
}
|
||||
remaining = remaining[openIdx+len(thinkOpen):]
|
||||
|
||||
closeIdx := strings.Index(remaining, thinkClose)
|
||||
if closeIdx < 0 {
|
||||
// Unclosed think — treat rest as visible.
|
||||
visible = append(visible, remaining)
|
||||
break
|
||||
}
|
||||
remaining = remaining[closeIdx+len(thinkClose):]
|
||||
if !strings.Contains(raw, thinkClose) {
|
||||
return stripThinkTags(raw)
|
||||
}
|
||||
|
||||
result := strings.TrimSpace(strings.Join(visible, ""))
|
||||
if hasThink && result == "" {
|
||||
// Only think content — return empty.
|
||||
return ""
|
||||
}
|
||||
return result
|
||||
lastClose := strings.LastIndex(raw, thinkClose)
|
||||
answer := raw[lastClose+len(thinkClose):]
|
||||
return stripThinkTags(answer)
|
||||
}
|
||||
|
||||
// BufferAnswerDelta processes an answer delta through the think-state lifecycle.
|
||||
// When closePending is true, it first flushes the deferred think buffer and </think>
|
||||
// marker, then processes pendingAfterClose + the new answer text.
|
||||
func BufferAnswerDelta(state *ThinkStreamState, text string, minTokens int) []ThinkDelta {
|
||||
if state == nil || text == "" {
|
||||
return nil
|
||||
}
|
||||
state.fullText += text
|
||||
|
||||
var deltas []ThinkDelta
|
||||
if state.closePending {
|
||||
state.closePending = false
|
||||
if piece := flushThinkBufferInternal(state); piece.Value != "" {
|
||||
deltas = append(deltas, piece)
|
||||
}
|
||||
state.inThink = false
|
||||
deltas = append(deltas, ThinkDelta{Kind: ThinkDeltaMarker, Value: thinkClose})
|
||||
if state.pendingAfterClose != "" {
|
||||
text = state.pendingAfterClose + text
|
||||
state.pendingAfterClose = ""
|
||||
}
|
||||
}
|
||||
|
||||
if out, kind := emitText(state, "answer", text, minTokens); out != "" {
|
||||
deltas = append(deltas, ThinkDelta{Kind: kind, Value: out})
|
||||
}
|
||||
return deltas
|
||||
}
|
||||
|
||||
@@ -24,117 +24,130 @@ import (
|
||||
|
||||
func TestNextThinkDelta_NoThinkTag(t *testing.T) {
|
||||
state := &ThinkStreamState{}
|
||||
deltas := NextThinkDelta(state, "hello world")
|
||||
if len(deltas) != 0 {
|
||||
t.Fatalf("expected 0 deltas, got %d", len(deltas))
|
||||
deltas := NextThinkDelta(state, "hello world", 0)
|
||||
if len(deltas) != 1 {
|
||||
t.Fatalf("expected 1 delta, got %d: %+v", len(deltas), deltas)
|
||||
}
|
||||
if state.buffer != "hello world" {
|
||||
t.Errorf("buffer = %q", state.buffer)
|
||||
if deltas[0].Kind != ThinkDeltaText || deltas[0].Value != "hello world" {
|
||||
t.Errorf("expected text delta, got %+v", deltas[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextThinkDelta_OnlyThinkTag(t *testing.T) {
|
||||
state := &ThinkStreamState{}
|
||||
deltas := NextThinkDelta(state, "<think>reasoning</think>visible")
|
||||
if len(deltas) != 2 {
|
||||
t.Fatalf("expected 2 deltas, got %d: %+v", len(deltas), deltas)
|
||||
deltas := NextThinkDelta(state, "<think>reasoning</think>visible", 0)
|
||||
if len(deltas) < 2 {
|
||||
t.Fatalf("expected at least 2 deltas, got %d: %+v", len(deltas), deltas)
|
||||
}
|
||||
if deltas[0].Kind != ThinkDeltaMarker || deltas[0].Value != "<think>" {
|
||||
t.Errorf("first delta should be <think> marker: %+v", deltas[0])
|
||||
foundOpen, foundClose := false, false
|
||||
for _, d := range deltas {
|
||||
if d.Kind == ThinkDeltaMarker && d.Value == "<think>" {
|
||||
foundOpen = true
|
||||
}
|
||||
if d.Kind == ThinkDeltaMarker && d.Value == "</think>" {
|
||||
foundClose = true
|
||||
}
|
||||
}
|
||||
if deltas[1].Kind != ThinkDeltaMarker || deltas[1].Value != "</think>" {
|
||||
t.Errorf("second delta should be </think> marker: %+v", deltas[1])
|
||||
}
|
||||
if state.buffer != "visible" {
|
||||
t.Errorf("buffer = %q, want visible", state.buffer)
|
||||
if !foundOpen || !foundClose {
|
||||
t.Errorf("missing markers: open=%v close=%v, deltas=%+v", foundOpen, foundClose, deltas)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextThinkDelta_TextThenThink(t *testing.T) {
|
||||
state := &ThinkStreamState{}
|
||||
deltas := NextThinkDelta(state, "before <think>inside</think> after")
|
||||
// "before " -> buffer (no flush yet)
|
||||
// "<think>" -> marker
|
||||
// "inside" -> inside think, consumed silently
|
||||
// "</think>" -> marker
|
||||
// " after" -> buffer
|
||||
if len(deltas) != 2 {
|
||||
t.Fatalf("expected 2 markers, got %d: %+v", len(deltas), deltas)
|
||||
deltas := NextThinkDelta(state, "before <think>inside</think> after", 0)
|
||||
foundOpen, foundClose := false, false
|
||||
for _, d := range deltas {
|
||||
if d.Kind == ThinkDeltaMarker && d.Value == "<think>" {
|
||||
foundOpen = true
|
||||
}
|
||||
if d.Kind == ThinkDeltaMarker && d.Value == "</think>" {
|
||||
foundClose = true
|
||||
}
|
||||
}
|
||||
if state.buffer != "before after" {
|
||||
t.Errorf("buffer = %q", state.buffer)
|
||||
if !foundOpen || !foundClose {
|
||||
t.Errorf("missing markers: open=%v close=%v, deltas=%+v", foundOpen, foundClose, deltas)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextThinkDelta_MultipleChunks(t *testing.T) {
|
||||
state := &ThinkStreamState{}
|
||||
NextThinkDelta(state, "hello ")
|
||||
NextThinkDelta(state, "<think>")
|
||||
NextThinkDelta(state, "reasoning")
|
||||
NextThinkDelta(state, "</think>")
|
||||
deltas := NextThinkDelta(state, " world")
|
||||
if len(deltas) != 0 {
|
||||
t.Fatalf("expected 0 deltas from final chunk, got %d", len(deltas))
|
||||
deltas := NextThinkDelta(state, "hello ", 0)
|
||||
_ = deltas
|
||||
deltas = NextThinkDelta(state, "<think>", 0)
|
||||
_ = deltas
|
||||
deltas = NextThinkDelta(state, "reasoning", 0)
|
||||
_ = deltas
|
||||
deltas = NextThinkDelta(state, "</think>", 0)
|
||||
_ = deltas
|
||||
deltas = NextThinkDelta(state, " world", 0)
|
||||
// After deferral, closing marker and answer text arrive in the last chunk.
|
||||
foundClose := false
|
||||
for _, d := range deltas {
|
||||
if d.Kind == ThinkDeltaMarker && d.Value == "</think>" {
|
||||
foundClose = true
|
||||
}
|
||||
}
|
||||
if state.buffer != "hello world" {
|
||||
t.Errorf("buffer = %q", state.buffer)
|
||||
if !foundClose {
|
||||
t.Errorf("expected </think> in final chunk, got %+v", deltas)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextThinkDelta_UnclosedThink(t *testing.T) {
|
||||
state := &ThinkStreamState{}
|
||||
deltas := NextThinkDelta(state, "text <think>unclosed")
|
||||
if len(deltas) != 1 {
|
||||
t.Fatalf("expected 1 marker (think open), got %d", len(deltas))
|
||||
deltas := NextThinkDelta(state, "text <think>unclosed", 0)
|
||||
foundOpen := false
|
||||
for _, d := range deltas {
|
||||
if d.Kind == ThinkDeltaMarker && d.Value == "<think>" {
|
||||
foundOpen = true
|
||||
}
|
||||
}
|
||||
if deltas[0].Value != "<think>" {
|
||||
t.Errorf("expected <think> marker")
|
||||
if !foundOpen {
|
||||
t.Errorf("expected <think> marker in %+v", deltas)
|
||||
}
|
||||
if state.buffer != "text " {
|
||||
t.Errorf("buffer = %q", state.buffer)
|
||||
}
|
||||
// "unclosed" should be consumed silently inside think
|
||||
}
|
||||
|
||||
func TestNextThinkDelta_EmptyInput(t *testing.T) {
|
||||
state := &ThinkStreamState{}
|
||||
deltas := NextThinkDelta(state, "")
|
||||
deltas := NextThinkDelta(state, "", 0)
|
||||
if len(deltas) != 0 {
|
||||
t.Errorf("expected 0 deltas for empty input")
|
||||
t.Errorf("expected 0 deltas for empty input, got %d", len(deltas))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextThinkDelta_NilState(t *testing.T) {
|
||||
deltas := NextThinkDelta(nil, "test")
|
||||
deltas := NextThinkDelta(nil, "test", 0)
|
||||
if deltas != nil {
|
||||
t.Error("expected nil for nil state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlushThinkBuffer_Empty(t *testing.T) {
|
||||
if deltas := FlushThinkBuffer(nil); len(deltas) != 0 {
|
||||
t.Error("expected empty for nil state")
|
||||
func TestFlushRemaining_FlushesAll(t *testing.T) {
|
||||
state := &ThinkStreamState{
|
||||
thinkBuffer: "think-tail",
|
||||
closePending: true,
|
||||
answerBuffer: "answer-tail",
|
||||
pendingAfterClose: "pending-tail",
|
||||
}
|
||||
state := &ThinkStreamState{}
|
||||
if deltas := FlushThinkBuffer(state); len(deltas) != 0 {
|
||||
t.Error("expected empty for zero state")
|
||||
deltas := FlushRemaining(state)
|
||||
if len(deltas) != 4 {
|
||||
t.Fatalf("expected 4 deltas, got %d: %+v", len(deltas), deltas)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlushThinkBuffer_WithContent(t *testing.T) {
|
||||
state := &ThinkStreamState{buffer: "flushed text"}
|
||||
deltas := FlushThinkBuffer(state)
|
||||
if len(deltas) != 1 {
|
||||
t.Fatalf("expected 1 delta, got %d", len(deltas))
|
||||
if deltas[0].Kind != ThinkDeltaText || deltas[0].Value != "think-tail" {
|
||||
t.Errorf("delta[0] = %+v", deltas[0])
|
||||
}
|
||||
if deltas[0].Kind != ThinkDeltaText {
|
||||
t.Error("expected text delta")
|
||||
if deltas[1].Kind != ThinkDeltaMarker || deltas[1].Value != "</think>" {
|
||||
t.Errorf("delta[1] = %+v", deltas[1])
|
||||
}
|
||||
if deltas[0].Value != "flushed text" {
|
||||
t.Errorf("value = %q", deltas[0].Value)
|
||||
if deltas[2].Kind != ThinkDeltaText || deltas[2].Value != "answer-tail" {
|
||||
t.Errorf("delta[2] = %+v", deltas[2])
|
||||
}
|
||||
if state.buffer != "" {
|
||||
t.Error("buffer should be cleared after flush")
|
||||
if deltas[3].Kind != ThinkDeltaText || deltas[3].Value != "pending-tail" {
|
||||
t.Errorf("delta[3] = %+v", deltas[3])
|
||||
}
|
||||
// State should be cleared.
|
||||
if state.closePending || state.inThink {
|
||||
t.Error("state not fully cleared")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,13 +173,13 @@ func TestExtractVisibleAnswer_WithThink(t *testing.T) {
|
||||
func TestExtractVisibleAnswer_ThinkOnly(t *testing.T) {
|
||||
raw := "<think>only reasoning here</think>"
|
||||
if got := ExtractVisibleAnswer(raw); got != "" {
|
||||
t.Errorf("expected empty for think-only, got %q", got)
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractVisibleAnswer_MultipleThinks(t *testing.T) {
|
||||
raw := "<think>first</think>visible1<think>second</think>visible2"
|
||||
if got := ExtractVisibleAnswer(raw); got != "visible1visible2" {
|
||||
if got := ExtractVisibleAnswer(raw); got != "visible2" {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -178,6 +191,18 @@ func TestExtractVisibleAnswer_NestedTags(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractVisibleAnswer_NoTags(t *testing.T) {
|
||||
if got := ExtractVisibleAnswer("plain text"); got != "plain text" {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractVisibleAnswer_StrayTag(t *testing.T) {
|
||||
if got := ExtractVisibleAnswer("<think>text"); got != "text" {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamThinkTagDelta(t *testing.T) {
|
||||
chunks := []string{"hello ", "wor", "<think>", "think text", "</think>", "ld", " final"}
|
||||
ch := make(chan string, len(chunks))
|
||||
@@ -208,14 +233,12 @@ func TestStreamThinkTagDelta(t *testing.T) {
|
||||
}
|
||||
|
||||
joined := strings.Join(texts, "")
|
||||
if !strings.Contains(joined, "hello world") || !strings.Contains(joined, "final") {
|
||||
if !strings.Contains(joined, "hello wor") || !strings.Contains(joined, "final") {
|
||||
t.Errorf("texts = %q", texts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamThinkTagDelta_IncrementalFlush(t *testing.T) {
|
||||
// Verify that visible text is streamed incrementally, not all at the end.
|
||||
// minTokens=1 → flushSize=4 bytes. Each chunk triggers a flush.
|
||||
chunks := []string{"1234", "5678", "90ab"}
|
||||
ch := make(chan string, len(chunks))
|
||||
for _, c := range chunks {
|
||||
@@ -229,8 +252,6 @@ func TestStreamThinkTagDelta_IncrementalFlush(t *testing.T) {
|
||||
texts = append(texts, d.Value)
|
||||
}
|
||||
}
|
||||
// With minTokens=1 (flushSize=4), each chunk triggers a flush.
|
||||
// We should get incremental text deltas, not just one final burst.
|
||||
if len(texts) < 2 {
|
||||
t.Errorf("expected >=2 incremental text deltas, got %d: %q", len(texts), texts)
|
||||
}
|
||||
@@ -254,3 +275,29 @@ func TestStreamThinkTagDelta_NoThinkTags(t *testing.T) {
|
||||
t.Errorf("got %q, want 'just plain text'", joined)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamThinkTagDelta_DeferredClose(t *testing.T) {
|
||||
// When </think> has no visible text after it, the marker is deferred.
|
||||
chunks := []string{"<think>", "hello", "</think>", "world"}
|
||||
ch := make(chan string, len(chunks))
|
||||
for _, c := range chunks {
|
||||
ch <- c
|
||||
}
|
||||
close(ch)
|
||||
|
||||
var markers []string
|
||||
for d := range StreamThinkTagDelta(context.Background(), ch, 1) {
|
||||
if d.Kind == ThinkDeltaMarker {
|
||||
markers = append(markers, d.Value)
|
||||
}
|
||||
}
|
||||
if len(markers) != 2 {
|
||||
t.Fatalf("expected 2 markers, got %d: %v", len(markers), markers)
|
||||
}
|
||||
if markers[0] != "<think>" {
|
||||
t.Errorf("first marker = %q", markers[0])
|
||||
}
|
||||
if markers[1] != "</think>" {
|
||||
t.Errorf("second marker = %q", markers[1])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,6 +103,11 @@ export default defineConfig(({ mode }) => {
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'^(/api/v1/datasets/search)|^(/api/v1/chat/completions)': {
|
||||
target: 'http://127.0.0.1:9384/',
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
},
|
||||
'/api': {
|
||||
target: 'http://127.0.0.1:9380/',
|
||||
changeOrigin: true,
|
||||
|
||||
Reference in New Issue
Block a user