mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
fix: support auto mode in table parser document metadata aggregation (#15780)
### What problem does this PR solve? Table parser metadata aggregation previously only ran when `table_column_mode` was set to `manual`. In auto mode (default), all columns default to `"both"` role, meaning they should also be aggregated into document-level metadata for UI/chat filters. Additionally, the task snapshot could be stale — `table_column_names` are written to KB `parser_config` during `chunk()` but the task may have been created before that. Changes: - Renames `aggregate_table_manual_doc_metadata` → `aggregate_table_doc_metadata` - Supports both `"manual"` and `"auto"` `table_column_mode` (defaults to `"auto"`) - Reloads `table_column_names` from KB DB when missing from task snapshot - Removes the manual-only guard in `task_executor` and refactored `post_processor` - Updates all tests with new function name and adds auto mode test cases ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@@ -16,20 +16,21 @@ import argparse
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from rag.svr.task_executor_refactor.task_manager import TaskManager
|
from rag.svr.task_executor_refactor.task_manager import TaskManager
|
||||||
from rag.svr.task_executor_refactor.recording_context import timed_with_recording, get_recording_context, \
|
from rag.svr.task_executor_refactor.recording_context import timed_with_recording, get_recording_context, RecordingContext, set_recording_context, NullRecordingContext
|
||||||
RecordingContext, set_recording_context, NullRecordingContext
|
|
||||||
|
|
||||||
start_ts = time.time()
|
start_ts = time.time()
|
||||||
|
|
||||||
# LiteLLM fetches a model cost map from GitHub during import unless this is set.
|
# LiteLLM fetches a model cost map from GitHub during import unless this is set.
|
||||||
# Parser pods should not block startup on external network access.
|
# Parser pods should not block startup on external network access.
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "True") # no internet, save about 10s
|
os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "True") # no internet, save about 10s
|
||||||
|
|
||||||
from common.misc_utils import thread_pool_exec
|
from common.misc_utils import thread_pool_exec
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
# from beartype import BeartypeConf
|
# from beartype import BeartypeConf
|
||||||
# from beartype.claw import beartype_all # <-- you didn't sign up for this
|
# from beartype.claw import beartype_all # <-- you didn't sign up for this
|
||||||
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
||||||
@@ -56,8 +57,7 @@ from rag.utils.raptor_utils import (
|
|||||||
from common.log_utils import init_root_logger
|
from common.log_utils import init_root_logger
|
||||||
from common.config_utils import show_configs
|
from common.config_utils import show_configs
|
||||||
from rag.graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
from rag.graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
||||||
from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging, run_toc_from_text, \
|
from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging, run_toc_from_text, gen_metadata
|
||||||
gen_metadata
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -82,8 +82,7 @@ from api.db.services.file2document_service import File2DocumentService
|
|||||||
from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance
|
from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance
|
||||||
from common.versions import get_ragflow_version
|
from common.versions import get_ragflow_version
|
||||||
from api.db.db_models import close_connection
|
from api.db.db_models import close_connection
|
||||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
|
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, email, tag
|
||||||
email, tag
|
|
||||||
from rag.nlp import search, rag_tokenizer, add_positions
|
from rag.nlp import search, rag_tokenizer, add_positions
|
||||||
from rag.raptor import (
|
from rag.raptor import (
|
||||||
RAPTOR_TREE_BUILDER,
|
RAPTOR_TREE_BUILDER,
|
||||||
@@ -103,7 +102,7 @@ from rag.svr.task_executor_limiter import (
|
|||||||
from common import settings
|
from common import settings
|
||||||
from common.constants import PAGERANK_FLD, TAG_FLD, SVR_CONSUMER_GROUP_NAME
|
from common.constants import PAGERANK_FLD, TAG_FLD, SVR_CONSUMER_GROUP_NAME
|
||||||
from rag.utils.table_es_metadata import (
|
from rag.utils.table_es_metadata import (
|
||||||
aggregate_table_manual_doc_metadata,
|
aggregate_table_doc_metadata,
|
||||||
merge_table_parser_config_from_kb,
|
merge_table_parser_config_from_kb,
|
||||||
table_parser_strip_doc_metadata_keys,
|
table_parser_strip_doc_metadata_keys,
|
||||||
)
|
)
|
||||||
@@ -128,7 +127,7 @@ FACTORY = {
|
|||||||
ParserType.AUDIO.value: audio,
|
ParserType.AUDIO.value: audio,
|
||||||
ParserType.EMAIL.value: email,
|
ParserType.EMAIL.value: email,
|
||||||
ParserType.KG.value: naive,
|
ParserType.KG.value: naive,
|
||||||
ParserType.TAG.value: tag
|
ParserType.TAG.value: tag,
|
||||||
}
|
}
|
||||||
|
|
||||||
TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
|
TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
|
||||||
@@ -152,9 +151,10 @@ FAILED_TASKS = 0
|
|||||||
|
|
||||||
CURRENT_TASKS = {}
|
CURRENT_TASKS = {}
|
||||||
|
|
||||||
WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120'))
|
WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get("WORKER_HEARTBEAT_TIMEOUT", "120"))
|
||||||
stop_event = threading.Event()
|
stop_event = threading.Event()
|
||||||
|
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
def signal_handler(sig, frame):
|
||||||
logging.info("Received interrupt signal, shutting down...")
|
logging.info("Received interrupt signal, shutting down...")
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
@@ -269,8 +269,7 @@ async def get_storage_binary(bucket, name):
|
|||||||
@timeout(60 * 80, 1)
|
@timeout(60 * 80, 1)
|
||||||
async def build_chunks(task, progress_callback):
|
async def build_chunks(task, progress_callback):
|
||||||
if task["size"] > settings.DOC_MAXIMUM_SIZE:
|
if task["size"] > settings.DOC_MAXIMUM_SIZE:
|
||||||
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
|
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" % (int(settings.DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
||||||
(int(settings.DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
|
||||||
get_recording_context().record("file_size_exceeded", True)
|
get_recording_context().record("file_size_exceeded", True)
|
||||||
return []
|
return []
|
||||||
get_recording_context().record("file_size_exceeded", False)
|
get_recording_context().record("file_size_exceeded", False)
|
||||||
@@ -286,8 +285,7 @@ async def build_chunks(task, progress_callback):
|
|||||||
logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
|
logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
||||||
logging.exception(
|
logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
|
||||||
"Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
|
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if re.search("(No such file|not found)", str(e)):
|
if re.search("(No such file|not found)", str(e)):
|
||||||
@@ -309,12 +307,11 @@ async def build_chunks(task, progress_callback):
|
|||||||
|
|
||||||
# Record chunk configuration for comparison
|
# Record chunk configuration for comparison
|
||||||
from common.float_utils import normalize_overlapped_percent
|
from common.float_utils import normalize_overlapped_percent
|
||||||
|
|
||||||
chunk_config = {
|
chunk_config = {
|
||||||
"parser_id": task["parser_id"],
|
"parser_id": task["parser_id"],
|
||||||
"chunk_token_num": parser_config_for_chunk.get("chunk_token_num", 128),
|
"chunk_token_num": parser_config_for_chunk.get("chunk_token_num", 128),
|
||||||
"overlapped_percent": normalize_overlapped_percent(
|
"overlapped_percent": normalize_overlapped_percent(parser_config_for_chunk.get("overlapped_percent", 0)),
|
||||||
parser_config_for_chunk.get("overlapped_percent", 0)
|
|
||||||
),
|
|
||||||
"delimiter": parser_config_for_chunk.get("delimiter", "\n!?。;!?"),
|
"delimiter": parser_config_for_chunk.get("delimiter", "\n!?。;!?"),
|
||||||
"from_page": task["from_page"],
|
"from_page": task["from_page"],
|
||||||
"to_page": task["to_page"],
|
"to_page": task["to_page"],
|
||||||
@@ -357,21 +354,14 @@ async def build_chunks(task, progress_callback):
|
|||||||
if cks and cks[0].get("__outline__"):
|
if cks and cks[0].get("__outline__"):
|
||||||
outline = cks[0].pop("__outline__")
|
outline = cks[0].pop("__outline__")
|
||||||
try:
|
try:
|
||||||
ret = DocMetadataService.update_document_metadata(
|
ret = DocMetadataService.update_document_metadata(task["doc_id"], update_metadata_to({"outline": outline}, DocMetadataService.get_document_metadata(task["doc_id"]) or {}))
|
||||||
task["doc_id"],
|
|
||||||
update_metadata_to({"outline": outline},
|
|
||||||
DocMetadataService.get_document_metadata(task["doc_id"]) or {})
|
|
||||||
)
|
|
||||||
get_recording_context().save_func_return_value("DocMetadataService.update_document_metadata", ret)
|
get_recording_context().save_func_return_value("DocMetadataService.update_document_metadata", ret)
|
||||||
logging.info("Persisted PDF outline (%d entries) for doc %s", len(outline), task["doc_id"])
|
logging.info("Persisted PDF outline (%d entries) for doc %s", len(outline), task["doc_id"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning("Failed to persist PDF outline for doc %s: %s", task["doc_id"], e)
|
logging.warning("Failed to persist PDF outline for doc %s: %s", task["doc_id"], e)
|
||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
doc = {
|
doc = {"doc_id": task["doc_id"], "kb_id": str(task["kb_id"])}
|
||||||
"doc_id": task["doc_id"],
|
|
||||||
"kb_id": str(task["kb_id"])
|
|
||||||
}
|
|
||||||
if task["pagerank"]:
|
if task["pagerank"]:
|
||||||
doc[PAGERANK_FLD] = int(task["pagerank"])
|
doc[PAGERANK_FLD] = int(task["pagerank"])
|
||||||
st = timer()
|
st = timer()
|
||||||
@@ -381,8 +371,7 @@ async def build_chunks(task, progress_callback):
|
|||||||
try:
|
try:
|
||||||
d = copy.deepcopy(document)
|
d = copy.deepcopy(document)
|
||||||
d.update(chunk)
|
d.update(chunk)
|
||||||
d["id"] = xxhash.xxh64(
|
d["id"] = xxhash.xxh64((chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
||||||
(chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
|
||||||
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
||||||
d["create_timestamp_flt"] = datetime.now().timestamp()
|
d["create_timestamp_flt"] = datetime.now().timestamp()
|
||||||
|
|
||||||
@@ -398,8 +387,7 @@ async def build_chunks(task, progress_callback):
|
|||||||
await image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=task["tenant_id"]), d["id"], task["kb_id"])
|
await image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=task["tenant_id"]), d["id"], task["kb_id"])
|
||||||
docs.append(d)
|
docs.append(d)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception(
|
logging.exception("Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
|
||||||
"Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
@@ -442,8 +430,7 @@ async def build_chunks(task, progress_callback):
|
|||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
for d in docs:
|
for d in docs:
|
||||||
tasks.append(
|
tasks.append(asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
|
||||||
asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
|
|
||||||
try:
|
try:
|
||||||
await asyncio.gather(*tasks, return_exceptions=False)
|
await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -479,8 +466,7 @@ async def build_chunks(task, progress_callback):
|
|||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
for d in docs:
|
for d in docs:
|
||||||
tasks.append(
|
tasks.append(asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
|
||||||
asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
|
|
||||||
try:
|
try:
|
||||||
await asyncio.gather(*tasks, return_exceptions=False)
|
await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -519,18 +505,14 @@ async def build_chunks(task, progress_callback):
|
|||||||
metadata_conf = metadata_conf + built_in_metadata
|
metadata_conf = metadata_conf + built_in_metadata
|
||||||
else:
|
else:
|
||||||
metadata_conf = built_in_metadata
|
metadata_conf = built_in_metadata
|
||||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata",
|
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata", metadata_conf)
|
||||||
metadata_conf)
|
|
||||||
if not cached:
|
if not cached:
|
||||||
if has_canceled(task["id"]):
|
if has_canceled(task["id"]):
|
||||||
progress_callback(-1, msg="Task has been canceled.")
|
progress_callback(-1, msg="Task has been canceled.")
|
||||||
return
|
return
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
cached = await gen_metadata(chat_mdl,
|
cached = await gen_metadata(chat_mdl, turn2jsonschema(metadata_conf), d["content_with_weight"])
|
||||||
turn2jsonschema(metadata_conf),
|
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata", metadata_conf)
|
||||||
d["content_with_weight"])
|
|
||||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata",
|
|
||||||
metadata_conf)
|
|
||||||
if cached:
|
if cached:
|
||||||
d["metadata_obj"] = cached
|
d["metadata_obj"] = cached
|
||||||
|
|
||||||
@@ -584,8 +566,7 @@ async def build_chunks(task, progress_callback):
|
|||||||
if task_canceled:
|
if task_canceled:
|
||||||
progress_callback(-1, msg="Task has been canceled.")
|
progress_callback(-1, msg="Task has been canceled.")
|
||||||
return None
|
return None
|
||||||
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(
|
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
|
||||||
d[TAG_FLD]) > 0:
|
|
||||||
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
|
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
|
||||||
else:
|
else:
|
||||||
docs_to_tag.append(d)
|
docs_to_tag.append(d)
|
||||||
@@ -598,7 +579,7 @@ async def build_chunks(task, progress_callback):
|
|||||||
return
|
return
|
||||||
picked_examples = random.choices(examples, k=2) if len(examples) > 2 else examples
|
picked_examples = random.choices(examples, k=2) if len(examples) > 2 else examples
|
||||||
if not picked_examples:
|
if not picked_examples:
|
||||||
picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}})
|
picked_examples.append({"content": "This is an example", TAG_FLD: {"example": 1}})
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
cached = await content_tagging(
|
cached = await content_tagging(
|
||||||
chat_mdl,
|
chat_mdl,
|
||||||
@@ -643,13 +624,15 @@ def build_TOC(task, docs, progress_callback):
|
|||||||
progress_callback(msg="Start to generate table of content ...")
|
progress_callback(msg="Start to generate table of content ...")
|
||||||
chat_model_config = get_model_config_from_provider_instance(task["tenant_id"], LLMType.CHAT, task["llm_id"])
|
chat_model_config = get_model_config_from_provider_instance(task["tenant_id"], LLMType.CHAT, task["llm_id"])
|
||||||
chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"])
|
chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"])
|
||||||
docs = sorted(docs, key=lambda d: (
|
docs = sorted(
|
||||||
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
|
docs,
|
||||||
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
|
key=lambda d: (
|
||||||
))
|
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
|
||||||
toc: list[dict] = asyncio.run(
|
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0),
|
||||||
run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
|
),
|
||||||
logging.info("------------ T O C -------------\n" + json.dumps(toc, ensure_ascii=False, indent=' '))
|
)
|
||||||
|
toc: list[dict] = asyncio.run(run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
|
||||||
|
logging.info("------------ T O C -------------\n" + json.dumps(toc, ensure_ascii=False, indent=" "))
|
||||||
for ii, item in enumerate(toc):
|
for ii, item in enumerate(toc):
|
||||||
try:
|
try:
|
||||||
chunk_val = item.pop("chunk_id", None)
|
chunk_val = item.pop("chunk_id", None)
|
||||||
@@ -680,8 +663,7 @@ def build_TOC(task, docs, progress_callback):
|
|||||||
d["toc_kwd"] = "toc"
|
d["toc_kwd"] = "toc"
|
||||||
d["available_int"] = 0
|
d["available_int"] = 0
|
||||||
d["page_num_int"] = [100000000]
|
d["page_num_int"] = [100000000]
|
||||||
d["id"] = xxhash.xxh64(
|
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
||||||
(d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
|
||||||
return d
|
return d
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -722,7 +704,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
|||||||
cnts_batches = []
|
cnts_batches = []
|
||||||
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
|
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
|
||||||
async with embed_limiter:
|
async with embed_limiter:
|
||||||
vts, c = await thread_pool_exec(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE])
|
vts, c = await thread_pool_exec(batch_encode, cnts[i : i + settings.EMBEDDING_BATCH_SIZE])
|
||||||
cnts_batches.append(vts)
|
cnts_batches.append(vts)
|
||||||
tk_count += c
|
tk_count += c
|
||||||
callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
|
callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
|
||||||
@@ -775,8 +757,7 @@ async def run_dataflow(task: dict):
|
|||||||
if not chunks:
|
if not chunks:
|
||||||
get_recording_context().record("pipeline_output_count", 0)
|
get_recording_context().record("pipeline_output_count", 0)
|
||||||
get_recording_context().record("pipeline_output_type", "empty")
|
get_recording_context().record("pipeline_output_type", "empty")
|
||||||
ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
|
ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||||
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
|
||||||
get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret)
|
get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -806,8 +787,7 @@ async def run_dataflow(task: dict):
|
|||||||
|
|
||||||
# An empty normalized payload means "nothing parsed", so stop before embedding/indexing.
|
# An empty normalized payload means "nothing parsed", so stop before embedding/indexing.
|
||||||
if not chunks:
|
if not chunks:
|
||||||
ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
|
ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||||
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
|
||||||
get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret)
|
get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -831,7 +811,7 @@ async def run_dataflow(task: dict):
|
|||||||
prog = 0.8
|
prog = 0.8
|
||||||
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
|
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
|
||||||
async with embed_limiter:
|
async with embed_limiter:
|
||||||
vts, c = await thread_pool_exec(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE])
|
vts, c = await thread_pool_exec(batch_encode, texts[i : i + settings.EMBEDDING_BATCH_SIZE])
|
||||||
vects_batches.append(vts)
|
vects_batches.append(vts)
|
||||||
embedding_token_consumption += c
|
embedding_token_consumption += c
|
||||||
prog += delta
|
prog += delta
|
||||||
@@ -849,8 +829,7 @@ async def run_dataflow(task: dict):
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
set_progress(task_id, prog=-1, msg=f"[ERROR]: {e}")
|
set_progress(task_id, prog=-1, msg=f"[ERROR]: {e}")
|
||||||
ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
|
ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||||
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
|
||||||
get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret)
|
get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -900,24 +879,21 @@ async def run_dataflow(task: dict):
|
|||||||
set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...")
|
set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...")
|
||||||
e = await insert_chunks(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
|
e = await insert_chunks(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
|
||||||
if not e:
|
if not e:
|
||||||
ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
|
ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||||
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
|
||||||
get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret)
|
get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret)
|
||||||
return
|
return
|
||||||
|
|
||||||
time_cost = timer() - start_ts
|
time_cost = timer() - start_ts
|
||||||
task_time_cost = timer() - task_start_ts
|
task_time_cost = timer() - task_start_ts
|
||||||
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
|
set_progress(task_id, prog=1.0, msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
|
||||||
ret = DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks),
|
ret = DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks), task_time_cost)
|
||||||
task_time_cost)
|
|
||||||
get_recording_context().save_func_return_value("DocumentService.increment_chunk_num", ret)
|
get_recording_context().save_func_return_value("DocumentService.increment_chunk_num", ret)
|
||||||
logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption,
|
logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
|
||||||
task_time_cost))
|
|
||||||
get_recording_context().record("dataflow_chunks", chunks)
|
get_recording_context().record("dataflow_chunks", chunks)
|
||||||
ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE,
|
ret = PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||||
dsl=str(pipeline))
|
|
||||||
get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret)
|
get_recording_context().save_func_return_value("PipelineOperationLogService.create", ret)
|
||||||
|
|
||||||
|
|
||||||
RAPTOR_METHOD_SEARCH_LIMIT = 10000
|
RAPTOR_METHOD_SEARCH_LIMIT = 10000
|
||||||
|
|
||||||
|
|
||||||
@@ -928,11 +904,7 @@ async def get_raptor_chunk_field_map(doc_id: str, tenant_id: str, kb_id: str) ->
|
|||||||
|
|
||||||
async def search_fields(fields: list[str], condition: dict, order_by=None):
|
async def search_fields(fields: list[str], condition: dict, order_by=None):
|
||||||
"""Search chunk fields in the current knowledge base."""
|
"""Search chunk fields in the current knowledge base."""
|
||||||
res = await thread_pool_exec(
|
res = await thread_pool_exec(settings.docStoreConn.search, fields, [], condition, [], order_by or OrderByExpr(), 0, RAPTOR_METHOD_SEARCH_LIMIT, nlp_search.index_name(tenant_id), [kb_id])
|
||||||
settings.docStoreConn.search,
|
|
||||||
fields, [], condition, [], order_by or OrderByExpr(),
|
|
||||||
0, RAPTOR_METHOD_SEARCH_LIMIT, nlp_search.index_name(tenant_id), [kb_id]
|
|
||||||
)
|
|
||||||
return settings.docStoreConn.get_fields(res, fields)
|
return settings.docStoreConn.get_fields(res, fields)
|
||||||
|
|
||||||
primary = await search_fields(["raptor_kwd", "extra"], {"doc_id": doc_id, "raptor_kwd": ["raptor"]})
|
primary = await search_fields(["raptor_kwd", "extra"], {"doc_id": doc_id, "raptor_kwd": ["raptor"]})
|
||||||
@@ -963,12 +935,17 @@ async def get_raptor_chunk_methods(doc_id: str, tenant_id: str, kb_id: str) -> s
|
|||||||
if methods:
|
if methods:
|
||||||
logging.info(
|
logging.info(
|
||||||
"Checkpoint hit: RAPTOR chunks for doc %s (tenant=%s kb=%s methods=%s) already exist",
|
"Checkpoint hit: RAPTOR chunks for doc %s (tenant=%s kb=%s methods=%s) already exist",
|
||||||
doc_id, tenant_id, kb_id, sorted(methods),
|
doc_id,
|
||||||
|
tenant_id,
|
||||||
|
kb_id,
|
||||||
|
sorted(methods),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info(
|
logging.info(
|
||||||
"Checkpoint miss: no RAPTOR chunks for doc %s (tenant=%s kb=%s)",
|
"Checkpoint miss: no RAPTOR chunks for doc %s (tenant=%s kb=%s)",
|
||||||
doc_id, tenant_id, kb_id,
|
doc_id,
|
||||||
|
tenant_id,
|
||||||
|
kb_id,
|
||||||
)
|
)
|
||||||
return methods
|
return methods
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -987,7 +964,9 @@ async def delete_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str, keep_met
|
|||||||
if keep_method is None:
|
if keep_method is None:
|
||||||
logging.info(
|
logging.info(
|
||||||
"delete_raptor_chunks: removing all RAPTOR summaries (doc=%s tenant=%s kb=%s)",
|
"delete_raptor_chunks: removing all RAPTOR summaries (doc=%s tenant=%s kb=%s)",
|
||||||
doc_id, tenant_id, kb_id,
|
doc_id,
|
||||||
|
tenant_id,
|
||||||
|
kb_id,
|
||||||
)
|
)
|
||||||
ret = await thread_pool_exec(
|
ret = await thread_pool_exec(
|
||||||
settings.docStoreConn.delete,
|
settings.docStoreConn.delete,
|
||||||
@@ -1003,13 +982,20 @@ async def delete_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str, keep_met
|
|||||||
if not chunk_ids:
|
if not chunk_ids:
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"delete_raptor_chunks: no stale RAPTOR chunks to remove (doc=%s tenant=%s kb=%s keep=%s)",
|
"delete_raptor_chunks: no stale RAPTOR chunks to remove (doc=%s tenant=%s kb=%s keep=%s)",
|
||||||
doc_id, tenant_id, kb_id, keep_method,
|
doc_id,
|
||||||
|
tenant_id,
|
||||||
|
kb_id,
|
||||||
|
keep_method,
|
||||||
)
|
)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
"delete_raptor_chunks: removing %d stale RAPTOR chunks (doc=%s tenant=%s kb=%s keep=%s)",
|
"delete_raptor_chunks: removing %d stale RAPTOR chunks (doc=%s tenant=%s kb=%s keep=%s)",
|
||||||
len(chunk_ids), doc_id, tenant_id, kb_id, keep_method,
|
len(chunk_ids),
|
||||||
|
doc_id,
|
||||||
|
tenant_id,
|
||||||
|
kb_id,
|
||||||
|
keep_method,
|
||||||
)
|
)
|
||||||
ret = await thread_pool_exec(
|
ret = await thread_pool_exec(
|
||||||
settings.docStoreConn.delete,
|
settings.docStoreConn.delete,
|
||||||
@@ -1073,6 +1059,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
|||||||
nonlocal tk_count, res
|
nonlocal tk_count, res
|
||||||
logging.info("RAPTOR: using tree_builder=%s clustering_method=%s for doc %s", tree_builder, clustering_method, did)
|
logging.info("RAPTOR: using tree_builder=%s clustering_method=%s for doc %s", tree_builder, clustering_method, did)
|
||||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor # Lazy load, save around 8s
|
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor # Lazy load, save around 8s
|
||||||
|
|
||||||
raptor = Raptor(
|
raptor = Raptor(
|
||||||
raptor_config.get("max_cluster", 64),
|
raptor_config.get("max_cluster", 64),
|
||||||
chat_mdl,
|
chat_mdl,
|
||||||
@@ -1132,7 +1119,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
|||||||
|
|
||||||
for x, doc_id in enumerate(doc_ids):
|
for x, doc_id in enumerate(doc_ids):
|
||||||
if skip_raptor_doc(doc_id):
|
if skip_raptor_doc(doc_id):
|
||||||
callback(prog=(x + 1.) / len(doc_ids))
|
callback(prog=(x + 1.0) / len(doc_ids))
|
||||||
continue
|
continue
|
||||||
# CHECKPOINT: skip docs that already have RAPTOR chunks in the doc store
|
# CHECKPOINT: skip docs that already have RAPTOR chunks in the doc store
|
||||||
existing_methods = await get_raptor_chunk_methods(doc_id, row["tenant_id"], row["kb_id"])
|
existing_methods = await get_raptor_chunk_methods(doc_id, row["tenant_id"], row["kb_id"])
|
||||||
@@ -1142,16 +1129,14 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
|||||||
schedule_raptor_cleanup(doc_id, tree_builder)
|
schedule_raptor_cleanup(doc_id, tree_builder)
|
||||||
callback(msg=f"[RAPTOR] doc:{doc_id} will remove old RAPTOR summaries after insert.")
|
callback(msg=f"[RAPTOR] doc:{doc_id} will remove old RAPTOR summaries after insert.")
|
||||||
callback(msg=f"[RAPTOR] doc:{doc_id} already has {tree_builder} RAPTOR chunks, skipping.")
|
callback(msg=f"[RAPTOR] doc:{doc_id} already has {tree_builder} RAPTOR chunks, skipping.")
|
||||||
callback(prog=(x + 1.) / len(doc_ids))
|
callback(prog=(x + 1.0) / len(doc_ids))
|
||||||
continue
|
continue
|
||||||
if existing_methods:
|
if existing_methods:
|
||||||
callback(msg=f"[RAPTOR] doc:{doc_id} will migrate RAPTOR summaries to {tree_builder} after insert.")
|
callback(msg=f"[RAPTOR] doc:{doc_id} will migrate RAPTOR summaries to {tree_builder} after insert.")
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
skipped_chunks = 0
|
skipped_chunks = 0
|
||||||
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
|
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm], sort_by_position=True):
|
||||||
fields=["content_with_weight", vctr_nm],
|
|
||||||
sort_by_position=True):
|
|
||||||
# Skip chunks that don't have the required vector field (may have been indexed with different embedding model)
|
# Skip chunks that don't have the required vector field (may have been indexed with different embedding model)
|
||||||
if vctr_nm not in d or d[vctr_nm] is None:
|
if vctr_nm not in d or d[vctr_nm] is None:
|
||||||
skipped_chunks += 1
|
skipped_chunks += 1
|
||||||
@@ -1173,7 +1158,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
|||||||
has_file_level_target = True
|
has_file_level_target = True
|
||||||
if existing_methods:
|
if existing_methods:
|
||||||
schedule_raptor_cleanup(doc_id, tree_builder)
|
schedule_raptor_cleanup(doc_id, tree_builder)
|
||||||
callback(prog=(x + 1.) / len(doc_ids))
|
callback(prog=(x + 1.0) / len(doc_ids))
|
||||||
|
|
||||||
if remove_dataset_summaries:
|
if remove_dataset_summaries:
|
||||||
if has_file_level_target:
|
if has_file_level_target:
|
||||||
@@ -1213,9 +1198,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
|||||||
for doc_id in doc_ids:
|
for doc_id in doc_ids:
|
||||||
if doc_id in skipped_doc_ids:
|
if doc_id in skipped_doc_ids:
|
||||||
continue
|
continue
|
||||||
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
|
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm], sort_by_position=True):
|
||||||
fields=["content_with_weight", vctr_nm],
|
|
||||||
sort_by_position=True):
|
|
||||||
# Skip chunks that don't have the required vector field
|
# Skip chunks that don't have the required vector field
|
||||||
if vctr_nm not in d or d[vctr_nm] is None:
|
if vctr_nm not in d or d[vctr_nm] is None:
|
||||||
skipped_chunks += 1
|
skipped_chunks += 1
|
||||||
@@ -1283,14 +1266,17 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre
|
|||||||
mom_ck["available_int"] = 0
|
mom_ck["available_int"] = 0
|
||||||
flds = list(mom_ck.keys())
|
flds = list(mom_ck.keys())
|
||||||
for fld in flds:
|
for fld in flds:
|
||||||
if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int",
|
if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int", "position_int", "create_timestamp_flt", "page_num_int", "top_int"]:
|
||||||
"position_int", "create_timestamp_flt", "page_num_int", "top_int"]:
|
|
||||||
del mom_ck[fld]
|
del mom_ck[fld]
|
||||||
mothers.append(mom_ck)
|
mothers.append(mom_ck)
|
||||||
|
|
||||||
for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
|
for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
|
||||||
ret = await thread_pool_exec(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
|
ret = await thread_pool_exec(
|
||||||
search.index_name(task_tenant_id), task_dataset_id, )
|
settings.docStoreConn.insert,
|
||||||
|
mothers[b : b + settings.DOC_BULK_SIZE],
|
||||||
|
search.index_name(task_tenant_id),
|
||||||
|
task_dataset_id,
|
||||||
|
)
|
||||||
get_recording_context().save_func_return_value("docStoreConn.insert", ret)
|
get_recording_context().save_func_return_value("docStoreConn.insert", ret)
|
||||||
task_canceled = has_canceled(task_id)
|
task_canceled = has_canceled(task_id)
|
||||||
if task_canceled:
|
if task_canceled:
|
||||||
@@ -1298,17 +1284,18 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
|
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
|
||||||
doc_store_result = await thread_pool_exec(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
|
doc_store_result = await thread_pool_exec(
|
||||||
search.index_name(task_tenant_id), task_dataset_id, )
|
settings.docStoreConn.insert,
|
||||||
|
chunks[b : b + settings.DOC_BULK_SIZE],
|
||||||
|
search.index_name(task_tenant_id),
|
||||||
|
task_dataset_id,
|
||||||
|
)
|
||||||
get_recording_context().save_func_return_value("docStoreConn.insert", doc_store_result)
|
get_recording_context().save_func_return_value("docStoreConn.insert", doc_store_result)
|
||||||
task_canceled = has_canceled(task_id)
|
task_canceled = has_canceled(task_id)
|
||||||
if task_canceled:
|
if task_canceled:
|
||||||
# Roll back partial RAPTOR summary inserts so the next run is not
|
# Roll back partial RAPTOR summary inserts so the next run is not
|
||||||
# mistaken for a completed checkpoint by get_raptor_chunk_methods.
|
# mistaken for a completed checkpoint by get_raptor_chunk_methods.
|
||||||
raptor_ids_to_rollback = [
|
raptor_ids_to_rollback = [c["id"] for c in chunks[: b + settings.DOC_BULK_SIZE] if c.get("raptor_kwd") == "raptor"]
|
||||||
c["id"] for c in chunks[:b + settings.DOC_BULK_SIZE]
|
|
||||||
if c.get("raptor_kwd") == "raptor"
|
|
||||||
]
|
|
||||||
if raptor_ids_to_rollback:
|
if raptor_ids_to_rollback:
|
||||||
try:
|
try:
|
||||||
ret = await thread_pool_exec(
|
ret = await thread_pool_exec(
|
||||||
@@ -1320,7 +1307,8 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre
|
|||||||
get_recording_context().save_func_return_value("docStoreConn.delete", ret)
|
get_recording_context().save_func_return_value("docStoreConn.delete", ret)
|
||||||
logging.info(
|
logging.info(
|
||||||
"insert_chunks: rolled back %d partial RAPTOR chunks after cancellation (task=%s)",
|
"insert_chunks: rolled back %d partial RAPTOR chunks after cancellation (task=%s)",
|
||||||
len(raptor_ids_to_rollback), task_id,
|
len(raptor_ids_to_rollback),
|
||||||
|
task_id,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception(
|
logging.exception(
|
||||||
@@ -1335,15 +1323,19 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre
|
|||||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||||
progress_callback(-1, msg=error_message)
|
progress_callback(-1, msg=error_message)
|
||||||
raise Exception(error_message)
|
raise Exception(error_message)
|
||||||
chunk_ids = [chunk["id"] for chunk in chunks[:b + settings.DOC_BULK_SIZE]]
|
chunk_ids = [chunk["id"] for chunk in chunks[: b + settings.DOC_BULK_SIZE]]
|
||||||
chunk_ids_str = " ".join(chunk_ids)
|
chunk_ids_str = " ".join(chunk_ids)
|
||||||
try:
|
try:
|
||||||
TaskService.update_chunk_ids(task_id, chunk_ids_str)
|
TaskService.update_chunk_ids(task_id, chunk_ids_str)
|
||||||
get_recording_context().save_func_return_value("TaskService.update_chunk_ids", None)
|
get_recording_context().save_func_return_value("TaskService.update_chunk_ids", None)
|
||||||
except DoesNotExist:
|
except DoesNotExist:
|
||||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
|
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
|
||||||
doc_store_result = await thread_pool_exec(settings.docStoreConn.delete, {"id": chunk_ids},
|
doc_store_result = await thread_pool_exec(
|
||||||
search.index_name(task_tenant_id), task_dataset_id, )
|
settings.docStoreConn.delete,
|
||||||
|
{"id": chunk_ids},
|
||||||
|
search.index_name(task_tenant_id),
|
||||||
|
task_dataset_id,
|
||||||
|
)
|
||||||
get_recording_context().save_func_return_value("docStoreConn.delete", doc_store_result)
|
get_recording_context().save_func_return_value("docStoreConn.delete", doc_store_result)
|
||||||
tasks = []
|
tasks = []
|
||||||
for chunk_id in chunk_ids:
|
for chunk_id in chunk_ids:
|
||||||
@@ -1383,8 +1375,8 @@ async def do_handle_task(task):
|
|||||||
if not task.get("language"):
|
if not task.get("language"):
|
||||||
logging.warning("Task %s has no language set, falling back to Chinese", task_id)
|
logging.warning("Task %s has no language set, falling back to Chinese", task_id)
|
||||||
doc_task_llm_id = task["parser_config"].get("llm_id") or task["llm_id"]
|
doc_task_llm_id = task["parser_config"].get("llm_id") or task["llm_id"]
|
||||||
kb_task_llm_id = task['kb_parser_config'].get("llm_id") or task["llm_id"]
|
kb_task_llm_id = task["kb_parser_config"].get("llm_id") or task["llm_id"]
|
||||||
task['llm_id'] = kb_task_llm_id
|
task["llm_id"] = kb_task_llm_id
|
||||||
task_dataset_id = task["kb_id"]
|
task_dataset_id = task["kb_id"]
|
||||||
task_doc_id = task["doc_id"]
|
task_doc_id = task["doc_id"]
|
||||||
task_document_name = task["name"]
|
task_document_name = task["name"]
|
||||||
@@ -1411,14 +1403,14 @@ async def do_handle_task(task):
|
|||||||
vts, _ = embedding_model.encode(["ok"])
|
vts, _ = embedding_model.encode(["ok"])
|
||||||
vector_size = len(vts[0])
|
vector_size = len(vts[0])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = f'Fail to bind embedding model: {str(e)}'
|
error_message = f"Fail to bind embedding model: {str(e)}"
|
||||||
progress_callback(-1, msg=error_message)
|
progress_callback(-1, msg=error_message)
|
||||||
logging.exception(error_message)
|
logging.exception(error_message)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
init_kb(task, vector_size)
|
init_kb(task, vector_size)
|
||||||
|
|
||||||
if task_type[:len("dataflow")] == "dataflow":
|
if task_type[: len("dataflow")] == "dataflow":
|
||||||
await run_dataflow(task)
|
await run_dataflow(task)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1517,7 +1509,8 @@ async def do_handle_task(task):
|
|||||||
with_community = graphrag_conf.get("community", False)
|
with_community = graphrag_conf.get("community", False)
|
||||||
async with kg_limiter:
|
async with kg_limiter:
|
||||||
# await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
|
# await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
|
||||||
from rag.graphrag.general.index import run_graphrag_for_kb # Lazy load, save around 2s
|
from rag.graphrag.general.index import run_graphrag_for_kb # Lazy load, save around 2s
|
||||||
|
|
||||||
result = await run_graphrag_for_kb(
|
result = await run_graphrag_for_kb(
|
||||||
row=task,
|
row=task,
|
||||||
doc_ids=task.get("doc_ids", []),
|
doc_ids=task.get("doc_ids", []),
|
||||||
@@ -1539,7 +1532,7 @@ async def do_handle_task(task):
|
|||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
# Standard chunking methods
|
# Standard chunking methods
|
||||||
task['llm_id'] = doc_task_llm_id
|
task["llm_id"] = doc_task_llm_id
|
||||||
start_ts = timer()
|
start_ts = timer()
|
||||||
chunks = await build_chunks(task, progress_callback)
|
chunks = await build_chunks(task, progress_callback)
|
||||||
get_recording_context().record("chunks", chunks)
|
get_recording_context().record("chunks", chunks)
|
||||||
@@ -1549,7 +1542,7 @@ async def do_handle_task(task):
|
|||||||
# Record chunks array for content comparison (first, middle, last, random)
|
# Record chunks array for content comparison (first, middle, last, random)
|
||||||
logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
|
logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
|
||||||
if not chunks:
|
if not chunks:
|
||||||
progress_callback(1., msg=f"No chunk built from {task_document_name}")
|
progress_callback(1.0, msg=f"No chunk built from {task_document_name}")
|
||||||
return
|
return
|
||||||
progress_callback(msg="Generate {} chunks".format(len(chunks)))
|
progress_callback(msg="Generate {} chunks".format(len(chunks)))
|
||||||
start_ts = timer()
|
start_ts = timer()
|
||||||
@@ -1605,52 +1598,45 @@ async def do_handle_task(task):
|
|||||||
if cleaned_chunks:
|
if cleaned_chunks:
|
||||||
progress_callback(msg=f"Cleaned up {cleaned_chunks} stale RAPTOR chunks.")
|
progress_callback(msg=f"Cleaned up {cleaned_chunks} stale RAPTOR chunks.")
|
||||||
|
|
||||||
logging.info(
|
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts))
|
||||||
"Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(
|
|
||||||
task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
ret = DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
|
ret = DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
|
||||||
get_recording_context().save_func_return_value("DocumentService.increment_chunk_num", ret)
|
get_recording_context().save_func_return_value("DocumentService.increment_chunk_num", ret)
|
||||||
|
|
||||||
# Table parser (manual): push metadata/both column values to document-level metadata for UI / chat filters
|
# Table parser: push metadata/both column values to document-level metadata for UI / chat filters
|
||||||
if task.get("parser_id", "").lower() == "table":
|
if task.get("parser_id", "").lower() == "table":
|
||||||
eff_pc = merge_table_parser_config_from_kb(task)
|
eff_pc = merge_table_parser_config_from_kb(task)
|
||||||
logging.debug(
|
logging.debug(f"[TABLE_META_DEBUG] table post-index: table_column_mode={eff_pc.get('table_column_mode')!r}")
|
||||||
f"[TABLE_META_DEBUG] table post-index: table_column_mode={eff_pc.get('table_column_mode')!r}"
|
try:
|
||||||
)
|
agg = aggregate_table_doc_metadata(chunks, task)
|
||||||
if eff_pc.get("table_column_mode") == "manual":
|
logging.debug(f"[TABLE_META_DEBUG] aggregated metadata: {agg}")
|
||||||
|
strip_keys = table_parser_strip_doc_metadata_keys(eff_pc)
|
||||||
|
existing = DocMetadataService.get_document_metadata(task_doc_id)
|
||||||
|
existing = existing if isinstance(existing, dict) else {}
|
||||||
|
preserved = {k: v for k, v in existing.items() if k not in strip_keys}
|
||||||
|
merged = update_metadata_to(dict(preserved), agg)
|
||||||
|
logging.debug(
|
||||||
|
f"[TABLE_META_DEBUG] calling update_document_metadata for doc_id={task_doc_id}, "
|
||||||
|
f"meta_fields keys={list(merged.keys())}, "
|
||||||
|
f"table_strip_key_count={len(strip_keys)}, agg_keys={list(agg.keys())}"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
agg = aggregate_table_manual_doc_metadata(chunks, task)
|
ret = DocMetadataService.update_document_metadata(task_doc_id, merged)
|
||||||
logging.debug(f"[TABLE_META_DEBUG] aggregated metadata: {agg}")
|
get_recording_context().save_func_return_value("DocMetadataService.update_document_metadata", ret)
|
||||||
strip_keys = table_parser_strip_doc_metadata_keys(eff_pc)
|
logging.debug("[TABLE_META_DEBUG] update_document_metadata succeeded")
|
||||||
existing = DocMetadataService.get_document_metadata(task_doc_id)
|
except Exception as ue:
|
||||||
existing = existing if isinstance(existing, dict) else {}
|
logging.error(
|
||||||
preserved = {k: v for k, v in existing.items() if k not in strip_keys}
|
"update_document_metadata failed (table parser, doc_id=%s): %s",
|
||||||
merged = update_metadata_to(dict(preserved), agg)
|
|
||||||
logging.debug(
|
|
||||||
f"[TABLE_META_DEBUG] calling update_document_metadata for doc_id={task_doc_id}, "
|
|
||||||
f"meta_fields keys={list(merged.keys())}, "
|
|
||||||
f"table_strip_key_count={len(strip_keys)}, agg_keys={list(agg.keys())}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
ret = DocMetadataService.update_document_metadata(task_doc_id, merged)
|
|
||||||
get_recording_context().save_func_return_value("DocMetadataService.update_document_metadata", ret)
|
|
||||||
logging.debug("[TABLE_META_DEBUG] update_document_metadata succeeded")
|
|
||||||
except Exception as ue:
|
|
||||||
logging.error(
|
|
||||||
"update_document_metadata failed (table parser, doc_id=%s): %s",
|
|
||||||
task_doc_id,
|
|
||||||
ue,
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception(
|
|
||||||
"Table parser document metadata aggregation failed (doc_id=%s): %s",
|
|
||||||
task_doc_id,
|
task_doc_id,
|
||||||
e,
|
ue,
|
||||||
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(
|
||||||
|
"Table parser document metadata aggregation failed (doc_id=%s): %s",
|
||||||
|
task_doc_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
progress_callback(msg="Indexing done ({:.2f}s).".format(timer() - start_ts))
|
progress_callback(msg="Indexing done ({:.2f}s).".format(timer() - start_ts))
|
||||||
|
|
||||||
@@ -1672,11 +1658,7 @@ async def do_handle_task(task):
|
|||||||
task_time_cost = timer() - task_start_ts
|
task_time_cost = timer() - task_start_ts
|
||||||
get_recording_context().record("task_status", "completed")
|
get_recording_context().record("task_status", "completed")
|
||||||
progress_callback(prog=1.0, msg="Task done ({:.2f}s)".format(task_time_cost))
|
progress_callback(prog=1.0, msg="Task done ({:.2f}s)".format(task_time_cost))
|
||||||
logging.info(
|
logging.info("Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), token_count, task_time_cost))
|
||||||
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(
|
|
||||||
task_document_name, task_from_page, task_to_page, len(chunks), token_count, task_time_cost
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if toc_thread is not None and not toc_thread.done():
|
if toc_thread is not None and not toc_thread.done():
|
||||||
@@ -1697,8 +1679,7 @@ async def do_handle_task(task):
|
|||||||
)
|
)
|
||||||
get_recording_context().save_func_return_value("docStoreConn.delete", ret)
|
get_recording_context().save_func_return_value("docStoreConn.delete", ret)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(
|
logging.exception(f"Remove doc({task_doc_id}) from docStore failed when task({task_id}) canceled, exception: {e}")
|
||||||
f"Remove doc({task_doc_id}) from docStore failed when task({task_id}) canceled, exception: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_task():
|
async def handle_task():
|
||||||
@@ -1709,8 +1690,7 @@ async def handle_task():
|
|||||||
return
|
return
|
||||||
|
|
||||||
task_type = task["task_type"]
|
task_type = task["task_type"]
|
||||||
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type,
|
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, PipelineTaskType.PARSE) or PipelineTaskType.PARSE
|
||||||
PipelineTaskType.PARSE) or PipelineTaskType.PARSE
|
|
||||||
task_id = task["id"]
|
task_id = task["id"]
|
||||||
try:
|
try:
|
||||||
CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
|
CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
|
||||||
@@ -1718,20 +1698,18 @@ async def handle_task():
|
|||||||
logging.info(f"TE_RUN_MODE is {run_mode}")
|
logging.info(f"TE_RUN_MODE is {run_mode}")
|
||||||
|
|
||||||
# Check if dry-run comparison is enabled via environment variable
|
# Check if dry-run comparison is enabled via environment variable
|
||||||
if run_mode == "1": # dry run mode - compare
|
if run_mode == "1": # dry run mode - compare
|
||||||
set_recording_context(RecordingContext())
|
set_recording_context(RecordingContext())
|
||||||
await do_handle_task(task) # original execution
|
await do_handle_task(task) # original execution
|
||||||
# dry run mode
|
# dry run mode
|
||||||
logging.info(f"-----dry run task:{task_id}, {task.get('name', '')}, doc id:{task.get('doc_id', '')}")
|
logging.info(f"-----dry run task:{task_id}, {task.get('name', '')}, doc id:{task.get('doc_id', '')}")
|
||||||
await TaskManager.dry_run_task(task, get_recording_context(), chat_limiter, minio_limiter, chunk_limiter,
|
await TaskManager.dry_run_task(task, get_recording_context(), chat_limiter, minio_limiter, chunk_limiter, embed_limiter, kg_limiter, set_progress, has_canceled)
|
||||||
embed_limiter,kg_limiter, set_progress, has_canceled)
|
elif run_mode == "0": # use refactor-ed version
|
||||||
elif run_mode == "0": # use refactor-ed version
|
|
||||||
# switch to refactor-ed version
|
# switch to refactor-ed version
|
||||||
logging.info(f"-----run refactor-ed task executor:{task_id}, {task.get('name', '')}, doc id:{task.get('doc_id', '')}")
|
logging.info(f"-----run refactor-ed task executor:{task_id}, {task.get('name', '')}, doc id:{task.get('doc_id', '')}")
|
||||||
set_recording_context(NullRecordingContext())
|
set_recording_context(NullRecordingContext())
|
||||||
await TaskManager.run_refactored_task(task, chat_limiter, minio_limiter, chunk_limiter,
|
await TaskManager.run_refactored_task(task, chat_limiter, minio_limiter, chunk_limiter, embed_limiter, kg_limiter, set_progress, has_canceled)
|
||||||
embed_limiter,kg_limiter, set_progress, has_canceled)
|
else: # original version
|
||||||
else: # original version
|
|
||||||
logging.info(f"-----run original task executor:{task_id}, {task.get('name', '')}, doc id:{task.get('doc_id', '')}")
|
logging.info(f"-----run original task executor:{task_id}, {task.get('name', '')}, doc id:{task.get('doc_id', '')}")
|
||||||
set_recording_context(NullRecordingContext())
|
set_recording_context(NullRecordingContext())
|
||||||
await do_handle_task(task)
|
await do_handle_task(task)
|
||||||
@@ -1742,9 +1720,7 @@ async def handle_task():
|
|||||||
except TaskCanceledException as e:
|
except TaskCanceledException as e:
|
||||||
DONE_TASKS += 1
|
DONE_TASKS += 1
|
||||||
CURRENT_TASKS.pop(task_id, None)
|
CURRENT_TASKS.pop(task_id, None)
|
||||||
logging.info(
|
logging.info(f"handle_task canceled for task {task_id}: {getattr(e, 'msg', str(e))}")
|
||||||
f"handle_task canceled for task {task_id}: {getattr(e, 'msg', str(e))}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
FAILED_TASKS += 1
|
FAILED_TASKS += 1
|
||||||
CURRENT_TASKS.pop(task_id, None)
|
CURRENT_TASKS.pop(task_id, None)
|
||||||
@@ -1752,7 +1728,7 @@ async def handle_task():
|
|||||||
err_msg = str(e)
|
err_msg = str(e)
|
||||||
while isinstance(e, exceptiongroup.ExceptionGroup):
|
while isinstance(e, exceptiongroup.ExceptionGroup):
|
||||||
e = e.exceptions[0]
|
e = e.exceptions[0]
|
||||||
err_msg += ' -- ' + str(e)
|
err_msg += " -- " + str(e)
|
||||||
set_progress(task_id, prog=-1, msg=f"[Exception]: {err_msg}")
|
set_progress(task_id, prog=-1, msg=f"[Exception]: {err_msg}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(f"[Exception]: {str(e)}")
|
logging.exception(f"[Exception]: {str(e)}")
|
||||||
@@ -1763,9 +1739,9 @@ async def handle_task():
|
|||||||
referred_document_id = None
|
referred_document_id = None
|
||||||
if task_type in ["graphrag", "raptor", "mindmap"]:
|
if task_type in ["graphrag", "raptor", "mindmap"]:
|
||||||
referred_document_id = task["doc_ids"][0]
|
referred_document_id = task["doc_ids"][0]
|
||||||
ret = PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="",
|
ret = PipelineOperationLogService.record_pipeline_operation(
|
||||||
task_type=pipeline_task_type,
|
document_id=task["doc_id"], pipeline_id="", task_type=pipeline_task_type, task_id=task_id, referred_document_id=referred_document_id
|
||||||
task_id=task_id, referred_document_id=referred_document_id)
|
)
|
||||||
get_recording_context().save_func_return_value("PipelineOperationLogService.record_pipeline_operation", ret)
|
get_recording_context().save_func_return_value("PipelineOperationLogService.record_pipeline_operation", ret)
|
||||||
|
|
||||||
redis_msg.ack()
|
redis_msg.ack()
|
||||||
@@ -1779,7 +1755,7 @@ async def get_server_ip() -> str:
|
|||||||
return s.getsockname()[0]
|
return s.getsockname()[0]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(str(e))
|
logging.error(str(e))
|
||||||
return 'Unknown'
|
return "Unknown"
|
||||||
|
|
||||||
|
|
||||||
async def report_status():
|
async def report_status():
|
||||||
@@ -1804,18 +1780,20 @@ async def report_status():
|
|||||||
LAG_TASKS = int(group_info.get("lag", 0))
|
LAG_TASKS = int(group_info.get("lag", 0))
|
||||||
|
|
||||||
current = copy.deepcopy(CURRENT_TASKS)
|
current = copy.deepcopy(CURRENT_TASKS)
|
||||||
heartbeat = json.dumps({
|
heartbeat = json.dumps(
|
||||||
"ip_address": ip_address,
|
{
|
||||||
"pid": pid,
|
"ip_address": ip_address,
|
||||||
"name": CONSUMER_NAME,
|
"pid": pid,
|
||||||
"now": now.astimezone().isoformat(timespec="milliseconds"),
|
"name": CONSUMER_NAME,
|
||||||
"boot_at": BOOT_AT,
|
"now": now.astimezone().isoformat(timespec="milliseconds"),
|
||||||
"pending": PENDING_TASKS,
|
"boot_at": BOOT_AT,
|
||||||
"lag": LAG_TASKS,
|
"pending": PENDING_TASKS,
|
||||||
"done": DONE_TASKS,
|
"lag": LAG_TASKS,
|
||||||
"failed": FAILED_TASKS,
|
"done": DONE_TASKS,
|
||||||
"current": current,
|
"failed": FAILED_TASKS,
|
||||||
})
|
"current": current,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Report heartbeat to Redis
|
# Report heartbeat to Redis
|
||||||
try:
|
try:
|
||||||
@@ -1890,17 +1868,17 @@ async def main():
|
|||||||
/___/_/ /_/\__, /\___/____/\__/_/\____/_/ /_/ /____/\___/_/ |___/\___/_/
|
/___/_/ /_/\__, /\___/____/\__/_/\____/_/ /_/ /____/\___/_/ |___/\___/_/
|
||||||
/____/
|
/____/
|
||||||
""")
|
""")
|
||||||
logging.info(f'RAGFlow ingestion version: {get_ragflow_version()}')
|
logging.info(f"RAGFlow ingestion version: {get_ragflow_version()}")
|
||||||
logging.info(f"ENABLE_DRY_RUN_COMPARISON: {os.environ.get("ENABLE_DRY_RUN_COMPARISON", "0")}")
|
logging.info(f"ENABLE_DRY_RUN_COMPARISON: {os.environ.get('ENABLE_DRY_RUN_COMPARISON', '0')}")
|
||||||
show_configs()
|
show_configs()
|
||||||
settings.init_settings()
|
settings.init_settings()
|
||||||
settings.check_and_install_torch()
|
settings.check_and_install_torch()
|
||||||
logging.info(f'default embedding config: {settings.EMBEDDING_CFG}')
|
logging.info(f"default embedding config: {settings.EMBEDDING_CFG}")
|
||||||
settings.print_rag_settings()
|
settings.print_rag_settings()
|
||||||
if sys.platform != "win32":
|
if sys.platform != "win32":
|
||||||
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
|
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
|
||||||
signal.signal(signal.SIGUSR2, stop_tracemalloc)
|
signal.signal(signal.SIGUSR2, stop_tracemalloc)
|
||||||
TRACE_MALLOC_ENABLED = int(os.environ.get('TRACE_MALLOC_ENABLED', "0"))
|
TRACE_MALLOC_ENABLED = int(os.environ.get("TRACE_MALLOC_ENABLED", "0"))
|
||||||
if TRACE_MALLOC_ENABLED:
|
if TRACE_MALLOC_ENABLED:
|
||||||
start_tracemalloc_and_snapshot(None, None)
|
start_tracemalloc_and_snapshot(None, None)
|
||||||
|
|
||||||
@@ -1927,16 +1905,16 @@ async def main():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Parse command line arguments (consistent with SAAS version)
|
# Parse command line arguments (consistent with SAAS version)
|
||||||
parser = argparse.ArgumentParser(description='Task Executor')
|
parser = argparse.ArgumentParser(description="Task Executor")
|
||||||
parser.add_argument("-i", "--index", type=str, default='0')
|
parser.add_argument("-i", "--index", type=str, default="0")
|
||||||
parser.add_argument("-t", "--type", type=str, default="common", help="[common, graphrag, raptor, resume]")
|
parser.add_argument("-t", "--type", type=str, default="common", help="[common, graphrag, raptor, resume]")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Update global variables
|
# Update global variables
|
||||||
TASK_TYPE = args.type
|
TASK_TYPE = args.type
|
||||||
TE_IDX = args.index
|
TE_IDX = args.index
|
||||||
CONSUMER_NAME = f"task_executor_{TASK_TYPE}_{TE_IDX}"
|
CONSUMER_NAME = f"task_executor_{TASK_TYPE}_{TE_IDX}"
|
||||||
|
|
||||||
faulthandler.enable()
|
faulthandler.enable()
|
||||||
init_root_logger(CONSUMER_NAME)
|
init_root_logger(CONSUMER_NAME)
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -28,11 +28,12 @@ from api.db.services.doc_metadata_service import DocMetadataService
|
|||||||
from common.metadata_utils import update_metadata_to
|
from common.metadata_utils import update_metadata_to
|
||||||
from rag.svr.task_executor_refactor.task_context import TaskContext
|
from rag.svr.task_executor_refactor.task_context import TaskContext
|
||||||
from rag.utils.table_es_metadata import (
|
from rag.utils.table_es_metadata import (
|
||||||
aggregate_table_manual_doc_metadata,
|
aggregate_table_doc_metadata,
|
||||||
merge_table_parser_config_from_kb,
|
merge_table_parser_config_from_kb,
|
||||||
table_parser_strip_doc_metadata_keys,
|
table_parser_strip_doc_metadata_keys,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PostProcessor:
|
class PostProcessor:
|
||||||
"""Service for post-indexing operations.
|
"""Service for post-indexing operations.
|
||||||
|
|
||||||
@@ -69,15 +70,10 @@ class PostProcessor:
|
|||||||
return
|
return
|
||||||
|
|
||||||
eff_pc = merge_table_parser_config_from_kb(ctx.raw_task)
|
eff_pc = merge_table_parser_config_from_kb(ctx.raw_task)
|
||||||
logging.debug(
|
logging.debug(f"[TABLE_META_DEBUG] table post-index: table_column_mode={eff_pc.get('table_column_mode')!r}")
|
||||||
f"[TABLE_META_DEBUG] table post-index: table_column_mode={eff_pc.get('table_column_mode')!r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if eff_pc.get("table_column_mode") != "manual":
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agg = aggregate_table_manual_doc_metadata(chunks, ctx.raw_task)
|
agg = aggregate_table_doc_metadata(chunks, ctx.raw_task)
|
||||||
logging.debug(f"[TABLE_META_DEBUG] aggregated metadata: {agg}")
|
logging.debug(f"[TABLE_META_DEBUG] aggregated metadata: {agg}")
|
||||||
|
|
||||||
strip_keys = table_parser_strip_doc_metadata_keys(eff_pc)
|
strip_keys = table_parser_strip_doc_metadata_keys(eff_pc)
|
||||||
|
|||||||
@@ -102,9 +102,7 @@ def _probe_es_typed_key_for_column(col: str, sample_chunk: dict) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _resolve_es_chunk_field_key(
|
def _resolve_es_chunk_field_key(col: str, field_map: dict, sample_chunk: dict | None) -> tuple[str | None, str]:
|
||||||
col: str, field_map: dict, sample_chunk: dict | None
|
|
||||||
) -> tuple[str | None, str]:
|
|
||||||
"""Prefer field_map when key exists on chunk; else probe by suffix (matches table.py naming)."""
|
"""Prefer field_map when key exists on chunk; else probe by suffix (matches table.py naming)."""
|
||||||
tk_fm = _field_map_typed_key_for_column(field_map, col) if field_map else None
|
tk_fm = _field_map_typed_key_for_column(field_map, col) if field_map else None
|
||||||
if sample_chunk:
|
if sample_chunk:
|
||||||
@@ -153,35 +151,44 @@ def _es_field_value_to_doc_metadata(val, *, from_tks_fallback: bool) -> str | No
|
|||||||
return _value_to_meta_string(val)
|
return _value_to_meta_string(val)
|
||||||
|
|
||||||
|
|
||||||
def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict:
|
def aggregate_table_doc_metadata(chunks: list, task: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
Collect unique values per metadata/both column across chunks for document-level metadata.
|
Collect unique values per metadata/both column across chunks for document-level metadata.
|
||||||
Used when table_column_mode == manual (parallel to LLM gen_metadata, no schema required).
|
Works for both table_column_mode == manual and auto (where all columns default to "both").
|
||||||
"""
|
"""
|
||||||
logging.debug(
|
logging.debug(f"[TABLE_META_DEBUG] aggregate_table_doc_metadata called with {len(chunks)} chunks")
|
||||||
f"[TABLE_META_DEBUG] aggregate_table_manual_doc_metadata called with {len(chunks)} chunks"
|
|
||||||
)
|
|
||||||
eff = merge_table_parser_config_from_kb(task)
|
eff = merge_table_parser_config_from_kb(task)
|
||||||
if eff.get("table_column_mode") != "manual":
|
mode = eff.get("table_column_mode") or "auto"
|
||||||
logging.debug(
|
if mode not in ("manual", "auto"):
|
||||||
f"[TABLE_META_DEBUG] skip aggregate: table_column_mode={eff.get('table_column_mode')!r}"
|
logging.debug(f"[TABLE_META_DEBUG] skip aggregate: table_column_mode={mode!r}")
|
||||||
)
|
|
||||||
return {}
|
return {}
|
||||||
roles = eff.get("table_column_roles") or {}
|
roles = eff.get("table_column_roles") or {}
|
||||||
table_column_names = eff.get("table_column_names") or []
|
table_column_names = eff.get("table_column_names") or []
|
||||||
|
# Reload table_column_names from KB if empty (chunk() writes them during parse,
|
||||||
|
# but the task snapshot may be stale)
|
||||||
|
if not table_column_names:
|
||||||
|
kb_id = task.get("kb_id")
|
||||||
|
if kb_id:
|
||||||
|
try:
|
||||||
|
KBS = _knowledgebase_service_cls()
|
||||||
|
ok, kb = KBS.get_by_id(kb_id)
|
||||||
|
if ok and kb:
|
||||||
|
fresh_names = (kb.parser_config or {}).get("table_column_names") or []
|
||||||
|
if fresh_names:
|
||||||
|
table_column_names = fresh_names
|
||||||
|
logging.debug(f"[TABLE_META_DEBUG] reloaded table_column_names from DB: {fresh_names}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.debug(
|
||||||
|
"[TABLE_META_DEBUG] failed to reload table_column_names from DB: %s",
|
||||||
|
e,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
if table_column_names:
|
if table_column_names:
|
||||||
meta_cols = [
|
meta_cols = [col for col in table_column_names if roles.get(col, "both") in ("metadata", "both")]
|
||||||
col
|
|
||||||
for col in table_column_names
|
|
||||||
if roles.get(col, "both") in ("metadata", "both")
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
meta_cols = [c for c, r in roles.items() if r in ("metadata", "both")]
|
meta_cols = [c for c, r in roles.items() if r in ("metadata", "both")]
|
||||||
if not meta_cols:
|
if not meta_cols:
|
||||||
logging.debug(
|
logging.debug(f"[TABLE_META_DEBUG] skip aggregate: no metadata/both columns (table_column_names_present={bool(table_column_names)})")
|
||||||
"[TABLE_META_DEBUG] skip aggregate: no metadata/both columns "
|
|
||||||
f"(table_column_names_present={bool(table_column_names)})"
|
|
||||||
)
|
|
||||||
return {}
|
return {}
|
||||||
fm = (task.get("kb_parser_config") or {}).get("field_map") or {}
|
fm = (task.get("kb_parser_config") or {}).get("field_map") or {}
|
||||||
kb_id = task.get("kb_id")
|
kb_id = task.get("kb_id")
|
||||||
@@ -194,14 +201,9 @@ def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict:
|
|||||||
reloaded = fresh_pc.get("field_map") or {}
|
reloaded = fresh_pc.get("field_map") or {}
|
||||||
if reloaded:
|
if reloaded:
|
||||||
fm = reloaded
|
fm = reloaded
|
||||||
logging.debug(
|
logging.debug(f"[TABLE_META_DEBUG] reloaded field_map from DB: {len(fm)} entries")
|
||||||
f"[TABLE_META_DEBUG] reloaded field_map from DB: {len(fm)} entries"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logging.debug(
|
logging.debug("[TABLE_META_DEBUG] KB reload: parser_config has no field_map yet; will use ES key probe on chunk dicts if applicable")
|
||||||
"[TABLE_META_DEBUG] KB reload: parser_config has no field_map yet; "
|
|
||||||
"will use ES key probe on chunk dicts if applicable"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"[TABLE_META_DEBUG] failed to reload field_map from DB: %s",
|
"[TABLE_META_DEBUG] failed to reload field_map from DB: %s",
|
||||||
@@ -209,21 +211,11 @@ def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict:
|
|||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
if not fm and not (settings.DOC_ENGINE_INFINITY or settings.DOC_ENGINE_OCEANBASE):
|
if not fm and not (settings.DOC_ENGINE_INFINITY or settings.DOC_ENGINE_OCEANBASE):
|
||||||
logging.debug(
|
logging.debug(f"[TABLE_META_DEBUG] field_map empty on task snapshot — will use ES key probe on chunk dicts; kb_parser_config keys={list((task.get('kb_parser_config') or {}).keys())}")
|
||||||
"[TABLE_META_DEBUG] field_map empty on task snapshot — will use ES key probe on chunk dicts; "
|
logging.debug(f"[TABLE_META_DEBUG] meta_cols={meta_cols}, field_map entries={len(fm)}, infinity={settings.DOC_ENGINE_INFINITY}, oceanbase={settings.DOC_ENGINE_OCEANBASE}")
|
||||||
f"kb_parser_config keys={list((task.get('kb_parser_config') or {}).keys())}"
|
|
||||||
)
|
|
||||||
logging.debug(
|
|
||||||
f"[TABLE_META_DEBUG] meta_cols={meta_cols}, field_map entries={len(fm)}, "
|
|
||||||
f"infinity={settings.DOC_ENGINE_INFINITY}, oceanbase={settings.DOC_ENGINE_OCEANBASE}"
|
|
||||||
)
|
|
||||||
sample_ck = next((c for c in chunks if isinstance(c, dict)), None)
|
sample_ck = next((c for c in chunks if isinstance(c, dict)), None)
|
||||||
if sample_ck:
|
if sample_ck:
|
||||||
sk = [
|
sk = [k for k in sample_ck.keys() if not (str(k).startswith("q_") and str(k).endswith("_vec"))][:50]
|
||||||
k
|
|
||||||
for k in sample_ck.keys()
|
|
||||||
if not (str(k).startswith("q_") and str(k).endswith("_vec"))
|
|
||||||
][:50]
|
|
||||||
logging.debug(f"[TABLE_META_DEBUG] first chunk non-vector keys (sample): {sk}")
|
logging.debug(f"[TABLE_META_DEBUG] first chunk non-vector keys (sample): {sk}")
|
||||||
|
|
||||||
es_col_keys: dict[str, tuple[str | None, str]] = {}
|
es_col_keys: dict[str, tuple[str | None, str]] = {}
|
||||||
@@ -231,9 +223,7 @@ def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict:
|
|||||||
for col in meta_cols:
|
for col in meta_cols:
|
||||||
tk, src = _resolve_es_chunk_field_key(col, fm, sample_ck)
|
tk, src = _resolve_es_chunk_field_key(col, fm, sample_ck)
|
||||||
es_col_keys[col] = (tk, src)
|
es_col_keys[col] = (tk, src)
|
||||||
logging.debug(
|
logging.debug(f"[TABLE_META_DEBUG] column '{col}' -> ES key {tk!r} (source={src})")
|
||||||
f"[TABLE_META_DEBUG] column '{col}' -> ES key {tk!r} (source={src})"
|
|
||||||
)
|
|
||||||
|
|
||||||
acc: dict[str, list] = {c: [] for c in meta_cols}
|
acc: dict[str, list] = {c: [] for c in meta_cols}
|
||||||
|
|
||||||
@@ -255,9 +245,7 @@ def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict:
|
|||||||
tk, _src = es_col_keys.get(col, (None, "none"))
|
tk, _src = es_col_keys.get(col, (None, "none"))
|
||||||
if not tk:
|
if not tk:
|
||||||
if i == 0:
|
if i == 0:
|
||||||
logging.debug(
|
logging.debug(f"[TABLE_META_DEBUG] no resolved ES key for column '{col}'")
|
||||||
f"[TABLE_META_DEBUG] no resolved ES key for column '{col}'"
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
raw_k = _es_raw_field_key_from_typed(tk)
|
raw_k = _es_raw_field_key_from_typed(tk)
|
||||||
val = None
|
val = None
|
||||||
@@ -269,10 +257,7 @@ def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict:
|
|||||||
from_tks = tk.endswith("_tks")
|
from_tks = tk.endswith("_tks")
|
||||||
else:
|
else:
|
||||||
if i == 0:
|
if i == 0:
|
||||||
logging.debug(
|
logging.debug(f"[TABLE_META_DEBUG] chunk missing ES field {tk!r}{' and ' + raw_k + ' (raw)' if raw_k else ''} for column '{col}'")
|
||||||
f"[TABLE_META_DEBUG] chunk missing ES field {tk!r}"
|
|
||||||
f"{' and ' + raw_k + ' (raw)' if raw_k else ''} for column '{col}'"
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
s = _es_field_value_to_doc_metadata(val, from_tks_fallback=from_tks)
|
s = _es_field_value_to_doc_metadata(val, from_tks_fallback=from_tks)
|
||||||
if s is not None:
|
if s is not None:
|
||||||
@@ -289,8 +274,5 @@ def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict:
|
|||||||
for col, vals in acc.items():
|
for col, vals in acc.items():
|
||||||
if vals:
|
if vals:
|
||||||
out[col] = dedupe_list(vals)
|
out[col] = dedupe_list(vals)
|
||||||
logging.debug(
|
logging.debug(f"[TABLE_META_DEBUG] aggregated metadata dict keys={list(out.keys())}, sizes={[len(v) for v in out.values()]}")
|
||||||
f"[TABLE_META_DEBUG] aggregated metadata dict keys={list(out.keys())}, "
|
|
||||||
f"sizes={[len(v) for v in out.values()]}"
|
|
||||||
)
|
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -69,12 +69,13 @@ class TestPostProcessorProcessTableParserMetadata:
|
|||||||
ctx.write_interceptor = None
|
ctx.write_interceptor = None
|
||||||
chunks = [{"col_key": "val"}]
|
chunks = [{"col_key": "val"}]
|
||||||
|
|
||||||
with patch("rag.svr.task_executor_refactor.post_processor.merge_table_parser_config_from_kb") as mock_merge, \
|
with (
|
||||||
patch("rag.svr.task_executor_refactor.post_processor.aggregate_table_manual_doc_metadata") as mock_agg, \
|
patch("rag.svr.task_executor_refactor.post_processor.merge_table_parser_config_from_kb") as mock_merge,
|
||||||
patch("rag.svr.task_executor_refactor.post_processor.table_parser_strip_doc_metadata_keys") as mock_strip, \
|
patch("rag.svr.task_executor_refactor.post_processor.aggregate_table_doc_metadata") as mock_agg,
|
||||||
patch("rag.svr.task_executor_refactor.post_processor.update_metadata_to") as mock_update, \
|
patch("rag.svr.task_executor_refactor.post_processor.table_parser_strip_doc_metadata_keys") as mock_strip,
|
||||||
patch("rag.svr.task_executor_refactor.post_processor.DocMetadataService") as mock_meta:
|
patch("rag.svr.task_executor_refactor.post_processor.update_metadata_to") as mock_update,
|
||||||
|
patch("rag.svr.task_executor_refactor.post_processor.DocMetadataService") as mock_meta,
|
||||||
|
):
|
||||||
mock_merge.return_value = {"table_column_mode": "manual"}
|
mock_merge.return_value = {"table_column_mode": "manual"}
|
||||||
mock_agg.return_value = {"col_key": ["val1", "val2"]}
|
mock_agg.return_value = {"col_key": ["val1", "val2"]}
|
||||||
mock_strip.return_value = set()
|
mock_strip.return_value = set()
|
||||||
@@ -95,11 +96,12 @@ class TestPostProcessorProcessTableParserMetadata:
|
|||||||
ctx.raw_task = {}
|
ctx.raw_task = {}
|
||||||
ctx.write_interceptor = MagicMock()
|
ctx.write_interceptor = MagicMock()
|
||||||
|
|
||||||
with patch("rag.svr.task_executor_refactor.post_processor.merge_table_parser_config_from_kb") as mock_merge, \
|
with (
|
||||||
patch("rag.svr.task_executor_refactor.post_processor.aggregate_table_manual_doc_metadata") as mock_agg, \
|
patch("rag.svr.task_executor_refactor.post_processor.merge_table_parser_config_from_kb") as mock_merge,
|
||||||
patch("rag.svr.task_executor_refactor.post_processor.table_parser_strip_doc_metadata_keys") as mock_strip, \
|
patch("rag.svr.task_executor_refactor.post_processor.aggregate_table_doc_metadata") as mock_agg,
|
||||||
patch("rag.svr.task_executor_refactor.post_processor.DocMetadataService") as mock_meta:
|
patch("rag.svr.task_executor_refactor.post_processor.table_parser_strip_doc_metadata_keys") as mock_strip,
|
||||||
|
patch("rag.svr.task_executor_refactor.post_processor.DocMetadataService") as mock_meta,
|
||||||
|
):
|
||||||
mock_merge.return_value = {"table_column_mode": "manual"}
|
mock_merge.return_value = {"table_column_mode": "manual"}
|
||||||
mock_agg.return_value = {"key": ["v"]}
|
mock_agg.return_value = {"key": ["v"]}
|
||||||
mock_strip.return_value = set()
|
mock_strip.return_value = set()
|
||||||
@@ -108,9 +110,7 @@ class TestPostProcessorProcessTableParserMetadata:
|
|||||||
service = PostProcessor(ctx=ctx)
|
service = PostProcessor(ctx=ctx)
|
||||||
await service.process_table_parser_metadata("doc_1", [])
|
await service.process_table_parser_metadata("doc_1", [])
|
||||||
|
|
||||||
ctx.write_interceptor.intercept.assert_called_once_with(
|
ctx.write_interceptor.intercept.assert_called_once_with("DocMetadataService.update_document_metadata")
|
||||||
"DocMetadataService.update_document_metadata"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestPostProcessorInsertTocChunk:
|
class TestPostProcessorInsertTocChunk:
|
||||||
@@ -160,9 +160,7 @@ class TestPostProcessorInsertTocChunk:
|
|||||||
result = await service.insert_toc_chunk(toc_chunk, chunk_service)
|
result = await service.insert_toc_chunk(toc_chunk, chunk_service)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
chunk_service.insert_chunks.assert_called_once_with(
|
chunk_service.insert_chunks.assert_called_once_with("task_1", "tenant_1", "kb_1", [toc_chunk])
|
||||||
"task_1", "tenant_1", "kb_1", [toc_chunk]
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_handles_insert_failure(self):
|
async def test_handles_insert_failure(self):
|
||||||
@@ -179,4 +177,4 @@ class TestPostProcessorInsertTocChunk:
|
|||||||
|
|
||||||
result = await service.insert_toc_chunk(toc_chunk, chunk_service)
|
result = await service.insert_toc_chunk(toc_chunk, chunk_service)
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|||||||
@@ -14,11 +14,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
"""Unit tests for aggregate_table_manual_doc_metadata."""
|
"""Unit tests for aggregate_table_doc_metadata."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from rag.utils.table_es_metadata import aggregate_table_manual_doc_metadata, merge_table_parser_config_from_kb
|
from rag.utils.table_es_metadata import aggregate_table_doc_metadata, merge_table_parser_config_from_kb
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -73,30 +73,64 @@ class TestAggregateTableManualDocMetadata:
|
|||||||
"category_tks": "y",
|
"category_tks": "y",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
out = aggregate_table_manual_doc_metadata(chunks, task)
|
out = aggregate_table_doc_metadata(chunks, task)
|
||||||
assert out["country"] == ["Brazil", "Turkey"]
|
assert out["country"] == ["Brazil", "Turkey"]
|
||||||
assert out["category"] == ["Economy", "Disaster"]
|
assert out["category"] == ["Economy", "Disaster"]
|
||||||
|
|
||||||
def test_aggregate_auto_mode_returns_empty(self, es_engine):
|
def test_aggregate_auto_mode_returns_data(self, es_engine):
|
||||||
task = {
|
task = {
|
||||||
"parser_id": "table",
|
"parser_id": "table",
|
||||||
"parser_config": {},
|
"parser_config": {},
|
||||||
"kb_parser_config": {
|
"kb_parser_config": {
|
||||||
"table_column_mode": "auto",
|
"table_column_mode": "auto",
|
||||||
"table_column_roles": {"country": "metadata"},
|
"table_column_names": ["country"],
|
||||||
|
"field_map": {"country_tks": "country"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assert aggregate_table_manual_doc_metadata([{"country_tks": "x"}], task) == {}
|
chunks = [{"country_raw": "Brazil", "country_tks": "x"}]
|
||||||
|
out = aggregate_table_doc_metadata(chunks, task)
|
||||||
|
assert out == {"country": ["Brazil"]}
|
||||||
|
|
||||||
def test_aggregate_no_mode_returns_empty(self, es_engine):
|
def test_aggregate_auto_mode_all_columns_both(self, es_engine):
|
||||||
task = {
|
task = {
|
||||||
"parser_id": "table",
|
"parser_id": "table",
|
||||||
"parser_config": {},
|
"parser_config": {},
|
||||||
"kb_parser_config": {
|
"kb_parser_config": {
|
||||||
"table_column_roles": {"country": "metadata"},
|
"table_column_mode": "auto",
|
||||||
|
"table_column_names": ["country", "category"],
|
||||||
|
"field_map": {"country_tks": "country", "category_tks": "category"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assert aggregate_table_manual_doc_metadata([{}], task) == {}
|
chunks = [
|
||||||
|
{"country_raw": "Brazil", "country_tks": "x", "category_raw": "Economy", "category_tks": "y"},
|
||||||
|
{"country_raw": "Turkey", "country_tks": "x", "category_raw": "Disaster", "category_tks": "y"},
|
||||||
|
]
|
||||||
|
out = aggregate_table_doc_metadata(chunks, task)
|
||||||
|
assert out["country"] == ["Brazil", "Turkey"]
|
||||||
|
assert out["category"] == ["Economy", "Disaster"]
|
||||||
|
|
||||||
|
def test_aggregate_no_mode_defaults_to_auto(self, es_engine):
|
||||||
|
"""When table_column_mode is missing, it defaults to 'auto'."""
|
||||||
|
task = {
|
||||||
|
"parser_id": "table",
|
||||||
|
"parser_config": {},
|
||||||
|
"kb_parser_config": {
|
||||||
|
"table_column_names": ["country"],
|
||||||
|
"field_map": {"country_tks": "country"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
chunks = [{"country_raw": "Brazil", "country_tks": "x"}]
|
||||||
|
out = aggregate_table_doc_metadata(chunks, task)
|
||||||
|
assert out == {"country": ["Brazil"]}
|
||||||
|
|
||||||
|
def test_aggregate_no_mode_no_columns_returns_empty(self, es_engine):
|
||||||
|
"""No mode and no column names/roles -> empty (nothing to aggregate)."""
|
||||||
|
task = {
|
||||||
|
"parser_id": "table",
|
||||||
|
"parser_config": {},
|
||||||
|
"kb_parser_config": {},
|
||||||
|
}
|
||||||
|
assert aggregate_table_doc_metadata([{}], task) == {}
|
||||||
|
|
||||||
def test_aggregate_no_metadata_columns(self, es_engine):
|
def test_aggregate_no_metadata_columns(self, es_engine):
|
||||||
task = {
|
task = {
|
||||||
@@ -108,14 +142,14 @@ class TestAggregateTableManualDocMetadata:
|
|||||||
"table_column_names": ["country"],
|
"table_column_names": ["country"],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assert aggregate_table_manual_doc_metadata([{"country_tks": "x"}], task) == {}
|
assert aggregate_table_doc_metadata([{"country_tks": "x"}], task) == {}
|
||||||
|
|
||||||
def test_aggregate_prefers_raw_over_tks(self, es_engine):
|
def test_aggregate_prefers_raw_over_tks(self, es_engine):
|
||||||
task = _table_task()
|
task = _table_task()
|
||||||
task["kb_parser_config"]["table_column_roles"] = {"country": "metadata"}
|
task["kb_parser_config"]["table_column_roles"] = {"country": "metadata"}
|
||||||
task["kb_parser_config"]["table_column_names"] = ["country"]
|
task["kb_parser_config"]["table_column_names"] = ["country"]
|
||||||
chunks = [{"country_raw": "Brazil", "country_tks": ["brazil"]}]
|
chunks = [{"country_raw": "Brazil", "country_tks": ["brazil"]}]
|
||||||
out = aggregate_table_manual_doc_metadata(chunks, task)
|
out = aggregate_table_doc_metadata(chunks, task)
|
||||||
assert out == {"country": ["Brazil"]}
|
assert out == {"country": ["Brazil"]}
|
||||||
|
|
||||||
def test_aggregate_tks_fallback(self, es_engine):
|
def test_aggregate_tks_fallback(self, es_engine):
|
||||||
@@ -123,7 +157,7 @@ class TestAggregateTableManualDocMetadata:
|
|||||||
task["kb_parser_config"]["table_column_roles"] = {"country": "metadata"}
|
task["kb_parser_config"]["table_column_roles"] = {"country": "metadata"}
|
||||||
task["kb_parser_config"]["table_column_names"] = ["country"]
|
task["kb_parser_config"]["table_column_names"] = ["country"]
|
||||||
chunks = [{"country_tks": ["brazil"]}]
|
chunks = [{"country_tks": ["brazil"]}]
|
||||||
out = aggregate_table_manual_doc_metadata(chunks, task)
|
out = aggregate_table_doc_metadata(chunks, task)
|
||||||
assert out == {"country": ["brazil"]}
|
assert out == {"country": ["brazil"]}
|
||||||
|
|
||||||
def test_aggregate_partial_roles_defaults_to_both(self, es_engine):
|
def test_aggregate_partial_roles_defaults_to_both(self, es_engine):
|
||||||
@@ -138,7 +172,7 @@ class TestAggregateTableManualDocMetadata:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
chunks = [{"city_raw": "SP", "city_tks": "t", "country_tks": "x"}]
|
chunks = [{"city_raw": "SP", "city_tks": "t", "country_tks": "x"}]
|
||||||
out = aggregate_table_manual_doc_metadata(chunks, task)
|
out = aggregate_table_doc_metadata(chunks, task)
|
||||||
assert out == {"city": ["SP"]}
|
assert out == {"city": ["SP"]}
|
||||||
assert "country" not in out
|
assert "country" not in out
|
||||||
|
|
||||||
@@ -156,7 +190,7 @@ class TestAggregateTableManualDocMetadata:
|
|||||||
chunks = [
|
chunks = [
|
||||||
{"country_raw": "BR", "city_raw": "SP", "country_tks": "x", "city_tks": "y"},
|
{"country_raw": "BR", "city_raw": "SP", "country_tks": "x", "city_tks": "y"},
|
||||||
]
|
]
|
||||||
out = aggregate_table_manual_doc_metadata(chunks, task)
|
out = aggregate_table_doc_metadata(chunks, task)
|
||||||
assert "country" in out and "city" in out
|
assert "country" in out and "city" in out
|
||||||
|
|
||||||
def test_aggregate_deduplicates_values(self, es_engine):
|
def test_aggregate_deduplicates_values(self, es_engine):
|
||||||
@@ -168,7 +202,7 @@ class TestAggregateTableManualDocMetadata:
|
|||||||
{"country_raw": "UK", "country_tks": "y"},
|
{"country_raw": "UK", "country_tks": "y"},
|
||||||
{"country_raw": "US", "country_tks": "x"},
|
{"country_raw": "US", "country_tks": "x"},
|
||||||
]
|
]
|
||||||
out = aggregate_table_manual_doc_metadata(chunks, task)
|
out = aggregate_table_doc_metadata(chunks, task)
|
||||||
assert out["country"] == ["US", "UK"]
|
assert out["country"] == ["US", "UK"]
|
||||||
|
|
||||||
def test_aggregate_kb_reload_field_map(self, es_engine, monkeypatch):
|
def test_aggregate_kb_reload_field_map(self, es_engine, monkeypatch):
|
||||||
@@ -197,7 +231,7 @@ class TestAggregateTableManualDocMetadata:
|
|||||||
"kb_id": "kb-1",
|
"kb_id": "kb-1",
|
||||||
}
|
}
|
||||||
chunks = [{"country_raw": "X", "country_tks": "t"}]
|
chunks = [{"country_raw": "X", "country_tks": "t"}]
|
||||||
out = aggregate_table_manual_doc_metadata(chunks, task)
|
out = aggregate_table_doc_metadata(chunks, task)
|
||||||
assert out == {"country": ["X"]}
|
assert out == {"country": ["X"]}
|
||||||
|
|
||||||
def test_merge_infinity_chunk_data(self, infinity_engine):
|
def test_merge_infinity_chunk_data(self, infinity_engine):
|
||||||
@@ -214,7 +248,7 @@ class TestAggregateTableManualDocMetadata:
|
|||||||
{"chunk_data": {"country": "US"}},
|
{"chunk_data": {"country": "US"}},
|
||||||
{"chunk_data": {"country": "UK"}},
|
{"chunk_data": {"country": "UK"}},
|
||||||
]
|
]
|
||||||
out = aggregate_table_manual_doc_metadata(chunks, task)
|
out = aggregate_table_doc_metadata(chunks, task)
|
||||||
assert out == {"country": ["US", "UK"]}
|
assert out == {"country": ["US", "UK"]}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user