fix: support auto mode in table parser document metadata aggregation (#15780)

### What problem does this PR solve?

Table parser metadata aggregation previously only ran when
`table_column_mode` was set to `manual`. In auto mode (default), all
columns default to `"both"` role, meaning they should also be aggregated
into document-level metadata for UI/chat filters. Additionally, the task
snapshot could be stale — `table_column_names` are written to KB
`parser_config` during `chunk()` but the task may have been created
before that.

Changes:
- Renames `aggregate_table_manual_doc_metadata` →
`aggregate_table_doc_metadata`
- Supports both `"manual"` and `"auto"` `table_column_mode` (defaults to
`"auto"`)
- Reloads `table_column_names` from KB DB when missing from task
snapshot
- Removes the manual-only guard in `task_executor` and refactored
`post_processor`
- Updates all tests with new function name and adds auto mode test cases

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
euvre
2026-06-08 04:08:23 -07:00
committed by GitHub
parent 2c64febc93
commit d9a04ef702
5 changed files with 277 additions and 289 deletions

View File

@@ -16,20 +16,21 @@ import argparse
import time import time
from rag.svr.task_executor_refactor.task_manager import TaskManager from rag.svr.task_executor_refactor.task_manager import TaskManager
from rag.svr.task_executor_refactor.recording_context import timed_with_recording, get_recording_context, \ from rag.svr.task_executor_refactor.recording_context import timed_with_recording, get_recording_context, RecordingContext, set_recording_context, NullRecordingContext
RecordingContext, set_recording_context, NullRecordingContext
start_ts = time.time() start_ts = time.time()
# LiteLLM fetches a model cost map from GitHub during import unless this is set. # LiteLLM fetches a model cost map from GitHub during import unless this is set.
# Parser pods should not block startup on external network access. # Parser pods should not block startup on external network access.
import os import os
os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "True") # no internet, save about 10s os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "True") # no internet, save about 10s
from common.misc_utils import thread_pool_exec from common.misc_utils import thread_pool_exec
import asyncio import asyncio
import socket import socket
# from beartype import BeartypeConf # from beartype import BeartypeConf
# from beartype.claw import beartype_all # <-- you didn't sign up for this # from beartype.claw import beartype_all # <-- you didn't sign up for this
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
@@ -56,8 +57,7 @@ from rag.utils.raptor_utils import (
from common.log_utils import init_root_logger from common.log_utils import init_root_logger
from common.config_utils import show_configs from common.config_utils import show_configs
from rag.graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache from rag.graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging, run_toc_from_text, \ from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging, run_toc_from_text, gen_metadata
gen_metadata
import logging import logging
import os import os
from datetime import datetime from datetime import datetime
@@ -82,8 +82,7 @@ from api.db.services.file2document_service import File2DocumentService
from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance
from common.versions import get_ragflow_version from common.versions import get_ragflow_version
from api.db.db_models import close_connection from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \ from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, email, tag
email, tag
from rag.nlp import search, rag_tokenizer, add_positions from rag.nlp import search, rag_tokenizer, add_positions
from rag.raptor import ( from rag.raptor import (
RAPTOR_TREE_BUILDER, RAPTOR_TREE_BUILDER,
@@ -103,7 +102,7 @@ from rag.svr.task_executor_limiter import (
from common import settings from common import settings
from common.constants import PAGERANK_FLD, TAG_FLD, SVR_CONSUMER_GROUP_NAME from common.constants import PAGERANK_FLD, TAG_FLD, SVR_CONSUMER_GROUP_NAME
from rag.utils.table_es_metadata import ( from rag.utils.table_es_metadata import (
aggregate_table_manual_doc_metadata, aggregate_table_doc_metadata,
merge_table_parser_config_from_kb, merge_table_parser_config_from_kb,
table_parser_strip_doc_metadata_keys, table_parser_strip_doc_metadata_keys,
) )
@@ -128,7 +127,7 @@ FACTORY = {
ParserType.AUDIO.value: audio, ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email, ParserType.EMAIL.value: email,
ParserType.KG.value: naive, ParserType.KG.value: naive,
ParserType.TAG.value: tag ParserType.TAG.value: tag,
} }
TASK_TYPE_TO_PIPELINE_TASK_TYPE = { TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
@@ -152,9 +151,10 @@ FAILED_TASKS = 0
CURRENT_TASKS = {} CURRENT_TASKS = {}
WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120')) WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get("WORKER_HEARTBEAT_TIMEOUT", "120"))
stop_event = threading.Event() stop_event = threading.Event()
def signal_handler(sig, frame): def signal_handler(sig, frame):
logging.info("Received interrupt signal, shutting down...") logging.info("Received interrupt signal, shutting down...")
stop_event.set() stop_event.set()
@@ -269,8 +269,7 @@ async def get_storage_binary(bucket, name):
@timeout(60 * 80, 1) @timeout(60 * 80, 1)
async def build_chunks(task, progress_callback): async def build_chunks(task, progress_callback):
if task["size"] > settings.DOC_MAXIMUM_SIZE: if task["size"] > settings.DOC_MAXIMUM_SIZE:
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" % set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" % (int(settings.DOC_MAXIMUM_SIZE / 1024 / 1024)))
(int(settings.DOC_MAXIMUM_SIZE / 1024 / 1024)))
get_recording_context().record("file_size_exceeded", True) get_recording_context().record("file_size_exceeded", True)
return [] return []
get_recording_context().record("file_size_exceeded", False) get_recording_context().record("file_size_exceeded", False)
@@ -286,8 +285,7 @@ async def build_chunks(task, progress_callback):
logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"])) logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
except TimeoutError: except TimeoutError:
progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.") progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
logging.exception( logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
"Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
raise raise
except Exception as e: except Exception as e:
if re.search("(No such file|not found)", str(e)): if re.search("(No such file|not found)", str(e)):
@@ -309,12 +307,11 @@ async def build_chunks(task, progress_callback):
# Record chunk configuration for comparison # Record chunk configuration for comparison
from common.float_utils import normalize_overlapped_percent from common.float_utils import normalize_overlapped_percent
chunk_config = { chunk_config = {
"parser_id": task["parser_id"], "parser_id": task["parser_id"],
"chunk_token_num": parser_config_for_chunk.get("chunk_token_num", 128), "chunk_token_num": parser_config_for_chunk.get("chunk_token_num", 128),
"overlapped_percent": normalize_overlapped_percent( "overlapped_percent": normalize_overlapped_percent(parser_config_for_chunk.get("overlapped_percent", 0)),
parser_config_for_chunk.get("overlapped_percent", 0)
),
"delimiter": parser_config_for_chunk.get("delimiter", "\n!?。;!?"), "delimiter": parser_config_for_chunk.get("delimiter", "\n!?。;!?"),
"from_page": task["from_page"], "from_page": task["from_page"],
"to_page": task["to_page"], "to_page": task["to_page"],
@@ -357,21 +354,14 @@ async def build_chunks(task, progress_callback):
if cks and cks[0].get("__outline__"): if cks and cks[0].get("__outline__"):
outline = cks[0].pop("__outline__") outline = cks[0].pop("__outline__")
try: try:
ret = DocMetadataService.update_document_metadata( ret = DocMetadataService.update_document_metadata(task["doc_id"], update_metadata_to({"outline": outline}, DocMetadataService.get_document_metadata(task["doc_id"]) or {}))
task["doc_id"],
update_metadata_to({"outline": outline},
DocMetadataService.get_document_metadata(task["doc_id"]) or {})
)
get_recording_context().save_func_return_value("DocMetadataService.update_document_metadata", ret) get_recording_context().save_func_return_value("DocMetadataService.update_document_metadata", ret)
logging.info("Persisted PDF outline (%d entries) for doc %s", len(outline), task["doc_id"]) logging.info("Persisted PDF outline (%d entries) for doc %s", len(outline), task["doc_id"])
except Exception as e: except Exception as e:
logging.warning("Failed to persist PDF outline for doc %s: %s", task["doc_id"], e) logging.warning("Failed to persist PDF outline for doc %s: %s", task["doc_id"], e)
docs = [] docs = []
doc = { doc = {"doc_id": task["doc_id"], "kb_id": str(task["kb_id"])}
"doc_id": task["doc_id"],
"kb_id": str(task["kb_id"])
}
if task["pagerank"]: if task["pagerank"]:
doc[PAGERANK_FLD] = int(task["pagerank"]) doc[PAGERANK_FLD] = int(task["pagerank"])
st = timer() st = timer()
@@ -381,8 +371,7 @@ async def build_chunks(task, progress_callback):
try: try:
d = copy.deepcopy(document) d = copy.deepcopy(document)
d.update(chunk) d.update(chunk)
d["id"] = xxhash.xxh64( d["id"] = xxhash.xxh64((chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
(chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
d["create_time"] = str(datetime.now()).replace("T", " ")[:19] d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp() d["create_timestamp_flt"] = datetime.now().timestamp()
@@ -398,8 +387,7 @@ async def build_chunks(task, progress_callback):
await image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=task["tenant_id"]), d["id"], task["kb_id"]) await image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=task["tenant_id"]), d["id"], task["kb_id"])
docs.append(d) docs.append(d)
except Exception: except Exception:
logging.exception( logging.exception("Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
"Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
raise raise
tasks = [] tasks = []
@@ -442,8 +430,7 @@ async def build_chunks(task, progress_callback):
tasks = [] tasks = []
for d in docs: for d in docs:
tasks.append( tasks.append(asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
try: try:
await asyncio.gather(*tasks, return_exceptions=False) await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e: except Exception as e:
@@ -479,8 +466,7 @@ async def build_chunks(task, progress_callback):
tasks = [] tasks = []
for d in docs: for d in docs:
tasks.append( tasks.append(asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
try: try:
await asyncio.gather(*tasks, return_exceptions=False) await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e: except Exception as e:
@@ -519,18 +505,14 @@ async def build_chunks(task, progress_callback):
metadata_conf = metadata_conf + built_in_metadata metadata_conf = metadata_conf + built_in_metadata
else: else:
metadata_conf = built_in_metadata metadata_conf = built_in_metadata
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata", cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata", metadata_conf)
metadata_conf)
if not cached: if not cached:
if has_canceled(task["id"]): if has_canceled(task["id"]):
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
return return
async with chat_limiter: async with chat_limiter:
cached = await gen_metadata(chat_mdl, cached = await gen_metadata(chat_mdl, turn2jsonschema(metadata_conf), d["content_with_weight"])
turn2jsonschema(metadata_conf), set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata", metadata_conf)
d["content_with_weight"])
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata",
metadata_conf)
if cached: if cached:
d["metadata_obj"] = cached d["metadata_obj"] = cached
@@ -584,8 +566,7 @@ async def build_chunks(task, progress_callback):
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
return None return None
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len( if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
d[TAG_FLD]) > 0:
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]}) examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
else: else:
docs_to_tag.append(d) docs_to_tag.append(d)
@@ -598,7 +579,7 @@ async def build_chunks(task, progress_callback):
return return
picked_examples = random.choices(examples, k=2) if len(examples) > 2 else examples picked_examples = random.choices(examples, k=2) if len(examples) > 2 else examples
if not picked_examples: if not picked_examples:
picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}}) picked_examples.append({"content": "This is an example", TAG_FLD: {"example": 1}})
async with chat_limiter: async with chat_limiter:
cached = await content_tagging( cached = await content_tagging(
chat_mdl, chat_mdl,
@@ -643,13 +624,15 @@ def build_TOC(task, docs, progress_callback):
progress_callback(msg="Start to generate table of content ...") progress_callback(msg="Start to generate table of content ...")
chat_model_config = get_model_config_from_provider_instance(task["tenant_id"], LLMType.CHAT, task["llm_id"]) chat_model_config = get_model_config_from_provider_instance(task["tenant_id"], LLMType.CHAT, task["llm_id"])
chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"]) chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"])
docs = sorted(docs, key=lambda d: ( docs = sorted(
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0), docs,
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0) key=lambda d: (
)) d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
toc: list[dict] = asyncio.run( d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0),
run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback)) ),
logging.info("------------ T O C -------------\n" + json.dumps(toc, ensure_ascii=False, indent=' ')) )
toc: list[dict] = asyncio.run(run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
logging.info("------------ T O C -------------\n" + json.dumps(toc, ensure_ascii=False, indent=" "))
for ii, item in enumerate(toc): for ii, item in enumerate(toc):
try: try:
chunk_val = item.pop("chunk_id", None) chunk_val = item.pop("chunk_id", None)
@@ -680,8 +663,7 @@ def build_TOC(task, docs, progress_callback):
d["toc_kwd"] = "toc" d["toc_kwd"] = "toc"
d["available_int"] = 0 d["available_int"] = 0
d["page_num_int"] = [100000000] d["page_num_int"] = [100000000]
d["id"] = xxhash.xxh64( d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
(d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
return d return d
return None return None
@@ -722,7 +704,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
cnts_batches = [] cnts_batches = []
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE): for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter: async with embed_limiter:
vts, c = await thread_pool_exec(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE]) vts, c = await thread_pool_exec(batch_encode, cnts[i : i + settings.EMBEDDING_BATCH_SIZE])
cnts_batches.append(vts) cnts_batches.append(vts)
tk_count += c tk_count += c
callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="") callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
@@ -775,8 +757,7 @@ async def run_dataflow(task: dict):
if not chunks: if not chunks:
get_recording_context().record("pipeline_output_count", 0) get_recording_context().record("pipeline_output_count", 0)
get_recording_context().record("pipeline_output_type", "empty") get_recording_context().record("pipeline_output_type", "empty")
ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret) get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret)
return return
@@ -806,8 +787,7 @@ async def run_dataflow(task: dict):
# An empty normalized payload means "nothing parsed", so stop before embedding/indexing. # An empty normalized payload means "nothing parsed", so stop before embedding/indexing.
if not chunks: if not chunks:
ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret) get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret)
return return
@@ -831,7 +811,7 @@ async def run_dataflow(task: dict):
prog = 0.8 prog = 0.8
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE): for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter: async with embed_limiter:
vts, c = await thread_pool_exec(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE]) vts, c = await thread_pool_exec(batch_encode, texts[i : i + settings.EMBEDDING_BATCH_SIZE])
vects_batches.append(vts) vects_batches.append(vts)
embedding_token_consumption += c embedding_token_consumption += c
prog += delta prog += delta
@@ -849,8 +829,7 @@ async def run_dataflow(task: dict):
raise raise
except Exception as e: except Exception as e:
set_progress(task_id, prog=-1, msg=f"[ERROR]: {e}") set_progress(task_id, prog=-1, msg=f"[ERROR]: {e}")
ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret) get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret)
return return
@@ -900,24 +879,21 @@ async def run_dataflow(task: dict):
set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...") set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...")
e = await insert_chunks(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000)) e = await insert_chunks(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
if not e: if not e:
ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret) get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret)
return return
time_cost = timer() - start_ts time_cost = timer() - start_ts
task_time_cost = timer() - task_start_ts task_time_cost = timer() - task_start_ts
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost)) set_progress(task_id, prog=1.0, msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
ret = DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks), ret = DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks), task_time_cost)
task_time_cost)
get_recording_context().save_func_return_value("DocumentService.increment_chunk_num", ret) get_recording_context().save_func_return_value("DocumentService.increment_chunk_num", ret)
logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
task_time_cost))
get_recording_context().record("dataflow_chunks", chunks) get_recording_context().record("dataflow_chunks", chunks)
ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
dsl=str(pipeline))
get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret) get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret)
RAPTOR_METHOD_SEARCH_LIMIT = 10000 RAPTOR_METHOD_SEARCH_LIMIT = 10000
@@ -928,11 +904,7 @@ async def get_raptor_chunk_field_map(doc_id: str, tenant_id: str, kb_id: str) ->
async def search_fields(fields: list[str], condition: dict, order_by=None): async def search_fields(fields: list[str], condition: dict, order_by=None):
"""Search chunk fields in the current knowledge base.""" """Search chunk fields in the current knowledge base."""
res = await thread_pool_exec( res = await thread_pool_exec(settings.docStoreConn.search, fields, [], condition, [], order_by or OrderByExpr(), 0, RAPTOR_METHOD_SEARCH_LIMIT, nlp_search.index_name(tenant_id), [kb_id])
settings.docStoreConn.search,
fields, [], condition, [], order_by or OrderByExpr(),
0, RAPTOR_METHOD_SEARCH_LIMIT, nlp_search.index_name(tenant_id), [kb_id]
)
return settings.docStoreConn.get_fields(res, fields) return settings.docStoreConn.get_fields(res, fields)
primary = await search_fields(["raptor_kwd", "extra"], {"doc_id": doc_id, "raptor_kwd": ["raptor"]}) primary = await search_fields(["raptor_kwd", "extra"], {"doc_id": doc_id, "raptor_kwd": ["raptor"]})
@@ -963,12 +935,17 @@ async def get_raptor_chunk_methods(doc_id: str, tenant_id: str, kb_id: str) -> s
if methods: if methods:
logging.info( logging.info(
"Checkpoint hit: RAPTOR chunks for doc %s (tenant=%s kb=%s methods=%s) already exist", "Checkpoint hit: RAPTOR chunks for doc %s (tenant=%s kb=%s methods=%s) already exist",
doc_id, tenant_id, kb_id, sorted(methods), doc_id,
tenant_id,
kb_id,
sorted(methods),
) )
else: else:
logging.info( logging.info(
"Checkpoint miss: no RAPTOR chunks for doc %s (tenant=%s kb=%s)", "Checkpoint miss: no RAPTOR chunks for doc %s (tenant=%s kb=%s)",
doc_id, tenant_id, kb_id, doc_id,
tenant_id,
kb_id,
) )
return methods return methods
except Exception: except Exception:
@@ -987,7 +964,9 @@ async def delete_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str, keep_met
if keep_method is None: if keep_method is None:
logging.info( logging.info(
"delete_raptor_chunks: removing all RAPTOR summaries (doc=%s tenant=%s kb=%s)", "delete_raptor_chunks: removing all RAPTOR summaries (doc=%s tenant=%s kb=%s)",
doc_id, tenant_id, kb_id, doc_id,
tenant_id,
kb_id,
) )
ret = await thread_pool_exec( ret = await thread_pool_exec(
settings.docStoreConn.delete, settings.docStoreConn.delete,
@@ -1003,13 +982,20 @@ async def delete_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str, keep_met
if not chunk_ids: if not chunk_ids:
logging.debug( logging.debug(
"delete_raptor_chunks: no stale RAPTOR chunks to remove (doc=%s tenant=%s kb=%s keep=%s)", "delete_raptor_chunks: no stale RAPTOR chunks to remove (doc=%s tenant=%s kb=%s keep=%s)",
doc_id, tenant_id, kb_id, keep_method, doc_id,
tenant_id,
kb_id,
keep_method,
) )
return 0 return 0
logging.info( logging.info(
"delete_raptor_chunks: removing %d stale RAPTOR chunks (doc=%s tenant=%s kb=%s keep=%s)", "delete_raptor_chunks: removing %d stale RAPTOR chunks (doc=%s tenant=%s kb=%s keep=%s)",
len(chunk_ids), doc_id, tenant_id, kb_id, keep_method, len(chunk_ids),
doc_id,
tenant_id,
kb_id,
keep_method,
) )
ret = await thread_pool_exec( ret = await thread_pool_exec(
settings.docStoreConn.delete, settings.docStoreConn.delete,
@@ -1073,6 +1059,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
nonlocal tk_count, res nonlocal tk_count, res
logging.info("RAPTOR: using tree_builder=%s clustering_method=%s for doc %s", tree_builder, clustering_method, did) logging.info("RAPTOR: using tree_builder=%s clustering_method=%s for doc %s", tree_builder, clustering_method, did)
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor # Lazy load, save around 8s from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor # Lazy load, save around 8s
raptor = Raptor( raptor = Raptor(
raptor_config.get("max_cluster", 64), raptor_config.get("max_cluster", 64),
chat_mdl, chat_mdl,
@@ -1132,7 +1119,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
for x, doc_id in enumerate(doc_ids): for x, doc_id in enumerate(doc_ids):
if skip_raptor_doc(doc_id): if skip_raptor_doc(doc_id):
callback(prog=(x + 1.) / len(doc_ids)) callback(prog=(x + 1.0) / len(doc_ids))
continue continue
# CHECKPOINT: skip docs that already have RAPTOR chunks in the doc store # CHECKPOINT: skip docs that already have RAPTOR chunks in the doc store
existing_methods = await get_raptor_chunk_methods(doc_id, row["tenant_id"], row["kb_id"]) existing_methods = await get_raptor_chunk_methods(doc_id, row["tenant_id"], row["kb_id"])
@@ -1142,16 +1129,14 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
schedule_raptor_cleanup(doc_id, tree_builder) schedule_raptor_cleanup(doc_id, tree_builder)
callback(msg=f"[RAPTOR] doc:{doc_id} will remove old RAPTOR summaries after insert.") callback(msg=f"[RAPTOR] doc:{doc_id} will remove old RAPTOR summaries after insert.")
callback(msg=f"[RAPTOR] doc:{doc_id} already has {tree_builder} RAPTOR chunks, skipping.") callback(msg=f"[RAPTOR] doc:{doc_id} already has {tree_builder} RAPTOR chunks, skipping.")
callback(prog=(x + 1.) / len(doc_ids)) callback(prog=(x + 1.0) / len(doc_ids))
continue continue
if existing_methods: if existing_methods:
callback(msg=f"[RAPTOR] doc:{doc_id} will migrate RAPTOR summaries to {tree_builder} after insert.") callback(msg=f"[RAPTOR] doc:{doc_id} will migrate RAPTOR summaries to {tree_builder} after insert.")
chunks = [] chunks = []
skipped_chunks = 0 skipped_chunks = 0
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm], sort_by_position=True):
fields=["content_with_weight", vctr_nm],
sort_by_position=True):
# Skip chunks that don't have the required vector field (may have been indexed with different embedding model) # Skip chunks that don't have the required vector field (may have been indexed with different embedding model)
if vctr_nm not in d or d[vctr_nm] is None: if vctr_nm not in d or d[vctr_nm] is None:
skipped_chunks += 1 skipped_chunks += 1
@@ -1173,7 +1158,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
has_file_level_target = True has_file_level_target = True
if existing_methods: if existing_methods:
schedule_raptor_cleanup(doc_id, tree_builder) schedule_raptor_cleanup(doc_id, tree_builder)
callback(prog=(x + 1.) / len(doc_ids)) callback(prog=(x + 1.0) / len(doc_ids))
if remove_dataset_summaries: if remove_dataset_summaries:
if has_file_level_target: if has_file_level_target:
@@ -1213,9 +1198,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
for doc_id in doc_ids: for doc_id in doc_ids:
if doc_id in skipped_doc_ids: if doc_id in skipped_doc_ids:
continue continue
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm], sort_by_position=True):
fields=["content_with_weight", vctr_nm],
sort_by_position=True):
# Skip chunks that don't have the required vector field # Skip chunks that don't have the required vector field
if vctr_nm not in d or d[vctr_nm] is None: if vctr_nm not in d or d[vctr_nm] is None:
skipped_chunks += 1 skipped_chunks += 1
@@ -1283,14 +1266,17 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre
mom_ck["available_int"] = 0 mom_ck["available_int"] = 0
flds = list(mom_ck.keys()) flds = list(mom_ck.keys())
for fld in flds: for fld in flds:
if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int", if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int", "position_int", "create_timestamp_flt", "page_num_int", "top_int"]:
"position_int", "create_timestamp_flt", "page_num_int", "top_int"]:
del mom_ck[fld] del mom_ck[fld]
mothers.append(mom_ck) mothers.append(mom_ck)
for b in range(0, len(mothers), settings.DOC_BULK_SIZE): for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
ret = await thread_pool_exec(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE], ret = await thread_pool_exec(
search.index_name(task_tenant_id), task_dataset_id, ) settings.docStoreConn.insert,
mothers[b : b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id),
task_dataset_id,
)
get_recording_context().save_func_return_value("docStoreConn.insert", ret) get_recording_context().save_func_return_value("docStoreConn.insert", ret)
task_canceled = has_canceled(task_id) task_canceled = has_canceled(task_id)
if task_canceled: if task_canceled:
@@ -1298,17 +1284,18 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre
return False return False
for b in range(0, len(chunks), settings.DOC_BULK_SIZE): for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
doc_store_result = await thread_pool_exec(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE], doc_store_result = await thread_pool_exec(
search.index_name(task_tenant_id), task_dataset_id, ) settings.docStoreConn.insert,
chunks[b : b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id),
task_dataset_id,
)
get_recording_context().save_func_return_value("docStoreConn.insert", doc_store_result) get_recording_context().save_func_return_value("docStoreConn.insert", doc_store_result)
task_canceled = has_canceled(task_id) task_canceled = has_canceled(task_id)
if task_canceled: if task_canceled:
# Roll back partial RAPTOR summary inserts so the next run is not # Roll back partial RAPTOR summary inserts so the next run is not
# mistaken for a completed checkpoint by get_raptor_chunk_methods. # mistaken for a completed checkpoint by get_raptor_chunk_methods.
raptor_ids_to_rollback = [ raptor_ids_to_rollback = [c["id"] for c in chunks[: b + settings.DOC_BULK_SIZE] if c.get("raptor_kwd") == "raptor"]
c["id"] for c in chunks[:b + settings.DOC_BULK_SIZE]
if c.get("raptor_kwd") == "raptor"
]
if raptor_ids_to_rollback: if raptor_ids_to_rollback:
try: try:
ret = await thread_pool_exec( ret = await thread_pool_exec(
@@ -1320,7 +1307,8 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre
get_recording_context().save_func_return_value("docStoreConn.delete", ret) get_recording_context().save_func_return_value("docStoreConn.delete", ret)
logging.info( logging.info(
"insert_chunks: rolled back %d partial RAPTOR chunks after cancellation (task=%s)", "insert_chunks: rolled back %d partial RAPTOR chunks after cancellation (task=%s)",
len(raptor_ids_to_rollback), task_id, len(raptor_ids_to_rollback),
task_id,
) )
except Exception: except Exception:
logging.exception( logging.exception(
@@ -1335,15 +1323,19 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
progress_callback(-1, msg=error_message) progress_callback(-1, msg=error_message)
raise Exception(error_message) raise Exception(error_message)
chunk_ids = [chunk["id"] for chunk in chunks[:b + settings.DOC_BULK_SIZE]] chunk_ids = [chunk["id"] for chunk in chunks[: b + settings.DOC_BULK_SIZE]]
chunk_ids_str = " ".join(chunk_ids) chunk_ids_str = " ".join(chunk_ids)
try: try:
TaskService.update_chunk_ids(task_id, chunk_ids_str) TaskService.update_chunk_ids(task_id, chunk_ids_str)
get_recording_context().save_func_return_value("TaskService.update_chunk_ids", None) get_recording_context().save_func_return_value("TaskService.update_chunk_ids", None)
except DoesNotExist: except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.") logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await thread_pool_exec(settings.docStoreConn.delete, {"id": chunk_ids}, doc_store_result = await thread_pool_exec(
search.index_name(task_tenant_id), task_dataset_id, ) settings.docStoreConn.delete,
{"id": chunk_ids},
search.index_name(task_tenant_id),
task_dataset_id,
)
get_recording_context().save_func_return_value("docStoreConn.delete", doc_store_result) get_recording_context().save_func_return_value("docStoreConn.delete", doc_store_result)
tasks = [] tasks = []
for chunk_id in chunk_ids: for chunk_id in chunk_ids:
@@ -1383,8 +1375,8 @@ async def do_handle_task(task):
if not task.get("language"): if not task.get("language"):
logging.warning("Task %s has no language set, falling back to Chinese", task_id) logging.warning("Task %s has no language set, falling back to Chinese", task_id)
doc_task_llm_id = task["parser_config"].get("llm_id") or task["llm_id"] doc_task_llm_id = task["parser_config"].get("llm_id") or task["llm_id"]
kb_task_llm_id = task['kb_parser_config'].get("llm_id") or task["llm_id"] kb_task_llm_id = task["kb_parser_config"].get("llm_id") or task["llm_id"]
task['llm_id'] = kb_task_llm_id task["llm_id"] = kb_task_llm_id
task_dataset_id = task["kb_id"] task_dataset_id = task["kb_id"]
task_doc_id = task["doc_id"] task_doc_id = task["doc_id"]
task_document_name = task["name"] task_document_name = task["name"]
@@ -1411,14 +1403,14 @@ async def do_handle_task(task):
vts, _ = embedding_model.encode(["ok"]) vts, _ = embedding_model.encode(["ok"])
vector_size = len(vts[0]) vector_size = len(vts[0])
except Exception as e: except Exception as e:
error_message = f'Fail to bind embedding model: {str(e)}' error_message = f"Fail to bind embedding model: {str(e)}"
progress_callback(-1, msg=error_message) progress_callback(-1, msg=error_message)
logging.exception(error_message) logging.exception(error_message)
raise raise
init_kb(task, vector_size) init_kb(task, vector_size)
if task_type[:len("dataflow")] == "dataflow": if task_type[: len("dataflow")] == "dataflow":
await run_dataflow(task) await run_dataflow(task)
return return
@@ -1517,7 +1509,8 @@ async def do_handle_task(task):
with_community = graphrag_conf.get("community", False) with_community = graphrag_conf.get("community", False)
async with kg_limiter: async with kg_limiter:
# await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback) # await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
from rag.graphrag.general.index import run_graphrag_for_kb # Lazy load, save around 2s from rag.graphrag.general.index import run_graphrag_for_kb # Lazy load, save around 2s
result = await run_graphrag_for_kb( result = await run_graphrag_for_kb(
row=task, row=task,
doc_ids=task.get("doc_ids", []), doc_ids=task.get("doc_ids", []),
@@ -1539,7 +1532,7 @@ async def do_handle_task(task):
return return
else: else:
# Standard chunking methods # Standard chunking methods
task['llm_id'] = doc_task_llm_id task["llm_id"] = doc_task_llm_id
start_ts = timer() start_ts = timer()
chunks = await build_chunks(task, progress_callback) chunks = await build_chunks(task, progress_callback)
get_recording_context().record("chunks", chunks) get_recording_context().record("chunks", chunks)
@@ -1549,7 +1542,7 @@ async def do_handle_task(task):
# Record chunks array for content comparison (first, middle, last, random) # Record chunks array for content comparison (first, middle, last, random)
logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts)) logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
if not chunks: if not chunks:
progress_callback(1., msg=f"No chunk built from {task_document_name}") progress_callback(1.0, msg=f"No chunk built from {task_document_name}")
return return
progress_callback(msg="Generate {} chunks".format(len(chunks))) progress_callback(msg="Generate {} chunks".format(len(chunks)))
start_ts = timer() start_ts = timer()
@@ -1605,52 +1598,45 @@ async def do_handle_task(task):
if cleaned_chunks: if cleaned_chunks:
progress_callback(msg=f"Cleaned up {cleaned_chunks} stale RAPTOR chunks.") progress_callback(msg=f"Cleaned up {cleaned_chunks} stale RAPTOR chunks.")
logging.info( logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts))
"Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(
task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts
)
)
ret = DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0) ret = DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
get_recording_context().save_func_return_value("DocumentService.increment_chunk_num", ret) get_recording_context().save_func_return_value("DocumentService.increment_chunk_num", ret)
# Table parser (manual): push metadata/both column values to document-level metadata for UI / chat filters # Table parser: push metadata/both column values to document-level metadata for UI / chat filters
if task.get("parser_id", "").lower() == "table": if task.get("parser_id", "").lower() == "table":
eff_pc = merge_table_parser_config_from_kb(task) eff_pc = merge_table_parser_config_from_kb(task)
logging.debug( logging.debug(f"[TABLE_META_DEBUG] table post-index: table_column_mode={eff_pc.get('table_column_mode')!r}")
f"[TABLE_META_DEBUG] table post-index: table_column_mode={eff_pc.get('table_column_mode')!r}" try:
) agg = aggregate_table_doc_metadata(chunks, task)
if eff_pc.get("table_column_mode") == "manual": logging.debug(f"[TABLE_META_DEBUG] aggregated metadata: {agg}")
strip_keys = table_parser_strip_doc_metadata_keys(eff_pc)
existing = DocMetadataService.get_document_metadata(task_doc_id)
existing = existing if isinstance(existing, dict) else {}
preserved = {k: v for k, v in existing.items() if k not in strip_keys}
merged = update_metadata_to(dict(preserved), agg)
logging.debug(
f"[TABLE_META_DEBUG] calling update_document_metadata for doc_id={task_doc_id}, "
f"meta_fields keys={list(merged.keys())}, "
f"table_strip_key_count={len(strip_keys)}, agg_keys={list(agg.keys())}"
)
try: try:
agg = aggregate_table_manual_doc_metadata(chunks, task) ret = DocMetadataService.update_document_metadata(task_doc_id, merged)
logging.debug(f"[TABLE_META_DEBUG] aggregated metadata: {agg}") get_recording_context().save_func_return_value("DocMetadataService.update_document_metadata", ret)
strip_keys = table_parser_strip_doc_metadata_keys(eff_pc) logging.debug("[TABLE_META_DEBUG] update_document_metadata succeeded")
existing = DocMetadataService.get_document_metadata(task_doc_id) except Exception as ue:
existing = existing if isinstance(existing, dict) else {} logging.error(
preserved = {k: v for k, v in existing.items() if k not in strip_keys} "update_document_metadata failed (table parser, doc_id=%s): %s",
merged = update_metadata_to(dict(preserved), agg)
logging.debug(
f"[TABLE_META_DEBUG] calling update_document_metadata for doc_id={task_doc_id}, "
f"meta_fields keys={list(merged.keys())}, "
f"table_strip_key_count={len(strip_keys)}, agg_keys={list(agg.keys())}"
)
try:
ret = DocMetadataService.update_document_metadata(task_doc_id, merged)
get_recording_context().save_func_return_value("DocMetadataService.update_document_metadata", ret)
logging.debug("[TABLE_META_DEBUG] update_document_metadata succeeded")
except Exception as ue:
logging.error(
"update_document_metadata failed (table parser, doc_id=%s): %s",
task_doc_id,
ue,
exc_info=True,
)
except Exception as e:
logging.exception(
"Table parser document metadata aggregation failed (doc_id=%s): %s",
task_doc_id, task_doc_id,
e, ue,
exc_info=True,
) )
except Exception as e:
logging.exception(
"Table parser document metadata aggregation failed (doc_id=%s): %s",
task_doc_id,
e,
)
progress_callback(msg="Indexing done ({:.2f}s).".format(timer() - start_ts)) progress_callback(msg="Indexing done ({:.2f}s).".format(timer() - start_ts))
@@ -1672,11 +1658,7 @@ async def do_handle_task(task):
task_time_cost = timer() - task_start_ts task_time_cost = timer() - task_start_ts
get_recording_context().record("task_status", "completed") get_recording_context().record("task_status", "completed")
progress_callback(prog=1.0, msg="Task done ({:.2f}s)".format(task_time_cost)) progress_callback(prog=1.0, msg="Task done ({:.2f}s)".format(task_time_cost))
logging.info( logging.info("Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), token_count, task_time_cost))
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(
task_document_name, task_from_page, task_to_page, len(chunks), token_count, task_time_cost
)
)
finally: finally:
if toc_thread is not None and not toc_thread.done(): if toc_thread is not None and not toc_thread.done():
@@ -1697,8 +1679,7 @@ async def do_handle_task(task):
) )
get_recording_context().save_func_return_value("docStoreConn.delete", ret) get_recording_context().save_func_return_value("docStoreConn.delete", ret)
except Exception as e: except Exception as e:
logging.exception( logging.exception(f"Remove doc({task_doc_id}) from docStore failed when task({task_id}) canceled, exception: {e}")
f"Remove doc({task_doc_id}) from docStore failed when task({task_id}) canceled, exception: {e}")
async def handle_task(): async def handle_task():
@@ -1709,8 +1690,7 @@ async def handle_task():
return return
task_type = task["task_type"] task_type = task["task_type"]
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, PipelineTaskType.PARSE) or PipelineTaskType.PARSE
PipelineTaskType.PARSE) or PipelineTaskType.PARSE
task_id = task["id"] task_id = task["id"]
try: try:
CURRENT_TASKS[task["id"]] = copy.deepcopy(task) CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
@@ -1718,20 +1698,18 @@ async def handle_task():
logging.info(f"TE_RUN_MODE is {run_mode}") logging.info(f"TE_RUN_MODE is {run_mode}")
# Check if dry-run comparison is enabled via environment variable # Check if dry-run comparison is enabled via environment variable
if run_mode == "1": # dry run mode - compare if run_mode == "1": # dry run mode - compare
set_recording_context(RecordingContext()) set_recording_context(RecordingContext())
await do_handle_task(task) # original execution await do_handle_task(task) # original execution
# dry run mode # dry run mode
logging.info(f"-----dry run task:{task_id}, {task.get('name', '')}, doc id:{task.get('doc_id', '')}") logging.info(f"-----dry run task:{task_id}, {task.get('name', '')}, doc id:{task.get('doc_id', '')}")
await TaskManager.dry_run_task(task, get_recording_context(), chat_limiter, minio_limiter, chunk_limiter, await TaskManager.dry_run_task(task, get_recording_context(), chat_limiter, minio_limiter, chunk_limiter, embed_limiter, kg_limiter, set_progress, has_canceled)
embed_limiter,kg_limiter, set_progress, has_canceled) elif run_mode == "0": # use refactor-ed version
elif run_mode == "0": # use refactor-ed version
# switch to refactor-ed version # switch to refactor-ed version
logging.info(f"-----run refactor-ed task executor:{task_id}, {task.get('name', '')}, doc id:{task.get('doc_id', '')}") logging.info(f"-----run refactor-ed task executor:{task_id}, {task.get('name', '')}, doc id:{task.get('doc_id', '')}")
set_recording_context(NullRecordingContext()) set_recording_context(NullRecordingContext())
await TaskManager.run_refactored_task(task, chat_limiter, minio_limiter, chunk_limiter, await TaskManager.run_refactored_task(task, chat_limiter, minio_limiter, chunk_limiter, embed_limiter, kg_limiter, set_progress, has_canceled)
embed_limiter,kg_limiter, set_progress, has_canceled) else: # original version
else: # original version
logging.info(f"-----run original task executor:{task_id}, {task.get('name', '')}, doc id:{task.get('doc_id', '')}") logging.info(f"-----run original task executor:{task_id}, {task.get('name', '')}, doc id:{task.get('doc_id', '')}")
set_recording_context(NullRecordingContext()) set_recording_context(NullRecordingContext())
await do_handle_task(task) await do_handle_task(task)
@@ -1742,9 +1720,7 @@ async def handle_task():
except TaskCanceledException as e: except TaskCanceledException as e:
DONE_TASKS += 1 DONE_TASKS += 1
CURRENT_TASKS.pop(task_id, None) CURRENT_TASKS.pop(task_id, None)
logging.info( logging.info(f"handle_task canceled for task {task_id}: {getattr(e, 'msg', str(e))}")
f"handle_task canceled for task {task_id}: {getattr(e, 'msg', str(e))}"
)
except Exception as e: except Exception as e:
FAILED_TASKS += 1 FAILED_TASKS += 1
CURRENT_TASKS.pop(task_id, None) CURRENT_TASKS.pop(task_id, None)
@@ -1752,7 +1728,7 @@ async def handle_task():
err_msg = str(e) err_msg = str(e)
while isinstance(e, exceptiongroup.ExceptionGroup): while isinstance(e, exceptiongroup.ExceptionGroup):
e = e.exceptions[0] e = e.exceptions[0]
err_msg += ' -- ' + str(e) err_msg += " -- " + str(e)
set_progress(task_id, prog=-1, msg=f"[Exception]: {err_msg}") set_progress(task_id, prog=-1, msg=f"[Exception]: {err_msg}")
except Exception as e: except Exception as e:
logging.exception(f"[Exception]: {str(e)}") logging.exception(f"[Exception]: {str(e)}")
@@ -1763,9 +1739,9 @@ async def handle_task():
referred_document_id = None referred_document_id = None
if task_type in ["graphrag", "raptor", "mindmap"]: if task_type in ["graphrag", "raptor", "mindmap"]:
referred_document_id = task["doc_ids"][0] referred_document_id = task["doc_ids"][0]
ret = PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="", ret = PipelineOperationLogService.record_pipeline_operation(
task_type=pipeline_task_type, document_id=task["doc_id"], pipeline_id="", task_type=pipeline_task_type, task_id=task_id, referred_document_id=referred_document_id
task_id=task_id, referred_document_id=referred_document_id) )
get_recording_context().save_func_return_value("PipelineOperationLogService.record_pipeline_operation", ret) get_recording_context().save_func_return_value("PipelineOperationLogService.record_pipeline_operation", ret)
redis_msg.ack() redis_msg.ack()
@@ -1779,7 +1755,7 @@ async def get_server_ip() -> str:
return s.getsockname()[0] return s.getsockname()[0]
except Exception as e: except Exception as e:
logging.error(str(e)) logging.error(str(e))
return 'Unknown' return "Unknown"
async def report_status(): async def report_status():
@@ -1804,18 +1780,20 @@ async def report_status():
LAG_TASKS = int(group_info.get("lag", 0)) LAG_TASKS = int(group_info.get("lag", 0))
current = copy.deepcopy(CURRENT_TASKS) current = copy.deepcopy(CURRENT_TASKS)
heartbeat = json.dumps({ heartbeat = json.dumps(
"ip_address": ip_address, {
"pid": pid, "ip_address": ip_address,
"name": CONSUMER_NAME, "pid": pid,
"now": now.astimezone().isoformat(timespec="milliseconds"), "name": CONSUMER_NAME,
"boot_at": BOOT_AT, "now": now.astimezone().isoformat(timespec="milliseconds"),
"pending": PENDING_TASKS, "boot_at": BOOT_AT,
"lag": LAG_TASKS, "pending": PENDING_TASKS,
"done": DONE_TASKS, "lag": LAG_TASKS,
"failed": FAILED_TASKS, "done": DONE_TASKS,
"current": current, "failed": FAILED_TASKS,
}) "current": current,
}
)
# Report heartbeat to Redis # Report heartbeat to Redis
try: try:
@@ -1890,17 +1868,17 @@ async def main():
/___/_/ /_/\__, /\___/____/\__/_/\____/_/ /_/ /____/\___/_/ |___/\___/_/ /___/_/ /_/\__, /\___/____/\__/_/\____/_/ /_/ /____/\___/_/ |___/\___/_/
/____/ /____/
""") """)
logging.info(f'RAGFlow ingestion version: {get_ragflow_version()}') logging.info(f"RAGFlow ingestion version: {get_ragflow_version()}")
logging.info(f"ENABLE_DRY_RUN_COMPARISON: {os.environ.get("ENABLE_DRY_RUN_COMPARISON", "0")}") logging.info(f"ENABLE_DRY_RUN_COMPARISON: {os.environ.get('ENABLE_DRY_RUN_COMPARISON', '0')}")
show_configs() show_configs()
settings.init_settings() settings.init_settings()
settings.check_and_install_torch() settings.check_and_install_torch()
logging.info(f'default embedding config: {settings.EMBEDDING_CFG}') logging.info(f"default embedding config: {settings.EMBEDDING_CFG}")
settings.print_rag_settings() settings.print_rag_settings()
if sys.platform != "win32": if sys.platform != "win32":
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot) signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
signal.signal(signal.SIGUSR2, stop_tracemalloc) signal.signal(signal.SIGUSR2, stop_tracemalloc)
TRACE_MALLOC_ENABLED = int(os.environ.get('TRACE_MALLOC_ENABLED', "0")) TRACE_MALLOC_ENABLED = int(os.environ.get("TRACE_MALLOC_ENABLED", "0"))
if TRACE_MALLOC_ENABLED: if TRACE_MALLOC_ENABLED:
start_tracemalloc_and_snapshot(None, None) start_tracemalloc_and_snapshot(None, None)
@@ -1927,16 +1905,16 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
# Parse command line arguments (consistent with SAAS version) # Parse command line arguments (consistent with SAAS version)
parser = argparse.ArgumentParser(description='Task Executor') parser = argparse.ArgumentParser(description="Task Executor")
parser.add_argument("-i", "--index", type=str, default='0') parser.add_argument("-i", "--index", type=str, default="0")
parser.add_argument("-t", "--type", type=str, default="common", help="[common, graphrag, raptor, resume]") parser.add_argument("-t", "--type", type=str, default="common", help="[common, graphrag, raptor, resume]")
args = parser.parse_args() args = parser.parse_args()
# Update global variables # Update global variables
TASK_TYPE = args.type TASK_TYPE = args.type
TE_IDX = args.index TE_IDX = args.index
CONSUMER_NAME = f"task_executor_{TASK_TYPE}_{TE_IDX}" CONSUMER_NAME = f"task_executor_{TASK_TYPE}_{TE_IDX}"
faulthandler.enable() faulthandler.enable()
init_root_logger(CONSUMER_NAME) init_root_logger(CONSUMER_NAME)
try: try:

View File

@@ -28,11 +28,12 @@ from api.db.services.doc_metadata_service import DocMetadataService
from common.metadata_utils import update_metadata_to from common.metadata_utils import update_metadata_to
from rag.svr.task_executor_refactor.task_context import TaskContext from rag.svr.task_executor_refactor.task_context import TaskContext
from rag.utils.table_es_metadata import ( from rag.utils.table_es_metadata import (
aggregate_table_manual_doc_metadata, aggregate_table_doc_metadata,
merge_table_parser_config_from_kb, merge_table_parser_config_from_kb,
table_parser_strip_doc_metadata_keys, table_parser_strip_doc_metadata_keys,
) )
class PostProcessor: class PostProcessor:
"""Service for post-indexing operations. """Service for post-indexing operations.
@@ -69,15 +70,10 @@ class PostProcessor:
return return
eff_pc = merge_table_parser_config_from_kb(ctx.raw_task) eff_pc = merge_table_parser_config_from_kb(ctx.raw_task)
logging.debug( logging.debug(f"[TABLE_META_DEBUG] table post-index: table_column_mode={eff_pc.get('table_column_mode')!r}")
f"[TABLE_META_DEBUG] table post-index: table_column_mode={eff_pc.get('table_column_mode')!r}"
)
if eff_pc.get("table_column_mode") != "manual":
return
try: try:
agg = aggregate_table_manual_doc_metadata(chunks, ctx.raw_task) agg = aggregate_table_doc_metadata(chunks, ctx.raw_task)
logging.debug(f"[TABLE_META_DEBUG] aggregated metadata: {agg}") logging.debug(f"[TABLE_META_DEBUG] aggregated metadata: {agg}")
strip_keys = table_parser_strip_doc_metadata_keys(eff_pc) strip_keys = table_parser_strip_doc_metadata_keys(eff_pc)

View File

@@ -102,9 +102,7 @@ def _probe_es_typed_key_for_column(col: str, sample_chunk: dict) -> str | None:
return None return None
def _resolve_es_chunk_field_key( def _resolve_es_chunk_field_key(col: str, field_map: dict, sample_chunk: dict | None) -> tuple[str | None, str]:
col: str, field_map: dict, sample_chunk: dict | None
) -> tuple[str | None, str]:
"""Prefer field_map when key exists on chunk; else probe by suffix (matches table.py naming).""" """Prefer field_map when key exists on chunk; else probe by suffix (matches table.py naming)."""
tk_fm = _field_map_typed_key_for_column(field_map, col) if field_map else None tk_fm = _field_map_typed_key_for_column(field_map, col) if field_map else None
if sample_chunk: if sample_chunk:
@@ -153,35 +151,44 @@ def _es_field_value_to_doc_metadata(val, *, from_tks_fallback: bool) -> str | No
return _value_to_meta_string(val) return _value_to_meta_string(val)
def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict: def aggregate_table_doc_metadata(chunks: list, task: dict) -> dict:
""" """
Collect unique values per metadata/both column across chunks for document-level metadata. Collect unique values per metadata/both column across chunks for document-level metadata.
Used when table_column_mode == manual (parallel to LLM gen_metadata, no schema required). Works for both table_column_mode == manual and auto (where all columns default to "both").
""" """
logging.debug( logging.debug(f"[TABLE_META_DEBUG] aggregate_table_doc_metadata called with {len(chunks)} chunks")
f"[TABLE_META_DEBUG] aggregate_table_manual_doc_metadata called with {len(chunks)} chunks"
)
eff = merge_table_parser_config_from_kb(task) eff = merge_table_parser_config_from_kb(task)
if eff.get("table_column_mode") != "manual": mode = eff.get("table_column_mode") or "auto"
logging.debug( if mode not in ("manual", "auto"):
f"[TABLE_META_DEBUG] skip aggregate: table_column_mode={eff.get('table_column_mode')!r}" logging.debug(f"[TABLE_META_DEBUG] skip aggregate: table_column_mode={mode!r}")
)
return {} return {}
roles = eff.get("table_column_roles") or {} roles = eff.get("table_column_roles") or {}
table_column_names = eff.get("table_column_names") or [] table_column_names = eff.get("table_column_names") or []
# Reload table_column_names from KB if empty (chunk() writes them during parse,
# but the task snapshot may be stale)
if not table_column_names:
kb_id = task.get("kb_id")
if kb_id:
try:
KBS = _knowledgebase_service_cls()
ok, kb = KBS.get_by_id(kb_id)
if ok and kb:
fresh_names = (kb.parser_config or {}).get("table_column_names") or []
if fresh_names:
table_column_names = fresh_names
logging.debug(f"[TABLE_META_DEBUG] reloaded table_column_names from DB: {fresh_names}")
except Exception as e:
logging.debug(
"[TABLE_META_DEBUG] failed to reload table_column_names from DB: %s",
e,
exc_info=True,
)
if table_column_names: if table_column_names:
meta_cols = [ meta_cols = [col for col in table_column_names if roles.get(col, "both") in ("metadata", "both")]
col
for col in table_column_names
if roles.get(col, "both") in ("metadata", "both")
]
else: else:
meta_cols = [c for c, r in roles.items() if r in ("metadata", "both")] meta_cols = [c for c, r in roles.items() if r in ("metadata", "both")]
if not meta_cols: if not meta_cols:
logging.debug( logging.debug(f"[TABLE_META_DEBUG] skip aggregate: no metadata/both columns (table_column_names_present={bool(table_column_names)})")
"[TABLE_META_DEBUG] skip aggregate: no metadata/both columns "
f"(table_column_names_present={bool(table_column_names)})"
)
return {} return {}
fm = (task.get("kb_parser_config") or {}).get("field_map") or {} fm = (task.get("kb_parser_config") or {}).get("field_map") or {}
kb_id = task.get("kb_id") kb_id = task.get("kb_id")
@@ -194,14 +201,9 @@ def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict:
reloaded = fresh_pc.get("field_map") or {} reloaded = fresh_pc.get("field_map") or {}
if reloaded: if reloaded:
fm = reloaded fm = reloaded
logging.debug( logging.debug(f"[TABLE_META_DEBUG] reloaded field_map from DB: {len(fm)} entries")
f"[TABLE_META_DEBUG] reloaded field_map from DB: {len(fm)} entries"
)
else: else:
logging.debug( logging.debug("[TABLE_META_DEBUG] KB reload: parser_config has no field_map yet; will use ES key probe on chunk dicts if applicable")
"[TABLE_META_DEBUG] KB reload: parser_config has no field_map yet; "
"will use ES key probe on chunk dicts if applicable"
)
except Exception as e: except Exception as e:
logging.debug( logging.debug(
"[TABLE_META_DEBUG] failed to reload field_map from DB: %s", "[TABLE_META_DEBUG] failed to reload field_map from DB: %s",
@@ -209,21 +211,11 @@ def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict:
exc_info=True, exc_info=True,
) )
if not fm and not (settings.DOC_ENGINE_INFINITY or settings.DOC_ENGINE_OCEANBASE): if not fm and not (settings.DOC_ENGINE_INFINITY or settings.DOC_ENGINE_OCEANBASE):
logging.debug( logging.debug(f"[TABLE_META_DEBUG] field_map empty on task snapshot — will use ES key probe on chunk dicts; kb_parser_config keys={list((task.get('kb_parser_config') or {}).keys())}")
"[TABLE_META_DEBUG] field_map empty on task snapshot — will use ES key probe on chunk dicts; " logging.debug(f"[TABLE_META_DEBUG] meta_cols={meta_cols}, field_map entries={len(fm)}, infinity={settings.DOC_ENGINE_INFINITY}, oceanbase={settings.DOC_ENGINE_OCEANBASE}")
f"kb_parser_config keys={list((task.get('kb_parser_config') or {}).keys())}"
)
logging.debug(
f"[TABLE_META_DEBUG] meta_cols={meta_cols}, field_map entries={len(fm)}, "
f"infinity={settings.DOC_ENGINE_INFINITY}, oceanbase={settings.DOC_ENGINE_OCEANBASE}"
)
sample_ck = next((c for c in chunks if isinstance(c, dict)), None) sample_ck = next((c for c in chunks if isinstance(c, dict)), None)
if sample_ck: if sample_ck:
sk = [ sk = [k for k in sample_ck.keys() if not (str(k).startswith("q_") and str(k).endswith("_vec"))][:50]
k
for k in sample_ck.keys()
if not (str(k).startswith("q_") and str(k).endswith("_vec"))
][:50]
logging.debug(f"[TABLE_META_DEBUG] first chunk non-vector keys (sample): {sk}") logging.debug(f"[TABLE_META_DEBUG] first chunk non-vector keys (sample): {sk}")
es_col_keys: dict[str, tuple[str | None, str]] = {} es_col_keys: dict[str, tuple[str | None, str]] = {}
@@ -231,9 +223,7 @@ def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict:
for col in meta_cols: for col in meta_cols:
tk, src = _resolve_es_chunk_field_key(col, fm, sample_ck) tk, src = _resolve_es_chunk_field_key(col, fm, sample_ck)
es_col_keys[col] = (tk, src) es_col_keys[col] = (tk, src)
logging.debug( logging.debug(f"[TABLE_META_DEBUG] column '{col}' -> ES key {tk!r} (source={src})")
f"[TABLE_META_DEBUG] column '{col}' -> ES key {tk!r} (source={src})"
)
acc: dict[str, list] = {c: [] for c in meta_cols} acc: dict[str, list] = {c: [] for c in meta_cols}
@@ -255,9 +245,7 @@ def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict:
tk, _src = es_col_keys.get(col, (None, "none")) tk, _src = es_col_keys.get(col, (None, "none"))
if not tk: if not tk:
if i == 0: if i == 0:
logging.debug( logging.debug(f"[TABLE_META_DEBUG] no resolved ES key for column '{col}'")
f"[TABLE_META_DEBUG] no resolved ES key for column '{col}'"
)
continue continue
raw_k = _es_raw_field_key_from_typed(tk) raw_k = _es_raw_field_key_from_typed(tk)
val = None val = None
@@ -269,10 +257,7 @@ def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict:
from_tks = tk.endswith("_tks") from_tks = tk.endswith("_tks")
else: else:
if i == 0: if i == 0:
logging.debug( logging.debug(f"[TABLE_META_DEBUG] chunk missing ES field {tk!r}{' and ' + raw_k + ' (raw)' if raw_k else ''} for column '{col}'")
f"[TABLE_META_DEBUG] chunk missing ES field {tk!r}"
f"{' and ' + raw_k + ' (raw)' if raw_k else ''} for column '{col}'"
)
continue continue
s = _es_field_value_to_doc_metadata(val, from_tks_fallback=from_tks) s = _es_field_value_to_doc_metadata(val, from_tks_fallback=from_tks)
if s is not None: if s is not None:
@@ -289,8 +274,5 @@ def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict:
for col, vals in acc.items(): for col, vals in acc.items():
if vals: if vals:
out[col] = dedupe_list(vals) out[col] = dedupe_list(vals)
logging.debug( logging.debug(f"[TABLE_META_DEBUG] aggregated metadata dict keys={list(out.keys())}, sizes={[len(v) for v in out.values()]}")
f"[TABLE_META_DEBUG] aggregated metadata dict keys={list(out.keys())}, "
f"sizes={[len(v) for v in out.values()]}"
)
return out return out

View File

@@ -69,12 +69,13 @@ class TestPostProcessorProcessTableParserMetadata:
ctx.write_interceptor = None ctx.write_interceptor = None
chunks = [{"col_key": "val"}] chunks = [{"col_key": "val"}]
with patch("rag.svr.task_executor_refactor.post_processor.merge_table_parser_config_from_kb") as mock_merge, \ with (
patch("rag.svr.task_executor_refactor.post_processor.aggregate_table_manual_doc_metadata") as mock_agg, \ patch("rag.svr.task_executor_refactor.post_processor.merge_table_parser_config_from_kb") as mock_merge,
patch("rag.svr.task_executor_refactor.post_processor.table_parser_strip_doc_metadata_keys") as mock_strip, \ patch("rag.svr.task_executor_refactor.post_processor.aggregate_table_doc_metadata") as mock_agg,
patch("rag.svr.task_executor_refactor.post_processor.update_metadata_to") as mock_update, \ patch("rag.svr.task_executor_refactor.post_processor.table_parser_strip_doc_metadata_keys") as mock_strip,
patch("rag.svr.task_executor_refactor.post_processor.DocMetadataService") as mock_meta: patch("rag.svr.task_executor_refactor.post_processor.update_metadata_to") as mock_update,
patch("rag.svr.task_executor_refactor.post_processor.DocMetadataService") as mock_meta,
):
mock_merge.return_value = {"table_column_mode": "manual"} mock_merge.return_value = {"table_column_mode": "manual"}
mock_agg.return_value = {"col_key": ["val1", "val2"]} mock_agg.return_value = {"col_key": ["val1", "val2"]}
mock_strip.return_value = set() mock_strip.return_value = set()
@@ -95,11 +96,12 @@ class TestPostProcessorProcessTableParserMetadata:
ctx.raw_task = {} ctx.raw_task = {}
ctx.write_interceptor = MagicMock() ctx.write_interceptor = MagicMock()
with patch("rag.svr.task_executor_refactor.post_processor.merge_table_parser_config_from_kb") as mock_merge, \ with (
patch("rag.svr.task_executor_refactor.post_processor.aggregate_table_manual_doc_metadata") as mock_agg, \ patch("rag.svr.task_executor_refactor.post_processor.merge_table_parser_config_from_kb") as mock_merge,
patch("rag.svr.task_executor_refactor.post_processor.table_parser_strip_doc_metadata_keys") as mock_strip, \ patch("rag.svr.task_executor_refactor.post_processor.aggregate_table_doc_metadata") as mock_agg,
patch("rag.svr.task_executor_refactor.post_processor.DocMetadataService") as mock_meta: patch("rag.svr.task_executor_refactor.post_processor.table_parser_strip_doc_metadata_keys") as mock_strip,
patch("rag.svr.task_executor_refactor.post_processor.DocMetadataService") as mock_meta,
):
mock_merge.return_value = {"table_column_mode": "manual"} mock_merge.return_value = {"table_column_mode": "manual"}
mock_agg.return_value = {"key": ["v"]} mock_agg.return_value = {"key": ["v"]}
mock_strip.return_value = set() mock_strip.return_value = set()
@@ -108,9 +110,7 @@ class TestPostProcessorProcessTableParserMetadata:
service = PostProcessor(ctx=ctx) service = PostProcessor(ctx=ctx)
await service.process_table_parser_metadata("doc_1", []) await service.process_table_parser_metadata("doc_1", [])
ctx.write_interceptor.intercept.assert_called_once_with( ctx.write_interceptor.intercept.assert_called_once_with("DocMetadataService.update_document_metadata")
"DocMetadataService.update_document_metadata"
)
class TestPostProcessorInsertTocChunk: class TestPostProcessorInsertTocChunk:
@@ -160,9 +160,7 @@ class TestPostProcessorInsertTocChunk:
result = await service.insert_toc_chunk(toc_chunk, chunk_service) result = await service.insert_toc_chunk(toc_chunk, chunk_service)
assert result is True assert result is True
chunk_service.insert_chunks.assert_called_once_with( chunk_service.insert_chunks.assert_called_once_with("task_1", "tenant_1", "kb_1", [toc_chunk])
"task_1", "tenant_1", "kb_1", [toc_chunk]
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handles_insert_failure(self): async def test_handles_insert_failure(self):
@@ -179,4 +177,4 @@ class TestPostProcessorInsertTocChunk:
result = await service.insert_toc_chunk(toc_chunk, chunk_service) result = await service.insert_toc_chunk(toc_chunk, chunk_service)
assert result is False assert result is False

View File

@@ -14,11 +14,11 @@
# limitations under the License. # limitations under the License.
# #
"""Unit tests for aggregate_table_manual_doc_metadata.""" """Unit tests for aggregate_table_doc_metadata."""
import pytest import pytest
from rag.utils.table_es_metadata import aggregate_table_manual_doc_metadata, merge_table_parser_config_from_kb from rag.utils.table_es_metadata import aggregate_table_doc_metadata, merge_table_parser_config_from_kb
@pytest.fixture @pytest.fixture
@@ -73,30 +73,64 @@ class TestAggregateTableManualDocMetadata:
"category_tks": "y", "category_tks": "y",
}, },
] ]
out = aggregate_table_manual_doc_metadata(chunks, task) out = aggregate_table_doc_metadata(chunks, task)
assert out["country"] == ["Brazil", "Turkey"] assert out["country"] == ["Brazil", "Turkey"]
assert out["category"] == ["Economy", "Disaster"] assert out["category"] == ["Economy", "Disaster"]
def test_aggregate_auto_mode_returns_empty(self, es_engine): def test_aggregate_auto_mode_returns_data(self, es_engine):
task = { task = {
"parser_id": "table", "parser_id": "table",
"parser_config": {}, "parser_config": {},
"kb_parser_config": { "kb_parser_config": {
"table_column_mode": "auto", "table_column_mode": "auto",
"table_column_roles": {"country": "metadata"}, "table_column_names": ["country"],
"field_map": {"country_tks": "country"},
}, },
} }
assert aggregate_table_manual_doc_metadata([{"country_tks": "x"}], task) == {} chunks = [{"country_raw": "Brazil", "country_tks": "x"}]
out = aggregate_table_doc_metadata(chunks, task)
assert out == {"country": ["Brazil"]}
def test_aggregate_no_mode_returns_empty(self, es_engine): def test_aggregate_auto_mode_all_columns_both(self, es_engine):
task = { task = {
"parser_id": "table", "parser_id": "table",
"parser_config": {}, "parser_config": {},
"kb_parser_config": { "kb_parser_config": {
"table_column_roles": {"country": "metadata"}, "table_column_mode": "auto",
"table_column_names": ["country", "category"],
"field_map": {"country_tks": "country", "category_tks": "category"},
}, },
} }
assert aggregate_table_manual_doc_metadata([{}], task) == {} chunks = [
{"country_raw": "Brazil", "country_tks": "x", "category_raw": "Economy", "category_tks": "y"},
{"country_raw": "Turkey", "country_tks": "x", "category_raw": "Disaster", "category_tks": "y"},
]
out = aggregate_table_doc_metadata(chunks, task)
assert out["country"] == ["Brazil", "Turkey"]
assert out["category"] == ["Economy", "Disaster"]
def test_aggregate_no_mode_defaults_to_auto(self, es_engine):
"""When table_column_mode is missing, it defaults to 'auto'."""
task = {
"parser_id": "table",
"parser_config": {},
"kb_parser_config": {
"table_column_names": ["country"],
"field_map": {"country_tks": "country"},
},
}
chunks = [{"country_raw": "Brazil", "country_tks": "x"}]
out = aggregate_table_doc_metadata(chunks, task)
assert out == {"country": ["Brazil"]}
def test_aggregate_no_mode_no_columns_returns_empty(self, es_engine):
"""No mode and no column names/roles -> empty (nothing to aggregate)."""
task = {
"parser_id": "table",
"parser_config": {},
"kb_parser_config": {},
}
assert aggregate_table_doc_metadata([{}], task) == {}
def test_aggregate_no_metadata_columns(self, es_engine): def test_aggregate_no_metadata_columns(self, es_engine):
task = { task = {
@@ -108,14 +142,14 @@ class TestAggregateTableManualDocMetadata:
"table_column_names": ["country"], "table_column_names": ["country"],
}, },
} }
assert aggregate_table_manual_doc_metadata([{"country_tks": "x"}], task) == {} assert aggregate_table_doc_metadata([{"country_tks": "x"}], task) == {}
def test_aggregate_prefers_raw_over_tks(self, es_engine): def test_aggregate_prefers_raw_over_tks(self, es_engine):
task = _table_task() task = _table_task()
task["kb_parser_config"]["table_column_roles"] = {"country": "metadata"} task["kb_parser_config"]["table_column_roles"] = {"country": "metadata"}
task["kb_parser_config"]["table_column_names"] = ["country"] task["kb_parser_config"]["table_column_names"] = ["country"]
chunks = [{"country_raw": "Brazil", "country_tks": ["brazil"]}] chunks = [{"country_raw": "Brazil", "country_tks": ["brazil"]}]
out = aggregate_table_manual_doc_metadata(chunks, task) out = aggregate_table_doc_metadata(chunks, task)
assert out == {"country": ["Brazil"]} assert out == {"country": ["Brazil"]}
def test_aggregate_tks_fallback(self, es_engine): def test_aggregate_tks_fallback(self, es_engine):
@@ -123,7 +157,7 @@ class TestAggregateTableManualDocMetadata:
task["kb_parser_config"]["table_column_roles"] = {"country": "metadata"} task["kb_parser_config"]["table_column_roles"] = {"country": "metadata"}
task["kb_parser_config"]["table_column_names"] = ["country"] task["kb_parser_config"]["table_column_names"] = ["country"]
chunks = [{"country_tks": ["brazil"]}] chunks = [{"country_tks": ["brazil"]}]
out = aggregate_table_manual_doc_metadata(chunks, task) out = aggregate_table_doc_metadata(chunks, task)
assert out == {"country": ["brazil"]} assert out == {"country": ["brazil"]}
def test_aggregate_partial_roles_defaults_to_both(self, es_engine): def test_aggregate_partial_roles_defaults_to_both(self, es_engine):
@@ -138,7 +172,7 @@ class TestAggregateTableManualDocMetadata:
}, },
} }
chunks = [{"city_raw": "SP", "city_tks": "t", "country_tks": "x"}] chunks = [{"city_raw": "SP", "city_tks": "t", "country_tks": "x"}]
out = aggregate_table_manual_doc_metadata(chunks, task) out = aggregate_table_doc_metadata(chunks, task)
assert out == {"city": ["SP"]} assert out == {"city": ["SP"]}
assert "country" not in out assert "country" not in out
@@ -156,7 +190,7 @@ class TestAggregateTableManualDocMetadata:
chunks = [ chunks = [
{"country_raw": "BR", "city_raw": "SP", "country_tks": "x", "city_tks": "y"}, {"country_raw": "BR", "city_raw": "SP", "country_tks": "x", "city_tks": "y"},
] ]
out = aggregate_table_manual_doc_metadata(chunks, task) out = aggregate_table_doc_metadata(chunks, task)
assert "country" in out and "city" in out assert "country" in out and "city" in out
def test_aggregate_deduplicates_values(self, es_engine): def test_aggregate_deduplicates_values(self, es_engine):
@@ -168,7 +202,7 @@ class TestAggregateTableManualDocMetadata:
{"country_raw": "UK", "country_tks": "y"}, {"country_raw": "UK", "country_tks": "y"},
{"country_raw": "US", "country_tks": "x"}, {"country_raw": "US", "country_tks": "x"},
] ]
out = aggregate_table_manual_doc_metadata(chunks, task) out = aggregate_table_doc_metadata(chunks, task)
assert out["country"] == ["US", "UK"] assert out["country"] == ["US", "UK"]
def test_aggregate_kb_reload_field_map(self, es_engine, monkeypatch): def test_aggregate_kb_reload_field_map(self, es_engine, monkeypatch):
@@ -197,7 +231,7 @@ class TestAggregateTableManualDocMetadata:
"kb_id": "kb-1", "kb_id": "kb-1",
} }
chunks = [{"country_raw": "X", "country_tks": "t"}] chunks = [{"country_raw": "X", "country_tks": "t"}]
out = aggregate_table_manual_doc_metadata(chunks, task) out = aggregate_table_doc_metadata(chunks, task)
assert out == {"country": ["X"]} assert out == {"country": ["X"]}
def test_merge_infinity_chunk_data(self, infinity_engine): def test_merge_infinity_chunk_data(self, infinity_engine):
@@ -214,7 +248,7 @@ class TestAggregateTableManualDocMetadata:
{"chunk_data": {"country": "US"}}, {"chunk_data": {"country": "US"}},
{"chunk_data": {"country": "UK"}}, {"chunk_data": {"country": "UK"}},
] ]
out = aggregate_table_manual_doc_metadata(chunks, task) out = aggregate_table_doc_metadata(chunks, task)
assert out == {"country": ["US", "UK"]} assert out == {"country": ["US", "UK"]}