mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Feat: enable sync deleted files for RDBMS & fix remove last file issue (#14615)
### What problem does this PR solve? Feat: enable sync deleted files for RDBMS & fix remove last file issue ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""RDBMS (MySQL/PostgreSQL) data source connector for importing data from relational databases."""
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
@@ -12,8 +13,13 @@ from common.data_source.exceptions import (
|
||||
ConnectorMissingCredentialError,
|
||||
ConnectorValidationError,
|
||||
)
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
|
||||
from common.data_source.models import Document
|
||||
from common.data_source.interfaces import (
|
||||
LoadConnector,
|
||||
PollConnector,
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync,
|
||||
)
|
||||
from common.data_source.models import Document, SlimDocument
|
||||
|
||||
|
||||
class DatabaseType(str, Enum):
|
||||
@@ -22,15 +28,18 @@ class DatabaseType(str, Enum):
|
||||
POSTGRESQL = "postgresql"
|
||||
|
||||
|
||||
class RDBMSConnector(LoadConnector, PollConnector):
|
||||
class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""
|
||||
RDBMS connector for importing data from MySQL and PostgreSQL databases.
|
||||
|
||||
This connector allows users to:
|
||||
1. Connect to a MySQL or PostgreSQL database
|
||||
2. Execute a SQL query to extract data
|
||||
3. Map columns to content (for vectorization) and metadata
|
||||
4. Sync data in batch or incremental mode using a timestamp column
|
||||
Import rows from MySQL or PostgreSQL into documents.
|
||||
|
||||
The flow is:
|
||||
1. Connect to the configured database.
|
||||
2. Read rows from a custom SQL query, or from every table when no query is provided.
|
||||
3. Build document content from the selected content columns.
|
||||
4. Copy the selected metadata columns into document metadata.
|
||||
5. Use the configured ID column as the stable document ID, or hash the content when no ID column is set.
|
||||
6. For incremental sync, treat the timestamp column as an ordered cursor and only compare values by size.
|
||||
7. For deleted-file sync, read a slim snapshot of current row IDs and let the sync worker remove stale documents.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
@@ -73,6 +82,9 @@ class RDBMSConnector(LoadConnector, PollConnector):
|
||||
|
||||
self._connection = None
|
||||
self._credentials: Dict[str, Any] = {}
|
||||
self._sync_connector_id: str | None = None
|
||||
self._sync_config: Dict[str, Any] | None = None
|
||||
self._pending_sync_cursor_value: Any = None
|
||||
|
||||
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
|
||||
"""Load database credentials."""
|
||||
@@ -160,98 +172,175 @@ class RDBMSConnector(LoadConnector, PollConnector):
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def _build_query_with_time_filter(
|
||||
self,
|
||||
start: Optional[datetime] = None,
|
||||
end: Optional[datetime] = None,
|
||||
) -> str:
|
||||
"""Build the query with optional time filtering for incremental sync."""
|
||||
if not self.query:
|
||||
return "" # Will be handled by table discovery
|
||||
base_query = self.query.rstrip(";")
|
||||
|
||||
if not self.timestamp_column or (start is None and end is None):
|
||||
return base_query
|
||||
|
||||
has_where = "where" in base_query.lower()
|
||||
connector = " AND" if has_where else " WHERE"
|
||||
|
||||
time_conditions = []
|
||||
if start is not None:
|
||||
if self.db_type == DatabaseType.MYSQL:
|
||||
time_conditions.append(f"{self.timestamp_column} > '{start.strftime('%Y-%m-%d %H:%M:%S')}'")
|
||||
else:
|
||||
time_conditions.append(f"{self.timestamp_column} > '{start.isoformat()}'")
|
||||
|
||||
if end is not None:
|
||||
if self.db_type == DatabaseType.MYSQL:
|
||||
time_conditions.append(f"{self.timestamp_column} <= '{end.strftime('%Y-%m-%d %H:%M:%S')}'")
|
||||
else:
|
||||
time_conditions.append(f"{self.timestamp_column} <= '{end.isoformat()}'")
|
||||
|
||||
if time_conditions:
|
||||
return f"{base_query}{connector} {' AND '.join(time_conditions)}"
|
||||
|
||||
return base_query
|
||||
|
||||
def _row_to_document(self, row: Union[tuple, list, Dict[str, Any]], column_names: list) -> Document:
|
||||
"""Convert a database row to a Document."""
|
||||
row_dict = dict(zip(column_names, row)) if isinstance(row, (list, tuple)) else row
|
||||
|
||||
def _get_base_queries(self) -> list[str]:
|
||||
if self.query:
|
||||
return [self.query.rstrip(";")]
|
||||
return [f"SELECT * FROM {table}" for table in self._get_tables()]
|
||||
|
||||
|
||||
def _wrap_query(self, base_query: str, select_clause: str = "*") -> str:
|
||||
return f"SELECT {select_clause} FROM ({base_query}) AS ragflow_src"
|
||||
|
||||
|
||||
@staticmethod
|
||||
def serialize_cursor_value(value: Any) -> Any:
|
||||
# Example:
|
||||
# - int cursor 42 is stored as 42
|
||||
# - datetime cursor 2026-05-07T12:34:56+00:00 is stored as
|
||||
# {"__ragflow_rdbms_cursor_type__": "datetime", "value": "..."}
|
||||
# Only datetime needs wrapping because connector config is JSON.
|
||||
if isinstance(value, datetime):
|
||||
return {
|
||||
"__ragflow_rdbms_cursor_type__": "datetime",
|
||||
"value": value.isoformat(),
|
||||
}
|
||||
return value
|
||||
|
||||
|
||||
@staticmethod
|
||||
def deserialize_cursor_value(value: Any) -> Any:
|
||||
# Reverse the datetime wrapper above.
|
||||
# Non-datetime cursors such as int/str/float are returned as-is.
|
||||
if (
|
||||
isinstance(value, dict)
|
||||
and value.get("__ragflow_rdbms_cursor_type__") == "datetime"
|
||||
):
|
||||
return datetime.fromisoformat(value["value"])
|
||||
return value
|
||||
|
||||
|
||||
def _format_sql_value(self, value: Any) -> str:
|
||||
if isinstance(value, datetime):
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
if self.db_type == DatabaseType.MYSQL:
|
||||
rendered = value.astimezone(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
|
||||
else:
|
||||
rendered = value.astimezone(timezone.utc).isoformat()
|
||||
return f"'{rendered}'"
|
||||
if isinstance(value, bool):
|
||||
if self.db_type == DatabaseType.POSTGRESQL:
|
||||
return "TRUE" if value else "FALSE"
|
||||
return "1" if value else "0"
|
||||
if isinstance(value, (int, float)):
|
||||
return str(value)
|
||||
if isinstance(value, str):
|
||||
return "'" + value.replace("'", "''") + "'"
|
||||
raise ConnectorValidationError(
|
||||
f"Unsupported timestamp cursor value type: {type(value).__name__}"
|
||||
)
|
||||
|
||||
|
||||
def _build_time_filtered_query(
|
||||
self,
|
||||
base_query: str,
|
||||
start: Any = None,
|
||||
end: Any = None,
|
||||
) -> str:
|
||||
if not self.timestamp_column or (start is None and end is None):
|
||||
return self._wrap_query(base_query)
|
||||
|
||||
conditions = []
|
||||
if start is not None:
|
||||
conditions.append(
|
||||
f"ragflow_src.{self.timestamp_column} > {self._format_sql_value(start)}"
|
||||
)
|
||||
if end is not None:
|
||||
conditions.append(
|
||||
f"ragflow_src.{self.timestamp_column} <= {self._format_sql_value(end)}"
|
||||
)
|
||||
|
||||
query = self._wrap_query(base_query)
|
||||
if conditions:
|
||||
query = f"{query} WHERE {' AND '.join(conditions)}"
|
||||
return query
|
||||
|
||||
|
||||
def _build_max_timestamp_query(self, base_query: str) -> str:
|
||||
return (
|
||||
f"SELECT MAX(ragflow_src.{self.timestamp_column}) "
|
||||
f"FROM ({base_query}) AS ragflow_src"
|
||||
)
|
||||
|
||||
|
||||
def _build_slim_query(self, base_query: str) -> str:
|
||||
columns = [self.id_column] if self.id_column else self.content_columns
|
||||
select_clause = ", ".join(f"ragflow_src.{column}" for column in columns)
|
||||
return self._wrap_query(base_query, select_clause)
|
||||
|
||||
|
||||
def _build_content(self, row_dict: Dict[str, Any]) -> str:
|
||||
content_parts = []
|
||||
for col in self.content_columns:
|
||||
if col in row_dict and row_dict[col] is not None:
|
||||
value = row_dict[col]
|
||||
if isinstance(value, (dict, list)):
|
||||
value = json.dumps(value, ensure_ascii=False)
|
||||
# Use brackets around field name and put value on a new line
|
||||
# so that TxtParser preserves field boundaries after chunking.
|
||||
content_parts.append(f"【{col}】:\n{value}")
|
||||
|
||||
content = "\n\n".join(content_parts)
|
||||
|
||||
if self.id_column and self.id_column in row_dict:
|
||||
doc_id = f"{self.db_type}:{self.database}:{row_dict[self.id_column]}"
|
||||
else:
|
||||
content_hash = hashlib.md5(content.encode()).hexdigest()
|
||||
doc_id = f"{self.db_type}:{self.database}:{content_hash}"
|
||||
|
||||
if col not in row_dict or row_dict[col] is None:
|
||||
continue
|
||||
value = row_dict[col]
|
||||
if isinstance(value, (dict, list)):
|
||||
value = json.dumps(value, ensure_ascii=False)
|
||||
content_parts.append(f"【{col}】:\n{value}")
|
||||
return "\n\n".join(content_parts)
|
||||
|
||||
|
||||
def _build_document_id_from_row(self, row_dict: Dict[str, Any]) -> str:
|
||||
if self.id_column and self.id_column in row_dict and row_dict[self.id_column] is not None:
|
||||
return f"{self.db_type}:{self.database}:{row_dict[self.id_column]}"
|
||||
content = self._build_content(row_dict)
|
||||
content_hash = hashlib.md5(content.encode()).hexdigest()
|
||||
return f"{self.db_type}:{self.database}:{content_hash}"
|
||||
|
||||
|
||||
def _row_to_document(
|
||||
self,
|
||||
row: Union[tuple, list, Dict[str, Any]],
|
||||
column_names: list[str],
|
||||
) -> Document:
|
||||
"""Convert a database row to a Document."""
|
||||
row_dict = dict(zip(column_names, row)) if isinstance(row, (list, tuple)) else row
|
||||
content = self._build_content(row_dict)
|
||||
metadata = {}
|
||||
for col in self.metadata_columns:
|
||||
if col in row_dict and row_dict[col] is not None:
|
||||
value = row_dict[col]
|
||||
if isinstance(value, datetime):
|
||||
value = value.isoformat()
|
||||
elif isinstance(value, (dict, list)):
|
||||
value = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
value = str(value)
|
||||
metadata[col] = value
|
||||
|
||||
if col not in row_dict or row_dict[col] is None:
|
||||
continue
|
||||
value = row_dict[col]
|
||||
if isinstance(value, datetime):
|
||||
value = value.isoformat()
|
||||
elif isinstance(value, (dict, list)):
|
||||
value = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
value = str(value)
|
||||
metadata[col] = value
|
||||
|
||||
doc_updated_at = datetime.now(timezone.utc)
|
||||
if self.timestamp_column and self.timestamp_column in row_dict:
|
||||
if self.timestamp_column and self.timestamp_column in row_dict and row_dict[self.timestamp_column] is not None:
|
||||
ts_value = row_dict[self.timestamp_column]
|
||||
if isinstance(ts_value, datetime):
|
||||
if ts_value.tzinfo is None:
|
||||
doc_updated_at = ts_value.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
doc_updated_at = ts_value
|
||||
|
||||
first_content_col = self.content_columns[0] if self.content_columns else "record"
|
||||
semantic_id = str(row_dict.get(first_content_col, "database_record")).replace("\n", " ").replace("\r", " ").strip()[:100]
|
||||
doc_updated_at = ts_value.astimezone(timezone.utc)
|
||||
|
||||
first_content_col = self.content_columns[0] if self.content_columns else "record"
|
||||
semantic_id = (
|
||||
str(row_dict.get(first_content_col, "database_record"))
|
||||
.replace("\n", " ")
|
||||
.replace("\r", " ")
|
||||
.strip()[:100]
|
||||
)
|
||||
blob = content.encode("utf-8")
|
||||
|
||||
|
||||
return Document(
|
||||
id=doc_id,
|
||||
blob=content.encode("utf-8"),
|
||||
id=self._build_document_id_from_row(row_dict),
|
||||
blob=blob,
|
||||
source=DocumentSource(self.db_type.value),
|
||||
semantic_identifier=semantic_id,
|
||||
extension=".txt",
|
||||
doc_updated_at=doc_updated_at,
|
||||
size_bytes=len(content.encode("utf-8")),
|
||||
size_bytes=len(blob),
|
||||
metadata=metadata if metadata else None,
|
||||
)
|
||||
|
||||
|
||||
def _yield_documents_from_query(
|
||||
self,
|
||||
query: str,
|
||||
@@ -288,30 +377,146 @@ class RDBMSConnector(LoadConnector, PollConnector):
|
||||
pass
|
||||
cursor.close()
|
||||
|
||||
|
||||
def _yield_slim_documents_from_query(
|
||||
self,
|
||||
query: str,
|
||||
) -> Generator[list[SlimDocument], None, None]:
|
||||
connection = self._get_connection()
|
||||
cursor = connection.cursor()
|
||||
|
||||
try:
|
||||
logging.debug(f"Executing slim query: {query[:200]}...")
|
||||
cursor.execute(query)
|
||||
column_names = [desc[0] for desc in cursor.description]
|
||||
|
||||
batch: list[SlimDocument] = []
|
||||
for row in cursor:
|
||||
row_dict = dict(zip(column_names, row)) if isinstance(row, (list, tuple)) else row
|
||||
batch.append(SlimDocument(id=self._build_document_id_from_row(row_dict)))
|
||||
if len(batch) >= self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
finally:
|
||||
try:
|
||||
cursor.fetchall()
|
||||
except Exception:
|
||||
pass
|
||||
cursor.close()
|
||||
|
||||
|
||||
def get_max_cursor_value(self) -> Any:
|
||||
if not self.timestamp_column:
|
||||
return None
|
||||
|
||||
max_cursor_value = None
|
||||
connection = self._get_connection()
|
||||
cursor = connection.cursor()
|
||||
|
||||
try:
|
||||
for base_query in self._get_base_queries():
|
||||
query = self._build_max_timestamp_query(base_query)
|
||||
logging.debug(f"Executing max timestamp query: {query[:200]}...")
|
||||
cursor.execute(query)
|
||||
row = cursor.fetchone()
|
||||
if row is None or row[0] is None:
|
||||
continue
|
||||
if max_cursor_value is None or row[0] > max_cursor_value:
|
||||
max_cursor_value = row[0]
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
return max_cursor_value
|
||||
|
||||
|
||||
def _yield_documents(
|
||||
self,
|
||||
start: Optional[datetime] = None,
|
||||
end: Optional[datetime] = None,
|
||||
start: Any = None,
|
||||
end: Any = None,
|
||||
) -> Generator[list[Document], None, None]:
|
||||
"""Generate documents from database query results."""
|
||||
if self.query:
|
||||
query = self._build_query_with_time_filter(start, end)
|
||||
yield from self._yield_documents_from_query(query)
|
||||
else:
|
||||
tables = self._get_tables()
|
||||
logging.info(f"No query specified. Loading all {len(tables)} tables: {tables}")
|
||||
for table in tables:
|
||||
query = f"SELECT * FROM {table}"
|
||||
logging.info(f"Loading table: {table}")
|
||||
base_queries = self._get_base_queries()
|
||||
if not self.query:
|
||||
logging.info(f"No query specified. Loading all {len(base_queries)} tables.")
|
||||
|
||||
try:
|
||||
for base_query in base_queries:
|
||||
query = self._build_time_filtered_query(base_query, start, end)
|
||||
yield from self._yield_documents_from_query(query)
|
||||
|
||||
self._close_connection()
|
||||
finally:
|
||||
self._close_connection()
|
||||
|
||||
|
||||
def load_from_state(self) -> Generator[list[Document], None, None]:
|
||||
"""Load all documents from the database (full sync)."""
|
||||
logging.debug(f"Loading all records from {self.db_type} database: {self.database}")
|
||||
return self._yield_documents()
|
||||
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
callback: Any = None,
|
||||
) -> Generator[list[SlimDocument], None, None]:
|
||||
del callback
|
||||
|
||||
base_queries = self._get_base_queries()
|
||||
if not self.query:
|
||||
logging.info(f"No query specified. Retrieving slim documents from all {len(base_queries)} tables.")
|
||||
|
||||
try:
|
||||
for base_query in base_queries:
|
||||
yield from self._yield_slim_documents_from_query(
|
||||
self._build_slim_query(base_query)
|
||||
)
|
||||
finally:
|
||||
self._close_connection()
|
||||
|
||||
def prepare_sync_state(self, connector_id: str, config: Dict[str, Any]) -> None:
|
||||
self._sync_connector_id = connector_id
|
||||
self._sync_config = copy.deepcopy(config)
|
||||
if not self.timestamp_column:
|
||||
self._pending_sync_cursor_value = None
|
||||
return
|
||||
self._pending_sync_cursor_value = self.get_max_cursor_value()
|
||||
|
||||
|
||||
def get_saved_sync_cursor_value(self) -> Any:
|
||||
if self._sync_config is None:
|
||||
return None
|
||||
return self.deserialize_cursor_value(self._sync_config.get("sync_cursor_value"))
|
||||
|
||||
|
||||
def persist_sync_state(self) -> None:
|
||||
if not self.timestamp_column or self._sync_connector_id is None or self._sync_config is None:
|
||||
return
|
||||
|
||||
from api.db.services.connector_service import ConnectorService
|
||||
|
||||
updated_conf = copy.deepcopy(self._sync_config)
|
||||
updated_conf["sync_cursor_value"] = self.serialize_cursor_value(
|
||||
self._pending_sync_cursor_value
|
||||
)
|
||||
ConnectorService.update_by_id(self._sync_connector_id, {"config": updated_conf})
|
||||
self._sync_config = updated_conf
|
||||
|
||||
|
||||
def load_from_cursor_range(
|
||||
self,
|
||||
start_value: Any = None,
|
||||
end_value: Any = None,
|
||||
) -> Generator[list[Document], None, None]:
|
||||
if end_value is None:
|
||||
self._close_connection()
|
||||
return iter(())
|
||||
if start_value is not None and end_value <= start_value:
|
||||
self._close_connection()
|
||||
return iter(())
|
||||
return self._yield_documents(start_value, end_value)
|
||||
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> Generator[list[Document], None, None]:
|
||||
@@ -322,16 +527,8 @@ class RDBMSConnector(LoadConnector, PollConnector):
|
||||
"Falling back to full sync."
|
||||
)
|
||||
return self.load_from_state()
|
||||
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
logging.debug(
|
||||
f"Polling {self.db_type} database {self.database} "
|
||||
f"from {start_datetime} to {end_datetime}"
|
||||
)
|
||||
|
||||
return self._yield_documents(start_datetime, end_datetime)
|
||||
return self._yield_documents(start, end)
|
||||
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate connector settings by testing the connection."""
|
||||
|
||||
@@ -253,6 +253,7 @@ class SyncBase:
|
||||
and task.get("poll_range_start")
|
||||
and self.conf.get("sync_deleted_files")
|
||||
)
|
||||
cleanup_errors = []
|
||||
if expects_deleted_file_snapshot and file_list is None:
|
||||
logging.warning(
|
||||
"%s deleted-file snapshot retrieval failed "
|
||||
@@ -261,16 +262,8 @@ class SyncBase:
|
||||
task["connector_id"],
|
||||
task["kb_id"],
|
||||
)
|
||||
elif file_list:
|
||||
logging.info(
|
||||
"[%s] Starting stale document reconciliation. Snapshot size: %d "
|
||||
"(connector_id=%s, kb_id=%s)",
|
||||
self.SOURCE_NAME,
|
||||
len(file_list),
|
||||
task["connector_id"],
|
||||
task["kb_id"],
|
||||
)
|
||||
removed_docs, _ = ConnectorService.cleanup_stale_documents_for_task(
|
||||
elif file_list is not None:
|
||||
removed_docs, cleanup_errors = ConnectorService.cleanup_stale_documents_for_task(
|
||||
task["id"],
|
||||
task["connector_id"],
|
||||
task["kb_id"],
|
||||
@@ -288,6 +281,13 @@ class SyncBase:
|
||||
summary = f"{summary}, skipped={failed_docs}"
|
||||
logging.info(summary)
|
||||
|
||||
if (
|
||||
isinstance(self, _RDBMSBase)
|
||||
and failed_docs == 0
|
||||
and (not expects_deleted_file_snapshot or file_list is not None)
|
||||
and not cleanup_errors
|
||||
):
|
||||
self.connector.persist_sync_state()
|
||||
SyncLogsService.done(task["id"], task["connector_id"])
|
||||
task["poll_range_start"] = next_update
|
||||
|
||||
@@ -937,14 +937,6 @@ class WebDAV(SyncBase):
|
||||
end_ts = datetime.now(timezone.utc).timestamp()
|
||||
if self.conf.get("sync_deleted_files"):
|
||||
file_list = []
|
||||
logging.info(
|
||||
"WebDAV: fetching slim snapshot for stale-document reconciliation "
|
||||
"(connector_id=%s, kb_id=%s, base_url=%s, path=%s)",
|
||||
task["connector_id"],
|
||||
task["kb_id"],
|
||||
self.conf["base_url"],
|
||||
self.conf.get("remote_path", "/"),
|
||||
)
|
||||
try:
|
||||
for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync():
|
||||
file_list.extend(slim_batch)
|
||||
@@ -1560,14 +1552,6 @@ class SeaFile(SyncBase):
|
||||
end_ts = datetime.now(timezone.utc).timestamp()
|
||||
if self.conf.get("sync_deleted_files"):
|
||||
file_list = []
|
||||
logging.info(
|
||||
"SeaFile: fetching slim snapshot for stale-document reconciliation "
|
||||
"(connector_id=%s, kb_id=%s, scope=%s)",
|
||||
task["connector_id"],
|
||||
task["kb_id"],
|
||||
conf.get("sync_scope")
|
||||
or SeafileSyncScope.ACCOUNT.value,
|
||||
)
|
||||
try:
|
||||
for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync():
|
||||
file_list.extend(slim_batch)
|
||||
@@ -1668,82 +1652,74 @@ class DingTalkAITable(SyncBase):
|
||||
return document_generator, file_list
|
||||
|
||||
|
||||
class MySQL(SyncBase):
|
||||
class _RDBMSBase(SyncBase):
|
||||
DB_TYPE: str = ""
|
||||
LOG_NAME: str = ""
|
||||
DEFAULT_PORT: int = 0
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
self.connector = RDBMSConnector(
|
||||
db_type=self.DB_TYPE,
|
||||
host=self.conf.get("host", "localhost"),
|
||||
port=int(self.conf.get("port", self.DEFAULT_PORT)),
|
||||
database=self.conf.get("database", ""),
|
||||
query=self.conf.get("query", ""),
|
||||
content_columns=self.conf.get("content_columns", ""),
|
||||
metadata_columns=self.conf.get("metadata_columns", ""),
|
||||
id_column=self.conf.get("id_column") or None,
|
||||
timestamp_column=self.conf.get("timestamp_column") or None,
|
||||
batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE),
|
||||
)
|
||||
|
||||
credentials = self.conf.get("credentials")
|
||||
if not credentials:
|
||||
raise ValueError(f"{self.DB_TYPE} connector is missing credentials.")
|
||||
|
||||
self.connector.load_credentials(credentials)
|
||||
self.connector.validate_connector_settings()
|
||||
self.connector.prepare_sync_state(task["connector_id"], self.conf)
|
||||
|
||||
file_list = None
|
||||
if (
|
||||
task["reindex"] != "1"
|
||||
and task["poll_range_start"]
|
||||
and self.conf.get("sync_deleted_files")
|
||||
):
|
||||
file_list = []
|
||||
for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync():
|
||||
file_list.extend(slim_batch)
|
||||
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]:
|
||||
document_generator = self.connector.load_from_state()
|
||||
_begin_info = "totally"
|
||||
elif not self.connector.timestamp_column:
|
||||
document_generator = self.connector.load_from_state()
|
||||
_begin_info = f"from {task['poll_range_start']}"
|
||||
else:
|
||||
poll_start = task["poll_range_start"]
|
||||
start_cursor_value = self.connector.get_saved_sync_cursor_value()
|
||||
document_generator = self.connector.load_from_cursor_range(
|
||||
start_cursor_value,
|
||||
self.connector._pending_sync_cursor_value,
|
||||
)
|
||||
_begin_info = f"from {poll_start}"
|
||||
|
||||
self.log_connection(self.LOG_NAME, f"{self.conf.get('host')}:{self.conf.get('database')}", task)
|
||||
return document_generator, file_list
|
||||
|
||||
|
||||
class MySQL(_RDBMSBase):
|
||||
SOURCE_NAME: str = FileSource.MYSQL
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
self.connector = RDBMSConnector(
|
||||
db_type="mysql",
|
||||
host=self.conf.get("host", "localhost"),
|
||||
port=int(self.conf.get("port", 3306)),
|
||||
database=self.conf.get("database", ""),
|
||||
query=self.conf.get("query", ""),
|
||||
content_columns=self.conf.get("content_columns", ""),
|
||||
metadata_columns=self.conf.get("metadata_columns", ""),
|
||||
id_column=self.conf.get("id_column") or None,
|
||||
timestamp_column=self.conf.get("timestamp_column") or None,
|
||||
batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE),
|
||||
)
|
||||
|
||||
credentials = self.conf.get("credentials")
|
||||
if not credentials:
|
||||
raise ValueError("MySQL connector is missing credentials.")
|
||||
|
||||
self.connector.load_credentials(credentials)
|
||||
self.connector.validate_connector_settings()
|
||||
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]:
|
||||
document_generator = self.connector.load_from_state()
|
||||
_begin_info = "totally"
|
||||
else:
|
||||
poll_start = task["poll_range_start"]
|
||||
document_generator = self.connector.poll_source(
|
||||
poll_start.timestamp(),
|
||||
datetime.now(timezone.utc).timestamp()
|
||||
)
|
||||
_begin_info = f"from {poll_start}"
|
||||
|
||||
self.log_connection("MySQL", f"{self.conf.get('host')}:{self.conf.get('database')}", task)
|
||||
return document_generator
|
||||
DB_TYPE: str = "mysql"
|
||||
LOG_NAME: str = "MySQL"
|
||||
DEFAULT_PORT: int = 3306
|
||||
|
||||
|
||||
class PostgreSQL(SyncBase):
|
||||
class PostgreSQL(_RDBMSBase):
|
||||
SOURCE_NAME: str = FileSource.POSTGRESQL
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
self.connector = RDBMSConnector(
|
||||
db_type="postgresql",
|
||||
host=self.conf.get("host", "localhost"),
|
||||
port=int(self.conf.get("port", 5432)),
|
||||
database=self.conf.get("database", ""),
|
||||
query=self.conf.get("query", ""),
|
||||
content_columns=self.conf.get("content_columns", ""),
|
||||
metadata_columns=self.conf.get("metadata_columns", ""),
|
||||
id_column=self.conf.get("id_column") or None,
|
||||
timestamp_column=self.conf.get("timestamp_column") or None,
|
||||
batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE),
|
||||
)
|
||||
|
||||
credentials = self.conf.get("credentials")
|
||||
if not credentials:
|
||||
raise ValueError("PostgreSQL connector is missing credentials.")
|
||||
|
||||
self.connector.load_credentials(credentials)
|
||||
self.connector.validate_connector_settings()
|
||||
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]:
|
||||
document_generator = self.connector.load_from_state()
|
||||
_begin_info = "totally"
|
||||
else:
|
||||
poll_start = task["poll_range_start"]
|
||||
document_generator = self.connector.poll_source(
|
||||
poll_start.timestamp(),
|
||||
datetime.now(timezone.utc).timestamp()
|
||||
)
|
||||
_begin_info = f"from {poll_start}"
|
||||
|
||||
self.log_connection("PostgreSQL", f"{self.conf.get('host')}:{self.conf.get('database')}", task)
|
||||
return document_generator
|
||||
DB_TYPE: str = "postgresql"
|
||||
LOG_NAME: str = "PostgreSQL"
|
||||
DEFAULT_PORT: int = 5432
|
||||
|
||||
|
||||
func_factory = {
|
||||
|
||||
@@ -95,6 +95,18 @@ class _FakeSync(sync_data_source.SyncBase):
|
||||
return self._generate_output
|
||||
|
||||
|
||||
def _make_fake_doc(doc_id="doc-1", updated_at=None):
|
||||
return types.SimpleNamespace(
|
||||
id=doc_id,
|
||||
semantic_identifier=doc_id,
|
||||
extension=".txt",
|
||||
size_bytes=1,
|
||||
doc_updated_at=updated_at or datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
blob=b"x",
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_task():
|
||||
return {
|
||||
"id": "task-1",
|
||||
@@ -121,19 +133,35 @@ def _patch_common_dependencies(monkeypatch):
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.p2
|
||||
async def test_run_task_logic_skips_cleanup_for_empty_snapshot(monkeypatch):
|
||||
async def test_run_task_logic_cleans_up_for_empty_snapshot(monkeypatch):
|
||||
cleanup_calls = []
|
||||
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
|
||||
def _fake_cleanup(*args, **kwargs):
|
||||
cleanup_calls.append((args, kwargs))
|
||||
return 1, []
|
||||
|
||||
monkeypatch.setattr(
|
||||
sync_data_source.ConnectorService,
|
||||
"cleanup_stale_documents_for_task",
|
||||
lambda *_args, **_kwargs: cleanup_calls.append((_args, _kwargs)),
|
||||
_fake_cleanup,
|
||||
)
|
||||
|
||||
await _FakeSync((iter(()), []))._run_task_logic(_make_task())
|
||||
|
||||
assert cleanup_calls == []
|
||||
assert cleanup_calls == [
|
||||
(
|
||||
(
|
||||
"task-1",
|
||||
"connector-1",
|
||||
"kb-1",
|
||||
"tenant-1",
|
||||
[],
|
||||
),
|
||||
{},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -170,6 +198,203 @@ async def test_run_task_logic_cleans_up_for_non_empty_snapshot(monkeypatch):
|
||||
]
|
||||
|
||||
|
||||
class _FakeRDBMSConnector:
|
||||
instance = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_type,
|
||||
host,
|
||||
port,
|
||||
database,
|
||||
query,
|
||||
content_columns,
|
||||
metadata_columns=None,
|
||||
id_column=None,
|
||||
timestamp_column=None,
|
||||
batch_size=2,
|
||||
):
|
||||
self.db_type = db_type
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.database = database
|
||||
self.query = query
|
||||
self.content_columns = content_columns
|
||||
self.metadata_columns = metadata_columns
|
||||
self.id_column = id_column
|
||||
self.timestamp_column = timestamp_column
|
||||
self.batch_size = batch_size
|
||||
self.load_from_state_called = False
|
||||
self.retrieve_all_slim_docs_perm_sync_called = False
|
||||
self.prepare_sync_state_called = False
|
||||
self.load_from_cursor_range_called = False
|
||||
self.persist_sync_state_called = False
|
||||
self._pending_sync_cursor_value = None
|
||||
_FakeRDBMSConnector.instance = self
|
||||
|
||||
def load_credentials(self, credentials):
|
||||
self.credentials = credentials
|
||||
|
||||
def validate_connector_settings(self):
|
||||
return None
|
||||
|
||||
def prepare_sync_state(self, connector_id, config):
|
||||
self.prepare_sync_state_called = True
|
||||
self.prepare_sync_state_args = (connector_id, config)
|
||||
|
||||
def get_saved_sync_cursor_value(self):
|
||||
return None
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(self, callback=None):
|
||||
del callback
|
||||
self.retrieve_all_slim_docs_perm_sync_called = True
|
||||
yield [types.SimpleNamespace(id="row-1")]
|
||||
|
||||
def load_from_state(self):
|
||||
self.load_from_state_called = True
|
||||
return iter((["full-sync"],))
|
||||
|
||||
def load_from_cursor_range(self, start_value=None, end_value=None):
|
||||
self.load_from_cursor_range_called = True
|
||||
return iter(([ _make_fake_doc("incremental-doc") ],))
|
||||
|
||||
def persist_sync_state(self):
|
||||
self.persist_sync_state_called = True
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.p2
|
||||
async def test_rdbms_generate_keeps_deleted_file_snapshot_without_timestamp_column(monkeypatch):
|
||||
monkeypatch.setattr(sync_data_source, "RDBMSConnector", _FakeRDBMSConnector)
|
||||
|
||||
task = {
|
||||
**_make_task(),
|
||||
"reindex": "0",
|
||||
"poll_range_start": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"skip_connection_log": True,
|
||||
}
|
||||
sync = sync_data_source.MySQL(
|
||||
{
|
||||
"host": "localhost",
|
||||
"port": 3306,
|
||||
"database": "db",
|
||||
"query": "SELECT * FROM t",
|
||||
"content_columns": "name",
|
||||
"credentials": {"username": "u", "password": "p"},
|
||||
"sync_deleted_files": True,
|
||||
}
|
||||
)
|
||||
|
||||
document_generator, file_list = await sync._generate(task)
|
||||
connector = _FakeRDBMSConnector.instance
|
||||
|
||||
assert connector is not None
|
||||
assert connector.load_from_state_called is True
|
||||
assert connector.load_from_cursor_range_called is False
|
||||
assert connector.retrieve_all_slim_docs_perm_sync_called is True
|
||||
assert file_list is not None
|
||||
assert [doc.id for doc in file_list] == ["row-1"]
|
||||
assert list(document_generator) == [["full-sync"]]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.p2
|
||||
async def test_rdbms_cursor_persists_only_after_success(monkeypatch):
|
||||
monkeypatch.setattr(sync_data_source, "RDBMSConnector", _FakeRDBMSConnector)
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
sync_data_source.KnowledgebaseService,
|
||||
"get_by_id",
|
||||
lambda *_args, **_kwargs: (True, object()),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
sync_data_source.SyncLogsService,
|
||||
"increase_docs",
|
||||
lambda *_args, **_kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
sync_data_source.SyncLogsService,
|
||||
"duplicate_and_parse",
|
||||
lambda *_args, **_kwargs: ([], ["parsed-doc-id"]),
|
||||
)
|
||||
|
||||
task = {
|
||||
**_make_task(),
|
||||
"reindex": "0",
|
||||
"poll_range_start": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"skip_connection_log": True,
|
||||
}
|
||||
sync = sync_data_source.MySQL(
|
||||
{
|
||||
"host": "localhost",
|
||||
"port": 3306,
|
||||
"database": "db",
|
||||
"query": "SELECT * FROM t",
|
||||
"content_columns": "name",
|
||||
"timestamp_column": "ts",
|
||||
"credentials": {"username": "u", "password": "p"},
|
||||
"sync_deleted_files": False,
|
||||
}
|
||||
)
|
||||
|
||||
await sync._run_task_logic(task)
|
||||
|
||||
connector = _FakeRDBMSConnector.instance
|
||||
assert connector is not None
|
||||
assert connector.persist_sync_state_called is True
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.p2
|
||||
async def test_rdbms_cursor_does_not_persist_when_batch_is_skipped(monkeypatch):
|
||||
monkeypatch.setattr(sync_data_source, "RDBMSConnector", _FakeRDBMSConnector)
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
sync_data_source.KnowledgebaseService,
|
||||
"get_by_id",
|
||||
lambda *_args, **_kwargs: (True, object()),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
sync_data_source.SyncLogsService,
|
||||
"increase_docs",
|
||||
lambda *_args, **_kwargs: None,
|
||||
)
|
||||
|
||||
def _raise_in_duplicate_and_parse(*_args, **_kwargs):
|
||||
raise RuntimeError("batch failed")
|
||||
|
||||
monkeypatch.setattr(
|
||||
sync_data_source.SyncLogsService,
|
||||
"duplicate_and_parse",
|
||||
_raise_in_duplicate_and_parse,
|
||||
)
|
||||
|
||||
task = {
|
||||
**_make_task(),
|
||||
"reindex": "0",
|
||||
"poll_range_start": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"skip_connection_log": True,
|
||||
}
|
||||
sync = sync_data_source.MySQL(
|
||||
{
|
||||
"host": "localhost",
|
||||
"port": 3306,
|
||||
"database": "db",
|
||||
"query": "SELECT * FROM t",
|
||||
"content_columns": "name",
|
||||
"timestamp_column": "ts",
|
||||
"credentials": {"username": "u", "password": "p"},
|
||||
"sync_deleted_files": False,
|
||||
}
|
||||
)
|
||||
|
||||
await sync._run_task_logic(task)
|
||||
|
||||
connector = _FakeRDBMSConnector.instance
|
||||
assert connector is not None
|
||||
assert connector.persist_sync_state_called is False
|
||||
|
||||
|
||||
class _FakeDropboxConnector:
|
||||
instance = None
|
||||
|
||||
|
||||
@@ -126,6 +126,12 @@ export const DataSourceFeatureVisibilityMap: Partial<
|
||||
[DataSourceKey.RSS]: {
|
||||
syncDeletedFiles: true,
|
||||
},
|
||||
[DataSourceKey.MYSQL]: {
|
||||
syncDeletedFiles: true,
|
||||
},
|
||||
[DataSourceKey.POSTGRESQL]: {
|
||||
syncDeletedFiles: true,
|
||||
},
|
||||
};
|
||||
|
||||
const isDataSourceFeatureVisible = (
|
||||
|
||||
Reference in New Issue
Block a user