From fec0b968e7708588aa412363d67b4904707dbf29 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Wed, 20 May 2026 16:19:17 -1000 Subject: [PATCH] Go: implement Rerank in Novita driver (#15014) ### What problem does this PR solve? Fixes #15012 The Novita Go driver landed in #14850 and shipped a stub `Rerank` method that returned `"novita, no such method"`, so Novita could not be used as a rerank provider in RAGFlow. This PR fills that gap, in the same way #14895 filled the Embed gap on the same driver. Novita exposes a public rerank endpoint at `POST https://api.novita.ai/openai/v1/rerank` that accepts the Cohere-compatible request shape (`{model, query, documents, top_n}`) with `Authorization: Bearer `. `baai/bge-reranker-v2-m3` is documented in Novita's model library with a 1024-token limit. --- conf/models/novita.json | 10 +- internal/entity/models/novita.go | 107 ++++++++++++++++++- internal/entity/models/novita_test.go | 143 ++++++++++++++++++++++++-- 3 files changed, 248 insertions(+), 12 deletions(-) diff --git a/conf/models/novita.json b/conf/models/novita.json index f95e684929..6dad88c2fa 100644 --- a/conf/models/novita.json +++ b/conf/models/novita.json @@ -6,7 +6,8 @@ "url_suffix": { "chat": "openai/v1/chat/completions", "models": "openai/v1/models", - "embedding": "openai/v1/embeddings" + "embedding": "openai/v1/embeddings", + "rerank": "openai/v1/rerank" }, "class": "novita", "models": [ @@ -65,6 +66,13 @@ "model_types": [ "embedding" ] + }, + { + "name": "baai/bge-reranker-v2-m3", + "max_tokens": 1024, + "model_types": [ + "rerank" + ] } ] } diff --git a/internal/entity/models/novita.go b/internal/entity/models/novita.go index 33e945f613..7335dbff68 100644 --- a/internal/entity/models/novita.go +++ b/internal/entity/models/novita.go @@ -733,9 +733,112 @@ func (n *NovitaModel) Embed(modelName *string, texts []string, apiConfig *APICon return embeddings, nil } -// Rerank is not exposed by the Novita API. +type novitaRerankResult struct { + Document struct { + Text string `json:"text"` + } `json:"document"` + Index int `json:"index"` + RelevanceScore float64 `json:"relevance_score"` +} + +type novitaRerankResponse struct { + Results []novitaRerankResult `json:"results"` +} + +// Rerank scores documents against the query using the Novita +// /openai/v1/rerank endpoint and returns one RerankResult per scored +// document in the API's ranking order. Caller may sort by Index to +// recover original input order. func (n *NovitaModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { - return nil, fmt.Errorf("%s, no such method", n.Name()) + if len(documents) == 0 { + return &RerankResponse{}, nil + } + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := n.baseURLForRegion(region) + if err != nil { + return nil, err + } + if n.URLSuffix.Rerank == "" { + return nil, fmt.Errorf("novita: no rerank URL suffix configured") + } + url := fmt.Sprintf("%s/%s", baseURL, n.URLSuffix.Rerank) + + topN := len(documents) + if rerankConfig != nil && rerankConfig.TopN > 0 && rerankConfig.TopN < topN { + topN = rerankConfig.TopN + } + + reqBody := map[string]interface{}{ + "model": *modelName, + "query": query, + "documents": documents, + "top_n": topN, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := n.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Novita rerank API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed novitaRerankResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + rerankResponse := RerankResponse{Data: make([]RerankResult, 0, len(parsed.Results))} + seen := make([]bool, len(documents)) + for _, item := range parsed.Results { + if item.Index < 0 || item.Index >= len(documents) { + return nil, fmt.Errorf("novita: rerank index %d out of range for %d inputs", item.Index, len(documents)) + } + if seen[item.Index] { + return nil, fmt.Errorf("novita: duplicate rerank index %d in response", item.Index) + } + rerankResponse.Data = append(rerankResponse.Data, RerankResult{ + Index: item.Index, + RelevanceScore: item.RelevanceScore, + }) + seen[item.Index] = true + } + + return &rerankResponse, nil } // Balance is not exposed by the Novita API. diff --git a/internal/entity/models/novita_test.go b/internal/entity/models/novita_test.go index 29cbdace18..177085cad9 100644 --- a/internal/entity/models/novita_test.go +++ b/internal/entity/models/novita_test.go @@ -49,7 +49,12 @@ func newNovitaServer(t *testing.T, expectedPath string, handler func(t *testing. func newNovitaForTest(baseURL string) *NovitaModel { return NewNovitaModel( map[string]string{"default": baseURL}, - URLSuffix{Chat: "openai/v1/chat/completions", Models: "openai/v1/models"}, + URLSuffix{ + Chat: "openai/v1/chat/completions", + Models: "openai/v1/models", + Embedding: "openai/v1/embeddings", + Rerank: "openai/v1/rerank", + }, ) } @@ -650,18 +655,138 @@ func TestNovitaCheckConnection(t *testing.T) { } } -func TestNovitaEmbedReturnsNoSuchMethod(t *testing.T) { - m := "x" - _, err := newNovitaForTest("http://unused").Embed(&m, []string{"a"}, &APIConfig{}, nil) - if err == nil || !strings.Contains(err.Error(), "no such method") { +func TestNovitaRerankHappyPathReordersByIndex(t *testing.T) { + srv := newNovitaServer(t, "/openai/v1/rerank", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["model"] != "baai/bge-reranker-v2-m3" { + t.Errorf("model=%v", body["model"]) + } + if body["query"] != "what is rag" { + t.Errorf("query=%v", body["query"]) + } + docs, ok := body["documents"].([]interface{}) + if !ok || len(docs) != 3 || docs[0] != "a" || docs[1] != "b" || docs[2] != "c" { + t.Errorf("documents=%v", body["documents"]) + } + if body["top_n"] != float64(3) { + t.Errorf("top_n=%v, want 3", body["top_n"]) + } + _, _ = io.WriteString(w, `{"results":[{"document":{"text":"c"},"index":2,"relevance_score":0.91},{"document":{"text":"a"},"index":0,"relevance_score":0.42},{"document":{"text":"b"},"index":1,"relevance_score":0.08}]}`) + }) + defer srv.Close() + + apiKey := "test-key" + model := "baai/bge-reranker-v2-m3" + resp, err := newNovitaForTest(srv.URL).Rerank(&model, "what is rag", []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Rerank: %v", err) + } + if len(resp.Data) != 3 { + t.Fatalf("len(Data)=%d, want 3", len(resp.Data)) + } + want := map[int]float64{2: 0.91, 0: 0.42, 1: 0.08} + for i, item := range resp.Data { + if want[item.Index] != item.RelevanceScore { + t.Errorf("Data[%d]={Index:%d, Score:%v}", i, item.Index, item.RelevanceScore) + } + } +} + +func TestNovitaRerankRespectsTopNConfig(t *testing.T) { + srv := newNovitaServer(t, "/openai/v1/rerank", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["top_n"] != float64(2) { + t.Errorf("top_n=%v, want 2", body["top_n"]) + } + _, _ = io.WriteString(w, `{"results":[{"index":0,"relevance_score":0.9},{"index":1,"relevance_score":0.5}]}`) + }) + defer srv.Close() + + apiKey := "test-key" + model := "baai/bge-reranker-v2-m3" + if _, err := newNovitaForTest(srv.URL).Rerank(&model, "q", []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{TopN: 2}); err != nil { + t.Fatalf("Rerank: %v", err) + } +} + +func TestNovitaRerankEmptyDocumentsShortCircuits(t *testing.T) { + apiKey := "test-key" + model := "x" + resp, err := newNovitaForTest("http://unused").Rerank(&model, "q", nil, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("expected nil error for empty docs, got %v", err) + } + if len(resp.Data) != 0 { + t.Errorf("len(Data)=%d, want 0", len(resp.Data)) + } +} + +func TestNovitaRerankRequiresApiKey(t *testing.T) { + model := "x" + _, err := newNovitaForTest("http://unused").Rerank(&model, "q", []string{"a"}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "api key is required") { t.Errorf("got %v", err) } } -func TestNovitaRerankReturnsNoSuchMethod(t *testing.T) { - m := "x" - _, err := newNovitaForTest("http://unused").Rerank(&m, "q", []string{"a"}, &APIConfig{}, &RerankConfig{TopN: 1}) - if err == nil || !strings.Contains(err.Error(), "no such method") { +func TestNovitaRerankRequiresModelName(t *testing.T) { + apiKey := "test-key" + _, err := newNovitaForTest("http://unused").Rerank(nil, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "model name is required") { + t.Errorf("got %v", err) + } +} + +func TestNovitaRerankRejectsOutOfRangeIndex(t *testing.T) { + srv := newNovitaServer(t, "/openai/v1/rerank", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + _, _ = io.WriteString(w, `{"results":[{"index":5,"relevance_score":0.5}]}`) + }) + defer srv.Close() + + apiKey := "test-key" + model := "x" + _, err := newNovitaForTest(srv.URL).Rerank(&model, "q", []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "out of range") { + t.Errorf("got %v", err) + } +} + +func TestNovitaRerankRejectsDuplicateIndex(t *testing.T) { + srv := newNovitaServer(t, "/openai/v1/rerank", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + _, _ = io.WriteString(w, `{"results":[{"index":0,"relevance_score":0.9},{"index":0,"relevance_score":0.5}]}`) + }) + defer srv.Close() + + apiKey := "test-key" + model := "x" + _, err := newNovitaForTest(srv.URL).Rerank(&model, "q", []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "duplicate") { + t.Errorf("got %v", err) + } +} + +func TestNovitaRerankSurfacesHTTPError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = io.WriteString(w, `{"error":"bad key"}`) + })) + defer srv.Close() + + apiKey := "test-key" + model := "x" + _, err := newNovitaForTest(srv.URL).Rerank(&model, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "Novita rerank API error") { + t.Errorf("got %v", err) + } +} + +func TestNovitaRerankRejectsMissingRerankSuffix(t *testing.T) { + apiKey := "test-key" + model := "x" + driver := NewNovitaModel( + map[string]string{"default": "http://unused"}, + URLSuffix{Chat: "openai/v1/chat/completions"}, + ) + _, err := driver.Rerank(&model, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "no rerank URL suffix configured") { t.Errorf("got %v", err) } }