diff --git a/api/apps/restful_apis/connector_api.py b/api/apps/restful_apis/connector_api.py index 21ab7fd4d0..89287a706d 100644 --- a/api/apps/restful_apis/connector_api.py +++ b/api/apps/restful_apis/connector_api.py @@ -53,17 +53,34 @@ async def update_connector(connector_id): return _connector_auth_error(connector_id, current_user.id) req = await get_request_json() + if isinstance(req, dict) and isinstance(req.get("data"), dict): + req = req["data"] + e, conn = ConnectorService.get_by_id(connector_id) if not e: return get_data_error_result(message="Can't find this Connector!") + should_sleep = False if req: - conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req} - conn["id"] = connector_id - ConnectorService.update_by_id(connector_id, conn) + update_fields = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req} + if update_fields: + update_fields["id"] = connector_id + ConnectorService.update_by_id(connector_id, update_fields) + should_sleep = True - await asyncio.sleep(1) + if req.get("reschedule"): + ConnectorService.cancel_tasks(connector_id) + ConnectorService.schedule_tasks(connector_id) + elif req.get("status") in [TaskStatus.CANCEL, "CANCEL"]: + ConnectorService.cancel_tasks(connector_id) + elif req.get("status") in [TaskStatus.SCHEDULE, "SCHEDULE"]: + ConnectorService.schedule_tasks(connector_id) + + if should_sleep: + await asyncio.sleep(1) e, conn = ConnectorService.get_by_id(connector_id) + if not e: + return get_data_error_result(message="Can't find this Connector!") return get_json_result(data=conn.to_dict()) @@ -83,9 +100,9 @@ async def create_connector(): "input_type": InputType.POLL, "config": req["config"], "refresh_freq": int(req.get("refresh_freq", 5)), - "prune_freq": int(req.get("prune_freq", 720)), + "prune_freq": int(req.get("prune_freq", 5)), "timeout_secs": int(req.get("timeout_secs", 60 * 29)), - "status": TaskStatus.SCHEDULE, + "status": TaskStatus.UNSTART, } ConnectorService.save(**conn) @@ -127,21 +144,6 @@ def list_logs(connector_id): return get_json_result(data={"total": total, "logs": arr}) -@manager.route("/connectors//resume", methods=["POST"]) # noqa: F821 -@login_required -async def resume(connector_id): - """Resume or cancel sync for an accessible connector.""" - if not ConnectorService.accessible(connector_id, current_user.id): - return _connector_auth_error(connector_id, current_user.id) - - req = await get_request_json() - if req.get("resume"): - ConnectorService.resume(connector_id, TaskStatus.SCHEDULE) - else: - ConnectorService.resume(connector_id, TaskStatus.CANCEL) - return get_json_result(data=True) - - @manager.route("/connectors//rebuild", methods=["POST"]) # noqa: F821 @login_required async def rebuild(connector_id): @@ -166,7 +168,7 @@ def rm_connector(connector_id): if not ConnectorService.accessible(connector_id, current_user.id): return _connector_auth_error(connector_id, current_user.id) - ConnectorService.resume(connector_id, TaskStatus.CANCEL) + ConnectorService.cancel_tasks(connector_id) ConnectorService.delete_by_id(connector_id) return get_json_result(data=True) diff --git a/api/db/db_models.py b/api/db/db_models.py index 3ed32ed3f2..a207b00788 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -1224,6 +1224,7 @@ class DateTimeTzField(CharField): class SyncLogs(DataBaseModel): id = CharField(max_length=32, primary_key=True) connector_id = CharField(max_length=32, index=True) + task_type = CharField(max_length=32, null=False, default="sync", index=True) status = CharField(max_length=128, null=False, help_text="Processing status", index=True) from_beginning = CharField(max_length=1, null=True, help_text="", default="0", index=False) new_docs_indexed = IntegerField(default=0, index=False) @@ -1632,6 +1633,7 @@ def migrate_db(): alter_db_add_column(migrator, "llm_factories", "rank", IntegerField(default=0, index=False)) alter_db_add_column(migrator, "api_4_conversation", "name", CharField(max_length=255, null=True, help_text="conversation name", index=False)) alter_db_add_column(migrator, "api_4_conversation", "exp_user_id", CharField(max_length=255, null=True, help_text="exp_user_id", index=True)) + alter_db_add_column(migrator, "sync_logs", "task_type", CharField(max_length=32, null=False, default="sync", index=True)) # Migrate system_settings.value from CharField to TextField for longer sandbox configs alter_db_column_type(migrator, "system_settings", "value", TextField(null=False, help_text="Configuration value (JSON, string, etc.)")) alter_db_add_column(migrator, "document", "content_hash", CharField(max_length=32, null=True, help_text="xxhash128 of document content for change detection", default="", index=True)) diff --git a/api/db/services/connector_service.py b/api/db/services/connector_service.py index 4d9fd2258a..9fa868c603 100644 --- a/api/db/services/connector_service.py +++ b/api/db/services/connector_service.py @@ -28,7 +28,7 @@ from api.db.services.document_service import DocumentService from api.db.services.document_service import DocMetadataService from api.utils.common import hash128 from common.misc_utils import get_uuid -from common.constants import TaskStatus +from common.constants import ConnectorTaskType, TaskStatus from common.settings import TIMEZONE from common.time_utils import current_timestamp, timestamp_to_date @@ -38,6 +38,33 @@ LOGGER = logging.getLogger(__name__) class ConnectorService(CommonService): model = Connector + @classmethod + def cancel_tasks(cls, connector_id): + e, conn = cls.get_by_id(connector_id) + if not e: + return + + logging.info( + "[Connector] stop connector=%s(%s)", + conn.name, + connector_id, + ) + for c2k in Connector2KbService.query(connector_id=connector_id): + SyncLogsService.filter_update( + [ + SyncLogs.connector_id == connector_id, + SyncLogs.kb_id == c2k.kb_id, + SyncLogs.status.in_([TaskStatus.SCHEDULE, TaskStatus.RUNNING]), + ], + {"status": TaskStatus.CANCEL}, + ) + ConnectorService.update_by_id(connector_id, {"status": TaskStatus.CANCEL}) + logging.info( + "[Connector] connector=%s status updated to %s", + connector_id, + TaskStatus.CANCEL, + ) + @classmethod @DB.connection_context() def accessible(cls, connector_id: str, user_id: str) -> bool: @@ -64,25 +91,39 @@ class ConnectorService(CommonService): return has_access @classmethod - def resume(cls, connector_id, status): + def schedule_tasks(cls, connector_id): + e, conn = cls.get_by_id(connector_id) + if not e: + return + + logging.info("[Connector] schedule connector=%s(%s)", conn.name, connector_id) + prune_enabled = bool((conn.config or {}).get("sync_deleted_files")) for c2k in Connector2KbService.query(connector_id=connector_id): - task = SyncLogsService.get_latest_task(connector_id, c2k.kb_id) - if not task: - if status == TaskStatus.SCHEDULE: - SyncLogsService.schedule(connector_id, c2k.kb_id) - ConnectorService.update_by_id(connector_id, {"status": status}) - return + sync_task = SyncLogsService.get_latest_task( + connector_id, + c2k.kb_id, + ConnectorTaskType.SYNC, + ) + poll_range_start = None + total_docs_indexed = 0 + if sync_task and sync_task.status == TaskStatus.DONE: + poll_range_start = sync_task.poll_range_end + total_docs_indexed = sync_task.total_docs_indexed - if task.status == TaskStatus.DONE: - if status == TaskStatus.SCHEDULE: - SyncLogsService.schedule(connector_id, c2k.kb_id, task.poll_range_end, total_docs_indexed=task.total_docs_indexed) - ConnectorService.update_by_id(connector_id, {"status": status}) - return + SyncLogsService.schedule( + connector_id, + c2k.kb_id, + poll_range_start, + total_docs_indexed=total_docs_indexed, + task_type=ConnectorTaskType.SYNC, + ) - task = task.to_dict() - task["status"] = status - SyncLogsService.update_by_id(task["id"], task) - ConnectorService.update_by_id(connector_id, {"status": status}) + if prune_enabled: + SyncLogsService.schedule( + connector_id, + c2k.kb_id, + task_type=ConnectorTaskType.PRUNE, + ) @classmethod def list(cls, tenant_id): @@ -105,7 +146,9 @@ class ConnectorService(CommonService): SyncLogsService.filter_delete([SyncLogs.connector_id==connector_id, SyncLogs.kb_id==kb_id]) docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}", kb_id=kb_id) err = FileService.delete_docs([d.id for d in docs], tenant_id) - SyncLogsService.schedule(connector_id, kb_id, reindex=True) + SyncLogsService.schedule(connector_id, kb_id, reindex=True, task_type=ConnectorTaskType.SYNC) + if (conn.config or {}).get("sync_deleted_files"): + SyncLogsService.schedule(connector_id, kb_id, task_type=ConnectorTaskType.PRUNE) return err @classmethod @@ -170,30 +213,25 @@ class ConnectorService(CommonService): class SyncLogsService(CommonService): model = SyncLogs + @classmethod def list_sync_tasks(cls, connector_id=None, page_number=None, items_per_page=15) -> Tuple[List[dict], int]: fields = [ cls.model.id, cls.model.connector_id, + cls.model.task_type, cls.model.kb_id, cls.model.update_date, - cls.model.poll_range_start, - cls.model.poll_range_end, cls.model.new_docs_indexed, cls.model.total_docs_indexed, + cls.model.docs_removed_from_index, cls.model.error_msg, - cls.model.full_exception_trace, cls.model.error_count, - Connector.name, - Connector.source, - Connector.tenant_id, - Connector.timeout_secs, + cls.model.time_started.alias("time_started"), + Connector.refresh_freq.alias("refresh_freq"), + Connector.prune_freq.alias("prune_freq"), Knowledgebase.name.alias("kb_name"), - Knowledgebase.avatar.alias("kb_avatar"), - Connector2Kb.auto_parse, - cls.model.from_beginning.alias("reindex"), cls.model.status, - cls.model.update_time ] if not connector_id: fields.append(Connector.config) @@ -225,6 +263,80 @@ class SyncLogsService(CommonService): return list(query.dicts()), total + @classmethod + def list_due_sync_tasks(cls) -> List[dict]: + return cls._list_due_tasks_for_freq( + ConnectorTaskType.SYNC, + "refresh_freq", + ) + + @classmethod + def list_due_prune_tasks(cls) -> List[dict]: + tasks = cls._list_due_tasks_for_freq( + ConnectorTaskType.PRUNE, + "prune_freq", + ) + return [ + task for task in tasks + # Prune is opt-in at the connector config level; keep the scheduler + # blind to prune_freq until the flag is enabled. + if bool((task.get("config") or {}).get("sync_deleted_files")) + and int(task.get("prune_freq") or 0) > 0 + ] + + @classmethod + def _list_due_tasks_for_freq(cls, task_type: str, freq_field: str) -> List[dict]: + fields = [ + cls.model.id, + cls.model.connector_id, + cls.model.task_type, + cls.model.kb_id, + cls.model.update_date, + cls.model.poll_range_start, + cls.model.poll_range_end, + cls.model.new_docs_indexed, + cls.model.total_docs_indexed, + cls.model.error_msg, + cls.model.full_exception_trace, + cls.model.error_count, + Connector.name, + Connector.source, + Connector.tenant_id, + Connector.timeout_secs, + Connector.config, + Connector.refresh_freq, + Connector.prune_freq, + Knowledgebase.name.alias("kb_name"), + Knowledgebase.avatar.alias("kb_avatar"), + Connector2Kb.auto_parse, + cls.model.from_beginning.alias("reindex"), + cls.model.status, + cls.model.update_time, + ] + + query = cls.model.select(*fields)\ + .join(Connector, on=(cls.model.connector_id==Connector.id))\ + .join(Connector2Kb, on=(cls.model.kb_id==Connector2Kb.kb_id))\ + .join(Knowledgebase, on=(cls.model.kb_id==Knowledgebase.id)) + + query = query.where( + Connector.input_type == InputType.POLL, + Connector.status == TaskStatus.SCHEDULE, + cls.model.status == TaskStatus.SCHEDULE, + cls.model.task_type == task_type, + ) + + database_type = os.getenv("DB_TYPE", "mysql") + if "postgres" in database_type.lower(): + expr = SQL( + f"NOW() AT TIME ZONE '{TIMEZONE}' - make_interval(mins => t2.{freq_field})" + ) + else: + expr = SQL(f"NOW() - INTERVAL `t2`.`{freq_field}` MINUTE") + query = query.where(cls.model.update_date < expr) + + return list(query.distinct().order_by(cls.model.update_time.desc()).dicts()) + @classmethod def start(cls, id, connector_id): cls.update_by_id(id, {"status": TaskStatus.RUNNING, "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') }) @@ -236,7 +348,15 @@ class SyncLogsService(CommonService): ConnectorService.update_by_id(connector_id, {"status": TaskStatus.DONE}) @classmethod - def schedule(cls, connector_id, kb_id, poll_range_start=None, reindex=False, total_docs_indexed=0): + def schedule( + cls, + connector_id, + kb_id, + poll_range_start=None, + reindex=False, + total_docs_indexed=0, + task_type=ConnectorTaskType.SYNC, + ): try: if cls.model.select().where(cls.model.kb_id == kb_id, cls.model.connector_id == connector_id).count() > 100: rm_ids = [m.id for m in cls.model.select(cls.model.id).where(cls.model.kb_id == kb_id, cls.model.connector_id == connector_id).order_by(cls.model.update_time.asc()).limit(70)] @@ -246,21 +366,33 @@ class SyncLogsService(CommonService): logging.exception(e) try: - e = cls.query(kb_id=kb_id, connector_id=connector_id, status=TaskStatus.SCHEDULE) + e = cls.query( + kb_id=kb_id, + connector_id=connector_id, + status=TaskStatus.SCHEDULE, + task_type=task_type, + ) if e: - logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.") + logging.warning( + "%s--%s already has a scheduled %s task.", + kb_id, + connector_id, + task_type, + ) return None reindex = "1" if reindex else "0" ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDULE}) return cls.save(**{ "id": get_uuid(), "kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id, + "task_type": task_type, "poll_range_start": poll_range_start, "from_beginning": reindex, - "total_docs_indexed": total_docs_indexed + "total_docs_indexed": total_docs_indexed, + "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') }) except Exception as e: logging.exception(e) - task = cls.get_latest_task(connector_id, kb_id) + task = cls.get_latest_task(connector_id, kb_id, task_type) if task: cls.model.update(status=TaskStatus.SCHEDULE, poll_range_start=poll_range_start, @@ -337,11 +469,14 @@ class SyncLogsService(CommonService): return errs, doc_ids @classmethod - def get_latest_task(cls, connector_id, kb_id): - return cls.model.select().where( + def get_latest_task(cls, connector_id, kb_id, task_type=None): + query = cls.model.select().where( cls.model.connector_id==connector_id, cls.model.kb_id == kb_id - ).order_by(cls.model.update_time.desc()).first() + ) + if task_type is not None: + query = query.where(cls.model.task_type == task_type) + return query.order_by(cls.model.update_time.desc()).first() class Connector2KbService(CommonService): @@ -364,7 +499,10 @@ class Connector2KbService(CommonService): "kb_id": kb_id, "auto_parse": conn.get("auto_parse", "1") }) - SyncLogsService.schedule(conn_id, kb_id, reindex=True) + SyncLogsService.schedule(conn_id, kb_id, reindex=True, task_type=ConnectorTaskType.SYNC) + e, full_conn = ConnectorService.get_by_id(conn_id) + if e and (full_conn.config or {}).get("sync_deleted_files"): + SyncLogsService.schedule(conn_id, kb_id, task_type=ConnectorTaskType.PRUNE) errs = [] for conn_id in old_conn_ids: diff --git a/common/constants.py b/common/constants.py index c80735255a..c76dcdbb09 100644 --- a/common/constants.py +++ b/common/constants.py @@ -93,6 +93,11 @@ class TaskStatus(StrEnum): VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL, TaskStatus.SCHEDULE} +class ConnectorTaskType(StrEnum): + SYNC = "sync" + PRUNE = "prune" + + class ParserType(StrEnum): PRESENTATION = "presentation" LAWS = "laws" diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index 8d397fc2d6..a5ba395820 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -41,7 +41,7 @@ from api.db.services.connector_service import ConnectorService, SyncLogsService from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from common import settings -from common.constants import FileSource, TaskStatus +from common.constants import ConnectorTaskType, FileSource, TaskStatus from common.config_utils import show_configs from common.data_source.config import INDEX_BATCH_SIZE from common.data_source import ( @@ -76,8 +76,6 @@ from common.log_utils import init_root_logger from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc from common.versions import get_ragflow_version from box_sdk_gen import BoxOAuth, OAuthConfig, AccessToken -from collections import namedtuple - MAX_CONCURRENT_TASKS = int(os.environ.get("MAX_CONCURRENT_TASKS", "5")) task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS) @@ -157,30 +155,37 @@ class SyncBase: }) return - SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"]) + task_type = task.get("task_type", ConnectorTaskType.SYNC) + if task_type == ConnectorTaskType.SYNC: + SyncLogsService.schedule( + task["connector_id"], + task["kb_id"], + task.get("poll_range_start"), + task_type=ConnectorTaskType.SYNC, + ) + elif task_type == ConnectorTaskType.PRUNE and self.conf.get("sync_deleted_files"): + SyncLogsService.schedule( + task["connector_id"], + task["kb_id"], + task_type=ConnectorTaskType.PRUNE, + ) async def _run_task_logic(self, task: dict): + task_type = task.get("task_type", ConnectorTaskType.SYNC) + if task_type == ConnectorTaskType.PRUNE: + await self._run_prune_task_logic(task) + return + await self._run_sync_task_logic(task) + + async def _run_sync_task_logic(self, task: dict): """ Executes the core synchronization pipeline for a data source task. - - This method retrieves documents from the external source via the `_generate` method, - parses and upserts them into the Knowledge Base (KB), and handles stale document - reconciliation (sync deletion) if a remote snapshot (`file_list`) is provided. """ - generate_output = await self._generate(task) - # `_generate()` currently supports two outputs: - # 1. `document_batch_generator` - # 2. `(document_batch_generator, file_list)` - if isinstance(generate_output, tuple): - document_batch_generator, file_list = generate_output - else: - document_batch_generator = generate_output - file_list = None + document_batch_generator = await self._generate(task) failed_docs = 0 added_docs = 0 updated_docs = 0 - removed_docs = 0 next_update = datetime(1970, 1, 1, tzinfo=timezone.utc) source_type = f"{self.SOURCE_NAME}/{task['connector_id']}" existing_doc_ids = { @@ -252,34 +257,12 @@ class SyncBase: prefix = self._get_source_prefix() prefix = f"{prefix} " if prefix else "" next_update_info = self._format_window_boundary(next_update) - expects_deleted_file_snapshot = ( - task.get("reindex") != "1" - and task.get("poll_range_start") - and self.conf.get("sync_deleted_files") - ) - cleanup_errors = [] - if expects_deleted_file_snapshot and file_list is None: - logging.warning( - "%s deleted-file snapshot retrieval failed " - "(connector_id=%s, kb_id=%s)", - self.SOURCE_NAME, - task["connector_id"], - task["kb_id"], - ) - elif file_list is not None: - removed_docs, cleanup_errors = ConnectorService.cleanup_stale_documents_for_task( - task["id"], - task["connector_id"], - task["kb_id"], - task["tenant_id"], - file_list, - ) - total_changed_docs = added_docs + updated_docs + removed_docs + total_changed_docs = added_docs + updated_docs summary = ( f"{prefix}sync summary till {next_update_info}: " f"total={total_changed_docs}, added={added_docs}, " - f"updated={updated_docs}, deleted={removed_docs}" + f"updated={updated_docs}" ) if failed_docs > 0: summary = f"{summary}, skipped={failed_docs}" @@ -288,19 +271,80 @@ class SyncBase: if ( isinstance(self, _RDBMSBase) and failed_docs == 0 - and (not expects_deleted_file_snapshot or file_list is not None) - and not cleanup_errors ): self.connector.persist_sync_state() SyncLogsService.done(task["id"], task["connector_id"]) task["poll_range_start"] = next_update + async def _run_prune_task_logic(self, task: dict): + if not self.conf.get("sync_deleted_files"): + SyncLogsService.done(task["id"], task["connector_id"]) + return + + await self._initialize_for_prune(task) + + file_list = self._collect_prune_snapshot(task) + if file_list is None: + logging.warning( + "%s prune snapshot retrieval failed (connector_id=%s, kb_id=%s)", + self.SOURCE_NAME, + task["connector_id"], + task["kb_id"], + ) + SyncLogsService.done(task["id"], task["connector_id"]) + return + + removed_docs, cleanup_errors = ConnectorService.cleanup_stale_documents_for_task( + task["id"], + task["connector_id"], + task["kb_id"], + task["tenant_id"], + file_list, + ) + logging.info( + "%s prune summary: deleted=%s, errors=%s", + self.SOURCE_NAME, + removed_docs, + len(cleanup_errors), + ) + SyncLogsService.done(task["id"], task["connector_id"]) + async def _generate(self, task: dict): raise NotImplementedError def _get_source_prefix(self): return "" + async def _initialize_for_prune(self, task: dict): + await self._generate(task) + + def _get_prune_snapshot_kwargs(self, task: dict) -> dict[str, Any]: + return {} + + def _collect_prune_snapshot(self, task: dict): + if not getattr(self, "connector", None): + return None + if not hasattr(self.connector, "retrieve_all_slim_docs_perm_sync"): + return None + + file_list = [] + snapshot_kwargs = self._get_prune_snapshot_kwargs(task) + try: + for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(**snapshot_kwargs): + file_list.extend(slim_batch) + except TypeError: + for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): + file_list.extend(slim_batch) + except Exception: + logging.exception( + "%s prune snapshot failed (connector_id=%s, kb_id=%s)", + self.SOURCE_NAME, + task["connector_id"], + task["kb_id"], + ) + return None + return file_list + class _BlobLikeBase(SyncBase): DEFAULT_BUCKET_TYPE: str = "s3" @@ -391,7 +435,6 @@ class _BlobLikeBase(SyncBase): self.connector.set_allow_images(self.conf.get("allow_images", False)) self.connector.load_credentials(self.conf["credentials"]) - file_list = None # Fingerprint-bypass path: skip GetObject for unchanged ETags. Disabled # on full reindex (we want to re-fetch everything in that case). use_fingerprint_path = task["reindex"] != "1" @@ -400,15 +443,6 @@ class _BlobLikeBase(SyncBase): else: document_batch_generator = self.connector.load_from_state() - if ( - task["reindex"] != "1" - and task["poll_range_start"] - and self.conf.get("sync_deleted_files") - ): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) - _begin_info = ( "fingerprint-bypass" if use_fingerprint_path @@ -423,7 +457,7 @@ class _BlobLikeBase(SyncBase): _begin_info, ) ) - return document_batch_generator, file_list + return document_batch_generator class S3(_BlobLikeBase): @@ -461,28 +495,11 @@ class RSS(SyncBase): return self.connector.load_from_state() end_time = datetime.now(timezone.utc).timestamp() - file_list = None - if self.conf.get("sync_deleted_files"): - logging.info( - "[RSS] Syncing deleted files via slim snapshot (connector_id=%s)", - task["connector_id"], - ) - snapshot_start = time.perf_counter() - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) - logging.info( - "[RSS] Slim snapshot fetched %d docs in %.2f seconds", - len(file_list), - time.perf_counter() - snapshot_start, - ) document_generator = self.connector.poll_source( task["poll_range_start"].timestamp(), end_time, ) - if file_list is not None: - return document_generator, file_list return document_generator @@ -525,16 +542,11 @@ class Confluence(SyncBase): credential_json=self.conf["credentials"]) self.connector.set_credentials_provider(credentials_provider) - file_list = None # Determine the time range for synchronization based on reindex or poll_range_start if task["reindex"] == "1" or not task["poll_range_start"]: start_time = 0.0 else: start_time = task["poll_range_start"].timestamp() - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) end_time = datetime.now(timezone.utc).timestamp() @@ -580,7 +592,7 @@ class Confluence(SyncBase): yield batch self.log_connection("Confluence", self.conf["wiki_base"], task) - return wrapper(), file_list + return wrapper() class Notion(SyncBase): @@ -589,7 +601,6 @@ class Notion(SyncBase): async def _generate(self, task: dict): self.connector = NotionConnector(root_page_id=self.conf["root_page_id"]) self.connector.load_credentials(self.conf["credentials"]) - file_list = None document_generator = ( self.connector.load_from_state() if task["reindex"] == "1" or not task["poll_range_start"] @@ -597,19 +608,10 @@ class Notion(SyncBase): datetime.now(timezone.utc).timestamp()) ) - if ( - task["reindex"] != "1" - and task["poll_range_start"] - and self.conf.get("sync_deleted_files") - ): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) - _begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format( task["poll_range_start"]) self.log_connection("Notion", f"root({self.conf['root_page_id']})", task) - return document_generator, file_list + return document_generator class Discord(SyncBase): @@ -627,26 +629,17 @@ class Discord(SyncBase): batch_size=self.conf.get("batch_size", 1024), ) self.connector.load_credentials(self.conf["credentials"]) - file_list = None document_generator = ( self.connector.load_from_state() if task["reindex"] == "1" or not task["poll_range_start"] else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp()) ) - if ( - task["reindex"] != "1" - and task["poll_range_start"] - and self.conf.get("sync_deleted_files") - ): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) _begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format( task["poll_range_start"]) self.log_connection("Discord", f"servers({server_ids}), channel({channel_names})", task) - return document_generator, file_list + return document_generator class Gmail(SyncBase): @@ -685,8 +678,6 @@ class Gmail(SyncBase): task["connector_id"], ) - file_list = None - # Decide between full reindex and incremental polling by time range. if task["reindex"] == "1" or not task.get("poll_range_start"): start_time = None @@ -706,17 +697,13 @@ class Gmail(SyncBase): end_time = datetime.now(timezone.utc).timestamp() _begin_info = f"from {poll_start}" document_generator = self.connector.poll_source(start_time, end_time) - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) try: admin_email = self.connector.primary_admin_email except RuntimeError: admin_email = "unknown" self.log_connection("Gmail", f"as {admin_email}", task) - return document_generator, file_list + return document_generator class Dropbox(SyncBase): @@ -726,22 +713,16 @@ class Dropbox(SyncBase): self.connector = DropboxConnector(batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE)) self.connector.load_credentials(self.conf["credentials"]) poll_start = task["poll_range_start"] - file_list = None - if task["reindex"] == "1" or not poll_start: document_generator = self.connector.load_from_state() _begin_info = "totally" else: end_time = datetime.now(timezone.utc).timestamp() - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) document_generator = self.connector.poll_source(poll_start.timestamp(), end_time) _begin_info = f"from {poll_start}" self.log_connection("Dropbox", "workspace", task) - return document_generator, file_list + return document_generator class GoogleDrive(SyncBase): @@ -775,8 +756,6 @@ class GoogleDrive(SyncBase): if new_credentials: self._persist_rotated_credentials(task["connector_id"], new_credentials) - file_list = None - # Capture end_time BEFORE the snapshot to prevent the ingestion race condition end_time = datetime.now(timezone.utc).timestamp() @@ -786,18 +765,6 @@ class GoogleDrive(SyncBase): else: start_time = task["poll_range_start"].timestamp() _begin_info = f"from {task['poll_range_start']}" - - if self.conf.get("sync_deleted_files"): - file_list = [] - SlimDoc = namedtuple('SlimDoc', ['id']) - - # Add observability timing so operators can track the O(N) cost - snapshot_start = time.perf_counter() - - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(SlimDoc(doc.id) for doc in slim_batch) - - logging.info("Slim snapshot fetched %d files in %.2f seconds", len(file_list), time.perf_counter() - snapshot_start) raw_batch_size = self.conf.get("sync_batch_size") or self.conf.get("batch_size") or INDEX_BATCH_SIZE try: @@ -843,7 +810,7 @@ class GoogleDrive(SyncBase): admin_email = "unknown" self.log_connection("Google Drive", f"as {admin_email}", task) - return document_batches(), file_list + return document_batches() def _persist_rotated_credentials(self, connector_id: str, credentials: dict[str, Any]) -> None: """Saves refreshed OAuth credentials back to the database configuration.""" @@ -886,17 +853,12 @@ class Jira(SyncBase): self.connector.load_credentials(credentials) self.connector.validate_connector_settings() - file_list = None if task["reindex"] == "1" or not task["poll_range_start"]: start_time = 0.0 _begin_info = "totally" else: start_time = task["poll_range_start"].timestamp() - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) _begin_info = f"from {task['poll_range_start']}" end_time = datetime.now(timezone.utc).timestamp() @@ -955,7 +917,7 @@ class Jira(SyncBase): f"overlap_buffer_s={getattr(self.connector, 'time_buffer_seconds', connector_kwargs.get('time_buffer_seconds'))}" ), ) - return document_batches(), file_list + return document_batches() @staticmethod def _normalize_list(values: Any) -> list[str] | None: @@ -1007,25 +969,11 @@ class WebDAV(SyncBase): self.connector.set_allow_images(self.conf.get("allow_images", False)) self.connector.load_credentials(self.conf["credentials"]) - file_list = None if task["reindex"] == "1" or not task["poll_range_start"]: document_batch_generator = self.connector.load_from_state() _begin_info = "totally" else: end_ts = datetime.now(timezone.utc).timestamp() - if self.conf.get("sync_deleted_files"): - file_list = [] - try: - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) - except Exception: - logging.exception( - "WebDAV slim snapshot failed; continuing without stale-document cleanup " - "(connector_id=%s, kb_id=%s)", - task["connector_id"], - task["kb_id"], - ) - file_list = None document_batch_generator = self.connector.poll_source( task["poll_range_start"].timestamp(), end_ts, @@ -1038,7 +986,7 @@ class WebDAV(SyncBase): for document_batch in document_batch_generator: yield document_batch - return wrapper(), file_list + return wrapper() class Moodle(SyncBase): @@ -1054,7 +1002,6 @@ class Moodle(SyncBase): # Determine the time range for synchronization based on reindex or poll_range_start poll_start = task.get("poll_range_start") - file_list = None if task["reindex"] == "1" or poll_start is None: document_generator = self.connector.load_from_state() @@ -1066,20 +1013,6 @@ class Moodle(SyncBase): # could be polled as new and at the same time be missing from # the slim list, which would mark it as stale and delete it. end_ts = datetime.now(timezone.utc).timestamp() - - if self.conf.get("sync_deleted_files"): - file_list = [] - try: - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) - except Exception: - logging.exception( - "Moodle slim snapshot failed; skipping stale-document cleanup " - "(connector_id=%s, kb_id=%s)", - task.get("connector_id"), - task.get("kb_id"), - ) - file_list = None document_generator = self.connector.poll_source( poll_start.timestamp(), end_ts, @@ -1087,7 +1020,7 @@ class Moodle(SyncBase): _begin_info = f"from {poll_start}" self.log_connection("Moodle", self.conf["moodle_url"], task) - return document_generator, file_list + return document_generator class BOX(SyncBase): @@ -1115,23 +1048,18 @@ class BOX(SyncBase): self.connector.load_credentials(auth) poll_start = task["poll_range_start"] - file_list = None if task["reindex"] == "1" or poll_start is None: document_generator = self.connector.load_from_state() _begin_info = "totally" else: - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) document_generator = self.connector.poll_source( poll_start.timestamp(), datetime.now(timezone.utc).timestamp(), ) _begin_info = f"from {poll_start}" self.log_connection("Box", f"folder_id({self.conf['folder_id']})", task) - return document_generator, file_list + return document_generator class Airtable(SyncBase): @@ -1156,16 +1084,11 @@ class Airtable(SyncBase): ) poll_start = task.get("poll_range_start") - file_list = None if task.get("reindex") == "1" or poll_start is None: document_generator = self.connector.load_from_state() _begin_info = "totally" else: - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) document_generator = self.connector.poll_source( poll_start.timestamp(), datetime.now(timezone.utc).timestamp(), @@ -1178,7 +1101,7 @@ class Airtable(SyncBase): task, ) - return document_generator, file_list + return document_generator class Asana(SyncBase): SOURCE_NAME: str = FileSource.ASANA @@ -1198,17 +1121,12 @@ class Asana(SyncBase): ) poll_start = task.get("poll_range_start") - file_list = None if task.get("reindex") == "1" or not poll_start: document_generator = self.connector.load_from_state() _begin_info = "totally" else: end_time = datetime.now(timezone.utc).timestamp() - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) document_generator = self.connector.poll_source( poll_start.timestamp(), end_time, @@ -1221,7 +1139,7 @@ class Asana(SyncBase): task, ) - return document_generator, file_list + return document_generator class Github(SyncBase): SOURCE_NAME: str = FileSource.GITHUB @@ -1247,15 +1165,10 @@ class Github(SyncBase): {"github_access_token": credentials["github_access_token"]} ) - file_list = None if task.get("reindex") == "1" or not task.get("poll_range_start"): start_time = datetime.fromtimestamp(0, tz=timezone.utc) else: start_time = task.get("poll_range_start") - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) end_time = datetime.now(timezone.utc) @@ -1292,7 +1205,7 @@ class Github(SyncBase): task, ) - return wrapper(), file_list + return wrapper() class IMAP(SyncBase): SOURCE_NAME: str = FileSource.IMAP @@ -1348,27 +1261,10 @@ class IMAP(SyncBase): task["connector_id"], ) - file_list = None - if ( - task["reindex"] != "1" - and task["poll_range_start"] - and self.conf.get("sync_deleted_files") - ): - file_list = [] - try: - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync( - start=initial_sync_start, - end=end_time, - ): - file_list.extend(slim_batch) - except Exception: - logging.exception( - "IMAP slim snapshot failed; continuing without stale-document cleanup " - "(connector_id=%s, kb_id=%s)", - task["connector_id"], - task["kb_id"], - ) - file_list = None + self._prune_snapshot_kwargs = { + "start": initial_sync_start, + "end": end_time, + } raw_batch_size = self.conf.get("sync_batch_size") or self.conf.get("batch_size") or INDEX_BATCH_SIZE try: @@ -1414,7 +1310,10 @@ class IMAP(SyncBase): f"host({self.conf['imap_host']}) port({self.conf['imap_port']}) user({self.conf['credentials']['imap_username']}) folder({self.conf['imap_mailbox']})", task, ) - return wrapper(), file_list + return wrapper() + + def _get_prune_snapshot_kwargs(self, task: dict) -> dict[str, Any]: + return getattr(self, "_prune_snapshot_kwargs", {}) class Zendesk(SyncBase): @@ -1424,26 +1323,11 @@ class Zendesk(SyncBase): self.connector.load_credentials(self.conf["credentials"]) end_time = datetime.now(timezone.utc).timestamp() - file_list = None if task["reindex"] == "1" or not task.get("poll_range_start"): start_time = 0 _begin_info = "totally" else: start_time = task["poll_range_start"].timestamp() - if self.conf.get("sync_deleted_files"): - logging.info( - "[Zendesk] Syncing deleted files via slim snapshot (connector_id=%s)", - task.get("connector_id"), - ) - snapshot_start = time.perf_counter() - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) - logging.info( - "[Zendesk] Slim snapshot fetched %d docs in %.2f seconds", - len(file_list), - time.perf_counter() - snapshot_start, - ) _begin_info = f"from {task['poll_range_start']}" raw_batch_size = ( @@ -1504,9 +1388,6 @@ class Zendesk(SyncBase): yield batch self.log_connection("Zendesk", f"subdomain({self.conf['credentials'].get('zendesk_subdomain')})", task) - - if file_list is not None: - return wrapper(), file_list return wrapper() @@ -1533,7 +1414,6 @@ class Gitlab(SyncBase): } ) - file_list = None if task["reindex"] == "1" or not task["poll_range_start"]: document_generator = self.connector.load_from_state() _begin_info = "totally" @@ -1547,13 +1427,9 @@ class Gitlab(SyncBase): poll_start.timestamp(), datetime.now(timezone.utc).timestamp() ) - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) _begin_info = "from {}".format(poll_start) self.log_connection("Gitlab", f"({self.conf['project_name']})", task) - return document_generator, file_list + return document_generator class Bitbucket(SyncBase): @@ -1572,17 +1448,12 @@ class Bitbucket(SyncBase): "bitbucket_api_token": self.conf["credentials"].get("bitbucket_api_token"), } ) - file_list = None if task["reindex"] == "1" or not task["poll_range_start"]: start_time = datetime.fromtimestamp(0, tz=timezone.utc) _begin_info = "totally" else: start_time = task.get("poll_range_start") - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) _begin_info = f"from {start_time}" end_time = datetime.now(timezone.utc) @@ -1614,8 +1485,6 @@ class Bitbucket(SyncBase): yield batch self.log_connection("Bitbucket", f"workspace({self.conf.get('workspace')})", task) - if file_list is not None: - return wrapper(), file_list return wrapper() @@ -1642,26 +1511,12 @@ class SeaFile(SyncBase): ) self.connector.load_credentials(conf["credentials"]) - file_list = None poll_start = task.get("poll_range_start") if task["reindex"] == "1" or poll_start is None: document_generator = self.connector.load_from_state() _begin_info = "totally" else: end_ts = datetime.now(timezone.utc).timestamp() - if self.conf.get("sync_deleted_files"): - file_list = [] - try: - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) - except Exception: - logging.exception( - "SeaFile slim snapshot failed; continuing without stale-document cleanup " - "(connector_id=%s, kb_id=%s)", - task["connector_id"], - task["kb_id"], - ) - file_list = None document_generator = self.connector.poll_source( poll_start.timestamp(), end_ts, @@ -1676,7 +1531,7 @@ class SeaFile(SyncBase): extra += f" path={conf.get('sync_path')}" self.log_connection("SeaFile", f"{conf['seafile_url']} (scope={scope}{extra})", task) - return document_generator, file_list + return document_generator class DingTalkAITable(SyncBase): @@ -1709,33 +1564,12 @@ class DingTalkAITable(SyncBase): ) poll_start = task.get("poll_range_start") - file_list = None if task.get("reindex") == "1" or poll_start is None: document_generator = self.connector.load_from_state() _begin_info = "totally" else: end_ts = datetime.now(timezone.utc).timestamp() - if self.conf.get("sync_deleted_files"): - file_list = [] - logging.info( - "DingTalk AI Table: fetching slim snapshot for stale-document reconciliation " - "(connector_id=%s, kb_id=%s, table_id=%s)", - task["connector_id"], - task["kb_id"], - self.conf.get("table_id"), - ) - try: - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) - except Exception: - logging.exception( - "DingTalk AI Table slim snapshot failed; continuing without stale-document cleanup " - "(connector_id=%s, kb_id=%s)", - task["connector_id"], - task["kb_id"], - ) - file_list = None document_generator = self.connector.poll_source( poll_start.timestamp(), end_ts, @@ -1748,7 +1582,7 @@ class DingTalkAITable(SyncBase): task, ) - return document_generator, file_list + return document_generator class _RDBMSBase(SyncBase): @@ -1778,16 +1612,6 @@ class _RDBMSBase(SyncBase): self.connector.validate_connector_settings() self.connector.prepare_sync_state(task["connector_id"], self.conf) - file_list = None - if ( - task["reindex"] != "1" - and task["poll_range_start"] - and self.conf.get("sync_deleted_files") - ): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) - if task["reindex"] == "1" or not task["poll_range_start"]: document_generator = self.connector.load_from_state() _begin_info = "totally" @@ -1804,7 +1628,7 @@ class _RDBMSBase(SyncBase): _begin_info = f"from {poll_start}" self.log_connection(self.LOG_NAME, f"{self.conf.get('host')}:{self.conf.get('database')}", task) - return document_generator, file_list + return document_generator class MySQL(_RDBMSBase): @@ -1886,14 +1710,17 @@ async def dispatch_tasks(): """Polls the database for pending synchronization tasks and dispatches them concurrently.""" while True: try: - list(SyncLogsService.list_sync_tasks()[0]) + SyncLogsService.list_due_sync_tasks() + SyncLogsService.list_due_prune_tasks() break except Exception as e: logging.warning(f"DB is not ready yet: {e}") await asyncio.sleep(3) + due_sync_tasks = SyncLogsService.list_due_sync_tasks() + due_prune_tasks = SyncLogsService.list_due_prune_tasks() tasks = [] - for task in SyncLogsService.list_sync_tasks()[0]: + for task in [*due_sync_tasks, *due_prune_tasks]: if task["poll_range_start"]: task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc) if task["poll_range_end"]: diff --git a/test/testcases/restful_api/test_connector_routes_unit.py b/test/testcases/restful_api/test_connector_routes_unit.py index 33c4d7a8f1..80cd5662a6 100644 --- a/test/testcases/restful_api/test_connector_routes_unit.py +++ b/test/testcases/restful_api/test_connector_routes_unit.py @@ -205,7 +205,7 @@ def _load_connector_app(monkeypatch): return True @staticmethod - def resume(*_args, **_kwargs): + def cancel_tasks(*_args, **_kwargs): return True @staticmethod @@ -252,7 +252,11 @@ def _load_connector_app(monkeypatch): PERMISSION_ERROR=403, AUTHENTICATION_ERROR=109, ) - constants_mod.TaskStatus = SimpleNamespace(SCHEDULE="schedule", CANCEL="cancel") + constants_mod.TaskStatus = SimpleNamespace( + UNSTART="unstart", + SCHEDULE="schedule", + CANCEL="cancel", + ) monkeypatch.setitem(sys.modules, "common.constants", constants_mod) config_mod = ModuleType("common.data_source.config") @@ -349,7 +353,7 @@ def test_connector_basic_routes_and_task_controls(monkeypatch): records = {"conn-1": _FakeConnectorRecord({"id": "conn-1", "source": "drive"})} update_calls = [] save_calls = [] - resume_calls = [] + cancel_calls = [] delete_calls = [] monkeypatch.setattr(module.ConnectorService, "update_by_id", lambda cid, payload: update_calls.append((cid, payload))) @@ -362,7 +366,7 @@ def test_connector_basic_routes_and_task_controls(monkeypatch): monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda cid: (True, records[cid])) monkeypatch.setattr(module.ConnectorService, "list", lambda tenant_id: [{"id": "listed", "tenant": tenant_id}]) monkeypatch.setattr(module.SyncLogsService, "list_sync_tasks", lambda cid, page, page_size: ([{"id": "log-1"}], 9)) - monkeypatch.setattr(module.ConnectorService, "resume", lambda cid, status: resume_calls.append((cid, status))) + monkeypatch.setattr(module.ConnectorService, "cancel_tasks", lambda cid: cancel_calls.append(cid)) monkeypatch.setattr(module.ConnectorService, "delete_by_id", lambda cid: delete_calls.append(cid)) monkeypatch.setattr(module, "get_uuid", lambda: "generated-id") @@ -384,6 +388,7 @@ def test_connector_basic_routes_and_task_controls(monkeypatch): assert save_calls[-1]["id"] == "generated-id" assert save_calls[-1]["tenant_id"] == "tenant-1" assert save_calls[-1]["input_type"] == module.InputType.POLL + assert save_calls[-1]["status"] == module.TaskStatus.UNSTART assert res["data"]["id"] == "generated-id" list_res = module.list_connector() @@ -401,14 +406,6 @@ def test_connector_basic_routes_and_task_controls(monkeypatch): logs_res = module.list_logs("conn-log") assert logs_res["data"] == {"total": 9, "logs": [{"id": "log-1"}]} - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"resume": True})) - assert _run(module.resume("conn-r1"))["data"] is True - - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"resume": False})) - assert _run(module.resume("conn-r2"))["data"] is True - assert ("conn-r1", module.TaskStatus.SCHEDULE) in resume_calls - assert ("conn-r2", module.TaskStatus.CANCEL) in resume_calls - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"kb_id": "kb-1"})) monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: "rebuild-failed") failed_rebuild = _run(module.rebuild("conn-rb")) @@ -421,7 +418,7 @@ def test_connector_basic_routes_and_task_controls(monkeypatch): rm_res = module.rm_connector("conn-rm") assert rm_res["data"] is True - assert ("conn-rm", module.TaskStatus.CANCEL) in resume_calls + assert cancel_calls == ["conn-rm"] assert delete_calls == ["conn-rm"] @@ -434,14 +431,14 @@ def test_connector_by_id_routes_reject_cross_tenant_access(monkeypatch): monkeypatch.setattr(module.ConnectorService, "accessible", lambda cid, uid: False) monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda *_args: touched.append("get_by_id")) monkeypatch.setattr(module.SyncLogsService, "list_sync_tasks", lambda *_args: touched.append("list_sync_tasks")) - monkeypatch.setattr(module.ConnectorService, "resume", lambda *_args: touched.append("resume")) + monkeypatch.setattr(module.ConnectorService, "cancel_tasks", lambda *_args: touched.append("cancel_tasks")) monkeypatch.setattr(module.ConnectorService, "delete_by_id", lambda *_args: touched.append("delete_by_id")) monkeypatch.setattr(module.ConnectorService, "update_by_id", lambda *_args: touched.append("update_by_id")) monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: touched.append("rebuild")) def _get_request_json(): touched.append("get_request_json") - return _AwaitableValue({"resume": True, "config": {"x": 1}}) + return _AwaitableValue({"config": {"x": 1}}) monkeypatch.setattr(module, "get_request_json", _get_request_json) @@ -449,7 +446,6 @@ def test_connector_by_id_routes_reject_cross_tenant_access(monkeypatch): _run(module.update_connector("conn-victim")), module.get_connector("conn-victim"), module.list_logs("conn-victim"), - _run(module.resume("conn-victim")), _run(module.rebuild("conn-victim")), module.rm_connector("conn-victim"), _run(module.test_connector("conn-victim")), diff --git a/test/testcases/test_web_api/test_connector_app/test_connector_routes_unit.py b/test/testcases/test_web_api/test_connector_app/test_connector_routes_unit.py index 3807fb8e15..605ec415f1 100644 --- a/test/testcases/test_web_api/test_connector_app/test_connector_routes_unit.py +++ b/test/testcases/test_web_api/test_connector_app/test_connector_routes_unit.py @@ -205,7 +205,7 @@ def _load_connector_app(monkeypatch): return True @staticmethod - def resume(*_args, **_kwargs): + def cancel_tasks(*_args, **_kwargs): return True @staticmethod @@ -349,7 +349,7 @@ def test_connector_basic_routes_and_task_controls(monkeypatch): records = {"conn-1": _FakeConnectorRecord({"id": "conn-1", "source": "drive"})} update_calls = [] save_calls = [] - resume_calls = [] + cancel_calls = [] delete_calls = [] monkeypatch.setattr(module.ConnectorService, "update_by_id", lambda cid, payload: update_calls.append((cid, payload))) @@ -362,7 +362,7 @@ def test_connector_basic_routes_and_task_controls(monkeypatch): monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda cid: (True, records[cid])) monkeypatch.setattr(module.ConnectorService, "list", lambda tenant_id: [{"id": "listed", "tenant": tenant_id}]) monkeypatch.setattr(module.SyncLogsService, "list_sync_tasks", lambda cid, page, page_size: ([{"id": "log-1"}], 9)) - monkeypatch.setattr(module.ConnectorService, "resume", lambda cid, status: resume_calls.append((cid, status))) + monkeypatch.setattr(module.ConnectorService, "cancel_tasks", lambda cid: cancel_calls.append(cid)) monkeypatch.setattr(module.ConnectorService, "delete_by_id", lambda cid: delete_calls.append(cid)) monkeypatch.setattr(module, "get_uuid", lambda: "generated-id") @@ -401,14 +401,6 @@ def test_connector_basic_routes_and_task_controls(monkeypatch): logs_res = module.list_logs("conn-log") assert logs_res["data"] == {"total": 9, "logs": [{"id": "log-1"}]} - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"resume": True})) - assert _run(module.resume("conn-r1"))["data"] is True - - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"resume": False})) - assert _run(module.resume("conn-r2"))["data"] is True - assert ("conn-r1", module.TaskStatus.SCHEDULE) in resume_calls - assert ("conn-r2", module.TaskStatus.CANCEL) in resume_calls - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"kb_id": "kb-1"})) monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: "rebuild-failed") failed_rebuild = _run(module.rebuild("conn-rb")) @@ -421,7 +413,7 @@ def test_connector_basic_routes_and_task_controls(monkeypatch): rm_res = module.rm_connector("conn-rm") assert rm_res["data"] is True - assert ("conn-rm", module.TaskStatus.CANCEL) in resume_calls + assert cancel_calls == ["conn-rm"] assert delete_calls == ["conn-rm"] @@ -434,14 +426,14 @@ def test_connector_by_id_routes_reject_cross_tenant_access(monkeypatch): monkeypatch.setattr(module.ConnectorService, "accessible", lambda cid, uid: False) monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda *_args: touched.append("get_by_id")) monkeypatch.setattr(module.SyncLogsService, "list_sync_tasks", lambda *_args: touched.append("list_sync_tasks")) - monkeypatch.setattr(module.ConnectorService, "resume", lambda *_args: touched.append("resume")) + monkeypatch.setattr(module.ConnectorService, "cancel_tasks", lambda *_args: touched.append("cancel_tasks")) monkeypatch.setattr(module.ConnectorService, "delete_by_id", lambda *_args: touched.append("delete_by_id")) monkeypatch.setattr(module.ConnectorService, "update_by_id", lambda *_args: touched.append("update_by_id")) monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: touched.append("rebuild")) def _get_request_json(): touched.append("get_request_json") - return _AwaitableValue({"resume": True, "config": {"x": 1}}) + return _AwaitableValue({"config": {"x": 1}}) monkeypatch.setattr(module, "get_request_json", _get_request_json) @@ -449,7 +441,6 @@ def test_connector_by_id_routes_reject_cross_tenant_access(monkeypatch): _run(module.update_connector("conn-victim")), module.get_connector("conn-victim"), module.list_logs("conn-victim"), - _run(module.resume("conn-victim")), _run(module.rebuild("conn-victim")), module.rm_connector("conn-victim"), _run(module.test_connector("conn-victim")), diff --git a/test/unit_test/rag/test_sync_data_source.py b/test/unit_test/rag/test_sync_data_source.py index be9d89372a..8bb5e4cd43 100644 --- a/test/unit_test/rag/test_sync_data_source.py +++ b/test/unit_test/rag/test_sync_data_source.py @@ -133,7 +133,53 @@ def _patch_common_dependencies(monkeypatch): @pytest.mark.anyio @pytest.mark.p2 -async def test_run_task_logic_cleans_up_for_empty_snapshot(monkeypatch): +async def test_run_task_logic_skips_empty_sync_batches(monkeypatch): + _patch_common_dependencies(monkeypatch) + monkeypatch.setattr( + sync_data_source.SyncLogsService, + "increase_docs", + lambda *_args, **_kwargs: pytest.fail("increase_docs should not be called for empty batches"), + ) + monkeypatch.setattr( + sync_data_source.KnowledgebaseService, + "get_by_id", + lambda *_args, **_kwargs: pytest.fail("get_by_id should not be called for empty batches"), + ) + monkeypatch.setattr( + sync_data_source.SyncLogsService, + "duplicate_and_parse", + lambda *_args, **_kwargs: pytest.fail("duplicate_and_parse should not be called for empty batches"), + ) + + await _FakeSync(iter(([],)))._run_task_logic(_make_task()) + + +@pytest.mark.anyio +@pytest.mark.p2 +async def test_run_task_logic_skips_multiple_empty_sync_batches(monkeypatch): + _patch_common_dependencies(monkeypatch) + monkeypatch.setattr( + sync_data_source.SyncLogsService, + "increase_docs", + lambda *_args, **_kwargs: pytest.fail("increase_docs should not be called for empty batches"), + ) + monkeypatch.setattr( + sync_data_source.KnowledgebaseService, + "get_by_id", + lambda *_args, **_kwargs: pytest.fail("get_by_id should not be called for empty batches"), + ) + monkeypatch.setattr( + sync_data_source.SyncLogsService, + "duplicate_and_parse", + lambda *_args, **_kwargs: pytest.fail("duplicate_and_parse should not be called for empty batches"), + ) + + await _FakeSync(iter(([], [],)))._run_task_logic(_make_task()) + + +@pytest.mark.anyio +@pytest.mark.p2 +async def test_run_prune_task_logic_cleans_up_for_empty_snapshot(monkeypatch): cleanup_calls = [] _patch_common_dependencies(monkeypatch) @@ -148,7 +194,14 @@ async def test_run_task_logic_cleans_up_for_empty_snapshot(monkeypatch): _fake_cleanup, ) - await _FakeSync((iter(()), []))._run_task_logic(_make_task()) + task = {**_make_task(), "task_type": sync_data_source.ConnectorTaskType.PRUNE} + sync = _FakeSync(iter(())) + sync.conf["sync_deleted_files"] = True + sync.connector = types.SimpleNamespace( + retrieve_all_slim_docs_perm_sync=lambda: iter(([],)) + ) + + await sync._run_task_logic(task) assert cleanup_calls == [ ( @@ -166,7 +219,7 @@ async def test_run_task_logic_cleans_up_for_empty_snapshot(monkeypatch): @pytest.mark.anyio @pytest.mark.p2 -async def test_run_task_logic_cleans_up_for_non_empty_snapshot(monkeypatch): +async def test_run_prune_task_logic_cleans_up_for_non_empty_snapshot(monkeypatch): cleanup_calls = [] _patch_common_dependencies(monkeypatch) @@ -182,7 +235,14 @@ async def test_run_task_logic_cleans_up_for_non_empty_snapshot(monkeypatch): ) file_list = [types.SimpleNamespace(id="doc-1")] - await _FakeSync((iter(()), file_list))._run_task_logic(_make_task()) + task = {**_make_task(), "task_type": sync_data_source.ConnectorTaskType.PRUNE} + sync = _FakeSync(iter(())) + sync.conf["sync_deleted_files"] = True + sync.connector = types.SimpleNamespace( + retrieve_all_slim_docs_perm_sync=lambda: iter((file_list,)) + ) + + await sync._run_task_logic(task) assert cleanup_calls == [ ( @@ -285,12 +345,13 @@ async def test_rdbms_generate_keeps_deleted_file_snapshot_without_timestamp_colu } ) - document_generator, file_list = await sync._generate(task) + document_generator = await sync._generate(task) connector = _FakeRDBMSConnector.instance assert connector is not None assert connector.load_from_state_called is True assert connector.load_from_cursor_range_called is False + file_list = sync._collect_prune_snapshot(task) assert connector.retrieve_all_slim_docs_perm_sync_called is True assert file_list is not None assert [doc.id for doc in file_list] == ["row-1"] @@ -447,14 +508,15 @@ async def test_dropbox_generate_returns_snapshot_when_sync_deleted_enabled(monke } ) - document_generator, file_list = await sync._generate(task) + document_generator = await sync._generate(task) connector = _FakeDropboxConnector.instance assert list(document_generator) == [["poll-sync"]] + file_list = sync._collect_prune_snapshot(task) assert [doc.id for doc in file_list] == ["dropbox:id-1", "dropbox:id-2"] assert connector.credentials == {"dropbox_access_token": "token-1"} assert connector.retrieve_all_slim_docs_perm_sync_called is True - assert connector.snapshot_called_before_poll is True + assert connector.snapshot_called_before_poll is False assert connector.poll_source_call[0] == poll_start.timestamp() assert connector.poll_source_call[1] >= poll_start.timestamp() @@ -477,11 +539,12 @@ async def test_dropbox_generate_skips_snapshot_for_full_reindex(monkeypatch): } ) - document_generator, file_list = await sync._generate(task) + document_generator = await sync._generate(task) connector = _FakeDropboxConnector.instance assert list(document_generator) == [["full-sync"]] - assert file_list is None assert connector.load_from_state_called is True - assert connector.retrieve_all_slim_docs_perm_sync_called is False + file_list = sync._collect_prune_snapshot(task) + assert [doc.id for doc in file_list] == ["dropbox:id-1", "dropbox:id-2"] + assert connector.retrieve_all_slim_docs_perm_sync_called is True assert connector.poll_source_called is False diff --git a/web/src/components/dynamic-form.tsx b/web/src/components/dynamic-form.tsx index 0ef13df1c1..0920e2422e 100644 --- a/web/src/components/dynamic-form.tsx +++ b/web/src/components/dynamic-form.tsx @@ -111,10 +111,12 @@ interface DynamicFormProps { // Form ref interface export interface DynamicFormRef { submit: () => void; + isDirty: () => boolean; getValues: (name?: string) => any; reset: (values?: any) => void; trigger: UseFormTrigger; watch: (field: string, callback: (value: any) => void) => () => void; + watchDirty: (callback: (isDirty: boolean, values: any) => void) => () => void; updateFieldType: (fieldName: string, newType: FormFieldType) => void; onFieldUpdate: ( fieldName: string, @@ -809,6 +811,7 @@ const DynamicForm = { onSubmit(filteredValues); })(); }, + isDirty: () => form.formState.isDirty, getValues: form.getValues, reset: (values?: T) => { if (values) { @@ -828,6 +831,12 @@ const DynamicForm = { }); return unsubscribe; }, + watchDirty: (callback: (isDirty: boolean, values: any) => void) => { + const { unsubscribe } = form.watch((values: any) => { + callback(form.formState.isDirty, values); + }); + return unsubscribe; + }, onFieldUpdate: ( fieldName: string, diff --git a/web/src/pages/user-setting/data-source/data-source-detail-page/index.tsx b/web/src/pages/user-setting/data-source/data-source-detail-page/index.tsx index dfeb7e0830..a55c2af8ee 100644 --- a/web/src/pages/user-setting/data-source/data-source-detail-page/index.tsx +++ b/web/src/pages/user-setting/data-source/data-source-detail-page/index.tsx @@ -9,9 +9,9 @@ import { Button } from '@/components/ui/button'; import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; import { Input } from '@/components/ui/input'; import { Separator } from '@/components/ui/separator'; -import { RunningStatus } from '@/constants/knowledge'; +import { RunningStatus, RunningStatusOld } from '@/constants/knowledge'; import { t } from 'i18next'; -import { CirclePause, Repeat } from 'lucide-react'; +import { isEqual } from 'lodash'; import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { FieldValues } from 'react-hook-form'; import { @@ -25,9 +25,9 @@ import { } from '../constant'; import { useAddDataSource, - useDataSourceResume, useFetchDataSourceDetail, useTestDataSource, + useUpdateDataSourceStatus, } from '../hooks'; import { DataSourceLogsTable } from './log-table'; @@ -35,7 +35,8 @@ const SourceDetailPage = () => { const formRef = useRef(null); const { data: detail } = useFetchDataSourceDetail(); - const { handleResume } = useDataSourceResume(); + const { updateStatus, loading: statusUpdateLoading } = + useUpdateDataSourceStatus(); const { dataSourceInfo } = useDataSourceInfo(); const detailInfo = useMemo(() => { if (detail) { @@ -44,83 +45,52 @@ const SourceDetailPage = () => { }, [detail, dataSourceInfo]); const [fields, setFields] = useState([]); + const [isDirty, setIsDirty] = useState(false); const [defaultValues, setDefaultValues] = useState( DataSourceFormDefaultValues[ detail?.source as keyof typeof DataSourceFormDefaultValues ] as FieldValues, ); - const runSchedule = useCallback(() => { - handleResume({ - resume: - detail?.status === RunningStatus.RUNNING || - detail?.status === RunningStatus.SCHEDULE - ? false - : true, - }); - }, [detail, handleResume]); - const customFields = useMemo(() => { return [ + { + label: 'Prune Freq', + name: 'prune_freq', + type: FormFieldType.Number, + required: false, + shouldRender: (values: any) => !!values?.config?.sync_deleted_files, + render: (fieldProps: FormFieldConfig) => { + return ( + + {t('setting.minutes')} + + } + /> + ); + }, + }, { label: 'Refresh Freq', name: 'refresh_freq', type: FormFieldType.Number, required: false, render: (fieldProps: FormFieldConfig) => ( -
-
- - {t('setting.minutes')} - - } - /> -
- -
+ + {t('setting.minutes')} + + } + /> ), }, - { - label: 'Prune Freq', - name: 'prune_freq', - type: FormFieldType.Number, - required: false, - hidden: true, - render: (fieldProps: FormFieldConfig) => { - return ( -
-
- - hours - - } - /> -
-
- ); - }, - }, { label: 'Timeout Secs', name: 'timeout_secs', @@ -143,7 +113,7 @@ const SourceDetailPage = () => { ), }, ]; - }, [detail, runSchedule]); + }, []); const { addLoading, handleAddOk } = useAddDataSource({ isEdit: true }); const { loading: testLoading, handleTest } = useTestDataSource(); @@ -152,6 +122,54 @@ const SourceDetailPage = () => { formRef?.current?.submit(); }, []); + const isUnstarted = useMemo( + () => + detail?.status === RunningStatus.UNSTART || + detail?.status === RunningStatusOld.UNSTART, + [detail?.status], + ); + + const isConnectorActive = useMemo( + () => + detail?.status === RunningStatus.RUNNING || + detail?.status === RunningStatus.SCHEDULE || + detail?.status === RunningStatusOld.RUNNING || + detail?.status === RunningStatusOld.SCHEDULE, + [detail?.status], + ); + + const actionMode = useMemo(() => { + if (isDirty) { + return 'save' as const; + } + + if (isUnstarted) { + return 'save' as const; + } + + if (isConnectorActive) { + return 'stop' as const; + } + + return 'resume' as const; + }, [isConnectorActive, isDirty, isUnstarted]); + + const handlePrimaryAction = useCallback(() => { + if (actionMode === 'save') { + onSubmit(); + return; + } + updateStatus( + actionMode === 'resume' ? RunningStatus.SCHEDULE : RunningStatus.CANCEL, + ); + }, [actionMode, onSubmit, updateStatus]); + + const primaryActionLabel = useMemo(() => { + if (actionMode === 'stop') return 'Stop'; + if (actionMode === 'resume') return 'Resume'; + return 'Save'; + }, [actionMode]); + useEffect(() => { const baseFields = DataSourceFormBaseFields.map((field) => { if (field.name === 'name') { @@ -191,9 +209,20 @@ const SourceDetailPage = () => { ), }; setDefaultValues(defaultValueTemp); + setIsDirty(false); } }, [detail, customFields, onSubmit]); + useEffect(() => { + const instance = formRef.current; + if (!instance) return; + + setIsDirty(!isEqual(instance.getValues(), defaultValues)); + return instance.watchDirty((_nextIsDirty, values) => { + setIsDirty(!isEqual(values, defaultValues)); + }); + }, [defaultValues, fields]); + return (
@@ -229,22 +258,21 @@ const SourceDetailPage = () => { )}
{t('setting.log')}
- +
diff --git a/web/src/pages/user-setting/data-source/data-source-detail-page/log-table.tsx b/web/src/pages/user-setting/data-source/data-source-detail-page/log-table.tsx index 33bef3d902..c00301b793 100644 --- a/web/src/pages/user-setting/data-source/data-source-detail-page/log-table.tsx +++ b/web/src/pages/user-setting/data-source/data-source-detail-page/log-table.tsx @@ -1,6 +1,5 @@ import FileStatusBadge from '@/components/file-status-badge'; import { RAGFlowAvatar } from '@/components/ragflow-avatar'; -import { Button } from '@/components/ui/button'; import { RAGFlowPagination } from '@/components/ui/ragflow-pagination'; import { Table, @@ -14,11 +13,6 @@ import { RunningStatusMap } from '@/constants/knowledge'; import { RunningStatus } from '@/pages/dataset/dataset/constant'; import { Routes } from '@/routes'; import { formatDate } from '@/utils/date'; -import { - HoverCard, - HoverCardContent, - HoverCardTrigger, -} from '@radix-ui/react-hover-card'; import { ColumnDef, flexRender, @@ -30,15 +24,86 @@ import { } from '@tanstack/react-table'; import { t } from 'i18next'; import { pick } from 'lodash'; -import { Eye } from 'lucide-react'; -import { useCallback, useMemo } from 'react'; +import { useCallback, useEffect, useMemo, useState } from 'react'; import { useNavigate } from 'react-router'; import { useLogListDataSource } from '../hooks'; +import { IDataSourceLog } from '../interface'; + +const formatDuration = (seconds: number) => { + const safeSeconds = Math.max(0, seconds); + const hours = Math.floor(safeSeconds / 3600); + const minutes = Math.floor((safeSeconds % 3600) / 60); + const remainingSeconds = safeSeconds % 60; + + if (hours > 0) { + return `${hours}h ${minutes}m ${remainingSeconds}s`; + } + if (minutes > 0) { + return `${minutes}m ${remainingSeconds}s`; + } + return `${remainingSeconds}s`; +}; + +const getTaskCountdownSeconds = (row: IDataSourceLog, now: number) => { + const freqMinutes = + row.task_type === 'prune' + ? Number(row.prune_freq || 0) + : Number(row.refresh_freq || 0); + const scheduledAt = row.time_started + ? new Date(row.time_started).getTime() + : 0; + + if (!freqMinutes || !scheduledAt) { + return null; + } + + const nextRunAt = scheduledAt + freqMinutes * 60 * 1000; + return Math.ceil((nextRunAt - now) / 1000); +}; + +const TaskCountdown = ({ row, now }: { row: IDataSourceLog; now: number }) => { + const remainingSeconds = getTaskCountdownSeconds(row, now); + + if (remainingSeconds === null) { + return ''; + } + + return Task starts in {formatDuration(remainingSeconds)}; +}; + +const getSummary = (row: IDataSourceLog, now: number) => { + if (row.status === RunningStatus.SCHEDULE || row.status === '5') { + return ; + } + + if (row.status === RunningStatus.RUNNING || row.status === '1') { + return row.task_type === 'prune' ? 'Prune in progress' : 'Sync in progress'; + } + + if (row.status === RunningStatus.FAIL || row.status === '4') { + return row.error_msg || 'Task failed'; + } + + if (row.status === RunningStatus.CANCEL || row.status === '2') { + return ''; + } + + if (row.task_type === 'prune') { + return `deleted=${row.docs_removed_from_index || 0}, error=${row.error_count || 0}`; + } + + return `total=${row.total_docs_indexed || 0}, added=${row.new_docs_indexed || 0}, updated=${Math.max( + 0, + (row.total_docs_indexed || 0) - (row.new_docs_indexed || 0), + )}, error=${row.error_count || 0}`; +}; const columns = ({ handleToDataSetDetail, + now, }: { handleToDataSetDetail: (id: string) => void; + now: number; }) => { return [ { @@ -71,7 +136,6 @@ const columns = ({
{ - console.log('handleToDataSetDetail', row.original.kb_id); handleToDataSetDetail(row.original.kb_id); }} > @@ -86,39 +150,16 @@ const columns = ({ }, }, { - accessorKey: 'new_docs_indexed', - header: t('setting.newDocs'), + accessorKey: 'task_type', + header: 'Task Type', + cell: ({ row }) => row.original.task_type || 'sync', }, - { - id: 'operations', - header: t('setting.errorMsg'), + id: 'summary', + header: 'Summary', cell: ({ row }) => ( -
- {row.original.error_msg} - {row.original.error_msg && ( -
- - - - - -
- {row.original.full_exception_trace} -
-
-
-
- )} +
+ {getSummary(row.original as IDataSourceLog, now)}
), }, @@ -131,14 +172,22 @@ const columns = ({ // total: 0, // }; export const DataSourceLogsTable = ({ - refresh_freq, + autoRefresh, }: { - refresh_freq: number | false; + autoRefresh: boolean; }) => { - // const [pagination, setPagination] = useState(paginationInit); - const { data, pagination, setPagination } = - useLogListDataSource(refresh_freq); + const { data, pagination, setPagination } = useLogListDataSource(autoRefresh); const navigate = useNavigate(); + const [now, setNow] = useState(() => Date.now()); + + useEffect(() => { + const timer = window.setInterval(() => { + setNow(Date.now()); + }, 1000); + + return () => window.clearInterval(timer); + }, []); + const currentPagination = useMemo( () => ({ pageIndex: (pagination.current || 1) - 1, @@ -149,15 +198,14 @@ export const DataSourceLogsTable = ({ const handleToDataSetDetail = useCallback( (id: string) => { - console.log('handleToDataSetDetail', id); - navigate(`${Routes.DatasetBase}${Routes.DatasetBase}/${id}`); + navigate(`${Routes.Dataset}/${id}`); }, [navigate], ); const table = useReactTable({ data: data || [], - columns: columns({ handleToDataSetDetail }), + columns: columns({ handleToDataSetDetail, now }), manualPagination: true, getCoreRowModel: getCoreRowModel(), getPaginationRowModel: getPaginationRowModel(), diff --git a/web/src/pages/user-setting/data-source/hooks.ts b/web/src/pages/user-setting/data-source/hooks.ts index 686da32865..1fe074e794 100644 --- a/web/src/pages/user-setting/data-source/hooks.ts +++ b/web/src/pages/user-setting/data-source/hooks.ts @@ -1,9 +1,9 @@ import message from '@/components/ui/message'; +import { RunningStatus } from '@/constants/knowledge'; import { useSetModalState } from '@/hooks/common-hooks'; import { useGetPaginationWithRouter } from '@/hooks/logic-hooks'; import dataSourceService, { dataSourceRebuild, - dataSourceResume, dataSourceUpdate, deleteDataSource, featchDataSourceDetail, @@ -15,7 +15,12 @@ import { t } from 'i18next'; import { useCallback, useMemo, useState } from 'react'; import { useParams, useSearchParams } from 'react-router'; import { DataSourceKey, useDataSourceInfo } from './constant'; -import { IDataSorceInfo, IDataSource, IDataSourceBase } from './interface'; +import { + IDataSorceInfo, + IDataSource, + IDataSourceBase, + IDataSourceLog, +} from './interface'; export const useListDataSource = () => { const { dataSourceInfo } = useDataSourceInfo(); @@ -28,10 +33,8 @@ export const useListDataSource = () => { }); const categorizeDataBySource = (data: IDataSourceBase[]) => { - const categorizedData: Record = {} as Record< - DataSourceKey, - any[] - >; + const categorizedData: Partial> = + {}; data.forEach((item) => { const source = item.source; @@ -93,17 +96,29 @@ export const useAddDataSource = ({ isEdit = false }: { isEdit?: boolean }) => { async (data: any) => { setAddLoading(true); const { data: res } = isEdit - ? await dataSourceUpdate(data.id, data) + ? await dataSourceUpdate(data.id, { + ...data, + reschedule: true, + }) : await dataSourceService.dataSourceSet(data); console.log('🚀 ~ handleAddOk ~ code:', res.code); if (res.code === 0) { + if (isEdit && res.data?.id) { + queryClient.setQueryData( + ['data-source-detail', res.data.id], + res.data, + ); + queryClient.invalidateQueries({ + queryKey: ['data-source-detail', res.data.id], + }); + } queryClient.invalidateQueries({ queryKey: ['data-source'] }); message.success(t(`message.operated`)); hideAddingModal(); } setAddLoading(false); }, - [hideAddingModal, queryClient], + [hideAddingModal, isEdit, queryClient], ); return { @@ -117,24 +132,25 @@ export const useAddDataSource = ({ isEdit = false }: { isEdit?: boolean }) => { }; }; -export const useLogListDataSource = (refresh_freq: number | false) => { +export const useLogListDataSource = (autoRefresh: boolean) => { const { pagination, setPagination } = useGetPaginationWithRouter(); const [currentQueryParameters] = useSearchParams(); const id = currentQueryParameters.get('id'); - const { data, isFetching } = useQuery<{ logs: IDataSource[]; total: number }>( - { - queryKey: ['data-source-logs', id, pagination, refresh_freq], - refetchInterval: refresh_freq ? refresh_freq * 60 * 1000 : false, - queryFn: async () => { - const { data } = await getDataSourceLogs(id as string, { - page_size: pagination.pageSize, - page: pagination.current, - }); - return data.data; - }, + const { data, isFetching } = useQuery<{ + logs: IDataSourceLog[]; + total: number; + }>({ + queryKey: ['data-source-logs', id, pagination, autoRefresh], + refetchInterval: autoRefresh ? 15 * 1000 : false, + queryFn: async () => { + const { data } = await getDataSourceLogs(id as string, { + page_size: pagination.pageSize, + page: pagination.current, + }); + return data.data; }, - ); + }); return { data: data?.logs, isFetching, @@ -179,21 +195,49 @@ export const useFetchDataSourceDetail = () => { return { data }; }; -export const useDataSourceResume = () => { +export const useUpdateDataSourceStatus = () => { const [currentQueryParameters] = useSearchParams(); const id = currentQueryParameters.get('id'); const queryClient = useQueryClient(); - const handleResume = useCallback( - async (param: { resume: boolean }) => { - const { data } = await dataSourceResume(id as string, param); - if (data.code === 0) { - queryClient.invalidateQueries({ queryKey: ['data-source-detail', id] }); - message.success(t(`message.operated`)); + const [loading, setLoading] = useState(false); + const updateStatus = useCallback( + async (status: RunningStatus.SCHEDULE | RunningStatus.CANCEL) => { + if (!id) return; + + setLoading(true); + try { + const { data } = await dataSourceUpdate(id, { + status, + }); + if (data.code === 0) { + queryClient.setQueryData( + ['data-source-detail', id], + (previous?: IDataSource) => ({ + ...(previous || {}), + ...(data.data || {}), + status: data.data?.status ?? status, + }), + ); + + await Promise.all([ + queryClient.invalidateQueries({ + queryKey: ['data-source-detail', id], + }), + queryClient.invalidateQueries({ queryKey: ['data-source'] }), + queryClient.invalidateQueries({ + queryKey: ['data-source-logs', id], + }), + ]); + + message.success(t(`message.operated`)); + } + } finally { + setLoading(false); } }, [id, queryClient], ); - return { handleResume }; + return { updateStatus, loading }; }; export const useDataSourceRebuild = () => { diff --git a/web/src/pages/user-setting/data-source/interface.ts b/web/src/pages/user-setting/data-source/interface.ts index 5cca997487..812a2f8e0d 100644 --- a/web/src/pages/user-setting/data-source/interface.ts +++ b/web/src/pages/user-setting/data-source/interface.ts @@ -1,5 +1,5 @@ import { RunningStatus } from '@/constants/knowledge'; -import { DataSourceKey } from './contant'; +import { DataSourceKey } from './constant'; export interface IDataSorceInfo { id: DataSourceKey; @@ -28,20 +28,20 @@ export interface IDataSourceBase { export interface IDataSourceLog { connector_id: string; + docs_removed_from_index?: number; error_count: number; error_msg: string; id: string; kb_id: string; kb_name: string; - name: string; new_docs_indexed: number; - poll_range_end: null | string; - poll_range_start: null | string; - reindex: string; - source: DataSourceKey; + prune_freq?: number; + refresh_freq?: number; status: RunningStatus; - tenant_id: string; - timeout_secs: number; + task_type?: string; + time_started?: string | null; + total_docs_indexed?: number; + update_date: string; } interface IDataSourceInfoItem { diff --git a/web/src/services/data-source-service.ts b/web/src/services/data-source-service.ts index 7be85dce85..2118899340 100644 --- a/web/src/services/data-source-service.ts +++ b/web/src/services/data-source-service.ts @@ -20,15 +20,12 @@ const dataSourceService = registerServer( export const deleteDataSource = (id: string) => request.delete(api.dataSourceDel(id)); -export const dataSourceResume = (id: string, data: { resume: boolean }) => { - return request.post(api.dataSourceResume(id), { data }); -}; export const dataSourceRebuild = (id: string, data: { kb_id: string }) => { return request.post(api.dataSourceRebuild(id), { data }); }; -export const dataSourceUpdate = (id: string, data: { kb_id: string }) => { +export const dataSourceUpdate = (id: string, data: Record) => { return request.patch(api.dataSourceUpdate(id), { data }); }; diff --git a/web/src/utils/api.ts b/web/src/utils/api.ts index 32a3d5bfd6..03b065cffe 100644 --- a/web/src/utils/api.ts +++ b/web/src/utils/api.ts @@ -39,7 +39,6 @@ export default { dataSourceSet: `${restAPIv1}/connectors`, dataSourceList: `${restAPIv1}/connectors`, dataSourceDel: (id: string) => `${restAPIv1}/connectors/${id}`, - dataSourceResume: (id: string) => `${restAPIv1}/connectors/${id}/resume`, dataSourceRebuild: (id: string) => `${restAPIv1}/connectors/${id}/rebuild`, dataSourceLogs: (id: string) => `${restAPIv1}/connectors/${id}/logs`, dataSourceDetail: (id: string) => `${restAPIv1}/connectors/${id}`,