diff --git a/conf/models/volcengine.json b/conf/models/volcengine.json index e4200ce576..96a6004097 100644 --- a/conf/models/volcengine.json +++ b/conf/models/volcengine.json @@ -5,7 +5,8 @@ }, "url_suffix": { "chat": "chat/completions", - "files": "files" + "files": "files", + "embedding": "embeddings/multimodal" }, "class": "volcengine", "models": [ @@ -19,6 +20,13 @@ "default_value": true, "clear_thinking": true } + }, + { + "name": "doubao-embedding-vision-250615", + "max_tokens": 131072, + "model_types": [ + "embedding" + ] } ] } \ No newline at end of file diff --git a/internal/cli/user_command.go b/internal/cli/user_command.go index 82c1a72896..4adfaea488 100644 --- a/internal/cli/user_command.go +++ b/internal/cli/user_command.go @@ -1764,16 +1764,7 @@ func (c *RAGFlowClient) ChatToModel(cmd *Command) (ResponseIf, error) { resp, err := c.HTTPClient.Request("POST", url, "web", nil, payload) if err != nil { - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { - return nil, fmt.Errorf("connection closed (EOF): upstream overloaded or proxy timeout: %w", err) - } - - var netErr net.Error - if errors.As(err, &netErr) && netErr.Timeout() { - return nil, fmt.Errorf("request timeout: model took too long to respond: %w", err) - } - - return nil, fmt.Errorf("request failed: %w", err) + return nil, formatRequestError("Chat request", err) } if resp.StatusCode != 200 { @@ -2407,3 +2398,21 @@ func (c *RAGFlowClient) RemoveChunks(cmd *Command) (ResponseIf, error) { result.Duration = 0 return &result, nil } + +// formatRequestError Uniformly handle and format network errors in HTTP requests +func formatRequestError(action string, err error) error { + if err == nil { + return nil + } + + var netErr net.Error + + switch { + case errors.Is(err, io.EOF), errors.Is(err, io.ErrUnexpectedEOF): + return fmt.Errorf("%s failed - connection closed (EOF): upstream overloaded or proxy timeout: %w", action, err) + case errors.As(err, &netErr) && netErr.Timeout(): + return fmt.Errorf("%s failed - request timeout: server took too long to respond: %w", action, err) + default: + return fmt.Errorf("%s failed: %w", action, err) + } +} diff --git a/internal/entity/models/lmstudio.go b/internal/entity/models/lmstudio.go index 061203837e..b9d1fee277 100644 --- a/internal/entity/models/lmstudio.go +++ b/internal/entity/models/lmstudio.go @@ -442,27 +442,8 @@ func (l *LmStudioModel) Balance(apiConfig *APIConfig) (map[string]interface{}, e return nil, fmt.Errorf("no such method") } -// CheckConnection verifies that the configured LM Studio base URL -// is reachable and that the API key (if any) is accepted, by issuing -// a lightweight ListModels call. The empty-URL guard runs first so -// a user who has not yet set the local access address gets a clear, -// actionable error instead of a low-level transport message. +// CheckConnection verifies that the configured LM Studio base URL is reachable func (l *LmStudioModel) CheckConnection(apiConfig *APIConfig) error { - var region = "default" - if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { - region = *apiConfig.Region - } - - baseURL := l.BaseURL[region] - if baseURL == "" { - baseURL = l.BaseURL["default"] - } - if baseURL == "" { - return fmt.Errorf("missing base URL: please configure the local access address for LM Studio (e.g., http://127.0.0.1:1234/v1)") - } - - if _, err := l.ListModels(apiConfig); err != nil { - return fmt.Errorf("connection check failed: %w", err) - } - return nil + _, err := l.ListModels(apiConfig) + return err } diff --git a/internal/entity/models/ollama.go b/internal/entity/models/ollama.go index 4e936fd9d7..f2352bc6a8 100644 --- a/internal/entity/models/ollama.go +++ b/internal/entity/models/ollama.go @@ -440,27 +440,8 @@ func (o *OllamaModel) Balance(apiConfig *APIConfig) (map[string]interface{}, err return nil, fmt.Errorf("no such method") } -// CheckConnection verifies that the configured Ollama base URL is -// reachable and that the API key (if any) is accepted, by issuing a -// lightweight ListModels call. The empty-URL guard runs first so a -// user who has not yet set the local access address gets a clear, -// actionable error instead of a low-level transport message. +// CheckConnection verifies that the configured Ollama base URL is reachable func (o *OllamaModel) CheckConnection(apiConfig *APIConfig) error { - var region = "default" - if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { - region = *apiConfig.Region - } - - baseURL := o.BaseURL[region] - if baseURL == "" { - baseURL = o.BaseURL["default"] - } - if baseURL == "" { - return fmt.Errorf("missing base URL: please configure the local access address for Ollama (e.g., http://127.0.0.1:11434/v1)") - } - - if _, err := o.ListModels(apiConfig); err != nil { - return fmt.Errorf("connection check failed: %w", err) - } - return nil + _, err := o.ListModels(apiConfig) + return err } diff --git a/internal/entity/models/siliconflow.go b/internal/entity/models/siliconflow.go index 6c85d96abf..61a300ce69 100644 --- a/internal/entity/models/siliconflow.go +++ b/internal/entity/models/siliconflow.go @@ -474,7 +474,7 @@ func (s *SiliconflowModel) Encode(modelName *string, texts []string, apiConfig * func (z *SiliconflowModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" - if apiConfig.Region != nil { + if apiConfig.Region != nil && *apiConfig.Region != "" { region = *apiConfig.Region } diff --git a/internal/entity/models/vllm.go b/internal/entity/models/vllm.go index 8f6d1e19be..b1ffe578fe 100644 --- a/internal/entity/models/vllm.go +++ b/internal/entity/models/vllm.go @@ -455,29 +455,10 @@ func (z *VllmModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error return nil, fmt.Errorf("no such method") } -// CheckConnection verifies that the configured vLLM base URL is -// reachable and that the API key (if any) is accepted, by issuing a -// lightweight ListModels call. The empty-URL guard runs first so a -// user who has not yet set the local access address gets a clear, -// actionable error instead of a low-level transport message. +// CheckConnection verifies that the configured vLLM base URL is reachable func (z *VllmModel) CheckConnection(apiConfig *APIConfig) error { - var region = "default" - if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { - region = *apiConfig.Region - } - - baseURL := z.BaseURL[region] - if baseURL == "" { - baseURL = z.BaseURL["default"] - } - if baseURL == "" { - return fmt.Errorf("missing base URL: please configure the local access address for vLLM (e.g., http://127.0.0.1:8000/v1)") - } - - if _, err := z.ListModels(apiConfig); err != nil { - return fmt.Errorf("connection check failed: %w", err) - } - return nil + _, err := z.ListModels(apiConfig) + return err } // Rerank calculates similarity scores between query and texts diff --git a/internal/entity/models/volcengine.go b/internal/entity/models/volcengine.go index 6269ebef5a..8b7ee8dab4 100644 --- a/internal/entity/models/volcengine.go +++ b/internal/entity/models/volcengine.go @@ -408,7 +408,86 @@ func (z *VolcEngine) ChatStreamlyWithSender(modelName string, messages []Message // Encode encodes a list of texts into embeddings func (z *VolcEngine) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { - return nil, fmt.Errorf("not implemented") + if len(texts) == 0 { + return [][]float64{}, nil + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Embedding) + + embeddings := make([][]float64, len(texts)) + + for i, text := range texts { + + reqBody := map[string]interface{}{ + "model": *modelName, + "encoding_format": "float", + "input": []map[string]interface{}{ + { + "type": "text", + "text": text, + }, + }, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf( + "failed to marshal request: %w", + err, + ) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Volcengine multimodal embedding response + type VolcengineEmbeddingResponse struct { + Data struct { + Embedding []float64 `json:"embedding"` + Object string `json:"object"` + } `json:"data"` + } + + var result VolcengineEmbeddingResponse + + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + if len(result.Data.Embedding) == 0 { + return nil, fmt.Errorf("empty embedding in response") + } + + embeddings[i] = result.Data.Embedding + } + + return embeddings, nil } // Rerank calculates similarity scores between query and texts diff --git a/internal/service/model_service.go b/internal/service/model_service.go index 71abefb6fe..953a1b51cf 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -199,15 +199,18 @@ func (m *ModelProviderService) ListSupportedModels(providerName, instanceName, u apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey + driver := providerInfo.ModelDriver + // For local deployed models if baseURL, ok := extra["base_url"]; ok && baseURL != "" { newURL := map[string]string{ region: baseURL, } - providerInfo.ModelDriver = providerInfo.ModelDriver.NewInstance(newURL) + + driver = driver.NewInstance(newURL) } - return providerInfo.ModelDriver.ListModels(apiConfig) + return driver.ListModels(apiConfig) } func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName, apiKey, baseURL, region, userID string) (common.ErrorCode, error) { @@ -455,7 +458,15 @@ func (m *ModelProviderService) CheckProviderConnection(providerName, instanceNam apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - err = providerInfo.ModelDriver.CheckConnection(apiConfig) + driver := providerInfo.ModelDriver + if baseURL, ok := extra["base_url"]; ok && baseURL != "" { + newURL := map[string]string{ + region: baseURL, + } + driver = driver.NewInstance(newURL) + } + + err = driver.CheckConnection(apiConfig) if err != nil { return common.CodeServerError, err }