diff --git a/conf/models/zhipu-ai.json b/conf/models/zhipu-ai.json index 34f856c139..87fcb43448 100644 --- a/conf/models/zhipu-ai.json +++ b/conf/models/zhipu-ai.json @@ -26,6 +26,15 @@ ], "features": {} }, + { + "name": "glm-4.6v-Flash", + "max_tokens": 128000, + "model_types": [ + "chat", + "image2text" + ], + "features": {} + }, { "name": "glm-4.5-x", "max_tokens": 128000, diff --git a/internal/cli/client.go b/internal/cli/client.go index 1a6cc0326d..39ae488d28 100644 --- a/internal/cli/client.go +++ b/internal/cli/client.go @@ -248,6 +248,10 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) { return c.UseModel(cmd) case "show_current_model": return c.ShowCurrentModel(cmd) + case "set_default_model": + return c.SetDefaultModel(cmd) + case "list_user_default_models": + return c.ListDefaultModels(cmd) // Dataset, metadata commands case "create_dataset_table": return c.CreateDatasetInDocEngine(cmd) diff --git a/internal/cli/common_command.go b/internal/cli/common_command.go index 52aacfea08..695e559178 100644 --- a/internal/cli/common_command.go +++ b/internal/cli/common_command.go @@ -373,6 +373,75 @@ func (c *RAGFlowClient) ShowModel(cmd *Command) (ResponseIf, error) { return &result, nil } +func (c *RAGFlowClient) SetDefaultModel(cmd *Command) (ResponseIf, error) { + + modeType, ok := cmd.Params["model_type"].(string) + if !ok { + return nil, fmt.Errorf("model_type not provided") + } + modelProvider, ok := cmd.Params["model_provider"].(string) + if !ok { + return nil, fmt.Errorf("model_provider not provided") + } + modelInstance, ok := cmd.Params["model_instance"].(string) + if !ok { + return nil, fmt.Errorf("model_instance not provided") + } + modelName, ok := cmd.Params["model_name"].(string) + if !ok { + return nil, fmt.Errorf("model_name not provided") + } + + payload := map[string]interface{}{ + "model_type": modeType, + "model_provider": modelProvider, + "model_instance": modelInstance, + "model_name": modelName, + } + + resp, err := c.HTTPClient.Request("PATCH", "/models", true, "web", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to set default model: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to set default model: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result SimpleResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("failed to set default model: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + +func (c *RAGFlowClient) ListDefaultModels(cmd *Command) (ResponseIf, error) { + resp, err := c.HTTPClient.Request("GET", "/models", true, "web", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list default models: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list default models: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("failed to list default models: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + // readPassword reads password from terminal without echoing func ReadPassword() (string, error) { if !term.IsTerminal(int(os.Stdin.Fd())) { diff --git a/internal/cli/user_parser.go b/internal/cli/user_parser.go index f74dae494b..43c9437824 100644 --- a/internal/cli/user_parser.go +++ b/internal/cli/user_parser.go @@ -1597,35 +1597,49 @@ func (p *Parser) parseSetVariable() (*Command, error) { func (p *Parser) parseSetDefault() (*Command, error) { p.nextToken() // consume DEFAULT - var modelType, modelID string + var modelType, modelProvider, modelInstance, modelName string + var err error switch p.curToken.Type { case TokenLLM: - modelType = "llm_id" + modelType = "chat" case TokenVLM: - modelType = "img2txt_id" + modelType = "image2text" case TokenEmbedding: - modelType = "embd_id" + modelType = "embedding" case TokenReranker: - modelType = "reranker_id" + modelType = "rerank" case TokenASR: - modelType = "asr_id" + modelType = "asr" case TokenTTS: - modelType = "tts_id" + modelType = "tts" default: return nil, fmt.Errorf("unknown model type: %s", p.curToken.Value) } p.nextToken() - id, err := p.parseQuotedString() + modelProvider, err = p.parseQuotedString() + if err != nil { + return nil, err + } + + p.nextToken() + modelInstance, err = p.parseQuotedString() + if err != nil { + return nil, err + } + + p.nextToken() + modelName, err = p.parseQuotedString() if err != nil { return nil, err } - modelID = id cmd := NewCommand("set_default_model") cmd.Params["model_type"] = modelType - cmd.Params["model_id"] = modelID + cmd.Params["model_provider"] = modelProvider + cmd.Params["model_instance"] = modelInstance + cmd.Params["model_name"] = modelName p.nextToken() // Semicolon is optional for UNSET TOKEN @@ -2601,7 +2615,6 @@ func (p *Parser) parseRemoveTags() (*Command, error) { return cmd, nil } - // parseRemoveChunk parses: // - REMOVE CHUNKS 'chunk_id1', 'chunk_id2' FROM DOCUMENT 'doc_id'; // - REMOVE ALL CHUNKS FROM DOCUMENT 'doc_id'; diff --git a/internal/entity/model.go b/internal/entity/model.go index 0b8f208cb3..0017b65663 100644 --- a/internal/entity/model.go +++ b/internal/entity/model.go @@ -149,7 +149,7 @@ type Provider struct { Tags string `json:"tags"` URL string `json:"url"` URLSuffix models.URLSuffix `json:"url_suffix"` - Models []Model `json:"models"` + Models []*Model `json:"models"` ModelDriver models.ModelDriver } @@ -547,7 +547,7 @@ func (pm *ProviderManager) FindProvider(name string) *Provider { func (pm *ProviderManager) findModel(provider *Provider, modelName string) *Model { for i := range provider.Models { if strings.EqualFold(provider.Models[i].Name, modelName) { - return &provider.Models[i] + return provider.Models[i] } } return nil diff --git a/internal/handler/tenant.go b/internal/handler/tenant.go index 79ffdda090..b01515af63 100644 --- a/internal/handler/tenant.go +++ b/internal/handler/tenant.go @@ -42,6 +42,81 @@ func NewTenantHandler(tenantService *service.TenantService, userService *service } } +func (h *TenantHandler) GetModels(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + defaultModels, err := h.tenantService.ListTenantDefaultModels(user.ID) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeExceptionError, + "message": err.Error(), + "data": false, + }) + return + } + + if defaultModels == nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeDataError, + "message": "No default models", + "data": nil, + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": "success", + "data": defaultModels, + }) +} + +type SetModelRequest struct { + ModelProvider string `json:"model_provider" binding:"required"` + ModelInstance string `json:"model_instance" binding:"required"` + ModelName string `json:"model_name" binding:"required"` + ModelType string `json:"model_type" binding:"required"` +} + +func (h *TenantHandler) SetModels(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + // Parse request body (same as Python get_request_json()) + var req SetModelRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeBadRequest, + "data": nil, + "message": "Invalid request body: " + err.Error(), + }) + return + } + + err := h.tenantService.SetTenantDefaultModels(user.ID, req.ModelProvider, req.ModelInstance, req.ModelName, req.ModelType) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeExceptionError, + "message": err.Error(), + "data": false, + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": "success", + "data": nil, + }) +} + // TenantInfo get tenant information // @Summary Get Tenant Information // @Description Get current user's tenant information (owner tenant) diff --git a/internal/router/router.go b/internal/router/router.go index def6a96d83..bc979b8b70 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -219,6 +219,12 @@ func (r *Router) Setup(engine *gin.Engine) { provider.POST("/:provider_name/instances/:instance_name/models/:model_name", r.providerHandler.ChatToModel) } + model := v1.Group("/models") + { + model.GET("/", r.tenantHandler.GetModels) + model.PATCH("/", r.tenantHandler.SetModels) + } + system := v1.Group("/system") { system.GET("/version", r.systemHandler.GetVersion) diff --git a/internal/service/tenant.go b/internal/service/tenant.go index 1415eabd32..5b8a2d33a5 100644 --- a/internal/service/tenant.go +++ b/internal/service/tenant.go @@ -28,17 +28,27 @@ import ( // TenantService tenant service type TenantService struct { - tenantDAO *dao.TenantDAO - userTenantDAO *dao.UserTenantDAO - docEngine engine.DocEngine + tenantDAO *dao.TenantDAO + userTenantDAO *dao.UserTenantDAO + modelProviderDAO *dao.TenantModelProviderDAO + modelInstanceDAO *dao.TenantModelInstanceDAO + modelDAO *dao.TenantModelDAO + modelGroupDAO *dao.TenantModelGroupDAO + modelGroupMappingDAO *dao.TenantModelGroupMappingDAO + docEngine engine.DocEngine } // NewTenantService create tenant service func NewTenantService() *TenantService { return &TenantService{ - tenantDAO: dao.NewTenantDAO(), - userTenantDAO: dao.NewUserTenantDAO(), - docEngine: engine.Get(), + tenantDAO: dao.NewTenantDAO(), + userTenantDAO: dao.NewUserTenantDAO(), + modelProviderDAO: dao.NewTenantModelProviderDAO(), + modelInstanceDAO: dao.NewTenantModelInstanceDAO(), + modelDAO: dao.NewTenantModelDAO(), + modelGroupDAO: dao.NewTenantModelGroupDAO(), + modelGroupMappingDAO: dao.NewTenantModelGroupMappingDAO(), + docEngine: engine.Get(), } } @@ -282,3 +292,264 @@ func (s *TenantService) DeleteMetadataInDocEngine(tenantID string) (common.Error return common.CodeSuccess, nil } + +type ModelItem struct { + ModelProvider *string `json:"model_provider"` + ModelInstance *string `json:"model_instance"` + ModelName *string `json:"model_name"` + ModelType string `json:"model_type"` + Enable bool `json:"enable"` +} + +type DefaultModelResponse struct { + Models []ModelItem `json:"models,omitempty"` + //TenantID string `json:"tenant_id"` + //ChatModelProvider *string `json:"chat_model_provider"` + //ChatModelInstance *string `json:"chat_model_instance"` + //ChatModelName *string `json:"chat_model_name"` + //ChatModelEnable bool `json:"chat_model_enable"` + //EmbeddingModelProvider *string `json:"embedding_model_provider"` + //EmbeddingModelInstance *string `json:"embedding_model_instance"` + //EmbeddingModelName *string `json:"embedding_model_name"` + //EmbeddingModelEnable bool `json:"embedding_model_enable"` + //RerankModelProvider *string `json:"rerank_model_provider"` + //RerankModelInstance *string `json:"rerank_model_instance"` + //RerankModelName *string `json:"rerank_model_name"` + //RerankModelEnable bool `json:"rerank_model_enable"` + //ASRModelProvider *string `json:"asr_model_provider"` + //ASRModelInstance *string `json:"asr_model_instance"` + //ASRModelName *string `json:"asr_model_name"` + //ASREnable bool `json:"asr_enable"` + //Image2TextModelProvider *string `json:"image2text_model_provider"` + //Image2TextModelInstance *string `json:"image2text_model_instance"` + //Image2TextModelName *string `json:"image2text_model_name"` + //Image2TextModelEnable bool `json:"image2text_model_enable"` + //TTSModelProvider *string `json:"tts_model_provider"` + //TTSModelInstance *string `json:"tts_model_instance"` + //TTSModelName *string `json:"tts_model_name"` + //TTSModelEnable bool `json:"tts_model_enable"` +} + +func (s *TenantService) GetModelInfo(tenantID string, defaultModel string, modelType string) (*string, *string, *string, bool, error) { + // normally the model string is: modelName@instanceName@providerName, sometimes it's just modelName@providerName + // for the 1st case, parse defaultChatModel into three parts + defaultChatModelParts := strings.Split(defaultModel, "@") + var providerName *string + var instanceName *string + var modelName *string + if len(defaultChatModelParts) == 3 { + providerName = &defaultChatModelParts[2] + instanceName = &defaultChatModelParts[1] + modelName = &defaultChatModelParts[0] + + } else if len(defaultChatModelParts) == 2 { + providerName = &defaultChatModelParts[1] + instanceName = new(string) + *instanceName = "default" + modelName = &defaultChatModelParts[0] + } else { + return nil, nil, nil, false, fmt.Errorf("invalid model string: %s", defaultModel) + } + + // Check if the provider and instance exists + modelProvider, err := s.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, *providerName) + if err != nil { + return nil, nil, nil, false, err + } + + modelInstance, err := s.modelInstanceDAO.GetByProviderIDAndInstanceName(modelProvider.ID, *instanceName) + if err != nil { + return nil, nil, nil, false, err + } + + modelSchema, err := dao.GetModelProviderManager().GetModelByName(*providerName, *modelName) + if err != nil { + return nil, nil, nil, false, err + } + + if !modelSchema.ModelTypeMap[modelType] { + return nil, nil, nil, false, fmt.Errorf("model %s isn't a chat model", *modelName) + } + + var modelEntity *entity.TenantModel + modelEntity, err = s.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(modelProvider.ID, modelInstance.ID, *modelName) + if err != nil { + errString := err.Error() + if !strings.Contains(errString, "record not found") { + return nil, nil, nil, false, err + } + } + + enable := modelEntity == nil + + return providerName, instanceName, modelName, enable, nil + +} + +func (s *TenantService) ListTenantDefaultModels(userID string) ([]ModelItem, error) { + + tenantInfos, err := s.tenantDAO.GetInfoByUserID(userID) + if err != nil { + return nil, err + } + if len(tenantInfos) == 0 { + return nil, nil // No tenant found (should not happen for valid user) + } + + ownedTenant := tenantInfos[0] + + var result []ModelItem + + defaultChatModelProvider, defaultChatModelInstance, defaultChatModelName, defaultChatModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.LLMID, "chat") + if err == nil { + result = append(result, ModelItem{ + ModelProvider: defaultChatModelProvider, + ModelInstance: defaultChatModelInstance, + ModelName: defaultChatModelName, + ModelType: "llm", + Enable: defaultChatModelEnable, + }) + } + + defaultEmbeddingModelProvider, defaultEmbeddingModelInstance, defaultEmbeddingModelName, defaultEmbeddingModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.EmbDID, "embedding") + if err == nil { + result = append(result, ModelItem{ + ModelProvider: defaultEmbeddingModelProvider, + ModelInstance: defaultEmbeddingModelInstance, + ModelName: defaultEmbeddingModelName, + ModelType: "embedding", + Enable: defaultEmbeddingModelEnable, + }) + } + + defaultRerankModelProvider, defaultRerankModelInstance, defaultRerankModelName, defaultRerankModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.RerankID, "rerank") + if err == nil { + result = append(result, ModelItem{ + ModelProvider: defaultRerankModelProvider, + ModelInstance: defaultRerankModelInstance, + ModelName: defaultRerankModelName, + ModelType: "rerank", + Enable: defaultRerankModelEnable, + }) + } + + defaultASRModelProvider, defaultASRModelInstance, defaultASRModelName, defaultASREnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.ASRID, "asr") + if err == nil { + result = append(result, ModelItem{ + ModelProvider: defaultASRModelProvider, + ModelInstance: defaultASRModelInstance, + ModelName: defaultASRModelName, + ModelType: "asr", + Enable: defaultASREnable, + }) + } + + defaultImage2TextModelProvider, defaultImage2TextModelInstance, defaultImage2TextModelName, defaultImage2TextModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.Img2TxtID, "image2text") + if err == nil { + result = append(result, ModelItem{ + ModelProvider: defaultImage2TextModelProvider, + ModelInstance: defaultImage2TextModelInstance, + ModelName: defaultImage2TextModelName, + ModelType: "image2text", + Enable: defaultImage2TextModelEnable, + }) + } + + if ownedTenant.TTSID == nil { + return result, nil + } + + defaultTTSModelProvider, defaultTTSModelInstance, defaultTTSModelName, defaultTTSModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, *ownedTenant.TTSID, "tts") + if err == nil { + result = append(result, ModelItem{ + ModelProvider: defaultTTSModelProvider, + ModelInstance: defaultTTSModelInstance, + ModelName: defaultTTSModelName, + ModelType: "tts", + Enable: defaultTTSModelEnable, + }) + } + + return result, nil +} + +func (s *TenantService) checkModelAvailable(tenantID, providerName, instanceName, modelName, modelType string) error { + // Check if the provider and instance exists + modelProvider, err := s.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + return err + } + + modelInstance, err := s.modelInstanceDAO.GetByProviderIDAndInstanceName(modelProvider.ID, instanceName) + if err != nil { + return err + } + + modelSchema, err := dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err != nil { + return err + } + + if !modelSchema.ModelTypeMap[modelType] { + return fmt.Errorf("model %s isn't a chat model", modelName) + } + + var modelEntity *entity.TenantModel + modelEntity, err = s.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(modelProvider.ID, modelInstance.ID, modelName) + if err != nil || modelEntity != nil { + var errString = err.Error() + if errString == "record not found" { + return nil + } + return fmt.Errorf("model %s isn't available", modelName) + } + + return nil +} + +func (s *TenantService) SetTenantDefaultModels(userID, modelProvider, modelInstance, modelName, modelType string) error { + + tenantInfos, err := s.tenantDAO.GetInfoByUserID(userID) + if err != nil { + return err + } + if len(tenantInfos) == 0 { + return nil // No tenant found (should not happen for valid user) + } + + ownedTenant := tenantInfos[0] + err = s.checkModelAvailable(ownedTenant.TenantID, modelProvider, modelInstance, modelName, modelType) + if err != nil { + return err + } + + var modelTypeID string + if modelType == "chat" { + modelTypeID = "llm_id" + } + if modelType == "embedding" { + modelTypeID = "embd_id" + } + if modelType == "rerank" { + modelTypeID = "rerank_id" + } + if modelType == "asr" { + modelTypeID = "asr_id" + } + if modelType == "image2text" { + modelTypeID = "img2txt_id" + } + if modelType == "tts" { + modelTypeID = "tts_id" + } + if modelTypeID == "" { + return fmt.Errorf("model type %s is invalid", modelType) + } + + defaultModel := fmt.Sprintf("%s@%s@%s", modelName, modelInstance, modelProvider) + err = s.tenantDAO.Update(ownedTenant.TenantID, map[string]interface{}{ + modelTypeID: defaultModel, + }) + + return nil +}