Add extra field to model instance (#14203)

### What problem does this PR solve?

Now each model support region with different URL

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2026-04-20 15:31:12 +08:00
committed by GitHub
parent 939933649a
commit af2ed416a7
22 changed files with 398 additions and 550 deletions

View File

@@ -1,7 +1,8 @@
{
"name": "OpenAI",
"tags": "LLM,TEXT EMBEDDING,TTS,TEXT RE-RANK,SPEECH2TEXT,MODERATION",
"url": "https://api.openai.com/v1",
"url": {
"default": "https://api.openai.com/v1"
},
"url_suffix": {
"chat": "chat/completions"
},
@@ -10,8 +11,8 @@
"name": "gpt-5.2-pro",
"max_tokens": 400000,
"model_types": [
"llm",
"vlm"
"chat",
"vision"
],
"features": {}
},
@@ -19,8 +20,8 @@
"name": "gpt-5.2",
"max_tokens": 400000,
"model_types": [
"llm",
"vlm"
"chat",
"vision"
],
"features": {}
},
@@ -28,8 +29,8 @@
"name": "gpt-5.1",
"max_tokens": 400000,
"model_types": [
"llm",
"vlm"
"chat",
"vision"
],
"features": {}
},
@@ -37,8 +38,8 @@
"name": "gpt-5.1-chat-latest",
"max_tokens": 400000,
"model_types": [
"llm",
"vlm"
"chat",
"vision"
],
"features": {}
},
@@ -46,8 +47,8 @@
"name": "gpt-5",
"max_tokens": 400000,
"model_types": [
"llm",
"vlm"
"chat",
"vision"
],
"features": {}
},
@@ -55,8 +56,8 @@
"name": "gpt-5-mini",
"max_tokens": 400000,
"model_types": [
"llm",
"vlm"
"chat",
"vision"
],
"features": {}
},
@@ -64,8 +65,8 @@
"name": "gpt-5-nano",
"max_tokens": 400000,
"model_types": [
"llm",
"vlm"
"chat",
"vision"
],
"features": {}
},
@@ -73,8 +74,8 @@
"name": "gpt-5-chat-latest",
"max_tokens": 400000,
"model_types": [
"llm",
"vlm"
"chat",
"vision"
],
"features": {}
},
@@ -82,8 +83,8 @@
"name": "gpt-4.1",
"max_tokens": 1047576,
"model_types": [
"llm",
"vlm"
"chat",
"vision"
],
"features": {}
},
@@ -91,8 +92,8 @@
"name": "gpt-4.1-mini",
"max_tokens": 1047576,
"model_types": [
"llm",
"vlm"
"chat",
"vision"
],
"features": {}
},
@@ -100,8 +101,8 @@
"name": "gpt-4.1-nano",
"max_tokens": 1047576,
"model_types": [
"llm",
"vlm"
"chat",
"vision"
],
"features": {}
},
@@ -109,7 +110,7 @@
"name": "gpt-4.5-preview",
"max_tokens": 128000,
"model_types": [
"llm"
"chat"
],
"features": {}
},
@@ -117,8 +118,8 @@
"name": "o3",
"max_tokens": 200000,
"model_types": [
"llm",
"vlm"
"chat",
"vision"
],
"features": {}
},
@@ -126,8 +127,8 @@
"name": "o4-mini",
"max_tokens": 200000,
"model_types": [
"llm",
"vlm"
"chat",
"vision"
],
"features": {}
},
@@ -135,8 +136,8 @@
"name": "o4-mini-high",
"max_tokens": 200000,
"model_types": [
"llm",
"vlm"
"chat",
"vision"
],
"features": {}
},
@@ -144,8 +145,8 @@
"name": "gpt-4o-mini",
"max_tokens": 128000,
"model_types": [
"llm",
"vlm"
"chat",
"vision"
],
"features": {}
},
@@ -153,8 +154,8 @@
"name": "gpt-4o",
"max_tokens": 128000,
"model_types": [
"llm",
"vlm"
"chat",
"vision"
],
"features": {}
},
@@ -162,7 +163,7 @@
"name": "gpt-3.5-turbo",
"max_tokens": 4096,
"model_types": [
"llm"
"chat"
],
"features": {}
},
@@ -170,7 +171,7 @@
"name": "gpt-3.5-turbo-16k-0613",
"max_tokens": 16385,
"model_types": [
"llm"
"chat"
],
"features": {}
},
@@ -202,7 +203,7 @@
"name": "whisper-1",
"max_tokens": 26214400,
"model_types": [
"speech2text"
"asr"
],
"features": {}
},
@@ -210,7 +211,7 @@
"name": "gpt-4",
"max_tokens": 8191,
"model_types": [
"llm"
"chat"
],
"features": {}
},
@@ -218,7 +219,7 @@
"name": "gpt-4-turbo",
"max_tokens": 8191,
"model_types": [
"llm"
"chat"
],
"features": {}
},
@@ -226,7 +227,7 @@
"name": "gpt-4-32k",
"max_tokens": 32768,
"model_types": [
"llm"
"chat"
],
"features": {}
},

View File

@@ -1,7 +1,8 @@
{
"name": "xAI",
"tags": "LLM",
"url": "https://api.x.ai/v1",
"url": {
"default": "https://api.x.ai/v1"
},
"url_suffix": {
"chat": "chat/completions"
},
@@ -9,44 +10,38 @@
{
"name": "grok-4",
"max_tokens": 256000,
"model_types": ["llm"],
"model_types": ["chat"],
"features": {}
},
{
"name": "grok-3",
"max_tokens": 131072,
"model_types": ["llm"],
"model_types": ["chat"],
"features": {}
},
{
"name": "grok-3-fast",
"max_tokens": 131072,
"model_types": ["llm"],
"model_types": ["chat"],
"features": {}
},
{
"name": "grok-3-mini",
"max_tokens": 131072,
"model_types": ["llm"],
"model_types": ["chat"],
"features": {}
},
{
"name": "grok-3-mini-mini-fast",
"max_tokens": 131072,
"model_types": ["llm"],
"model_types": ["chat"],
"features": {}
},
{
"name": "grok-2-vision",
"max_tokens": 32768,
"model_types": ["vlm"],
"features": {
"multimodal": {
"enabled": true,
"input_modalities": ["image"],
"output_modalities": ["text"]
}
}
"model_types": ["vision"],
"features": {}
}
]
}

View File

@@ -1,7 +1,8 @@
{
"name": "ZHIPU-AI",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"url": "https://open.bigmodel.cn/api/paas/v4",
"url": {
"default": "https://open.bigmodel.cn/api/paas/v4"
},
"url_suffix": {
"chat": "chat/completions",
"async_chat": "async/chat/completions",
@@ -31,7 +32,7 @@
"max_tokens": 128000,
"model_types": [
"chat",
"image2text"
"vision"
],
"features": {}
},
@@ -71,7 +72,7 @@
"name": "glm-4.5v",
"max_tokens": 64000,
"model_types": [
"image2text"
"vision"
],
"features": {}
},
@@ -151,7 +152,7 @@
"name": "glm-4v",
"max_tokens": 2000,
"model_types": [
"image2text"
"vision"
],
"features": {}
},
@@ -183,7 +184,28 @@
"name": "glm-asr",
"max_tokens": 4096,
"model_types": [
"speech2text"
"asr"
],
"features": {}
},
{
"name": "glm-tts",
"model_types": [
"tts"
],
"features": {}
},
{
"name": "glm-ocr",
"model_types": [
"ocr"
],
"features": {}
},
{
"name": "glm-rerank",
"model_types": [
"rerank"
],
"features": {}
}

View File

@@ -1187,35 +1187,36 @@ func (p *Parser) parseAdminSetVariable() (*Command, error) {
func (p *Parser) parseAdminSetDefault() (*Command, error) {
p.nextToken() // consume DEFAULT
var modelType, modelID string
var modelType string
switch p.curToken.Type {
case TokenLLM:
modelType = "llm_id"
case TokenVLM:
modelType = "img2txt_id"
case TokenChat:
modelType = "chat"
case TokenVision:
modelType = "vision"
case TokenEmbedding:
modelType = "embd_id"
case TokenReranker:
modelType = "reranker_id"
modelType = "embedding"
case TokenRerank:
modelType = "rerank"
case TokenASR:
modelType = "asr_id"
modelType = "asr"
case TokenTTS:
modelType = "tts_id"
modelType = "tts"
case TokenOCR:
modelType = "ocr"
default:
return nil, fmt.Errorf("unknown model type: %s", p.curToken.Value)
}
p.nextToken()
id, err := p.parseQuotedString()
compositeModelName, err := p.parseQuotedString()
if err != nil {
return nil, err
}
modelID = id
cmd := NewCommand("set_default_model")
cmd.Params["model_type"] = modelType
cmd.Params["model_id"] = modelID
cmd.Params["composite_model_name"] = compositeModelName
p.nextToken()
// Semicolon is optional for UNSET TOKEN
@@ -1254,18 +1255,20 @@ func (p *Parser) parseAdminResetCommand() (*Command, error) {
var modelType string
switch p.curToken.Type {
case TokenLLM:
modelType = "llm_id"
case TokenVLM:
modelType = "img2txt_id"
case TokenChat:
modelType = "chat"
case TokenVision:
modelType = "vision"
case TokenEmbedding:
modelType = "embd_id"
case TokenReranker:
modelType = "reranker_id"
modelType = "embedding"
case TokenRerank:
modelType = "rerank"
case TokenASR:
modelType = "asr_id"
modelType = "asr"
case TokenTTS:
modelType = "tts_id"
modelType = "tts"
case TokenOCR:
modelType = "ocr"
default:
return nil, fmt.Errorf("unknown model type: %s", p.curToken.Value)
}

View File

@@ -250,6 +250,8 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) {
return c.ShowCurrentModel(cmd)
case "set_default_model":
return c.SetDefaultModel(cmd)
case "reset_default_model":
return c.ResetDefaultModel(cmd)
case "list_user_default_models":
return c.ListDefaultModels(cmd)
// Dataset, metadata commands

View File

@@ -375,27 +375,29 @@ func (c *RAGFlowClient) ShowModel(cmd *Command) (ResponseIf, error) {
func (c *RAGFlowClient) SetDefaultModel(cmd *Command) (ResponseIf, error) {
modeType, ok := cmd.Params["model_type"].(string)
modelType, ok := cmd.Params["model_type"].(string)
if !ok {
return nil, fmt.Errorf("model_type not provided")
}
modelProvider, ok := cmd.Params["model_provider"].(string)
if !ok {
return nil, fmt.Errorf("model_provider not provided")
}
modelInstance, ok := cmd.Params["model_instance"].(string)
if !ok {
return nil, fmt.Errorf("model_instance not provided")
}
modelName, ok := cmd.Params["model_name"].(string)
compositeModelName, ok := cmd.Params["composite_model_name"].(string)
if !ok {
return nil, fmt.Errorf("model_name not provided")
}
var providerName, instanceName, modelName string
names := strings.Split(compositeModelName, "/")
if len(names) != 3 {
return nil, fmt.Errorf("model name must be in format 'provider/instance/model'")
}
providerName = names[0]
instanceName = names[1]
modelName = names[2]
payload := map[string]interface{}{
"model_type": modeType,
"model_provider": modelProvider,
"model_instance": modelInstance,
"model_type": modelType,
"model_provider": providerName,
"model_instance": instanceName,
"model_name": modelName,
}
@@ -420,6 +422,38 @@ func (c *RAGFlowClient) SetDefaultModel(cmd *Command) (ResponseIf, error) {
return &result, nil
}
func (c *RAGFlowClient) ResetDefaultModel(cmd *Command) (ResponseIf, error) {
modelType, ok := cmd.Params["model_type"].(string)
if !ok {
return nil, fmt.Errorf("model_type not provided")
}
payload := map[string]interface{}{
"model_type": modelType,
}
resp, err := c.HTTPClient.Request("PATCH", "/models", true, "web", nil, payload)
if err != nil {
return nil, fmt.Errorf("failed to reset default model: %w", err)
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("failed to reset default model: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
}
var result SimpleResponse
if err = json.Unmarshal(resp.Body, &result); err != nil {
return nil, fmt.Errorf("failed to reset default model: invalid JSON (%w)", err)
}
if result.Code != 0 {
return nil, fmt.Errorf("%s", result.Message)
}
result.Duration = resp.Duration
return &result, nil
}
func (c *RAGFlowClient) ListDefaultModels(cmd *Command) (ResponseIf, error) {
resp, err := c.HTTPClient.Request("GET", "/models", true, "web", nil, nil)
if err != nil {

View File

@@ -327,18 +327,18 @@ func (l *Lexer) lookupIdent(ident string) Token {
return Token{Type: TokenSearch, Value: ident}
case "CURRENT":
return Token{Type: TokenCurrent, Value: ident}
case "LLM":
return Token{Type: TokenLLM, Value: ident}
case "VLM":
return Token{Type: TokenVLM, Value: ident}
case "VISION":
return Token{Type: TokenVision, Value: ident}
case "EMBEDDING":
return Token{Type: TokenEmbedding, Value: ident}
case "RERANKER":
return Token{Type: TokenReranker, Value: ident}
case "RERANK":
return Token{Type: TokenRerank, Value: ident}
case "ASR":
return Token{Type: TokenASR, Value: ident}
case "TTS":
return Token{Type: TokenTTS, Value: ident}
case "OCR":
return Token{Type: TokenOCR, Value: ident}
case "ASYNC":
return Token{Type: TokenAsync, Value: ident}
case "SYNC":

View File

@@ -90,12 +90,12 @@ const (
TokenPipeline
TokenSearch
TokenCurrent
TokenLLM
TokenVLM
TokenVision
TokenEmbedding
TokenReranker
TokenRerank
TokenASR
TokenTTS
TokenOCR
TokenAsync
TokenSync
TokenBenchmark

View File

@@ -1417,8 +1417,8 @@ func (c *RAGFlowClient) ChatToModel(cmd *Command) (ResponseIf, error) {
var providerName, instanceName, modelName string
// Check if model_name is provided in command
if compositeModelName, ok := cmd.Params["model_name"].(string); ok && compositeModelName != "" {
// Check if composite_model_name is provided in command
if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" {
names := strings.Split(compositeModelName, "/")
if len(names) != 3 {
return nil, fmt.Errorf("model name must be in format 'provider/instance/model'")
@@ -1524,12 +1524,12 @@ func (c *RAGFlowClient) UseModel(cmd *Command) (ResponseIf, error) {
return nil, fmt.Errorf("this command is only allowed in USER mode")
}
modelIdentifier, ok := cmd.Params["model_identifier"].(string)
if !ok || modelIdentifier == "" {
compositeModelName, ok := cmd.Params["composite_model_name"].(string)
if !ok || compositeModelName == "" {
return nil, fmt.Errorf("model identifier not provided")
}
names := strings.Split(modelIdentifier, "/")
names := strings.Split(compositeModelName, "/")
if len(names) != 3 {
return nil, fmt.Errorf("model identifier must be in format 'provider/instance/model'")
}

View File

@@ -1597,49 +1597,48 @@ func (p *Parser) parseSetVariable() (*Command, error) {
func (p *Parser) parseSetDefault() (*Command, error) {
p.nextToken() // consume DEFAULT
var modelType, modelProvider, modelInstance, modelName string
var modelType, compositeModelName string
var err error
switch p.curToken.Type {
case TokenLLM:
case TokenChat:
modelType = "chat"
case TokenVLM:
modelType = "image2text"
case TokenVision:
modelType = "vision"
case TokenEmbedding:
modelType = "embedding"
case TokenReranker:
case TokenRerank:
modelType = "rerank"
case TokenASR:
modelType = "asr"
case TokenTTS:
modelType = "tts"
case TokenOCR:
modelType = "ocr"
default:
return nil, fmt.Errorf("unknown model type: %s", p.curToken.Value)
}
p.nextToken() // pass model type
p.nextToken()
modelProvider, err = p.parseQuotedString()
if p.curToken.Type != TokenModel {
return nil, fmt.Errorf("expected MODEL")
}
p.nextToken() // pass MODEL
// Format: 'provider/instance/model' or just 'message'
if p.curToken.Type != TokenQuotedString {
return nil, fmt.Errorf("expected quoted string with format provider/instance/model")
}
compositeModelName, err = p.parseQuotedString()
if err != nil {
return nil, err
}
p.nextToken()
modelInstance, err = p.parseQuotedString()
if err != nil {
return nil, err
}
p.nextToken()
modelName, err = p.parseQuotedString()
if err != nil {
return nil, err
}
cmd := NewCommand("set_default_model")
cmd.Params["model_type"] = modelType
cmd.Params["model_provider"] = modelProvider
cmd.Params["model_instance"] = modelInstance
cmd.Params["model_name"] = modelName
cmd.Params["composite_model_name"] = compositeModelName
p.nextToken()
// Semicolon is optional for UNSET TOKEN
@@ -1717,26 +1716,33 @@ func (p *Parser) parseResetCommand() (*Command, error) {
var modelType string
switch p.curToken.Type {
case TokenLLM:
modelType = "llm_id"
case TokenVLM:
modelType = "img2txt_id"
case TokenChat:
modelType = "chat"
case TokenVision:
modelType = "vision"
case TokenEmbedding:
modelType = "embd_id"
case TokenReranker:
modelType = "reranker_id"
modelType = "embedding"
case TokenRerank:
modelType = "rerank"
case TokenASR:
modelType = "asr_id"
modelType = "asr"
case TokenTTS:
modelType = "tts_id"
modelType = "tts"
case TokenOCR:
modelType = "ocr"
default:
return nil, fmt.Errorf("unknown model type: %s", p.curToken.Value)
}
cmd := NewCommand("reset_default_model")
cmd.Params["model_type"] = modelType
p.nextToken()
if p.curToken.Type != TokenModel {
return nil, fmt.Errorf("expected MODEL")
}
p.nextToken() // pass MODEL
// Semicolon is optional for UNSET TOKEN
if p.curToken.Type == TokenSemicolon {
p.nextToken()
@@ -2144,7 +2150,7 @@ func (p *Parser) parseDisableCommand() (*Command, error) {
func (p *Parser) parseChatCommand() (*Command, error) {
p.nextToken() // consume CHAT
var modelName string
var compositeModelName string
var message string
// Check if we have a quoted string that looks like a model identifier (contains two slashes)
@@ -2156,7 +2162,7 @@ func (p *Parser) parseChatCommand() (*Command, error) {
slashCount := strings.Count(firstArg, "/")
if slashCount == 2 {
// This is likely a model identifier, expect another quoted string for message
modelName = firstArg
compositeModelName = firstArg
p.nextToken()
// After model name, expect message
@@ -2184,8 +2190,8 @@ func (p *Parser) parseChatCommand() (*Command, error) {
}
cmd := NewCommand("chat_to_model")
if modelName != "" {
cmd.Params["model_name"] = modelName
if compositeModelName != "" {
cmd.Params["composite_model_name"] = compositeModelName
}
cmd.Params["message"] = message
cmd.Params["reasoning"] = false
@@ -2213,7 +2219,7 @@ func (p *Parser) parseUseCommand() (*Command, error) {
p.nextToken() // consume MODEL
// Parse model identifier in format 'provider/instance/model'
modelIdentifier, err := p.parseQuotedString()
compositeModelName, err := p.parseQuotedString()
if err != nil {
return nil, fmt.Errorf("expected model identifier in format 'provider/instance/model': %w", err)
}
@@ -2225,7 +2231,7 @@ func (p *Parser) parseUseCommand() (*Command, error) {
}
cmd := NewCommand("use_model")
cmd.Params["model_identifier"] = modelIdentifier
cmd.Params["composite_model_name"] = compositeModelName
return cmd, nil
}

View File

@@ -62,6 +62,7 @@ type TenantInfo struct {
ASRID string `gorm:"column:asr_id" json:"asr_id"`
Img2TxtID string `gorm:"column:img2txt_id" json:"img2txt_id"`
TTSID *string `gorm:"column:tts_id" json:"tts_id,omitempty"`
OCRID string `gorm:"column:ocr_id" json:"ocr_id"`
ParserIDs string `gorm:"column:parser_ids" json:"parser_ids"`
Role string `gorm:"column:role" json:"role"`
}
@@ -71,7 +72,7 @@ func (dao *TenantDAO) GetInfoByUserID(userID string) ([]*TenantInfo, error) {
var results []*TenantInfo
err := DB.Model(&entity.Tenant{}).
Select("tenant.id as tenant_id, tenant.name, tenant.llm_id, tenant.embd_id, tenant.rerank_id, tenant.asr_id, tenant.img2txt_id, tenant.tts_id, tenant.parser_ids, user_tenant.role").
Select("tenant.id as tenant_id, tenant.name, tenant.llm_id, tenant.embd_id, tenant.rerank_id, tenant.asr_id, tenant.img2txt_id, tenant.tts_id, tenant.ocr_id, tenant.parser_ids, user_tenant.role").
Joins("INNER JOIN user_tenant ON user_tenant.tenant_id = tenant.id").
Where("user_tenant.user_id = ? AND user_tenant.status = ? AND user_tenant.role = ? AND tenant.status = ?", userID, "1", "owner", "1").
Scan(&results).Error

View File

@@ -145,11 +145,10 @@ type Model struct {
// Provider represents an LLM provider
type Provider struct {
Name string `json:"name"`
Tags string `json:"tags"`
URL string `json:"url"`
URLSuffix models.URLSuffix `json:"url_suffix"`
Models []*Model `json:"models"`
Name string `json:"name"`
URL map[string]string `json:"url"`
URLSuffix models.URLSuffix `json:"url_suffix"`
Models []*Model `json:"models"`
ModelDriver models.ModelDriver
}
@@ -236,11 +235,24 @@ func (pm *ProviderManager) ListProviders() ([]map[string]interface{}, error) {
var providers []map[string]interface{}
for _, provider := range pm.Providers {
modelTypeSet := make(map[string]struct{})
for _, model := range provider.Models {
for _, modelType := range model.ModelTypes {
modelTypeSet[modelType] = struct{}{}
}
}
var modelTypes []string
for modelType := range modelTypeSet {
modelTypes = append(modelTypes, modelType)
}
providerData := map[string]interface{}{
"name": provider.Name,
"tags": provider.Tags,
"url": provider.URL,
"url_suffix": provider.URLSuffix,
"name": provider.Name,
"url": provider.URL,
"model_types": modelTypes,
"url_suffix": provider.URLSuffix,
}
providers = append(providers, providerData)
}
@@ -262,7 +274,6 @@ func (pm *ProviderManager) GetProviderByName(providerName string) (map[string]in
providerInfo := map[string]interface{}{
"name": provider.Name,
"tags": provider.Tags,
"base_url": provider.URL,
"total_models": len(provider.Models),
}

View File

@@ -20,14 +20,14 @@ import (
"fmt"
)
// DummyModel implements ModelDriver for Zhipu AI (智谱 AI)
// DummyModel implements ModelDriver for Zhipu AI
type DummyModel struct {
BaseURL string
BaseURL map[string]string
URLSuffix URLSuffix
}
// NewDummyModel creates a new Zhipu AI model instance
func NewDummyModel(baseURL string, urlSuffix URLSuffix) *DummyModel {
func NewDummyModel(baseURL map[string]string, urlSuffix URLSuffix) *DummyModel {
return &DummyModel{
BaseURL: baseURL,
URLSuffix: urlSuffix,
@@ -35,26 +35,16 @@ func NewDummyModel(baseURL string, urlSuffix URLSuffix) *DummyModel {
}
// Chat sends a message and returns response
func (z *DummyModel) Chat(modelName, apiKey, message *string, genConf map[string]interface{}) (string, error) {
func (z *DummyModel) Chat(modelName, apiKey, message *string, modelConfig *ChatConfig) (string, error) {
return "", fmt.Errorf("not implemented")
}
// ChatStreamly sends a message and streams response
func (z *DummyModel) ChatStreamly(modelName, apiKey, message *string, genConf map[string]interface{}) (<-chan string, error) {
return nil, fmt.Errorf("not implemented")
}
// ChatStreamlyWithChannel sends a message and streams response to channel (better performance)
func (z *DummyModel) ChatStreamlyWithChannel(modelName, apiKey, message *string, genConf map[string]interface{}, resultChan chan<- string) error {
return fmt.Errorf("not implemented")
}
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
func (z *DummyModel) ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error {
return fmt.Errorf("not implemented")
}
// EncodeToEmbedding encodes a list of texts into embeddings
func (z *DummyModel) EncodeToEmbedding(modelName, apiKey *string, texts []string) ([][]float64, error) {
func (z *DummyModel) EncodeToEmbedding(modelName, apiKey *string, texts []string, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
return nil, fmt.Errorf("not implemented")
}

View File

@@ -30,7 +30,7 @@ func NewModelFactory() *ModelFactory {
}
// CreateModelDriver creates a ModelDriver for the given provider and model
func (f *ModelFactory) CreateModelDriver(providerName string, baseURL string, urlSuffix URLSuffix) (ModelDriver, error) {
func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string]string, urlSuffix URLSuffix) (ModelDriver, error) {
providerLower := strings.ToLower(providerName)
switch providerLower {
case "zhipu-ai":

View File

@@ -3,15 +3,11 @@ package models
// EmbeddingModel interface for embedding models
type ModelDriver interface {
// Chat sends a message and returns response
Chat(modelName, apiKey, message *string, genConf map[string]interface{}) (string, error)
// ChatStreamly sends a message and streams response
ChatStreamly(modelName, apiKey, message *string, genConf map[string]interface{}) (<-chan string, error)
// ChatStreamlyWithChannel sends a message and streams response to channel (better performance)
ChatStreamlyWithChannel(modelName, apiKey, message *string, genConf map[string]interface{}, resultChan chan<- string) error
Chat(modelName, apiKey, message *string, modelConfig *ChatConfig) (string, error)
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error
// Encode encodes a list of texts into embeddings
EncodeToEmbedding(modelName, apiKey *string, texts []string) ([][]float64, error)
EncodeToEmbedding(modelName, apiKey *string, texts []string, embeddingConfig *EmbeddingConfig) ([][]float64, error)
}
// URLSuffix represents the URL suffixes for different API endpoints
@@ -31,4 +27,9 @@ type ChatConfig struct {
TopP *float64
DoSample *bool
Stop *[]string
Region *string
}
type EmbeddingConfig struct {
Region *string
}

View File

@@ -30,13 +30,13 @@ import (
// ZhipuAIModel implements ModelDriver for Zhipu AI
type ZhipuAIModel struct {
BaseURL string
BaseURL map[string]string
URLSuffix URLSuffix
httpClient *http.Client // Reusable HTTP client with connection pool
}
// NewZhipuAIModel creates a new Zhipu AI model instance
func NewZhipuAIModel(baseURL string, urlSuffix URLSuffix) *ZhipuAIModel {
func NewZhipuAIModel(baseURL map[string]string, urlSuffix URLSuffix) *ZhipuAIModel {
return &ZhipuAIModel{
BaseURL: baseURL,
URLSuffix: urlSuffix,
@@ -53,7 +53,7 @@ func NewZhipuAIModel(baseURL string, urlSuffix URLSuffix) *ZhipuAIModel {
}
// Chat sends a message and returns response
func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, genConf map[string]interface{}) (string, error) {
func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, chatModelConfig *ChatConfig) (string, error) {
if message == nil {
return "", fmt.Errorf("message is nil")
}
@@ -70,16 +70,17 @@ func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, genConf map[stri
"temperature": 1,
}
// Add generation config if provided
if genConf != nil {
if maxTokens, ok := genConf["max_tokens"]; ok {
reqBody["max_tokens"] = maxTokens
if chatModelConfig != nil {
if chatModelConfig.MaxTokens != nil {
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
}
if temperature, ok := genConf["temperature"]; ok {
reqBody["temperature"] = temperature
if chatModelConfig.Temperature != nil {
reqBody["temperature"] = *chatModelConfig.Temperature
}
if topP, ok := genConf["top_p"]; ok {
reqBody["top_p"] = topP
if chatModelConfig.TopP != nil {
reqBody["top_p"] = *chatModelConfig.TopP
}
}
@@ -140,229 +141,14 @@ func (z *ZhipuAIModel) Chat(modelName, apiKey, message *string, genConf map[stri
return content, nil
}
// ChatStreamly sends a message and streams response
func (z *ZhipuAIModel) ChatStreamly(modelName, apiKey, message *string, genConf map[string]interface{}) (<-chan string, error) {
url := fmt.Sprintf("%s/chat/completions", z.BaseURL)
// Build request body with streaming enabled
reqBody := map[string]interface{}{
"model": modelName,
"messages": []map[string]string{
{"role": "user", "content": *message},
},
"stream": true,
"temperature": 1,
}
// Add generation config if provided
if genConf != nil {
if maxTokens, ok := genConf["max_tokens"]; ok {
reqBody["max_tokens"] = maxTokens
}
if temperature, ok := genConf["temperature"]; ok {
reqBody["temperature"] = temperature
}
if topP, ok := genConf["top_p"]; ok {
reqBody["top_p"] = topP
}
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
resp, err := z.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
// Create channel for streaming
resultChan := make(chan string)
go func() {
defer close(resultChan)
defer resp.Body.Close()
// SSE parsing: read line by line
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
// SSE data line starts with "data:"
if !strings.HasPrefix(line, "data:") {
continue
}
// Extract JSON after "data:"
data := strings.TrimSpace(line[5:])
// [DONE] marks the end of stream
if data == "[DONE]" {
break
}
// Parse the JSON event
var event map[string]interface{}
if err := json.Unmarshal([]byte(data), &event); err != nil {
continue
}
choices, ok := event["choices"].([]interface{})
if !ok || len(choices) == 0 {
continue
}
firstChoice, ok := choices[0].(map[string]interface{})
if !ok {
continue
}
delta, ok := firstChoice["delta"].(map[string]interface{})
if !ok {
continue
}
content, ok := delta["content"].(string)
if ok && content != "" {
resultChan <- content
}
finishReason, ok := firstChoice["finish_reason"].(string)
if ok && finishReason != "" {
break
}
}
}()
return resultChan, nil
}
// ChatStreamlyWithChannel sends a message and streams response to channel (better performance)
func (z *ZhipuAIModel) ChatStreamlyWithChannel(modelName, apiKey, message *string, genConf map[string]interface{}, resultChan chan<- string) error {
url := fmt.Sprintf("%s/chat/completions", z.BaseURL)
// Build request body with streaming enabled
reqBody := map[string]interface{}{
"model": modelName,
"messages": []map[string]string{
{"role": "user", "content": *message},
},
"stream": true,
"temperature": 1,
}
// Add generation config if provided
if genConf != nil {
if maxTokens, ok := genConf["max_tokens"]; ok {
reqBody["max_tokens"] = maxTokens
}
if temperature, ok := genConf["temperature"]; ok {
reqBody["temperature"] = temperature
}
if topP, ok := genConf["top_p"]; ok {
reqBody["top_p"] = topP
}
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
resp, err := z.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
// SSE parsing: read line by line
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
logger.Info(line)
// SSE data line starts with "data:"
if !strings.HasPrefix(line, "data:") {
continue
}
// Extract JSON after "data:"
data := strings.TrimSpace(line[5:])
// [DONE] marks the end of stream
if data == "[DONE]" {
break
}
// Parse the JSON event
var event map[string]interface{}
if err := json.Unmarshal([]byte(data), &event); err != nil {
continue
}
choices, ok := event["choices"].([]interface{})
if !ok || len(choices) == 0 {
continue
}
firstChoice, ok := choices[0].(map[string]interface{})
if !ok {
continue
}
delta, ok := firstChoice["delta"].(map[string]interface{})
if !ok {
continue
}
content, ok := delta["content"].(string)
if ok && content != "" {
resultChan <- content
}
finishReason, ok := firstChoice["finish_reason"].(string)
if ok && finishReason != "" {
break
}
}
// Send [DONE] marker for OpenAI compatibility
resultChan <- "[DONE]"
return scanner.Err()
}
// ChatStreamlyWithSender sends a message and streams response via sender function (best performance, no channel)
func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string, modelConfig *ChatConfig, sender func(*string, *string) error) error {
url := fmt.Sprintf("%s/chat/completions", z.BaseURL)
func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string, chatModelConfig *ChatConfig, sender func(*string, *string) error) error {
var region = "default"
if chatModelConfig.Region != nil {
region = *chatModelConfig.Region
}
url := fmt.Sprintf("%s/chat/completions", z.BaseURL[region])
// Build request body with streaming enabled
reqBody := map[string]interface{}{
@@ -374,33 +160,33 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string
"temperature": 1,
}
if modelConfig != nil {
if modelConfig.Stream != nil {
reqBody["stream"] = *modelConfig.Stream
if chatModelConfig != nil {
if chatModelConfig.Stream != nil {
reqBody["stream"] = *chatModelConfig.Stream
}
if modelConfig.MaxTokens != nil {
reqBody["max_tokens"] = *modelConfig.MaxTokens
if chatModelConfig.MaxTokens != nil {
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
}
if modelConfig.Temperature != nil {
reqBody["temperature"] = *modelConfig.Temperature
if chatModelConfig.Temperature != nil {
reqBody["temperature"] = *chatModelConfig.Temperature
}
if modelConfig.DoSample != nil {
reqBody["do_sample"] = *modelConfig.DoSample
if chatModelConfig.DoSample != nil {
reqBody["do_sample"] = *chatModelConfig.DoSample
}
if modelConfig.TopP != nil {
reqBody["top_p"] = *modelConfig.TopP
if chatModelConfig.TopP != nil {
reqBody["top_p"] = *chatModelConfig.TopP
}
if modelConfig.Stop != nil {
reqBody["stop"] = *modelConfig.Stop
if chatModelConfig.Stop != nil {
reqBody["stop"] = *chatModelConfig.Stop
}
if modelConfig.Reasoning != nil {
if *modelConfig.Reasoning {
if chatModelConfig.Reasoning != nil {
if *chatModelConfig.Reasoning {
reqBody["thinking"] = map[string]interface{}{
"type": "enabled",
}
@@ -506,8 +292,13 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName, apiKey, message *string
}
// EncodeToEmbedding encodes a list of texts into embeddings
func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []string) ([][]float64, error) {
url := fmt.Sprintf("%s/embedding", z.BaseURL)
func (z *ZhipuAIModel) EncodeToEmbedding(modelName, apiKey *string, texts []string, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
var region = "default"
if embeddingConfig.Region != nil {
region = *embeddingConfig.Region
}
url := fmt.Sprintf("%s/embedding", z.BaseURL[region])
embeddings := make([][]float64, len(texts))

View File

@@ -34,6 +34,8 @@ type Tenant struct {
TTSID *string `gorm:"column:tts_id;size:256;index" json:"tts_id,omitempty"`
TenantTTSID *int64 `gorm:"column:tenant_tts_id;index" json:"tenant_tts_id,omitempty"`
ParserIDs string `gorm:"column:parser_ids;size:256;not null;index" json:"parser_ids"`
OCRID string `gorm:"column:ocr_id;size:256;not null" json:"ocr_id"`
TenantOCRID *int64 `gorm:"column:tenant_ocr_id" json:"tenant_ocr_id,omitempty"`
Credit int64 `gorm:"column:credit;default:512;index" json:"credit"`
Status *string `gorm:"column:status;size:1;index" json:"status,omitempty"`
BaseModel

View File

@@ -23,6 +23,7 @@ type TenantModelInstance struct {
ProviderID string `gorm:"column:provider_id;size:32;not null;uniqueIndex:idx_api_key_provider_id" json:"provider_id"`
APIKey string `gorm:"column:api_key;size:512;not null;uniqueIndex:idx_api_key_provider_id" json:"api_key"`
Status string `gorm:"column:status;size:32;default:'active'" json:"status"`
Extra string `gorm:"column:extra;size:512;default:'active'" json:"extra"`
BaseModel
}

View File

@@ -192,7 +192,7 @@ func (h *ProviderHandler) ListModels(c *gin.Context) {
})
return
}
models, err := dao.GetModelProviderManager().ListModels(providerName)
providerModels, err := dao.GetModelProviderManager().ListModels(providerName)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeNotFound,
@@ -203,7 +203,7 @@ func (h *ProviderHandler) ListModels(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "success",
"data": models,
"data": providerModels,
})
}
@@ -274,7 +274,7 @@ func (h *ProviderHandler) CreateProviderInstance(c *gin.Context) {
userID := c.GetString("user_id")
_, err := h.modelProviderService.CreateProviderInstance(providerName, req.InstanceName, req.APIKey, userID)
_, err := h.modelProviderService.CreateProviderInstance(providerName, req.InstanceName, req.APIKey, userID, "default")
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
@@ -458,7 +458,7 @@ func (h *ProviderHandler) ListInstanceModels(c *gin.Context) {
})
return
}
models, err := h.modelProviderService.ListInstanceModels(providerName, instanceName, c.GetString("user_id"))
modelInstances, err := h.modelProviderService.ListInstanceModels(providerName, instanceName, c.GetString("user_id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeNotFound,
@@ -469,7 +469,7 @@ func (h *ProviderHandler) ListInstanceModels(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "success",
"data": models,
"data": modelInstances,
})
}
@@ -618,6 +618,7 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) {
MaxTokens: nil,
Temperature: nil,
TopP: nil,
Region: nil,
}
// Stream response using sender function (best performance, no channel)
@@ -629,8 +630,19 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) {
return
}
chatConfig := models.ChatConfig{
Reasoning: &req.Reasoning,
Stream: &req.Stream,
Stop: &[]string{},
DoSample: nil,
MaxTokens: nil,
Temperature: nil,
TopP: nil,
Region: nil,
}
// Non-stream response
response, errorCode, err := h.modelProviderService.ChatToModel(providerName, instanceName, modelName, userID, req.Message)
response, errorCode, err := h.modelProviderService.ChatToModel(providerName, instanceName, modelName, userID, req.Message, &chatConfig)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": errorCode,

View File

@@ -76,9 +76,9 @@ func (h *TenantHandler) GetModels(c *gin.Context) {
}
type SetModelRequest struct {
ModelProvider string `json:"model_provider" binding:"required"`
ModelInstance string `json:"model_instance" binding:"required"`
ModelName string `json:"model_name" binding:"required"`
ModelProvider string `json:"model_provider"`
ModelInstance string `json:"model_instance"`
ModelName string `json:"model_name"`
ModelType string `json:"model_type" binding:"required"`
}

View File

@@ -18,6 +18,7 @@ package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
@@ -228,7 +229,7 @@ func (m *ModelProviderService) DeleteModelProvider(providerName, userID string)
return common.CodeSuccess, nil
}
func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName, apiKey, userID string) (common.ErrorCode, error) {
func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName, apiKey, userID, region string) (common.ErrorCode, error) {
// Get tenant ID from user
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
if err != nil {
@@ -252,6 +253,15 @@ func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName
return common.CodeServerError, errors.New("fail to get UUID")
}
extra := make(map[string]string)
extra["region"] = region
// convert extra to string
extraByte, err := json.Marshal(extra)
if err != nil {
return common.CodeServerError, errors.New("fail to marshal extra")
}
extraStr := string(extraByte)
now := time.Now().Unix()
nowDate := time.Now().Truncate(time.Second)
tenantModelProvider := &entity.TenantModelInstance{
@@ -259,7 +269,8 @@ func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName
InstanceName: instanceName,
ProviderID: provider.ID,
APIKey: apiKey,
Status: "active",
Status: "enable",
Extra: extraStr,
}
tenantModelProvider.CreateTime = &now
tenantModelProvider.UpdateTime = &now
@@ -301,12 +312,20 @@ func (m *ModelProviderService) ListProviderInstances(providerName, userID string
var result []map[string]interface{}
for _, instance := range instances {
// convert instance.Extra (json string) to map
var extra map[string]string
err = json.Unmarshal([]byte(instance.Extra), &extra)
if err != nil {
return nil, common.CodeServerError, err
}
result = append(result, map[string]interface{}{
"id": instance.ID,
"instanceName": instance.InstanceName,
"providerID": instance.ProviderID,
"apiKey": instance.APIKey,
"status": instance.Status,
"region": extra["region"],
})
}
@@ -338,11 +357,19 @@ func (m *ModelProviderService) ShowProviderInstance(providerName, instanceName,
return nil, common.CodeServerError, err
}
// convert instance.Extra (json string) to map
var extra map[string]string
err = json.Unmarshal([]byte(instance.Extra), &extra)
if err != nil {
return nil, common.CodeServerError, err
}
result := map[string]interface{}{
"id": instance.ID,
"instanceName": instance.InstanceName,
"providerID": instance.ProviderID,
"status": instance.Status,
"region": extra["region"],
}
return result, common.CodeSuccess, nil
@@ -504,7 +531,7 @@ func (m *ModelProviderService) UpdateModelStatus(providerName, instanceName, mod
return common.CodeSuccess, nil
}
func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName, userID, message string) (*string, common.ErrorCode, error) {
func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName, userID, message string, modelConfig *modelModule.ChatConfig) (*string, common.ErrorCode, error) {
// Get tenant ID from user
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
@@ -541,8 +568,17 @@ func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName
return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName))
}
var extra map[string]string
err = json.Unmarshal([]byte(instance.Extra), &extra)
if err != nil {
return nil, common.CodeServerError, err
}
region := extra["region"]
modelConfig.Region = &region
var response string
response, err = providerInfo.ModelDriver.Chat(&modelName, &instance.APIKey, &message, nil)
response, err = providerInfo.ModelDriver.Chat(&modelName, &instance.APIKey, &message, modelConfig)
if err != nil {
return nil, common.CodeServerError, err
}
@@ -553,77 +589,6 @@ func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName
return nil, common.CodeServerError, errors.New("model is disabled")
}
// ChatToModelStream
func (m *ModelProviderService) ChatToModelStream(providerName, instanceName, modelName, userID, message string) (<-chan string, <-chan error, common.ErrorCode, error) {
streamChan := make(chan string)
errChan := make(chan error, 1)
// Get tenant ID from user
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
if err != nil {
close(streamChan)
close(errChan)
return streamChan, errChan, common.CodeServerError, err
}
if len(tenants) == 0 {
close(streamChan)
close(errChan)
return streamChan, errChan, common.CodeNotFound, errors.New("user has no tenants")
}
tenantID := tenants[0].TenantID
// Check if provider exists
provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName)
if err != nil {
close(streamChan)
close(errChan)
return streamChan, errChan, common.CodeServerError, err
}
instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName)
if err != nil {
close(streamChan)
close(errChan)
return streamChan, errChan, common.CodeServerError, err
}
_, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName)
if err != nil {
providerInfo := dao.GetModelProviderManager().FindProvider(providerName)
if providerInfo == nil {
close(streamChan)
close(errChan)
return streamChan, errChan, common.CodeNotFound, errors.New("provider not found")
}
_, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName)
if err != nil {
close(streamChan)
close(errChan)
return streamChan, errChan, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName))
}
// Async call stream interface using channel for better performance
go func() {
defer close(streamChan)
defer close(errChan)
err := providerInfo.ModelDriver.ChatStreamlyWithChannel(&modelName, &instance.APIKey, &message, nil, streamChan)
if err != nil {
errChan <- err
}
}()
return streamChan, errChan, common.CodeSuccess, nil
}
close(streamChan)
close(errChan)
return streamChan, errChan, common.CodeServerError, errors.New("model is disabled")
}
// ChatToModelStreamWithSender streams chat response directly via sender function (best performance, no channel)
func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, message string, modelConfig *modelModule.ChatConfig, sender func(*string, *string) error) common.ErrorCode {
// Get tenant ID from user
@@ -661,6 +626,15 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc
return common.CodeNotFound
}
var extra map[string]string
err = json.Unmarshal([]byte(instance.Extra), &extra)
if err != nil {
return common.CodeServerError
}
region := extra["region"]
modelConfig.Region = &region
// Direct call with sender function
err := providerInfo.ModelDriver.ChatStreamlyWithSender(&modelName, &instance.APIKey, &message, modelConfig, sender)
if err != nil {

View File

@@ -303,31 +303,6 @@ type ModelItem struct {
type DefaultModelResponse struct {
Models []ModelItem `json:"models,omitempty"`
//TenantID string `json:"tenant_id"`
//ChatModelProvider *string `json:"chat_model_provider"`
//ChatModelInstance *string `json:"chat_model_instance"`
//ChatModelName *string `json:"chat_model_name"`
//ChatModelEnable bool `json:"chat_model_enable"`
//EmbeddingModelProvider *string `json:"embedding_model_provider"`
//EmbeddingModelInstance *string `json:"embedding_model_instance"`
//EmbeddingModelName *string `json:"embedding_model_name"`
//EmbeddingModelEnable bool `json:"embedding_model_enable"`
//RerankModelProvider *string `json:"rerank_model_provider"`
//RerankModelInstance *string `json:"rerank_model_instance"`
//RerankModelName *string `json:"rerank_model_name"`
//RerankModelEnable bool `json:"rerank_model_enable"`
//ASRModelProvider *string `json:"asr_model_provider"`
//ASRModelInstance *string `json:"asr_model_instance"`
//ASRModelName *string `json:"asr_model_name"`
//ASREnable bool `json:"asr_enable"`
//Image2TextModelProvider *string `json:"image2text_model_provider"`
//Image2TextModelInstance *string `json:"image2text_model_instance"`
//Image2TextModelName *string `json:"image2text_model_name"`
//Image2TextModelEnable bool `json:"image2text_model_enable"`
//TTSModelProvider *string `json:"tts_model_provider"`
//TTSModelInstance *string `json:"tts_model_instance"`
//TTSModelName *string `json:"tts_model_name"`
//TTSModelEnable bool `json:"tts_model_enable"`
}
func (s *TenantService) GetModelInfo(tenantID string, defaultModel string, modelType string) (*string, *string, *string, bool, error) {
@@ -351,6 +326,12 @@ func (s *TenantService) GetModelInfo(tenantID string, defaultModel string, model
return nil, nil, nil, false, fmt.Errorf("invalid model string: %s", defaultModel)
}
if modelType == "ocr" {
if *providerName == "infiniflow" && *instanceName == "default" && *modelName == "deepdoc" {
return providerName, instanceName, modelName, true, nil
}
}
// Check if the provider and instance exists
modelProvider, err := s.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, *providerName)
if err != nil {
@@ -406,7 +387,7 @@ func (s *TenantService) ListTenantDefaultModels(userID string) ([]ModelItem, err
ModelProvider: defaultChatModelProvider,
ModelInstance: defaultChatModelInstance,
ModelName: defaultChatModelName,
ModelType: "llm",
ModelType: "chat",
Enable: defaultChatModelEnable,
})
}
@@ -444,17 +425,28 @@ func (s *TenantService) ListTenantDefaultModels(userID string) ([]ModelItem, err
})
}
defaultImage2TextModelProvider, defaultImage2TextModelInstance, defaultImage2TextModelName, defaultImage2TextModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.Img2TxtID, "image2text")
defaultImage2TextModelProvider, defaultImage2TextModelInstance, defaultImage2TextModelName, defaultImage2TextModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.Img2TxtID, "vision")
if err == nil {
result = append(result, ModelItem{
ModelProvider: defaultImage2TextModelProvider,
ModelInstance: defaultImage2TextModelInstance,
ModelName: defaultImage2TextModelName,
ModelType: "image2text",
ModelType: "vision",
Enable: defaultImage2TextModelEnable,
})
}
defaultOCRModelProvider, defaultOCRModelInstance, defaultOCRModelName, defaultOCRModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.OCRID, "ocr")
if err == nil {
result = append(result, ModelItem{
ModelProvider: defaultOCRModelProvider,
ModelInstance: defaultOCRModelInstance,
ModelName: defaultOCRModelName,
ModelType: "ocr",
Enable: defaultOCRModelEnable,
})
}
if ownedTenant.TTSID == nil {
return result, nil
}
@@ -518,11 +510,7 @@ func (s *TenantService) SetTenantDefaultModels(userID, modelProvider, modelInsta
}
ownedTenant := tenantInfos[0]
err = s.checkModelAvailable(ownedTenant.TenantID, modelProvider, modelInstance, modelName, modelType)
if err != nil {
return err
}
var defaultModel string
var modelTypeID string
if modelType == "chat" {
modelTypeID = "llm_id"
@@ -536,17 +524,31 @@ func (s *TenantService) SetTenantDefaultModels(userID, modelProvider, modelInsta
if modelType == "asr" {
modelTypeID = "asr_id"
}
if modelType == "image2text" {
if modelType == "vision" {
modelTypeID = "img2txt_id"
}
if modelType == "tts" {
modelTypeID = "tts_id"
}
if modelType == "ocr" {
modelTypeID = "ocr_id"
}
if modelTypeID == "" {
return fmt.Errorf("model type %s is invalid", modelType)
}
defaultModel := fmt.Sprintf("%s@%s@%s", modelName, modelInstance, modelProvider)
if modelProvider == "" && modelInstance == "" && modelName == "" {
defaultModel = ""
} else if modelProvider != "" && modelInstance != "" && modelName != "" {
err = s.checkModelAvailable(ownedTenant.TenantID, modelProvider, modelInstance, modelName, modelType)
if err != nil {
return err
}
defaultModel = fmt.Sprintf("%s@%s@%s", modelName, modelInstance, modelProvider)
} else {
return fmt.Errorf("model provider, instance and name must be specified together")
}
err = s.tenantDAO.Update(ownedTenant.TenantID, map[string]interface{}{
modelTypeID: defaultModel,
})