mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Update Rerank logic in GO (#15755)
### What problem does this PR solve? Sync the rerank logic in the following PR to GO. https://github.com/infiniflow/ragflow/pull/15429 https://github.com/infiniflow/ragflow/pull/15434 ### Type of change - [x] Refactoring
This commit is contained in:
@@ -2031,6 +2031,32 @@ func getDefaultSkillMapping() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
// rerankWindow returns the candidate-window size shared by retrieval's
|
||||
// block fetch and slice. Mirrors Dealer._rerank_window in rag/nlp/search.py.
|
||||
//
|
||||
// `size` is the per-page size; the window MUST be an exact multiple of it,
|
||||
// otherwise the block fetched (offset // window) and the in-block page slice
|
||||
// (offset % window) drift apart and deep pagination silently drops results.
|
||||
//
|
||||
// The window targets a provider-friendly pool of ~64 candidates, bounded by
|
||||
// `topK` when given (i.e. when an external reranker is active), and is always
|
||||
// rounded UP to a whole number of pages to preserve the alignment invariant.
|
||||
func rerankWindow(size, topK int) int {
|
||||
if size <= 1 {
|
||||
if topK > 0 {
|
||||
return min(30, topK)
|
||||
}
|
||||
return 30
|
||||
}
|
||||
window := ((64 + size - 1) / size) * size // ceil(64/size) * size
|
||||
if topK > 0 {
|
||||
if aligned := ((topK + size - 1) / size) * size; window > aligned {
|
||||
window = aligned
|
||||
}
|
||||
}
|
||||
return window
|
||||
}
|
||||
|
||||
// calculatePagination calculates offset and limit based on page, size and topK
|
||||
func calculatePagination(page, size, topK int) (int, int) {
|
||||
if page < 1 {
|
||||
@@ -2043,99 +2069,14 @@ func calculatePagination(page, size, topK int) (int, int) {
|
||||
topK = 1024
|
||||
}
|
||||
|
||||
RERANK_LIMIT := max(30, (64/size)*size)
|
||||
if RERANK_LIMIT < size {
|
||||
RERANK_LIMIT = size
|
||||
}
|
||||
if RERANK_LIMIT > topK {
|
||||
RERANK_LIMIT = topK
|
||||
}
|
||||
window := rerankWindow(size, topK)
|
||||
|
||||
offset := (page - 1) * RERANK_LIMIT
|
||||
offset := (page - 1) * window
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
return offset, RERANK_LIMIT
|
||||
}
|
||||
|
||||
// buildFilterClauses builds ES filter clauses from kb_ids and available_int
|
||||
// Reference: rag/utils/es_conn.py L60-L78
|
||||
// When available=0: available_int < 1
|
||||
// When available!=0: NOT (available_int < 1)
|
||||
func buildFilterClauses(datasetIDs []string, available int) []map[string]interface{} {
|
||||
var filters []map[string]interface{}
|
||||
|
||||
if len(datasetIDs) > 0 {
|
||||
filters = append(filters, map[string]interface{}{
|
||||
"terms": map[string]interface{}{"kb_id": datasetIDs},
|
||||
})
|
||||
}
|
||||
|
||||
// Add available_int filter
|
||||
// Reference: rag/utils/es_conn.py L63-L68
|
||||
if available == 0 {
|
||||
// available_int < 1
|
||||
filters = append(filters, map[string]interface{}{
|
||||
"range": map[string]interface{}{
|
||||
"available_int": map[string]interface{}{
|
||||
"lt": 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
} else {
|
||||
// must_not: available_int < 1 (i.e., available_int >= 1)
|
||||
filters = append(filters, map[string]interface{}{
|
||||
"bool": map[string]interface{}{
|
||||
"must_not": []map[string]interface{}{
|
||||
{
|
||||
"range": map[string]interface{}{
|
||||
"available_int": map[string]interface{}{
|
||||
"lt": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return filters
|
||||
}
|
||||
|
||||
// buildSkillFilterClauses builds ES filter clauses for skill index
|
||||
// Skill index uses 'status' field instead of 'available_int'
|
||||
func buildSkillFilterClauses() []map[string]interface{} {
|
||||
// Filter for active skills (status = "1")
|
||||
return []map[string]interface{}{
|
||||
{
|
||||
"term": map[string]interface{}{
|
||||
"status": "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// buildFilterFromMap converts a generic filter map to ES filter clauses
|
||||
func buildFilterFromMap(filter map[string]interface{}) []map[string]interface{} {
|
||||
var filters []map[string]interface{}
|
||||
for field, value := range filter {
|
||||
switch v := value.(type) {
|
||||
case []string:
|
||||
filters = append(filters, map[string]interface{}{
|
||||
"terms": map[string]interface{}{field: v},
|
||||
})
|
||||
case []interface{}:
|
||||
filters = append(filters, map[string]interface{}{
|
||||
"terms": map[string]interface{}{field: v},
|
||||
})
|
||||
default:
|
||||
filters = append(filters, map[string]interface{}{
|
||||
"term": map[string]interface{}{field: v},
|
||||
})
|
||||
}
|
||||
}
|
||||
return filters
|
||||
return offset, window
|
||||
}
|
||||
|
||||
// convertESResponse converts ES SearchResponse to unified chunks format
|
||||
|
||||
@@ -431,3 +431,106 @@ func TestBuildBoolQueryFromConditionIDFilter(t *testing.T) {
|
||||
"id": 42,
|
||||
}, []string{"id", "_id"})
|
||||
}
|
||||
|
||||
// paginationGRID mirrors the (page_size, top) grid from
|
||||
// rag/nlp/search.py::Dealer._rerank_window tests. It covers the common page
|
||||
// sizes that do NOT divide 64 (the exact case the legacy min(..., 64) clamp
|
||||
// broke) plus tiny / large / page-aligned tops.
|
||||
var paginationGRID = func() []struct{ size, topK int } {
|
||||
sizes := []int{1, 5, 7, 10, 30, 50, 64}
|
||||
tops := []int{0, 5, 30, 50, 55, 64, 100, 1024}
|
||||
out := make([]struct{ size, topK int }, 0, len(sizes)*len(tops))
|
||||
for _, s := range sizes {
|
||||
for _, t := range tops {
|
||||
out = append(out, struct{ size, topK int }{s, t})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}()
|
||||
|
||||
// paginate replays the (block-fetch + in-block slice) math that
|
||||
// calculatePagination's window is consumed by: for every page whose start is
|
||||
// inside the candidate pool, return the in-block page slice. The block is
|
||||
// window-aligned, so on the aligned invariant every page is full and the
|
||||
// concatenation reconstructs [0, cap).
|
||||
func paginate(total, size, topK int) (window, capN int, surfaced []int) {
|
||||
window = rerankWindow(size, topK)
|
||||
capN = total
|
||||
if topK > 0 && capN > topK {
|
||||
capN = topK
|
||||
}
|
||||
for page := 1; (page-1)*size < capN; page++ {
|
||||
globalOffset := (page - 1) * size
|
||||
blockIndex := globalOffset / window
|
||||
blockStart := blockIndex * window
|
||||
block := make([]int, 0, window)
|
||||
for i := blockStart; i < blockStart+window && i < capN; i++ {
|
||||
block = append(block, i)
|
||||
}
|
||||
begin := globalOffset % window
|
||||
end := begin + size
|
||||
if end > len(block) {
|
||||
end = len(block)
|
||||
}
|
||||
surfaced = append(surfaced, block[begin:end]...)
|
||||
}
|
||||
return window, capN, surfaced
|
||||
}
|
||||
|
||||
func TestRerankWindowIsPageAligned(t *testing.T) {
|
||||
for _, g := range paginationGRID {
|
||||
window := rerankWindow(g.size, g.topK)
|
||||
if window < 1 {
|
||||
t.Errorf("rerankWindow(%d, %d) = %d, want >= 1", g.size, g.topK, window)
|
||||
}
|
||||
if g.size > 1 && window%g.size != 0 {
|
||||
t.Errorf("rerankWindow(%d, %d) = %d, want multiple of %d", g.size, g.topK, window, g.size)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRerankWindowPaginationReconstructsPool(t *testing.T) {
|
||||
// Walking every page reconstructs the candidate pool exactly: in order,
|
||||
// no gaps, no duplicates, and no short interior pages.
|
||||
const total = 250
|
||||
for _, g := range paginationGRID {
|
||||
window, capN, surfaced := paginate(total, g.size, g.topK)
|
||||
if len(surfaced) != capN {
|
||||
t.Errorf("size=%d topK=%d: surfaced %d, want %d (window=%d)",
|
||||
g.size, g.topK, len(surfaced), capN, window)
|
||||
continue
|
||||
}
|
||||
for i, v := range surfaced {
|
||||
if v != i {
|
||||
t.Errorf("size=%d topK=%d: surfaced[%d] = %d, want %d (window=%d)",
|
||||
g.size, g.topK, i, v, i, window)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculatePaginationReportedRegression(t *testing.T) {
|
||||
// The reported case: size=10, topK=1024. Legacy min(..., 64) clamped the
|
||||
// window to 64 (not a multiple of 10), so page 7 (global offset 60) used
|
||||
// to return only 4 of 10 results. With the fix, the window is 70 and
|
||||
// page 7 is full and contiguous.
|
||||
_, limit := calculatePagination(7, 10, 1024)
|
||||
if limit != 70 {
|
||||
t.Fatalf("calculatePagination(7, 10, 1024) limit = %d, want 70", limit)
|
||||
}
|
||||
if limit%10 != 0 {
|
||||
t.Fatalf("calculatePagination(7, 10, 1024) limit = %d, want multiple of 10", limit)
|
||||
}
|
||||
|
||||
// And the simulated end-to-end page walk covers positions 60..69 fully.
|
||||
_, capN, surfaced := paginate(250, 10, 1024)
|
||||
if capN < 70 || len(surfaced) < 70 {
|
||||
t.Fatalf("paginate(250, 10, 1024) returned cap=%d surfaced=%d, want >= 70", capN, len(surfaced))
|
||||
}
|
||||
for i := 60; i < 70; i++ {
|
||||
if surfaced[i] != i {
|
||||
t.Errorf("page 7: surfaced[%d] = %d, want %d", i, surfaced[i], i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"ragflow/internal/engine/infinity"
|
||||
"ragflow/internal/engine/types"
|
||||
)
|
||||
|
||||
func TestNewQueryBuilder(t *testing.T) {
|
||||
@@ -229,7 +229,7 @@ func TestQueryBuilder_Question(t *testing.T) {
|
||||
tbl string
|
||||
minMatch float64
|
||||
expectNil bool
|
||||
checkExpr func(*infinity.MatchTextExpr) bool
|
||||
checkExpr func(*types.MatchTextExpr) bool
|
||||
checkKeywords func([]string) bool
|
||||
}{
|
||||
{
|
||||
@@ -237,7 +237,7 @@ func TestQueryBuilder_Question(t *testing.T) {
|
||||
txt: "请问如何安装软件",
|
||||
tbl: "test",
|
||||
minMatch: 0.5,
|
||||
checkExpr: func(expr *infinity.MatchTextExpr) bool {
|
||||
checkExpr: func(expr *types.MatchTextExpr) bool {
|
||||
// Should return a valid query expression with processed text
|
||||
return expr != nil && expr.MatchingText != ""
|
||||
},
|
||||
@@ -251,7 +251,7 @@ func TestQueryBuilder_Question(t *testing.T) {
|
||||
txt: "How to install software",
|
||||
tbl: "test",
|
||||
minMatch: 0.5,
|
||||
checkExpr: func(expr *infinity.MatchTextExpr) bool {
|
||||
checkExpr: func(expr *types.MatchTextExpr) bool {
|
||||
// Should return a valid query expression with processed text
|
||||
return expr != nil && expr.MatchingText != ""
|
||||
},
|
||||
@@ -265,7 +265,7 @@ func TestQueryBuilder_Question(t *testing.T) {
|
||||
txt: "hello世界",
|
||||
tbl: "test",
|
||||
minMatch: 0.5,
|
||||
checkExpr: func(expr *infinity.MatchTextExpr) bool {
|
||||
checkExpr: func(expr *types.MatchTextExpr) bool {
|
||||
// Should return a valid query expression with processed text
|
||||
return expr != nil && expr.MatchingText != ""
|
||||
},
|
||||
@@ -280,7 +280,7 @@ func TestQueryBuilder_Question(t *testing.T) {
|
||||
tbl: "test",
|
||||
minMatch: 0.5,
|
||||
expectNil: true,
|
||||
checkExpr: func(expr *infinity.MatchTextExpr) bool {
|
||||
checkExpr: func(expr *types.MatchTextExpr) bool {
|
||||
return expr == nil
|
||||
},
|
||||
checkKeywords: func(keywords []string) bool {
|
||||
|
||||
@@ -164,6 +164,16 @@ func RerankByModel(
|
||||
}
|
||||
}
|
||||
|
||||
// Reranker drivers do not agree on a score scale: Cohere/Jina/Voyage emit
|
||||
// calibrated [0, 1] relevance scores, but NVIDIA returns raw, often
|
||||
// negative logits. The hybrid blend below (tkWeight * tksim + vtWeight *
|
||||
// modelSim) lives on a fixed [0, 1] scale, so an un-normalized logit
|
||||
// weighted by vtWeight=0.7 can sink a relevant chunk below pure keyword
|
||||
// matches and dominate the blend. Centralize the normalization here so
|
||||
// every provider contributes on the same scale. See
|
||||
// NormalizeRerankScores for the contract.
|
||||
modelSim = NormalizeRerankScores(modelSim)
|
||||
|
||||
// Combine token similarity with model similarity
|
||||
// Model similarity is treated as vector similarity component
|
||||
sim = make([]float64, len(insTw))
|
||||
@@ -179,6 +189,68 @@ func RerankByModel(
|
||||
return sim, tsim, modelSim
|
||||
}
|
||||
|
||||
// NormalizeRerankScores rescales reranker scores into [0, 1] for the
|
||||
// hybrid blend in RerankByModel. Mirrors the contract enforced by
|
||||
// Base.similarity / _normalize_rank in rag/llm/rerank_model.py.
|
||||
//
|
||||
// Providers that already return calibrated [0, 1] relevance scores
|
||||
// (Cohere, Jina, Voyage, ...) are returned unchanged, so
|
||||
// similarity_threshold filtering and reported vector_similarity keep
|
||||
// their absolute magnitudes. Only out-of-range output (e.g. NVIDIA's
|
||||
// unbounded, often negative logits) is rescaled: a batch with usable
|
||||
// spread is min-max mapped onto [0, 1] (which stops a negative logit
|
||||
// from dragging a relevant chunk below pure keyword matches once
|
||||
// weighted by vtweight), while a spreadless batch (including a single
|
||||
// candidate) is clamped per element so a lone high score is not silently
|
||||
// zeroed and no NaN leaks into the blend.
|
||||
//
|
||||
// An empty input is returned verbatim. Mutates the input slice in place
|
||||
// to keep the RerankByModel call site allocation-free; the returned
|
||||
// slice is the same backing array.
|
||||
func NormalizeRerankScores(scores []float64) []float64 {
|
||||
n := len(scores)
|
||||
if n == 0 {
|
||||
return scores
|
||||
}
|
||||
minScore := scores[0]
|
||||
maxScore := scores[0]
|
||||
for _, s := range scores[1:] {
|
||||
if s < minScore {
|
||||
minScore = s
|
||||
}
|
||||
if s > maxScore {
|
||||
maxScore = s
|
||||
}
|
||||
}
|
||||
|
||||
// Already in [0, 1]? Keep absolute magnitudes so calibrated providers
|
||||
// and degenerate (but valid) batches are NOT collapsed to zero.
|
||||
if minScore >= 0.0 && maxScore <= 1.0 {
|
||||
return scores
|
||||
}
|
||||
|
||||
// Spreadless out-of-range batch: clamp per element instead of
|
||||
// collapsing to zero or dividing by ~0.
|
||||
span := maxScore - minScore
|
||||
if span < 1e-3 {
|
||||
for i, s := range scores {
|
||||
if s < 0.0 {
|
||||
scores[i] = 0.0
|
||||
} else if s > 1.0 {
|
||||
scores[i] = 1.0
|
||||
}
|
||||
}
|
||||
return scores
|
||||
}
|
||||
|
||||
// Min-max rescale onto [0, 1].
|
||||
invSpan := 1.0 / span
|
||||
for i, s := range scores {
|
||||
scores[i] = (s - minScore) * invSpan
|
||||
}
|
||||
return scores
|
||||
}
|
||||
|
||||
// RerankStandard performs standard reranking without a reranker model
|
||||
// Used for Elasticsearch when no reranker model is provided
|
||||
func RerankStandard(
|
||||
|
||||
205
internal/service/nlp/reranker_normalize_test.go
Normal file
205
internal/service/nlp/reranker_normalize_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
// Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package nlp
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestNormalizeRerankScores_OutOfRange_Rescaled covers the central bug fix:
|
||||
// uncalibrated reranker output (e.g. NVIDIA logits) is min-max rescaled
|
||||
// onto [0, 1] so a negative logit weighted by vtWeight=0.7 cannot sink a
|
||||
// relevant chunk below pure keyword matches.
|
||||
func TestNormalizeRerankScores_OutOfRange_Rescaled(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in []float64
|
||||
want []float64
|
||||
}{
|
||||
{"unbounded mixed-sign logits", []float64{10.0, -3.0, 0.0}, []float64{1.0, 0.0, 3.0 / 13.0}},
|
||||
{"large positive logits", []float64{100.0, 50.0, 75.0}, []float64{1.0, 0.0, 0.5}},
|
||||
{"negative-only logits", []float64{-1.0, -5.0, -3.0}, []float64{1.0, 0.0, 0.5}},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := NormalizeRerankScores(tc.in)
|
||||
if !floatsClose(got, tc.want, 1e-9) {
|
||||
t.Errorf("got %v, want %v", got, tc.want)
|
||||
}
|
||||
if minOf(got) < 0.0 || maxOf(got) > 1.0 {
|
||||
t.Errorf("scores escaped [0, 1]: %v", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNormalizeRerankScores_InRange_Preserved pins the calibrated-provider
|
||||
// guarantee: Cohere/Jina/Voyage-style scores in [0, 1] are returned verbatim,
|
||||
// so similarity_threshold semantics and the reported vector_similarity keep
|
||||
// their absolute magnitudes.
|
||||
func TestNormalizeRerankScores_InRange_Preserved(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in []float64
|
||||
}{
|
||||
{"spread relevance", []float64{0.9, 0.1, 0.5}},
|
||||
{"all-equal but valid", []float64{0.8, 0.8, 0.8}},
|
||||
{"single candidate", []float64{1.0}},
|
||||
{"already spanning the full range", []float64{0.0, 1.0, 0.42}},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := NormalizeRerankScores(tc.in)
|
||||
if !floatsClose(got, tc.in, 1e-9) {
|
||||
t.Errorf("got %v, want %v (must be preserved)", got, tc.in)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNormalizeRerankScores_PreservesOrdering ensures rescaling does not
|
||||
// scramble the relative ranking; this is the property downstream code relies
|
||||
// on when sorting by rerank score.
|
||||
func TestNormalizeRerankScores_PreservesOrdering(t *testing.T) {
|
||||
in := []float64{-5.0, 12.0, 3.0, -1.0}
|
||||
got := NormalizeRerankScores(in)
|
||||
wantOrder := argsortDesc(in)
|
||||
gotOrder := argsortDesc(got)
|
||||
if !intsEqual(wantOrder, gotOrder) {
|
||||
t.Errorf("ordering changed: want %v, got %v", wantOrder, gotOrder)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNormalizeRerankScores_SpreadlessOutOfRange_Clamped covers the
|
||||
// degenerate but realistic case of a single rerank candidate or a flat
|
||||
// batch of out-of-range values: clamped per element, never zeroed, never
|
||||
// NaN. A lone high logit would otherwise be silently dropped and
|
||||
// contaminate the blend with NaN if divided by ~0.
|
||||
func TestNormalizeRerankScores_SpreadlessOutOfRange_Clamped(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in []float64
|
||||
want []float64
|
||||
}{
|
||||
{"single out-of-range high", []float64{5.0}, []float64{1.0}},
|
||||
{"single out-of-range negative", []float64{-3.0}, []float64{0.0}},
|
||||
{"flat out-of-range high batch", []float64{5.0, 5.0, 5.0}, []float64{1.0, 1.0, 1.0}},
|
||||
{"flat out-of-range low batch", []float64{-2.0, -2.0, -2.0}, []float64{0.0, 0.0, 0.0}},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := NormalizeRerankScores(tc.in)
|
||||
if !floatsClose(got, tc.want, 1e-9) {
|
||||
t.Errorf("got %v, want %v", got, tc.want)
|
||||
}
|
||||
for _, s := range got {
|
||||
if math.IsNaN(s) {
|
||||
t.Fatalf("NaN leaked into normalized scores: %v", got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNormalizeRerankScores_Empty covers the empty-input contract: returned
|
||||
// verbatim, no allocation, no panic.
|
||||
func TestNormalizeRerankScores_Empty(t *testing.T) {
|
||||
got := NormalizeRerankScores(nil)
|
||||
if len(got) != 0 {
|
||||
t.Errorf("nil in -> expected empty out, got %v", got)
|
||||
}
|
||||
got = NormalizeRerankScores([]float64{})
|
||||
if len(got) != 0 {
|
||||
t.Errorf("[] in -> expected empty out, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNormalizeRerankScores_InPlace pins the in-place guarantee: the input
|
||||
// slice's backing array is what gets returned, so the RerankByModel call
|
||||
// site stays allocation-free.
|
||||
func TestNormalizeRerankScores_InPlace(t *testing.T) {
|
||||
in := []float64{10.0, -3.0, 0.0}
|
||||
got := NormalizeRerankScores(in)
|
||||
if &got[0] != &in[0] {
|
||||
t.Errorf("NormalizeRerankScores must mutate in place; got a new backing array")
|
||||
}
|
||||
}
|
||||
|
||||
func floatsClose(a, b []float64, tol float64) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if math.Abs(a[i]-b[i]) > tol {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func intsEqual(a, b []int) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func minOf(s []float64) float64 {
|
||||
if len(s) == 0 {
|
||||
return 0
|
||||
}
|
||||
m := s[0]
|
||||
for _, v := range s[1:] {
|
||||
if v < m {
|
||||
m = v
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func maxOf(s []float64) float64 {
|
||||
if len(s) == 0 {
|
||||
return 0
|
||||
}
|
||||
m := s[0]
|
||||
for _, v := range s[1:] {
|
||||
if v > m {
|
||||
m = v
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// argsortDesc returns the indices of s sorted by value in descending order,
|
||||
// matching how a downstream consumer would compare rerank scores.
|
||||
func argsortDesc(s []float64) []int {
|
||||
idx := make([]int, len(s))
|
||||
for i := range idx {
|
||||
idx[i] = i
|
||||
}
|
||||
// Insertion sort keeps it dependency-free; len is small (batch size).
|
||||
for i := 1; i < len(idx); i++ {
|
||||
for j := i; j > 0 && s[idx[j]] > s[idx[j-1]]; j-- {
|
||||
idx[j], idx[j-1] = idx[j-1], idx[j]
|
||||
}
|
||||
}
|
||||
return idx
|
||||
}
|
||||
Reference in New Issue
Block a user