Files
ragflow/rag/svr/task_executor_refactor/dataflow_service.py
Jack b363146997 refactor: overhaul task executor with layered architecture and comprehensive test suite (#15471)
## Summary

Decomposes the monolithic `task_executor.py` (1945 lines) into a 6-layer
architecture with clear separation of concerns. The refactored code is
functionally equivalent to the original, verified through 400 passing
tests and a production-vs-dry-run comparison framework.

## Architecture

```
entry (task_manager)
  └─ orchestration (task_handler)
       ├─ services (chunk_service, embedding_service, dataflow_service, raptor_service, post_processor)
       │    └─ utilities (chunk_builder, chunk_post_processor, embedding_utils)
       └─ infrastructure (task_context, recording_context, interceptor)
```

Key design decisions:
- **TaskContext** — typed facade over raw task dict, injects rate
limiters + callbacks via composition
- **RecordingContext + Comparator** — enables side-by-side production vs
dry-run execution for safe migration
- **NullRecordingContext** — zero-allocation no-op for production, uses
`__slots__`
- **WriteOperationInterceptor** — FIFO replay of previous runs function
returns for comparison mode

## Migration Strategy

The original `handle_task()` in `task_executor.py` uses a 3-way switch
via `TE_RUN_MODE`:
- `TE_RUN_MODE=0` (default) → runs refactored code
- `TE_RUN_MODE=1` → runs both original + refactored, compares all
intermediate results
- `TE_RUN_MODE=2` → runs original code (fallback)

The comparison mode (`TE_RUN_MODE=1`) records ~40 intermediate values
(chunks, vectors, token counts, func return values) from the production
run and replays them during dry-run, then uses `ContextComparator` to
report mismatches.

## Functional Equivalence Fixes

All divergences between original and refactored code were identified and
fixed:
- Timeout decorators (handle/build_chunks/raptor/embedding)
- NullRecordingContext leak in finally block causing RuntimeError
- MinIO None-binary check with proper FileNotFoundError
- Dataflow dispatch after embedding binding + init_kb
- Memory task missing return after processing
- RAPTOR checkpoint progress reporting
- Tag cache (get_tags_from_cache/set_tags_to_cache) restoration
- dataflow_id correction in _load_dsl
- Language default Chinese, dead code guard removal
- embed_chunks made async with proper thread_pool_exec
- Full GraphRAG default configuration (10 parameters)
- Hardcoded q_768_vec fallback removal in RAPTOR

## Test Changes

- 20 new tests covering table parser manual mode, tag cache, embedding
edge cases, RAPTOR checkpoint, dataflow_id correction, storage binary
None, cancel cleanup, metadata=None boundary
- Unified `make_task_context`/`make_task_dict` factories eliminated 10+
duplicated helpers
- DataflowService tests migrated from internal method mocks to IO
boundary mocks (real orchestration code executes)
- Parametrized duplicate build_chunks post-processor tests
- 7 raptor tests modernized to @pytest.mark.asyncio
- Mock count per test reduced through boundary-level mocking strategy

**Test count: 400 passing, 0 warnings, 0 skips**

## Files Changed

| File | Change |
|------|--------|
| `rag/svr/task_executor.py` | +1 line (NullRecordingContext fix) |
| `rag/svr/task_executor_refactor/task_handler.py` | Orchestration
layer, 8 logic fixes |
| `rag/svr/task_executor_refactor/chunk_service.py` | +timeout +
None-check |
| `rag/svr/task_executor_refactor/embedding_service.py` | sync→async
rewrite |
| `rag/svr/task_executor_refactor/dataflow_service.py` | dataflow_id fix
+ timeout |
| `rag/svr/task_executor_refactor/raptor_service.py` | checkpoint fix +
assert |
| `rag/svr/task_executor_refactor/chunk_post_processor.py` | tag cache
restore |
| `rag/svr/task_executor_refactor/task_context.py` | language default
fix |
| `test/.../conftest.py` | +294 lines shared helpers |
| `test/.../*.py` | 15 test files refactored, 20 new tests |

---------

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-03 17:18:31 +08:00

399 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Dataflow Service Module.
Provides [`DataflowService`](rag/svr/task_executor_refactor/dataflow_service.py:42) for dataflow
pipeline execution.
"""
import abc
import copy
import logging
import re
from datetime import datetime
from timeit import default_timer as timer
from typing import Dict, List, Optional, Tuple
import numpy as np
import xxhash
from common import settings
from rag.svr.task_executor_refactor.embedding_utils import EmbeddingUtils
from rag.flow.pipeline import Pipeline
from api.db.services.canvas_service import UserCanvasService
from api.db.services.document_service import DocumentService
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance
from common.connection_utils import timeout
from common.constants import LLMType, PipelineTaskType
from common.metadata_utils import update_metadata_to
from common.misc_utils import thread_pool_exec
from rag.nlp import rag_tokenizer, add_positions
from rag.svr.task_executor_refactor.constants import CANVAS_DEBUG_DOC_ID
from rag.svr.task_executor_refactor.task_context import TaskContext
class BillingHook(abc.ABC):
"""Abstract base for billing hooks on pipeline success/error.
Implementations override the no-op methods to integrate with billing
systems (e.g., consume quota on success, release hold on error).
"""
async def on_pipeline_success(self) -> None:
"""Called when the dataflow pipeline completes successfully."""
async def on_pipeline_error(self) -> None:
"""Called when the dataflow pipeline encounters an error."""
class DataflowService:
"""Service for dataflow pipeline execution.
This service handles:
- Dataflow DSL loading and execution
- Chunk embedding for dataflow output
- Chunk metadata processing and indexing
"""
def __init__(
self,
ctx: TaskContext,
billing_hook: Optional[BillingHook] = None,
embedding_batch_size: int = None,
doc_bulk_size: int = None,
):
"""Initialize DataflowService.
Args:
ctx: TaskContext containing task configuration and execution resources.
billing_hook: Optional billing hook for pipeline success/error callbacks.
embedding_batch_size: Batch size for embedding operations.
doc_bulk_size: Batch size for document store inserts.
"""
self._task_context = ctx
self._billing_hook = billing_hook
self._embedding_batch_size = embedding_batch_size or self._get_default_embedding_batch_size()
self._doc_bulk_size = doc_bulk_size or self._get_default_bulk_size()
async def run_dataflow(self) -> None:
"""Run a dataflow pipeline."""
ctx = self._task_context
pipeline = None
try:
task_start_ts = timer()
dataflow_id = ctx.dataflow_id
doc_id = ctx.doc_id
task_id = ctx.id
task_dataset_id = ctx.kb_id
# Load DSL
dsl, corrected_id = await self._load_dsl(dataflow_id)
if dsl is None:
return
dataflow_id = corrected_id
# Run pipeline
pipeline = Pipeline(
dsl, tenant_id=ctx.tenant_id, doc_id=doc_id,
task_id=task_id, flow_id=dataflow_id
)
chunks = await pipeline.run(file=ctx.file) if ctx.file else await pipeline.run()
if doc_id == CANVAS_DEBUG_DOC_ID:
ctx.recording_context.record("dataflow_debug_result", "canvas_debug_mode")
ctx.recording_context.record("dataflow_chunks", chunks)
return
if not chunks:
ctx.recording_context.record("pipeline_output_count", 0)
ctx.recording_context.record("pipeline_output_type", "empty")
self._record_pipeline_log(doc_id, dataflow_id, pipeline)
return
embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
output_type = DataflowService._get_output_type(chunks)
chunks = self._normalize_chunks(chunks)
ctx.recording_context.record("pipeline_output_type", output_type)
ctx.recording_context.record("pipeline_output_count", len(chunks))
if not chunks:
self._record_pipeline_log(doc_id, dataflow_id, pipeline)
return
# Embed chunks if needed
keys = [k for o in chunks for k in list(o.keys())]
if not any([re.match(r"q_[0-9]+_vec", k) for k in keys]):
chunks, embedding_token_consumption = await self._embed_chunks(
chunks, embedding_token_consumption
)
if chunks is None:
self._record_pipeline_log(doc_id, dataflow_id, pipeline)
return
# Process chunks
metadata = self._process_chunks(chunks)
# Update document metadata
if metadata:
self._update_document_metadata(doc_id, metadata)
# Insert chunks
start_ts = timer()
self._progress(prog=0.82, msg="[DOC Engine]:\nStart to index...")
e = await self._insert_chunks(
task_id, ctx.tenant_id, ctx.kb_id, chunks
)
if not e:
self._record_pipeline_log(doc_id, dataflow_id, pipeline)
return
time_cost = timer() - start_ts
task_time_cost = timer() - task_start_ts
self._progress(
prog=1.,
msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost)
)
# Update document stats
if ctx.write_interceptor:
ctx.write_interceptor.intercept("DocumentService.increment_chunk_num")
else:
DocumentService.increment_chunk_num(
doc_id, task_dataset_id, embedding_token_consumption, len(chunks), task_time_cost
)
logging.info(
"[Done], chunks({}), token({}), elapsed:{:.2f}".format(
len(chunks), embedding_token_consumption, task_time_cost
)
)
ctx.recording_context.record("dataflow_chunks", chunks)
self._record_pipeline_log(doc_id, dataflow_id, pipeline)
# Billing hook: pipeline succeeded
if self._billing_hook:
await self._billing_hook.on_pipeline_success()
except Exception:
if self._billing_hook:
await self._billing_hook.on_pipeline_error()
raise
async def _load_dsl(self, dataflow_id: str) -> tuple:
"""Load dataflow DSL from service.
Returns:
Tuple of (dsl, corrected_dataflow_id).
When task_type is not 'dataflow', the dataflow_id is corrected
from the pipeline log's pipeline_id.
"""
ctx = self._task_context
if ctx.task_type == "dataflow":
e, cvs = UserCanvasService.get_by_id(dataflow_id)
assert e, "User pipeline not found."
return cvs.dsl, dataflow_id
else:
e, pipeline_log = PipelineOperationLogService.get_by_id(dataflow_id)
assert e, "Pipeline log not found."
return pipeline_log.dsl, pipeline_log.pipeline_id
@staticmethod
def _get_output_type(chunks: Dict) -> str:
"""Determine output type from chunks dict."""
if "chunks" in chunks:
return "chunks"
elif "json" in chunks:
return "json"
elif "markdown" in chunks:
return "markdown"
elif "text" in chunks:
return "text"
elif "html" in chunks:
return "html"
return "empty"
@classmethod
def _normalize_chunks(cls, chunks: Dict) -> List[Dict]:
"""Normalize chunks from various output formats."""
if "chunks" in chunks:
return copy.deepcopy(chunks["chunks"])
elif "json" in chunks:
return copy.deepcopy(chunks["json"])
elif "markdown" in chunks:
return [{"text": [chunks["markdown"]]}] if chunks["markdown"] else []
elif "text" in chunks:
return [{"text": [chunks["text"]]}] if chunks["text"] else []
elif "html" in chunks:
return [{"text": [chunks["html"]]}] if chunks["html"] else []
return []
@timeout(60)
async def _embed_chunks(
self, chunks: List[Dict], token_consumption: int
) -> Tuple[Optional[List[Dict]], int]:
"""Embed chunks using the embedding model."""
ctx = self._task_context
try:
self._progress(prog=0.82, msg="\n-------------------------------------\nStart to embedding...")
e, kb = self._get_kb_by_id(ctx.kb_id)
embedding_id = kb.embd_id
embd_model_config = get_model_config_from_provider_instance(
ctx.tenant_id, LLMType.EMBEDDING, embedding_id
)
from api.db.services.llm_service import LLMBundle
with LLMBundle(ctx.tenant_id, embd_model_config) as embedding_model:
# Prepare texts for embedding using EmbeddingUtils
texts = EmbeddingUtils.prepare_texts_for_dataflow_embedding(chunks)
delta = 0.20 / (len(texts) // self._embedding_batch_size + 1)
prog = 0.8
# Batch encode using EmbeddingUtils
vects_batches = []
for i in range(0, len(texts), self._embedding_batch_size):
batch = texts[i: i + self._embedding_batch_size]
async with ctx.embed_limiter:
vts, c = await thread_pool_exec(
self._encode_batch, batch, embedding_model
)
vects_batches.append(vts)
token_consumption += c
prog += delta
if i % (len(texts) // self._embedding_batch_size / 100 + 1) == 1:
self._progress(
prog=prog,
msg=f"{i + 1} / {len(texts) // self._embedding_batch_size}"
)
# Stack vectors using EmbeddingUtils
vects = EmbeddingUtils.stack_vectors(vects_batches)
if len(vects) != len(chunks):
raise ValueError(f"Vector count mismatch: {len(vects)} vs {len(chunks)}")
# Attach vectors using EmbeddingUtils
EmbeddingUtils.attach_vectors(chunks, vects)
return chunks, token_consumption
except Exception as e:
ctx.progress_cb(prog=-1, msg=f"[ERROR]: {e}")
return None, token_consumption
@classmethod
async def _encode_batch(cls, txts: List[str], embedding_model) -> Tuple[np.ndarray, int]:
"""Batch encode texts using the embedding model with truncation."""
truncated = EmbeddingUtils.truncate_texts(txts, embedding_model.max_length)
return embedding_model.encode(truncated)
def _process_chunks(self, chunks: List[Dict]) -> Dict:
"""Process chunks for metadata and indexing."""
ctx = self._task_context
metadata = {}
for ck in chunks:
ck["doc_id"] = ctx.doc_id
ck["kb_id"] = [str(ctx.kb_id)]
ck["docnm_kwd"] = ctx.name
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
ck["create_timestamp_flt"] = datetime.now().timestamp()
if not ck.get("id"):
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
if "questions" in ck:
if "question_tks" not in ck:
ck["question_kwd"] = ck["questions"].split("\n")
ck["question_tks"] = rag_tokenizer.tokenize(str(ck["questions"]))
del ck["questions"]
if "keywords" in ck:
if "important_tks" not in ck:
ck["important_kwd"] = [k for k in re.split(r"[,;;、\r\n]+", ck["keywords"]) if k.strip()]
ck["important_tks"] = rag_tokenizer.tokenize(str(ck["keywords"]))
del ck["keywords"]
if "summary" in ck:
if "content_ltks" not in ck:
ck["content_ltks"] = rag_tokenizer.tokenize(str(ck["summary"]))
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
del ck["summary"]
if "metadata" in ck:
metadata = update_metadata_to(metadata, ck["metadata"])
del ck["metadata"]
if "content_with_weight" not in ck:
ck["content_with_weight"] = ck["text"]
del ck["text"]
if "positions" in ck:
add_positions(ck, ck["positions"])
del ck["positions"]
return metadata
def _update_document_metadata(self, doc_id: str, metadata: Dict) -> None:
"""Update document metadata."""
existing_meta = DocMetadataService.get_document_metadata(doc_id)
existing_meta = existing_meta if isinstance(existing_meta, dict) else {}
metadata = update_metadata_to(metadata, existing_meta)
self._task_context.recording_context.record("run_dataflow_metadata", metadata)
if self._task_context.write_interceptor:
self._task_context.write_interceptor.intercept("DocMetadataService.update_document_metadata")
else:
DocMetadataService.update_document_metadata(doc_id, metadata)
async def _insert_chunks(
self, task_id: str, tenant_id: str, kb_id: str, chunks: List[Dict]
) -> bool:
"""Insert chunks into document store."""
from rag.svr.task_executor_refactor.chunk_service import ChunkService
chunk_service = ChunkService(self._task_context)
return await chunk_service.insert_chunks(task_id, tenant_id, kb_id, chunks)
def _record_pipeline_log(self, doc_id: str, dataflow_id: str, pipeline) -> None:
"""Record pipeline operation log."""
if self._task_context.write_interceptor:
self._task_context.write_interceptor.intercept("PipelineOperationLogService.create")
else:
PipelineOperationLogService.create(
document_id=doc_id, pipeline_id=dataflow_id,
task_type=PipelineTaskType.PARSE, dsl=str(pipeline)
)
@classmethod
def _get_kb_by_id(cls, kb_id: str):
"""Get knowledge base by ID."""
from api.db.services.knowledgebase_service import KnowledgebaseService
return KnowledgebaseService.get_by_id(kb_id)
def _progress(self, prog=None, msg=None):
"""Progress callback helper."""
if prog is not None or msg is not None:
self._task_context.progress_cb(prog=prog, msg=msg)
@classmethod
def _get_default_embedding_batch_size(cls) -> int:
"""Get default embedding batch size."""
return settings.EMBEDDING_BATCH_SIZE
@classmethod
def _get_default_bulk_size(cls) -> int:
"""Get default bulk size."""
return settings.DOC_BULK_SIZE