Fix(Go): prevent global state pollution in local model connection check (#14669)

### What problem does this PR solve?

1. **Fix Global State Pollution in Local Providers (Critical Bug):** -
Resolved a severe concurrency and architecture issue in
`model_service.go`. Previously, `ListSupportedModels` would permanently
overwrite the global provider singleton with a localized URL instance
(`driver.NewInstance`). This caused cross-request contamination in
multi-tenant environments.
- Fixed `CheckProviderConnection` for local models (LM Studio, vLLM,
Ollama). It now properly creates a localized driver copy and injects the
`base_url` before testing the connection, entirely eliminating the
false-positive `missing base URL` error without polluting the global
state.
2. **Implement `VolcEngine` Embeddings:** - Fully implemented the
`Encode` method for the `volcengine` provider, enabling text embedding
capabilities for VolcEngine models.
3. **Enhance Region Validation in `SiliconFlow`:** - Added a strict
empty string check (`*apiConfig.Region != ""`) alongside the existing
`nil` check when parsing regions. This ensures that if an empty string
is passed, the system safely falls back to the `"default"` region,
preventing malformed URL requests and `unsupported protocol scheme`
errors.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Haruko386
2026-05-08 15:54:27 +08:00
committed by GitHub
parent ee5ae6f1a4
commit 94f82acd03
8 changed files with 132 additions and 82 deletions

View File

@@ -5,7 +5,8 @@
},
"url_suffix": {
"chat": "chat/completions",
"files": "files"
"files": "files",
"embedding": "embeddings/multimodal"
},
"class": "volcengine",
"models": [
@@ -19,6 +20,13 @@
"default_value": true,
"clear_thinking": true
}
},
{
"name": "doubao-embedding-vision-250615",
"max_tokens": 131072,
"model_types": [
"embedding"
]
}
]
}

View File

@@ -1764,16 +1764,7 @@ func (c *RAGFlowClient) ChatToModel(cmd *Command) (ResponseIf, error) {
resp, err := c.HTTPClient.Request("POST", url, "web", nil, payload)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
return nil, fmt.Errorf("connection closed (EOF): upstream overloaded or proxy timeout: %w", err)
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return nil, fmt.Errorf("request timeout: model took too long to respond: %w", err)
}
return nil, fmt.Errorf("request failed: %w", err)
return nil, formatRequestError("Chat request", err)
}
if resp.StatusCode != 200 {
@@ -2407,3 +2398,21 @@ func (c *RAGFlowClient) RemoveChunks(cmd *Command) (ResponseIf, error) {
result.Duration = 0
return &result, nil
}
// formatRequestError Uniformly handle and format network errors in HTTP requests
func formatRequestError(action string, err error) error {
if err == nil {
return nil
}
var netErr net.Error
switch {
case errors.Is(err, io.EOF), errors.Is(err, io.ErrUnexpectedEOF):
return fmt.Errorf("%s failed - connection closed (EOF): upstream overloaded or proxy timeout: %w", action, err)
case errors.As(err, &netErr) && netErr.Timeout():
return fmt.Errorf("%s failed - request timeout: server took too long to respond: %w", action, err)
default:
return fmt.Errorf("%s failed: %w", action, err)
}
}

View File

@@ -442,27 +442,8 @@ func (l *LmStudioModel) Balance(apiConfig *APIConfig) (map[string]interface{}, e
return nil, fmt.Errorf("no such method")
}
// CheckConnection verifies that the configured LM Studio base URL
// is reachable and that the API key (if any) is accepted, by issuing
// a lightweight ListModels call. The empty-URL guard runs first so
// a user who has not yet set the local access address gets a clear,
// actionable error instead of a low-level transport message.
// CheckConnection verifies that the configured LM Studio base URL is reachable
func (l *LmStudioModel) CheckConnection(apiConfig *APIConfig) error {
var region = "default"
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
baseURL := l.BaseURL[region]
if baseURL == "" {
baseURL = l.BaseURL["default"]
}
if baseURL == "" {
return fmt.Errorf("missing base URL: please configure the local access address for LM Studio (e.g., http://127.0.0.1:1234/v1)")
}
if _, err := l.ListModels(apiConfig); err != nil {
return fmt.Errorf("connection check failed: %w", err)
}
return nil
_, err := l.ListModels(apiConfig)
return err
}

View File

@@ -440,27 +440,8 @@ func (o *OllamaModel) Balance(apiConfig *APIConfig) (map[string]interface{}, err
return nil, fmt.Errorf("no such method")
}
// CheckConnection verifies that the configured Ollama base URL is
// reachable and that the API key (if any) is accepted, by issuing a
// lightweight ListModels call. The empty-URL guard runs first so a
// user who has not yet set the local access address gets a clear,
// actionable error instead of a low-level transport message.
// CheckConnection verifies that the configured Ollama base URL is reachable
func (o *OllamaModel) CheckConnection(apiConfig *APIConfig) error {
var region = "default"
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
baseURL := o.BaseURL[region]
if baseURL == "" {
baseURL = o.BaseURL["default"]
}
if baseURL == "" {
return fmt.Errorf("missing base URL: please configure the local access address for Ollama (e.g., http://127.0.0.1:11434/v1)")
}
if _, err := o.ListModels(apiConfig); err != nil {
return fmt.Errorf("connection check failed: %w", err)
}
return nil
_, err := o.ListModels(apiConfig)
return err
}

View File

@@ -474,7 +474,7 @@ func (s *SiliconflowModel) Encode(modelName *string, texts []string, apiConfig *
func (z *SiliconflowModel) ListModels(apiConfig *APIConfig) ([]string, error) {
var region = "default"
if apiConfig.Region != nil {
if apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}

View File

@@ -455,29 +455,10 @@ func (z *VllmModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error
return nil, fmt.Errorf("no such method")
}
// CheckConnection verifies that the configured vLLM base URL is
// reachable and that the API key (if any) is accepted, by issuing a
// lightweight ListModels call. The empty-URL guard runs first so a
// user who has not yet set the local access address gets a clear,
// actionable error instead of a low-level transport message.
// CheckConnection verifies that the configured vLLM base URL is reachable
func (z *VllmModel) CheckConnection(apiConfig *APIConfig) error {
var region = "default"
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
baseURL := z.BaseURL[region]
if baseURL == "" {
baseURL = z.BaseURL["default"]
}
if baseURL == "" {
return fmt.Errorf("missing base URL: please configure the local access address for vLLM (e.g., http://127.0.0.1:8000/v1)")
}
if _, err := z.ListModels(apiConfig); err != nil {
return fmt.Errorf("connection check failed: %w", err)
}
return nil
_, err := z.ListModels(apiConfig)
return err
}
// Rerank calculates similarity scores between query and texts

View File

@@ -408,7 +408,86 @@ func (z *VolcEngine) ChatStreamlyWithSender(modelName string, messages []Message
// Encode encodes a list of texts into embeddings
func (z *VolcEngine) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
return nil, fmt.Errorf("not implemented")
if len(texts) == 0 {
return [][]float64{}, nil
}
var region = "default"
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Embedding)
embeddings := make([][]float64, len(texts))
for i, text := range texts {
reqBody := map[string]interface{}{
"model": *modelName,
"encoding_format": "float",
"input": []map[string]interface{}{
{
"type": "text",
"text": text,
},
},
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf(
"failed to marshal request: %w",
err,
)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := z.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
// Volcengine multimodal embedding response
type VolcengineEmbeddingResponse struct {
Data struct {
Embedding []float64 `json:"embedding"`
Object string `json:"object"`
} `json:"data"`
}
var result VolcengineEmbeddingResponse
if err = json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
if len(result.Data.Embedding) == 0 {
return nil, fmt.Errorf("empty embedding in response")
}
embeddings[i] = result.Data.Embedding
}
return embeddings, nil
}
// Rerank calculates similarity scores between query and texts

View File

@@ -199,15 +199,18 @@ func (m *ModelProviderService) ListSupportedModels(providerName, instanceName, u
apiConfig.Region = &region
apiConfig.ApiKey = &instance.APIKey
driver := providerInfo.ModelDriver
// For local deployed models
if baseURL, ok := extra["base_url"]; ok && baseURL != "" {
newURL := map[string]string{
region: baseURL,
}
providerInfo.ModelDriver = providerInfo.ModelDriver.NewInstance(newURL)
driver = driver.NewInstance(newURL)
}
return providerInfo.ModelDriver.ListModels(apiConfig)
return driver.ListModels(apiConfig)
}
func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName, apiKey, baseURL, region, userID string) (common.ErrorCode, error) {
@@ -455,7 +458,15 @@ func (m *ModelProviderService) CheckProviderConnection(providerName, instanceNam
apiConfig.Region = &region
apiConfig.ApiKey = &instance.APIKey
err = providerInfo.ModelDriver.CheckConnection(apiConfig)
driver := providerInfo.ModelDriver
if baseURL, ok := extra["base_url"]; ok && baseURL != "" {
newURL := map[string]string{
region: baseURL,
}
driver = driver.NewInstance(newURL)
}
err = driver.CheckConnection(apiConfig)
if err != nil {
return common.CodeServerError, err
}