mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
feat[Go] add max_dimension and dimensions for ModelRequest (#16019)
### What problem does this PR solve? As title ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -1133,6 +1133,79 @@ type ModelInstanceAndProviderInfo struct {
|
||||
APIConfig *modelModule.APIConfig
|
||||
}
|
||||
|
||||
type tenantModelExtra struct {
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
ModelTypes []string `json:"model_types"`
|
||||
MaxDimension *int `json:"max_dimension"`
|
||||
Dimensions []int `json:"dimensions"`
|
||||
Thinking *bool `json:"thinking"`
|
||||
}
|
||||
|
||||
func modelInfoWithTenantExtra(modelInfo *modelModule.Model, modelEntity *entity.TenantModel) (*modelModule.Model, error) {
|
||||
if modelInfo == nil || modelEntity == nil || strings.TrimSpace(modelEntity.Extra) == "" {
|
||||
return modelInfo, nil
|
||||
}
|
||||
|
||||
var extra tenantModelExtra
|
||||
if err := json.Unmarshal([]byte(modelEntity.Extra), &extra); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model := *modelInfo
|
||||
model.ModelTypes = append([]string(nil), modelInfo.ModelTypes...)
|
||||
model.Dimensions = append([]int(nil), modelInfo.Dimensions...)
|
||||
model.Alias = append([]string(nil), modelInfo.Alias...)
|
||||
if modelInfo.ModelTypeMap != nil {
|
||||
model.ModelTypeMap = make(map[string]bool, len(modelInfo.ModelTypeMap))
|
||||
for modelType, enabled := range modelInfo.ModelTypeMap {
|
||||
model.ModelTypeMap[modelType] = enabled
|
||||
}
|
||||
}
|
||||
if modelInfo.Thinking != nil {
|
||||
thinking := *modelInfo.Thinking
|
||||
model.Thinking = &thinking
|
||||
}
|
||||
|
||||
if extra.MaxTokens != nil && *extra.MaxTokens > 0 {
|
||||
model.MaxTokens = extra.MaxTokens
|
||||
}
|
||||
if len(extra.ModelTypes) > 0 {
|
||||
model.ModelTypes = append([]string(nil), extra.ModelTypes...)
|
||||
model.ModelTypeMap = make(map[string]bool, len(extra.ModelTypes))
|
||||
for _, modelType := range extra.ModelTypes {
|
||||
model.ModelTypeMap[modelType] = true
|
||||
}
|
||||
}
|
||||
if extra.MaxDimension != nil && *extra.MaxDimension > 0 {
|
||||
model.MaxDimension = extra.MaxDimension
|
||||
}
|
||||
if len(extra.Dimensions) > 0 {
|
||||
model.Dimensions = append([]int(nil), extra.Dimensions...)
|
||||
}
|
||||
if extra.Thinking != nil {
|
||||
if model.Thinking == nil {
|
||||
model.Thinking = &modelModule.ModelThinking{}
|
||||
}
|
||||
model.Thinking.DefaultValue = *extra.Thinking
|
||||
}
|
||||
|
||||
return &model, nil
|
||||
}
|
||||
|
||||
func maxTokensFromTenantModelExtra(modelEntity *entity.TenantModel, fallback int) (int, error) {
|
||||
if modelEntity == nil || strings.TrimSpace(modelEntity.Extra) == "" {
|
||||
return fallback, nil
|
||||
}
|
||||
var extra tenantModelExtra
|
||||
if err := json.Unmarshal([]byte(modelEntity.Extra), &extra); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if extra.MaxTokens != nil && *extra.MaxTokens > 0 {
|
||||
return *extra.MaxTokens, nil
|
||||
}
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func (m *ModelProviderService) getModelInstanceAndProviderByName(providerName, instanceName, modelName *string, userID string, apiConfig *modelModule.APIConfig) (*ModelInstanceAndProviderInfo, error) {
|
||||
// Get tenant ID from user
|
||||
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
|
||||
@@ -1172,6 +1245,10 @@ func (m *ModelProviderService) getModelInstanceAndProviderByName(providerName, i
|
||||
if err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("provider %s model %s not found", *providerName, *modelName))
|
||||
}
|
||||
modelInfo, err = modelInfoWithTenantExtra(modelInfo, modelEntity)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var extra map[string]string
|
||||
err = json.Unmarshal([]byte(instanceEntity.Extra), &extra)
|
||||
@@ -1243,6 +1320,10 @@ func (m *ModelProviderService) getModelInstanceAndProviderByID(modelID *string,
|
||||
if err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("provider %s model %s not found", providerEntity.ProviderName, modelEntity.ModelName))
|
||||
}
|
||||
modelInfo, err = modelInfoWithTenantExtra(modelInfo, modelEntity)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var extra map[string]string
|
||||
err = json.Unmarshal([]byte(instanceEntity.Extra), &extra)
|
||||
@@ -1295,6 +1376,10 @@ func (m *ModelProviderService) ChatToModelWithMessages(providerName, instanceNam
|
||||
modelConfig = &modelModule.ChatConfig{}
|
||||
}
|
||||
modelConfig.ModelClass = info.ModelInfo.Class
|
||||
if modelConfig.Thinking == nil && info.ModelInfo.Thinking != nil {
|
||||
thinking := info.ModelInfo.Thinking.DefaultValue
|
||||
modelConfig.Thinking = &thinking
|
||||
}
|
||||
|
||||
var response *modelModule.ChatResponse
|
||||
var modelDriver modelModule.ModelDriver
|
||||
@@ -1353,6 +1438,10 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc
|
||||
modelConfig = &modelModule.ChatConfig{}
|
||||
}
|
||||
modelConfig.ModelClass = info.ModelInfo.Class
|
||||
if modelConfig.Thinking == nil && info.ModelInfo.Thinking != nil {
|
||||
thinking := info.ModelInfo.Thinking.DefaultValue
|
||||
modelConfig.Thinking = &thinking
|
||||
}
|
||||
|
||||
var modelDriver modelModule.ModelDriver
|
||||
|
||||
@@ -1374,7 +1463,7 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc
|
||||
}
|
||||
}
|
||||
|
||||
err = modelDriver.ChatStreamlyWithSender(*modelName, messages, apiConfig, modelConfig, sender)
|
||||
err = modelDriver.ChatStreamlyWithSender(*modelName, messages, info.APIConfig, modelConfig, sender)
|
||||
if err != nil {
|
||||
return common.CodeServerError, err
|
||||
}
|
||||
@@ -1461,7 +1550,7 @@ func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName,
|
||||
}
|
||||
|
||||
var response []modelModule.EmbeddingData
|
||||
response, err = modelDriver.Embed(modelName, texts, apiConfig, modelConfig)
|
||||
response, err = modelDriver.Embed(modelName, texts, info.APIConfig, modelConfig)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
@@ -1985,10 +2074,12 @@ type AddCustomModelRequest struct {
|
||||
}
|
||||
|
||||
type ModelRequest struct {
|
||||
ModelName string `json:"model_name"`
|
||||
ModelTypes []string `json:"model_types"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Thinking *bool `json:"thinking"`
|
||||
ModelName string `json:"model_name"`
|
||||
ModelTypes []string `json:"model_types"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
MaxDimension int `json:"max_dimension"`
|
||||
Dimensions []int `json:"dimensions"`
|
||||
Thinking *bool `json:"thinking"`
|
||||
}
|
||||
|
||||
func (m *ModelProviderService) AddModel(request *AddModelRequest, userID string) (common.ErrorCode, error) {
|
||||
@@ -2052,9 +2143,22 @@ func (m *ModelProviderService) AddModel(request *AddModelRequest, userID string)
|
||||
|
||||
modelID := utility.GenerateToken()
|
||||
|
||||
if model.MaxDimension < 0 {
|
||||
return common.CodeBadRequest, errors.New("max_dimension must be non-negative")
|
||||
}
|
||||
for _, dimension := range model.Dimensions {
|
||||
if dimension <= 0 {
|
||||
return common.CodeBadRequest, errors.New("dimensions must contain positive values")
|
||||
}
|
||||
if model.MaxDimension > 0 && dimension > model.MaxDimension {
|
||||
return common.CodeBadRequest, fmt.Errorf("dimension %d exceeds max_dimension %d", dimension, model.MaxDimension)
|
||||
}
|
||||
}
|
||||
extra := map[string]interface{}{
|
||||
"max_tokens": model.MaxTokens,
|
||||
"model_types": []string{modelType},
|
||||
"max_tokens": model.MaxTokens,
|
||||
"model_types": []string{modelType},
|
||||
"max_dimension": model.MaxDimension,
|
||||
"dimensions": model.Dimensions,
|
||||
}
|
||||
if model.Thinking != nil {
|
||||
extra["thinking"] = *model.Thinking
|
||||
@@ -2209,6 +2313,10 @@ func (m *ModelProviderService) GetModelConfigFromProviderInstance(tenantID strin
|
||||
maxTokens = *mi.MaxTokens
|
||||
}
|
||||
}
|
||||
maxTokens, driverErr = maxTokensFromTenantModelExtra(modelObj, maxTokens)
|
||||
if driverErr != nil {
|
||||
return nil, "", nil, 0, driverErr
|
||||
}
|
||||
apiConfig := &modelModule.APIConfig{ApiKey: &apiKey, Region: ®ion, BaseURL: &baseURL}
|
||||
return driver, modelObj.ModelName, apiConfig, maxTokens, nil
|
||||
case errors.Is(modelErr, gorm.ErrRecordNotFound):
|
||||
@@ -2300,12 +2408,14 @@ func (m *ModelProviderService) getModelConfig(tenantID, compositeModelName strin
|
||||
|
||||
var extra map[string]string
|
||||
var region string
|
||||
var baseURL string
|
||||
if instance != nil {
|
||||
err = json.Unmarshal([]byte(instance.Extra), &extra)
|
||||
if err != nil {
|
||||
return nil, "", nil, 0, err
|
||||
}
|
||||
region = extra["region"]
|
||||
baseURL = extra["base_url"]
|
||||
}
|
||||
|
||||
providerInfo := dao.GetModelProviderManager().FindProvider(providerName)
|
||||
@@ -2340,17 +2450,30 @@ func (m *ModelProviderService) getModelConfig(tenantID, compositeModelName strin
|
||||
return builtinDriver, modelName, apiConfig, maxTokens, nil
|
||||
}
|
||||
|
||||
_, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(providerID, instance.ID, modelName)
|
||||
var modelRecord *entity.TenantModel
|
||||
modelRecord, err = m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(providerID, instance.ID, modelName)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, "", nil, 0, fmt.Errorf("tenant model %q lookup failed: %w", modelName, err)
|
||||
}
|
||||
_, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName)
|
||||
if err != nil {
|
||||
return nil, "", nil, 0, fmt.Errorf("provider %s model %s not found", providerName, modelName)
|
||||
}
|
||||
}
|
||||
maxTokens, err = maxTokensFromTenantModelExtra(modelRecord, maxTokens)
|
||||
if err != nil {
|
||||
return nil, "", nil, 0, err
|
||||
}
|
||||
apiKey = instance.APIKey
|
||||
|
||||
apiConfig := &modelModule.APIConfig{ApiKey: &apiKey, Region: ®ion}
|
||||
return providerInfo.ModelDriver, modelName, apiConfig, maxTokens, nil
|
||||
driver, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, baseURL)
|
||||
if err != nil {
|
||||
return nil, "", nil, 0, err
|
||||
}
|
||||
|
||||
apiConfig := &modelModule.APIConfig{ApiKey: &apiKey, Region: ®ion, BaseURL: &baseURL}
|
||||
return driver, modelName, apiConfig, maxTokens, nil
|
||||
}
|
||||
|
||||
// ListAllModels list all models
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ragflow/internal/entity"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
)
|
||||
|
||||
@@ -68,3 +69,43 @@ func TestValidateEmbeddingDimension(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelInfoWithTenantExtraAppliesEmbeddingDimensions(t *testing.T) {
|
||||
factoryMaxDimension := 2048
|
||||
modelInfo := &modelModule.Model{
|
||||
Name: "embedding-3",
|
||||
MaxDimension: &factoryMaxDimension,
|
||||
Dimensions: []int{1024, 2048},
|
||||
ModelTypes: []string{"embedding"},
|
||||
ModelTypeMap: map[string]bool{"embedding": true},
|
||||
}
|
||||
modelEntity := &entity.TenantModel{
|
||||
Extra: `{"max_dimension":768,"dimensions":[384,768],"model_types":["embedding"]}`,
|
||||
}
|
||||
|
||||
merged, err := modelInfoWithTenantExtra(modelInfo, modelEntity)
|
||||
if err != nil {
|
||||
t.Fatalf("modelInfoWithTenantExtra() error = %v", err)
|
||||
}
|
||||
if merged == modelInfo {
|
||||
t.Fatalf("modelInfoWithTenantExtra() returned original model pointer")
|
||||
}
|
||||
if merged.MaxDimension == nil || *merged.MaxDimension != 768 {
|
||||
t.Fatalf("MaxDimension = %v, want 768", merged.MaxDimension)
|
||||
}
|
||||
if len(merged.Dimensions) != 2 || merged.Dimensions[0] != 384 || merged.Dimensions[1] != 768 {
|
||||
t.Fatalf("Dimensions = %v, want [384 768]", merged.Dimensions)
|
||||
}
|
||||
if err := validateEmbeddingDimension(merged, 1024); err == nil || !strings.Contains(err.Error(), "supported dimensions") {
|
||||
t.Fatalf("validateEmbeddingDimension() error = %v, want supported dimensions error", err)
|
||||
}
|
||||
if err := validateEmbeddingDimension(merged, 768); err != nil {
|
||||
t.Fatalf("validateEmbeddingDimension() error = %v", err)
|
||||
}
|
||||
if modelInfo.MaxDimension == nil || *modelInfo.MaxDimension != factoryMaxDimension {
|
||||
t.Fatalf("factory MaxDimension was mutated: %v", modelInfo.MaxDimension)
|
||||
}
|
||||
if len(modelInfo.Dimensions) != 2 || modelInfo.Dimensions[0] != 1024 || modelInfo.Dimensions[1] != 2048 {
|
||||
t.Fatalf("factory Dimensions were mutated: %v", modelInfo.Dimensions)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user