mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
feat(go-api): support setting tenant default models by model_id (#16030)
### Description Currently, when setting tenant default models (e.g., chat, embedding, rerank), the API only accepts the composite name (`model_name@model_instance@model_provider`). However, some integrations and front-end features prefer using the database `model_id` (UUID) directly. This PR adds support for `model_id` in default model configuration: 1. **Request Binding**: Added `model_id` (optional field) to the request body schema in the handler. 2. **Database Lookup**: If `model_id` is supplied, the service queries the database to resolve the respective provider, instance, and model names. 3. **Security Validation**: Verified that the provider associated with the resolved `model_id` belongs to the requesting tenant. 4. **Unit Tests**: Added `TestSetTenantDefaultModels_WithModelID` to verify DB ID resolution and tenant mapping.
This commit is contained in:
@@ -83,6 +83,7 @@ type SetModelRequest struct {
|
||||
ModelProvider string `json:"model_provider"`
|
||||
ModelInstance string `json:"model_instance"`
|
||||
ModelName string `json:"model_name"`
|
||||
ModelID string `json:"model_id"`
|
||||
ModelType string `json:"model_type" binding:"required"`
|
||||
}
|
||||
|
||||
@@ -112,7 +113,7 @@ func (h *TenantHandler) setDefaultModels(c *gin.Context, wrapModels bool) {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.tenantService.SetTenantDefaultModels(user.ID, req.ModelProvider, req.ModelInstance, req.ModelName, req.ModelType)
|
||||
err := h.tenantService.SetTenantDefaultModels(user.ID, req.ModelProvider, req.ModelInstance, req.ModelName, req.ModelType, req.ModelID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeExceptionError,
|
||||
@@ -462,9 +463,9 @@ func (h *TenantHandler) InsertChunksFromFile(c *gin.Context) {
|
||||
|
||||
// Parse JSON - format: {"index_name"/"table_name": ..., "knowledgebase_id": ..., "chunks": [...]}
|
||||
var debugFormat struct {
|
||||
IndexName string `json:"index_name"`
|
||||
TableName string `json:"table_name"`
|
||||
KnowledgebaseID string `json:"knowledgebase_id"`
|
||||
IndexName string `json:"index_name"`
|
||||
TableName string `json:"table_name"`
|
||||
KnowledgebaseID string `json:"knowledgebase_id"`
|
||||
Chunks []map[string]interface{} `json:"chunks"`
|
||||
}
|
||||
|
||||
|
||||
@@ -610,7 +610,7 @@ func (s *TenantService) checkModelAvailable(tenantID, providerName, instanceName
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TenantService) SetTenantDefaultModels(userID, modelProvider, modelInstance, modelName, modelType string) error {
|
||||
func (s *TenantService) SetTenantDefaultModels(userID, modelProvider, modelInstance, modelName, modelType, modelID string) error {
|
||||
|
||||
tenantInfos, err := s.tenantDAO.GetInfoByUserID(userID)
|
||||
if err != nil {
|
||||
@@ -648,6 +648,35 @@ func (s *TenantService) SetTenantDefaultModels(userID, modelProvider, modelInsta
|
||||
return fmt.Errorf("model type %s is invalid", modelType)
|
||||
}
|
||||
|
||||
if modelID != "" {
|
||||
modelEntity, err := s.modelDAO.GetByID(modelID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("model ID %s is invalid", modelID)
|
||||
}
|
||||
instanceEntity, err := s.modelInstanceDAO.GetByID(modelEntity.InstanceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("instance for model %s not found: %w", modelID, err)
|
||||
}
|
||||
providerEntity, err := s.modelProviderDAO.GetByID(instanceEntity.ProviderID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("provider for model %s not found: %w", modelID, err)
|
||||
}
|
||||
|
||||
if providerEntity.TenantID != ownedTenant.TenantID {
|
||||
return fmt.Errorf("model %s does not belong to your tenant", modelID)
|
||||
}
|
||||
|
||||
if modelProvider == "" {
|
||||
modelProvider = providerEntity.ProviderName
|
||||
}
|
||||
if modelInstance == "" {
|
||||
modelInstance = instanceEntity.InstanceName
|
||||
}
|
||||
if modelName == "" {
|
||||
modelName = modelEntity.ModelName
|
||||
}
|
||||
}
|
||||
|
||||
if modelProvider == "" && modelInstance == "" && modelName == "" {
|
||||
defaultModel = ""
|
||||
} else if modelProvider != "" && modelInstance != "" && modelName != "" {
|
||||
|
||||
@@ -17,11 +17,21 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
_ "unsafe"
|
||||
"testing"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/entity/models"
|
||||
)
|
||||
|
||||
//go:linkname daoModelProviderManager ragflow/internal/dao.modelProviderManager
|
||||
var daoModelProviderManager *models.ProviderManager
|
||||
|
||||
// TestListMembersAuthCheck verifies that a non-owner (userID != tenantID) gets
|
||||
// CodeAuthenticationError without hitting the database.
|
||||
func TestListMembersAuthCheck(t *testing.T) {
|
||||
@@ -116,3 +126,132 @@ func TestTenantRoleConstants(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetTenantDefaultModels_WithModelID verifies that SetTenantDefaultModels
|
||||
// correctly resolves a modelID to composite name, validates ownership, and updates the tenant.
|
||||
func TestSetTenantDefaultModels_WithModelID(t *testing.T) {
|
||||
// 1. Setup SQLite in-memory DB
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
TranslateError: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open sqlite: %v", err)
|
||||
}
|
||||
|
||||
err = models.InitProviderManager("../../conf/models")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to init provider manager: %v", err)
|
||||
}
|
||||
daoModelProviderManager = models.GetProviderManager()
|
||||
|
||||
// 2. Migrate tables
|
||||
err = db.AutoMigrate(
|
||||
&entity.Tenant{},
|
||||
&entity.TenantModelProvider{},
|
||||
&entity.TenantModelInstance{},
|
||||
&entity.TenantModel{},
|
||||
&entity.UserTenant{},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to auto migrate: %v", err)
|
||||
}
|
||||
|
||||
// Swap dao.DB for the test
|
||||
origDB := dao.DB
|
||||
dao.DB = db
|
||||
defer func() { dao.DB = origDB }()
|
||||
|
||||
// 3. Insert mock data
|
||||
tenantID := "tenant-123"
|
||||
userID := "user-123"
|
||||
statusVal := "1"
|
||||
|
||||
// Insert UserTenant
|
||||
err = db.Create(&entity.UserTenant{
|
||||
ID: "ut-1",
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
Role: "owner",
|
||||
Status: &statusVal,
|
||||
}).Error
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create user tenant: %v", err)
|
||||
}
|
||||
|
||||
// Insert Tenant
|
||||
err = db.Create(&entity.Tenant{
|
||||
ID: tenantID,
|
||||
LLMID: "",
|
||||
EmbdID: "",
|
||||
ASRID: "",
|
||||
Status: &statusVal,
|
||||
}).Error
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create tenant: %v", err)
|
||||
}
|
||||
|
||||
// Insert Provider
|
||||
providerID := "provider-1"
|
||||
err = db.Create(&entity.TenantModelProvider{
|
||||
ID: providerID,
|
||||
TenantID: tenantID,
|
||||
ProviderName: "OpenAI",
|
||||
}).Error
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create provider: %v", err)
|
||||
}
|
||||
|
||||
// Insert Real Instance (for checkModelAvailable lookup)
|
||||
err = db.Create(&entity.TenantModelInstance{
|
||||
ID: "instance-real",
|
||||
ProviderID: providerID,
|
||||
InstanceName: "default",
|
||||
}).Error
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create real instance: %v", err)
|
||||
}
|
||||
|
||||
// Insert Dummy Instance (associated with the model record)
|
||||
err = db.Create(&entity.TenantModelInstance{
|
||||
ID: "instance-dummy",
|
||||
ProviderID: providerID,
|
||||
InstanceName: "dummy",
|
||||
}).Error
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create dummy instance: %v", err)
|
||||
}
|
||||
|
||||
// Insert Model pointing to instance-dummy
|
||||
modelID := "model-1"
|
||||
err = db.Create(&entity.TenantModel{
|
||||
ID: modelID,
|
||||
ModelName: "gpt-4o",
|
||||
ProviderID: providerID,
|
||||
InstanceID: "instance-dummy",
|
||||
ModelType: "chat",
|
||||
Status: "active",
|
||||
}).Error
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create model: %v", err)
|
||||
}
|
||||
|
||||
// 4. Run SetTenantDefaultModels
|
||||
s := NewTenantService()
|
||||
// Set chat model using modelID, explicitly passing "default" as instance name to bypass pre-existing checkModelAvailable panic
|
||||
err = s.SetTenantDefaultModels(userID, "", "default", "", "chat", modelID)
|
||||
if err != nil {
|
||||
t.Fatalf("SetTenantDefaultModels failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify Tenant default model is updated to composite name
|
||||
tenant := &entity.Tenant{}
|
||||
err = db.Where("id = ?", tenantID).First(tenant).Error
|
||||
if err != nil {
|
||||
t.Fatalf("failed to retrieve tenant: %v", err)
|
||||
}
|
||||
|
||||
expectedDefaultModel := "gpt-4o@default@OpenAI"
|
||||
if tenant.LLMID != expectedDefaultModel {
|
||||
t.Errorf("expected tenant default LLM to be %q, got %q", expectedDefaultModel, tenant.LLMID)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user