Refactor: enhance graphrag - part 2 (#14972)

### What problem does this PR solve?
1. expose batch_chunk_token_size for configuration
2. retrieve chunks when build subgraph for the doc, not retreive all
docs chunks at the begining
3. get all chunks for a document, used to be hard coded 10000
4. delete not used method run_graphrag

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring

Follow on: #14617
This commit is contained in:
Wang Qi
2026-05-18 16:10:21 +08:00
committed by GitHub
parent b12eaee38b
commit 13b422037f
15 changed files with 82 additions and 118 deletions

View File

@@ -439,6 +439,7 @@ def get_parser_config(chunk_method, parser_config):
"category",
],
"method": "light",
"batch_chunk_token_size": 4096,
},
"parent_child": {
"use_parent_child": False,

View File

@@ -362,6 +362,7 @@ class GraphragConfig(Base):
method: Annotated[Literal["light", "general", "ner"], Field(default="light")]
community: Annotated[bool, Field(default=False)]
resolution: Annotated[bool, Field(default=False)]
batch_chunk_token_size: Annotated[int, Field(default=4096, ge=512, le=8196)]
class ParentChildConfig(Base):

View File

@@ -54,6 +54,22 @@ from common import settings
from common.doc_store.doc_store_base import OrderByExpr
DEFAULT_GRAPHRAG_BATCH_CHUNK_TOKEN_SIZE = 4096
def _positive_int_config(config: dict, key: str, default: int) -> int:
value = config.get(key, default)
try:
value = int(value)
except (TypeError, ValueError):
logging.warning("Invalid GraphRAG config %s=%r, using default %s", key, value, default)
return default
if value < 512 or value > 8196:
logging.warning("Invalid GraphRAG config %s=%r, using default %s", key, value, default)
return default
return value
def _select_extractor(graphrag_config: dict):
"""Return the extractor class matching ``graphrag_config["method"]``.
@@ -121,100 +137,6 @@ async def load_subgraph_from_store(tenant_id: str, kb_id: str, doc_id: str):
return None
async def run_graphrag(
row: dict,
language,
with_resolution: bool,
with_community: bool,
chat_model,
embedding_model,
callback,
):
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
start = asyncio.get_running_loop().time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = []
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True):
chunks.append(d["content_with_weight"])
timeout_sec = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
try:
subgraph = await asyncio.wait_for(
generate_subgraph(
_select_extractor(row["kb_parser_config"].get("graphrag", {})),
tenant_id,
kb_id,
doc_id,
chunks,
language,
row["kb_parser_config"]["graphrag"].get("entity_types", []),
chat_model,
embedding_model,
callback,
),
timeout=timeout_sec,
)
except asyncio.TimeoutError:
logging.error("generate_subgraph timeout")
raise
if not subgraph:
return
graphrag_task_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value=doc_id, timeout=1200)
await graphrag_task_lock.spin_acquire()
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
try:
subgraph_nodes = set(subgraph.nodes())
new_graph = await merge_subgraph(
tenant_id,
kb_id,
doc_id,
subgraph,
embedding_model,
callback,
)
assert new_graph is not None
if not with_resolution and not with_community:
return
if with_resolution:
await graphrag_task_lock.spin_acquire()
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
await resolve_entities(
new_graph,
subgraph_nodes,
tenant_id,
kb_id,
doc_id,
chat_model,
embedding_model,
callback,
task_id=row["id"],
)
if with_community:
await graphrag_task_lock.spin_acquire()
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
await extract_community(
new_graph,
tenant_id,
kb_id,
doc_id,
chat_model,
embedding_model,
callback,
task_id=row["id"],
)
finally:
graphrag_task_lock.release()
now = asyncio.get_running_loop().time()
callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
return
async def run_graphrag_for_kb(
row: dict,
doc_ids: list[str],
@@ -232,6 +154,8 @@ async def run_graphrag_for_kb(
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
start = asyncio.get_running_loop().time()
fields_for_chunks = ["content_with_weight", "doc_id"]
graphrag_config = kb_parser_config.get("graphrag", {})
batch_chunk_token_size = _positive_int_config(graphrag_config, "batch_chunk_token_size", DEFAULT_GRAPHRAG_BATCH_CHUNK_TOKEN_SIZE)
if not doc_ids:
logging.info(f"Fetching all docs for {kb_id}")
@@ -259,21 +183,20 @@ async def run_graphrag_for_kb(
chunks = []
current_chunk = ""
# DEBUG: Obtener todos los chunks primero
raw_chunks = list(settings.retriever.chunk_list(
doc_id,
tenant_id,
[kb_id],
max_count=10000, # FIX: Aumentar límite para procesar todos los chunks
fields=fields_for_chunks,
sort_by_position=True,
retrieve_all=True
))
callback(msg=f"[DEBUG] chunk_list() returned {len(raw_chunks)} raw chunks for doc {doc_id}")
callback(msg=f"[GraphRAG] chunk_list returned {len(raw_chunks)} raw chunks for doc:{doc_id}")
for d in raw_chunks:
content = d["content_with_weight"]
if num_tokens_from_string(current_chunk + content) < 4096:
if num_tokens_from_string(current_chunk + content) < batch_chunk_token_size:
current_chunk += content
else:
if current_chunk:
@@ -285,16 +208,7 @@ async def run_graphrag_for_kb(
return chunks
all_doc_chunks: dict[str, list[str]] = {}
total_chunks = 0
for doc_id in doc_ids:
chunks = load_doc_chunks(doc_id)
all_doc_chunks[doc_id] = chunks
total_chunks += len(chunks)
if total_chunks == 0:
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
semaphore = asyncio.Semaphore(max_parallel_docs)
@@ -302,18 +216,13 @@ async def run_graphrag_for_kb(
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
async def build_one(doc_id: str):
nonlocal total_chunks
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled, stopping execution.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")
chunks = all_doc_chunks.get(doc_id, [])
if not chunks:
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
return
kg_extractor = _select_extractor(kb_parser_config.get("graphrag", {}))
deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
kg_extractor = _select_extractor(graphrag_config)
async with semaphore:
# CHECKPOINT: bounded by semaphore so doc-store lookups respect max_parallel_docs
@@ -323,6 +232,13 @@ async def run_graphrag_for_kb(
callback(msg=f"[GraphRAG] doc:{doc_id} subgraph found in store, skipping LLM extraction.")
return
try:
chunks = load_doc_chunks(doc_id)
total_chunks += len(chunks)
if not chunks:
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
return
deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
@@ -373,6 +289,10 @@ async def run_graphrag_for_kb(
await asyncio.gather(*tasks, return_exceptions=True)
raise
if total_chunks == 0 and not subgraphs:
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
return {"ok_docs": [], "failed_docs": [(doc_id, "no available chunks") for doc_id in doc_ids], "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled after document processing.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")

View File

@@ -753,7 +753,13 @@ class Dealer:
kb_ids: list[str], max_count=1024,
offset=0,
fields=["docnm_kwd", "content_with_weight", "img_id"],
sort_by_position: bool = False):
sort_by_position: bool = False,
retrieve_all: bool = False):
"""Return chunks for a document.
By default, preserve the historical max_count cap. When retrieve_all is
True, keep paging until the doc store returns fewer rows than requested.
"""
condition = {"doc_id": doc_id}
fields_set = set(fields or [])
@@ -771,8 +777,9 @@ class Dealer:
res = []
bs = 128
for p in range(offset, max_count, bs):
limit = min(bs, max_count - p)
p = offset
while retrieve_all or p < max_count:
limit = bs if retrieve_all else min(bs, max_count - p)
if limit <= 0:
break
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, limit, index_name(tenant_id),
@@ -785,6 +792,7 @@ class Dealer:
chunk_count = len(dict_chunks)
if chunk_count == 0 or chunk_count < limit:
break
p += limit
return res
def all_tags(self, tenant_id: str, kb_ids: list[str], S=1000):

View File

@@ -1390,6 +1390,7 @@ async def do_handle_task(task):
"category",
],
"method": "light",
"batch_chunk_token_size": 4096,
}
}
)

View File

@@ -65,6 +65,7 @@ DEFAULT_PARSER_CONFIG = {
"category",
],
"method": "light",
"batch_chunk_token_size": 4096,
},
"parent_child": {
"use_parent_child": False,

View File

@@ -387,6 +387,7 @@ DEFAULT_PARSER_CONFIG_FOR_TEST = {
"category",
],
"method": "light",
"batch_chunk_token_size": 4096,
},
}

View File

@@ -313,6 +313,7 @@ DEFAULT_PARSER_CONFIG_FOR_TEST = {
"category",
],
"method": "light",
"batch_chunk_token_size": 4096,
},
}

View File

@@ -1,3 +1,4 @@
import { FormLayout } from '@/constants/form';
import { DocumentParserType } from '@/constants/knowledge';
import { useTranslate } from '@/hooks/common-hooks';
import { cn } from '@/lib/utils';
@@ -12,6 +13,7 @@ import { useCallback, useMemo } from 'react';
import { useFormContext, useWatch } from 'react-hook-form';
import { EntityTypesFormField } from '../entity-types-form-field';
import { FormContainer } from '../form-container';
import { SliderInputFormField } from '../slider-input-form-field';
import {
FormControl,
FormField,
@@ -191,6 +193,19 @@ const GraphRagItems = ({
)}
/>
<SliderInputFormField
name="parser_config.graphrag.batch_chunk_token_size"
label={t('graphRagBatchChunkTokenSize')}
tooltip={t('graphRagBatchChunkTokenSizeTip')}
max={8196}
min={512}
step={1}
defaultValue={4096}
layout={FormLayout.Horizontal}
sliderTestId="ds-settings-graph-batch-chunk-token-size-slider"
numberInputTestId="ds-settings-graph-batch-chunk-token-size-input"
></SliderInputFormField>
<FormField
control={form.control}
name="parser_config.graphrag.resolution"

View File

@@ -89,6 +89,7 @@ interface Parentchild {
}
interface Graphrag {
batch_chunk_token_size?: number;
entity_types: string[];
method: string;
use_graphrag: boolean;

View File

@@ -67,6 +67,7 @@ interface Raptor {
}
interface GraphRag {
batch_chunk_token_size?: number;
community?: boolean;
entity_types?: string[];
method?: string;

View File

@@ -903,6 +903,9 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s
Light: (Default) Use prompts provided by github.com/HKUDS/LightRAG to extract entities and relationships. This option consumes fewer tokens, less memory, and fewer computational resources.</br>
General: Use prompts provided by github.com/microsoft/graphrag to extract entities and relationships.</br>
NER: Use spaCy NER and rule-based keyword extraction to extract entities and relationships. No LLM is required for extraction itself, making it fast and resource-efficient.`,
graphRagBatchChunkTokenSize: 'Batch chunk token size',
graphRagBatchChunkTokenSizeTip:
'The token limit for each batch of chunks sent to the LLM for knowledge graph entity and relation extraction. Not applied to NER.',
resolution: 'Entity resolution',
resolutionTip: `An entity deduplication switch. When enabled, the LLM will combine similar entities - e.g., '2025' and 'the year of 2025', or 'IT' and 'Information Technology' - to construct a more accurate graph`,
community: 'Community reports',

View File

@@ -818,6 +818,9 @@ export default {
graphRagMethodTip: `Light实体和关系提取提示来自 GitHub - HKUDS/LightRAG“LightRAG简单快速的检索增强生成”<br>
General实体和关系提取提示来自 GitHub - microsoft/graphrag基于图的模块化检索增强生成 (RAG) 系统<br>
NER使用 spaCy NER 和基于规则的关键词提取来抽取实体和关系,无需 LLM 参与提取过程,速度快且资源消耗低`,
graphRagBatchChunkTokenSize: '批量chunk token 大小',
graphRagBatchChunkTokenSizeTip:
'发送给 LLM 进行知识图谱实体和关系抽取时,每批文本块的 token 上限。NER 不适用。',
resolution: '实体归一化',
resolutionTip: `解析过程会将具有相同含义的实体合并在一起从而使知识图谱更简洁、更准确。应合并以下实体特朗普总统、唐纳德·特朗普、唐纳德·J·特朗普、唐纳德·约翰·特朗普`,
community: '社区报告生成',

View File

@@ -70,6 +70,12 @@ export const formSchema = z
method: z.string().optional(),
resolution: z.boolean().optional(),
community: z.boolean().optional(),
batch_chunk_token_size: z
.number()
.int()
.min(512)
.max(8196)
.optional(),
})
.refine(
(data) => {

View File

@@ -103,6 +103,7 @@ export default function DatasetSettings() {
use_graphrag: true,
entity_types: initialEntityTypes,
method: MethodValue.Light,
batch_chunk_token_size: 4096,
},
metadata: {
type: 'object',