feat(go): Add embedding dimension metadata and validation (#15939)

### What problem does this PR solve?

- Replace embedding model `dimension` metadata with `max_dimension`.
- Add optional `dimensions` metadata for models with fixed selectable
output dimensions.
- Include `max_dimension` and `dimensions` in model list responses.
- Validate requested embedding dimensions before calling provider
embedding APIs.
- Forward SiliconFlow embedding dimensions with the correct `dimensions`
request field.
- Add unit coverage for embedding dimension validation rules.
This commit is contained in:
Hz_
2026-06-11 17:55:13 +08:00
committed by GitHub
parent 9d5950963b
commit 312514c032
6 changed files with 138 additions and 13 deletions

View File

@@ -95,7 +95,8 @@ func ParseListModel(modelList ModelList) []ListModelResponse {
}
modelResponse.Name = modelName
if modelEntity != nil {
modelResponse.Dimension = modelEntity.Dimension
modelResponse.MaxDimension = modelEntity.MaxDimension
modelResponse.Dimensions = modelEntity.Dimensions
modelResponse.MaxTokens = modelEntity.MaxTokens
modelResponse.ModelTypes = modelEntity.ModelTypes
modelResponse.Thinking = modelEntity.Thinking

View File

@@ -160,7 +160,8 @@ type Model struct {
ModelTypes []string `json:"model_types"`
Thinking *ModelThinking `json:"thinking"`
Class *string `json:"class"`
Dimension *int `json:"dimension"` // used by embedding models
MaxDimension *int `json:"max_dimension"` // used by embedding models
Dimensions []int `json:"dimensions"`
Alias []string `json:"alias"`
ModelTypeMap map[string]bool
}
@@ -386,6 +387,12 @@ func (pm *ProviderManager) ListAllModels() ([]map[string]interface{}, error) {
if model.MaxTokens != nil {
modelData["max_tokens"] = *model.MaxTokens
}
if model.MaxDimension != nil {
modelData["max_dimension"] = *model.MaxDimension
}
if len(model.Dimensions) > 0 {
modelData["dimensions"] = model.Dimensions
}
modelList = append(modelList, modelData)
}
@@ -437,6 +444,12 @@ func (pm *ProviderManager) ListModels(providerName string) ([]map[string]interfa
"max_tokens": model.MaxTokens,
"model_types": model.ModelTypes,
}
if model.MaxDimension != nil {
modelData["max_dimension"] = *model.MaxDimension
}
if len(model.Dimensions) > 0 {
modelData["dimensions"] = model.Dimensions
}
modelList = append(modelList, modelData)
}

View File

@@ -557,7 +557,8 @@ func (r *ReplicateModel) ListModels(apiConfig *APIConfig) ([]ListModelResponse,
}
modelResponse.Name = modelName
if modelEntity != nil {
modelResponse.Dimension = modelEntity.Dimension
modelResponse.MaxDimension = modelEntity.MaxDimension
modelResponse.Dimensions = modelEntity.Dimensions
modelResponse.MaxTokens = modelEntity.MaxTokens
modelResponse.ModelTypes = modelEntity.ModelTypes
modelResponse.Thinking = modelEntity.Thinking

View File

@@ -79,11 +79,12 @@ type OCRFileResponse struct {
}
type ListModelResponse struct {
Name string `json:"name"`
MaxTokens *int `json:"max_tokens"`
ModelTypes []string `json:"model_types"`
Thinking *ModelThinking `json:"thinking"`
Dimension *int `json:"dimension"` // used by embedding models
Name string `json:"name"`
MaxTokens *int `json:"max_tokens"`
ModelTypes []string `json:"model_types"`
Thinking *ModelThinking `json:"thinking"`
MaxDimension *int `json:"max_dimension"` // used by embedding models
Dimensions []int `json:"dimensions"`
}
type ParseFileResponse struct {

View File

@@ -247,11 +247,12 @@ func (m *ModelProviderService) ListSupportedModels(providerName, instanceName, u
var result []map[string]interface{}
for _, model := range modelList {
result = append(result, map[string]interface{}{
"name": model.Name,
"dimension": model.Dimension,
"max_tokens": model.MaxTokens,
"model_types": model.ModelTypes,
"thinking": model.Thinking,
"name": model.Name,
"max_dimension": model.MaxDimension,
"dimensions": model.Dimensions,
"max_tokens": model.MaxTokens,
"model_types": model.ModelTypes,
"thinking": model.Thinking,
})
}
return result, nil
@@ -1108,6 +1109,36 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc
return common.CodeServerError, errors.New("model is disabled")
}
func validateEmbeddingDimension(model *modelModule.Model, requested int) error {
if requested <= 0 || model == nil {
return nil
}
if len(model.Dimensions) > 0 {
for _, dim := range model.Dimensions {
if dim == requested {
return nil
}
}
return fmt.Errorf(
"dimension %d is not supported by model %s, supported dimensions: %v",
requested,
model.Name,
model.Dimensions,
)
}
if model.MaxDimension != nil && requested > *model.MaxDimension {
return fmt.Errorf(
"dimension %d is not supported by model %s, max dimension: %d",
requested,
model.Name,
*model.MaxDimension,
)
}
return nil
}
// EmbedText sends texts to the embedding model
func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName, userID string, texts []string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.EmbeddingConfig) ([]modelModule.EmbeddingData, common.ErrorCode, error) {
if apiConfig == nil {
@@ -1167,6 +1198,10 @@ func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName,
apiConfig.Region = &region
apiConfig.ApiKey = &instance.APIKey
if err := validateEmbeddingDimension(model, modelConfig.Dimension); err != nil {
return nil, common.CodeBadRequest, err
}
var response []modelModule.EmbeddingData
response, err = providerInfo.ModelDriver.Embed(&modelName, texts, apiConfig, modelConfig)
if err != nil {
@@ -1204,6 +1239,11 @@ func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName,
return nil, common.CodeServerError, err
}
modelSchema, _ := dao.GetModelProviderManager().GetModelByName(providerName, modelName)
if err := validateEmbeddingDimension(modelSchema, modelConfig.Dimension); err != nil {
return nil, common.CodeBadRequest, err
}
var response []modelModule.EmbeddingData
response, err = newProviderInfo.Embed(&modelName, texts, apiConfig, modelConfig)
if err != nil {

View File

@@ -1 +1,70 @@
package service
import (
"strings"
"testing"
modelModule "ragflow/internal/entity/models"
)
func TestValidateEmbeddingDimension(t *testing.T) {
maxDimension := 2048
tests := []struct {
name string
model *modelModule.Model
requested int
wantErr string
}{
{
name: "allows unset requested dimension",
model: &modelModule.Model{MaxDimension: &maxDimension, Dimensions: []int{256, 512}},
requested: 0,
},
{
name: "allows missing model schema",
model: nil,
requested: 256,
},
{
name: "allows dimension listed in explicit options",
model: &modelModule.Model{Name: "embedding-3", MaxDimension: &maxDimension, Dimensions: []int{256, 512, 1024, 2048}},
requested: 1024,
},
{
name: "rejects dimension not listed in explicit options",
model: &modelModule.Model{Name: "embedding-3", MaxDimension: &maxDimension, Dimensions: []int{256, 512, 1024, 2048}},
requested: 1536,
wantErr: "supported dimensions",
},
{
name: "allows custom dimension within max dimension",
model: &modelModule.Model{Name: "flex-embedding", MaxDimension: &maxDimension},
requested: 1536,
},
{
name: "rejects custom dimension above max dimension",
model: &modelModule.Model{Name: "flex-embedding", MaxDimension: &maxDimension},
requested: 4096,
wantErr: "max dimension",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateEmbeddingDimension(tt.model, tt.requested)
if tt.wantErr == "" {
if err != nil {
t.Fatalf("validateEmbeddingDimension() error = %v", err)
}
return
}
if err == nil {
t.Fatalf("validateEmbeddingDimension() expected error containing %q", tt.wantErr)
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("validateEmbeddingDimension() error = %v, want substring %q", err, tt.wantErr)
}
})
}
}