From 13e6554901d7ae0c2a987b63a312c003ded1edd7 Mon Sep 17 00:00:00 2001 From: Joseff Date: Mon, 11 May 2026 00:57:11 -0400 Subject: [PATCH] Fix(Go): make OpenRouter Encode fail loudly on malformed responses (#14717) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? The OpenRouter `Encode` method silently swallowed malformed responses. If a `data[]` item from the API was missing a field (`index`, `embedding`, or unexpected shape), the loop did `continue` instead of returning an error — leaving `nil` entries in the result slice. Callers got back partial results with no indication anything went wrong, which then crashes downstream consumers when they try to use a `nil` vector. There were three concrete gaps: - No count-mismatch check between `data` length and input texts (only checked for empty) - No duplicate-index detection (a duplicate would silently overwrite) - Parse failures on individual items returned partial slices instead of erroring This PR replaces `map[string]interface{}` parsing with a typed `openrouterEmbeddingResponse` struct and applies the same 3-layer validation used in the other drivers (count mismatch → out-of-range index → duplicate index), so any malformed response produces a clear error instead of corrupted data. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- internal/entity/models/openrouter.go | 62 +++++++++++----------------- 1 file changed, 25 insertions(+), 37 deletions(-) diff --git a/internal/entity/models/openrouter.go b/internal/entity/models/openrouter.go index a48707e97e..1be3f49e56 100644 --- a/internal/entity/models/openrouter.go +++ b/internal/entity/models/openrouter.go @@ -351,10 +351,20 @@ func (o *OpenRouterModel) ChatStreamlyWithSender(modelName string, messages []Me return scanner.Err() } +type openrouterEmbeddingResponse struct { + Data []struct { + Index int `json:"index"` + Embedding []float64 `json:"embedding"` + } `json:"data"` +} + func (o *OpenRouterModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { if len(texts) == 0 { return [][]float64{}, nil } + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } var region = "default" if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { @@ -368,6 +378,10 @@ func (o *OpenRouterModel) Encode(modelName *string, texts []string, apiConfig *A "input": texts, } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["dimensions"] = embeddingConfig.Dimension + } + jsonData, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) @@ -398,52 +412,26 @@ func (o *OpenRouterModel) Encode(modelName *string, texts []string, apiConfig *A return nil, fmt.Errorf("OpenRouter embedding API error: status %d, body: %s", resp.StatusCode, string(body)) } - var result map[string]interface{} + var result openrouterEmbeddingResponse if err = json.Unmarshal(body, &result); err != nil { return nil, fmt.Errorf("failed to decode response: %w", err) } - dataObj, ok := result["data"].([]interface{}) - if !ok || len(dataObj) == 0 { - return nil, fmt.Errorf("OpenRouter embedding response contains no data: %s", string(body)) + if len(result.Data) != len(texts) { + return nil, fmt.Errorf("expected %d embeddings, got %d", len(texts), len(result.Data)) } embeddings := make([][]float64, len(texts)) - - for _, item := range dataObj { - dataMap, ok := item.(map[string]interface{}) - if !ok { - continue + seen := make([]bool, len(texts)) + for _, item := range result.Data { + if item.Index < 0 || item.Index >= len(texts) { + return nil, fmt.Errorf("embedding index %d out of range", item.Index) } - - indexFloat, ok := dataMap["index"].(float64) - if !ok { - continue + if seen[item.Index] { + return nil, fmt.Errorf("duplicate embedding index %d", item.Index) } - index := int(indexFloat) - - if index < 0 || index >= len(texts) { - continue - } - - embeddingSlice, ok := dataMap["embedding"].([]interface{}) - if !ok { - continue - } - - embedding := make([]float64, len(embeddingSlice)) - for j, v := range embeddingSlice { - switch val := v.(type) { - case float64: - embedding[j] = val - case float32: - embedding[j] = float64(val) - default: - return nil, fmt.Errorf("unexpected embedding value type") - } - } - - embeddings[index] = embedding + seen[item.Index] = true + embeddings[item.Index] = item.Embedding } return embeddings, nil