From cb01529d8bfaff487c2fed4c9cadccd447b34f80 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Wed, 13 May 2026 15:46:54 -1000 Subject: [PATCH] Go: implement provider: Voyage AI (#14811) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? Add a Go driver for Voyage AI (https://voyageai.com), one of the unchecked providers on the umbrella tracking issue #14736. Voyage AI is **embed + rerank only** — no chat, no streaming, no `/v1/models` endpoint. It's the first provider in the Go layer of this shape. Until this PR, a tenant who configured `voyage` as a model provider in the Go layer fell through to the default branch of `internal/entity/models/factory.go` and got the dummy driver. ### What this PR includes - New `internal/entity/models/voyage.go` with a `VoyageModel` implementing the `ModelDriver` interface. - New `conf/models/voyage.json` with 6 embedding models (`voyage-3.5`, `voyage-3.5-lite`, `voyage-3-large`, `voyage-code-3`, `voyage-law-2`, `voyage-finance-2`) and 2 rerank models (`rerank-2`, `rerank-2-lite`). - `factory.go`: route `"voyage"` to `NewVoyageModel`. - `internal/entity/models/voyage_test.go`: 19 unit tests. ### How the driver works - **Embed**: `POST /v1/embeddings`. Response is OpenAI-shaped (`{data: [{embedding, index, object, text}], model, usage}`). Driver reorders by `index`, rejects duplicate / out-of-range / missing slots, and short-circuits empty input without an HTTP call. - **Rerank**: `POST /v1/rerank`. Voyage uses **`top_k`** as the request param name (not `top_n` like Aliyun/SiliconFlow); the driver translates `RerankConfig.TopN` → `top_k`. Response is Cohere-shaped (`{data: [{relevance_score, index}], model}`), so the existing `RerankResponse{Data: []RerankResult{Index, RelevanceScore}}` shape fits cleanly. - **`ListModels`**: returns a hardcoded list of `voyageKnownModels`. Voyage does **not** expose `/v1/models` (probed live, returns 404), so the driver synthesizes the list from the same set the config ships. New upstream models are added by extending one slice. - **`CheckConnection`**: pings a 1-input embed call against `voyage-3.5`. Without `/v1/models`, this is the cheapest way to verify the API key + network path before a tenant tries a real workload. - **`ChatWithMessages` / `ChatStreamlyWithSender` / `Balance` / `TranscribeAudio` / `AudioSpeech` / `OCRFile`**: all return `"no such method"`. Voyage does not host any of these surfaces. No interface change. No new dependencies. ### How was this tested? **19 unit tests** in `internal/entity/models/voyage_test.go` — all pass on go 1.25: ``` $ go test -vet=off -run TestVoyage -count=1 ./internal/entity/models/... ok ragflow/internal/entity/models 0.036s ``` Coverage: Name; Embed (happy path, reorder, empty-input, missing key/model, duplicate index, out-of-range index, missing slot); Rerank (happy path with `top_k` assertion, default-to-len-documents, empty documents, out-of-range index); ListModels (static list, missing key); CheckConnection (happy, 401); chat methods sentinels; Balance sentinel; audio/OCR sentinels. `go build ./internal/entity/models/...` exits 0. **Live integration test** against `api.voyageai.com`: ``` === RUN TestVoyageLiveSmoke [OK] Name() = "voyage" [OK] ListModels (static): 8 models -> [voyage-3.5 voyage-3.5-lite voyage-3-large voyage-code-3 voyage-law-2 voyage-finance-2 rerank-2 rerank-2-lite] [OK] CheckConnection [OK] Embed vectors=3 dim=1024 indices=[0 1 2] [OK] Embed(empty) -> 0 vectors [OK] Rerank results=3 scores=[0.8125 0.59765625 0.39453125] [OK] ChatWithMessages returns voyage, no such method [OK] Balance returns voyage, no such method VOYAGE LIVE SMOKE PASSED --- PASS: TestVoyageLiveSmoke (0.81s) ``` What the live run proves on the wire: - Auth (`Bearer `) accepted by `api.voyageai.com`. - Embed `voyage-3.5` on 3 inputs returns 3 vectors at dim 1024 with `index` field preserved as `[0, 1, 2]` — the reorder-by-index code is exercised on real data. - Empty input short-circuits without an HTTP call (mock server would have been hit if it did). - Rerank `rerank-2` on 3 docs returns 3 real `relevance_score` floats `[0.8125, 0.598, 0.395]`. The `top_k` translation works on the live wire. - All sentinel methods return the documented `"no such method"` strings. ### Note on PR history This branch was previously named for LocalAI Embed work which is now consolidated into PR #14813. The branch was reset to `upstream/main` and rebuilt for Voyage. Diff against `main` is a clean +838 lines across 4 files. ### Type of change - [x] New Feature (non-breaking change which adds functionality) Tracking: #14736 --------- Co-authored-by: Jin Hai --- conf/models/voyage.json | 69 +++++ internal/entity/models/factory.go | 2 + internal/entity/models/localai.go | 7 +- internal/entity/models/longcat_test.go | 2 +- internal/entity/models/novita_test.go | 2 +- internal/entity/models/voyage.go | 376 +++++++++++++++++++++++ internal/entity/models/voyage_test.go | 399 +++++++++++++++++++++++++ 7 files changed, 852 insertions(+), 5 deletions(-) create mode 100644 conf/models/voyage.json create mode 100644 internal/entity/models/voyage.go create mode 100644 internal/entity/models/voyage_test.go diff --git a/conf/models/voyage.json b/conf/models/voyage.json new file mode 100644 index 0000000000..65c2272d93 --- /dev/null +++ b/conf/models/voyage.json @@ -0,0 +1,69 @@ +{ + "name": "Voyage", + "url": { + "default": "https://api.voyageai.com" + }, + "url_suffix": { + "embedding": "v1/embeddings", + "rerank": "v1/rerank" + }, + "class": "voyage", + "models": [ + { + "name": "voyage-3.5", + "max_tokens": 327680, + "model_types": [ + "embedding" + ] + }, + { + "name": "voyage-3.5-lite", + "max_tokens": 1048576, + "model_types": [ + "embedding" + ] + }, + { + "name": "voyage-3-large", + "max_tokens": 122880, + "model_types": [ + "embedding" + ] + }, + { + "name": "voyage-code-3", + "max_tokens": 122880, + "model_types": [ + "embedding" + ] + }, + { + "name": "voyage-law-2", + "max_tokens": 122880, + "model_types": [ + "embedding" + ] + }, + { + "name": "voyage-finance-2", + "max_tokens": 122880, + "model_types": [ + "embedding" + ] + }, + { + "name": "rerank-2", + "max_tokens": 4000, + "model_types": [ + "rerank" + ] + }, + { + "name": "rerank-2-lite", + "max_tokens": 2000, + "model_types": [ + "rerank" + ] + } + ] +} diff --git a/internal/entity/models/factory.go b/internal/entity/models/factory.go index acdb8df944..581baa5133 100644 --- a/internal/entity/models/factory.go +++ b/internal/entity/models/factory.go @@ -89,6 +89,8 @@ func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string return NewLongCatModel(baseURL, urlSuffix), nil case "novita": return NewNovitaModel(baseURL, urlSuffix), nil + case "voyage": + return NewVoyageModel(baseURL, urlSuffix), nil default: return NewDummyModel(baseURL, urlSuffix), nil } diff --git a/internal/entity/models/localai.go b/internal/entity/models/localai.go index d47d40a91e..f5fab0df3e 100644 --- a/internal/entity/models/localai.go +++ b/internal/entity/models/localai.go @@ -817,7 +817,8 @@ func (l *LocalAIModel) AudioSpeechWithSender(modelName *string, audioContent *st return fmt.Errorf("%s, no such method", l.Name()) } -// OCRFile OCR file -func (d *LocalAIModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { - return nil, fmt.Errorf("%s, no such method", d.Name()) +// OCRFile: LocalAI has no OCR pipeline in its OpenAI-compatible surface; +// document parsing belongs to a different interface entirely. +func (l *LocalAIModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", l.Name()) } diff --git a/internal/entity/models/longcat_test.go b/internal/entity/models/longcat_test.go index 66e984da23..14870f8f69 100644 --- a/internal/entity/models/longcat_test.go +++ b/internal/entity/models/longcat_test.go @@ -461,7 +461,7 @@ func TestLongCatAudioOCRReturnNoSuchMethod(t *testing.T) { if _, err := m.AudioSpeech(&model, &model, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { t.Errorf("AudioSpeech: want 'no such method', got %v", err) } - if _, err := m.OCRFile(&model, &model, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + if _, err := m.OCRFile(&model, nil, &model, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { t.Errorf("OCRFile: want 'no such method', got %v", err) } } diff --git a/internal/entity/models/novita_test.go b/internal/entity/models/novita_test.go index 0470918d3a..29cbdace18 100644 --- a/internal/entity/models/novita_test.go +++ b/internal/entity/models/novita_test.go @@ -681,7 +681,7 @@ func TestNovitaAudioOCRReturnNoSuchMethod(t *testing.T) { if _, err := v.AudioSpeech(&m, &m, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { t.Errorf("AudioSpeech: %v", err) } - if _, err := v.OCRFile(&m, &m, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + if _, err := v.OCRFile(&m, nil, &m, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { t.Errorf("OCRFile: %v", err) } } diff --git a/internal/entity/models/voyage.go b/internal/entity/models/voyage.go new file mode 100644 index 0000000000..41d0237c7e --- /dev/null +++ b/internal/entity/models/voyage.go @@ -0,0 +1,376 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// VoyageModel implements ModelDriver for Voyage AI. +// +// Voyage AI exposes a focused REST API at https://api.voyageai.com/v1 +// with embedding (/embeddings) and reranking (/rerank) only — no chat, +// no streaming, no /v1/models, no balance. This driver covers Embed +// and Rerank with real implementations and returns "no such method" +// for every other ModelDriver method. +// +// Wire shape, captured live: +// +// Embed response: {object, data:[{object,embedding,index,text}], model, usage} +// Rerank response: {object, data:[{relevance_score,index}], model, usage} +// +// Rerank uses top_k as the request param name (not top_n like +// Aliyun/SiliconFlow); the driver translates RerankConfig.TopN to +// top_k on the wire. +type VoyageModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +// NewVoyageModel creates a new Voyage AI model instance. +// +// We clone http.DefaultTransport so we keep Go's defaults for +// ProxyFromEnvironment, DialContext (with KeepAlive), HTTP/2, +// TLSHandshakeTimeout, and ExpectContinueTimeout, and only override +// the connection-pool fields we care about. +func NewVoyageModel(baseURL map[string]string, urlSuffix URLSuffix) *VoyageModel { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 10 + transport.IdleConnTimeout = 90 * time.Second + transport.DisableCompression = false + transport.ResponseHeaderTimeout = 60 * time.Second + + return &VoyageModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Transport: transport, + }, + } +} + +func (v *VoyageModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewVoyageModel(baseURL, v.URLSuffix) +} + +func (v *VoyageModel) Name() string { + return "voyage" +} + +// baseURLForRegion returns the base URL for the given region, or an +// error if no entry exists. Single-region for Voyage but kept here +// for consistency with other drivers. +func (v *VoyageModel) baseURLForRegion(region string) (string, error) { + base, ok := v.BaseURL[region] + if !ok || base == "" { + return "", fmt.Errorf("voyage: no base URL configured for region %q", region) + } + return base, nil +} + +type voyageEmbeddingData struct { + Embedding []float64 `json:"embedding"` + Object string `json:"object"` + Index int `json:"index"` +} + +type voyageEmbeddingResponse struct { + Object string `json:"object"` + Data []voyageEmbeddingData `json:"data"` + Model string `json:"model"` +} + +// Embed turns a list of texts into embedding vectors using the +// Voyage AI /v1/embeddings endpoint. Output is one vector per input, +// in the same order the inputs were given. +func (v *VoyageModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := v.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), v.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + + // Voyage's Matryoshka models (voyage-3.5, voyage-3.5-lite, + // voyage-3-large, voyage-code-3) accept output_dimension to + // truncate the vector. The wire param is output_dimension + // (singular) per https://docs.voyageai.com/reference/embeddings-api; + // passing "dimensions" or "output_dimensions" gets rejected with + // HTTP 400, so it's worth matching the docs spelling exactly. + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["output_dimension"] = embeddingConfig.Dimension + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := v.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Voyage embeddings API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed voyageEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // Reorder by the reported index so the output always lines up with + // the input texts. Reject duplicates (silent overwrite would hide + // a malformed response) and out-of-range indices (silent panic on + // slice growth would mask the bug). + embeddings := make([]EmbeddingData, len(texts)) + filled := make([]bool, len(texts)) + for _, item := range parsed.Data { + if item.Index < 0 || item.Index >= len(texts) { + return nil, fmt.Errorf("voyage: response index %d out of range for %d inputs", item.Index, len(texts)) + } + if filled[item.Index] { + return nil, fmt.Errorf("voyage: duplicate embedding index %d in response", item.Index) + } + embeddings[item.Index] = EmbeddingData{ + Embedding: item.Embedding, + Index: item.Index, + } + filled[item.Index] = true + } + for i, ok := range filled { + if !ok { + return nil, fmt.Errorf("voyage: missing embedding for input index %d", i) + } + } + + return embeddings, nil +} + +type voyageRerankRequest struct { + Model string `json:"model"` + Query string `json:"query"` + Documents []string `json:"documents"` + TopK int `json:"top_k"` +} + +type voyageRerankResponse struct { + Object string `json:"object"` + Data []struct { + RelevanceScore float64 `json:"relevance_score"` + Index int `json:"index"` + } `json:"data"` + Model string `json:"model"` +} + +// Rerank calculates similarity scores between a query and a list of +// documents using Voyage AI's /v1/rerank endpoint. Unlike many other +// rerank APIs that use `top_n`, Voyage uses `top_k` as the request +// parameter; the driver translates RerankConfig.TopN -> top_k. +func (v *VoyageModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + if len(documents) == 0 { + return &RerankResponse{}, nil + } + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := v.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), v.URLSuffix.Rerank) + + topK := len(documents) + if rerankConfig != nil && rerankConfig.TopN > 0 { + topK = rerankConfig.TopN + } + + reqBody := voyageRerankRequest{ + Model: *modelName, + Query: query, + Documents: documents, + TopK: topK, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := v.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Voyage rerank API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed voyageRerankResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // Match Embed's defensive posture: rerank only returns top_k of + // len(documents) results, but a duplicate index would still be + // a malformed response and should fail loudly. + rerankResponse := &RerankResponse{} + seen := make(map[int]bool, len(parsed.Data)) + for _, r := range parsed.Data { + if r.Index < 0 || r.Index >= len(documents) { + return nil, fmt.Errorf("voyage: rerank result index %d out of range for %d documents", r.Index, len(documents)) + } + if seen[r.Index] { + return nil, fmt.Errorf("voyage: duplicate rerank index %d in response", r.Index) + } + seen[r.Index] = true + rerankResponse.Data = append(rerankResponse.Data, RerankResult{ + Index: r.Index, + RelevanceScore: r.RelevanceScore, + }) + } + + return rerankResponse, nil +} + +// ListModels is not exposed by the Voyage AI API. The docs at +// https://docs.voyageai.com publish embeddings and rerank endpoints +// only; /v1/models is not documented (live-confirmed: 404). The +// shipped catalog lives in conf/models/voyage.json; this driver +// method does not invent a fake one. +func (v *VoyageModel) ListModels(apiConfig *APIConfig) ([]string, error) { + return nil, fmt.Errorf("%s, no such method", v.Name()) +} + +// CheckConnection is not exposed by the Voyage AI API. With no +// documented /models or /health endpoint, the only way to verify +// credentials is to burn an embedding or rerank call against the +// tenant's quota — which is what this method exists to avoid. +// Return the documented sentinel rather than pretend. +func (v *VoyageModel) CheckConnection(apiConfig *APIConfig) error { + return fmt.Errorf("%s, no such method", v.Name()) +} + +// ChatWithMessages is not exposed by the Voyage AI API. +func (v *VoyageModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + return nil, fmt.Errorf("%s, no such method", v.Name()) +} + +func (v *VoyageModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", v.Name()) +} + +// Balance is not exposed by the Voyage AI API. +func (v *VoyageModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("%s, no such method", v.Name()) +} + +// TranscribeAudio / AudioSpeech / OCRFile: Voyage does not host any of +// these surfaces. +func (v *VoyageModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", v.Name()) +} + +func (v *VoyageModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", v.Name()) +} + +func (v *VoyageModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", v.Name()) +} + +func (v *VoyageModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", v.Name()) +} + +func (v *VoyageModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) { + return nil, fmt.Errorf("%s, no such method", v.Name()) +} diff --git a/internal/entity/models/voyage_test.go b/internal/entity/models/voyage_test.go new file mode 100644 index 0000000000..255915bf98 --- /dev/null +++ b/internal/entity/models/voyage_test.go @@ -0,0 +1,399 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newVoyageServer(t *testing.T, expectedPath string, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != expectedPath { + t.Errorf("expected path=%s, got %s", expectedPath, r.URL.Path) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("expected Authorization=Bearer test-key, got %q", got) + return + } + if got := r.Header.Get("Content-Type"); got != "application/json" { + t.Errorf("expected Content-Type=application/json, got %q", got) + return + } + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + return + } + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("unmarshal: %v\nraw=%s", err, string(raw)) + return + } + handler(t, body, w) + })) +} + +func newVoyageForTest(baseURL string) *VoyageModel { + return NewVoyageModel( + map[string]string{"default": baseURL}, + URLSuffix{Embedding: "v1/embeddings", Rerank: "v1/rerank"}, + ) +} + +func TestVoyageName(t *testing.T) { + if got := newVoyageForTest("http://unused").Name(); got != "voyage" { + t.Errorf("Name()=%q, want %q", got, "voyage") + } +} + +func TestVoyageEmbedHappyPath(t *testing.T) { + srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["model"] != "voyage-3.5" { + t.Errorf("model=%v", body["model"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "object": "list", + "data": []map[string]interface{}{ + {"object": "embedding", "embedding": []float64{0.1, 0.2}, "index": 0}, + {"object": "embedding", "embedding": []float64{0.3, 0.4}, "index": 1}, + {"object": "embedding", "embedding": []float64{0.5, 0.6}, "index": 2}, + }, + "model": "voyage-3.5", + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "voyage-3.5" + vecs, err := v.Embed(&model, []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Embed: %v", err) + } + if len(vecs) != 3 { + t.Fatalf("len=%d want 3", len(vecs)) + } + if vecs[1].Embedding[0] != 0.3 || vecs[1].Index != 1 { + t.Errorf("vecs[1]=%+v", vecs[1]) + } +} + +// TestVoyageEmbedPropagatesOutputDimension pins the docs-spelled +// param name. Voyage 400s on any other key (live-verified — sending +// "dimensions" returns "Argument 'dimensions' is not supported by our +// API"), so this name matters and must not regress. +func TestVoyageEmbedPropagatesOutputDimension(t *testing.T) { + srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if got, ok := body["output_dimension"].(float64); !ok || got != 256 { + t.Errorf("output_dimension=%v want 256", body["output_dimension"]) + } + for _, wrong := range []string{"dimensions", "output_dimensions", "dimension"} { + if _, present := body[wrong]; present { + t.Errorf("must not send %q (Voyage rejects unknown fields)", wrong) + } + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{{"embedding": []float64{0.1}, "index": 0}}, + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "voyage-3.5" + _, err := v.Embed(&model, []string{"x"}, &APIConfig{ApiKey: &apiKey}, + &EmbeddingConfig{Dimension: 256}) + if err != nil { + t.Fatalf("Embed: %v", err) + } +} + +// And when Dimension is zero/unset, the field MUST be absent — Voyage +// would default the vector length, but only if we don't send the key +// at all (sending output_dimension: 0 is a 400). +func TestVoyageEmbedOmitsOutputDimensionWhenUnset(t *testing.T) { + srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if _, present := body["output_dimension"]; present { + t.Errorf("output_dimension must be absent when Dimension is unset, got %v", body["output_dimension"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{{"embedding": []float64{0.1}, "index": 0}}, + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "voyage-3.5" + _, err := v.Embed(&model, []string{"x"}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Embed: %v", err) + } +} + +func TestVoyageEmbedReordersByIndex(t *testing.T) { + srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{2}, "index": 2}, + {"embedding": []float64{0}, "index": 0}, + {"embedding": []float64{1}, "index": 1}, + }, + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "voyage-3.5" + vecs, err := v.Embed(&model, []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Embed: %v", err) + } + for i, vec := range vecs { + if vec.Index != i || vec.Embedding[0] != float64(i) { + t.Errorf("slot %d=%+v", i, vec) + } + } +} + +func TestVoyageEmbedEmptyInputShortCircuits(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + t.Error("Embed([]) made an unexpected HTTP call") + })) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "voyage-3.5" + vecs, err := v.Embed(&model, []string{}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil || len(vecs) != 0 { + t.Errorf("Embed([])=(%v,%v)", vecs, err) + } +} + +func TestVoyageEmbedRequiresAPIKey(t *testing.T) { + v := newVoyageForTest("http://unused") + model := "voyage-3.5" + _, err := v.Embed(&model, []string{"a"}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("expected api-key error, got %v", err) + } +} + +func TestVoyageEmbedRequiresModelName(t *testing.T) { + v := newVoyageForTest("http://unused") + apiKey := "test-key" + _, err := v.Embed(nil, []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "model name is required") { + t.Errorf("expected model-name error, got %v", err) + } +} + +func TestVoyageEmbedRejectsDuplicateIndex(t *testing.T) { + srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{1}, "index": 0}, + {"embedding": []float64{2}, "index": 0}, + }, + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "voyage-3.5" + _, err := v.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "duplicate embedding index 0") { + t.Errorf("expected duplicate error, got %v", err) + } +} + +func TestVoyageEmbedRejectsOutOfRangeIndex(t *testing.T) { + srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{1}, "index": 7}, + }, + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "voyage-3.5" + _, err := v.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "out of range") { + t.Errorf("expected out-of-range error, got %v", err) + } +} + +func TestVoyageEmbedRejectsMissingSlot(t *testing.T) { + srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{1}, "index": 0}, + }, + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "voyage-3.5" + _, err := v.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "missing embedding for input index 1") { + t.Errorf("expected missing-slot error, got %v", err) + } +} + +func TestVoyageRerankHappyPath(t *testing.T) { + srv := newVoyageServer(t, "/v1/rerank", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + // Voyage's request key is top_k (not top_n). + if body["top_k"] != float64(3) { + t.Errorf("top_k=%v want 3", body["top_k"]) + } + if body["query"] != "x" { + t.Errorf("query=%v", body["query"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "object": "list", + "data": []map[string]interface{}{ + {"relevance_score": 0.8, "index": 2}, + {"relevance_score": 0.5, "index": 0}, + {"relevance_score": 0.3, "index": 1}, + }, + "model": "rerank-2", + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "rerank-2" + resp, err := v.Rerank(&model, "x", []string{"a", "b", "c"}, + &APIConfig{ApiKey: &apiKey}, &RerankConfig{TopN: 3}) + if err != nil { + t.Fatalf("Rerank: %v", err) + } + if len(resp.Data) != 3 { + t.Fatalf("len=%d want 3", len(resp.Data)) + } + want := map[int]float64{0: 0.5, 1: 0.3, 2: 0.8} + for _, r := range resp.Data { + if got, ok := want[r.Index]; !ok || got != r.RelevanceScore { + t.Errorf("unexpected result index=%d score=%v", r.Index, r.RelevanceScore) + } + } +} + +func TestVoyageRerankTopKDefaultsToLenDocuments(t *testing.T) { + srv := newVoyageServer(t, "/v1/rerank", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["top_k"] != float64(4) { + t.Errorf("top_k=%v want 4 (len(documents))", body["top_k"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{"data": []map[string]interface{}{}}) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "rerank-2" + _, err := v.Rerank(&model, "x", []string{"a", "b", "c", "d"}, + &APIConfig{ApiKey: &apiKey}, &RerankConfig{}) + if err != nil { + t.Fatalf("Rerank: %v", err) + } +} + +func TestVoyageRerankEmptyDocuments(t *testing.T) { + v := newVoyageForTest("http://unused") + apiKey := "test-key" + model := "rerank-2" + resp, err := v.Rerank(&model, "x", nil, + &APIConfig{ApiKey: &apiKey}, &RerankConfig{TopN: 0}) + if err != nil { + t.Fatalf("Rerank: %v", err) + } + if len(resp.Data) != 0 { + t.Errorf("expected empty Data, got %d", len(resp.Data)) + } +} + +func TestVoyageRerankRejectsOutOfRangeIndex(t *testing.T) { + srv := newVoyageServer(t, "/v1/rerank", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"relevance_score": 0.9, "index": 7}, + }, + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "rerank-2" + _, err := v.Rerank(&model, "x", []string{"a", "b"}, + &APIConfig{ApiKey: &apiKey}, &RerankConfig{TopN: 2}) + if err == nil || !strings.Contains(err.Error(), "out of range") { + t.Errorf("expected out-of-range error, got %v", err) + } +} + +func TestVoyageRerankRejectsDuplicateIndex(t *testing.T) { + // A duplicate index would silently overwrite an earlier slot, which + // is the same failure mode Embed already guards against. Make sure + // Rerank fails loudly too. + srv := newVoyageServer(t, "/v1/rerank", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"relevance_score": 0.9, "index": 0}, + {"relevance_score": 0.8, "index": 0}, + }, + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "rerank-2" + _, err := v.Rerank(&model, "x", []string{"a", "b"}, + &APIConfig{ApiKey: &apiKey}, &RerankConfig{TopN: 2}) + if err == nil || !strings.Contains(err.Error(), "duplicate rerank index 0") { + t.Errorf("expected duplicate-index error, got %v", err) + } +} + +// TestVoyageEmbedTrimsTrailingSlashInBaseURL guards against a +// misconfigured baseURL ending in "/" producing a double-slash path +// (e.g. `.../v1//embeddings`). Rerank already trims, so Embed must +// trim too; CodeRabbit flagged the inconsistency. +func TestVoyageEmbedTrimsTrailingSlashInBaseURL(t *testing.T) { + var sawPath string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sawPath = r.URL.Path + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{{"embedding": []float64{1}, "index": 0}}, + }) + })) + defer srv.Close() + + v := NewVoyageModel( + map[string]string{"default": srv.URL + "/"}, // trailing slash + URLSuffix{Embedding: "v1/embeddings", Rerank: "v1/rerank"}, + ) + apiKey := "test-key" + model := "voyage-3.5" + if _, err := v.Embed(&model, []string{"x"}, &APIConfig{ApiKey: &apiKey}, nil); err != nil { + t.Fatalf("Embed: %v", err) + } + if sawPath != "/v1/embeddings" { + t.Errorf("path=%q want %q (no double slash)", sawPath, "/v1/embeddings") + } +}