mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Go: implement Rerank in NVIDIA driver (#14778)
## Summary - Replaces the `"no such method"` stub on `NvidiaModel.Rerank` (`internal/entity/models/nvidia.go`) with a real implementation against NVIDIA NIM's `/ranking` endpoint. - Mirrors the existing Python `NvidiaRerank` class at `rag/llm/rerank_model.py:149-190` for behavior parity: same `passages`/`query.text`/`logit` payload shape; `top_n` set to `len(documents)` so every input gets a score returned in original order (the issue body's spec omitted `top_n`, which would cause silent data loss). - Adds the `"rerank": "ranking"` URL suffix and two NIM rerank model entries (`nvidia/nv-rerankqa-mistral-4b-v3`, `nvidia/llama-3.2-nv-rerankqa-1b-v2`) to `conf/models/nvidia.json` so the picker exposes them. - Follows the same shape as the recently merged Aliyun (#14676), Gitee (#14656), and ZhipuAI (#14608) Rerank implementations: lowercase per-driver request/response types, conversion to the project-wide `RerankResponse{Data: []RerankResult}`, per-call `context.WithTimeout` of 30s. Closes #14720 ## Test plan - [x] `gofmt -l internal/entity/models/nvidia.go` — clean - [x] `go vet ./internal/entity/models/...` — no new errors introduced (the two pre-existing vet errors in `baidu.go:642` and `openrouter.go:566` are unrelated to this PR) - [x] `go build ./internal/entity/models/...` — succeeds - [x] `python3 -c "import json; json.load(open('conf/models/nvidia.json'))"` — JSON valid - [ ] Live smoke test against NVIDIA NIM with a real API key (requires reviewer with NIM credentials) ## Notes for reviewers - The issue body suggested omitting `top_n`. The Python reference includes it (`top_n: len(texts)`), and without it NVIDIA returns only the default top-K rankings rather than scores for every input. This PR follows the Python. - The URL host is `integrate.api.nvidia.com` (kept consistent with the existing chat/embeddings BaseURL in `nvidia.go`), not the legacy `ai.api.nvidia.com` host the Python uses. NIM's unified endpoint accepts the model names as-is, so no per-model URL transform is needed.
This commit is contained in:
@@ -6,7 +6,8 @@
|
||||
"url_suffix": {
|
||||
"chat": "chat/completions",
|
||||
"models": "models",
|
||||
"embedding": "embeddings"
|
||||
"embedding": "embeddings",
|
||||
"rerank": "ranking"
|
||||
},
|
||||
"class": "nvidia",
|
||||
"models": [
|
||||
@@ -396,6 +397,20 @@
|
||||
"embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "nvidia/nv-rerankqa-mistral-4b-v3",
|
||||
"max_tokens": 4096,
|
||||
"model_types": [
|
||||
"rerank"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "nvidia/llama-3.2-nv-rerankqa-1b-v2",
|
||||
"max_tokens": 4096,
|
||||
"model_types": [
|
||||
"rerank"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "nvidia/nvidia-nemotron-nano-9b-v2",
|
||||
"max_tokens": 131072,
|
||||
|
||||
@@ -423,8 +423,133 @@ func (n NvidiaModel) Embed(modelName *string, texts []string, apiConfig *APIConf
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// nvidiaRerankRequest mirrors the NIM /ranking request shape:
|
||||
// query is an object with a "text" field, passages is an array of
|
||||
// objects each with a "text" field. truncate=END matches the Python
|
||||
// NvidiaRerank reference at rag/llm/rerank_model.py.
|
||||
type nvidiaRerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query nvidiaRerankText `json:"query"`
|
||||
Passages []nvidiaRerankText `json:"passages"`
|
||||
Truncate string `json:"truncate,omitempty"`
|
||||
TopN int `json:"top_n"`
|
||||
}
|
||||
|
||||
type nvidiaRerankText struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// nvidiaRerankResponse maps the NIM rankings array. Each entry pairs
|
||||
// the original passage index with a logit score; the caller uses the
|
||||
// index to restore original input order.
|
||||
type nvidiaRerankResponse struct {
|
||||
Rankings []struct {
|
||||
Index int `json:"index"`
|
||||
Logit float64 `json:"logit"`
|
||||
} `json:"rankings"`
|
||||
}
|
||||
|
||||
// Rerank scores documents against the query using an NVIDIA NIM
|
||||
// reranking model. Mirrors the Python NvidiaRerank class in
|
||||
// rag/llm/rerank_model.py for payload shape (passages/query/logit).
|
||||
// Defaults top_n to len(documents) so the API returns a score per
|
||||
// input; callers may shrink it via RerankConfig.TopN, in which case
|
||||
// only the top RerankConfig.TopN entries come back. Returned
|
||||
// RerankResult entries are in the API's ranking order; callers that
|
||||
// need original-input order should sort by Index. Same return-shape
|
||||
// contract as the Aliyun and ZhipuAI Rerank drivers.
|
||||
func (n NvidiaModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
|
||||
return nil, fmt.Errorf("no such method")
|
||||
if len(documents) == 0 {
|
||||
return &RerankResponse{}, nil
|
||||
}
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return nil, fmt.Errorf("api key is required")
|
||||
}
|
||||
if modelName == nil || *modelName == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
|
||||
region := "default"
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL := n.BaseURL[region]
|
||||
if baseURL == "" {
|
||||
baseURL = n.BaseURL["default"]
|
||||
}
|
||||
if baseURL == "" {
|
||||
return nil, fmt.Errorf("nvidia: no base URL configured for region %q", region)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), n.URLSuffix.Rerank)
|
||||
|
||||
topN := len(documents)
|
||||
if rerankConfig != nil && rerankConfig.TopN > 0 && rerankConfig.TopN < topN {
|
||||
topN = rerankConfig.TopN
|
||||
}
|
||||
|
||||
passages := make([]nvidiaRerankText, len(documents))
|
||||
for i, doc := range documents {
|
||||
passages[i] = nvidiaRerankText{Text: doc}
|
||||
}
|
||||
|
||||
reqBody := nvidiaRerankRequest{
|
||||
Model: *modelName,
|
||||
Query: nvidiaRerankText{Text: query},
|
||||
Passages: passages,
|
||||
Truncate: "END",
|
||||
TopN: topN,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := n.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Nvidia rerank API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var parsed nvidiaRerankResponse
|
||||
if err = json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
rerankResponse := RerankResponse{Data: make([]RerankResult, 0, len(parsed.Rankings))}
|
||||
for _, r := range parsed.Rankings {
|
||||
if r.Index < 0 || r.Index >= len(documents) {
|
||||
return nil, fmt.Errorf("unexpected rerank index %d for %d inputs", r.Index, len(documents))
|
||||
}
|
||||
rerankResponse.Data = append(rerankResponse.Data, RerankResult{
|
||||
Index: r.Index,
|
||||
RelevanceScore: r.Logit,
|
||||
})
|
||||
}
|
||||
|
||||
return &rerankResponse, nil
|
||||
}
|
||||
|
||||
// ListModels calls /v1/models on the configured NVIDIA NIM base URL
|
||||
|
||||
195
internal/entity/models/nvidia_rerank_test.go
Normal file
195
internal/entity/models/nvidia_rerank_test.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newNvidiaRerankServer(t *testing.T, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server {
|
||||
t.Helper()
|
||||
// Use t.Errorf + return inside the handler goroutine; t.Fatalf would
|
||||
// only Goexit the handler goroutine and the test would silently pass.
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
return
|
||||
}
|
||||
if r.URL.Path != "/ranking" {
|
||||
t.Errorf("expected path=/ranking, got %s", r.URL.Path)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
|
||||
t.Errorf("expected Authorization=Bearer test-key, got %q", got)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get("Content-Type"); got != "application/json" {
|
||||
t.Errorf("expected Content-Type=application/json, got %q", got)
|
||||
return
|
||||
}
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("failed to read body: %v", err)
|
||||
return
|
||||
}
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal(raw, &body); err != nil {
|
||||
t.Errorf("invalid JSON body: %v\n%s", err, string(raw))
|
||||
return
|
||||
}
|
||||
handler(t, body, w)
|
||||
}))
|
||||
}
|
||||
|
||||
func newNvidiaModelForTest(baseURL string) *NvidiaModel {
|
||||
return NewNvidiaModel(
|
||||
map[string]string{"default": baseURL},
|
||||
URLSuffix{Rerank: "ranking"},
|
||||
)
|
||||
}
|
||||
|
||||
func TestNvidiaRerankHappyPath(t *testing.T) {
|
||||
srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
if body["model"] != "nvidia/nv-rerankqa-mistral-4b-v3" {
|
||||
t.Errorf("expected model=nvidia/nv-rerankqa-mistral-4b-v3, got %v", body["model"])
|
||||
}
|
||||
query, ok := body["query"].(map[string]interface{})
|
||||
if !ok || query["text"] != "What is RAPTOR?" {
|
||||
t.Errorf("expected query.text=What is RAPTOR?, got %v", body["query"])
|
||||
}
|
||||
passages, ok := body["passages"].([]interface{})
|
||||
if !ok || len(passages) != 3 {
|
||||
t.Errorf("expected 3 passages, got %v", body["passages"])
|
||||
return
|
||||
}
|
||||
if body["truncate"] != "END" {
|
||||
t.Errorf("expected truncate=END, got %v", body["truncate"])
|
||||
}
|
||||
if body["top_n"] != float64(3) {
|
||||
t.Errorf("expected top_n=3 (matching len(documents)), got %v", body["top_n"])
|
||||
}
|
||||
// Return rankings out of input order to verify Index preservation.
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"rankings": []map[string]interface{}{
|
||||
{"index": 2, "logit": 9.5},
|
||||
{"index": 0, "logit": 4.25},
|
||||
{"index": 1, "logit": 7.8},
|
||||
},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
model := newNvidiaModelForTest(srv.URL)
|
||||
apiKey := "test-key"
|
||||
modelName := "nvidia/nv-rerankqa-mistral-4b-v3"
|
||||
resp, err := model.Rerank(
|
||||
&modelName,
|
||||
"What is RAPTOR?",
|
||||
[]string{"doc-zero", "doc-one", "doc-two"},
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
&RerankConfig{},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Rerank failed: %v", err)
|
||||
}
|
||||
if len(resp.Data) != 3 {
|
||||
t.Fatalf("expected 3 results, got %d", len(resp.Data))
|
||||
}
|
||||
want := map[int]float64{0: 4.25, 1: 7.8, 2: 9.5}
|
||||
for _, r := range resp.Data {
|
||||
if got, ok := want[r.Index]; !ok || got != r.RelevanceScore {
|
||||
t.Errorf("unexpected result Index=%d RelevanceScore=%v", r.Index, r.RelevanceScore)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNvidiaRerankTopNClamp(t *testing.T) {
|
||||
srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
if body["top_n"] != float64(2) {
|
||||
t.Errorf("expected top_n clamp to RerankConfig.TopN=2, got %v", body["top_n"])
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{"rankings": []map[string]interface{}{}})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
model := newNvidiaModelForTest(srv.URL)
|
||||
apiKey := "test-key"
|
||||
modelName := "nvidia/nv-rerankqa-mistral-4b-v3"
|
||||
if _, err := model.Rerank(
|
||||
&modelName, "q",
|
||||
[]string{"a", "b", "c", "d"},
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
&RerankConfig{TopN: 2},
|
||||
); err != nil {
|
||||
t.Fatalf("Rerank failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNvidiaRerankEmptyDocuments(t *testing.T) {
|
||||
model := newNvidiaModelForTest("http://unused")
|
||||
apiKey := "test-key"
|
||||
modelName := "nvidia/nv-rerankqa-mistral-4b-v3"
|
||||
resp, err := model.Rerank(&modelName, "q", nil, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error for empty documents, got %v", err)
|
||||
}
|
||||
if len(resp.Data) != 0 {
|
||||
t.Errorf("expected empty Data, got %d entries", len(resp.Data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNvidiaRerankRequiresAPIKey(t *testing.T) {
|
||||
model := newNvidiaModelForTest("http://unused")
|
||||
modelName := "nvidia/nv-rerankqa-mistral-4b-v3"
|
||||
_, err := model.Rerank(&modelName, "q", []string{"a"}, &APIConfig{}, &RerankConfig{})
|
||||
if err == nil || !strings.Contains(err.Error(), "api key is required") {
|
||||
t.Errorf("expected api-key error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNvidiaRerankRequiresModelName(t *testing.T) {
|
||||
model := newNvidiaModelForTest("http://unused")
|
||||
apiKey := "test-key"
|
||||
_, err := model.Rerank(nil, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
|
||||
if err == nil || !strings.Contains(err.Error(), "model name is required") {
|
||||
t.Errorf("expected model-name error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNvidiaRerankRejectsHTTPError(t *testing.T) {
|
||||
srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
model := newNvidiaModelForTest(srv.URL)
|
||||
apiKey := "test-key"
|
||||
modelName := "nvidia/nv-rerankqa-mistral-4b-v3"
|
||||
_, err := model.Rerank(&modelName, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
|
||||
if err == nil || !strings.Contains(err.Error(), "Nvidia rerank API error") {
|
||||
t.Errorf("expected API error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNvidiaRerankRejectsOutOfRangeIndex(t *testing.T) {
|
||||
srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"rankings": []map[string]interface{}{
|
||||
{"index": 5, "logit": 1.0}, // out of range for 2-input request
|
||||
},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
model := newNvidiaModelForTest(srv.URL)
|
||||
apiKey := "test-key"
|
||||
modelName := "nvidia/nv-rerankqa-mistral-4b-v3"
|
||||
_, err := model.Rerank(&modelName, "q", []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
|
||||
if err == nil || !strings.Contains(err.Error(), "unexpected rerank index") {
|
||||
t.Errorf("expected out-of-range error, got %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user