diff --git a/common/data_source/rdbms_connector.py b/common/data_source/rdbms_connector.py index 05628501c6..9811d2064d 100644 --- a/common/data_source/rdbms_connector.py +++ b/common/data_source/rdbms_connector.py @@ -1,5 +1,6 @@ """RDBMS (MySQL/PostgreSQL) data source connector for importing data from relational databases.""" +import copy import hashlib import json import logging @@ -12,8 +13,13 @@ from common.data_source.exceptions import ( ConnectorMissingCredentialError, ConnectorValidationError, ) -from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch -from common.data_source.models import Document +from common.data_source.interfaces import ( + LoadConnector, + PollConnector, + SecondsSinceUnixEpoch, + SlimConnectorWithPermSync, +) +from common.data_source.models import Document, SlimDocument class DatabaseType(str, Enum): @@ -22,15 +28,18 @@ class DatabaseType(str, Enum): POSTGRESQL = "postgresql" -class RDBMSConnector(LoadConnector, PollConnector): +class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """ - RDBMS connector for importing data from MySQL and PostgreSQL databases. - - This connector allows users to: - 1. Connect to a MySQL or PostgreSQL database - 2. Execute a SQL query to extract data - 3. Map columns to content (for vectorization) and metadata - 4. Sync data in batch or incremental mode using a timestamp column + Import rows from MySQL or PostgreSQL into documents. + + The flow is: + 1. Connect to the configured database. + 2. Read rows from a custom SQL query, or from every table when no query is provided. + 3. Build document content from the selected content columns. + 4. Copy the selected metadata columns into document metadata. + 5. Use the configured ID column as the stable document ID, or hash the content when no ID column is set. + 6. For incremental sync, treat the timestamp column as an ordered cursor and only compare values by size. + 7. For deleted-file sync, read a slim snapshot of current row IDs and let the sync worker remove stale documents. """ def __init__( self, @@ -73,6 +82,9 @@ class RDBMSConnector(LoadConnector, PollConnector): self._connection = None self._credentials: Dict[str, Any] = {} + self._sync_connector_id: str | None = None + self._sync_config: Dict[str, Any] | None = None + self._pending_sync_cursor_value: Any = None def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None: """Load database credentials.""" @@ -160,98 +172,175 @@ class RDBMSConnector(LoadConnector, PollConnector): finally: cursor.close() - def _build_query_with_time_filter( - self, - start: Optional[datetime] = None, - end: Optional[datetime] = None, - ) -> str: - """Build the query with optional time filtering for incremental sync.""" - if not self.query: - return "" # Will be handled by table discovery - base_query = self.query.rstrip(";") - - if not self.timestamp_column or (start is None and end is None): - return base_query - - has_where = "where" in base_query.lower() - connector = " AND" if has_where else " WHERE" - - time_conditions = [] - if start is not None: - if self.db_type == DatabaseType.MYSQL: - time_conditions.append(f"{self.timestamp_column} > '{start.strftime('%Y-%m-%d %H:%M:%S')}'") - else: - time_conditions.append(f"{self.timestamp_column} > '{start.isoformat()}'") - - if end is not None: - if self.db_type == DatabaseType.MYSQL: - time_conditions.append(f"{self.timestamp_column} <= '{end.strftime('%Y-%m-%d %H:%M:%S')}'") - else: - time_conditions.append(f"{self.timestamp_column} <= '{end.isoformat()}'") - - if time_conditions: - return f"{base_query}{connector} {' AND '.join(time_conditions)}" - - return base_query - def _row_to_document(self, row: Union[tuple, list, Dict[str, Any]], column_names: list) -> Document: - """Convert a database row to a Document.""" - row_dict = dict(zip(column_names, row)) if isinstance(row, (list, tuple)) else row - + def _get_base_queries(self) -> list[str]: + if self.query: + return [self.query.rstrip(";")] + return [f"SELECT * FROM {table}" for table in self._get_tables()] + + + def _wrap_query(self, base_query: str, select_clause: str = "*") -> str: + return f"SELECT {select_clause} FROM ({base_query}) AS ragflow_src" + + + @staticmethod + def serialize_cursor_value(value: Any) -> Any: + # Example: + # - int cursor 42 is stored as 42 + # - datetime cursor 2026-05-07T12:34:56+00:00 is stored as + # {"__ragflow_rdbms_cursor_type__": "datetime", "value": "..."} + # Only datetime needs wrapping because connector config is JSON. + if isinstance(value, datetime): + return { + "__ragflow_rdbms_cursor_type__": "datetime", + "value": value.isoformat(), + } + return value + + + @staticmethod + def deserialize_cursor_value(value: Any) -> Any: + # Reverse the datetime wrapper above. + # Non-datetime cursors such as int/str/float are returned as-is. + if ( + isinstance(value, dict) + and value.get("__ragflow_rdbms_cursor_type__") == "datetime" + ): + return datetime.fromisoformat(value["value"]) + return value + + + def _format_sql_value(self, value: Any) -> str: + if isinstance(value, datetime): + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + if self.db_type == DatabaseType.MYSQL: + rendered = value.astimezone(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") + else: + rendered = value.astimezone(timezone.utc).isoformat() + return f"'{rendered}'" + if isinstance(value, bool): + if self.db_type == DatabaseType.POSTGRESQL: + return "TRUE" if value else "FALSE" + return "1" if value else "0" + if isinstance(value, (int, float)): + return str(value) + if isinstance(value, str): + return "'" + value.replace("'", "''") + "'" + raise ConnectorValidationError( + f"Unsupported timestamp cursor value type: {type(value).__name__}" + ) + + + def _build_time_filtered_query( + self, + base_query: str, + start: Any = None, + end: Any = None, + ) -> str: + if not self.timestamp_column or (start is None and end is None): + return self._wrap_query(base_query) + + conditions = [] + if start is not None: + conditions.append( + f"ragflow_src.{self.timestamp_column} > {self._format_sql_value(start)}" + ) + if end is not None: + conditions.append( + f"ragflow_src.{self.timestamp_column} <= {self._format_sql_value(end)}" + ) + + query = self._wrap_query(base_query) + if conditions: + query = f"{query} WHERE {' AND '.join(conditions)}" + return query + + + def _build_max_timestamp_query(self, base_query: str) -> str: + return ( + f"SELECT MAX(ragflow_src.{self.timestamp_column}) " + f"FROM ({base_query}) AS ragflow_src" + ) + + + def _build_slim_query(self, base_query: str) -> str: + columns = [self.id_column] if self.id_column else self.content_columns + select_clause = ", ".join(f"ragflow_src.{column}" for column in columns) + return self._wrap_query(base_query, select_clause) + + + def _build_content(self, row_dict: Dict[str, Any]) -> str: content_parts = [] for col in self.content_columns: - if col in row_dict and row_dict[col] is not None: - value = row_dict[col] - if isinstance(value, (dict, list)): - value = json.dumps(value, ensure_ascii=False) - # Use brackets around field name and put value on a new line - # so that TxtParser preserves field boundaries after chunking. - content_parts.append(f"【{col}】:\n{value}") - - content = "\n\n".join(content_parts) - - if self.id_column and self.id_column in row_dict: - doc_id = f"{self.db_type}:{self.database}:{row_dict[self.id_column]}" - else: - content_hash = hashlib.md5(content.encode()).hexdigest() - doc_id = f"{self.db_type}:{self.database}:{content_hash}" - + if col not in row_dict or row_dict[col] is None: + continue + value = row_dict[col] + if isinstance(value, (dict, list)): + value = json.dumps(value, ensure_ascii=False) + content_parts.append(f"【{col}】:\n{value}") + return "\n\n".join(content_parts) + + + def _build_document_id_from_row(self, row_dict: Dict[str, Any]) -> str: + if self.id_column and self.id_column in row_dict and row_dict[self.id_column] is not None: + return f"{self.db_type}:{self.database}:{row_dict[self.id_column]}" + content = self._build_content(row_dict) + content_hash = hashlib.md5(content.encode()).hexdigest() + return f"{self.db_type}:{self.database}:{content_hash}" + + + def _row_to_document( + self, + row: Union[tuple, list, Dict[str, Any]], + column_names: list[str], + ) -> Document: + """Convert a database row to a Document.""" + row_dict = dict(zip(column_names, row)) if isinstance(row, (list, tuple)) else row + content = self._build_content(row_dict) metadata = {} for col in self.metadata_columns: - if col in row_dict and row_dict[col] is not None: - value = row_dict[col] - if isinstance(value, datetime): - value = value.isoformat() - elif isinstance(value, (dict, list)): - value = json.dumps(value, ensure_ascii=False) - else: - value = str(value) - metadata[col] = value - + if col not in row_dict or row_dict[col] is None: + continue + value = row_dict[col] + if isinstance(value, datetime): + value = value.isoformat() + elif isinstance(value, (dict, list)): + value = json.dumps(value, ensure_ascii=False) + else: + value = str(value) + metadata[col] = value + doc_updated_at = datetime.now(timezone.utc) - if self.timestamp_column and self.timestamp_column in row_dict: + if self.timestamp_column and self.timestamp_column in row_dict and row_dict[self.timestamp_column] is not None: ts_value = row_dict[self.timestamp_column] if isinstance(ts_value, datetime): if ts_value.tzinfo is None: doc_updated_at = ts_value.replace(tzinfo=timezone.utc) else: - doc_updated_at = ts_value - - first_content_col = self.content_columns[0] if self.content_columns else "record" - semantic_id = str(row_dict.get(first_content_col, "database_record")).replace("\n", " ").replace("\r", " ").strip()[:100] + doc_updated_at = ts_value.astimezone(timezone.utc) + + first_content_col = self.content_columns[0] if self.content_columns else "record" + semantic_id = ( + str(row_dict.get(first_content_col, "database_record")) + .replace("\n", " ") + .replace("\r", " ") + .strip()[:100] + ) + blob = content.encode("utf-8") - return Document( - id=doc_id, - blob=content.encode("utf-8"), + id=self._build_document_id_from_row(row_dict), + blob=blob, source=DocumentSource(self.db_type.value), semantic_identifier=semantic_id, extension=".txt", doc_updated_at=doc_updated_at, - size_bytes=len(content.encode("utf-8")), + size_bytes=len(blob), metadata=metadata if metadata else None, ) + def _yield_documents_from_query( self, query: str, @@ -288,30 +377,146 @@ class RDBMSConnector(LoadConnector, PollConnector): pass cursor.close() + + def _yield_slim_documents_from_query( + self, + query: str, + ) -> Generator[list[SlimDocument], None, None]: + connection = self._get_connection() + cursor = connection.cursor() + + try: + logging.debug(f"Executing slim query: {query[:200]}...") + cursor.execute(query) + column_names = [desc[0] for desc in cursor.description] + + batch: list[SlimDocument] = [] + for row in cursor: + row_dict = dict(zip(column_names, row)) if isinstance(row, (list, tuple)) else row + batch.append(SlimDocument(id=self._build_document_id_from_row(row_dict))) + if len(batch) >= self.batch_size: + yield batch + batch = [] + + if batch: + yield batch + finally: + try: + cursor.fetchall() + except Exception: + pass + cursor.close() + + + def get_max_cursor_value(self) -> Any: + if not self.timestamp_column: + return None + + max_cursor_value = None + connection = self._get_connection() + cursor = connection.cursor() + + try: + for base_query in self._get_base_queries(): + query = self._build_max_timestamp_query(base_query) + logging.debug(f"Executing max timestamp query: {query[:200]}...") + cursor.execute(query) + row = cursor.fetchone() + if row is None or row[0] is None: + continue + if max_cursor_value is None or row[0] > max_cursor_value: + max_cursor_value = row[0] + finally: + cursor.close() + + return max_cursor_value + + def _yield_documents( self, - start: Optional[datetime] = None, - end: Optional[datetime] = None, + start: Any = None, + end: Any = None, ) -> Generator[list[Document], None, None]: """Generate documents from database query results.""" - if self.query: - query = self._build_query_with_time_filter(start, end) - yield from self._yield_documents_from_query(query) - else: - tables = self._get_tables() - logging.info(f"No query specified. Loading all {len(tables)} tables: {tables}") - for table in tables: - query = f"SELECT * FROM {table}" - logging.info(f"Loading table: {table}") + base_queries = self._get_base_queries() + if not self.query: + logging.info(f"No query specified. Loading all {len(base_queries)} tables.") + + try: + for base_query in base_queries: + query = self._build_time_filtered_query(base_query, start, end) yield from self._yield_documents_from_query(query) - - self._close_connection() + finally: + self._close_connection() + def load_from_state(self) -> Generator[list[Document], None, None]: """Load all documents from the database (full sync).""" logging.debug(f"Loading all records from {self.db_type} database: {self.database}") return self._yield_documents() + + def retrieve_all_slim_docs_perm_sync( + self, + callback: Any = None, + ) -> Generator[list[SlimDocument], None, None]: + del callback + + base_queries = self._get_base_queries() + if not self.query: + logging.info(f"No query specified. Retrieving slim documents from all {len(base_queries)} tables.") + + try: + for base_query in base_queries: + yield from self._yield_slim_documents_from_query( + self._build_slim_query(base_query) + ) + finally: + self._close_connection() + + def prepare_sync_state(self, connector_id: str, config: Dict[str, Any]) -> None: + self._sync_connector_id = connector_id + self._sync_config = copy.deepcopy(config) + if not self.timestamp_column: + self._pending_sync_cursor_value = None + return + self._pending_sync_cursor_value = self.get_max_cursor_value() + + + def get_saved_sync_cursor_value(self) -> Any: + if self._sync_config is None: + return None + return self.deserialize_cursor_value(self._sync_config.get("sync_cursor_value")) + + + def persist_sync_state(self) -> None: + if not self.timestamp_column or self._sync_connector_id is None or self._sync_config is None: + return + + from api.db.services.connector_service import ConnectorService + + updated_conf = copy.deepcopy(self._sync_config) + updated_conf["sync_cursor_value"] = self.serialize_cursor_value( + self._pending_sync_cursor_value + ) + ConnectorService.update_by_id(self._sync_connector_id, {"config": updated_conf}) + self._sync_config = updated_conf + + + def load_from_cursor_range( + self, + start_value: Any = None, + end_value: Any = None, + ) -> Generator[list[Document], None, None]: + if end_value is None: + self._close_connection() + return iter(()) + if start_value is not None and end_value <= start_value: + self._close_connection() + return iter(()) + return self._yield_documents(start_value, end_value) + + def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> Generator[list[Document], None, None]: @@ -322,16 +527,8 @@ class RDBMSConnector(LoadConnector, PollConnector): "Falling back to full sync." ) return self.load_from_state() - - start_datetime = datetime.fromtimestamp(start, tz=timezone.utc) - end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) - - logging.debug( - f"Polling {self.db_type} database {self.database} " - f"from {start_datetime} to {end_datetime}" - ) - - return self._yield_documents(start_datetime, end_datetime) + return self._yield_documents(start, end) + def validate_connector_settings(self) -> None: """Validate connector settings by testing the connection.""" diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index da16e318ea..697e3d5dee 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -253,6 +253,7 @@ class SyncBase: and task.get("poll_range_start") and self.conf.get("sync_deleted_files") ) + cleanup_errors = [] if expects_deleted_file_snapshot and file_list is None: logging.warning( "%s deleted-file snapshot retrieval failed " @@ -261,16 +262,8 @@ class SyncBase: task["connector_id"], task["kb_id"], ) - elif file_list: - logging.info( - "[%s] Starting stale document reconciliation. Snapshot size: %d " - "(connector_id=%s, kb_id=%s)", - self.SOURCE_NAME, - len(file_list), - task["connector_id"], - task["kb_id"], - ) - removed_docs, _ = ConnectorService.cleanup_stale_documents_for_task( + elif file_list is not None: + removed_docs, cleanup_errors = ConnectorService.cleanup_stale_documents_for_task( task["id"], task["connector_id"], task["kb_id"], @@ -288,6 +281,13 @@ class SyncBase: summary = f"{summary}, skipped={failed_docs}" logging.info(summary) + if ( + isinstance(self, _RDBMSBase) + and failed_docs == 0 + and (not expects_deleted_file_snapshot or file_list is not None) + and not cleanup_errors + ): + self.connector.persist_sync_state() SyncLogsService.done(task["id"], task["connector_id"]) task["poll_range_start"] = next_update @@ -937,14 +937,6 @@ class WebDAV(SyncBase): end_ts = datetime.now(timezone.utc).timestamp() if self.conf.get("sync_deleted_files"): file_list = [] - logging.info( - "WebDAV: fetching slim snapshot for stale-document reconciliation " - "(connector_id=%s, kb_id=%s, base_url=%s, path=%s)", - task["connector_id"], - task["kb_id"], - self.conf["base_url"], - self.conf.get("remote_path", "/"), - ) try: for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): file_list.extend(slim_batch) @@ -1560,14 +1552,6 @@ class SeaFile(SyncBase): end_ts = datetime.now(timezone.utc).timestamp() if self.conf.get("sync_deleted_files"): file_list = [] - logging.info( - "SeaFile: fetching slim snapshot for stale-document reconciliation " - "(connector_id=%s, kb_id=%s, scope=%s)", - task["connector_id"], - task["kb_id"], - conf.get("sync_scope") - or SeafileSyncScope.ACCOUNT.value, - ) try: for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): file_list.extend(slim_batch) @@ -1668,82 +1652,74 @@ class DingTalkAITable(SyncBase): return document_generator, file_list -class MySQL(SyncBase): +class _RDBMSBase(SyncBase): + DB_TYPE: str = "" + LOG_NAME: str = "" + DEFAULT_PORT: int = 0 + + async def _generate(self, task: dict): + self.connector = RDBMSConnector( + db_type=self.DB_TYPE, + host=self.conf.get("host", "localhost"), + port=int(self.conf.get("port", self.DEFAULT_PORT)), + database=self.conf.get("database", ""), + query=self.conf.get("query", ""), + content_columns=self.conf.get("content_columns", ""), + metadata_columns=self.conf.get("metadata_columns", ""), + id_column=self.conf.get("id_column") or None, + timestamp_column=self.conf.get("timestamp_column") or None, + batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE), + ) + + credentials = self.conf.get("credentials") + if not credentials: + raise ValueError(f"{self.DB_TYPE} connector is missing credentials.") + + self.connector.load_credentials(credentials) + self.connector.validate_connector_settings() + self.connector.prepare_sync_state(task["connector_id"], self.conf) + + file_list = None + if ( + task["reindex"] != "1" + and task["poll_range_start"] + and self.conf.get("sync_deleted_files") + ): + file_list = [] + for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): + file_list.extend(slim_batch) + + if task["reindex"] == "1" or not task["poll_range_start"]: + document_generator = self.connector.load_from_state() + _begin_info = "totally" + elif not self.connector.timestamp_column: + document_generator = self.connector.load_from_state() + _begin_info = f"from {task['poll_range_start']}" + else: + poll_start = task["poll_range_start"] + start_cursor_value = self.connector.get_saved_sync_cursor_value() + document_generator = self.connector.load_from_cursor_range( + start_cursor_value, + self.connector._pending_sync_cursor_value, + ) + _begin_info = f"from {poll_start}" + + self.log_connection(self.LOG_NAME, f"{self.conf.get('host')}:{self.conf.get('database')}", task) + return document_generator, file_list + + +class MySQL(_RDBMSBase): SOURCE_NAME: str = FileSource.MYSQL - - async def _generate(self, task: dict): - self.connector = RDBMSConnector( - db_type="mysql", - host=self.conf.get("host", "localhost"), - port=int(self.conf.get("port", 3306)), - database=self.conf.get("database", ""), - query=self.conf.get("query", ""), - content_columns=self.conf.get("content_columns", ""), - metadata_columns=self.conf.get("metadata_columns", ""), - id_column=self.conf.get("id_column") or None, - timestamp_column=self.conf.get("timestamp_column") or None, - batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE), - ) - - credentials = self.conf.get("credentials") - if not credentials: - raise ValueError("MySQL connector is missing credentials.") - - self.connector.load_credentials(credentials) - self.connector.validate_connector_settings() - - if task["reindex"] == "1" or not task["poll_range_start"]: - document_generator = self.connector.load_from_state() - _begin_info = "totally" - else: - poll_start = task["poll_range_start"] - document_generator = self.connector.poll_source( - poll_start.timestamp(), - datetime.now(timezone.utc).timestamp() - ) - _begin_info = f"from {poll_start}" - - self.log_connection("MySQL", f"{self.conf.get('host')}:{self.conf.get('database')}", task) - return document_generator + DB_TYPE: str = "mysql" + LOG_NAME: str = "MySQL" + DEFAULT_PORT: int = 3306 -class PostgreSQL(SyncBase): +class PostgreSQL(_RDBMSBase): SOURCE_NAME: str = FileSource.POSTGRESQL - - async def _generate(self, task: dict): - self.connector = RDBMSConnector( - db_type="postgresql", - host=self.conf.get("host", "localhost"), - port=int(self.conf.get("port", 5432)), - database=self.conf.get("database", ""), - query=self.conf.get("query", ""), - content_columns=self.conf.get("content_columns", ""), - metadata_columns=self.conf.get("metadata_columns", ""), - id_column=self.conf.get("id_column") or None, - timestamp_column=self.conf.get("timestamp_column") or None, - batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE), - ) - - credentials = self.conf.get("credentials") - if not credentials: - raise ValueError("PostgreSQL connector is missing credentials.") - - self.connector.load_credentials(credentials) - self.connector.validate_connector_settings() - - if task["reindex"] == "1" or not task["poll_range_start"]: - document_generator = self.connector.load_from_state() - _begin_info = "totally" - else: - poll_start = task["poll_range_start"] - document_generator = self.connector.poll_source( - poll_start.timestamp(), - datetime.now(timezone.utc).timestamp() - ) - _begin_info = f"from {poll_start}" - - self.log_connection("PostgreSQL", f"{self.conf.get('host')}:{self.conf.get('database')}", task) - return document_generator + DB_TYPE: str = "postgresql" + LOG_NAME: str = "PostgreSQL" + DEFAULT_PORT: int = 5432 func_factory = { diff --git a/test/unit_test/rag/test_sync_data_source.py b/test/unit_test/rag/test_sync_data_source.py index f513ec7a31..be9d89372a 100644 --- a/test/unit_test/rag/test_sync_data_source.py +++ b/test/unit_test/rag/test_sync_data_source.py @@ -95,6 +95,18 @@ class _FakeSync(sync_data_source.SyncBase): return self._generate_output +def _make_fake_doc(doc_id="doc-1", updated_at=None): + return types.SimpleNamespace( + id=doc_id, + semantic_identifier=doc_id, + extension=".txt", + size_bytes=1, + doc_updated_at=updated_at or datetime(2026, 1, 1, tzinfo=timezone.utc), + blob=b"x", + metadata=None, + ) + + def _make_task(): return { "id": "task-1", @@ -121,19 +133,35 @@ def _patch_common_dependencies(monkeypatch): @pytest.mark.anyio @pytest.mark.p2 -async def test_run_task_logic_skips_cleanup_for_empty_snapshot(monkeypatch): +async def test_run_task_logic_cleans_up_for_empty_snapshot(monkeypatch): cleanup_calls = [] _patch_common_dependencies(monkeypatch) + + def _fake_cleanup(*args, **kwargs): + cleanup_calls.append((args, kwargs)) + return 1, [] + monkeypatch.setattr( sync_data_source.ConnectorService, "cleanup_stale_documents_for_task", - lambda *_args, **_kwargs: cleanup_calls.append((_args, _kwargs)), + _fake_cleanup, ) await _FakeSync((iter(()), []))._run_task_logic(_make_task()) - assert cleanup_calls == [] + assert cleanup_calls == [ + ( + ( + "task-1", + "connector-1", + "kb-1", + "tenant-1", + [], + ), + {}, + ) + ] @pytest.mark.anyio @@ -170,6 +198,203 @@ async def test_run_task_logic_cleans_up_for_non_empty_snapshot(monkeypatch): ] +class _FakeRDBMSConnector: + instance = None + + def __init__( + self, + db_type, + host, + port, + database, + query, + content_columns, + metadata_columns=None, + id_column=None, + timestamp_column=None, + batch_size=2, + ): + self.db_type = db_type + self.host = host + self.port = port + self.database = database + self.query = query + self.content_columns = content_columns + self.metadata_columns = metadata_columns + self.id_column = id_column + self.timestamp_column = timestamp_column + self.batch_size = batch_size + self.load_from_state_called = False + self.retrieve_all_slim_docs_perm_sync_called = False + self.prepare_sync_state_called = False + self.load_from_cursor_range_called = False + self.persist_sync_state_called = False + self._pending_sync_cursor_value = None + _FakeRDBMSConnector.instance = self + + def load_credentials(self, credentials): + self.credentials = credentials + + def validate_connector_settings(self): + return None + + def prepare_sync_state(self, connector_id, config): + self.prepare_sync_state_called = True + self.prepare_sync_state_args = (connector_id, config) + + def get_saved_sync_cursor_value(self): + return None + + def retrieve_all_slim_docs_perm_sync(self, callback=None): + del callback + self.retrieve_all_slim_docs_perm_sync_called = True + yield [types.SimpleNamespace(id="row-1")] + + def load_from_state(self): + self.load_from_state_called = True + return iter((["full-sync"],)) + + def load_from_cursor_range(self, start_value=None, end_value=None): + self.load_from_cursor_range_called = True + return iter(([ _make_fake_doc("incremental-doc") ],)) + + def persist_sync_state(self): + self.persist_sync_state_called = True + + +@pytest.mark.anyio +@pytest.mark.p2 +async def test_rdbms_generate_keeps_deleted_file_snapshot_without_timestamp_column(monkeypatch): + monkeypatch.setattr(sync_data_source, "RDBMSConnector", _FakeRDBMSConnector) + + task = { + **_make_task(), + "reindex": "0", + "poll_range_start": datetime(2026, 1, 1, tzinfo=timezone.utc), + "skip_connection_log": True, + } + sync = sync_data_source.MySQL( + { + "host": "localhost", + "port": 3306, + "database": "db", + "query": "SELECT * FROM t", + "content_columns": "name", + "credentials": {"username": "u", "password": "p"}, + "sync_deleted_files": True, + } + ) + + document_generator, file_list = await sync._generate(task) + connector = _FakeRDBMSConnector.instance + + assert connector is not None + assert connector.load_from_state_called is True + assert connector.load_from_cursor_range_called is False + assert connector.retrieve_all_slim_docs_perm_sync_called is True + assert file_list is not None + assert [doc.id for doc in file_list] == ["row-1"] + assert list(document_generator) == [["full-sync"]] + + +@pytest.mark.anyio +@pytest.mark.p2 +async def test_rdbms_cursor_persists_only_after_success(monkeypatch): + monkeypatch.setattr(sync_data_source, "RDBMSConnector", _FakeRDBMSConnector) + _patch_common_dependencies(monkeypatch) + monkeypatch.setattr( + sync_data_source.KnowledgebaseService, + "get_by_id", + lambda *_args, **_kwargs: (True, object()), + ) + monkeypatch.setattr( + sync_data_source.SyncLogsService, + "increase_docs", + lambda *_args, **_kwargs: None, + ) + monkeypatch.setattr( + sync_data_source.SyncLogsService, + "duplicate_and_parse", + lambda *_args, **_kwargs: ([], ["parsed-doc-id"]), + ) + + task = { + **_make_task(), + "reindex": "0", + "poll_range_start": datetime(2026, 1, 1, tzinfo=timezone.utc), + "skip_connection_log": True, + } + sync = sync_data_source.MySQL( + { + "host": "localhost", + "port": 3306, + "database": "db", + "query": "SELECT * FROM t", + "content_columns": "name", + "timestamp_column": "ts", + "credentials": {"username": "u", "password": "p"}, + "sync_deleted_files": False, + } + ) + + await sync._run_task_logic(task) + + connector = _FakeRDBMSConnector.instance + assert connector is not None + assert connector.persist_sync_state_called is True + + +@pytest.mark.anyio +@pytest.mark.p2 +async def test_rdbms_cursor_does_not_persist_when_batch_is_skipped(monkeypatch): + monkeypatch.setattr(sync_data_source, "RDBMSConnector", _FakeRDBMSConnector) + _patch_common_dependencies(monkeypatch) + monkeypatch.setattr( + sync_data_source.KnowledgebaseService, + "get_by_id", + lambda *_args, **_kwargs: (True, object()), + ) + monkeypatch.setattr( + sync_data_source.SyncLogsService, + "increase_docs", + lambda *_args, **_kwargs: None, + ) + + def _raise_in_duplicate_and_parse(*_args, **_kwargs): + raise RuntimeError("batch failed") + + monkeypatch.setattr( + sync_data_source.SyncLogsService, + "duplicate_and_parse", + _raise_in_duplicate_and_parse, + ) + + task = { + **_make_task(), + "reindex": "0", + "poll_range_start": datetime(2026, 1, 1, tzinfo=timezone.utc), + "skip_connection_log": True, + } + sync = sync_data_source.MySQL( + { + "host": "localhost", + "port": 3306, + "database": "db", + "query": "SELECT * FROM t", + "content_columns": "name", + "timestamp_column": "ts", + "credentials": {"username": "u", "password": "p"}, + "sync_deleted_files": False, + } + ) + + await sync._run_task_logic(task) + + connector = _FakeRDBMSConnector.instance + assert connector is not None + assert connector.persist_sync_state_called is False + + class _FakeDropboxConnector: instance = None diff --git a/web/src/pages/user-setting/data-source/constant/index.tsx b/web/src/pages/user-setting/data-source/constant/index.tsx index 50a0932b48..0aae8868c5 100644 --- a/web/src/pages/user-setting/data-source/constant/index.tsx +++ b/web/src/pages/user-setting/data-source/constant/index.tsx @@ -126,6 +126,12 @@ export const DataSourceFeatureVisibilityMap: Partial< [DataSourceKey.RSS]: { syncDeletedFiles: true, }, + [DataSourceKey.MYSQL]: { + syncDeletedFiles: true, + }, + [DataSourceKey.POSTGRESQL]: { + syncDeletedFiles: true, + }, }; const isDataSourceFeatureVisible = (