diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index e8b71a6afd..34776a6797 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -561,14 +561,9 @@ class FileService(CommonService): @staticmethod def parse_docs(file_objs, user_id): - exe = ThreadPoolExecutor(max_workers=12) - threads = [] - for file in file_objs: - threads.append(exe.submit(FileService.parse, file.filename, file.read(), False)) - - res = [] - for th in threads: - res.append(th.result()) + with ThreadPoolExecutor(max_workers=12) as exe: + threads = [exe.submit(FileService.parse, file.filename, file.read(), False) for file in file_objs] + res = [th.result() for th in threads] return "\n\n".join(res) @@ -793,19 +788,21 @@ class FileService(CommonService): def image_to_base64(file): return "data:{};base64,{}".format(file["mime_type"], base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8")) - exe = ThreadPoolExecutor(max_workers=5) threads = [] imgs = [] - for file in files: - if file["mime_type"].find("image") >=0: - if raw: - imgs.append(FileService.get_blob(file["created_by"], file["id"])) - else: - threads.append(exe.submit(image_to_base64, file)) - continue - threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"], layout_recognize)) - + with ThreadPoolExecutor(max_workers=5) as exe: + for file in files: + if file["mime_type"].find("image") >=0: + if raw: + imgs.append(FileService.get_blob(file["created_by"], file["id"])) + else: + threads.append(exe.submit(image_to_base64, file)) + continue + threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"], layout_recognize)) + + results = [th.result() for th in threads] + if raw: - return [th.result() for th in threads], imgs + return results, imgs else: - return [th.result() for th in threads] + return results diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 8ce913e79f..cb41366170 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -22,7 +22,6 @@ start_ts = time.time() import asyncio import socket -import concurrent # from beartype import BeartypeConf # from beartype.claw import beartype_all # <-- you didn't sign up for this # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code @@ -1089,7 +1088,6 @@ async def do_handle_task(task): task_parser_config = task["parser_config"] task_start_ts = timer() toc_thread = None - executor = concurrent.futures.ThreadPoolExecutor() # prepare the progress callback function progress_callback = partial(set_progress, task_id, task_from_page, task_to_page) @@ -1251,7 +1249,7 @@ async def do_handle_task(task): logging.info(progress_message) progress_callback(msg=progress_message) if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False): - toc_thread = executor.submit(build_TOC, task, chunks, progress_callback) + toc_thread = asyncio.create_task(asyncio.to_thread(build_TOC, task, chunks, progress_callback)) chunk_count = len(set([chunk["id"] for chunk in chunks])) start_ts = timer() @@ -1318,7 +1316,7 @@ async def do_handle_task(task): progress_callback(msg="Indexing done ({:.2f}s).".format(timer() - start_ts)) if toc_thread: - d = toc_thread.result() + d = await toc_thread if d: if not await _maybe_insert_chunks([d]): return @@ -1337,7 +1335,8 @@ async def do_handle_task(task): ) finally: - executor.shutdown(wait=False) + if toc_thread is not None and not toc_thread.done(): + toc_thread.cancel() if has_canceled(task_id): try: exists = await thread_pool_exec(