From 6d9430a1254e0e33177a3fcf41544c9ed5c5355a Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Fri, 3 Apr 2026 18:11:23 +0800 Subject: [PATCH] Add think chat to CLI (#13922) ### What problem does this PR solve? Now user can use 'think mode' to chat with LLM ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Signed-off-by: Jin Hai --- internal/cli/cli.go | 23 +- internal/cli/client.go | 2 + internal/cli/common_command.go | 47 +--- internal/cli/http_client.go | 47 ++++ internal/cli/lexer.go | 2 + internal/cli/parser.go | 2 + internal/cli/response.go | 26 ++ internal/cli/types.go | 1 + internal/cli/user_command.go | 78 +++++- internal/cli/user_parser.go | 17 +- internal/entity/models/dummy.go | 10 + internal/entity/models/types.go | 14 + internal/entity/models/zhipu-ai.go | 317 +++++++++++++++++++++-- internal/entity/tenant_model_instance.go | 4 +- internal/handler/providers.go | 58 ++++- internal/service/model_service.go | 121 +++++++++ 16 files changed, 684 insertions(+), 85 deletions(-) diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 198cb92ff1..5a72327ec6 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -332,7 +332,7 @@ func looksLikeSQL(s string) bool { "LIST ", "SHOW ", "CREATE ", "DROP ", "ALTER ", "LOGIN ", "REGISTER ", "PING", "GRANT ", "REVOKE ", "SET ", "UNSET ", "UPDATE ", "DELETE ", "INSERT ", - "SELECT ", "DESCRIBE ", "EXPLAIN ", "ADD ", "ENABLE ", "DISABLE ", "CHAT ", "USE", + "SELECT ", "DESCRIBE ", "EXPLAIN ", "ADD ", "ENABLE ", "DISABLE ", "CHAT ", "USE", "THINK", } for _, prefix := range sqlPrefixes { if strings.HasPrefix(s, prefix) { @@ -480,26 +480,19 @@ func NewCLIWithArgs(args *ConnectionArgs) (*CLI, error) { func (c *CLI) Run() error { // If username is provided without password, prompt for password if c.args != nil && c.args.UserName != "" && c.args.Password == "" && c.args.APIToken == "" { - // Allow 3 attempts for password verification maxAttempts := 3 for attempt := 1; attempt <= maxAttempts; attempt++ { - var input string - var err error + fmt.Print("Please input your password: ") + + passwordBytes, err := term.ReadPassword(int(os.Stdin.Fd())) + fmt.Println() - // Check if terminal supports password masking - if term.IsTerminal(int(os.Stdin.Fd())) { - input, err = c.line.PasswordPrompt("Please input your password: ") - } else { - // Terminal doesn't support password masking, use regular prompt - fmt.Println("Warning: This terminal does not support secure password input") - input, err = c.line.Prompt("Please input your password (will be visible): ") - } if err != nil { - fmt.Printf("Error reading input: %v\n", err) + fmt.Printf("Error reading password: %v\n", err) return err } - input = strings.TrimSpace(input) + input := strings.TrimSpace(string(passwordBytes)) if input == "" { if attempt < maxAttempts { @@ -509,7 +502,6 @@ func (c *CLI) Run() error { return errors.New("no password provided after 3 attempts") } - // Set the password for verification c.args.Password = input if err = c.VerifyAuth(); err != nil { @@ -520,7 +512,6 @@ func (c *CLI) Run() error { return fmt.Errorf("authentication failed after %d attempts: %v", maxAttempts, err) } - // Authentication successful break } } diff --git a/internal/cli/client.go b/internal/cli/client.go index 3054db7d66..c760c25ac4 100644 --- a/internal/cli/client.go +++ b/internal/cli/client.go @@ -247,6 +247,8 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) { return c.EnableOrDisableModel(cmd, "disable") case "chat_to_model": return c.ChatToModel(cmd) + case "think_chat_to_model": + return c.ChatToModel(cmd) case "use_model": return c.UseModel(cmd) case "show_current_model": diff --git a/internal/cli/common_command.go b/internal/cli/common_command.go index 2940b131c1..0c71e79e41 100644 --- a/internal/cli/common_command.go +++ b/internal/cli/common_command.go @@ -21,10 +21,9 @@ import ( "encoding/json" "fmt" "os" - "os/exec" "strings" - "syscall" - "unsafe" + + "golang.org/x/term" ) // LoginUserInteractive performs interactive login with username and password @@ -376,44 +375,24 @@ func (c *RAGFlowClient) ShowModel(cmd *Command) (ResponseIf, error) { // readPassword reads password from terminal without echoing func readPassword() (string, error) { - // Check if stdin is a terminal by trying to get terminal size - if isTerminal() { - // Use stty to disable echo - cmd := exec.Command("stty", "-echo") - cmd.Stdin = os.Stdin - if err := cmd.Run(); err != nil { - // Fallback: read normally - return readPasswordFallback() - } - defer func() { - // Re-enable echo - cmd := exec.Command("stty", "echo") - cmd.Stdin = os.Stdin - cmd.Run() - }() - - reader := bufio.NewReader(os.Stdin) - password, err := reader.ReadString('\n') - fmt.Println() // New line after password input - if err != nil { - return "", err - } - return strings.TrimSpace(password), nil + if !term.IsTerminal(int(os.Stdin.Fd())) { + return readPasswordFallback() } - // Fallback for non-terminal input (e.g., piped input) - return readPasswordFallback() -} + fmt.Print("Password: ") + passwordBytes, err := term.ReadPassword(int(os.Stdin.Fd())) + fmt.Println() -// isTerminal checks if stdin is a terminal -func isTerminal() bool { - var termios syscall.Termios - _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, os.Stdin.Fd(), syscall.TCGETS, uintptr(unsafe.Pointer(&termios)), 0, 0, 0) - return err == 0 + if err != nil { + return "", err + } + + return strings.TrimSpace(string(passwordBytes)), nil } // readPasswordFallback reads password as plain text (fallback mode) func readPasswordFallback() (string, error) { + fmt.Print("Password (will be visible): ") reader := bufio.NewReader(os.Stdin) password, err := reader.ReadString('\n') if err != nil { diff --git a/internal/cli/http_client.go b/internal/cli/http_client.go index eb5ab1a804..eed0d8be7b 100644 --- a/internal/cli/http_client.go +++ b/internal/cli/http_client.go @@ -326,3 +326,50 @@ func (c *HTTPClient) RequestJSON(method, path string, useAPIBase bool, authKind } return resp.JSON() } + +// RequestStream makes an HTTP request for SSE streaming and returns the response body reader +func (c *HTTPClient) RequestStream(method, path string, useAPIBase bool, authKind string, headers map[string]string, jsonBody map[string]interface{}) (io.ReadCloser, float64, error) { + url := c.BuildURL(path, useAPIBase) + mergedHeaders := c.Headers(authKind, headers) + + var body io.Reader + if jsonBody != nil { + jsonData, err := json.Marshal(jsonBody) + if err != nil { + return nil, 0, err + } + body = bytes.NewReader(jsonData) + if mergedHeaders == nil { + mergedHeaders = make(map[string]string) + } + mergedHeaders["Content-Type"] = "application/json" + } + // Add Accept header for SSE + if mergedHeaders == nil { + mergedHeaders = make(map[string]string) + } + mergedHeaders["Accept"] = "text/event-stream" + + req, err := http.NewRequest(method, url, body) + if err != nil { + return nil, 0, err + } + + for k, v := range mergedHeaders { + req.Header.Set(k, v) + } + + startTime := time.Now() + resp, err := c.client.Do(req) + if err != nil { + return nil, 0, err + } + duration := time.Since(startTime).Seconds() + + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, duration, fmt.Errorf("HTTP %d", resp.StatusCode) + } + + return resp.Body, duration, nil +} diff --git a/internal/cli/lexer.go b/internal/cli/lexer.go index cc8d6c6d4a..d641e31431 100644 --- a/internal/cli/lexer.go +++ b/internal/cli/lexer.go @@ -255,6 +255,8 @@ func (l *Lexer) lookupIdent(ident string) Token { return Token{Type: TokenChats, Value: ident} case "CHAT": return Token{Type: TokenChat, Value: ident} + case "THINK": + return Token{Type: TokenThink, Value: ident} case "FILES": return Token{Type: TokenFiles, Value: ident} case "AS": diff --git a/internal/cli/parser.go b/internal/cli/parser.go index cb26220252..bc2e83fbb6 100644 --- a/internal/cli/parser.go +++ b/internal/cli/parser.go @@ -192,6 +192,8 @@ func (p *Parser) parseUserCommand() (*Command, error) { return p.parseDisableCommand() case TokenChat: return p.parseChatCommand() + case TokenThink: + return p.parseThinkCommand() case TokenUse: return p.parseUseCommand() default: diff --git a/internal/cli/response.go b/internal/cli/response.go index 16934aa0e4..d97ec0a461 100644 --- a/internal/cli/response.go +++ b/internal/cli/response.go @@ -141,6 +141,32 @@ func (r *MessageResponse) PrintOut() { } } +type StreamMessageResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Duration float64 + outputFormat OutputFormat +} + +func (r *StreamMessageResponse) Type() string { + return "stream_message" +} + +func (r *StreamMessageResponse) TimeCost() float64 { + return r.Duration +} + +func (r *StreamMessageResponse) SetOutputFormat(format OutputFormat) { + r.outputFormat = format +} + +func (r *StreamMessageResponse) PrintOut() { + if r.Code != 0 { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + } +} + type RegisterResponse struct { Code int `json:"code"` Message string `json:"message"` diff --git a/internal/cli/types.go b/internal/cli/types.go index d1f5826056..6fc46e91d7 100644 --- a/internal/cli/types.go +++ b/internal/cli/types.go @@ -111,6 +111,7 @@ const ( TokenDisable TokenEnable TokenUse + TokenThink TokenInsert TokenFile TokenMetadata diff --git a/internal/cli/user_command.go b/internal/cli/user_command.go index 90dd160d2e..260c040a23 100644 --- a/internal/cli/user_command.go +++ b/internal/cli/user_command.go @@ -17,9 +17,11 @@ package cli import ( + "bufio" "context" "encoding/json" "fmt" + "os" ce "ragflow/internal/cli/contextengine" "strings" ) @@ -1187,29 +1189,83 @@ func (c *RAGFlowClient) ChatToModel(cmd *Command) (ResponseIf, error) { } message := cmd.Params["message"].(string) + reasoning := cmd.Params["reasoning"].(bool) url := fmt.Sprintf("/providers/%s/instances/%s/models/%s", providerName, instanceName, modelName) payload := map[string]interface{}{ - "message": message, + "message": message, + "stream": true, // use stream API + "reasoning": reasoning, } - resp, err := c.HTTPClient.Request("POST", url, true, "web", nil, payload) + // Call stream http api + reader, duration, err := c.HTTPClient.RequestStream("POST", url, true, "web", nil, payload) if err != nil { return nil, fmt.Errorf("failed to chat model: %w", err) } - if resp.StatusCode != 200 { - return nil, fmt.Errorf("failed to chat model: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + defer reader.Close() + + // Parse SSE and output to console + scanner := bufio.NewScanner(reader) + var fullMessage strings.Builder + + reasoningPrint := true + messagePrint := true + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "data:") { + data := strings.TrimPrefix(line, "data:") + data = strings.TrimSpace(data) + + if strings.HasPrefix(data, "[REASONING]") { + data = strings.TrimPrefix(data, "[REASONING]") + if reasoningPrint { + fmt.Print("Thinking: ") + reasoningPrint = false + } else { + fmt.Print(data) + } + os.Stdout.Sync() + } + if strings.HasPrefix(data, "[MESSAGE]") { + data = strings.TrimPrefix(data, "[MESSAGE]") + if messagePrint { + if reasoning { + fmt.Println() + } + fmt.Print("Answer: ") + messagePrint = false + } else { + fmt.Print(data) + os.Stdout.Sync() + fullMessage.WriteString(data) + } + } + } else if strings.HasPrefix(line, "event:error") { + // error event + if scanner.Scan() { + errData := strings.TrimPrefix(scanner.Text(), "data:") + errData = strings.TrimSpace(errData) + return nil, fmt.Errorf("chat error: %s", errData) + } + // If there's an error, return a generic error + return nil, fmt.Errorf("chat error: received error event from server") + } } - var result MessageResponse - if err = json.Unmarshal(resp.Body, &result); err != nil { - return nil, fmt.Errorf("chat model failed: invalid JSON (%w)", err) + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("error reading stream: %w", err) } - if result.Code != 0 { - return nil, fmt.Errorf("%s", result.Message) + + fmt.Println() + + result := &StreamMessageResponse{ + Code: 0, + Message: fullMessage.String(), + Duration: duration, } - result.Duration = resp.Duration - return &result, nil + return result, nil } // UseModel sets the current model for chat diff --git a/internal/cli/user_parser.go b/internal/cli/user_parser.go index 2069fa3fde..4ad5742071 100644 --- a/internal/cli/user_parser.go +++ b/internal/cli/user_parser.go @@ -2025,14 +2025,14 @@ func (p *Parser) parseChatCommand() (*Command, error) { // Format: 'provider/instance/model' or just 'message' if p.curToken.Type == TokenQuotedString { firstArg := p.curToken.Value - + // Check if it looks like a model identifier (contains exactly 2 slashes) slashCount := strings.Count(firstArg, "/") if slashCount == 2 { // This is likely a model identifier, expect another quoted string for message modelName = firstArg p.nextToken() - + // After model name, expect message if p.curToken.Type != TokenQuotedString { return nil, fmt.Errorf("expected message after model name") @@ -2062,9 +2062,22 @@ func (p *Parser) parseChatCommand() (*Command, error) { cmd.Params["model_name"] = modelName } cmd.Params["message"] = message + cmd.Params["reasoning"] = false return cmd, nil } +func (p *Parser) parseThinkCommand() (*Command, error) { + + p.nextToken() // consume THINK + command, err := p.parseChatCommand() + if err != nil { + return nil, err + } + command.Type = "think_chat_to_model" + command.Params["reasoning"] = true + return command, nil +} + func (p *Parser) parseUseCommand() (*Command, error) { p.nextToken() // consume USE diff --git a/internal/entity/models/dummy.go b/internal/entity/models/dummy.go index 84ebea3191..ab74463feb 100644 --- a/internal/entity/models/dummy.go +++ b/internal/entity/models/dummy.go @@ -44,6 +44,16 @@ func (z *DummyModel) ChatStreamly(modelName, apiKey, message *string, genConf ma return nil, fmt.Errorf("not implemented") } +// ChatStreamlyWithChannel sends a message and streams response to channel (better performance) +func (z *DummyModel) ChatStreamlyWithChannel(modelName, apiKey, message *string, genConf map[string]interface{}, resultChan chan<- string) error { + return fmt.Errorf("not implemented") +} + +// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel) +func (z *DummyModel) ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error { + return fmt.Errorf("not implemented") +} + // EncodeToEmbedding encodes a list of texts into embeddings func (z *DummyModel) EncodeToEmbedding(modelName, apiKey *string, texts []string) ([][]float64, error) { return nil, fmt.Errorf("not implemented") diff --git a/internal/entity/models/types.go b/internal/entity/models/types.go index 42f80039ee..dc13db942a 100644 --- a/internal/entity/models/types.go +++ b/internal/entity/models/types.go @@ -6,6 +6,10 @@ type ModelDriver interface { Chat(modelName, apiKey, message *string, genConf map[string]interface{}) (string, error) // ChatStreamly sends a message and streams response ChatStreamly(modelName, apiKey, message *string, genConf map[string]interface{}) (<-chan string, error) + // ChatStreamlyWithChannel sends a message and streams response to channel (better performance) + ChatStreamlyWithChannel(modelName, apiKey, message *string, genConf map[string]interface{}, resultChan chan<- string) error + // ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel) + ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error // Encode encodes a list of texts into embeddings EncodeToEmbedding(modelName, apiKey *string, texts []string) ([][]float64, error) } @@ -18,3 +22,13 @@ type URLSuffix struct { Embedding string `json:"embedding"` Rerank string `json:"rerank"` } + +type ChatConfig struct { + Stream *bool + Reasoning *bool + MaxTokens *int + Temperature *float64 + TopP *float64 + DoSample *bool + Stop *[]string +} diff --git a/internal/entity/models/zhipu-ai.go b/internal/entity/models/zhipu-ai.go index 1f17c8f322..417f242f7c 100644 --- a/internal/entity/models/zhipu-ai.go +++ b/internal/entity/models/zhipu-ai.go @@ -17,17 +17,22 @@ package models import ( + "bufio" "bytes" "encoding/json" "fmt" "io" "net/http" + "ragflow/internal/logger" + "strings" + "time" ) -// ZhipuAIModel implements ModelDriver for Zhipu AI (智谱 AI) +// ZhipuAIModel implements ModelDriver for Zhipu AI type ZhipuAIModel struct { - BaseURL string - URLSuffix URLSuffix + BaseURL string + URLSuffix URLSuffix + httpClient *http.Client // Reusable HTTP client with connection pool } // NewZhipuAIModel creates a new Zhipu AI model instance @@ -35,6 +40,15 @@ func NewZhipuAIModel(baseURL string, urlSuffix URLSuffix) *ZhipuAIModel { return &ZhipuAIModel{ BaseURL: baseURL, URLSuffix: urlSuffix, + httpClient: &http.Client{ + Timeout: 120 * time.Second, + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + DisableCompression: false, + }, + }, } } @@ -82,8 +96,7 @@ func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, genConf map[stri req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey)) - client := &http.Client{} - resp, err := client.Do(req) + resp, err := z.httpClient.Do(req) if err != nil { return "", fmt.Errorf("failed to send request: %w", err) } @@ -137,7 +150,8 @@ func (z *ZhipuAIModel) ChatStreamly(modelName, apiKey, message *string, genConf "messages": []map[string]string{ {"role": "user", "content": *message}, }, - "stream": true, + "stream": true, + "temperature": 1, } // Add generation config if provided @@ -164,10 +178,9 @@ func (z *ZhipuAIModel) ChatStreamly(modelName, apiKey, message *string, genConf } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey)) - client := &http.Client{} - resp, err := client.Do(req) + resp, err := z.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } @@ -185,14 +198,28 @@ func (z *ZhipuAIModel) ChatStreamly(modelName, apiKey, message *string, genConf defer close(resultChan) defer resp.Body.Close() - decoder := json.NewDecoder(resp.Body) - for { + // SSE parsing: read line by line + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + + // SSE data line starts with "data:" + if !strings.HasPrefix(line, "data:") { + continue + } + + // Extract JSON after "data:" + data := strings.TrimSpace(line[5:]) + + // [DONE] marks the end of stream + if data == "[DONE]" { + break + } + + // Parse the JSON event var event map[string]interface{} - if err := decoder.Decode(&event); err != nil { - if err == io.EOF { - break - } - return + if err := json.Unmarshal([]byte(data), &event); err != nil { + continue } choices, ok := event["choices"].([]interface{}) @@ -225,6 +252,259 @@ func (z *ZhipuAIModel) ChatStreamly(modelName, apiKey, message *string, genConf return resultChan, nil } +// ChatStreamlyWithChannel sends a message and streams response to channel (better performance) +func (z *ZhipuAIModel) ChatStreamlyWithChannel(modelName, apiKey, message *string, genConf map[string]interface{}, resultChan chan<- string) error { + url := fmt.Sprintf("%s/chat/completions", z.BaseURL) + + // Build request body with streaming enabled + reqBody := map[string]interface{}{ + "model": modelName, + "messages": []map[string]string{ + {"role": "user", "content": *message}, + }, + "stream": true, + "temperature": 1, + } + + // Add generation config if provided + if genConf != nil { + if maxTokens, ok := genConf["max_tokens"]; ok { + reqBody["max_tokens"] = maxTokens + } + if temperature, ok := genConf["temperature"]; ok { + reqBody["temperature"] = temperature + } + if topP, ok := genConf["top_p"]; ok { + reqBody["top_p"] = topP + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey)) + + resp, err := z.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // SSE parsing: read line by line + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + logger.Info(line) + + // SSE data line starts with "data:" + if !strings.HasPrefix(line, "data:") { + continue + } + + // Extract JSON after "data:" + data := strings.TrimSpace(line[5:]) + + // [DONE] marks the end of stream + if data == "[DONE]" { + break + } + + // Parse the JSON event + var event map[string]interface{} + if err := json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + + delta, ok := firstChoice["delta"].(map[string]interface{}) + if !ok { + continue + } + + content, ok := delta["content"].(string) + if ok && content != "" { + resultChan <- content + } + + finishReason, ok := firstChoice["finish_reason"].(string) + if ok && finishReason != "" { + break + } + } + + // Send [DONE] marker for OpenAI compatibility + resultChan <- "[DONE]" + + return scanner.Err() +} + +// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel) +func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error { + url := fmt.Sprintf("%s/chat/completions", z.BaseURL) + + // Build request body with streaming enabled + reqBody := map[string]interface{}{ + "model": modelName, + "messages": []map[string]string{ + {"role": "user", "content": *message}, + }, + "stream": false, + "temperature": 1, + } + + if modelConfig != nil { + if modelConfig.Stream != nil { + reqBody["stream"] = *modelConfig.Stream + } + + if modelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *modelConfig.MaxTokens + } + + if modelConfig.Temperature != nil { + reqBody["temperature"] = *modelConfig.Temperature + } + + if modelConfig.DoSample != nil { + reqBody["do_sample"] = *modelConfig.DoSample + } + + if modelConfig.TopP != nil { + reqBody["top_p"] = *modelConfig.TopP + } + + if modelConfig.Stop != nil { + reqBody["stop"] = *modelConfig.Stop + } + + if modelConfig.Reasoning != nil { + if *modelConfig.Reasoning { + reqBody["thinking"] = map[string]interface{}{ + "type": "enabled", + } + } else { + reqBody["thinking"] = map[string]interface{}{ + "type": "disabled", + } + } + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey)) + + resp, err := z.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // SSE parsing: read line by line + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + logger.Info(line) + + // SSE data line starts with "data:" + if !strings.HasPrefix(line, "data:") { + continue + } + + // Extract JSON after "data:" + data := strings.TrimSpace(line[5:]) + + // [DONE] marks the end of stream + if data == "[DONE]" { + break + } + + // Parse the JSON event + var event map[string]interface{} + if err := json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + + delta, ok := firstChoice["delta"].(map[string]interface{}) + if !ok { + continue + } + + content, ok := delta["content"].(string) + if ok && content != "" { + if err := sender(&content, nil); err != nil { + return err + } + } + + reasoningContent, ok := delta["reasoning_content"].(string) + if ok && reasoningContent != "" { + if err := sender(nil, &reasoningContent); err != nil { + return err + } + } + + finishReason, ok := firstChoice["finish_reason"].(string) + if ok && finishReason != "" { + break + } + } + + // Send [DONE] marker for OpenAI compatibility + endOfStream := "[DONE]" + if err := sender(&endOfStream, nil); err != nil { + return err + } + + return scanner.Err() +} + // EncodeToEmbedding encodes a list of texts into embeddings func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []string) ([][]float64, error) { url := fmt.Sprintf("%s/embedding", z.BaseURL) @@ -248,10 +528,9 @@ func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []stri } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey)) - client := &http.Client{} - resp, err := client.Do(req) + resp, err := z.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } diff --git a/internal/entity/tenant_model_instance.go b/internal/entity/tenant_model_instance.go index de5da075af..0a0a9f5149 100644 --- a/internal/entity/tenant_model_instance.go +++ b/internal/entity/tenant_model_instance.go @@ -20,8 +20,8 @@ package entity type TenantModelInstance struct { ID string `gorm:"column:id;primaryKey;size:32" json:"id"` InstanceName string `gorm:"column:instance_name;size:128;not null" json:"instance_name"` - ProviderID string `gorm:"column:provider_id;size:32;not null;index" json:"provider_id"` - APIKey string `gorm:"column:api_key;size:512;not null;uniqueIndex" json:"api_key"` + ProviderID string `gorm:"column:provider_id;size:32;not null;uniqueIndex:idx_api_key_provider_id" json:"provider_id"` + APIKey string `gorm:"column:api_key;size:512;not null;uniqueIndex:idx_api_key_provider_id" json:"api_key"` Status string `gorm:"column:status;size:32;default:'active'" json:"status"` BaseModel } diff --git a/internal/handler/providers.go b/internal/handler/providers.go index be99355555..d93cb9df57 100644 --- a/internal/handler/providers.go +++ b/internal/handler/providers.go @@ -17,9 +17,11 @@ package handler import ( + "fmt" "net/http" "ragflow/internal/common" "ragflow/internal/dao" + "ragflow/internal/entity/models" "ragflow/internal/service" "strings" @@ -528,7 +530,9 @@ func (h *ProviderHandler) EnableOrDisableModel(c *gin.Context) { } type ChatToModelRequest struct { - Message string `json:"message" binding:"required"` + Message string `json:"message" binding:"required"` + Stream bool `json:"stream"` + Reasoning bool `json:"reasoning"` } func (h *ProviderHandler) ChatToModel(c *gin.Context) { @@ -571,6 +575,58 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) { userID := c.GetString("user_id") + // Check if it's a stream request + if req.Stream { + // Set SSE headers + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Writer.WriteHeader(http.StatusOK) + c.Writer.Flush() + + // Create sender function that writes directly to response + sender := func(content, reasoningContent *string) error { + // Check for [DONE] marker (OpenAI compatible) + if content != nil { + if *content == "[DONE]" { + c.SSEvent("done", "[DONE]") + return nil + } + message := fmt.Sprintf("[MESSAGE]%s", *content) + c.SSEvent("message", message) + c.Writer.Flush() + } + + if reasoningContent != nil { + message := fmt.Sprintf("[REASONING]%s", *reasoningContent) + c.SSEvent("message", message) + c.Writer.Flush() + } + + //logger.Info(data) + return nil + } + + chatConfig := models.ChatConfig{ + Reasoning: &req.Reasoning, + Stream: &req.Stream, + Stop: &[]string{}, + DoSample: nil, + MaxTokens: nil, + Temperature: nil, + TopP: nil, + } + + // Stream response using sender function (best performance, no channel) + errorCode := h.modelProviderService.ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, req.Message, &chatConfig, sender) + + if errorCode != common.CodeSuccess { + c.SSEvent("error", "stream failed") + } + return + } + + // Non-stream response response, errorCode, err := h.modelProviderService.ChatToModel(providerName, instanceName, modelName, userID, req.Message) if err != nil { c.JSON(http.StatusOK, gin.H{ diff --git a/internal/service/model_service.go b/internal/service/model_service.go index 50e53aa97f..8761a9fa36 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -27,6 +27,7 @@ import ( "strings" "time" + model "ragflow/internal/entity/models" "ragflow/internal/service/models" ) @@ -542,3 +543,123 @@ func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName return nil, common.CodeServerError, errors.New("model is disabled") } + +// ChatToModelStream +func (m *ModelProviderService) ChatToModelStream(providerName, instanceName, modelName, userID, message string) (<-chan string, <-chan error, common.ErrorCode, error) { + streamChan := make(chan string) + errChan := make(chan error, 1) + + // Get tenant ID from user + tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + close(streamChan) + close(errChan) + return streamChan, errChan, common.CodeServerError, err + } + + if len(tenants) == 0 { + close(streamChan) + close(errChan) + return streamChan, errChan, common.CodeNotFound, errors.New("user has no tenants") + } + + tenantID := tenants[0].TenantID + + // Check if provider exists + provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + close(streamChan) + close(errChan) + return streamChan, errChan, common.CodeServerError, err + } + + instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + close(streamChan) + close(errChan) + return streamChan, errChan, common.CodeServerError, err + } + + _, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) + if err != nil { + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + close(streamChan) + close(errChan) + return streamChan, errChan, common.CodeNotFound, errors.New("provider not found") + } + + _, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err != nil { + close(streamChan) + close(errChan) + return streamChan, errChan, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) + } + + // Async call stream interface using channel for better performance + go func() { + defer close(streamChan) + defer close(errChan) + + err := providerInfo.ModelDriver.ChatStreamlyWithChannel(&modelName, &instance.APIKey, &message, nil, streamChan) + if err != nil { + errChan <- err + } + }() + + return streamChan, errChan, common.CodeSuccess, nil + } + + close(streamChan) + close(errChan) + return streamChan, errChan, common.CodeServerError, errors.New("model is disabled") +} + +// ChatToModelStreamWithSender streams chat response directly via sender function (best performance, no channel) +func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, message string, modelConfig *model.ChatConfig, sender func(*string, *string) error) common.ErrorCode { + // Get tenant ID from user + tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + return common.CodeServerError + } + + if len(tenants) == 0 { + return common.CodeNotFound + } + + tenantID := tenants[0].TenantID + + // Check if provider exists + provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + return common.CodeServerError + } + + instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return common.CodeServerError + } + + _, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) + if err != nil { + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return common.CodeNotFound + } + + _, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err != nil { + return common.CodeNotFound + } + + // Direct call with sender function + err := providerInfo.ModelDriver.ChatStreamlyWithSender(&modelName, &instance.APIKey, &message, modelConfig, sender) + if err != nil { + return common.CodeServerError + } + + return common.CodeSuccess + } + + return common.CodeServerError +}