mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-01 08:15:44 +08:00
Fix(Go): prevent global state pollution in local model connection check (#14669)
### What problem does this PR solve? 1. **Fix Global State Pollution in Local Providers (Critical Bug):** - Resolved a severe concurrency and architecture issue in `model_service.go`. Previously, `ListSupportedModels` would permanently overwrite the global provider singleton with a localized URL instance (`driver.NewInstance`). This caused cross-request contamination in multi-tenant environments. - Fixed `CheckProviderConnection` for local models (LM Studio, vLLM, Ollama). It now properly creates a localized driver copy and injects the `base_url` before testing the connection, entirely eliminating the false-positive `missing base URL` error without polluting the global state. 2. **Implement `VolcEngine` Embeddings:** - Fully implemented the `Encode` method for the `volcengine` provider, enabling text embedding capabilities for VolcEngine models. 3. **Enhance Region Validation in `SiliconFlow`:** - Added a strict empty string check (`*apiConfig.Region != ""`) alongside the existing `nil` check when parsing regions. This ensures that if an empty string is passed, the system safely falls back to the `"default"` region, preventing malformed URL requests and `unsupported protocol scheme` errors. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -5,7 +5,8 @@
|
||||
},
|
||||
"url_suffix": {
|
||||
"chat": "chat/completions",
|
||||
"files": "files"
|
||||
"files": "files",
|
||||
"embedding": "embeddings/multimodal"
|
||||
},
|
||||
"class": "volcengine",
|
||||
"models": [
|
||||
@@ -19,6 +20,13 @@
|
||||
"default_value": true,
|
||||
"clear_thinking": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "doubao-embedding-vision-250615",
|
||||
"max_tokens": 131072,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1764,16 +1764,7 @@ func (c *RAGFlowClient) ChatToModel(cmd *Command) (ResponseIf, error) {
|
||||
|
||||
resp, err := c.HTTPClient.Request("POST", url, "web", nil, payload)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return nil, fmt.Errorf("connection closed (EOF): upstream overloaded or proxy timeout: %w", err)
|
||||
}
|
||||
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return nil, fmt.Errorf("request timeout: model took too long to respond: %w", err)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
return nil, formatRequestError("Chat request", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
@@ -2407,3 +2398,21 @@ func (c *RAGFlowClient) RemoveChunks(cmd *Command) (ResponseIf, error) {
|
||||
result.Duration = 0
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// formatRequestError Uniformly handle and format network errors in HTTP requests
|
||||
func formatRequestError(action string, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var netErr net.Error
|
||||
|
||||
switch {
|
||||
case errors.Is(err, io.EOF), errors.Is(err, io.ErrUnexpectedEOF):
|
||||
return fmt.Errorf("%s failed - connection closed (EOF): upstream overloaded or proxy timeout: %w", action, err)
|
||||
case errors.As(err, &netErr) && netErr.Timeout():
|
||||
return fmt.Errorf("%s failed - request timeout: server took too long to respond: %w", action, err)
|
||||
default:
|
||||
return fmt.Errorf("%s failed: %w", action, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -442,27 +442,8 @@ func (l *LmStudioModel) Balance(apiConfig *APIConfig) (map[string]interface{}, e
|
||||
return nil, fmt.Errorf("no such method")
|
||||
}
|
||||
|
||||
// CheckConnection verifies that the configured LM Studio base URL
|
||||
// is reachable and that the API key (if any) is accepted, by issuing
|
||||
// a lightweight ListModels call. The empty-URL guard runs first so
|
||||
// a user who has not yet set the local access address gets a clear,
|
||||
// actionable error instead of a low-level transport message.
|
||||
// CheckConnection verifies that the configured LM Studio base URL is reachable
|
||||
func (l *LmStudioModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
var region = "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL := l.BaseURL[region]
|
||||
if baseURL == "" {
|
||||
baseURL = l.BaseURL["default"]
|
||||
}
|
||||
if baseURL == "" {
|
||||
return fmt.Errorf("missing base URL: please configure the local access address for LM Studio (e.g., http://127.0.0.1:1234/v1)")
|
||||
}
|
||||
|
||||
if _, err := l.ListModels(apiConfig); err != nil {
|
||||
return fmt.Errorf("connection check failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
_, err := l.ListModels(apiConfig)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -440,27 +440,8 @@ func (o *OllamaModel) Balance(apiConfig *APIConfig) (map[string]interface{}, err
|
||||
return nil, fmt.Errorf("no such method")
|
||||
}
|
||||
|
||||
// CheckConnection verifies that the configured Ollama base URL is
|
||||
// reachable and that the API key (if any) is accepted, by issuing a
|
||||
// lightweight ListModels call. The empty-URL guard runs first so a
|
||||
// user who has not yet set the local access address gets a clear,
|
||||
// actionable error instead of a low-level transport message.
|
||||
// CheckConnection verifies that the configured Ollama base URL is reachable
|
||||
func (o *OllamaModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
var region = "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL := o.BaseURL[region]
|
||||
if baseURL == "" {
|
||||
baseURL = o.BaseURL["default"]
|
||||
}
|
||||
if baseURL == "" {
|
||||
return fmt.Errorf("missing base URL: please configure the local access address for Ollama (e.g., http://127.0.0.1:11434/v1)")
|
||||
}
|
||||
|
||||
if _, err := o.ListModels(apiConfig); err != nil {
|
||||
return fmt.Errorf("connection check failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
_, err := o.ListModels(apiConfig)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -474,7 +474,7 @@ func (s *SiliconflowModel) Encode(modelName *string, texts []string, apiConfig *
|
||||
|
||||
func (z *SiliconflowModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
var region = "default"
|
||||
if apiConfig.Region != nil {
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
|
||||
@@ -455,29 +455,10 @@ func (z *VllmModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error
|
||||
return nil, fmt.Errorf("no such method")
|
||||
}
|
||||
|
||||
// CheckConnection verifies that the configured vLLM base URL is
|
||||
// reachable and that the API key (if any) is accepted, by issuing a
|
||||
// lightweight ListModels call. The empty-URL guard runs first so a
|
||||
// user who has not yet set the local access address gets a clear,
|
||||
// actionable error instead of a low-level transport message.
|
||||
// CheckConnection verifies that the configured vLLM base URL is reachable
|
||||
func (z *VllmModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
var region = "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL := z.BaseURL[region]
|
||||
if baseURL == "" {
|
||||
baseURL = z.BaseURL["default"]
|
||||
}
|
||||
if baseURL == "" {
|
||||
return fmt.Errorf("missing base URL: please configure the local access address for vLLM (e.g., http://127.0.0.1:8000/v1)")
|
||||
}
|
||||
|
||||
if _, err := z.ListModels(apiConfig); err != nil {
|
||||
return fmt.Errorf("connection check failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
_, err := z.ListModels(apiConfig)
|
||||
return err
|
||||
}
|
||||
|
||||
// Rerank calculates similarity scores between query and texts
|
||||
|
||||
@@ -408,7 +408,86 @@ func (z *VolcEngine) ChatStreamlyWithSender(modelName string, messages []Message
|
||||
|
||||
// Encode encodes a list of texts into embeddings
|
||||
func (z *VolcEngine) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
if len(texts) == 0 {
|
||||
return [][]float64{}, nil
|
||||
}
|
||||
|
||||
var region = "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Embedding)
|
||||
|
||||
embeddings := make([][]float64, len(texts))
|
||||
|
||||
for i, text := range texts {
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": *modelName,
|
||||
"encoding_format": "float",
|
||||
"input": []map[string]interface{}{
|
||||
{
|
||||
"type": "text",
|
||||
"text": text,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to marshal request: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Volcengine multimodal embedding response
|
||||
type VolcengineEmbeddingResponse struct {
|
||||
Data struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Object string `json:"object"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
var result VolcengineEmbeddingResponse
|
||||
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Data.Embedding) == 0 {
|
||||
return nil, fmt.Errorf("empty embedding in response")
|
||||
}
|
||||
|
||||
embeddings[i] = result.Data.Embedding
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// Rerank calculates similarity scores between query and texts
|
||||
|
||||
@@ -199,15 +199,18 @@ func (m *ModelProviderService) ListSupportedModels(providerName, instanceName, u
|
||||
apiConfig.Region = ®ion
|
||||
apiConfig.ApiKey = &instance.APIKey
|
||||
|
||||
driver := providerInfo.ModelDriver
|
||||
|
||||
// For local deployed models
|
||||
if baseURL, ok := extra["base_url"]; ok && baseURL != "" {
|
||||
newURL := map[string]string{
|
||||
region: baseURL,
|
||||
}
|
||||
providerInfo.ModelDriver = providerInfo.ModelDriver.NewInstance(newURL)
|
||||
|
||||
driver = driver.NewInstance(newURL)
|
||||
}
|
||||
|
||||
return providerInfo.ModelDriver.ListModels(apiConfig)
|
||||
return driver.ListModels(apiConfig)
|
||||
}
|
||||
|
||||
func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName, apiKey, baseURL, region, userID string) (common.ErrorCode, error) {
|
||||
@@ -455,7 +458,15 @@ func (m *ModelProviderService) CheckProviderConnection(providerName, instanceNam
|
||||
apiConfig.Region = ®ion
|
||||
apiConfig.ApiKey = &instance.APIKey
|
||||
|
||||
err = providerInfo.ModelDriver.CheckConnection(apiConfig)
|
||||
driver := providerInfo.ModelDriver
|
||||
if baseURL, ok := extra["base_url"]; ok && baseURL != "" {
|
||||
newURL := map[string]string{
|
||||
region: baseURL,
|
||||
}
|
||||
driver = driver.NewInstance(newURL)
|
||||
}
|
||||
|
||||
err = driver.CheckConnection(apiConfig)
|
||||
if err != nil {
|
||||
return common.CodeServerError, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user