diff --git a/internal/service/model_service.go b/internal/service/model_service.go index 111bbaa05c..c0c7b4667c 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -1133,6 +1133,79 @@ type ModelInstanceAndProviderInfo struct { APIConfig *modelModule.APIConfig } +type tenantModelExtra struct { + MaxTokens *int `json:"max_tokens"` + ModelTypes []string `json:"model_types"` + MaxDimension *int `json:"max_dimension"` + Dimensions []int `json:"dimensions"` + Thinking *bool `json:"thinking"` +} + +func modelInfoWithTenantExtra(modelInfo *modelModule.Model, modelEntity *entity.TenantModel) (*modelModule.Model, error) { + if modelInfo == nil || modelEntity == nil || strings.TrimSpace(modelEntity.Extra) == "" { + return modelInfo, nil + } + + var extra tenantModelExtra + if err := json.Unmarshal([]byte(modelEntity.Extra), &extra); err != nil { + return nil, err + } + + model := *modelInfo + model.ModelTypes = append([]string(nil), modelInfo.ModelTypes...) + model.Dimensions = append([]int(nil), modelInfo.Dimensions...) + model.Alias = append([]string(nil), modelInfo.Alias...) + if modelInfo.ModelTypeMap != nil { + model.ModelTypeMap = make(map[string]bool, len(modelInfo.ModelTypeMap)) + for modelType, enabled := range modelInfo.ModelTypeMap { + model.ModelTypeMap[modelType] = enabled + } + } + if modelInfo.Thinking != nil { + thinking := *modelInfo.Thinking + model.Thinking = &thinking + } + + if extra.MaxTokens != nil && *extra.MaxTokens > 0 { + model.MaxTokens = extra.MaxTokens + } + if len(extra.ModelTypes) > 0 { + model.ModelTypes = append([]string(nil), extra.ModelTypes...) + model.ModelTypeMap = make(map[string]bool, len(extra.ModelTypes)) + for _, modelType := range extra.ModelTypes { + model.ModelTypeMap[modelType] = true + } + } + if extra.MaxDimension != nil && *extra.MaxDimension > 0 { + model.MaxDimension = extra.MaxDimension + } + if len(extra.Dimensions) > 0 { + model.Dimensions = append([]int(nil), extra.Dimensions...) + } + if extra.Thinking != nil { + if model.Thinking == nil { + model.Thinking = &modelModule.ModelThinking{} + } + model.Thinking.DefaultValue = *extra.Thinking + } + + return &model, nil +} + +func maxTokensFromTenantModelExtra(modelEntity *entity.TenantModel, fallback int) (int, error) { + if modelEntity == nil || strings.TrimSpace(modelEntity.Extra) == "" { + return fallback, nil + } + var extra tenantModelExtra + if err := json.Unmarshal([]byte(modelEntity.Extra), &extra); err != nil { + return 0, err + } + if extra.MaxTokens != nil && *extra.MaxTokens > 0 { + return *extra.MaxTokens, nil + } + return fallback, nil +} + func (m *ModelProviderService) getModelInstanceAndProviderByName(providerName, instanceName, modelName *string, userID string, apiConfig *modelModule.APIConfig) (*ModelInstanceAndProviderInfo, error) { // Get tenant ID from user tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") @@ -1172,6 +1245,10 @@ func (m *ModelProviderService) getModelInstanceAndProviderByName(providerName, i if err != nil { return nil, errors.New(fmt.Sprintf("provider %s model %s not found", *providerName, *modelName)) } + modelInfo, err = modelInfoWithTenantExtra(modelInfo, modelEntity) + if err != nil { + return nil, err + } var extra map[string]string err = json.Unmarshal([]byte(instanceEntity.Extra), &extra) @@ -1243,6 +1320,10 @@ func (m *ModelProviderService) getModelInstanceAndProviderByID(modelID *string, if err != nil { return nil, errors.New(fmt.Sprintf("provider %s model %s not found", providerEntity.ProviderName, modelEntity.ModelName)) } + modelInfo, err = modelInfoWithTenantExtra(modelInfo, modelEntity) + if err != nil { + return nil, err + } var extra map[string]string err = json.Unmarshal([]byte(instanceEntity.Extra), &extra) @@ -1295,6 +1376,10 @@ func (m *ModelProviderService) ChatToModelWithMessages(providerName, instanceNam modelConfig = &modelModule.ChatConfig{} } modelConfig.ModelClass = info.ModelInfo.Class + if modelConfig.Thinking == nil && info.ModelInfo.Thinking != nil { + thinking := info.ModelInfo.Thinking.DefaultValue + modelConfig.Thinking = &thinking + } var response *modelModule.ChatResponse var modelDriver modelModule.ModelDriver @@ -1353,6 +1438,10 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc modelConfig = &modelModule.ChatConfig{} } modelConfig.ModelClass = info.ModelInfo.Class + if modelConfig.Thinking == nil && info.ModelInfo.Thinking != nil { + thinking := info.ModelInfo.Thinking.DefaultValue + modelConfig.Thinking = &thinking + } var modelDriver modelModule.ModelDriver @@ -1374,7 +1463,7 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc } } - err = modelDriver.ChatStreamlyWithSender(*modelName, messages, apiConfig, modelConfig, sender) + err = modelDriver.ChatStreamlyWithSender(*modelName, messages, info.APIConfig, modelConfig, sender) if err != nil { return common.CodeServerError, err } @@ -1461,7 +1550,7 @@ func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName, } var response []modelModule.EmbeddingData - response, err = modelDriver.Embed(modelName, texts, apiConfig, modelConfig) + response, err = modelDriver.Embed(modelName, texts, info.APIConfig, modelConfig) if err != nil { return nil, common.CodeServerError, err } @@ -1985,10 +2074,12 @@ type AddCustomModelRequest struct { } type ModelRequest struct { - ModelName string `json:"model_name"` - ModelTypes []string `json:"model_types"` - MaxTokens int `json:"max_tokens"` - Thinking *bool `json:"thinking"` + ModelName string `json:"model_name"` + ModelTypes []string `json:"model_types"` + MaxTokens int `json:"max_tokens"` + MaxDimension int `json:"max_dimension"` + Dimensions []int `json:"dimensions"` + Thinking *bool `json:"thinking"` } func (m *ModelProviderService) AddModel(request *AddModelRequest, userID string) (common.ErrorCode, error) { @@ -2052,9 +2143,22 @@ func (m *ModelProviderService) AddModel(request *AddModelRequest, userID string) modelID := utility.GenerateToken() + if model.MaxDimension < 0 { + return common.CodeBadRequest, errors.New("max_dimension must be non-negative") + } + for _, dimension := range model.Dimensions { + if dimension <= 0 { + return common.CodeBadRequest, errors.New("dimensions must contain positive values") + } + if model.MaxDimension > 0 && dimension > model.MaxDimension { + return common.CodeBadRequest, fmt.Errorf("dimension %d exceeds max_dimension %d", dimension, model.MaxDimension) + } + } extra := map[string]interface{}{ - "max_tokens": model.MaxTokens, - "model_types": []string{modelType}, + "max_tokens": model.MaxTokens, + "model_types": []string{modelType}, + "max_dimension": model.MaxDimension, + "dimensions": model.Dimensions, } if model.Thinking != nil { extra["thinking"] = *model.Thinking @@ -2209,6 +2313,10 @@ func (m *ModelProviderService) GetModelConfigFromProviderInstance(tenantID strin maxTokens = *mi.MaxTokens } } + maxTokens, driverErr = maxTokensFromTenantModelExtra(modelObj, maxTokens) + if driverErr != nil { + return nil, "", nil, 0, driverErr + } apiConfig := &modelModule.APIConfig{ApiKey: &apiKey, Region: ®ion, BaseURL: &baseURL} return driver, modelObj.ModelName, apiConfig, maxTokens, nil case errors.Is(modelErr, gorm.ErrRecordNotFound): @@ -2300,12 +2408,14 @@ func (m *ModelProviderService) getModelConfig(tenantID, compositeModelName strin var extra map[string]string var region string + var baseURL string if instance != nil { err = json.Unmarshal([]byte(instance.Extra), &extra) if err != nil { return nil, "", nil, 0, err } region = extra["region"] + baseURL = extra["base_url"] } providerInfo := dao.GetModelProviderManager().FindProvider(providerName) @@ -2340,17 +2450,30 @@ func (m *ModelProviderService) getModelConfig(tenantID, compositeModelName strin return builtinDriver, modelName, apiConfig, maxTokens, nil } - _, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(providerID, instance.ID, modelName) + var modelRecord *entity.TenantModel + modelRecord, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(providerID, instance.ID, modelName) if err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + return nil, "", nil, 0, fmt.Errorf("tenant model %q lookup failed: %w", modelName, err) + } _, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) if err != nil { return nil, "", nil, 0, fmt.Errorf("provider %s model %s not found", providerName, modelName) } } + maxTokens, err = maxTokensFromTenantModelExtra(modelRecord, maxTokens) + if err != nil { + return nil, "", nil, 0, err + } apiKey = instance.APIKey - apiConfig := &modelModule.APIConfig{ApiKey: &apiKey, Region: ®ion} - return providerInfo.ModelDriver, modelName, apiConfig, maxTokens, nil + driver, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, baseURL) + if err != nil { + return nil, "", nil, 0, err + } + + apiConfig := &modelModule.APIConfig{ApiKey: &apiKey, Region: ®ion, BaseURL: &baseURL} + return driver, modelName, apiConfig, maxTokens, nil } // ListAllModels list all models diff --git a/internal/service/model_service_test.go b/internal/service/model_service_test.go index a0be4082da..b482f4fbbe 100644 --- a/internal/service/model_service_test.go +++ b/internal/service/model_service_test.go @@ -4,6 +4,7 @@ import ( "strings" "testing" + "ragflow/internal/entity" modelModule "ragflow/internal/entity/models" ) @@ -68,3 +69,43 @@ func TestValidateEmbeddingDimension(t *testing.T) { }) } } + +func TestModelInfoWithTenantExtraAppliesEmbeddingDimensions(t *testing.T) { + factoryMaxDimension := 2048 + modelInfo := &modelModule.Model{ + Name: "embedding-3", + MaxDimension: &factoryMaxDimension, + Dimensions: []int{1024, 2048}, + ModelTypes: []string{"embedding"}, + ModelTypeMap: map[string]bool{"embedding": true}, + } + modelEntity := &entity.TenantModel{ + Extra: `{"max_dimension":768,"dimensions":[384,768],"model_types":["embedding"]}`, + } + + merged, err := modelInfoWithTenantExtra(modelInfo, modelEntity) + if err != nil { + t.Fatalf("modelInfoWithTenantExtra() error = %v", err) + } + if merged == modelInfo { + t.Fatalf("modelInfoWithTenantExtra() returned original model pointer") + } + if merged.MaxDimension == nil || *merged.MaxDimension != 768 { + t.Fatalf("MaxDimension = %v, want 768", merged.MaxDimension) + } + if len(merged.Dimensions) != 2 || merged.Dimensions[0] != 384 || merged.Dimensions[1] != 768 { + t.Fatalf("Dimensions = %v, want [384 768]", merged.Dimensions) + } + if err := validateEmbeddingDimension(merged, 1024); err == nil || !strings.Contains(err.Error(), "supported dimensions") { + t.Fatalf("validateEmbeddingDimension() error = %v, want supported dimensions error", err) + } + if err := validateEmbeddingDimension(merged, 768); err != nil { + t.Fatalf("validateEmbeddingDimension() error = %v", err) + } + if modelInfo.MaxDimension == nil || *modelInfo.MaxDimension != factoryMaxDimension { + t.Fatalf("factory MaxDimension was mutated: %v", modelInfo.MaxDimension) + } + if len(modelInfo.Dimensions) != 2 || modelInfo.Dimensions[0] != 1024 || modelInfo.Dimensions[1] != 2048 { + t.Fatalf("factory Dimensions were mutated: %v", modelInfo.Dimensions) + } +}