mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-05 02:55:48 +08:00
Refactor: enhance graphrag - part 2 (#14972)
### What problem does this PR solve? 1. expose batch_chunk_token_size for configuration 2. retrieve chunks when build subgraph for the doc, not retreive all docs chunks at the begining 3. get all chunks for a document, used to be hard coded 10000 4. delete not used method run_graphrag ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring Follow on: #14617
This commit is contained in:
@@ -439,6 +439,7 @@ def get_parser_config(chunk_method, parser_config):
|
||||
"category",
|
||||
],
|
||||
"method": "light",
|
||||
"batch_chunk_token_size": 4096,
|
||||
},
|
||||
"parent_child": {
|
||||
"use_parent_child": False,
|
||||
|
||||
@@ -362,6 +362,7 @@ class GraphragConfig(Base):
|
||||
method: Annotated[Literal["light", "general", "ner"], Field(default="light")]
|
||||
community: Annotated[bool, Field(default=False)]
|
||||
resolution: Annotated[bool, Field(default=False)]
|
||||
batch_chunk_token_size: Annotated[int, Field(default=4096, ge=512, le=8196)]
|
||||
|
||||
|
||||
class ParentChildConfig(Base):
|
||||
|
||||
@@ -54,6 +54,22 @@ from common import settings
|
||||
from common.doc_store.doc_store_base import OrderByExpr
|
||||
|
||||
|
||||
DEFAULT_GRAPHRAG_BATCH_CHUNK_TOKEN_SIZE = 4096
|
||||
|
||||
|
||||
def _positive_int_config(config: dict, key: str, default: int) -> int:
|
||||
value = config.get(key, default)
|
||||
try:
|
||||
value = int(value)
|
||||
except (TypeError, ValueError):
|
||||
logging.warning("Invalid GraphRAG config %s=%r, using default %s", key, value, default)
|
||||
return default
|
||||
if value < 512 or value > 8196:
|
||||
logging.warning("Invalid GraphRAG config %s=%r, using default %s", key, value, default)
|
||||
return default
|
||||
return value
|
||||
|
||||
|
||||
def _select_extractor(graphrag_config: dict):
|
||||
"""Return the extractor class matching ``graphrag_config["method"]``.
|
||||
|
||||
@@ -121,100 +137,6 @@ async def load_subgraph_from_store(tenant_id: str, kb_id: str, doc_id: str):
|
||||
return None
|
||||
|
||||
|
||||
async def run_graphrag(
|
||||
row: dict,
|
||||
language,
|
||||
with_resolution: bool,
|
||||
with_community: bool,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
):
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
start = asyncio.get_running_loop().time()
|
||||
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
||||
chunks = []
|
||||
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
||||
chunks.append(d["content_with_weight"])
|
||||
|
||||
timeout_sec = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
|
||||
|
||||
try:
|
||||
subgraph = await asyncio.wait_for(
|
||||
generate_subgraph(
|
||||
_select_extractor(row["kb_parser_config"].get("graphrag", {})),
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
chunks,
|
||||
language,
|
||||
row["kb_parser_config"]["graphrag"].get("entity_types", []),
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
),
|
||||
timeout=timeout_sec,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logging.error("generate_subgraph timeout")
|
||||
raise
|
||||
|
||||
if not subgraph:
|
||||
return
|
||||
|
||||
graphrag_task_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value=doc_id, timeout=1200)
|
||||
await graphrag_task_lock.spin_acquire()
|
||||
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
|
||||
|
||||
try:
|
||||
subgraph_nodes = set(subgraph.nodes())
|
||||
new_graph = await merge_subgraph(
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
subgraph,
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
assert new_graph is not None
|
||||
|
||||
if not with_resolution and not with_community:
|
||||
return
|
||||
|
||||
if with_resolution:
|
||||
await graphrag_task_lock.spin_acquire()
|
||||
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
|
||||
await resolve_entities(
|
||||
new_graph,
|
||||
subgraph_nodes,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"],
|
||||
)
|
||||
if with_community:
|
||||
await graphrag_task_lock.spin_acquire()
|
||||
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
|
||||
await extract_community(
|
||||
new_graph,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"],
|
||||
)
|
||||
finally:
|
||||
graphrag_task_lock.release()
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
|
||||
return
|
||||
|
||||
|
||||
async def run_graphrag_for_kb(
|
||||
row: dict,
|
||||
doc_ids: list[str],
|
||||
@@ -232,6 +154,8 @@ async def run_graphrag_for_kb(
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
start = asyncio.get_running_loop().time()
|
||||
fields_for_chunks = ["content_with_weight", "doc_id"]
|
||||
graphrag_config = kb_parser_config.get("graphrag", {})
|
||||
batch_chunk_token_size = _positive_int_config(graphrag_config, "batch_chunk_token_size", DEFAULT_GRAPHRAG_BATCH_CHUNK_TOKEN_SIZE)
|
||||
|
||||
if not doc_ids:
|
||||
logging.info(f"Fetching all docs for {kb_id}")
|
||||
@@ -259,21 +183,20 @@ async def run_graphrag_for_kb(
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
|
||||
# DEBUG: Obtener todos los chunks primero
|
||||
raw_chunks = list(settings.retriever.chunk_list(
|
||||
doc_id,
|
||||
tenant_id,
|
||||
[kb_id],
|
||||
max_count=10000, # FIX: Aumentar límite para procesar todos los chunks
|
||||
fields=fields_for_chunks,
|
||||
sort_by_position=True,
|
||||
retrieve_all=True
|
||||
))
|
||||
|
||||
callback(msg=f"[DEBUG] chunk_list() returned {len(raw_chunks)} raw chunks for doc {doc_id}")
|
||||
callback(msg=f"[GraphRAG] chunk_list returned {len(raw_chunks)} raw chunks for doc:{doc_id}")
|
||||
|
||||
for d in raw_chunks:
|
||||
content = d["content_with_weight"]
|
||||
if num_tokens_from_string(current_chunk + content) < 4096:
|
||||
if num_tokens_from_string(current_chunk + content) < batch_chunk_token_size:
|
||||
current_chunk += content
|
||||
else:
|
||||
if current_chunk:
|
||||
@@ -285,16 +208,7 @@ async def run_graphrag_for_kb(
|
||||
|
||||
return chunks
|
||||
|
||||
all_doc_chunks: dict[str, list[str]] = {}
|
||||
total_chunks = 0
|
||||
for doc_id in doc_ids:
|
||||
chunks = load_doc_chunks(doc_id)
|
||||
all_doc_chunks[doc_id] = chunks
|
||||
total_chunks += len(chunks)
|
||||
|
||||
if total_chunks == 0:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
|
||||
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
|
||||
|
||||
semaphore = asyncio.Semaphore(max_parallel_docs)
|
||||
|
||||
@@ -302,18 +216,13 @@ async def run_graphrag_for_kb(
|
||||
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
|
||||
|
||||
async def build_one(doc_id: str):
|
||||
nonlocal total_chunks
|
||||
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled, stopping execution.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
chunks = all_doc_chunks.get(doc_id, [])
|
||||
if not chunks:
|
||||
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
|
||||
return
|
||||
|
||||
kg_extractor = _select_extractor(kb_parser_config.get("graphrag", {}))
|
||||
|
||||
deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
|
||||
kg_extractor = _select_extractor(graphrag_config)
|
||||
|
||||
async with semaphore:
|
||||
# CHECKPOINT: bounded by semaphore so doc-store lookups respect max_parallel_docs
|
||||
@@ -323,6 +232,13 @@ async def run_graphrag_for_kb(
|
||||
callback(msg=f"[GraphRAG] doc:{doc_id} subgraph found in store, skipping LLM extraction.")
|
||||
return
|
||||
try:
|
||||
chunks = load_doc_chunks(doc_id)
|
||||
total_chunks += len(chunks)
|
||||
if not chunks:
|
||||
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
|
||||
return
|
||||
|
||||
deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
|
||||
msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
|
||||
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
|
||||
|
||||
@@ -373,6 +289,10 @@ async def run_graphrag_for_kb(
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
if total_chunks == 0 and not subgraphs:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
|
||||
return {"ok_docs": [], "failed_docs": [(doc_id, "no available chunks") for doc_id in doc_ids], "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
|
||||
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled after document processing.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
@@ -753,7 +753,13 @@ class Dealer:
|
||||
kb_ids: list[str], max_count=1024,
|
||||
offset=0,
|
||||
fields=["docnm_kwd", "content_with_weight", "img_id"],
|
||||
sort_by_position: bool = False):
|
||||
sort_by_position: bool = False,
|
||||
retrieve_all: bool = False):
|
||||
"""Return chunks for a document.
|
||||
|
||||
By default, preserve the historical max_count cap. When retrieve_all is
|
||||
True, keep paging until the doc store returns fewer rows than requested.
|
||||
"""
|
||||
condition = {"doc_id": doc_id}
|
||||
|
||||
fields_set = set(fields or [])
|
||||
@@ -771,8 +777,9 @@ class Dealer:
|
||||
|
||||
res = []
|
||||
bs = 128
|
||||
for p in range(offset, max_count, bs):
|
||||
limit = min(bs, max_count - p)
|
||||
p = offset
|
||||
while retrieve_all or p < max_count:
|
||||
limit = bs if retrieve_all else min(bs, max_count - p)
|
||||
if limit <= 0:
|
||||
break
|
||||
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, limit, index_name(tenant_id),
|
||||
@@ -785,6 +792,7 @@ class Dealer:
|
||||
chunk_count = len(dict_chunks)
|
||||
if chunk_count == 0 or chunk_count < limit:
|
||||
break
|
||||
p += limit
|
||||
return res
|
||||
|
||||
def all_tags(self, tenant_id: str, kb_ids: list[str], S=1000):
|
||||
|
||||
@@ -1390,6 +1390,7 @@ async def do_handle_task(task):
|
||||
"category",
|
||||
],
|
||||
"method": "light",
|
||||
"batch_chunk_token_size": 4096,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
@@ -65,6 +65,7 @@ DEFAULT_PARSER_CONFIG = {
|
||||
"category",
|
||||
],
|
||||
"method": "light",
|
||||
"batch_chunk_token_size": 4096,
|
||||
},
|
||||
"parent_child": {
|
||||
"use_parent_child": False,
|
||||
|
||||
@@ -387,6 +387,7 @@ DEFAULT_PARSER_CONFIG_FOR_TEST = {
|
||||
"category",
|
||||
],
|
||||
"method": "light",
|
||||
"batch_chunk_token_size": 4096,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -313,6 +313,7 @@ DEFAULT_PARSER_CONFIG_FOR_TEST = {
|
||||
"category",
|
||||
],
|
||||
"method": "light",
|
||||
"batch_chunk_token_size": 4096,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { FormLayout } from '@/constants/form';
|
||||
import { DocumentParserType } from '@/constants/knowledge';
|
||||
import { useTranslate } from '@/hooks/common-hooks';
|
||||
import { cn } from '@/lib/utils';
|
||||
@@ -12,6 +13,7 @@ import { useCallback, useMemo } from 'react';
|
||||
import { useFormContext, useWatch } from 'react-hook-form';
|
||||
import { EntityTypesFormField } from '../entity-types-form-field';
|
||||
import { FormContainer } from '../form-container';
|
||||
import { SliderInputFormField } from '../slider-input-form-field';
|
||||
import {
|
||||
FormControl,
|
||||
FormField,
|
||||
@@ -191,6 +193,19 @@ const GraphRagItems = ({
|
||||
)}
|
||||
/>
|
||||
|
||||
<SliderInputFormField
|
||||
name="parser_config.graphrag.batch_chunk_token_size"
|
||||
label={t('graphRagBatchChunkTokenSize')}
|
||||
tooltip={t('graphRagBatchChunkTokenSizeTip')}
|
||||
max={8196}
|
||||
min={512}
|
||||
step={1}
|
||||
defaultValue={4096}
|
||||
layout={FormLayout.Horizontal}
|
||||
sliderTestId="ds-settings-graph-batch-chunk-token-size-slider"
|
||||
numberInputTestId="ds-settings-graph-batch-chunk-token-size-input"
|
||||
></SliderInputFormField>
|
||||
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="parser_config.graphrag.resolution"
|
||||
|
||||
@@ -89,6 +89,7 @@ interface Parentchild {
|
||||
}
|
||||
|
||||
interface Graphrag {
|
||||
batch_chunk_token_size?: number;
|
||||
entity_types: string[];
|
||||
method: string;
|
||||
use_graphrag: boolean;
|
||||
|
||||
@@ -67,6 +67,7 @@ interface Raptor {
|
||||
}
|
||||
|
||||
interface GraphRag {
|
||||
batch_chunk_token_size?: number;
|
||||
community?: boolean;
|
||||
entity_types?: string[];
|
||||
method?: string;
|
||||
|
||||
@@ -903,6 +903,9 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s
|
||||
Light: (Default) Use prompts provided by github.com/HKUDS/LightRAG to extract entities and relationships. This option consumes fewer tokens, less memory, and fewer computational resources.</br>
|
||||
General: Use prompts provided by github.com/microsoft/graphrag to extract entities and relationships.</br>
|
||||
NER: Use spaCy NER and rule-based keyword extraction to extract entities and relationships. No LLM is required for extraction itself, making it fast and resource-efficient.`,
|
||||
graphRagBatchChunkTokenSize: 'Batch chunk token size',
|
||||
graphRagBatchChunkTokenSizeTip:
|
||||
'The token limit for each batch of chunks sent to the LLM for knowledge graph entity and relation extraction. Not applied to NER.',
|
||||
resolution: 'Entity resolution',
|
||||
resolutionTip: `An entity deduplication switch. When enabled, the LLM will combine similar entities - e.g., '2025' and 'the year of 2025', or 'IT' and 'Information Technology' - to construct a more accurate graph`,
|
||||
community: 'Community reports',
|
||||
|
||||
@@ -818,6 +818,9 @@ export default {
|
||||
graphRagMethodTip: `Light:实体和关系提取提示来自 GitHub - HKUDS/LightRAG:“LightRAG:简单快速的检索增强生成”<br>
|
||||
General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于图的模块化检索增强生成 (RAG) 系统<br>
|
||||
NER:使用 spaCy NER 和基于规则的关键词提取来抽取实体和关系,无需 LLM 参与提取过程,速度快且资源消耗低`,
|
||||
graphRagBatchChunkTokenSize: '批量chunk token 大小',
|
||||
graphRagBatchChunkTokenSizeTip:
|
||||
'发送给 LLM 进行知识图谱实体和关系抽取时,每批文本块的 token 上限。NER 不适用。',
|
||||
resolution: '实体归一化',
|
||||
resolutionTip: `解析过程会将具有相同含义的实体合并在一起,从而使知识图谱更简洁、更准确。应合并以下实体:特朗普总统、唐纳德·特朗普、唐纳德·J·特朗普、唐纳德·约翰·特朗普`,
|
||||
community: '社区报告生成',
|
||||
|
||||
@@ -70,6 +70,12 @@ export const formSchema = z
|
||||
method: z.string().optional(),
|
||||
resolution: z.boolean().optional(),
|
||||
community: z.boolean().optional(),
|
||||
batch_chunk_token_size: z
|
||||
.number()
|
||||
.int()
|
||||
.min(512)
|
||||
.max(8196)
|
||||
.optional(),
|
||||
})
|
||||
.refine(
|
||||
(data) => {
|
||||
|
||||
@@ -103,6 +103,7 @@ export default function DatasetSettings() {
|
||||
use_graphrag: true,
|
||||
entity_types: initialEntityTypes,
|
||||
method: MethodValue.Light,
|
||||
batch_chunk_token_size: 4096,
|
||||
},
|
||||
metadata: {
|
||||
type: 'object',
|
||||
|
||||
Reference in New Issue
Block a user