Go: implement Rerank in NVIDIA driver (#14778)

## Summary

- Replaces the `"no such method"` stub on `NvidiaModel.Rerank`
(`internal/entity/models/nvidia.go`) with a real implementation against
NVIDIA NIM's `/ranking` endpoint.
- Mirrors the existing Python `NvidiaRerank` class at
`rag/llm/rerank_model.py:149-190` for behavior parity: same
`passages`/`query.text`/`logit` payload shape; `top_n` set to
`len(documents)` so every input gets a score returned in original order
(the issue body's spec omitted `top_n`, which would cause silent data
loss).
- Adds the `"rerank": "ranking"` URL suffix and two NIM rerank model
entries (`nvidia/nv-rerankqa-mistral-4b-v3`,
`nvidia/llama-3.2-nv-rerankqa-1b-v2`) to `conf/models/nvidia.json` so
the picker exposes them.
- Follows the same shape as the recently merged Aliyun (#14676), Gitee
(#14656), and ZhipuAI (#14608) Rerank implementations: lowercase
per-driver request/response types, conversion to the project-wide
`RerankResponse{Data: []RerankResult}`, per-call `context.WithTimeout`
of 30s.

Closes #14720

## Test plan

- [x] `gofmt -l internal/entity/models/nvidia.go` — clean
- [x] `go vet ./internal/entity/models/...` — no new errors introduced
(the two pre-existing vet errors in `baidu.go:642` and
`openrouter.go:566` are unrelated to this PR)
- [x] `go build ./internal/entity/models/...` — succeeds
- [x] `python3 -c "import json;
json.load(open('conf/models/nvidia.json'))"` — JSON valid
- [ ] Live smoke test against NVIDIA NIM with a real API key (requires
reviewer with NIM credentials)

## Notes for reviewers

- The issue body suggested omitting `top_n`. The Python reference
includes it (`top_n: len(texts)`), and without it NVIDIA returns only
the default top-K rankings rather than scores for every input. This PR
follows the Python.
- The URL host is `integrate.api.nvidia.com` (kept consistent with the
existing chat/embeddings BaseURL in `nvidia.go`), not the legacy
`ai.api.nvidia.com` host the Python uses. NIM's unified endpoint accepts
the model names as-is, so no per-model URL transform is needed.
This commit is contained in:
Renzo
2026-05-11 11:21:16 +02:00
committed by GitHub
parent 9b3850339b
commit 39ee2fb120
3 changed files with 337 additions and 2 deletions

View File

@@ -6,7 +6,8 @@
"url_suffix": {
"chat": "chat/completions",
"models": "models",
"embedding": "embeddings"
"embedding": "embeddings",
"rerank": "ranking"
},
"class": "nvidia",
"models": [
@@ -396,6 +397,20 @@
"embedding"
]
},
{
"name": "nvidia/nv-rerankqa-mistral-4b-v3",
"max_tokens": 4096,
"model_types": [
"rerank"
]
},
{
"name": "nvidia/llama-3.2-nv-rerankqa-1b-v2",
"max_tokens": 4096,
"model_types": [
"rerank"
]
},
{
"name": "nvidia/nvidia-nemotron-nano-9b-v2",
"max_tokens": 131072,

View File

@@ -423,8 +423,133 @@ func (n NvidiaModel) Embed(modelName *string, texts []string, apiConfig *APIConf
return embeddings, nil
}
// nvidiaRerankRequest mirrors the NIM /ranking request shape:
// query is an object with a "text" field, passages is an array of
// objects each with a "text" field. truncate=END matches the Python
// NvidiaRerank reference at rag/llm/rerank_model.py.
type nvidiaRerankRequest struct {
Model string `json:"model"`
Query nvidiaRerankText `json:"query"`
Passages []nvidiaRerankText `json:"passages"`
Truncate string `json:"truncate,omitempty"`
TopN int `json:"top_n"`
}
type nvidiaRerankText struct {
Text string `json:"text"`
}
// nvidiaRerankResponse maps the NIM rankings array. Each entry pairs
// the original passage index with a logit score; the caller uses the
// index to restore original input order.
type nvidiaRerankResponse struct {
Rankings []struct {
Index int `json:"index"`
Logit float64 `json:"logit"`
} `json:"rankings"`
}
// Rerank scores documents against the query using an NVIDIA NIM
// reranking model. Mirrors the Python NvidiaRerank class in
// rag/llm/rerank_model.py for payload shape (passages/query/logit).
// Defaults top_n to len(documents) so the API returns a score per
// input; callers may shrink it via RerankConfig.TopN, in which case
// only the top RerankConfig.TopN entries come back. Returned
// RerankResult entries are in the API's ranking order; callers that
// need original-input order should sort by Index. Same return-shape
// contract as the Aliyun and ZhipuAI Rerank drivers.
func (n NvidiaModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
return nil, fmt.Errorf("no such method")
if len(documents) == 0 {
return &RerankResponse{}, nil
}
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
return nil, fmt.Errorf("api key is required")
}
if modelName == nil || *modelName == "" {
return nil, fmt.Errorf("model name is required")
}
region := "default"
if apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
baseURL := n.BaseURL[region]
if baseURL == "" {
baseURL = n.BaseURL["default"]
}
if baseURL == "" {
return nil, fmt.Errorf("nvidia: no base URL configured for region %q", region)
}
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), n.URLSuffix.Rerank)
topN := len(documents)
if rerankConfig != nil && rerankConfig.TopN > 0 && rerankConfig.TopN < topN {
topN = rerankConfig.TopN
}
passages := make([]nvidiaRerankText, len(documents))
for i, doc := range documents {
passages[i] = nvidiaRerankText{Text: doc}
}
reqBody := nvidiaRerankRequest{
Model: *modelName,
Query: nvidiaRerankText{Text: query},
Passages: passages,
Truncate: "END",
TopN: topN,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := n.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Nvidia rerank API error: %s, body: %s", resp.Status, string(body))
}
var parsed nvidiaRerankResponse
if err = json.Unmarshal(body, &parsed); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
rerankResponse := RerankResponse{Data: make([]RerankResult, 0, len(parsed.Rankings))}
for _, r := range parsed.Rankings {
if r.Index < 0 || r.Index >= len(documents) {
return nil, fmt.Errorf("unexpected rerank index %d for %d inputs", r.Index, len(documents))
}
rerankResponse.Data = append(rerankResponse.Data, RerankResult{
Index: r.Index,
RelevanceScore: r.Logit,
})
}
return &rerankResponse, nil
}
// ListModels calls /v1/models on the configured NVIDIA NIM base URL

View File

@@ -0,0 +1,195 @@
package models
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func newNvidiaRerankServer(t *testing.T, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server {
t.Helper()
// Use t.Errorf + return inside the handler goroutine; t.Fatalf would
// only Goexit the handler goroutine and the test would silently pass.
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST, got %s", r.Method)
return
}
if r.URL.Path != "/ranking" {
t.Errorf("expected path=/ranking, got %s", r.URL.Path)
return
}
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
t.Errorf("expected Authorization=Bearer test-key, got %q", got)
return
}
if got := r.Header.Get("Content-Type"); got != "application/json" {
t.Errorf("expected Content-Type=application/json, got %q", got)
return
}
raw, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("failed to read body: %v", err)
return
}
var body map[string]interface{}
if err := json.Unmarshal(raw, &body); err != nil {
t.Errorf("invalid JSON body: %v\n%s", err, string(raw))
return
}
handler(t, body, w)
}))
}
func newNvidiaModelForTest(baseURL string) *NvidiaModel {
return NewNvidiaModel(
map[string]string{"default": baseURL},
URLSuffix{Rerank: "ranking"},
)
}
func TestNvidiaRerankHappyPath(t *testing.T) {
srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
if body["model"] != "nvidia/nv-rerankqa-mistral-4b-v3" {
t.Errorf("expected model=nvidia/nv-rerankqa-mistral-4b-v3, got %v", body["model"])
}
query, ok := body["query"].(map[string]interface{})
if !ok || query["text"] != "What is RAPTOR?" {
t.Errorf("expected query.text=What is RAPTOR?, got %v", body["query"])
}
passages, ok := body["passages"].([]interface{})
if !ok || len(passages) != 3 {
t.Errorf("expected 3 passages, got %v", body["passages"])
return
}
if body["truncate"] != "END" {
t.Errorf("expected truncate=END, got %v", body["truncate"])
}
if body["top_n"] != float64(3) {
t.Errorf("expected top_n=3 (matching len(documents)), got %v", body["top_n"])
}
// Return rankings out of input order to verify Index preservation.
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"rankings": []map[string]interface{}{
{"index": 2, "logit": 9.5},
{"index": 0, "logit": 4.25},
{"index": 1, "logit": 7.8},
},
})
})
defer srv.Close()
model := newNvidiaModelForTest(srv.URL)
apiKey := "test-key"
modelName := "nvidia/nv-rerankqa-mistral-4b-v3"
resp, err := model.Rerank(
&modelName,
"What is RAPTOR?",
[]string{"doc-zero", "doc-one", "doc-two"},
&APIConfig{ApiKey: &apiKey},
&RerankConfig{},
)
if err != nil {
t.Fatalf("Rerank failed: %v", err)
}
if len(resp.Data) != 3 {
t.Fatalf("expected 3 results, got %d", len(resp.Data))
}
want := map[int]float64{0: 4.25, 1: 7.8, 2: 9.5}
for _, r := range resp.Data {
if got, ok := want[r.Index]; !ok || got != r.RelevanceScore {
t.Errorf("unexpected result Index=%d RelevanceScore=%v", r.Index, r.RelevanceScore)
}
}
}
func TestNvidiaRerankTopNClamp(t *testing.T) {
srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
if body["top_n"] != float64(2) {
t.Errorf("expected top_n clamp to RerankConfig.TopN=2, got %v", body["top_n"])
}
_ = json.NewEncoder(w).Encode(map[string]interface{}{"rankings": []map[string]interface{}{}})
})
defer srv.Close()
model := newNvidiaModelForTest(srv.URL)
apiKey := "test-key"
modelName := "nvidia/nv-rerankqa-mistral-4b-v3"
if _, err := model.Rerank(
&modelName, "q",
[]string{"a", "b", "c", "d"},
&APIConfig{ApiKey: &apiKey},
&RerankConfig{TopN: 2},
); err != nil {
t.Fatalf("Rerank failed: %v", err)
}
}
func TestNvidiaRerankEmptyDocuments(t *testing.T) {
model := newNvidiaModelForTest("http://unused")
apiKey := "test-key"
modelName := "nvidia/nv-rerankqa-mistral-4b-v3"
resp, err := model.Rerank(&modelName, "q", nil, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
if err != nil {
t.Fatalf("expected nil error for empty documents, got %v", err)
}
if len(resp.Data) != 0 {
t.Errorf("expected empty Data, got %d entries", len(resp.Data))
}
}
func TestNvidiaRerankRequiresAPIKey(t *testing.T) {
model := newNvidiaModelForTest("http://unused")
modelName := "nvidia/nv-rerankqa-mistral-4b-v3"
_, err := model.Rerank(&modelName, "q", []string{"a"}, &APIConfig{}, &RerankConfig{})
if err == nil || !strings.Contains(err.Error(), "api key is required") {
t.Errorf("expected api-key error, got %v", err)
}
}
func TestNvidiaRerankRequiresModelName(t *testing.T) {
model := newNvidiaModelForTest("http://unused")
apiKey := "test-key"
_, err := model.Rerank(nil, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
if err == nil || !strings.Contains(err.Error(), "model name is required") {
t.Errorf("expected model-name error, got %v", err)
}
}
func TestNvidiaRerankRejectsHTTPError(t *testing.T) {
srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
})
defer srv.Close()
model := newNvidiaModelForTest(srv.URL)
apiKey := "test-key"
modelName := "nvidia/nv-rerankqa-mistral-4b-v3"
_, err := model.Rerank(&modelName, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
if err == nil || !strings.Contains(err.Error(), "Nvidia rerank API error") {
t.Errorf("expected API error, got %v", err)
}
}
func TestNvidiaRerankRejectsOutOfRangeIndex(t *testing.T) {
srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"rankings": []map[string]interface{}{
{"index": 5, "logit": 1.0}, // out of range for 2-input request
},
})
})
defer srv.Close()
model := newNvidiaModelForTest(srv.URL)
apiKey := "test-key"
modelName := "nvidia/nv-rerankqa-mistral-4b-v3"
_, err := model.Rerank(&modelName, "q", []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
if err == nil || !strings.Contains(err.Error(), "unexpected rerank index") {
t.Errorf("expected out-of-range error, got %v", err)
}
}