diff --git a/internal/dao/tenant_model.go b/internal/dao/tenant_model.go index 8fc31c450b..d9572e4db5 100644 --- a/internal/dao/tenant_model.go +++ b/internal/dao/tenant_model.go @@ -54,6 +54,11 @@ func (dao *TenantModelDAO) DeleteByModelID(modelID string) (int64, error) { return result.RowsAffected, result.Error } +func (dao *TenantModelDAO) DeleteByModelIDAndProviderIDAndInstanceID(modelID, providerID, instanceID string) (int64, error) { + result := DB.Unscoped().Where("id = ? AND provider_id = ? AND instance_id = ?", modelID, providerID, instanceID).Delete(&entity.TenantModel{}) + return result.RowsAffected, result.Error +} + func (dao *TenantModelDAO) DeleteByProviderIDAndInstanceID(provideID, instanceID string) (int64, error) { result := DB.Unscoped().Where("provider_id = ? AND instance_id = ?", provideID, instanceID).Delete(&entity.TenantModel{}) return result.RowsAffected, result.Error @@ -64,6 +69,11 @@ func (dao *TenantModelDAO) DeleteByProviderIDAndInstanceIDAndModelName(provideID return result.RowsAffected, result.Error } +func (dao *TenantModelDAO) UpdateStatusByIDAndScope(modelID, providerID, instanceID, status string) (int64, error) { + result := DB.Model(&entity.TenantModel{}).Where("id = ? AND provider_id = ? AND instance_id = ?", modelID, providerID, instanceID).Update("status", status) + return result.RowsAffected, result.Error +} + // GetByID get tenant model by primary key (id) func (dao *TenantModelDAO) GetByID(id string) (*entity.TenantModel, error) { var model entity.TenantModel diff --git a/internal/dao/tenant_model_test.go b/internal/dao/tenant_model_test.go new file mode 100644 index 0000000000..07b33c5a76 --- /dev/null +++ b/internal/dao/tenant_model_test.go @@ -0,0 +1,113 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "testing" + + "github.com/glebarez/sqlite" + "gorm.io/gorm" + + "ragflow/internal/entity" +) + +func setupTenantModelDAOTestDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{TranslateError: true}) + if err != nil { + t.Fatalf("failed to open sqlite: %v", err) + } + if err := db.AutoMigrate(&entity.TenantModel{}); err != nil { + t.Fatalf("failed to migrate tenant_model: %v", err) + } + return db +} + +func useTenantModelDAOTestDB(t *testing.T, db *gorm.DB) { + t.Helper() + orig := DB + DB = db + t.Cleanup(func() { DB = orig }) +} + +func seedTenantModel(t *testing.T, db *gorm.DB, model *entity.TenantModel) { + t.Helper() + if err := db.Create(model).Error; err != nil { + t.Fatalf("failed to seed tenant model: %v", err) + } +} + +func TestTenantModelDAODeleteByModelIDAndScopeDeletesOnlyMatchingModel(t *testing.T) { + db := setupTenantModelDAOTestDB(t) + useTenantModelDAOTestDB(t, db) + + seedTenantModel(t, db, &entity.TenantModel{ID: "model-delete", ModelName: "m", ModelType: "chat", ProviderID: "provider-1", InstanceID: "instance-1", Status: "active"}) + seedTenantModel(t, db, &entity.TenantModel{ID: "model-keep", ModelName: "m", ModelType: "chat", ProviderID: "provider-1", InstanceID: "instance-2", Status: "active"}) + + rows, err := NewTenantModelDAO().DeleteByModelIDAndProviderIDAndInstanceID("model-delete", "provider-1", "instance-1") + if err != nil { + t.Fatalf("DeleteByModelIDAndProviderIDAndInstanceID() error = %v", err) + } + if rows != 1 { + t.Fatalf("rows = %d, want 1", rows) + } + + var count int64 + if err := db.Model(&entity.TenantModel{}).Where("id = ?", "model-delete").Count(&count).Error; err != nil { + t.Fatalf("count deleted model: %v", err) + } + if count != 0 { + t.Fatalf("deleted model count = %d, want 0", count) + } + if err := db.Model(&entity.TenantModel{}).Where("id = ?", "model-keep").Count(&count).Error; err != nil { + t.Fatalf("count kept model: %v", err) + } + if count != 1 { + t.Fatalf("kept model count = %d, want 1", count) + } +} + +func TestTenantModelDAOUpdateStatusByIDAndScope(t *testing.T) { + db := setupTenantModelDAOTestDB(t) + useTenantModelDAOTestDB(t, db) + + seedTenantModel(t, db, &entity.TenantModel{ID: "model-status", ModelName: "m", ModelType: "chat", ProviderID: "provider-1", InstanceID: "instance-1", Status: "active"}) + + rows, err := NewTenantModelDAO().UpdateStatusByIDAndScope("model-status", "provider-1", "instance-1", "inactive") + if err != nil { + t.Fatalf("UpdateStatusByIDAndScope() error = %v", err) + } + if rows != 1 { + t.Fatalf("rows = %d, want 1", rows) + } + + var got entity.TenantModel + if err := db.Where("id = ?", "model-status").First(&got).Error; err != nil { + t.Fatalf("failed to reload model: %v", err) + } + if got.Status != "inactive" { + t.Fatalf("status = %q, want inactive", got.Status) + } + + rows, err = NewTenantModelDAO().UpdateStatusByIDAndScope("model-status", "provider-1", "wrong-instance", "active") + if err != nil { + t.Fatalf("UpdateStatusByIDAndScope() wrong scope error = %v", err) + } + if rows != 0 { + t.Fatalf("wrong-scope rows = %d, want 0", rows) + } +} diff --git a/internal/dao/tenant_test.go b/internal/dao/tenant_test.go new file mode 100644 index 0000000000..fe5ea3cd12 --- /dev/null +++ b/internal/dao/tenant_test.go @@ -0,0 +1,112 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "testing" + + "github.com/glebarez/sqlite" + "gorm.io/gorm" + + "ragflow/internal/entity" +) + +func setupTenantDAOTestDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{TranslateError: true}) + if err != nil { + t.Fatalf("failed to open sqlite: %v", err) + } + if err := db.AutoMigrate(&entity.Tenant{}); err != nil { + t.Fatalf("failed to migrate tenant: %v", err) + } + return db +} + +func useTenantDAOTestDB(t *testing.T, db *gorm.DB) { + t.Helper() + orig := DB + DB = db + t.Cleanup(func() { DB = orig }) +} + +func TestTenantDAODeleteSoftDeletesTenant(t *testing.T) { + db := setupTenantDAOTestDB(t) + useTenantDAOTestDB(t, db) + + active := "1" + tenant := &entity.Tenant{ + ID: "tenant-delete", + LLMID: "llm", + EmbdID: "embd", + ASRID: "asr", + Img2TxtID: "img2txt", + RerankID: "rerank", + ParserIDs: "naive", + Status: &active, + } + if err := NewTenantDAO().Create(tenant); err != nil { + t.Fatalf("Create() error = %v", err) + } + + if err := NewTenantDAO().Delete(tenant.ID); err != nil { + t.Fatalf("Delete() error = %v", err) + } + + var got entity.Tenant + if err := db.Where("id = ?", tenant.ID).First(&got).Error; err != nil { + t.Fatalf("failed to reload tenant: %v", err) + } + if got.Status == nil || *got.Status != "0" { + t.Fatalf("status = %v, want 0", got.Status) + } + if _, err := NewTenantDAO().GetByID(tenant.ID); err == nil { + t.Fatalf("GetByID() after Delete() error = nil, want not found") + } +} + +func TestTenantDAOUpdateStatus(t *testing.T) { + db := setupTenantDAOTestDB(t) + useTenantDAOTestDB(t, db) + + active := "1" + tenant := &entity.Tenant{ + ID: "tenant-update", + LLMID: "llm", + EmbdID: "embd", + ASRID: "asr", + Img2TxtID: "img2txt", + RerankID: "rerank", + ParserIDs: "naive", + Status: &active, + } + if err := NewTenantDAO().Create(tenant); err != nil { + t.Fatalf("Create() error = %v", err) + } + + if err := NewTenantDAO().Update(tenant.ID, map[string]interface{}{"status": "0"}); err != nil { + t.Fatalf("Update() error = %v", err) + } + + var got entity.Tenant + if err := db.Where("id = ?", tenant.ID).First(&got).Error; err != nil { + t.Fatalf("failed to reload tenant: %v", err) + } + if got.Status == nil || *got.Status != "0" { + t.Fatalf("status = %v, want 0", got.Status) + } +} diff --git a/internal/handler/providers.go b/internal/handler/providers.go index 9f3324e76c..76a865bca7 100644 --- a/internal/handler/providers.go +++ b/internal/handler/providers.go @@ -704,7 +704,8 @@ func (h *ProviderHandler) ListInstanceModels(c *gin.Context) { } type EnableOrDisableModelRequest struct { - Status string `json:"status" binding:"required"` + ModelID string `json:"model_id"` + Status string `json:"status"` } func (h *ProviderHandler) EnableOrDisableModel(c *gin.Context) { @@ -726,18 +727,6 @@ func (h *ProviderHandler) EnableOrDisableModel(c *gin.Context) { return } - modelName := c.Param("model_name") - if modelName != "" { - modelName = strings.TrimPrefix(modelName, "/") - } - if modelName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Model name is required", - }) - return - } - var req EnableOrDisableModelRequest if err := c.ShouldBindJSON(&req); err != nil { println("JSON bind error: %v (type: %T)", err, err) @@ -749,11 +738,30 @@ func (h *ProviderHandler) EnableOrDisableModel(c *gin.Context) { } userID := c.GetString("user_id") + modelID := strings.TrimSpace(req.ModelID) + modelName := strings.TrimPrefix(c.Param("model_name"), "/") + modelName = strings.TrimSpace(modelName) + if modelName == "" && modelID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeBadRequest, + "message": "model_name or model_id is required", + }) + return + } - _, err := h.modelProviderService.UpdateModelStatus(providerName, instanceName, modelName, userID, req.Status) + status := strings.TrimSpace(req.Status) + if status != "active" && status != "inactive" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeBadRequest, + "message": "Status must be active or inactive", + }) + return + } + + code, err := h.modelProviderService.UpdateModelStatus(providerName, instanceName, modelName, userID, modelID, status) if err != nil { c.JSON(http.StatusOK, gin.H{ - "code": common.CodeServerError, + "code": code, "message": err.Error(), }) return @@ -846,7 +854,8 @@ func (h *ProviderHandler) AddModel(c *gin.Context) { } type DropInstanceModelRequest struct { - Models []string `json:"models" binding:"required"` + ModelIDs []string `json:"model_ids"` + Models []string `json:"models"` } func (h *ProviderHandler) DropInstanceModels(c *gin.Context) { @@ -875,13 +884,20 @@ func (h *ProviderHandler) DropInstanceModels(c *gin.Context) { }) return } + if len(req.ModelIDs) == 0 && len(req.Models) == 0 { + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeBadRequest, + "message": "model_ids or models is required", + }) + return + } userID := c.GetString("user_id") - _, err := h.modelProviderService.DropInstanceModels(providerName, instanceName, userID, req.Models) + code, err := h.modelProviderService.DropInstanceModels(providerName, instanceName, userID, req.ModelIDs, req.Models) if err != nil { c.JSON(http.StatusOK, gin.H{ - "code": common.CodeServerError, + "code": code, "message": err.Error(), }) return diff --git a/internal/handler/providers_test.go b/internal/handler/providers_test.go index abeebd162e..8f65e293fb 100644 --- a/internal/handler/providers_test.go +++ b/internal/handler/providers_test.go @@ -1 +1,156 @@ package handler + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/glebarez/sqlite" + "gorm.io/gorm" + + "ragflow/internal/common" + "ragflow/internal/dao" + "ragflow/internal/entity" + "ragflow/internal/service" +) + +func setupProviderHandlerTestDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{TranslateError: true}) + if err != nil { + t.Fatalf("failed to open sqlite: %v", err) + } + if err := db.AutoMigrate( + &entity.UserTenant{}, + &entity.TenantModelProvider{}, + &entity.TenantModelInstance{}, + &entity.TenantModel{}, + ); err != nil { + t.Fatalf("failed to migrate provider handler tables: %v", err) + } + return db +} + +func useProviderHandlerTestDB(t *testing.T, db *gorm.DB) { + t.Helper() + orig := dao.DB + dao.DB = db + t.Cleanup(func() { dao.DB = orig }) +} + +func seedProviderHandlerModel(t *testing.T, db *gorm.DB) { + t.Helper() + activeStatus := "1" + rows := []interface{}{ + &entity.UserTenant{ID: "user-tenant-1", UserID: "user-1", TenantID: "tenant-1", Role: "owner", InvitedBy: "user-1", Status: &activeStatus}, + &entity.TenantModelProvider{ID: "provider-1", TenantID: "tenant-1", ProviderName: "OpenAI"}, + &entity.TenantModelInstance{ID: "instance-1", ProviderID: "provider-1", InstanceName: "default", APIKey: "sk-test", Status: "active", Extra: "{}"}, + &entity.TenantModel{ID: "model-1", ProviderID: "provider-1", InstanceID: "instance-1", ModelName: "gpt-test", ModelType: "chat", Status: "active"}, + } + for _, row := range rows { + if err := db.Create(row).Error; err != nil { + t.Fatalf("failed to seed %T: %v", row, err) + } + } +} + +func newProviderHandlerRequest(t *testing.T, body map[string]interface{}, params ...gin.Param) (*gin.Context, *httptest.ResponseRecorder) { + t.Helper() + gin.SetMode(gin.TestMode) + + payload, err := json.Marshal(body) + if err != nil { + t.Fatalf("marshal body: %v", err) + } + req := httptest.NewRequest(http.MethodPatch, "/providers/OpenAI/instances/default/models/gpt-test", bytes.NewReader(payload)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = req + ctx.Params = params + ctx.Set("user_id", "user-1") + return ctx, recorder +} + +func decodeProviderHandlerResponse(t *testing.T, recorder *httptest.ResponseRecorder) map[string]interface{} { + t.Helper() + var body map[string]interface{} + if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil { + t.Fatalf("decode response %q: %v", recorder.Body.String(), err) + } + return body +} + +func TestProviderHandlerEnableOrDisableModelRejectsMissingModelSelector(t *testing.T) { + ctx, recorder := newProviderHandlerRequest( + t, + map[string]interface{}{"status": "active"}, + gin.Param{Key: "provider_name", Value: "OpenAI"}, + gin.Param{Key: "instance_name", Value: "default"}, + ) + + NewProviderHandler(nil, service.NewModelProviderService()).EnableOrDisableModel(ctx) + + if recorder.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d; body=%s", recorder.Code, http.StatusBadRequest, recorder.Body.String()) + } + body := decodeProviderHandlerResponse(t, recorder) + if common.ErrorCode(body["code"].(float64)) != common.CodeBadRequest { + t.Fatalf("code = %v, want %v", body["code"], common.CodeBadRequest) + } +} + +func TestProviderHandlerEnableOrDisableModelRejectsInvalidStatus(t *testing.T) { + ctx, recorder := newProviderHandlerRequest( + t, + map[string]interface{}{"status": "disabled"}, + gin.Param{Key: "provider_name", Value: "OpenAI"}, + gin.Param{Key: "instance_name", Value: "default"}, + gin.Param{Key: "model_name", Value: "gpt-test"}, + ) + + NewProviderHandler(nil, service.NewModelProviderService()).EnableOrDisableModel(ctx) + + if recorder.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d; body=%s", recorder.Code, http.StatusBadRequest, recorder.Body.String()) + } + body := decodeProviderHandlerResponse(t, recorder) + if common.ErrorCode(body["code"].(float64)) != common.CodeBadRequest { + t.Fatalf("code = %v, want %v", body["code"], common.CodeBadRequest) + } +} + +func TestProviderHandlerEnableOrDisableModelUpdatesStatus(t *testing.T) { + db := setupProviderHandlerTestDB(t) + useProviderHandlerTestDB(t, db) + seedProviderHandlerModel(t, db) + + ctx, recorder := newProviderHandlerRequest( + t, + map[string]interface{}{"status": "inactive"}, + gin.Param{Key: "provider_name", Value: "OpenAI"}, + gin.Param{Key: "instance_name", Value: "default"}, + gin.Param{Key: "model_name", Value: "gpt-test"}, + ) + + NewProviderHandler(nil, service.NewModelProviderService()).EnableOrDisableModel(ctx) + + if recorder.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", recorder.Code, http.StatusOK, recorder.Body.String()) + } + body := decodeProviderHandlerResponse(t, recorder) + if common.ErrorCode(body["code"].(float64)) != common.CodeSuccess { + t.Fatalf("code = %v, want %v", body["code"], common.CodeSuccess) + } + + var got entity.TenantModel + if err := db.Where("id = ?", "model-1").First(&got).Error; err != nil { + t.Fatalf("failed to reload model: %v", err) + } + if got.Status != "inactive" { + t.Fatalf("status = %q, want inactive", got.Status) + } +} diff --git a/internal/service/model_service.go b/internal/service/model_service.go index d75278db54..93d7926365 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -966,7 +966,7 @@ func (m *ModelProviderService) DropProviderInstances(providerName, userID string return common.CodeSuccess, nil } -func (m *ModelProviderService) DropInstanceModels(providerName, instanceName, userID string, models []string) (common.ErrorCode, error) { +func (m *ModelProviderService) DropInstanceModels(providerName, instanceName, userID string, modelIDs, models []string) (common.ErrorCode, error) { // Get tenant ID from user tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") @@ -992,7 +992,27 @@ func (m *ModelProviderService) DropInstanceModels(providerName, instanceName, us return common.CodeServerError, err } + for _, modelID := range modelIDs { + modelID = strings.TrimSpace(modelID) + if modelID == "" { + return common.CodeBadRequest, errors.New("model ID is required") + } + var count int64 = 0 + count, err = m.modelDAO.DeleteByModelIDAndProviderIDAndInstanceID(modelID, provider.ID, modelInstance.ID) + if err != nil { + return common.CodeServerError, err + } + + if count == 0 { + return common.CodeNotFound, fmt.Errorf("model %s not found", modelID) + } + } + for _, modelName := range models { + modelName = strings.TrimSpace(modelName) + if modelName == "" { + return common.CodeBadRequest, errors.New("model name is required") + } // Delete all models of this instance var count int64 = 0 count, err = m.modelDAO.DeleteByProviderIDAndInstanceIDAndModelName(provider.ID, modelInstance.ID, modelName) @@ -1068,7 +1088,16 @@ func (m *ModelProviderService) ListInstanceModels(providerName, instanceName, us return allModels, nil } -func (m *ModelProviderService) UpdateModelStatus(providerName, instanceName, modelName, userID, status string) (common.ErrorCode, error) { +func (m *ModelProviderService) UpdateModelStatus(providerName, instanceName, modelName, userID, modelID, status string) (common.ErrorCode, error) { + modelName = strings.TrimSpace(modelName) + modelID = strings.TrimSpace(modelID) + status = strings.TrimSpace(status) + if status != "active" && status != "inactive" { + return common.CodeBadRequest, errors.New("status must be active or inactive") + } + if modelName == "" && modelID == "" { + return common.CodeBadRequest, errors.New("model name or model ID is required") + } // Get tenant ID from user tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") @@ -1093,34 +1122,66 @@ func (m *ModelProviderService) UpdateModelStatus(providerName, instanceName, mod return common.CodeServerError, err } - model, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) - if err != nil { - var modelID string - modelID = utility.GenerateToken() + var model *entity.TenantModel - var modelSchema *modelModule.Model - modelSchema, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) - if err != nil { - return common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) + if modelID != "" { + model, err = m.modelDAO.GetByID(modelID) + if err != nil || model == nil { + return common.CodeNotFound, errors.New("model not found") } - // Get model info from provider - model = &entity.TenantModel{ - ID: modelID, - ModelName: modelName, - ModelType: modelSchema.ModelTypes[0], - ProviderID: provider.ID, - InstanceID: instance.ID, - Status: status, + if model.ProviderID != provider.ID || model.InstanceID != instance.ID { + return common.CodeNotFound, errors.New("model not found") } - err = m.modelDAO.Create(model) + + if modelName != "" && model.ModelName != modelName { + return common.CodeBadRequest, errors.New("model ID does not match model name") + } + + count, err := m.modelDAO.UpdateStatusByIDAndScope(modelID, provider.ID, instance.ID, status) if err != nil { - return common.CodeServerError, errors.New("fail to create model") + return common.CodeServerError, err + } + if count == 0 { + return common.CodeNotFound, errors.New("model not found") } return common.CodeSuccess, nil + } else { + model, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) + if err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + return common.CodeServerError, err + } + + modelID = utility.GenerateToken() + + var modelSchema *modelModule.Model + modelSchema, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err != nil { + return common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) + } + + // Get model info from provider + if len(modelSchema.ModelTypes) == 0 { + return common.CodeServerError, fmt.Errorf("provider %s model %s has no model types", providerName, modelName) + } + model = &entity.TenantModel{ + ID: modelID, + ModelName: modelName, + ModelType: modelSchema.ModelTypes[0], + ProviderID: provider.ID, + InstanceID: instance.ID, + Status: status, + } + err = m.modelDAO.Create(model) + if err != nil { + return common.CodeServerError, errors.New("fail to create model") + } + return common.CodeSuccess, nil + } } - count, err := m.modelDAO.DeleteByModelID(model.ID) + count, err := m.modelDAO.UpdateStatusByIDAndScope(model.ID, provider.ID, instance.ID, status) if err != nil { return common.CodeServerError, err } @@ -1387,23 +1448,28 @@ func (m *ModelProviderService) ChatToModelWithMessages(providerName, instanceNam thinking := info.ModelInfo.Thinking.DefaultValue modelConfig.Thinking = &thinking } + resolvedModelName := info.ModelInfo.Name + if info.ModelEntity != nil && info.ModelEntity.ModelName != "" { + resolvedModelName = info.ModelEntity.ModelName + } + resolvedProviderName := info.ProviderEntity.ProviderName var response *modelModule.ChatResponse var modelDriver modelModule.ModelDriver if info.ModelEntity == nil { if !info.ModelInfo.ModelTypeMap["chat"] && !info.ModelInfo.ModelTypeMap["vision"] { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a chat or multimodal model", *modelName, *providerName)) + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a chat or multimodal model", resolvedModelName, resolvedProviderName)) } modelDriver = info.ProviderInfo.ModelDriver } else { // model entity exists if info.ModelEntity.Status == "active" { if info.ModelEntity.ModelType != "chat" && info.ModelEntity.ModelType != "vision" { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a chat or multimodal model", *modelName, *providerName)) + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a chat or multimodal model", resolvedModelName, resolvedProviderName)) } - modelDriver, err = newModelDriverForBaseURL(info.ProviderInfo.ModelDriver, *providerName, *info.APIConfig.Region, *info.APIConfig.BaseURL) + modelDriver, err = newModelDriverForBaseURL(info.ProviderInfo.ModelDriver, resolvedProviderName, *info.APIConfig.Region, *info.APIConfig.BaseURL) if err != nil { return nil, common.CodeServerError, err } @@ -1412,7 +1478,7 @@ func (m *ModelProviderService) ChatToModelWithMessages(providerName, instanceNam } } - response, err = modelDriver.ChatWithMessages(*modelName, messages, info.APIConfig, modelConfig) + response, err = modelDriver.ChatWithMessages(resolvedModelName, messages, info.APIConfig, modelConfig) if err != nil { return nil, common.CodeServerError, err } @@ -1449,6 +1515,11 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc thinking := info.ModelInfo.Thinking.DefaultValue modelConfig.Thinking = &thinking } + resolvedModelName := info.ModelInfo.Name + if info.ModelEntity != nil && info.ModelEntity.ModelName != "" { + resolvedModelName = info.ModelEntity.ModelName + } + resolvedProviderName := info.ProviderEntity.ProviderName var modelDriver modelModule.ModelDriver @@ -1458,10 +1529,10 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc // model entity exists if info.ModelEntity.Status == "active" { if info.ModelEntity.ModelType != "chat" && info.ModelEntity.ModelType != "vision" { - return common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a chat or multimodal model", *modelName, *providerName)) + return common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a chat or multimodal model", resolvedModelName, resolvedProviderName)) } - modelDriver, err = newModelDriverForBaseURL(info.ProviderInfo.ModelDriver, *providerName, *info.APIConfig.Region, *info.APIConfig.BaseURL) + modelDriver, err = newModelDriverForBaseURL(info.ProviderInfo.ModelDriver, resolvedProviderName, *info.APIConfig.Region, *info.APIConfig.BaseURL) if err != nil { return common.CodeServerError, err } @@ -1470,7 +1541,7 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc } } - err = modelDriver.ChatStreamlyWithSender(*modelName, messages, info.APIConfig, modelConfig, sender) + err = modelDriver.ChatStreamlyWithSender(resolvedModelName, messages, info.APIConfig, modelConfig, sender) if err != nil { return common.CodeServerError, err } @@ -1528,22 +1599,27 @@ func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName, if modelConfig == nil { modelConfig = &modelModule.EmbeddingConfig{} } + resolvedModelName := info.ModelInfo.Name + if info.ModelEntity != nil && info.ModelEntity.ModelName != "" { + resolvedModelName = info.ModelEntity.ModelName + } + resolvedProviderName := info.ProviderEntity.ProviderName var modelDriver modelModule.ModelDriver if info.ModelEntity == nil { if !info.ModelInfo.ModelTypeMap["embedding"] { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is an embedding model", *modelName, *providerName)) + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is an embedding model", resolvedModelName, resolvedProviderName)) } modelDriver = info.ProviderInfo.ModelDriver } else { // model entity exists if info.ModelEntity.Status == "active" { if info.ModelEntity.ModelType != "embedding" { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is an embedding model", *modelName, *providerName)) + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is an embedding model", resolvedModelName, resolvedProviderName)) } - modelDriver, err = newModelDriverForBaseURL(info.ProviderInfo.ModelDriver, *providerName, *info.APIConfig.Region, *info.APIConfig.BaseURL) + modelDriver, err = newModelDriverForBaseURL(info.ProviderInfo.ModelDriver, resolvedProviderName, *info.APIConfig.Region, *info.APIConfig.BaseURL) if err != nil { return nil, common.CodeServerError, err } @@ -1557,7 +1633,7 @@ func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName, } var response []modelModule.EmbeddingData - response, err = modelDriver.Embed(modelName, texts, info.APIConfig, modelConfig) + response, err = modelDriver.Embed(&resolvedModelName, texts, info.APIConfig, modelConfig) if err != nil { return nil, common.CodeServerError, err } @@ -1589,22 +1665,27 @@ func (m *ModelProviderService) RerankDocument(providerName, instanceName, modelN if modelConfig == nil { modelConfig = &modelModule.RerankConfig{} } + resolvedModelName := info.ModelInfo.Name + if info.ModelEntity != nil && info.ModelEntity.ModelName != "" { + resolvedModelName = info.ModelEntity.ModelName + } + resolvedProviderName := info.ProviderEntity.ProviderName var modelDriver modelModule.ModelDriver if info.ModelEntity == nil { if !info.ModelInfo.ModelTypeMap["rerank"] { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a rerank model", *modelName, *providerName)) + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a rerank model", resolvedModelName, resolvedProviderName)) } modelDriver = info.ProviderInfo.ModelDriver } else { // model entity exists if info.ModelEntity.Status == "active" { if info.ModelEntity.ModelType != "rerank" { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a rerank model", *modelName, *providerName)) + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a rerank model", resolvedModelName, resolvedProviderName)) } - modelDriver, err = newModelDriverForBaseURL(info.ProviderInfo.ModelDriver, *providerName, *info.APIConfig.Region, *info.APIConfig.BaseURL) + modelDriver, err = newModelDriverForBaseURL(info.ProviderInfo.ModelDriver, resolvedProviderName, *info.APIConfig.Region, *info.APIConfig.BaseURL) if err != nil { return nil, common.CodeServerError, err } @@ -1614,7 +1695,7 @@ func (m *ModelProviderService) RerankDocument(providerName, instanceName, modelN } var response *modelModule.RerankResponse - response, err = modelDriver.Rerank(modelName, query, documents, apiConfig, modelConfig) + response, err = modelDriver.Rerank(&resolvedModelName, query, documents, info.APIConfig, modelConfig) if err != nil { return nil, common.CodeServerError, err } diff --git a/internal/service/model_service_test.go b/internal/service/model_service_test.go index b482f4fbbe..0d4b6b2199 100644 --- a/internal/service/model_service_test.go +++ b/internal/service/model_service_test.go @@ -4,6 +4,11 @@ import ( "strings" "testing" + "github.com/glebarez/sqlite" + "gorm.io/gorm" + + "ragflow/internal/common" + "ragflow/internal/dao" "ragflow/internal/entity" modelModule "ragflow/internal/entity/models" ) @@ -109,3 +114,108 @@ func TestModelInfoWithTenantExtraAppliesEmbeddingDimensions(t *testing.T) { t.Fatalf("factory Dimensions were mutated: %v", modelInfo.Dimensions) } } + +func setupModelProviderServiceTestDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{TranslateError: true}) + if err != nil { + t.Fatalf("failed to open sqlite: %v", err) + } + if err := db.AutoMigrate( + &entity.UserTenant{}, + &entity.TenantModelProvider{}, + &entity.TenantModelInstance{}, + &entity.TenantModel{}, + ); err != nil { + t.Fatalf("failed to migrate model service tables: %v", err) + } + return db +} + +func useModelProviderServiceTestDB(t *testing.T, db *gorm.DB) { + t.Helper() + orig := dao.DB + dao.DB = db + t.Cleanup(func() { dao.DB = orig }) +} + +func seedModelProviderServiceScope(t *testing.T, db *gorm.DB) { + t.Helper() + activeStatus := "1" + rows := []interface{}{ + &entity.UserTenant{ID: "user-tenant-1", UserID: "user-1", TenantID: "tenant-1", Role: "owner", InvitedBy: "user-1", Status: &activeStatus}, + &entity.TenantModelProvider{ID: "provider-1", TenantID: "tenant-1", ProviderName: "OpenAI"}, + &entity.TenantModelInstance{ID: "instance-1", ProviderID: "provider-1", InstanceName: "default", APIKey: "sk-test", Status: "active", Extra: "{}"}, + &entity.TenantModel{ID: "model-1", ProviderID: "provider-1", InstanceID: "instance-1", ModelName: "gpt-test", ModelType: "chat", Status: "active"}, + } + for _, row := range rows { + if err := db.Create(row).Error; err != nil { + t.Fatalf("failed to seed %T: %v", row, err) + } + } +} + +func TestModelProviderServiceUpdateModelStatusByID(t *testing.T) { + db := setupModelProviderServiceTestDB(t) + useModelProviderServiceTestDB(t, db) + seedModelProviderServiceScope(t, db) + + code, err := NewModelProviderService().UpdateModelStatus("OpenAI", "default", "", "user-1", "model-1", "inactive") + if err != nil { + t.Fatalf("UpdateModelStatus() error = %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("code = %v, want %v", code, common.CodeSuccess) + } + + var got entity.TenantModel + if err := db.Where("id = ?", "model-1").First(&got).Error; err != nil { + t.Fatalf("failed to reload tenant model: %v", err) + } + if got.Status != "inactive" { + t.Fatalf("status = %q, want inactive", got.Status) + } +} + +func TestModelProviderServiceUpdateModelStatusRejectsInvalidStatus(t *testing.T) { + code, err := NewModelProviderService().UpdateModelStatus("OpenAI", "default", "", "user-1", "model-1", "disabled") + if err == nil { + t.Fatalf("UpdateModelStatus() error = nil, want invalid status error") + } + if code != common.CodeBadRequest { + t.Fatalf("code = %v, want %v", code, common.CodeBadRequest) + } + if !strings.Contains(err.Error(), "status must be active or inactive") { + t.Fatalf("error = %v, want status validation message", err) + } +} + +func TestModelProviderServiceUpdateModelStatusRejectsMissingModelSelector(t *testing.T) { + code, err := NewModelProviderService().UpdateModelStatus("OpenAI", "default", "", "user-1", "", "active") + if err == nil { + t.Fatalf("UpdateModelStatus() error = nil, want missing model selector error") + } + if code != common.CodeBadRequest { + t.Fatalf("code = %v, want %v", code, common.CodeBadRequest) + } + if !strings.Contains(err.Error(), "model name or model ID is required") { + t.Fatalf("error = %v, want missing model selector message", err) + } +} + +func TestModelProviderServiceUpdateModelStatusRejectsWrongScopedModelID(t *testing.T) { + db := setupModelProviderServiceTestDB(t) + useModelProviderServiceTestDB(t, db) + seedModelProviderServiceScope(t, db) + if err := db.Create(&entity.TenantModelInstance{ID: "instance-2", ProviderID: "provider-1", InstanceName: "other", APIKey: "sk-test", Status: "active", Extra: "{}"}).Error; err != nil { + t.Fatalf("failed to seed second instance: %v", err) + } + + code, err := NewModelProviderService().UpdateModelStatus("OpenAI", "other", "", "user-1", "model-1", "inactive") + if err == nil { + t.Fatalf("UpdateModelStatus() error = nil, want not found error") + } + if code != common.CodeNotFound { + t.Fatalf("code = %v, want %v", code, common.CodeNotFound) + } +}