From 423fb6faaec7e45d84722b7b2559feefb53bb94a Mon Sep 17 00:00:00 2001 From: buua436 Date: Thu, 4 Jun 2026 17:57:51 +0800 Subject: [PATCH] fix: duplicate document ingest guard (#15638) ### What problem does this PR solve? When a document is rerun or updated concurrently, the previous unconditional update could overwrite a newer task state. This change adds an `update_time`-based optimistic lock so the update only succeeds if the record has not been modified by another flow in the meantime. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/apps/restful_apis/document_api.py | 19 ++++++++++++++++--- api/db/services/common_service.py | 20 ++++++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/api/apps/restful_apis/document_api.py b/api/apps/restful_apis/document_api.py index d83ff442ea..ffb62f4016 100644 --- a/api/apps/restful_apis/document_api.py +++ b/api/apps/restful_apis/document_api.py @@ -38,7 +38,7 @@ from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService from api.common.check_team_permission import check_kb_team_permission -from api.db.services.task_service import TaskService, cancel_all_task_of +from api.db.services.task_service import TaskService, cancel_all_task_of, has_canceled from api.utils.api_utils import construct_json_result, get_data_error_result, get_error_data_result, get_result, get_json_result, \ server_error_response, add_tenant_id_to_kwargs, get_request_json, get_error_argument_result, check_duplicate_ids from api.utils.pagination_utils import validate_rest_api_page_size @@ -1397,17 +1397,30 @@ def _run_sync(user_id:str, req): if not e: return RetCode.DATA_ERROR, "Document not found!" + if str(req["run"]) == TaskStatus.RUNNING.value: + tasks = list(TaskService.query(doc_id=doc_id)) + has_active_task = any((task.progress or 0) < 1 and not has_canceled(task.id) for task in tasks) + if str(doc.run) in [TaskStatus.RUNNING.value, TaskStatus.SCHEDULE.value] or has_active_task: + return RetCode.DATA_ERROR, "Document is already running" + + should_cancel = False if str(req["run"]) == TaskStatus.CANCEL.value: tasks = list(TaskService.query(doc_id=doc_id)) has_unfinished_task = any((task.progress or 0) < 1 for task in tasks) if str(doc.run) in [TaskStatus.RUNNING.value, TaskStatus.CANCEL.value] or has_unfinished_task: - cancel_all_task_of(doc_id) + should_cancel = True else: return RetCode.DATA_ERROR, "Cannot cancel a task that is not in RUNNING status" if all([rerun_with_delete, str(doc.run) == TaskStatus.DONE.value]): DocumentService.clear_chunk_num_when_rerun(doc_id) - DocumentService.update_by_id(doc_id, info) + affected_rows = DocumentService.update_by_id_if_update_time(doc_id, doc.update_time, info) + if not affected_rows: + return RetCode.DATA_ERROR, "Document is already running" + + if str(req["run"]) == TaskStatus.CANCEL.value and should_cancel: + cancel_all_task_of(doc_id) + if req.get("delete", False): TaskService.filter_delete([Task.doc_id == doc_id]) if settings.docStoreConn.index_exist(search.index_name(doc_tenant_id), doc.kb_id): diff --git a/api/db/services/common_service.py b/api/db/services/common_service.py index 8ef4bb94b4..ccf896542e 100644 --- a/api/db/services/common_service.py +++ b/api/db/services/common_service.py @@ -276,6 +276,26 @@ class CommonService: num = cls.model.update(data).where(cls.model.id == pid).execute() return num + @classmethod + @DB.connection_context() + @retry_db_operation + def update_by_id_if_update_time(cls, pid, update_time, data): + # Update a single record by ID only if update_time matches the expected value. + # Args: + # pid: Record ID + # update_time: Expected update_time value for optimistic locking + # data: Updated field values + # Returns: + # Number of records updated + data["update_time"] = current_timestamp() + data["update_date"] = datetime_format(datetime.now()) + num = ( + cls.model.update(data) + .where(cls.model.id == pid, cls.model.update_time == update_time) + .execute() + ) + return num + @classmethod @DB.connection_context() def get_by_id(cls, pid):