mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Go: implement Rerank in Xinference driver (#15032)
### What problem does this PR solve? Fixes #14816 The Xinference Go driver landed chat in #14938 and Embed is in review in #14932, but `Rerank` shipped as a stub that returns `"xinference, no such method"`. Tenants who launch a rerank model with `--model-type rerank` on their Xinference instance cannot route it through the Go API server. This PR fills the gap. Xinference exposes an OpenAI-compatible REST API. The rerank endpoint is at `POST <base>/v1/rerank` and accepts the Cohere-shaped body `{model, query, documents, top_n}`, returning `{results: [{index, relevance_score}]}` — the same wire shape used by the merged NVIDIA (#14778), Aliyun (#14676), Gitee (#14656), ZhipuAI (#14608), Novita (#15014), and LocalAI (#14813) Rerank implementations. Documented in [Xinference rerank docs](https://inference.readthedocs.io/en/v1.6.1/models/model_abilities/rerank.html); the [builtin rerank model catalog](https://inference.readthedocs.io/en/stable/models/builtin/rerank/) lists `bge-reranker-base`, `bge-reranker-large`, `bge-reranker-v2-m3`, and others.
This commit is contained in:
@@ -2,7 +2,8 @@
|
||||
"name": "xinference",
|
||||
"url_suffix": {
|
||||
"chat": "v1/chat/completions",
|
||||
"models": "v1/models"
|
||||
"models": "v1/models",
|
||||
"rerank": "v1/rerank"
|
||||
},
|
||||
"class": "local"
|
||||
}
|
||||
|
||||
@@ -384,8 +384,102 @@ func (x *XinferenceModel) Embed(modelName *string, texts []string, apiConfig *AP
|
||||
return nil, fmt.Errorf("%s, no such method", x.Name())
|
||||
}
|
||||
|
||||
type xinferenceRerankResult struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
}
|
||||
|
||||
type xinferenceRerankResponse struct {
|
||||
Results []xinferenceRerankResult `json:"results"`
|
||||
}
|
||||
|
||||
// Rerank scores documents against the query using the Xinference
|
||||
// /v1/rerank endpoint and returns one RerankResult per scored document
|
||||
// in the API's ranking order. Caller may sort by Index to recover
|
||||
// original input order. Xinference rerank models are launched with
|
||||
// --model-type rerank and exposed under the OpenAI-compatible base URL.
|
||||
func (x *XinferenceModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", x.Name())
|
||||
if len(documents) == 0 {
|
||||
return &RerankResponse{}, nil
|
||||
}
|
||||
if modelName == nil || *modelName == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
|
||||
baseURL, err := x.baseURLForRegion(xinferenceRegion(apiConfig))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if x.URLSuffix.Rerank == "" {
|
||||
return nil, fmt.Errorf("xinference: no rerank URL suffix configured")
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s", baseURL, x.URLSuffix.Rerank)
|
||||
|
||||
topN := len(documents)
|
||||
if rerankConfig != nil && rerankConfig.TopN > 0 && rerankConfig.TopN < topN {
|
||||
topN = rerankConfig.TopN
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": *modelName,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": topN,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
setXinferenceAuth(req, apiConfig)
|
||||
|
||||
resp, err := x.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Xinference rerank API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var parsed xinferenceRerankResponse
|
||||
if err = json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
rerankResponse := RerankResponse{Data: make([]RerankResult, 0, len(parsed.Results))}
|
||||
seen := make([]bool, len(documents))
|
||||
for _, item := range parsed.Results {
|
||||
if item.Index < 0 || item.Index >= len(documents) {
|
||||
return nil, fmt.Errorf("xinference: rerank index %d out of range for %d inputs", item.Index, len(documents))
|
||||
}
|
||||
if seen[item.Index] {
|
||||
return nil, fmt.Errorf("xinference: duplicate rerank index %d in response", item.Index)
|
||||
}
|
||||
rerankResponse.Data = append(rerankResponse.Data, RerankResult{
|
||||
Index: item.Index,
|
||||
RelevanceScore: item.RelevanceScore,
|
||||
})
|
||||
seen[item.Index] = true
|
||||
}
|
||||
|
||||
return &rerankResponse, nil
|
||||
}
|
||||
|
||||
func (x *XinferenceModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
|
||||
|
||||
@@ -16,6 +16,7 @@ func newXinferenceForTest(baseURL string) *XinferenceModel {
|
||||
URLSuffix{
|
||||
Chat: "v1/chat/completions",
|
||||
Models: "v1/models",
|
||||
Rerank: "v1/rerank",
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -289,9 +290,6 @@ func TestXinferenceUnsupportedMethodsReturnNoSuchMethod(t *testing.T) {
|
||||
if _, err := x.Embed(&model, []string{"x"}, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") {
|
||||
t.Errorf("Embed: expected no such method, got %v", err)
|
||||
}
|
||||
if _, err := x.Rerank(&model, "q", []string{"d"}, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") {
|
||||
t.Errorf("Rerank: expected no such method, got %v", err)
|
||||
}
|
||||
if _, err := x.Balance(&APIConfig{}); err == nil || !strings.Contains(err.Error(), "no such method") {
|
||||
t.Errorf("Balance: expected no such method, got %v", err)
|
||||
}
|
||||
@@ -311,3 +309,190 @@ func TestXinferenceUnsupportedMethodsReturnNoSuchMethod(t *testing.T) {
|
||||
t.Errorf("OCRFile: expected no such method, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func newXinferenceRerankServer(t *testing.T, expectedAuth string, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server {
|
||||
t.Helper()
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/rerank" {
|
||||
t.Errorf("path=%s want /v1/rerank", r.URL.Path)
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("method=%s want POST", r.Method)
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); got != expectedAuth {
|
||||
t.Errorf("Authorization=%q want %q", got, expectedAuth)
|
||||
}
|
||||
if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/json") {
|
||||
t.Errorf("Content-Type=%q", got)
|
||||
}
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("read body: %v", err)
|
||||
return
|
||||
}
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal(raw, &body); err != nil {
|
||||
t.Errorf("unmarshal: %v\nraw=%s", err, string(raw))
|
||||
return
|
||||
}
|
||||
handler(t, body, w)
|
||||
}))
|
||||
}
|
||||
|
||||
func TestXinferenceRerankHappyPathReordersByIndex(t *testing.T) {
|
||||
srv := newXinferenceRerankServer(t, "", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
if body["model"] != "bge-reranker-v2-m3" {
|
||||
t.Errorf("model=%v", body["model"])
|
||||
}
|
||||
if body["query"] != "capital of France" {
|
||||
t.Errorf("query=%v", body["query"])
|
||||
}
|
||||
if got := body["top_n"].(float64); got != 3 {
|
||||
t.Errorf("top_n=%v want 3", got)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"results": []map[string]interface{}{
|
||||
{"index": 2, "relevance_score": 0.91},
|
||||
{"index": 0, "relevance_score": 0.88},
|
||||
{"index": 1, "relevance_score": 0.42},
|
||||
},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
x := newXinferenceForTest(srv.URL)
|
||||
model := "bge-reranker-v2-m3"
|
||||
resp, err := x.Rerank(&model, "capital of France",
|
||||
[]string{"Paris is the capital of France.", "Eiffel Tower.", "Berlin is the capital of Germany."},
|
||||
&APIConfig{}, nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Rerank: %v", err)
|
||||
}
|
||||
if len(resp.Data) != 3 {
|
||||
t.Fatalf("Data len=%d", len(resp.Data))
|
||||
}
|
||||
if resp.Data[0].Index != 2 || resp.Data[1].Index != 0 || resp.Data[2].Index != 1 {
|
||||
t.Errorf("order=%v %v %v", resp.Data[0].Index, resp.Data[1].Index, resp.Data[2].Index)
|
||||
}
|
||||
if resp.Data[0].RelevanceScore != 0.91 {
|
||||
t.Errorf("top score=%v", resp.Data[0].RelevanceScore)
|
||||
}
|
||||
}
|
||||
|
||||
func TestXinferenceRerankNormalizesV1BaseURL(t *testing.T) {
|
||||
srv := newXinferenceRerankServer(t, "Bearer test-key", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{"results": []map[string]interface{}{}})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
x := NewXinferenceModel(
|
||||
map[string]string{"default": srv.URL + "/v1"},
|
||||
URLSuffix{Rerank: "v1/rerank"},
|
||||
)
|
||||
apiKey := "test-key"
|
||||
model := "bge-reranker-v2-m3"
|
||||
_, err := x.Rerank(&model, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Rerank: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestXinferenceRerankRespectsTopNConfig(t *testing.T) {
|
||||
srv := newXinferenceRerankServer(t, "", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
if got := body["top_n"].(float64); got != 2 {
|
||||
t.Errorf("top_n=%v want 2", got)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{"results": []map[string]interface{}{}})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
x := newXinferenceForTest(srv.URL)
|
||||
model := "bge-reranker-v2-m3"
|
||||
_, err := x.Rerank(&model, "q", []string{"a", "b", "c", "d"}, &APIConfig{}, &RerankConfig{TopN: 2})
|
||||
if err != nil {
|
||||
t.Fatalf("Rerank: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestXinferenceRerankEmptyDocumentsShortCircuits(t *testing.T) {
|
||||
x := newXinferenceForTest("http://unused")
|
||||
model := "bge-reranker-v2-m3"
|
||||
resp, err := x.Rerank(&model, "q", nil, &APIConfig{}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Rerank: %v", err)
|
||||
}
|
||||
if len(resp.Data) != 0 {
|
||||
t.Errorf("Data len=%d want 0", len(resp.Data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestXinferenceRerankRequiresModelName(t *testing.T) {
|
||||
x := newXinferenceForTest("http://unused")
|
||||
_, err := x.Rerank(nil, "q", []string{"a"}, &APIConfig{}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "model name is required") {
|
||||
t.Errorf("err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestXinferenceRerankRejectsOutOfRangeIndex(t *testing.T) {
|
||||
srv := newXinferenceRerankServer(t, "", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"results": []map[string]interface{}{{"index": 5, "relevance_score": 0.1}},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
x := newXinferenceForTest(srv.URL)
|
||||
model := "bge-reranker-v2-m3"
|
||||
_, err := x.Rerank(&model, "q", []string{"a", "b"}, &APIConfig{}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "out of range") {
|
||||
t.Errorf("err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestXinferenceRerankRejectsDuplicateIndex(t *testing.T) {
|
||||
srv := newXinferenceRerankServer(t, "", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"results": []map[string]interface{}{
|
||||
{"index": 0, "relevance_score": 0.9},
|
||||
{"index": 0, "relevance_score": 0.8},
|
||||
},
|
||||
})
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
x := newXinferenceForTest(srv.URL)
|
||||
model := "bge-reranker-v2-m3"
|
||||
_, err := x.Rerank(&model, "q", []string{"a", "b"}, &APIConfig{}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "duplicate") {
|
||||
t.Errorf("err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestXinferenceRerankSurfacesHTTPError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(`{"error":"model not loaded"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
x := newXinferenceForTest(srv.URL)
|
||||
model := "bge-reranker-v2-m3"
|
||||
_, err := x.Rerank(&model, "q", []string{"a"}, &APIConfig{}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "Xinference rerank API error") {
|
||||
t.Errorf("err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestXinferenceRerankRejectsMissingRerankSuffix(t *testing.T) {
|
||||
x := NewXinferenceModel(
|
||||
map[string]string{"default": "http://unused"},
|
||||
URLSuffix{Chat: "v1/chat/completions"},
|
||||
)
|
||||
model := "bge-reranker-v2-m3"
|
||||
_, err := x.Rerank(&model, "q", []string{"a"}, &APIConfig{}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "no rerank URL suffix configured") {
|
||||
t.Errorf("err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user