Go: implement Rerank in DeepInfra driver (#15185)

### What problem does this PR solve?

The Go DeepInfra driver returned a stub error for `Rerank()` even though
DeepInfra serves reranker models at `POST /v1/inference/{model}` with
`query`, `documents`, and a `scores[]` response.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
glorydavid03023
2026-05-25 21:52:09 -05:00
committed by GitHub
parent 67f7d87dff
commit 3dbd874a79
4 changed files with 270 additions and 24 deletions

View File

@@ -171,26 +171,28 @@ jobs:
RUNNER_NUM=$(sudo docker inspect $(hostname) --format '{{index .Config.Labels "com.docker.compose.container-number"}}' 2>/dev/null || true)
RUNNER_NUM=${RUNNER_NUM:-1}
# Compute port numbers using bash arithmetic
ES_PORT=$((1200 + RUNNER_NUM * 10))
OS_PORT=$((1201 + RUNNER_NUM * 10))
INFINITY_THRIFT_PORT=$((23817 + RUNNER_NUM * 10))
INFINITY_HTTP_PORT=$((23820 + RUNNER_NUM * 10))
INFINITY_PSQL_PORT=$((5432 + RUNNER_NUM * 10))
EXPOSE_MYSQL_PORT=$((5455 + RUNNER_NUM * 10))
MINIO_PORT=$((9000 + RUNNER_NUM * 10))
MINIO_CONSOLE_PORT=$((9001 + RUNNER_NUM * 10))
REDIS_PORT=$((6379 + RUNNER_NUM * 10))
TEI_PORT=$((6380 + RUNNER_NUM * 10))
KIBANA_PORT=$((6601 + RUNNER_NUM * 10))
SVR_HTTP_PORT=$((9380 + RUNNER_NUM * 10))
ADMIN_SVR_HTTP_PORT=$((9381 + RUNNER_NUM * 10))
SVR_MCP_PORT=$((9382 + RUNNER_NUM * 10))
GO_HTTP_PORT=$((9384 + RUNNER_NUM * 10))
GO_ADMIN_PORT=$((9383 + RUNNER_NUM * 10))
SANDBOX_EXECUTOR_MANAGER_PORT=$((9385 + RUNNER_NUM * 10))
SVR_WEB_HTTP_PORT=$((80 + RUNNER_NUM * 10))
SVR_WEB_HTTPS_PORT=$((443 + RUNNER_NUM * 10))
# Per-runner base plus per-workflow-run offset avoids port clashes when
# multiple CI jobs share the same self-hosted runner concurrently.
PORT_OFFSET=$(( (GITHUB_RUN_ID % 400) + RUNNER_NUM * 10 ))
ES_PORT=$((1200 + PORT_OFFSET))
OS_PORT=$((1201 + PORT_OFFSET))
INFINITY_THRIFT_PORT=$((23817 + PORT_OFFSET))
INFINITY_HTTP_PORT=$((23820 + PORT_OFFSET))
INFINITY_PSQL_PORT=$((5432 + PORT_OFFSET))
EXPOSE_MYSQL_PORT=$((5455 + PORT_OFFSET))
MINIO_PORT=$((9000 + PORT_OFFSET))
MINIO_CONSOLE_PORT=$((9001 + PORT_OFFSET))
REDIS_PORT=$((6379 + PORT_OFFSET))
TEI_PORT=$((6380 + PORT_OFFSET))
KIBANA_PORT=$((6601 + PORT_OFFSET))
SVR_HTTP_PORT=$((9380 + PORT_OFFSET))
ADMIN_SVR_HTTP_PORT=$((9381 + PORT_OFFSET))
SVR_MCP_PORT=$((9382 + PORT_OFFSET))
GO_HTTP_PORT=$((9384 + PORT_OFFSET))
GO_ADMIN_PORT=$((9383 + PORT_OFFSET))
SANDBOX_EXECUTOR_MANAGER_PORT=$((9385 + PORT_OFFSET))
SVR_WEB_HTTP_PORT=$((80 + PORT_OFFSET))
SVR_WEB_HTTPS_PORT=$((443 + PORT_OFFSET))
# Persist computed ports into .env so docker-compose uses the correct host bindings
echo "" >> .env
@@ -228,6 +230,8 @@ jobs:
- name: Start ragflow:nightly for Infinity
run: |
sed -i 's/^DOC_ENGINE=.*$/DOC_ENGINE=infinity/' docker/.env
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} down -v || true
sudo docker ps -a --filter "label=com.docker.compose.project=${GITHUB_RUN_ID}" -q | xargs -r sudo docker rm -f
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
- name: Run sdk tests against Infinity
@@ -442,6 +446,8 @@ jobs:
- name: Start ragflow:nightly for Elasticsearch
run: |
sed -i 's/^DOC_ENGINE=.*$/DOC_ENGINE=elasticsearch/' docker/.env
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} down -v || true
sudo docker ps -a --filter "label=com.docker.compose.project=${GITHUB_RUN_ID}" -q | xargs -r sudo docker rm -f
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
- name: Run sdk tests against Elasticsearch

View File

@@ -7,6 +7,7 @@
"chat": "v1/chat/completions",
"models": "models/list",
"balance": "payment/checklist",
"rerank": "v1/inference",
"embedding": "v1/embeddings",
"tts": "v1/text-to-speech",
"asr": "v1/audio/transcriptions"

View File

@@ -12,6 +12,7 @@ import (
"os"
"path/filepath"
"ragflow/internal/common"
"slices"
"strconv"
"strings"
"time"
@@ -418,12 +419,109 @@ func (d *DeepInfraModel) Embed(modelName *string, texts []string, apiConfig *API
}
return embeddings, nil
return embeddings, nil
}
func (d *DeepInfraModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
return nil, fmt.Errorf("%s no such method", d.Name())
// deepinfraRerankResponse is the JSON body returned by DeepInfra reranker models.
type deepinfraRerankResponse struct {
Scores []float64 `json:"scores"`
}
// Rerank scores documents against a query using DeepInfra's inference endpoint.
// The model id is part of the URL path (e.g. Qwen/Qwen3-Reranker-4B). The API
// returns one score per input document; RerankConfig.TopN is enforced client-side
// by keeping the highest-scoring entries when TopN is less than len(documents).
func (d *DeepInfraModel) Rerank(
modelName *string,
query string,
documents []string,
apiConfig *APIConfig,
rerankConfig *RerankConfig,
) (*RerankResponse, error) {
if len(documents) == 0 {
return &RerankResponse{}, nil
}
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
return nil, fmt.Errorf("api key is required")
}
if modelName == nil || strings.TrimSpace(*modelName) == "" {
return nil, fmt.Errorf("model name is required")
}
region := "default"
if apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
baseURL := d.BaseURL[region]
if baseURL == "" {
return nil, fmt.Errorf("deepinfra: no base URL configured for region %q", region)
}
// Reranker model ids may contain slashes (e.g. Qwen/Qwen3-Reranker-4B).
url := fmt.Sprintf("%s/%s/%s", strings.TrimSuffix(baseURL, "/"), d.URLSuffix.Rerank, *modelName)
reqBody := map[string]interface{}{
"query": query,
"documents": documents,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := d.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("DeepInfra rerank API error: status %d, body: %s", resp.StatusCode, string(body))
}
var parsed deepinfraRerankResponse
if err = json.Unmarshal(body, &parsed); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
if len(parsed.Scores) != len(documents) {
return nil, fmt.Errorf("deepinfra: expected %d scores, got %d", len(documents), len(parsed.Scores))
}
results := make([]RerankResult, len(parsed.Scores))
for i, score := range parsed.Scores {
results[i] = RerankResult{
Index: i,
RelevanceScore: score,
}
}
topN := len(results)
if rerankConfig != nil && rerankConfig.TopN > 0 && rerankConfig.TopN < topN {
topN = rerankConfig.TopN
slices.SortFunc(results, func(a, b RerankResult) int {
if a.RelevanceScore > b.RelevanceScore {
return -1
}
if a.RelevanceScore < b.RelevanceScore {
return 1
}
return 0
})
results = results[:topN]
}
return &RerankResponse{Data: results}, nil
}
func (d *DeepInfraModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {

View File

@@ -0,0 +1,141 @@
package models
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// newDeepInfraForTest builds a DeepInfra driver pointed at the test server URL.
func newDeepInfraForTest(baseURL string) *DeepInfraModel {
return NewDeepInfraModel(
map[string]string{"default": baseURL},
URLSuffix{
Chat: "v1/chat/completions",
Embedding: "v1/embeddings",
Rerank: "v1/inference",
},
)
}
// TestDeepInfraRerankHappyPath verifies request shape and score mapping.
func TestDeepInfraRerankHappyPath(t *testing.T) {
const modelPath = "/v1/inference/Qwen/Qwen3-Reranker-4B"
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != modelPath {
t.Errorf("path=%s want %s", r.URL.Path, modelPath)
return
}
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
t.Errorf("Authorization=%q", got)
return
}
raw, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("read body: %v", err)
return
}
var body map[string]interface{}
if err := json.Unmarshal(raw, &body); err != nil {
t.Errorf("unmarshal: %v", err)
return
}
if body["query"] != "capital of France?" {
t.Errorf("query=%v", body["query"])
}
docs, ok := body["documents"].([]interface{})
if !ok || len(docs) != 2 {
t.Errorf("documents=%v", body["documents"])
}
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"scores": []float64{0.9, 0.1},
})
}))
defer srv.Close()
apiKey := "test-key"
model := "Qwen/Qwen3-Reranker-4B"
resp, err := newDeepInfraForTest(srv.URL).Rerank(
&model,
"capital of France?",
[]string{"Paris is the capital.", "Berlin is the capital."},
&APIConfig{ApiKey: &apiKey},
&RerankConfig{TopN: 1},
)
if err != nil {
t.Fatalf("Rerank: %v", err)
}
if len(resp.Data) != 1 ||
resp.Data[0].RelevanceScore != 0.9 || resp.Data[0].Index != 0 {
t.Errorf("resp=%+v", resp.Data)
}
}
// TestDeepInfraRerankNoTopNLimit returns every scored document when TopN is unset.
func TestDeepInfraRerankNoTopNLimit(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"scores": []float64{0.9, 0.1},
})
}))
defer srv.Close()
apiKey := "test-key"
model := "Qwen/Qwen3-Reranker-4B"
resp, err := newDeepInfraForTest(srv.URL).Rerank(
&model,
"capital of France?",
[]string{"Paris is the capital.", "Berlin is the capital."},
&APIConfig{ApiKey: &apiKey},
nil,
)
if err != nil {
t.Fatalf("Rerank: %v", err)
}
if len(resp.Data) != 2 ||
resp.Data[0].RelevanceScore != 0.9 || resp.Data[0].Index != 0 ||
resp.Data[1].RelevanceScore != 0.1 || resp.Data[1].Index != 1 {
t.Errorf("resp=%+v", resp.Data)
}
}
// TestDeepInfraRerankEmptyDocuments returns an empty result without calling the API.
func TestDeepInfraRerankEmptyDocuments(t *testing.T) {
apiKey := "test-key"
model := "Qwen/Qwen3-Reranker-4B"
resp, err := newDeepInfraForTest("http://unused").Rerank(&model, "q", nil, &APIConfig{ApiKey: &apiKey}, nil)
if err != nil {
t.Fatalf("Rerank: %v", err)
}
if len(resp.Data) != 0 {
t.Errorf("len=%d want 0", len(resp.Data))
}
}
// TestDeepInfraRerankRequiresAPIKey rejects requests without an API key.
func TestDeepInfraRerankRequiresAPIKey(t *testing.T) {
model := "Qwen/Qwen3-Reranker-4B"
_, err := newDeepInfraForTest("http://unused").Rerank(&model, "q", []string{"a"}, &APIConfig{}, nil)
if err == nil || !strings.Contains(err.Error(), "api key is required") {
t.Errorf("expected api-key error, got %v", err)
}
}
// TestDeepInfraRerankRejectsScoreCountMismatch errors when scores length mismatches documents.
func TestDeepInfraRerankRejectsScoreCountMismatch(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]interface{}{"scores": []float64{0.5}})
}))
defer srv.Close()
apiKey := "test-key"
model := "cross-encoder/ms-marco-MiniLM-L-12-v2"
_, err := newDeepInfraForTest(srv.URL).Rerank(
&model, "q", []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil)
if err == nil || !strings.Contains(err.Error(), "expected 2 scores") {
t.Errorf("expected score-count error, got %v", err)
}
}