Go: implement Rerank in Novita driver (#15014)

### What problem does this PR solve?

Fixes #15012

The Novita Go driver landed in #14850 and shipped a stub `Rerank` method
that returned `"novita, no such method"`, so Novita could not be used as
a rerank provider in RAGFlow. This PR fills that gap, in the same way
#14895 filled the Embed gap on the same driver.

Novita exposes a public rerank endpoint at `POST
https://api.novita.ai/openai/v1/rerank` that accepts the
Cohere-compatible request shape (`{model, query, documents, top_n}`)
with `Authorization: Bearer <api_key>`. `baai/bge-reranker-v2-m3` is
documented in Novita's model library with a 1024-token limit.
This commit is contained in:
Renzo
2026-05-20 16:19:17 -10:00
committed by GitHub
parent 536ed07d27
commit fec0b968e7
3 changed files with 248 additions and 12 deletions

View File

@@ -6,7 +6,8 @@
"url_suffix": {
"chat": "openai/v1/chat/completions",
"models": "openai/v1/models",
"embedding": "openai/v1/embeddings"
"embedding": "openai/v1/embeddings",
"rerank": "openai/v1/rerank"
},
"class": "novita",
"models": [
@@ -65,6 +66,13 @@
"model_types": [
"embedding"
]
},
{
"name": "baai/bge-reranker-v2-m3",
"max_tokens": 1024,
"model_types": [
"rerank"
]
}
]
}

View File

@@ -733,9 +733,112 @@ func (n *NovitaModel) Embed(modelName *string, texts []string, apiConfig *APICon
return embeddings, nil
}
// Rerank is not exposed by the Novita API.
type novitaRerankResult struct {
Document struct {
Text string `json:"text"`
} `json:"document"`
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
}
type novitaRerankResponse struct {
Results []novitaRerankResult `json:"results"`
}
// Rerank scores documents against the query using the Novita
// /openai/v1/rerank endpoint and returns one RerankResult per scored
// document in the API's ranking order. Caller may sort by Index to
// recover original input order.
func (n *NovitaModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
return nil, fmt.Errorf("%s, no such method", n.Name())
if len(documents) == 0 {
return &RerankResponse{}, nil
}
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
return nil, fmt.Errorf("api key is required")
}
if modelName == nil || *modelName == "" {
return nil, fmt.Errorf("model name is required")
}
region := "default"
if apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
baseURL, err := n.baseURLForRegion(region)
if err != nil {
return nil, err
}
if n.URLSuffix.Rerank == "" {
return nil, fmt.Errorf("novita: no rerank URL suffix configured")
}
url := fmt.Sprintf("%s/%s", baseURL, n.URLSuffix.Rerank)
topN := len(documents)
if rerankConfig != nil && rerankConfig.TopN > 0 && rerankConfig.TopN < topN {
topN = rerankConfig.TopN
}
reqBody := map[string]interface{}{
"model": *modelName,
"query": query,
"documents": documents,
"top_n": topN,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := n.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Novita rerank API error: %s, body: %s", resp.Status, string(body))
}
var parsed novitaRerankResponse
if err = json.Unmarshal(body, &parsed); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
rerankResponse := RerankResponse{Data: make([]RerankResult, 0, len(parsed.Results))}
seen := make([]bool, len(documents))
for _, item := range parsed.Results {
if item.Index < 0 || item.Index >= len(documents) {
return nil, fmt.Errorf("novita: rerank index %d out of range for %d inputs", item.Index, len(documents))
}
if seen[item.Index] {
return nil, fmt.Errorf("novita: duplicate rerank index %d in response", item.Index)
}
rerankResponse.Data = append(rerankResponse.Data, RerankResult{
Index: item.Index,
RelevanceScore: item.RelevanceScore,
})
seen[item.Index] = true
}
return &rerankResponse, nil
}
// Balance is not exposed by the Novita API.

View File

@@ -49,7 +49,12 @@ func newNovitaServer(t *testing.T, expectedPath string, handler func(t *testing.
func newNovitaForTest(baseURL string) *NovitaModel {
return NewNovitaModel(
map[string]string{"default": baseURL},
URLSuffix{Chat: "openai/v1/chat/completions", Models: "openai/v1/models"},
URLSuffix{
Chat: "openai/v1/chat/completions",
Models: "openai/v1/models",
Embedding: "openai/v1/embeddings",
Rerank: "openai/v1/rerank",
},
)
}
@@ -650,18 +655,138 @@ func TestNovitaCheckConnection(t *testing.T) {
}
}
func TestNovitaEmbedReturnsNoSuchMethod(t *testing.T) {
m := "x"
_, err := newNovitaForTest("http://unused").Embed(&m, []string{"a"}, &APIConfig{}, nil)
if err == nil || !strings.Contains(err.Error(), "no such method") {
func TestNovitaRerankHappyPathReordersByIndex(t *testing.T) {
srv := newNovitaServer(t, "/openai/v1/rerank", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
if body["model"] != "baai/bge-reranker-v2-m3" {
t.Errorf("model=%v", body["model"])
}
if body["query"] != "what is rag" {
t.Errorf("query=%v", body["query"])
}
docs, ok := body["documents"].([]interface{})
if !ok || len(docs) != 3 || docs[0] != "a" || docs[1] != "b" || docs[2] != "c" {
t.Errorf("documents=%v", body["documents"])
}
if body["top_n"] != float64(3) {
t.Errorf("top_n=%v, want 3", body["top_n"])
}
_, _ = io.WriteString(w, `{"results":[{"document":{"text":"c"},"index":2,"relevance_score":0.91},{"document":{"text":"a"},"index":0,"relevance_score":0.42},{"document":{"text":"b"},"index":1,"relevance_score":0.08}]}`)
})
defer srv.Close()
apiKey := "test-key"
model := "baai/bge-reranker-v2-m3"
resp, err := newNovitaForTest(srv.URL).Rerank(&model, "what is rag", []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, nil)
if err != nil {
t.Fatalf("Rerank: %v", err)
}
if len(resp.Data) != 3 {
t.Fatalf("len(Data)=%d, want 3", len(resp.Data))
}
want := map[int]float64{2: 0.91, 0: 0.42, 1: 0.08}
for i, item := range resp.Data {
if want[item.Index] != item.RelevanceScore {
t.Errorf("Data[%d]={Index:%d, Score:%v}", i, item.Index, item.RelevanceScore)
}
}
}
func TestNovitaRerankRespectsTopNConfig(t *testing.T) {
srv := newNovitaServer(t, "/openai/v1/rerank", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
if body["top_n"] != float64(2) {
t.Errorf("top_n=%v, want 2", body["top_n"])
}
_, _ = io.WriteString(w, `{"results":[{"index":0,"relevance_score":0.9},{"index":1,"relevance_score":0.5}]}`)
})
defer srv.Close()
apiKey := "test-key"
model := "baai/bge-reranker-v2-m3"
if _, err := newNovitaForTest(srv.URL).Rerank(&model, "q", []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{TopN: 2}); err != nil {
t.Fatalf("Rerank: %v", err)
}
}
func TestNovitaRerankEmptyDocumentsShortCircuits(t *testing.T) {
apiKey := "test-key"
model := "x"
resp, err := newNovitaForTest("http://unused").Rerank(&model, "q", nil, &APIConfig{ApiKey: &apiKey}, nil)
if err != nil {
t.Fatalf("expected nil error for empty docs, got %v", err)
}
if len(resp.Data) != 0 {
t.Errorf("len(Data)=%d, want 0", len(resp.Data))
}
}
func TestNovitaRerankRequiresApiKey(t *testing.T) {
model := "x"
_, err := newNovitaForTest("http://unused").Rerank(&model, "q", []string{"a"}, &APIConfig{}, nil)
if err == nil || !strings.Contains(err.Error(), "api key is required") {
t.Errorf("got %v", err)
}
}
func TestNovitaRerankReturnsNoSuchMethod(t *testing.T) {
m := "x"
_, err := newNovitaForTest("http://unused").Rerank(&m, "q", []string{"a"}, &APIConfig{}, &RerankConfig{TopN: 1})
if err == nil || !strings.Contains(err.Error(), "no such method") {
func TestNovitaRerankRequiresModelName(t *testing.T) {
apiKey := "test-key"
_, err := newNovitaForTest("http://unused").Rerank(nil, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil)
if err == nil || !strings.Contains(err.Error(), "model name is required") {
t.Errorf("got %v", err)
}
}
func TestNovitaRerankRejectsOutOfRangeIndex(t *testing.T) {
srv := newNovitaServer(t, "/openai/v1/rerank", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
_, _ = io.WriteString(w, `{"results":[{"index":5,"relevance_score":0.5}]}`)
})
defer srv.Close()
apiKey := "test-key"
model := "x"
_, err := newNovitaForTest(srv.URL).Rerank(&model, "q", []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil)
if err == nil || !strings.Contains(err.Error(), "out of range") {
t.Errorf("got %v", err)
}
}
func TestNovitaRerankRejectsDuplicateIndex(t *testing.T) {
srv := newNovitaServer(t, "/openai/v1/rerank", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
_, _ = io.WriteString(w, `{"results":[{"index":0,"relevance_score":0.9},{"index":0,"relevance_score":0.5}]}`)
})
defer srv.Close()
apiKey := "test-key"
model := "x"
_, err := newNovitaForTest(srv.URL).Rerank(&model, "q", []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil)
if err == nil || !strings.Contains(err.Error(), "duplicate") {
t.Errorf("got %v", err)
}
}
func TestNovitaRerankSurfacesHTTPError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = io.WriteString(w, `{"error":"bad key"}`)
}))
defer srv.Close()
apiKey := "test-key"
model := "x"
_, err := newNovitaForTest(srv.URL).Rerank(&model, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil)
if err == nil || !strings.Contains(err.Error(), "Novita rerank API error") {
t.Errorf("got %v", err)
}
}
func TestNovitaRerankRejectsMissingRerankSuffix(t *testing.T) {
apiKey := "test-key"
model := "x"
driver := NewNovitaModel(
map[string]string{"default": "http://unused"},
URLSuffix{Chat: "openai/v1/chat/completions"},
)
_, err := driver.Rerank(&model, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil)
if err == nil || !strings.Contains(err.Error(), "no rerank URL suffix configured") {
t.Errorf("got %v", err)
}
}