Refactor: tidy up ThreadPoolExecutor lifecycle in file_service and task executor (#14668)

## Summary
- Wrap the `ThreadPoolExecutor` instances in `FileService.parse_docs`
and `FileService.get_files` with `with ... as exe:` blocks for
deterministic cleanup
- Replace the `concurrent.futures.ThreadPoolExecutor` in
`do_handle_task` with `asyncio.create_task(asyncio.to_thread(build_TOC,
...))`, preserving the existing parallelism with chunk insertion while
leveraging the surrounding async context
- Drop the now-unused `import concurrent` and the
`executor.shutdown(wait=False)` call in the `finally` block

Closes #14622.

No behavioral change, no public API change. Net diff: ~19 insertions /
25 deletions across two files.

## Test plan
- [ ] `uv run ruff check api/db/services/file_service.py
rag/svr/task_executor.py` passes
- [ ] Upload a multi-file batch through the chat/file endpoint and
confirm `FileService.parse_docs` still returns combined parsed text
- [ ] Trigger `FileService.get_files` via the chat reference flow with a
mix of image and non-image files; verify both `raw=True` and `raw=False`
paths return correctly
- [ ] Run a `naive`-parser document task with `toc_extraction: true` and
confirm the TOC chunk is generated and inserted exactly as before
- [ ] Run a `naive`-parser document task with `toc_extraction: false`
and confirm the path with `toc_thread = None` is unaffected
- [ ] Cancel a running task to exercise the `finally` block and confirm
cleanup still works without the executor shutdown call

---------

Co-authored-by: web-dev0521 <jasonpette1783@gmail.com>
Co-authored-by: Wang Qi <wangq8@outlook.com>
This commit is contained in:
web-dev0521
2026-05-11 00:59:00 -04:00
committed by GitHub
parent 13e6554901
commit cc207b5b05
2 changed files with 21 additions and 25 deletions

View File

@@ -561,14 +561,9 @@ class FileService(CommonService):
@staticmethod
def parse_docs(file_objs, user_id):
exe = ThreadPoolExecutor(max_workers=12)
threads = []
for file in file_objs:
threads.append(exe.submit(FileService.parse, file.filename, file.read(), False))
res = []
for th in threads:
res.append(th.result())
with ThreadPoolExecutor(max_workers=12) as exe:
threads = [exe.submit(FileService.parse, file.filename, file.read(), False) for file in file_objs]
res = [th.result() for th in threads]
return "\n\n".join(res)
@@ -793,19 +788,21 @@ class FileService(CommonService):
def image_to_base64(file):
return "data:{};base64,{}".format(file["mime_type"],
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
exe = ThreadPoolExecutor(max_workers=5)
threads = []
imgs = []
for file in files:
if file["mime_type"].find("image") >=0:
if raw:
imgs.append(FileService.get_blob(file["created_by"], file["id"]))
else:
threads.append(exe.submit(image_to_base64, file))
continue
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"], layout_recognize))
with ThreadPoolExecutor(max_workers=5) as exe:
for file in files:
if file["mime_type"].find("image") >=0:
if raw:
imgs.append(FileService.get_blob(file["created_by"], file["id"]))
else:
threads.append(exe.submit(image_to_base64, file))
continue
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"], layout_recognize))
results = [th.result() for th in threads]
if raw:
return [th.result() for th in threads], imgs
return results, imgs
else:
return [th.result() for th in threads]
return results

View File

@@ -22,7 +22,6 @@ start_ts = time.time()
import asyncio
import socket
import concurrent
# from beartype import BeartypeConf
# from beartype.claw import beartype_all # <-- you didn't sign up for this
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
@@ -1089,7 +1088,6 @@ async def do_handle_task(task):
task_parser_config = task["parser_config"]
task_start_ts = timer()
toc_thread = None
executor = concurrent.futures.ThreadPoolExecutor()
# prepare the progress callback function
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
@@ -1251,7 +1249,7 @@ async def do_handle_task(task):
logging.info(progress_message)
progress_callback(msg=progress_message)
if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False):
toc_thread = executor.submit(build_TOC, task, chunks, progress_callback)
toc_thread = asyncio.create_task(asyncio.to_thread(build_TOC, task, chunks, progress_callback))
chunk_count = len(set([chunk["id"] for chunk in chunks]))
start_ts = timer()
@@ -1318,7 +1316,7 @@ async def do_handle_task(task):
progress_callback(msg="Indexing done ({:.2f}s).".format(timer() - start_ts))
if toc_thread:
d = toc_thread.result()
d = await toc_thread
if d:
if not await _maybe_insert_chunks([d]):
return
@@ -1337,7 +1335,8 @@ async def do_handle_task(task):
)
finally:
executor.shutdown(wait=False)
if toc_thread is not None and not toc_thread.done():
toc_thread.cancel()
if has_canceled(task_id):
try:
exists = await thread_pool_exec(