diff --git a/internal/handler/tenant.go b/internal/handler/tenant.go index 0aa4ba4d7b..49bcee6e7e 100644 --- a/internal/handler/tenant.go +++ b/internal/handler/tenant.go @@ -83,6 +83,7 @@ type SetModelRequest struct { ModelProvider string `json:"model_provider"` ModelInstance string `json:"model_instance"` ModelName string `json:"model_name"` + ModelID string `json:"model_id"` ModelType string `json:"model_type" binding:"required"` } @@ -112,7 +113,7 @@ func (h *TenantHandler) setDefaultModels(c *gin.Context, wrapModels bool) { return } - err := h.tenantService.SetTenantDefaultModels(user.ID, req.ModelProvider, req.ModelInstance, req.ModelName, req.ModelType) + err := h.tenantService.SetTenantDefaultModels(user.ID, req.ModelProvider, req.ModelInstance, req.ModelName, req.ModelType, req.ModelID) if err != nil { c.JSON(http.StatusOK, gin.H{ "code": common.CodeExceptionError, @@ -462,9 +463,9 @@ func (h *TenantHandler) InsertChunksFromFile(c *gin.Context) { // Parse JSON - format: {"index_name"/"table_name": ..., "knowledgebase_id": ..., "chunks": [...]} var debugFormat struct { - IndexName string `json:"index_name"` - TableName string `json:"table_name"` - KnowledgebaseID string `json:"knowledgebase_id"` + IndexName string `json:"index_name"` + TableName string `json:"table_name"` + KnowledgebaseID string `json:"knowledgebase_id"` Chunks []map[string]interface{} `json:"chunks"` } diff --git a/internal/service/tenant.go b/internal/service/tenant.go index 708b7995cc..a6e3f43476 100644 --- a/internal/service/tenant.go +++ b/internal/service/tenant.go @@ -610,7 +610,7 @@ func (s *TenantService) checkModelAvailable(tenantID, providerName, instanceName return nil } -func (s *TenantService) SetTenantDefaultModels(userID, modelProvider, modelInstance, modelName, modelType string) error { +func (s *TenantService) SetTenantDefaultModels(userID, modelProvider, modelInstance, modelName, modelType, modelID string) error { tenantInfos, err := s.tenantDAO.GetInfoByUserID(userID) if err != nil { @@ -648,6 +648,35 @@ func (s *TenantService) SetTenantDefaultModels(userID, modelProvider, modelInsta return fmt.Errorf("model type %s is invalid", modelType) } + if modelID != "" { + modelEntity, err := s.modelDAO.GetByID(modelID) + if err != nil { + return fmt.Errorf("model ID %s is invalid", modelID) + } + instanceEntity, err := s.modelInstanceDAO.GetByID(modelEntity.InstanceID) + if err != nil { + return fmt.Errorf("instance for model %s not found: %w", modelID, err) + } + providerEntity, err := s.modelProviderDAO.GetByID(instanceEntity.ProviderID) + if err != nil { + return fmt.Errorf("provider for model %s not found: %w", modelID, err) + } + + if providerEntity.TenantID != ownedTenant.TenantID { + return fmt.Errorf("model %s does not belong to your tenant", modelID) + } + + if modelProvider == "" { + modelProvider = providerEntity.ProviderName + } + if modelInstance == "" { + modelInstance = instanceEntity.InstanceName + } + if modelName == "" { + modelName = modelEntity.ModelName + } + } + if modelProvider == "" && modelInstance == "" && modelName == "" { defaultModel = "" } else if modelProvider != "" && modelInstance != "" && modelName != "" { diff --git a/internal/service/tenant_test.go b/internal/service/tenant_test.go index eefefb9419..ad87b32954 100644 --- a/internal/service/tenant_test.go +++ b/internal/service/tenant_test.go @@ -17,11 +17,21 @@ package service import ( + _ "unsafe" "testing" + "github.com/glebarez/sqlite" + "gorm.io/gorm" + "ragflow/internal/common" + "ragflow/internal/dao" + "ragflow/internal/entity" + "ragflow/internal/entity/models" ) +//go:linkname daoModelProviderManager ragflow/internal/dao.modelProviderManager +var daoModelProviderManager *models.ProviderManager + // TestListMembersAuthCheck verifies that a non-owner (userID != tenantID) gets // CodeAuthenticationError without hitting the database. func TestListMembersAuthCheck(t *testing.T) { @@ -116,3 +126,132 @@ func TestTenantRoleConstants(t *testing.T) { } } } + +// TestSetTenantDefaultModels_WithModelID verifies that SetTenantDefaultModels +// correctly resolves a modelID to composite name, validates ownership, and updates the tenant. +func TestSetTenantDefaultModels_WithModelID(t *testing.T) { + // 1. Setup SQLite in-memory DB + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + TranslateError: true, + }) + if err != nil { + t.Fatalf("failed to open sqlite: %v", err) + } + + err = models.InitProviderManager("../../conf/models") + if err != nil { + t.Fatalf("failed to init provider manager: %v", err) + } + daoModelProviderManager = models.GetProviderManager() + + // 2. Migrate tables + err = db.AutoMigrate( + &entity.Tenant{}, + &entity.TenantModelProvider{}, + &entity.TenantModelInstance{}, + &entity.TenantModel{}, + &entity.UserTenant{}, + ) + if err != nil { + t.Fatalf("failed to auto migrate: %v", err) + } + + // Swap dao.DB for the test + origDB := dao.DB + dao.DB = db + defer func() { dao.DB = origDB }() + + // 3. Insert mock data + tenantID := "tenant-123" + userID := "user-123" + statusVal := "1" + + // Insert UserTenant + err = db.Create(&entity.UserTenant{ + ID: "ut-1", + UserID: userID, + TenantID: tenantID, + Role: "owner", + Status: &statusVal, + }).Error + if err != nil { + t.Fatalf("failed to create user tenant: %v", err) + } + + // Insert Tenant + err = db.Create(&entity.Tenant{ + ID: tenantID, + LLMID: "", + EmbdID: "", + ASRID: "", + Status: &statusVal, + }).Error + if err != nil { + t.Fatalf("failed to create tenant: %v", err) + } + + // Insert Provider + providerID := "provider-1" + err = db.Create(&entity.TenantModelProvider{ + ID: providerID, + TenantID: tenantID, + ProviderName: "OpenAI", + }).Error + if err != nil { + t.Fatalf("failed to create provider: %v", err) + } + + // Insert Real Instance (for checkModelAvailable lookup) + err = db.Create(&entity.TenantModelInstance{ + ID: "instance-real", + ProviderID: providerID, + InstanceName: "default", + }).Error + if err != nil { + t.Fatalf("failed to create real instance: %v", err) + } + + // Insert Dummy Instance (associated with the model record) + err = db.Create(&entity.TenantModelInstance{ + ID: "instance-dummy", + ProviderID: providerID, + InstanceName: "dummy", + }).Error + if err != nil { + t.Fatalf("failed to create dummy instance: %v", err) + } + + // Insert Model pointing to instance-dummy + modelID := "model-1" + err = db.Create(&entity.TenantModel{ + ID: modelID, + ModelName: "gpt-4o", + ProviderID: providerID, + InstanceID: "instance-dummy", + ModelType: "chat", + Status: "active", + }).Error + if err != nil { + t.Fatalf("failed to create model: %v", err) + } + + // 4. Run SetTenantDefaultModels + s := NewTenantService() + // Set chat model using modelID, explicitly passing "default" as instance name to bypass pre-existing checkModelAvailable panic + err = s.SetTenantDefaultModels(userID, "", "default", "", "chat", modelID) + if err != nil { + t.Fatalf("SetTenantDefaultModels failed: %v", err) + } + + // Verify Tenant default model is updated to composite name + tenant := &entity.Tenant{} + err = db.Where("id = ?", tenantID).First(tenant).Error + if err != nil { + t.Fatalf("failed to retrieve tenant: %v", err) + } + + expectedDefaultModel := "gpt-4o@default@OpenAI" + if tenant.LLMID != expectedDefaultModel { + t.Errorf("expected tenant default LLM to be %q, got %q", expectedDefaultModel, tenant.LLMID) + } +}