From d9a04ef7022c36e36cbdb41e3745ce0bc80c6510 Mon Sep 17 00:00:00 2001 From: euvre <93761161+euvre@users.noreply.github.com> Date: Mon, 8 Jun 2026 04:08:23 -0700 Subject: [PATCH] fix: support auto mode in table parser document metadata aggregation (#15780) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? Table parser metadata aggregation previously only ran when `table_column_mode` was set to `manual`. In auto mode (default), all columns default to `"both"` role, meaning they should also be aggregated into document-level metadata for UI/chat filters. Additionally, the task snapshot could be stale — `table_column_names` are written to KB `parser_config` during `chunk()` but the task may have been created before that. Changes: - Renames `aggregate_table_manual_doc_metadata` → `aggregate_table_doc_metadata` - Supports both `"manual"` and `"auto"` `table_column_mode` (defaults to `"auto"`) - Reloads `table_column_names` from KB DB when missing from task snapshot - Removes the manual-only guard in `task_executor` and refactored `post_processor` - Updates all tests with new function name and adds auto mode test cases ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- rag/svr/task_executor.py | 360 ++++++++---------- .../task_executor_refactor/post_processor.py | 12 +- rag/utils/table_es_metadata.py | 92 ++--- .../test_post_processor.py | 34 +- .../svr/test_table_metadata_aggregation.py | 68 +++- 5 files changed, 277 insertions(+), 289 deletions(-) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index bf162dd4b8..a27714d6ef 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -16,20 +16,21 @@ import argparse import time from rag.svr.task_executor_refactor.task_manager import TaskManager -from rag.svr.task_executor_refactor.recording_context import timed_with_recording, get_recording_context, \ - RecordingContext, set_recording_context, NullRecordingContext +from rag.svr.task_executor_refactor.recording_context import timed_with_recording, get_recording_context, RecordingContext, set_recording_context, NullRecordingContext start_ts = time.time() # LiteLLM fetches a model cost map from GitHub during import unless this is set. # Parser pods should not block startup on external network access. import os + os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "True") # no internet, save about 10s from common.misc_utils import thread_pool_exec import asyncio import socket + # from beartype import BeartypeConf # from beartype.claw import beartype_all # <-- you didn't sign up for this # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code @@ -56,8 +57,7 @@ from rag.utils.raptor_utils import ( from common.log_utils import init_root_logger from common.config_utils import show_configs from rag.graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache -from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging, run_toc_from_text, \ - gen_metadata +from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging, run_toc_from_text, gen_metadata import logging import os from datetime import datetime @@ -82,8 +82,7 @@ from api.db.services.file2document_service import File2DocumentService from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance from common.versions import get_ragflow_version from api.db.db_models import close_connection -from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \ - email, tag +from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, email, tag from rag.nlp import search, rag_tokenizer, add_positions from rag.raptor import ( RAPTOR_TREE_BUILDER, @@ -103,7 +102,7 @@ from rag.svr.task_executor_limiter import ( from common import settings from common.constants import PAGERANK_FLD, TAG_FLD, SVR_CONSUMER_GROUP_NAME from rag.utils.table_es_metadata import ( - aggregate_table_manual_doc_metadata, + aggregate_table_doc_metadata, merge_table_parser_config_from_kb, table_parser_strip_doc_metadata_keys, ) @@ -128,7 +127,7 @@ FACTORY = { ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email, ParserType.KG.value: naive, - ParserType.TAG.value: tag + ParserType.TAG.value: tag, } TASK_TYPE_TO_PIPELINE_TASK_TYPE = { @@ -152,9 +151,10 @@ FAILED_TASKS = 0 CURRENT_TASKS = {} -WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120')) +WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get("WORKER_HEARTBEAT_TIMEOUT", "120")) stop_event = threading.Event() + def signal_handler(sig, frame): logging.info("Received interrupt signal, shutting down...") stop_event.set() @@ -269,8 +269,7 @@ async def get_storage_binary(bucket, name): @timeout(60 * 80, 1) async def build_chunks(task, progress_callback): if task["size"] > settings.DOC_MAXIMUM_SIZE: - set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" % - (int(settings.DOC_MAXIMUM_SIZE / 1024 / 1024))) + set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" % (int(settings.DOC_MAXIMUM_SIZE / 1024 / 1024))) get_recording_context().record("file_size_exceeded", True) return [] get_recording_context().record("file_size_exceeded", False) @@ -286,8 +285,7 @@ async def build_chunks(task, progress_callback): logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"])) except TimeoutError: progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.") - logging.exception( - "Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"])) + logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"])) raise except Exception as e: if re.search("(No such file|not found)", str(e)): @@ -309,12 +307,11 @@ async def build_chunks(task, progress_callback): # Record chunk configuration for comparison from common.float_utils import normalize_overlapped_percent + chunk_config = { "parser_id": task["parser_id"], "chunk_token_num": parser_config_for_chunk.get("chunk_token_num", 128), - "overlapped_percent": normalize_overlapped_percent( - parser_config_for_chunk.get("overlapped_percent", 0) - ), + "overlapped_percent": normalize_overlapped_percent(parser_config_for_chunk.get("overlapped_percent", 0)), "delimiter": parser_config_for_chunk.get("delimiter", "\n!?。;!?"), "from_page": task["from_page"], "to_page": task["to_page"], @@ -357,21 +354,14 @@ async def build_chunks(task, progress_callback): if cks and cks[0].get("__outline__"): outline = cks[0].pop("__outline__") try: - ret = DocMetadataService.update_document_metadata( - task["doc_id"], - update_metadata_to({"outline": outline}, - DocMetadataService.get_document_metadata(task["doc_id"]) or {}) - ) + ret = DocMetadataService.update_document_metadata(task["doc_id"], update_metadata_to({"outline": outline}, DocMetadataService.get_document_metadata(task["doc_id"]) or {})) get_recording_context().save_func_return_value("DocMetadataService.update_document_metadata", ret) logging.info("Persisted PDF outline (%d entries) for doc %s", len(outline), task["doc_id"]) except Exception as e: logging.warning("Failed to persist PDF outline for doc %s: %s", task["doc_id"], e) docs = [] - doc = { - "doc_id": task["doc_id"], - "kb_id": str(task["kb_id"]) - } + doc = {"doc_id": task["doc_id"], "kb_id": str(task["kb_id"])} if task["pagerank"]: doc[PAGERANK_FLD] = int(task["pagerank"]) st = timer() @@ -381,8 +371,7 @@ async def build_chunks(task, progress_callback): try: d = copy.deepcopy(document) d.update(chunk) - d["id"] = xxhash.xxh64( - (chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest() + d["id"] = xxhash.xxh64((chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest() d["create_time"] = str(datetime.now()).replace("T", " ")[:19] d["create_timestamp_flt"] = datetime.now().timestamp() @@ -398,8 +387,7 @@ async def build_chunks(task, progress_callback): await image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=task["tenant_id"]), d["id"], task["kb_id"]) docs.append(d) except Exception: - logging.exception( - "Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"])) + logging.exception("Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"])) raise tasks = [] @@ -442,8 +430,7 @@ async def build_chunks(task, progress_callback): tasks = [] for d in docs: - tasks.append( - asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"]))) + tasks.append(asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"]))) try: await asyncio.gather(*tasks, return_exceptions=False) except Exception as e: @@ -479,8 +466,7 @@ async def build_chunks(task, progress_callback): tasks = [] for d in docs: - tasks.append( - asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"]))) + tasks.append(asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"]))) try: await asyncio.gather(*tasks, return_exceptions=False) except Exception as e: @@ -519,18 +505,14 @@ async def build_chunks(task, progress_callback): metadata_conf = metadata_conf + built_in_metadata else: metadata_conf = built_in_metadata - cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata", - metadata_conf) + cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata", metadata_conf) if not cached: if has_canceled(task["id"]): progress_callback(-1, msg="Task has been canceled.") return async with chat_limiter: - cached = await gen_metadata(chat_mdl, - turn2jsonschema(metadata_conf), - d["content_with_weight"]) - set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata", - metadata_conf) + cached = await gen_metadata(chat_mdl, turn2jsonschema(metadata_conf), d["content_with_weight"]) + set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata", metadata_conf) if cached: d["metadata_obj"] = cached @@ -584,8 +566,7 @@ async def build_chunks(task, progress_callback): if task_canceled: progress_callback(-1, msg="Task has been canceled.") return None - if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len( - d[TAG_FLD]) > 0: + if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0: examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]}) else: docs_to_tag.append(d) @@ -598,7 +579,7 @@ async def build_chunks(task, progress_callback): return picked_examples = random.choices(examples, k=2) if len(examples) > 2 else examples if not picked_examples: - picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}}) + picked_examples.append({"content": "This is an example", TAG_FLD: {"example": 1}}) async with chat_limiter: cached = await content_tagging( chat_mdl, @@ -643,13 +624,15 @@ def build_TOC(task, docs, progress_callback): progress_callback(msg="Start to generate table of content ...") chat_model_config = get_model_config_from_provider_instance(task["tenant_id"], LLMType.CHAT, task["llm_id"]) chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"]) - docs = sorted(docs, key=lambda d: ( - d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0), - d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0) - )) - toc: list[dict] = asyncio.run( - run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback)) - logging.info("------------ T O C -------------\n" + json.dumps(toc, ensure_ascii=False, indent=' ')) + docs = sorted( + docs, + key=lambda d: ( + d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0), + d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0), + ), + ) + toc: list[dict] = asyncio.run(run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback)) + logging.info("------------ T O C -------------\n" + json.dumps(toc, ensure_ascii=False, indent=" ")) for ii, item in enumerate(toc): try: chunk_val = item.pop("chunk_id", None) @@ -680,8 +663,7 @@ def build_TOC(task, docs, progress_callback): d["toc_kwd"] = "toc" d["available_int"] = 0 d["page_num_int"] = [100000000] - d["id"] = xxhash.xxh64( - (d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest() + d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest() return d return None @@ -722,7 +704,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None): cnts_batches = [] for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE): async with embed_limiter: - vts, c = await thread_pool_exec(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE]) + vts, c = await thread_pool_exec(batch_encode, cnts[i : i + settings.EMBEDDING_BATCH_SIZE]) cnts_batches.append(vts) tk_count += c callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="") @@ -775,8 +757,7 @@ async def run_dataflow(task: dict): if not chunks: get_recording_context().record("pipeline_output_count", 0) get_recording_context().record("pipeline_output_type", "empty") - ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, - task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) + ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret) return @@ -806,8 +787,7 @@ async def run_dataflow(task: dict): # An empty normalized payload means "nothing parsed", so stop before embedding/indexing. if not chunks: - ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, - task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) + ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret) return @@ -831,7 +811,7 @@ async def run_dataflow(task: dict): prog = 0.8 for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE): async with embed_limiter: - vts, c = await thread_pool_exec(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE]) + vts, c = await thread_pool_exec(batch_encode, texts[i : i + settings.EMBEDDING_BATCH_SIZE]) vects_batches.append(vts) embedding_token_consumption += c prog += delta @@ -849,8 +829,7 @@ async def run_dataflow(task: dict): raise except Exception as e: set_progress(task_id, prog=-1, msg=f"[ERROR]: {e}") - ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, - task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) + ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret) return @@ -900,24 +879,21 @@ async def run_dataflow(task: dict): set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...") e = await insert_chunks(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000)) if not e: - ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, - task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) + ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret) return time_cost = timer() - start_ts task_time_cost = timer() - task_start_ts - set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost)) - ret = DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks), - task_time_cost) + set_progress(task_id, prog=1.0, msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost)) + ret = DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks), task_time_cost) get_recording_context().save_func_return_value("DocumentService.increment_chunk_num", ret) - logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, - task_time_cost)) + logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost)) get_recording_context().record("dataflow_chunks", chunks) - ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, - dsl=str(pipeline)) + ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret) + RAPTOR_METHOD_SEARCH_LIMIT = 10000 @@ -928,11 +904,7 @@ async def get_raptor_chunk_field_map(doc_id: str, tenant_id: str, kb_id: str) -> async def search_fields(fields: list[str], condition: dict, order_by=None): """Search chunk fields in the current knowledge base.""" - res = await thread_pool_exec( - settings.docStoreConn.search, - fields, [], condition, [], order_by or OrderByExpr(), - 0, RAPTOR_METHOD_SEARCH_LIMIT, nlp_search.index_name(tenant_id), [kb_id] - ) + res = await thread_pool_exec(settings.docStoreConn.search, fields, [], condition, [], order_by or OrderByExpr(), 0, RAPTOR_METHOD_SEARCH_LIMIT, nlp_search.index_name(tenant_id), [kb_id]) return settings.docStoreConn.get_fields(res, fields) primary = await search_fields(["raptor_kwd", "extra"], {"doc_id": doc_id, "raptor_kwd": ["raptor"]}) @@ -963,12 +935,17 @@ async def get_raptor_chunk_methods(doc_id: str, tenant_id: str, kb_id: str) -> s if methods: logging.info( "Checkpoint hit: RAPTOR chunks for doc %s (tenant=%s kb=%s methods=%s) already exist", - doc_id, tenant_id, kb_id, sorted(methods), + doc_id, + tenant_id, + kb_id, + sorted(methods), ) else: logging.info( "Checkpoint miss: no RAPTOR chunks for doc %s (tenant=%s kb=%s)", - doc_id, tenant_id, kb_id, + doc_id, + tenant_id, + kb_id, ) return methods except Exception: @@ -987,7 +964,9 @@ async def delete_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str, keep_met if keep_method is None: logging.info( "delete_raptor_chunks: removing all RAPTOR summaries (doc=%s tenant=%s kb=%s)", - doc_id, tenant_id, kb_id, + doc_id, + tenant_id, + kb_id, ) ret = await thread_pool_exec( settings.docStoreConn.delete, @@ -1003,13 +982,20 @@ async def delete_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str, keep_met if not chunk_ids: logging.debug( "delete_raptor_chunks: no stale RAPTOR chunks to remove (doc=%s tenant=%s kb=%s keep=%s)", - doc_id, tenant_id, kb_id, keep_method, + doc_id, + tenant_id, + kb_id, + keep_method, ) return 0 logging.info( "delete_raptor_chunks: removing %d stale RAPTOR chunks (doc=%s tenant=%s kb=%s keep=%s)", - len(chunk_ids), doc_id, tenant_id, kb_id, keep_method, + len(chunk_ids), + doc_id, + tenant_id, + kb_id, + keep_method, ) ret = await thread_pool_exec( settings.docStoreConn.delete, @@ -1073,6 +1059,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si nonlocal tk_count, res logging.info("RAPTOR: using tree_builder=%s clustering_method=%s for doc %s", tree_builder, clustering_method, did) from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor # Lazy load, save around 8s + raptor = Raptor( raptor_config.get("max_cluster", 64), chat_mdl, @@ -1132,7 +1119,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si for x, doc_id in enumerate(doc_ids): if skip_raptor_doc(doc_id): - callback(prog=(x + 1.) / len(doc_ids)) + callback(prog=(x + 1.0) / len(doc_ids)) continue # CHECKPOINT: skip docs that already have RAPTOR chunks in the doc store existing_methods = await get_raptor_chunk_methods(doc_id, row["tenant_id"], row["kb_id"]) @@ -1142,16 +1129,14 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si schedule_raptor_cleanup(doc_id, tree_builder) callback(msg=f"[RAPTOR] doc:{doc_id} will remove old RAPTOR summaries after insert.") callback(msg=f"[RAPTOR] doc:{doc_id} already has {tree_builder} RAPTOR chunks, skipping.") - callback(prog=(x + 1.) / len(doc_ids)) + callback(prog=(x + 1.0) / len(doc_ids)) continue if existing_methods: callback(msg=f"[RAPTOR] doc:{doc_id} will migrate RAPTOR summaries to {tree_builder} after insert.") chunks = [] skipped_chunks = 0 - for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], - fields=["content_with_weight", vctr_nm], - sort_by_position=True): + for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm], sort_by_position=True): # Skip chunks that don't have the required vector field (may have been indexed with different embedding model) if vctr_nm not in d or d[vctr_nm] is None: skipped_chunks += 1 @@ -1173,7 +1158,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si has_file_level_target = True if existing_methods: schedule_raptor_cleanup(doc_id, tree_builder) - callback(prog=(x + 1.) / len(doc_ids)) + callback(prog=(x + 1.0) / len(doc_ids)) if remove_dataset_summaries: if has_file_level_target: @@ -1213,9 +1198,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si for doc_id in doc_ids: if doc_id in skipped_doc_ids: continue - for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], - fields=["content_with_weight", vctr_nm], - sort_by_position=True): + for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm], sort_by_position=True): # Skip chunks that don't have the required vector field if vctr_nm not in d or d[vctr_nm] is None: skipped_chunks += 1 @@ -1283,14 +1266,17 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre mom_ck["available_int"] = 0 flds = list(mom_ck.keys()) for fld in flds: - if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int", - "position_int", "create_timestamp_flt", "page_num_int", "top_int"]: + if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int", "position_int", "create_timestamp_flt", "page_num_int", "top_int"]: del mom_ck[fld] mothers.append(mom_ck) for b in range(0, len(mothers), settings.DOC_BULK_SIZE): - ret = await thread_pool_exec(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE], - search.index_name(task_tenant_id), task_dataset_id, ) + ret = await thread_pool_exec( + settings.docStoreConn.insert, + mothers[b : b + settings.DOC_BULK_SIZE], + search.index_name(task_tenant_id), + task_dataset_id, + ) get_recording_context().save_func_return_value("docStoreConn.insert", ret) task_canceled = has_canceled(task_id) if task_canceled: @@ -1298,17 +1284,18 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre return False for b in range(0, len(chunks), settings.DOC_BULK_SIZE): - doc_store_result = await thread_pool_exec(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE], - search.index_name(task_tenant_id), task_dataset_id, ) + doc_store_result = await thread_pool_exec( + settings.docStoreConn.insert, + chunks[b : b + settings.DOC_BULK_SIZE], + search.index_name(task_tenant_id), + task_dataset_id, + ) get_recording_context().save_func_return_value("docStoreConn.insert", doc_store_result) task_canceled = has_canceled(task_id) if task_canceled: # Roll back partial RAPTOR summary inserts so the next run is not # mistaken for a completed checkpoint by get_raptor_chunk_methods. - raptor_ids_to_rollback = [ - c["id"] for c in chunks[:b + settings.DOC_BULK_SIZE] - if c.get("raptor_kwd") == "raptor" - ] + raptor_ids_to_rollback = [c["id"] for c in chunks[: b + settings.DOC_BULK_SIZE] if c.get("raptor_kwd") == "raptor"] if raptor_ids_to_rollback: try: ret = await thread_pool_exec( @@ -1320,7 +1307,8 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre get_recording_context().save_func_return_value("docStoreConn.delete", ret) logging.info( "insert_chunks: rolled back %d partial RAPTOR chunks after cancellation (task=%s)", - len(raptor_ids_to_rollback), task_id, + len(raptor_ids_to_rollback), + task_id, ) except Exception: logging.exception( @@ -1335,15 +1323,19 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" progress_callback(-1, msg=error_message) raise Exception(error_message) - chunk_ids = [chunk["id"] for chunk in chunks[:b + settings.DOC_BULK_SIZE]] + chunk_ids = [chunk["id"] for chunk in chunks[: b + settings.DOC_BULK_SIZE]] chunk_ids_str = " ".join(chunk_ids) try: TaskService.update_chunk_ids(task_id, chunk_ids_str) get_recording_context().save_func_return_value("TaskService.update_chunk_ids", None) except DoesNotExist: logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.") - doc_store_result = await thread_pool_exec(settings.docStoreConn.delete, {"id": chunk_ids}, - search.index_name(task_tenant_id), task_dataset_id, ) + doc_store_result = await thread_pool_exec( + settings.docStoreConn.delete, + {"id": chunk_ids}, + search.index_name(task_tenant_id), + task_dataset_id, + ) get_recording_context().save_func_return_value("docStoreConn.delete", doc_store_result) tasks = [] for chunk_id in chunk_ids: @@ -1383,8 +1375,8 @@ async def do_handle_task(task): if not task.get("language"): logging.warning("Task %s has no language set, falling back to Chinese", task_id) doc_task_llm_id = task["parser_config"].get("llm_id") or task["llm_id"] - kb_task_llm_id = task['kb_parser_config'].get("llm_id") or task["llm_id"] - task['llm_id'] = kb_task_llm_id + kb_task_llm_id = task["kb_parser_config"].get("llm_id") or task["llm_id"] + task["llm_id"] = kb_task_llm_id task_dataset_id = task["kb_id"] task_doc_id = task["doc_id"] task_document_name = task["name"] @@ -1411,14 +1403,14 @@ async def do_handle_task(task): vts, _ = embedding_model.encode(["ok"]) vector_size = len(vts[0]) except Exception as e: - error_message = f'Fail to bind embedding model: {str(e)}' + error_message = f"Fail to bind embedding model: {str(e)}" progress_callback(-1, msg=error_message) logging.exception(error_message) raise init_kb(task, vector_size) - if task_type[:len("dataflow")] == "dataflow": + if task_type[: len("dataflow")] == "dataflow": await run_dataflow(task) return @@ -1517,7 +1509,8 @@ async def do_handle_task(task): with_community = graphrag_conf.get("community", False) async with kg_limiter: # await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback) - from rag.graphrag.general.index import run_graphrag_for_kb # Lazy load, save around 2s + from rag.graphrag.general.index import run_graphrag_for_kb # Lazy load, save around 2s + result = await run_graphrag_for_kb( row=task, doc_ids=task.get("doc_ids", []), @@ -1539,7 +1532,7 @@ async def do_handle_task(task): return else: # Standard chunking methods - task['llm_id'] = doc_task_llm_id + task["llm_id"] = doc_task_llm_id start_ts = timer() chunks = await build_chunks(task, progress_callback) get_recording_context().record("chunks", chunks) @@ -1549,7 +1542,7 @@ async def do_handle_task(task): # Record chunks array for content comparison (first, middle, last, random) logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts)) if not chunks: - progress_callback(1., msg=f"No chunk built from {task_document_name}") + progress_callback(1.0, msg=f"No chunk built from {task_document_name}") return progress_callback(msg="Generate {} chunks".format(len(chunks))) start_ts = timer() @@ -1605,52 +1598,45 @@ async def do_handle_task(task): if cleaned_chunks: progress_callback(msg=f"Cleaned up {cleaned_chunks} stale RAPTOR chunks.") - logging.info( - "Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format( - task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts - ) - ) + logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts)) ret = DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0) get_recording_context().save_func_return_value("DocumentService.increment_chunk_num", ret) - # Table parser (manual): push metadata/both column values to document-level metadata for UI / chat filters + # Table parser: push metadata/both column values to document-level metadata for UI / chat filters if task.get("parser_id", "").lower() == "table": eff_pc = merge_table_parser_config_from_kb(task) - logging.debug( - f"[TABLE_META_DEBUG] table post-index: table_column_mode={eff_pc.get('table_column_mode')!r}" - ) - if eff_pc.get("table_column_mode") == "manual": + logging.debug(f"[TABLE_META_DEBUG] table post-index: table_column_mode={eff_pc.get('table_column_mode')!r}") + try: + agg = aggregate_table_doc_metadata(chunks, task) + logging.debug(f"[TABLE_META_DEBUG] aggregated metadata: {agg}") + strip_keys = table_parser_strip_doc_metadata_keys(eff_pc) + existing = DocMetadataService.get_document_metadata(task_doc_id) + existing = existing if isinstance(existing, dict) else {} + preserved = {k: v for k, v in existing.items() if k not in strip_keys} + merged = update_metadata_to(dict(preserved), agg) + logging.debug( + f"[TABLE_META_DEBUG] calling update_document_metadata for doc_id={task_doc_id}, " + f"meta_fields keys={list(merged.keys())}, " + f"table_strip_key_count={len(strip_keys)}, agg_keys={list(agg.keys())}" + ) try: - agg = aggregate_table_manual_doc_metadata(chunks, task) - logging.debug(f"[TABLE_META_DEBUG] aggregated metadata: {agg}") - strip_keys = table_parser_strip_doc_metadata_keys(eff_pc) - existing = DocMetadataService.get_document_metadata(task_doc_id) - existing = existing if isinstance(existing, dict) else {} - preserved = {k: v for k, v in existing.items() if k not in strip_keys} - merged = update_metadata_to(dict(preserved), agg) - logging.debug( - f"[TABLE_META_DEBUG] calling update_document_metadata for doc_id={task_doc_id}, " - f"meta_fields keys={list(merged.keys())}, " - f"table_strip_key_count={len(strip_keys)}, agg_keys={list(agg.keys())}" - ) - try: - ret = DocMetadataService.update_document_metadata(task_doc_id, merged) - get_recording_context().save_func_return_value("DocMetadataService.update_document_metadata", ret) - logging.debug("[TABLE_META_DEBUG] update_document_metadata succeeded") - except Exception as ue: - logging.error( - "update_document_metadata failed (table parser, doc_id=%s): %s", - task_doc_id, - ue, - exc_info=True, - ) - except Exception as e: - logging.exception( - "Table parser document metadata aggregation failed (doc_id=%s): %s", + ret = DocMetadataService.update_document_metadata(task_doc_id, merged) + get_recording_context().save_func_return_value("DocMetadataService.update_document_metadata", ret) + logging.debug("[TABLE_META_DEBUG] update_document_metadata succeeded") + except Exception as ue: + logging.error( + "update_document_metadata failed (table parser, doc_id=%s): %s", task_doc_id, - e, + ue, + exc_info=True, ) + except Exception as e: + logging.exception( + "Table parser document metadata aggregation failed (doc_id=%s): %s", + task_doc_id, + e, + ) progress_callback(msg="Indexing done ({:.2f}s).".format(timer() - start_ts)) @@ -1672,11 +1658,7 @@ async def do_handle_task(task): task_time_cost = timer() - task_start_ts get_recording_context().record("task_status", "completed") progress_callback(prog=1.0, msg="Task done ({:.2f}s)".format(task_time_cost)) - logging.info( - "Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format( - task_document_name, task_from_page, task_to_page, len(chunks), token_count, task_time_cost - ) - ) + logging.info("Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), token_count, task_time_cost)) finally: if toc_thread is not None and not toc_thread.done(): @@ -1697,8 +1679,7 @@ async def do_handle_task(task): ) get_recording_context().save_func_return_value("docStoreConn.delete", ret) except Exception as e: - logging.exception( - f"Remove doc({task_doc_id}) from docStore failed when task({task_id}) canceled, exception: {e}") + logging.exception(f"Remove doc({task_doc_id}) from docStore failed when task({task_id}) canceled, exception: {e}") async def handle_task(): @@ -1709,8 +1690,7 @@ async def handle_task(): return task_type = task["task_type"] - pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, - PipelineTaskType.PARSE) or PipelineTaskType.PARSE + pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, PipelineTaskType.PARSE) or PipelineTaskType.PARSE task_id = task["id"] try: CURRENT_TASKS[task["id"]] = copy.deepcopy(task) @@ -1718,20 +1698,18 @@ async def handle_task(): logging.info(f"TE_RUN_MODE is {run_mode}") # Check if dry-run comparison is enabled via environment variable - if run_mode == "1": # dry run mode - compare + if run_mode == "1": # dry run mode - compare set_recording_context(RecordingContext()) - await do_handle_task(task) # original execution + await do_handle_task(task) # original execution # dry run mode logging.info(f"-----dry run task:{task_id}, {task.get('name', '')}, doc id:{task.get('doc_id', '')}") - await TaskManager.dry_run_task(task, get_recording_context(), chat_limiter, minio_limiter, chunk_limiter, - embed_limiter,kg_limiter, set_progress, has_canceled) - elif run_mode == "0": # use refactor-ed version + await TaskManager.dry_run_task(task, get_recording_context(), chat_limiter, minio_limiter, chunk_limiter, embed_limiter, kg_limiter, set_progress, has_canceled) + elif run_mode == "0": # use refactor-ed version # switch to refactor-ed version logging.info(f"-----run refactor-ed task executor:{task_id}, {task.get('name', '')}, doc id:{task.get('doc_id', '')}") set_recording_context(NullRecordingContext()) - await TaskManager.run_refactored_task(task, chat_limiter, minio_limiter, chunk_limiter, - embed_limiter,kg_limiter, set_progress, has_canceled) - else: # original version + await TaskManager.run_refactored_task(task, chat_limiter, minio_limiter, chunk_limiter, embed_limiter, kg_limiter, set_progress, has_canceled) + else: # original version logging.info(f"-----run original task executor:{task_id}, {task.get('name', '')}, doc id:{task.get('doc_id', '')}") set_recording_context(NullRecordingContext()) await do_handle_task(task) @@ -1742,9 +1720,7 @@ async def handle_task(): except TaskCanceledException as e: DONE_TASKS += 1 CURRENT_TASKS.pop(task_id, None) - logging.info( - f"handle_task canceled for task {task_id}: {getattr(e, 'msg', str(e))}" - ) + logging.info(f"handle_task canceled for task {task_id}: {getattr(e, 'msg', str(e))}") except Exception as e: FAILED_TASKS += 1 CURRENT_TASKS.pop(task_id, None) @@ -1752,7 +1728,7 @@ async def handle_task(): err_msg = str(e) while isinstance(e, exceptiongroup.ExceptionGroup): e = e.exceptions[0] - err_msg += ' -- ' + str(e) + err_msg += " -- " + str(e) set_progress(task_id, prog=-1, msg=f"[Exception]: {err_msg}") except Exception as e: logging.exception(f"[Exception]: {str(e)}") @@ -1763,9 +1739,9 @@ async def handle_task(): referred_document_id = None if task_type in ["graphrag", "raptor", "mindmap"]: referred_document_id = task["doc_ids"][0] - ret = PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="", - task_type=pipeline_task_type, - task_id=task_id, referred_document_id=referred_document_id) + ret = PipelineOperationLogService.record_pipeline_operation( + document_id=task["doc_id"], pipeline_id="", task_type=pipeline_task_type, task_id=task_id, referred_document_id=referred_document_id + ) get_recording_context().save_func_return_value("PipelineOperationLogService.record_pipeline_operation", ret) redis_msg.ack() @@ -1779,7 +1755,7 @@ async def get_server_ip() -> str: return s.getsockname()[0] except Exception as e: logging.error(str(e)) - return 'Unknown' + return "Unknown" async def report_status(): @@ -1804,18 +1780,20 @@ async def report_status(): LAG_TASKS = int(group_info.get("lag", 0)) current = copy.deepcopy(CURRENT_TASKS) - heartbeat = json.dumps({ - "ip_address": ip_address, - "pid": pid, - "name": CONSUMER_NAME, - "now": now.astimezone().isoformat(timespec="milliseconds"), - "boot_at": BOOT_AT, - "pending": PENDING_TASKS, - "lag": LAG_TASKS, - "done": DONE_TASKS, - "failed": FAILED_TASKS, - "current": current, - }) + heartbeat = json.dumps( + { + "ip_address": ip_address, + "pid": pid, + "name": CONSUMER_NAME, + "now": now.astimezone().isoformat(timespec="milliseconds"), + "boot_at": BOOT_AT, + "pending": PENDING_TASKS, + "lag": LAG_TASKS, + "done": DONE_TASKS, + "failed": FAILED_TASKS, + "current": current, + } + ) # Report heartbeat to Redis try: @@ -1890,17 +1868,17 @@ async def main(): /___/_/ /_/\__, /\___/____/\__/_/\____/_/ /_/ /____/\___/_/ |___/\___/_/ /____/ """) - logging.info(f'RAGFlow ingestion version: {get_ragflow_version()}') - logging.info(f"ENABLE_DRY_RUN_COMPARISON: {os.environ.get("ENABLE_DRY_RUN_COMPARISON", "0")}") + logging.info(f"RAGFlow ingestion version: {get_ragflow_version()}") + logging.info(f"ENABLE_DRY_RUN_COMPARISON: {os.environ.get('ENABLE_DRY_RUN_COMPARISON', '0')}") show_configs() settings.init_settings() settings.check_and_install_torch() - logging.info(f'default embedding config: {settings.EMBEDDING_CFG}') + logging.info(f"default embedding config: {settings.EMBEDDING_CFG}") settings.print_rag_settings() if sys.platform != "win32": signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot) signal.signal(signal.SIGUSR2, stop_tracemalloc) - TRACE_MALLOC_ENABLED = int(os.environ.get('TRACE_MALLOC_ENABLED', "0")) + TRACE_MALLOC_ENABLED = int(os.environ.get("TRACE_MALLOC_ENABLED", "0")) if TRACE_MALLOC_ENABLED: start_tracemalloc_and_snapshot(None, None) @@ -1927,16 +1905,16 @@ async def main(): if __name__ == "__main__": # Parse command line arguments (consistent with SAAS version) - parser = argparse.ArgumentParser(description='Task Executor') - parser.add_argument("-i", "--index", type=str, default='0') + parser = argparse.ArgumentParser(description="Task Executor") + parser.add_argument("-i", "--index", type=str, default="0") parser.add_argument("-t", "--type", type=str, default="common", help="[common, graphrag, raptor, resume]") args = parser.parse_args() - + # Update global variables TASK_TYPE = args.type TE_IDX = args.index CONSUMER_NAME = f"task_executor_{TASK_TYPE}_{TE_IDX}" - + faulthandler.enable() init_root_logger(CONSUMER_NAME) try: diff --git a/rag/svr/task_executor_refactor/post_processor.py b/rag/svr/task_executor_refactor/post_processor.py index 31a73a5bed..73be2cd610 100644 --- a/rag/svr/task_executor_refactor/post_processor.py +++ b/rag/svr/task_executor_refactor/post_processor.py @@ -28,11 +28,12 @@ from api.db.services.doc_metadata_service import DocMetadataService from common.metadata_utils import update_metadata_to from rag.svr.task_executor_refactor.task_context import TaskContext from rag.utils.table_es_metadata import ( - aggregate_table_manual_doc_metadata, + aggregate_table_doc_metadata, merge_table_parser_config_from_kb, table_parser_strip_doc_metadata_keys, ) + class PostProcessor: """Service for post-indexing operations. @@ -69,15 +70,10 @@ class PostProcessor: return eff_pc = merge_table_parser_config_from_kb(ctx.raw_task) - logging.debug( - f"[TABLE_META_DEBUG] table post-index: table_column_mode={eff_pc.get('table_column_mode')!r}" - ) - - if eff_pc.get("table_column_mode") != "manual": - return + logging.debug(f"[TABLE_META_DEBUG] table post-index: table_column_mode={eff_pc.get('table_column_mode')!r}") try: - agg = aggregate_table_manual_doc_metadata(chunks, ctx.raw_task) + agg = aggregate_table_doc_metadata(chunks, ctx.raw_task) logging.debug(f"[TABLE_META_DEBUG] aggregated metadata: {agg}") strip_keys = table_parser_strip_doc_metadata_keys(eff_pc) diff --git a/rag/utils/table_es_metadata.py b/rag/utils/table_es_metadata.py index 18edfc4696..82baddf5be 100644 --- a/rag/utils/table_es_metadata.py +++ b/rag/utils/table_es_metadata.py @@ -102,9 +102,7 @@ def _probe_es_typed_key_for_column(col: str, sample_chunk: dict) -> str | None: return None -def _resolve_es_chunk_field_key( - col: str, field_map: dict, sample_chunk: dict | None -) -> tuple[str | None, str]: +def _resolve_es_chunk_field_key(col: str, field_map: dict, sample_chunk: dict | None) -> tuple[str | None, str]: """Prefer field_map when key exists on chunk; else probe by suffix (matches table.py naming).""" tk_fm = _field_map_typed_key_for_column(field_map, col) if field_map else None if sample_chunk: @@ -153,35 +151,44 @@ def _es_field_value_to_doc_metadata(val, *, from_tks_fallback: bool) -> str | No return _value_to_meta_string(val) -def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict: +def aggregate_table_doc_metadata(chunks: list, task: dict) -> dict: """ Collect unique values per metadata/both column across chunks for document-level metadata. - Used when table_column_mode == manual (parallel to LLM gen_metadata, no schema required). + Works for both table_column_mode == manual and auto (where all columns default to "both"). """ - logging.debug( - f"[TABLE_META_DEBUG] aggregate_table_manual_doc_metadata called with {len(chunks)} chunks" - ) + logging.debug(f"[TABLE_META_DEBUG] aggregate_table_doc_metadata called with {len(chunks)} chunks") eff = merge_table_parser_config_from_kb(task) - if eff.get("table_column_mode") != "manual": - logging.debug( - f"[TABLE_META_DEBUG] skip aggregate: table_column_mode={eff.get('table_column_mode')!r}" - ) + mode = eff.get("table_column_mode") or "auto" + if mode not in ("manual", "auto"): + logging.debug(f"[TABLE_META_DEBUG] skip aggregate: table_column_mode={mode!r}") return {} roles = eff.get("table_column_roles") or {} table_column_names = eff.get("table_column_names") or [] + # Reload table_column_names from KB if empty (chunk() writes them during parse, + # but the task snapshot may be stale) + if not table_column_names: + kb_id = task.get("kb_id") + if kb_id: + try: + KBS = _knowledgebase_service_cls() + ok, kb = KBS.get_by_id(kb_id) + if ok and kb: + fresh_names = (kb.parser_config or {}).get("table_column_names") or [] + if fresh_names: + table_column_names = fresh_names + logging.debug(f"[TABLE_META_DEBUG] reloaded table_column_names from DB: {fresh_names}") + except Exception as e: + logging.debug( + "[TABLE_META_DEBUG] failed to reload table_column_names from DB: %s", + e, + exc_info=True, + ) if table_column_names: - meta_cols = [ - col - for col in table_column_names - if roles.get(col, "both") in ("metadata", "both") - ] + meta_cols = [col for col in table_column_names if roles.get(col, "both") in ("metadata", "both")] else: meta_cols = [c for c, r in roles.items() if r in ("metadata", "both")] if not meta_cols: - logging.debug( - "[TABLE_META_DEBUG] skip aggregate: no metadata/both columns " - f"(table_column_names_present={bool(table_column_names)})" - ) + logging.debug(f"[TABLE_META_DEBUG] skip aggregate: no metadata/both columns (table_column_names_present={bool(table_column_names)})") return {} fm = (task.get("kb_parser_config") or {}).get("field_map") or {} kb_id = task.get("kb_id") @@ -194,14 +201,9 @@ def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict: reloaded = fresh_pc.get("field_map") or {} if reloaded: fm = reloaded - logging.debug( - f"[TABLE_META_DEBUG] reloaded field_map from DB: {len(fm)} entries" - ) + logging.debug(f"[TABLE_META_DEBUG] reloaded field_map from DB: {len(fm)} entries") else: - logging.debug( - "[TABLE_META_DEBUG] KB reload: parser_config has no field_map yet; " - "will use ES key probe on chunk dicts if applicable" - ) + logging.debug("[TABLE_META_DEBUG] KB reload: parser_config has no field_map yet; will use ES key probe on chunk dicts if applicable") except Exception as e: logging.debug( "[TABLE_META_DEBUG] failed to reload field_map from DB: %s", @@ -209,21 +211,11 @@ def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict: exc_info=True, ) if not fm and not (settings.DOC_ENGINE_INFINITY or settings.DOC_ENGINE_OCEANBASE): - logging.debug( - "[TABLE_META_DEBUG] field_map empty on task snapshot — will use ES key probe on chunk dicts; " - f"kb_parser_config keys={list((task.get('kb_parser_config') or {}).keys())}" - ) - logging.debug( - f"[TABLE_META_DEBUG] meta_cols={meta_cols}, field_map entries={len(fm)}, " - f"infinity={settings.DOC_ENGINE_INFINITY}, oceanbase={settings.DOC_ENGINE_OCEANBASE}" - ) + logging.debug(f"[TABLE_META_DEBUG] field_map empty on task snapshot — will use ES key probe on chunk dicts; kb_parser_config keys={list((task.get('kb_parser_config') or {}).keys())}") + logging.debug(f"[TABLE_META_DEBUG] meta_cols={meta_cols}, field_map entries={len(fm)}, infinity={settings.DOC_ENGINE_INFINITY}, oceanbase={settings.DOC_ENGINE_OCEANBASE}") sample_ck = next((c for c in chunks if isinstance(c, dict)), None) if sample_ck: - sk = [ - k - for k in sample_ck.keys() - if not (str(k).startswith("q_") and str(k).endswith("_vec")) - ][:50] + sk = [k for k in sample_ck.keys() if not (str(k).startswith("q_") and str(k).endswith("_vec"))][:50] logging.debug(f"[TABLE_META_DEBUG] first chunk non-vector keys (sample): {sk}") es_col_keys: dict[str, tuple[str | None, str]] = {} @@ -231,9 +223,7 @@ def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict: for col in meta_cols: tk, src = _resolve_es_chunk_field_key(col, fm, sample_ck) es_col_keys[col] = (tk, src) - logging.debug( - f"[TABLE_META_DEBUG] column '{col}' -> ES key {tk!r} (source={src})" - ) + logging.debug(f"[TABLE_META_DEBUG] column '{col}' -> ES key {tk!r} (source={src})") acc: dict[str, list] = {c: [] for c in meta_cols} @@ -255,9 +245,7 @@ def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict: tk, _src = es_col_keys.get(col, (None, "none")) if not tk: if i == 0: - logging.debug( - f"[TABLE_META_DEBUG] no resolved ES key for column '{col}'" - ) + logging.debug(f"[TABLE_META_DEBUG] no resolved ES key for column '{col}'") continue raw_k = _es_raw_field_key_from_typed(tk) val = None @@ -269,10 +257,7 @@ def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict: from_tks = tk.endswith("_tks") else: if i == 0: - logging.debug( - f"[TABLE_META_DEBUG] chunk missing ES field {tk!r}" - f"{' and ' + raw_k + ' (raw)' if raw_k else ''} for column '{col}'" - ) + logging.debug(f"[TABLE_META_DEBUG] chunk missing ES field {tk!r}{' and ' + raw_k + ' (raw)' if raw_k else ''} for column '{col}'") continue s = _es_field_value_to_doc_metadata(val, from_tks_fallback=from_tks) if s is not None: @@ -289,8 +274,5 @@ def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict: for col, vals in acc.items(): if vals: out[col] = dedupe_list(vals) - logging.debug( - f"[TABLE_META_DEBUG] aggregated metadata dict keys={list(out.keys())}, " - f"sizes={[len(v) for v in out.values()]}" - ) + logging.debug(f"[TABLE_META_DEBUG] aggregated metadata dict keys={list(out.keys())}, sizes={[len(v) for v in out.values()]}") return out diff --git a/test/unit_test/rag/svr/task_executor_refactor/test_post_processor.py b/test/unit_test/rag/svr/task_executor_refactor/test_post_processor.py index 7e42187693..c3022503c9 100644 --- a/test/unit_test/rag/svr/task_executor_refactor/test_post_processor.py +++ b/test/unit_test/rag/svr/task_executor_refactor/test_post_processor.py @@ -69,12 +69,13 @@ class TestPostProcessorProcessTableParserMetadata: ctx.write_interceptor = None chunks = [{"col_key": "val"}] - with patch("rag.svr.task_executor_refactor.post_processor.merge_table_parser_config_from_kb") as mock_merge, \ - patch("rag.svr.task_executor_refactor.post_processor.aggregate_table_manual_doc_metadata") as mock_agg, \ - patch("rag.svr.task_executor_refactor.post_processor.table_parser_strip_doc_metadata_keys") as mock_strip, \ - patch("rag.svr.task_executor_refactor.post_processor.update_metadata_to") as mock_update, \ - patch("rag.svr.task_executor_refactor.post_processor.DocMetadataService") as mock_meta: - + with ( + patch("rag.svr.task_executor_refactor.post_processor.merge_table_parser_config_from_kb") as mock_merge, + patch("rag.svr.task_executor_refactor.post_processor.aggregate_table_doc_metadata") as mock_agg, + patch("rag.svr.task_executor_refactor.post_processor.table_parser_strip_doc_metadata_keys") as mock_strip, + patch("rag.svr.task_executor_refactor.post_processor.update_metadata_to") as mock_update, + patch("rag.svr.task_executor_refactor.post_processor.DocMetadataService") as mock_meta, + ): mock_merge.return_value = {"table_column_mode": "manual"} mock_agg.return_value = {"col_key": ["val1", "val2"]} mock_strip.return_value = set() @@ -95,11 +96,12 @@ class TestPostProcessorProcessTableParserMetadata: ctx.raw_task = {} ctx.write_interceptor = MagicMock() - with patch("rag.svr.task_executor_refactor.post_processor.merge_table_parser_config_from_kb") as mock_merge, \ - patch("rag.svr.task_executor_refactor.post_processor.aggregate_table_manual_doc_metadata") as mock_agg, \ - patch("rag.svr.task_executor_refactor.post_processor.table_parser_strip_doc_metadata_keys") as mock_strip, \ - patch("rag.svr.task_executor_refactor.post_processor.DocMetadataService") as mock_meta: - + with ( + patch("rag.svr.task_executor_refactor.post_processor.merge_table_parser_config_from_kb") as mock_merge, + patch("rag.svr.task_executor_refactor.post_processor.aggregate_table_doc_metadata") as mock_agg, + patch("rag.svr.task_executor_refactor.post_processor.table_parser_strip_doc_metadata_keys") as mock_strip, + patch("rag.svr.task_executor_refactor.post_processor.DocMetadataService") as mock_meta, + ): mock_merge.return_value = {"table_column_mode": "manual"} mock_agg.return_value = {"key": ["v"]} mock_strip.return_value = set() @@ -108,9 +110,7 @@ class TestPostProcessorProcessTableParserMetadata: service = PostProcessor(ctx=ctx) await service.process_table_parser_metadata("doc_1", []) - ctx.write_interceptor.intercept.assert_called_once_with( - "DocMetadataService.update_document_metadata" - ) + ctx.write_interceptor.intercept.assert_called_once_with("DocMetadataService.update_document_metadata") class TestPostProcessorInsertTocChunk: @@ -160,9 +160,7 @@ class TestPostProcessorInsertTocChunk: result = await service.insert_toc_chunk(toc_chunk, chunk_service) assert result is True - chunk_service.insert_chunks.assert_called_once_with( - "task_1", "tenant_1", "kb_1", [toc_chunk] - ) + chunk_service.insert_chunks.assert_called_once_with("task_1", "tenant_1", "kb_1", [toc_chunk]) @pytest.mark.asyncio async def test_handles_insert_failure(self): @@ -179,4 +177,4 @@ class TestPostProcessorInsertTocChunk: result = await service.insert_toc_chunk(toc_chunk, chunk_service) - assert result is False \ No newline at end of file + assert result is False diff --git a/test/unit_test/rag/svr/test_table_metadata_aggregation.py b/test/unit_test/rag/svr/test_table_metadata_aggregation.py index 59d2f7ee47..9406d1fe40 100644 --- a/test/unit_test/rag/svr/test_table_metadata_aggregation.py +++ b/test/unit_test/rag/svr/test_table_metadata_aggregation.py @@ -14,11 +14,11 @@ # limitations under the License. # -"""Unit tests for aggregate_table_manual_doc_metadata.""" +"""Unit tests for aggregate_table_doc_metadata.""" import pytest -from rag.utils.table_es_metadata import aggregate_table_manual_doc_metadata, merge_table_parser_config_from_kb +from rag.utils.table_es_metadata import aggregate_table_doc_metadata, merge_table_parser_config_from_kb @pytest.fixture @@ -73,30 +73,64 @@ class TestAggregateTableManualDocMetadata: "category_tks": "y", }, ] - out = aggregate_table_manual_doc_metadata(chunks, task) + out = aggregate_table_doc_metadata(chunks, task) assert out["country"] == ["Brazil", "Turkey"] assert out["category"] == ["Economy", "Disaster"] - def test_aggregate_auto_mode_returns_empty(self, es_engine): + def test_aggregate_auto_mode_returns_data(self, es_engine): task = { "parser_id": "table", "parser_config": {}, "kb_parser_config": { "table_column_mode": "auto", - "table_column_roles": {"country": "metadata"}, + "table_column_names": ["country"], + "field_map": {"country_tks": "country"}, }, } - assert aggregate_table_manual_doc_metadata([{"country_tks": "x"}], task) == {} + chunks = [{"country_raw": "Brazil", "country_tks": "x"}] + out = aggregate_table_doc_metadata(chunks, task) + assert out == {"country": ["Brazil"]} - def test_aggregate_no_mode_returns_empty(self, es_engine): + def test_aggregate_auto_mode_all_columns_both(self, es_engine): task = { "parser_id": "table", "parser_config": {}, "kb_parser_config": { - "table_column_roles": {"country": "metadata"}, + "table_column_mode": "auto", + "table_column_names": ["country", "category"], + "field_map": {"country_tks": "country", "category_tks": "category"}, }, } - assert aggregate_table_manual_doc_metadata([{}], task) == {} + chunks = [ + {"country_raw": "Brazil", "country_tks": "x", "category_raw": "Economy", "category_tks": "y"}, + {"country_raw": "Turkey", "country_tks": "x", "category_raw": "Disaster", "category_tks": "y"}, + ] + out = aggregate_table_doc_metadata(chunks, task) + assert out["country"] == ["Brazil", "Turkey"] + assert out["category"] == ["Economy", "Disaster"] + + def test_aggregate_no_mode_defaults_to_auto(self, es_engine): + """When table_column_mode is missing, it defaults to 'auto'.""" + task = { + "parser_id": "table", + "parser_config": {}, + "kb_parser_config": { + "table_column_names": ["country"], + "field_map": {"country_tks": "country"}, + }, + } + chunks = [{"country_raw": "Brazil", "country_tks": "x"}] + out = aggregate_table_doc_metadata(chunks, task) + assert out == {"country": ["Brazil"]} + + def test_aggregate_no_mode_no_columns_returns_empty(self, es_engine): + """No mode and no column names/roles -> empty (nothing to aggregate).""" + task = { + "parser_id": "table", + "parser_config": {}, + "kb_parser_config": {}, + } + assert aggregate_table_doc_metadata([{}], task) == {} def test_aggregate_no_metadata_columns(self, es_engine): task = { @@ -108,14 +142,14 @@ class TestAggregateTableManualDocMetadata: "table_column_names": ["country"], }, } - assert aggregate_table_manual_doc_metadata([{"country_tks": "x"}], task) == {} + assert aggregate_table_doc_metadata([{"country_tks": "x"}], task) == {} def test_aggregate_prefers_raw_over_tks(self, es_engine): task = _table_task() task["kb_parser_config"]["table_column_roles"] = {"country": "metadata"} task["kb_parser_config"]["table_column_names"] = ["country"] chunks = [{"country_raw": "Brazil", "country_tks": ["brazil"]}] - out = aggregate_table_manual_doc_metadata(chunks, task) + out = aggregate_table_doc_metadata(chunks, task) assert out == {"country": ["Brazil"]} def test_aggregate_tks_fallback(self, es_engine): @@ -123,7 +157,7 @@ class TestAggregateTableManualDocMetadata: task["kb_parser_config"]["table_column_roles"] = {"country": "metadata"} task["kb_parser_config"]["table_column_names"] = ["country"] chunks = [{"country_tks": ["brazil"]}] - out = aggregate_table_manual_doc_metadata(chunks, task) + out = aggregate_table_doc_metadata(chunks, task) assert out == {"country": ["brazil"]} def test_aggregate_partial_roles_defaults_to_both(self, es_engine): @@ -138,7 +172,7 @@ class TestAggregateTableManualDocMetadata: }, } chunks = [{"city_raw": "SP", "city_tks": "t", "country_tks": "x"}] - out = aggregate_table_manual_doc_metadata(chunks, task) + out = aggregate_table_doc_metadata(chunks, task) assert out == {"city": ["SP"]} assert "country" not in out @@ -156,7 +190,7 @@ class TestAggregateTableManualDocMetadata: chunks = [ {"country_raw": "BR", "city_raw": "SP", "country_tks": "x", "city_tks": "y"}, ] - out = aggregate_table_manual_doc_metadata(chunks, task) + out = aggregate_table_doc_metadata(chunks, task) assert "country" in out and "city" in out def test_aggregate_deduplicates_values(self, es_engine): @@ -168,7 +202,7 @@ class TestAggregateTableManualDocMetadata: {"country_raw": "UK", "country_tks": "y"}, {"country_raw": "US", "country_tks": "x"}, ] - out = aggregate_table_manual_doc_metadata(chunks, task) + out = aggregate_table_doc_metadata(chunks, task) assert out["country"] == ["US", "UK"] def test_aggregate_kb_reload_field_map(self, es_engine, monkeypatch): @@ -197,7 +231,7 @@ class TestAggregateTableManualDocMetadata: "kb_id": "kb-1", } chunks = [{"country_raw": "X", "country_tks": "t"}] - out = aggregate_table_manual_doc_metadata(chunks, task) + out = aggregate_table_doc_metadata(chunks, task) assert out == {"country": ["X"]} def test_merge_infinity_chunk_data(self, infinity_engine): @@ -214,7 +248,7 @@ class TestAggregateTableManualDocMetadata: {"chunk_data": {"country": "US"}}, {"chunk_data": {"country": "UK"}}, ] - out = aggregate_table_manual_doc_metadata(chunks, task) + out = aggregate_table_doc_metadata(chunks, task) assert out == {"country": ["US", "UK"]}