From 312514c032fa0f245029ed7139341df9a9124be2 Mon Sep 17 00:00:00 2001 From: Hz_ Date: Thu, 11 Jun 2026 17:55:13 +0800 Subject: [PATCH] feat(go): Add embedding dimension metadata and validation (#15939) ### What problem does this PR solve? - Replace embedding model `dimension` metadata with `max_dimension`. - Add optional `dimensions` metadata for models with fixed selectable output dimensions. - Include `max_dimension` and `dimensions` in model list responses. - Validate requested embedding dimensions before calling provider embedding APIs. - Forward SiliconFlow embedding dimensions with the correct `dimensions` request field. - Add unit coverage for embedding dimension validation rules. --- internal/entity/models/base_model.go | 3 +- internal/entity/models/model.go | 15 +++++- internal/entity/models/replicate.go | 3 +- internal/entity/models/types.go | 11 ++-- internal/service/model_service.go | 50 +++++++++++++++++-- internal/service/model_service_test.go | 69 ++++++++++++++++++++++++++ 6 files changed, 138 insertions(+), 13 deletions(-) diff --git a/internal/entity/models/base_model.go b/internal/entity/models/base_model.go index f4037c928a..e76f20f9cb 100644 --- a/internal/entity/models/base_model.go +++ b/internal/entity/models/base_model.go @@ -95,7 +95,8 @@ func ParseListModel(modelList ModelList) []ListModelResponse { } modelResponse.Name = modelName if modelEntity != nil { - modelResponse.Dimension = modelEntity.Dimension + modelResponse.MaxDimension = modelEntity.MaxDimension + modelResponse.Dimensions = modelEntity.Dimensions modelResponse.MaxTokens = modelEntity.MaxTokens modelResponse.ModelTypes = modelEntity.ModelTypes modelResponse.Thinking = modelEntity.Thinking diff --git a/internal/entity/models/model.go b/internal/entity/models/model.go index 5b80cd20b8..091bdd8f8a 100644 --- a/internal/entity/models/model.go +++ b/internal/entity/models/model.go @@ -160,7 +160,8 @@ type Model struct { ModelTypes []string `json:"model_types"` Thinking *ModelThinking `json:"thinking"` Class *string `json:"class"` - Dimension *int `json:"dimension"` // used by embedding models + MaxDimension *int `json:"max_dimension"` // used by embedding models + Dimensions []int `json:"dimensions"` Alias []string `json:"alias"` ModelTypeMap map[string]bool } @@ -386,6 +387,12 @@ func (pm *ProviderManager) ListAllModels() ([]map[string]interface{}, error) { if model.MaxTokens != nil { modelData["max_tokens"] = *model.MaxTokens } + if model.MaxDimension != nil { + modelData["max_dimension"] = *model.MaxDimension + } + if len(model.Dimensions) > 0 { + modelData["dimensions"] = model.Dimensions + } modelList = append(modelList, modelData) } @@ -437,6 +444,12 @@ func (pm *ProviderManager) ListModels(providerName string) ([]map[string]interfa "max_tokens": model.MaxTokens, "model_types": model.ModelTypes, } + if model.MaxDimension != nil { + modelData["max_dimension"] = *model.MaxDimension + } + if len(model.Dimensions) > 0 { + modelData["dimensions"] = model.Dimensions + } modelList = append(modelList, modelData) } diff --git a/internal/entity/models/replicate.go b/internal/entity/models/replicate.go index dc95701866..f4a50c27d1 100644 --- a/internal/entity/models/replicate.go +++ b/internal/entity/models/replicate.go @@ -557,7 +557,8 @@ func (r *ReplicateModel) ListModels(apiConfig *APIConfig) ([]ListModelResponse, } modelResponse.Name = modelName if modelEntity != nil { - modelResponse.Dimension = modelEntity.Dimension + modelResponse.MaxDimension = modelEntity.MaxDimension + modelResponse.Dimensions = modelEntity.Dimensions modelResponse.MaxTokens = modelEntity.MaxTokens modelResponse.ModelTypes = modelEntity.ModelTypes modelResponse.Thinking = modelEntity.Thinking diff --git a/internal/entity/models/types.go b/internal/entity/models/types.go index caad547c40..863a33a0d0 100644 --- a/internal/entity/models/types.go +++ b/internal/entity/models/types.go @@ -79,11 +79,12 @@ type OCRFileResponse struct { } type ListModelResponse struct { - Name string `json:"name"` - MaxTokens *int `json:"max_tokens"` - ModelTypes []string `json:"model_types"` - Thinking *ModelThinking `json:"thinking"` - Dimension *int `json:"dimension"` // used by embedding models + Name string `json:"name"` + MaxTokens *int `json:"max_tokens"` + ModelTypes []string `json:"model_types"` + Thinking *ModelThinking `json:"thinking"` + MaxDimension *int `json:"max_dimension"` // used by embedding models + Dimensions []int `json:"dimensions"` } type ParseFileResponse struct { diff --git a/internal/service/model_service.go b/internal/service/model_service.go index e0f1fa6c8b..180075986b 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -247,11 +247,12 @@ func (m *ModelProviderService) ListSupportedModels(providerName, instanceName, u var result []map[string]interface{} for _, model := range modelList { result = append(result, map[string]interface{}{ - "name": model.Name, - "dimension": model.Dimension, - "max_tokens": model.MaxTokens, - "model_types": model.ModelTypes, - "thinking": model.Thinking, + "name": model.Name, + "max_dimension": model.MaxDimension, + "dimensions": model.Dimensions, + "max_tokens": model.MaxTokens, + "model_types": model.ModelTypes, + "thinking": model.Thinking, }) } return result, nil @@ -1108,6 +1109,36 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc return common.CodeServerError, errors.New("model is disabled") } +func validateEmbeddingDimension(model *modelModule.Model, requested int) error { + if requested <= 0 || model == nil { + return nil + } + + if len(model.Dimensions) > 0 { + for _, dim := range model.Dimensions { + if dim == requested { + return nil + } + } + return fmt.Errorf( + "dimension %d is not supported by model %s, supported dimensions: %v", + requested, + model.Name, + model.Dimensions, + ) + } + if model.MaxDimension != nil && requested > *model.MaxDimension { + return fmt.Errorf( + "dimension %d is not supported by model %s, max dimension: %d", + requested, + model.Name, + *model.MaxDimension, + ) + } + + return nil +} + // EmbedText sends texts to the embedding model func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName, userID string, texts []string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.EmbeddingConfig) ([]modelModule.EmbeddingData, common.ErrorCode, error) { if apiConfig == nil { @@ -1167,6 +1198,10 @@ func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName, apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey + if err := validateEmbeddingDimension(model, modelConfig.Dimension); err != nil { + return nil, common.CodeBadRequest, err + } + var response []modelModule.EmbeddingData response, err = providerInfo.ModelDriver.Embed(&modelName, texts, apiConfig, modelConfig) if err != nil { @@ -1204,6 +1239,11 @@ func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName, return nil, common.CodeServerError, err } + modelSchema, _ := dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err := validateEmbeddingDimension(modelSchema, modelConfig.Dimension); err != nil { + return nil, common.CodeBadRequest, err + } + var response []modelModule.EmbeddingData response, err = newProviderInfo.Embed(&modelName, texts, apiConfig, modelConfig) if err != nil { diff --git a/internal/service/model_service_test.go b/internal/service/model_service_test.go index 6d43c3366c..a0be4082da 100644 --- a/internal/service/model_service_test.go +++ b/internal/service/model_service_test.go @@ -1 +1,70 @@ package service + +import ( + "strings" + "testing" + + modelModule "ragflow/internal/entity/models" +) + +func TestValidateEmbeddingDimension(t *testing.T) { + maxDimension := 2048 + + tests := []struct { + name string + model *modelModule.Model + requested int + wantErr string + }{ + { + name: "allows unset requested dimension", + model: &modelModule.Model{MaxDimension: &maxDimension, Dimensions: []int{256, 512}}, + requested: 0, + }, + { + name: "allows missing model schema", + model: nil, + requested: 256, + }, + { + name: "allows dimension listed in explicit options", + model: &modelModule.Model{Name: "embedding-3", MaxDimension: &maxDimension, Dimensions: []int{256, 512, 1024, 2048}}, + requested: 1024, + }, + { + name: "rejects dimension not listed in explicit options", + model: &modelModule.Model{Name: "embedding-3", MaxDimension: &maxDimension, Dimensions: []int{256, 512, 1024, 2048}}, + requested: 1536, + wantErr: "supported dimensions", + }, + { + name: "allows custom dimension within max dimension", + model: &modelModule.Model{Name: "flex-embedding", MaxDimension: &maxDimension}, + requested: 1536, + }, + { + name: "rejects custom dimension above max dimension", + model: &modelModule.Model{Name: "flex-embedding", MaxDimension: &maxDimension}, + requested: 4096, + wantErr: "max dimension", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateEmbeddingDimension(tt.model, tt.requested) + if tt.wantErr == "" { + if err != nil { + t.Fatalf("validateEmbeddingDimension() error = %v", err) + } + return + } + if err == nil { + t.Fatalf("validateEmbeddingDimension() expected error containing %q", tt.wantErr) + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("validateEmbeddingDimension() error = %v, want substring %q", err, tt.wantErr) + } + }) + } +}