diff --git a/conf/models/openai.json b/conf/models/openai.json index e7c8a61d40..f89c6c0d1d 100644 --- a/conf/models/openai.json +++ b/conf/models/openai.json @@ -1,7 +1,8 @@ { "name": "OpenAI", - "tags": "LLM,TEXT EMBEDDING,TTS,TEXT RE-RANK,SPEECH2TEXT,MODERATION", - "url": "https://api.openai.com/v1", + "url": { + "default": "https://api.openai.com/v1" + }, "url_suffix": { "chat": "chat/completions" }, @@ -10,8 +11,8 @@ "name": "gpt-5.2-pro", "max_tokens": 400000, "model_types": [ - "llm", - "vlm" + "chat", + "vision" ], "features": {} }, @@ -19,8 +20,8 @@ "name": "gpt-5.2", "max_tokens": 400000, "model_types": [ - "llm", - "vlm" + "chat", + "vision" ], "features": {} }, @@ -28,8 +29,8 @@ "name": "gpt-5.1", "max_tokens": 400000, "model_types": [ - "llm", - "vlm" + "chat", + "vision" ], "features": {} }, @@ -37,8 +38,8 @@ "name": "gpt-5.1-chat-latest", "max_tokens": 400000, "model_types": [ - "llm", - "vlm" + "chat", + "vision" ], "features": {} }, @@ -46,8 +47,8 @@ "name": "gpt-5", "max_tokens": 400000, "model_types": [ - "llm", - "vlm" + "chat", + "vision" ], "features": {} }, @@ -55,8 +56,8 @@ "name": "gpt-5-mini", "max_tokens": 400000, "model_types": [ - "llm", - "vlm" + "chat", + "vision" ], "features": {} }, @@ -64,8 +65,8 @@ "name": "gpt-5-nano", "max_tokens": 400000, "model_types": [ - "llm", - "vlm" + "chat", + "vision" ], "features": {} }, @@ -73,8 +74,8 @@ "name": "gpt-5-chat-latest", "max_tokens": 400000, "model_types": [ - "llm", - "vlm" + "chat", + "vision" ], "features": {} }, @@ -82,8 +83,8 @@ "name": "gpt-4.1", "max_tokens": 1047576, "model_types": [ - "llm", - "vlm" + "chat", + "vision" ], "features": {} }, @@ -91,8 +92,8 @@ "name": "gpt-4.1-mini", "max_tokens": 1047576, "model_types": [ - "llm", - "vlm" + "chat", + "vision" ], "features": {} }, @@ -100,8 +101,8 @@ "name": "gpt-4.1-nano", "max_tokens": 1047576, "model_types": [ - "llm", - "vlm" + "chat", + "vision" ], "features": {} }, @@ -109,7 +110,7 @@ "name": "gpt-4.5-preview", "max_tokens": 128000, "model_types": [ - "llm" + "chat" ], "features": {} }, @@ -117,8 +118,8 @@ "name": "o3", "max_tokens": 200000, "model_types": [ - "llm", - "vlm" + "chat", + "vision" ], "features": {} }, @@ -126,8 +127,8 @@ "name": "o4-mini", "max_tokens": 200000, "model_types": [ - "llm", - "vlm" + "chat", + "vision" ], "features": {} }, @@ -135,8 +136,8 @@ "name": "o4-mini-high", "max_tokens": 200000, "model_types": [ - "llm", - "vlm" + "chat", + "vision" ], "features": {} }, @@ -144,8 +145,8 @@ "name": "gpt-4o-mini", "max_tokens": 128000, "model_types": [ - "llm", - "vlm" + "chat", + "vision" ], "features": {} }, @@ -153,8 +154,8 @@ "name": "gpt-4o", "max_tokens": 128000, "model_types": [ - "llm", - "vlm" + "chat", + "vision" ], "features": {} }, @@ -162,7 +163,7 @@ "name": "gpt-3.5-turbo", "max_tokens": 4096, "model_types": [ - "llm" + "chat" ], "features": {} }, @@ -170,7 +171,7 @@ "name": "gpt-3.5-turbo-16k-0613", "max_tokens": 16385, "model_types": [ - "llm" + "chat" ], "features": {} }, @@ -202,7 +203,7 @@ "name": "whisper-1", "max_tokens": 26214400, "model_types": [ - "speech2text" + "asr" ], "features": {} }, @@ -210,7 +211,7 @@ "name": "gpt-4", "max_tokens": 8191, "model_types": [ - "llm" + "chat" ], "features": {} }, @@ -218,7 +219,7 @@ "name": "gpt-4-turbo", "max_tokens": 8191, "model_types": [ - "llm" + "chat" ], "features": {} }, @@ -226,7 +227,7 @@ "name": "gpt-4-32k", "max_tokens": 32768, "model_types": [ - "llm" + "chat" ], "features": {} }, diff --git a/conf/models/xai.json b/conf/models/xai.json index 455069140b..5e12776c92 100644 --- a/conf/models/xai.json +++ b/conf/models/xai.json @@ -1,7 +1,8 @@ { "name": "xAI", - "tags": "LLM", - "url": "https://api.x.ai/v1", + "url": { + "default": "https://api.x.ai/v1" + }, "url_suffix": { "chat": "chat/completions" }, @@ -9,44 +10,38 @@ { "name": "grok-4", "max_tokens": 256000, - "model_types": ["llm"], + "model_types": ["chat"], "features": {} }, { "name": "grok-3", "max_tokens": 131072, - "model_types": ["llm"], + "model_types": ["chat"], "features": {} }, { "name": "grok-3-fast", "max_tokens": 131072, - "model_types": ["llm"], + "model_types": ["chat"], "features": {} }, { "name": "grok-3-mini", "max_tokens": 131072, - "model_types": ["llm"], + "model_types": ["chat"], "features": {} }, { "name": "grok-3-mini-mini-fast", "max_tokens": 131072, - "model_types": ["llm"], + "model_types": ["chat"], "features": {} }, { "name": "grok-2-vision", "max_tokens": 32768, - "model_types": ["vlm"], - "features": { - "multimodal": { - "enabled": true, - "input_modalities": ["image"], - "output_modalities": ["text"] - } - } + "model_types": ["vision"], + "features": {} } ] } \ No newline at end of file diff --git a/conf/models/zhipu-ai.json b/conf/models/zhipu-ai.json index 87fcb43448..f41166acae 100644 --- a/conf/models/zhipu-ai.json +++ b/conf/models/zhipu-ai.json @@ -1,7 +1,8 @@ { "name": "ZHIPU-AI", - "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", - "url": "https://open.bigmodel.cn/api/paas/v4", + "url": { + "default": "https://open.bigmodel.cn/api/paas/v4" + }, "url_suffix": { "chat": "chat/completions", "async_chat": "async/chat/completions", @@ -31,7 +32,7 @@ "max_tokens": 128000, "model_types": [ "chat", - "image2text" + "vision" ], "features": {} }, @@ -71,7 +72,7 @@ "name": "glm-4.5v", "max_tokens": 64000, "model_types": [ - "image2text" + "vision" ], "features": {} }, @@ -151,7 +152,7 @@ "name": "glm-4v", "max_tokens": 2000, "model_types": [ - "image2text" + "vision" ], "features": {} }, @@ -183,7 +184,28 @@ "name": "glm-asr", "max_tokens": 4096, "model_types": [ - "speech2text" + "asr" + ], + "features": {} + }, + { + "name": "glm-tts", + "model_types": [ + "tts" + ], + "features": {} + }, + { + "name": "glm-ocr", + "model_types": [ + "ocr" + ], + "features": {} + }, + { + "name": "glm-rerank", + "model_types": [ + "rerank" ], "features": {} } diff --git a/internal/cli/admin_parser.go b/internal/cli/admin_parser.go index 52ff7d03b2..723aad512a 100644 --- a/internal/cli/admin_parser.go +++ b/internal/cli/admin_parser.go @@ -1187,35 +1187,36 @@ func (p *Parser) parseAdminSetVariable() (*Command, error) { func (p *Parser) parseAdminSetDefault() (*Command, error) { p.nextToken() // consume DEFAULT - var modelType, modelID string + var modelType string switch p.curToken.Type { - case TokenLLM: - modelType = "llm_id" - case TokenVLM: - modelType = "img2txt_id" + case TokenChat: + modelType = "chat" + case TokenVision: + modelType = "vision" case TokenEmbedding: - modelType = "embd_id" - case TokenReranker: - modelType = "reranker_id" + modelType = "embedding" + case TokenRerank: + modelType = "rerank" case TokenASR: - modelType = "asr_id" + modelType = "asr" case TokenTTS: - modelType = "tts_id" + modelType = "tts" + case TokenOCR: + modelType = "ocr" default: return nil, fmt.Errorf("unknown model type: %s", p.curToken.Value) } p.nextToken() - id, err := p.parseQuotedString() + compositeModelName, err := p.parseQuotedString() if err != nil { return nil, err } - modelID = id cmd := NewCommand("set_default_model") cmd.Params["model_type"] = modelType - cmd.Params["model_id"] = modelID + cmd.Params["composite_model_name"] = compositeModelName p.nextToken() // Semicolon is optional for UNSET TOKEN @@ -1254,18 +1255,20 @@ func (p *Parser) parseAdminResetCommand() (*Command, error) { var modelType string switch p.curToken.Type { - case TokenLLM: - modelType = "llm_id" - case TokenVLM: - modelType = "img2txt_id" + case TokenChat: + modelType = "chat" + case TokenVision: + modelType = "vision" case TokenEmbedding: - modelType = "embd_id" - case TokenReranker: - modelType = "reranker_id" + modelType = "embedding" + case TokenRerank: + modelType = "rerank" case TokenASR: - modelType = "asr_id" + modelType = "asr" case TokenTTS: - modelType = "tts_id" + modelType = "tts" + case TokenOCR: + modelType = "ocr" default: return nil, fmt.Errorf("unknown model type: %s", p.curToken.Value) } diff --git a/internal/cli/client.go b/internal/cli/client.go index 39ae488d28..2fbaa30a36 100644 --- a/internal/cli/client.go +++ b/internal/cli/client.go @@ -250,6 +250,8 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) { return c.ShowCurrentModel(cmd) case "set_default_model": return c.SetDefaultModel(cmd) + case "reset_default_model": + return c.ResetDefaultModel(cmd) case "list_user_default_models": return c.ListDefaultModels(cmd) // Dataset, metadata commands diff --git a/internal/cli/common_command.go b/internal/cli/common_command.go index 695e559178..111af8bc6b 100644 --- a/internal/cli/common_command.go +++ b/internal/cli/common_command.go @@ -375,27 +375,29 @@ func (c *RAGFlowClient) ShowModel(cmd *Command) (ResponseIf, error) { func (c *RAGFlowClient) SetDefaultModel(cmd *Command) (ResponseIf, error) { - modeType, ok := cmd.Params["model_type"].(string) + modelType, ok := cmd.Params["model_type"].(string) if !ok { return nil, fmt.Errorf("model_type not provided") } - modelProvider, ok := cmd.Params["model_provider"].(string) - if !ok { - return nil, fmt.Errorf("model_provider not provided") - } - modelInstance, ok := cmd.Params["model_instance"].(string) - if !ok { - return nil, fmt.Errorf("model_instance not provided") - } - modelName, ok := cmd.Params["model_name"].(string) + + compositeModelName, ok := cmd.Params["composite_model_name"].(string) if !ok { return nil, fmt.Errorf("model_name not provided") } + var providerName, instanceName, modelName string + names := strings.Split(compositeModelName, "/") + if len(names) != 3 { + return nil, fmt.Errorf("model name must be in format 'provider/instance/model'") + } + providerName = names[0] + instanceName = names[1] + modelName = names[2] + payload := map[string]interface{}{ - "model_type": modeType, - "model_provider": modelProvider, - "model_instance": modelInstance, + "model_type": modelType, + "model_provider": providerName, + "model_instance": instanceName, "model_name": modelName, } @@ -420,6 +422,38 @@ func (c *RAGFlowClient) SetDefaultModel(cmd *Command) (ResponseIf, error) { return &result, nil } +func (c *RAGFlowClient) ResetDefaultModel(cmd *Command) (ResponseIf, error) { + + modelType, ok := cmd.Params["model_type"].(string) + if !ok { + return nil, fmt.Errorf("model_type not provided") + } + + payload := map[string]interface{}{ + "model_type": modelType, + } + + resp, err := c.HTTPClient.Request("PATCH", "/models", true, "web", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to reset default model: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to reset default model: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result SimpleResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("failed to reset default model: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + func (c *RAGFlowClient) ListDefaultModels(cmd *Command) (ResponseIf, error) { resp, err := c.HTTPClient.Request("GET", "/models", true, "web", nil, nil) if err != nil { diff --git a/internal/cli/lexer.go b/internal/cli/lexer.go index afa990206f..4bef26c34b 100644 --- a/internal/cli/lexer.go +++ b/internal/cli/lexer.go @@ -327,18 +327,18 @@ func (l *Lexer) lookupIdent(ident string) Token { return Token{Type: TokenSearch, Value: ident} case "CURRENT": return Token{Type: TokenCurrent, Value: ident} - case "LLM": - return Token{Type: TokenLLM, Value: ident} - case "VLM": - return Token{Type: TokenVLM, Value: ident} + case "VISION": + return Token{Type: TokenVision, Value: ident} case "EMBEDDING": return Token{Type: TokenEmbedding, Value: ident} - case "RERANKER": - return Token{Type: TokenReranker, Value: ident} + case "RERANK": + return Token{Type: TokenRerank, Value: ident} case "ASR": return Token{Type: TokenASR, Value: ident} case "TTS": return Token{Type: TokenTTS, Value: ident} + case "OCR": + return Token{Type: TokenOCR, Value: ident} case "ASYNC": return Token{Type: TokenAsync, Value: ident} case "SYNC": diff --git a/internal/cli/types.go b/internal/cli/types.go index a3c879a33c..90aeff1cb9 100644 --- a/internal/cli/types.go +++ b/internal/cli/types.go @@ -90,12 +90,12 @@ const ( TokenPipeline TokenSearch TokenCurrent - TokenLLM - TokenVLM + TokenVision TokenEmbedding - TokenReranker + TokenRerank TokenASR TokenTTS + TokenOCR TokenAsync TokenSync TokenBenchmark diff --git a/internal/cli/user_command.go b/internal/cli/user_command.go index 48a3639cb3..e1bb27b1d9 100644 --- a/internal/cli/user_command.go +++ b/internal/cli/user_command.go @@ -1417,8 +1417,8 @@ func (c *RAGFlowClient) ChatToModel(cmd *Command) (ResponseIf, error) { var providerName, instanceName, modelName string - // Check if model_name is provided in command - if compositeModelName, ok := cmd.Params["model_name"].(string); ok && compositeModelName != "" { + // Check if composite_model_name is provided in command + if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" { names := strings.Split(compositeModelName, "/") if len(names) != 3 { return nil, fmt.Errorf("model name must be in format 'provider/instance/model'") @@ -1524,12 +1524,12 @@ func (c *RAGFlowClient) UseModel(cmd *Command) (ResponseIf, error) { return nil, fmt.Errorf("this command is only allowed in USER mode") } - modelIdentifier, ok := cmd.Params["model_identifier"].(string) - if !ok || modelIdentifier == "" { + compositeModelName, ok := cmd.Params["composite_model_name"].(string) + if !ok || compositeModelName == "" { return nil, fmt.Errorf("model identifier not provided") } - names := strings.Split(modelIdentifier, "/") + names := strings.Split(compositeModelName, "/") if len(names) != 3 { return nil, fmt.Errorf("model identifier must be in format 'provider/instance/model'") } diff --git a/internal/cli/user_parser.go b/internal/cli/user_parser.go index 43c9437824..c6c97779b2 100644 --- a/internal/cli/user_parser.go +++ b/internal/cli/user_parser.go @@ -1597,49 +1597,48 @@ func (p *Parser) parseSetVariable() (*Command, error) { func (p *Parser) parseSetDefault() (*Command, error) { p.nextToken() // consume DEFAULT - var modelType, modelProvider, modelInstance, modelName string + var modelType, compositeModelName string var err error switch p.curToken.Type { - case TokenLLM: + case TokenChat: modelType = "chat" - case TokenVLM: - modelType = "image2text" + case TokenVision: + modelType = "vision" case TokenEmbedding: modelType = "embedding" - case TokenReranker: + case TokenRerank: modelType = "rerank" case TokenASR: modelType = "asr" case TokenTTS: modelType = "tts" + case TokenOCR: + modelType = "ocr" default: return nil, fmt.Errorf("unknown model type: %s", p.curToken.Value) } + p.nextToken() // pass model type - p.nextToken() - modelProvider, err = p.parseQuotedString() + if p.curToken.Type != TokenModel { + return nil, fmt.Errorf("expected MODEL") + } + p.nextToken() // pass MODEL + + // Format: 'provider/instance/model' or just 'message' + if p.curToken.Type != TokenQuotedString { + return nil, fmt.Errorf("expected quoted string with format provider/instance/model") + } + + compositeModelName, err = p.parseQuotedString() if err != nil { return nil, err } - p.nextToken() - modelInstance, err = p.parseQuotedString() - if err != nil { - return nil, err - } - - p.nextToken() - modelName, err = p.parseQuotedString() - if err != nil { - return nil, err - } cmd := NewCommand("set_default_model") cmd.Params["model_type"] = modelType - cmd.Params["model_provider"] = modelProvider - cmd.Params["model_instance"] = modelInstance - cmd.Params["model_name"] = modelName + cmd.Params["composite_model_name"] = compositeModelName p.nextToken() // Semicolon is optional for UNSET TOKEN @@ -1717,26 +1716,33 @@ func (p *Parser) parseResetCommand() (*Command, error) { var modelType string switch p.curToken.Type { - case TokenLLM: - modelType = "llm_id" - case TokenVLM: - modelType = "img2txt_id" + case TokenChat: + modelType = "chat" + case TokenVision: + modelType = "vision" case TokenEmbedding: - modelType = "embd_id" - case TokenReranker: - modelType = "reranker_id" + modelType = "embedding" + case TokenRerank: + modelType = "rerank" case TokenASR: - modelType = "asr_id" + modelType = "asr" case TokenTTS: - modelType = "tts_id" + modelType = "tts" + case TokenOCR: + modelType = "ocr" default: return nil, fmt.Errorf("unknown model type: %s", p.curToken.Value) } cmd := NewCommand("reset_default_model") cmd.Params["model_type"] = modelType - p.nextToken() + + if p.curToken.Type != TokenModel { + return nil, fmt.Errorf("expected MODEL") + } + p.nextToken() // pass MODEL + // Semicolon is optional for UNSET TOKEN if p.curToken.Type == TokenSemicolon { p.nextToken() @@ -2144,7 +2150,7 @@ func (p *Parser) parseDisableCommand() (*Command, error) { func (p *Parser) parseChatCommand() (*Command, error) { p.nextToken() // consume CHAT - var modelName string + var compositeModelName string var message string // Check if we have a quoted string that looks like a model identifier (contains two slashes) @@ -2156,7 +2162,7 @@ func (p *Parser) parseChatCommand() (*Command, error) { slashCount := strings.Count(firstArg, "/") if slashCount == 2 { // This is likely a model identifier, expect another quoted string for message - modelName = firstArg + compositeModelName = firstArg p.nextToken() // After model name, expect message @@ -2184,8 +2190,8 @@ func (p *Parser) parseChatCommand() (*Command, error) { } cmd := NewCommand("chat_to_model") - if modelName != "" { - cmd.Params["model_name"] = modelName + if compositeModelName != "" { + cmd.Params["composite_model_name"] = compositeModelName } cmd.Params["message"] = message cmd.Params["reasoning"] = false @@ -2213,7 +2219,7 @@ func (p *Parser) parseUseCommand() (*Command, error) { p.nextToken() // consume MODEL // Parse model identifier in format 'provider/instance/model' - modelIdentifier, err := p.parseQuotedString() + compositeModelName, err := p.parseQuotedString() if err != nil { return nil, fmt.Errorf("expected model identifier in format 'provider/instance/model': %w", err) } @@ -2225,7 +2231,7 @@ func (p *Parser) parseUseCommand() (*Command, error) { } cmd := NewCommand("use_model") - cmd.Params["model_identifier"] = modelIdentifier + cmd.Params["composite_model_name"] = compositeModelName return cmd, nil } diff --git a/internal/dao/tenant.go b/internal/dao/tenant.go index b5540246fa..044adfbcb3 100644 --- a/internal/dao/tenant.go +++ b/internal/dao/tenant.go @@ -62,6 +62,7 @@ type TenantInfo struct { ASRID string `gorm:"column:asr_id" json:"asr_id"` Img2TxtID string `gorm:"column:img2txt_id" json:"img2txt_id"` TTSID *string `gorm:"column:tts_id" json:"tts_id,omitempty"` + OCRID string `gorm:"column:ocr_id" json:"ocr_id"` ParserIDs string `gorm:"column:parser_ids" json:"parser_ids"` Role string `gorm:"column:role" json:"role"` } @@ -71,7 +72,7 @@ func (dao *TenantDAO) GetInfoByUserID(userID string) ([]*TenantInfo, error) { var results []*TenantInfo err := DB.Model(&entity.Tenant{}). - Select("tenant.id as tenant_id, tenant.name, tenant.llm_id, tenant.embd_id, tenant.rerank_id, tenant.asr_id, tenant.img2txt_id, tenant.tts_id, tenant.parser_ids, user_tenant.role"). + Select("tenant.id as tenant_id, tenant.name, tenant.llm_id, tenant.embd_id, tenant.rerank_id, tenant.asr_id, tenant.img2txt_id, tenant.tts_id, tenant.ocr_id, tenant.parser_ids, user_tenant.role"). Joins("INNER JOIN user_tenant ON user_tenant.tenant_id = tenant.id"). Where("user_tenant.user_id = ? AND user_tenant.status = ? AND user_tenant.role = ? AND tenant.status = ?", userID, "1", "owner", "1"). Scan(&results).Error diff --git a/internal/entity/model.go b/internal/entity/model.go index 0017b65663..93125807a4 100644 --- a/internal/entity/model.go +++ b/internal/entity/model.go @@ -145,11 +145,10 @@ type Model struct { // Provider represents an LLM provider type Provider struct { - Name string `json:"name"` - Tags string `json:"tags"` - URL string `json:"url"` - URLSuffix models.URLSuffix `json:"url_suffix"` - Models []*Model `json:"models"` + Name string `json:"name"` + URL map[string]string `json:"url"` + URLSuffix models.URLSuffix `json:"url_suffix"` + Models []*Model `json:"models"` ModelDriver models.ModelDriver } @@ -236,11 +235,24 @@ func (pm *ProviderManager) ListProviders() ([]map[string]interface{}, error) { var providers []map[string]interface{} for _, provider := range pm.Providers { + + modelTypeSet := make(map[string]struct{}) + for _, model := range provider.Models { + for _, modelType := range model.ModelTypes { + modelTypeSet[modelType] = struct{}{} + } + } + + var modelTypes []string + for modelType := range modelTypeSet { + modelTypes = append(modelTypes, modelType) + } + providerData := map[string]interface{}{ - "name": provider.Name, - "tags": provider.Tags, - "url": provider.URL, - "url_suffix": provider.URLSuffix, + "name": provider.Name, + "url": provider.URL, + "model_types": modelTypes, + "url_suffix": provider.URLSuffix, } providers = append(providers, providerData) } @@ -262,7 +274,6 @@ func (pm *ProviderManager) GetProviderByName(providerName string) (map[string]in providerInfo := map[string]interface{}{ "name": provider.Name, - "tags": provider.Tags, "base_url": provider.URL, "total_models": len(provider.Models), } diff --git a/internal/entity/models/dummy.go b/internal/entity/models/dummy.go index ab74463feb..647abff6a4 100644 --- a/internal/entity/models/dummy.go +++ b/internal/entity/models/dummy.go @@ -20,14 +20,14 @@ import ( "fmt" ) -// DummyModel implements ModelDriver for Zhipu AI (智谱 AI) +// DummyModel implements ModelDriver for Zhipu AI type DummyModel struct { - BaseURL string + BaseURL map[string]string URLSuffix URLSuffix } // NewDummyModel creates a new Zhipu AI model instance -func NewDummyModel(baseURL string, urlSuffix URLSuffix) *DummyModel { +func NewDummyModel(baseURL map[string]string, urlSuffix URLSuffix) *DummyModel { return &DummyModel{ BaseURL: baseURL, URLSuffix: urlSuffix, @@ -35,26 +35,16 @@ func NewDummyModel(baseURL string, urlSuffix URLSuffix) *DummyModel { } // Chat sends a message and returns response -func (z *DummyModel) Chat(modelName, apiKey, message *string, genConf map[string]interface{}) (string, error) { +func (z *DummyModel) Chat(modelName, apiKey, message *string, modelConfig *ChatConfig) (string, error) { return "", fmt.Errorf("not implemented") } -// ChatStreamly sends a message and streams response -func (z *DummyModel) ChatStreamly(modelName, apiKey, message *string, genConf map[string]interface{}) (<-chan string, error) { - return nil, fmt.Errorf("not implemented") -} - -// ChatStreamlyWithChannel sends a message and streams response to channel (better performance) -func (z *DummyModel) ChatStreamlyWithChannel(modelName, apiKey, message *string, genConf map[string]interface{}, resultChan chan<- string) error { - return fmt.Errorf("not implemented") -} - // ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel) func (z *DummyModel) ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error { return fmt.Errorf("not implemented") } // EncodeToEmbedding encodes a list of texts into embeddings -func (z *DummyModel) EncodeToEmbedding(modelName, apiKey *string, texts []string) ([][]float64, error) { +func (z *DummyModel) EncodeToEmbedding(modelName, apiKey *string, texts []string, embeddingConfig *EmbeddingConfig) ([][]float64, error) { return nil, fmt.Errorf("not implemented") } diff --git a/internal/entity/models/factory.go b/internal/entity/models/factory.go index 2531c50f23..1b490b0e6a 100644 --- a/internal/entity/models/factory.go +++ b/internal/entity/models/factory.go @@ -30,7 +30,7 @@ func NewModelFactory() *ModelFactory { } // CreateModelDriver creates a ModelDriver for the given provider and model -func (f *ModelFactory) CreateModelDriver(providerName string, baseURL string, urlSuffix URLSuffix) (ModelDriver, error) { +func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string]string, urlSuffix URLSuffix) (ModelDriver, error) { providerLower := strings.ToLower(providerName) switch providerLower { case "zhipu-ai": diff --git a/internal/entity/models/types.go b/internal/entity/models/types.go index dc13db942a..7d360796dc 100644 --- a/internal/entity/models/types.go +++ b/internal/entity/models/types.go @@ -3,15 +3,11 @@ package models // EmbeddingModel interface for embedding models type ModelDriver interface { // Chat sends a message and returns response - Chat(modelName, apiKey, message *string, genConf map[string]interface{}) (string, error) - // ChatStreamly sends a message and streams response - ChatStreamly(modelName, apiKey, message *string, genConf map[string]interface{}) (<-chan string, error) - // ChatStreamlyWithChannel sends a message and streams response to channel (better performance) - ChatStreamlyWithChannel(modelName, apiKey, message *string, genConf map[string]interface{}, resultChan chan<- string) error + Chat(modelName, apiKey, message *string, modelConfig *ChatConfig) (string, error) // ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel) ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error // Encode encodes a list of texts into embeddings - EncodeToEmbedding(modelName, apiKey *string, texts []string) ([][]float64, error) + EncodeToEmbedding(modelName, apiKey *string, texts []string, embeddingConfig *EmbeddingConfig) ([][]float64, error) } // URLSuffix represents the URL suffixes for different API endpoints @@ -31,4 +27,9 @@ type ChatConfig struct { TopP *float64 DoSample *bool Stop *[]string + Region *string +} + +type EmbeddingConfig struct { + Region *string } diff --git a/internal/entity/models/zhipu-ai.go b/internal/entity/models/zhipu-ai.go index 417f242f7c..37de091ae6 100644 --- a/internal/entity/models/zhipu-ai.go +++ b/internal/entity/models/zhipu-ai.go @@ -30,13 +30,13 @@ import ( // ZhipuAIModel implements ModelDriver for Zhipu AI type ZhipuAIModel struct { - BaseURL string + BaseURL map[string]string URLSuffix URLSuffix httpClient *http.Client // Reusable HTTP client with connection pool } // NewZhipuAIModel creates a new Zhipu AI model instance -func NewZhipuAIModel(baseURL string, urlSuffix URLSuffix) *ZhipuAIModel { +func NewZhipuAIModel(baseURL map[string]string, urlSuffix URLSuffix) *ZhipuAIModel { return &ZhipuAIModel{ BaseURL: baseURL, URLSuffix: urlSuffix, @@ -53,7 +53,7 @@ func NewZhipuAIModel(baseURL string, urlSuffix URLSuffix) *ZhipuAIModel { } // Chat sends a message and returns response -func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, genConf map[string]interface{}) (string, error) { +func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, chatModelConfig *ChatConfig) (string, error) { if message == nil { return "", fmt.Errorf("message is nil") } @@ -70,16 +70,17 @@ func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, genConf map[stri "temperature": 1, } - // Add generation config if provided - if genConf != nil { - if maxTokens, ok := genConf["max_tokens"]; ok { - reqBody["max_tokens"] = maxTokens + if chatModelConfig != nil { + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens } - if temperature, ok := genConf["temperature"]; ok { - reqBody["temperature"] = temperature + + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature } - if topP, ok := genConf["top_p"]; ok { - reqBody["top_p"] = topP + + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP } } @@ -140,229 +141,14 @@ func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, genConf map[stri return content, nil } -// ChatStreamly sends a message and streams response -func (z *ZhipuAIModel) ChatStreamly(modelName, apiKey, message *string, genConf map[string]interface{}) (<-chan string, error) { - url := fmt.Sprintf("%s/chat/completions", z.BaseURL) - - // Build request body with streaming enabled - reqBody := map[string]interface{}{ - "model": modelName, - "messages": []map[string]string{ - {"role": "user", "content": *message}, - }, - "stream": true, - "temperature": 1, - } - - // Add generation config if provided - if genConf != nil { - if maxTokens, ok := genConf["max_tokens"]; ok { - reqBody["max_tokens"] = maxTokens - } - if temperature, ok := genConf["temperature"]; ok { - reqBody["temperature"] = temperature - } - if topP, ok := genConf["top_p"]; ok { - reqBody["top_p"] = topP - } - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey)) - - resp, err := z.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - resp.Body.Close() - return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) - } - - // Create channel for streaming - resultChan := make(chan string) - - go func() { - defer close(resultChan) - defer resp.Body.Close() - - // SSE parsing: read line by line - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - line := scanner.Text() - - // SSE data line starts with "data:" - if !strings.HasPrefix(line, "data:") { - continue - } - - // Extract JSON after "data:" - data := strings.TrimSpace(line[5:]) - - // [DONE] marks the end of stream - if data == "[DONE]" { - break - } - - // Parse the JSON event - var event map[string]interface{} - if err := json.Unmarshal([]byte(data), &event); err != nil { - continue - } - - choices, ok := event["choices"].([]interface{}) - if !ok || len(choices) == 0 { - continue - } - - firstChoice, ok := choices[0].(map[string]interface{}) - if !ok { - continue - } - - delta, ok := firstChoice["delta"].(map[string]interface{}) - if !ok { - continue - } - - content, ok := delta["content"].(string) - if ok && content != "" { - resultChan <- content - } - - finishReason, ok := firstChoice["finish_reason"].(string) - if ok && finishReason != "" { - break - } - } - }() - - return resultChan, nil -} - -// ChatStreamlyWithChannel sends a message and streams response to channel (better performance) -func (z *ZhipuAIModel) ChatStreamlyWithChannel(modelName, apiKey, message *string, genConf map[string]interface{}, resultChan chan<- string) error { - url := fmt.Sprintf("%s/chat/completions", z.BaseURL) - - // Build request body with streaming enabled - reqBody := map[string]interface{}{ - "model": modelName, - "messages": []map[string]string{ - {"role": "user", "content": *message}, - }, - "stream": true, - "temperature": 1, - } - - // Add generation config if provided - if genConf != nil { - if maxTokens, ok := genConf["max_tokens"]; ok { - reqBody["max_tokens"] = maxTokens - } - if temperature, ok := genConf["temperature"]; ok { - reqBody["temperature"] = temperature - } - if topP, ok := genConf["top_p"]; ok { - reqBody["top_p"] = topP - } - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return fmt.Errorf("failed to marshal request: %w", err) - } - - req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey)) - - resp, err := z.httpClient.Do(req) - if err != nil { - return fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) - } - - // SSE parsing: read line by line - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - line := scanner.Text() - logger.Info(line) - - // SSE data line starts with "data:" - if !strings.HasPrefix(line, "data:") { - continue - } - - // Extract JSON after "data:" - data := strings.TrimSpace(line[5:]) - - // [DONE] marks the end of stream - if data == "[DONE]" { - break - } - - // Parse the JSON event - var event map[string]interface{} - if err := json.Unmarshal([]byte(data), &event); err != nil { - continue - } - - choices, ok := event["choices"].([]interface{}) - if !ok || len(choices) == 0 { - continue - } - - firstChoice, ok := choices[0].(map[string]interface{}) - if !ok { - continue - } - - delta, ok := firstChoice["delta"].(map[string]interface{}) - if !ok { - continue - } - - content, ok := delta["content"].(string) - if ok && content != "" { - resultChan <- content - } - - finishReason, ok := firstChoice["finish_reason"].(string) - if ok && finishReason != "" { - break - } - } - - // Send [DONE] marker for OpenAI compatibility - resultChan <- "[DONE]" - - return scanner.Err() -} - // ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel) -func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error { - url := fmt.Sprintf("%s/chat/completions", z.BaseURL) +func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { + var region = "default" + if chatModelConfig.Region != nil { + region = *chatModelConfig.Region + } + + url := fmt.Sprintf("%s/chat/completions", z.BaseURL[region]) // Build request body with streaming enabled reqBody := map[string]interface{}{ @@ -374,33 +160,33 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string "temperature": 1, } - if modelConfig != nil { - if modelConfig.Stream != nil { - reqBody["stream"] = *modelConfig.Stream + if chatModelConfig != nil { + if chatModelConfig.Stream != nil { + reqBody["stream"] = *chatModelConfig.Stream } - if modelConfig.MaxTokens != nil { - reqBody["max_tokens"] = *modelConfig.MaxTokens + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens } - if modelConfig.Temperature != nil { - reqBody["temperature"] = *modelConfig.Temperature + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature } - if modelConfig.DoSample != nil { - reqBody["do_sample"] = *modelConfig.DoSample + if chatModelConfig.DoSample != nil { + reqBody["do_sample"] = *chatModelConfig.DoSample } - if modelConfig.TopP != nil { - reqBody["top_p"] = *modelConfig.TopP + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP } - if modelConfig.Stop != nil { - reqBody["stop"] = *modelConfig.Stop + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop } - if modelConfig.Reasoning != nil { - if *modelConfig.Reasoning { + if chatModelConfig.Reasoning != nil { + if *chatModelConfig.Reasoning { reqBody["thinking"] = map[string]interface{}{ "type": "enabled", } @@ -506,8 +292,13 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string } // EncodeToEmbedding encodes a list of texts into embeddings -func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []string) ([][]float64, error) { - url := fmt.Sprintf("%s/embedding", z.BaseURL) +func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []string, embeddingConfig *EmbeddingConfig) ([][]float64, error) { + var region = "default" + if embeddingConfig.Region != nil { + region = *embeddingConfig.Region + } + + url := fmt.Sprintf("%s/embedding", z.BaseURL[region]) embeddings := make([][]float64, len(texts)) diff --git a/internal/entity/tenant.go b/internal/entity/tenant.go index 34a56f0640..0865ab29a0 100644 --- a/internal/entity/tenant.go +++ b/internal/entity/tenant.go @@ -34,6 +34,8 @@ type Tenant struct { TTSID *string `gorm:"column:tts_id;size:256;index" json:"tts_id,omitempty"` TenantTTSID *int64 `gorm:"column:tenant_tts_id;index" json:"tenant_tts_id,omitempty"` ParserIDs string `gorm:"column:parser_ids;size:256;not null;index" json:"parser_ids"` + OCRID string `gorm:"column:ocr_id;size:256;not null" json:"ocr_id"` + TenantOCRID *int64 `gorm:"column:tenant_ocr_id" json:"tenant_ocr_id,omitempty"` Credit int64 `gorm:"column:credit;default:512;index" json:"credit"` Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"` BaseModel diff --git a/internal/entity/tenant_model_instance.go b/internal/entity/tenant_model_instance.go index 0a0a9f5149..8a2ffaa6be 100644 --- a/internal/entity/tenant_model_instance.go +++ b/internal/entity/tenant_model_instance.go @@ -23,6 +23,7 @@ type TenantModelInstance struct { ProviderID string `gorm:"column:provider_id;size:32;not null;uniqueIndex:idx_api_key_provider_id" json:"provider_id"` APIKey string `gorm:"column:api_key;size:512;not null;uniqueIndex:idx_api_key_provider_id" json:"api_key"` Status string `gorm:"column:status;size:32;default:'active'" json:"status"` + Extra string `gorm:"column:extra;size:512;default:'active'" json:"extra"` BaseModel } diff --git a/internal/handler/providers.go b/internal/handler/providers.go index 5c9f4fdc08..bb4b7a6be3 100644 --- a/internal/handler/providers.go +++ b/internal/handler/providers.go @@ -192,7 +192,7 @@ func (h *ProviderHandler) ListModels(c *gin.Context) { }) return } - models, err := dao.GetModelProviderManager().ListModels(providerName) + providerModels, err := dao.GetModelProviderManager().ListModels(providerName) if err != nil { c.JSON(http.StatusOK, gin.H{ "code": common.CodeNotFound, @@ -203,7 +203,7 @@ func (h *ProviderHandler) ListModels(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "code": 0, "message": "success", - "data": models, + "data": providerModels, }) } @@ -274,7 +274,7 @@ func (h *ProviderHandler) CreateProviderInstance(c *gin.Context) { userID := c.GetString("user_id") - _, err := h.modelProviderService.CreateProviderInstance(providerName, req.InstanceName, req.APIKey, userID) + _, err := h.modelProviderService.CreateProviderInstance(providerName, req.InstanceName, req.APIKey, userID, "default") if err != nil { c.JSON(http.StatusOK, gin.H{ "code": common.CodeServerError, @@ -458,7 +458,7 @@ func (h *ProviderHandler) ListInstanceModels(c *gin.Context) { }) return } - models, err := h.modelProviderService.ListInstanceModels(providerName, instanceName, c.GetString("user_id")) + modelInstances, err := h.modelProviderService.ListInstanceModels(providerName, instanceName, c.GetString("user_id")) if err != nil { c.JSON(http.StatusOK, gin.H{ "code": common.CodeNotFound, @@ -469,7 +469,7 @@ func (h *ProviderHandler) ListInstanceModels(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "code": 0, "message": "success", - "data": models, + "data": modelInstances, }) } @@ -618,6 +618,7 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) { MaxTokens: nil, Temperature: nil, TopP: nil, + Region: nil, } // Stream response using sender function (best performance, no channel) @@ -629,8 +630,19 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) { return } + chatConfig := models.ChatConfig{ + Reasoning: &req.Reasoning, + Stream: &req.Stream, + Stop: &[]string{}, + DoSample: nil, + MaxTokens: nil, + Temperature: nil, + TopP: nil, + Region: nil, + } + // Non-stream response - response, errorCode, err := h.modelProviderService.ChatToModel(providerName, instanceName, modelName, userID, req.Message) + response, errorCode, err := h.modelProviderService.ChatToModel(providerName, instanceName, modelName, userID, req.Message, &chatConfig) if err != nil { c.JSON(http.StatusOK, gin.H{ "code": errorCode, diff --git a/internal/handler/tenant.go b/internal/handler/tenant.go index b01515af63..90fcde4580 100644 --- a/internal/handler/tenant.go +++ b/internal/handler/tenant.go @@ -76,9 +76,9 @@ func (h *TenantHandler) GetModels(c *gin.Context) { } type SetModelRequest struct { - ModelProvider string `json:"model_provider" binding:"required"` - ModelInstance string `json:"model_instance" binding:"required"` - ModelName string `json:"model_name" binding:"required"` + ModelProvider string `json:"model_provider"` + ModelInstance string `json:"model_instance"` + ModelName string `json:"model_name"` ModelType string `json:"model_type" binding:"required"` } diff --git a/internal/service/model_service.go b/internal/service/model_service.go index b3d0685180..7f96db1778 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -18,6 +18,7 @@ package service import ( "context" + "encoding/json" "errors" "fmt" "net/http" @@ -228,7 +229,7 @@ func (m *ModelProviderService) DeleteModelProvider(providerName, userID string) return common.CodeSuccess, nil } -func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName, apiKey, userID string) (common.ErrorCode, error) { +func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName, apiKey, userID, region string) (common.ErrorCode, error) { // Get tenant ID from user tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") if err != nil { @@ -252,6 +253,15 @@ func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName return common.CodeServerError, errors.New("fail to get UUID") } + extra := make(map[string]string) + extra["region"] = region + // convert extra to string + extraByte, err := json.Marshal(extra) + if err != nil { + return common.CodeServerError, errors.New("fail to marshal extra") + } + extraStr := string(extraByte) + now := time.Now().Unix() nowDate := time.Now().Truncate(time.Second) tenantModelProvider := &entity.TenantModelInstance{ @@ -259,7 +269,8 @@ func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName InstanceName: instanceName, ProviderID: provider.ID, APIKey: apiKey, - Status: "active", + Status: "enable", + Extra: extraStr, } tenantModelProvider.CreateTime = &now tenantModelProvider.UpdateTime = &now @@ -301,12 +312,20 @@ func (m *ModelProviderService) ListProviderInstances(providerName, userID string var result []map[string]interface{} for _, instance := range instances { + // convert instance.Extra (json string) to map + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, common.CodeServerError, err + } + result = append(result, map[string]interface{}{ "id": instance.ID, "instanceName": instance.InstanceName, "providerID": instance.ProviderID, "apiKey": instance.APIKey, "status": instance.Status, + "region": extra["region"], }) } @@ -338,11 +357,19 @@ func (m *ModelProviderService) ShowProviderInstance(providerName, instanceName, return nil, common.CodeServerError, err } + // convert instance.Extra (json string) to map + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, common.CodeServerError, err + } + result := map[string]interface{}{ "id": instance.ID, "instanceName": instance.InstanceName, "providerID": instance.ProviderID, "status": instance.Status, + "region": extra["region"], } return result, common.CodeSuccess, nil @@ -504,7 +531,7 @@ func (m *ModelProviderService) UpdateModelStatus(providerName, instanceName, mod return common.CodeSuccess, nil } -func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName, userID, message string) (*string, common.ErrorCode, error) { +func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName, userID, message string, modelConfig *modelModule.ChatConfig) (*string, common.ErrorCode, error) { // Get tenant ID from user tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") @@ -541,8 +568,17 @@ func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) } + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, common.CodeServerError, err + } + + region := extra["region"] + modelConfig.Region = ®ion + var response string - response, err = providerInfo.ModelDriver.Chat(&modelName, &instance.APIKey, &message, nil) + response, err = providerInfo.ModelDriver.Chat(&modelName, &instance.APIKey, &message, modelConfig) if err != nil { return nil, common.CodeServerError, err } @@ -553,77 +589,6 @@ func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName return nil, common.CodeServerError, errors.New("model is disabled") } -// ChatToModelStream -func (m *ModelProviderService) ChatToModelStream(providerName, instanceName, modelName, userID, message string) (<-chan string, <-chan error, common.ErrorCode, error) { - streamChan := make(chan string) - errChan := make(chan error, 1) - - // Get tenant ID from user - tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") - if err != nil { - close(streamChan) - close(errChan) - return streamChan, errChan, common.CodeServerError, err - } - - if len(tenants) == 0 { - close(streamChan) - close(errChan) - return streamChan, errChan, common.CodeNotFound, errors.New("user has no tenants") - } - - tenantID := tenants[0].TenantID - - // Check if provider exists - provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) - if err != nil { - close(streamChan) - close(errChan) - return streamChan, errChan, common.CodeServerError, err - } - - instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) - if err != nil { - close(streamChan) - close(errChan) - return streamChan, errChan, common.CodeServerError, err - } - - _, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) - if err != nil { - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { - close(streamChan) - close(errChan) - return streamChan, errChan, common.CodeNotFound, errors.New("provider not found") - } - - _, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) - if err != nil { - close(streamChan) - close(errChan) - return streamChan, errChan, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) - } - - // Async call stream interface using channel for better performance - go func() { - defer close(streamChan) - defer close(errChan) - - err := providerInfo.ModelDriver.ChatStreamlyWithChannel(&modelName, &instance.APIKey, &message, nil, streamChan) - if err != nil { - errChan <- err - } - }() - - return streamChan, errChan, common.CodeSuccess, nil - } - - close(streamChan) - close(errChan) - return streamChan, errChan, common.CodeServerError, errors.New("model is disabled") -} - // ChatToModelStreamWithSender streams chat response directly via sender function (best performance, no channel) func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, message string, modelConfig *modelModule.ChatConfig, sender func(*string, *string) error) common.ErrorCode { // Get tenant ID from user @@ -661,6 +626,15 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc return common.CodeNotFound } + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return common.CodeServerError + } + + region := extra["region"] + modelConfig.Region = ®ion + // Direct call with sender function err := providerInfo.ModelDriver.ChatStreamlyWithSender(&modelName, &instance.APIKey, &message, modelConfig, sender) if err != nil { diff --git a/internal/service/tenant.go b/internal/service/tenant.go index 5b8a2d33a5..e994d08c14 100644 --- a/internal/service/tenant.go +++ b/internal/service/tenant.go @@ -303,31 +303,6 @@ type ModelItem struct { type DefaultModelResponse struct { Models []ModelItem `json:"models,omitempty"` - //TenantID string `json:"tenant_id"` - //ChatModelProvider *string `json:"chat_model_provider"` - //ChatModelInstance *string `json:"chat_model_instance"` - //ChatModelName *string `json:"chat_model_name"` - //ChatModelEnable bool `json:"chat_model_enable"` - //EmbeddingModelProvider *string `json:"embedding_model_provider"` - //EmbeddingModelInstance *string `json:"embedding_model_instance"` - //EmbeddingModelName *string `json:"embedding_model_name"` - //EmbeddingModelEnable bool `json:"embedding_model_enable"` - //RerankModelProvider *string `json:"rerank_model_provider"` - //RerankModelInstance *string `json:"rerank_model_instance"` - //RerankModelName *string `json:"rerank_model_name"` - //RerankModelEnable bool `json:"rerank_model_enable"` - //ASRModelProvider *string `json:"asr_model_provider"` - //ASRModelInstance *string `json:"asr_model_instance"` - //ASRModelName *string `json:"asr_model_name"` - //ASREnable bool `json:"asr_enable"` - //Image2TextModelProvider *string `json:"image2text_model_provider"` - //Image2TextModelInstance *string `json:"image2text_model_instance"` - //Image2TextModelName *string `json:"image2text_model_name"` - //Image2TextModelEnable bool `json:"image2text_model_enable"` - //TTSModelProvider *string `json:"tts_model_provider"` - //TTSModelInstance *string `json:"tts_model_instance"` - //TTSModelName *string `json:"tts_model_name"` - //TTSModelEnable bool `json:"tts_model_enable"` } func (s *TenantService) GetModelInfo(tenantID string, defaultModel string, modelType string) (*string, *string, *string, bool, error) { @@ -351,6 +326,12 @@ func (s *TenantService) GetModelInfo(tenantID string, defaultModel string, model return nil, nil, nil, false, fmt.Errorf("invalid model string: %s", defaultModel) } + if modelType == "ocr" { + if *providerName == "infiniflow" && *instanceName == "default" && *modelName == "deepdoc" { + return providerName, instanceName, modelName, true, nil + } + } + // Check if the provider and instance exists modelProvider, err := s.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, *providerName) if err != nil { @@ -406,7 +387,7 @@ func (s *TenantService) ListTenantDefaultModels(userID string) ([]ModelItem, err ModelProvider: defaultChatModelProvider, ModelInstance: defaultChatModelInstance, ModelName: defaultChatModelName, - ModelType: "llm", + ModelType: "chat", Enable: defaultChatModelEnable, }) } @@ -444,17 +425,28 @@ func (s *TenantService) ListTenantDefaultModels(userID string) ([]ModelItem, err }) } - defaultImage2TextModelProvider, defaultImage2TextModelInstance, defaultImage2TextModelName, defaultImage2TextModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.Img2TxtID, "image2text") + defaultImage2TextModelProvider, defaultImage2TextModelInstance, defaultImage2TextModelName, defaultImage2TextModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.Img2TxtID, "vision") if err == nil { result = append(result, ModelItem{ ModelProvider: defaultImage2TextModelProvider, ModelInstance: defaultImage2TextModelInstance, ModelName: defaultImage2TextModelName, - ModelType: "image2text", + ModelType: "vision", Enable: defaultImage2TextModelEnable, }) } + defaultOCRModelProvider, defaultOCRModelInstance, defaultOCRModelName, defaultOCRModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.OCRID, "ocr") + if err == nil { + result = append(result, ModelItem{ + ModelProvider: defaultOCRModelProvider, + ModelInstance: defaultOCRModelInstance, + ModelName: defaultOCRModelName, + ModelType: "ocr", + Enable: defaultOCRModelEnable, + }) + } + if ownedTenant.TTSID == nil { return result, nil } @@ -518,11 +510,7 @@ func (s *TenantService) SetTenantDefaultModels(userID, modelProvider, modelInsta } ownedTenant := tenantInfos[0] - err = s.checkModelAvailable(ownedTenant.TenantID, modelProvider, modelInstance, modelName, modelType) - if err != nil { - return err - } - + var defaultModel string var modelTypeID string if modelType == "chat" { modelTypeID = "llm_id" @@ -536,17 +524,31 @@ func (s *TenantService) SetTenantDefaultModels(userID, modelProvider, modelInsta if modelType == "asr" { modelTypeID = "asr_id" } - if modelType == "image2text" { + if modelType == "vision" { modelTypeID = "img2txt_id" } if modelType == "tts" { modelTypeID = "tts_id" } + if modelType == "ocr" { + modelTypeID = "ocr_id" + } if modelTypeID == "" { return fmt.Errorf("model type %s is invalid", modelType) } - defaultModel := fmt.Sprintf("%s@%s@%s", modelName, modelInstance, modelProvider) + if modelProvider == "" && modelInstance == "" && modelName == "" { + defaultModel = "" + } else if modelProvider != "" && modelInstance != "" && modelName != "" { + err = s.checkModelAvailable(ownedTenant.TenantID, modelProvider, modelInstance, modelName, modelType) + if err != nil { + return err + } + defaultModel = fmt.Sprintf("%s@%s@%s", modelName, modelInstance, modelProvider) + } else { + return fmt.Errorf("model provider, instance and name must be specified together") + } + err = s.tenantDAO.Update(ownedTenant.TenantID, map[string]interface{}{ modelTypeID: defaultModel, })