mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Go: implement Rerank in DeepInfra driver (#15185)
### What problem does this PR solve?
The Go DeepInfra driver returned a stub error for `Rerank()` even though
DeepInfra serves reranker models at `POST /v1/inference/{model}` with
`query`, `documents`, and a `scores[]` response.
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
46
.github/workflows/tests.yml
vendored
46
.github/workflows/tests.yml
vendored
@@ -171,26 +171,28 @@ jobs:
|
||||
RUNNER_NUM=$(sudo docker inspect $(hostname) --format '{{index .Config.Labels "com.docker.compose.container-number"}}' 2>/dev/null || true)
|
||||
RUNNER_NUM=${RUNNER_NUM:-1}
|
||||
|
||||
# Compute port numbers using bash arithmetic
|
||||
ES_PORT=$((1200 + RUNNER_NUM * 10))
|
||||
OS_PORT=$((1201 + RUNNER_NUM * 10))
|
||||
INFINITY_THRIFT_PORT=$((23817 + RUNNER_NUM * 10))
|
||||
INFINITY_HTTP_PORT=$((23820 + RUNNER_NUM * 10))
|
||||
INFINITY_PSQL_PORT=$((5432 + RUNNER_NUM * 10))
|
||||
EXPOSE_MYSQL_PORT=$((5455 + RUNNER_NUM * 10))
|
||||
MINIO_PORT=$((9000 + RUNNER_NUM * 10))
|
||||
MINIO_CONSOLE_PORT=$((9001 + RUNNER_NUM * 10))
|
||||
REDIS_PORT=$((6379 + RUNNER_NUM * 10))
|
||||
TEI_PORT=$((6380 + RUNNER_NUM * 10))
|
||||
KIBANA_PORT=$((6601 + RUNNER_NUM * 10))
|
||||
SVR_HTTP_PORT=$((9380 + RUNNER_NUM * 10))
|
||||
ADMIN_SVR_HTTP_PORT=$((9381 + RUNNER_NUM * 10))
|
||||
SVR_MCP_PORT=$((9382 + RUNNER_NUM * 10))
|
||||
GO_HTTP_PORT=$((9384 + RUNNER_NUM * 10))
|
||||
GO_ADMIN_PORT=$((9383 + RUNNER_NUM * 10))
|
||||
SANDBOX_EXECUTOR_MANAGER_PORT=$((9385 + RUNNER_NUM * 10))
|
||||
SVR_WEB_HTTP_PORT=$((80 + RUNNER_NUM * 10))
|
||||
SVR_WEB_HTTPS_PORT=$((443 + RUNNER_NUM * 10))
|
||||
# Per-runner base plus per-workflow-run offset avoids port clashes when
|
||||
# multiple CI jobs share the same self-hosted runner concurrently.
|
||||
PORT_OFFSET=$(( (GITHUB_RUN_ID % 400) + RUNNER_NUM * 10 ))
|
||||
ES_PORT=$((1200 + PORT_OFFSET))
|
||||
OS_PORT=$((1201 + PORT_OFFSET))
|
||||
INFINITY_THRIFT_PORT=$((23817 + PORT_OFFSET))
|
||||
INFINITY_HTTP_PORT=$((23820 + PORT_OFFSET))
|
||||
INFINITY_PSQL_PORT=$((5432 + PORT_OFFSET))
|
||||
EXPOSE_MYSQL_PORT=$((5455 + PORT_OFFSET))
|
||||
MINIO_PORT=$((9000 + PORT_OFFSET))
|
||||
MINIO_CONSOLE_PORT=$((9001 + PORT_OFFSET))
|
||||
REDIS_PORT=$((6379 + PORT_OFFSET))
|
||||
TEI_PORT=$((6380 + PORT_OFFSET))
|
||||
KIBANA_PORT=$((6601 + PORT_OFFSET))
|
||||
SVR_HTTP_PORT=$((9380 + PORT_OFFSET))
|
||||
ADMIN_SVR_HTTP_PORT=$((9381 + PORT_OFFSET))
|
||||
SVR_MCP_PORT=$((9382 + PORT_OFFSET))
|
||||
GO_HTTP_PORT=$((9384 + PORT_OFFSET))
|
||||
GO_ADMIN_PORT=$((9383 + PORT_OFFSET))
|
||||
SANDBOX_EXECUTOR_MANAGER_PORT=$((9385 + PORT_OFFSET))
|
||||
SVR_WEB_HTTP_PORT=$((80 + PORT_OFFSET))
|
||||
SVR_WEB_HTTPS_PORT=$((443 + PORT_OFFSET))
|
||||
|
||||
# Persist computed ports into .env so docker-compose uses the correct host bindings
|
||||
echo "" >> .env
|
||||
@@ -228,6 +230,8 @@ jobs:
|
||||
- name: Start ragflow:nightly for Infinity
|
||||
run: |
|
||||
sed -i 's/^DOC_ENGINE=.*$/DOC_ENGINE=infinity/' docker/.env
|
||||
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} down -v || true
|
||||
sudo docker ps -a --filter "label=com.docker.compose.project=${GITHUB_RUN_ID}" -q | xargs -r sudo docker rm -f
|
||||
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
|
||||
|
||||
- name: Run sdk tests against Infinity
|
||||
@@ -442,6 +446,8 @@ jobs:
|
||||
- name: Start ragflow:nightly for Elasticsearch
|
||||
run: |
|
||||
sed -i 's/^DOC_ENGINE=.*$/DOC_ENGINE=elasticsearch/' docker/.env
|
||||
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} down -v || true
|
||||
sudo docker ps -a --filter "label=com.docker.compose.project=${GITHUB_RUN_ID}" -q | xargs -r sudo docker rm -f
|
||||
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
|
||||
|
||||
- name: Run sdk tests against Elasticsearch
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
"chat": "v1/chat/completions",
|
||||
"models": "models/list",
|
||||
"balance": "payment/checklist",
|
||||
"rerank": "v1/inference",
|
||||
"embedding": "v1/embeddings",
|
||||
"tts": "v1/text-to-speech",
|
||||
"asr": "v1/audio/transcriptions"
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"ragflow/internal/common"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -418,12 +419,109 @@ func (d *DeepInfraModel) Embed(modelName *string, texts []string, apiConfig *API
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
func (d *DeepInfraModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
|
||||
return nil, fmt.Errorf("%s no such method", d.Name())
|
||||
// deepinfraRerankResponse is the JSON body returned by DeepInfra reranker models.
|
||||
type deepinfraRerankResponse struct {
|
||||
Scores []float64 `json:"scores"`
|
||||
}
|
||||
|
||||
// Rerank scores documents against a query using DeepInfra's inference endpoint.
|
||||
// The model id is part of the URL path (e.g. Qwen/Qwen3-Reranker-4B). The API
|
||||
// returns one score per input document; RerankConfig.TopN is enforced client-side
|
||||
// by keeping the highest-scoring entries when TopN is less than len(documents).
|
||||
func (d *DeepInfraModel) Rerank(
|
||||
modelName *string,
|
||||
query string,
|
||||
documents []string,
|
||||
apiConfig *APIConfig,
|
||||
rerankConfig *RerankConfig,
|
||||
) (*RerankResponse, error) {
|
||||
if len(documents) == 0 {
|
||||
return &RerankResponse{}, nil
|
||||
}
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return nil, fmt.Errorf("api key is required")
|
||||
}
|
||||
if modelName == nil || strings.TrimSpace(*modelName) == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
|
||||
region := "default"
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
baseURL := d.BaseURL[region]
|
||||
if baseURL == "" {
|
||||
return nil, fmt.Errorf("deepinfra: no base URL configured for region %q", region)
|
||||
}
|
||||
|
||||
// Reranker model ids may contain slashes (e.g. Qwen/Qwen3-Reranker-4B).
|
||||
url := fmt.Sprintf("%s/%s/%s", strings.TrimSuffix(baseURL, "/"), d.URLSuffix.Rerank, *modelName)
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := d.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("DeepInfra rerank API error: status %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var parsed deepinfraRerankResponse
|
||||
if err = json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
if len(parsed.Scores) != len(documents) {
|
||||
return nil, fmt.Errorf("deepinfra: expected %d scores, got %d", len(documents), len(parsed.Scores))
|
||||
}
|
||||
|
||||
results := make([]RerankResult, len(parsed.Scores))
|
||||
for i, score := range parsed.Scores {
|
||||
results[i] = RerankResult{
|
||||
Index: i,
|
||||
RelevanceScore: score,
|
||||
}
|
||||
}
|
||||
|
||||
topN := len(results)
|
||||
if rerankConfig != nil && rerankConfig.TopN > 0 && rerankConfig.TopN < topN {
|
||||
topN = rerankConfig.TopN
|
||||
slices.SortFunc(results, func(a, b RerankResult) int {
|
||||
if a.RelevanceScore > b.RelevanceScore {
|
||||
return -1
|
||||
}
|
||||
if a.RelevanceScore < b.RelevanceScore {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
})
|
||||
results = results[:topN]
|
||||
}
|
||||
|
||||
return &RerankResponse{Data: results}, nil
|
||||
}
|
||||
|
||||
func (d *DeepInfraModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
|
||||
|
||||
141
internal/entity/models/deepinfra_test.go
Normal file
141
internal/entity/models/deepinfra_test.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// newDeepInfraForTest builds a DeepInfra driver pointed at the test server URL.
|
||||
func newDeepInfraForTest(baseURL string) *DeepInfraModel {
|
||||
return NewDeepInfraModel(
|
||||
map[string]string{"default": baseURL},
|
||||
URLSuffix{
|
||||
Chat: "v1/chat/completions",
|
||||
Embedding: "v1/embeddings",
|
||||
Rerank: "v1/inference",
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// TestDeepInfraRerankHappyPath verifies request shape and score mapping.
|
||||
func TestDeepInfraRerankHappyPath(t *testing.T) {
|
||||
const modelPath = "/v1/inference/Qwen/Qwen3-Reranker-4B"
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != modelPath {
|
||||
t.Errorf("path=%s want %s", r.URL.Path, modelPath)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
|
||||
t.Errorf("Authorization=%q", got)
|
||||
return
|
||||
}
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("read body: %v", err)
|
||||
return
|
||||
}
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal(raw, &body); err != nil {
|
||||
t.Errorf("unmarshal: %v", err)
|
||||
return
|
||||
}
|
||||
if body["query"] != "capital of France?" {
|
||||
t.Errorf("query=%v", body["query"])
|
||||
}
|
||||
docs, ok := body["documents"].([]interface{})
|
||||
if !ok || len(docs) != 2 {
|
||||
t.Errorf("documents=%v", body["documents"])
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"scores": []float64{0.9, 0.1},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "Qwen/Qwen3-Reranker-4B"
|
||||
resp, err := newDeepInfraForTest(srv.URL).Rerank(
|
||||
&model,
|
||||
"capital of France?",
|
||||
[]string{"Paris is the capital.", "Berlin is the capital."},
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
&RerankConfig{TopN: 1},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Rerank: %v", err)
|
||||
}
|
||||
if len(resp.Data) != 1 ||
|
||||
resp.Data[0].RelevanceScore != 0.9 || resp.Data[0].Index != 0 {
|
||||
t.Errorf("resp=%+v", resp.Data)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeepInfraRerankNoTopNLimit returns every scored document when TopN is unset.
|
||||
func TestDeepInfraRerankNoTopNLimit(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"scores": []float64{0.9, 0.1},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "Qwen/Qwen3-Reranker-4B"
|
||||
resp, err := newDeepInfraForTest(srv.URL).Rerank(
|
||||
&model,
|
||||
"capital of France?",
|
||||
[]string{"Paris is the capital.", "Berlin is the capital."},
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Rerank: %v", err)
|
||||
}
|
||||
if len(resp.Data) != 2 ||
|
||||
resp.Data[0].RelevanceScore != 0.9 || resp.Data[0].Index != 0 ||
|
||||
resp.Data[1].RelevanceScore != 0.1 || resp.Data[1].Index != 1 {
|
||||
t.Errorf("resp=%+v", resp.Data)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeepInfraRerankEmptyDocuments returns an empty result without calling the API.
|
||||
func TestDeepInfraRerankEmptyDocuments(t *testing.T) {
|
||||
apiKey := "test-key"
|
||||
model := "Qwen/Qwen3-Reranker-4B"
|
||||
resp, err := newDeepInfraForTest("http://unused").Rerank(&model, "q", nil, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Rerank: %v", err)
|
||||
}
|
||||
if len(resp.Data) != 0 {
|
||||
t.Errorf("len=%d want 0", len(resp.Data))
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeepInfraRerankRequiresAPIKey rejects requests without an API key.
|
||||
func TestDeepInfraRerankRequiresAPIKey(t *testing.T) {
|
||||
model := "Qwen/Qwen3-Reranker-4B"
|
||||
_, err := newDeepInfraForTest("http://unused").Rerank(&model, "q", []string{"a"}, &APIConfig{}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "api key is required") {
|
||||
t.Errorf("expected api-key error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeepInfraRerankRejectsScoreCountMismatch errors when scores length mismatches documents.
|
||||
func TestDeepInfraRerankRejectsScoreCountMismatch(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{"scores": []float64{0.5}})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
model := "cross-encoder/ms-marco-MiniLM-L-12-v2"
|
||||
_, err := newDeepInfraForTest(srv.URL).Rerank(
|
||||
&model, "q", []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "expected 2 scores") {
|
||||
t.Errorf("expected score-count error, got %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user