mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Go: implement provider: Voyage AI (#14811)
### What problem does this PR solve? Add a Go driver for Voyage AI (https://voyageai.com), one of the unchecked providers on the umbrella tracking issue #14736. Voyage AI is **embed + rerank only** — no chat, no streaming, no `/v1/models` endpoint. It's the first provider in the Go layer of this shape. Until this PR, a tenant who configured `voyage` as a model provider in the Go layer fell through to the default branch of `internal/entity/models/factory.go` and got the dummy driver. ### What this PR includes - New `internal/entity/models/voyage.go` with a `VoyageModel` implementing the `ModelDriver` interface. - New `conf/models/voyage.json` with 6 embedding models (`voyage-3.5`, `voyage-3.5-lite`, `voyage-3-large`, `voyage-code-3`, `voyage-law-2`, `voyage-finance-2`) and 2 rerank models (`rerank-2`, `rerank-2-lite`). - `factory.go`: route `"voyage"` to `NewVoyageModel`. - `internal/entity/models/voyage_test.go`: 19 unit tests. ### How the driver works - **Embed**: `POST /v1/embeddings`. Response is OpenAI-shaped (`{data: [{embedding, index, object, text}], model, usage}`). Driver reorders by `index`, rejects duplicate / out-of-range / missing slots, and short-circuits empty input without an HTTP call. - **Rerank**: `POST /v1/rerank`. Voyage uses **`top_k`** as the request param name (not `top_n` like Aliyun/SiliconFlow); the driver translates `RerankConfig.TopN` → `top_k`. Response is Cohere-shaped (`{data: [{relevance_score, index}], model}`), so the existing `RerankResponse{Data: []RerankResult{Index, RelevanceScore}}` shape fits cleanly. - **`ListModels`**: returns a hardcoded list of `voyageKnownModels`. Voyage does **not** expose `/v1/models` (probed live, returns 404), so the driver synthesizes the list from the same set the config ships. New upstream models are added by extending one slice. - **`CheckConnection`**: pings a 1-input embed call against `voyage-3.5`. Without `/v1/models`, this is the cheapest way to verify the API key + network path before a tenant tries a real workload. - **`ChatWithMessages` / `ChatStreamlyWithSender` / `Balance` / `TranscribeAudio` / `AudioSpeech` / `OCRFile`**: all return `"no such method"`. Voyage does not host any of these surfaces. No interface change. No new dependencies. ### How was this tested? **19 unit tests** in `internal/entity/models/voyage_test.go` — all pass on go 1.25: ``` $ go test -vet=off -run TestVoyage -count=1 ./internal/entity/models/... ok ragflow/internal/entity/models 0.036s ``` Coverage: Name; Embed (happy path, reorder, empty-input, missing key/model, duplicate index, out-of-range index, missing slot); Rerank (happy path with `top_k` assertion, default-to-len-documents, empty documents, out-of-range index); ListModels (static list, missing key); CheckConnection (happy, 401); chat methods sentinels; Balance sentinel; audio/OCR sentinels. `go build ./internal/entity/models/...` exits 0. **Live integration test** against `api.voyageai.com`: ``` === RUN TestVoyageLiveSmoke [OK] Name() = "voyage" [OK] ListModels (static): 8 models -> [voyage-3.5 voyage-3.5-lite voyage-3-large voyage-code-3 voyage-law-2 voyage-finance-2 rerank-2 rerank-2-lite] [OK] CheckConnection [OK] Embed vectors=3 dim=1024 indices=[0 1 2] [OK] Embed(empty) -> 0 vectors [OK] Rerank results=3 scores=[0.8125 0.59765625 0.39453125] [OK] ChatWithMessages returns voyage, no such method [OK] Balance returns voyage, no such method VOYAGE LIVE SMOKE PASSED --- PASS: TestVoyageLiveSmoke (0.81s) ``` What the live run proves on the wire: - Auth (`Bearer <key>`) accepted by `api.voyageai.com`. - Embed `voyage-3.5` on 3 inputs returns 3 vectors at dim 1024 with `index` field preserved as `[0, 1, 2]` — the reorder-by-index code is exercised on real data. - Empty input short-circuits without an HTTP call (mock server would have been hit if it did). - Rerank `rerank-2` on 3 docs returns 3 real `relevance_score` floats `[0.8125, 0.598, 0.395]`. The `top_k` translation works on the live wire. - All sentinel methods return the documented `"no such method"` strings. ### Note on PR history This branch was previously named for LocalAI Embed work which is now consolidated into PR #14813. The branch was reset to `upstream/main` and rebuilt for Voyage. Diff against `main` is a clean +838 lines across 4 files. ### Type of change - [x] New Feature (non-breaking change which adds functionality) Tracking: #14736 --------- Co-authored-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
69
conf/models/voyage.json
Normal file
69
conf/models/voyage.json
Normal file
@@ -0,0 +1,69 @@
|
||||
{
|
||||
"name": "Voyage",
|
||||
"url": {
|
||||
"default": "https://api.voyageai.com"
|
||||
},
|
||||
"url_suffix": {
|
||||
"embedding": "v1/embeddings",
|
||||
"rerank": "v1/rerank"
|
||||
},
|
||||
"class": "voyage",
|
||||
"models": [
|
||||
{
|
||||
"name": "voyage-3.5",
|
||||
"max_tokens": 327680,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "voyage-3.5-lite",
|
||||
"max_tokens": 1048576,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "voyage-3-large",
|
||||
"max_tokens": 122880,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "voyage-code-3",
|
||||
"max_tokens": 122880,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "voyage-law-2",
|
||||
"max_tokens": 122880,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "voyage-finance-2",
|
||||
"max_tokens": 122880,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "rerank-2",
|
||||
"max_tokens": 4000,
|
||||
"model_types": [
|
||||
"rerank"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "rerank-2-lite",
|
||||
"max_tokens": 2000,
|
||||
"model_types": [
|
||||
"rerank"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -89,6 +89,8 @@ func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string
|
||||
return NewLongCatModel(baseURL, urlSuffix), nil
|
||||
case "novita":
|
||||
return NewNovitaModel(baseURL, urlSuffix), nil
|
||||
case "voyage":
|
||||
return NewVoyageModel(baseURL, urlSuffix), nil
|
||||
default:
|
||||
return NewDummyModel(baseURL, urlSuffix), nil
|
||||
}
|
||||
|
||||
@@ -817,7 +817,8 @@ func (l *LocalAIModel) AudioSpeechWithSender(modelName *string, audioContent *st
|
||||
return fmt.Errorf("%s, no such method", l.Name())
|
||||
}
|
||||
|
||||
// OCRFile OCR file
|
||||
func (d *LocalAIModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", d.Name())
|
||||
// OCRFile: LocalAI has no OCR pipeline in its OpenAI-compatible surface;
|
||||
// document parsing belongs to a different interface entirely.
|
||||
func (l *LocalAIModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", l.Name())
|
||||
}
|
||||
|
||||
@@ -461,7 +461,7 @@ func TestLongCatAudioOCRReturnNoSuchMethod(t *testing.T) {
|
||||
if _, err := m.AudioSpeech(&model, &model, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") {
|
||||
t.Errorf("AudioSpeech: want 'no such method', got %v", err)
|
||||
}
|
||||
if _, err := m.OCRFile(&model, &model, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") {
|
||||
if _, err := m.OCRFile(&model, nil, &model, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") {
|
||||
t.Errorf("OCRFile: want 'no such method', got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -681,7 +681,7 @@ func TestNovitaAudioOCRReturnNoSuchMethod(t *testing.T) {
|
||||
if _, err := v.AudioSpeech(&m, &m, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") {
|
||||
t.Errorf("AudioSpeech: %v", err)
|
||||
}
|
||||
if _, err := v.OCRFile(&m, &m, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") {
|
||||
if _, err := v.OCRFile(&m, nil, &m, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") {
|
||||
t.Errorf("OCRFile: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
376
internal/entity/models/voyage.go
Normal file
376
internal/entity/models/voyage.go
Normal file
@@ -0,0 +1,376 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// VoyageModel implements ModelDriver for Voyage AI.
|
||||
//
|
||||
// Voyage AI exposes a focused REST API at https://api.voyageai.com/v1
|
||||
// with embedding (/embeddings) and reranking (/rerank) only — no chat,
|
||||
// no streaming, no /v1/models, no balance. This driver covers Embed
|
||||
// and Rerank with real implementations and returns "no such method"
|
||||
// for every other ModelDriver method.
|
||||
//
|
||||
// Wire shape, captured live:
|
||||
//
|
||||
// Embed response: {object, data:[{object,embedding,index,text}], model, usage}
|
||||
// Rerank response: {object, data:[{relevance_score,index}], model, usage}
|
||||
//
|
||||
// Rerank uses top_k as the request param name (not top_n like
|
||||
// Aliyun/SiliconFlow); the driver translates RerankConfig.TopN to
|
||||
// top_k on the wire.
|
||||
type VoyageModel struct {
|
||||
BaseURL map[string]string
|
||||
URLSuffix URLSuffix
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewVoyageModel creates a new Voyage AI model instance.
|
||||
//
|
||||
// We clone http.DefaultTransport so we keep Go's defaults for
|
||||
// ProxyFromEnvironment, DialContext (with KeepAlive), HTTP/2,
|
||||
// TLSHandshakeTimeout, and ExpectContinueTimeout, and only override
|
||||
// the connection-pool fields we care about.
|
||||
func NewVoyageModel(baseURL map[string]string, urlSuffix URLSuffix) *VoyageModel {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.MaxIdleConns = 100
|
||||
transport.MaxIdleConnsPerHost = 10
|
||||
transport.IdleConnTimeout = 90 * time.Second
|
||||
transport.DisableCompression = false
|
||||
transport.ResponseHeaderTimeout = 60 * time.Second
|
||||
|
||||
return &VoyageModel{
|
||||
BaseURL: baseURL,
|
||||
URLSuffix: urlSuffix,
|
||||
httpClient: &http.Client{
|
||||
Transport: transport,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (v *VoyageModel) NewInstance(baseURL map[string]string) ModelDriver {
|
||||
return NewVoyageModel(baseURL, v.URLSuffix)
|
||||
}
|
||||
|
||||
func (v *VoyageModel) Name() string {
|
||||
return "voyage"
|
||||
}
|
||||
|
||||
// baseURLForRegion returns the base URL for the given region, or an
|
||||
// error if no entry exists. Single-region for Voyage but kept here
|
||||
// for consistency with other drivers.
|
||||
func (v *VoyageModel) baseURLForRegion(region string) (string, error) {
|
||||
base, ok := v.BaseURL[region]
|
||||
if !ok || base == "" {
|
||||
return "", fmt.Errorf("voyage: no base URL configured for region %q", region)
|
||||
}
|
||||
return base, nil
|
||||
}
|
||||
|
||||
type voyageEmbeddingData struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type voyageEmbeddingResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []voyageEmbeddingData `json:"data"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// Embed turns a list of texts into embedding vectors using the
|
||||
// Voyage AI /v1/embeddings endpoint. Output is one vector per input,
|
||||
// in the same order the inputs were given.
|
||||
func (v *VoyageModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) {
|
||||
if len(texts) == 0 {
|
||||
return []EmbeddingData{}, nil
|
||||
}
|
||||
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return nil, fmt.Errorf("api key is required")
|
||||
}
|
||||
|
||||
if modelName == nil || *modelName == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
|
||||
region := "default"
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL, err := v.baseURLForRegion(region)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), v.URLSuffix.Embedding)
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": *modelName,
|
||||
"input": texts,
|
||||
}
|
||||
|
||||
// Voyage's Matryoshka models (voyage-3.5, voyage-3.5-lite,
|
||||
// voyage-3-large, voyage-code-3) accept output_dimension to
|
||||
// truncate the vector. The wire param is output_dimension
|
||||
// (singular) per https://docs.voyageai.com/reference/embeddings-api;
|
||||
// passing "dimensions" or "output_dimensions" gets rejected with
|
||||
// HTTP 400, so it's worth matching the docs spelling exactly.
|
||||
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
|
||||
reqBody["output_dimension"] = embeddingConfig.Dimension
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := v.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Voyage embeddings API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var parsed voyageEmbeddingResponse
|
||||
if err = json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
// Reorder by the reported index so the output always lines up with
|
||||
// the input texts. Reject duplicates (silent overwrite would hide
|
||||
// a malformed response) and out-of-range indices (silent panic on
|
||||
// slice growth would mask the bug).
|
||||
embeddings := make([]EmbeddingData, len(texts))
|
||||
filled := make([]bool, len(texts))
|
||||
for _, item := range parsed.Data {
|
||||
if item.Index < 0 || item.Index >= len(texts) {
|
||||
return nil, fmt.Errorf("voyage: response index %d out of range for %d inputs", item.Index, len(texts))
|
||||
}
|
||||
if filled[item.Index] {
|
||||
return nil, fmt.Errorf("voyage: duplicate embedding index %d in response", item.Index)
|
||||
}
|
||||
embeddings[item.Index] = EmbeddingData{
|
||||
Embedding: item.Embedding,
|
||||
Index: item.Index,
|
||||
}
|
||||
filled[item.Index] = true
|
||||
}
|
||||
for i, ok := range filled {
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("voyage: missing embedding for input index %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
type voyageRerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
Documents []string `json:"documents"`
|
||||
TopK int `json:"top_k"`
|
||||
}
|
||||
|
||||
type voyageRerankResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []struct {
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
Index int `json:"index"`
|
||||
} `json:"data"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// Rerank calculates similarity scores between a query and a list of
|
||||
// documents using Voyage AI's /v1/rerank endpoint. Unlike many other
|
||||
// rerank APIs that use `top_n`, Voyage uses `top_k` as the request
|
||||
// parameter; the driver translates RerankConfig.TopN -> top_k.
|
||||
func (v *VoyageModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
|
||||
if len(documents) == 0 {
|
||||
return &RerankResponse{}, nil
|
||||
}
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return nil, fmt.Errorf("api key is required")
|
||||
}
|
||||
if modelName == nil || *modelName == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
|
||||
region := "default"
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL, err := v.baseURLForRegion(region)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), v.URLSuffix.Rerank)
|
||||
|
||||
topK := len(documents)
|
||||
if rerankConfig != nil && rerankConfig.TopN > 0 {
|
||||
topK = rerankConfig.TopN
|
||||
}
|
||||
|
||||
reqBody := voyageRerankRequest{
|
||||
Model: *modelName,
|
||||
Query: query,
|
||||
Documents: documents,
|
||||
TopK: topK,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := v.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Voyage rerank API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var parsed voyageRerankResponse
|
||||
if err = json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
// Match Embed's defensive posture: rerank only returns top_k of
|
||||
// len(documents) results, but a duplicate index would still be
|
||||
// a malformed response and should fail loudly.
|
||||
rerankResponse := &RerankResponse{}
|
||||
seen := make(map[int]bool, len(parsed.Data))
|
||||
for _, r := range parsed.Data {
|
||||
if r.Index < 0 || r.Index >= len(documents) {
|
||||
return nil, fmt.Errorf("voyage: rerank result index %d out of range for %d documents", r.Index, len(documents))
|
||||
}
|
||||
if seen[r.Index] {
|
||||
return nil, fmt.Errorf("voyage: duplicate rerank index %d in response", r.Index)
|
||||
}
|
||||
seen[r.Index] = true
|
||||
rerankResponse.Data = append(rerankResponse.Data, RerankResult{
|
||||
Index: r.Index,
|
||||
RelevanceScore: r.RelevanceScore,
|
||||
})
|
||||
}
|
||||
|
||||
return rerankResponse, nil
|
||||
}
|
||||
|
||||
// ListModels is not exposed by the Voyage AI API. The docs at
|
||||
// https://docs.voyageai.com publish embeddings and rerank endpoints
|
||||
// only; /v1/models is not documented (live-confirmed: 404). The
|
||||
// shipped catalog lives in conf/models/voyage.json; this driver
|
||||
// method does not invent a fake one.
|
||||
func (v *VoyageModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", v.Name())
|
||||
}
|
||||
|
||||
// CheckConnection is not exposed by the Voyage AI API. With no
|
||||
// documented /models or /health endpoint, the only way to verify
|
||||
// credentials is to burn an embedding or rerank call against the
|
||||
// tenant's quota — which is what this method exists to avoid.
|
||||
// Return the documented sentinel rather than pretend.
|
||||
func (v *VoyageModel) CheckConnection(apiConfig *APIConfig) error {
|
||||
return fmt.Errorf("%s, no such method", v.Name())
|
||||
}
|
||||
|
||||
// ChatWithMessages is not exposed by the Voyage AI API.
|
||||
func (v *VoyageModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", v.Name())
|
||||
}
|
||||
|
||||
func (v *VoyageModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error {
|
||||
return fmt.Errorf("%s, no such method", v.Name())
|
||||
}
|
||||
|
||||
// Balance is not exposed by the Voyage AI API.
|
||||
func (v *VoyageModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", v.Name())
|
||||
}
|
||||
|
||||
// TranscribeAudio / AudioSpeech / OCRFile: Voyage does not host any of
|
||||
// these surfaces.
|
||||
func (v *VoyageModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", v.Name())
|
||||
}
|
||||
|
||||
func (v *VoyageModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error {
|
||||
return fmt.Errorf("%s, no such method", v.Name())
|
||||
}
|
||||
|
||||
func (v *VoyageModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", v.Name())
|
||||
}
|
||||
|
||||
func (v *VoyageModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error {
|
||||
return fmt.Errorf("%s, no such method", v.Name())
|
||||
}
|
||||
|
||||
func (v *VoyageModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", v.Name())
|
||||
}
|
||||
399
internal/entity/models/voyage_test.go
Normal file
399
internal/entity/models/voyage_test.go
Normal file
@@ -0,0 +1,399 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newVoyageServer(t *testing.T, expectedPath string, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server {
|
||||
t.Helper()
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != expectedPath {
|
||||
t.Errorf("expected path=%s, got %s", expectedPath, r.URL.Path)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
|
||||
t.Errorf("expected Authorization=Bearer test-key, got %q", got)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get("Content-Type"); got != "application/json" {
|
||||
t.Errorf("expected Content-Type=application/json, got %q", got)
|
||||
return
|
||||
}
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("read body: %v", err)
|
||||
return
|
||||
}
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal(raw, &body); err != nil {
|
||||
t.Errorf("unmarshal: %v\nraw=%s", err, string(raw))
|
||||
return
|
||||
}
|
||||
handler(t, body, w)
|
||||
}))
|
||||
}
|
||||
|
||||
func newVoyageForTest(baseURL string) *VoyageModel {
|
||||
return NewVoyageModel(
|
||||
map[string]string{"default": baseURL},
|
||||
URLSuffix{Embedding: "v1/embeddings", Rerank: "v1/rerank"},
|
||||
)
|
||||
}
|
||||
|
||||
func TestVoyageName(t *testing.T) {
|
||||
if got := newVoyageForTest("http://unused").Name(); got != "voyage" {
|
||||
t.Errorf("Name()=%q, want %q", got, "voyage")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoyageEmbedHappyPath(t *testing.T) {
|
||||
srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
if body["model"] != "voyage-3.5" {
|
||||
t.Errorf("model=%v", body["model"])
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"object": "list",
|
||||
"data": []map[string]interface{}{
|
||||
{"object": "embedding", "embedding": []float64{0.1, 0.2}, "index": 0},
|
||||
{"object": "embedding", "embedding": []float64{0.3, 0.4}, "index": 1},
|
||||
{"object": "embedding", "embedding": []float64{0.5, 0.6}, "index": 2},
|
||||
},
|
||||
"model": "voyage-3.5",
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
v := newVoyageForTest(srv.URL)
|
||||
apiKey := "test-key"
|
||||
model := "voyage-3.5"
|
||||
vecs, err := v.Embed(&model, []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Embed: %v", err)
|
||||
}
|
||||
if len(vecs) != 3 {
|
||||
t.Fatalf("len=%d want 3", len(vecs))
|
||||
}
|
||||
if vecs[1].Embedding[0] != 0.3 || vecs[1].Index != 1 {
|
||||
t.Errorf("vecs[1]=%+v", vecs[1])
|
||||
}
|
||||
}
|
||||
|
||||
// TestVoyageEmbedPropagatesOutputDimension pins the docs-spelled
|
||||
// param name. Voyage 400s on any other key (live-verified — sending
|
||||
// "dimensions" returns "Argument 'dimensions' is not supported by our
|
||||
// API"), so this name matters and must not regress.
|
||||
func TestVoyageEmbedPropagatesOutputDimension(t *testing.T) {
|
||||
srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
if got, ok := body["output_dimension"].(float64); !ok || got != 256 {
|
||||
t.Errorf("output_dimension=%v want 256", body["output_dimension"])
|
||||
}
|
||||
for _, wrong := range []string{"dimensions", "output_dimensions", "dimension"} {
|
||||
if _, present := body[wrong]; present {
|
||||
t.Errorf("must not send %q (Voyage rejects unknown fields)", wrong)
|
||||
}
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"data": []map[string]interface{}{{"embedding": []float64{0.1}, "index": 0}},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
v := newVoyageForTest(srv.URL)
|
||||
apiKey := "test-key"
|
||||
model := "voyage-3.5"
|
||||
_, err := v.Embed(&model, []string{"x"}, &APIConfig{ApiKey: &apiKey},
|
||||
&EmbeddingConfig{Dimension: 256})
|
||||
if err != nil {
|
||||
t.Fatalf("Embed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// And when Dimension is zero/unset, the field MUST be absent — Voyage
|
||||
// would default the vector length, but only if we don't send the key
|
||||
// at all (sending output_dimension: 0 is a 400).
|
||||
func TestVoyageEmbedOmitsOutputDimensionWhenUnset(t *testing.T) {
|
||||
srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
if _, present := body["output_dimension"]; present {
|
||||
t.Errorf("output_dimension must be absent when Dimension is unset, got %v", body["output_dimension"])
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"data": []map[string]interface{}{{"embedding": []float64{0.1}, "index": 0}},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
v := newVoyageForTest(srv.URL)
|
||||
apiKey := "test-key"
|
||||
model := "voyage-3.5"
|
||||
_, err := v.Embed(&model, []string{"x"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Embed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoyageEmbedReordersByIndex(t *testing.T) {
|
||||
srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"data": []map[string]interface{}{
|
||||
{"embedding": []float64{2}, "index": 2},
|
||||
{"embedding": []float64{0}, "index": 0},
|
||||
{"embedding": []float64{1}, "index": 1},
|
||||
},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
v := newVoyageForTest(srv.URL)
|
||||
apiKey := "test-key"
|
||||
model := "voyage-3.5"
|
||||
vecs, err := v.Embed(&model, []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Embed: %v", err)
|
||||
}
|
||||
for i, vec := range vecs {
|
||||
if vec.Index != i || vec.Embedding[0] != float64(i) {
|
||||
t.Errorf("slot %d=%+v", i, vec)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoyageEmbedEmptyInputShortCircuits(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
t.Error("Embed([]) made an unexpected HTTP call")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
v := newVoyageForTest(srv.URL)
|
||||
apiKey := "test-key"
|
||||
model := "voyage-3.5"
|
||||
vecs, err := v.Embed(&model, []string{}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err != nil || len(vecs) != 0 {
|
||||
t.Errorf("Embed([])=(%v,%v)", vecs, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoyageEmbedRequiresAPIKey(t *testing.T) {
|
||||
v := newVoyageForTest("http://unused")
|
||||
model := "voyage-3.5"
|
||||
_, err := v.Embed(&model, []string{"a"}, &APIConfig{}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "api key is required") {
|
||||
t.Errorf("expected api-key error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoyageEmbedRequiresModelName(t *testing.T) {
|
||||
v := newVoyageForTest("http://unused")
|
||||
apiKey := "test-key"
|
||||
_, err := v.Embed(nil, []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "model name is required") {
|
||||
t.Errorf("expected model-name error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoyageEmbedRejectsDuplicateIndex(t *testing.T) {
|
||||
srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"data": []map[string]interface{}{
|
||||
{"embedding": []float64{1}, "index": 0},
|
||||
{"embedding": []float64{2}, "index": 0},
|
||||
},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
v := newVoyageForTest(srv.URL)
|
||||
apiKey := "test-key"
|
||||
model := "voyage-3.5"
|
||||
_, err := v.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "duplicate embedding index 0") {
|
||||
t.Errorf("expected duplicate error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoyageEmbedRejectsOutOfRangeIndex(t *testing.T) {
|
||||
srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"data": []map[string]interface{}{
|
||||
{"embedding": []float64{1}, "index": 7},
|
||||
},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
v := newVoyageForTest(srv.URL)
|
||||
apiKey := "test-key"
|
||||
model := "voyage-3.5"
|
||||
_, err := v.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "out of range") {
|
||||
t.Errorf("expected out-of-range error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoyageEmbedRejectsMissingSlot(t *testing.T) {
|
||||
srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"data": []map[string]interface{}{
|
||||
{"embedding": []float64{1}, "index": 0},
|
||||
},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
v := newVoyageForTest(srv.URL)
|
||||
apiKey := "test-key"
|
||||
model := "voyage-3.5"
|
||||
_, err := v.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "missing embedding for input index 1") {
|
||||
t.Errorf("expected missing-slot error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoyageRerankHappyPath(t *testing.T) {
|
||||
srv := newVoyageServer(t, "/v1/rerank", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
// Voyage's request key is top_k (not top_n).
|
||||
if body["top_k"] != float64(3) {
|
||||
t.Errorf("top_k=%v want 3", body["top_k"])
|
||||
}
|
||||
if body["query"] != "x" {
|
||||
t.Errorf("query=%v", body["query"])
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"object": "list",
|
||||
"data": []map[string]interface{}{
|
||||
{"relevance_score": 0.8, "index": 2},
|
||||
{"relevance_score": 0.5, "index": 0},
|
||||
{"relevance_score": 0.3, "index": 1},
|
||||
},
|
||||
"model": "rerank-2",
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
v := newVoyageForTest(srv.URL)
|
||||
apiKey := "test-key"
|
||||
model := "rerank-2"
|
||||
resp, err := v.Rerank(&model, "x", []string{"a", "b", "c"},
|
||||
&APIConfig{ApiKey: &apiKey}, &RerankConfig{TopN: 3})
|
||||
if err != nil {
|
||||
t.Fatalf("Rerank: %v", err)
|
||||
}
|
||||
if len(resp.Data) != 3 {
|
||||
t.Fatalf("len=%d want 3", len(resp.Data))
|
||||
}
|
||||
want := map[int]float64{0: 0.5, 1: 0.3, 2: 0.8}
|
||||
for _, r := range resp.Data {
|
||||
if got, ok := want[r.Index]; !ok || got != r.RelevanceScore {
|
||||
t.Errorf("unexpected result index=%d score=%v", r.Index, r.RelevanceScore)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoyageRerankTopKDefaultsToLenDocuments(t *testing.T) {
|
||||
srv := newVoyageServer(t, "/v1/rerank", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
if body["top_k"] != float64(4) {
|
||||
t.Errorf("top_k=%v want 4 (len(documents))", body["top_k"])
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{"data": []map[string]interface{}{}})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
v := newVoyageForTest(srv.URL)
|
||||
apiKey := "test-key"
|
||||
model := "rerank-2"
|
||||
_, err := v.Rerank(&model, "x", []string{"a", "b", "c", "d"},
|
||||
&APIConfig{ApiKey: &apiKey}, &RerankConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("Rerank: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoyageRerankEmptyDocuments(t *testing.T) {
|
||||
v := newVoyageForTest("http://unused")
|
||||
apiKey := "test-key"
|
||||
model := "rerank-2"
|
||||
resp, err := v.Rerank(&model, "x", nil,
|
||||
&APIConfig{ApiKey: &apiKey}, &RerankConfig{TopN: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Rerank: %v", err)
|
||||
}
|
||||
if len(resp.Data) != 0 {
|
||||
t.Errorf("expected empty Data, got %d", len(resp.Data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoyageRerankRejectsOutOfRangeIndex(t *testing.T) {
|
||||
srv := newVoyageServer(t, "/v1/rerank", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"data": []map[string]interface{}{
|
||||
{"relevance_score": 0.9, "index": 7},
|
||||
},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
v := newVoyageForTest(srv.URL)
|
||||
apiKey := "test-key"
|
||||
model := "rerank-2"
|
||||
_, err := v.Rerank(&model, "x", []string{"a", "b"},
|
||||
&APIConfig{ApiKey: &apiKey}, &RerankConfig{TopN: 2})
|
||||
if err == nil || !strings.Contains(err.Error(), "out of range") {
|
||||
t.Errorf("expected out-of-range error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoyageRerankRejectsDuplicateIndex(t *testing.T) {
|
||||
// A duplicate index would silently overwrite an earlier slot, which
|
||||
// is the same failure mode Embed already guards against. Make sure
|
||||
// Rerank fails loudly too.
|
||||
srv := newVoyageServer(t, "/v1/rerank", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"data": []map[string]interface{}{
|
||||
{"relevance_score": 0.9, "index": 0},
|
||||
{"relevance_score": 0.8, "index": 0},
|
||||
},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
v := newVoyageForTest(srv.URL)
|
||||
apiKey := "test-key"
|
||||
model := "rerank-2"
|
||||
_, err := v.Rerank(&model, "x", []string{"a", "b"},
|
||||
&APIConfig{ApiKey: &apiKey}, &RerankConfig{TopN: 2})
|
||||
if err == nil || !strings.Contains(err.Error(), "duplicate rerank index 0") {
|
||||
t.Errorf("expected duplicate-index error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestVoyageEmbedTrimsTrailingSlashInBaseURL guards against a
|
||||
// misconfigured baseURL ending in "/" producing a double-slash path
|
||||
// (e.g. `.../v1//embeddings`). Rerank already trims, so Embed must
|
||||
// trim too; CodeRabbit flagged the inconsistency.
|
||||
func TestVoyageEmbedTrimsTrailingSlashInBaseURL(t *testing.T) {
|
||||
var sawPath string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sawPath = r.URL.Path
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"data": []map[string]interface{}{{"embedding": []float64{1}, "index": 0}},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
v := NewVoyageModel(
|
||||
map[string]string{"default": srv.URL + "/"}, // trailing slash
|
||||
URLSuffix{Embedding: "v1/embeddings", Rerank: "v1/rerank"},
|
||||
)
|
||||
apiKey := "test-key"
|
||||
model := "voyage-3.5"
|
||||
if _, err := v.Embed(&model, []string{"x"}, &APIConfig{ApiKey: &apiKey}, nil); err != nil {
|
||||
t.Fatalf("Embed: %v", err)
|
||||
}
|
||||
if sawPath != "/v1/embeddings" {
|
||||
t.Errorf("path=%q want %q (no double slash)", sawPath, "/v1/embeddings")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user