feat[Go] add max_dimension and dimensions for ModelRequest (#16019)

### What problem does this PR solve?

As title

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Haruko386
2026-06-16 10:31:27 +08:00
committed by GitHub
parent e7c068747e
commit efdd58df66
2 changed files with 175 additions and 11 deletions

View File

@@ -1133,6 +1133,79 @@ type ModelInstanceAndProviderInfo struct {
APIConfig *modelModule.APIConfig
}
type tenantModelExtra struct {
MaxTokens *int `json:"max_tokens"`
ModelTypes []string `json:"model_types"`
MaxDimension *int `json:"max_dimension"`
Dimensions []int `json:"dimensions"`
Thinking *bool `json:"thinking"`
}
func modelInfoWithTenantExtra(modelInfo *modelModule.Model, modelEntity *entity.TenantModel) (*modelModule.Model, error) {
if modelInfo == nil || modelEntity == nil || strings.TrimSpace(modelEntity.Extra) == "" {
return modelInfo, nil
}
var extra tenantModelExtra
if err := json.Unmarshal([]byte(modelEntity.Extra), &extra); err != nil {
return nil, err
}
model := *modelInfo
model.ModelTypes = append([]string(nil), modelInfo.ModelTypes...)
model.Dimensions = append([]int(nil), modelInfo.Dimensions...)
model.Alias = append([]string(nil), modelInfo.Alias...)
if modelInfo.ModelTypeMap != nil {
model.ModelTypeMap = make(map[string]bool, len(modelInfo.ModelTypeMap))
for modelType, enabled := range modelInfo.ModelTypeMap {
model.ModelTypeMap[modelType] = enabled
}
}
if modelInfo.Thinking != nil {
thinking := *modelInfo.Thinking
model.Thinking = &thinking
}
if extra.MaxTokens != nil && *extra.MaxTokens > 0 {
model.MaxTokens = extra.MaxTokens
}
if len(extra.ModelTypes) > 0 {
model.ModelTypes = append([]string(nil), extra.ModelTypes...)
model.ModelTypeMap = make(map[string]bool, len(extra.ModelTypes))
for _, modelType := range extra.ModelTypes {
model.ModelTypeMap[modelType] = true
}
}
if extra.MaxDimension != nil && *extra.MaxDimension > 0 {
model.MaxDimension = extra.MaxDimension
}
if len(extra.Dimensions) > 0 {
model.Dimensions = append([]int(nil), extra.Dimensions...)
}
if extra.Thinking != nil {
if model.Thinking == nil {
model.Thinking = &modelModule.ModelThinking{}
}
model.Thinking.DefaultValue = *extra.Thinking
}
return &model, nil
}
func maxTokensFromTenantModelExtra(modelEntity *entity.TenantModel, fallback int) (int, error) {
if modelEntity == nil || strings.TrimSpace(modelEntity.Extra) == "" {
return fallback, nil
}
var extra tenantModelExtra
if err := json.Unmarshal([]byte(modelEntity.Extra), &extra); err != nil {
return 0, err
}
if extra.MaxTokens != nil && *extra.MaxTokens > 0 {
return *extra.MaxTokens, nil
}
return fallback, nil
}
func (m *ModelProviderService) getModelInstanceAndProviderByName(providerName, instanceName, modelName *string, userID string, apiConfig *modelModule.APIConfig) (*ModelInstanceAndProviderInfo, error) {
// Get tenant ID from user
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
@@ -1172,6 +1245,10 @@ func (m *ModelProviderService) getModelInstanceAndProviderByName(providerName, i
if err != nil {
return nil, errors.New(fmt.Sprintf("provider %s model %s not found", *providerName, *modelName))
}
modelInfo, err = modelInfoWithTenantExtra(modelInfo, modelEntity)
if err != nil {
return nil, err
}
var extra map[string]string
err = json.Unmarshal([]byte(instanceEntity.Extra), &extra)
@@ -1243,6 +1320,10 @@ func (m *ModelProviderService) getModelInstanceAndProviderByID(modelID *string,
if err != nil {
return nil, errors.New(fmt.Sprintf("provider %s model %s not found", providerEntity.ProviderName, modelEntity.ModelName))
}
modelInfo, err = modelInfoWithTenantExtra(modelInfo, modelEntity)
if err != nil {
return nil, err
}
var extra map[string]string
err = json.Unmarshal([]byte(instanceEntity.Extra), &extra)
@@ -1295,6 +1376,10 @@ func (m *ModelProviderService) ChatToModelWithMessages(providerName, instanceNam
modelConfig = &modelModule.ChatConfig{}
}
modelConfig.ModelClass = info.ModelInfo.Class
if modelConfig.Thinking == nil && info.ModelInfo.Thinking != nil {
thinking := info.ModelInfo.Thinking.DefaultValue
modelConfig.Thinking = &thinking
}
var response *modelModule.ChatResponse
var modelDriver modelModule.ModelDriver
@@ -1353,6 +1438,10 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc
modelConfig = &modelModule.ChatConfig{}
}
modelConfig.ModelClass = info.ModelInfo.Class
if modelConfig.Thinking == nil && info.ModelInfo.Thinking != nil {
thinking := info.ModelInfo.Thinking.DefaultValue
modelConfig.Thinking = &thinking
}
var modelDriver modelModule.ModelDriver
@@ -1374,7 +1463,7 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc
}
}
err = modelDriver.ChatStreamlyWithSender(*modelName, messages, apiConfig, modelConfig, sender)
err = modelDriver.ChatStreamlyWithSender(*modelName, messages, info.APIConfig, modelConfig, sender)
if err != nil {
return common.CodeServerError, err
}
@@ -1461,7 +1550,7 @@ func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName,
}
var response []modelModule.EmbeddingData
response, err = modelDriver.Embed(modelName, texts, apiConfig, modelConfig)
response, err = modelDriver.Embed(modelName, texts, info.APIConfig, modelConfig)
if err != nil {
return nil, common.CodeServerError, err
}
@@ -1985,10 +2074,12 @@ type AddCustomModelRequest struct {
}
type ModelRequest struct {
ModelName string `json:"model_name"`
ModelTypes []string `json:"model_types"`
MaxTokens int `json:"max_tokens"`
Thinking *bool `json:"thinking"`
ModelName string `json:"model_name"`
ModelTypes []string `json:"model_types"`
MaxTokens int `json:"max_tokens"`
MaxDimension int `json:"max_dimension"`
Dimensions []int `json:"dimensions"`
Thinking *bool `json:"thinking"`
}
func (m *ModelProviderService) AddModel(request *AddModelRequest, userID string) (common.ErrorCode, error) {
@@ -2052,9 +2143,22 @@ func (m *ModelProviderService) AddModel(request *AddModelRequest, userID string)
modelID := utility.GenerateToken()
if model.MaxDimension < 0 {
return common.CodeBadRequest, errors.New("max_dimension must be non-negative")
}
for _, dimension := range model.Dimensions {
if dimension <= 0 {
return common.CodeBadRequest, errors.New("dimensions must contain positive values")
}
if model.MaxDimension > 0 && dimension > model.MaxDimension {
return common.CodeBadRequest, fmt.Errorf("dimension %d exceeds max_dimension %d", dimension, model.MaxDimension)
}
}
extra := map[string]interface{}{
"max_tokens": model.MaxTokens,
"model_types": []string{modelType},
"max_tokens": model.MaxTokens,
"model_types": []string{modelType},
"max_dimension": model.MaxDimension,
"dimensions": model.Dimensions,
}
if model.Thinking != nil {
extra["thinking"] = *model.Thinking
@@ -2209,6 +2313,10 @@ func (m *ModelProviderService) GetModelConfigFromProviderInstance(tenantID strin
maxTokens = *mi.MaxTokens
}
}
maxTokens, driverErr = maxTokensFromTenantModelExtra(modelObj, maxTokens)
if driverErr != nil {
return nil, "", nil, 0, driverErr
}
apiConfig := &modelModule.APIConfig{ApiKey: &apiKey, Region: &region, BaseURL: &baseURL}
return driver, modelObj.ModelName, apiConfig, maxTokens, nil
case errors.Is(modelErr, gorm.ErrRecordNotFound):
@@ -2300,12 +2408,14 @@ func (m *ModelProviderService) getModelConfig(tenantID, compositeModelName strin
var extra map[string]string
var region string
var baseURL string
if instance != nil {
err = json.Unmarshal([]byte(instance.Extra), &extra)
if err != nil {
return nil, "", nil, 0, err
}
region = extra["region"]
baseURL = extra["base_url"]
}
providerInfo := dao.GetModelProviderManager().FindProvider(providerName)
@@ -2340,17 +2450,30 @@ func (m *ModelProviderService) getModelConfig(tenantID, compositeModelName strin
return builtinDriver, modelName, apiConfig, maxTokens, nil
}
_, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(providerID, instance.ID, modelName)
var modelRecord *entity.TenantModel
modelRecord, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(providerID, instance.ID, modelName)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, "", nil, 0, fmt.Errorf("tenant model %q lookup failed: %w", modelName, err)
}
_, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName)
if err != nil {
return nil, "", nil, 0, fmt.Errorf("provider %s model %s not found", providerName, modelName)
}
}
maxTokens, err = maxTokensFromTenantModelExtra(modelRecord, maxTokens)
if err != nil {
return nil, "", nil, 0, err
}
apiKey = instance.APIKey
apiConfig := &modelModule.APIConfig{ApiKey: &apiKey, Region: &region}
return providerInfo.ModelDriver, modelName, apiConfig, maxTokens, nil
driver, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, baseURL)
if err != nil {
return nil, "", nil, 0, err
}
apiConfig := &modelModule.APIConfig{ApiKey: &apiKey, Region: &region, BaseURL: &baseURL}
return driver, modelName, apiConfig, maxTokens, nil
}
// ListAllModels list all models

View File

@@ -4,6 +4,7 @@ import (
"strings"
"testing"
"ragflow/internal/entity"
modelModule "ragflow/internal/entity/models"
)
@@ -68,3 +69,43 @@ func TestValidateEmbeddingDimension(t *testing.T) {
})
}
}
func TestModelInfoWithTenantExtraAppliesEmbeddingDimensions(t *testing.T) {
factoryMaxDimension := 2048
modelInfo := &modelModule.Model{
Name: "embedding-3",
MaxDimension: &factoryMaxDimension,
Dimensions: []int{1024, 2048},
ModelTypes: []string{"embedding"},
ModelTypeMap: map[string]bool{"embedding": true},
}
modelEntity := &entity.TenantModel{
Extra: `{"max_dimension":768,"dimensions":[384,768],"model_types":["embedding"]}`,
}
merged, err := modelInfoWithTenantExtra(modelInfo, modelEntity)
if err != nil {
t.Fatalf("modelInfoWithTenantExtra() error = %v", err)
}
if merged == modelInfo {
t.Fatalf("modelInfoWithTenantExtra() returned original model pointer")
}
if merged.MaxDimension == nil || *merged.MaxDimension != 768 {
t.Fatalf("MaxDimension = %v, want 768", merged.MaxDimension)
}
if len(merged.Dimensions) != 2 || merged.Dimensions[0] != 384 || merged.Dimensions[1] != 768 {
t.Fatalf("Dimensions = %v, want [384 768]", merged.Dimensions)
}
if err := validateEmbeddingDimension(merged, 1024); err == nil || !strings.Contains(err.Error(), "supported dimensions") {
t.Fatalf("validateEmbeddingDimension() error = %v, want supported dimensions error", err)
}
if err := validateEmbeddingDimension(merged, 768); err != nil {
t.Fatalf("validateEmbeddingDimension() error = %v", err)
}
if modelInfo.MaxDimension == nil || *modelInfo.MaxDimension != factoryMaxDimension {
t.Fatalf("factory MaxDimension was mutated: %v", modelInfo.MaxDimension)
}
if len(modelInfo.Dimensions) != 2 || modelInfo.Dimensions[0] != 1024 || modelInfo.Dimensions[1] != 2048 {
t.Fatalf("factory Dimensions were mutated: %v", modelInfo.Dimensions)
}
}