mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-03 17:21:59 +08:00
Fix(Go): make OpenRouter Encode fail loudly on malformed responses (#14717)
### What problem does this PR solve?
The OpenRouter `Encode` method silently swallowed malformed responses.
If a `data[]` item from the API was missing a field (`index`,
`embedding`, or unexpected shape), the loop did `continue` instead of
returning an error — leaving `nil` entries in the result slice. Callers
got back partial results with no indication anything went wrong, which
then crashes downstream consumers when they try to use a `nil` vector.
There were three concrete gaps:
- No count-mismatch check between `data` length and input texts (only
checked for empty)
- No duplicate-index detection (a duplicate would silently overwrite)
- Parse failures on individual items returned partial slices instead of
erroring
This PR replaces `map[string]interface{}` parsing with a typed
`openrouterEmbeddingResponse` struct and applies the same 3-layer
validation used in the other drivers (count mismatch → out-of-range
index → duplicate index), so any malformed response produces a clear
error instead of corrupted data.
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@@ -351,10 +351,20 @@ func (o *OpenRouterModel) ChatStreamlyWithSender(modelName string, messages []Me
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
type openrouterEmbeddingResponse struct {
|
||||
Data []struct {
|
||||
Index int `json:"index"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func (o *OpenRouterModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
if len(texts) == 0 {
|
||||
return [][]float64{}, nil
|
||||
}
|
||||
if modelName == nil || *modelName == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
|
||||
var region = "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
@@ -368,6 +378,10 @@ func (o *OpenRouterModel) Encode(modelName *string, texts []string, apiConfig *A
|
||||
"input": texts,
|
||||
}
|
||||
|
||||
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
|
||||
reqBody["dimensions"] = embeddingConfig.Dimension
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
@@ -398,52 +412,26 @@ func (o *OpenRouterModel) Encode(modelName *string, texts []string, apiConfig *A
|
||||
return nil, fmt.Errorf("OpenRouter embedding API error: status %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
var result openrouterEmbeddingResponse
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
dataObj, ok := result["data"].([]interface{})
|
||||
if !ok || len(dataObj) == 0 {
|
||||
return nil, fmt.Errorf("OpenRouter embedding response contains no data: %s", string(body))
|
||||
if len(result.Data) != len(texts) {
|
||||
return nil, fmt.Errorf("expected %d embeddings, got %d", len(texts), len(result.Data))
|
||||
}
|
||||
|
||||
embeddings := make([][]float64, len(texts))
|
||||
|
||||
for _, item := range dataObj {
|
||||
dataMap, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
seen := make([]bool, len(texts))
|
||||
for _, item := range result.Data {
|
||||
if item.Index < 0 || item.Index >= len(texts) {
|
||||
return nil, fmt.Errorf("embedding index %d out of range", item.Index)
|
||||
}
|
||||
|
||||
indexFloat, ok := dataMap["index"].(float64)
|
||||
if !ok {
|
||||
continue
|
||||
if seen[item.Index] {
|
||||
return nil, fmt.Errorf("duplicate embedding index %d", item.Index)
|
||||
}
|
||||
index := int(indexFloat)
|
||||
|
||||
if index < 0 || index >= len(texts) {
|
||||
continue
|
||||
}
|
||||
|
||||
embeddingSlice, ok := dataMap["embedding"].([]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
embedding := make([]float64, len(embeddingSlice))
|
||||
for j, v := range embeddingSlice {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
embedding[j] = val
|
||||
case float32:
|
||||
embedding[j] = float64(val)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected embedding value type")
|
||||
}
|
||||
}
|
||||
|
||||
embeddings[index] = embedding
|
||||
seen[item.Index] = true
|
||||
embeddings[item.Index] = item.Embedding
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
|
||||
Reference in New Issue
Block a user