mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Refactor: tidy up ThreadPoolExecutor lifecycle in file_service and task executor (#14668)
## Summary - Wrap the `ThreadPoolExecutor` instances in `FileService.parse_docs` and `FileService.get_files` with `with ... as exe:` blocks for deterministic cleanup - Replace the `concurrent.futures.ThreadPoolExecutor` in `do_handle_task` with `asyncio.create_task(asyncio.to_thread(build_TOC, ...))`, preserving the existing parallelism with chunk insertion while leveraging the surrounding async context - Drop the now-unused `import concurrent` and the `executor.shutdown(wait=False)` call in the `finally` block Closes #14622. No behavioral change, no public API change. Net diff: ~19 insertions / 25 deletions across two files. ## Test plan - [ ] `uv run ruff check api/db/services/file_service.py rag/svr/task_executor.py` passes - [ ] Upload a multi-file batch through the chat/file endpoint and confirm `FileService.parse_docs` still returns combined parsed text - [ ] Trigger `FileService.get_files` via the chat reference flow with a mix of image and non-image files; verify both `raw=True` and `raw=False` paths return correctly - [ ] Run a `naive`-parser document task with `toc_extraction: true` and confirm the TOC chunk is generated and inserted exactly as before - [ ] Run a `naive`-parser document task with `toc_extraction: false` and confirm the path with `toc_thread = None` is unaffected - [ ] Cancel a running task to exercise the `finally` block and confirm cleanup still works without the executor shutdown call --------- Co-authored-by: web-dev0521 <jasonpette1783@gmail.com> Co-authored-by: Wang Qi <wangq8@outlook.com>
This commit is contained in:
@@ -561,14 +561,9 @@ class FileService(CommonService):
|
||||
|
||||
@staticmethod
|
||||
def parse_docs(file_objs, user_id):
|
||||
exe = ThreadPoolExecutor(max_workers=12)
|
||||
threads = []
|
||||
for file in file_objs:
|
||||
threads.append(exe.submit(FileService.parse, file.filename, file.read(), False))
|
||||
|
||||
res = []
|
||||
for th in threads:
|
||||
res.append(th.result())
|
||||
with ThreadPoolExecutor(max_workers=12) as exe:
|
||||
threads = [exe.submit(FileService.parse, file.filename, file.read(), False) for file in file_objs]
|
||||
res = [th.result() for th in threads]
|
||||
|
||||
return "\n\n".join(res)
|
||||
|
||||
@@ -793,19 +788,21 @@ class FileService(CommonService):
|
||||
def image_to_base64(file):
|
||||
return "data:{};base64,{}".format(file["mime_type"],
|
||||
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
||||
exe = ThreadPoolExecutor(max_workers=5)
|
||||
threads = []
|
||||
imgs = []
|
||||
for file in files:
|
||||
if file["mime_type"].find("image") >=0:
|
||||
if raw:
|
||||
imgs.append(FileService.get_blob(file["created_by"], file["id"]))
|
||||
else:
|
||||
threads.append(exe.submit(image_to_base64, file))
|
||||
continue
|
||||
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"], layout_recognize))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as exe:
|
||||
for file in files:
|
||||
if file["mime_type"].find("image") >=0:
|
||||
if raw:
|
||||
imgs.append(FileService.get_blob(file["created_by"], file["id"]))
|
||||
else:
|
||||
threads.append(exe.submit(image_to_base64, file))
|
||||
continue
|
||||
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"], layout_recognize))
|
||||
|
||||
results = [th.result() for th in threads]
|
||||
|
||||
if raw:
|
||||
return [th.result() for th in threads], imgs
|
||||
return results, imgs
|
||||
else:
|
||||
return [th.result() for th in threads]
|
||||
return results
|
||||
|
||||
@@ -22,7 +22,6 @@ start_ts = time.time()
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import concurrent
|
||||
# from beartype import BeartypeConf
|
||||
# from beartype.claw import beartype_all # <-- you didn't sign up for this
|
||||
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
||||
@@ -1089,7 +1088,6 @@ async def do_handle_task(task):
|
||||
task_parser_config = task["parser_config"]
|
||||
task_start_ts = timer()
|
||||
toc_thread = None
|
||||
executor = concurrent.futures.ThreadPoolExecutor()
|
||||
|
||||
# prepare the progress callback function
|
||||
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
|
||||
@@ -1251,7 +1249,7 @@ async def do_handle_task(task):
|
||||
logging.info(progress_message)
|
||||
progress_callback(msg=progress_message)
|
||||
if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False):
|
||||
toc_thread = executor.submit(build_TOC, task, chunks, progress_callback)
|
||||
toc_thread = asyncio.create_task(asyncio.to_thread(build_TOC, task, chunks, progress_callback))
|
||||
|
||||
chunk_count = len(set([chunk["id"] for chunk in chunks]))
|
||||
start_ts = timer()
|
||||
@@ -1318,7 +1316,7 @@ async def do_handle_task(task):
|
||||
progress_callback(msg="Indexing done ({:.2f}s).".format(timer() - start_ts))
|
||||
|
||||
if toc_thread:
|
||||
d = toc_thread.result()
|
||||
d = await toc_thread
|
||||
if d:
|
||||
if not await _maybe_insert_chunks([d]):
|
||||
return
|
||||
@@ -1337,7 +1335,8 @@ async def do_handle_task(task):
|
||||
)
|
||||
|
||||
finally:
|
||||
executor.shutdown(wait=False)
|
||||
if toc_thread is not None and not toc_thread.done():
|
||||
toc_thread.cancel()
|
||||
if has_canceled(task_id):
|
||||
try:
|
||||
exists = await thread_pool_exec(
|
||||
|
||||
Reference in New Issue
Block a user