Go: add dimensions for list models and fix some embed-bug in providers (#15940)

### What problem does this PR solve?

As title

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
This commit is contained in:
Haruko386
2026-06-11 19:18:49 +08:00
committed by GitHub
parent 92c4b7688b
commit 9c30557ef7
8 changed files with 67 additions and 33 deletions

View File

@@ -399,6 +399,9 @@ func (a *AstraflowModel) Embed(modelName *string, texts []string, apiConfig *API
"model": *modelName,
"input": texts,
}
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
reqBody["dimensions"] = embeddingConfig.Dimension
}
jsonData, err := json.Marshal(reqBody)
if err != nil {

View File

@@ -100,6 +100,7 @@ func ParseListModel(modelList ModelList) []ListModelResponse {
modelResponse.MaxTokens = modelEntity.MaxTokens
modelResponse.ModelTypes = modelEntity.ModelTypes
modelResponse.Thinking = modelEntity.Thinking
modelResponse.Dimensions = modelEntity.Dimensions
}
models = append(models, modelResponse)

View File

@@ -372,6 +372,10 @@ func (c *CoHereModel) Embed(modelName *string, texts []string, apiConfig *APICon
"input_type": "search_document",
"embedding_types": []string{"float"},
}
// This is only available for embed-v4 and newer models. Possible values are 256, 512, 1024, and 1536. The default is 1536.
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
reqBody["output_dimension"] = embeddingConfig.Dimension
}
jsonData, err := json.Marshal(reqBody)
if err != nil {

View File

@@ -399,6 +399,9 @@ func (m *MistralModel) Embed(modelName *string, texts []string, apiConfig *APICo
"model": *modelName,
"input": texts,
}
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
reqBody["output_dimension"] = embeddingConfig.Dimension
}
jsonData, err := json.Marshal(reqBody)
if err != nil {

View File

@@ -76,18 +76,21 @@ type replicatePrediction struct {
URLs replicatePredictionURLs `json:"urls"`
}
type replicateModelsResponse struct {
Results []struct {
Owner string `json:"owner"`
Name string `json:"name"`
} `json:"results"`
}
type replicateSSEEvent struct {
event string
data string
}
type replicateModelList struct {
Results []replicateModelSummary `json:"results"`
}
type replicateModelSummary struct {
ID string `json:"id"`
Owner string `json:"owner"`
Name string `json:"name"`
}
func (r *ReplicateModel) endpoint(apiConfig *APIConfig, suffix string) (string, error) {
baseURL, err := r.baseModel.GetBaseURL(apiConfig)
@@ -538,35 +541,38 @@ func (r *ReplicateModel) ListModels(apiConfig *APIConfig) ([]ListModelResponse,
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
var result replicateModelsResponse
if err = json.Unmarshal(body, &result); err != nil {
var modelList ModelList
if err = json.Unmarshal(body, &modelList); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
models := make([]ListModelResponse, 0, len(result.Results))
pm := GetProviderManager()
for _, model := range result.Results {
modelName := model.Name
var modelResponse ListModelResponse
var modelEntity *Model
if pm != nil {
modelEntity = pm.GetModelByNameOrAlias(modelName)
}
if model.Owner != "" {
modelName = model.Name + "@" + model.Owner
}
modelResponse.Name = modelName
if modelEntity != nil {
modelResponse.MaxDimension = modelEntity.MaxDimension
modelResponse.Dimensions = modelEntity.Dimensions
modelResponse.MaxTokens = modelEntity.MaxTokens
modelResponse.ModelTypes = modelEntity.ModelTypes
modelResponse.Thinking = modelEntity.Thinking
}
models = append(models, modelResponse)
if modelList.Models != nil {
return ParseListModel(modelList), nil
}
return models, nil
var replicateList replicateModelList
if err = json.Unmarshal(body, &replicateList); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
if replicateList.Results == nil {
return nil, fmt.Errorf("invalid models list format")
}
for _, model := range replicateList.Results {
modelName := strings.TrimSpace(model.ID)
if modelName == "" && model.Owner != "" && model.Name != "" {
modelName = fmt.Sprintf("%s/%s", model.Owner, model.Name)
}
if modelName == "" {
modelName = strings.TrimSpace(model.Name)
}
if modelName == "" {
continue
}
modelList.Models = append(modelList.Models, DSModel{
ID: modelName,
})
}
return ParseListModel(modelList), nil
}
func (r *ReplicateModel) CheckConnection(apiConfig *APIConfig) error {

View File

@@ -444,6 +444,9 @@ func (s *SiliconflowModel) Embed(modelName *string, texts []string, apiConfig *A
"model": modelName,
"input": texts,
}
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
reqBody["dimensions"] = embeddingConfig.Dimension
}
jsonData, err := json.Marshal(reqBody)
if err != nil {

View File

@@ -479,6 +479,9 @@ func (v *VolcEngine) Embed(modelName *string, texts []string, apiConfig *APIConf
},
},
}
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
reqBody["dimensions"] = embeddingConfig.Dimension
}
jsonData, err := json.Marshal(reqBody)
if err != nil {

View File

@@ -254,6 +254,17 @@ func (m *ModelProviderService) ListSupportedModels(providerName, instanceName, u
"model_types": model.ModelTypes,
"thinking": model.Thinking,
})
modelData := map[string]interface{}{
"name": model.Name,
"dimension": model.MaxDimension,
"max_tokens": model.MaxTokens,
"model_types": model.ModelTypes,
"thinking": model.Thinking,
}
if len(model.Dimensions) > 0 {
modelData["dimensions"] = model.Dimensions
}
result = append(result, modelData)
}
return result, nil
}