diff --git a/common/data_source/discord_connector.py b/common/data_source/discord_connector.py index e65a632418..83b2b562f0 100644 --- a/common/data_source/discord_connector.py +++ b/common/data_source/discord_connector.py @@ -13,8 +13,14 @@ from discord.message import Message as DiscordMessage from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource from common.data_source.exceptions import ConnectorMissingCredentialError -from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch -from common.data_source.models import Document, GenerateDocumentsOutput, TextSection +from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync +from common.data_source.models import ( + Document, + GenerateDocumentsOutput, + GenerateSlimDocumentOutput, + SlimDocument, + TextSection, +) _DISCORD_DOC_ID_PREFIX = "DISCORD_" _SNIPPET_LENGTH = 30 @@ -94,8 +100,12 @@ async def _fetch_filtered_channels( async def _fetch_documents_from_channel( channel: TextChannel, start_time: datetime | None, - end_time: datetime | None, -) -> AsyncIterable[Document]: +) -> AsyncIterable[DiscordMessage]: + """Yield raw Discord messages for one channel and its threads. + + This stays at the message layer so callers can decide whether they need + full Document construction or only lightweight ID accounting. + """ # Discord's epoch starts at 2015-01-01 discord_epoch = datetime(2015, 1, 1, tzinfo=timezone.utc) if start_time and start_time < discord_epoch: @@ -109,39 +119,23 @@ async def _fetch_documents_from_channel( async for channel_message in channel.history( limit=None, after=start_time, - before=end_time, ): # Skip messages that are not the default type if channel_message.type != MessageType.default: continue - sections: list[TextSection] = [ - TextSection( - text=channel_message.content, - link=channel_message.jump_url, - ) - ] - - yield _convert_message_to_document(channel_message, sections) + yield channel_message for active_thread in channel.threads: async for thread_message in active_thread.history( limit=None, after=start_time, - before=end_time, ): # Skip messages that are not the default type if thread_message.type != MessageType.default: continue - sections = [ - TextSection( - text=thread_message.content, - link=thread_message.jump_url, - ) - ] - - yield _convert_message_to_document(thread_message, sections) + yield thread_message async for archived_thread in channel.archived_threads( limit=None, @@ -149,20 +143,12 @@ async def _fetch_documents_from_channel( async for thread_message in archived_thread.history( limit=None, after=start_time, - before=end_time, ): # Skip messages that are not the default type if thread_message.type != MessageType.default: continue - sections = [ - TextSection( - text=thread_message.content, - link=thread_message.jump_url, - ) - ] - - yield _convert_message_to_document(thread_message, sections) + yield thread_message def _manage_async_retrieval( @@ -171,20 +157,23 @@ def _manage_async_retrieval( channel_names: list[str], server_ids: list[int], start: datetime | None = None, - end: datetime | None = None, -) -> Iterable[Document]: +) -> Iterable[DiscordMessage]: + """Bridge the async Discord client into a synchronous iterator. + + `start` is only used as a lower bound for the underlying fetch. Callers + that need a narrower time window should apply their own filtering while + iterating so the same full scan can also support deleted-file sync. + """ # parse requested_start_date_string to datetime pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None - # Set start_time to the most recent of start and pull_date, or whichever is provided + # Keep the configured start date as the full-scan lower bound. start_time = max(filter(None, [start, pull_date])) if start or pull_date else None - - end_time: datetime | None = end proxy_url: str | None = os.environ.get("https_proxy") or os.environ.get("http_proxy") if proxy_url: logging.info(f"Using proxy for Discord: {proxy_url}") - async def _async_fetch() -> AsyncIterable[Document]: + async def _async_fetch() -> AsyncIterable[DiscordMessage]: intents = Intents.default() intents.message_content = True async with Client(intents=intents, proxy=proxy_url) as cli: @@ -198,15 +187,13 @@ def _manage_async_retrieval( ) for channel in filtered_channels: - async for doc in _fetch_documents_from_channel( + async for message in _fetch_documents_from_channel( channel=channel, start_time=start_time, - end_time=end_time, ): - print(doc) - yield doc + yield message - def run_and_yield() -> Iterable[Document]: + def run_and_yield() -> Iterable[DiscordMessage]: loop = asyncio.new_event_loop() try: # Get the async generator @@ -228,7 +215,7 @@ def _manage_async_retrieval( return run_and_yield() -class DiscordConnector(LoadConnector, PollConnector): +class DiscordConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """Discord connector for accessing Discord messages and channels""" def __init__( @@ -251,12 +238,28 @@ class DiscordConnector(LoadConnector, PollConnector): raise ConnectorMissingCredentialError("Discord") return self._discord_bot_token - def _manage_doc_batching( + def _iter_merged_documents( self, start: datetime | None = None, end: datetime | None = None, ) -> GenerateDocumentsOutput: - doc_batch = [] + """Build merged Discord documents for the requested polling window.""" + doc_batch: list[Document] = [] + + def _message_created_at(message: DiscordMessage) -> datetime: + created_at = message.created_at + if created_at.tzinfo is None: + return created_at.replace(tzinfo=timezone.utc) + return created_at.astimezone(timezone.utc) + + def _is_in_window(message: DiscordMessage) -> bool: + created_at = _message_created_at(message) + if start is not None and created_at < start: + return False + if end is not None and created_at >= end: + return False + return True + def merge_batch(): nonlocal doc_batch id = doc_batch[0].id @@ -280,14 +283,23 @@ class DiscordConnector(LoadConnector, PollConnector): size_bytes=size_bytes, ) - for doc in _manage_async_retrieval( + for message in _manage_async_retrieval( token=self.discord_bot_token, requested_start_date_string=self.requested_start_date_string, channel_names=self.channel_names, server_ids=self.server_ids, start=start, - end=end, ): + if not _is_in_window(message): + continue + + sections = [ + TextSection( + text=message.content, + link=message.jump_url, + ) + ] + doc = _convert_message_to_document(message, sections) doc_batch.append(doc) if len(doc_batch) >= self.batch_size: yield [merge_batch()] @@ -296,6 +308,13 @@ class DiscordConnector(LoadConnector, PollConnector): if doc_batch: yield [merge_batch()] + def _manage_doc_batching( + self, + start: datetime | None = None, + end: datetime | None = None, + ) -> GenerateDocumentsOutput: + yield from self._iter_merged_documents(start=start, end=end) + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self._discord_bot_token = credentials["discord_bot_token"] return None @@ -316,6 +335,41 @@ class DiscordConnector(LoadConnector, PollConnector): """Load messages from Discord state""" return self._manage_doc_batching(None, None) + def retrieve_all_slim_docs_perm_sync( + self, + callback: Any = None, + ) -> GenerateSlimDocumentOutput: + del callback + slim_doc_batch: list[SlimDocument] = [] + full_scan_batch_size = 0 + full_scan_batch_first_id: str | None = None + + for message in _manage_async_retrieval( + token=self.discord_bot_token, + requested_start_date_string=self.requested_start_date_string, + channel_names=self.channel_names, + server_ids=self.server_ids, + start=None, + ): + if full_scan_batch_first_id is None: + full_scan_batch_first_id = f"{_DISCORD_DOC_ID_PREFIX}{message.id}" + full_scan_batch_size += 1 + + if full_scan_batch_size >= self.batch_size: + slim_doc_batch.append(SlimDocument(id=full_scan_batch_first_id)) + full_scan_batch_size = 0 + full_scan_batch_first_id = None + + if len(slim_doc_batch) >= self.batch_size: + yield slim_doc_batch + slim_doc_batch = [] + + if full_scan_batch_first_id is not None: + slim_doc_batch.append(SlimDocument(id=full_scan_batch_first_id)) + + if slim_doc_batch: + yield slim_doc_batch + if __name__ == "__main__": import os diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index 5d36a957f5..a3afbba902 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -248,7 +248,20 @@ class SyncBase: prefix = self._get_source_prefix() prefix = f"{prefix} " if prefix else "" next_update_info = self._format_window_boundary(next_update) - if file_list == []: + expects_deleted_file_snapshot = ( + task.get("reindex") != "1" + and task.get("poll_range_start") + and self.conf.get("sync_deleted_files") + ) + if expects_deleted_file_snapshot and file_list is None: + logging.warning( + "%s deleted-file snapshot retrieval failed " + "(connector_id=%s, kb_id=%s)", + self.SOURCE_NAME, + task["connector_id"], + task["kb_id"], + ) + elif file_list == []: logging.warning( "%s deleted-file sync skipped because the snapshot was empty " "(connector_id=%s, kb_id=%s)", @@ -340,9 +353,7 @@ class _BlobLikeBase(SyncBase): _begin_info, ) ) - if file_list is not None: - return document_batch_generator, file_list - return document_batch_generator + return document_batch_generator, file_list class S3(_BlobLikeBase): @@ -508,9 +519,7 @@ class Notion(SyncBase): _begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format( task["poll_range_start"]) self.log_connection("Notion", f"root({self.conf['root_page_id']})", task) - if file_list is not None: - return document_generator, file_list - return document_generator + return document_generator, file_list class Discord(SyncBase): @@ -528,17 +537,26 @@ class Discord(SyncBase): batch_size=self.conf.get("batch_size", 1024), ) self.connector.load_credentials(self.conf["credentials"]) + file_list = None document_generator = ( self.connector.load_from_state() if task["reindex"] == "1" or not task["poll_range_start"] else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp()) ) + if ( + task["reindex"] != "1" + and task["poll_range_start"] + and self.conf.get("sync_deleted_files") + ): + file_list = [] + for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): + file_list.extend(slim_batch) _begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format( task["poll_range_start"]) self.log_connection("Discord", f"servers({server_ids}), channel({channel_names})", task) - return document_generator + return document_generator, file_list class Gmail(SyncBase): @@ -847,9 +865,7 @@ class Jira(SyncBase): f"overlap_buffer_s={getattr(self.connector, 'time_buffer_seconds', connector_kwargs.get('time_buffer_seconds'))}" ), ) - if file_list is not None: - return document_batches(), file_list - return document_batches() + return document_batches(), file_list @staticmethod def _normalize_list(values: Any) -> list[str] | None: @@ -979,9 +995,7 @@ class BOX(SyncBase): ) _begin_info = f"from {poll_start}" self.log_connection("Box", f"folder_id({self.conf['folder_id']})", task) - if file_list is not None: - return document_generator, file_list - return document_generator + return document_generator, file_list class Airtable(SyncBase): @@ -1028,9 +1042,7 @@ class Airtable(SyncBase): task, ) - if file_list is not None: - return document_generator, file_list - return document_generator + return document_generator, file_list class Asana(SyncBase): SOURCE_NAME: str = FileSource.ASANA diff --git a/web/src/pages/user-setting/data-source/constant/index.tsx b/web/src/pages/user-setting/data-source/constant/index.tsx index efe1c687e4..32619c05f0 100644 --- a/web/src/pages/user-setting/data-source/constant/index.tsx +++ b/web/src/pages/user-setting/data-source/constant/index.tsx @@ -54,7 +54,9 @@ type DataSourceFeatureVisibility = { type DataSourceFormValues = Record; -export const DataSourceFeatureVisibilityMap = { +export const DataSourceFeatureVisibilityMap: Partial< + Record +> = { [DataSourceKey.GITHUB]: { syncDeletedFiles: true, }, @@ -91,6 +93,9 @@ export const DataSourceFeatureVisibilityMap = { [DataSourceKey.NOTION]: { syncDeletedFiles: true, }, + [DataSourceKey.DISCORD]: { + syncDeletedFiles: true, + }, [DataSourceKey.JIRA]: { syncDeletedFiles: true, },