2024-11-12 14:59:41 +08:00
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
2026-01-15 12:28:49 +08:00
import asyncio
2024-11-12 14:59:41 +08:00
import json
import os
import sys
import time
import argparse
from collections import defaultdict
2025-11-06 09:36:38 +08:00
from common import settings
2025-11-05 08:01:39 +08:00
from common . constants import LLMType
2024-11-12 14:59:41 +08:00
from api . db . services . llm_service import LLMBundle
from api . db . services . knowledgebase_service import KnowledgebaseService
2026-05-29 17:39:41 +08:00
from api . db . joint_services . tenant_model_service import get_model_config_from_provider_instance
2025-10-31 16:42:01 +08:00
from common . misc_utils import get_uuid
2024-11-12 14:59:41 +08:00
from rag . nlp import tokenize , search
from ranx import evaluate
2024-11-15 10:14:51 +08:00
from ranx import Qrels , Run
2024-11-12 14:59:41 +08:00
import pandas as pd
from tqdm import tqdm
global max_docs
max_docs = sys . maxsize
2024-11-13 15:56:40 +08:00
2024-11-12 14:59:41 +08:00
class Benchmark :
def __init__ ( self , kb_id ) :
self . kb_id = kb_id
e , self . kb = KnowledgebaseService . get_by_id ( kb_id )
self . similarity_threshold = self . kb . similarity_threshold
self . vector_similarity_weight = self . kb . vector_similarity_weight
2026-05-29 17:39:41 +08:00
embd_model_config = get_model_config_from_provider_instance ( self . kb . tenant_id , LLMType . EMBEDDING , self . kb . embd_id )
2026-03-05 17:27:17 +08:00
self . embd_mdl = LLMBundle ( self . kb . tenant_id , embd_model_config , lang = self . kb . language )
2026-07-03 12:53:39 +08:00
self . tenant_id = " "
self . index_name = " "
2024-11-12 14:59:41 +08:00
self . initialized_index = False
def _get_retrieval ( self , qrels ) :
# Need to wait for the ES and Infinity index to be ready
time . sleep ( 20 )
run = defaultdict ( dict )
query_list = list ( qrels . keys ( ) )
for query in query_list :
2026-07-03 12:53:39 +08:00
ranks = asyncio . run ( settings . retriever . retrieval ( query , self . embd_mdl , self . tenant_id , [ self . kb . id ] , 1 , 30 , 0.0 , self . vector_similarity_weight ) )
2024-11-12 14:59:41 +08:00
if len ( ranks [ " chunks " ] ) == 0 :
print ( f " deleted query: { query } " )
del qrels [ query ]
continue
for c in ranks [ " chunks " ] :
2024-11-19 14:15:25 +08:00
c . pop ( " vector " , None )
2024-11-12 14:59:41 +08:00
run [ query ] [ c [ " chunk_id " ] ] = c [ " similarity " ]
return run
2024-12-03 16:22:39 +08:00
def embedding ( self , docs ) :
texts = [ d [ " content_with_weight " ] for d in docs ]
embeddings , _ = self . embd_mdl . encode ( texts )
assert len ( docs ) == len ( embeddings )
2024-11-12 14:59:41 +08:00
vector_size = 0
for i , d in enumerate ( docs ) :
2024-12-03 16:22:39 +08:00
v = embeddings [ i ]
2024-11-12 14:59:41 +08:00
vector_size = len ( v )
d [ " q_ %d _vec " % len ( v ) ] = v
return docs , vector_size
def init_index ( self , vector_size : int ) :
if self . initialized_index :
return
2025-12-25 21:18:13 +08:00
if settings . docStoreConn . index_exist ( self . index_name , self . kb_id ) :
settings . docStoreConn . delete_idx ( self . index_name , self . kb_id )
settings . docStoreConn . create_idx ( self . index_name , self . kb_id , vector_size )
2024-11-12 14:59:41 +08:00
self . initialized_index = True
def ms_marco_index ( self , file_path , index_name ) :
qrels = defaultdict ( dict )
texts = defaultdict ( dict )
docs_count = 0
docs = [ ]
filelist = sorted ( os . listdir ( file_path ) )
for fn in filelist :
if docs_count > = max_docs :
break
if not fn . endswith ( " .parquet " ) :
continue
data = pd . read_parquet ( os . path . join ( file_path , fn ) )
for i in tqdm ( range ( len ( data ) ) , colour = " green " , desc = " Tokenizing: " + fn ) :
if docs_count > = max_docs :
break
2026-07-03 12:53:39 +08:00
query = data . iloc [ i ] [ " query " ]
for rel , text in zip ( data . iloc [ i ] [ " passages " ] [ " is_selected " ] , data . iloc [ i ] [ " passages " ] [ " passage_text " ] ) :
d = { " id " : get_uuid ( ) , " kb_id " : self . kb . id , " docnm_kwd " : " xxxxx " , " doc_id " : " ksksks " }
2024-11-12 14:59:41 +08:00
tokenize ( d , text , " english " )
docs . append ( d )
texts [ d [ " id " ] ] = text
qrels [ query ] [ d [ " id " ] ] = int ( rel )
if len ( docs ) > = 32 :
docs_count + = len ( docs )
docs , vector_size = self . embedding ( docs )
self . init_index ( vector_size )
2025-11-06 09:36:38 +08:00
settings . docStoreConn . insert ( docs , self . index_name , self . kb_id )
2024-11-12 14:59:41 +08:00
docs = [ ]
if docs :
docs , vector_size = self . embedding ( docs )
self . init_index ( vector_size )
2025-11-06 09:36:38 +08:00
settings . docStoreConn . insert ( docs , self . index_name , self . kb_id )
2024-11-12 14:59:41 +08:00
return qrels , texts
def trivia_qa_index ( self , file_path , index_name ) :
qrels = defaultdict ( dict )
texts = defaultdict ( dict )
docs_count = 0
docs = [ ]
filelist = sorted ( os . listdir ( file_path ) )
for fn in filelist :
if docs_count > = max_docs :
break
if not fn . endswith ( " .parquet " ) :
continue
data = pd . read_parquet ( os . path . join ( file_path , fn ) )
for i in tqdm ( range ( len ( data ) ) , colour = " green " , desc = " Indexing: " + fn ) :
if docs_count > = max_docs :
break
2026-07-03 12:53:39 +08:00
query = data . iloc [ i ] [ " question " ]
for rel , text in zip ( data . iloc [ i ] [ " search_results " ] [ " rank " ] , data . iloc [ i ] [ " search_results " ] [ " search_context " ] ) :
d = { " id " : get_uuid ( ) , " kb_id " : self . kb . id , " docnm_kwd " : " xxxxx " , " doc_id " : " ksksks " }
2024-11-12 14:59:41 +08:00
tokenize ( d , text , " english " )
docs . append ( d )
texts [ d [ " id " ] ] = text
qrels [ query ] [ d [ " id " ] ] = int ( rel )
if len ( docs ) > = 32 :
docs_count + = len ( docs )
docs , vector_size = self . embedding ( docs )
self . init_index ( vector_size )
2026-07-03 12:53:39 +08:00
settings . docStoreConn . insert ( docs , self . index_name )
2024-11-12 14:59:41 +08:00
docs = [ ]
docs , vector_size = self . embedding ( docs )
self . init_index ( vector_size )
2025-11-06 09:36:38 +08:00
settings . docStoreConn . insert ( docs , self . index_name )
2024-11-12 14:59:41 +08:00
return qrels , texts
def miracl_index ( self , file_path , corpus_path , index_name ) :
corpus_total = { }
for corpus_file in os . listdir ( corpus_path ) :
tmp_data = pd . read_json ( os . path . join ( corpus_path , corpus_file ) , lines = True )
for index , i in tmp_data . iterrows ( ) :
2026-07-03 12:53:39 +08:00
corpus_total [ i [ " docid " ] ] = i [ " text " ]
2024-11-12 14:59:41 +08:00
topics_total = { }
2026-07-03 12:53:39 +08:00
for topics_file in os . listdir ( os . path . join ( file_path , " topics " ) ) :
if " test " in topics_file :
2024-11-12 14:59:41 +08:00
continue
2026-07-03 12:53:39 +08:00
tmp_data = pd . read_csv ( os . path . join ( file_path , " topics " , topics_file ) , sep = " \t " , names = [ " qid " , " query " ] )
2024-11-12 14:59:41 +08:00
for index , i in tmp_data . iterrows ( ) :
2026-07-03 12:53:39 +08:00
topics_total [ i [ " qid " ] ] = i [ " query " ]
2024-11-12 14:59:41 +08:00
qrels = defaultdict ( dict )
texts = defaultdict ( dict )
docs_count = 0
docs = [ ]
2026-07-03 12:53:39 +08:00
for qrels_file in os . listdir ( os . path . join ( file_path , " qrels " ) ) :
if " test " in qrels_file :
2024-11-12 14:59:41 +08:00
continue
if docs_count > = max_docs :
break
2026-07-03 12:53:39 +08:00
tmp_data = pd . read_csv ( os . path . join ( file_path , " qrels " , qrels_file ) , sep = " \t " , names = [ " qid " , " Q0 " , " docid " , " relevance " ] )
2024-11-12 14:59:41 +08:00
for i in tqdm ( range ( len ( tmp_data ) ) , colour = " green " , desc = " Indexing: " + qrels_file ) :
if docs_count > = max_docs :
break
2026-07-03 12:53:39 +08:00
query = topics_total [ tmp_data . iloc [ i ] [ " qid " ] ]
text = corpus_total [ tmp_data . iloc [ i ] [ " docid " ] ]
rel = tmp_data . iloc [ i ] [ " relevance " ]
d = { " id " : get_uuid ( ) , " kb_id " : self . kb . id , " docnm_kwd " : " xxxxx " , " doc_id " : " ksksks " }
tokenize ( d , text , " english " )
2024-11-12 14:59:41 +08:00
docs . append ( d )
texts [ d [ " id " ] ] = text
qrels [ query ] [ d [ " id " ] ] = int ( rel )
if len ( docs ) > = 32 :
docs_count + = len ( docs )
docs , vector_size = self . embedding ( docs )
self . init_index ( vector_size )
2025-11-06 09:36:38 +08:00
settings . docStoreConn . insert ( docs , self . index_name )
2024-11-12 14:59:41 +08:00
docs = [ ]
docs , vector_size = self . embedding ( docs )
self . init_index ( vector_size )
2025-11-06 09:36:38 +08:00
settings . docStoreConn . insert ( docs , self . index_name )
2024-11-12 14:59:41 +08:00
return qrels , texts
def save_results ( self , qrels , run , texts , dataset , file_path ) :
keep_result = [ ]
run_keys = list ( run . keys ( ) )
for run_i in tqdm ( range ( len ( run_keys ) ) , desc = " Calculating ndcg@10 for single query " ) :
key = run_keys [ run_i ]
2026-07-03 12:53:39 +08:00
keep_result . append ( { " query " : key , " qrel " : qrels [ key ] , " run " : run [ key ] , " ndcg@10 " : evaluate ( { key : qrels [ key ] } , { key : run [ key ] } , " ndcg@10 " ) } )
keep_result = sorted ( keep_result , key = lambda kk : kk [ " ndcg@10 " ] )
with open ( os . path . join ( file_path , dataset + " result.md " ) , " w " , encoding = " utf-8 " ) as f :
f . write ( " ## Score For Every Query \n " )
2024-11-12 14:59:41 +08:00
for keep_result_i in keep_result :
2026-07-03 12:53:39 +08:00
f . write ( " ### query: " + keep_result_i [ " query " ] + " ndcg@10: " + str ( keep_result_i [ " ndcg@10 " ] ) + " \n " )
scores = [ [ i [ 0 ] , i [ 1 ] ] for i in keep_result_i [ " run " ] . items ( ) ]
2024-11-12 14:59:41 +08:00
scores = sorted ( scores , key = lambda kk : kk [ 1 ] )
for score in scores [ : 10 ] :
2026-07-03 12:53:39 +08:00
f . write ( " - text: " + str ( texts [ score [ 0 ] ] ) + " \t qrel: " + str ( score [ 1 ] ) + " \n " )
json . dump ( qrels , open ( os . path . join ( file_path , dataset + " .qrels.json " ) , " w+ " , encoding = " utf-8 " ) , indent = 2 )
json . dump ( run , open ( os . path . join ( file_path , dataset + " .run.json " ) , " w+ " , encoding = " utf-8 " ) , indent = 2 )
print ( os . path . join ( file_path , dataset + " _result.md " ) , " Saved! " )
2024-11-12 14:59:41 +08:00
2026-07-03 12:53:39 +08:00
def __call__ ( self , dataset , file_path , miracl_corpus = " " ) :
2024-11-12 14:59:41 +08:00
if dataset == " ms_marco_v1.1 " :
self . tenant_id = " benchmark_ms_marco_v11 "
self . index_name = search . index_name ( self . tenant_id )
qrels , texts = self . ms_marco_index ( file_path , " benchmark_ms_marco_v1.1 " )
run = self . _get_retrieval ( qrels )
2024-11-15 10:14:51 +08:00
print ( dataset , evaluate ( Qrels ( qrels ) , Run ( run ) , [ " ndcg@10 " , " map@5 " , " mrr@10 " ] ) )
2024-11-12 14:59:41 +08:00
self . save_results ( qrels , run , texts , dataset , file_path )
if dataset == " trivia_qa " :
self . tenant_id = " benchmark_trivia_qa "
self . index_name = search . index_name ( self . tenant_id )
qrels , texts = self . trivia_qa_index ( file_path , " benchmark_trivia_qa " )
run = self . _get_retrieval ( qrels )
2024-11-15 10:14:51 +08:00
print ( dataset , evaluate ( Qrels ( qrels ) , Run ( run ) , [ " ndcg@10 " , " map@5 " , " mrr@10 " ] ) )
2024-11-12 14:59:41 +08:00
self . save_results ( qrels , run , texts , dataset , file_path )
if dataset == " miracl " :
2026-07-03 12:53:39 +08:00
for lang in [ " ar " , " bn " , " de " , " en " , " es " , " fa " , " fi " , " fr " , " hi " , " id " , " ja " , " ko " , " ru " , " sw " , " te " , " th " , " yo " , " zh " ] :
if not os . path . isdir ( os . path . join ( file_path , " miracl-v1.0- " + lang ) ) :
print ( " Directory: " + os . path . join ( file_path , " miracl-v1.0- " + lang ) + " not found! " )
2024-11-12 14:59:41 +08:00
continue
2026-07-03 12:53:39 +08:00
if not os . path . isdir ( os . path . join ( file_path , " miracl-v1.0- " + lang , " qrels " ) ) :
print ( " Directory: " + os . path . join ( file_path , " miracl-v1.0- " + lang , " qrels " ) + " not found! " )
2024-11-12 14:59:41 +08:00
continue
2026-07-03 12:53:39 +08:00
if not os . path . isdir ( os . path . join ( file_path , " miracl-v1.0- " + lang , " topics " ) ) :
print ( " Directory: " + os . path . join ( file_path , " miracl-v1.0- " + lang , " topics " ) + " not found! " )
2024-11-12 14:59:41 +08:00
continue
2026-07-03 12:53:39 +08:00
if not os . path . isdir ( os . path . join ( miracl_corpus , " miracl-corpus-v1.0- " + lang ) ) :
print ( " Directory: " + os . path . join ( miracl_corpus , " miracl-corpus-v1.0- " + lang ) + " not found! " )
2024-11-12 14:59:41 +08:00
continue
self . tenant_id = " benchmark_miracl_ " + lang
self . index_name = search . index_name ( self . tenant_id )
self . initialized_index = False
2026-07-03 12:53:39 +08:00
qrels , texts = self . miracl_index ( os . path . join ( file_path , " miracl-v1.0- " + lang ) , os . path . join ( miracl_corpus , " miracl-corpus-v1.0- " + lang ) , " benchmark_miracl_ " + lang )
2024-11-12 14:59:41 +08:00
run = self . _get_retrieval ( qrels )
2024-11-15 10:14:51 +08:00
print ( dataset , evaluate ( Qrels ( qrels ) , Run ( run ) , [ " ndcg@10 " , " map@5 " , " mrr@10 " ] ) )
2024-11-12 14:59:41 +08:00
self . save_results ( qrels , run , texts , dataset , file_path )
2026-07-03 12:53:39 +08:00
if __name__ == " __main__ " :
print ( " *****************RAGFlow Benchmark***************** " )
parser = argparse . ArgumentParser ( usage = " benchmark.py <max_docs> <kb_id> <dataset> <dataset_path> [<miracl_corpus_path>]) " , description = " RAGFlow Benchmark " )
parser . add_argument ( " max_docs " , metavar = " max_docs " , type = int , help = " max docs to evaluate " )
parser . add_argument ( " kb_id " , metavar = " kb_id " , help = " dataset id " )
parser . add_argument (
" dataset " ,
metavar = " dataset " ,
help = " dataset name, shall be one of ms_marco_v1.1(https://huggingface.co/datasets/microsoft/ms_marco), trivia_qa(https://huggingface.co/datasets/mandarjoshi/trivia_qa>), miracl(https://huggingface.co/datasets/miracl/miracl " ,
)
parser . add_argument ( " dataset_path " , metavar = " dataset_path " , help = " dataset path " )
parser . add_argument ( " miracl_corpus_path " , metavar = " miracl_corpus_path " , nargs = " ? " , default = " " , help = " miracl corpus path. Only needed when dataset is miracl " )
2024-11-12 14:59:41 +08:00
args = parser . parse_args ( )
max_docs = args . max_docs
kb_id = args . kb_id
ex = Benchmark ( kb_id )
dataset = args . dataset
dataset_path = args . dataset_path
if dataset == " ms_marco_v1.1 " or dataset == " trivia_qa " :
ex ( dataset , dataset_path )
elif dataset == " miracl " :
if len ( args ) < 5 :
2026-07-03 12:53:39 +08:00
print ( " Please input the correct parameters! " )
2024-11-12 14:59:41 +08:00
exit ( 1 )
miracl_corpus_path = args [ 4 ]
ex ( dataset , dataset_path , miracl_corpus = args . miracl_corpus_path )
else :
print ( " Dataset: " , dataset , " not supported! " )