mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Add think chat to CLI (#13922)
### What problem does this PR solve? Now user can use 'think mode' to chat with LLM ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@@ -332,7 +332,7 @@ func looksLikeSQL(s string) bool {
|
||||
"LIST ", "SHOW ", "CREATE ", "DROP ", "ALTER ",
|
||||
"LOGIN ", "REGISTER ", "PING", "GRANT ", "REVOKE ",
|
||||
"SET ", "UNSET ", "UPDATE ", "DELETE ", "INSERT ",
|
||||
"SELECT ", "DESCRIBE ", "EXPLAIN ", "ADD ", "ENABLE ", "DISABLE ", "CHAT ", "USE",
|
||||
"SELECT ", "DESCRIBE ", "EXPLAIN ", "ADD ", "ENABLE ", "DISABLE ", "CHAT ", "USE", "THINK",
|
||||
}
|
||||
for _, prefix := range sqlPrefixes {
|
||||
if strings.HasPrefix(s, prefix) {
|
||||
@@ -480,26 +480,19 @@ func NewCLIWithArgs(args *ConnectionArgs) (*CLI, error) {
|
||||
func (c *CLI) Run() error {
|
||||
// If username is provided without password, prompt for password
|
||||
if c.args != nil && c.args.UserName != "" && c.args.Password == "" && c.args.APIToken == "" {
|
||||
// Allow 3 attempts for password verification
|
||||
maxAttempts := 3
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
var input string
|
||||
var err error
|
||||
fmt.Print("Please input your password: ")
|
||||
|
||||
passwordBytes, err := term.ReadPassword(int(os.Stdin.Fd()))
|
||||
fmt.Println()
|
||||
|
||||
// Check if terminal supports password masking
|
||||
if term.IsTerminal(int(os.Stdin.Fd())) {
|
||||
input, err = c.line.PasswordPrompt("Please input your password: ")
|
||||
} else {
|
||||
// Terminal doesn't support password masking, use regular prompt
|
||||
fmt.Println("Warning: This terminal does not support secure password input")
|
||||
input, err = c.line.Prompt("Please input your password (will be visible): ")
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Printf("Error reading input: %v\n", err)
|
||||
fmt.Printf("Error reading password: %v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
input = strings.TrimSpace(input)
|
||||
input := strings.TrimSpace(string(passwordBytes))
|
||||
|
||||
if input == "" {
|
||||
if attempt < maxAttempts {
|
||||
@@ -509,7 +502,6 @@ func (c *CLI) Run() error {
|
||||
return errors.New("no password provided after 3 attempts")
|
||||
}
|
||||
|
||||
// Set the password for verification
|
||||
c.args.Password = input
|
||||
|
||||
if err = c.VerifyAuth(); err != nil {
|
||||
@@ -520,7 +512,6 @@ func (c *CLI) Run() error {
|
||||
return fmt.Errorf("authentication failed after %d attempts: %v", maxAttempts, err)
|
||||
}
|
||||
|
||||
// Authentication successful
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,6 +247,8 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) {
|
||||
return c.EnableOrDisableModel(cmd, "disable")
|
||||
case "chat_to_model":
|
||||
return c.ChatToModel(cmd)
|
||||
case "think_chat_to_model":
|
||||
return c.ChatToModel(cmd)
|
||||
case "use_model":
|
||||
return c.UseModel(cmd)
|
||||
case "show_current_model":
|
||||
|
||||
@@ -21,10 +21,9 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// LoginUserInteractive performs interactive login with username and password
|
||||
@@ -376,44 +375,24 @@ func (c *RAGFlowClient) ShowModel(cmd *Command) (ResponseIf, error) {
|
||||
|
||||
// readPassword reads password from terminal without echoing
|
||||
func readPassword() (string, error) {
|
||||
// Check if stdin is a terminal by trying to get terminal size
|
||||
if isTerminal() {
|
||||
// Use stty to disable echo
|
||||
cmd := exec.Command("stty", "-echo")
|
||||
cmd.Stdin = os.Stdin
|
||||
if err := cmd.Run(); err != nil {
|
||||
// Fallback: read normally
|
||||
return readPasswordFallback()
|
||||
}
|
||||
defer func() {
|
||||
// Re-enable echo
|
||||
cmd := exec.Command("stty", "echo")
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Run()
|
||||
}()
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
password, err := reader.ReadString('\n')
|
||||
fmt.Println() // New line after password input
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(password), nil
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
||||
return readPasswordFallback()
|
||||
}
|
||||
|
||||
// Fallback for non-terminal input (e.g., piped input)
|
||||
return readPasswordFallback()
|
||||
}
|
||||
fmt.Print("Password: ")
|
||||
passwordBytes, err := term.ReadPassword(int(os.Stdin.Fd()))
|
||||
fmt.Println()
|
||||
|
||||
// isTerminal checks if stdin is a terminal
|
||||
func isTerminal() bool {
|
||||
var termios syscall.Termios
|
||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, os.Stdin.Fd(), syscall.TCGETS, uintptr(unsafe.Pointer(&termios)), 0, 0, 0)
|
||||
return err == 0
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return strings.TrimSpace(string(passwordBytes)), nil
|
||||
}
|
||||
|
||||
// readPasswordFallback reads password as plain text (fallback mode)
|
||||
func readPasswordFallback() (string, error) {
|
||||
fmt.Print("Password (will be visible): ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
password, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
|
||||
@@ -326,3 +326,50 @@ func (c *HTTPClient) RequestJSON(method, path string, useAPIBase bool, authKind
|
||||
}
|
||||
return resp.JSON()
|
||||
}
|
||||
|
||||
// RequestStream makes an HTTP request for SSE streaming and returns the response body reader
|
||||
func (c *HTTPClient) RequestStream(method, path string, useAPIBase bool, authKind string, headers map[string]string, jsonBody map[string]interface{}) (io.ReadCloser, float64, error) {
|
||||
url := c.BuildURL(path, useAPIBase)
|
||||
mergedHeaders := c.Headers(authKind, headers)
|
||||
|
||||
var body io.Reader
|
||||
if jsonBody != nil {
|
||||
jsonData, err := json.Marshal(jsonBody)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
body = bytes.NewReader(jsonData)
|
||||
if mergedHeaders == nil {
|
||||
mergedHeaders = make(map[string]string)
|
||||
}
|
||||
mergedHeaders["Content-Type"] = "application/json"
|
||||
}
|
||||
// Add Accept header for SSE
|
||||
if mergedHeaders == nil {
|
||||
mergedHeaders = make(map[string]string)
|
||||
}
|
||||
mergedHeaders["Accept"] = "text/event-stream"
|
||||
|
||||
req, err := http.NewRequest(method, url, body)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
for k, v := range mergedHeaders {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
duration := time.Since(startTime).Seconds()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
resp.Body.Close()
|
||||
return nil, duration, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return resp.Body, duration, nil
|
||||
}
|
||||
|
||||
@@ -255,6 +255,8 @@ func (l *Lexer) lookupIdent(ident string) Token {
|
||||
return Token{Type: TokenChats, Value: ident}
|
||||
case "CHAT":
|
||||
return Token{Type: TokenChat, Value: ident}
|
||||
case "THINK":
|
||||
return Token{Type: TokenThink, Value: ident}
|
||||
case "FILES":
|
||||
return Token{Type: TokenFiles, Value: ident}
|
||||
case "AS":
|
||||
|
||||
@@ -192,6 +192,8 @@ func (p *Parser) parseUserCommand() (*Command, error) {
|
||||
return p.parseDisableCommand()
|
||||
case TokenChat:
|
||||
return p.parseChatCommand()
|
||||
case TokenThink:
|
||||
return p.parseThinkCommand()
|
||||
case TokenUse:
|
||||
return p.parseUseCommand()
|
||||
default:
|
||||
|
||||
@@ -141,6 +141,32 @@ func (r *MessageResponse) PrintOut() {
|
||||
}
|
||||
}
|
||||
|
||||
type StreamMessageResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Duration float64
|
||||
outputFormat OutputFormat
|
||||
}
|
||||
|
||||
func (r *StreamMessageResponse) Type() string {
|
||||
return "stream_message"
|
||||
}
|
||||
|
||||
func (r *StreamMessageResponse) TimeCost() float64 {
|
||||
return r.Duration
|
||||
}
|
||||
|
||||
func (r *StreamMessageResponse) SetOutputFormat(format OutputFormat) {
|
||||
r.outputFormat = format
|
||||
}
|
||||
|
||||
func (r *StreamMessageResponse) PrintOut() {
|
||||
if r.Code != 0 {
|
||||
fmt.Println("ERROR")
|
||||
fmt.Printf("%d, %s\n", r.Code, r.Message)
|
||||
}
|
||||
}
|
||||
|
||||
type RegisterResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
|
||||
@@ -111,6 +111,7 @@ const (
|
||||
TokenDisable
|
||||
TokenEnable
|
||||
TokenUse
|
||||
TokenThink
|
||||
TokenInsert
|
||||
TokenFile
|
||||
TokenMetadata
|
||||
|
||||
@@ -17,9 +17,11 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
ce "ragflow/internal/cli/contextengine"
|
||||
"strings"
|
||||
)
|
||||
@@ -1187,29 +1189,83 @@ func (c *RAGFlowClient) ChatToModel(cmd *Command) (ResponseIf, error) {
|
||||
}
|
||||
|
||||
message := cmd.Params["message"].(string)
|
||||
reasoning := cmd.Params["reasoning"].(bool)
|
||||
|
||||
url := fmt.Sprintf("/providers/%s/instances/%s/models/%s", providerName, instanceName, modelName)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"message": message,
|
||||
"message": message,
|
||||
"stream": true, // use stream API
|
||||
"reasoning": reasoning,
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("POST", url, true, "web", nil, payload)
|
||||
// Call stream http api
|
||||
reader, duration, err := c.HTTPClient.RequestStream("POST", url, true, "web", nil, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to chat model: %w", err)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to chat model: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
defer reader.Close()
|
||||
|
||||
// Parse SSE and output to console
|
||||
scanner := bufio.NewScanner(reader)
|
||||
var fullMessage strings.Builder
|
||||
|
||||
reasoningPrint := true
|
||||
messagePrint := true
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
data := strings.TrimPrefix(line, "data:")
|
||||
data = strings.TrimSpace(data)
|
||||
|
||||
if strings.HasPrefix(data, "[REASONING]") {
|
||||
data = strings.TrimPrefix(data, "[REASONING]")
|
||||
if reasoningPrint {
|
||||
fmt.Print("Thinking: ")
|
||||
reasoningPrint = false
|
||||
} else {
|
||||
fmt.Print(data)
|
||||
}
|
||||
os.Stdout.Sync()
|
||||
}
|
||||
if strings.HasPrefix(data, "[MESSAGE]") {
|
||||
data = strings.TrimPrefix(data, "[MESSAGE]")
|
||||
if messagePrint {
|
||||
if reasoning {
|
||||
fmt.Println()
|
||||
}
|
||||
fmt.Print("Answer: ")
|
||||
messagePrint = false
|
||||
} else {
|
||||
fmt.Print(data)
|
||||
os.Stdout.Sync()
|
||||
fullMessage.WriteString(data)
|
||||
}
|
||||
}
|
||||
} else if strings.HasPrefix(line, "event:error") {
|
||||
// error event
|
||||
if scanner.Scan() {
|
||||
errData := strings.TrimPrefix(scanner.Text(), "data:")
|
||||
errData = strings.TrimSpace(errData)
|
||||
return nil, fmt.Errorf("chat error: %s", errData)
|
||||
}
|
||||
// If there's an error, return a generic error
|
||||
return nil, fmt.Errorf("chat error: received error event from server")
|
||||
}
|
||||
}
|
||||
var result MessageResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("chat model failed: invalid JSON (%w)", err)
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error reading stream: %w", err)
|
||||
}
|
||||
if result.Code != 0 {
|
||||
return nil, fmt.Errorf("%s", result.Message)
|
||||
|
||||
fmt.Println()
|
||||
|
||||
result := &StreamMessageResponse{
|
||||
Code: 0,
|
||||
Message: fullMessage.String(),
|
||||
Duration: duration,
|
||||
}
|
||||
result.Duration = resp.Duration
|
||||
return &result, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// UseModel sets the current model for chat
|
||||
|
||||
@@ -2025,14 +2025,14 @@ func (p *Parser) parseChatCommand() (*Command, error) {
|
||||
// Format: 'provider/instance/model' or just 'message'
|
||||
if p.curToken.Type == TokenQuotedString {
|
||||
firstArg := p.curToken.Value
|
||||
|
||||
|
||||
// Check if it looks like a model identifier (contains exactly 2 slashes)
|
||||
slashCount := strings.Count(firstArg, "/")
|
||||
if slashCount == 2 {
|
||||
// This is likely a model identifier, expect another quoted string for message
|
||||
modelName = firstArg
|
||||
p.nextToken()
|
||||
|
||||
|
||||
// After model name, expect message
|
||||
if p.curToken.Type != TokenQuotedString {
|
||||
return nil, fmt.Errorf("expected message after model name")
|
||||
@@ -2062,9 +2062,22 @@ func (p *Parser) parseChatCommand() (*Command, error) {
|
||||
cmd.Params["model_name"] = modelName
|
||||
}
|
||||
cmd.Params["message"] = message
|
||||
cmd.Params["reasoning"] = false
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseThinkCommand() (*Command, error) {
|
||||
|
||||
p.nextToken() // consume THINK
|
||||
command, err := p.parseChatCommand()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
command.Type = "think_chat_to_model"
|
||||
command.Params["reasoning"] = true
|
||||
return command, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseUseCommand() (*Command, error) {
|
||||
p.nextToken() // consume USE
|
||||
|
||||
|
||||
@@ -44,6 +44,16 @@ func (z *DummyModel) ChatStreamly(modelName, apiKey, message *string, genConf ma
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// ChatStreamlyWithChannel sends a message and streams response to channel (better performance)
|
||||
func (z *DummyModel) ChatStreamlyWithChannel(modelName, apiKey, message *string, genConf map[string]interface{}, resultChan chan<- string) error {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
|
||||
func (z *DummyModel) ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
func (z *DummyModel) EncodeToEmbedding(modelName, apiKey *string, texts []string) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
|
||||
@@ -6,6 +6,10 @@ type ModelDriver interface {
|
||||
Chat(modelName, apiKey, message *string, genConf map[string]interface{}) (string, error)
|
||||
// ChatStreamly sends a message and streams response
|
||||
ChatStreamly(modelName, apiKey, message *string, genConf map[string]interface{}) (<-chan string, error)
|
||||
// ChatStreamlyWithChannel sends a message and streams response to channel (better performance)
|
||||
ChatStreamlyWithChannel(modelName, apiKey, message *string, genConf map[string]interface{}, resultChan chan<- string) error
|
||||
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
|
||||
ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error
|
||||
// Encode encodes a list of texts into embeddings
|
||||
EncodeToEmbedding(modelName, apiKey *string, texts []string) ([][]float64, error)
|
||||
}
|
||||
@@ -18,3 +22,13 @@ type URLSuffix struct {
|
||||
Embedding string `json:"embedding"`
|
||||
Rerank string `json:"rerank"`
|
||||
}
|
||||
|
||||
type ChatConfig struct {
|
||||
Stream *bool
|
||||
Reasoning *bool
|
||||
MaxTokens *int
|
||||
Temperature *float64
|
||||
TopP *float64
|
||||
DoSample *bool
|
||||
Stop *[]string
|
||||
}
|
||||
|
||||
@@ -17,17 +17,22 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"ragflow/internal/logger"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ZhipuAIModel implements ModelDriver for Zhipu AI (智谱 AI)
|
||||
// ZhipuAIModel implements ModelDriver for Zhipu AI
|
||||
type ZhipuAIModel struct {
|
||||
BaseURL string
|
||||
URLSuffix URLSuffix
|
||||
BaseURL string
|
||||
URLSuffix URLSuffix
|
||||
httpClient *http.Client // Reusable HTTP client with connection pool
|
||||
}
|
||||
|
||||
// NewZhipuAIModel creates a new Zhipu AI model instance
|
||||
@@ -35,6 +40,15 @@ func NewZhipuAIModel(baseURL string, urlSuffix URLSuffix) *ZhipuAIModel {
|
||||
return &ZhipuAIModel{
|
||||
BaseURL: baseURL,
|
||||
URLSuffix: urlSuffix,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
DisableCompression: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,8 +96,7 @@ func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, genConf map[stri
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
@@ -137,7 +150,8 @@ func (z *ZhipuAIModel) ChatStreamly(modelName, apiKey, message *string, genConf
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": *message},
|
||||
},
|
||||
"stream": true,
|
||||
"stream": true,
|
||||
"temperature": 1,
|
||||
}
|
||||
|
||||
// Add generation config if provided
|
||||
@@ -164,10 +178,9 @@ func (z *ZhipuAIModel) ChatStreamly(modelName, apiKey, message *string, genConf
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
@@ -185,14 +198,28 @@ func (z *ZhipuAIModel) ChatStreamly(modelName, apiKey, message *string, genConf
|
||||
defer close(resultChan)
|
||||
defer resp.Body.Close()
|
||||
|
||||
decoder := json.NewDecoder(resp.Body)
|
||||
for {
|
||||
// SSE parsing: read line by line
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// SSE data line starts with "data:"
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract JSON after "data:"
|
||||
data := strings.TrimSpace(line[5:])
|
||||
|
||||
// [DONE] marks the end of stream
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
// Parse the JSON event
|
||||
var event map[string]interface{}
|
||||
if err := decoder.Decode(&event); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
choices, ok := event["choices"].([]interface{})
|
||||
@@ -225,6 +252,259 @@ func (z *ZhipuAIModel) ChatStreamly(modelName, apiKey, message *string, genConf
|
||||
return resultChan, nil
|
||||
}
|
||||
|
||||
// ChatStreamlyWithChannel sends a message and streams response to channel (better performance)
|
||||
func (z *ZhipuAIModel) ChatStreamlyWithChannel(modelName, apiKey, message *string, genConf map[string]interface{}, resultChan chan<- string) error {
|
||||
url := fmt.Sprintf("%s/chat/completions", z.BaseURL)
|
||||
|
||||
// Build request body with streaming enabled
|
||||
reqBody := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": *message},
|
||||
},
|
||||
"stream": true,
|
||||
"temperature": 1,
|
||||
}
|
||||
|
||||
// Add generation config if provided
|
||||
if genConf != nil {
|
||||
if maxTokens, ok := genConf["max_tokens"]; ok {
|
||||
reqBody["max_tokens"] = maxTokens
|
||||
}
|
||||
if temperature, ok := genConf["temperature"]; ok {
|
||||
reqBody["temperature"] = temperature
|
||||
}
|
||||
if topP, ok := genConf["top_p"]; ok {
|
||||
reqBody["top_p"] = topP
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
|
||||
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// SSE parsing: read line by line
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
logger.Info(line)
|
||||
|
||||
// SSE data line starts with "data:"
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract JSON after "data:"
|
||||
data := strings.TrimSpace(line[5:])
|
||||
|
||||
// [DONE] marks the end of stream
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
// Parse the JSON event
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
choices, ok := event["choices"].([]interface{})
|
||||
if !ok || len(choices) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
firstChoice, ok := choices[0].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
delta, ok := firstChoice["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
content, ok := delta["content"].(string)
|
||||
if ok && content != "" {
|
||||
resultChan <- content
|
||||
}
|
||||
|
||||
finishReason, ok := firstChoice["finish_reason"].(string)
|
||||
if ok && finishReason != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Send [DONE] marker for OpenAI compatibility
|
||||
resultChan <- "[DONE]"
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
|
||||
func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error {
|
||||
url := fmt.Sprintf("%s/chat/completions", z.BaseURL)
|
||||
|
||||
// Build request body with streaming enabled
|
||||
reqBody := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": *message},
|
||||
},
|
||||
"stream": false,
|
||||
"temperature": 1,
|
||||
}
|
||||
|
||||
if modelConfig != nil {
|
||||
if modelConfig.Stream != nil {
|
||||
reqBody["stream"] = *modelConfig.Stream
|
||||
}
|
||||
|
||||
if modelConfig.MaxTokens != nil {
|
||||
reqBody["max_tokens"] = *modelConfig.MaxTokens
|
||||
}
|
||||
|
||||
if modelConfig.Temperature != nil {
|
||||
reqBody["temperature"] = *modelConfig.Temperature
|
||||
}
|
||||
|
||||
if modelConfig.DoSample != nil {
|
||||
reqBody["do_sample"] = *modelConfig.DoSample
|
||||
}
|
||||
|
||||
if modelConfig.TopP != nil {
|
||||
reqBody["top_p"] = *modelConfig.TopP
|
||||
}
|
||||
|
||||
if modelConfig.Stop != nil {
|
||||
reqBody["stop"] = *modelConfig.Stop
|
||||
}
|
||||
|
||||
if modelConfig.Reasoning != nil {
|
||||
if *modelConfig.Reasoning {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "enabled",
|
||||
}
|
||||
} else {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "disabled",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
|
||||
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// SSE parsing: read line by line
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
logger.Info(line)
|
||||
|
||||
// SSE data line starts with "data:"
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract JSON after "data:"
|
||||
data := strings.TrimSpace(line[5:])
|
||||
|
||||
// [DONE] marks the end of stream
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
// Parse the JSON event
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
choices, ok := event["choices"].([]interface{})
|
||||
if !ok || len(choices) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
firstChoice, ok := choices[0].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
delta, ok := firstChoice["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
content, ok := delta["content"].(string)
|
||||
if ok && content != "" {
|
||||
if err := sender(&content, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
reasoningContent, ok := delta["reasoning_content"].(string)
|
||||
if ok && reasoningContent != "" {
|
||||
if err := sender(nil, &reasoningContent); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
finishReason, ok := firstChoice["finish_reason"].(string)
|
||||
if ok && finishReason != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Send [DONE] marker for OpenAI compatibility
|
||||
endOfStream := "[DONE]"
|
||||
if err := sender(&endOfStream, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []string) ([][]float64, error) {
|
||||
url := fmt.Sprintf("%s/embedding", z.BaseURL)
|
||||
@@ -248,10 +528,9 @@ func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []stri
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
|
||||
@@ -20,8 +20,8 @@ package entity
|
||||
type TenantModelInstance struct {
|
||||
ID string `gorm:"column:id;primaryKey;size:32" json:"id"`
|
||||
InstanceName string `gorm:"column:instance_name;size:128;not null" json:"instance_name"`
|
||||
ProviderID string `gorm:"column:provider_id;size:32;not null;index" json:"provider_id"`
|
||||
APIKey string `gorm:"column:api_key;size:512;not null;uniqueIndex" json:"api_key"`
|
||||
ProviderID string `gorm:"column:provider_id;size:32;not null;uniqueIndex:idx_api_key_provider_id" json:"provider_id"`
|
||||
APIKey string `gorm:"column:api_key;size:512;not null;uniqueIndex:idx_api_key_provider_id" json:"api_key"`
|
||||
Status string `gorm:"column:status;size:32;default:'active'" json:"status"`
|
||||
BaseModel
|
||||
}
|
||||
|
||||
@@ -17,9 +17,11 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity/models"
|
||||
"ragflow/internal/service"
|
||||
"strings"
|
||||
|
||||
@@ -528,7 +530,9 @@ func (h *ProviderHandler) EnableOrDisableModel(c *gin.Context) {
|
||||
}
|
||||
|
||||
type ChatToModelRequest struct {
|
||||
Message string `json:"message" binding:"required"`
|
||||
Message string `json:"message" binding:"required"`
|
||||
Stream bool `json:"stream"`
|
||||
Reasoning bool `json:"reasoning"`
|
||||
}
|
||||
|
||||
func (h *ProviderHandler) ChatToModel(c *gin.Context) {
|
||||
@@ -571,6 +575,58 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) {
|
||||
|
||||
userID := c.GetString("user_id")
|
||||
|
||||
// Check if it's a stream request
|
||||
if req.Stream {
|
||||
// Set SSE headers
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
c.Writer.Flush()
|
||||
|
||||
// Create sender function that writes directly to response
|
||||
sender := func(content, reasoningContent *string) error {
|
||||
// Check for [DONE] marker (OpenAI compatible)
|
||||
if content != nil {
|
||||
if *content == "[DONE]" {
|
||||
c.SSEvent("done", "[DONE]")
|
||||
return nil
|
||||
}
|
||||
message := fmt.Sprintf("[MESSAGE]%s", *content)
|
||||
c.SSEvent("message", message)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
if reasoningContent != nil {
|
||||
message := fmt.Sprintf("[REASONING]%s", *reasoningContent)
|
||||
c.SSEvent("message", message)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
//logger.Info(data)
|
||||
return nil
|
||||
}
|
||||
|
||||
chatConfig := models.ChatConfig{
|
||||
Reasoning: &req.Reasoning,
|
||||
Stream: &req.Stream,
|
||||
Stop: &[]string{},
|
||||
DoSample: nil,
|
||||
MaxTokens: nil,
|
||||
Temperature: nil,
|
||||
TopP: nil,
|
||||
}
|
||||
|
||||
// Stream response using sender function (best performance, no channel)
|
||||
errorCode := h.modelProviderService.ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, req.Message, &chatConfig, sender)
|
||||
|
||||
if errorCode != common.CodeSuccess {
|
||||
c.SSEvent("error", "stream failed")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Non-stream response
|
||||
response, errorCode, err := h.modelProviderService.ChatToModel(providerName, instanceName, modelName, userID, req.Message)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
model "ragflow/internal/entity/models"
|
||||
"ragflow/internal/service/models"
|
||||
)
|
||||
|
||||
@@ -542,3 +543,123 @@ func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName
|
||||
|
||||
return nil, common.CodeServerError, errors.New("model is disabled")
|
||||
}
|
||||
|
||||
// ChatToModelStream
|
||||
func (m *ModelProviderService) ChatToModelStream(providerName, instanceName, modelName, userID, message string) (<-chan string, <-chan error, common.ErrorCode, error) {
|
||||
streamChan := make(chan string)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
// Get tenant ID from user
|
||||
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
|
||||
if err != nil {
|
||||
close(streamChan)
|
||||
close(errChan)
|
||||
return streamChan, errChan, common.CodeServerError, err
|
||||
}
|
||||
|
||||
if len(tenants) == 0 {
|
||||
close(streamChan)
|
||||
close(errChan)
|
||||
return streamChan, errChan, common.CodeNotFound, errors.New("user has no tenants")
|
||||
}
|
||||
|
||||
tenantID := tenants[0].TenantID
|
||||
|
||||
// Check if provider exists
|
||||
provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName)
|
||||
if err != nil {
|
||||
close(streamChan)
|
||||
close(errChan)
|
||||
return streamChan, errChan, common.CodeServerError, err
|
||||
}
|
||||
|
||||
instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName)
|
||||
if err != nil {
|
||||
close(streamChan)
|
||||
close(errChan)
|
||||
return streamChan, errChan, common.CodeServerError, err
|
||||
}
|
||||
|
||||
_, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName)
|
||||
if err != nil {
|
||||
providerInfo := dao.GetModelProviderManager().FindProvider(providerName)
|
||||
if providerInfo == nil {
|
||||
close(streamChan)
|
||||
close(errChan)
|
||||
return streamChan, errChan, common.CodeNotFound, errors.New("provider not found")
|
||||
}
|
||||
|
||||
_, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName)
|
||||
if err != nil {
|
||||
close(streamChan)
|
||||
close(errChan)
|
||||
return streamChan, errChan, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName))
|
||||
}
|
||||
|
||||
// Async call stream interface using channel for better performance
|
||||
go func() {
|
||||
defer close(streamChan)
|
||||
defer close(errChan)
|
||||
|
||||
err := providerInfo.ModelDriver.ChatStreamlyWithChannel(&modelName, &instance.APIKey, &message, nil, streamChan)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
return streamChan, errChan, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
close(streamChan)
|
||||
close(errChan)
|
||||
return streamChan, errChan, common.CodeServerError, errors.New("model is disabled")
|
||||
}
|
||||
|
||||
// ChatToModelStreamWithSender streams chat response directly via sender function (best performance, no channel)
|
||||
func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, message string, modelConfig *model.ChatConfig, sender func(*string, *string) error) common.ErrorCode {
|
||||
// Get tenant ID from user
|
||||
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
|
||||
if err != nil {
|
||||
return common.CodeServerError
|
||||
}
|
||||
|
||||
if len(tenants) == 0 {
|
||||
return common.CodeNotFound
|
||||
}
|
||||
|
||||
tenantID := tenants[0].TenantID
|
||||
|
||||
// Check if provider exists
|
||||
provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName)
|
||||
if err != nil {
|
||||
return common.CodeServerError
|
||||
}
|
||||
|
||||
instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName)
|
||||
if err != nil {
|
||||
return common.CodeServerError
|
||||
}
|
||||
|
||||
_, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName)
|
||||
if err != nil {
|
||||
providerInfo := dao.GetModelProviderManager().FindProvider(providerName)
|
||||
if providerInfo == nil {
|
||||
return common.CodeNotFound
|
||||
}
|
||||
|
||||
_, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName)
|
||||
if err != nil {
|
||||
return common.CodeNotFound
|
||||
}
|
||||
|
||||
// Direct call with sender function
|
||||
err := providerInfo.ModelDriver.ChatStreamlyWithSender(&modelName, &instance.APIKey, &message, modelConfig, sender)
|
||||
if err != nil {
|
||||
return common.CodeServerError
|
||||
}
|
||||
|
||||
return common.CodeSuccess
|
||||
}
|
||||
|
||||
return common.CodeServerError
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user