Add think chat to CLI (#13922)

### What problem does this PR solve?

Now user can use 'think mode' to chat with LLM

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2026-04-03 18:11:23 +08:00
committed by GitHub
parent e518c20736
commit 6d9430a125
16 changed files with 684 additions and 85 deletions

View File

@@ -332,7 +332,7 @@ func looksLikeSQL(s string) bool {
"LIST ", "SHOW ", "CREATE ", "DROP ", "ALTER ",
"LOGIN ", "REGISTER ", "PING", "GRANT ", "REVOKE ",
"SET ", "UNSET ", "UPDATE ", "DELETE ", "INSERT ",
"SELECT ", "DESCRIBE ", "EXPLAIN ", "ADD ", "ENABLE ", "DISABLE ", "CHAT ", "USE",
"SELECT ", "DESCRIBE ", "EXPLAIN ", "ADD ", "ENABLE ", "DISABLE ", "CHAT ", "USE", "THINK",
}
for _, prefix := range sqlPrefixes {
if strings.HasPrefix(s, prefix) {
@@ -480,26 +480,19 @@ func NewCLIWithArgs(args *ConnectionArgs) (*CLI, error) {
func (c *CLI) Run() error {
// If username is provided without password, prompt for password
if c.args != nil && c.args.UserName != "" && c.args.Password == "" && c.args.APIToken == "" {
// Allow 3 attempts for password verification
maxAttempts := 3
for attempt := 1; attempt <= maxAttempts; attempt++ {
var input string
var err error
fmt.Print("Please input your password: ")
passwordBytes, err := term.ReadPassword(int(os.Stdin.Fd()))
fmt.Println()
// Check if terminal supports password masking
if term.IsTerminal(int(os.Stdin.Fd())) {
input, err = c.line.PasswordPrompt("Please input your password: ")
} else {
// Terminal doesn't support password masking, use regular prompt
fmt.Println("Warning: This terminal does not support secure password input")
input, err = c.line.Prompt("Please input your password (will be visible): ")
}
if err != nil {
fmt.Printf("Error reading input: %v\n", err)
fmt.Printf("Error reading password: %v\n", err)
return err
}
input = strings.TrimSpace(input)
input := strings.TrimSpace(string(passwordBytes))
if input == "" {
if attempt < maxAttempts {
@@ -509,7 +502,6 @@ func (c *CLI) Run() error {
return errors.New("no password provided after 3 attempts")
}
// Set the password for verification
c.args.Password = input
if err = c.VerifyAuth(); err != nil {
@@ -520,7 +512,6 @@ func (c *CLI) Run() error {
return fmt.Errorf("authentication failed after %d attempts: %v", maxAttempts, err)
}
// Authentication successful
break
}
}

View File

@@ -247,6 +247,8 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) {
return c.EnableOrDisableModel(cmd, "disable")
case "chat_to_model":
return c.ChatToModel(cmd)
case "think_chat_to_model":
return c.ChatToModel(cmd)
case "use_model":
return c.UseModel(cmd)
case "show_current_model":

View File

@@ -21,10 +21,9 @@ import (
"encoding/json"
"fmt"
"os"
"os/exec"
"strings"
"syscall"
"unsafe"
"golang.org/x/term"
)
// LoginUserInteractive performs interactive login with username and password
@@ -376,44 +375,24 @@ func (c *RAGFlowClient) ShowModel(cmd *Command) (ResponseIf, error) {
// readPassword reads password from terminal without echoing
func readPassword() (string, error) {
// Check if stdin is a terminal by trying to get terminal size
if isTerminal() {
// Use stty to disable echo
cmd := exec.Command("stty", "-echo")
cmd.Stdin = os.Stdin
if err := cmd.Run(); err != nil {
// Fallback: read normally
return readPasswordFallback()
}
defer func() {
// Re-enable echo
cmd := exec.Command("stty", "echo")
cmd.Stdin = os.Stdin
cmd.Run()
}()
reader := bufio.NewReader(os.Stdin)
password, err := reader.ReadString('\n')
fmt.Println() // New line after password input
if err != nil {
return "", err
}
return strings.TrimSpace(password), nil
if !term.IsTerminal(int(os.Stdin.Fd())) {
return readPasswordFallback()
}
// Fallback for non-terminal input (e.g., piped input)
return readPasswordFallback()
}
fmt.Print("Password: ")
passwordBytes, err := term.ReadPassword(int(os.Stdin.Fd()))
fmt.Println()
// isTerminal checks if stdin is a terminal
func isTerminal() bool {
var termios syscall.Termios
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, os.Stdin.Fd(), syscall.TCGETS, uintptr(unsafe.Pointer(&termios)), 0, 0, 0)
return err == 0
if err != nil {
return "", err
}
return strings.TrimSpace(string(passwordBytes)), nil
}
// readPasswordFallback reads password as plain text (fallback mode)
func readPasswordFallback() (string, error) {
fmt.Print("Password (will be visible): ")
reader := bufio.NewReader(os.Stdin)
password, err := reader.ReadString('\n')
if err != nil {

View File

@@ -326,3 +326,50 @@ func (c *HTTPClient) RequestJSON(method, path string, useAPIBase bool, authKind
}
return resp.JSON()
}
// RequestStream makes an HTTP request for SSE streaming and returns the response body reader
func (c *HTTPClient) RequestStream(method, path string, useAPIBase bool, authKind string, headers map[string]string, jsonBody map[string]interface{}) (io.ReadCloser, float64, error) {
url := c.BuildURL(path, useAPIBase)
mergedHeaders := c.Headers(authKind, headers)
var body io.Reader
if jsonBody != nil {
jsonData, err := json.Marshal(jsonBody)
if err != nil {
return nil, 0, err
}
body = bytes.NewReader(jsonData)
if mergedHeaders == nil {
mergedHeaders = make(map[string]string)
}
mergedHeaders["Content-Type"] = "application/json"
}
// Add Accept header for SSE
if mergedHeaders == nil {
mergedHeaders = make(map[string]string)
}
mergedHeaders["Accept"] = "text/event-stream"
req, err := http.NewRequest(method, url, body)
if err != nil {
return nil, 0, err
}
for k, v := range mergedHeaders {
req.Header.Set(k, v)
}
startTime := time.Now()
resp, err := c.client.Do(req)
if err != nil {
return nil, 0, err
}
duration := time.Since(startTime).Seconds()
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
return nil, duration, fmt.Errorf("HTTP %d", resp.StatusCode)
}
return resp.Body, duration, nil
}

View File

@@ -255,6 +255,8 @@ func (l *Lexer) lookupIdent(ident string) Token {
return Token{Type: TokenChats, Value: ident}
case "CHAT":
return Token{Type: TokenChat, Value: ident}
case "THINK":
return Token{Type: TokenThink, Value: ident}
case "FILES":
return Token{Type: TokenFiles, Value: ident}
case "AS":

View File

@@ -192,6 +192,8 @@ func (p *Parser) parseUserCommand() (*Command, error) {
return p.parseDisableCommand()
case TokenChat:
return p.parseChatCommand()
case TokenThink:
return p.parseThinkCommand()
case TokenUse:
return p.parseUseCommand()
default:

View File

@@ -141,6 +141,32 @@ func (r *MessageResponse) PrintOut() {
}
}
type StreamMessageResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Duration float64
outputFormat OutputFormat
}
func (r *StreamMessageResponse) Type() string {
return "stream_message"
}
func (r *StreamMessageResponse) TimeCost() float64 {
return r.Duration
}
func (r *StreamMessageResponse) SetOutputFormat(format OutputFormat) {
r.outputFormat = format
}
func (r *StreamMessageResponse) PrintOut() {
if r.Code != 0 {
fmt.Println("ERROR")
fmt.Printf("%d, %s\n", r.Code, r.Message)
}
}
type RegisterResponse struct {
Code int `json:"code"`
Message string `json:"message"`

View File

@@ -111,6 +111,7 @@ const (
TokenDisable
TokenEnable
TokenUse
TokenThink
TokenInsert
TokenFile
TokenMetadata

View File

@@ -17,9 +17,11 @@
package cli
import (
"bufio"
"context"
"encoding/json"
"fmt"
"os"
ce "ragflow/internal/cli/contextengine"
"strings"
)
@@ -1187,29 +1189,83 @@ func (c *RAGFlowClient) ChatToModel(cmd *Command) (ResponseIf, error) {
}
message := cmd.Params["message"].(string)
reasoning := cmd.Params["reasoning"].(bool)
url := fmt.Sprintf("/providers/%s/instances/%s/models/%s", providerName, instanceName, modelName)
payload := map[string]interface{}{
"message": message,
"message": message,
"stream": true, // use stream API
"reasoning": reasoning,
}
resp, err := c.HTTPClient.Request("POST", url, true, "web", nil, payload)
// Call stream http api
reader, duration, err := c.HTTPClient.RequestStream("POST", url, true, "web", nil, payload)
if err != nil {
return nil, fmt.Errorf("failed to chat model: %w", err)
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("failed to chat model: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
defer reader.Close()
// Parse SSE and output to console
scanner := bufio.NewScanner(reader)
var fullMessage strings.Builder
reasoningPrint := true
messagePrint := true
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "data:") {
data := strings.TrimPrefix(line, "data:")
data = strings.TrimSpace(data)
if strings.HasPrefix(data, "[REASONING]") {
data = strings.TrimPrefix(data, "[REASONING]")
if reasoningPrint {
fmt.Print("Thinking: ")
reasoningPrint = false
} else {
fmt.Print(data)
}
os.Stdout.Sync()
}
if strings.HasPrefix(data, "[MESSAGE]") {
data = strings.TrimPrefix(data, "[MESSAGE]")
if messagePrint {
if reasoning {
fmt.Println()
}
fmt.Print("Answer: ")
messagePrint = false
} else {
fmt.Print(data)
os.Stdout.Sync()
fullMessage.WriteString(data)
}
}
} else if strings.HasPrefix(line, "event:error") {
// error event
if scanner.Scan() {
errData := strings.TrimPrefix(scanner.Text(), "data:")
errData = strings.TrimSpace(errData)
return nil, fmt.Errorf("chat error: %s", errData)
}
// If there's an error, return a generic error
return nil, fmt.Errorf("chat error: received error event from server")
}
}
var result MessageResponse
if err = json.Unmarshal(resp.Body, &result); err != nil {
return nil, fmt.Errorf("chat model failed: invalid JSON (%w)", err)
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading stream: %w", err)
}
if result.Code != 0 {
return nil, fmt.Errorf("%s", result.Message)
fmt.Println()
result := &StreamMessageResponse{
Code: 0,
Message: fullMessage.String(),
Duration: duration,
}
result.Duration = resp.Duration
return &result, nil
return result, nil
}
// UseModel sets the current model for chat

View File

@@ -2025,14 +2025,14 @@ func (p *Parser) parseChatCommand() (*Command, error) {
// Format: 'provider/instance/model' or just 'message'
if p.curToken.Type == TokenQuotedString {
firstArg := p.curToken.Value
// Check if it looks like a model identifier (contains exactly 2 slashes)
slashCount := strings.Count(firstArg, "/")
if slashCount == 2 {
// This is likely a model identifier, expect another quoted string for message
modelName = firstArg
p.nextToken()
// After model name, expect message
if p.curToken.Type != TokenQuotedString {
return nil, fmt.Errorf("expected message after model name")
@@ -2062,9 +2062,22 @@ func (p *Parser) parseChatCommand() (*Command, error) {
cmd.Params["model_name"] = modelName
}
cmd.Params["message"] = message
cmd.Params["reasoning"] = false
return cmd, nil
}
func (p *Parser) parseThinkCommand() (*Command, error) {
p.nextToken() // consume THINK
command, err := p.parseChatCommand()
if err != nil {
return nil, err
}
command.Type = "think_chat_to_model"
command.Params["reasoning"] = true
return command, nil
}
func (p *Parser) parseUseCommand() (*Command, error) {
p.nextToken() // consume USE

View File

@@ -44,6 +44,16 @@ func (z *DummyModel) ChatStreamly(modelName, apiKey, message *string, genConf ma
return nil, fmt.Errorf("not implemented")
}
// ChatStreamlyWithChannel sends a message and streams response to channel (better performance)
func (z *DummyModel) ChatStreamlyWithChannel(modelName, apiKey, message *string, genConf map[string]interface{}, resultChan chan<- string) error {
return fmt.Errorf("not implemented")
}
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
func (z *DummyModel) ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error {
return fmt.Errorf("not implemented")
}
// EncodeToEmbedding encodes a list of texts into embeddings
func (z *DummyModel) EncodeToEmbedding(modelName, apiKey *string, texts []string) ([][]float64, error) {
return nil, fmt.Errorf("not implemented")

View File

@@ -6,6 +6,10 @@ type ModelDriver interface {
Chat(modelName, apiKey, message *string, genConf map[string]interface{}) (string, error)
// ChatStreamly sends a message and streams response
ChatStreamly(modelName, apiKey, message *string, genConf map[string]interface{}) (<-chan string, error)
// ChatStreamlyWithChannel sends a message and streams response to channel (better performance)
ChatStreamlyWithChannel(modelName, apiKey, message *string, genConf map[string]interface{}, resultChan chan<- string) error
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error
// Encode encodes a list of texts into embeddings
EncodeToEmbedding(modelName, apiKey *string, texts []string) ([][]float64, error)
}
@@ -18,3 +22,13 @@ type URLSuffix struct {
Embedding string `json:"embedding"`
Rerank string `json:"rerank"`
}
type ChatConfig struct {
Stream *bool
Reasoning *bool
MaxTokens *int
Temperature *float64
TopP *float64
DoSample *bool
Stop *[]string
}

View File

@@ -17,17 +17,22 @@
package models
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"ragflow/internal/logger"
"strings"
"time"
)
// ZhipuAIModel implements ModelDriver for Zhipu AI (智谱 AI)
// ZhipuAIModel implements ModelDriver for Zhipu AI
type ZhipuAIModel struct {
BaseURL string
URLSuffix URLSuffix
BaseURL string
URLSuffix URLSuffix
httpClient *http.Client // Reusable HTTP client with connection pool
}
// NewZhipuAIModel creates a new Zhipu AI model instance
@@ -35,6 +40,15 @@ func NewZhipuAIModel(baseURL string, urlSuffix URLSuffix) *ZhipuAIModel {
return &ZhipuAIModel{
BaseURL: baseURL,
URLSuffix: urlSuffix,
httpClient: &http.Client{
Timeout: 120 * time.Second,
Transport: &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
DisableCompression: false,
},
},
}
}
@@ -82,8 +96,7 @@ func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, genConf map[stri
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
client := &http.Client{}
resp, err := client.Do(req)
resp, err := z.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send request: %w", err)
}
@@ -137,7 +150,8 @@ func (z *ZhipuAIModel) ChatStreamly(modelName, apiKey, message *string, genConf
"messages": []map[string]string{
{"role": "user", "content": *message},
},
"stream": true,
"stream": true,
"temperature": 1,
}
// Add generation config if provided
@@ -164,10 +178,9 @@ func (z *ZhipuAIModel) ChatStreamly(modelName, apiKey, message *string, genConf
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
client := &http.Client{}
resp, err := client.Do(req)
resp, err := z.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
@@ -185,14 +198,28 @@ func (z *ZhipuAIModel) ChatStreamly(modelName, apiKey, message *string, genConf
defer close(resultChan)
defer resp.Body.Close()
decoder := json.NewDecoder(resp.Body)
for {
// SSE parsing: read line by line
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
// SSE data line starts with "data:"
if !strings.HasPrefix(line, "data:") {
continue
}
// Extract JSON after "data:"
data := strings.TrimSpace(line[5:])
// [DONE] marks the end of stream
if data == "[DONE]" {
break
}
// Parse the JSON event
var event map[string]interface{}
if err := decoder.Decode(&event); err != nil {
if err == io.EOF {
break
}
return
if err := json.Unmarshal([]byte(data), &event); err != nil {
continue
}
choices, ok := event["choices"].([]interface{})
@@ -225,6 +252,259 @@ func (z *ZhipuAIModel) ChatStreamly(modelName, apiKey, message *string, genConf
return resultChan, nil
}
// ChatStreamlyWithChannel sends a message and streams response to channel (better performance)
func (z *ZhipuAIModel) ChatStreamlyWithChannel(modelName, apiKey, message *string, genConf map[string]interface{}, resultChan chan<- string) error {
url := fmt.Sprintf("%s/chat/completions", z.BaseURL)
// Build request body with streaming enabled
reqBody := map[string]interface{}{
"model": modelName,
"messages": []map[string]string{
{"role": "user", "content": *message},
},
"stream": true,
"temperature": 1,
}
// Add generation config if provided
if genConf != nil {
if maxTokens, ok := genConf["max_tokens"]; ok {
reqBody["max_tokens"] = maxTokens
}
if temperature, ok := genConf["temperature"]; ok {
reqBody["temperature"] = temperature
}
if topP, ok := genConf["top_p"]; ok {
reqBody["top_p"] = topP
}
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
resp, err := z.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
// SSE parsing: read line by line
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
logger.Info(line)
// SSE data line starts with "data:"
if !strings.HasPrefix(line, "data:") {
continue
}
// Extract JSON after "data:"
data := strings.TrimSpace(line[5:])
// [DONE] marks the end of stream
if data == "[DONE]" {
break
}
// Parse the JSON event
var event map[string]interface{}
if err := json.Unmarshal([]byte(data), &event); err != nil {
continue
}
choices, ok := event["choices"].([]interface{})
if !ok || len(choices) == 0 {
continue
}
firstChoice, ok := choices[0].(map[string]interface{})
if !ok {
continue
}
delta, ok := firstChoice["delta"].(map[string]interface{})
if !ok {
continue
}
content, ok := delta["content"].(string)
if ok && content != "" {
resultChan <- content
}
finishReason, ok := firstChoice["finish_reason"].(string)
if ok && finishReason != "" {
break
}
}
// Send [DONE] marker for OpenAI compatibility
resultChan <- "[DONE]"
return scanner.Err()
}
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error {
url := fmt.Sprintf("%s/chat/completions", z.BaseURL)
// Build request body with streaming enabled
reqBody := map[string]interface{}{
"model": modelName,
"messages": []map[string]string{
{"role": "user", "content": *message},
},
"stream": false,
"temperature": 1,
}
if modelConfig != nil {
if modelConfig.Stream != nil {
reqBody["stream"] = *modelConfig.Stream
}
if modelConfig.MaxTokens != nil {
reqBody["max_tokens"] = *modelConfig.MaxTokens
}
if modelConfig.Temperature != nil {
reqBody["temperature"] = *modelConfig.Temperature
}
if modelConfig.DoSample != nil {
reqBody["do_sample"] = *modelConfig.DoSample
}
if modelConfig.TopP != nil {
reqBody["top_p"] = *modelConfig.TopP
}
if modelConfig.Stop != nil {
reqBody["stop"] = *modelConfig.Stop
}
if modelConfig.Reasoning != nil {
if *modelConfig.Reasoning {
reqBody["thinking"] = map[string]interface{}{
"type": "enabled",
}
} else {
reqBody["thinking"] = map[string]interface{}{
"type": "disabled",
}
}
}
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
resp, err := z.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
// SSE parsing: read line by line
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
logger.Info(line)
// SSE data line starts with "data:"
if !strings.HasPrefix(line, "data:") {
continue
}
// Extract JSON after "data:"
data := strings.TrimSpace(line[5:])
// [DONE] marks the end of stream
if data == "[DONE]" {
break
}
// Parse the JSON event
var event map[string]interface{}
if err := json.Unmarshal([]byte(data), &event); err != nil {
continue
}
choices, ok := event["choices"].([]interface{})
if !ok || len(choices) == 0 {
continue
}
firstChoice, ok := choices[0].(map[string]interface{})
if !ok {
continue
}
delta, ok := firstChoice["delta"].(map[string]interface{})
if !ok {
continue
}
content, ok := delta["content"].(string)
if ok && content != "" {
if err := sender(&content, nil); err != nil {
return err
}
}
reasoningContent, ok := delta["reasoning_content"].(string)
if ok && reasoningContent != "" {
if err := sender(nil, &reasoningContent); err != nil {
return err
}
}
finishReason, ok := firstChoice["finish_reason"].(string)
if ok && finishReason != "" {
break
}
}
// Send [DONE] marker for OpenAI compatibility
endOfStream := "[DONE]"
if err := sender(&endOfStream, nil); err != nil {
return err
}
return scanner.Err()
}
// EncodeToEmbedding encodes a list of texts into embeddings
func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []string) ([][]float64, error) {
url := fmt.Sprintf("%s/embedding", z.BaseURL)
@@ -248,10 +528,9 @@ func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []stri
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
client := &http.Client{}
resp, err := client.Do(req)
resp, err := z.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}

View File

@@ -20,8 +20,8 @@ package entity
type TenantModelInstance struct {
ID string `gorm:"column:id;primaryKey;size:32" json:"id"`
InstanceName string `gorm:"column:instance_name;size:128;not null" json:"instance_name"`
ProviderID string `gorm:"column:provider_id;size:32;not null;index" json:"provider_id"`
APIKey string `gorm:"column:api_key;size:512;not null;uniqueIndex" json:"api_key"`
ProviderID string `gorm:"column:provider_id;size:32;not null;uniqueIndex:idx_api_key_provider_id" json:"provider_id"`
APIKey string `gorm:"column:api_key;size:512;not null;uniqueIndex:idx_api_key_provider_id" json:"api_key"`
Status string `gorm:"column:status;size:32;default:'active'" json:"status"`
BaseModel
}

View File

@@ -17,9 +17,11 @@
package handler
import (
"fmt"
"net/http"
"ragflow/internal/common"
"ragflow/internal/dao"
"ragflow/internal/entity/models"
"ragflow/internal/service"
"strings"
@@ -528,7 +530,9 @@ func (h *ProviderHandler) EnableOrDisableModel(c *gin.Context) {
}
type ChatToModelRequest struct {
Message string `json:"message" binding:"required"`
Message string `json:"message" binding:"required"`
Stream bool `json:"stream"`
Reasoning bool `json:"reasoning"`
}
func (h *ProviderHandler) ChatToModel(c *gin.Context) {
@@ -571,6 +575,58 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) {
userID := c.GetString("user_id")
// Check if it's a stream request
if req.Stream {
// Set SSE headers
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Writer.WriteHeader(http.StatusOK)
c.Writer.Flush()
// Create sender function that writes directly to response
sender := func(content, reasoningContent *string) error {
// Check for [DONE] marker (OpenAI compatible)
if content != nil {
if *content == "[DONE]" {
c.SSEvent("done", "[DONE]")
return nil
}
message := fmt.Sprintf("[MESSAGE]%s", *content)
c.SSEvent("message", message)
c.Writer.Flush()
}
if reasoningContent != nil {
message := fmt.Sprintf("[REASONING]%s", *reasoningContent)
c.SSEvent("message", message)
c.Writer.Flush()
}
//logger.Info(data)
return nil
}
chatConfig := models.ChatConfig{
Reasoning: &req.Reasoning,
Stream: &req.Stream,
Stop: &[]string{},
DoSample: nil,
MaxTokens: nil,
Temperature: nil,
TopP: nil,
}
// Stream response using sender function (best performance, no channel)
errorCode := h.modelProviderService.ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, req.Message, &chatConfig, sender)
if errorCode != common.CodeSuccess {
c.SSEvent("error", "stream failed")
}
return
}
// Non-stream response
response, errorCode, err := h.modelProviderService.ChatToModel(providerName, instanceName, modelName, userID, req.Message)
if err != nil {
c.JSON(http.StatusOK, gin.H{

View File

@@ -27,6 +27,7 @@ import (
"strings"
"time"
model "ragflow/internal/entity/models"
"ragflow/internal/service/models"
)
@@ -542,3 +543,123 @@ func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName
return nil, common.CodeServerError, errors.New("model is disabled")
}
// ChatToModelStream
func (m *ModelProviderService) ChatToModelStream(providerName, instanceName, modelName, userID, message string) (<-chan string, <-chan error, common.ErrorCode, error) {
streamChan := make(chan string)
errChan := make(chan error, 1)
// Get tenant ID from user
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
if err != nil {
close(streamChan)
close(errChan)
return streamChan, errChan, common.CodeServerError, err
}
if len(tenants) == 0 {
close(streamChan)
close(errChan)
return streamChan, errChan, common.CodeNotFound, errors.New("user has no tenants")
}
tenantID := tenants[0].TenantID
// Check if provider exists
provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName)
if err != nil {
close(streamChan)
close(errChan)
return streamChan, errChan, common.CodeServerError, err
}
instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName)
if err != nil {
close(streamChan)
close(errChan)
return streamChan, errChan, common.CodeServerError, err
}
_, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName)
if err != nil {
providerInfo := dao.GetModelProviderManager().FindProvider(providerName)
if providerInfo == nil {
close(streamChan)
close(errChan)
return streamChan, errChan, common.CodeNotFound, errors.New("provider not found")
}
_, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName)
if err != nil {
close(streamChan)
close(errChan)
return streamChan, errChan, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName))
}
// Async call stream interface using channel for better performance
go func() {
defer close(streamChan)
defer close(errChan)
err := providerInfo.ModelDriver.ChatStreamlyWithChannel(&modelName, &instance.APIKey, &message, nil, streamChan)
if err != nil {
errChan <- err
}
}()
return streamChan, errChan, common.CodeSuccess, nil
}
close(streamChan)
close(errChan)
return streamChan, errChan, common.CodeServerError, errors.New("model is disabled")
}
// ChatToModelStreamWithSender streams chat response directly via sender function (best performance, no channel)
func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, message string, modelConfig *model.ChatConfig, sender func(*string, *string) error) common.ErrorCode {
// Get tenant ID from user
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
if err != nil {
return common.CodeServerError
}
if len(tenants) == 0 {
return common.CodeNotFound
}
tenantID := tenants[0].TenantID
// Check if provider exists
provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName)
if err != nil {
return common.CodeServerError
}
instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName)
if err != nil {
return common.CodeServerError
}
_, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName)
if err != nil {
providerInfo := dao.GetModelProviderManager().FindProvider(providerName)
if providerInfo == nil {
return common.CodeNotFound
}
_, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName)
if err != nil {
return common.CodeNotFound
}
// Direct call with sender function
err := providerInfo.ModelDriver.ChatStreamlyWithSender(&modelName, &instance.APIKey, &message, modelConfig, sender)
if err != nil {
return common.CodeServerError
}
return common.CodeSuccess
}
return common.CodeServerError
}