mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Feat: enable sync deleted files for RDBMS & fix remove last file issue (#14615)
### What problem does this PR solve? Feat: enable sync deleted files for RDBMS & fix remove last file issue ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""RDBMS (MySQL/PostgreSQL) data source connector for importing data from relational databases."""
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
@@ -12,8 +13,13 @@ from common.data_source.exceptions import (
|
||||
ConnectorMissingCredentialError,
|
||||
ConnectorValidationError,
|
||||
)
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
|
||||
from common.data_source.models import Document
|
||||
from common.data_source.interfaces import (
|
||||
LoadConnector,
|
||||
PollConnector,
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync,
|
||||
)
|
||||
from common.data_source.models import Document, SlimDocument
|
||||
|
||||
|
||||
class DatabaseType(str, Enum):
|
||||
@@ -22,15 +28,18 @@ class DatabaseType(str, Enum):
|
||||
POSTGRESQL = "postgresql"
|
||||
|
||||
|
||||
class RDBMSConnector(LoadConnector, PollConnector):
|
||||
class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""
|
||||
RDBMS connector for importing data from MySQL and PostgreSQL databases.
|
||||
|
||||
This connector allows users to:
|
||||
1. Connect to a MySQL or PostgreSQL database
|
||||
2. Execute a SQL query to extract data
|
||||
3. Map columns to content (for vectorization) and metadata
|
||||
4. Sync data in batch or incremental mode using a timestamp column
|
||||
Import rows from MySQL or PostgreSQL into documents.
|
||||
|
||||
The flow is:
|
||||
1. Connect to the configured database.
|
||||
2. Read rows from a custom SQL query, or from every table when no query is provided.
|
||||
3. Build document content from the selected content columns.
|
||||
4. Copy the selected metadata columns into document metadata.
|
||||
5. Use the configured ID column as the stable document ID, or hash the content when no ID column is set.
|
||||
6. For incremental sync, treat the timestamp column as an ordered cursor and only compare values by size.
|
||||
7. For deleted-file sync, read a slim snapshot of current row IDs and let the sync worker remove stale documents.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
@@ -73,6 +82,9 @@ class RDBMSConnector(LoadConnector, PollConnector):
|
||||
|
||||
self._connection = None
|
||||
self._credentials: Dict[str, Any] = {}
|
||||
self._sync_connector_id: str | None = None
|
||||
self._sync_config: Dict[str, Any] | None = None
|
||||
self._pending_sync_cursor_value: Any = None
|
||||
|
||||
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
|
||||
"""Load database credentials."""
|
||||
@@ -160,98 +172,175 @@ class RDBMSConnector(LoadConnector, PollConnector):
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def _build_query_with_time_filter(
|
||||
self,
|
||||
start: Optional[datetime] = None,
|
||||
end: Optional[datetime] = None,
|
||||
) -> str:
|
||||
"""Build the query with optional time filtering for incremental sync."""
|
||||
if not self.query:
|
||||
return "" # Will be handled by table discovery
|
||||
base_query = self.query.rstrip(";")
|
||||
|
||||
if not self.timestamp_column or (start is None and end is None):
|
||||
return base_query
|
||||
|
||||
has_where = "where" in base_query.lower()
|
||||
connector = " AND" if has_where else " WHERE"
|
||||
|
||||
time_conditions = []
|
||||
if start is not None:
|
||||
if self.db_type == DatabaseType.MYSQL:
|
||||
time_conditions.append(f"{self.timestamp_column} > '{start.strftime('%Y-%m-%d %H:%M:%S')}'")
|
||||
else:
|
||||
time_conditions.append(f"{self.timestamp_column} > '{start.isoformat()}'")
|
||||
|
||||
if end is not None:
|
||||
if self.db_type == DatabaseType.MYSQL:
|
||||
time_conditions.append(f"{self.timestamp_column} <= '{end.strftime('%Y-%m-%d %H:%M:%S')}'")
|
||||
else:
|
||||
time_conditions.append(f"{self.timestamp_column} <= '{end.isoformat()}'")
|
||||
|
||||
if time_conditions:
|
||||
return f"{base_query}{connector} {' AND '.join(time_conditions)}"
|
||||
|
||||
return base_query
|
||||
|
||||
def _row_to_document(self, row: Union[tuple, list, Dict[str, Any]], column_names: list) -> Document:
|
||||
"""Convert a database row to a Document."""
|
||||
row_dict = dict(zip(column_names, row)) if isinstance(row, (list, tuple)) else row
|
||||
|
||||
def _get_base_queries(self) -> list[str]:
|
||||
if self.query:
|
||||
return [self.query.rstrip(";")]
|
||||
return [f"SELECT * FROM {table}" for table in self._get_tables()]
|
||||
|
||||
|
||||
def _wrap_query(self, base_query: str, select_clause: str = "*") -> str:
|
||||
return f"SELECT {select_clause} FROM ({base_query}) AS ragflow_src"
|
||||
|
||||
|
||||
@staticmethod
|
||||
def serialize_cursor_value(value: Any) -> Any:
|
||||
# Example:
|
||||
# - int cursor 42 is stored as 42
|
||||
# - datetime cursor 2026-05-07T12:34:56+00:00 is stored as
|
||||
# {"__ragflow_rdbms_cursor_type__": "datetime", "value": "..."}
|
||||
# Only datetime needs wrapping because connector config is JSON.
|
||||
if isinstance(value, datetime):
|
||||
return {
|
||||
"__ragflow_rdbms_cursor_type__": "datetime",
|
||||
"value": value.isoformat(),
|
||||
}
|
||||
return value
|
||||
|
||||
|
||||
@staticmethod
|
||||
def deserialize_cursor_value(value: Any) -> Any:
|
||||
# Reverse the datetime wrapper above.
|
||||
# Non-datetime cursors such as int/str/float are returned as-is.
|
||||
if (
|
||||
isinstance(value, dict)
|
||||
and value.get("__ragflow_rdbms_cursor_type__") == "datetime"
|
||||
):
|
||||
return datetime.fromisoformat(value["value"])
|
||||
return value
|
||||
|
||||
|
||||
def _format_sql_value(self, value: Any) -> str:
|
||||
if isinstance(value, datetime):
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
if self.db_type == DatabaseType.MYSQL:
|
||||
rendered = value.astimezone(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
|
||||
else:
|
||||
rendered = value.astimezone(timezone.utc).isoformat()
|
||||
return f"'{rendered}'"
|
||||
if isinstance(value, bool):
|
||||
if self.db_type == DatabaseType.POSTGRESQL:
|
||||
return "TRUE" if value else "FALSE"
|
||||
return "1" if value else "0"
|
||||
if isinstance(value, (int, float)):
|
||||
return str(value)
|
||||
if isinstance(value, str):
|
||||
return "'" + value.replace("'", "''") + "'"
|
||||
raise ConnectorValidationError(
|
||||
f"Unsupported timestamp cursor value type: {type(value).__name__}"
|
||||
)
|
||||
|
||||
|
||||
def _build_time_filtered_query(
|
||||
self,
|
||||
base_query: str,
|
||||
start: Any = None,
|
||||
end: Any = None,
|
||||
) -> str:
|
||||
if not self.timestamp_column or (start is None and end is None):
|
||||
return self._wrap_query(base_query)
|
||||
|
||||
conditions = []
|
||||
if start is not None:
|
||||
conditions.append(
|
||||
f"ragflow_src.{self.timestamp_column} > {self._format_sql_value(start)}"
|
||||
)
|
||||
if end is not None:
|
||||
conditions.append(
|
||||
f"ragflow_src.{self.timestamp_column} <= {self._format_sql_value(end)}"
|
||||
)
|
||||
|
||||
query = self._wrap_query(base_query)
|
||||
if conditions:
|
||||
query = f"{query} WHERE {' AND '.join(conditions)}"
|
||||
return query
|
||||
|
||||
|
||||
def _build_max_timestamp_query(self, base_query: str) -> str:
|
||||
return (
|
||||
f"SELECT MAX(ragflow_src.{self.timestamp_column}) "
|
||||
f"FROM ({base_query}) AS ragflow_src"
|
||||
)
|
||||
|
||||
|
||||
def _build_slim_query(self, base_query: str) -> str:
|
||||
columns = [self.id_column] if self.id_column else self.content_columns
|
||||
select_clause = ", ".join(f"ragflow_src.{column}" for column in columns)
|
||||
return self._wrap_query(base_query, select_clause)
|
||||
|
||||
|
||||
def _build_content(self, row_dict: Dict[str, Any]) -> str:
|
||||
content_parts = []
|
||||
for col in self.content_columns:
|
||||
if col in row_dict and row_dict[col] is not None:
|
||||
value = row_dict[col]
|
||||
if isinstance(value, (dict, list)):
|
||||
value = json.dumps(value, ensure_ascii=False)
|
||||
# Use brackets around field name and put value on a new line
|
||||
# so that TxtParser preserves field boundaries after chunking.
|
||||
content_parts.append(f"【{col}】:\n{value}")
|
||||
|
||||
content = "\n\n".join(content_parts)
|
||||
|
||||
if self.id_column and self.id_column in row_dict:
|
||||
doc_id = f"{self.db_type}:{self.database}:{row_dict[self.id_column]}"
|
||||
else:
|
||||
content_hash = hashlib.md5(content.encode()).hexdigest()
|
||||
doc_id = f"{self.db_type}:{self.database}:{content_hash}"
|
||||
|
||||
if col not in row_dict or row_dict[col] is None:
|
||||
continue
|
||||
value = row_dict[col]
|
||||
if isinstance(value, (dict, list)):
|
||||
value = json.dumps(value, ensure_ascii=False)
|
||||
content_parts.append(f"【{col}】:\n{value}")
|
||||
return "\n\n".join(content_parts)
|
||||
|
||||
|
||||
def _build_document_id_from_row(self, row_dict: Dict[str, Any]) -> str:
|
||||
if self.id_column and self.id_column in row_dict and row_dict[self.id_column] is not None:
|
||||
return f"{self.db_type}:{self.database}:{row_dict[self.id_column]}"
|
||||
content = self._build_content(row_dict)
|
||||
content_hash = hashlib.md5(content.encode()).hexdigest()
|
||||
return f"{self.db_type}:{self.database}:{content_hash}"
|
||||
|
||||
|
||||
def _row_to_document(
|
||||
self,
|
||||
row: Union[tuple, list, Dict[str, Any]],
|
||||
column_names: list[str],
|
||||
) -> Document:
|
||||
"""Convert a database row to a Document."""
|
||||
row_dict = dict(zip(column_names, row)) if isinstance(row, (list, tuple)) else row
|
||||
content = self._build_content(row_dict)
|
||||
metadata = {}
|
||||
for col in self.metadata_columns:
|
||||
if col in row_dict and row_dict[col] is not None:
|
||||
value = row_dict[col]
|
||||
if isinstance(value, datetime):
|
||||
value = value.isoformat()
|
||||
elif isinstance(value, (dict, list)):
|
||||
value = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
value = str(value)
|
||||
metadata[col] = value
|
||||
|
||||
if col not in row_dict or row_dict[col] is None:
|
||||
continue
|
||||
value = row_dict[col]
|
||||
if isinstance(value, datetime):
|
||||
value = value.isoformat()
|
||||
elif isinstance(value, (dict, list)):
|
||||
value = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
value = str(value)
|
||||
metadata[col] = value
|
||||
|
||||
doc_updated_at = datetime.now(timezone.utc)
|
||||
if self.timestamp_column and self.timestamp_column in row_dict:
|
||||
if self.timestamp_column and self.timestamp_column in row_dict and row_dict[self.timestamp_column] is not None:
|
||||
ts_value = row_dict[self.timestamp_column]
|
||||
if isinstance(ts_value, datetime):
|
||||
if ts_value.tzinfo is None:
|
||||
doc_updated_at = ts_value.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
doc_updated_at = ts_value
|
||||
|
||||
first_content_col = self.content_columns[0] if self.content_columns else "record"
|
||||
semantic_id = str(row_dict.get(first_content_col, "database_record")).replace("\n", " ").replace("\r", " ").strip()[:100]
|
||||
doc_updated_at = ts_value.astimezone(timezone.utc)
|
||||
|
||||
first_content_col = self.content_columns[0] if self.content_columns else "record"
|
||||
semantic_id = (
|
||||
str(row_dict.get(first_content_col, "database_record"))
|
||||
.replace("\n", " ")
|
||||
.replace("\r", " ")
|
||||
.strip()[:100]
|
||||
)
|
||||
blob = content.encode("utf-8")
|
||||
|
||||
|
||||
return Document(
|
||||
id=doc_id,
|
||||
blob=content.encode("utf-8"),
|
||||
id=self._build_document_id_from_row(row_dict),
|
||||
blob=blob,
|
||||
source=DocumentSource(self.db_type.value),
|
||||
semantic_identifier=semantic_id,
|
||||
extension=".txt",
|
||||
doc_updated_at=doc_updated_at,
|
||||
size_bytes=len(content.encode("utf-8")),
|
||||
size_bytes=len(blob),
|
||||
metadata=metadata if metadata else None,
|
||||
)
|
||||
|
||||
|
||||
def _yield_documents_from_query(
|
||||
self,
|
||||
query: str,
|
||||
@@ -288,30 +377,146 @@ class RDBMSConnector(LoadConnector, PollConnector):
|
||||
pass
|
||||
cursor.close()
|
||||
|
||||
|
||||
def _yield_slim_documents_from_query(
|
||||
self,
|
||||
query: str,
|
||||
) -> Generator[list[SlimDocument], None, None]:
|
||||
connection = self._get_connection()
|
||||
cursor = connection.cursor()
|
||||
|
||||
try:
|
||||
logging.debug(f"Executing slim query: {query[:200]}...")
|
||||
cursor.execute(query)
|
||||
column_names = [desc[0] for desc in cursor.description]
|
||||
|
||||
batch: list[SlimDocument] = []
|
||||
for row in cursor:
|
||||
row_dict = dict(zip(column_names, row)) if isinstance(row, (list, tuple)) else row
|
||||
batch.append(SlimDocument(id=self._build_document_id_from_row(row_dict)))
|
||||
if len(batch) >= self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
finally:
|
||||
try:
|
||||
cursor.fetchall()
|
||||
except Exception:
|
||||
pass
|
||||
cursor.close()
|
||||
|
||||
|
||||
def get_max_cursor_value(self) -> Any:
|
||||
if not self.timestamp_column:
|
||||
return None
|
||||
|
||||
max_cursor_value = None
|
||||
connection = self._get_connection()
|
||||
cursor = connection.cursor()
|
||||
|
||||
try:
|
||||
for base_query in self._get_base_queries():
|
||||
query = self._build_max_timestamp_query(base_query)
|
||||
logging.debug(f"Executing max timestamp query: {query[:200]}...")
|
||||
cursor.execute(query)
|
||||
row = cursor.fetchone()
|
||||
if row is None or row[0] is None:
|
||||
continue
|
||||
if max_cursor_value is None or row[0] > max_cursor_value:
|
||||
max_cursor_value = row[0]
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
return max_cursor_value
|
||||
|
||||
|
||||
def _yield_documents(
|
||||
self,
|
||||
start: Optional[datetime] = None,
|
||||
end: Optional[datetime] = None,
|
||||
start: Any = None,
|
||||
end: Any = None,
|
||||
) -> Generator[list[Document], None, None]:
|
||||
"""Generate documents from database query results."""
|
||||
if self.query:
|
||||
query = self._build_query_with_time_filter(start, end)
|
||||
yield from self._yield_documents_from_query(query)
|
||||
else:
|
||||
tables = self._get_tables()
|
||||
logging.info(f"No query specified. Loading all {len(tables)} tables: {tables}")
|
||||
for table in tables:
|
||||
query = f"SELECT * FROM {table}"
|
||||
logging.info(f"Loading table: {table}")
|
||||
base_queries = self._get_base_queries()
|
||||
if not self.query:
|
||||
logging.info(f"No query specified. Loading all {len(base_queries)} tables.")
|
||||
|
||||
try:
|
||||
for base_query in base_queries:
|
||||
query = self._build_time_filtered_query(base_query, start, end)
|
||||
yield from self._yield_documents_from_query(query)
|
||||
|
||||
self._close_connection()
|
||||
finally:
|
||||
self._close_connection()
|
||||
|
||||
|
||||
def load_from_state(self) -> Generator[list[Document], None, None]:
|
||||
"""Load all documents from the database (full sync)."""
|
||||
logging.debug(f"Loading all records from {self.db_type} database: {self.database}")
|
||||
return self._yield_documents()
|
||||
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
callback: Any = None,
|
||||
) -> Generator[list[SlimDocument], None, None]:
|
||||
del callback
|
||||
|
||||
base_queries = self._get_base_queries()
|
||||
if not self.query:
|
||||
logging.info(f"No query specified. Retrieving slim documents from all {len(base_queries)} tables.")
|
||||
|
||||
try:
|
||||
for base_query in base_queries:
|
||||
yield from self._yield_slim_documents_from_query(
|
||||
self._build_slim_query(base_query)
|
||||
)
|
||||
finally:
|
||||
self._close_connection()
|
||||
|
||||
def prepare_sync_state(self, connector_id: str, config: Dict[str, Any]) -> None:
|
||||
self._sync_connector_id = connector_id
|
||||
self._sync_config = copy.deepcopy(config)
|
||||
if not self.timestamp_column:
|
||||
self._pending_sync_cursor_value = None
|
||||
return
|
||||
self._pending_sync_cursor_value = self.get_max_cursor_value()
|
||||
|
||||
|
||||
def get_saved_sync_cursor_value(self) -> Any:
|
||||
if self._sync_config is None:
|
||||
return None
|
||||
return self.deserialize_cursor_value(self._sync_config.get("sync_cursor_value"))
|
||||
|
||||
|
||||
def persist_sync_state(self) -> None:
|
||||
if not self.timestamp_column or self._sync_connector_id is None or self._sync_config is None:
|
||||
return
|
||||
|
||||
from api.db.services.connector_service import ConnectorService
|
||||
|
||||
updated_conf = copy.deepcopy(self._sync_config)
|
||||
updated_conf["sync_cursor_value"] = self.serialize_cursor_value(
|
||||
self._pending_sync_cursor_value
|
||||
)
|
||||
ConnectorService.update_by_id(self._sync_connector_id, {"config": updated_conf})
|
||||
self._sync_config = updated_conf
|
||||
|
||||
|
||||
def load_from_cursor_range(
|
||||
self,
|
||||
start_value: Any = None,
|
||||
end_value: Any = None,
|
||||
) -> Generator[list[Document], None, None]:
|
||||
if end_value is None:
|
||||
self._close_connection()
|
||||
return iter(())
|
||||
if start_value is not None and end_value <= start_value:
|
||||
self._close_connection()
|
||||
return iter(())
|
||||
return self._yield_documents(start_value, end_value)
|
||||
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> Generator[list[Document], None, None]:
|
||||
@@ -322,16 +527,8 @@ class RDBMSConnector(LoadConnector, PollConnector):
|
||||
"Falling back to full sync."
|
||||
)
|
||||
return self.load_from_state()
|
||||
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
logging.debug(
|
||||
f"Polling {self.db_type} database {self.database} "
|
||||
f"from {start_datetime} to {end_datetime}"
|
||||
)
|
||||
|
||||
return self._yield_documents(start_datetime, end_datetime)
|
||||
return self._yield_documents(start, end)
|
||||
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate connector settings by testing the connection."""
|
||||
|
||||
Reference in New Issue
Block a user