diff --git a/conf/models/zhipu-ai.json b/conf/models/zhipu-ai.json index 910c675fae..b38624bffe 100644 --- a/conf/models/zhipu-ai.json +++ b/conf/models/zhipu-ai.json @@ -222,10 +222,20 @@ "glm-4.6v", "glm-4.5", "glm-4.5v" - ], - "clear": { - "default_value": true - } + ] + }, + "clear_thinking": { + "default_value": true, + "supported_models": [ + "glm-5.1", + "glm-5", + "glm-5v-turbo", + "glm-4.7", + "glm-4.6", + "glm-4.6v", + "glm-4.5", + "glm-4.5v" + ] } } } \ No newline at end of file diff --git a/internal/cli/client.go b/internal/cli/client.go index 2fbaa30a36..fc9e920ed7 100644 --- a/internal/cli/client.go +++ b/internal/cli/client.go @@ -164,6 +164,8 @@ func (c *RAGFlowClient) ExecuteAdminCommand(cmd *Command) (ResponseIf, error) { return c.ShowProvider(cmd) case "list_provider_models": return c.ListModels(cmd) + case "list_supported_models": + return c.ListSupportedModels(cmd) case "list_instance_models": return c.ListInstanceModels(cmd) case "show_model": @@ -214,6 +216,8 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) { return c.ShowProvider(cmd) case "list_provider_models": return c.ListModels(cmd) + case "list_supported_models": + return c.ListSupportedModels(cmd) case "list_instance_models": return c.ListInstanceModels(cmd) case "show_model": diff --git a/internal/cli/common_command.go b/internal/cli/common_command.go index 111af8bc6b..045d53206d 100644 --- a/internal/cli/common_command.go +++ b/internal/cli/common_command.go @@ -335,6 +335,45 @@ func (c *RAGFlowClient) ListModels(cmd *Command) (ResponseIf, error) { return &result, nil } +func (c *RAGFlowClient) ListSupportedModels(cmd *Command) (ResponseIf, error) { + + providerName, ok := cmd.Params["provider_name"].(string) + if !ok { + return nil, fmt.Errorf("provider_name not provided") + } + instanceName, ok := cmd.Params["instance_name"].(string) + if !ok { + return nil, fmt.Errorf("instance_name not provided") + } + + var endPoint string + if c.ServerType == "admin" { + endPoint = fmt.Sprintf("/admin/providers/%s/instances/%s/models?supported=true", providerName, instanceName) + } else { + endPoint = fmt.Sprintf("/providers/%s/instances/%s/models?supported=true", providerName, instanceName) + } + + resp, err := c.HTTPClient.Request("GET", endPoint, true, "web", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list models: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list models: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("failed to list models: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + func (c *RAGFlowClient) ShowModel(cmd *Command) (ResponseIf, error) { providerName, ok := cmd.Params["provider_name"].(string) if !ok { diff --git a/internal/cli/lexer.go b/internal/cli/lexer.go index 4bef26c34b..26d3f647a0 100644 --- a/internal/cli/lexer.go +++ b/internal/cli/lexer.go @@ -303,6 +303,8 @@ func (l *Lexer) lookupIdent(ident string) Token { return Token{Type: TokenChat, Value: ident} case "THINK": return Token{Type: TokenThink, Value: ident} + case "STREAM": + return Token{Type: TokenStream, Value: ident} case "LS": return Token{Type: TokenLS, Value: ident} case "CAT": @@ -363,6 +365,8 @@ func (l *Lexer) lookupIdent(ident string) Token { return Token{Type: TokenTable, Value: ident} case "AVAILABLE": return Token{Type: TokenAvailable, Value: ident} + case "SUPPORTED": + return Token{Type: TokenSupported, Value: ident} case "NAME": return Token{Type: TokenName, Value: ident} case "INSTANCE": diff --git a/internal/cli/parser.go b/internal/cli/parser.go index f431f812bf..85271b2725 100644 --- a/internal/cli/parser.go +++ b/internal/cli/parser.go @@ -190,6 +190,8 @@ func (p *Parser) parseUserCommand() (*Command, error) { return p.parseEnableCommand() case TokenDisable: return p.parseDisableCommand() + case TokenStream: + return p.parseStreamCommand() case TokenChat: return p.parseChatCommand() case TokenThink: diff --git a/internal/cli/response.go b/internal/cli/response.go index 712f41c102..f611467ee3 100644 --- a/internal/cli/response.go +++ b/internal/cli/response.go @@ -113,28 +113,33 @@ func (r *SimpleResponse) PrintOut() { } } -type MessageResponse struct { - Code int `json:"code"` - Message string `json:"message"` - Duration float64 - OutputFormat OutputFormat +type NonStreamResponse struct { + Code int `json:"code"` + ReasoningContent string `json:"reasoning_content"` + Answer string `json:"answer"` + Message string `json:"message"` + Duration float64 + OutputFormat OutputFormat } -func (r *MessageResponse) Type() string { - return "message" +func (r *NonStreamResponse) Type() string { + return "non_stream_message" } -func (r *MessageResponse) TimeCost() float64 { +func (r *NonStreamResponse) TimeCost() float64 { return r.Duration } -func (r *MessageResponse) SetOutputFormat(format OutputFormat) { +func (r *NonStreamResponse) SetOutputFormat(format OutputFormat) { r.OutputFormat = format } -func (r *MessageResponse) PrintOut() { +func (r *NonStreamResponse) PrintOut() { if r.Code == 0 { - fmt.Println(r.Message) + if r.ReasoningContent != "" { + fmt.Printf("Thinking: %s\n", r.ReasoningContent) + } + fmt.Printf("Answer: %s\n", r.Answer) } else { fmt.Println("ERROR") fmt.Printf("%d, %s\n", r.Code, r.Message) diff --git a/internal/cli/types.go b/internal/cli/types.go index 90aeff1cb9..b8b2115ec9 100644 --- a/internal/cli/types.go +++ b/internal/cli/types.go @@ -73,6 +73,7 @@ const ( TokenKeys TokenGenerate TokenAvailable + TokenSupported TokenModel TokenModels TokenProvider @@ -80,6 +81,7 @@ const ( TokenDefault TokenChats TokenChat + TokenStream TokenFiles TokenAs TokenParse @@ -106,7 +108,6 @@ const ( TokenIndex TokenVector TokenSize - TokenDocMeta TokenName // For ALTER PROVIDER NAME TokenInstance TokenInstances diff --git a/internal/cli/user_command.go b/internal/cli/user_command.go index e1bb27b1d9..23d20c8da5 100644 --- a/internal/cli/user_command.go +++ b/internal/cli/user_command.go @@ -1436,83 +1436,106 @@ func (c *RAGFlowClient) ChatToModel(cmd *Command) (ResponseIf, error) { } message := cmd.Params["message"].(string) - reasoning := cmd.Params["reasoning"].(bool) + thinking := cmd.Params["thinking"].(bool) + stream := cmd.Params["stream"].(bool) url := fmt.Sprintf("/providers/%s/instances/%s/models/%s", providerName, instanceName, modelName) payload := map[string]interface{}{ - "message": message, - "stream": true, // use stream API - "reasoning": reasoning, + "message": message, + "stream": stream, // use stream API + "thinking": thinking, } - // Call stream http api - reader, duration, err := c.HTTPClient.RequestStream("POST", url, true, "web", nil, payload) - if err != nil { - return nil, fmt.Errorf("failed to chat model: %w", err) - } - defer reader.Close() - - // Parse SSE and output to console - scanner := bufio.NewScanner(reader) - var fullMessage strings.Builder - - reasoningPrint := true - messagePrint := true - for scanner.Scan() { - line := scanner.Text() - if strings.HasPrefix(line, "data:") { - data := strings.TrimPrefix(line, "data:") - data = strings.TrimSpace(data) - - if strings.HasPrefix(data, "[REASONING]") { - data = strings.TrimPrefix(data, "[REASONING]") - if reasoningPrint { - fmt.Print("Thinking: ") - reasoningPrint = false - } else { - fmt.Print(data) - } - os.Stdout.Sync() - } - if strings.HasPrefix(data, "[MESSAGE]") { - data = strings.TrimPrefix(data, "[MESSAGE]") - if messagePrint { - if reasoning { - fmt.Println() - } - fmt.Print("Answer: ") - messagePrint = false - } else { - fmt.Print(data) - os.Stdout.Sync() - fullMessage.WriteString(data) - } - } - } else if strings.HasPrefix(line, "event:error") { - // error event - if scanner.Scan() { - errData := strings.TrimPrefix(scanner.Text(), "data:") - errData = strings.TrimSpace(errData) - return nil, fmt.Errorf("chat error: %s", errData) - } - // If there's an error, return a generic error - return nil, fmt.Errorf("chat error: received error event from server") + if stream { + // Call stream http api + reader, duration, err := c.HTTPClient.RequestStream("POST", url, true, "web", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to chat model: %w", err) } + defer reader.Close() + + // Parse SSE and output to console + scanner := bufio.NewScanner(reader) + var fullMessage strings.Builder + + reasoningPrint := true + messagePrint := true + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "data:") { + data := strings.TrimPrefix(line, "data:") + data = strings.TrimSpace(data) + + if strings.HasPrefix(data, "[REASONING]") { + data = strings.TrimPrefix(data, "[REASONING]") + if reasoningPrint { + fmt.Print("Thinking: ") + reasoningPrint = false + } else { + fmt.Print(data) + } + os.Stdout.Sync() + } + if strings.HasPrefix(data, "[MESSAGE]") { + data = strings.TrimPrefix(data, "[MESSAGE]") + if messagePrint { + if thinking { + fmt.Println() + } + fmt.Print("Answer: ") + messagePrint = false + } else { + fmt.Print(data) + os.Stdout.Sync() + fullMessage.WriteString(data) + } + } + } else if strings.HasPrefix(line, "event:error") { + // error event + if scanner.Scan() { + errData := strings.TrimPrefix(scanner.Text(), "data:") + errData = strings.TrimSpace(errData) + return nil, fmt.Errorf("chat error: %s", errData) + } + // If there's an error, return a generic error + return nil, fmt.Errorf("chat error: received error event from server") + } + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("error reading stream: %w", err) + } + + fmt.Println() + + result := &StreamMessageResponse{ + Code: 0, + Message: fullMessage.String(), + Duration: duration, + } + return result, nil } - if err := scanner.Err(); err != nil { - return nil, fmt.Errorf("error reading stream: %w", err) + resp, err := c.HTTPClient.Request("POST", url, true, "web", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to list instance models: %w", err) } - fmt.Println() - - result := &StreamMessageResponse{ - Code: 0, - Message: fullMessage.String(), - Duration: duration, + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list instance models: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) } - return result, nil + + var result NonStreamResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("failed to list instance models: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil } // UseModel sets the current model for chat diff --git a/internal/cli/user_parser.go b/internal/cli/user_parser.go index c6c97779b2..ff46c0e378 100644 --- a/internal/cli/user_parser.go +++ b/internal/cli/user_parser.go @@ -163,6 +163,8 @@ func (p *Parser) parseListCommand() (*Command, error) { return p.parseListTokens() case TokenModel: return p.parseListModelProviders() + case TokenSupported: + return p.parseListModelsOfProvider() case TokenModels: return p.parseListModelsOfProvider() case TokenProviders: @@ -2014,11 +2016,55 @@ func (p *Parser) parseSearchCommand() (*Command, error) { } func (p *Parser) parseListModelsOfProvider() (*Command, error) { + + if p.curToken.Type == TokenSupported { + // List supported models + p.nextToken() + + cmd := NewCommand("list_supported_models") + if p.curToken.Type != TokenModels { + return nil, fmt.Errorf("expected MODELS") + } + p.nextToken() + + if p.curToken.Type != TokenFrom { + return nil, fmt.Errorf("expected FROM") + } + p.nextToken() + + if p.curToken.Type != TokenQuotedString { + return nil, fmt.Errorf("expected quoted string for provider name") + } + firstName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + p.nextToken() + + if p.curToken.Type != TokenQuotedString { + return nil, fmt.Errorf("expected quoted string for instance name") + } + secondName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + p.nextToken() + + cmd.Params["provider_name"] = firstName + cmd.Params["instance_name"] = secondName + + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return cmd, nil + } + if p.curToken.Type != TokenModels { return nil, fmt.Errorf("expected MODELS") } - p.nextToken() + if p.curToken.Type != TokenFrom { return nil, fmt.Errorf("expected FROM") } @@ -2194,19 +2240,47 @@ func (p *Parser) parseChatCommand() (*Command, error) { cmd.Params["composite_model_name"] = compositeModelName } cmd.Params["message"] = message - cmd.Params["reasoning"] = false + cmd.Params["thinking"] = false + cmd.Params["stream"] = false return cmd, nil } func (p *Parser) parseThinkCommand() (*Command, error) { p.nextToken() // consume THINK + + if p.curToken.Type != TokenChat { + return nil, fmt.Errorf("expected CHAT after THINK") + } + command, err := p.parseChatCommand() if err != nil { return nil, err } - command.Type = "think_chat_to_model" - command.Params["reasoning"] = true + command.Params["thinking"] = true + return command, nil +} + +func (p *Parser) parseStreamCommand() (*Command, error) { + + p.nextToken() // consume STREAM + + var command *Command + var err error + + if p.curToken.Type == TokenChat { + command, err = p.parseChatCommand() + if err != nil { + return nil, err + } + } else if p.curToken.Type == TokenThink { + command, err = p.parseThinkCommand() + if err != nil { + return nil, err + } + } + + command.Params["stream"] = true return command, nil } diff --git a/internal/entity/model.go b/internal/entity/model.go index e60af28a58..e8307b7ae3 100644 --- a/internal/entity/model.go +++ b/internal/entity/model.go @@ -61,14 +61,14 @@ type Reasoning struct { // Reasoning represents the reasoning capability (can be one of three types) type ClearReasoningContent struct { - DefaultValue bool `json:"default_value"` + DefaultValue bool `json:"default_value"` + SupportedModels []string `json:"supported_models"` } // Reasoning represents the reasoning capability (can be one of three types) type Thinking struct { - DefaultValue bool `json:"default_value"` - SupportedModels []string `json:"supported_models"` - Clear ClearReasoningContent `json:"clear"` + DefaultValue bool `json:"default_value"` + SupportedModels []string `json:"supported_models"` } // UnmarshalJSON custom unmarshal for Reasoning @@ -142,9 +142,10 @@ type Multimodal struct { // Features represents all features of a model type Features struct { - Multimodal *Multimodal `json:"multimodal,omitempty"` - Reasoning *Reasoning `json:"reasoning,omitempty"` - Thinking *Thinking `json:"thinking,omitempty"` + Multimodal *Multimodal `json:"multimodal,omitempty"` + Reasoning *Reasoning `json:"reasoning,omitempty"` + Thinking *Thinking `json:"thinking,omitempty"` + ClearThinking *ClearReasoningContent `json:"clear_thinking,omitempty"` } type ModelThinking struct { @@ -231,16 +232,29 @@ func NewProviderManager(dirPath string) (*ProviderManager, error) { } } + modelClearThinking := make(map[string]bool) + if provider.Features.ClearThinking != nil { + for _, modelName := range provider.Features.ClearThinking.SupportedModels { + modelClearThinking[modelName] = true + } + } + for _, model := range provider.Models { // if the prefix of mode.Name is matched with keys of modelSupportThinking for modelPrefix, _ := range modelSupportThinking { if strings.HasPrefix(model.Name, modelPrefix) { model.Thinking = &ModelThinking{ DefaultValue: provider.Features.Thinking.DefaultValue, - ClearContent: provider.Features.Thinking.Clear.DefaultValue, } } } + + for modelPrefix, _ := range modelClearThinking { + if strings.HasPrefix(model.Name, modelPrefix) { + model.Thinking.ClearContent = true + } + } + model.ModelTypeMap = make(map[string]bool) for _, modelType := range model.ModelTypes { model.ModelTypeMap[modelType] = true diff --git a/internal/entity/models/deepseek.go b/internal/entity/models/deepseek.go new file mode 100644 index 0000000000..ef3a81a0f2 --- /dev/null +++ b/internal/entity/models/deepseek.go @@ -0,0 +1,147 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +// DeepSeekModel implements ModelDriver for DeepSeek +type DeepSeekModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client // Reusable HTTP client with connection pool +} + +// NewDeepSeekModel creates a new DeepSeek model instance +func NewDeepSeekModel(baseURL map[string]string, urlSuffix URLSuffix) *DeepSeekModel { + return &DeepSeekModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Timeout: 120 * time.Second, + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + DisableCompression: false, + }, + }, + } +} + +// Chat sends a message and returns response +func (z *DeepSeekModel) Chat(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + return nil, fmt.Errorf("not implemented") +} + +// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel) +func (z *DeepSeekModel) ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { + return fmt.Errorf("not implemented") +} + +// EncodeToEmbedding encodes a list of texts into embeddings +func (z *DeepSeekModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { + return nil, fmt.Errorf("not implemented") +} + +/* +{ + "object": "list", + "data": [ + { + "id": "deepseek-chat", + "object": "model", + "owned_by": "deepseek" + }, + { + "id": "deepseek-reasoner", + "object": "model", + "owned_by": "deepseek" + } + ] +} +*/ + +type Model struct { + ID string `json:"id"` + Object string `json:"object"` + OwnedBy string `json:"owned_by"` +} + +type ModelList struct { + Object string `json:"object"` + Models []Model `json:"data"` +} + +func (z *DeepSeekModel) ListModels(apiConfig *APIConfig) ([]string, error) { + var region = "default" + if apiConfig.Region != nil { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Models) + + // Build request body + reqBody := map[string]interface{}{} + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("GET", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response + var modelList ModelList + if err = json.Unmarshal(body, &modelList); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + var models []string + for _, model := range modelList.Models { + models = append(models, model.ID) + } + + return models, nil +} diff --git a/internal/entity/models/dummy.go b/internal/entity/models/dummy.go index 647abff6a4..ed07ad6647 100644 --- a/internal/entity/models/dummy.go +++ b/internal/entity/models/dummy.go @@ -35,16 +35,20 @@ func NewDummyModel(baseURL map[string]string, urlSuffix URLSuffix) *DummyModel { } // Chat sends a message and returns response -func (z *DummyModel) Chat(modelName, apiKey, message *string, modelConfig *ChatConfig) (string, error) { - return "", fmt.Errorf("not implemented") +func (z *DummyModel) Chat(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig) (*ChatResponse, error) { + return nil, fmt.Errorf("not implemented") } // ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel) -func (z *DummyModel) ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error { +func (z *DummyModel) ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { return fmt.Errorf("not implemented") } // EncodeToEmbedding encodes a list of texts into embeddings -func (z *DummyModel) EncodeToEmbedding(modelName, apiKey *string, texts []string, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +func (z *DummyModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { + return nil, fmt.Errorf("not implemented") +} + +func (z *DummyModel) ListModels(apiConfig *APIConfig) ([]string, error) { return nil, fmt.Errorf("not implemented") } diff --git a/internal/entity/models/factory.go b/internal/entity/models/factory.go index 1b490b0e6a..1a4ef46138 100644 --- a/internal/entity/models/factory.go +++ b/internal/entity/models/factory.go @@ -35,6 +35,10 @@ func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string switch providerLower { case "zhipu-ai": return NewZhipuAIModel(baseURL, urlSuffix), nil + case "deepseek": + return NewDeepSeekModel(baseURL, urlSuffix), nil + case "moonshot": + return NewMooshotModel(baseURL, urlSuffix), nil default: return NewDummyModel(baseURL, urlSuffix), nil } diff --git a/internal/entity/models/moonshot.go b/internal/entity/models/moonshot.go new file mode 100644 index 0000000000..85b16a80a1 --- /dev/null +++ b/internal/entity/models/moonshot.go @@ -0,0 +1,118 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +// MooshotModel implements ModelDriver for Mooshot +type MooshotModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client // Reusable HTTP client with connection pool +} + +// NewMooshotModel creates a new Mooshot model instance +func NewMooshotModel(baseURL map[string]string, urlSuffix URLSuffix) *MooshotModel { + return &MooshotModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Timeout: 120 * time.Second, + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + DisableCompression: false, + }, + }, + } +} + +// Chat sends a message and returns response +func (z *MooshotModel) Chat(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + return nil, fmt.Errorf("not implemented") +} + +// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel) +func (z *MooshotModel) ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { + return fmt.Errorf("not implemented") +} + +// EncodeToEmbedding encodes a list of texts into embeddings +func (z *MooshotModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { + return nil, fmt.Errorf("not implemented") +} + +func (z *MooshotModel) ListModels(apiConfig *APIConfig) ([]string, error) { + var region = "default" + if apiConfig.Region != nil { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Models) + + // Build request body + reqBody := map[string]interface{}{} + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + models, ok := result["models"].([]string) + if !ok || len(models) == 0 { + return nil, fmt.Errorf("no models in response") + } + + return models, nil +} diff --git a/internal/entity/models/types.go b/internal/entity/models/types.go index 7d360796dc..db005e740e 100644 --- a/internal/entity/models/types.go +++ b/internal/entity/models/types.go @@ -3,11 +3,18 @@ package models // EmbeddingModel interface for embedding models type ModelDriver interface { // Chat sends a message and returns response - Chat(modelName, apiKey, message *string, modelConfig *ChatConfig) (string, error) + Chat(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig) (*ChatResponse, error) // ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel) - ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error + ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error // Encode encodes a list of texts into embeddings - EncodeToEmbedding(modelName, apiKey *string, texts []string, embeddingConfig *EmbeddingConfig) ([][]float64, error) + EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) + // List suppported models + ListModels(apiConfig *APIConfig) ([]string, error) +} + +type ChatResponse struct { + Answer *string `json:"answer"` + ReasonContent *string `json:"reason_content"` } // URLSuffix represents the URL suffixes for different API endpoints @@ -17,19 +24,24 @@ type URLSuffix struct { AsyncResult string `json:"async_result"` Embedding string `json:"embedding"` Rerank string `json:"rerank"` + Models string `json:"models"` + Balance string `json:"balance"` } type ChatConfig struct { Stream *bool - Reasoning *bool + Thinking *bool MaxTokens *int Temperature *float64 TopP *float64 DoSample *bool Stop *[]string - Region *string +} + +type APIConfig struct { + ApiKey *string + Region *string } type EmbeddingConfig struct { - Region *string } diff --git a/internal/entity/models/zhipu-ai.go b/internal/entity/models/zhipu-ai.go index 37de091ae6..502593ea9b 100644 --- a/internal/entity/models/zhipu-ai.go +++ b/internal/entity/models/zhipu-ai.go @@ -53,12 +53,17 @@ func NewZhipuAIModel(baseURL map[string]string, urlSuffix URLSuffix) *ZhipuAIMod } // Chat sends a message and returns response -func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, chatModelConfig *ChatConfig) (string, error) { +func (z *ZhipuAIModel) Chat(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { if message == nil { - return "", fmt.Errorf("message is nil") + return nil, fmt.Errorf("message is nil") } - url := fmt.Sprintf("%s/%s", z.BaseURL, z.URLSuffix.Chat) + var region = "default" + if apiConfig.Region != nil { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Chat) // Build request body reqBody := map[string]interface{}{ @@ -70,82 +75,117 @@ func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, chatModelConfig "temperature": 1, } - if chatModelConfig != nil { - if chatModelConfig.MaxTokens != nil { - reqBody["max_tokens"] = *chatModelConfig.MaxTokens - } + if chatModelConfig.Stream != nil { + reqBody["stream"] = *chatModelConfig.Stream + } - if chatModelConfig.Temperature != nil { - reqBody["temperature"] = *chatModelConfig.Temperature - } + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } - if chatModelConfig.TopP != nil { - reqBody["top_p"] = *chatModelConfig.TopP + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + + if chatModelConfig.Thinking != nil { + if *chatModelConfig.Thinking { + reqBody["thinking"] = map[string]interface{}{ + "type": "enabled", + } + } else { + reqBody["thinking"] = map[string]interface{}{ + "type": "disabled", + } } } jsonData, err := json.Marshal(reqBody) if err != nil { - return "", fmt.Errorf("failed to marshal request: %w", err) + return nil, fmt.Errorf("failed to marshal request: %w", err) } req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) + return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) resp, err := z.httpClient.Do(req) if err != nil { - return "", fmt.Errorf("failed to send request: %w", err) + return nil, fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return "", fmt.Errorf("failed to read response: %w", err) + return nil, fmt.Errorf("failed to read response: %w", err) } if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) } // Parse response var result map[string]interface{} - if err := json.Unmarshal(body, &result); err != nil { - return "", fmt.Errorf("failed to parse response: %w", err) + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) } choices, ok := result["choices"].([]interface{}) if !ok || len(choices) == 0 { - return "", fmt.Errorf("no choices in response") + return nil, fmt.Errorf("no choices in response") } firstChoice, ok := choices[0].(map[string]interface{}) if !ok { - return "", fmt.Errorf("invalid choice format") + return nil, fmt.Errorf("invalid choice format") } messageMap, ok := firstChoice["message"].(map[string]interface{}) if !ok { - return "", fmt.Errorf("invalid message format") + return nil, fmt.Errorf("invalid message format") } content, ok := messageMap["content"].(string) if !ok { - return "", fmt.Errorf("invalid content format") + return nil, fmt.Errorf("invalid content format") } - return content, nil + var reasonContent string + if chatModelConfig.Thinking != nil && *chatModelConfig.Thinking { + reasonContent, ok = messageMap["reasoning_content"].(string) + if !ok { + return nil, fmt.Errorf("invalid content format") + } + // if first char of reasonContent is \n remove the '\n' + if reasonContent != "" && reasonContent[0] == '\n' { + reasonContent = reasonContent[1:] + } + } + + chatResponse := &ChatResponse{ + Answer: &content, + ReasonContent: &reasonContent, + } + + return chatResponse, nil } // ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel) -func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { +func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, message *string, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { var region = "default" - if chatModelConfig.Region != nil { - region = *chatModelConfig.Region + if apiConfig.Region != nil { + region = *apiConfig.Region } url := fmt.Sprintf("%s/chat/completions", z.BaseURL[region]) @@ -160,40 +200,38 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string "temperature": 1, } - if chatModelConfig != nil { - if chatModelConfig.Stream != nil { - reqBody["stream"] = *chatModelConfig.Stream - } + if chatModelConfig.Stream != nil { + reqBody["stream"] = *chatModelConfig.Stream + } - if chatModelConfig.MaxTokens != nil { - reqBody["max_tokens"] = *chatModelConfig.MaxTokens - } + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } - if chatModelConfig.Temperature != nil { - reqBody["temperature"] = *chatModelConfig.Temperature - } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } - if chatModelConfig.DoSample != nil { - reqBody["do_sample"] = *chatModelConfig.DoSample - } + if chatModelConfig.DoSample != nil { + reqBody["do_sample"] = *chatModelConfig.DoSample + } - if chatModelConfig.TopP != nil { - reqBody["top_p"] = *chatModelConfig.TopP - } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } - if chatModelConfig.Stop != nil { - reqBody["stop"] = *chatModelConfig.Stop - } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } - if chatModelConfig.Reasoning != nil { - if *chatModelConfig.Reasoning { - reqBody["thinking"] = map[string]interface{}{ - "type": "enabled", - } - } else { - reqBody["thinking"] = map[string]interface{}{ - "type": "disabled", - } + if chatModelConfig.Thinking != nil { + if *chatModelConfig.Thinking { + reqBody["thinking"] = map[string]interface{}{ + "type": "enabled", + } + } else { + reqBody["thinking"] = map[string]interface{}{ + "type": "disabled", } } } @@ -209,7 +247,7 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) resp, err := z.httpClient.Do(req) if err != nil { @@ -292,10 +330,10 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string } // EncodeToEmbedding encodes a list of texts into embeddings -func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []string, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +func (z *ZhipuAIModel) EncodeToEmbedding(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { var region = "default" - if embeddingConfig.Region != nil { - region = *embeddingConfig.Region + if apiConfig.Region != nil { + region = *apiConfig.Region } url := fmt.Sprintf("%s/embedding", z.BaseURL[region]) @@ -319,7 +357,7 @@ func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []stri } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) resp, err := z.httpClient.Do(req) if err != nil { @@ -375,3 +413,7 @@ func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []stri return embeddings, nil } + +func (z *ZhipuAIModel) ListModels(apiConfig *APIConfig) ([]string, error) { + return nil, fmt.Errorf("no such method") +} diff --git a/internal/handler/providers.go b/internal/handler/providers.go index bb4b7a6be3..71ff9c1846 100644 --- a/internal/handler/providers.go +++ b/internal/handler/providers.go @@ -458,6 +458,41 @@ func (h *ProviderHandler) ListInstanceModels(c *gin.Context) { }) return } + + keywords := "" + if queryKeywords := c.Query("supported"); queryKeywords != "" { + keywords = queryKeywords + } + + // convert keywords to small case + keywords = strings.ToLower(keywords) + if keywords == "true" { + // list supported models + + modelList, err := h.modelProviderService.ListSupportedModels(providerName, instanceName, c.GetString("user_id")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": err.Error(), + }) + return + } + + var modelResponse []map[string]string + for _, modelName := range modelList { + modelResponse = append(modelResponse, map[string]string{ + "model_name": modelName, + }) + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + "data": modelResponse, + }) + return + } + modelInstances, err := h.modelProviderService.ListInstanceModels(providerName, instanceName, c.GetString("user_id")) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -533,9 +568,9 @@ func (h *ProviderHandler) EnableOrDisableModel(c *gin.Context) { } type ChatToModelRequest struct { - Message string `json:"message" binding:"required"` - Stream bool `json:"stream"` - Reasoning bool `json:"reasoning"` + Message string `json:"message" binding:"required"` + Stream bool `json:"stream"` + Thinking bool `json:"thinking"` } func (h *ProviderHandler) ChatToModel(c *gin.Context) { @@ -610,19 +645,23 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) { return nil } + apiConfig := models.APIConfig{ + ApiKey: nil, + Region: nil, + } + chatConfig := models.ChatConfig{ - Reasoning: &req.Reasoning, + Thinking: &req.Thinking, Stream: &req.Stream, Stop: &[]string{}, DoSample: nil, MaxTokens: nil, Temperature: nil, TopP: nil, - Region: nil, } // Stream response using sender function (best performance, no channel) - errorCode := h.modelProviderService.ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, req.Message, &chatConfig, sender) + errorCode := h.modelProviderService.ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, req.Message, &apiConfig, &chatConfig, sender) if errorCode != common.CodeSuccess { c.SSEvent("error", "stream failed") @@ -630,19 +669,23 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) { return } + apiConfig := models.APIConfig{ + ApiKey: nil, + Region: nil, + } + chatConfig := models.ChatConfig{ - Reasoning: &req.Reasoning, + Thinking: &req.Thinking, Stream: &req.Stream, Stop: &[]string{}, DoSample: nil, MaxTokens: nil, Temperature: nil, TopP: nil, - Region: nil, } // Non-stream response - response, errorCode, err := h.modelProviderService.ChatToModel(providerName, instanceName, modelName, userID, req.Message, &chatConfig) + response, errorCode, err := h.modelProviderService.ChatToModel(providerName, instanceName, modelName, userID, req.Message, &apiConfig, &chatConfig) if err != nil { c.JSON(http.StatusOK, gin.H{ "code": errorCode, @@ -652,7 +695,8 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) { } c.JSON(http.StatusOK, gin.H{ - "code": 0, - "message": response, + "code": 0, + "reasoning_content": response.ReasonContent, + "answer": response.Answer, }) } diff --git a/internal/service/model_service.go b/internal/service/model_service.go index 7f96db1778..a7aa82d6b8 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -229,6 +229,54 @@ func (m *ModelProviderService) DeleteModelProvider(providerName, userID string) return common.CodeSuccess, nil } +func (m *ModelProviderService) ListSupportedModels(providerName, instanceName, userID string) ([]string, error) { + + // Get tenant ID from user + tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + return nil, errors.New("fail to get tenant") + } + + if len(tenants) == 0 { + return nil, errors.New("user has no tenants") + } + + tenantID := tenants[0].TenantID + + // Check if provider exists + provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + return nil, err + } + + instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return nil, err + } + + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return nil, fmt.Errorf("provider %s not found", providerName) + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, err + } + + apiConfig := &modelModule.APIConfig{ + ApiKey: nil, + Region: nil, + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + return providerInfo.ModelDriver.ListModels(apiConfig) +} + func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName, apiKey, userID, region string) (common.ErrorCode, error) { // Get tenant ID from user tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") @@ -531,7 +579,7 @@ func (m *ModelProviderService) UpdateModelStatus(providerName, instanceName, mod return common.CodeSuccess, nil } -func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName, userID, message string, modelConfig *modelModule.ChatConfig) (*string, common.ErrorCode, error) { +func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName, userID, message string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.ChatConfig) (*modelModule.ChatResponse, common.ErrorCode, error) { // Get tenant ID from user tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") @@ -575,22 +623,23 @@ func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName } region := extra["region"] - modelConfig.Region = ®ion + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey - var response string - response, err = providerInfo.ModelDriver.Chat(&modelName, &instance.APIKey, &message, modelConfig) + var response *modelModule.ChatResponse + response, err = providerInfo.ModelDriver.Chat(&modelName, &message, apiConfig, modelConfig) if err != nil { return nil, common.CodeServerError, err } - return &response, common.CodeSuccess, nil + return response, common.CodeSuccess, nil } return nil, common.CodeServerError, errors.New("model is disabled") } // ChatToModelStreamWithSender streams chat response directly via sender function (best performance, no channel) -func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, message string, modelConfig *modelModule.ChatConfig, sender func(*string, *string) error) common.ErrorCode { +func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, message string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.ChatConfig, sender func(*string, *string) error) common.ErrorCode { // Get tenant ID from user tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") if err != nil { @@ -633,10 +682,11 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc } region := extra["region"] - modelConfig.Region = ®ion + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey // Direct call with sender function - err := providerInfo.ModelDriver.ChatStreamlyWithSender(&modelName, &instance.APIKey, &message, modelConfig, sender) + err = providerInfo.ModelDriver.ChatStreamlyWithSender(&modelName, &message, apiConfig, modelConfig, sender) if err != nil { return common.CodeServerError }