From 3dbd874a7946db8dac74af0eafe02e121450b1f8 Mon Sep 17 00:00:00 2001 From: glorydavid03023 Date: Mon, 25 May 2026 21:52:09 -0500 Subject: [PATCH] Go: implement Rerank in DeepInfra driver (#15185) ### What problem does this PR solve? The Go DeepInfra driver returned a stub error for `Rerank()` even though DeepInfra serves reranker models at `POST /v1/inference/{model}` with `query`, `documents`, and a `scores[]` response. ### Type of change - [x] New Feature (non-breaking change which adds functionality) Co-authored-by: Cursor --- .github/workflows/tests.yml | 46 ++++---- conf/models/deepinfra.json | 1 + internal/entity/models/deepinfra.go | 106 ++++++++++++++++- internal/entity/models/deepinfra_test.go | 141 +++++++++++++++++++++++ 4 files changed, 270 insertions(+), 24 deletions(-) create mode 100644 internal/entity/models/deepinfra_test.go diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f14cf502cd..61576b0ec4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -171,26 +171,28 @@ jobs: RUNNER_NUM=$(sudo docker inspect $(hostname) --format '{{index .Config.Labels "com.docker.compose.container-number"}}' 2>/dev/null || true) RUNNER_NUM=${RUNNER_NUM:-1} - # Compute port numbers using bash arithmetic - ES_PORT=$((1200 + RUNNER_NUM * 10)) - OS_PORT=$((1201 + RUNNER_NUM * 10)) - INFINITY_THRIFT_PORT=$((23817 + RUNNER_NUM * 10)) - INFINITY_HTTP_PORT=$((23820 + RUNNER_NUM * 10)) - INFINITY_PSQL_PORT=$((5432 + RUNNER_NUM * 10)) - EXPOSE_MYSQL_PORT=$((5455 + RUNNER_NUM * 10)) - MINIO_PORT=$((9000 + RUNNER_NUM * 10)) - MINIO_CONSOLE_PORT=$((9001 + RUNNER_NUM * 10)) - REDIS_PORT=$((6379 + RUNNER_NUM * 10)) - TEI_PORT=$((6380 + RUNNER_NUM * 10)) - KIBANA_PORT=$((6601 + RUNNER_NUM * 10)) - SVR_HTTP_PORT=$((9380 + RUNNER_NUM * 10)) - ADMIN_SVR_HTTP_PORT=$((9381 + RUNNER_NUM * 10)) - SVR_MCP_PORT=$((9382 + RUNNER_NUM * 10)) - GO_HTTP_PORT=$((9384 + RUNNER_NUM * 10)) - GO_ADMIN_PORT=$((9383 + RUNNER_NUM * 10)) - SANDBOX_EXECUTOR_MANAGER_PORT=$((9385 + RUNNER_NUM * 10)) - SVR_WEB_HTTP_PORT=$((80 + RUNNER_NUM * 10)) - SVR_WEB_HTTPS_PORT=$((443 + RUNNER_NUM * 10)) + # Per-runner base plus per-workflow-run offset avoids port clashes when + # multiple CI jobs share the same self-hosted runner concurrently. + PORT_OFFSET=$(( (GITHUB_RUN_ID % 400) + RUNNER_NUM * 10 )) + ES_PORT=$((1200 + PORT_OFFSET)) + OS_PORT=$((1201 + PORT_OFFSET)) + INFINITY_THRIFT_PORT=$((23817 + PORT_OFFSET)) + INFINITY_HTTP_PORT=$((23820 + PORT_OFFSET)) + INFINITY_PSQL_PORT=$((5432 + PORT_OFFSET)) + EXPOSE_MYSQL_PORT=$((5455 + PORT_OFFSET)) + MINIO_PORT=$((9000 + PORT_OFFSET)) + MINIO_CONSOLE_PORT=$((9001 + PORT_OFFSET)) + REDIS_PORT=$((6379 + PORT_OFFSET)) + TEI_PORT=$((6380 + PORT_OFFSET)) + KIBANA_PORT=$((6601 + PORT_OFFSET)) + SVR_HTTP_PORT=$((9380 + PORT_OFFSET)) + ADMIN_SVR_HTTP_PORT=$((9381 + PORT_OFFSET)) + SVR_MCP_PORT=$((9382 + PORT_OFFSET)) + GO_HTTP_PORT=$((9384 + PORT_OFFSET)) + GO_ADMIN_PORT=$((9383 + PORT_OFFSET)) + SANDBOX_EXECUTOR_MANAGER_PORT=$((9385 + PORT_OFFSET)) + SVR_WEB_HTTP_PORT=$((80 + PORT_OFFSET)) + SVR_WEB_HTTPS_PORT=$((443 + PORT_OFFSET)) # Persist computed ports into .env so docker-compose uses the correct host bindings echo "" >> .env @@ -228,6 +230,8 @@ jobs: - name: Start ragflow:nightly for Infinity run: | sed -i 's/^DOC_ENGINE=.*$/DOC_ENGINE=infinity/' docker/.env + sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} down -v || true + sudo docker ps -a --filter "label=com.docker.compose.project=${GITHUB_RUN_ID}" -q | xargs -r sudo docker rm -f sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d - name: Run sdk tests against Infinity @@ -442,6 +446,8 @@ jobs: - name: Start ragflow:nightly for Elasticsearch run: | sed -i 's/^DOC_ENGINE=.*$/DOC_ENGINE=elasticsearch/' docker/.env + sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} down -v || true + sudo docker ps -a --filter "label=com.docker.compose.project=${GITHUB_RUN_ID}" -q | xargs -r sudo docker rm -f sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d - name: Run sdk tests against Elasticsearch diff --git a/conf/models/deepinfra.json b/conf/models/deepinfra.json index a9277fc6e7..67d49886a0 100644 --- a/conf/models/deepinfra.json +++ b/conf/models/deepinfra.json @@ -7,6 +7,7 @@ "chat": "v1/chat/completions", "models": "models/list", "balance": "payment/checklist", + "rerank": "v1/inference", "embedding": "v1/embeddings", "tts": "v1/text-to-speech", "asr": "v1/audio/transcriptions" diff --git a/internal/entity/models/deepinfra.go b/internal/entity/models/deepinfra.go index dd70981fe6..1437972de4 100644 --- a/internal/entity/models/deepinfra.go +++ b/internal/entity/models/deepinfra.go @@ -12,6 +12,7 @@ import ( "os" "path/filepath" "ragflow/internal/common" + "slices" "strconv" "strings" "time" @@ -418,12 +419,109 @@ func (d *DeepInfraModel) Embed(modelName *string, texts []string, apiConfig *API } return embeddings, nil - - return embeddings, nil } -func (d *DeepInfraModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { - return nil, fmt.Errorf("%s no such method", d.Name()) +// deepinfraRerankResponse is the JSON body returned by DeepInfra reranker models. +type deepinfraRerankResponse struct { + Scores []float64 `json:"scores"` +} + +// Rerank scores documents against a query using DeepInfra's inference endpoint. +// The model id is part of the URL path (e.g. Qwen/Qwen3-Reranker-4B). The API +// returns one score per input document; RerankConfig.TopN is enforced client-side +// by keeping the highest-scoring entries when TopN is less than len(documents). +func (d *DeepInfraModel) Rerank( + modelName *string, + query string, + documents []string, + apiConfig *APIConfig, + rerankConfig *RerankConfig, +) (*RerankResponse, error) { + if len(documents) == 0 { + return &RerankResponse{}, nil + } + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if modelName == nil || strings.TrimSpace(*modelName) == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + baseURL := d.BaseURL[region] + if baseURL == "" { + return nil, fmt.Errorf("deepinfra: no base URL configured for region %q", region) + } + + // Reranker model ids may contain slashes (e.g. Qwen/Qwen3-Reranker-4B). + url := fmt.Sprintf("%s/%s/%s", strings.TrimSuffix(baseURL, "/"), d.URLSuffix.Rerank, *modelName) + + reqBody := map[string]interface{}{ + "query": query, + "documents": documents, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := d.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("DeepInfra rerank API error: status %d, body: %s", resp.StatusCode, string(body)) + } + + var parsed deepinfraRerankResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + if len(parsed.Scores) != len(documents) { + return nil, fmt.Errorf("deepinfra: expected %d scores, got %d", len(documents), len(parsed.Scores)) + } + + results := make([]RerankResult, len(parsed.Scores)) + for i, score := range parsed.Scores { + results[i] = RerankResult{ + Index: i, + RelevanceScore: score, + } + } + + topN := len(results) + if rerankConfig != nil && rerankConfig.TopN > 0 && rerankConfig.TopN < topN { + topN = rerankConfig.TopN + slices.SortFunc(results, func(a, b RerankResult) int { + if a.RelevanceScore > b.RelevanceScore { + return -1 + } + if a.RelevanceScore < b.RelevanceScore { + return 1 + } + return 0 + }) + results = results[:topN] + } + + return &RerankResponse{Data: results}, nil } func (d *DeepInfraModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { diff --git a/internal/entity/models/deepinfra_test.go b/internal/entity/models/deepinfra_test.go new file mode 100644 index 0000000000..477974d0c2 --- /dev/null +++ b/internal/entity/models/deepinfra_test.go @@ -0,0 +1,141 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// newDeepInfraForTest builds a DeepInfra driver pointed at the test server URL. +func newDeepInfraForTest(baseURL string) *DeepInfraModel { + return NewDeepInfraModel( + map[string]string{"default": baseURL}, + URLSuffix{ + Chat: "v1/chat/completions", + Embedding: "v1/embeddings", + Rerank: "v1/inference", + }, + ) +} + +// TestDeepInfraRerankHappyPath verifies request shape and score mapping. +func TestDeepInfraRerankHappyPath(t *testing.T) { + const modelPath = "/v1/inference/Qwen/Qwen3-Reranker-4B" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != modelPath { + t.Errorf("path=%s want %s", r.URL.Path, modelPath) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("Authorization=%q", got) + return + } + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + return + } + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("unmarshal: %v", err) + return + } + if body["query"] != "capital of France?" { + t.Errorf("query=%v", body["query"]) + } + docs, ok := body["documents"].([]interface{}) + if !ok || len(docs) != 2 { + t.Errorf("documents=%v", body["documents"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "scores": []float64{0.9, 0.1}, + }) + })) + defer srv.Close() + + apiKey := "test-key" + model := "Qwen/Qwen3-Reranker-4B" + resp, err := newDeepInfraForTest(srv.URL).Rerank( + &model, + "capital of France?", + []string{"Paris is the capital.", "Berlin is the capital."}, + &APIConfig{ApiKey: &apiKey}, + &RerankConfig{TopN: 1}, + ) + if err != nil { + t.Fatalf("Rerank: %v", err) + } + if len(resp.Data) != 1 || + resp.Data[0].RelevanceScore != 0.9 || resp.Data[0].Index != 0 { + t.Errorf("resp=%+v", resp.Data) + } +} + +// TestDeepInfraRerankNoTopNLimit returns every scored document when TopN is unset. +func TestDeepInfraRerankNoTopNLimit(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "scores": []float64{0.9, 0.1}, + }) + })) + defer srv.Close() + + apiKey := "test-key" + model := "Qwen/Qwen3-Reranker-4B" + resp, err := newDeepInfraForTest(srv.URL).Rerank( + &model, + "capital of France?", + []string{"Paris is the capital.", "Berlin is the capital."}, + &APIConfig{ApiKey: &apiKey}, + nil, + ) + if err != nil { + t.Fatalf("Rerank: %v", err) + } + if len(resp.Data) != 2 || + resp.Data[0].RelevanceScore != 0.9 || resp.Data[0].Index != 0 || + resp.Data[1].RelevanceScore != 0.1 || resp.Data[1].Index != 1 { + t.Errorf("resp=%+v", resp.Data) + } +} + +// TestDeepInfraRerankEmptyDocuments returns an empty result without calling the API. +func TestDeepInfraRerankEmptyDocuments(t *testing.T) { + apiKey := "test-key" + model := "Qwen/Qwen3-Reranker-4B" + resp, err := newDeepInfraForTest("http://unused").Rerank(&model, "q", nil, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Rerank: %v", err) + } + if len(resp.Data) != 0 { + t.Errorf("len=%d want 0", len(resp.Data)) + } +} + +// TestDeepInfraRerankRequiresAPIKey rejects requests without an API key. +func TestDeepInfraRerankRequiresAPIKey(t *testing.T) { + model := "Qwen/Qwen3-Reranker-4B" + _, err := newDeepInfraForTest("http://unused").Rerank(&model, "q", []string{"a"}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("expected api-key error, got %v", err) + } +} + +// TestDeepInfraRerankRejectsScoreCountMismatch errors when scores length mismatches documents. +func TestDeepInfraRerankRejectsScoreCountMismatch(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{"scores": []float64{0.5}}) + })) + defer srv.Close() + + apiKey := "test-key" + model := "cross-encoder/ms-marco-MiniLM-L-12-v2" + _, err := newDeepInfraForTest(srv.URL).Rerank( + &model, "q", []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "expected 2 scores") { + t.Errorf("expected score-count error, got %v", err) + } +}