mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-02 16:55:42 +08:00
Add extra field to model instance (#14203)
### What problem does this PR solve? Now each model support region with different URL ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
{
|
||||
"name": "OpenAI",
|
||||
"tags": "LLM,TEXT EMBEDDING,TTS,TEXT RE-RANK,SPEECH2TEXT,MODERATION",
|
||||
"url": "https://api.openai.com/v1",
|
||||
"url": {
|
||||
"default": "https://api.openai.com/v1"
|
||||
},
|
||||
"url_suffix": {
|
||||
"chat": "chat/completions"
|
||||
},
|
||||
@@ -10,8 +11,8 @@
|
||||
"name": "gpt-5.2-pro",
|
||||
"max_tokens": 400000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
"chat",
|
||||
"vision"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -19,8 +20,8 @@
|
||||
"name": "gpt-5.2",
|
||||
"max_tokens": 400000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
"chat",
|
||||
"vision"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -28,8 +29,8 @@
|
||||
"name": "gpt-5.1",
|
||||
"max_tokens": 400000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
"chat",
|
||||
"vision"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -37,8 +38,8 @@
|
||||
"name": "gpt-5.1-chat-latest",
|
||||
"max_tokens": 400000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
"chat",
|
||||
"vision"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -46,8 +47,8 @@
|
||||
"name": "gpt-5",
|
||||
"max_tokens": 400000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
"chat",
|
||||
"vision"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -55,8 +56,8 @@
|
||||
"name": "gpt-5-mini",
|
||||
"max_tokens": 400000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
"chat",
|
||||
"vision"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -64,8 +65,8 @@
|
||||
"name": "gpt-5-nano",
|
||||
"max_tokens": 400000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
"chat",
|
||||
"vision"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -73,8 +74,8 @@
|
||||
"name": "gpt-5-chat-latest",
|
||||
"max_tokens": 400000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
"chat",
|
||||
"vision"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -82,8 +83,8 @@
|
||||
"name": "gpt-4.1",
|
||||
"max_tokens": 1047576,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
"chat",
|
||||
"vision"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -91,8 +92,8 @@
|
||||
"name": "gpt-4.1-mini",
|
||||
"max_tokens": 1047576,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
"chat",
|
||||
"vision"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -100,8 +101,8 @@
|
||||
"name": "gpt-4.1-nano",
|
||||
"max_tokens": 1047576,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
"chat",
|
||||
"vision"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -109,7 +110,7 @@
|
||||
"name": "gpt-4.5-preview",
|
||||
"max_tokens": 128000,
|
||||
"model_types": [
|
||||
"llm"
|
||||
"chat"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -117,8 +118,8 @@
|
||||
"name": "o3",
|
||||
"max_tokens": 200000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
"chat",
|
||||
"vision"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -126,8 +127,8 @@
|
||||
"name": "o4-mini",
|
||||
"max_tokens": 200000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
"chat",
|
||||
"vision"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -135,8 +136,8 @@
|
||||
"name": "o4-mini-high",
|
||||
"max_tokens": 200000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
"chat",
|
||||
"vision"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -144,8 +145,8 @@
|
||||
"name": "gpt-4o-mini",
|
||||
"max_tokens": 128000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
"chat",
|
||||
"vision"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -153,8 +154,8 @@
|
||||
"name": "gpt-4o",
|
||||
"max_tokens": 128000,
|
||||
"model_types": [
|
||||
"llm",
|
||||
"vlm"
|
||||
"chat",
|
||||
"vision"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -162,7 +163,7 @@
|
||||
"name": "gpt-3.5-turbo",
|
||||
"max_tokens": 4096,
|
||||
"model_types": [
|
||||
"llm"
|
||||
"chat"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -170,7 +171,7 @@
|
||||
"name": "gpt-3.5-turbo-16k-0613",
|
||||
"max_tokens": 16385,
|
||||
"model_types": [
|
||||
"llm"
|
||||
"chat"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -202,7 +203,7 @@
|
||||
"name": "whisper-1",
|
||||
"max_tokens": 26214400,
|
||||
"model_types": [
|
||||
"speech2text"
|
||||
"asr"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -210,7 +211,7 @@
|
||||
"name": "gpt-4",
|
||||
"max_tokens": 8191,
|
||||
"model_types": [
|
||||
"llm"
|
||||
"chat"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -218,7 +219,7 @@
|
||||
"name": "gpt-4-turbo",
|
||||
"max_tokens": 8191,
|
||||
"model_types": [
|
||||
"llm"
|
||||
"chat"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -226,7 +227,7 @@
|
||||
"name": "gpt-4-32k",
|
||||
"max_tokens": 32768,
|
||||
"model_types": [
|
||||
"llm"
|
||||
"chat"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
{
|
||||
"name": "xAI",
|
||||
"tags": "LLM",
|
||||
"url": "https://api.x.ai/v1",
|
||||
"url": {
|
||||
"default": "https://api.x.ai/v1"
|
||||
},
|
||||
"url_suffix": {
|
||||
"chat": "chat/completions"
|
||||
},
|
||||
@@ -9,44 +10,38 @@
|
||||
{
|
||||
"name": "grok-4",
|
||||
"max_tokens": 256000,
|
||||
"model_types": ["llm"],
|
||||
"model_types": ["chat"],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "grok-3",
|
||||
"max_tokens": 131072,
|
||||
"model_types": ["llm"],
|
||||
"model_types": ["chat"],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "grok-3-fast",
|
||||
"max_tokens": 131072,
|
||||
"model_types": ["llm"],
|
||||
"model_types": ["chat"],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "grok-3-mini",
|
||||
"max_tokens": 131072,
|
||||
"model_types": ["llm"],
|
||||
"model_types": ["chat"],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "grok-3-mini-mini-fast",
|
||||
"max_tokens": 131072,
|
||||
"model_types": ["llm"],
|
||||
"model_types": ["chat"],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "grok-2-vision",
|
||||
"max_tokens": 32768,
|
||||
"model_types": ["vlm"],
|
||||
"features": {
|
||||
"multimodal": {
|
||||
"enabled": true,
|
||||
"input_modalities": ["image"],
|
||||
"output_modalities": ["text"]
|
||||
}
|
||||
}
|
||||
"model_types": ["vision"],
|
||||
"features": {}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,7 +1,8 @@
|
||||
{
|
||||
"name": "ZHIPU-AI",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"url": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"url": {
|
||||
"default": "https://open.bigmodel.cn/api/paas/v4"
|
||||
},
|
||||
"url_suffix": {
|
||||
"chat": "chat/completions",
|
||||
"async_chat": "async/chat/completions",
|
||||
@@ -31,7 +32,7 @@
|
||||
"max_tokens": 128000,
|
||||
"model_types": [
|
||||
"chat",
|
||||
"image2text"
|
||||
"vision"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -71,7 +72,7 @@
|
||||
"name": "glm-4.5v",
|
||||
"max_tokens": 64000,
|
||||
"model_types": [
|
||||
"image2text"
|
||||
"vision"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -151,7 +152,7 @@
|
||||
"name": "glm-4v",
|
||||
"max_tokens": 2000,
|
||||
"model_types": [
|
||||
"image2text"
|
||||
"vision"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
@@ -183,7 +184,28 @@
|
||||
"name": "glm-asr",
|
||||
"max_tokens": 4096,
|
||||
"model_types": [
|
||||
"speech2text"
|
||||
"asr"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "glm-tts",
|
||||
"model_types": [
|
||||
"tts"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "glm-ocr",
|
||||
"model_types": [
|
||||
"ocr"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "glm-rerank",
|
||||
"model_types": [
|
||||
"rerank"
|
||||
],
|
||||
"features": {}
|
||||
}
|
||||
|
||||
@@ -1187,35 +1187,36 @@ func (p *Parser) parseAdminSetVariable() (*Command, error) {
|
||||
func (p *Parser) parseAdminSetDefault() (*Command, error) {
|
||||
p.nextToken() // consume DEFAULT
|
||||
|
||||
var modelType, modelID string
|
||||
var modelType string
|
||||
|
||||
switch p.curToken.Type {
|
||||
case TokenLLM:
|
||||
modelType = "llm_id"
|
||||
case TokenVLM:
|
||||
modelType = "img2txt_id"
|
||||
case TokenChat:
|
||||
modelType = "chat"
|
||||
case TokenVision:
|
||||
modelType = "vision"
|
||||
case TokenEmbedding:
|
||||
modelType = "embd_id"
|
||||
case TokenReranker:
|
||||
modelType = "reranker_id"
|
||||
modelType = "embedding"
|
||||
case TokenRerank:
|
||||
modelType = "rerank"
|
||||
case TokenASR:
|
||||
modelType = "asr_id"
|
||||
modelType = "asr"
|
||||
case TokenTTS:
|
||||
modelType = "tts_id"
|
||||
modelType = "tts"
|
||||
case TokenOCR:
|
||||
modelType = "ocr"
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown model type: %s", p.curToken.Value)
|
||||
}
|
||||
|
||||
p.nextToken()
|
||||
id, err := p.parseQuotedString()
|
||||
compositeModelName, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
modelID = id
|
||||
|
||||
cmd := NewCommand("set_default_model")
|
||||
cmd.Params["model_type"] = modelType
|
||||
cmd.Params["model_id"] = modelID
|
||||
cmd.Params["composite_model_name"] = compositeModelName
|
||||
|
||||
p.nextToken()
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
@@ -1254,18 +1255,20 @@ func (p *Parser) parseAdminResetCommand() (*Command, error) {
|
||||
|
||||
var modelType string
|
||||
switch p.curToken.Type {
|
||||
case TokenLLM:
|
||||
modelType = "llm_id"
|
||||
case TokenVLM:
|
||||
modelType = "img2txt_id"
|
||||
case TokenChat:
|
||||
modelType = "chat"
|
||||
case TokenVision:
|
||||
modelType = "vision"
|
||||
case TokenEmbedding:
|
||||
modelType = "embd_id"
|
||||
case TokenReranker:
|
||||
modelType = "reranker_id"
|
||||
modelType = "embedding"
|
||||
case TokenRerank:
|
||||
modelType = "rerank"
|
||||
case TokenASR:
|
||||
modelType = "asr_id"
|
||||
modelType = "asr"
|
||||
case TokenTTS:
|
||||
modelType = "tts_id"
|
||||
modelType = "tts"
|
||||
case TokenOCR:
|
||||
modelType = "ocr"
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown model type: %s", p.curToken.Value)
|
||||
}
|
||||
|
||||
@@ -250,6 +250,8 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) {
|
||||
return c.ShowCurrentModel(cmd)
|
||||
case "set_default_model":
|
||||
return c.SetDefaultModel(cmd)
|
||||
case "reset_default_model":
|
||||
return c.ResetDefaultModel(cmd)
|
||||
case "list_user_default_models":
|
||||
return c.ListDefaultModels(cmd)
|
||||
// Dataset, metadata commands
|
||||
|
||||
@@ -375,27 +375,29 @@ func (c *RAGFlowClient) ShowModel(cmd *Command) (ResponseIf, error) {
|
||||
|
||||
func (c *RAGFlowClient) SetDefaultModel(cmd *Command) (ResponseIf, error) {
|
||||
|
||||
modeType, ok := cmd.Params["model_type"].(string)
|
||||
modelType, ok := cmd.Params["model_type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model_type not provided")
|
||||
}
|
||||
modelProvider, ok := cmd.Params["model_provider"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model_provider not provided")
|
||||
}
|
||||
modelInstance, ok := cmd.Params["model_instance"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model_instance not provided")
|
||||
}
|
||||
modelName, ok := cmd.Params["model_name"].(string)
|
||||
|
||||
compositeModelName, ok := cmd.Params["composite_model_name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model_name not provided")
|
||||
}
|
||||
|
||||
var providerName, instanceName, modelName string
|
||||
names := strings.Split(compositeModelName, "/")
|
||||
if len(names) != 3 {
|
||||
return nil, fmt.Errorf("model name must be in format 'provider/instance/model'")
|
||||
}
|
||||
providerName = names[0]
|
||||
instanceName = names[1]
|
||||
modelName = names[2]
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"model_type": modeType,
|
||||
"model_provider": modelProvider,
|
||||
"model_instance": modelInstance,
|
||||
"model_type": modelType,
|
||||
"model_provider": providerName,
|
||||
"model_instance": instanceName,
|
||||
"model_name": modelName,
|
||||
}
|
||||
|
||||
@@ -420,6 +422,38 @@ func (c *RAGFlowClient) SetDefaultModel(cmd *Command) (ResponseIf, error) {
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *RAGFlowClient) ResetDefaultModel(cmd *Command) (ResponseIf, error) {
|
||||
|
||||
modelType, ok := cmd.Params["model_type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model_type not provided")
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"model_type": modelType,
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("PATCH", "/models", true, "web", nil, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to reset default model: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to reset default model: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
var result SimpleResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to reset default model: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
return nil, fmt.Errorf("%s", result.Message)
|
||||
}
|
||||
result.Duration = resp.Duration
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *RAGFlowClient) ListDefaultModels(cmd *Command) (ResponseIf, error) {
|
||||
resp, err := c.HTTPClient.Request("GET", "/models", true, "web", nil, nil)
|
||||
if err != nil {
|
||||
|
||||
@@ -327,18 +327,18 @@ func (l *Lexer) lookupIdent(ident string) Token {
|
||||
return Token{Type: TokenSearch, Value: ident}
|
||||
case "CURRENT":
|
||||
return Token{Type: TokenCurrent, Value: ident}
|
||||
case "LLM":
|
||||
return Token{Type: TokenLLM, Value: ident}
|
||||
case "VLM":
|
||||
return Token{Type: TokenVLM, Value: ident}
|
||||
case "VISION":
|
||||
return Token{Type: TokenVision, Value: ident}
|
||||
case "EMBEDDING":
|
||||
return Token{Type: TokenEmbedding, Value: ident}
|
||||
case "RERANKER":
|
||||
return Token{Type: TokenReranker, Value: ident}
|
||||
case "RERANK":
|
||||
return Token{Type: TokenRerank, Value: ident}
|
||||
case "ASR":
|
||||
return Token{Type: TokenASR, Value: ident}
|
||||
case "TTS":
|
||||
return Token{Type: TokenTTS, Value: ident}
|
||||
case "OCR":
|
||||
return Token{Type: TokenOCR, Value: ident}
|
||||
case "ASYNC":
|
||||
return Token{Type: TokenAsync, Value: ident}
|
||||
case "SYNC":
|
||||
|
||||
@@ -90,12 +90,12 @@ const (
|
||||
TokenPipeline
|
||||
TokenSearch
|
||||
TokenCurrent
|
||||
TokenLLM
|
||||
TokenVLM
|
||||
TokenVision
|
||||
TokenEmbedding
|
||||
TokenReranker
|
||||
TokenRerank
|
||||
TokenASR
|
||||
TokenTTS
|
||||
TokenOCR
|
||||
TokenAsync
|
||||
TokenSync
|
||||
TokenBenchmark
|
||||
|
||||
@@ -1417,8 +1417,8 @@ func (c *RAGFlowClient) ChatToModel(cmd *Command) (ResponseIf, error) {
|
||||
|
||||
var providerName, instanceName, modelName string
|
||||
|
||||
// Check if model_name is provided in command
|
||||
if compositeModelName, ok := cmd.Params["model_name"].(string); ok && compositeModelName != "" {
|
||||
// Check if composite_model_name is provided in command
|
||||
if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" {
|
||||
names := strings.Split(compositeModelName, "/")
|
||||
if len(names) != 3 {
|
||||
return nil, fmt.Errorf("model name must be in format 'provider/instance/model'")
|
||||
@@ -1524,12 +1524,12 @@ func (c *RAGFlowClient) UseModel(cmd *Command) (ResponseIf, error) {
|
||||
return nil, fmt.Errorf("this command is only allowed in USER mode")
|
||||
}
|
||||
|
||||
modelIdentifier, ok := cmd.Params["model_identifier"].(string)
|
||||
if !ok || modelIdentifier == "" {
|
||||
compositeModelName, ok := cmd.Params["composite_model_name"].(string)
|
||||
if !ok || compositeModelName == "" {
|
||||
return nil, fmt.Errorf("model identifier not provided")
|
||||
}
|
||||
|
||||
names := strings.Split(modelIdentifier, "/")
|
||||
names := strings.Split(compositeModelName, "/")
|
||||
if len(names) != 3 {
|
||||
return nil, fmt.Errorf("model identifier must be in format 'provider/instance/model'")
|
||||
}
|
||||
|
||||
@@ -1597,49 +1597,48 @@ func (p *Parser) parseSetVariable() (*Command, error) {
|
||||
func (p *Parser) parseSetDefault() (*Command, error) {
|
||||
p.nextToken() // consume DEFAULT
|
||||
|
||||
var modelType, modelProvider, modelInstance, modelName string
|
||||
var modelType, compositeModelName string
|
||||
var err error
|
||||
|
||||
switch p.curToken.Type {
|
||||
case TokenLLM:
|
||||
case TokenChat:
|
||||
modelType = "chat"
|
||||
case TokenVLM:
|
||||
modelType = "image2text"
|
||||
case TokenVision:
|
||||
modelType = "vision"
|
||||
case TokenEmbedding:
|
||||
modelType = "embedding"
|
||||
case TokenReranker:
|
||||
case TokenRerank:
|
||||
modelType = "rerank"
|
||||
case TokenASR:
|
||||
modelType = "asr"
|
||||
case TokenTTS:
|
||||
modelType = "tts"
|
||||
case TokenOCR:
|
||||
modelType = "ocr"
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown model type: %s", p.curToken.Value)
|
||||
}
|
||||
p.nextToken() // pass model type
|
||||
|
||||
p.nextToken()
|
||||
modelProvider, err = p.parseQuotedString()
|
||||
if p.curToken.Type != TokenModel {
|
||||
return nil, fmt.Errorf("expected MODEL")
|
||||
}
|
||||
p.nextToken() // pass MODEL
|
||||
|
||||
// Format: 'provider/instance/model' or just 'message'
|
||||
if p.curToken.Type != TokenQuotedString {
|
||||
return nil, fmt.Errorf("expected quoted string with format provider/instance/model")
|
||||
}
|
||||
|
||||
compositeModelName, err = p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.nextToken()
|
||||
modelInstance, err = p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.nextToken()
|
||||
modelName, err = p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cmd := NewCommand("set_default_model")
|
||||
cmd.Params["model_type"] = modelType
|
||||
cmd.Params["model_provider"] = modelProvider
|
||||
cmd.Params["model_instance"] = modelInstance
|
||||
cmd.Params["model_name"] = modelName
|
||||
cmd.Params["composite_model_name"] = compositeModelName
|
||||
|
||||
p.nextToken()
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
@@ -1717,26 +1716,33 @@ func (p *Parser) parseResetCommand() (*Command, error) {
|
||||
|
||||
var modelType string
|
||||
switch p.curToken.Type {
|
||||
case TokenLLM:
|
||||
modelType = "llm_id"
|
||||
case TokenVLM:
|
||||
modelType = "img2txt_id"
|
||||
case TokenChat:
|
||||
modelType = "chat"
|
||||
case TokenVision:
|
||||
modelType = "vision"
|
||||
case TokenEmbedding:
|
||||
modelType = "embd_id"
|
||||
case TokenReranker:
|
||||
modelType = "reranker_id"
|
||||
modelType = "embedding"
|
||||
case TokenRerank:
|
||||
modelType = "rerank"
|
||||
case TokenASR:
|
||||
modelType = "asr_id"
|
||||
modelType = "asr"
|
||||
case TokenTTS:
|
||||
modelType = "tts_id"
|
||||
modelType = "tts"
|
||||
case TokenOCR:
|
||||
modelType = "ocr"
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown model type: %s", p.curToken.Value)
|
||||
}
|
||||
|
||||
cmd := NewCommand("reset_default_model")
|
||||
cmd.Params["model_type"] = modelType
|
||||
|
||||
p.nextToken()
|
||||
|
||||
if p.curToken.Type != TokenModel {
|
||||
return nil, fmt.Errorf("expected MODEL")
|
||||
}
|
||||
p.nextToken() // pass MODEL
|
||||
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
@@ -2144,7 +2150,7 @@ func (p *Parser) parseDisableCommand() (*Command, error) {
|
||||
func (p *Parser) parseChatCommand() (*Command, error) {
|
||||
p.nextToken() // consume CHAT
|
||||
|
||||
var modelName string
|
||||
var compositeModelName string
|
||||
var message string
|
||||
|
||||
// Check if we have a quoted string that looks like a model identifier (contains two slashes)
|
||||
@@ -2156,7 +2162,7 @@ func (p *Parser) parseChatCommand() (*Command, error) {
|
||||
slashCount := strings.Count(firstArg, "/")
|
||||
if slashCount == 2 {
|
||||
// This is likely a model identifier, expect another quoted string for message
|
||||
modelName = firstArg
|
||||
compositeModelName = firstArg
|
||||
p.nextToken()
|
||||
|
||||
// After model name, expect message
|
||||
@@ -2184,8 +2190,8 @@ func (p *Parser) parseChatCommand() (*Command, error) {
|
||||
}
|
||||
|
||||
cmd := NewCommand("chat_to_model")
|
||||
if modelName != "" {
|
||||
cmd.Params["model_name"] = modelName
|
||||
if compositeModelName != "" {
|
||||
cmd.Params["composite_model_name"] = compositeModelName
|
||||
}
|
||||
cmd.Params["message"] = message
|
||||
cmd.Params["reasoning"] = false
|
||||
@@ -2213,7 +2219,7 @@ func (p *Parser) parseUseCommand() (*Command, error) {
|
||||
p.nextToken() // consume MODEL
|
||||
|
||||
// Parse model identifier in format 'provider/instance/model'
|
||||
modelIdentifier, err := p.parseQuotedString()
|
||||
compositeModelName, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expected model identifier in format 'provider/instance/model': %w", err)
|
||||
}
|
||||
@@ -2225,7 +2231,7 @@ func (p *Parser) parseUseCommand() (*Command, error) {
|
||||
}
|
||||
|
||||
cmd := NewCommand("use_model")
|
||||
cmd.Params["model_identifier"] = modelIdentifier
|
||||
cmd.Params["composite_model_name"] = compositeModelName
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ type TenantInfo struct {
|
||||
ASRID string `gorm:"column:asr_id" json:"asr_id"`
|
||||
Img2TxtID string `gorm:"column:img2txt_id" json:"img2txt_id"`
|
||||
TTSID *string `gorm:"column:tts_id" json:"tts_id,omitempty"`
|
||||
OCRID string `gorm:"column:ocr_id" json:"ocr_id"`
|
||||
ParserIDs string `gorm:"column:parser_ids" json:"parser_ids"`
|
||||
Role string `gorm:"column:role" json:"role"`
|
||||
}
|
||||
@@ -71,7 +72,7 @@ func (dao *TenantDAO) GetInfoByUserID(userID string) ([]*TenantInfo, error) {
|
||||
var results []*TenantInfo
|
||||
|
||||
err := DB.Model(&entity.Tenant{}).
|
||||
Select("tenant.id as tenant_id, tenant.name, tenant.llm_id, tenant.embd_id, tenant.rerank_id, tenant.asr_id, tenant.img2txt_id, tenant.tts_id, tenant.parser_ids, user_tenant.role").
|
||||
Select("tenant.id as tenant_id, tenant.name, tenant.llm_id, tenant.embd_id, tenant.rerank_id, tenant.asr_id, tenant.img2txt_id, tenant.tts_id, tenant.ocr_id, tenant.parser_ids, user_tenant.role").
|
||||
Joins("INNER JOIN user_tenant ON user_tenant.tenant_id = tenant.id").
|
||||
Where("user_tenant.user_id = ? AND user_tenant.status = ? AND user_tenant.role = ? AND tenant.status = ?", userID, "1", "owner", "1").
|
||||
Scan(&results).Error
|
||||
|
||||
@@ -145,11 +145,10 @@ type Model struct {
|
||||
|
||||
// Provider represents an LLM provider
|
||||
type Provider struct {
|
||||
Name string `json:"name"`
|
||||
Tags string `json:"tags"`
|
||||
URL string `json:"url"`
|
||||
URLSuffix models.URLSuffix `json:"url_suffix"`
|
||||
Models []*Model `json:"models"`
|
||||
Name string `json:"name"`
|
||||
URL map[string]string `json:"url"`
|
||||
URLSuffix models.URLSuffix `json:"url_suffix"`
|
||||
Models []*Model `json:"models"`
|
||||
ModelDriver models.ModelDriver
|
||||
}
|
||||
|
||||
@@ -236,11 +235,24 @@ func (pm *ProviderManager) ListProviders() ([]map[string]interface{}, error) {
|
||||
var providers []map[string]interface{}
|
||||
|
||||
for _, provider := range pm.Providers {
|
||||
|
||||
modelTypeSet := make(map[string]struct{})
|
||||
for _, model := range provider.Models {
|
||||
for _, modelType := range model.ModelTypes {
|
||||
modelTypeSet[modelType] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
var modelTypes []string
|
||||
for modelType := range modelTypeSet {
|
||||
modelTypes = append(modelTypes, modelType)
|
||||
}
|
||||
|
||||
providerData := map[string]interface{}{
|
||||
"name": provider.Name,
|
||||
"tags": provider.Tags,
|
||||
"url": provider.URL,
|
||||
"url_suffix": provider.URLSuffix,
|
||||
"name": provider.Name,
|
||||
"url": provider.URL,
|
||||
"model_types": modelTypes,
|
||||
"url_suffix": provider.URLSuffix,
|
||||
}
|
||||
providers = append(providers, providerData)
|
||||
}
|
||||
@@ -262,7 +274,6 @@ func (pm *ProviderManager) GetProviderByName(providerName string) (map[string]in
|
||||
|
||||
providerInfo := map[string]interface{}{
|
||||
"name": provider.Name,
|
||||
"tags": provider.Tags,
|
||||
"base_url": provider.URL,
|
||||
"total_models": len(provider.Models),
|
||||
}
|
||||
|
||||
@@ -20,14 +20,14 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// DummyModel implements ModelDriver for Zhipu AI (智谱 AI)
|
||||
// DummyModel implements ModelDriver for Zhipu AI
|
||||
type DummyModel struct {
|
||||
BaseURL string
|
||||
BaseURL map[string]string
|
||||
URLSuffix URLSuffix
|
||||
}
|
||||
|
||||
// NewDummyModel creates a new Zhipu AI model instance
|
||||
func NewDummyModel(baseURL string, urlSuffix URLSuffix) *DummyModel {
|
||||
func NewDummyModel(baseURL map[string]string, urlSuffix URLSuffix) *DummyModel {
|
||||
return &DummyModel{
|
||||
BaseURL: baseURL,
|
||||
URLSuffix: urlSuffix,
|
||||
@@ -35,26 +35,16 @@ func NewDummyModel(baseURL string, urlSuffix URLSuffix) *DummyModel {
|
||||
}
|
||||
|
||||
// Chat sends a message and returns response
|
||||
func (z *DummyModel) Chat(modelName, apiKey, message *string, genConf map[string]interface{}) (string, error) {
|
||||
func (z *DummyModel) Chat(modelName, apiKey, message *string, modelConfig *ChatConfig) (string, error) {
|
||||
return "", fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// ChatStreamly sends a message and streams response
|
||||
func (z *DummyModel) ChatStreamly(modelName, apiKey, message *string, genConf map[string]interface{}) (<-chan string, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// ChatStreamlyWithChannel sends a message and streams response to channel (better performance)
|
||||
func (z *DummyModel) ChatStreamlyWithChannel(modelName, apiKey, message *string, genConf map[string]interface{}, resultChan chan<- string) error {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
|
||||
func (z *DummyModel) ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
func (z *DummyModel) EncodeToEmbedding(modelName, apiKey *string, texts []string) ([][]float64, error) {
|
||||
func (z *DummyModel) EncodeToEmbedding(modelName, apiKey *string, texts []string, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ func NewModelFactory() *ModelFactory {
|
||||
}
|
||||
|
||||
// CreateModelDriver creates a ModelDriver for the given provider and model
|
||||
func (f *ModelFactory) CreateModelDriver(providerName string, baseURL string, urlSuffix URLSuffix) (ModelDriver, error) {
|
||||
func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string]string, urlSuffix URLSuffix) (ModelDriver, error) {
|
||||
providerLower := strings.ToLower(providerName)
|
||||
switch providerLower {
|
||||
case "zhipu-ai":
|
||||
|
||||
@@ -3,15 +3,11 @@ package models
|
||||
// EmbeddingModel interface for embedding models
|
||||
type ModelDriver interface {
|
||||
// Chat sends a message and returns response
|
||||
Chat(modelName, apiKey, message *string, genConf map[string]interface{}) (string, error)
|
||||
// ChatStreamly sends a message and streams response
|
||||
ChatStreamly(modelName, apiKey, message *string, genConf map[string]interface{}) (<-chan string, error)
|
||||
// ChatStreamlyWithChannel sends a message and streams response to channel (better performance)
|
||||
ChatStreamlyWithChannel(modelName, apiKey, message *string, genConf map[string]interface{}, resultChan chan<- string) error
|
||||
Chat(modelName, apiKey, message *string, modelConfig *ChatConfig) (string, error)
|
||||
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
|
||||
ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error
|
||||
// Encode encodes a list of texts into embeddings
|
||||
EncodeToEmbedding(modelName, apiKey *string, texts []string) ([][]float64, error)
|
||||
EncodeToEmbedding(modelName, apiKey *string, texts []string, embeddingConfig *EmbeddingConfig) ([][]float64, error)
|
||||
}
|
||||
|
||||
// URLSuffix represents the URL suffixes for different API endpoints
|
||||
@@ -31,4 +27,9 @@ type ChatConfig struct {
|
||||
TopP *float64
|
||||
DoSample *bool
|
||||
Stop *[]string
|
||||
Region *string
|
||||
}
|
||||
|
||||
type EmbeddingConfig struct {
|
||||
Region *string
|
||||
}
|
||||
|
||||
@@ -30,13 +30,13 @@ import (
|
||||
|
||||
// ZhipuAIModel implements ModelDriver for Zhipu AI
|
||||
type ZhipuAIModel struct {
|
||||
BaseURL string
|
||||
BaseURL map[string]string
|
||||
URLSuffix URLSuffix
|
||||
httpClient *http.Client // Reusable HTTP client with connection pool
|
||||
}
|
||||
|
||||
// NewZhipuAIModel creates a new Zhipu AI model instance
|
||||
func NewZhipuAIModel(baseURL string, urlSuffix URLSuffix) *ZhipuAIModel {
|
||||
func NewZhipuAIModel(baseURL map[string]string, urlSuffix URLSuffix) *ZhipuAIModel {
|
||||
return &ZhipuAIModel{
|
||||
BaseURL: baseURL,
|
||||
URLSuffix: urlSuffix,
|
||||
@@ -53,7 +53,7 @@ func NewZhipuAIModel(baseURL string, urlSuffix URLSuffix) *ZhipuAIModel {
|
||||
}
|
||||
|
||||
// Chat sends a message and returns response
|
||||
func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, genConf map[string]interface{}) (string, error) {
|
||||
func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, chatModelConfig *ChatConfig) (string, error) {
|
||||
if message == nil {
|
||||
return "", fmt.Errorf("message is nil")
|
||||
}
|
||||
@@ -70,16 +70,17 @@ func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, genConf map[stri
|
||||
"temperature": 1,
|
||||
}
|
||||
|
||||
// Add generation config if provided
|
||||
if genConf != nil {
|
||||
if maxTokens, ok := genConf["max_tokens"]; ok {
|
||||
reqBody["max_tokens"] = maxTokens
|
||||
if chatModelConfig != nil {
|
||||
if chatModelConfig.MaxTokens != nil {
|
||||
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
|
||||
}
|
||||
if temperature, ok := genConf["temperature"]; ok {
|
||||
reqBody["temperature"] = temperature
|
||||
|
||||
if chatModelConfig.Temperature != nil {
|
||||
reqBody["temperature"] = *chatModelConfig.Temperature
|
||||
}
|
||||
if topP, ok := genConf["top_p"]; ok {
|
||||
reqBody["top_p"] = topP
|
||||
|
||||
if chatModelConfig.TopP != nil {
|
||||
reqBody["top_p"] = *chatModelConfig.TopP
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,229 +141,14 @@ func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, genConf map[stri
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// ChatStreamly sends a message and streams response
|
||||
func (z *ZhipuAIModel) ChatStreamly(modelName, apiKey, message *string, genConf map[string]interface{}) (<-chan string, error) {
|
||||
url := fmt.Sprintf("%s/chat/completions", z.BaseURL)
|
||||
|
||||
// Build request body with streaming enabled
|
||||
reqBody := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": *message},
|
||||
},
|
||||
"stream": true,
|
||||
"temperature": 1,
|
||||
}
|
||||
|
||||
// Add generation config if provided
|
||||
if genConf != nil {
|
||||
if maxTokens, ok := genConf["max_tokens"]; ok {
|
||||
reqBody["max_tokens"] = maxTokens
|
||||
}
|
||||
if temperature, ok := genConf["temperature"]; ok {
|
||||
reqBody["temperature"] = temperature
|
||||
}
|
||||
if topP, ok := genConf["top_p"]; ok {
|
||||
reqBody["top_p"] = topP
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
|
||||
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Create channel for streaming
|
||||
resultChan := make(chan string)
|
||||
|
||||
go func() {
|
||||
defer close(resultChan)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// SSE parsing: read line by line
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// SSE data line starts with "data:"
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract JSON after "data:"
|
||||
data := strings.TrimSpace(line[5:])
|
||||
|
||||
// [DONE] marks the end of stream
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
// Parse the JSON event
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
choices, ok := event["choices"].([]interface{})
|
||||
if !ok || len(choices) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
firstChoice, ok := choices[0].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
delta, ok := firstChoice["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
content, ok := delta["content"].(string)
|
||||
if ok && content != "" {
|
||||
resultChan <- content
|
||||
}
|
||||
|
||||
finishReason, ok := firstChoice["finish_reason"].(string)
|
||||
if ok && finishReason != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return resultChan, nil
|
||||
}
|
||||
|
||||
// ChatStreamlyWithChannel sends a message and streams response to channel (better performance)
|
||||
func (z *ZhipuAIModel) ChatStreamlyWithChannel(modelName, apiKey, message *string, genConf map[string]interface{}, resultChan chan<- string) error {
|
||||
url := fmt.Sprintf("%s/chat/completions", z.BaseURL)
|
||||
|
||||
// Build request body with streaming enabled
|
||||
reqBody := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": *message},
|
||||
},
|
||||
"stream": true,
|
||||
"temperature": 1,
|
||||
}
|
||||
|
||||
// Add generation config if provided
|
||||
if genConf != nil {
|
||||
if maxTokens, ok := genConf["max_tokens"]; ok {
|
||||
reqBody["max_tokens"] = maxTokens
|
||||
}
|
||||
if temperature, ok := genConf["temperature"]; ok {
|
||||
reqBody["temperature"] = temperature
|
||||
}
|
||||
if topP, ok := genConf["top_p"]; ok {
|
||||
reqBody["top_p"] = topP
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
|
||||
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// SSE parsing: read line by line
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
logger.Info(line)
|
||||
|
||||
// SSE data line starts with "data:"
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract JSON after "data:"
|
||||
data := strings.TrimSpace(line[5:])
|
||||
|
||||
// [DONE] marks the end of stream
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
// Parse the JSON event
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
choices, ok := event["choices"].([]interface{})
|
||||
if !ok || len(choices) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
firstChoice, ok := choices[0].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
delta, ok := firstChoice["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
content, ok := delta["content"].(string)
|
||||
if ok && content != "" {
|
||||
resultChan <- content
|
||||
}
|
||||
|
||||
finishReason, ok := firstChoice["finish_reason"].(string)
|
||||
if ok && finishReason != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Send [DONE] marker for OpenAI compatibility
|
||||
resultChan <- "[DONE]"
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
|
||||
func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error {
|
||||
url := fmt.Sprintf("%s/chat/completions", z.BaseURL)
|
||||
func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string, chatModelConfig *ChatConfig, sender func(*string, *string) error) error {
|
||||
var region = "default"
|
||||
if chatModelConfig.Region != nil {
|
||||
region = *chatModelConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/chat/completions", z.BaseURL[region])
|
||||
|
||||
// Build request body with streaming enabled
|
||||
reqBody := map[string]interface{}{
|
||||
@@ -374,33 +160,33 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string
|
||||
"temperature": 1,
|
||||
}
|
||||
|
||||
if modelConfig != nil {
|
||||
if modelConfig.Stream != nil {
|
||||
reqBody["stream"] = *modelConfig.Stream
|
||||
if chatModelConfig != nil {
|
||||
if chatModelConfig.Stream != nil {
|
||||
reqBody["stream"] = *chatModelConfig.Stream
|
||||
}
|
||||
|
||||
if modelConfig.MaxTokens != nil {
|
||||
reqBody["max_tokens"] = *modelConfig.MaxTokens
|
||||
if chatModelConfig.MaxTokens != nil {
|
||||
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
|
||||
}
|
||||
|
||||
if modelConfig.Temperature != nil {
|
||||
reqBody["temperature"] = *modelConfig.Temperature
|
||||
if chatModelConfig.Temperature != nil {
|
||||
reqBody["temperature"] = *chatModelConfig.Temperature
|
||||
}
|
||||
|
||||
if modelConfig.DoSample != nil {
|
||||
reqBody["do_sample"] = *modelConfig.DoSample
|
||||
if chatModelConfig.DoSample != nil {
|
||||
reqBody["do_sample"] = *chatModelConfig.DoSample
|
||||
}
|
||||
|
||||
if modelConfig.TopP != nil {
|
||||
reqBody["top_p"] = *modelConfig.TopP
|
||||
if chatModelConfig.TopP != nil {
|
||||
reqBody["top_p"] = *chatModelConfig.TopP
|
||||
}
|
||||
|
||||
if modelConfig.Stop != nil {
|
||||
reqBody["stop"] = *modelConfig.Stop
|
||||
if chatModelConfig.Stop != nil {
|
||||
reqBody["stop"] = *chatModelConfig.Stop
|
||||
}
|
||||
|
||||
if modelConfig.Reasoning != nil {
|
||||
if *modelConfig.Reasoning {
|
||||
if chatModelConfig.Reasoning != nil {
|
||||
if *chatModelConfig.Reasoning {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "enabled",
|
||||
}
|
||||
@@ -506,8 +292,13 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string
|
||||
}
|
||||
|
||||
// EncodeToEmbedding encodes a list of texts into embeddings
|
||||
func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []string) ([][]float64, error) {
|
||||
url := fmt.Sprintf("%s/embedding", z.BaseURL)
|
||||
func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []string, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
var region = "default"
|
||||
if embeddingConfig.Region != nil {
|
||||
region = *embeddingConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/embedding", z.BaseURL[region])
|
||||
|
||||
embeddings := make([][]float64, len(texts))
|
||||
|
||||
|
||||
@@ -34,6 +34,8 @@ type Tenant struct {
|
||||
TTSID *string `gorm:"column:tts_id;size:256;index" json:"tts_id,omitempty"`
|
||||
TenantTTSID *int64 `gorm:"column:tenant_tts_id;index" json:"tenant_tts_id,omitempty"`
|
||||
ParserIDs string `gorm:"column:parser_ids;size:256;not null;index" json:"parser_ids"`
|
||||
OCRID string `gorm:"column:ocr_id;size:256;not null" json:"ocr_id"`
|
||||
TenantOCRID *int64 `gorm:"column:tenant_ocr_id" json:"tenant_ocr_id,omitempty"`
|
||||
Credit int64 `gorm:"column:credit;default:512;index" json:"credit"`
|
||||
Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"`
|
||||
BaseModel
|
||||
|
||||
@@ -23,6 +23,7 @@ type TenantModelInstance struct {
|
||||
ProviderID string `gorm:"column:provider_id;size:32;not null;uniqueIndex:idx_api_key_provider_id" json:"provider_id"`
|
||||
APIKey string `gorm:"column:api_key;size:512;not null;uniqueIndex:idx_api_key_provider_id" json:"api_key"`
|
||||
Status string `gorm:"column:status;size:32;default:'active'" json:"status"`
|
||||
Extra string `gorm:"column:extra;size:512;default:'active'" json:"extra"`
|
||||
BaseModel
|
||||
}
|
||||
|
||||
|
||||
@@ -192,7 +192,7 @@ func (h *ProviderHandler) ListModels(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
models, err := dao.GetModelProviderManager().ListModels(providerName)
|
||||
providerModels, err := dao.GetModelProviderManager().ListModels(providerName)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeNotFound,
|
||||
@@ -203,7 +203,7 @@ func (h *ProviderHandler) ListModels(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": models,
|
||||
"data": providerModels,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -274,7 +274,7 @@ func (h *ProviderHandler) CreateProviderInstance(c *gin.Context) {
|
||||
|
||||
userID := c.GetString("user_id")
|
||||
|
||||
_, err := h.modelProviderService.CreateProviderInstance(providerName, req.InstanceName, req.APIKey, userID)
|
||||
_, err := h.modelProviderService.CreateProviderInstance(providerName, req.InstanceName, req.APIKey, userID, "default")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
@@ -458,7 +458,7 @@ func (h *ProviderHandler) ListInstanceModels(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
models, err := h.modelProviderService.ListInstanceModels(providerName, instanceName, c.GetString("user_id"))
|
||||
modelInstances, err := h.modelProviderService.ListInstanceModels(providerName, instanceName, c.GetString("user_id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeNotFound,
|
||||
@@ -469,7 +469,7 @@ func (h *ProviderHandler) ListInstanceModels(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": models,
|
||||
"data": modelInstances,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -618,6 +618,7 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) {
|
||||
MaxTokens: nil,
|
||||
Temperature: nil,
|
||||
TopP: nil,
|
||||
Region: nil,
|
||||
}
|
||||
|
||||
// Stream response using sender function (best performance, no channel)
|
||||
@@ -629,8 +630,19 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
chatConfig := models.ChatConfig{
|
||||
Reasoning: &req.Reasoning,
|
||||
Stream: &req.Stream,
|
||||
Stop: &[]string{},
|
||||
DoSample: nil,
|
||||
MaxTokens: nil,
|
||||
Temperature: nil,
|
||||
TopP: nil,
|
||||
Region: nil,
|
||||
}
|
||||
|
||||
// Non-stream response
|
||||
response, errorCode, err := h.modelProviderService.ChatToModel(providerName, instanceName, modelName, userID, req.Message)
|
||||
response, errorCode, err := h.modelProviderService.ChatToModel(providerName, instanceName, modelName, userID, req.Message, &chatConfig)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": errorCode,
|
||||
|
||||
@@ -76,9 +76,9 @@ func (h *TenantHandler) GetModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
type SetModelRequest struct {
|
||||
ModelProvider string `json:"model_provider" binding:"required"`
|
||||
ModelInstance string `json:"model_instance" binding:"required"`
|
||||
ModelName string `json:"model_name" binding:"required"`
|
||||
ModelProvider string `json:"model_provider"`
|
||||
ModelInstance string `json:"model_instance"`
|
||||
ModelName string `json:"model_name"`
|
||||
ModelType string `json:"model_type" binding:"required"`
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -228,7 +229,7 @@ func (m *ModelProviderService) DeleteModelProvider(providerName, userID string)
|
||||
return common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName, apiKey, userID string) (common.ErrorCode, error) {
|
||||
func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName, apiKey, userID, region string) (common.ErrorCode, error) {
|
||||
// Get tenant ID from user
|
||||
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
|
||||
if err != nil {
|
||||
@@ -252,6 +253,15 @@ func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName
|
||||
return common.CodeServerError, errors.New("fail to get UUID")
|
||||
}
|
||||
|
||||
extra := make(map[string]string)
|
||||
extra["region"] = region
|
||||
// convert extra to string
|
||||
extraByte, err := json.Marshal(extra)
|
||||
if err != nil {
|
||||
return common.CodeServerError, errors.New("fail to marshal extra")
|
||||
}
|
||||
extraStr := string(extraByte)
|
||||
|
||||
now := time.Now().Unix()
|
||||
nowDate := time.Now().Truncate(time.Second)
|
||||
tenantModelProvider := &entity.TenantModelInstance{
|
||||
@@ -259,7 +269,8 @@ func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName
|
||||
InstanceName: instanceName,
|
||||
ProviderID: provider.ID,
|
||||
APIKey: apiKey,
|
||||
Status: "active",
|
||||
Status: "enable",
|
||||
Extra: extraStr,
|
||||
}
|
||||
tenantModelProvider.CreateTime = &now
|
||||
tenantModelProvider.UpdateTime = &now
|
||||
@@ -301,12 +312,20 @@ func (m *ModelProviderService) ListProviderInstances(providerName, userID string
|
||||
|
||||
var result []map[string]interface{}
|
||||
for _, instance := range instances {
|
||||
// convert instance.Extra (json string) to map
|
||||
var extra map[string]string
|
||||
err = json.Unmarshal([]byte(instance.Extra), &extra)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
result = append(result, map[string]interface{}{
|
||||
"id": instance.ID,
|
||||
"instanceName": instance.InstanceName,
|
||||
"providerID": instance.ProviderID,
|
||||
"apiKey": instance.APIKey,
|
||||
"status": instance.Status,
|
||||
"region": extra["region"],
|
||||
})
|
||||
}
|
||||
|
||||
@@ -338,11 +357,19 @@ func (m *ModelProviderService) ShowProviderInstance(providerName, instanceName,
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
// convert instance.Extra (json string) to map
|
||||
var extra map[string]string
|
||||
err = json.Unmarshal([]byte(instance.Extra), &extra)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"id": instance.ID,
|
||||
"instanceName": instance.InstanceName,
|
||||
"providerID": instance.ProviderID,
|
||||
"status": instance.Status,
|
||||
"region": extra["region"],
|
||||
}
|
||||
|
||||
return result, common.CodeSuccess, nil
|
||||
@@ -504,7 +531,7 @@ func (m *ModelProviderService) UpdateModelStatus(providerName, instanceName, mod
|
||||
return common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName, userID, message string) (*string, common.ErrorCode, error) {
|
||||
func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName, userID, message string, modelConfig *modelModule.ChatConfig) (*string, common.ErrorCode, error) {
|
||||
|
||||
// Get tenant ID from user
|
||||
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
|
||||
@@ -541,8 +568,17 @@ func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName
|
||||
return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName))
|
||||
}
|
||||
|
||||
var extra map[string]string
|
||||
err = json.Unmarshal([]byte(instance.Extra), &extra)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
region := extra["region"]
|
||||
modelConfig.Region = ®ion
|
||||
|
||||
var response string
|
||||
response, err = providerInfo.ModelDriver.Chat(&modelName, &instance.APIKey, &message, nil)
|
||||
response, err = providerInfo.ModelDriver.Chat(&modelName, &instance.APIKey, &message, modelConfig)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
@@ -553,77 +589,6 @@ func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName
|
||||
return nil, common.CodeServerError, errors.New("model is disabled")
|
||||
}
|
||||
|
||||
// ChatToModelStream
|
||||
func (m *ModelProviderService) ChatToModelStream(providerName, instanceName, modelName, userID, message string) (<-chan string, <-chan error, common.ErrorCode, error) {
|
||||
streamChan := make(chan string)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
// Get tenant ID from user
|
||||
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
|
||||
if err != nil {
|
||||
close(streamChan)
|
||||
close(errChan)
|
||||
return streamChan, errChan, common.CodeServerError, err
|
||||
}
|
||||
|
||||
if len(tenants) == 0 {
|
||||
close(streamChan)
|
||||
close(errChan)
|
||||
return streamChan, errChan, common.CodeNotFound, errors.New("user has no tenants")
|
||||
}
|
||||
|
||||
tenantID := tenants[0].TenantID
|
||||
|
||||
// Check if provider exists
|
||||
provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName)
|
||||
if err != nil {
|
||||
close(streamChan)
|
||||
close(errChan)
|
||||
return streamChan, errChan, common.CodeServerError, err
|
||||
}
|
||||
|
||||
instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName)
|
||||
if err != nil {
|
||||
close(streamChan)
|
||||
close(errChan)
|
||||
return streamChan, errChan, common.CodeServerError, err
|
||||
}
|
||||
|
||||
_, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName)
|
||||
if err != nil {
|
||||
providerInfo := dao.GetModelProviderManager().FindProvider(providerName)
|
||||
if providerInfo == nil {
|
||||
close(streamChan)
|
||||
close(errChan)
|
||||
return streamChan, errChan, common.CodeNotFound, errors.New("provider not found")
|
||||
}
|
||||
|
||||
_, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName)
|
||||
if err != nil {
|
||||
close(streamChan)
|
||||
close(errChan)
|
||||
return streamChan, errChan, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName))
|
||||
}
|
||||
|
||||
// Async call stream interface using channel for better performance
|
||||
go func() {
|
||||
defer close(streamChan)
|
||||
defer close(errChan)
|
||||
|
||||
err := providerInfo.ModelDriver.ChatStreamlyWithChannel(&modelName, &instance.APIKey, &message, nil, streamChan)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
return streamChan, errChan, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
close(streamChan)
|
||||
close(errChan)
|
||||
return streamChan, errChan, common.CodeServerError, errors.New("model is disabled")
|
||||
}
|
||||
|
||||
// ChatToModelStreamWithSender streams chat response directly via sender function (best performance, no channel)
|
||||
func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, message string, modelConfig *modelModule.ChatConfig, sender func(*string, *string) error) common.ErrorCode {
|
||||
// Get tenant ID from user
|
||||
@@ -661,6 +626,15 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc
|
||||
return common.CodeNotFound
|
||||
}
|
||||
|
||||
var extra map[string]string
|
||||
err = json.Unmarshal([]byte(instance.Extra), &extra)
|
||||
if err != nil {
|
||||
return common.CodeServerError
|
||||
}
|
||||
|
||||
region := extra["region"]
|
||||
modelConfig.Region = ®ion
|
||||
|
||||
// Direct call with sender function
|
||||
err := providerInfo.ModelDriver.ChatStreamlyWithSender(&modelName, &instance.APIKey, &message, modelConfig, sender)
|
||||
if err != nil {
|
||||
|
||||
@@ -303,31 +303,6 @@ type ModelItem struct {
|
||||
|
||||
type DefaultModelResponse struct {
|
||||
Models []ModelItem `json:"models,omitempty"`
|
||||
//TenantID string `json:"tenant_id"`
|
||||
//ChatModelProvider *string `json:"chat_model_provider"`
|
||||
//ChatModelInstance *string `json:"chat_model_instance"`
|
||||
//ChatModelName *string `json:"chat_model_name"`
|
||||
//ChatModelEnable bool `json:"chat_model_enable"`
|
||||
//EmbeddingModelProvider *string `json:"embedding_model_provider"`
|
||||
//EmbeddingModelInstance *string `json:"embedding_model_instance"`
|
||||
//EmbeddingModelName *string `json:"embedding_model_name"`
|
||||
//EmbeddingModelEnable bool `json:"embedding_model_enable"`
|
||||
//RerankModelProvider *string `json:"rerank_model_provider"`
|
||||
//RerankModelInstance *string `json:"rerank_model_instance"`
|
||||
//RerankModelName *string `json:"rerank_model_name"`
|
||||
//RerankModelEnable bool `json:"rerank_model_enable"`
|
||||
//ASRModelProvider *string `json:"asr_model_provider"`
|
||||
//ASRModelInstance *string `json:"asr_model_instance"`
|
||||
//ASRModelName *string `json:"asr_model_name"`
|
||||
//ASREnable bool `json:"asr_enable"`
|
||||
//Image2TextModelProvider *string `json:"image2text_model_provider"`
|
||||
//Image2TextModelInstance *string `json:"image2text_model_instance"`
|
||||
//Image2TextModelName *string `json:"image2text_model_name"`
|
||||
//Image2TextModelEnable bool `json:"image2text_model_enable"`
|
||||
//TTSModelProvider *string `json:"tts_model_provider"`
|
||||
//TTSModelInstance *string `json:"tts_model_instance"`
|
||||
//TTSModelName *string `json:"tts_model_name"`
|
||||
//TTSModelEnable bool `json:"tts_model_enable"`
|
||||
}
|
||||
|
||||
func (s *TenantService) GetModelInfo(tenantID string, defaultModel string, modelType string) (*string, *string, *string, bool, error) {
|
||||
@@ -351,6 +326,12 @@ func (s *TenantService) GetModelInfo(tenantID string, defaultModel string, model
|
||||
return nil, nil, nil, false, fmt.Errorf("invalid model string: %s", defaultModel)
|
||||
}
|
||||
|
||||
if modelType == "ocr" {
|
||||
if *providerName == "infiniflow" && *instanceName == "default" && *modelName == "deepdoc" {
|
||||
return providerName, instanceName, modelName, true, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the provider and instance exists
|
||||
modelProvider, err := s.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, *providerName)
|
||||
if err != nil {
|
||||
@@ -406,7 +387,7 @@ func (s *TenantService) ListTenantDefaultModels(userID string) ([]ModelItem, err
|
||||
ModelProvider: defaultChatModelProvider,
|
||||
ModelInstance: defaultChatModelInstance,
|
||||
ModelName: defaultChatModelName,
|
||||
ModelType: "llm",
|
||||
ModelType: "chat",
|
||||
Enable: defaultChatModelEnable,
|
||||
})
|
||||
}
|
||||
@@ -444,17 +425,28 @@ func (s *TenantService) ListTenantDefaultModels(userID string) ([]ModelItem, err
|
||||
})
|
||||
}
|
||||
|
||||
defaultImage2TextModelProvider, defaultImage2TextModelInstance, defaultImage2TextModelName, defaultImage2TextModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.Img2TxtID, "image2text")
|
||||
defaultImage2TextModelProvider, defaultImage2TextModelInstance, defaultImage2TextModelName, defaultImage2TextModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.Img2TxtID, "vision")
|
||||
if err == nil {
|
||||
result = append(result, ModelItem{
|
||||
ModelProvider: defaultImage2TextModelProvider,
|
||||
ModelInstance: defaultImage2TextModelInstance,
|
||||
ModelName: defaultImage2TextModelName,
|
||||
ModelType: "image2text",
|
||||
ModelType: "vision",
|
||||
Enable: defaultImage2TextModelEnable,
|
||||
})
|
||||
}
|
||||
|
||||
defaultOCRModelProvider, defaultOCRModelInstance, defaultOCRModelName, defaultOCRModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.OCRID, "ocr")
|
||||
if err == nil {
|
||||
result = append(result, ModelItem{
|
||||
ModelProvider: defaultOCRModelProvider,
|
||||
ModelInstance: defaultOCRModelInstance,
|
||||
ModelName: defaultOCRModelName,
|
||||
ModelType: "ocr",
|
||||
Enable: defaultOCRModelEnable,
|
||||
})
|
||||
}
|
||||
|
||||
if ownedTenant.TTSID == nil {
|
||||
return result, nil
|
||||
}
|
||||
@@ -518,11 +510,7 @@ func (s *TenantService) SetTenantDefaultModels(userID, modelProvider, modelInsta
|
||||
}
|
||||
|
||||
ownedTenant := tenantInfos[0]
|
||||
err = s.checkModelAvailable(ownedTenant.TenantID, modelProvider, modelInstance, modelName, modelType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var defaultModel string
|
||||
var modelTypeID string
|
||||
if modelType == "chat" {
|
||||
modelTypeID = "llm_id"
|
||||
@@ -536,17 +524,31 @@ func (s *TenantService) SetTenantDefaultModels(userID, modelProvider, modelInsta
|
||||
if modelType == "asr" {
|
||||
modelTypeID = "asr_id"
|
||||
}
|
||||
if modelType == "image2text" {
|
||||
if modelType == "vision" {
|
||||
modelTypeID = "img2txt_id"
|
||||
}
|
||||
if modelType == "tts" {
|
||||
modelTypeID = "tts_id"
|
||||
}
|
||||
if modelType == "ocr" {
|
||||
modelTypeID = "ocr_id"
|
||||
}
|
||||
if modelTypeID == "" {
|
||||
return fmt.Errorf("model type %s is invalid", modelType)
|
||||
}
|
||||
|
||||
defaultModel := fmt.Sprintf("%s@%s@%s", modelName, modelInstance, modelProvider)
|
||||
if modelProvider == "" && modelInstance == "" && modelName == "" {
|
||||
defaultModel = ""
|
||||
} else if modelProvider != "" && modelInstance != "" && modelName != "" {
|
||||
err = s.checkModelAvailable(ownedTenant.TenantID, modelProvider, modelInstance, modelName, modelType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defaultModel = fmt.Sprintf("%s@%s@%s", modelName, modelInstance, modelProvider)
|
||||
} else {
|
||||
return fmt.Errorf("model provider, instance and name must be specified together")
|
||||
}
|
||||
|
||||
err = s.tenantDAO.Update(ownedTenant.TenantID, map[string]interface{}{
|
||||
modelTypeID: defaultModel,
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user