diff --git a/internal/entity/models/openrouter.go b/internal/entity/models/openrouter.go index a48707e97e..1be3f49e56 100644 --- a/internal/entity/models/openrouter.go +++ b/internal/entity/models/openrouter.go @@ -351,10 +351,20 @@ func (o *OpenRouterModel) ChatStreamlyWithSender(modelName string, messages []Me return scanner.Err() } +type openrouterEmbeddingResponse struct { + Data []struct { + Index int `json:"index"` + Embedding []float64 `json:"embedding"` + } `json:"data"` +} + func (o *OpenRouterModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { if len(texts) == 0 { return [][]float64{}, nil } + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } var region = "default" if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { @@ -368,6 +378,10 @@ func (o *OpenRouterModel) Encode(modelName *string, texts []string, apiConfig *A "input": texts, } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["dimensions"] = embeddingConfig.Dimension + } + jsonData, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) @@ -398,52 +412,26 @@ func (o *OpenRouterModel) Encode(modelName *string, texts []string, apiConfig *A return nil, fmt.Errorf("OpenRouter embedding API error: status %d, body: %s", resp.StatusCode, string(body)) } - var result map[string]interface{} + var result openrouterEmbeddingResponse if err = json.Unmarshal(body, &result); err != nil { return nil, fmt.Errorf("failed to decode response: %w", err) } - dataObj, ok := result["data"].([]interface{}) - if !ok || len(dataObj) == 0 { - return nil, fmt.Errorf("OpenRouter embedding response contains no data: %s", string(body)) + if len(result.Data) != len(texts) { + return nil, fmt.Errorf("expected %d embeddings, got %d", len(texts), len(result.Data)) } embeddings := make([][]float64, len(texts)) - - for _, item := range dataObj { - dataMap, ok := item.(map[string]interface{}) - if !ok { - continue + seen := make([]bool, len(texts)) + for _, item := range result.Data { + if item.Index < 0 || item.Index >= len(texts) { + return nil, fmt.Errorf("embedding index %d out of range", item.Index) } - - indexFloat, ok := dataMap["index"].(float64) - if !ok { - continue + if seen[item.Index] { + return nil, fmt.Errorf("duplicate embedding index %d", item.Index) } - index := int(indexFloat) - - if index < 0 || index >= len(texts) { - continue - } - - embeddingSlice, ok := dataMap["embedding"].([]interface{}) - if !ok { - continue - } - - embedding := make([]float64, len(embeddingSlice)) - for j, v := range embeddingSlice { - switch val := v.(type) { - case float64: - embedding[j] = val - case float32: - embedding[j] = float64(val) - default: - return nil, fmt.Errorf("unexpected embedding value type") - } - } - - embeddings[index] = embedding + seen[item.Index] = true + embeddings[item.Index] = item.Embedding } return embeddings, nil