Go: implement Rerank in Xinference driver (#15032)

### What problem does this PR solve?

Fixes #14816

The Xinference Go driver landed chat in #14938 and Embed is in review in
#14932, but `Rerank` shipped as a stub that returns `"xinference, no
such method"`. Tenants who launch a rerank model with `--model-type
rerank` on their Xinference instance cannot route it through the Go API
server. This PR fills the gap.

Xinference exposes an OpenAI-compatible REST API. The rerank endpoint is
at `POST <base>/v1/rerank` and accepts the Cohere-shaped body `{model,
query, documents, top_n}`, returning `{results: [{index,
relevance_score}]}` — the same wire shape used by the merged NVIDIA
(#14778), Aliyun (#14676), Gitee (#14656), ZhipuAI (#14608), Novita
(#15014), and LocalAI (#14813) Rerank implementations. Documented in
[Xinference rerank
docs](https://inference.readthedocs.io/en/v1.6.1/models/model_abilities/rerank.html);
the [builtin rerank model
catalog](https://inference.readthedocs.io/en/stable/models/builtin/rerank/)
lists `bge-reranker-base`, `bge-reranker-large`, `bge-reranker-v2-m3`,
and others.
This commit is contained in:
Renzo
2026-05-20 16:14:30 -10:00
committed by GitHub
parent 63db30f0d9
commit 536ed07d27
3 changed files with 285 additions and 5 deletions

View File

@@ -2,7 +2,8 @@
"name": "xinference",
"url_suffix": {
"chat": "v1/chat/completions",
"models": "v1/models"
"models": "v1/models",
"rerank": "v1/rerank"
},
"class": "local"
}

View File

@@ -384,8 +384,102 @@ func (x *XinferenceModel) Embed(modelName *string, texts []string, apiConfig *AP
return nil, fmt.Errorf("%s, no such method", x.Name())
}
type xinferenceRerankResult struct {
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
}
type xinferenceRerankResponse struct {
Results []xinferenceRerankResult `json:"results"`
}
// Rerank scores documents against the query using the Xinference
// /v1/rerank endpoint and returns one RerankResult per scored document
// in the API's ranking order. Caller may sort by Index to recover
// original input order. Xinference rerank models are launched with
// --model-type rerank and exposed under the OpenAI-compatible base URL.
func (x *XinferenceModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
return nil, fmt.Errorf("%s, no such method", x.Name())
if len(documents) == 0 {
return &RerankResponse{}, nil
}
if modelName == nil || *modelName == "" {
return nil, fmt.Errorf("model name is required")
}
baseURL, err := x.baseURLForRegion(xinferenceRegion(apiConfig))
if err != nil {
return nil, err
}
if x.URLSuffix.Rerank == "" {
return nil, fmt.Errorf("xinference: no rerank URL suffix configured")
}
url := fmt.Sprintf("%s/%s", baseURL, x.URLSuffix.Rerank)
topN := len(documents)
if rerankConfig != nil && rerankConfig.TopN > 0 && rerankConfig.TopN < topN {
topN = rerankConfig.TopN
}
reqBody := map[string]interface{}{
"model": *modelName,
"query": query,
"documents": documents,
"top_n": topN,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
setXinferenceAuth(req, apiConfig)
resp, err := x.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Xinference rerank API error: %s, body: %s", resp.Status, string(body))
}
var parsed xinferenceRerankResponse
if err = json.Unmarshal(body, &parsed); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
rerankResponse := RerankResponse{Data: make([]RerankResult, 0, len(parsed.Results))}
seen := make([]bool, len(documents))
for _, item := range parsed.Results {
if item.Index < 0 || item.Index >= len(documents) {
return nil, fmt.Errorf("xinference: rerank index %d out of range for %d inputs", item.Index, len(documents))
}
if seen[item.Index] {
return nil, fmt.Errorf("xinference: duplicate rerank index %d in response", item.Index)
}
rerankResponse.Data = append(rerankResponse.Data, RerankResult{
Index: item.Index,
RelevanceScore: item.RelevanceScore,
})
seen[item.Index] = true
}
return &rerankResponse, nil
}
func (x *XinferenceModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {

View File

@@ -16,6 +16,7 @@ func newXinferenceForTest(baseURL string) *XinferenceModel {
URLSuffix{
Chat: "v1/chat/completions",
Models: "v1/models",
Rerank: "v1/rerank",
},
)
}
@@ -289,9 +290,6 @@ func TestXinferenceUnsupportedMethodsReturnNoSuchMethod(t *testing.T) {
if _, err := x.Embed(&model, []string{"x"}, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") {
t.Errorf("Embed: expected no such method, got %v", err)
}
if _, err := x.Rerank(&model, "q", []string{"d"}, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") {
t.Errorf("Rerank: expected no such method, got %v", err)
}
if _, err := x.Balance(&APIConfig{}); err == nil || !strings.Contains(err.Error(), "no such method") {
t.Errorf("Balance: expected no such method, got %v", err)
}
@@ -311,3 +309,190 @@ func TestXinferenceUnsupportedMethodsReturnNoSuchMethod(t *testing.T) {
t.Errorf("OCRFile: expected no such method, got %v", err)
}
}
func newXinferenceRerankServer(t *testing.T, expectedAuth string, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/rerank" {
t.Errorf("path=%s want /v1/rerank", r.URL.Path)
}
if r.Method != http.MethodPost {
t.Errorf("method=%s want POST", r.Method)
}
if got := r.Header.Get("Authorization"); got != expectedAuth {
t.Errorf("Authorization=%q want %q", got, expectedAuth)
}
if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/json") {
t.Errorf("Content-Type=%q", got)
}
raw, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("read body: %v", err)
return
}
var body map[string]interface{}
if err := json.Unmarshal(raw, &body); err != nil {
t.Errorf("unmarshal: %v\nraw=%s", err, string(raw))
return
}
handler(t, body, w)
}))
}
func TestXinferenceRerankHappyPathReordersByIndex(t *testing.T) {
srv := newXinferenceRerankServer(t, "", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
if body["model"] != "bge-reranker-v2-m3" {
t.Errorf("model=%v", body["model"])
}
if body["query"] != "capital of France" {
t.Errorf("query=%v", body["query"])
}
if got := body["top_n"].(float64); got != 3 {
t.Errorf("top_n=%v want 3", got)
}
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"results": []map[string]interface{}{
{"index": 2, "relevance_score": 0.91},
{"index": 0, "relevance_score": 0.88},
{"index": 1, "relevance_score": 0.42},
},
})
})
defer srv.Close()
x := newXinferenceForTest(srv.URL)
model := "bge-reranker-v2-m3"
resp, err := x.Rerank(&model, "capital of France",
[]string{"Paris is the capital of France.", "Eiffel Tower.", "Berlin is the capital of Germany."},
&APIConfig{}, nil,
)
if err != nil {
t.Fatalf("Rerank: %v", err)
}
if len(resp.Data) != 3 {
t.Fatalf("Data len=%d", len(resp.Data))
}
if resp.Data[0].Index != 2 || resp.Data[1].Index != 0 || resp.Data[2].Index != 1 {
t.Errorf("order=%v %v %v", resp.Data[0].Index, resp.Data[1].Index, resp.Data[2].Index)
}
if resp.Data[0].RelevanceScore != 0.91 {
t.Errorf("top score=%v", resp.Data[0].RelevanceScore)
}
}
func TestXinferenceRerankNormalizesV1BaseURL(t *testing.T) {
srv := newXinferenceRerankServer(t, "Bearer test-key", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
_ = json.NewEncoder(w).Encode(map[string]interface{}{"results": []map[string]interface{}{}})
})
defer srv.Close()
x := NewXinferenceModel(
map[string]string{"default": srv.URL + "/v1"},
URLSuffix{Rerank: "v1/rerank"},
)
apiKey := "test-key"
model := "bge-reranker-v2-m3"
_, err := x.Rerank(&model, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil)
if err != nil {
t.Fatalf("Rerank: %v", err)
}
}
func TestXinferenceRerankRespectsTopNConfig(t *testing.T) {
srv := newXinferenceRerankServer(t, "", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
if got := body["top_n"].(float64); got != 2 {
t.Errorf("top_n=%v want 2", got)
}
_ = json.NewEncoder(w).Encode(map[string]interface{}{"results": []map[string]interface{}{}})
})
defer srv.Close()
x := newXinferenceForTest(srv.URL)
model := "bge-reranker-v2-m3"
_, err := x.Rerank(&model, "q", []string{"a", "b", "c", "d"}, &APIConfig{}, &RerankConfig{TopN: 2})
if err != nil {
t.Fatalf("Rerank: %v", err)
}
}
func TestXinferenceRerankEmptyDocumentsShortCircuits(t *testing.T) {
x := newXinferenceForTest("http://unused")
model := "bge-reranker-v2-m3"
resp, err := x.Rerank(&model, "q", nil, &APIConfig{}, nil)
if err != nil {
t.Fatalf("Rerank: %v", err)
}
if len(resp.Data) != 0 {
t.Errorf("Data len=%d want 0", len(resp.Data))
}
}
func TestXinferenceRerankRequiresModelName(t *testing.T) {
x := newXinferenceForTest("http://unused")
_, err := x.Rerank(nil, "q", []string{"a"}, &APIConfig{}, nil)
if err == nil || !strings.Contains(err.Error(), "model name is required") {
t.Errorf("err=%v", err)
}
}
func TestXinferenceRerankRejectsOutOfRangeIndex(t *testing.T) {
srv := newXinferenceRerankServer(t, "", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"results": []map[string]interface{}{{"index": 5, "relevance_score": 0.1}},
})
})
defer srv.Close()
x := newXinferenceForTest(srv.URL)
model := "bge-reranker-v2-m3"
_, err := x.Rerank(&model, "q", []string{"a", "b"}, &APIConfig{}, nil)
if err == nil || !strings.Contains(err.Error(), "out of range") {
t.Errorf("err=%v", err)
}
}
func TestXinferenceRerankRejectsDuplicateIndex(t *testing.T) {
srv := newXinferenceRerankServer(t, "", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"results": []map[string]interface{}{
{"index": 0, "relevance_score": 0.9},
{"index": 0, "relevance_score": 0.8},
},
})
})
defer srv.Close()
x := newXinferenceForTest(srv.URL)
model := "bge-reranker-v2-m3"
_, err := x.Rerank(&model, "q", []string{"a", "b"}, &APIConfig{}, nil)
if err == nil || !strings.Contains(err.Error(), "duplicate") {
t.Errorf("err=%v", err)
}
}
func TestXinferenceRerankSurfacesHTTPError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(`{"error":"model not loaded"}`))
}))
defer srv.Close()
x := newXinferenceForTest(srv.URL)
model := "bge-reranker-v2-m3"
_, err := x.Rerank(&model, "q", []string{"a"}, &APIConfig{}, nil)
if err == nil || !strings.Contains(err.Error(), "Xinference rerank API error") {
t.Errorf("err=%v", err)
}
}
func TestXinferenceRerankRejectsMissingRerankSuffix(t *testing.T) {
x := NewXinferenceModel(
map[string]string{"default": "http://unused"},
URLSuffix{Chat: "v1/chat/completions"},
)
model := "bge-reranker-v2-m3"
_, err := x.Rerank(&model, "q", []string{"a"}, &APIConfig{}, nil)
if err == nil || !strings.Contains(err.Error(), "no rerank URL suffix configured") {
t.Errorf("err=%v", err)
}
}