diff --git a/internal/cli/admin_parser.go b/internal/cli/admin_parser.go index 08654a898c..2b1a44bacb 100644 --- a/internal/cli/admin_parser.go +++ b/internal/cli/admin_parser.go @@ -18,6 +18,7 @@ package cli import ( "fmt" + "ragflow/internal/common" "strings" ) @@ -1301,14 +1302,20 @@ func (p *Parser) parseAdminSetDefault() (*Command, error) { } p.nextToken() - compositeModelName, err := p.parseQuotedString() + modelNameOrID, err := p.parseQuotedString() if err != nil { return nil, err } cmd := NewCommand("set_default_model") cmd.Params["model_type"] = modelType - cmd.Params["composite_model_name"] = compositeModelName + if common.IsCompositeModelName(modelNameOrID) { + cmd.Params["composite_model_name"] = modelNameOrID + } else if common.IsUUID(modelNameOrID) { + cmd.Params["model_id"] = modelNameOrID + } else { + return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID) + } p.nextToken() // Semicolon is optional for UNSET TOKEN diff --git a/internal/cli/common_command.go b/internal/cli/common_command.go index 9b190126a0..f523913098 100644 --- a/internal/cli/common_command.go +++ b/internal/cli/common_command.go @@ -23,6 +23,7 @@ import ( "fmt" "net/http" "os" + "ragflow/internal/common" "strings" "time" @@ -464,13 +465,11 @@ func (c *CLI) SetDefaultModel(cmd *Command) (ResponseIf, error) { } var providerName, instanceName, modelName string - names := strings.Split(compositeModelName, "/") - if len(names) != 3 { - return nil, fmt.Errorf("model name must be in format 'provider/instance/model'") + var err error + modelName, instanceName, providerName, err = common.ExtractCompositeName(compositeModelName) + if err != nil { + return nil, err } - providerName = names[0] - instanceName = names[1] - modelName = names[2] payload := map[string]interface{}{ "model_type": modelType, @@ -480,7 +479,6 @@ func (c *CLI) SetDefaultModel(cmd *Command) (ResponseIf, error) { } var resp *Response - var err error switch c.Config.CLIMode { case AdminMode: resp, err = c.AdminServerClient.Request("PATCH", "/admin/models", "web", nil, payload) @@ -560,7 +558,7 @@ func (c *CLI) ListDefaultModels(cmd *Command) (ResponseIf, error) { case AdminMode: resp, err = c.AdminServerClient.Request("GET", "/admin/models", "web", nil, nil) case APIMode: - resp, err = c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer].Request("GET", "/models", "web", nil, nil) + resp, err = c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer].Request("GET", "/models/default", "web", nil, nil) default: return nil, fmt.Errorf("invalid server type") } @@ -573,7 +571,7 @@ func (c *CLI) ListDefaultModels(cmd *Command) (ResponseIf, error) { return nil, fmt.Errorf("failed to list default models: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) } - var result CommonResponse + var result ModelsResponse if err = json.Unmarshal(resp.Body, &result); err != nil { return nil, fmt.Errorf("failed to list default models: invalid JSON (%w)", err) } diff --git a/internal/cli/http_client.go b/internal/cli/http_client.go index 9bb63f46a3..9f270ef176 100644 --- a/internal/cli/http_client.go +++ b/internal/cli/http_client.go @@ -374,6 +374,7 @@ type CurrentModel struct { Provider string Instance string Model string + ModelID string } // httpClientAdapter adapts HTTPClient to ce.HTTPClientInterface diff --git a/internal/cli/response.go b/internal/cli/response.go index 4f8654a2c4..c2b3efbfa2 100644 --- a/internal/cli/response.go +++ b/internal/cli/response.go @@ -57,6 +57,36 @@ func (r *CommonResponse) PrintOut() { } } +type ModelsResponse struct { + Code int `json:"code"` + Data map[string][]map[string]interface{} `json:"data"` + Message string `json:"message"` + Duration float64 + OutputFormat OutputFormat +} + +func (r *ModelsResponse) Type() string { + return "models" +} + +func (r *ModelsResponse) TimeCost() float64 { + return r.Duration +} + +func (r *ModelsResponse) SetOutputFormat(format OutputFormat) { + r.OutputFormat = format +} + +func (r *ModelsResponse) PrintOut() { + if r.Code == 0 { + models := r.Data["models"] + PrintTableSimpleByFormat(models, r.OutputFormat) + } else { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + } +} + type CommonDataResponse struct { Code int `json:"code"` Data map[string]interface{} `json:"data"` diff --git a/internal/cli/user_command.go b/internal/cli/user_command.go index 9f58f7eca2..1f6cc1ac9a 100644 --- a/internal/cli/user_command.go +++ b/internal/cli/user_command.go @@ -28,6 +28,7 @@ import ( "os" "os/exec" "path/filepath" + "ragflow/internal/common" "ragflow/internal/ingestion" "ragflow/internal/ingestion/parser" "ragflow/internal/utility" @@ -1632,23 +1633,34 @@ func (c *CLI) ChatToModel(cmd *Command) (ResponseIf, error) { } var providerName, instanceName, modelName string + var err error - // Check if composite_model_name is provided in command - if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" { - names := strings.Split(compositeModelName, "@") - if len(names) != 3 { - return nil, fmt.Errorf("model name must be in format 'model@instance@provider'") + compositeModelName, ok := cmd.Params["composite_model_name"].(string) + if ok { + modelName, instanceName, providerName, err = common.ExtractCompositeName(compositeModelName) + if err != nil { + return nil, err } - providerName = names[2] - instanceName = names[1] - modelName = names[0] - } else if c.CurrentModel != nil { + } + + modelID, ok := cmd.Params["model_id"].(string) + if !ok { + modelID = "" + } + + if modelID == "" && compositeModelName == "" { + if c.CurrentModel == nil { + return nil, fmt.Errorf("model name or ID not provided and no current model set. Use 'use model' command first") + } + // Use current model if set - providerName = c.CurrentModel.Provider - instanceName = c.CurrentModel.Instance - modelName = c.CurrentModel.Model - } else { - return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first") + if c.CurrentModel.ModelID != "" { + modelID = c.CurrentModel.ModelID + } else { + providerName = c.CurrentModel.Provider + instanceName = c.CurrentModel.Instance + modelName = c.CurrentModel.Model + } } formattedMessages := []map[string]interface{}{} @@ -1773,12 +1785,16 @@ func (c *CLI) ChatToModel(cmd *Command) (ResponseIf, error) { url := "/chat/completions" payload := map[string]interface{}{ - "provider_name": providerName, - "instance_name": instanceName, - "model_name": modelName, - "messages": formattedMessages, - "stream": stream, - "thinking": thinking, + "messages": formattedMessages, + "stream": stream, + "thinking": thinking, + } + if modelID == "" { + payload["provider_name"] = providerName + payload["instance_name"] = instanceName + payload["model_name"] = modelName + } else { + payload["model_id"] = modelID } if thinking { @@ -1864,12 +1880,12 @@ func (c *CLI) ChatToModel(cmd *Command) (ResponseIf, error) { } if resp.StatusCode != 200 { - return nil, fmt.Errorf("failed to list instance models: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + return nil, fmt.Errorf("failed to chat model: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) } var result NonStreamResponse if err = json.Unmarshal(resp.Body, &result); err != nil { - return nil, fmt.Errorf("failed to list instance models: invalid JSON (%w)", err) + return nil, fmt.Errorf("failed to chat model: invalid JSON (%w)", err) } if result.Code != 0 { @@ -1889,23 +1905,35 @@ func (c *CLI) EmbedUserText(cmd *Command) (ResponseIf, error) { } var providerName, instanceName, modelName string + var err error // Check if composite_model_name is provided in command - if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" { - names := strings.Split(compositeModelName, "@") - if len(names) != 3 { - return nil, fmt.Errorf("model name must be in format 'model@instance@provider'") + compositeModelName, ok := cmd.Params["composite_model_name"].(string) + if ok { + modelName, instanceName, providerName, err = common.ExtractCompositeName(compositeModelName) + if err != nil { + return nil, err } - providerName = names[2] - instanceName = names[1] - modelName = names[0] - } else if c.CurrentModel != nil { + } + + modelID, ok := cmd.Params["model_id"].(string) + if !ok { + modelID = "" + } + + if modelID == "" && compositeModelName == "" { + if c.CurrentModel == nil { + return nil, fmt.Errorf("model name or ID not provided and no current model set. Use 'use model' command first") + } + // Use current model if set - providerName = c.CurrentModel.Provider - instanceName = c.CurrentModel.Instance - modelName = c.CurrentModel.Model - } else { - return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first") + if c.CurrentModel.ModelID != "" { + modelID = c.CurrentModel.ModelID + } else { + providerName = c.CurrentModel.Provider + instanceName = c.CurrentModel.Instance + modelName = c.CurrentModel.Model + } } texts, ok := cmd.Params["texts"].([]string) @@ -1919,11 +1947,15 @@ func (c *CLI) EmbedUserText(cmd *Command) (ResponseIf, error) { } payload := map[string]interface{}{ - "provider_name": providerName, - "instance_name": instanceName, - "model_name": modelName, - "texts": texts, - "dimension": dimension, + "texts": texts, + "dimension": dimension, + } + if modelID == "" { + payload["provider_name"] = providerName + payload["instance_name"] = instanceName + payload["model_name"] = modelName + } else { + payload["model_id"] = modelID } url := "/embeddings" @@ -1956,23 +1988,35 @@ func (c *CLI) RerankUserDocument(cmd *Command) (ResponseIf, error) { } var providerName, instanceName, modelName string + var err error // Check if composite_model_name is provided in command - if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" { - names := strings.Split(compositeModelName, "@") - if len(names) != 3 { - return nil, fmt.Errorf("model name must be in format 'model@instance@provider'") + compositeModelName, ok := cmd.Params["composite_model_name"].(string) + if ok { + modelName, instanceName, providerName, err = common.ExtractCompositeName(compositeModelName) + if err != nil { + return nil, err } - providerName = names[2] - instanceName = names[1] - modelName = names[0] - } else if c.CurrentModel != nil { + } + + modelID, ok := cmd.Params["model_id"].(string) + if !ok { + modelID = "" + } + + if modelID == "" && compositeModelName == "" { + if c.CurrentModel == nil { + return nil, fmt.Errorf("model name or ID not provided and no current model set. Use 'use model' command first") + } + // Use current model if set - providerName = c.CurrentModel.Provider - instanceName = c.CurrentModel.Instance - modelName = c.CurrentModel.Model - } else { - return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first") + if c.CurrentModel.ModelID != "" { + modelID = c.CurrentModel.ModelID + } else { + providerName = c.CurrentModel.Provider + instanceName = c.CurrentModel.Instance + modelName = c.CurrentModel.Model + } } query, ok := cmd.Params["query"].(string) @@ -1991,12 +2035,16 @@ func (c *CLI) RerankUserDocument(cmd *Command) (ResponseIf, error) { } payload := map[string]interface{}{ - "provider_name": providerName, - "instance_name": instanceName, - "model_name": modelName, - "query": query, - "documents": documents, - "top_n": topN, + "query": query, + "documents": documents, + "top_n": topN, + } + if modelID == "" { + payload["provider_name"] = providerName + payload["instance_name"] = instanceName + payload["model_name"] = modelName + } else { + payload["model_id"] = modelID } url := "/rerank" @@ -2029,23 +2077,35 @@ func (c *CLI) TTSUserCommand(cmd *Command) (ResponseIf, error) { } var providerName, instanceName, modelName string + var err error // Check if composite_model_name is provided in command - if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" { - names := strings.Split(compositeModelName, "@") - if len(names) != 3 { - return nil, fmt.Errorf("model name must be in format 'model@instance@provider'") + compositeModelName, ok := cmd.Params["composite_model_name"].(string) + if ok { + modelName, instanceName, providerName, err = common.ExtractCompositeName(compositeModelName) + if err != nil { + return nil, err } - providerName = names[2] - instanceName = names[1] - modelName = names[0] - } else if c.CurrentModel != nil { + } + + modelID, ok := cmd.Params["model_id"].(string) + if !ok { + modelID = "" + } + + if modelID == "" && compositeModelName == "" { + if c.CurrentModel == nil { + return nil, fmt.Errorf("model name or ID not provided and no current model set. Use 'use model' command first") + } + // Use current model if set - providerName = c.CurrentModel.Provider - instanceName = c.CurrentModel.Instance - modelName = c.CurrentModel.Model - } else { - return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first") + if c.CurrentModel.ModelID != "" { + modelID = c.CurrentModel.ModelID + } else { + providerName = c.CurrentModel.Provider + instanceName = c.CurrentModel.Instance + modelName = c.CurrentModel.Model + } } text, ok := cmd.Params["text"].(string) @@ -2059,10 +2119,14 @@ func (c *CLI) TTSUserCommand(cmd *Command) (ResponseIf, error) { //} payload := map[string]interface{}{ - "provider_name": providerName, - "instance_name": instanceName, - "model_name": modelName, - "text": text, + "text": text, + } + if modelID == "" { + payload["provider_name"] = providerName + payload["instance_name"] = instanceName + payload["model_name"] = modelName + } else { + payload["model_id"] = modelID } ttsConfigPayload := make(map[string]interface{}) @@ -2221,23 +2285,35 @@ func (c *CLI) ASRUserCommand(cmd *Command) (ResponseIf, error) { } var providerName, instanceName, modelName string + var err error // Check if composite_model_name is provided in command - if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" { - names := strings.Split(compositeModelName, "@") - if len(names) != 3 { - return nil, fmt.Errorf("model name must be in format 'model@instance@provider'") + compositeModelName, ok := cmd.Params["composite_model_name"].(string) + if ok { + modelName, instanceName, providerName, err = common.ExtractCompositeName(compositeModelName) + if err != nil { + return nil, err } - providerName = names[2] - instanceName = names[1] - modelName = names[0] - } else if c.CurrentModel != nil { + } + + modelID, ok := cmd.Params["model_id"].(string) + if !ok { + modelID = "" + } + + if modelID == "" && compositeModelName == "" { + if c.CurrentModel == nil { + return nil, fmt.Errorf("model name or ID not provided and no current model set. Use 'use model' command first") + } + // Use current model if set - providerName = c.CurrentModel.Provider - instanceName = c.CurrentModel.Instance - modelName = c.CurrentModel.Model - } else { - return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first") + if c.CurrentModel.ModelID != "" { + modelID = c.CurrentModel.ModelID + } else { + providerName = c.CurrentModel.Provider + instanceName = c.CurrentModel.Instance + modelName = c.CurrentModel.Model + } } audioFile, ok := cmd.Params["audio_file"].(string) @@ -2246,10 +2322,15 @@ func (c *CLI) ASRUserCommand(cmd *Command) (ResponseIf, error) { } payload := map[string]interface{}{ - "provider_name": providerName, - "instance_name": instanceName, - "model_name": modelName, - "file": audioFile, + "file": audioFile, + } + + if modelID == "" { + payload["provider_name"] = providerName + payload["instance_name"] = instanceName + payload["model_name"] = modelName + } else { + payload["model_id"] = modelID } asrConfigPayload := make(map[string]interface{}) @@ -2308,28 +2389,38 @@ func (c *CLI) OCRUserCommand(cmd *Command) (ResponseIf, error) { } var providerName, instanceName, modelName string + var err error // Check if composite_model_name is provided in command - if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" { - names := strings.Split(compositeModelName, "@") - if len(names) != 3 { - return nil, fmt.Errorf("model name must be in format 'model@instance@provider'") + compositeModelName, ok := cmd.Params["composite_model_name"].(string) + if ok { + modelName, instanceName, providerName, err = common.ExtractCompositeName(compositeModelName) + if err != nil { + return nil, err } - providerName = names[2] - instanceName = names[1] - modelName = names[0] - } else if c.CurrentModel != nil { - // Use current model if set - providerName = c.CurrentModel.Provider - instanceName = c.CurrentModel.Instance - modelName = c.CurrentModel.Model - } else { - return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first") } + modelID, ok := cmd.Params["model_id"].(string) + if !ok { + modelID = "" + } + + if modelID == "" && compositeModelName == "" { + if c.CurrentModel == nil { + return nil, fmt.Errorf("model name or ID not provided and no current model set. Use 'use model' command first") + } + + // Use current model if set + if c.CurrentModel.ModelID != "" { + modelID = c.CurrentModel.ModelID + } else { + providerName = c.CurrentModel.Provider + instanceName = c.CurrentModel.Instance + modelName = c.CurrentModel.Model + } + } var filename string var fileURL string - var ok bool var fileContent []byte filename, ok = cmd.Params["file"].(string) @@ -2347,10 +2438,14 @@ func (c *CLI) OCRUserCommand(cmd *Command) (ResponseIf, error) { } } - payload := map[string]interface{}{ - "provider_name": providerName, - "instance_name": instanceName, - "model_name": modelName, + payload := map[string]interface{}{} + + if modelID == "" { + payload["provider_name"] = providerName + payload["instance_name"] = instanceName + payload["model_name"] = modelName + } else { + payload["model_id"] = modelID } if fileContent != nil { @@ -2390,28 +2485,39 @@ func (c *CLI) ParseFileUserCommand(cmd *Command) (ResponseIf, error) { } var providerName, instanceName, modelName string + var err error // Check if composite_model_name is provided in command - if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" { - names := strings.Split(compositeModelName, "@") - if len(names) != 3 { - return nil, fmt.Errorf("model name must be in format 'model@instance@provider'") + compositeModelName, ok := cmd.Params["composite_model_name"].(string) + if ok { + modelName, instanceName, providerName, err = common.ExtractCompositeName(compositeModelName) + if err != nil { + return nil, err } - providerName = names[2] - instanceName = names[1] - modelName = names[0] - } else if c.CurrentModel != nil { + } + + modelID, ok := cmd.Params["model_id"].(string) + if !ok { + modelID = "" + } + + if modelID == "" && compositeModelName == "" { + if c.CurrentModel == nil { + return nil, fmt.Errorf("model name or ID not provided and no current model set. Use 'use model' command first") + } + // Use current model if set - providerName = c.CurrentModel.Provider - instanceName = c.CurrentModel.Instance - modelName = c.CurrentModel.Model - } else { - return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first") + if c.CurrentModel.ModelID != "" { + modelID = c.CurrentModel.ModelID + } else { + providerName = c.CurrentModel.Provider + instanceName = c.CurrentModel.Instance + modelName = c.CurrentModel.Model + } } var filename string var fileURL string - var ok bool var fileContent []byte filename, ok = cmd.Params["file"].(string) @@ -2434,10 +2540,14 @@ func (c *CLI) ParseFileUserCommand(cmd *Command) (ResponseIf, error) { } } - payload := map[string]interface{}{ - "provider_name": providerName, - "instance_name": instanceName, - "model_name": modelName, + payload := map[string]interface{}{} + + if modelID == "" { + payload["provider_name"] = providerName + payload["instance_name"] = instanceName + payload["model_name"] = modelName + } else { + payload["model_id"] = modelID } if fileContent != nil { @@ -2664,20 +2774,30 @@ func (c *CLI) UseModel(cmd *Command) (ResponseIf, error) { return nil, fmt.Errorf("this command is only allowed in USER mode") } + var modelName, instanceName, providerName string + var err error compositeModelName, ok := cmd.Params["composite_model_name"].(string) - if !ok || compositeModelName == "" { - return nil, fmt.Errorf("model identifier not provided") + if ok { + modelName, instanceName, providerName, err = common.ExtractCompositeName(compositeModelName) + if err != nil { + return nil, err + } } - names := strings.Split(compositeModelName, "@") - if len(names) != 3 { - return nil, fmt.Errorf("model identifier must be in format 'model@instance@provider'") + modelID, ok := cmd.Params["model_id"].(string) + if !ok { + modelID = "" + } + + if modelID == "" && compositeModelName == "" { + return nil, fmt.Errorf("model name or ID not provided and no current model set. Use 'use model' command first") } c.CurrentModel = &CurrentModel{ - Provider: names[2], - Instance: names[1], - Model: names[0], + Provider: providerName, + Instance: instanceName, + Model: modelName, + ModelID: modelID, } var result SimpleResponse diff --git a/internal/cli/user_parser.go b/internal/cli/user_parser.go index 8efdd19fbb..c9c8d29d35 100644 --- a/internal/cli/user_parser.go +++ b/internal/cli/user_parser.go @@ -2,6 +2,7 @@ package cli import ( "fmt" + "ragflow/internal/common" "strconv" "strings" ) @@ -2285,7 +2286,7 @@ func (p *Parser) parseSetVariable() (*Command, error) { func (p *Parser) parseSetDefault() (*Command, error) { p.nextToken() // consume DEFAULT - var modelType, compositeModelName string + var modelType, modelNameOrID string var err error switch p.curToken.Type { @@ -2313,12 +2314,12 @@ func (p *Parser) parseSetDefault() (*Command, error) { } p.nextToken() // pass MODEL - // Format: 'provider/instance/model' or just 'message' + // Format: 'model@instance@provider' or just 'message' if p.curToken.Type != TokenQuotedString { - return nil, fmt.Errorf("expected quoted string with format provider/instance/model") + return nil, fmt.Errorf("expected quoted string with format model@instance@provider") } - compositeModelName, err = p.parseQuotedString() + modelNameOrID, err = p.parseQuotedString() if err != nil { return nil, err } @@ -2326,7 +2327,14 @@ func (p *Parser) parseSetDefault() (*Command, error) { cmd := NewCommand("set_default_model") cmd.Params["model_type"] = modelType - cmd.Params["composite_model_name"] = compositeModelName + + if common.IsCompositeModelName(modelNameOrID) { + cmd.Params["composite_model_name"] = modelNameOrID + } else if common.IsUUID(modelNameOrID) { + cmd.Params["model_id"] = modelNameOrID + } else { + return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID) + } p.nextToken() // Semicolon is optional for UNSET TOKEN @@ -3024,7 +3032,7 @@ func (p *Parser) parseChatCommand() (*Command, error) { p.nextToken() // consume CHAT var err error - var compositeModelName string = "" + var modelNameOrID string = "" var messages []string var images []string var videos []string @@ -3038,11 +3046,11 @@ optionsLoop: switch p.curToken.Type { case TokenWith: p.nextToken() - // 'model@instance@provider' - if compositeModelName != "" { - return nil, fmt.Errorf("model name is already set") + // 'model@instance@provider' or model ID + if modelNameOrID != "" { + return nil, fmt.Errorf("model name or ID is already set") } - compositeModelName, err = p.parseQuotedString() + modelNameOrID, err = p.parseQuotedString() if err != nil { return nil, err } @@ -3182,7 +3190,13 @@ optionsLoop: } cmd := NewCommand("chat_to_model") - cmd.Params["composite_model_name"] = compositeModelName + if common.IsCompositeModelName(modelNameOrID) { + cmd.Params["composite_model_name"] = modelNameOrID + } else if common.IsUUID(modelNameOrID) { + cmd.Params["model_id"] = modelNameOrID + } else { + return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID) + } cmd.Params["messages"] = messages cmd.Params["images"] = images cmd.Params["videos"] = videos @@ -3276,7 +3290,7 @@ textLoop: } p.nextToken() // consume WITH - compositeModelName, err := p.parseQuotedString() + modelNameOrID, err := p.parseQuotedString() if err != nil { return nil, err } @@ -3306,7 +3320,13 @@ textLoop: } cmd := NewCommand("embed_user_text") - cmd.Params["composite_model_name"] = compositeModelName + if common.IsCompositeModelName(modelNameOrID) { + cmd.Params["composite_model_name"] = modelNameOrID + } else if common.IsUUID(modelNameOrID) { + cmd.Params["model_id"] = modelNameOrID + } else { + return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID) + } cmd.Params["texts"] = texts if dimension > 0 { cmd.Params["dimension"] = dimension @@ -3356,7 +3376,7 @@ documentLoop: } p.nextToken() // consume WITH - compositeModelName, err := p.parseQuotedString() + modelNameOrID, err := p.parseQuotedString() if err != nil { return nil, err } @@ -3374,7 +3394,13 @@ documentLoop: p.nextToken() cmd := NewCommand("rarank_user_document") - cmd.Params["composite_model_name"] = compositeModelName + if common.IsCompositeModelName(modelNameOrID) { + cmd.Params["composite_model_name"] = modelNameOrID + } else if common.IsUUID(modelNameOrID) { + cmd.Params["model_id"] = modelNameOrID + } else { + return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID) + } cmd.Params["query"] = query cmd.Params["documents"] = documents cmd.Params["top_n"] = topN @@ -3389,7 +3415,7 @@ func (p *Parser) parseASRCommand() (*Command, error) { } p.nextToken() // consume WITH - compositeModelName, err := p.parseQuotedString() + modelNameOrID, err := p.parseQuotedString() if err != nil { return nil, err } @@ -3407,7 +3433,13 @@ func (p *Parser) parseASRCommand() (*Command, error) { p.nextToken() cmd := NewCommand("asr_user_command") - cmd.Params["composite_model_name"] = compositeModelName + if common.IsCompositeModelName(modelNameOrID) { + cmd.Params["composite_model_name"] = modelNameOrID + } else if common.IsUUID(modelNameOrID) { + cmd.Params["model_id"] = modelNameOrID + } else { + return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID) + } cmd.Params["audio_file"] = audioFile for p.curToken.Type != TokenEOF && p.curToken.Type != TokenSemicolon { @@ -3445,9 +3477,18 @@ func (p *Parser) parseTTSCommand() (*Command, error) { if p.curToken.Type != TokenQuotedString && p.curToken.Type != TokenIdentifier { return nil, fmt.Errorf("expect model name after 'with'") } - cmd.Params["composite_model_name"] = strings.Trim(p.curToken.Value, "\"'") + + modelNameOrID := strings.Trim(p.curToken.Value, "\"'") p.nextToken() + if common.IsCompositeModelName(modelNameOrID) { + cmd.Params["composite_model_name"] = modelNameOrID + } else if common.IsUUID(modelNameOrID) { + cmd.Params["model_id"] = modelNameOrID + } else { + return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID) + } + if p.curToken.Type != TokenText { return nil, fmt.Errorf("expect 'text' parameter") } @@ -3509,7 +3550,7 @@ func (p *Parser) parseOCRCommand() (*Command, error) { } p.nextToken() // consume WITH - compositeModelName, err := p.parseQuotedString() + modelNameOrID, err := p.parseQuotedString() if err != nil { return nil, err } @@ -3540,7 +3581,18 @@ func (p *Parser) parseOCRCommand() (*Command, error) { return nil, fmt.Errorf("expected FILE or URL") } - cmd.Params["composite_model_name"] = compositeModelName + if common.IsCompositeModelName(modelNameOrID) { + cmd.Params["composite_model_name"] = modelNameOrID + } else if common.IsUUID(modelNameOrID) { + cmd.Params["model_id"] = modelNameOrID + } else { + return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID) + } + + // Semicolon is optional + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } return cmd, nil } @@ -3548,7 +3600,7 @@ func (p *Parser) parseOCRCommand() (*Command, error) { func (p *Parser) parseModelParseCommand() (*Command, error) { p.nextToken() // consume WITH - compositeModelName, err := p.parseQuotedString() + modelNameOrID, err := p.parseQuotedString() if err != nil { return nil, err } @@ -3579,7 +3631,18 @@ func (p *Parser) parseModelParseCommand() (*Command, error) { return nil, fmt.Errorf("expected FILE or URL") } - cmd.Params["composite_model_name"] = compositeModelName + if common.IsCompositeModelName(modelNameOrID) { + cmd.Params["composite_model_name"] = modelNameOrID + } else if common.IsUUID(modelNameOrID) { + cmd.Params["model_id"] = modelNameOrID + } else { + return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID) + } + + // Semicolon is optional + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } return cmd, nil } @@ -3711,7 +3774,7 @@ func (p *Parser) parseUseCommand() (*Command, error) { func (p *Parser) parseUseModel() (*Command, error) { p.nextToken() // consume MODEL - compositeModelName, err := p.parseQuotedString() + modelNameOrID, err := p.parseQuotedString() if err != nil { return nil, fmt.Errorf("expected model identifier in format 'model@instance@provider': %w", err) } @@ -3723,7 +3786,19 @@ func (p *Parser) parseUseModel() (*Command, error) { } cmd := NewCommand("use_model") - cmd.Params["composite_model_name"] = compositeModelName + + if common.IsCompositeModelName(modelNameOrID) { + cmd.Params["composite_model_name"] = modelNameOrID + } else if common.IsUUID(modelNameOrID) { + cmd.Params["model_id"] = modelNameOrID + } else { + return nil, fmt.Errorf("invalid format of model name or ID: %s", modelNameOrID) + } + + // Semicolon is optional + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } return cmd, nil } diff --git a/internal/common/format.go b/internal/common/format.go index e7ccc1eaef..b1c1835654 100644 --- a/internal/common/format.go +++ b/internal/common/format.go @@ -16,7 +16,11 @@ package common -import "fmt" +import ( + "fmt" + "regexp" + "strings" +) // PtrString formats a pointer value as a string for debug/log output. // Returns "" for nil pointers. @@ -26,3 +30,45 @@ func PtrString[T any](p *T) string { } return fmt.Sprintf("%v", *p) } + +// composite model name format: model_name@instance_name@provider_name +func IsCompositeModelName(modelName string) bool { + parts := strings.Split(modelName, "@") + if len(parts) != 3 { + return false + } + for _, p := range parts { + if p == "" { + return false + } + } + return true +} + +func IsUUID(uuid string) bool { + // only lower case letters and numbers, length is 32 + if len(uuid) != 32 { + return false + } + uuidRegex := regexp.MustCompile(`^[a-z0-9]+$`) + if uuidRegex.MatchString(uuid) { + return true + } + return false +} + +// ExtractCompositeName splits a composite model name into three parts. +// Returns (modelName, instanceName, providerName, true) on success, +// or ("", "", "", false) if the name is not a valid composite name. +func ExtractCompositeName(modelName string) (string, string, string, error) { + parts := strings.Split(modelName, "@") + if len(parts) != 3 { + return "", "", "", fmt.Errorf("invalid model name format") + } + for _, p := range parts { + if p == "" { + return "", "", "", fmt.Errorf("invalid model name format") + } + } + return parts[0], parts[1], parts[2], nil +} diff --git a/internal/handler/providers.go b/internal/handler/providers.go index cf896629a6..80a83afbee 100644 --- a/internal/handler/providers.go +++ b/internal/handler/providers.go @@ -897,6 +897,7 @@ type ChatToModelRequest struct { ProviderName *string `json:"provider_name"` InstanceName *string `json:"instance_name"` ModelName *string `json:"model_name"` + ModelID *string `json:"model_id"` Messages []map[string]interface{} `json:"messages"` Stream bool `json:"stream"` Thinking bool `json:"thinking"` @@ -915,28 +916,38 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) { return } - if req.ProviderName == nil || *req.ProviderName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Provider name is required", - }) - return - } + if req.ModelID == nil { + if req.ProviderName == nil || *req.ProviderName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } - if req.InstanceName == nil || *req.InstanceName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Instance name is required", - }) - return - } + if req.InstanceName == nil || *req.InstanceName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Instance name is required", + }) + return + } - if req.ModelName == nil || *req.ModelName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Model name is required", - }) - return + if req.ModelName == nil || *req.ModelName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model name is required", + }) + return + } + } else { + if *req.ModelID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model ID is empty", + }) + return + } } userID := c.GetString("user_id") @@ -1005,8 +1016,8 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) { messages[i] = models.Message{Role: role, Content: content} } - // Stream response using sender function (best performance, no channel) - errorCode, err := h.modelProviderService.ChatToModelStreamWithSender(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, messages, &apiConfig, &chatConfig, sender) + // Stream response using sender function (the best performance, no channel) + errorCode, err := h.modelProviderService.ChatToModelStreamWithSender(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, messages, &apiConfig, &chatConfig, sender) if errorCode != common.CodeSuccess { c.SSEvent("error", err.Error()) @@ -1026,7 +1037,7 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) { content := msg["content"] messages[i] = models.Message{Role: role, Content: content} } - response, errorCode, err = h.modelProviderService.ChatToModelWithMessages(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, messages, &apiConfig, &chatConfig) + response, errorCode, err = h.modelProviderService.ChatToModelWithMessages(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, messages, &apiConfig, &chatConfig) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -1047,6 +1058,7 @@ type EmbedTextRequest struct { ProviderName *string `json:"provider_name"` InstanceName *string `json:"instance_name"` ModelName *string `json:"model_name"` + ModelID *string `json:"model_id"` Texts []string `json:"texts"` Dimension int `json:"dimension"` } @@ -1062,28 +1074,38 @@ func (h *ProviderHandler) EmbedText(c *gin.Context) { return } - if req.ProviderName == nil || *req.ProviderName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Provider name is required", - }) - return - } + if req.ModelID == nil { + if req.ProviderName == nil || *req.ProviderName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } - if req.InstanceName == nil || *req.InstanceName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Instance name is required", - }) - return - } + if req.InstanceName == nil || *req.InstanceName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Instance name is required", + }) + return + } - if req.ModelName == nil || *req.ModelName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Model name is required", - }) - return + if req.ModelName == nil || *req.ModelName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model name is required", + }) + return + } + } else { + if *req.ModelID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model ID is empty", + }) + return + } } userID := c.GetString("user_id") @@ -1102,8 +1124,7 @@ func (h *ProviderHandler) EmbedText(c *gin.Context) { var errorCode common.ErrorCode var err error - response, errorCode, err = h.modelProviderService.EmbedText(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.Texts, &apiConfig, &embeddingConfig) - + response, errorCode, err = h.modelProviderService.EmbedText(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, req.Texts, &apiConfig, &embeddingConfig) if err != nil { c.JSON(http.StatusOK, gin.H{ "code": errorCode, @@ -1123,6 +1144,7 @@ type RerankDocumentRequest struct { ProviderName *string `json:"provider_name"` InstanceName *string `json:"instance_name"` ModelName *string `json:"model_name"` + ModelID *string `json:"model_id"` Query string `json:"query"` Documents []string `json:"documents"` TopN int `json:"top_n"` @@ -1139,28 +1161,38 @@ func (h *ProviderHandler) RerankDocument(c *gin.Context) { return } - if req.ProviderName == nil || *req.ProviderName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Provider name is required", - }) - return - } + if req.ModelID == nil { + if req.ProviderName == nil || *req.ProviderName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } - if req.InstanceName == nil || *req.InstanceName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Instance name is required", - }) - return - } + if req.InstanceName == nil || *req.InstanceName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Instance name is required", + }) + return + } - if req.ModelName == nil || *req.ModelName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Model name is required", - }) - return + if req.ModelName == nil || *req.ModelName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model name is required", + }) + return + } + } else { + if *req.ModelID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model ID is empty", + }) + return + } } userID := c.GetString("user_id") @@ -1179,8 +1211,7 @@ func (h *ProviderHandler) RerankDocument(c *gin.Context) { var errorCode common.ErrorCode var err error - response, errorCode, err = h.modelProviderService.RerankDocument(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.Query, req.Documents, &apiConfig, &rerankConfig) - + response, errorCode, err = h.modelProviderService.RerankDocument(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, req.Query, req.Documents, &apiConfig, &rerankConfig) if err != nil { c.JSON(http.StatusOK, gin.H{ "code": errorCode, @@ -1200,6 +1231,7 @@ type TranscribeAudioRequest struct { ProviderName *string `json:"provider_name"` InstanceName *string `json:"instance_name"` ModelName *string `json:"model_name"` + ModelID *string `json:"model_id"` File *string `json:"file"` Language []string `json:"language"` Prompt int `json:"prompt"` @@ -1218,28 +1250,38 @@ func (h *ProviderHandler) TranscribeAudio(c *gin.Context) { return } - if req.ProviderName == nil || *req.ProviderName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Provider name is required", - }) - return - } + if req.ModelID == nil { + if req.ProviderName == nil || *req.ProviderName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } - if req.InstanceName == nil || *req.InstanceName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Instance name is required", - }) - return - } + if req.InstanceName == nil || *req.InstanceName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Instance name is required", + }) + return + } - if req.ModelName == nil || *req.ModelName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Model name is required", - }) - return + if req.ModelName == nil || *req.ModelName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model name is required", + }) + return + } + } else { + if *req.ModelID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model ID is empty", + }) + return + } } userID := c.GetString("user_id") @@ -1287,9 +1329,8 @@ func (h *ProviderHandler) TranscribeAudio(c *gin.Context) { return nil } - // Stream response using sender function (best performance, no channel) - errorCode, err := h.modelProviderService.TranscribeAudioStream(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.File, &apiConfig, &asrConfig, sender) - + // Stream response using sender function ( the best performance, no channel) + errorCode, err := h.modelProviderService.TranscribeAudioStream(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, req.File, &apiConfig, &asrConfig, sender) if errorCode != common.CodeSuccess { c.SSEvent("error", err.Error()) } @@ -1301,8 +1342,7 @@ func (h *ProviderHandler) TranscribeAudio(c *gin.Context) { var errorCode common.ErrorCode var err error - response, errorCode, err = h.modelProviderService.TranscribeAudio(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.File, &apiConfig, &asrConfig) - + response, errorCode, err = h.modelProviderService.TranscribeAudio(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, req.File, &apiConfig, &asrConfig) if err != nil { c.JSON(http.StatusOK, gin.H{ "code": errorCode, @@ -1322,6 +1362,7 @@ type AudioSpeechRequest struct { ProviderName *string `json:"provider_name"` InstanceName *string `json:"instance_name"` ModelName *string `json:"model_name"` + ModelID *string `json:"model_id"` Text *string `json:"text"` Stream bool `json:"stream"` TTSConfig *models.TTSConfig `json:"tts_config"` @@ -1338,28 +1379,38 @@ func (h *ProviderHandler) AudioSpeech(c *gin.Context) { return } - if req.ProviderName == nil || *req.ProviderName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Provider name is required", - }) - return - } + if req.ModelID == nil { + if req.ProviderName == nil || *req.ProviderName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } - if req.InstanceName == nil || *req.InstanceName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Instance name is required", - }) - return - } + if req.InstanceName == nil || *req.InstanceName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Instance name is required", + }) + return + } - if req.ModelName == nil || *req.ModelName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Model name is required", - }) - return + if req.ModelName == nil || *req.ModelName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model name is required", + }) + return + } + } else { + if *req.ModelID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model ID is empty", + }) + return + } } userID := c.GetString("user_id") @@ -1407,9 +1458,8 @@ func (h *ProviderHandler) AudioSpeech(c *gin.Context) { return nil } - // Stream response using sender function (best performance, no channel) - errorCode, err := h.modelProviderService.AudioSpeechStream(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.Text, &apiConfig, &ttsConfig, sender) - + // Stream response using sender function ( the best performance, no channel) + errorCode, err := h.modelProviderService.AudioSpeechStream(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, req.Text, &apiConfig, &ttsConfig, sender) if errorCode != common.CodeSuccess { c.SSEvent("error", err.Error()) } @@ -1421,8 +1471,7 @@ func (h *ProviderHandler) AudioSpeech(c *gin.Context) { var errorCode common.ErrorCode var err error - response, errorCode, err = h.modelProviderService.AudioSpeech(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.Text, &apiConfig, &ttsConfig) - + response, errorCode, err = h.modelProviderService.AudioSpeech(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, req.Text, &apiConfig, &ttsConfig) if err != nil { c.JSON(http.StatusOK, gin.H{ "code": errorCode, @@ -1442,6 +1491,7 @@ type OCRFileRequest struct { ProviderName *string `json:"provider_name"` InstanceName *string `json:"instance_name"` ModelName *string `json:"model_name"` + ModelID *string `json:"model_id"` Content []byte `json:"content"` URL *string `json:"url"` } @@ -1457,28 +1507,38 @@ func (h *ProviderHandler) OCRFile(c *gin.Context) { return } - if req.ProviderName == nil || *req.ProviderName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Provider name is required", - }) - return - } + if req.ModelID == nil { + if req.ProviderName == nil || *req.ProviderName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } - if req.InstanceName == nil || *req.InstanceName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Instance name is required", - }) - return - } + if req.InstanceName == nil || *req.InstanceName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Instance name is required", + }) + return + } - if req.ModelName == nil || *req.ModelName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Model name is required", - }) - return + if req.ModelName == nil || *req.ModelName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model name is required", + }) + return + } + } else { + if *req.ModelID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model ID is empty", + }) + return + } } userID := c.GetString("user_id") @@ -1495,8 +1555,7 @@ func (h *ProviderHandler) OCRFile(c *gin.Context) { var errorCode common.ErrorCode var err error - response, errorCode, err = h.modelProviderService.OCRFile(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.Content, req.URL, &apiConfig, &OCRConfig) - + response, errorCode, err = h.modelProviderService.OCRFile(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, req.Content, req.URL, &apiConfig, &OCRConfig) if err != nil { c.JSON(http.StatusOK, gin.H{ "code": errorCode, @@ -1516,6 +1575,7 @@ type ParseFileRequest struct { ProviderName *string `json:"provider_name"` InstanceName *string `json:"instance_name"` ModelName *string `json:"model_name"` + ModelID *string `json:"model_id"` Content []byte `json:"content"` URL *string `json:"url"` } @@ -1531,28 +1591,38 @@ func (h *ProviderHandler) ParseFile(c *gin.Context) { return } - if req.ProviderName == nil || *req.ProviderName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Provider name is required", - }) - return - } + if req.ModelID == nil { + if req.ProviderName == nil || *req.ProviderName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } - if req.InstanceName == nil || *req.InstanceName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Instance name is required", - }) - return - } + if req.InstanceName == nil || *req.InstanceName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Instance name is required", + }) + return + } - if req.ModelName == nil || *req.ModelName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Model name is required", - }) - return + if req.ModelName == nil || *req.ModelName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model name is required", + }) + return + } + } else { + if *req.ModelID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model ID is empty", + }) + return + } } userID := c.GetString("user_id") @@ -1569,8 +1639,7 @@ func (h *ProviderHandler) ParseFile(c *gin.Context) { var errorCode common.ErrorCode var err error - response, errorCode, err = h.modelProviderService.ParseFile(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.Content, req.URL, &apiConfig, &parseFileConfig) - + response, errorCode, err = h.modelProviderService.ParseFile(req.ProviderName, req.InstanceName, req.ModelName, req.ModelID, userID, req.Content, req.URL, &apiConfig, &parseFileConfig) if err != nil { c.JSON(http.StatusOK, gin.H{ "code": errorCode, diff --git a/internal/router/router.go b/internal/router/router.go index 80d8e3e2f3..d587edb4a1 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -399,6 +399,9 @@ func (r *Router) Setup(engine *gin.Engine) { // provider handler because that's where the // modelProviderService is wired. model.GET("/", r.providerHandler.ListTenantAddedModels) + + // TODO: list default models? + //model.GET("/", r.tenantHandler.GetModels) model.PATCH("/", r.tenantHandler.SetModels) // Tenant default-model selection (used by the agent // page's useFetchDefaultModels hook). Mirrors the diff --git a/internal/service/model_service.go b/internal/service/model_service.go index 362a9089f5..38812a570f 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -44,9 +44,9 @@ func parseModelName(compositeName string) (modelName, instanceName, providerName return parts[0], "default", parts[1], nil } else if len(parts) == 1 { return parts[0], "", "", fmt.Errorf("provider name missing in model name: %s", compositeName) - } else { - return "", "", "", fmt.Errorf("invalid model name format: %s", compositeName) } + + return "", "", "", fmt.Errorf("invalid model name format: %s", compositeName) } func newModelDriverForBaseURL(driver modelModule.ModelDriver, providerName, region, baseURL string) (modelModule.ModelDriver, error) { @@ -386,7 +386,7 @@ func (m *ModelProviderService) ListProviderInstances(providerName, userID string // crash on a freshly created tenant. result := make([]map[string]interface{}, 0, len(instances)) for _, instance := range instances { - // convert instance.Extra (json string) to map + // convert instance.Extra (JSON string) to map var extra map[string]string err = json.Unmarshal([]byte(instance.Extra), &extra) if err != nil { @@ -440,7 +440,7 @@ func (m *ModelProviderService) ShowProviderInstance(providerName, instanceName, return nil, common.CodeServerError, err } - // convert instance.Extra (json string) to map + // convert instance.Extra (JSON string) to map var extra map[string]string err = json.Unmarshal([]byte(instance.Extra), &extra) if err != nil { @@ -748,14 +748,14 @@ func (m *ModelProviderService) ShowTask(providerName, instanceName, taskID, user // contract for GET /api/v1/models — the endpoint that // web/src/hooks/use-llm-request.tsx → useFetchAllAddedModels consumes. // -// Per the Python algorithm, for each (provider × instance) we cross- -// reference the factory catalog (internal/entity/models/model.go +// Per the Python algorithm, for each (provider × instance) we cross-reference the factory catalog (internal/entity/models/model.go // ProviderManager.Providers) with the per-tenant overrides in // tenant_model: -// active_model_types = tenant_model rows with status='active' -// inactive_model_types = tenant_model rows with status='inactive' -// factory_model_types = provider.Models[i].ModelTypes -// model_types = (factory ∪ active) \ inactive +// +// active_model_types = tenant_model rows with status='active' +// inactive_model_types = tenant_model rows with status='inactive' +// factory_model_types = provider.Models[i].ModelTypes +// model_types = (factory ∪ active) \ inactive // // The Go port never WRITES to tenant_model, so in practice every model // from the factory catalog is treated as added unless explicitly @@ -1144,216 +1144,261 @@ func (m *ModelProviderService) UpdateModelStatus(providerName, instanceName, mod return common.CodeSuccess, nil } -// ChatToModelWithMessages sends messages to the model with messages array -func (m *ModelProviderService) ChatToModelWithMessages(providerName, instanceName, modelName, userID string, messages []modelModule.Message, apiConfig *modelModule.APIConfig, modelConfig *modelModule.ChatConfig) (*modelModule.ChatResponse, common.ErrorCode, error) { +type ModelInstanceAndProviderInfo struct { + ProviderEntity *entity.TenantModelProvider + ProviderInfo *modelModule.Provider + InstanceEntity *entity.TenantModelInstance + ModelEntity *entity.TenantModel + ModelInfo *modelModule.Model + APIConfig *modelModule.APIConfig +} + +func (m *ModelProviderService) getModelInstanceAndProviderByName(providerName, instanceName, modelName *string, userID string, apiConfig *modelModule.APIConfig) (*ModelInstanceAndProviderInfo, error) { + // Get tenant ID from user + tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + return nil, err + } + + if len(tenants) == 0 { + return nil, err + } + + tenantID := tenants[0].TenantID + + // Check if provider exists + providerEntity, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, *providerName) + if err != nil { + return nil, err + } + + instanceEntity, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(providerEntity.ID, *instanceName) + if err != nil { + return nil, err + } + + modelEntity, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(providerEntity.ID, instanceEntity.ID, *modelName) + if err != nil { + // Not found model + modelEntity = nil + } + + providerInfo := dao.GetModelProviderManager().FindProvider(*providerName) + if providerInfo == nil { + return nil, errors.New("provider not found") + } + + modelInfo, err := dao.GetModelProviderManager().GetModelByName(*providerName, *modelName) + if err != nil { + return nil, errors.New(fmt.Sprintf("provider %s model %s not found", *providerName, *modelName)) + } + + var extra map[string]string + err = json.Unmarshal([]byte(instanceEntity.Extra), &extra) + if err != nil { + return nil, err + } + if apiConfig == nil { apiConfig = &modelModule.APIConfig{} } + + region := extra["region"] + baseURL := extra["base_url"] + + apiConfig.ApiKey = &instanceEntity.APIKey + apiConfig.BaseURL = &baseURL + apiConfig.Region = ®ion + + var result = &ModelInstanceAndProviderInfo{ + ProviderEntity: providerEntity, + ProviderInfo: providerInfo, + InstanceEntity: instanceEntity, + ModelEntity: modelEntity, + ModelInfo: modelInfo, + APIConfig: apiConfig, + } + + return result, nil +} + +func (m *ModelProviderService) getModelInstanceAndProviderByID(modelID *string, userID string, apiConfig *modelModule.APIConfig) (*ModelInstanceAndProviderInfo, error) { + // Get tenant ID from user + tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + return nil, err + } + + if len(tenants) == 0 { + return nil, err + } + + tenantID := tenants[0].TenantID + + modelEntity, err := m.modelDAO.GetByID(*modelID) + if err != nil { + return nil, err + } + + instanceEntity, err := m.modelInstanceDAO.GetByID(modelEntity.InstanceID) + if err != nil { + return nil, err + } + + providerEntity, err := m.modelProviderDAO.GetByID(instanceEntity.ProviderID) + if err != nil { + return nil, err + } + + if providerEntity.TenantID != tenantID { + return nil, errors.New("provider not found") + } + + providerInfo := dao.GetModelProviderManager().FindProvider(providerEntity.ProviderName) + if providerInfo == nil { + return nil, errors.New("provider not found") + } + + modelInfo, err := dao.GetModelProviderManager().GetModelByName(providerEntity.ProviderName, modelEntity.ModelName) + if err != nil { + return nil, errors.New(fmt.Sprintf("provider %s model %s not found", providerEntity.ProviderName, modelEntity.ModelName)) + } + + var extra map[string]string + err = json.Unmarshal([]byte(instanceEntity.Extra), &extra) + if err != nil { + return nil, err + } + + if apiConfig == nil { + apiConfig = &modelModule.APIConfig{} + } + + region := extra["region"] + baseURL := extra["base_url"] + + apiConfig.ApiKey = &instanceEntity.APIKey + apiConfig.BaseURL = &baseURL + apiConfig.Region = ®ion + + var result = &ModelInstanceAndProviderInfo{ + ProviderEntity: providerEntity, + ProviderInfo: providerInfo, + InstanceEntity: instanceEntity, + ModelEntity: modelEntity, + ModelInfo: modelInfo, + APIConfig: apiConfig, + } + + return result, nil +} + +// ChatToModelWithMessages sends messages to the model with messages array +func (m *ModelProviderService) ChatToModelWithMessages(providerName, instanceName, modelName, modelID *string, userID string, messages []modelModule.Message, apiConfig *modelModule.APIConfig, modelConfig *modelModule.ChatConfig) (*modelModule.ChatResponse, common.ErrorCode, error) { + + var err error + var info *ModelInstanceAndProviderInfo + + if modelID != nil { + info, err = m.getModelInstanceAndProviderByID(modelID, userID, apiConfig) + if err != nil || info == nil { + return nil, common.CodeNotFound, err + } + } else { + info, err = m.getModelInstanceAndProviderByName(providerName, instanceName, modelName, userID, apiConfig) + if err != nil || info == nil { + return nil, common.CodeNotFound, err + } + } + if modelConfig == nil { modelConfig = &modelModule.ChatConfig{} } + modelConfig.ModelClass = info.ModelInfo.Class - // Get tenant ID from user - tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + var response *modelModule.ChatResponse + var modelDriver modelModule.ModelDriver + + if info.ModelEntity == nil { + if !info.ModelInfo.ModelTypeMap["chat"] && !info.ModelInfo.ModelTypeMap["vision"] { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a chat or multimodal model", *modelName, *providerName)) + } + modelDriver = info.ProviderInfo.ModelDriver + } else { + // model entity exists + if info.ModelEntity.Status == "active" { + if info.ModelEntity.ModelType != "chat" && info.ModelEntity.ModelType != "vision" { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a chat or multimodal model", *modelName, *providerName)) + } + + modelDriver, err = newModelDriverForBaseURL(info.ProviderInfo.ModelDriver, *providerName, *info.APIConfig.Region, *info.APIConfig.BaseURL) + if err != nil { + return nil, common.CodeServerError, err + } + } else { + return nil, common.CodeServerError, errors.New("model is inactive") + } + } + + response, err = modelDriver.ChatWithMessages(*modelName, messages, info.APIConfig, modelConfig) if err != nil { return nil, common.CodeServerError, err } - - if len(tenants) == 0 { - return nil, common.CodeNotFound, errors.New("user has no tenants") + if response == nil { + return nil, common.CodeServerError, errors.New("empty chat response") } - tenantID := tenants[0].TenantID - - // Check if provider exists - provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) - if err != nil { - return nil, common.CodeServerError, err - } - - instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) - if err != nil { - return nil, common.CodeServerError, err - } - - modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) - if err != nil { - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { - return nil, common.CodeNotFound, errors.New("provider not found") - } - - var model *modelModule.Model = nil - model, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) - if err != nil { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) - } - - if !model.ModelTypeMap["chat"] && !model.ModelTypeMap["vision"] { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a chat or multimodal model", modelName, providerName)) - } - - modelConfig.ModelClass = model.Class - - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return nil, common.CodeServerError, err - } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - var response *modelModule.ChatResponse - response, err = providerInfo.ModelDriver.ChatWithMessages(modelName, messages, apiConfig, modelConfig) - if err != nil { - return nil, common.CodeServerError, err - } - if response == nil { - return nil, common.CodeServerError, errors.New("empty chat response") - } - - return response, common.CodeSuccess, nil - } - - if modelInfo.Status == "active" { - if modelInfo.ModelType != "chat" && modelInfo.ModelType != "vision" { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a chat or multimodal model", modelName, providerName)) - } - // For local deployed models - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { - return nil, common.CodeNotFound, errors.New("provider not found") - } - - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return nil, common.CodeServerError, err - } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - modelConfig.ModelClass = &providerInfo.Class - - newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) - if err != nil { - return nil, common.CodeServerError, err - } - - var response *modelModule.ChatResponse - response, err = newProviderInfo.ChatWithMessages(modelName, messages, apiConfig, modelConfig) - if err != nil { - return nil, common.CodeServerError, err - } - if response == nil { - return nil, common.CodeServerError, errors.New("empty chat response") - } - - return response, common.CodeSuccess, nil - } - - return nil, common.CodeServerError, errors.New("model is disabled") + return response, common.CodeSuccess, nil } -// ChatToModelStreamWithSender streams chat response directly via sender function (best performance, no channel) -func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanceName, modelName, userID string, messages []modelModule.Message, apiConfig *modelModule.APIConfig, modelConfig *modelModule.ChatConfig, sender func(*string, *string) error) (common.ErrorCode, error) { - // Get tenant ID from user - tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") - if err != nil { - return common.CodeServerError, err - } +// ChatToModelStreamWithSender streams chat response directly via sender function ( the best performance, no channel) +func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanceName, modelName, modelID *string, userID string, messages []modelModule.Message, apiConfig *modelModule.APIConfig, modelConfig *modelModule.ChatConfig, sender func(*string, *string) error) (common.ErrorCode, error) { - if len(tenants) == 0 { - return common.CodeNotFound, errors.New("user has no tenants") - } + var err error + var info *ModelInstanceAndProviderInfo - tenantID := tenants[0].TenantID - - // Check if provider exists - provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) - if err != nil { - return common.CodeServerError, err - } - - instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) - if err != nil { - return common.CodeServerError, err - } - - modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) - if err != nil { - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { + if modelID != nil { + info, err = m.getModelInstanceAndProviderByID(modelID, userID, apiConfig) + if err != nil || info == nil { return common.CodeNotFound, err } - - var model *modelModule.Model = nil - model, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) - if err != nil { + } else { + info, err = m.getModelInstanceAndProviderByName(providerName, instanceName, modelName, userID, apiConfig) + if err != nil || info == nil { return common.CodeNotFound, err } - - if !model.ModelTypeMap["chat"] && !model.ModelTypeMap["vision"] { - return common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a chat or multimodal model", modelName, providerName)) - } - - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return common.CodeServerError, err - } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - err = providerInfo.ModelDriver.ChatStreamlyWithSender(modelName, messages, apiConfig, modelConfig, sender) - if err != nil { - return common.CodeServerError, err - } - - return common.CodeSuccess, nil } - if modelInfo.Status == "active" { - if modelInfo.ModelType != "chat" && modelInfo.ModelType != "vision" { - return common.CodeServerError, errors.New(fmt.Sprintf("expect model %s@%s is a chat or multimodal model", modelName, providerName)) - } - // For local deployed models - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { - return common.CodeNotFound, errors.New("provider not found") - } + if modelConfig == nil { + modelConfig = &modelModule.ChatConfig{} + } + modelConfig.ModelClass = info.ModelInfo.Class - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return common.CodeServerError, err + var modelDriver modelModule.ModelDriver + + if info.ModelEntity == nil { + modelDriver = info.ProviderInfo.ModelDriver + } else { + // model entity exists + if info.ModelEntity.Status == "active" { + if info.ModelEntity.ModelType != "chat" && info.ModelEntity.ModelType != "vision" { + return common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a chat or multimodal model", *modelName, *providerName)) + } + + modelDriver, err = newModelDriverForBaseURL(info.ProviderInfo.ModelDriver, *providerName, *info.APIConfig.Region, *info.APIConfig.BaseURL) + if err != nil { + return common.CodeServerError, err + } + } else { + return common.CodeServerError, errors.New("model is inactive") } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - modelConfig.ModelClass = &providerInfo.Class - - newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) - if err != nil { - return common.CodeServerError, err - } - - err = newProviderInfo.ChatStreamlyWithSender(modelName, messages, apiConfig, modelConfig, sender) - if err != nil { - return common.CodeServerError, err - } - return common.CodeSuccess, nil } - return common.CodeServerError, errors.New("model is disabled") + err = modelDriver.ChatStreamlyWithSender(*modelName, messages, apiConfig, modelConfig, sender) + if err != nil { + return common.CodeServerError, err + } + return common.CodeSuccess, nil } func validateEmbeddingDimension(model *modelModule.Model, requested int) error { @@ -1387,858 +1432,449 @@ func validateEmbeddingDimension(model *modelModule.Model, requested int) error { } // EmbedText sends texts to the embedding model -func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName, userID string, texts []string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.EmbeddingConfig) ([]modelModule.EmbeddingData, common.ErrorCode, error) { - if apiConfig == nil { - apiConfig = &modelModule.APIConfig{} +func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName, modelID *string, userID string, texts []string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.EmbeddingConfig) ([]modelModule.EmbeddingData, common.ErrorCode, error) { + + var err error + var info *ModelInstanceAndProviderInfo + + if modelID != nil { + info, err = m.getModelInstanceAndProviderByID(modelID, userID, apiConfig) + if err != nil || info == nil { + return nil, common.CodeNotFound, err + } + } else { + info, err = m.getModelInstanceAndProviderByName(providerName, instanceName, modelName, userID, apiConfig) + if err != nil || info == nil { + return nil, common.CodeNotFound, err + } } + if modelConfig == nil { modelConfig = &modelModule.EmbeddingConfig{} } - // Get tenant ID from user - tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + var modelDriver modelModule.ModelDriver + + if info.ModelEntity == nil { + if !info.ModelInfo.ModelTypeMap["embedding"] { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is an embedding model", *modelName, *providerName)) + } + modelDriver = info.ProviderInfo.ModelDriver + } else { + // model entity exists + if info.ModelEntity.Status == "active" { + if info.ModelEntity.ModelType != "embedding" { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is an embedding model", *modelName, *providerName)) + } + + modelDriver, err = newModelDriverForBaseURL(info.ProviderInfo.ModelDriver, *providerName, *info.APIConfig.Region, *info.APIConfig.BaseURL) + if err != nil { + return nil, common.CodeServerError, err + } + } else { + return nil, common.CodeServerError, errors.New("model is inactive") + } + } + + if err = validateEmbeddingDimension(info.ModelInfo, modelConfig.Dimension); err != nil { + return nil, common.CodeBadRequest, err + } + + var response []modelModule.EmbeddingData + response, err = modelDriver.Embed(modelName, texts, apiConfig, modelConfig) if err != nil { return nil, common.CodeServerError, err } - - if len(tenants) == 0 { - return nil, common.CodeNotFound, errors.New("user has no tenants") + if response == nil || len(response) == 0 { + return nil, common.CodeServerError, errors.New("empty embed response") } - tenantID := tenants[0].TenantID - - // Check if provider exists - provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) - if err != nil { - return nil, common.CodeServerError, err - } - - instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) - if err != nil { - return nil, common.CodeServerError, err - } - - modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) - if err != nil { - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { - return nil, common.CodeNotFound, errors.New("provider not found") - } - - var model *modelModule.Model = nil - model, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) - if err != nil { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) - } - - if !model.ModelTypeMap["embedding"] { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s is not an embedding model", providerName, modelName)) - } - - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return nil, common.CodeServerError, err - } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - if err := validateEmbeddingDimension(model, modelConfig.Dimension); err != nil { - return nil, common.CodeBadRequest, err - } - - var response []modelModule.EmbeddingData - response, err = providerInfo.ModelDriver.Embed(&modelName, texts, apiConfig, modelConfig) - if err != nil { - return nil, common.CodeServerError, err - } - if response == nil || len(response) == 0 { - return nil, common.CodeServerError, errors.New("empty embed response") - } - - return response, common.CodeSuccess, nil - } - - if modelInfo.Status == "active" { - if modelInfo.ModelType != "embedding" { - return nil, common.CodeServerError, errors.New(fmt.Sprintf("expect model %s@%s is an embedding model", modelName, providerName)) - } - // For local deployed models - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { - return nil, common.CodeNotFound, errors.New("provider not found") - } - - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return nil, common.CodeServerError, err - } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) - if err != nil { - return nil, common.CodeServerError, err - } - - modelSchema, _ := dao.GetModelProviderManager().GetModelByName(providerName, modelName) - if err := validateEmbeddingDimension(modelSchema, modelConfig.Dimension); err != nil { - return nil, common.CodeBadRequest, err - } - - var response []modelModule.EmbeddingData - response, err = newProviderInfo.Embed(&modelName, texts, apiConfig, modelConfig) - if err != nil { - return nil, common.CodeServerError, err - } - if response == nil || len(response) == 0 { - return nil, common.CodeServerError, errors.New("empty embed response") - } - - return response, common.CodeSuccess, nil - } - - return nil, common.CodeServerError, errors.New("model is disabled") + return response, common.CodeSuccess, nil } // RerankDocument sends texts to the embedding model -func (m *ModelProviderService) RerankDocument(providerName, instanceName, modelName, userID, query string, documents []string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.RerankConfig) (*modelModule.RerankResponse, common.ErrorCode, error) { - if apiConfig == nil { - apiConfig = &modelModule.APIConfig{} +func (m *ModelProviderService) RerankDocument(providerName, instanceName, modelName, modelID *string, userID, query string, documents []string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.RerankConfig) (*modelModule.RerankResponse, common.ErrorCode, error) { + + var err error + var info *ModelInstanceAndProviderInfo + + if modelID != nil { + info, err = m.getModelInstanceAndProviderByID(modelID, userID, apiConfig) + if err != nil || info == nil { + return nil, common.CodeNotFound, err + } + } else { + info, err = m.getModelInstanceAndProviderByName(providerName, instanceName, modelName, userID, apiConfig) + if err != nil || info == nil { + return nil, common.CodeNotFound, err + } } + if modelConfig == nil { modelConfig = &modelModule.RerankConfig{} } - // Get tenant ID from user - tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + var modelDriver modelModule.ModelDriver + + if info.ModelEntity == nil { + if !info.ModelInfo.ModelTypeMap["rerank"] { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a rerank model", *modelName, *providerName)) + } + modelDriver = info.ProviderInfo.ModelDriver + } else { + // model entity exists + if info.ModelEntity.Status == "active" { + if info.ModelEntity.ModelType != "rerank" { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a rerank model", *modelName, *providerName)) + } + + modelDriver, err = newModelDriverForBaseURL(info.ProviderInfo.ModelDriver, *providerName, *info.APIConfig.Region, *info.APIConfig.BaseURL) + if err != nil { + return nil, common.CodeServerError, err + } + } else { + return nil, common.CodeServerError, errors.New("model is inactive") + } + } + + var response *modelModule.RerankResponse + response, err = modelDriver.Rerank(modelName, query, documents, apiConfig, modelConfig) if err != nil { return nil, common.CodeServerError, err } - if len(tenants) == 0 { - return nil, common.CodeNotFound, errors.New("user has no tenants") - } - - tenantID := tenants[0].TenantID - - // Check if provider exists - provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) - if err != nil { - return nil, common.CodeServerError, err - } - - instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) - if err != nil { - return nil, common.CodeServerError, err - } - - modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) - if err != nil { - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { - return nil, common.CodeNotFound, errors.New("provider not found") - } - - var model *modelModule.Model = nil - model, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) - if err != nil { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) - } - - if !model.ModelTypeMap["rerank"] { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s is not a rerank model", providerName, modelName)) - } - - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return nil, common.CodeServerError, err - } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - var response *modelModule.RerankResponse - response, err = providerInfo.ModelDriver.Rerank(&modelName, query, documents, apiConfig, modelConfig) - if err != nil { - return nil, common.CodeServerError, err - } - - return response, common.CodeSuccess, nil - } - - if modelInfo.Status == "active" { - if modelInfo.ModelType != "rerank" { - return nil, common.CodeServerError, errors.New(fmt.Sprintf("expect model %s@%s is a rerank model", modelName, providerName)) - } - // For local deployed models - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { - return nil, common.CodeNotFound, errors.New("provider not found") - } - - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return nil, common.CodeServerError, err - } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) - if err != nil { - return nil, common.CodeServerError, err - } - - var response *modelModule.RerankResponse - response, err = newProviderInfo.Rerank(&modelName, query, documents, apiConfig, modelConfig) - if err != nil { - return nil, common.CodeServerError, err - } - - return response, common.CodeSuccess, nil - } - - return nil, common.CodeServerError, errors.New("model is disabled") + return response, common.CodeSuccess, nil } // TranscribeAudio transcribe audio file to text -func (m *ModelProviderService) TranscribeAudio(providerName, instanceName, modelName, userID string, audioFile *string, apiConfig *modelModule.APIConfig, asrConfig *modelModule.ASRConfig) (*modelModule.ASRResponse, common.ErrorCode, error) { - if apiConfig == nil { - apiConfig = &modelModule.APIConfig{} - } - if asrConfig == nil { - asrConfig = &modelModule.ASRConfig{} +func (m *ModelProviderService) TranscribeAudio(providerName, instanceName, modelName, modelID *string, userID string, audioFile *string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.ASRConfig) (*modelModule.ASRResponse, common.ErrorCode, error) { + + var err error + var info *ModelInstanceAndProviderInfo + + if modelID != nil { + info, err = m.getModelInstanceAndProviderByID(modelID, userID, apiConfig) + if err != nil || info == nil { + return nil, common.CodeNotFound, err + } + } else { + info, err = m.getModelInstanceAndProviderByName(providerName, instanceName, modelName, userID, apiConfig) + if err != nil || info == nil { + return nil, common.CodeNotFound, err + } } - // Get tenant ID from user - tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if modelConfig == nil { + modelConfig = &modelModule.ASRConfig{} + } + + var modelDriver modelModule.ModelDriver + + if info.ModelEntity == nil { + if !info.ModelInfo.ModelTypeMap["asr"] { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is an ASR model", *modelName, *providerName)) + } + modelDriver = info.ProviderInfo.ModelDriver + } else { + // model entity exists + if info.ModelEntity.Status == "active" { + if info.ModelEntity.ModelType != "asr" { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is an ASR model", *modelName, *providerName)) + } + + modelDriver, err = newModelDriverForBaseURL(info.ProviderInfo.ModelDriver, *providerName, *info.APIConfig.Region, *info.APIConfig.BaseURL) + if err != nil { + return nil, common.CodeServerError, err + } + } else { + return nil, common.CodeServerError, errors.New("model is inactive") + } + } + + var response *modelModule.ASRResponse + response, err = modelDriver.TranscribeAudio(modelName, audioFile, apiConfig, modelConfig) if err != nil { return nil, common.CodeServerError, err } - - if len(tenants) == 0 { - return nil, common.CodeNotFound, errors.New("user has no tenants") + if response == nil { + return nil, common.CodeServerError, errors.New("empty chat response") } - tenantID := tenants[0].TenantID - - // Check if provider exists - provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) - if err != nil { - return nil, common.CodeServerError, err - } - - instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) - if err != nil { - return nil, common.CodeServerError, err - } - - modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) - if err != nil { - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { - return nil, common.CodeNotFound, errors.New("provider not found") - } - - var model *modelModule.Model = nil - model, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) - if err != nil { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) - } - - if !model.ModelTypeMap["asr"] { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s is not an ASR model", providerName, modelName)) - } - - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return nil, common.CodeServerError, err - } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - var response *modelModule.ASRResponse - response, err = providerInfo.ModelDriver.TranscribeAudio(&modelName, audioFile, apiConfig, asrConfig) - if err != nil { - return nil, common.CodeServerError, err - } - if response == nil { - return nil, common.CodeServerError, errors.New("empty chat response") - } - - return response, common.CodeSuccess, nil - } - - if modelInfo.Status == "active" { - if modelInfo.ModelType != "asr" { - return nil, common.CodeServerError, errors.New(fmt.Sprintf("expect model %s@%s is an ASR model", modelName, providerName)) - } - // For local deployed models - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { - return nil, common.CodeNotFound, errors.New("provider not found") - } - - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return nil, common.CodeServerError, err - } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) - if err != nil { - return nil, common.CodeServerError, err - } - - var response *modelModule.ASRResponse - response, err = newProviderInfo.TranscribeAudio(&modelName, audioFile, apiConfig, asrConfig) - if err != nil { - return nil, common.CodeServerError, err - } - if response == nil { - return nil, common.CodeServerError, errors.New("empty chat response") - } - - return response, common.CodeSuccess, nil - } - - return nil, common.CodeServerError, errors.New("model is disabled") + return response, common.CodeSuccess, nil } -// ChatToModelStreamWithSender streams chat response directly via sender function (best performance, no channel) -func (m *ModelProviderService) TranscribeAudioStream(providerName, instanceName, modelName, userID string, audioFile *string, apiConfig *modelModule.APIConfig, asrConfig *modelModule.ASRConfig, sender func(*string, *string) error) (common.ErrorCode, error) { - // Get tenant ID from user - tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") - if err != nil { - return common.CodeServerError, err - } +// TranscribeAudioStream transcribe audio file to text stream directly via sender function ( the best performance, no channel) +func (m *ModelProviderService) TranscribeAudioStream(providerName, instanceName, modelName, modelID *string, userID string, audioFile *string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.ASRConfig, sender func(*string, *string) error) (common.ErrorCode, error) { - if len(tenants) == 0 { - return common.CodeNotFound, errors.New("user has no tenants") - } + var err error + var info *ModelInstanceAndProviderInfo - tenantID := tenants[0].TenantID - - // Check if provider exists - provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) - if err != nil { - return common.CodeServerError, err - } - - instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) - if err != nil { - return common.CodeServerError, err - } - - modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) - if err != nil { - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { + if modelID != nil { + info, err = m.getModelInstanceAndProviderByID(modelID, userID, apiConfig) + if err != nil || info == nil { return common.CodeNotFound, err } - - var model *modelModule.Model = nil - model, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) - if err != nil { + } else { + info, err = m.getModelInstanceAndProviderByName(providerName, instanceName, modelName, userID, apiConfig) + if err != nil || info == nil { return common.CodeNotFound, err } - if !model.ModelTypeMap["asr"] { - return common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s is not an ASR model", providerName, modelName)) - } - - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return common.CodeServerError, err - } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - err = providerInfo.ModelDriver.TranscribeAudioWithSender(&modelName, audioFile, apiConfig, asrConfig, sender) - if err != nil { - return common.CodeServerError, err - } - - return common.CodeSuccess, nil } - if modelInfo.Status == "active" { - if modelInfo.ModelType != "asr" { - return common.CodeServerError, errors.New(fmt.Sprintf("expect model %s@%s is an ASR model", modelName, providerName)) - } - // For local deployed models - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { - return common.CodeNotFound, errors.New("provider not found") - } - - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return common.CodeServerError, err - } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) - if err != nil { - return common.CodeServerError, err - } - - err = newProviderInfo.TranscribeAudioWithSender(&modelName, audioFile, apiConfig, asrConfig, sender) - if err != nil { - return common.CodeServerError, err - } - return common.CodeSuccess, nil + if modelConfig == nil { + modelConfig = &modelModule.ASRConfig{} } - return common.CodeServerError, errors.New("model is disabled") -} + var modelDriver modelModule.ModelDriver -// TranscribeAudio transcribe audio file to text -func (m *ModelProviderService) AudioSpeech(providerName, instanceName, modelName, userID string, audioContent *string, apiConfig *modelModule.APIConfig, ttsConfig *modelModule.TTSConfig) (*modelModule.TTSResponse, common.ErrorCode, error) { - if apiConfig == nil { - apiConfig = &modelModule.APIConfig{} - } - if ttsConfig == nil { - ttsConfig = &modelModule.TTSConfig{} + if info.ModelEntity == nil { + if !info.ModelInfo.ModelTypeMap["asr"] { + return common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is an ASR model", *modelName, *providerName)) + } + modelDriver = info.ProviderInfo.ModelDriver + } else { + // model entity exists + if info.ModelEntity.Status == "active" { + if info.ModelEntity.ModelType != "asr" { + return common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is an ASR model", *modelName, *providerName)) + } + + modelDriver, err = newModelDriverForBaseURL(info.ProviderInfo.ModelDriver, *providerName, *info.APIConfig.Region, *info.APIConfig.BaseURL) + if err != nil { + return common.CodeServerError, err + } + } else { + return common.CodeServerError, errors.New("model is inactive") + } } - // Get tenant ID from user - tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") - if err != nil { - return nil, common.CodeServerError, err - } - - if len(tenants) == 0 { - return nil, common.CodeNotFound, errors.New("user has no tenants") - } - - tenantID := tenants[0].TenantID - - // Check if provider exists - provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) - if err != nil { - return nil, common.CodeServerError, err - } - - instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) - if err != nil { - return nil, common.CodeServerError, err - } - - modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) - if err != nil { - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { - return nil, common.CodeNotFound, errors.New("provider not found") - } - - var model *modelModule.Model = nil - model, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) - if err != nil { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) - } - - if !model.ModelTypeMap["tts"] { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s is not a TTS model", providerName, modelName)) - } - - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return nil, common.CodeServerError, err - } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - var response *modelModule.TTSResponse - response, err = providerInfo.ModelDriver.AudioSpeech(&modelName, audioContent, apiConfig, ttsConfig) - if err != nil { - return nil, common.CodeServerError, err - } - if response == nil { - return nil, common.CodeServerError, errors.New("empty chat response") - } - - return response, common.CodeSuccess, nil - } - - if modelInfo.Status == "active" { - if modelInfo.ModelType != "tts" { - return nil, common.CodeServerError, errors.New(fmt.Sprintf("expect model %s@%s is a TTS model", modelName, providerName)) - } - // For local deployed models - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { - return nil, common.CodeNotFound, errors.New("provider not found") - } - - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return nil, common.CodeServerError, err - } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) - if err != nil { - return nil, common.CodeServerError, err - } - - var response *modelModule.TTSResponse - response, err = newProviderInfo.AudioSpeech(&modelName, audioContent, apiConfig, ttsConfig) - if err != nil { - return nil, common.CodeServerError, err - } - if response == nil { - return nil, common.CodeServerError, errors.New("empty chat response") - } - - return response, common.CodeSuccess, nil - } - - return nil, common.CodeServerError, errors.New("model is disabled") -} - -func (m *ModelProviderService) AudioSpeechStream(providerName, instanceName, modelName, userID string, audioContent *string, apiConfig *modelModule.APIConfig, ttsConfig *modelModule.TTSConfig, sender func(*string, *string) error) (common.ErrorCode, error) { - // Get tenant ID from user - tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + err = modelDriver.TranscribeAudioWithSender(modelName, audioFile, apiConfig, modelConfig, sender) if err != nil { return common.CodeServerError, err } - if len(tenants) == 0 { - return common.CodeNotFound, errors.New("user has no tenants") + return common.CodeSuccess, nil +} + +// AudioSpeech convert audio to speech +func (m *ModelProviderService) AudioSpeech(providerName, instanceName, modelName, modelID *string, userID string, audioContent *string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.TTSConfig) (*modelModule.TTSResponse, common.ErrorCode, error) { + + var err error + var info *ModelInstanceAndProviderInfo + + if modelID != nil { + info, err = m.getModelInstanceAndProviderByID(modelID, userID, apiConfig) + if err != nil || info == nil { + return nil, common.CodeNotFound, err + } + } else { + info, err = m.getModelInstanceAndProviderByName(providerName, instanceName, modelName, userID, apiConfig) + if err != nil || info == nil { + return nil, common.CodeNotFound, err + } } - tenantID := tenants[0].TenantID + if modelConfig == nil { + modelConfig = &modelModule.TTSConfig{} + } - // Check if provider exists - provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + var modelDriver modelModule.ModelDriver + + if info.ModelEntity == nil { + if !info.ModelInfo.ModelTypeMap["tts"] { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a TTS model", *modelName, *providerName)) + } + modelDriver = info.ProviderInfo.ModelDriver + } else { + // model entity exists + if info.ModelEntity.Status == "active" { + if info.ModelEntity.ModelType != "tts" { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a TTS model", *modelName, *providerName)) + } + + modelDriver, err = newModelDriverForBaseURL(info.ProviderInfo.ModelDriver, *providerName, *info.APIConfig.Region, *info.APIConfig.BaseURL) + if err != nil { + return nil, common.CodeServerError, err + } + } else { + return nil, common.CodeServerError, errors.New("model is inactive") + } + } + + var response *modelModule.TTSResponse + response, err = modelDriver.AudioSpeech(modelName, audioContent, apiConfig, modelConfig) + if err != nil { + return nil, common.CodeServerError, err + } + if response == nil { + return nil, common.CodeServerError, errors.New("empty chat response") + } + + return response, common.CodeSuccess, nil +} + +func (m *ModelProviderService) AudioSpeechStream(providerName, instanceName, modelName, modelID *string, userID string, audioContent *string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.TTSConfig, sender func(*string, *string) error) (common.ErrorCode, error) { + + var err error + var info *ModelInstanceAndProviderInfo + + if modelID != nil { + info, err = m.getModelInstanceAndProviderByID(modelID, userID, apiConfig) + if err != nil || info == nil { + return common.CodeNotFound, err + } + } else { + info, err = m.getModelInstanceAndProviderByName(providerName, instanceName, modelName, userID, apiConfig) + if err != nil || info == nil { + return common.CodeNotFound, err + } + } + + if modelConfig == nil { + modelConfig = &modelModule.TTSConfig{} + } + + var modelDriver modelModule.ModelDriver + + if info.ModelEntity == nil { + if !info.ModelInfo.ModelTypeMap["tts"] { + return common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a TTS model", *modelName, *providerName)) + } + modelDriver = info.ProviderInfo.ModelDriver + } else { + // model entity exists + if info.ModelEntity.Status == "active" { + if info.ModelEntity.ModelType != "tts" { + return common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a TTS model", *modelName, *providerName)) + } + + modelDriver, err = newModelDriverForBaseURL(info.ProviderInfo.ModelDriver, *providerName, *info.APIConfig.Region, *info.APIConfig.BaseURL) + if err != nil { + return common.CodeServerError, err + } + } else { + return common.CodeServerError, errors.New("model is inactive") + } + } + + err = modelDriver.AudioSpeechWithSender(modelName, audioContent, apiConfig, modelConfig, sender) if err != nil { return common.CodeServerError, err } - instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) - if err != nil { - return common.CodeServerError, err - } - - modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) - if err != nil { - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { - return common.CodeNotFound, err - } - - var model *modelModule.Model = nil - model, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) - if err != nil { - return common.CodeNotFound, err - } - - if !model.ModelTypeMap["tts"] { - return common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s is not a TTS model", providerName, modelName)) - } - - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return common.CodeServerError, err - } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - err = providerInfo.ModelDriver.AudioSpeechWithSender(&modelName, audioContent, apiConfig, ttsConfig, sender) - if err != nil { - return common.CodeServerError, err - } - - return common.CodeSuccess, nil - } - - if modelInfo.Status == "active" { - if modelInfo.ModelType != "tts" { - return common.CodeServerError, errors.New(fmt.Sprintf("expect model %s@%s is a TTS model", modelName, providerName)) - } - // For local deployed models - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { - return common.CodeNotFound, errors.New("provider not found") - } - - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return common.CodeServerError, err - } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) - if err != nil { - return common.CodeServerError, err - } - - err = newProviderInfo.AudioSpeechWithSender(&modelName, audioContent, apiConfig, ttsConfig, sender) - if err != nil { - return common.CodeServerError, err - } - return common.CodeSuccess, nil - } - - return common.CodeServerError, errors.New("model is disabled") + return common.CodeSuccess, nil } -func (m *ModelProviderService) OCRFile(providerName, instanceName, modelName, userID string, content []byte, url *string, apiConfig *modelModule.APIConfig, ocrConfig *modelModule.OCRConfig) (*modelModule.OCRFileResponse, common.ErrorCode, error) { - if apiConfig == nil { - apiConfig = &modelModule.APIConfig{} - } - if ocrConfig == nil { - ocrConfig = &modelModule.OCRConfig{} +func (m *ModelProviderService) OCRFile(providerName, instanceName, modelName, modelID *string, userID string, content []byte, url *string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.OCRConfig) (*modelModule.OCRFileResponse, common.ErrorCode, error) { + + var err error + var info *ModelInstanceAndProviderInfo + + if modelID != nil { + info, err = m.getModelInstanceAndProviderByID(modelID, userID, apiConfig) + if err != nil || info == nil { + return nil, common.CodeNotFound, err + } + } else { + info, err = m.getModelInstanceAndProviderByName(providerName, instanceName, modelName, userID, apiConfig) + if err != nil || info == nil { + return nil, common.CodeNotFound, err + } } - // Get tenant ID from user - tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if modelConfig == nil { + modelConfig = &modelModule.OCRConfig{} + } + + var modelDriver modelModule.ModelDriver + + if info.ModelEntity == nil { + if !info.ModelInfo.ModelTypeMap["ocr"] { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is an OCR model", *modelName, *providerName)) + } + modelDriver = info.ProviderInfo.ModelDriver + } else { + // model entity exists + if info.ModelEntity.Status == "active" { + if info.ModelEntity.ModelType != "ocr" { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is an OCR model", *modelName, *providerName)) + } + + modelDriver, err = newModelDriverForBaseURL(info.ProviderInfo.ModelDriver, *providerName, *info.APIConfig.Region, *info.APIConfig.BaseURL) + if err != nil { + return nil, common.CodeServerError, err + } + } else { + return nil, common.CodeServerError, errors.New("model is inactive") + } + } + + var response *modelModule.OCRFileResponse + response, err = modelDriver.OCRFile(modelName, content, url, apiConfig, modelConfig) if err != nil { return nil, common.CodeServerError, err } - - if len(tenants) == 0 { - return nil, common.CodeNotFound, errors.New("user has no tenants") + if response == nil { + return nil, common.CodeServerError, errors.New("empty chat response") } - tenantID := tenants[0].TenantID - - // Check if provider exists - provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) - if err != nil { - return nil, common.CodeServerError, err - } - - instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) - if err != nil { - return nil, common.CodeServerError, err - } - - modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) - if err != nil { - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { - return nil, common.CodeNotFound, errors.New("provider not found") - } - - var model *modelModule.Model = nil - model, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) - if err != nil { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) - } - - if !model.ModelTypeMap["ocr"] { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s is not an OCR model", providerName, modelName)) - } - - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return nil, common.CodeServerError, err - } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - var response *modelModule.OCRFileResponse - response, err = providerInfo.ModelDriver.OCRFile(&modelName, content, url, apiConfig, ocrConfig) - if err != nil { - return nil, common.CodeServerError, err - } - if response == nil { - return nil, common.CodeServerError, errors.New("empty chat response") - } - - return response, common.CodeSuccess, nil - } - - if modelInfo.Status == "active" { - if modelInfo.ModelType != "ocr" { - return nil, common.CodeServerError, errors.New(fmt.Sprintf("expect model %s@%s is an OCR model", modelName, providerName)) - } - // For local deployed models - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { - return nil, common.CodeNotFound, errors.New("provider not found") - } - - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return nil, common.CodeServerError, err - } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) - if err != nil { - return nil, common.CodeServerError, err - } - - var response *modelModule.OCRFileResponse - response, err = newProviderInfo.OCRFile(&modelName, content, url, apiConfig, ocrConfig) - if err != nil { - return nil, common.CodeServerError, err - } - if response == nil { - return nil, common.CodeServerError, errors.New("empty chat response") - } - - return response, common.CodeSuccess, nil - } - - return nil, common.CodeServerError, errors.New("model is disabled") + return response, common.CodeSuccess, nil } -func (m *ModelProviderService) ParseFile(providerName, instanceName, modelName, userID string, content []byte, url *string, apiConfig *modelModule.APIConfig, parseFileConfig *modelModule.ParseFileConfig) (*modelModule.ParseFileResponse, common.ErrorCode, error) { - if apiConfig == nil { - apiConfig = &modelModule.APIConfig{} - } - if parseFileConfig == nil { - parseFileConfig = &modelModule.ParseFileConfig{} +func (m *ModelProviderService) ParseFile(providerName, instanceName, modelName, modelID *string, userID string, content []byte, url *string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.ParseFileConfig) (*modelModule.ParseFileResponse, common.ErrorCode, error) { + + var err error + var info *ModelInstanceAndProviderInfo + + if modelID != nil { + info, err = m.getModelInstanceAndProviderByID(modelID, userID, apiConfig) + if err != nil || info == nil { + return nil, common.CodeNotFound, err + } + } else { + info, err = m.getModelInstanceAndProviderByName(providerName, instanceName, modelName, userID, apiConfig) + if err != nil || info == nil { + return nil, common.CodeNotFound, err + } } - // Get tenant ID from user - tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if modelConfig == nil { + modelConfig = &modelModule.ParseFileConfig{} + } + + var modelDriver modelModule.ModelDriver + + if info.ModelEntity == nil { + if !info.ModelInfo.ModelTypeMap["doc_parse"] { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a ParseFile model", *modelName, *providerName)) + } + modelDriver = info.ProviderInfo.ModelDriver + } else { + // model entity exists + if info.ModelEntity.Status == "active" { + if info.ModelEntity.ModelType != "doc_parse" { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a ParseFile model", *modelName, *providerName)) + } + + modelDriver, err = newModelDriverForBaseURL(info.ProviderInfo.ModelDriver, *providerName, *info.APIConfig.Region, *info.APIConfig.BaseURL) + if err != nil { + return nil, common.CodeServerError, err + } + } else { + return nil, common.CodeServerError, errors.New("model is inactive") + } + } + + var response *modelModule.ParseFileResponse + response, err = modelDriver.ParseFile(modelName, content, url, apiConfig, modelConfig) if err != nil { return nil, common.CodeServerError, err } - - if len(tenants) == 0 { - return nil, common.CodeNotFound, errors.New("user has no tenants") + if response == nil { + return nil, common.CodeServerError, errors.New("empty chat response") } - tenantID := tenants[0].TenantID - - // Check if provider exists - provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) - if err != nil { - return nil, common.CodeServerError, err - } - - instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) - if err != nil { - return nil, common.CodeServerError, err - } - - modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) - if err != nil { - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { - return nil, common.CodeNotFound, errors.New("provider not found") - } - - var model *modelModule.Model = nil - model, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) - if err != nil { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) - } - - if !model.ModelTypeMap["doc_parse"] { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s is not a Document Parse model", providerName, modelName)) - } - - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return nil, common.CodeServerError, err - } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - var response *modelModule.ParseFileResponse - response, err = providerInfo.ModelDriver.ParseFile(&modelName, content, url, apiConfig, parseFileConfig) - if err != nil { - return nil, common.CodeServerError, err - } - if response == nil { - return nil, common.CodeServerError, errors.New("empty chat response") - } - - return response, common.CodeSuccess, nil - } - - if modelInfo.Status == "active" { - if modelInfo.ModelType != "doc_parse" { - return nil, common.CodeServerError, errors.New(fmt.Sprintf("expect model %s@%s is a Document Parse model", modelName, providerName)) - } - // For local deployed models - providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerInfo == nil { - return nil, common.CodeNotFound, errors.New("provider not found") - } - - var extra map[string]string - err = json.Unmarshal([]byte(instance.Extra), &extra) - if err != nil { - return nil, common.CodeServerError, err - } - - region := extra["region"] - apiConfig.Region = ®ion - apiConfig.ApiKey = &instance.APIKey - - newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) - if err != nil { - return nil, common.CodeServerError, err - } - - var response *modelModule.ParseFileResponse - response, err = newProviderInfo.ParseFile(&modelName, content, url, apiConfig, parseFileConfig) - if err != nil { - return nil, common.CodeServerError, err - } - if response == nil { - return nil, common.CodeServerError, errors.New("empty chat response") - } - - return response, common.CodeSuccess, nil - } - - return nil, common.CodeServerError, errors.New("model is disabled") + return response, common.CodeSuccess, nil } // GetEmbeddingModel returns an EmbeddingModel wrapper for the given tenant @@ -2472,11 +2108,12 @@ func (m *ModelProviderService) AddModel(request *AddModelRequest, userID string) return common.CodeSuccess, nil } -// modelName must be a composite name of the form "model@instance@provider" or +// GetModelConfigFromProviderInstance get model config from provider instance +// modelName@instance@provider or // "model@provider" — the provider is required and is looked up directly via // tenant_model_provider. For 2-part names the instance defaults to "default". // If the model is enrolled in tenant_model, that row is used (and INACTIVE rows -// raise). Otherwise the factory's LLM catalog is consulted, with +// raise). Otherwise, the factory's LLM catalog is consulted, with // region=intl + siliconflow redirected to the siliconflow_intl factory. func (m *ModelProviderService) GetModelConfigFromProviderInstance(tenantID string, modelType entity.ModelType, modelName string) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error) { common.Debug("GetModelConfigFromProviderInstance", @@ -2696,8 +2333,10 @@ func (m *ModelProviderService) getModelConfig(tenantID, compositeModelName strin } providerInfo := dao.GetModelProviderManager().FindProvider(providerName) - if providerName != "Builtin" && providerInfo == nil { - return nil, "", nil, 0, fmt.Errorf("provider %s not found", providerName) + if providerInfo == nil { + if providerName != "Builtin" { + return nil, "", nil, 0, fmt.Errorf("model provider config not found: %s", providerName) + } } // Get model info to extract max_tokens @@ -2723,22 +2362,22 @@ func (m *ModelProviderService) getModelConfig(tenantID, compositeModelName strin } apiConfig := &modelModule.APIConfig{ApiKey: &apiKey, Region: ®ion} return builtinDriver, modelName, apiConfig, maxTokens, nil - } else { - _, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(providerID, instance.ID, modelName) - if err != nil { - _, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) - if err != nil { - return nil, "", nil, 0, fmt.Errorf("provider %s model %s not found", providerName, modelName) - } - } - apiKey = instance.APIKey } + _, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(providerID, instance.ID, modelName) + if err != nil { + _, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err != nil { + return nil, "", nil, 0, fmt.Errorf("provider %s model %s not found", providerName, modelName) + } + } + apiKey = instance.APIKey + apiConfig := &modelModule.APIConfig{ApiKey: &apiKey, Region: ®ion} return providerInfo.ModelDriver, modelName, apiConfig, maxTokens, nil } -// getModelConfig returns the model driver, model name, API config, and max tokens for a model +// ListAllModels list all models func (m *ModelProviderService) ListAllModels(pageIndex, pageSize int) ([]map[string]interface{}, error) { models, err := dao.GetModelProviderManager().ListAllModels() if err != nil {