mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Go: set and list default models (#14191)
### What problem does this PR solve? ``` RAGFlow(user)> set default vlm "zhipu-ai" "ccc" "glm-4.6v-flash"; SUCCESS RAGFlow(user)> list default models; +--------+----------------+----------------+----------------+------------+ | enable | model_instance | model_name | model_provider | model_type | +--------+----------------+----------------+----------------+------------+ | true | ccc | glm-4.6v-flash | zhipu-ai | llm | | true | ccc | glm-4.6v-flash | zhipu-ai | image2text | +--------+----------------+----------------+----------------+------------+ ``` ### Type of change - [x] New Feature (non-breaking change which adds functionality) Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@@ -26,6 +26,15 @@
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "glm-4.6v-Flash",
|
||||
"max_tokens": 128000,
|
||||
"model_types": [
|
||||
"chat",
|
||||
"image2text"
|
||||
],
|
||||
"features": {}
|
||||
},
|
||||
{
|
||||
"name": "glm-4.5-x",
|
||||
"max_tokens": 128000,
|
||||
|
||||
@@ -248,6 +248,10 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) {
|
||||
return c.UseModel(cmd)
|
||||
case "show_current_model":
|
||||
return c.ShowCurrentModel(cmd)
|
||||
case "set_default_model":
|
||||
return c.SetDefaultModel(cmd)
|
||||
case "list_user_default_models":
|
||||
return c.ListDefaultModels(cmd)
|
||||
// Dataset, metadata commands
|
||||
case "create_dataset_table":
|
||||
return c.CreateDatasetInDocEngine(cmd)
|
||||
|
||||
@@ -373,6 +373,75 @@ func (c *RAGFlowClient) ShowModel(cmd *Command) (ResponseIf, error) {
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *RAGFlowClient) SetDefaultModel(cmd *Command) (ResponseIf, error) {
|
||||
|
||||
modeType, ok := cmd.Params["model_type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model_type not provided")
|
||||
}
|
||||
modelProvider, ok := cmd.Params["model_provider"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model_provider not provided")
|
||||
}
|
||||
modelInstance, ok := cmd.Params["model_instance"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model_instance not provided")
|
||||
}
|
||||
modelName, ok := cmd.Params["model_name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model_name not provided")
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"model_type": modeType,
|
||||
"model_provider": modelProvider,
|
||||
"model_instance": modelInstance,
|
||||
"model_name": modelName,
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("PATCH", "/models", true, "web", nil, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to set default model: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to set default model: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
var result SimpleResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to set default model: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
return nil, fmt.Errorf("%s", result.Message)
|
||||
}
|
||||
result.Duration = resp.Duration
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *RAGFlowClient) ListDefaultModels(cmd *Command) (ResponseIf, error) {
|
||||
resp, err := c.HTTPClient.Request("GET", "/models", true, "web", nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list default models: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to list default models: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
var result CommonResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to list default models: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
return nil, fmt.Errorf("%s", result.Message)
|
||||
}
|
||||
result.Duration = resp.Duration
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// readPassword reads password from terminal without echoing
|
||||
func ReadPassword() (string, error) {
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
||||
|
||||
@@ -1597,35 +1597,49 @@ func (p *Parser) parseSetVariable() (*Command, error) {
|
||||
func (p *Parser) parseSetDefault() (*Command, error) {
|
||||
p.nextToken() // consume DEFAULT
|
||||
|
||||
var modelType, modelID string
|
||||
var modelType, modelProvider, modelInstance, modelName string
|
||||
var err error
|
||||
|
||||
switch p.curToken.Type {
|
||||
case TokenLLM:
|
||||
modelType = "llm_id"
|
||||
modelType = "chat"
|
||||
case TokenVLM:
|
||||
modelType = "img2txt_id"
|
||||
modelType = "image2text"
|
||||
case TokenEmbedding:
|
||||
modelType = "embd_id"
|
||||
modelType = "embedding"
|
||||
case TokenReranker:
|
||||
modelType = "reranker_id"
|
||||
modelType = "rerank"
|
||||
case TokenASR:
|
||||
modelType = "asr_id"
|
||||
modelType = "asr"
|
||||
case TokenTTS:
|
||||
modelType = "tts_id"
|
||||
modelType = "tts"
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown model type: %s", p.curToken.Value)
|
||||
}
|
||||
|
||||
p.nextToken()
|
||||
id, err := p.parseQuotedString()
|
||||
modelProvider, err = p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.nextToken()
|
||||
modelInstance, err = p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.nextToken()
|
||||
modelName, err = p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
modelID = id
|
||||
|
||||
cmd := NewCommand("set_default_model")
|
||||
cmd.Params["model_type"] = modelType
|
||||
cmd.Params["model_id"] = modelID
|
||||
cmd.Params["model_provider"] = modelProvider
|
||||
cmd.Params["model_instance"] = modelInstance
|
||||
cmd.Params["model_name"] = modelName
|
||||
|
||||
p.nextToken()
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
@@ -2601,7 +2615,6 @@ func (p *Parser) parseRemoveTags() (*Command, error) {
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
|
||||
// parseRemoveChunk parses:
|
||||
// - REMOVE CHUNKS 'chunk_id1', 'chunk_id2' FROM DOCUMENT 'doc_id';
|
||||
// - REMOVE ALL CHUNKS FROM DOCUMENT 'doc_id';
|
||||
|
||||
@@ -149,7 +149,7 @@ type Provider struct {
|
||||
Tags string `json:"tags"`
|
||||
URL string `json:"url"`
|
||||
URLSuffix models.URLSuffix `json:"url_suffix"`
|
||||
Models []Model `json:"models"`
|
||||
Models []*Model `json:"models"`
|
||||
ModelDriver models.ModelDriver
|
||||
}
|
||||
|
||||
@@ -547,7 +547,7 @@ func (pm *ProviderManager) FindProvider(name string) *Provider {
|
||||
func (pm *ProviderManager) findModel(provider *Provider, modelName string) *Model {
|
||||
for i := range provider.Models {
|
||||
if strings.EqualFold(provider.Models[i].Name, modelName) {
|
||||
return &provider.Models[i]
|
||||
return provider.Models[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -42,6 +42,81 @@ func NewTenantHandler(tenantService *service.TenantService, userService *service
|
||||
}
|
||||
}
|
||||
|
||||
func (h *TenantHandler) GetModels(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
defaultModels, err := h.tenantService.ListTenantDefaultModels(user.ID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeExceptionError,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if defaultModels == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeDataError,
|
||||
"message": "No default models",
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"message": "success",
|
||||
"data": defaultModels,
|
||||
})
|
||||
}
|
||||
|
||||
type SetModelRequest struct {
|
||||
ModelProvider string `json:"model_provider" binding:"required"`
|
||||
ModelInstance string `json:"model_instance" binding:"required"`
|
||||
ModelName string `json:"model_name" binding:"required"`
|
||||
ModelType string `json:"model_type" binding:"required"`
|
||||
}
|
||||
|
||||
func (h *TenantHandler) SetModels(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request body (same as Python get_request_json())
|
||||
var req SetModelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeBadRequest,
|
||||
"data": nil,
|
||||
"message": "Invalid request body: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err := h.tenantService.SetTenantDefaultModels(user.ID, req.ModelProvider, req.ModelInstance, req.ModelName, req.ModelType)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeExceptionError,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"message": "success",
|
||||
"data": nil,
|
||||
})
|
||||
}
|
||||
|
||||
// TenantInfo get tenant information
|
||||
// @Summary Get Tenant Information
|
||||
// @Description Get current user's tenant information (owner tenant)
|
||||
|
||||
@@ -219,6 +219,12 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
provider.POST("/:provider_name/instances/:instance_name/models/:model_name", r.providerHandler.ChatToModel)
|
||||
}
|
||||
|
||||
model := v1.Group("/models")
|
||||
{
|
||||
model.GET("/", r.tenantHandler.GetModels)
|
||||
model.PATCH("/", r.tenantHandler.SetModels)
|
||||
}
|
||||
|
||||
system := v1.Group("/system")
|
||||
{
|
||||
system.GET("/version", r.systemHandler.GetVersion)
|
||||
|
||||
@@ -28,17 +28,27 @@ import (
|
||||
|
||||
// TenantService tenant service
|
||||
type TenantService struct {
|
||||
tenantDAO *dao.TenantDAO
|
||||
userTenantDAO *dao.UserTenantDAO
|
||||
docEngine engine.DocEngine
|
||||
tenantDAO *dao.TenantDAO
|
||||
userTenantDAO *dao.UserTenantDAO
|
||||
modelProviderDAO *dao.TenantModelProviderDAO
|
||||
modelInstanceDAO *dao.TenantModelInstanceDAO
|
||||
modelDAO *dao.TenantModelDAO
|
||||
modelGroupDAO *dao.TenantModelGroupDAO
|
||||
modelGroupMappingDAO *dao.TenantModelGroupMappingDAO
|
||||
docEngine engine.DocEngine
|
||||
}
|
||||
|
||||
// NewTenantService create tenant service
|
||||
func NewTenantService() *TenantService {
|
||||
return &TenantService{
|
||||
tenantDAO: dao.NewTenantDAO(),
|
||||
userTenantDAO: dao.NewUserTenantDAO(),
|
||||
docEngine: engine.Get(),
|
||||
tenantDAO: dao.NewTenantDAO(),
|
||||
userTenantDAO: dao.NewUserTenantDAO(),
|
||||
modelProviderDAO: dao.NewTenantModelProviderDAO(),
|
||||
modelInstanceDAO: dao.NewTenantModelInstanceDAO(),
|
||||
modelDAO: dao.NewTenantModelDAO(),
|
||||
modelGroupDAO: dao.NewTenantModelGroupDAO(),
|
||||
modelGroupMappingDAO: dao.NewTenantModelGroupMappingDAO(),
|
||||
docEngine: engine.Get(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -282,3 +292,264 @@ func (s *TenantService) DeleteMetadataInDocEngine(tenantID string) (common.Error
|
||||
|
||||
return common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
type ModelItem struct {
|
||||
ModelProvider *string `json:"model_provider"`
|
||||
ModelInstance *string `json:"model_instance"`
|
||||
ModelName *string `json:"model_name"`
|
||||
ModelType string `json:"model_type"`
|
||||
Enable bool `json:"enable"`
|
||||
}
|
||||
|
||||
type DefaultModelResponse struct {
|
||||
Models []ModelItem `json:"models,omitempty"`
|
||||
//TenantID string `json:"tenant_id"`
|
||||
//ChatModelProvider *string `json:"chat_model_provider"`
|
||||
//ChatModelInstance *string `json:"chat_model_instance"`
|
||||
//ChatModelName *string `json:"chat_model_name"`
|
||||
//ChatModelEnable bool `json:"chat_model_enable"`
|
||||
//EmbeddingModelProvider *string `json:"embedding_model_provider"`
|
||||
//EmbeddingModelInstance *string `json:"embedding_model_instance"`
|
||||
//EmbeddingModelName *string `json:"embedding_model_name"`
|
||||
//EmbeddingModelEnable bool `json:"embedding_model_enable"`
|
||||
//RerankModelProvider *string `json:"rerank_model_provider"`
|
||||
//RerankModelInstance *string `json:"rerank_model_instance"`
|
||||
//RerankModelName *string `json:"rerank_model_name"`
|
||||
//RerankModelEnable bool `json:"rerank_model_enable"`
|
||||
//ASRModelProvider *string `json:"asr_model_provider"`
|
||||
//ASRModelInstance *string `json:"asr_model_instance"`
|
||||
//ASRModelName *string `json:"asr_model_name"`
|
||||
//ASREnable bool `json:"asr_enable"`
|
||||
//Image2TextModelProvider *string `json:"image2text_model_provider"`
|
||||
//Image2TextModelInstance *string `json:"image2text_model_instance"`
|
||||
//Image2TextModelName *string `json:"image2text_model_name"`
|
||||
//Image2TextModelEnable bool `json:"image2text_model_enable"`
|
||||
//TTSModelProvider *string `json:"tts_model_provider"`
|
||||
//TTSModelInstance *string `json:"tts_model_instance"`
|
||||
//TTSModelName *string `json:"tts_model_name"`
|
||||
//TTSModelEnable bool `json:"tts_model_enable"`
|
||||
}
|
||||
|
||||
func (s *TenantService) GetModelInfo(tenantID string, defaultModel string, modelType string) (*string, *string, *string, bool, error) {
|
||||
// normally the model string is: modelName@instanceName@providerName, sometimes it's just modelName@providerName
|
||||
// for the 1st case, parse defaultChatModel into three parts
|
||||
defaultChatModelParts := strings.Split(defaultModel, "@")
|
||||
var providerName *string
|
||||
var instanceName *string
|
||||
var modelName *string
|
||||
if len(defaultChatModelParts) == 3 {
|
||||
providerName = &defaultChatModelParts[2]
|
||||
instanceName = &defaultChatModelParts[1]
|
||||
modelName = &defaultChatModelParts[0]
|
||||
|
||||
} else if len(defaultChatModelParts) == 2 {
|
||||
providerName = &defaultChatModelParts[1]
|
||||
instanceName = new(string)
|
||||
*instanceName = "default"
|
||||
modelName = &defaultChatModelParts[0]
|
||||
} else {
|
||||
return nil, nil, nil, false, fmt.Errorf("invalid model string: %s", defaultModel)
|
||||
}
|
||||
|
||||
// Check if the provider and instance exists
|
||||
modelProvider, err := s.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, *providerName)
|
||||
if err != nil {
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
modelInstance, err := s.modelInstanceDAO.GetByProviderIDAndInstanceName(modelProvider.ID, *instanceName)
|
||||
if err != nil {
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
modelSchema, err := dao.GetModelProviderManager().GetModelByName(*providerName, *modelName)
|
||||
if err != nil {
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
if !modelSchema.ModelTypeMap[modelType] {
|
||||
return nil, nil, nil, false, fmt.Errorf("model %s isn't a chat model", *modelName)
|
||||
}
|
||||
|
||||
var modelEntity *entity.TenantModel
|
||||
modelEntity, err = s.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(modelProvider.ID, modelInstance.ID, *modelName)
|
||||
if err != nil {
|
||||
errString := err.Error()
|
||||
if !strings.Contains(errString, "record not found") {
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
}
|
||||
|
||||
enable := modelEntity == nil
|
||||
|
||||
return providerName, instanceName, modelName, enable, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *TenantService) ListTenantDefaultModels(userID string) ([]ModelItem, error) {
|
||||
|
||||
tenantInfos, err := s.tenantDAO.GetInfoByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(tenantInfos) == 0 {
|
||||
return nil, nil // No tenant found (should not happen for valid user)
|
||||
}
|
||||
|
||||
ownedTenant := tenantInfos[0]
|
||||
|
||||
var result []ModelItem
|
||||
|
||||
defaultChatModelProvider, defaultChatModelInstance, defaultChatModelName, defaultChatModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.LLMID, "chat")
|
||||
if err == nil {
|
||||
result = append(result, ModelItem{
|
||||
ModelProvider: defaultChatModelProvider,
|
||||
ModelInstance: defaultChatModelInstance,
|
||||
ModelName: defaultChatModelName,
|
||||
ModelType: "llm",
|
||||
Enable: defaultChatModelEnable,
|
||||
})
|
||||
}
|
||||
|
||||
defaultEmbeddingModelProvider, defaultEmbeddingModelInstance, defaultEmbeddingModelName, defaultEmbeddingModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.EmbDID, "embedding")
|
||||
if err == nil {
|
||||
result = append(result, ModelItem{
|
||||
ModelProvider: defaultEmbeddingModelProvider,
|
||||
ModelInstance: defaultEmbeddingModelInstance,
|
||||
ModelName: defaultEmbeddingModelName,
|
||||
ModelType: "embedding",
|
||||
Enable: defaultEmbeddingModelEnable,
|
||||
})
|
||||
}
|
||||
|
||||
defaultRerankModelProvider, defaultRerankModelInstance, defaultRerankModelName, defaultRerankModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.RerankID, "rerank")
|
||||
if err == nil {
|
||||
result = append(result, ModelItem{
|
||||
ModelProvider: defaultRerankModelProvider,
|
||||
ModelInstance: defaultRerankModelInstance,
|
||||
ModelName: defaultRerankModelName,
|
||||
ModelType: "rerank",
|
||||
Enable: defaultRerankModelEnable,
|
||||
})
|
||||
}
|
||||
|
||||
defaultASRModelProvider, defaultASRModelInstance, defaultASRModelName, defaultASREnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.ASRID, "asr")
|
||||
if err == nil {
|
||||
result = append(result, ModelItem{
|
||||
ModelProvider: defaultASRModelProvider,
|
||||
ModelInstance: defaultASRModelInstance,
|
||||
ModelName: defaultASRModelName,
|
||||
ModelType: "asr",
|
||||
Enable: defaultASREnable,
|
||||
})
|
||||
}
|
||||
|
||||
defaultImage2TextModelProvider, defaultImage2TextModelInstance, defaultImage2TextModelName, defaultImage2TextModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.Img2TxtID, "image2text")
|
||||
if err == nil {
|
||||
result = append(result, ModelItem{
|
||||
ModelProvider: defaultImage2TextModelProvider,
|
||||
ModelInstance: defaultImage2TextModelInstance,
|
||||
ModelName: defaultImage2TextModelName,
|
||||
ModelType: "image2text",
|
||||
Enable: defaultImage2TextModelEnable,
|
||||
})
|
||||
}
|
||||
|
||||
if ownedTenant.TTSID == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
defaultTTSModelProvider, defaultTTSModelInstance, defaultTTSModelName, defaultTTSModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, *ownedTenant.TTSID, "tts")
|
||||
if err == nil {
|
||||
result = append(result, ModelItem{
|
||||
ModelProvider: defaultTTSModelProvider,
|
||||
ModelInstance: defaultTTSModelInstance,
|
||||
ModelName: defaultTTSModelName,
|
||||
ModelType: "tts",
|
||||
Enable: defaultTTSModelEnable,
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *TenantService) checkModelAvailable(tenantID, providerName, instanceName, modelName, modelType string) error {
|
||||
// Check if the provider and instance exists
|
||||
modelProvider, err := s.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modelInstance, err := s.modelInstanceDAO.GetByProviderIDAndInstanceName(modelProvider.ID, instanceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modelSchema, err := dao.GetModelProviderManager().GetModelByName(providerName, modelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !modelSchema.ModelTypeMap[modelType] {
|
||||
return fmt.Errorf("model %s isn't a chat model", modelName)
|
||||
}
|
||||
|
||||
var modelEntity *entity.TenantModel
|
||||
modelEntity, err = s.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(modelProvider.ID, modelInstance.ID, modelName)
|
||||
if err != nil || modelEntity != nil {
|
||||
var errString = err.Error()
|
||||
if errString == "record not found" {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("model %s isn't available", modelName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TenantService) SetTenantDefaultModels(userID, modelProvider, modelInstance, modelName, modelType string) error {
|
||||
|
||||
tenantInfos, err := s.tenantDAO.GetInfoByUserID(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(tenantInfos) == 0 {
|
||||
return nil // No tenant found (should not happen for valid user)
|
||||
}
|
||||
|
||||
ownedTenant := tenantInfos[0]
|
||||
err = s.checkModelAvailable(ownedTenant.TenantID, modelProvider, modelInstance, modelName, modelType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var modelTypeID string
|
||||
if modelType == "chat" {
|
||||
modelTypeID = "llm_id"
|
||||
}
|
||||
if modelType == "embedding" {
|
||||
modelTypeID = "embd_id"
|
||||
}
|
||||
if modelType == "rerank" {
|
||||
modelTypeID = "rerank_id"
|
||||
}
|
||||
if modelType == "asr" {
|
||||
modelTypeID = "asr_id"
|
||||
}
|
||||
if modelType == "image2text" {
|
||||
modelTypeID = "img2txt_id"
|
||||
}
|
||||
if modelType == "tts" {
|
||||
modelTypeID = "tts_id"
|
||||
}
|
||||
if modelTypeID == "" {
|
||||
return fmt.Errorf("model type %s is invalid", modelType)
|
||||
}
|
||||
|
||||
defaultModel := fmt.Sprintf("%s@%s@%s", modelName, modelInstance, modelProvider)
|
||||
err = s.tenantDAO.Update(ownedTenant.TenantID, map[string]interface{}{
|
||||
modelTypeID: defaultModel,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user