Feat: enable sync deleted file for Discord (#14451)

### What problem does this PR solve?

Feat: enable sync deleted file for Discord

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Magicbook1108
2026-04-29 19:05:40 +08:00
committed by GitHub
parent 2bc8c6d35e
commit de8c6ad0f3
3 changed files with 136 additions and 65 deletions

View File

@@ -13,8 +13,14 @@ from discord.message import Message as DiscordMessage
from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource
from common.data_source.exceptions import ConnectorMissingCredentialError
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
from common.data_source.models import Document, GenerateDocumentsOutput, TextSection
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync
from common.data_source.models import (
Document,
GenerateDocumentsOutput,
GenerateSlimDocumentOutput,
SlimDocument,
TextSection,
)
_DISCORD_DOC_ID_PREFIX = "DISCORD_"
_SNIPPET_LENGTH = 30
@@ -94,8 +100,12 @@ async def _fetch_filtered_channels(
async def _fetch_documents_from_channel(
channel: TextChannel,
start_time: datetime | None,
end_time: datetime | None,
) -> AsyncIterable[Document]:
) -> AsyncIterable[DiscordMessage]:
"""Yield raw Discord messages for one channel and its threads.
This stays at the message layer so callers can decide whether they need
full Document construction or only lightweight ID accounting.
"""
# Discord's epoch starts at 2015-01-01
discord_epoch = datetime(2015, 1, 1, tzinfo=timezone.utc)
if start_time and start_time < discord_epoch:
@@ -109,39 +119,23 @@ async def _fetch_documents_from_channel(
async for channel_message in channel.history(
limit=None,
after=start_time,
before=end_time,
):
# Skip messages that are not the default type
if channel_message.type != MessageType.default:
continue
sections: list[TextSection] = [
TextSection(
text=channel_message.content,
link=channel_message.jump_url,
)
]
yield _convert_message_to_document(channel_message, sections)
yield channel_message
for active_thread in channel.threads:
async for thread_message in active_thread.history(
limit=None,
after=start_time,
before=end_time,
):
# Skip messages that are not the default type
if thread_message.type != MessageType.default:
continue
sections = [
TextSection(
text=thread_message.content,
link=thread_message.jump_url,
)
]
yield _convert_message_to_document(thread_message, sections)
yield thread_message
async for archived_thread in channel.archived_threads(
limit=None,
@@ -149,20 +143,12 @@ async def _fetch_documents_from_channel(
async for thread_message in archived_thread.history(
limit=None,
after=start_time,
before=end_time,
):
# Skip messages that are not the default type
if thread_message.type != MessageType.default:
continue
sections = [
TextSection(
text=thread_message.content,
link=thread_message.jump_url,
)
]
yield _convert_message_to_document(thread_message, sections)
yield thread_message
def _manage_async_retrieval(
@@ -171,20 +157,23 @@ def _manage_async_retrieval(
channel_names: list[str],
server_ids: list[int],
start: datetime | None = None,
end: datetime | None = None,
) -> Iterable[Document]:
) -> Iterable[DiscordMessage]:
"""Bridge the async Discord client into a synchronous iterator.
`start` is only used as a lower bound for the underlying fetch. Callers
that need a narrower time window should apply their own filtering while
iterating so the same full scan can also support deleted-file sync.
"""
# parse requested_start_date_string to datetime
pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None
# Set start_time to the most recent of start and pull_date, or whichever is provided
# Keep the configured start date as the full-scan lower bound.
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
end_time: datetime | None = end
proxy_url: str | None = os.environ.get("https_proxy") or os.environ.get("http_proxy")
if proxy_url:
logging.info(f"Using proxy for Discord: {proxy_url}")
async def _async_fetch() -> AsyncIterable[Document]:
async def _async_fetch() -> AsyncIterable[DiscordMessage]:
intents = Intents.default()
intents.message_content = True
async with Client(intents=intents, proxy=proxy_url) as cli:
@@ -198,15 +187,13 @@ def _manage_async_retrieval(
)
for channel in filtered_channels:
async for doc in _fetch_documents_from_channel(
async for message in _fetch_documents_from_channel(
channel=channel,
start_time=start_time,
end_time=end_time,
):
print(doc)
yield doc
yield message
def run_and_yield() -> Iterable[Document]:
def run_and_yield() -> Iterable[DiscordMessage]:
loop = asyncio.new_event_loop()
try:
# Get the async generator
@@ -228,7 +215,7 @@ def _manage_async_retrieval(
return run_and_yield()
class DiscordConnector(LoadConnector, PollConnector):
class DiscordConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
"""Discord connector for accessing Discord messages and channels"""
def __init__(
@@ -251,12 +238,28 @@ class DiscordConnector(LoadConnector, PollConnector):
raise ConnectorMissingCredentialError("Discord")
return self._discord_bot_token
def _manage_doc_batching(
def _iter_merged_documents(
self,
start: datetime | None = None,
end: datetime | None = None,
) -> GenerateDocumentsOutput:
doc_batch = []
"""Build merged Discord documents for the requested polling window."""
doc_batch: list[Document] = []
def _message_created_at(message: DiscordMessage) -> datetime:
created_at = message.created_at
if created_at.tzinfo is None:
return created_at.replace(tzinfo=timezone.utc)
return created_at.astimezone(timezone.utc)
def _is_in_window(message: DiscordMessage) -> bool:
created_at = _message_created_at(message)
if start is not None and created_at < start:
return False
if end is not None and created_at >= end:
return False
return True
def merge_batch():
nonlocal doc_batch
id = doc_batch[0].id
@@ -280,14 +283,23 @@ class DiscordConnector(LoadConnector, PollConnector):
size_bytes=size_bytes,
)
for doc in _manage_async_retrieval(
for message in _manage_async_retrieval(
token=self.discord_bot_token,
requested_start_date_string=self.requested_start_date_string,
channel_names=self.channel_names,
server_ids=self.server_ids,
start=start,
end=end,
):
if not _is_in_window(message):
continue
sections = [
TextSection(
text=message.content,
link=message.jump_url,
)
]
doc = _convert_message_to_document(message, sections)
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield [merge_batch()]
@@ -296,6 +308,13 @@ class DiscordConnector(LoadConnector, PollConnector):
if doc_batch:
yield [merge_batch()]
def _manage_doc_batching(
self,
start: datetime | None = None,
end: datetime | None = None,
) -> GenerateDocumentsOutput:
yield from self._iter_merged_documents(start=start, end=end)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._discord_bot_token = credentials["discord_bot_token"]
return None
@@ -316,6 +335,41 @@ class DiscordConnector(LoadConnector, PollConnector):
"""Load messages from Discord state"""
return self._manage_doc_batching(None, None)
def retrieve_all_slim_docs_perm_sync(
self,
callback: Any = None,
) -> GenerateSlimDocumentOutput:
del callback
slim_doc_batch: list[SlimDocument] = []
full_scan_batch_size = 0
full_scan_batch_first_id: str | None = None
for message in _manage_async_retrieval(
token=self.discord_bot_token,
requested_start_date_string=self.requested_start_date_string,
channel_names=self.channel_names,
server_ids=self.server_ids,
start=None,
):
if full_scan_batch_first_id is None:
full_scan_batch_first_id = f"{_DISCORD_DOC_ID_PREFIX}{message.id}"
full_scan_batch_size += 1
if full_scan_batch_size >= self.batch_size:
slim_doc_batch.append(SlimDocument(id=full_scan_batch_first_id))
full_scan_batch_size = 0
full_scan_batch_first_id = None
if len(slim_doc_batch) >= self.batch_size:
yield slim_doc_batch
slim_doc_batch = []
if full_scan_batch_first_id is not None:
slim_doc_batch.append(SlimDocument(id=full_scan_batch_first_id))
if slim_doc_batch:
yield slim_doc_batch
if __name__ == "__main__":
import os

View File

@@ -248,7 +248,20 @@ class SyncBase:
prefix = self._get_source_prefix()
prefix = f"{prefix} " if prefix else ""
next_update_info = self._format_window_boundary(next_update)
if file_list == []:
expects_deleted_file_snapshot = (
task.get("reindex") != "1"
and task.get("poll_range_start")
and self.conf.get("sync_deleted_files")
)
if expects_deleted_file_snapshot and file_list is None:
logging.warning(
"%s deleted-file snapshot retrieval failed "
"(connector_id=%s, kb_id=%s)",
self.SOURCE_NAME,
task["connector_id"],
task["kb_id"],
)
elif file_list == []:
logging.warning(
"%s deleted-file sync skipped because the snapshot was empty "
"(connector_id=%s, kb_id=%s)",
@@ -340,9 +353,7 @@ class _BlobLikeBase(SyncBase):
_begin_info,
)
)
if file_list is not None:
return document_batch_generator, file_list
return document_batch_generator
return document_batch_generator, file_list
class S3(_BlobLikeBase):
@@ -508,9 +519,7 @@ class Notion(SyncBase):
_begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(
task["poll_range_start"])
self.log_connection("Notion", f"root({self.conf['root_page_id']})", task)
if file_list is not None:
return document_generator, file_list
return document_generator
return document_generator, file_list
class Discord(SyncBase):
@@ -528,17 +537,26 @@ class Discord(SyncBase):
batch_size=self.conf.get("batch_size", 1024),
)
self.connector.load_credentials(self.conf["credentials"])
file_list = None
document_generator = (
self.connector.load_from_state()
if task["reindex"] == "1" or not task["poll_range_start"]
else self.connector.poll_source(task["poll_range_start"].timestamp(),
datetime.now(timezone.utc).timestamp())
)
if (
task["reindex"] != "1"
and task["poll_range_start"]
and self.conf.get("sync_deleted_files")
):
file_list = []
for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync():
file_list.extend(slim_batch)
_begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(
task["poll_range_start"])
self.log_connection("Discord", f"servers({server_ids}), channel({channel_names})", task)
return document_generator
return document_generator, file_list
class Gmail(SyncBase):
@@ -847,9 +865,7 @@ class Jira(SyncBase):
f"overlap_buffer_s={getattr(self.connector, 'time_buffer_seconds', connector_kwargs.get('time_buffer_seconds'))}"
),
)
if file_list is not None:
return document_batches(), file_list
return document_batches()
return document_batches(), file_list
@staticmethod
def _normalize_list(values: Any) -> list[str] | None:
@@ -979,9 +995,7 @@ class BOX(SyncBase):
)
_begin_info = f"from {poll_start}"
self.log_connection("Box", f"folder_id({self.conf['folder_id']})", task)
if file_list is not None:
return document_generator, file_list
return document_generator
return document_generator, file_list
class Airtable(SyncBase):
@@ -1028,9 +1042,7 @@ class Airtable(SyncBase):
task,
)
if file_list is not None:
return document_generator, file_list
return document_generator
return document_generator, file_list
class Asana(SyncBase):
SOURCE_NAME: str = FileSource.ASANA

View File

@@ -54,7 +54,9 @@ type DataSourceFeatureVisibility = {
type DataSourceFormValues = Record<string, any>;
export const DataSourceFeatureVisibilityMap = {
export const DataSourceFeatureVisibilityMap: Partial<
Record<DataSourceKey, DataSourceFeatureVisibility>
> = {
[DataSourceKey.GITHUB]: {
syncDeletedFiles: true,
},
@@ -91,6 +93,9 @@ export const DataSourceFeatureVisibilityMap = {
[DataSourceKey.NOTION]: {
syncDeletedFiles: true,
},
[DataSourceKey.DISCORD]: {
syncDeletedFiles: true,
},
[DataSourceKey.JIRA]: {
syncDeletedFiles: true,
},