diff --git a/internal/entity/models/astraflow.go b/internal/entity/models/astraflow.go index 46b5bc2e03..ffe2f909ad 100644 --- a/internal/entity/models/astraflow.go +++ b/internal/entity/models/astraflow.go @@ -399,6 +399,9 @@ func (a *AstraflowModel) Embed(modelName *string, texts []string, apiConfig *API "model": *modelName, "input": texts, } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["dimensions"] = embeddingConfig.Dimension + } jsonData, err := json.Marshal(reqBody) if err != nil { diff --git a/internal/entity/models/base_model.go b/internal/entity/models/base_model.go index e76f20f9cb..1e43849f08 100644 --- a/internal/entity/models/base_model.go +++ b/internal/entity/models/base_model.go @@ -100,6 +100,7 @@ func ParseListModel(modelList ModelList) []ListModelResponse { modelResponse.MaxTokens = modelEntity.MaxTokens modelResponse.ModelTypes = modelEntity.ModelTypes modelResponse.Thinking = modelEntity.Thinking + modelResponse.Dimensions = modelEntity.Dimensions } models = append(models, modelResponse) diff --git a/internal/entity/models/cohere.go b/internal/entity/models/cohere.go index 72ed987165..8c27b12571 100644 --- a/internal/entity/models/cohere.go +++ b/internal/entity/models/cohere.go @@ -372,6 +372,10 @@ func (c *CoHereModel) Embed(modelName *string, texts []string, apiConfig *APICon "input_type": "search_document", "embedding_types": []string{"float"}, } + // This is only available for embed-v4 and newer models. Possible values are 256, 512, 1024, and 1536. The default is 1536. + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["output_dimension"] = embeddingConfig.Dimension + } jsonData, err := json.Marshal(reqBody) if err != nil { diff --git a/internal/entity/models/mistral.go b/internal/entity/models/mistral.go index 19cea47625..0daaee0313 100644 --- a/internal/entity/models/mistral.go +++ b/internal/entity/models/mistral.go @@ -399,6 +399,9 @@ func (m *MistralModel) Embed(modelName *string, texts []string, apiConfig *APICo "model": *modelName, "input": texts, } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["output_dimension"] = embeddingConfig.Dimension + } jsonData, err := json.Marshal(reqBody) if err != nil { diff --git a/internal/entity/models/replicate.go b/internal/entity/models/replicate.go index f4a50c27d1..0f3396e287 100644 --- a/internal/entity/models/replicate.go +++ b/internal/entity/models/replicate.go @@ -76,18 +76,21 @@ type replicatePrediction struct { URLs replicatePredictionURLs `json:"urls"` } -type replicateModelsResponse struct { - Results []struct { - Owner string `json:"owner"` - Name string `json:"name"` - } `json:"results"` -} - type replicateSSEEvent struct { event string data string } +type replicateModelList struct { + Results []replicateModelSummary `json:"results"` +} + +type replicateModelSummary struct { + ID string `json:"id"` + Owner string `json:"owner"` + Name string `json:"name"` +} + func (r *ReplicateModel) endpoint(apiConfig *APIConfig, suffix string) (string, error) { baseURL, err := r.baseModel.GetBaseURL(apiConfig) @@ -538,35 +541,38 @@ func (r *ReplicateModel) ListModels(apiConfig *APIConfig) ([]ListModelResponse, return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) } - var result replicateModelsResponse - if err = json.Unmarshal(body, &result); err != nil { + var modelList ModelList + if err = json.Unmarshal(body, &modelList); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } - - models := make([]ListModelResponse, 0, len(result.Results)) - pm := GetProviderManager() - for _, model := range result.Results { - modelName := model.Name - var modelResponse ListModelResponse - var modelEntity *Model - if pm != nil { - modelEntity = pm.GetModelByNameOrAlias(modelName) - } - if model.Owner != "" { - modelName = model.Name + "@" + model.Owner - } - modelResponse.Name = modelName - if modelEntity != nil { - modelResponse.MaxDimension = modelEntity.MaxDimension - modelResponse.Dimensions = modelEntity.Dimensions - modelResponse.MaxTokens = modelEntity.MaxTokens - modelResponse.ModelTypes = modelEntity.ModelTypes - modelResponse.Thinking = modelEntity.Thinking - } - - models = append(models, modelResponse) + if modelList.Models != nil { + return ParseListModel(modelList), nil } - return models, nil + + var replicateList replicateModelList + if err = json.Unmarshal(body, &replicateList); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + if replicateList.Results == nil { + return nil, fmt.Errorf("invalid models list format") + } + for _, model := range replicateList.Results { + modelName := strings.TrimSpace(model.ID) + if modelName == "" && model.Owner != "" && model.Name != "" { + modelName = fmt.Sprintf("%s/%s", model.Owner, model.Name) + } + if modelName == "" { + modelName = strings.TrimSpace(model.Name) + } + if modelName == "" { + continue + } + modelList.Models = append(modelList.Models, DSModel{ + ID: modelName, + }) + } + + return ParseListModel(modelList), nil } func (r *ReplicateModel) CheckConnection(apiConfig *APIConfig) error { diff --git a/internal/entity/models/siliconflow.go b/internal/entity/models/siliconflow.go index 2d3b62a897..95c7922c91 100644 --- a/internal/entity/models/siliconflow.go +++ b/internal/entity/models/siliconflow.go @@ -444,6 +444,9 @@ func (s *SiliconflowModel) Embed(modelName *string, texts []string, apiConfig *A "model": modelName, "input": texts, } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["dimensions"] = embeddingConfig.Dimension + } jsonData, err := json.Marshal(reqBody) if err != nil { diff --git a/internal/entity/models/volcengine.go b/internal/entity/models/volcengine.go index 8031a1a1e4..7abd7e8061 100644 --- a/internal/entity/models/volcengine.go +++ b/internal/entity/models/volcengine.go @@ -479,6 +479,9 @@ func (v *VolcEngine) Embed(modelName *string, texts []string, apiConfig *APIConf }, }, } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["dimensions"] = embeddingConfig.Dimension + } jsonData, err := json.Marshal(reqBody) if err != nil { diff --git a/internal/service/model_service.go b/internal/service/model_service.go index 180075986b..da91475f9b 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -254,6 +254,17 @@ func (m *ModelProviderService) ListSupportedModels(providerName, instanceName, u "model_types": model.ModelTypes, "thinking": model.Thinking, }) + modelData := map[string]interface{}{ + "name": model.Name, + "dimension": model.MaxDimension, + "max_tokens": model.MaxTokens, + "model_types": model.ModelTypes, + "thinking": model.Thinking, + } + if len(model.Dimensions) > 0 { + modelData["dimensions"] = model.Dimensions + } + result = append(result, modelData) } return result, nil }