mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-30 16:01:58 +08:00
Go: implement Rerank in Novita driver (#15014)
### What problem does this PR solve? Fixes #15012 The Novita Go driver landed in #14850 and shipped a stub `Rerank` method that returned `"novita, no such method"`, so Novita could not be used as a rerank provider in RAGFlow. This PR fills that gap, in the same way #14895 filled the Embed gap on the same driver. Novita exposes a public rerank endpoint at `POST https://api.novita.ai/openai/v1/rerank` that accepts the Cohere-compatible request shape (`{model, query, documents, top_n}`) with `Authorization: Bearer <api_key>`. `baai/bge-reranker-v2-m3` is documented in Novita's model library with a 1024-token limit.
This commit is contained in:
@@ -6,7 +6,8 @@
|
||||
"url_suffix": {
|
||||
"chat": "openai/v1/chat/completions",
|
||||
"models": "openai/v1/models",
|
||||
"embedding": "openai/v1/embeddings"
|
||||
"embedding": "openai/v1/embeddings",
|
||||
"rerank": "openai/v1/rerank"
|
||||
},
|
||||
"class": "novita",
|
||||
"models": [
|
||||
@@ -65,6 +66,13 @@
|
||||
"model_types": [
|
||||
"embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "baai/bge-reranker-v2-m3",
|
||||
"max_tokens": 1024,
|
||||
"model_types": [
|
||||
"rerank"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -733,9 +733,112 @@ func (n *NovitaModel) Embed(modelName *string, texts []string, apiConfig *APICon
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// Rerank is not exposed by the Novita API.
|
||||
type novitaRerankResult struct {
|
||||
Document struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"document"`
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
}
|
||||
|
||||
type novitaRerankResponse struct {
|
||||
Results []novitaRerankResult `json:"results"`
|
||||
}
|
||||
|
||||
// Rerank scores documents against the query using the Novita
|
||||
// /openai/v1/rerank endpoint and returns one RerankResult per scored
|
||||
// document in the API's ranking order. Caller may sort by Index to
|
||||
// recover original input order.
|
||||
func (n *NovitaModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", n.Name())
|
||||
if len(documents) == 0 {
|
||||
return &RerankResponse{}, nil
|
||||
}
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return nil, fmt.Errorf("api key is required")
|
||||
}
|
||||
if modelName == nil || *modelName == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
|
||||
region := "default"
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL, err := n.baseURLForRegion(region)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if n.URLSuffix.Rerank == "" {
|
||||
return nil, fmt.Errorf("novita: no rerank URL suffix configured")
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s", baseURL, n.URLSuffix.Rerank)
|
||||
|
||||
topN := len(documents)
|
||||
if rerankConfig != nil && rerankConfig.TopN > 0 && rerankConfig.TopN < topN {
|
||||
topN = rerankConfig.TopN
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": *modelName,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": topN,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := n.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Novita rerank API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var parsed novitaRerankResponse
|
||||
if err = json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
rerankResponse := RerankResponse{Data: make([]RerankResult, 0, len(parsed.Results))}
|
||||
seen := make([]bool, len(documents))
|
||||
for _, item := range parsed.Results {
|
||||
if item.Index < 0 || item.Index >= len(documents) {
|
||||
return nil, fmt.Errorf("novita: rerank index %d out of range for %d inputs", item.Index, len(documents))
|
||||
}
|
||||
if seen[item.Index] {
|
||||
return nil, fmt.Errorf("novita: duplicate rerank index %d in response", item.Index)
|
||||
}
|
||||
rerankResponse.Data = append(rerankResponse.Data, RerankResult{
|
||||
Index: item.Index,
|
||||
RelevanceScore: item.RelevanceScore,
|
||||
})
|
||||
seen[item.Index] = true
|
||||
}
|
||||
|
||||
return &rerankResponse, nil
|
||||
}
|
||||
|
||||
// Balance is not exposed by the Novita API.
|
||||
|
||||
@@ -49,7 +49,12 @@ func newNovitaServer(t *testing.T, expectedPath string, handler func(t *testing.
|
||||
func newNovitaForTest(baseURL string) *NovitaModel {
|
||||
return NewNovitaModel(
|
||||
map[string]string{"default": baseURL},
|
||||
URLSuffix{Chat: "openai/v1/chat/completions", Models: "openai/v1/models"},
|
||||
URLSuffix{
|
||||
Chat: "openai/v1/chat/completions",
|
||||
Models: "openai/v1/models",
|
||||
Embedding: "openai/v1/embeddings",
|
||||
Rerank: "openai/v1/rerank",
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -650,18 +655,138 @@ func TestNovitaCheckConnection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNovitaEmbedReturnsNoSuchMethod(t *testing.T) {
|
||||
m := "x"
|
||||
_, err := newNovitaForTest("http://unused").Embed(&m, []string{"a"}, &APIConfig{}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "no such method") {
|
||||
func TestNovitaRerankHappyPathReordersByIndex(t *testing.T) {
|
||||
srv := newNovitaServer(t, "/openai/v1/rerank", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
if body["model"] != "baai/bge-reranker-v2-m3" {
|
||||
t.Errorf("model=%v", body["model"])
|
||||
}
|
||||
if body["query"] != "what is rag" {
|
||||
t.Errorf("query=%v", body["query"])
|
||||
}
|
||||
docs, ok := body["documents"].([]interface{})
|
||||
if !ok || len(docs) != 3 || docs[0] != "a" || docs[1] != "b" || docs[2] != "c" {
|
||||
t.Errorf("documents=%v", body["documents"])
|
||||
}
|
||||
if body["top_n"] != float64(3) {
|
||||
t.Errorf("top_n=%v, want 3", body["top_n"])
|
||||
}
|
||||
_, _ = io.WriteString(w, `{"results":[{"document":{"text":"c"},"index":2,"relevance_score":0.91},{"document":{"text":"a"},"index":0,"relevance_score":0.42},{"document":{"text":"b"},"index":1,"relevance_score":0.08}]}`)
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "baai/bge-reranker-v2-m3"
|
||||
resp, err := newNovitaForTest(srv.URL).Rerank(&model, "what is rag", []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Rerank: %v", err)
|
||||
}
|
||||
if len(resp.Data) != 3 {
|
||||
t.Fatalf("len(Data)=%d, want 3", len(resp.Data))
|
||||
}
|
||||
want := map[int]float64{2: 0.91, 0: 0.42, 1: 0.08}
|
||||
for i, item := range resp.Data {
|
||||
if want[item.Index] != item.RelevanceScore {
|
||||
t.Errorf("Data[%d]={Index:%d, Score:%v}", i, item.Index, item.RelevanceScore)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNovitaRerankRespectsTopNConfig(t *testing.T) {
|
||||
srv := newNovitaServer(t, "/openai/v1/rerank", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
if body["top_n"] != float64(2) {
|
||||
t.Errorf("top_n=%v, want 2", body["top_n"])
|
||||
}
|
||||
_, _ = io.WriteString(w, `{"results":[{"index":0,"relevance_score":0.9},{"index":1,"relevance_score":0.5}]}`)
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "baai/bge-reranker-v2-m3"
|
||||
if _, err := newNovitaForTest(srv.URL).Rerank(&model, "q", []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{TopN: 2}); err != nil {
|
||||
t.Fatalf("Rerank: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNovitaRerankEmptyDocumentsShortCircuits(t *testing.T) {
|
||||
apiKey := "test-key"
|
||||
model := "x"
|
||||
resp, err := newNovitaForTest("http://unused").Rerank(&model, "q", nil, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error for empty docs, got %v", err)
|
||||
}
|
||||
if len(resp.Data) != 0 {
|
||||
t.Errorf("len(Data)=%d, want 0", len(resp.Data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNovitaRerankRequiresApiKey(t *testing.T) {
|
||||
model := "x"
|
||||
_, err := newNovitaForTest("http://unused").Rerank(&model, "q", []string{"a"}, &APIConfig{}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "api key is required") {
|
||||
t.Errorf("got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNovitaRerankReturnsNoSuchMethod(t *testing.T) {
|
||||
m := "x"
|
||||
_, err := newNovitaForTest("http://unused").Rerank(&m, "q", []string{"a"}, &APIConfig{}, &RerankConfig{TopN: 1})
|
||||
if err == nil || !strings.Contains(err.Error(), "no such method") {
|
||||
func TestNovitaRerankRequiresModelName(t *testing.T) {
|
||||
apiKey := "test-key"
|
||||
_, err := newNovitaForTest("http://unused").Rerank(nil, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "model name is required") {
|
||||
t.Errorf("got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNovitaRerankRejectsOutOfRangeIndex(t *testing.T) {
|
||||
srv := newNovitaServer(t, "/openai/v1/rerank", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
_, _ = io.WriteString(w, `{"results":[{"index":5,"relevance_score":0.5}]}`)
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "x"
|
||||
_, err := newNovitaForTest(srv.URL).Rerank(&model, "q", []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "out of range") {
|
||||
t.Errorf("got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNovitaRerankRejectsDuplicateIndex(t *testing.T) {
|
||||
srv := newNovitaServer(t, "/openai/v1/rerank", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
_, _ = io.WriteString(w, `{"results":[{"index":0,"relevance_score":0.9},{"index":0,"relevance_score":0.5}]}`)
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "x"
|
||||
_, err := newNovitaForTest(srv.URL).Rerank(&model, "q", []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "duplicate") {
|
||||
t.Errorf("got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNovitaRerankSurfacesHTTPError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = io.WriteString(w, `{"error":"bad key"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "x"
|
||||
_, err := newNovitaForTest(srv.URL).Rerank(&model, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "Novita rerank API error") {
|
||||
t.Errorf("got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNovitaRerankRejectsMissingRerankSuffix(t *testing.T) {
|
||||
apiKey := "test-key"
|
||||
model := "x"
|
||||
driver := NewNovitaModel(
|
||||
map[string]string{"default": "http://unused"},
|
||||
URLSuffix{Chat: "openai/v1/chat/completions"},
|
||||
)
|
||||
_, err := driver.Rerank(&model, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "no rerank URL suffix configured") {
|
||||
t.Errorf("got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user