mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
## Summary
Decomposes the monolithic `task_executor.py` (1945 lines) into a 6-layer
architecture with clear separation of concerns. The refactored code is
functionally equivalent to the original, verified through 400 passing
tests and a production-vs-dry-run comparison framework.
## Architecture
```
entry (task_manager)
└─ orchestration (task_handler)
├─ services (chunk_service, embedding_service, dataflow_service, raptor_service, post_processor)
│ └─ utilities (chunk_builder, chunk_post_processor, embedding_utils)
└─ infrastructure (task_context, recording_context, interceptor)
```
Key design decisions:
- **TaskContext** — typed facade over raw task dict, injects rate
limiters + callbacks via composition
- **RecordingContext + Comparator** — enables side-by-side production vs
dry-run execution for safe migration
- **NullRecordingContext** — zero-allocation no-op for production, uses
`__slots__`
- **WriteOperationInterceptor** — FIFO replay of previous runs function
returns for comparison mode
## Migration Strategy
The original `handle_task()` in `task_executor.py` uses a 3-way switch
via `TE_RUN_MODE`:
- `TE_RUN_MODE=0` (default) → runs refactored code
- `TE_RUN_MODE=1` → runs both original + refactored, compares all
intermediate results
- `TE_RUN_MODE=2` → runs original code (fallback)
The comparison mode (`TE_RUN_MODE=1`) records ~40 intermediate values
(chunks, vectors, token counts, func return values) from the production
run and replays them during dry-run, then uses `ContextComparator` to
report mismatches.
## Functional Equivalence Fixes
All divergences between original and refactored code were identified and
fixed:
- Timeout decorators (handle/build_chunks/raptor/embedding)
- NullRecordingContext leak in finally block causing RuntimeError
- MinIO None-binary check with proper FileNotFoundError
- Dataflow dispatch after embedding binding + init_kb
- Memory task missing return after processing
- RAPTOR checkpoint progress reporting
- Tag cache (get_tags_from_cache/set_tags_to_cache) restoration
- dataflow_id correction in _load_dsl
- Language default Chinese, dead code guard removal
- embed_chunks made async with proper thread_pool_exec
- Full GraphRAG default configuration (10 parameters)
- Hardcoded q_768_vec fallback removal in RAPTOR
## Test Changes
- 20 new tests covering table parser manual mode, tag cache, embedding
edge cases, RAPTOR checkpoint, dataflow_id correction, storage binary
None, cancel cleanup, metadata=None boundary
- Unified `make_task_context`/`make_task_dict` factories eliminated 10+
duplicated helpers
- DataflowService tests migrated from internal method mocks to IO
boundary mocks (real orchestration code executes)
- Parametrized duplicate build_chunks post-processor tests
- 7 raptor tests modernized to @pytest.mark.asyncio
- Mock count per test reduced through boundary-level mocking strategy
**Test count: 400 passing, 0 warnings, 0 skips**
## Files Changed
| File | Change |
|------|--------|
| `rag/svr/task_executor.py` | +1 line (NullRecordingContext fix) |
| `rag/svr/task_executor_refactor/task_handler.py` | Orchestration
layer, 8 logic fixes |
| `rag/svr/task_executor_refactor/chunk_service.py` | +timeout +
None-check |
| `rag/svr/task_executor_refactor/embedding_service.py` | sync→async
rewrite |
| `rag/svr/task_executor_refactor/dataflow_service.py` | dataflow_id fix
+ timeout |
| `rag/svr/task_executor_refactor/raptor_service.py` | checkpoint fix +
assert |
| `rag/svr/task_executor_refactor/chunk_post_processor.py` | tag cache
restore |
| `rag/svr/task_executor_refactor/task_context.py` | language default
fix |
| `test/.../conftest.py` | +294 lines shared helpers |
| `test/.../*.py` | 15 test files refactored, 20 new tests |
---------
Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
599 lines
25 KiB
Python
599 lines
25 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Task Handler Module.
|
|
|
|
Provides [`TaskHandler`](rag/svr/task_executor_refactor/task_handler.py:56) as the main entry point
|
|
for handling document processing tasks with refactored, testable methods.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import json
|
|
import xxhash
|
|
|
|
from timeit import default_timer as timer
|
|
from typing import Callable, Dict, List, Optional
|
|
|
|
from api.db.services.document_service import DocumentService
|
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
from api.db.joint_services.memory_message_service import handle_save_to_memory_task
|
|
from api.db.joint_services.tenant_model_service import (
|
|
get_tenant_default_model_by_type,
|
|
get_model_config_from_provider_instance
|
|
)
|
|
from api.db.services.llm_service import LLMBundle
|
|
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID
|
|
from common.constants import LLMType
|
|
from common.exceptions import TaskCanceledException
|
|
from common.connection_utils import timeout
|
|
from common.misc_utils import thread_pool_exec
|
|
from rag.nlp import search
|
|
from rag.svr.task_executor_refactor.constants import CANVAS_DEBUG_DOC_ID
|
|
from rag.svr.task_executor_refactor.chunk_service import ChunkService
|
|
from rag.svr.task_executor_refactor.dataflow_service import BillingHook, DataflowService
|
|
from rag.svr.task_executor_refactor.embedding_service import EmbeddingService
|
|
from rag.svr.task_executor_refactor.post_processor import PostProcessor
|
|
from rag.svr.task_executor_refactor.raptor_service import RaptorService
|
|
from rag.svr.task_executor_refactor.raptor_utils import delete_raptor_chunks
|
|
from rag.svr.task_executor_refactor.recording_context import RecordingContext
|
|
from rag.svr.task_executor_refactor.task_context import TaskContext
|
|
from rag.graphrag.general.index import run_graphrag_for_kb
|
|
from api.db.services.file2document_service import File2DocumentService
|
|
from rag.prompts.generator import run_toc_from_text
|
|
from common import settings
|
|
|
|
|
|
class TaskHandler:
|
|
"""Main task handler for document processing.
|
|
|
|
This class orchestrates the entire document processing pipeline:
|
|
1. Task type detection (memory, dataflow, raptor, graphrag, standard)
|
|
2. Model binding (embedding, chat)
|
|
3. Chunk building or RAPTOR/GraphRAG execution
|
|
4. Embedding
|
|
5. Indexing
|
|
6. Post-processing (TOC, table metadata)
|
|
|
|
All intermediate results are recorded via RecordingContext for comparison.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
ctx: TaskContext,
|
|
billing_hook: Optional[BillingHook] = None,
|
|
):
|
|
"""Initialize TaskHandler.
|
|
|
|
Args:
|
|
ctx: TaskContext containing task configuration and execution resources.
|
|
billing_hook: Optional billing hook for pipeline success/error callbacks.
|
|
"""
|
|
self._task_context = ctx
|
|
self._billing_hook = billing_hook
|
|
|
|
async def handle_task(self) -> None:
|
|
try:
|
|
await self.handle()
|
|
finally:
|
|
task_id = self._task_context.id
|
|
task_tenant_id = self._task_context.tenant_id
|
|
task_dataset_id = self._task_context.kb_id
|
|
task_doc_id = self._task_context.doc_id
|
|
if self._task_context.has_canceled_func(task_id):
|
|
try:
|
|
exists = await thread_pool_exec(
|
|
settings.docStoreConn.index_exist,
|
|
search.index_name(task_tenant_id),
|
|
task_dataset_id,
|
|
)
|
|
if exists:
|
|
ret = await thread_pool_exec(
|
|
settings.docStoreConn.delete,
|
|
{"doc_id": task_doc_id},
|
|
search.index_name(task_tenant_id),
|
|
task_dataset_id,
|
|
)
|
|
self._task_context.recording_context.save_func_return_value("docStoreConn.delete", ret)
|
|
except Exception as e:
|
|
logging.exception(
|
|
f"Remove doc({task_doc_id}) from docStore failed when task({task_id}) canceled, exception: {e}")
|
|
|
|
@timeout(60 * 60 * 3, 1)
|
|
async def handle(self) -> None:
|
|
"""Handle a document processing task."""
|
|
ctx = self._task_context
|
|
task_type = ctx.task_type
|
|
task_id = ctx.id
|
|
|
|
# Handle memory tasks
|
|
if task_type == "memory":
|
|
# ignore when it's dry run - no change on handle_save_to_memory_task when refactor
|
|
if isinstance(ctx.write_interceptor, RecordingContext):
|
|
logging.info(f"dry run, ignore handle_save_to_memory_task {task_id}")
|
|
else:
|
|
# actual run - not dry run
|
|
await handle_save_to_memory_task(ctx.raw_task)
|
|
return
|
|
|
|
# Check if task is canceled
|
|
if ctx.has_canceled_func(task_id):
|
|
ctx.progress_cb(-1, msg="Task has been canceled.")
|
|
return
|
|
|
|
# Language defaults to "Chinese" via TaskContext._DEFAULTS — safe to bind model directly.
|
|
# Bind embedding model (matching original do_handle_task order: bind + init_kb before routing)
|
|
result = await self._bind_embedding_model()
|
|
if result is None:
|
|
return
|
|
embedding_model, vector_size = result
|
|
|
|
with embedding_model:
|
|
self._init_kb(vector_size)
|
|
|
|
# Handle dataflow tasks (after init_kb, matching original behavior)
|
|
if task_type == "dataflow" and ctx.doc_id == CANVAS_DEBUG_DOC_ID:
|
|
await self._run_dataflow()
|
|
return
|
|
|
|
if task_type.startswith("dataflow"):
|
|
await self._run_dataflow()
|
|
return
|
|
|
|
# Route to appropriate handler
|
|
if task_type == "raptor":
|
|
await self._run_raptor(embedding_model, vector_size)
|
|
elif task_type == "graphrag":
|
|
await self._run_graphrag(embedding_model)
|
|
elif task_type == "mindmap":
|
|
ctx.progress_cb(1, "place holder")
|
|
elif task_type == "evaluation":
|
|
await self._run_evaluation()
|
|
elif task_type == "reembedding":
|
|
await self._run_reembedding()
|
|
elif task_type == "clone":
|
|
await self._run_clone()
|
|
else:
|
|
await self._run_standard_chunking(embedding_model)
|
|
|
|
|
|
def _init_kb(self, vector_size: int) -> None:
|
|
"""Initialize knowledge base index."""
|
|
ctx = self._task_context
|
|
idxnm = search.index_name(ctx.tenant_id)
|
|
parser_id = ctx.parser_id
|
|
# Create index if not exists
|
|
settings.docStoreConn.create_idx(idxnm, ctx.kb_id, vector_size, parser_id)
|
|
|
|
async def _run_dataflow(self) -> None:
|
|
"""Run dataflow pipeline."""
|
|
dataflow_service = DataflowService(
|
|
ctx=self._task_context,
|
|
billing_hook=self._billing_hook,
|
|
)
|
|
await dataflow_service.run_dataflow()
|
|
|
|
async def _run_evaluation(self) -> None:
|
|
"""Run evaluation task."""
|
|
ctx = self._task_context
|
|
ctx.progress_cb(1, "Evaluation task placeholder")
|
|
|
|
async def _run_reembedding(self) -> None:
|
|
"""Run reembedding task."""
|
|
ctx = self._task_context
|
|
ctx.progress_cb(1, "Reembedding task placeholder")
|
|
|
|
async def _run_clone(self) -> None:
|
|
"""Run clone task."""
|
|
ctx = self._task_context
|
|
ctx.progress_cb(1, "Clone task placeholder")
|
|
|
|
async def _bind_embedding_model(self) -> Optional[tuple]:
|
|
"""Bind embedding model to task.
|
|
|
|
Returns:
|
|
Tuple of (embedding_model, vector_size) on success, or None on failure.
|
|
"""
|
|
ctx = self._task_context
|
|
task_tenant_id = ctx.tenant_id
|
|
task_embedding_id = ctx.embd_id
|
|
task_language = ctx.language
|
|
|
|
try:
|
|
if task_embedding_id:
|
|
embd_model_config = get_model_config_from_provider_instance(
|
|
task_tenant_id, LLMType.EMBEDDING, task_embedding_id
|
|
)
|
|
else:
|
|
embd_model_config = get_tenant_default_model_by_type(
|
|
task_tenant_id, LLMType.EMBEDDING
|
|
)
|
|
embedding_model = LLMBundle(task_tenant_id, embd_model_config, lang=task_language)
|
|
vts, _ = embedding_model.encode(["ok"])
|
|
return embedding_model, len(vts[0])
|
|
except Exception as e:
|
|
error_message = f'Fail to bind embedding model: {str(e)}'
|
|
ctx.progress_cb(-1, msg=error_message)
|
|
logging.exception(error_message)
|
|
raise
|
|
|
|
async def _run_raptor(
|
|
self,
|
|
embedding_model: LLMBundle,
|
|
vector_size: int,
|
|
) -> None:
|
|
"""Run RAPTOR summary generation."""
|
|
ctx = self._task_context
|
|
task_tenant_id = ctx.tenant_id
|
|
task_dataset_id = ctx.kb_id
|
|
kb_task_llm_id = ctx.kb_parser_config.get("llm_id") or ctx.llm_id
|
|
|
|
ok, kb = KnowledgebaseService.get_by_id(task_dataset_id)
|
|
if not ok:
|
|
ctx.progress_cb(prog=-1.0, msg="Cannot found valid dataset for RAPTOR task")
|
|
return
|
|
|
|
kb_parser_config = kb.parser_config
|
|
if not kb_parser_config.get("raptor", {}).get("use_raptor", False):
|
|
kb_parser_config.update({
|
|
"raptor": {
|
|
"use_raptor": True,
|
|
"prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize.",
|
|
"max_token": 256,
|
|
"threshold": 0.1,
|
|
"max_cluster": 64,
|
|
"random_seed": 0,
|
|
"scope": "file",
|
|
"clustering_method": "gmm",
|
|
"tree_builder": "raptor",
|
|
},
|
|
})
|
|
if ctx.write_interceptor:
|
|
update_result = ctx.write_interceptor.intercept("KnowledgebaseService.update_by_id")
|
|
else:
|
|
update_result = KnowledgebaseService.update_by_id(kb.id, {"parser_config": kb_parser_config})
|
|
|
|
if not update_result:
|
|
ctx.progress_cb(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
|
|
return
|
|
|
|
# Bind LLM for raptor
|
|
chat_model_config = get_model_config_from_provider_instance(
|
|
task_tenant_id, LLMType.CHAT, kb_task_llm_id
|
|
)
|
|
with LLMBundle(task_tenant_id, chat_model_config, lang=ctx.language) as chat_model:
|
|
|
|
# Run RAPTOR
|
|
raptor_service = RaptorService(ctx=ctx)
|
|
|
|
async with ctx.kg_limiter:
|
|
chunks, token_count, raptor_cleanup_chunks = await raptor_service.run_raptor_for_kb(
|
|
kb_parser_config=kb_parser_config,
|
|
chat_mdl=chat_model,
|
|
embd_mdl=embedding_model,
|
|
vector_size=vector_size,
|
|
doc_ids=ctx.doc_ids,
|
|
)
|
|
|
|
ctx.recording_context.record("raptor_chunks", chunks)
|
|
ctx.recording_context.record("raptor_token_count", token_count)
|
|
|
|
# Insert RAPTOR chunks
|
|
if chunks:
|
|
task_doc_id = (ctx.doc_ids or [GRAPH_RAPTOR_FAKE_DOC_ID])[0]
|
|
chunk_service = ChunkService(ctx=ctx)
|
|
insert_result = await chunk_service.insert_chunks(ctx.id, task_tenant_id, task_dataset_id, chunks)
|
|
if insert_result:
|
|
ctx.recording_context.record("insertion_result", "success")
|
|
else:
|
|
ctx.recording_context.record("insertion_result", "failed")
|
|
|
|
# Cleanup stale RAPTOR chunks
|
|
cleaned_chunks = 0
|
|
for cleanup_doc_id, keep_method in raptor_cleanup_chunks:
|
|
ret = await self._delete_raptor_chunks(
|
|
cleanup_doc_id, task_tenant_id, task_dataset_id, keep_method
|
|
)
|
|
cleaned_chunks += ret
|
|
|
|
if cleaned_chunks:
|
|
ctx.progress_cb(msg=f"Cleaned up {cleaned_chunks} stale RAPTOR chunks.")
|
|
|
|
# Update document stats
|
|
if ctx.write_interceptor:
|
|
ctx.write_interceptor.intercept("DocumentService.increment_chunk_num")
|
|
else:
|
|
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, len(chunks), 0)
|
|
|
|
ctx.recording_context.record("task_status", "completed")
|
|
ctx.progress_cb(prog=1.0, msg="RAPTOR done")
|
|
|
|
async def _run_graphrag(
|
|
self,
|
|
embedding_model: LLMBundle
|
|
) -> None:
|
|
"""Run GraphRAG."""
|
|
ctx = self._task_context
|
|
task_tenant_id = ctx.tenant_id
|
|
task_dataset_id = ctx.kb_id
|
|
kb_task_llm_id = ctx.kb_parser_config.get("llm_id") or ctx.llm_id
|
|
task_language = ctx.language
|
|
|
|
ok, kb = KnowledgebaseService.get_by_id(task_dataset_id)
|
|
if not ok:
|
|
ctx.progress_cb(prog=-1.0, msg="Cannot found valid dataset for GraphRAG task")
|
|
return
|
|
|
|
kb_parser_config = kb.parser_config
|
|
if not kb_parser_config.get("graphrag", {}).get("use_graphrag", False):
|
|
kb_parser_config.update({
|
|
"graphrag": {
|
|
"use_graphrag": True,
|
|
"entity_types": [
|
|
"organization",
|
|
"person",
|
|
"geo",
|
|
"event",
|
|
"category",
|
|
],
|
|
"method": "light",
|
|
"batch_chunk_token_size": 4096,
|
|
"retry_attempts": 2,
|
|
"retry_backoff_seconds": 2.0,
|
|
"retry_backoff_max_seconds": 60.0,
|
|
"build_subgraph_timeout_per_chunk_seconds": 300,
|
|
"build_subgraph_min_timeout_seconds": 600,
|
|
"merge_timeout_seconds": 180,
|
|
"resolution_timeout_seconds": 1800,
|
|
"community_timeout_seconds": 1800,
|
|
"lock_acquire_timeout_seconds": 600,
|
|
}
|
|
})
|
|
if ctx.write_interceptor:
|
|
update_result = ctx.write_interceptor.intercept("KnowledgebaseService.update_by_id")
|
|
else:
|
|
update_result = KnowledgebaseService.update_by_id(kb.id, {"parser_config": kb_parser_config})
|
|
if not update_result:
|
|
ctx.progress_cb(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration")
|
|
return
|
|
|
|
graphrag_conf = kb_parser_config.get("graphrag", {})
|
|
start_ts = timer()
|
|
chat_model_config = get_model_config_from_provider_instance(
|
|
task_tenant_id, LLMType.CHAT, kb_task_llm_id
|
|
)
|
|
with LLMBundle(task_tenant_id, chat_model_config, lang=task_language) as chat_model:
|
|
|
|
with_resolution = graphrag_conf.get("resolution", False)
|
|
with_community = graphrag_conf.get("community", False)
|
|
|
|
async with ctx.kg_limiter:
|
|
result = await run_graphrag_for_kb(
|
|
row=ctx.raw_task,
|
|
doc_ids=ctx.doc_ids,
|
|
language=task_language,
|
|
kb_parser_config=kb_parser_config,
|
|
chat_model=chat_model,
|
|
embedding_model=embedding_model,
|
|
callback=ctx.progress_cb,
|
|
with_resolution=with_resolution,
|
|
with_community=with_community,
|
|
)
|
|
logging.info(f"GraphRAG task result for task {ctx.raw_task}:\n{result}")
|
|
|
|
ctx.recording_context.record("graphrag_result", result)
|
|
ctx.progress_cb(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts))
|
|
|
|
async def _run_standard_chunking(
|
|
self,
|
|
embedding_model: LLMBundle
|
|
) -> None:
|
|
"""Run standard chunking pipeline."""
|
|
ctx = self._task_context
|
|
task_id = ctx.id
|
|
task_tenant_id = ctx.tenant_id
|
|
task_dataset_id = ctx.kb_id
|
|
task_doc_id = ctx.doc_id
|
|
task_start_ts = timer()
|
|
doc_task_llm_id = ctx.parser_config.get("llm_id") or ctx.llm_id
|
|
ctx.raw_task['llm_id'] = doc_task_llm_id
|
|
|
|
# Build chunks
|
|
start_ts = timer()
|
|
chunk_service = ChunkService(ctx=ctx)
|
|
|
|
# Get storage binary
|
|
bucket, name = File2DocumentService.get_storage_address(doc_id=ctx.doc_id)
|
|
binary = await self._get_storage_binary(bucket, name)
|
|
if binary is None:
|
|
raise FileNotFoundError(
|
|
f"Can not find file <{ctx.name}> from minio. Could you try it again."
|
|
)
|
|
|
|
chunks = await chunk_service.build_chunks(binary)
|
|
ctx.recording_context.record("chunks", chunks)
|
|
chunk_ids = [c.get("id") for c in chunks if isinstance(c, dict) and "id" in c]
|
|
ctx.recording_context.record("chunk_ids_count", len(chunk_ids))
|
|
|
|
logging.info("Build document {}: {:.2f}s".format(ctx.name, timer() - start_ts))
|
|
|
|
if not chunks:
|
|
ctx.progress_cb(1., msg=f"No chunk built from {ctx.name}")
|
|
return
|
|
|
|
ctx.progress_cb(msg="Generate {} chunks".format(len(chunks)))
|
|
|
|
# Embed chunks
|
|
start_ts = timer()
|
|
embedding_service = EmbeddingService(ctx=ctx)
|
|
try:
|
|
token_count, vector_size = await embedding_service.embed_chunks(
|
|
chunks, embedding_model, ctx.parser_config
|
|
)
|
|
except TaskCanceledException:
|
|
raise
|
|
except Exception as e:
|
|
error_message = "Generate embedding error:{}".format(str(e))
|
|
ctx.progress_cb(-1, error_message)
|
|
logging.exception(error_message)
|
|
raise
|
|
|
|
ctx.recording_context.record("token_count", token_count)
|
|
ctx.recording_context.record("vector_size", vector_size)
|
|
progress_message = "Embedding chunks ({:.2f}s)".format(timer() - start_ts)
|
|
logging.info(progress_message)
|
|
ctx.progress_cb(msg=progress_message)
|
|
|
|
# Build TOC if needed
|
|
toc_thread = None
|
|
if ctx.parser_id.lower() == "naive" and ctx.parser_config.get("toc_extraction", False):
|
|
toc_thread = asyncio.create_task(asyncio.to_thread(self._build_toc, ctx, chunks, ctx.progress_cb))
|
|
|
|
# Insert chunks
|
|
chunk_count = len(set([chunk["id"] for chunk in chunks]))
|
|
start_ts = timer()
|
|
|
|
chunk_service = ChunkService(ctx=ctx)
|
|
|
|
if ctx.has_canceled_func(task_id):
|
|
ctx.progress_cb(-1, msg="Task has been canceled.")
|
|
return
|
|
|
|
insert_result = await chunk_service.insert_chunks(
|
|
task_id, task_tenant_id, task_dataset_id, chunks
|
|
)
|
|
|
|
if not insert_result:
|
|
ctx.recording_context.record("insertion_result", "failed")
|
|
return
|
|
ctx.recording_context.record("insertion_result", "success")
|
|
|
|
# Post-processing
|
|
post_processor = PostProcessor(ctx=ctx)
|
|
await post_processor.process_table_parser_metadata(task_doc_id, chunks)
|
|
|
|
ctx.progress_cb(msg="Indexing done ({:.2f}s).".format(timer() - start_ts))
|
|
|
|
toc_chunk = await self._process_toc_thread(toc_thread)
|
|
if toc_chunk:
|
|
ctx.recording_context.record("toc_chunk", [toc_chunk])
|
|
await post_processor.insert_toc_chunk(toc_chunk, chunk_service)
|
|
|
|
if ctx.has_canceled_func(task_id):
|
|
ctx.progress_cb(-1, msg="Task has been canceled.")
|
|
return
|
|
|
|
# Update document stats
|
|
if ctx.write_interceptor:
|
|
ctx.write_interceptor.intercept("DocumentService.increment_chunk_num")
|
|
else:
|
|
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
|
|
|
|
task_time_cost = timer() - task_start_ts
|
|
ctx.recording_context.record("task_status", "completed")
|
|
ctx.progress_cb(prog=1.0, msg="Task done ({:.2f}s)".format(task_time_cost))
|
|
|
|
logging.info(
|
|
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(
|
|
ctx.name, ctx.from_page, ctx.to_page,
|
|
len(chunks), token_count, task_time_cost
|
|
)
|
|
)
|
|
|
|
|
|
async def _process_toc_thread(self, toc_thread):
|
|
try:
|
|
if toc_thread:
|
|
return await toc_thread
|
|
else:
|
|
return None
|
|
finally:
|
|
if toc_thread is not None and not toc_thread.done():
|
|
toc_thread.cancel()
|
|
|
|
@classmethod
|
|
async def _get_storage_binary(cls, bucket: str, name: str) -> bytes:
|
|
from common import settings
|
|
"""Get binary from storage."""
|
|
return await thread_pool_exec(settings.STORAGE_IMPL.get, bucket, name)
|
|
|
|
@classmethod
|
|
def _build_toc(cls, ctx: TaskContext, docs: List[Dict], progress_cb: Callable) -> Optional[Dict]:
|
|
"""Build table of contents."""
|
|
progress_cb(msg="Start to generate table of content ...")
|
|
chat_model_config = get_model_config_from_provider_instance(
|
|
ctx.tenant_id, LLMType.CHAT, ctx.llm_id
|
|
)
|
|
with LLMBundle(ctx.tenant_id, chat_model_config, lang=ctx.language) as chat_mdl:
|
|
|
|
docs = sorted(docs, key=lambda d: (
|
|
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
|
|
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
|
|
))
|
|
|
|
# NOTE: asyncio.run() creates a new event loop in the worker thread
|
|
# (this method is called via asyncio.to_thread), which is the
|
|
# intended pattern for bridging sync -> async in a thread context.
|
|
toc: list[dict] = asyncio.run(
|
|
run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_cb)
|
|
)
|
|
logging.info("------------ T O C -------------\n" + json.dumps(toc, ensure_ascii=False, indent=' '))
|
|
|
|
for ii, item in enumerate(toc):
|
|
try:
|
|
chunk_val = item.pop("chunk_id", None)
|
|
if chunk_val is None or str(chunk_val).strip() == "":
|
|
logging.warning(f"Index {ii}: chunk_id is missing or empty. Skipping.")
|
|
continue
|
|
curr_idx = int(chunk_val or -1)
|
|
if curr_idx >= len(docs):
|
|
logging.error(f"Index {ii}: chunk_id {curr_idx} exceeds docs length {len(docs)}.")
|
|
continue
|
|
item["ids"] = [docs[curr_idx]["id"]]
|
|
if ii + 1 < len(toc):
|
|
next_chunk_val = toc[ii + 1].get("chunk_id", "")
|
|
if str(next_chunk_val).strip() != "":
|
|
next_idx = int(next_chunk_val)
|
|
for jj in range(curr_idx + 1, min(next_idx + 1, len(docs))):
|
|
item["ids"].append(docs[jj]["id"])
|
|
else:
|
|
logging.warning(f"Index {ii + 1}: next chunk_id is empty, range fill skipped.")
|
|
except (ValueError, TypeError) as e:
|
|
logging.error(f"Index {ii}: Data conversion error - {e}")
|
|
except Exception as e:
|
|
logging.exception(f"Index {ii}: Unexpected error - {e}")
|
|
|
|
if toc:
|
|
import copy
|
|
d = copy.deepcopy(docs[-1])
|
|
d["content_with_weight"] = json.dumps(toc, ensure_ascii=False)
|
|
d["toc_kwd"] = "toc"
|
|
d["available_int"] = 0
|
|
d["page_num_int"] = [100000000]
|
|
d["id"] = xxhash.xxh64(
|
|
(d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
|
return d
|
|
return None
|
|
|
|
async def _delete_raptor_chunks(
|
|
self, doc_id: str, tenant_id: str, kb_id: str, keep_method: Optional[str]
|
|
) -> int:
|
|
"""Delete RAPTOR chunks."""
|
|
if self._task_context.write_interceptor:
|
|
return self._task_context.write_interceptor.intercept("delete_raptor_chunks")
|
|
else:
|
|
return await delete_raptor_chunks(doc_id, tenant_id, kb_id, keep_method)
|