From 261be81127158ce76209f99e3977be5a975340c5 Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Wed, 29 Apr 2026 19:18:49 +0800 Subject: [PATCH] Go: add drop instance models (#14485) ### What problem does this PR solve? 1. drop instance model 2. Fix issue of drop instance but not drop models. ### Type of change - [x] New Feature (non-breaking change which adds functionality) Signed-off-by: Jin Hai --- internal/cli/admin_parser.go | 2 - internal/cli/client.go | 2 + internal/cli/user_command.go | 54 +++++++++++++- internal/cli/user_parser.go | 113 ++++++++++++++---------------- internal/dao/tenant_model.go | 10 +++ internal/entity/models/types.go | 1 + internal/handler/providers.go | 51 +++++++++++++- internal/router/router.go | 1 + internal/service/model_service.go | 77 +++++++++++++++++--- 9 files changed, 237 insertions(+), 74 deletions(-) diff --git a/internal/cli/admin_parser.go b/internal/cli/admin_parser.go index 723aad512a..ef0394b189 100644 --- a/internal/cli/admin_parser.go +++ b/internal/cli/admin_parser.go @@ -700,8 +700,6 @@ func (p *Parser) parseAdminDropCommand() (*Command, error) { return p.parseDropUser() case TokenRole: return p.parseDropRole() - case TokenModel: - return p.parseDropModelProvider() case TokenDataset: return p.parseDropDataset() case TokenChat: diff --git a/internal/cli/client.go b/internal/cli/client.go index acd8eba175..f92aeb2d9c 100644 --- a/internal/cli/client.go +++ b/internal/cli/client.go @@ -242,6 +242,8 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) { return c.AlterProviderInstance(cmd) case "drop_provider_instance": return c.DropProviderInstance(cmd) + case "drop_instance_model": + return c.DropInstanceModel(cmd) case "enable_model": return c.EnableOrDisableModel(cmd, "enable") case "disable_model": diff --git a/internal/cli/user_command.go b/internal/cli/user_command.go index 2e30b52adb..c78a102960 100644 --- a/internal/cli/user_command.go +++ b/internal/cli/user_command.go @@ -1383,6 +1383,56 @@ func (c *RAGFlowClient) DropProviderInstance(cmd *Command) (ResponseIf, error) { return &result, nil } +// DropInstanceModel deletes a provider instance, only works for local deployed model +// DROP MODEL FROM +func (c *RAGFlowClient) DropInstanceModel(cmd *Command) (ResponseIf, error) { + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + instanceName, ok := cmd.Params["instance_name"].(string) + if !ok { + return nil, fmt.Errorf("instance name not provided") + } + + providerName, ok := cmd.Params["provider_name"].(string) + if !ok { + return nil, fmt.Errorf("provider name not provided") + } + + modelName, ok := cmd.Params["model_name"].(string) + if !ok { + return nil, fmt.Errorf("model name not provided") + } + + payload := map[string]interface{}{ + "models": []string{modelName}, + } + + url := fmt.Sprintf("/providers/%s/instances/%s/models", providerName, instanceName) + + resp, err := c.HTTPClient.Request("DELETE", url, true, "web", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to drop instance: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to drop instance: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result SimpleResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("drop instance failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} + func (c *RAGFlowClient) ListInstanceModels(cmd *Command) (ResponseIf, error) { if c.ServerType != "user" { return nil, fmt.Errorf("this command is only allowed in USER mode") @@ -1722,7 +1772,7 @@ func (c *RAGFlowClient) AddCustomModel(cmd *Command) (ResponseIf, error) { } // chat, vision, embedding, rerank, tts, asr, ocr - modelType, ok := cmd.Params["model_type"].(string) + modelTypes, ok := cmd.Params["model_types"].([]string) if !ok { return nil, fmt.Errorf("model type not provided") } @@ -1738,7 +1788,7 @@ func (c *RAGFlowClient) AddCustomModel(cmd *Command) (ResponseIf, error) { "provider_name": providerName, "instance_name": instanceName, "model_name": modelName, - "model_type": modelType, + "model_types": modelTypes, "max_tokens": maxTokens, } diff --git a/internal/cli/user_parser.go b/internal/cli/user_parser.go index a31a374ec5..43317fe6ec 100644 --- a/internal/cli/user_parser.go +++ b/internal/cli/user_parser.go @@ -772,7 +772,7 @@ func (p *Parser) parseAddModel() (*Command, error) { } p.nextToken() - modelType := "" + var modelTypes []string var supportThink *bool = nil maxTokens := 0 if p.curToken.Type == TokenWith { @@ -789,46 +789,25 @@ func (p *Parser) parseAddModel() (*Command, error) { *supportThink = true case TokenVision: p.nextToken() - if modelType != "" { - return nil, fmt.Errorf("model type is %s, attempt to change to vision", modelType) - } - modelType = "vision" + modelTypes = append(modelTypes, "vision") case TokenChat: p.nextToken() - if modelType != "" { - return nil, fmt.Errorf("model type is %s, attempt to change to chat", modelType) - } - modelType = "chat" + modelTypes = append(modelTypes, "chat") case TokenEmbedding: - if modelType != "" { - return nil, fmt.Errorf("model type is %s, attempt to change to embedding", modelType) - } p.nextToken() - modelType = "embedding" + modelTypes = append(modelTypes, "embedding") case TokenRerank: - if modelType != "" { - return nil, fmt.Errorf("model type is %s, attempt to change to rerank", modelType) - } p.nextToken() - modelType = "rerank" + modelTypes = append(modelTypes, "rerank") case TokenOCR: - if modelType != "" { - return nil, fmt.Errorf("model type is %s, attempt to change to OCR", modelType) - } p.nextToken() - modelType = "ocr" + modelTypes = append(modelTypes, "ocr") case TokenTTS: - if modelType != "" { - return nil, fmt.Errorf("model type is %s, attempt to change to TTS", modelType) - } p.nextToken() - modelType = "tts" + modelTypes = append(modelTypes, "tts") case TokenASR: - if modelType != "" { - return nil, fmt.Errorf("model type is %s, attempt to change to ASR", modelType) - } p.nextToken() - modelType = "asr" + modelTypes = append(modelTypes, "asr") case TokenTokens: p.nextToken() // pass TOKENS if maxTokens != 0 { @@ -854,7 +833,7 @@ func (p *Parser) parseAddModel() (*Command, error) { cmd := NewCommand("add_custom_model") cmd.Params["model_name"] = modelName - cmd.Params["model_type"] = modelType + cmd.Params["model_types"] = modelTypes cmd.Params["provider_name"] = providerName cmd.Params["instance_name"] = instanceName if supportThink != nil { @@ -862,12 +841,6 @@ func (p *Parser) parseAddModel() (*Command, error) { } cmd.Params["max_tokens"] = maxTokens - if modelType != "chat" && modelType != "vision" { - if supportThink != nil && *supportThink { - return nil, fmt.Errorf("think not supported for model type %s", modelType) - } - } - return cmd, nil } @@ -951,8 +924,6 @@ func (p *Parser) parseDropCommand() (*Command, error) { return p.parseDropUser() case TokenRole: return p.parseDropRole() - case TokenModel: - return p.parseDropModelProvider() case TokenDataset: return p.parseDropDataset() case TokenChat: @@ -965,6 +936,8 @@ func (p *Parser) parseDropCommand() (*Command, error) { return p.parseDropMetadataTable() case TokenInstance: return p.parseDropInstance() + case TokenModel: + return p.parseDropInstanceModel() default: return nil, fmt.Errorf("unknown DROP target: %s", p.curToken.Value) } @@ -1099,29 +1072,6 @@ func (p *Parser) parseDropRole() (*Command, error) { return cmd, nil } -func (p *Parser) parseDropModelProvider() (*Command, error) { - p.nextToken() // consume MODEL - if p.curToken.Type != TokenProvider { - return nil, fmt.Errorf("expected PROVIDER") - } - p.nextToken() - - providerName, err := p.parseQuotedString() - if err != nil { - return nil, err - } - - cmd := NewCommand("drop_model_provider") - cmd.Params["provider_name"] = providerName - - p.nextToken() - // Semicolon is optional for UNSET TOKEN - if p.curToken.Type == TokenSemicolon { - p.nextToken() - } - return cmd, nil -} - // parseDeleteProvider parses DELETE PROVIDER command func (p *Parser) parseDeleteProvider() (*Command, error) { p.nextToken() // consume PROVIDER @@ -1610,6 +1560,47 @@ func (p *Parser) parseDropInstance() (*Command, error) { return cmd, nil } +// parseDropInstanceModel parses DROP MODEL FROM command +// Only works for local deployed model +func (p *Parser) parseDropInstanceModel() (*Command, error) { + p.nextToken() // consume MODEL + + modelName, err := p.parseQuotedString() + if err != nil { + return nil, fmt.Errorf("expected instance name: %w", err) + } + p.nextToken() + + if p.curToken.Type != TokenFrom { + return nil, fmt.Errorf("expected FROM") + } + p.nextToken() + + providerName, err := p.parseQuotedString() + if err != nil { + return nil, fmt.Errorf("expected provider name after FROM PROVIDER: %w", err) + } + p.nextToken() + + instanceName, err := p.parseQuotedString() + if err != nil { + return nil, fmt.Errorf("expected instance name after provider name: %w", err) + } + p.nextToken() + + cmd := NewCommand("drop_instance_model") + cmd.Params["instance_name"] = instanceName + cmd.Params["provider_name"] = providerName + cmd.Params["model_name"] = modelName + + p.nextToken() + // Semicolon is optional + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return cmd, nil +} + func (p *Parser) parseGrantCommand() (*Command, error) { p.nextToken() // consume GRANT diff --git a/internal/dao/tenant_model.go b/internal/dao/tenant_model.go index bb3b4f41ba..fd69c3ca41 100644 --- a/internal/dao/tenant_model.go +++ b/internal/dao/tenant_model.go @@ -37,6 +37,16 @@ func (dao *TenantModelDAO) DeleteByModelID(modelID string) (int64, error) { return result.RowsAffected, result.Error } +func (dao *TenantModelDAO) DeleteByProviderIDAndInstanceID(provideID, instanceID string) (int64, error) { + result := DB.Unscoped().Where("provider_id = ? AND instance_id = ?", provideID, instanceID).Delete(&entity.TenantModel{}) + return result.RowsAffected, result.Error +} + +func (dao *TenantModelDAO) DeleteByProviderIDAndInstanceIDAndModelName(provideID, instanceID, modelName string) (int64, error) { + result := DB.Unscoped().Where("provider_id = ? AND instance_id = ? AND model_name = ?", provideID, instanceID, modelName).Delete(&entity.TenantModel{}) + return result.RowsAffected, result.Error +} + // GetByID get tenant model by primary key (id) func (dao *TenantModelDAO) GetByID(id string) (*entity.TenantModel, error) { var model entity.TenantModel diff --git a/internal/entity/models/types.go b/internal/entity/models/types.go index 90a9a69aee..c12f37c5f0 100644 --- a/internal/entity/models/types.go +++ b/internal/entity/models/types.go @@ -50,6 +50,7 @@ type URLSuffix struct { type ChatConfig struct { Stream *bool + Vision *bool Thinking *bool MaxTokens *int Temperature *float64 diff --git a/internal/handler/providers.go b/internal/handler/providers.go index 5104076d77..4db54759df 100644 --- a/internal/handler/providers.go +++ b/internal/handler/providers.go @@ -682,7 +682,7 @@ func (h *ProviderHandler) AddCustomModel(c *gin.Context) { return } - if req.ModelType == "" { + if req.ModelTypes == nil { c.JSON(http.StatusBadRequest, gin.H{ "code": 400, "message": "Model type is required", @@ -707,6 +707,54 @@ func (h *ProviderHandler) AddCustomModel(c *gin.Context) { } +type DropInstanceModelRequest struct { + Models []string `json:"models" binding:"required"` +} + +func (h *ProviderHandler) DropInstanceModels(c *gin.Context) { + providerName := c.Param("provider_name") + if providerName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } + instanceName := c.Param("instance_name") + if instanceName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Instance name is required", + }) + return + } + + var req DropInstanceModelRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + }) + return + } + + userID := c.GetString("user_id") + + _, err := h.modelProviderService.DropInstanceModels(providerName, instanceName, userID, req.Models) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + }) +} + type ChatToModelRequest struct { ProviderName *string `json:"provider_name"` InstanceName *string `json:"instance_name"` @@ -768,6 +816,7 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) { chatConfig := models.ChatConfig{ Thinking: &req.Thinking, Stream: &req.Stream, + Vision: nil, Stop: &[]string{}, DoSample: nil, MaxTokens: nil, diff --git a/internal/router/router.go b/internal/router/router.go index ab8c44197e..8c8d30dca2 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -219,6 +219,7 @@ func (r *Router) Setup(engine *gin.Engine) { provider.GET("/:provider_name/instances/:instance_name/models", r.providerHandler.ListInstanceModels) provider.PATCH("/:provider_name/instances/:instance_name/models/*model_name", r.providerHandler.EnableOrDisableModel) provider.POST("/:provider_name/instances/:instance_name/models", r.providerHandler.AddCustomModel) + provider.DELETE("/:provider_name/instances/:instance_name/models", r.providerHandler.DropInstanceModels) v1.POST("/chat/completions", r.providerHandler.ChatToModel) } diff --git a/internal/service/model_service.go b/internal/service/model_service.go index 043b5ff4d7..7b95b745c1 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -478,7 +478,22 @@ func (m *ModelProviderService) DropProviderInstances(providerName, userID string } for _, instanceName := range instances { - count, err := m.modelInstanceDAO.DeleteByProviderIDAndInstanceName(provider.ID, instanceName) + // Get model instance + var tenantModelInstance *entity.TenantModelInstance + tenantModelInstance, err = m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return common.CodeServerError, err + } + + // Delete all models of this instance + var count int64 = 0 + count, err = m.modelDAO.DeleteByProviderIDAndInstanceID(provider.ID, tenantModelInstance.ID) + if err != nil { + return common.CodeServerError, err + } + + // Delete model instance + count, err = m.modelInstanceDAO.DeleteByProviderIDAndInstanceName(provider.ID, instanceName) if err != nil { return common.CodeServerError, err } @@ -491,6 +506,48 @@ func (m *ModelProviderService) DropProviderInstances(providerName, userID string return common.CodeSuccess, nil } +func (m *ModelProviderService) DropInstanceModels(providerName, instanceName, userID string, models []string) (common.ErrorCode, error) { + + // Get tenant ID from user + tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + return common.CodeServerError, err + } + + if len(tenants) == 0 { + return common.CodeNotFound, errors.New("user has no tenants") + } + + tenantID := tenants[0].TenantID + + // Check if provider exists + provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + return common.CodeServerError, err + } + + var modelInstance *entity.TenantModelInstance + modelInstance, err = m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return common.CodeServerError, err + } + + for _, modelName := range models { + // Delete all models of this instance + var count int64 = 0 + count, err = m.modelDAO.DeleteByProviderIDAndInstanceIDAndModelName(provider.ID, modelInstance.ID, modelName) + if err != nil { + return common.CodeServerError, err + } + + if count == 0 { + return common.CodeNotFound, fmt.Errorf("model: %s not found", modelName) + } + } + + return common.CodeSuccess, nil +} + func (m *ModelProviderService) ListInstanceModels(providerName, instanceName, userID string) ([]map[string]interface{}, error) { // Get tenant ID from user tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") @@ -693,6 +750,9 @@ func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey + modelTypes := extra["model_types"] + println(modelTypes) + modelConfig.ModelClass = &providerInfo.Class newURL := map[string]string{ @@ -891,12 +951,12 @@ func (m *ModelProviderService) GetChatModel(tenantID, compositeModelName string) } type AddCustomModelRequest struct { - ProviderName string `json:"provider_name"` - InstanceName string `json:"instance_name"` - ModelName string `json:"model_name"` - ModelType string `json:"model_type"` - MaxTokens int `json:"max_tokens"` - Thinking *bool `json:"thinking"` + ProviderName string `json:"provider_name"` + InstanceName string `json:"instance_name"` + ModelName string `json:"model_name"` + ModelTypes []string `json:"model_types"` + MaxTokens int `json:"max_tokens"` + Thinking *bool `json:"thinking"` } func (m *ModelProviderService) AddCustomModel(request *AddCustomModelRequest, userID string) (common.ErrorCode, error) { @@ -938,6 +998,7 @@ func (m *ModelProviderService) AddCustomModel(request *AddCustomModelRequest, us if request.Thinking != nil { extra["thinking"] = *request.Thinking } + extra["model_types"] = request.ModelTypes // convert extra to string extraByte, err := json.Marshal(extra) if err != nil { @@ -948,7 +1009,7 @@ func (m *ModelProviderService) AddCustomModel(request *AddCustomModelRequest, us model := &entity.TenantModel{ ID: modelID, ModelName: request.ModelName, - ModelType: request.ModelType, + ModelType: request.ModelTypes[0], ProviderID: provider.ID, InstanceID: instance.ID, Status: "active",