mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-04 18:45:38 +08:00
feat(go): Add embedding dimension metadata and validation (#15939)
### What problem does this PR solve? - Replace embedding model `dimension` metadata with `max_dimension`. - Add optional `dimensions` metadata for models with fixed selectable output dimensions. - Include `max_dimension` and `dimensions` in model list responses. - Validate requested embedding dimensions before calling provider embedding APIs. - Forward SiliconFlow embedding dimensions with the correct `dimensions` request field. - Add unit coverage for embedding dimension validation rules.
This commit is contained in:
@@ -95,7 +95,8 @@ func ParseListModel(modelList ModelList) []ListModelResponse {
|
||||
}
|
||||
modelResponse.Name = modelName
|
||||
if modelEntity != nil {
|
||||
modelResponse.Dimension = modelEntity.Dimension
|
||||
modelResponse.MaxDimension = modelEntity.MaxDimension
|
||||
modelResponse.Dimensions = modelEntity.Dimensions
|
||||
modelResponse.MaxTokens = modelEntity.MaxTokens
|
||||
modelResponse.ModelTypes = modelEntity.ModelTypes
|
||||
modelResponse.Thinking = modelEntity.Thinking
|
||||
|
||||
@@ -160,7 +160,8 @@ type Model struct {
|
||||
ModelTypes []string `json:"model_types"`
|
||||
Thinking *ModelThinking `json:"thinking"`
|
||||
Class *string `json:"class"`
|
||||
Dimension *int `json:"dimension"` // used by embedding models
|
||||
MaxDimension *int `json:"max_dimension"` // used by embedding models
|
||||
Dimensions []int `json:"dimensions"`
|
||||
Alias []string `json:"alias"`
|
||||
ModelTypeMap map[string]bool
|
||||
}
|
||||
@@ -386,6 +387,12 @@ func (pm *ProviderManager) ListAllModels() ([]map[string]interface{}, error) {
|
||||
if model.MaxTokens != nil {
|
||||
modelData["max_tokens"] = *model.MaxTokens
|
||||
}
|
||||
if model.MaxDimension != nil {
|
||||
modelData["max_dimension"] = *model.MaxDimension
|
||||
}
|
||||
if len(model.Dimensions) > 0 {
|
||||
modelData["dimensions"] = model.Dimensions
|
||||
}
|
||||
modelList = append(modelList, modelData)
|
||||
}
|
||||
|
||||
@@ -437,6 +444,12 @@ func (pm *ProviderManager) ListModels(providerName string) ([]map[string]interfa
|
||||
"max_tokens": model.MaxTokens,
|
||||
"model_types": model.ModelTypes,
|
||||
}
|
||||
if model.MaxDimension != nil {
|
||||
modelData["max_dimension"] = *model.MaxDimension
|
||||
}
|
||||
if len(model.Dimensions) > 0 {
|
||||
modelData["dimensions"] = model.Dimensions
|
||||
}
|
||||
modelList = append(modelList, modelData)
|
||||
}
|
||||
|
||||
|
||||
@@ -557,7 +557,8 @@ func (r *ReplicateModel) ListModels(apiConfig *APIConfig) ([]ListModelResponse,
|
||||
}
|
||||
modelResponse.Name = modelName
|
||||
if modelEntity != nil {
|
||||
modelResponse.Dimension = modelEntity.Dimension
|
||||
modelResponse.MaxDimension = modelEntity.MaxDimension
|
||||
modelResponse.Dimensions = modelEntity.Dimensions
|
||||
modelResponse.MaxTokens = modelEntity.MaxTokens
|
||||
modelResponse.ModelTypes = modelEntity.ModelTypes
|
||||
modelResponse.Thinking = modelEntity.Thinking
|
||||
|
||||
@@ -79,11 +79,12 @@ type OCRFileResponse struct {
|
||||
}
|
||||
|
||||
type ListModelResponse struct {
|
||||
Name string `json:"name"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
ModelTypes []string `json:"model_types"`
|
||||
Thinking *ModelThinking `json:"thinking"`
|
||||
Dimension *int `json:"dimension"` // used by embedding models
|
||||
Name string `json:"name"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
ModelTypes []string `json:"model_types"`
|
||||
Thinking *ModelThinking `json:"thinking"`
|
||||
MaxDimension *int `json:"max_dimension"` // used by embedding models
|
||||
Dimensions []int `json:"dimensions"`
|
||||
}
|
||||
|
||||
type ParseFileResponse struct {
|
||||
|
||||
@@ -247,11 +247,12 @@ func (m *ModelProviderService) ListSupportedModels(providerName, instanceName, u
|
||||
var result []map[string]interface{}
|
||||
for _, model := range modelList {
|
||||
result = append(result, map[string]interface{}{
|
||||
"name": model.Name,
|
||||
"dimension": model.Dimension,
|
||||
"max_tokens": model.MaxTokens,
|
||||
"model_types": model.ModelTypes,
|
||||
"thinking": model.Thinking,
|
||||
"name": model.Name,
|
||||
"max_dimension": model.MaxDimension,
|
||||
"dimensions": model.Dimensions,
|
||||
"max_tokens": model.MaxTokens,
|
||||
"model_types": model.ModelTypes,
|
||||
"thinking": model.Thinking,
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
@@ -1108,6 +1109,36 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc
|
||||
return common.CodeServerError, errors.New("model is disabled")
|
||||
}
|
||||
|
||||
func validateEmbeddingDimension(model *modelModule.Model, requested int) error {
|
||||
if requested <= 0 || model == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(model.Dimensions) > 0 {
|
||||
for _, dim := range model.Dimensions {
|
||||
if dim == requested {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf(
|
||||
"dimension %d is not supported by model %s, supported dimensions: %v",
|
||||
requested,
|
||||
model.Name,
|
||||
model.Dimensions,
|
||||
)
|
||||
}
|
||||
if model.MaxDimension != nil && requested > *model.MaxDimension {
|
||||
return fmt.Errorf(
|
||||
"dimension %d is not supported by model %s, max dimension: %d",
|
||||
requested,
|
||||
model.Name,
|
||||
*model.MaxDimension,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EmbedText sends texts to the embedding model
|
||||
func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName, userID string, texts []string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.EmbeddingConfig) ([]modelModule.EmbeddingData, common.ErrorCode, error) {
|
||||
if apiConfig == nil {
|
||||
@@ -1167,6 +1198,10 @@ func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName,
|
||||
apiConfig.Region = ®ion
|
||||
apiConfig.ApiKey = &instance.APIKey
|
||||
|
||||
if err := validateEmbeddingDimension(model, modelConfig.Dimension); err != nil {
|
||||
return nil, common.CodeBadRequest, err
|
||||
}
|
||||
|
||||
var response []modelModule.EmbeddingData
|
||||
response, err = providerInfo.ModelDriver.Embed(&modelName, texts, apiConfig, modelConfig)
|
||||
if err != nil {
|
||||
@@ -1204,6 +1239,11 @@ func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName,
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
modelSchema, _ := dao.GetModelProviderManager().GetModelByName(providerName, modelName)
|
||||
if err := validateEmbeddingDimension(modelSchema, modelConfig.Dimension); err != nil {
|
||||
return nil, common.CodeBadRequest, err
|
||||
}
|
||||
|
||||
var response []modelModule.EmbeddingData
|
||||
response, err = newProviderInfo.Embed(&modelName, texts, apiConfig, modelConfig)
|
||||
if err != nil {
|
||||
|
||||
@@ -1 +1,70 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
)
|
||||
|
||||
func TestValidateEmbeddingDimension(t *testing.T) {
|
||||
maxDimension := 2048
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model *modelModule.Model
|
||||
requested int
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "allows unset requested dimension",
|
||||
model: &modelModule.Model{MaxDimension: &maxDimension, Dimensions: []int{256, 512}},
|
||||
requested: 0,
|
||||
},
|
||||
{
|
||||
name: "allows missing model schema",
|
||||
model: nil,
|
||||
requested: 256,
|
||||
},
|
||||
{
|
||||
name: "allows dimension listed in explicit options",
|
||||
model: &modelModule.Model{Name: "embedding-3", MaxDimension: &maxDimension, Dimensions: []int{256, 512, 1024, 2048}},
|
||||
requested: 1024,
|
||||
},
|
||||
{
|
||||
name: "rejects dimension not listed in explicit options",
|
||||
model: &modelModule.Model{Name: "embedding-3", MaxDimension: &maxDimension, Dimensions: []int{256, 512, 1024, 2048}},
|
||||
requested: 1536,
|
||||
wantErr: "supported dimensions",
|
||||
},
|
||||
{
|
||||
name: "allows custom dimension within max dimension",
|
||||
model: &modelModule.Model{Name: "flex-embedding", MaxDimension: &maxDimension},
|
||||
requested: 1536,
|
||||
},
|
||||
{
|
||||
name: "rejects custom dimension above max dimension",
|
||||
model: &modelModule.Model{Name: "flex-embedding", MaxDimension: &maxDimension},
|
||||
requested: 4096,
|
||||
wantErr: "max dimension",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateEmbeddingDimension(tt.model, tt.requested)
|
||||
if tt.wantErr == "" {
|
||||
if err != nil {
|
||||
t.Fatalf("validateEmbeddingDimension() error = %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
t.Fatalf("validateEmbeddingDimension() expected error containing %q", tt.wantErr)
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Fatalf("validateEmbeddingDimension() error = %v, want substring %q", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user