Files
ragflow/rag/svr/task_executor_refactor/task_manager.py
Jack f0cb7a544b Refactor: Task Executor (#15154)
### What problem does this PR solve?

1. Break huge function into smaller pieces
2. Add unit test for the smaller pieces function
3. Layer-ed design
a. infra layer - task_context.py, recording_context.py,
write_operation_interceptor.py, ...
    b. service layer - *_service.py
    c. business layer - task_handler.py
4. Default behavior: use "refactor-ed version" - can switch to original
version by change env variable

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
- [x] Performance Improvement

---------

Co-authored-by: Liu An <asiro@qq.com>
Co-authored-by: Zhichang Yu <yuzhichang@gmail.com>
2026-05-27 21:54:17 +08:00

177 lines
7.1 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Task Manager Module.
Provides [`TaskManager`](rag/svr/task_executor_refactor/task_manager.py:50) as the entry point
for executing document processing tasks, supporting both production and dry-run (comparison) modes.
"""
import logging
from typing import Any, Optional
from rag.svr.task_executor_refactor.comparator import ContextComparator
from rag.svr.task_executor_refactor.task_context import TaskCallbacks, TaskDict, TaskLimiters
from rag.svr.task_executor_refactor.dataflow_service import BillingHook
from rag.svr.task_executor_refactor.recording_context import (
BaseRecordingContext,
RecordingContext,
_NULL_RECORDING_CONTEXT,
set_recording_context, recording_context_manager,
)
from rag.svr.task_executor_refactor.task_context import TaskContext
from rag.svr.task_executor_refactor.task_handler import TaskHandler
from rag.svr.task_executor_refactor.write_operation_interceptor import (
WriteOperationInterceptor,
)
class TaskManager:
"""Entry point for executing document processing tasks.
This class provides methods for:
- Production task execution (run_refactored_task)
- Dry-run task execution with comparison (dry_run_task)
Usage:
manager = TaskManager()
await manager.run_refactored_task(task, chat_limiter, ...)
# or
await manager.dry_run_task(task, recording_ctx1, ...)
"""
@classmethod
async def run_refactored_task(
cls,
task: dict,
chat_limiter: Any,
minio_limiter: Any,
chunk_limiter: Any,
embed_limiter: Any,
kg_limiter: Any,
set_progress: Any,
has_canceled: Any,
billing_hook: Optional[BillingHook] = None,
) -> None:
"""Run a document processing task in production mode.
Args:
task: Task configuration dictionary.
chat_limiter: Rate limiter for chat operations.
minio_limiter: Rate limiter for MinIO operations.
chunk_limiter: Rate limiter for chunking operations.
embed_limiter: Rate limiter for embedding operations.
kg_limiter: Rate limiter for knowledge graph operations.
set_progress: Progress callback function.
has_canceled: Function to check if task is canceled.
billing_hook: Optional billing hook for pipeline success/error callbacks.
"""
with recording_context_manager(_NULL_RECORDING_CONTEXT):
# Use NullRecordingContext in production to avoid memory allocation
set_recording_context(_NULL_RECORDING_CONTEXT)
# Create TaskContext with all execution resources
task_context = TaskContext(
task=task,
limiters=TaskLimiters(
chat=chat_limiter,
minio=minio_limiter,
chunk=chunk_limiter,
embed=embed_limiter,
kg=kg_limiter,
),
callbacks=TaskCallbacks(
progress=set_progress,
has_canceled=has_canceled,
),
recording_context=_NULL_RECORDING_CONTEXT,
)
# Execute with TaskHandler
handler = TaskHandler(ctx=task_context, billing_hook=billing_hook)
await handler.handle_task()
@classmethod
async def dry_run_task(
cls,
task: TaskDict,
recording_ctx1: BaseRecordingContext,
chat_limiter: Any,
minio_limiter: Any,
chunk_limiter: Any,
embed_limiter: Any,
kg_limiter: Any,
set_progress: Any,
has_canceled: Any,
) -> None:
"""Run a document processing task in dry-run mode for comparison.
This executes the task with a write operation interceptor that records
all write operations, then compares the results with the production run.
Args:
task: Task configuration dictionary.
recording_ctx1: RecordingContext from production execution.
chat_limiter: Rate limiter for chat operations.
minio_limiter: Rate limiter for MinIO operations.
chunk_limiter: Rate limiter for chunking operations.
embed_limiter: Rate limiter for embedding operations.
kg_limiter: Rate limiter for knowledge graph operations.
set_progress: Progress callback function.
has_canceled: Function to check if task is canceled.
"""
interceptor = WriteOperationInterceptor(recording_ctx1.get_all_func_return_values())
recording_ctx2 = RecordingContext()
with recording_context_manager(recording_ctx2):
set_recording_context(recording_ctx2)
# Create TaskContext with all execution resources
task_context = TaskContext(
task=task,
limiters=TaskLimiters(
chat=chat_limiter,
minio=minio_limiter,
chunk=chunk_limiter,
embed=embed_limiter,
kg=kg_limiter,
),
callbacks=TaskCallbacks(
progress=set_progress,
has_canceled=has_canceled,
),
write_interceptor=interceptor,
recording_context=recording_ctx2,
)
# Execute with TaskHandler
handler = TaskHandler(ctx=task_context)
await handler.handle_task()
# Compare results
comp: ContextComparator = ContextComparator()
comp_result = comp.compare(task_context.id, recording_ctx1, recording_ctx2)
logging.info(f"-------{task_context.name}, compare result:{comp_result.to_markdown()}")
if interceptor.remaining_values_count() > 0 or comp_result.mismatched_keys > 0:
logging.info(f"------task:{task_context.id} {task_context.name} differs, "
f"interceptor.remaining_values_count():{interceptor.remaining_values_count()}, "
f"mismatched_keys:{comp_result.mismatched_keys}")
if interceptor.remaining_values_count() > 0:
logging.info(f"------task:{task_context.id}, remaining values:{interceptor.remaining_values()}")
if comp_result.mismatched_keys > 0:
logging.info(f"-------compare result:{comp_result.details}")
else:
logging.info(f"------task:{task_context.id} {task_context.name} same result for prod and dry run ")