Update Rerank logic in GO (#15755)

### What problem does this PR solve?

Sync the rerank logic in the following PR  to  GO.
https://github.com/infiniflow/ragflow/pull/15429
https://github.com/infiniflow/ragflow/pull/15434

### Type of change

- [x] Refactoring
This commit is contained in:
qinling0210
2026-06-08 15:28:10 +08:00
committed by GitHub
parent 6bf7056422
commit 5e0a7ce408
5 changed files with 415 additions and 94 deletions

View File

@@ -2031,6 +2031,32 @@ func getDefaultSkillMapping() map[string]interface{} {
}
}
// rerankWindow returns the candidate-window size shared by retrieval's
// block fetch and slice. Mirrors Dealer._rerank_window in rag/nlp/search.py.
//
// `size` is the per-page size; the window MUST be an exact multiple of it,
// otherwise the block fetched (offset // window) and the in-block page slice
// (offset % window) drift apart and deep pagination silently drops results.
//
// The window targets a provider-friendly pool of ~64 candidates, bounded by
// `topK` when given (i.e. when an external reranker is active), and is always
// rounded UP to a whole number of pages to preserve the alignment invariant.
func rerankWindow(size, topK int) int {
if size <= 1 {
if topK > 0 {
return min(30, topK)
}
return 30
}
window := ((64 + size - 1) / size) * size // ceil(64/size) * size
if topK > 0 {
if aligned := ((topK + size - 1) / size) * size; window > aligned {
window = aligned
}
}
return window
}
// calculatePagination calculates offset and limit based on page, size and topK
func calculatePagination(page, size, topK int) (int, int) {
if page < 1 {
@@ -2043,99 +2069,14 @@ func calculatePagination(page, size, topK int) (int, int) {
topK = 1024
}
RERANK_LIMIT := max(30, (64/size)*size)
if RERANK_LIMIT < size {
RERANK_LIMIT = size
}
if RERANK_LIMIT > topK {
RERANK_LIMIT = topK
}
window := rerankWindow(size, topK)
offset := (page - 1) * RERANK_LIMIT
offset := (page - 1) * window
if offset < 0 {
offset = 0
}
return offset, RERANK_LIMIT
}
// buildFilterClauses builds ES filter clauses from kb_ids and available_int
// Reference: rag/utils/es_conn.py L60-L78
// When available=0: available_int < 1
// When available!=0: NOT (available_int < 1)
func buildFilterClauses(datasetIDs []string, available int) []map[string]interface{} {
var filters []map[string]interface{}
if len(datasetIDs) > 0 {
filters = append(filters, map[string]interface{}{
"terms": map[string]interface{}{"kb_id": datasetIDs},
})
}
// Add available_int filter
// Reference: rag/utils/es_conn.py L63-L68
if available == 0 {
// available_int < 1
filters = append(filters, map[string]interface{}{
"range": map[string]interface{}{
"available_int": map[string]interface{}{
"lt": 1,
},
},
})
} else {
// must_not: available_int < 1 (i.e., available_int >= 1)
filters = append(filters, map[string]interface{}{
"bool": map[string]interface{}{
"must_not": []map[string]interface{}{
{
"range": map[string]interface{}{
"available_int": map[string]interface{}{
"lt": 1,
},
},
},
},
},
})
}
return filters
}
// buildSkillFilterClauses builds ES filter clauses for skill index
// Skill index uses 'status' field instead of 'available_int'
func buildSkillFilterClauses() []map[string]interface{} {
// Filter for active skills (status = "1")
return []map[string]interface{}{
{
"term": map[string]interface{}{
"status": "1",
},
},
}
}
// buildFilterFromMap converts a generic filter map to ES filter clauses
func buildFilterFromMap(filter map[string]interface{}) []map[string]interface{} {
var filters []map[string]interface{}
for field, value := range filter {
switch v := value.(type) {
case []string:
filters = append(filters, map[string]interface{}{
"terms": map[string]interface{}{field: v},
})
case []interface{}:
filters = append(filters, map[string]interface{}{
"terms": map[string]interface{}{field: v},
})
default:
filters = append(filters, map[string]interface{}{
"term": map[string]interface{}{field: v},
})
}
}
return filters
return offset, window
}
// convertESResponse converts ES SearchResponse to unified chunks format

View File

@@ -431,3 +431,106 @@ func TestBuildBoolQueryFromConditionIDFilter(t *testing.T) {
"id": 42,
}, []string{"id", "_id"})
}
// paginationGRID mirrors the (page_size, top) grid from
// rag/nlp/search.py::Dealer._rerank_window tests. It covers the common page
// sizes that do NOT divide 64 (the exact case the legacy min(..., 64) clamp
// broke) plus tiny / large / page-aligned tops.
var paginationGRID = func() []struct{ size, topK int } {
sizes := []int{1, 5, 7, 10, 30, 50, 64}
tops := []int{0, 5, 30, 50, 55, 64, 100, 1024}
out := make([]struct{ size, topK int }, 0, len(sizes)*len(tops))
for _, s := range sizes {
for _, t := range tops {
out = append(out, struct{ size, topK int }{s, t})
}
}
return out
}()
// paginate replays the (block-fetch + in-block slice) math that
// calculatePagination's window is consumed by: for every page whose start is
// inside the candidate pool, return the in-block page slice. The block is
// window-aligned, so on the aligned invariant every page is full and the
// concatenation reconstructs [0, cap).
func paginate(total, size, topK int) (window, capN int, surfaced []int) {
window = rerankWindow(size, topK)
capN = total
if topK > 0 && capN > topK {
capN = topK
}
for page := 1; (page-1)*size < capN; page++ {
globalOffset := (page - 1) * size
blockIndex := globalOffset / window
blockStart := blockIndex * window
block := make([]int, 0, window)
for i := blockStart; i < blockStart+window && i < capN; i++ {
block = append(block, i)
}
begin := globalOffset % window
end := begin + size
if end > len(block) {
end = len(block)
}
surfaced = append(surfaced, block[begin:end]...)
}
return window, capN, surfaced
}
func TestRerankWindowIsPageAligned(t *testing.T) {
for _, g := range paginationGRID {
window := rerankWindow(g.size, g.topK)
if window < 1 {
t.Errorf("rerankWindow(%d, %d) = %d, want >= 1", g.size, g.topK, window)
}
if g.size > 1 && window%g.size != 0 {
t.Errorf("rerankWindow(%d, %d) = %d, want multiple of %d", g.size, g.topK, window, g.size)
}
}
}
func TestRerankWindowPaginationReconstructsPool(t *testing.T) {
// Walking every page reconstructs the candidate pool exactly: in order,
// no gaps, no duplicates, and no short interior pages.
const total = 250
for _, g := range paginationGRID {
window, capN, surfaced := paginate(total, g.size, g.topK)
if len(surfaced) != capN {
t.Errorf("size=%d topK=%d: surfaced %d, want %d (window=%d)",
g.size, g.topK, len(surfaced), capN, window)
continue
}
for i, v := range surfaced {
if v != i {
t.Errorf("size=%d topK=%d: surfaced[%d] = %d, want %d (window=%d)",
g.size, g.topK, i, v, i, window)
break
}
}
}
}
func TestCalculatePaginationReportedRegression(t *testing.T) {
// The reported case: size=10, topK=1024. Legacy min(..., 64) clamped the
// window to 64 (not a multiple of 10), so page 7 (global offset 60) used
// to return only 4 of 10 results. With the fix, the window is 70 and
// page 7 is full and contiguous.
_, limit := calculatePagination(7, 10, 1024)
if limit != 70 {
t.Fatalf("calculatePagination(7, 10, 1024) limit = %d, want 70", limit)
}
if limit%10 != 0 {
t.Fatalf("calculatePagination(7, 10, 1024) limit = %d, want multiple of 10", limit)
}
// And the simulated end-to-end page walk covers positions 60..69 fully.
_, capN, surfaced := paginate(250, 10, 1024)
if capN < 70 || len(surfaced) < 70 {
t.Fatalf("paginate(250, 10, 1024) returned cap=%d surfaced=%d, want >= 70", capN, len(surfaced))
}
for i := 60; i < 70; i++ {
if surfaced[i] != i {
t.Errorf("page 7: surfaced[%d] = %d, want %d", i, surfaced[i], i)
}
}
}

View File

@@ -18,7 +18,7 @@ import (
"reflect"
"testing"
"ragflow/internal/engine/infinity"
"ragflow/internal/engine/types"
)
func TestNewQueryBuilder(t *testing.T) {
@@ -229,7 +229,7 @@ func TestQueryBuilder_Question(t *testing.T) {
tbl string
minMatch float64
expectNil bool
checkExpr func(*infinity.MatchTextExpr) bool
checkExpr func(*types.MatchTextExpr) bool
checkKeywords func([]string) bool
}{
{
@@ -237,7 +237,7 @@ func TestQueryBuilder_Question(t *testing.T) {
txt: "请问如何安装软件",
tbl: "test",
minMatch: 0.5,
checkExpr: func(expr *infinity.MatchTextExpr) bool {
checkExpr: func(expr *types.MatchTextExpr) bool {
// Should return a valid query expression with processed text
return expr != nil && expr.MatchingText != ""
},
@@ -251,7 +251,7 @@ func TestQueryBuilder_Question(t *testing.T) {
txt: "How to install software",
tbl: "test",
minMatch: 0.5,
checkExpr: func(expr *infinity.MatchTextExpr) bool {
checkExpr: func(expr *types.MatchTextExpr) bool {
// Should return a valid query expression with processed text
return expr != nil && expr.MatchingText != ""
},
@@ -265,7 +265,7 @@ func TestQueryBuilder_Question(t *testing.T) {
txt: "hello世界",
tbl: "test",
minMatch: 0.5,
checkExpr: func(expr *infinity.MatchTextExpr) bool {
checkExpr: func(expr *types.MatchTextExpr) bool {
// Should return a valid query expression with processed text
return expr != nil && expr.MatchingText != ""
},
@@ -280,7 +280,7 @@ func TestQueryBuilder_Question(t *testing.T) {
tbl: "test",
minMatch: 0.5,
expectNil: true,
checkExpr: func(expr *infinity.MatchTextExpr) bool {
checkExpr: func(expr *types.MatchTextExpr) bool {
return expr == nil
},
checkKeywords: func(keywords []string) bool {

View File

@@ -164,6 +164,16 @@ func RerankByModel(
}
}
// Reranker drivers do not agree on a score scale: Cohere/Jina/Voyage emit
// calibrated [0, 1] relevance scores, but NVIDIA returns raw, often
// negative logits. The hybrid blend below (tkWeight * tksim + vtWeight *
// modelSim) lives on a fixed [0, 1] scale, so an un-normalized logit
// weighted by vtWeight=0.7 can sink a relevant chunk below pure keyword
// matches and dominate the blend. Centralize the normalization here so
// every provider contributes on the same scale. See
// NormalizeRerankScores for the contract.
modelSim = NormalizeRerankScores(modelSim)
// Combine token similarity with model similarity
// Model similarity is treated as vector similarity component
sim = make([]float64, len(insTw))
@@ -179,6 +189,68 @@ func RerankByModel(
return sim, tsim, modelSim
}
// NormalizeRerankScores rescales reranker scores into [0, 1] for the
// hybrid blend in RerankByModel. Mirrors the contract enforced by
// Base.similarity / _normalize_rank in rag/llm/rerank_model.py.
//
// Providers that already return calibrated [0, 1] relevance scores
// (Cohere, Jina, Voyage, ...) are returned unchanged, so
// similarity_threshold filtering and reported vector_similarity keep
// their absolute magnitudes. Only out-of-range output (e.g. NVIDIA's
// unbounded, often negative logits) is rescaled: a batch with usable
// spread is min-max mapped onto [0, 1] (which stops a negative logit
// from dragging a relevant chunk below pure keyword matches once
// weighted by vtweight), while a spreadless batch (including a single
// candidate) is clamped per element so a lone high score is not silently
// zeroed and no NaN leaks into the blend.
//
// An empty input is returned verbatim. Mutates the input slice in place
// to keep the RerankByModel call site allocation-free; the returned
// slice is the same backing array.
func NormalizeRerankScores(scores []float64) []float64 {
n := len(scores)
if n == 0 {
return scores
}
minScore := scores[0]
maxScore := scores[0]
for _, s := range scores[1:] {
if s < minScore {
minScore = s
}
if s > maxScore {
maxScore = s
}
}
// Already in [0, 1]? Keep absolute magnitudes so calibrated providers
// and degenerate (but valid) batches are NOT collapsed to zero.
if minScore >= 0.0 && maxScore <= 1.0 {
return scores
}
// Spreadless out-of-range batch: clamp per element instead of
// collapsing to zero or dividing by ~0.
span := maxScore - minScore
if span < 1e-3 {
for i, s := range scores {
if s < 0.0 {
scores[i] = 0.0
} else if s > 1.0 {
scores[i] = 1.0
}
}
return scores
}
// Min-max rescale onto [0, 1].
invSpan := 1.0 / span
for i, s := range scores {
scores[i] = (s - minScore) * invSpan
}
return scores
}
// RerankStandard performs standard reranking without a reranker model
// Used for Elasticsearch when no reranker model is provided
func RerankStandard(

View File

@@ -0,0 +1,205 @@
// Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package nlp
import (
"math"
"testing"
)
// TestNormalizeRerankScores_OutOfRange_Rescaled covers the central bug fix:
// uncalibrated reranker output (e.g. NVIDIA logits) is min-max rescaled
// onto [0, 1] so a negative logit weighted by vtWeight=0.7 cannot sink a
// relevant chunk below pure keyword matches.
func TestNormalizeRerankScores_OutOfRange_Rescaled(t *testing.T) {
cases := []struct {
name string
in []float64
want []float64
}{
{"unbounded mixed-sign logits", []float64{10.0, -3.0, 0.0}, []float64{1.0, 0.0, 3.0 / 13.0}},
{"large positive logits", []float64{100.0, 50.0, 75.0}, []float64{1.0, 0.0, 0.5}},
{"negative-only logits", []float64{-1.0, -5.0, -3.0}, []float64{1.0, 0.0, 0.5}},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := NormalizeRerankScores(tc.in)
if !floatsClose(got, tc.want, 1e-9) {
t.Errorf("got %v, want %v", got, tc.want)
}
if minOf(got) < 0.0 || maxOf(got) > 1.0 {
t.Errorf("scores escaped [0, 1]: %v", got)
}
})
}
}
// TestNormalizeRerankScores_InRange_Preserved pins the calibrated-provider
// guarantee: Cohere/Jina/Voyage-style scores in [0, 1] are returned verbatim,
// so similarity_threshold semantics and the reported vector_similarity keep
// their absolute magnitudes.
func TestNormalizeRerankScores_InRange_Preserved(t *testing.T) {
cases := []struct {
name string
in []float64
}{
{"spread relevance", []float64{0.9, 0.1, 0.5}},
{"all-equal but valid", []float64{0.8, 0.8, 0.8}},
{"single candidate", []float64{1.0}},
{"already spanning the full range", []float64{0.0, 1.0, 0.42}},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := NormalizeRerankScores(tc.in)
if !floatsClose(got, tc.in, 1e-9) {
t.Errorf("got %v, want %v (must be preserved)", got, tc.in)
}
})
}
}
// TestNormalizeRerankScores_PreservesOrdering ensures rescaling does not
// scramble the relative ranking; this is the property downstream code relies
// on when sorting by rerank score.
func TestNormalizeRerankScores_PreservesOrdering(t *testing.T) {
in := []float64{-5.0, 12.0, 3.0, -1.0}
got := NormalizeRerankScores(in)
wantOrder := argsortDesc(in)
gotOrder := argsortDesc(got)
if !intsEqual(wantOrder, gotOrder) {
t.Errorf("ordering changed: want %v, got %v", wantOrder, gotOrder)
}
}
// TestNormalizeRerankScores_SpreadlessOutOfRange_Clamped covers the
// degenerate but realistic case of a single rerank candidate or a flat
// batch of out-of-range values: clamped per element, never zeroed, never
// NaN. A lone high logit would otherwise be silently dropped and
// contaminate the blend with NaN if divided by ~0.
func TestNormalizeRerankScores_SpreadlessOutOfRange_Clamped(t *testing.T) {
cases := []struct {
name string
in []float64
want []float64
}{
{"single out-of-range high", []float64{5.0}, []float64{1.0}},
{"single out-of-range negative", []float64{-3.0}, []float64{0.0}},
{"flat out-of-range high batch", []float64{5.0, 5.0, 5.0}, []float64{1.0, 1.0, 1.0}},
{"flat out-of-range low batch", []float64{-2.0, -2.0, -2.0}, []float64{0.0, 0.0, 0.0}},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := NormalizeRerankScores(tc.in)
if !floatsClose(got, tc.want, 1e-9) {
t.Errorf("got %v, want %v", got, tc.want)
}
for _, s := range got {
if math.IsNaN(s) {
t.Fatalf("NaN leaked into normalized scores: %v", got)
}
}
})
}
}
// TestNormalizeRerankScores_Empty covers the empty-input contract: returned
// verbatim, no allocation, no panic.
func TestNormalizeRerankScores_Empty(t *testing.T) {
got := NormalizeRerankScores(nil)
if len(got) != 0 {
t.Errorf("nil in -> expected empty out, got %v", got)
}
got = NormalizeRerankScores([]float64{})
if len(got) != 0 {
t.Errorf("[] in -> expected empty out, got %v", got)
}
}
// TestNormalizeRerankScores_InPlace pins the in-place guarantee: the input
// slice's backing array is what gets returned, so the RerankByModel call
// site stays allocation-free.
func TestNormalizeRerankScores_InPlace(t *testing.T) {
in := []float64{10.0, -3.0, 0.0}
got := NormalizeRerankScores(in)
if &got[0] != &in[0] {
t.Errorf("NormalizeRerankScores must mutate in place; got a new backing array")
}
}
func floatsClose(a, b []float64, tol float64) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if math.Abs(a[i]-b[i]) > tol {
return false
}
}
return true
}
func intsEqual(a, b []int) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func minOf(s []float64) float64 {
if len(s) == 0 {
return 0
}
m := s[0]
for _, v := range s[1:] {
if v < m {
m = v
}
}
return m
}
func maxOf(s []float64) float64 {
if len(s) == 0 {
return 0
}
m := s[0]
for _, v := range s[1:] {
if v > m {
m = v
}
}
return m
}
// argsortDesc returns the indices of s sorted by value in descending order,
// matching how a downstream consumer would compare rerank scores.
func argsortDesc(s []float64) []int {
idx := make([]int, len(s))
for i := range idx {
idx[i] = i
}
// Insertion sort keeps it dependency-free; len is small (batch size).
for i := 1; i < len(idx); i++ {
for j := i; j > 0 && s[idx[j]] > s[idx[j-1]]; j-- {
idx[j], idx[j-1] = idx[j-1], idx[j]
}
}
return idx
}