# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Task Handler Module. Provides [`TaskHandler`](rag/svr/task_executor_refactor/task_handler.py:56) as the main entry point for handling document processing tasks with refactored, testable methods. """ import asyncio import logging import json import xxhash from timeit import default_timer as timer from typing import Callable, Dict, List, Optional from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.joint_services.memory_message_service import handle_save_to_memory_task from api.db.joint_services.tenant_model_service import ( get_tenant_default_model_by_type, get_model_config_from_provider_instance ) from api.db.services.llm_service import LLMBundle from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID from common.constants import LLMType from common.exceptions import TaskCanceledException from common.connection_utils import timeout from common.misc_utils import thread_pool_exec from rag.nlp import search from rag.svr.task_executor_refactor.constants import CANVAS_DEBUG_DOC_ID from rag.svr.task_executor_refactor.chunk_service import ChunkService from rag.svr.task_executor_refactor.dataflow_service import BillingHook, DataflowService from rag.svr.task_executor_refactor.embedding_service import EmbeddingService from rag.svr.task_executor_refactor.post_processor import PostProcessor from rag.svr.task_executor_refactor.raptor_service import RaptorService from rag.svr.task_executor_refactor.raptor_utils import delete_raptor_chunks from rag.svr.task_executor_refactor.recording_context import RecordingContext from rag.svr.task_executor_refactor.task_context import TaskContext from rag.graphrag.general.index import run_graphrag_for_kb from api.db.services.file2document_service import File2DocumentService from rag.prompts.generator import run_toc_from_text from common import settings class TaskHandler: """Main task handler for document processing. This class orchestrates the entire document processing pipeline: 1. Task type detection (memory, dataflow, raptor, graphrag, standard) 2. Model binding (embedding, chat) 3. Chunk building or RAPTOR/GraphRAG execution 4. Embedding 5. Indexing 6. Post-processing (TOC, table metadata) All intermediate results are recorded via RecordingContext for comparison. """ def __init__( self, ctx: TaskContext, billing_hook: Optional[BillingHook] = None, ): """Initialize TaskHandler. Args: ctx: TaskContext containing task configuration and execution resources. billing_hook: Optional billing hook for pipeline success/error callbacks. """ self._task_context = ctx self._billing_hook = billing_hook async def handle_task(self) -> None: try: await self.handle() finally: task_id = self._task_context.id task_tenant_id = self._task_context.tenant_id task_dataset_id = self._task_context.kb_id task_doc_id = self._task_context.doc_id if self._task_context.has_canceled_func(task_id): try: exists = await thread_pool_exec( settings.docStoreConn.index_exist, search.index_name(task_tenant_id), task_dataset_id, ) if exists: ret = await thread_pool_exec( settings.docStoreConn.delete, {"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id, ) self._task_context.recording_context.save_func_return_value("docStoreConn.delete", ret) except Exception as e: logging.exception( f"Remove doc({task_doc_id}) from docStore failed when task({task_id}) canceled, exception: {e}") @timeout(60 * 60 * 3, 1) async def handle(self) -> None: """Handle a document processing task.""" ctx = self._task_context task_type = ctx.task_type task_id = ctx.id # Handle memory tasks if task_type == "memory": # ignore when it's dry run - no change on handle_save_to_memory_task when refactor if isinstance(ctx.write_interceptor, RecordingContext): logging.info(f"dry run, ignore handle_save_to_memory_task {task_id}") else: # actual run - not dry run await handle_save_to_memory_task(ctx.raw_task) return # Check if task is canceled if ctx.has_canceled_func(task_id): ctx.progress_cb(-1, msg="Task has been canceled.") return # Language defaults to "Chinese" via TaskContext._DEFAULTS — safe to bind model directly. # Bind embedding model (matching original do_handle_task order: bind + init_kb before routing) result = await self._bind_embedding_model() if result is None: return embedding_model, vector_size = result with embedding_model: self._init_kb(vector_size) # Handle dataflow tasks (after init_kb, matching original behavior) if task_type == "dataflow" and ctx.doc_id == CANVAS_DEBUG_DOC_ID: await self._run_dataflow() return if task_type.startswith("dataflow"): await self._run_dataflow() return # Route to appropriate handler if task_type == "raptor": await self._run_raptor(embedding_model, vector_size) elif task_type == "graphrag": await self._run_graphrag(embedding_model) elif task_type == "mindmap": ctx.progress_cb(1, "place holder") elif task_type == "evaluation": await self._run_evaluation() elif task_type == "reembedding": await self._run_reembedding() elif task_type == "clone": await self._run_clone() else: await self._run_standard_chunking(embedding_model) def _init_kb(self, vector_size: int) -> None: """Initialize knowledge base index.""" ctx = self._task_context idxnm = search.index_name(ctx.tenant_id) parser_id = ctx.parser_id # Create index if not exists settings.docStoreConn.create_idx(idxnm, ctx.kb_id, vector_size, parser_id) async def _run_dataflow(self) -> None: """Run dataflow pipeline.""" dataflow_service = DataflowService( ctx=self._task_context, billing_hook=self._billing_hook, ) await dataflow_service.run_dataflow() async def _run_evaluation(self) -> None: """Run evaluation task.""" ctx = self._task_context ctx.progress_cb(1, "Evaluation task placeholder") async def _run_reembedding(self) -> None: """Run reembedding task.""" ctx = self._task_context ctx.progress_cb(1, "Reembedding task placeholder") async def _run_clone(self) -> None: """Run clone task.""" ctx = self._task_context ctx.progress_cb(1, "Clone task placeholder") async def _bind_embedding_model(self) -> Optional[tuple]: """Bind embedding model to task. Returns: Tuple of (embedding_model, vector_size) on success, or None on failure. """ ctx = self._task_context task_tenant_id = ctx.tenant_id task_embedding_id = ctx.embd_id task_language = ctx.language try: if task_embedding_id: embd_model_config = get_model_config_from_provider_instance( task_tenant_id, LLMType.EMBEDDING, task_embedding_id ) else: embd_model_config = get_tenant_default_model_by_type( task_tenant_id, LLMType.EMBEDDING ) embedding_model = LLMBundle(task_tenant_id, embd_model_config, lang=task_language) vts, _ = embedding_model.encode(["ok"]) return embedding_model, len(vts[0]) except Exception as e: error_message = f'Fail to bind embedding model: {str(e)}' ctx.progress_cb(-1, msg=error_message) logging.exception(error_message) raise async def _run_raptor( self, embedding_model: LLMBundle, vector_size: int, ) -> None: """Run RAPTOR summary generation.""" ctx = self._task_context task_tenant_id = ctx.tenant_id task_dataset_id = ctx.kb_id kb_task_llm_id = ctx.kb_parser_config.get("llm_id") or ctx.llm_id ok, kb = KnowledgebaseService.get_by_id(task_dataset_id) if not ok: ctx.progress_cb(prog=-1.0, msg="Cannot found valid dataset for RAPTOR task") return kb_parser_config = kb.parser_config if not kb_parser_config.get("raptor", {}).get("use_raptor", False): kb_parser_config.update({ "raptor": { "use_raptor": True, "prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize.", "max_token": 256, "threshold": 0.1, "max_cluster": 64, "random_seed": 0, "scope": "file", "clustering_method": "gmm", "tree_builder": "raptor", }, }) if ctx.write_interceptor: update_result = ctx.write_interceptor.intercept("KnowledgebaseService.update_by_id") else: update_result = KnowledgebaseService.update_by_id(kb.id, {"parser_config": kb_parser_config}) if not update_result: ctx.progress_cb(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration") return # Bind LLM for raptor chat_model_config = get_model_config_from_provider_instance( task_tenant_id, LLMType.CHAT, kb_task_llm_id ) with LLMBundle(task_tenant_id, chat_model_config, lang=ctx.language) as chat_model: # Run RAPTOR raptor_service = RaptorService(ctx=ctx) async with ctx.kg_limiter: chunks, token_count, raptor_cleanup_chunks = await raptor_service.run_raptor_for_kb( kb_parser_config=kb_parser_config, chat_mdl=chat_model, embd_mdl=embedding_model, vector_size=vector_size, doc_ids=ctx.doc_ids, ) ctx.recording_context.record("raptor_chunks", chunks) ctx.recording_context.record("raptor_token_count", token_count) # Insert RAPTOR chunks if chunks: task_doc_id = (ctx.doc_ids or [GRAPH_RAPTOR_FAKE_DOC_ID])[0] chunk_service = ChunkService(ctx=ctx) insert_result = await chunk_service.insert_chunks(ctx.id, task_tenant_id, task_dataset_id, chunks) if insert_result: ctx.recording_context.record("insertion_result", "success") else: ctx.recording_context.record("insertion_result", "failed") # Cleanup stale RAPTOR chunks cleaned_chunks = 0 for cleanup_doc_id, keep_method in raptor_cleanup_chunks: ret = await self._delete_raptor_chunks( cleanup_doc_id, task_tenant_id, task_dataset_id, keep_method ) cleaned_chunks += ret if cleaned_chunks: ctx.progress_cb(msg=f"Cleaned up {cleaned_chunks} stale RAPTOR chunks.") # Update document stats if ctx.write_interceptor: ctx.write_interceptor.intercept("DocumentService.increment_chunk_num") else: DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, len(chunks), 0) ctx.recording_context.record("task_status", "completed") ctx.progress_cb(prog=1.0, msg="RAPTOR done") async def _run_graphrag( self, embedding_model: LLMBundle ) -> None: """Run GraphRAG.""" ctx = self._task_context task_tenant_id = ctx.tenant_id task_dataset_id = ctx.kb_id kb_task_llm_id = ctx.kb_parser_config.get("llm_id") or ctx.llm_id task_language = ctx.language ok, kb = KnowledgebaseService.get_by_id(task_dataset_id) if not ok: ctx.progress_cb(prog=-1.0, msg="Cannot found valid dataset for GraphRAG task") return kb_parser_config = kb.parser_config if not kb_parser_config.get("graphrag", {}).get("use_graphrag", False): kb_parser_config.update({ "graphrag": { "use_graphrag": True, "entity_types": [ "organization", "person", "geo", "event", "category", ], "method": "light", "batch_chunk_token_size": 4096, "retry_attempts": 2, "retry_backoff_seconds": 2.0, "retry_backoff_max_seconds": 60.0, "build_subgraph_timeout_per_chunk_seconds": 300, "build_subgraph_min_timeout_seconds": 600, "merge_timeout_seconds": 180, "resolution_timeout_seconds": 1800, "community_timeout_seconds": 1800, "lock_acquire_timeout_seconds": 600, } }) if ctx.write_interceptor: update_result = ctx.write_interceptor.intercept("KnowledgebaseService.update_by_id") else: update_result = KnowledgebaseService.update_by_id(kb.id, {"parser_config": kb_parser_config}) if not update_result: ctx.progress_cb(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration") return graphrag_conf = kb_parser_config.get("graphrag", {}) start_ts = timer() chat_model_config = get_model_config_from_provider_instance( task_tenant_id, LLMType.CHAT, kb_task_llm_id ) with LLMBundle(task_tenant_id, chat_model_config, lang=task_language) as chat_model: with_resolution = graphrag_conf.get("resolution", False) with_community = graphrag_conf.get("community", False) async with ctx.kg_limiter: result = await run_graphrag_for_kb( row=ctx.raw_task, doc_ids=ctx.doc_ids, language=task_language, kb_parser_config=kb_parser_config, chat_model=chat_model, embedding_model=embedding_model, callback=ctx.progress_cb, with_resolution=with_resolution, with_community=with_community, ) logging.info(f"GraphRAG task result for task {ctx.raw_task}:\n{result}") ctx.recording_context.record("graphrag_result", result) ctx.progress_cb(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts)) async def _run_standard_chunking( self, embedding_model: LLMBundle ) -> None: """Run standard chunking pipeline.""" ctx = self._task_context task_id = ctx.id task_tenant_id = ctx.tenant_id task_dataset_id = ctx.kb_id task_doc_id = ctx.doc_id task_start_ts = timer() doc_task_llm_id = ctx.parser_config.get("llm_id") or ctx.llm_id ctx.raw_task['llm_id'] = doc_task_llm_id # Build chunks start_ts = timer() chunk_service = ChunkService(ctx=ctx) # Get storage binary bucket, name = File2DocumentService.get_storage_address(doc_id=ctx.doc_id) binary = await self._get_storage_binary(bucket, name) if binary is None: raise FileNotFoundError( f"Can not find file <{ctx.name}> from minio. Could you try it again." ) chunks = await chunk_service.build_chunks(binary) ctx.recording_context.record("chunks", chunks) chunk_ids = [c.get("id") for c in chunks if isinstance(c, dict) and "id" in c] ctx.recording_context.record("chunk_ids_count", len(chunk_ids)) logging.info("Build document {}: {:.2f}s".format(ctx.name, timer() - start_ts)) if not chunks: ctx.progress_cb(1., msg=f"No chunk built from {ctx.name}") return ctx.progress_cb(msg="Generate {} chunks".format(len(chunks))) # Embed chunks start_ts = timer() embedding_service = EmbeddingService(ctx=ctx) try: token_count, vector_size = await embedding_service.embed_chunks( chunks, embedding_model, ctx.parser_config ) except TaskCanceledException: raise except Exception as e: error_message = "Generate embedding error:{}".format(str(e)) ctx.progress_cb(-1, error_message) logging.exception(error_message) raise ctx.recording_context.record("token_count", token_count) ctx.recording_context.record("vector_size", vector_size) progress_message = "Embedding chunks ({:.2f}s)".format(timer() - start_ts) logging.info(progress_message) ctx.progress_cb(msg=progress_message) # Build TOC if needed toc_thread = None if ctx.parser_id.lower() == "naive" and ctx.parser_config.get("toc_extraction", False): toc_thread = asyncio.create_task(asyncio.to_thread(self._build_toc, ctx, chunks, ctx.progress_cb)) # Insert chunks chunk_count = len(set([chunk["id"] for chunk in chunks])) start_ts = timer() chunk_service = ChunkService(ctx=ctx) if ctx.has_canceled_func(task_id): ctx.progress_cb(-1, msg="Task has been canceled.") return insert_result = await chunk_service.insert_chunks( task_id, task_tenant_id, task_dataset_id, chunks ) if not insert_result: ctx.recording_context.record("insertion_result", "failed") return ctx.recording_context.record("insertion_result", "success") # Post-processing post_processor = PostProcessor(ctx=ctx) await post_processor.process_table_parser_metadata(task_doc_id, chunks) ctx.progress_cb(msg="Indexing done ({:.2f}s).".format(timer() - start_ts)) toc_chunk = await self._process_toc_thread(toc_thread) if toc_chunk: ctx.recording_context.record("toc_chunk", [toc_chunk]) await post_processor.insert_toc_chunk(toc_chunk, chunk_service) if ctx.has_canceled_func(task_id): ctx.progress_cb(-1, msg="Task has been canceled.") return # Update document stats if ctx.write_interceptor: ctx.write_interceptor.intercept("DocumentService.increment_chunk_num") else: DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0) task_time_cost = timer() - task_start_ts ctx.recording_context.record("task_status", "completed") ctx.progress_cb(prog=1.0, msg="Task done ({:.2f}s)".format(task_time_cost)) logging.info( "Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format( ctx.name, ctx.from_page, ctx.to_page, len(chunks), token_count, task_time_cost ) ) async def _process_toc_thread(self, toc_thread): try: if toc_thread: return await toc_thread else: return None finally: if toc_thread is not None and not toc_thread.done(): toc_thread.cancel() @classmethod async def _get_storage_binary(cls, bucket: str, name: str) -> bytes: from common import settings """Get binary from storage.""" return await thread_pool_exec(settings.STORAGE_IMPL.get, bucket, name) @classmethod def _build_toc(cls, ctx: TaskContext, docs: List[Dict], progress_cb: Callable) -> Optional[Dict]: """Build table of contents.""" progress_cb(msg="Start to generate table of content ...") chat_model_config = get_model_config_from_provider_instance( ctx.tenant_id, LLMType.CHAT, ctx.llm_id ) with LLMBundle(ctx.tenant_id, chat_model_config, lang=ctx.language) as chat_mdl: docs = sorted(docs, key=lambda d: ( d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0), d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0) )) # NOTE: asyncio.run() creates a new event loop in the worker thread # (this method is called via asyncio.to_thread), which is the # intended pattern for bridging sync -> async in a thread context. toc: list[dict] = asyncio.run( run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_cb) ) logging.info("------------ T O C -------------\n" + json.dumps(toc, ensure_ascii=False, indent=' ')) for ii, item in enumerate(toc): try: chunk_val = item.pop("chunk_id", None) if chunk_val is None or str(chunk_val).strip() == "": logging.warning(f"Index {ii}: chunk_id is missing or empty. Skipping.") continue curr_idx = int(chunk_val or -1) if curr_idx >= len(docs): logging.error(f"Index {ii}: chunk_id {curr_idx} exceeds docs length {len(docs)}.") continue item["ids"] = [docs[curr_idx]["id"]] if ii + 1 < len(toc): next_chunk_val = toc[ii + 1].get("chunk_id", "") if str(next_chunk_val).strip() != "": next_idx = int(next_chunk_val) for jj in range(curr_idx + 1, min(next_idx + 1, len(docs))): item["ids"].append(docs[jj]["id"]) else: logging.warning(f"Index {ii + 1}: next chunk_id is empty, range fill skipped.") except (ValueError, TypeError) as e: logging.error(f"Index {ii}: Data conversion error - {e}") except Exception as e: logging.exception(f"Index {ii}: Unexpected error - {e}") if toc: import copy d = copy.deepcopy(docs[-1]) d["content_with_weight"] = json.dumps(toc, ensure_ascii=False) d["toc_kwd"] = "toc" d["available_int"] = 0 d["page_num_int"] = [100000000] d["id"] = xxhash.xxh64( (d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest() return d return None async def _delete_raptor_chunks( self, doc_id: str, tenant_id: str, kb_id: str, keep_method: Optional[str] ) -> int: """Delete RAPTOR chunks.""" if self._task_context.write_interceptor: return self._task_context.write_interceptor.intercept("delete_raptor_chunks") else: return await delete_raptor_chunks(doc_id, tenant_id, kb_id, keep_method)