mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-05 10:58:34 +08:00
Go: add dimensions for list models and fix some embed-bug in providers (#15940)
### What problem does this PR solve? As title ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring
This commit is contained in:
@@ -399,6 +399,9 @@ func (a *AstraflowModel) Embed(modelName *string, texts []string, apiConfig *API
|
||||
"model": *modelName,
|
||||
"input": texts,
|
||||
}
|
||||
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
|
||||
reqBody["dimensions"] = embeddingConfig.Dimension
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
|
||||
@@ -100,6 +100,7 @@ func ParseListModel(modelList ModelList) []ListModelResponse {
|
||||
modelResponse.MaxTokens = modelEntity.MaxTokens
|
||||
modelResponse.ModelTypes = modelEntity.ModelTypes
|
||||
modelResponse.Thinking = modelEntity.Thinking
|
||||
modelResponse.Dimensions = modelEntity.Dimensions
|
||||
}
|
||||
|
||||
models = append(models, modelResponse)
|
||||
|
||||
@@ -372,6 +372,10 @@ func (c *CoHereModel) Embed(modelName *string, texts []string, apiConfig *APICon
|
||||
"input_type": "search_document",
|
||||
"embedding_types": []string{"float"},
|
||||
}
|
||||
// This is only available for embed-v4 and newer models. Possible values are 256, 512, 1024, and 1536. The default is 1536.
|
||||
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
|
||||
reqBody["output_dimension"] = embeddingConfig.Dimension
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
|
||||
@@ -399,6 +399,9 @@ func (m *MistralModel) Embed(modelName *string, texts []string, apiConfig *APICo
|
||||
"model": *modelName,
|
||||
"input": texts,
|
||||
}
|
||||
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
|
||||
reqBody["output_dimension"] = embeddingConfig.Dimension
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
|
||||
@@ -76,18 +76,21 @@ type replicatePrediction struct {
|
||||
URLs replicatePredictionURLs `json:"urls"`
|
||||
}
|
||||
|
||||
type replicateModelsResponse struct {
|
||||
Results []struct {
|
||||
Owner string `json:"owner"`
|
||||
Name string `json:"name"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
type replicateSSEEvent struct {
|
||||
event string
|
||||
data string
|
||||
}
|
||||
|
||||
type replicateModelList struct {
|
||||
Results []replicateModelSummary `json:"results"`
|
||||
}
|
||||
|
||||
type replicateModelSummary struct {
|
||||
ID string `json:"id"`
|
||||
Owner string `json:"owner"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) endpoint(apiConfig *APIConfig, suffix string) (string, error) {
|
||||
|
||||
baseURL, err := r.baseModel.GetBaseURL(apiConfig)
|
||||
@@ -538,35 +541,38 @@ func (r *ReplicateModel) ListModels(apiConfig *APIConfig) ([]ListModelResponse,
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result replicateModelsResponse
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
var modelList ModelList
|
||||
if err = json.Unmarshal(body, &modelList); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
models := make([]ListModelResponse, 0, len(result.Results))
|
||||
pm := GetProviderManager()
|
||||
for _, model := range result.Results {
|
||||
modelName := model.Name
|
||||
var modelResponse ListModelResponse
|
||||
var modelEntity *Model
|
||||
if pm != nil {
|
||||
modelEntity = pm.GetModelByNameOrAlias(modelName)
|
||||
}
|
||||
if model.Owner != "" {
|
||||
modelName = model.Name + "@" + model.Owner
|
||||
}
|
||||
modelResponse.Name = modelName
|
||||
if modelEntity != nil {
|
||||
modelResponse.MaxDimension = modelEntity.MaxDimension
|
||||
modelResponse.Dimensions = modelEntity.Dimensions
|
||||
modelResponse.MaxTokens = modelEntity.MaxTokens
|
||||
modelResponse.ModelTypes = modelEntity.ModelTypes
|
||||
modelResponse.Thinking = modelEntity.Thinking
|
||||
}
|
||||
|
||||
models = append(models, modelResponse)
|
||||
if modelList.Models != nil {
|
||||
return ParseListModel(modelList), nil
|
||||
}
|
||||
return models, nil
|
||||
|
||||
var replicateList replicateModelList
|
||||
if err = json.Unmarshal(body, &replicateList); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
if replicateList.Results == nil {
|
||||
return nil, fmt.Errorf("invalid models list format")
|
||||
}
|
||||
for _, model := range replicateList.Results {
|
||||
modelName := strings.TrimSpace(model.ID)
|
||||
if modelName == "" && model.Owner != "" && model.Name != "" {
|
||||
modelName = fmt.Sprintf("%s/%s", model.Owner, model.Name)
|
||||
}
|
||||
if modelName == "" {
|
||||
modelName = strings.TrimSpace(model.Name)
|
||||
}
|
||||
if modelName == "" {
|
||||
continue
|
||||
}
|
||||
modelList.Models = append(modelList.Models, DSModel{
|
||||
ID: modelName,
|
||||
})
|
||||
}
|
||||
|
||||
return ParseListModel(modelList), nil
|
||||
}
|
||||
|
||||
func (r *ReplicateModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
|
||||
@@ -444,6 +444,9 @@ func (s *SiliconflowModel) Embed(modelName *string, texts []string, apiConfig *A
|
||||
"model": modelName,
|
||||
"input": texts,
|
||||
}
|
||||
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
|
||||
reqBody["dimensions"] = embeddingConfig.Dimension
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
|
||||
@@ -479,6 +479,9 @@ func (v *VolcEngine) Embed(modelName *string, texts []string, apiConfig *APIConf
|
||||
},
|
||||
},
|
||||
}
|
||||
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
|
||||
reqBody["dimensions"] = embeddingConfig.Dimension
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
|
||||
@@ -254,6 +254,17 @@ func (m *ModelProviderService) ListSupportedModels(providerName, instanceName, u
|
||||
"model_types": model.ModelTypes,
|
||||
"thinking": model.Thinking,
|
||||
})
|
||||
modelData := map[string]interface{}{
|
||||
"name": model.Name,
|
||||
"dimension": model.MaxDimension,
|
||||
"max_tokens": model.MaxTokens,
|
||||
"model_types": model.ModelTypes,
|
||||
"thinking": model.Thinking,
|
||||
}
|
||||
if len(model.Dimensions) > 0 {
|
||||
modelData["dimensions"] = model.Dimensions
|
||||
}
|
||||
result = append(result, modelData)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user