Files
ragflow/internal/entity/models/vllm_rerank_test.go
Hunnyboy1217 86bcf9767d Go: implement Rerank in vLLM driver (#14878) (#14880)
### What problem does this PR solve?

Closes #14878.

`VllmModel.Rerank()` in
[internal/entity/models/vllm.go:551](internal/entity/models/vllm.go#L551)
is currently a stub returning `nil, fmt.Errorf("%s, Rerank not
implemented", z.Name())`, and
[conf/models/vllm.json](conf/models/vllm.json) is missing a `rerank`
entry in `url_suffix`. Chat (long-standing) and embeddings (#14688)
already work, so rerank is the last missing leg of the retrieval
pipeline for operators running everything on a single self-hosted vLLM
server — today they have to point rerank at a different provider, which
defeats the point of a fully local deployment.

Upstream vLLM has supported a Jina/Cohere-compatible `POST /v1/rerank`
endpoint since v0.7
([vllm-project/vllm#12376](https://github.com/vllm-project/vllm/pull/12376)).
The request/response shape is essentially identical to the NVIDIA driver
landed in #14778, so this PR mirrors that structure with two
vLLM-specific adjustments.

This PR replaces the stub with a real implementation against vLLM's
`/v1/rerank`:

- `POST {baseURL}/rerank`
- Request body: `{"model": "<modelName>", "query": "<query>",
"documents": [...], "top_n": <int>}` — documents are a flat `[]string`,
**not** wrapped as `{text: "..."}` like NVIDIA's `/ranking`.
- Response body: `{"results": [{"index": int, "relevance_score": float},
...]}` (Jina-compatible; the optional `document` field is ignored since
callers reconstruct text via `Index`).
- `Authorization: Bearer <ApiKey>` is set **only when `APIConfig.ApiKey`
is non-empty**, matching the existing `Embed`/`ListModels` behaviour in
this file. vLLM is a local driver and can be deployed without an API
key.

The return shape matches the existing `*RerankResponse` contract used by
the NVIDIA ([nvidia.go:461](internal/entity/models/nvidia.go#L461)),
Aliyun ([aliyun.go:507](internal/entity/models/aliyun.go#L507)), and
ZhipuAI ([zhipu-ai.go:554](internal/entity/models/zhipu-ai.go#L554))
drivers, i.e. `Data []RerankResult` carrying `{Index, RelevanceScore}`
in the API's ranking order. Callers that need original-input order sort
by `Index`.

Behaviour requirements from the issue, all covered:

1. Empty `documents` → returns `&RerankResponse{}` without an HTTP call.
2. Missing `modelName` → `"model name is required"` validation error.
3. `rerankConfig.TopN` honored when `0 < TopN < len(documents)`;
otherwise `top_n` defaults to `len(documents)` so callers get a score
per input.
4. Non-200 responses return an error including upstream status and body
(`"vLLM rerank API error: <status>, body: <body>"`).
5. Response `index` values are bounds-checked against `len(documents)`.

**Scope:**

- [internal/entity/models/vllm.go](internal/entity/models/vllm.go) —
replaces the `Rerank` stub at line 551 with a real implementation; adds
`vllmRerankRequest`/`vllmRerankResponse` types for the slim subset of
the payload we need. Region/baseURL resolution, 30s context timeout,
conditional bearer header, and error wrapping all follow the existing
patterns in this file.
- [conf/models/vllm.json](conf/models/vllm.json) — adds `"rerank":
"rerank"` to `url_suffix`, joined to the operator-configured vLLM base
URL the same way the NVIDIA driver joins at
[nvidia.go:485](internal/entity/models/nvidia.go#L485).
-
[internal/entity/models/vllm_rerank_test.go](internal/entity/models/vllm_rerank_test.go)
— adds 7 `httptest`-backed tests mirroring `nvidia_rerank_test.go`:
happy path (out-of-order ranking → Index preservation), `top_n` clamp to
`RerankConfig.TopN`, empty-documents short-circuit, missing-model-name
validation, HTTP error propagation, out-of-range index rejection, and a
vLLM-specific `TestVllmRerankWithoutAPIKey` locking in the optional-auth
behaviour that distinguishes this driver from NVIDIA.

**Out of scope:** no interface change, no DDL, no frontend change. Chat,
embeddings, and balance paths are untouched. No new user-facing docs
required beyond the existing rerank model setup page — vLLM joins the
list of providers whose rerank model can be selected once `/v1/rerank`
is exposed by the server.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-05-15 13:27:22 +08:00

210 lines
6.7 KiB
Go

package models
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func newVllmRerankServer(t *testing.T, expectAuth string, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST, got %s", r.Method)
return
}
if r.URL.Path != "/rerank" {
t.Errorf("expected path=/rerank, got %s", r.URL.Path)
return
}
if got := r.Header.Get("Authorization"); got != expectAuth {
t.Errorf("expected Authorization=%q, got %q", expectAuth, got)
return
}
if got := r.Header.Get("Content-Type"); got != "application/json" {
t.Errorf("expected Content-Type=application/json, got %q", got)
return
}
raw, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("failed to read body: %v", err)
return
}
var body map[string]interface{}
if err := json.Unmarshal(raw, &body); err != nil {
t.Errorf("invalid JSON body: %v\n%s", err, string(raw))
return
}
handler(t, body, w)
}))
}
func newVllmModelForTest(baseURL string) *VllmModel {
return NewVllmModel(
map[string]string{"default": baseURL},
URLSuffix{Rerank: "rerank"},
)
}
func TestVllmRerankHappyPath(t *testing.T) {
srv := newVllmRerankServer(t, "Bearer test-key", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
if body["model"] != "BAAI/bge-reranker-v2-m3" {
t.Errorf("expected model=BAAI/bge-reranker-v2-m3, got %v", body["model"])
}
if body["query"] != "What is RAPTOR?" {
t.Errorf("expected query=What is RAPTOR?, got %v", body["query"])
}
// vLLM differs from NVIDIA: documents is a flat []string, not [{text}].
docs, ok := body["documents"].([]interface{})
if !ok || len(docs) != 3 {
t.Errorf("expected 3 documents, got %v", body["documents"])
return
}
for i, want := range []string{"doc-zero", "doc-one", "doc-two"} {
if docs[i] != want {
t.Errorf("documents[%d]=%v, want %s", i, docs[i], want)
}
}
if body["top_n"] != float64(3) {
t.Errorf("expected top_n=3 (matching len(documents)), got %v", body["top_n"])
}
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"results": []map[string]interface{}{
{"index": 2, "relevance_score": 0.95},
{"index": 0, "relevance_score": 0.42},
{"index": 1, "relevance_score": 0.78},
},
})
})
defer srv.Close()
model := newVllmModelForTest(srv.URL)
apiKey := "test-key"
modelName := "BAAI/bge-reranker-v2-m3"
resp, err := model.Rerank(
&modelName,
"What is RAPTOR?",
[]string{"doc-zero", "doc-one", "doc-two"},
&APIConfig{ApiKey: &apiKey},
&RerankConfig{},
)
if err != nil {
t.Fatalf("Rerank failed: %v", err)
}
if len(resp.Data) != 3 {
t.Fatalf("expected 3 results, got %d", len(resp.Data))
}
want := map[int]float64{0: 0.42, 1: 0.78, 2: 0.95}
for _, r := range resp.Data {
if got, ok := want[r.Index]; !ok || got != r.RelevanceScore {
t.Errorf("unexpected result Index=%d RelevanceScore=%v", r.Index, r.RelevanceScore)
}
}
}
func TestVllmRerankTopNClamp(t *testing.T) {
srv := newVllmRerankServer(t, "Bearer test-key", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
if body["top_n"] != float64(2) {
t.Errorf("expected top_n clamp to RerankConfig.TopN=2, got %v", body["top_n"])
}
_ = json.NewEncoder(w).Encode(map[string]interface{}{"results": []map[string]interface{}{}})
})
defer srv.Close()
model := newVllmModelForTest(srv.URL)
apiKey := "test-key"
modelName := "BAAI/bge-reranker-v2-m3"
if _, err := model.Rerank(
&modelName, "q",
[]string{"a", "b", "c", "d"},
&APIConfig{ApiKey: &apiKey},
&RerankConfig{TopN: 2},
); err != nil {
t.Fatalf("Rerank failed: %v", err)
}
}
func TestVllmRerankEmptyDocuments(t *testing.T) {
model := newVllmModelForTest("http://unused")
apiKey := "test-key"
modelName := "BAAI/bge-reranker-v2-m3"
resp, err := model.Rerank(&modelName, "q", nil, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
if err != nil {
t.Fatalf("expected nil error for empty documents, got %v", err)
}
if len(resp.Data) != 0 {
t.Errorf("expected empty Data, got %d entries", len(resp.Data))
}
}
// vLLM is a local driver; the Authorization header must be omitted when
// no APIConfig.ApiKey is configured. This diverges from the NVIDIA driver
// which requires an API key.
func TestVllmRerankWithoutAPIKey(t *testing.T) {
srv := newVllmRerankServer(t, "", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"results": []map[string]interface{}{
{"index": 0, "relevance_score": 0.5},
},
})
})
defer srv.Close()
model := newVllmModelForTest(srv.URL)
modelName := "BAAI/bge-reranker-v2-m3"
resp, err := model.Rerank(&modelName, "q", []string{"a"}, &APIConfig{}, &RerankConfig{})
if err != nil {
t.Fatalf("Rerank failed without api key: %v", err)
}
if len(resp.Data) != 1 || resp.Data[0].Index != 0 {
t.Errorf("unexpected response: %+v", resp)
}
}
func TestVllmRerankRequiresModelName(t *testing.T) {
model := newVllmModelForTest("http://unused")
apiKey := "test-key"
_, err := model.Rerank(nil, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
if err == nil || !strings.Contains(err.Error(), "model name is required") {
t.Errorf("expected model-name error, got %v", err)
}
}
func TestVllmRerankRejectsHTTPError(t *testing.T) {
srv := newVllmRerankServer(t, "Bearer test-key", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(`{"error":"boom"}`))
})
defer srv.Close()
model := newVllmModelForTest(srv.URL)
apiKey := "test-key"
modelName := "BAAI/bge-reranker-v2-m3"
_, err := model.Rerank(&modelName, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
if err == nil || !strings.Contains(err.Error(), "vLLM rerank API error") {
t.Errorf("expected API error, got %v", err)
}
}
func TestVllmRerankRejectsOutOfRangeIndex(t *testing.T) {
srv := newVllmRerankServer(t, "Bearer test-key", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"results": []map[string]interface{}{
{"index": 5, "relevance_score": 0.9},
},
})
})
defer srv.Close()
model := newVllmModelForTest(srv.URL)
apiKey := "test-key"
modelName := "BAAI/bge-reranker-v2-m3"
_, err := model.Rerank(&modelName, "q", []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{})
if err == nil || !strings.Contains(err.Error(), "unexpected rerank index") {
t.Errorf("expected out-of-range error, got %v", err)
}
}