diff --git a/internal/cli/user_command.go b/internal/cli/user_command.go index c628e82b4e..48a3639cb3 100644 --- a/internal/cli/user_command.go +++ b/internal/cli/user_command.go @@ -1014,7 +1014,7 @@ func (c *RAGFlowClient) AddProvider(cmd *Command) (ResponseIf, error) { "provider_name": providerName, } - resp, err := c.HTTPClient.Request("POST", "/providers", true, "web", nil, payload) + resp, err := c.HTTPClient.Request("PUT", "/providers", true, "web", nil, payload) if err != nil { return nil, fmt.Errorf("failed to add provider: %w", err) } @@ -1301,9 +1301,13 @@ func (c *RAGFlowClient) DropProviderInstance(cmd *Command) (ResponseIf, error) { return nil, fmt.Errorf("provider name not provided") } - url := fmt.Sprintf("/providers/%s/instances/%s", providerName, instanceName) + payload := map[string]interface{}{ + "instances": []string{instanceName}, + } - resp, err := c.HTTPClient.Request("DELETE", url, true, "web", nil, nil) + url := fmt.Sprintf("/providers/%s/instances", providerName) + + resp, err := c.HTTPClient.Request("DELETE", url, true, "web", nil, payload) if err != nil { return nil, fmt.Errorf("failed to drop instance: %w", err) } @@ -1388,7 +1392,7 @@ func (c *RAGFlowClient) EnableOrDisableModel(cmd *Command, status string) (Respo "status": status, } - resp, err := c.HTTPClient.Request("PUT", url, true, "web", nil, payload) + resp, err := c.HTTPClient.Request("PATCH", url, true, "web", nil, payload) if err != nil { return nil, fmt.Errorf("failed to enable/disable model: %w", err) } diff --git a/internal/handler/providers.go b/internal/handler/providers.go index d93cb9df57..5c9f4fdc08 100644 --- a/internal/handler/providers.go +++ b/internal/handler/providers.go @@ -402,6 +402,10 @@ func (h *ProviderHandler) AlterProviderInstance(c *gin.Context) { }) } +type DropProviderInstanceRequest struct { + Instances []string `json:"instances" binding:"required"` +} + func (h *ProviderHandler) DropProviderInstance(c *gin.Context) { providerName := c.Param("provider_name") if providerName == "" { @@ -411,19 +415,18 @@ func (h *ProviderHandler) DropProviderInstance(c *gin.Context) { }) return } - - instanceName := c.Param("instance_name") - if instanceName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Instance name is required", + var req DropProviderInstanceRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), }) return } userID := c.GetString("user_id") - _, err := h.modelProviderService.DropProviderInstance(providerName, instanceName, userID) + _, err := h.modelProviderService.DropProviderInstances(providerName, userID, req.Instances) if err != nil { c.JSON(http.StatusOK, gin.H{ "code": common.CodeServerError, diff --git a/internal/router/router.go b/internal/router/router.go index af255675cf..def6a96d83 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -204,7 +204,7 @@ func (r *Router) Setup(engine *gin.Engine) { provider := v1.Group("/providers") { provider.GET("/", r.providerHandler.ListProviders) - provider.POST("/", r.providerHandler.AddProvider) + provider.PUT("/", r.providerHandler.AddProvider) provider.GET("/:provider_name", r.providerHandler.ShowProvider) provider.DELETE("/:provider_name", r.providerHandler.DeleteProvider) provider.GET("/:provider_name/models", r.providerHandler.ListModels) @@ -213,9 +213,9 @@ func (r *Router) Setup(engine *gin.Engine) { provider.GET("/:provider_name/instances", r.providerHandler.ListProviderInstances) provider.GET("/:provider_name/instances/:instance_name", r.providerHandler.ShowProviderInstance) provider.PUT("/:provider_name/instances/:instance_name", r.providerHandler.AlterProviderInstance) - provider.DELETE("/:provider_name/instances/:instance_name", r.providerHandler.DropProviderInstance) + provider.DELETE("/:provider_name/instances", r.providerHandler.DropProviderInstance) provider.GET("/:provider_name/instances/:instance_name/models", r.providerHandler.ListInstanceModels) - provider.PUT("/:provider_name/instances/:instance_name/models/:model_name", r.providerHandler.EnableOrDisableModel) + provider.PATCH("/:provider_name/instances/:instance_name/models/:model_name", r.providerHandler.EnableOrDisableModel) provider.POST("/:provider_name/instances/:instance_name/models/:model_name", r.providerHandler.ChatToModel) } diff --git a/internal/service/model_service.go b/internal/service/model_service.go index 8761a9fa36..b3d0685180 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -24,10 +24,10 @@ import ( "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/entity" + modelModule "ragflow/internal/entity/models" "strings" "time" - model "ragflow/internal/entity/models" "ragflow/internal/service/models" ) @@ -351,7 +351,7 @@ func (m *ModelProviderService) ShowProviderInstance(providerName, instanceName, func (m *ModelProviderService) AlterProviderInstance(providerName, instanceName, newInstanceName, apiKey, userID string) (common.ErrorCode, error) { return common.CodeSuccess, nil } -func (m *ModelProviderService) DropProviderInstance(providerName, instanceName, userID string) (common.ErrorCode, error) { +func (m *ModelProviderService) DropProviderInstances(providerName, userID string, instances []string) (common.ErrorCode, error) { // Get tenant ID from user tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") @@ -371,13 +371,15 @@ func (m *ModelProviderService) DropProviderInstance(providerName, instanceName, return common.CodeServerError, err } - count, err := m.modelInstanceDAO.DeleteByProviderIDAndInstanceName(provider.ID, instanceName) - if err != nil { - return common.CodeServerError, err - } + for _, instanceName := range instances { + count, err := m.modelInstanceDAO.DeleteByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return common.CodeServerError, err + } - if count == 0 { - return common.CodeNotFound, errors.New("provider instance not found") + if count == 0 { + return common.CodeNotFound, errors.New("provider instance not found") + } } return common.CodeSuccess, nil @@ -468,11 +470,18 @@ func (m *ModelProviderService) UpdateModelStatus(providerName, instanceName, mod if err != nil { return common.CodeServerError, errors.New("fail to get UUID") } + + var modelSchema *entity.Model + modelSchema, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err != nil { + return common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) + } + // Get model info from provider model = &entity.TenantModel{ ID: modelID, ModelName: modelName, - ModelType: model.ModelType, + ModelType: modelSchema.ModelTypes[0], ProviderID: provider.ID, InstanceID: instance.ID, Status: status, @@ -616,7 +625,7 @@ func (m *ModelProviderService) ChatToModelStream(providerName, instanceName, mod } // ChatToModelStreamWithSender streams chat response directly via sender function (best performance, no channel) -func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, message string, modelConfig *model.ChatConfig, sender func(*string, *string) error) common.ErrorCode { +func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanceName, modelName, userID, message string, modelConfig *modelModule.ChatConfig, sender func(*string, *string) error) common.ErrorCode { // Get tenant ID from user tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") if err != nil {