mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-04 01:29:35 +08:00
### What problem does this PR solve? Closes #14878. `VllmModel.Rerank()` in [internal/entity/models/vllm.go:551](internal/entity/models/vllm.go#L551) is currently a stub returning `nil, fmt.Errorf("%s, Rerank not implemented", z.Name())`, and [conf/models/vllm.json](conf/models/vllm.json) is missing a `rerank` entry in `url_suffix`. Chat (long-standing) and embeddings (#14688) already work, so rerank is the last missing leg of the retrieval pipeline for operators running everything on a single self-hosted vLLM server — today they have to point rerank at a different provider, which defeats the point of a fully local deployment. Upstream vLLM has supported a Jina/Cohere-compatible `POST /v1/rerank` endpoint since v0.7 ([vllm-project/vllm#12376](https://github.com/vllm-project/vllm/pull/12376)). The request/response shape is essentially identical to the NVIDIA driver landed in #14778, so this PR mirrors that structure with two vLLM-specific adjustments. This PR replaces the stub with a real implementation against vLLM's `/v1/rerank`: - `POST {baseURL}/rerank` - Request body: `{"model": "<modelName>", "query": "<query>", "documents": [...], "top_n": <int>}` — documents are a flat `[]string`, **not** wrapped as `{text: "..."}` like NVIDIA's `/ranking`. - Response body: `{"results": [{"index": int, "relevance_score": float}, ...]}` (Jina-compatible; the optional `document` field is ignored since callers reconstruct text via `Index`). - `Authorization: Bearer <ApiKey>` is set **only when `APIConfig.ApiKey` is non-empty**, matching the existing `Embed`/`ListModels` behaviour in this file. vLLM is a local driver and can be deployed without an API key. The return shape matches the existing `*RerankResponse` contract used by the NVIDIA ([nvidia.go:461](internal/entity/models/nvidia.go#L461)), Aliyun ([aliyun.go:507](internal/entity/models/aliyun.go#L507)), and ZhipuAI ([zhipu-ai.go:554](internal/entity/models/zhipu-ai.go#L554)) drivers, i.e. `Data []RerankResult` carrying `{Index, RelevanceScore}` in the API's ranking order. Callers that need original-input order sort by `Index`. Behaviour requirements from the issue, all covered: 1. Empty `documents` → returns `&RerankResponse{}` without an HTTP call. 2. Missing `modelName` → `"model name is required"` validation error. 3. `rerankConfig.TopN` honored when `0 < TopN < len(documents)`; otherwise `top_n` defaults to `len(documents)` so callers get a score per input. 4. Non-200 responses return an error including upstream status and body (`"vLLM rerank API error: <status>, body: <body>"`). 5. Response `index` values are bounds-checked against `len(documents)`. **Scope:** - [internal/entity/models/vllm.go](internal/entity/models/vllm.go) — replaces the `Rerank` stub at line 551 with a real implementation; adds `vllmRerankRequest`/`vllmRerankResponse` types for the slim subset of the payload we need. Region/baseURL resolution, 30s context timeout, conditional bearer header, and error wrapping all follow the existing patterns in this file. - [conf/models/vllm.json](conf/models/vllm.json) — adds `"rerank": "rerank"` to `url_suffix`, joined to the operator-configured vLLM base URL the same way the NVIDIA driver joins at [nvidia.go:485](internal/entity/models/nvidia.go#L485). - [internal/entity/models/vllm_rerank_test.go](internal/entity/models/vllm_rerank_test.go) — adds 7 `httptest`-backed tests mirroring `nvidia_rerank_test.go`: happy path (out-of-order ranking → Index preservation), `top_n` clamp to `RerankConfig.TopN`, empty-documents short-circuit, missing-model-name validation, HTTP error propagation, out-of-range index rejection, and a vLLM-specific `TestVllmRerankWithoutAPIKey` locking in the optional-auth behaviour that distinguishes this driver from NVIDIA. **Out of scope:** no interface change, no DDL, no frontend change. Chat, embeddings, and balance paths are untouched. No new user-facing docs required beyond the existing rerank model setup page — vLLM joins the list of providers whose rerank model can be selected once `/v1/rerank` is exposed by the server. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
210 lines
6.7 KiB
Go
210 lines
6.7 KiB
Go
package models
|
|
|
|
import (
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func newVllmRerankServer(t *testing.T, expectAuth string, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server {
|
|
t.Helper()
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
t.Errorf("expected POST, got %s", r.Method)
|
|
return
|
|
}
|
|
if r.URL.Path != "/rerank" {
|
|
t.Errorf("expected path=/rerank, got %s", r.URL.Path)
|
|
return
|
|
}
|
|
if got := r.Header.Get("Authorization"); got != expectAuth {
|
|
t.Errorf("expected Authorization=%q, got %q", expectAuth, got)
|
|
return
|
|
}
|
|
if got := r.Header.Get("Content-Type"); got != "application/json" {
|
|
t.Errorf("expected Content-Type=application/json, got %q", got)
|
|
return
|
|
}
|
|
raw, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
t.Errorf("failed to read body: %v", err)
|
|
return
|
|
}
|
|
var body map[string]interface{}
|
|
if err := json.Unmarshal(raw, &body); err != nil {
|
|
t.Errorf("invalid JSON body: %v\n%s", err, string(raw))
|
|
return
|
|
}
|
|
handler(t, body, w)
|
|
}))
|
|
}
|
|
|
|
func newVllmModelForTest(baseURL string) *VllmModel {
|
|
return NewVllmModel(
|
|
map[string]string{"default": baseURL},
|
|
URLSuffix{Rerank: "rerank"},
|
|
)
|
|
}
|
|
|
|
func TestVllmRerankHappyPath(t *testing.T) {
|
|
srv := newVllmRerankServer(t, "Bearer test-key", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
|
if body["model"] != "BAAI/bge-reranker-v2-m3" {
|
|
t.Errorf("expected model=BAAI/bge-reranker-v2-m3, got %v", body["model"])
|
|
}
|
|
if body["query"] != "What is RAPTOR?" {
|
|
t.Errorf("expected query=What is RAPTOR?, got %v", body["query"])
|
|
}
|
|
// vLLM differs from NVIDIA: documents is a flat []string, not [{text}].
|
|
docs, ok := body["documents"].([]interface{})
|
|
if !ok || len(docs) != 3 {
|
|
t.Errorf("expected 3 documents, got %v", body["documents"])
|
|
return
|
|
}
|
|
for i, want := range []string{"doc-zero", "doc-one", "doc-two"} {
|
|
if docs[i] != want {
|
|
t.Errorf("documents[%d]=%v, want %s", i, docs[i], want)
|
|
}
|
|
}
|
|
if body["top_n"] != float64(3) {
|
|
t.Errorf("expected top_n=3 (matching len(documents)), got %v", body["top_n"])
|
|
}
|
|
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"results": []map[string]interface{}{
|
|
{"index": 2, "relevance_score": 0.95},
|
|
{"index": 0, "relevance_score": 0.42},
|
|
{"index": 1, "relevance_score": 0.78},
|
|
},
|
|
})
|
|
})
|
|
defer srv.Close()
|
|
|
|
model := newVllmModelForTest(srv.URL)
|
|
apiKey := "test-key"
|
|
modelName := "BAAI/bge-reranker-v2-m3"
|
|
resp, err := model.Rerank(
|
|
&modelName,
|
|
"What is RAPTOR?",
|
|
[]string{"doc-zero", "doc-one", "doc-two"},
|
|
&APIConfig{ApiKey: &apiKey},
|
|
&RerankConfig{},
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("Rerank failed: %v", err)
|
|
}
|
|
if len(resp.Data) != 3 {
|
|
t.Fatalf("expected 3 results, got %d", len(resp.Data))
|
|
}
|
|
want := map[int]float64{0: 0.42, 1: 0.78, 2: 0.95}
|
|
for _, r := range resp.Data {
|
|
if got, ok := want[r.Index]; !ok || got != r.RelevanceScore {
|
|
t.Errorf("unexpected result Index=%d RelevanceScore=%v", r.Index, r.RelevanceScore)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestVllmRerankTopNClamp(t *testing.T) {
|
|
srv := newVllmRerankServer(t, "Bearer test-key", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
|
if body["top_n"] != float64(2) {
|
|
t.Errorf("expected top_n clamp to RerankConfig.TopN=2, got %v", body["top_n"])
|
|
}
|
|
_ = json.NewEncoder(w).Encode(map[string]interface{}{"results": []map[string]interface{}{}})
|
|
})
|
|
defer srv.Close()
|
|
|
|
model := newVllmModelForTest(srv.URL)
|
|
apiKey := "test-key"
|
|
modelName := "BAAI/bge-reranker-v2-m3"
|
|
if _, err := model.Rerank(
|
|
&modelName, "q",
|
|
[]string{"a", "b", "c", "d"},
|
|
&APIConfig{ApiKey: &apiKey},
|
|
&RerankConfig{TopN: 2},
|
|
); err != nil {
|
|
t.Fatalf("Rerank failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestVllmRerankEmptyDocuments(t *testing.T) {
|
|
model := newVllmModelForTest("http://unused")
|
|
apiKey := "test-key"
|
|
modelName := "BAAI/bge-reranker-v2-m3"
|
|
resp, err := model.Rerank(&modelName, "q", nil, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
|
|
if err != nil {
|
|
t.Fatalf("expected nil error for empty documents, got %v", err)
|
|
}
|
|
if len(resp.Data) != 0 {
|
|
t.Errorf("expected empty Data, got %d entries", len(resp.Data))
|
|
}
|
|
}
|
|
|
|
// vLLM is a local driver; the Authorization header must be omitted when
|
|
// no APIConfig.ApiKey is configured. This diverges from the NVIDIA driver
|
|
// which requires an API key.
|
|
func TestVllmRerankWithoutAPIKey(t *testing.T) {
|
|
srv := newVllmRerankServer(t, "", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
|
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"results": []map[string]interface{}{
|
|
{"index": 0, "relevance_score": 0.5},
|
|
},
|
|
})
|
|
})
|
|
defer srv.Close()
|
|
|
|
model := newVllmModelForTest(srv.URL)
|
|
modelName := "BAAI/bge-reranker-v2-m3"
|
|
resp, err := model.Rerank(&modelName, "q", []string{"a"}, &APIConfig{}, &RerankConfig{})
|
|
if err != nil {
|
|
t.Fatalf("Rerank failed without api key: %v", err)
|
|
}
|
|
if len(resp.Data) != 1 || resp.Data[0].Index != 0 {
|
|
t.Errorf("unexpected response: %+v", resp)
|
|
}
|
|
}
|
|
|
|
func TestVllmRerankRequiresModelName(t *testing.T) {
|
|
model := newVllmModelForTest("http://unused")
|
|
apiKey := "test-key"
|
|
_, err := model.Rerank(nil, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
|
|
if err == nil || !strings.Contains(err.Error(), "model name is required") {
|
|
t.Errorf("expected model-name error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestVllmRerankRejectsHTTPError(t *testing.T) {
|
|
srv := newVllmRerankServer(t, "Bearer test-key", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
_, _ = w.Write([]byte(`{"error":"boom"}`))
|
|
})
|
|
defer srv.Close()
|
|
|
|
model := newVllmModelForTest(srv.URL)
|
|
apiKey := "test-key"
|
|
modelName := "BAAI/bge-reranker-v2-m3"
|
|
_, err := model.Rerank(&modelName, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
|
|
if err == nil || !strings.Contains(err.Error(), "vLLM rerank API error") {
|
|
t.Errorf("expected API error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestVllmRerankRejectsOutOfRangeIndex(t *testing.T) {
|
|
srv := newVllmRerankServer(t, "Bearer test-key", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
|
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"results": []map[string]interface{}{
|
|
{"index": 5, "relevance_score": 0.9},
|
|
},
|
|
})
|
|
})
|
|
defer srv.Close()
|
|
|
|
model := newVllmModelForTest(srv.URL)
|
|
apiKey := "test-key"
|
|
modelName := "BAAI/bge-reranker-v2-m3"
|
|
_, err := model.Rerank(&modelName, "q", []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
|
|
if err == nil || !strings.Contains(err.Error(), "unexpected rerank index") {
|
|
t.Errorf("expected out-of-range error, got %v", err)
|
|
}
|
|
}
|