diff --git a/conf/models/nvidia.json b/conf/models/nvidia.json index d07f12e4d6..9f2f9a415d 100644 --- a/conf/models/nvidia.json +++ b/conf/models/nvidia.json @@ -6,7 +6,8 @@ "url_suffix": { "chat": "chat/completions", "models": "models", - "embedding": "embeddings" + "embedding": "embeddings", + "rerank": "ranking" }, "class": "nvidia", "models": [ @@ -396,6 +397,20 @@ "embedding" ] }, + { + "name": "nvidia/nv-rerankqa-mistral-4b-v3", + "max_tokens": 4096, + "model_types": [ + "rerank" + ] + }, + { + "name": "nvidia/llama-3.2-nv-rerankqa-1b-v2", + "max_tokens": 4096, + "model_types": [ + "rerank" + ] + }, { "name": "nvidia/nvidia-nemotron-nano-9b-v2", "max_tokens": 131072, diff --git a/internal/entity/models/nvidia.go b/internal/entity/models/nvidia.go index fe50dcd425..88029dac15 100644 --- a/internal/entity/models/nvidia.go +++ b/internal/entity/models/nvidia.go @@ -423,8 +423,133 @@ func (n NvidiaModel) Embed(modelName *string, texts []string, apiConfig *APIConf return embeddings, nil } +// nvidiaRerankRequest mirrors the NIM /ranking request shape: +// query is an object with a "text" field, passages is an array of +// objects each with a "text" field. truncate=END matches the Python +// NvidiaRerank reference at rag/llm/rerank_model.py. +type nvidiaRerankRequest struct { + Model string `json:"model"` + Query nvidiaRerankText `json:"query"` + Passages []nvidiaRerankText `json:"passages"` + Truncate string `json:"truncate,omitempty"` + TopN int `json:"top_n"` +} + +type nvidiaRerankText struct { + Text string `json:"text"` +} + +// nvidiaRerankResponse maps the NIM rankings array. Each entry pairs +// the original passage index with a logit score; the caller uses the +// index to restore original input order. +type nvidiaRerankResponse struct { + Rankings []struct { + Index int `json:"index"` + Logit float64 `json:"logit"` + } `json:"rankings"` +} + +// Rerank scores documents against the query using an NVIDIA NIM +// reranking model. Mirrors the Python NvidiaRerank class in +// rag/llm/rerank_model.py for payload shape (passages/query/logit). +// Defaults top_n to len(documents) so the API returns a score per +// input; callers may shrink it via RerankConfig.TopN, in which case +// only the top RerankConfig.TopN entries come back. Returned +// RerankResult entries are in the API's ranking order; callers that +// need original-input order should sort by Index. Same return-shape +// contract as the Aliyun and ZhipuAI Rerank drivers. func (n NvidiaModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { - return nil, fmt.Errorf("no such method") + if len(documents) == 0 { + return &RerankResponse{}, nil + } + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL := n.BaseURL[region] + if baseURL == "" { + baseURL = n.BaseURL["default"] + } + if baseURL == "" { + return nil, fmt.Errorf("nvidia: no base URL configured for region %q", region) + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), n.URLSuffix.Rerank) + + topN := len(documents) + if rerankConfig != nil && rerankConfig.TopN > 0 && rerankConfig.TopN < topN { + topN = rerankConfig.TopN + } + + passages := make([]nvidiaRerankText, len(documents)) + for i, doc := range documents { + passages[i] = nvidiaRerankText{Text: doc} + } + + reqBody := nvidiaRerankRequest{ + Model: *modelName, + Query: nvidiaRerankText{Text: query}, + Passages: passages, + Truncate: "END", + TopN: topN, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := n.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Nvidia rerank API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed nvidiaRerankResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + rerankResponse := RerankResponse{Data: make([]RerankResult, 0, len(parsed.Rankings))} + for _, r := range parsed.Rankings { + if r.Index < 0 || r.Index >= len(documents) { + return nil, fmt.Errorf("unexpected rerank index %d for %d inputs", r.Index, len(documents)) + } + rerankResponse.Data = append(rerankResponse.Data, RerankResult{ + Index: r.Index, + RelevanceScore: r.Logit, + }) + } + + return &rerankResponse, nil } // ListModels calls /v1/models on the configured NVIDIA NIM base URL diff --git a/internal/entity/models/nvidia_rerank_test.go b/internal/entity/models/nvidia_rerank_test.go new file mode 100644 index 0000000000..c92249bfbb --- /dev/null +++ b/internal/entity/models/nvidia_rerank_test.go @@ -0,0 +1,195 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newNvidiaRerankServer(t *testing.T, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server { + t.Helper() + // Use t.Errorf + return inside the handler goroutine; t.Fatalf would + // only Goexit the handler goroutine and the test would silently pass. + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + return + } + if r.URL.Path != "/ranking" { + t.Errorf("expected path=/ranking, got %s", r.URL.Path) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("expected Authorization=Bearer test-key, got %q", got) + return + } + if got := r.Header.Get("Content-Type"); got != "application/json" { + t.Errorf("expected Content-Type=application/json, got %q", got) + return + } + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("failed to read body: %v", err) + return + } + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("invalid JSON body: %v\n%s", err, string(raw)) + return + } + handler(t, body, w) + })) +} + +func newNvidiaModelForTest(baseURL string) *NvidiaModel { + return NewNvidiaModel( + map[string]string{"default": baseURL}, + URLSuffix{Rerank: "ranking"}, + ) +} + +func TestNvidiaRerankHappyPath(t *testing.T) { + srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["model"] != "nvidia/nv-rerankqa-mistral-4b-v3" { + t.Errorf("expected model=nvidia/nv-rerankqa-mistral-4b-v3, got %v", body["model"]) + } + query, ok := body["query"].(map[string]interface{}) + if !ok || query["text"] != "What is RAPTOR?" { + t.Errorf("expected query.text=What is RAPTOR?, got %v", body["query"]) + } + passages, ok := body["passages"].([]interface{}) + if !ok || len(passages) != 3 { + t.Errorf("expected 3 passages, got %v", body["passages"]) + return + } + if body["truncate"] != "END" { + t.Errorf("expected truncate=END, got %v", body["truncate"]) + } + if body["top_n"] != float64(3) { + t.Errorf("expected top_n=3 (matching len(documents)), got %v", body["top_n"]) + } + // Return rankings out of input order to verify Index preservation. + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "rankings": []map[string]interface{}{ + {"index": 2, "logit": 9.5}, + {"index": 0, "logit": 4.25}, + {"index": 1, "logit": 7.8}, + }, + }) + }) + defer srv.Close() + + model := newNvidiaModelForTest(srv.URL) + apiKey := "test-key" + modelName := "nvidia/nv-rerankqa-mistral-4b-v3" + resp, err := model.Rerank( + &modelName, + "What is RAPTOR?", + []string{"doc-zero", "doc-one", "doc-two"}, + &APIConfig{ApiKey: &apiKey}, + &RerankConfig{}, + ) + if err != nil { + t.Fatalf("Rerank failed: %v", err) + } + if len(resp.Data) != 3 { + t.Fatalf("expected 3 results, got %d", len(resp.Data)) + } + want := map[int]float64{0: 4.25, 1: 7.8, 2: 9.5} + for _, r := range resp.Data { + if got, ok := want[r.Index]; !ok || got != r.RelevanceScore { + t.Errorf("unexpected result Index=%d RelevanceScore=%v", r.Index, r.RelevanceScore) + } + } +} + +func TestNvidiaRerankTopNClamp(t *testing.T) { + srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["top_n"] != float64(2) { + t.Errorf("expected top_n clamp to RerankConfig.TopN=2, got %v", body["top_n"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{"rankings": []map[string]interface{}{}}) + }) + defer srv.Close() + + model := newNvidiaModelForTest(srv.URL) + apiKey := "test-key" + modelName := "nvidia/nv-rerankqa-mistral-4b-v3" + if _, err := model.Rerank( + &modelName, "q", + []string{"a", "b", "c", "d"}, + &APIConfig{ApiKey: &apiKey}, + &RerankConfig{TopN: 2}, + ); err != nil { + t.Fatalf("Rerank failed: %v", err) + } +} + +func TestNvidiaRerankEmptyDocuments(t *testing.T) { + model := newNvidiaModelForTest("http://unused") + apiKey := "test-key" + modelName := "nvidia/nv-rerankqa-mistral-4b-v3" + resp, err := model.Rerank(&modelName, "q", nil, &APIConfig{ApiKey: &apiKey}, &RerankConfig{}) + if err != nil { + t.Fatalf("expected nil error for empty documents, got %v", err) + } + if len(resp.Data) != 0 { + t.Errorf("expected empty Data, got %d entries", len(resp.Data)) + } +} + +func TestNvidiaRerankRequiresAPIKey(t *testing.T) { + model := newNvidiaModelForTest("http://unused") + modelName := "nvidia/nv-rerankqa-mistral-4b-v3" + _, err := model.Rerank(&modelName, "q", []string{"a"}, &APIConfig{}, &RerankConfig{}) + if err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("expected api-key error, got %v", err) + } +} + +func TestNvidiaRerankRequiresModelName(t *testing.T) { + model := newNvidiaModelForTest("http://unused") + apiKey := "test-key" + _, err := model.Rerank(nil, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{}) + if err == nil || !strings.Contains(err.Error(), "model name is required") { + t.Errorf("expected model-name error, got %v", err) + } +} + +func TestNvidiaRerankRejectsHTTPError(t *testing.T) { + srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + }) + defer srv.Close() + + model := newNvidiaModelForTest(srv.URL) + apiKey := "test-key" + modelName := "nvidia/nv-rerankqa-mistral-4b-v3" + _, err := model.Rerank(&modelName, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{}) + if err == nil || !strings.Contains(err.Error(), "Nvidia rerank API error") { + t.Errorf("expected API error, got %v", err) + } +} + +func TestNvidiaRerankRejectsOutOfRangeIndex(t *testing.T) { + srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "rankings": []map[string]interface{}{ + {"index": 5, "logit": 1.0}, // out of range for 2-input request + }, + }) + }) + defer srv.Close() + + model := newNvidiaModelForTest(srv.URL) + apiKey := "test-key" + modelName := "nvidia/nv-rerankqa-mistral-4b-v3" + _, err := model.Rerank(&modelName, "q", []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{}) + if err == nil || !strings.Contains(err.Error(), "unexpected rerank index") { + t.Errorf("expected out-of-range error, got %v", err) + } +}