diff --git a/conf/models/xinference.json b/conf/models/xinference.json index cf50dbc731..5076a63a51 100644 --- a/conf/models/xinference.json +++ b/conf/models/xinference.json @@ -2,7 +2,8 @@ "name": "xinference", "url_suffix": { "chat": "v1/chat/completions", - "models": "v1/models" + "models": "v1/models", + "rerank": "v1/rerank" }, "class": "local" } diff --git a/internal/entity/models/xinference.go b/internal/entity/models/xinference.go index d8f8fa39f5..30325c8006 100644 --- a/internal/entity/models/xinference.go +++ b/internal/entity/models/xinference.go @@ -384,8 +384,102 @@ func (x *XinferenceModel) Embed(modelName *string, texts []string, apiConfig *AP return nil, fmt.Errorf("%s, no such method", x.Name()) } +type xinferenceRerankResult struct { + Index int `json:"index"` + RelevanceScore float64 `json:"relevance_score"` +} + +type xinferenceRerankResponse struct { + Results []xinferenceRerankResult `json:"results"` +} + +// Rerank scores documents against the query using the Xinference +// /v1/rerank endpoint and returns one RerankResult per scored document +// in the API's ranking order. Caller may sort by Index to recover +// original input order. Xinference rerank models are launched with +// --model-type rerank and exposed under the OpenAI-compatible base URL. func (x *XinferenceModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { - return nil, fmt.Errorf("%s, no such method", x.Name()) + if len(documents) == 0 { + return &RerankResponse{}, nil + } + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + baseURL, err := x.baseURLForRegion(xinferenceRegion(apiConfig)) + if err != nil { + return nil, err + } + if x.URLSuffix.Rerank == "" { + return nil, fmt.Errorf("xinference: no rerank URL suffix configured") + } + url := fmt.Sprintf("%s/%s", baseURL, x.URLSuffix.Rerank) + + topN := len(documents) + if rerankConfig != nil && rerankConfig.TopN > 0 && rerankConfig.TopN < topN { + topN = rerankConfig.TopN + } + + reqBody := map[string]interface{}{ + "model": *modelName, + "query": query, + "documents": documents, + "top_n": topN, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + setXinferenceAuth(req, apiConfig) + + resp, err := x.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Xinference rerank API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed xinferenceRerankResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + rerankResponse := RerankResponse{Data: make([]RerankResult, 0, len(parsed.Results))} + seen := make([]bool, len(documents)) + for _, item := range parsed.Results { + if item.Index < 0 || item.Index >= len(documents) { + return nil, fmt.Errorf("xinference: rerank index %d out of range for %d inputs", item.Index, len(documents)) + } + if seen[item.Index] { + return nil, fmt.Errorf("xinference: duplicate rerank index %d in response", item.Index) + } + rerankResponse.Data = append(rerankResponse.Data, RerankResult{ + Index: item.Index, + RelevanceScore: item.RelevanceScore, + }) + seen[item.Index] = true + } + + return &rerankResponse, nil } func (x *XinferenceModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { diff --git a/internal/entity/models/xinference_test.go b/internal/entity/models/xinference_test.go index af3179ea0e..95d577aa72 100644 --- a/internal/entity/models/xinference_test.go +++ b/internal/entity/models/xinference_test.go @@ -16,6 +16,7 @@ func newXinferenceForTest(baseURL string) *XinferenceModel { URLSuffix{ Chat: "v1/chat/completions", Models: "v1/models", + Rerank: "v1/rerank", }, ) } @@ -289,9 +290,6 @@ func TestXinferenceUnsupportedMethodsReturnNoSuchMethod(t *testing.T) { if _, err := x.Embed(&model, []string{"x"}, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { t.Errorf("Embed: expected no such method, got %v", err) } - if _, err := x.Rerank(&model, "q", []string{"d"}, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { - t.Errorf("Rerank: expected no such method, got %v", err) - } if _, err := x.Balance(&APIConfig{}); err == nil || !strings.Contains(err.Error(), "no such method") { t.Errorf("Balance: expected no such method, got %v", err) } @@ -311,3 +309,190 @@ func TestXinferenceUnsupportedMethodsReturnNoSuchMethod(t *testing.T) { t.Errorf("OCRFile: expected no such method, got %v", err) } } + +func newXinferenceRerankServer(t *testing.T, expectedAuth string, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/rerank" { + t.Errorf("path=%s want /v1/rerank", r.URL.Path) + } + if r.Method != http.MethodPost { + t.Errorf("method=%s want POST", r.Method) + } + if got := r.Header.Get("Authorization"); got != expectedAuth { + t.Errorf("Authorization=%q want %q", got, expectedAuth) + } + if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/json") { + t.Errorf("Content-Type=%q", got) + } + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + return + } + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("unmarshal: %v\nraw=%s", err, string(raw)) + return + } + handler(t, body, w) + })) +} + +func TestXinferenceRerankHappyPathReordersByIndex(t *testing.T) { + srv := newXinferenceRerankServer(t, "", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["model"] != "bge-reranker-v2-m3" { + t.Errorf("model=%v", body["model"]) + } + if body["query"] != "capital of France" { + t.Errorf("query=%v", body["query"]) + } + if got := body["top_n"].(float64); got != 3 { + t.Errorf("top_n=%v want 3", got) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "results": []map[string]interface{}{ + {"index": 2, "relevance_score": 0.91}, + {"index": 0, "relevance_score": 0.88}, + {"index": 1, "relevance_score": 0.42}, + }, + }) + }) + defer srv.Close() + + x := newXinferenceForTest(srv.URL) + model := "bge-reranker-v2-m3" + resp, err := x.Rerank(&model, "capital of France", + []string{"Paris is the capital of France.", "Eiffel Tower.", "Berlin is the capital of Germany."}, + &APIConfig{}, nil, + ) + if err != nil { + t.Fatalf("Rerank: %v", err) + } + if len(resp.Data) != 3 { + t.Fatalf("Data len=%d", len(resp.Data)) + } + if resp.Data[0].Index != 2 || resp.Data[1].Index != 0 || resp.Data[2].Index != 1 { + t.Errorf("order=%v %v %v", resp.Data[0].Index, resp.Data[1].Index, resp.Data[2].Index) + } + if resp.Data[0].RelevanceScore != 0.91 { + t.Errorf("top score=%v", resp.Data[0].RelevanceScore) + } +} + +func TestXinferenceRerankNormalizesV1BaseURL(t *testing.T) { + srv := newXinferenceRerankServer(t, "Bearer test-key", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{"results": []map[string]interface{}{}}) + }) + defer srv.Close() + + x := NewXinferenceModel( + map[string]string{"default": srv.URL + "/v1"}, + URLSuffix{Rerank: "v1/rerank"}, + ) + apiKey := "test-key" + model := "bge-reranker-v2-m3" + _, err := x.Rerank(&model, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Rerank: %v", err) + } +} + +func TestXinferenceRerankRespectsTopNConfig(t *testing.T) { + srv := newXinferenceRerankServer(t, "", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if got := body["top_n"].(float64); got != 2 { + t.Errorf("top_n=%v want 2", got) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{"results": []map[string]interface{}{}}) + }) + defer srv.Close() + + x := newXinferenceForTest(srv.URL) + model := "bge-reranker-v2-m3" + _, err := x.Rerank(&model, "q", []string{"a", "b", "c", "d"}, &APIConfig{}, &RerankConfig{TopN: 2}) + if err != nil { + t.Fatalf("Rerank: %v", err) + } +} + +func TestXinferenceRerankEmptyDocumentsShortCircuits(t *testing.T) { + x := newXinferenceForTest("http://unused") + model := "bge-reranker-v2-m3" + resp, err := x.Rerank(&model, "q", nil, &APIConfig{}, nil) + if err != nil { + t.Fatalf("Rerank: %v", err) + } + if len(resp.Data) != 0 { + t.Errorf("Data len=%d want 0", len(resp.Data)) + } +} + +func TestXinferenceRerankRequiresModelName(t *testing.T) { + x := newXinferenceForTest("http://unused") + _, err := x.Rerank(nil, "q", []string{"a"}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "model name is required") { + t.Errorf("err=%v", err) + } +} + +func TestXinferenceRerankRejectsOutOfRangeIndex(t *testing.T) { + srv := newXinferenceRerankServer(t, "", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "results": []map[string]interface{}{{"index": 5, "relevance_score": 0.1}}, + }) + }) + defer srv.Close() + + x := newXinferenceForTest(srv.URL) + model := "bge-reranker-v2-m3" + _, err := x.Rerank(&model, "q", []string{"a", "b"}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "out of range") { + t.Errorf("err=%v", err) + } +} + +func TestXinferenceRerankRejectsDuplicateIndex(t *testing.T) { + srv := newXinferenceRerankServer(t, "", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "results": []map[string]interface{}{ + {"index": 0, "relevance_score": 0.9}, + {"index": 0, "relevance_score": 0.8}, + }, + }) + }) + defer srv.Close() + + x := newXinferenceForTest(srv.URL) + model := "bge-reranker-v2-m3" + _, err := x.Rerank(&model, "q", []string{"a", "b"}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "duplicate") { + t.Errorf("err=%v", err) + } +} + +func TestXinferenceRerankSurfacesHTTPError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"model not loaded"}`)) + })) + defer srv.Close() + + x := newXinferenceForTest(srv.URL) + model := "bge-reranker-v2-m3" + _, err := x.Rerank(&model, "q", []string{"a"}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "Xinference rerank API error") { + t.Errorf("err=%v", err) + } +} + +func TestXinferenceRerankRejectsMissingRerankSuffix(t *testing.T) { + x := NewXinferenceModel( + map[string]string{"default": "http://unused"}, + URLSuffix{Chat: "v1/chat/completions"}, + ) + model := "bge-reranker-v2-m3" + _, err := x.Rerank(&model, "q", []string{"a"}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "no rerank URL suffix configured") { + t.Errorf("err=%v", err) + } +}