Files
ragflow/rag/graphrag/search.py
Yufeng He 0d836afd34 fix: keep max pagerank for repeated n-hop edges (#15696)
## Summary

Fixes #15695.

The Python GraphRAG path already accumulates similarity when several
N-hop paths produce the same edge, but PageRank was overwritten by the
last path. That makes ranking depend on path order for repeated edges.

This keeps the strongest PageRank seen for a repeated edge in the Python
implementation:

- `rag/graphrag/search.py`

The similarity score still accumulates exactly as before.

## To verify

- `python -m py_compile rag\graphrag\search.py`
- `git diff --check`
- `git diff --stat upstream/main` -> only `rag/graphrag/search.py`

I originally included the Go implementation too, but removed it after
maintainer feedback because the Go version is still under development
and not released yet.
2026-06-11 20:53:11 +08:00

351 lines
15 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
from collections import defaultdict
from copy import deepcopy
import json_repair
import pandas as pd
from common.misc_utils import get_uuid
from rag.graphrag.query_analyze_prompt import PROMPTS
from rag.graphrag.utils import get_entity_type2samples, get_llm_cache, set_llm_cache, get_relation
from common.token_utils import num_tokens_from_string
from rag.nlp.search import Dealer, index_name
from common.float_utils import get_float
from common import settings
from common.doc_store.doc_store_base import OrderByExpr
class KGSearch(Dealer):
async def _chat(self, llm_bdl, system, history, gen_conf):
response = get_llm_cache(llm_bdl.llm_name, system, history, gen_conf)
if response:
return response
response = await llm_bdl.async_chat(system, history, gen_conf)
if response.find("**ERROR**") >= 0:
raise Exception(response)
set_llm_cache(llm_bdl.llm_name, system, response, history, gen_conf)
return response
async def query_rewrite(self, llm, question, idxnms, kb_ids):
ty2ents = await get_entity_type2samples(idxnms, kb_ids)
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
result = await self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {})
try:
keywords_data = json_repair.loads(result)
type_keywords = keywords_data.get("answer_type_keywords", [])
entities_from_query = keywords_data.get("entities_from_query", [])[:5]
return type_keywords, entities_from_query
except json_repair.JSONDecodeError:
try:
result = result.replace(hint_prompt[:-1], '').replace('user', '').replace('model', '').strip()
result = '{' + result.split('{')[1].split('}')[0] + '}'
keywords_data = json_repair.loads(result)
type_keywords = keywords_data.get("answer_type_keywords", [])
entities_from_query = keywords_data.get("entities_from_query", [])[:5]
return type_keywords, entities_from_query
# Handle parsing error
except Exception as e:
logging.exception(f"JSON parsing error: {result} -> {e}")
raise e
def _ent_info_from_(self, es_res, sim_thr=0.3):
res = {}
flds = ["content_with_weight", "_score", "entity_kwd", "rank_flt", "n_hop_with_weight"]
es_res = self.dataStore.get_fields(es_res, flds)
for _, ent in es_res.items():
for f in flds:
if f in ent and ent[f] is None:
del ent[f]
if get_float(ent.get("_score", 0)) < sim_thr:
continue
if isinstance(ent["entity_kwd"], list):
ent["entity_kwd"] = ent["entity_kwd"][0]
# n_hop_with_weight may be absent (older chunks) or an empty string
# (the Infinity column default), neither of which json.loads handles.
n_hop_raw = ent.get("n_hop_with_weight") or "[]"
try:
n_hop_ents = json.loads(n_hop_raw)
except (json.JSONDecodeError, TypeError):
logging.warning(f"Failed to parse n_hop_with_weight for entity {ent.get('entity_kwd')}: {n_hop_raw}")
n_hop_ents = []
res[ent["entity_kwd"]] = {
"sim": get_float(ent.get("_score", 0)),
"pagerank": get_float(ent.get("rank_flt", 0)),
"n_hop_ents": n_hop_ents,
"description": ent.get("content_with_weight", "{}")
}
return res
def _relation_info_from_(self, es_res, sim_thr=0.3):
res = {}
es_res = self.dataStore.get_fields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd",
"weight_int"])
for _, ent in es_res.items():
if get_float(ent.get("_score", 0)) < sim_thr:
continue
f, t = sorted([ent["from_entity_kwd"], ent["to_entity_kwd"]])
if isinstance(f, list):
f = f[0]
if isinstance(t, list):
t = t[0]
res[(f, t)] = {
"sim": get_float(ent.get("_score", 0)),
"pagerank": get_float(ent.get("weight_int", 0)),
"description": ent["content_with_weight"]
}
return res
def get_relevant_ents_by_keywords(self, keywords, filters, idxnms, kb_ids, emb_mdl, sim_thr=0.3, N=56):
if not keywords:
return {}
filters = deepcopy(filters)
filters["knowledge_graph_kwd"] = "entity"
matchDense = self.get_vector(", ".join(keywords), emb_mdl, 1024, sim_thr)
es_res = self.dataStore.search(["content_with_weight", "entity_kwd", "rank_flt", "n_hop_with_weight"], [], filters, [matchDense],
OrderByExpr(), 0, N,
idxnms, kb_ids)
return self._ent_info_from_(es_res, sim_thr)
def get_relevant_relations_by_txt(self, txt, filters, idxnms, kb_ids, emb_mdl, sim_thr=0.3, N=56):
if not txt:
return {}
filters = deepcopy(filters)
filters["knowledge_graph_kwd"] = "relation"
matchDense = self.get_vector(txt, emb_mdl, 1024, sim_thr)
es_res = self.dataStore.search(
["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd", "weight_int"],
[], filters, [matchDense], OrderByExpr(), 0, N, idxnms, kb_ids)
return self._relation_info_from_(es_res, sim_thr)
def get_relevant_ents_by_types(self, types, filters, idxnms, kb_ids, N=56):
if not types:
return {}
filters = deepcopy(filters)
filters["knowledge_graph_kwd"] = "entity"
filters["entity_type_kwd"] = types
ordr = OrderByExpr()
ordr.desc("rank_flt")
es_res = self.dataStore.search(["entity_kwd", "rank_flt"], [], filters, [], ordr, 0, N,
idxnms, kb_ids)
return self._ent_info_from_(es_res, 0)
async def retrieval(self, question: str,
tenant_ids: str | list[str],
kb_ids: list[str],
emb_mdl,
llm,
max_token: int = 8196,
ent_topn: int = 6,
rel_topn: int = 6,
comm_topn: int = 1,
ent_sim_threshold: float = 0.3,
rel_sim_threshold: float = 0.3,
**kwargs
):
qst = question
filters = self.get_filters({"kb_ids": kb_ids})
if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",")
idxnms = [index_name(tid) for tid in tenant_ids]
ty_kwds = []
try:
ty_kwds, ents = await self.query_rewrite(llm, qst, [index_name(tid) for tid in tenant_ids], kb_ids)
logging.info(f"Q: {qst}, Types: {ty_kwds}, Entities: {ents}")
except Exception as e:
logging.exception(e)
ents = [qst]
pass
ents_from_query = self.get_relevant_ents_by_keywords(ents, filters, idxnms, kb_ids, emb_mdl, ent_sim_threshold)
ents_from_types = self.get_relevant_ents_by_types(ty_kwds, filters, idxnms, kb_ids, 10000)
rels_from_txt = self.get_relevant_relations_by_txt(qst, filters, idxnms, kb_ids, emb_mdl, rel_sim_threshold)
nhop_pathes = defaultdict(dict)
for _, ent in ents_from_query.items():
nhops = ent.get("n_hop_ents", [])
if not isinstance(nhops, list):
logging.warning(f"Abnormal n_hop_ents: {nhops}")
continue
for nbr in nhops:
path = nbr["path"]
wts = nbr["weights"]
for i in range(len(path) - 1):
f, t = path[i], path[i + 1]
if (f, t) in nhop_pathes:
nhop_pathes[(f, t)]["sim"] += ent["sim"] / (2 + i)
else:
nhop_pathes[(f, t)]["sim"] = ent["sim"] / (2 + i)
nhop_pathes[(f, t)]["pagerank"] = max(
nhop_pathes[(f, t)].get("pagerank", 0), wts[i]
)
logging.info("Retrieved entities: {}".format(list(ents_from_query.keys())))
logging.info("Retrieved relations: {}".format(list(rels_from_txt.keys())))
logging.info("Retrieved entities from types({}): {}".format(ty_kwds, list(ents_from_types.keys())))
logging.info("Retrieved N-hops: {}".format(list(nhop_pathes.keys())))
# P(E|Q) => P(E) * P(Q|E) => pagerank * sim
for ent in ents_from_types.keys():
if ent not in ents_from_query:
continue
ents_from_query[ent]["sim"] *= 2
for (f, t) in rels_from_txt.keys():
pair = tuple(sorted([f, t]))
s = 0
if pair in nhop_pathes:
s += nhop_pathes[pair]["sim"]
del nhop_pathes[pair]
if f in ents_from_types:
s += 1
if t in ents_from_types:
s += 1
rels_from_txt[(f, t)]["sim"] *= s + 1
# This is for the relations from n-hop but not by query search
for (f, t) in nhop_pathes.keys():
s = 0
if f in ents_from_types:
s += 1
if t in ents_from_types:
s += 1
rels_from_txt[(f, t)] = {
"sim": nhop_pathes[(f, t)]["sim"] * (s + 1),
"pagerank": nhop_pathes[(f, t)]["pagerank"]
}
ents_from_query = sorted(ents_from_query.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[
:ent_topn]
rels_from_txt = sorted(rels_from_txt.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[
:rel_topn]
ents = []
relas = []
for n, ent in ents_from_query:
ents.append({
"Entity": n,
"Score": "%.2f" % (ent["sim"] * ent["pagerank"]),
"Description": json.loads(ent["description"]).get("description", "") if ent["description"] else ""
})
max_token -= num_tokens_from_string(str(ents[-1]))
if max_token <= 0:
ents = ents[:-1]
break
for (f, t), rel in rels_from_txt:
if not rel.get("description"):
for tid in tenant_ids:
rela = await get_relation(tid, kb_ids, f, t)
if rela:
break
else:
continue
rel["description"] = rela["description"]
desc = rel["description"]
try:
desc = json.loads(desc).get("description", "")
except Exception:
pass
relas.append({
"From Entity": f,
"To Entity": t,
"Score": "%.2f" % (rel["sim"] * rel["pagerank"]),
"Description": desc
})
max_token -= num_tokens_from_string(str(relas[-1]))
if max_token <= 0:
relas = relas[:-1]
break
if ents:
ents = "\n---- Entities ----\n{}".format(pd.DataFrame(ents).to_csv())
else:
ents = ""
if relas:
relas = "\n---- Relations ----\n{}".format(pd.DataFrame(relas).to_csv())
else:
relas = ""
return {
"chunk_id": get_uuid(),
"content_ltks": "",
"content_with_weight": ents + relas + self._community_retrieval_([n for n, _ in ents_from_query], filters, kb_ids, idxnms,
comm_topn, max_token),
"doc_id": "",
"docnm_kwd": "Related content in Knowledge Graph",
"kb_id": kb_ids,
"important_kwd": [],
"image_id": "",
"similarity": 1.,
"vector_similarity": 1.,
"term_similarity": 0,
"vector": [],
"positions": [],
}
def _community_retrieval_(self, entities, condition, kb_ids, idxnms, topn, max_token):
## Community retrieval
fields = ["docnm_kwd", "content_with_weight"]
odr = OrderByExpr()
odr.desc("weight_flt")
fltr = deepcopy(condition)
fltr["knowledge_graph_kwd"] = "community_report"
fltr["entities_kwd"] = entities
comm_res = self.dataStore.search(fields, [], fltr, [],
odr, 0, topn, idxnms, kb_ids)
comm_res_fields = self.dataStore.get_fields(comm_res, fields)
txts = []
for ii, (_, row) in enumerate(comm_res_fields.items()):
obj = json.loads(row["content_with_weight"])
txts.append("# {}. {}\n## Content\n{}\n## Evidences\n{}\n".format(
ii + 1, row["docnm_kwd"], obj["report"], obj["evidences"]))
max_token -= num_tokens_from_string(str(txts[-1]))
if not txts:
return ""
return "\n---- Community Report ----\n" + "\n".join(txts)
if __name__ == "__main__":
import argparse
from common.constants import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance
from rag.nlp import search
settings.init_settings()
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
parser.add_argument('-d', '--kb_id', default=False, help="Knowledge base ID", action='store', required=True)
parser.add_argument('-q', '--question', default=False, help="Question", action='store', required=True)
args = parser.parse_args()
kb_id = args.kb_id
llm_config = get_tenant_default_model_by_type(args.tenant_id, LLMType.CHAT)
llm_bdl = LLMBundle(args.tenant_id, llm_config)
_, kb = KnowledgebaseService.get_by_id(kb_id)
embd_model_config = get_model_config_from_provider_instance(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
embed_bdl = LLMBundle(args.tenant_id, embd_model_config)
kg = KGSearch(settings.docStoreConn)
print(asyncio.run(kg.retrieval({"question": args.question, "kb_ids": [kb_id]},
search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl)))