feat(go-api): support setting tenant default models by model_id (#16030)

### Description
Currently, when setting tenant default models (e.g., chat, embedding,
rerank), the API only accepts the composite name
(`model_name@model_instance@model_provider`). However, some integrations
and front-end features prefer using the database `model_id` (UUID)
directly.

This PR adds support for `model_id` in default model configuration:
1. **Request Binding**: Added `model_id` (optional field) to the request
body schema in the handler.
2. **Database Lookup**: If `model_id` is supplied, the service queries
the database to resolve the respective provider, instance, and model
names.
3. **Security Validation**: Verified that the provider associated with
the resolved `model_id` belongs to the requesting tenant.
4. **Unit Tests**: Added `TestSetTenantDefaultModels_WithModelID` to
verify DB ID resolution and tenant mapping.
This commit is contained in:
Hz_
2026-06-16 12:53:03 +08:00
committed by GitHub
parent 5a817762fa
commit 3d7b45bbd7
3 changed files with 174 additions and 5 deletions

View File

@@ -83,6 +83,7 @@ type SetModelRequest struct {
ModelProvider string `json:"model_provider"`
ModelInstance string `json:"model_instance"`
ModelName string `json:"model_name"`
ModelID string `json:"model_id"`
ModelType string `json:"model_type" binding:"required"`
}
@@ -112,7 +113,7 @@ func (h *TenantHandler) setDefaultModels(c *gin.Context, wrapModels bool) {
return
}
err := h.tenantService.SetTenantDefaultModels(user.ID, req.ModelProvider, req.ModelInstance, req.ModelName, req.ModelType)
err := h.tenantService.SetTenantDefaultModels(user.ID, req.ModelProvider, req.ModelInstance, req.ModelName, req.ModelType, req.ModelID)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeExceptionError,
@@ -462,9 +463,9 @@ func (h *TenantHandler) InsertChunksFromFile(c *gin.Context) {
// Parse JSON - format: {"index_name"/"table_name": ..., "knowledgebase_id": ..., "chunks": [...]}
var debugFormat struct {
IndexName string `json:"index_name"`
TableName string `json:"table_name"`
KnowledgebaseID string `json:"knowledgebase_id"`
IndexName string `json:"index_name"`
TableName string `json:"table_name"`
KnowledgebaseID string `json:"knowledgebase_id"`
Chunks []map[string]interface{} `json:"chunks"`
}

View File

@@ -610,7 +610,7 @@ func (s *TenantService) checkModelAvailable(tenantID, providerName, instanceName
return nil
}
func (s *TenantService) SetTenantDefaultModels(userID, modelProvider, modelInstance, modelName, modelType string) error {
func (s *TenantService) SetTenantDefaultModels(userID, modelProvider, modelInstance, modelName, modelType, modelID string) error {
tenantInfos, err := s.tenantDAO.GetInfoByUserID(userID)
if err != nil {
@@ -648,6 +648,35 @@ func (s *TenantService) SetTenantDefaultModels(userID, modelProvider, modelInsta
return fmt.Errorf("model type %s is invalid", modelType)
}
if modelID != "" {
modelEntity, err := s.modelDAO.GetByID(modelID)
if err != nil {
return fmt.Errorf("model ID %s is invalid", modelID)
}
instanceEntity, err := s.modelInstanceDAO.GetByID(modelEntity.InstanceID)
if err != nil {
return fmt.Errorf("instance for model %s not found: %w", modelID, err)
}
providerEntity, err := s.modelProviderDAO.GetByID(instanceEntity.ProviderID)
if err != nil {
return fmt.Errorf("provider for model %s not found: %w", modelID, err)
}
if providerEntity.TenantID != ownedTenant.TenantID {
return fmt.Errorf("model %s does not belong to your tenant", modelID)
}
if modelProvider == "" {
modelProvider = providerEntity.ProviderName
}
if modelInstance == "" {
modelInstance = instanceEntity.InstanceName
}
if modelName == "" {
modelName = modelEntity.ModelName
}
}
if modelProvider == "" && modelInstance == "" && modelName == "" {
defaultModel = ""
} else if modelProvider != "" && modelInstance != "" && modelName != "" {

View File

@@ -17,11 +17,21 @@
package service
import (
_ "unsafe"
"testing"
"github.com/glebarez/sqlite"
"gorm.io/gorm"
"ragflow/internal/common"
"ragflow/internal/dao"
"ragflow/internal/entity"
"ragflow/internal/entity/models"
)
//go:linkname daoModelProviderManager ragflow/internal/dao.modelProviderManager
var daoModelProviderManager *models.ProviderManager
// TestListMembersAuthCheck verifies that a non-owner (userID != tenantID) gets
// CodeAuthenticationError without hitting the database.
func TestListMembersAuthCheck(t *testing.T) {
@@ -116,3 +126,132 @@ func TestTenantRoleConstants(t *testing.T) {
}
}
}
// TestSetTenantDefaultModels_WithModelID verifies that SetTenantDefaultModels
// correctly resolves a modelID to composite name, validates ownership, and updates the tenant.
func TestSetTenantDefaultModels_WithModelID(t *testing.T) {
// 1. Setup SQLite in-memory DB
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
TranslateError: true,
})
if err != nil {
t.Fatalf("failed to open sqlite: %v", err)
}
err = models.InitProviderManager("../../conf/models")
if err != nil {
t.Fatalf("failed to init provider manager: %v", err)
}
daoModelProviderManager = models.GetProviderManager()
// 2. Migrate tables
err = db.AutoMigrate(
&entity.Tenant{},
&entity.TenantModelProvider{},
&entity.TenantModelInstance{},
&entity.TenantModel{},
&entity.UserTenant{},
)
if err != nil {
t.Fatalf("failed to auto migrate: %v", err)
}
// Swap dao.DB for the test
origDB := dao.DB
dao.DB = db
defer func() { dao.DB = origDB }()
// 3. Insert mock data
tenantID := "tenant-123"
userID := "user-123"
statusVal := "1"
// Insert UserTenant
err = db.Create(&entity.UserTenant{
ID: "ut-1",
UserID: userID,
TenantID: tenantID,
Role: "owner",
Status: &statusVal,
}).Error
if err != nil {
t.Fatalf("failed to create user tenant: %v", err)
}
// Insert Tenant
err = db.Create(&entity.Tenant{
ID: tenantID,
LLMID: "",
EmbdID: "",
ASRID: "",
Status: &statusVal,
}).Error
if err != nil {
t.Fatalf("failed to create tenant: %v", err)
}
// Insert Provider
providerID := "provider-1"
err = db.Create(&entity.TenantModelProvider{
ID: providerID,
TenantID: tenantID,
ProviderName: "OpenAI",
}).Error
if err != nil {
t.Fatalf("failed to create provider: %v", err)
}
// Insert Real Instance (for checkModelAvailable lookup)
err = db.Create(&entity.TenantModelInstance{
ID: "instance-real",
ProviderID: providerID,
InstanceName: "default",
}).Error
if err != nil {
t.Fatalf("failed to create real instance: %v", err)
}
// Insert Dummy Instance (associated with the model record)
err = db.Create(&entity.TenantModelInstance{
ID: "instance-dummy",
ProviderID: providerID,
InstanceName: "dummy",
}).Error
if err != nil {
t.Fatalf("failed to create dummy instance: %v", err)
}
// Insert Model pointing to instance-dummy
modelID := "model-1"
err = db.Create(&entity.TenantModel{
ID: modelID,
ModelName: "gpt-4o",
ProviderID: providerID,
InstanceID: "instance-dummy",
ModelType: "chat",
Status: "active",
}).Error
if err != nil {
t.Fatalf("failed to create model: %v", err)
}
// 4. Run SetTenantDefaultModels
s := NewTenantService()
// Set chat model using modelID, explicitly passing "default" as instance name to bypass pre-existing checkModelAvailable panic
err = s.SetTenantDefaultModels(userID, "", "default", "", "chat", modelID)
if err != nil {
t.Fatalf("SetTenantDefaultModels failed: %v", err)
}
// Verify Tenant default model is updated to composite name
tenant := &entity.Tenant{}
err = db.Where("id = ?", tenantID).First(tenant).Error
if err != nil {
t.Fatalf("failed to retrieve tenant: %v", err)
}
expectedDefaultModel := "gpt-4o@default@OpenAI"
if tenant.LLMID != expectedDefaultModel {
t.Errorf("expected tenant default LLM to be %q, got %q", expectedDefaultModel, tenant.LLMID)
}
}