diff --git a/conf/models/replicate.json b/conf/models/replicate.json index a57ab2db2e..42a8255dc7 100644 --- a/conf/models/replicate.json +++ b/conf/models/replicate.json @@ -29,6 +29,13 @@ "model_types": [ "embedding" ] + }, + { + "name": "yxzwayne/bge-reranker-v2-m3:7f7c6e9d18336e2cbf07d88e9362d881d2fe4d6a9854ec1260f115cabc106a8c", + "max_tokens": 8192, + "model_types": [ + "rerank" + ] } ] } diff --git a/internal/entity/models/replicate.go b/internal/entity/models/replicate.go index c1c72ace54..135fcebf4f 100644 --- a/internal/entity/models/replicate.go +++ b/internal/entity/models/replicate.go @@ -25,6 +25,7 @@ import ( "io" "net/http" "net/url" + "sort" "strings" "time" ) @@ -716,10 +717,159 @@ func (r *ReplicateModel) Embed(modelName *string, texts []string, apiConfig *API return replicateEmbedOutputToVectors(prediction.Output, len(texts)) } -func (r *ReplicateModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { - return nil, fmt.Errorf("%s, no such method", r.Name()) +// replicateRerankInput shapes the request body for Replicate's +// canonical bge-style reranker schema. The documented input is a +// single string field `input_list` carrying a JSON-encoded list of +// `[query, passage]` pairs; the model returns a flat list of +// numeric scores, one per pair, in the same order. +// +// See yxzwayne/bge-reranker-v2-m3's openapi_schema + default_example +// at https://replicate.com/yxzwayne/bge-reranker-v2-m3. Other +// reranker models on Replicate (sesamo-srl/bge-reranker-v2-m3, +// ninehills/bge-reranker-large) follow compatible +// pair-list-in-string conventions; this driver targets the +// canonical shape and leaves model-specific adapters for future +// PRs if other schemas are needed. +func replicateRerankInput(query string, documents []string) (map[string]interface{}, error) { + if len(documents) == 0 { + return nil, fmt.Errorf("replicate: documents is empty") + } + pairs := make([][2]string, len(documents)) + for i, doc := range documents { + pairs[i] = [2]string{query, doc} + } + encoded, err := json.Marshal(pairs) + if err != nil { + return nil, fmt.Errorf("failed to encode input_list: %w", err) + } + return map[string]interface{}{"input_list": string(encoded)}, nil } +// replicateRerankOutputToScores normalizes Replicate's two observed +// rerank-output shapes into a []float64 aligned with the caller's +// document order: +// +// []float64 — flat scores array, used by +// yxzwayne/bge-reranker-v2-m3 (canonical) +// { "scores": [..] } — wrapped object, used by ninehills/bge-reranker-large +// +// Rejects mismatched cardinality and non-numeric scores rather than +// silently truncate, matching the defensive posture the Embed +// implementation already uses. +func replicateRerankOutputToScores(output interface{}, n int) ([]float64, error) { + if scores, ok := output.([]interface{}); ok { + return replicateScoresFromInterface(scores, n) + } + if obj, ok := output.(map[string]interface{}); ok { + raw, present := obj["scores"] + if !present { + return nil, fmt.Errorf("replicate: rerank output missing 'scores' field; got keys %v", replicateKeys(obj)) + } + arr, ok := raw.([]interface{}) + if !ok { + return nil, fmt.Errorf("replicate: rerank output.scores is %T, expected array", raw) + } + return replicateScoresFromInterface(arr, n) + } + return nil, fmt.Errorf("replicate: expected rerank output to be an array or object, got %T", output) +} + +func replicateScoresFromInterface(arr []interface{}, n int) ([]float64, error) { + if len(arr) != n { + return nil, fmt.Errorf("replicate: expected %d rerank scores, got %d", n, len(arr)) + } + out := make([]float64, n) + for i, v := range arr { + f, ok := v.(float64) + if !ok { + return nil, fmt.Errorf("replicate: rerank score %d is %T, expected number", i, v) + } + out[i] = f + } + return out, nil +} + +// Rerank scores a query against a list of documents via Replicate's +// prediction API. The driver targets bge-reranker-v2-m3-style models +// (the most widely-published rerank schema on Replicate) and reuses +// the existing createPrediction + waitForPrediction plumbing from +// the chat and embed paths. +// +// Replicate rerank model outputs are raw similarity scores — they +// are NOT normalized to [0, 1] like Cohere or Voyage rerank +// responses. Higher scores still indicate stronger relevance; the +// driver passes the raw value through without rescaling so callers +// can compare against per-model thresholds, but the RelevanceScore +// field should not be assumed to be a probability. +func (r *ReplicateModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + if len(documents) == 0 { + return &RerankResponse{}, nil + } + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if modelName == nil || strings.TrimSpace(*modelName) == "" { + return nil, fmt.Errorf("model name is required") + } + + url, version, err := r.predictionEndpoint(apiConfig, *modelName) + if err != nil { + return nil, err + } + + input, err := replicateRerankInput(query, documents) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + prediction, err := r.createPrediction(ctx, url, version, input, false, *apiConfig.ApiKey, true) + if err != nil { + return nil, err + } + prediction, err = r.waitForPrediction(ctx, prediction, *apiConfig.ApiKey) + if err != nil { + return nil, err + } + if !replicatePredictionSucceeded(prediction.Status) { + return nil, fmt.Errorf("replicate: prediction ended with status %q", prediction.Status) + } + + scores, err := replicateRerankOutputToScores(prediction.Output, len(documents)) + if err != nil { + return nil, err + } + + // Build the canonical RerankResponse with one entry per input + // document. Optional top_n trimming sorts by score descending + // and keeps the highest-ranking documents; otherwise return all + // scores in original document order, matching how Voyage's + // driver in this package surfaces its results. + topN := len(documents) + if rerankConfig != nil && rerankConfig.TopN > 0 && rerankConfig.TopN < topN { + topN = rerankConfig.TopN + } + results := make([]RerankResult, len(documents)) + for i, score := range scores { + results[i] = RerankResult{Index: i, RelevanceScore: score} + } + if topN < len(results) { + // Sort by score descending, stable on index to keep deterministic + // ordering for ties. + sort.SliceStable(results, func(a, b int) bool { + if results[a].RelevanceScore == results[b].RelevanceScore { + return results[a].Index < results[b].Index + } + return results[a].RelevanceScore > results[b].RelevanceScore + }) + results = results[:topN] + } + return &RerankResponse{Data: results}, nil +} + + func (r *ReplicateModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { return nil, fmt.Errorf("%s, no such method", r.Name()) }