mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Feat: enable sync deleted file for Discord (#14451)
### What problem does this PR solve? Feat: enable sync deleted file for Discord ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -13,8 +13,14 @@ from discord.message import Message as DiscordMessage
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource
|
||||
from common.data_source.exceptions import ConnectorMissingCredentialError
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
|
||||
from common.data_source.models import Document, GenerateDocumentsOutput, TextSection
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync
|
||||
from common.data_source.models import (
|
||||
Document,
|
||||
GenerateDocumentsOutput,
|
||||
GenerateSlimDocumentOutput,
|
||||
SlimDocument,
|
||||
TextSection,
|
||||
)
|
||||
|
||||
_DISCORD_DOC_ID_PREFIX = "DISCORD_"
|
||||
_SNIPPET_LENGTH = 30
|
||||
@@ -94,8 +100,12 @@ async def _fetch_filtered_channels(
|
||||
async def _fetch_documents_from_channel(
|
||||
channel: TextChannel,
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
) -> AsyncIterable[Document]:
|
||||
) -> AsyncIterable[DiscordMessage]:
|
||||
"""Yield raw Discord messages for one channel and its threads.
|
||||
|
||||
This stays at the message layer so callers can decide whether they need
|
||||
full Document construction or only lightweight ID accounting.
|
||||
"""
|
||||
# Discord's epoch starts at 2015-01-01
|
||||
discord_epoch = datetime(2015, 1, 1, tzinfo=timezone.utc)
|
||||
if start_time and start_time < discord_epoch:
|
||||
@@ -109,39 +119,23 @@ async def _fetch_documents_from_channel(
|
||||
async for channel_message in channel.history(
|
||||
limit=None,
|
||||
after=start_time,
|
||||
before=end_time,
|
||||
):
|
||||
# Skip messages that are not the default type
|
||||
if channel_message.type != MessageType.default:
|
||||
continue
|
||||
|
||||
sections: list[TextSection] = [
|
||||
TextSection(
|
||||
text=channel_message.content,
|
||||
link=channel_message.jump_url,
|
||||
)
|
||||
]
|
||||
|
||||
yield _convert_message_to_document(channel_message, sections)
|
||||
yield channel_message
|
||||
|
||||
for active_thread in channel.threads:
|
||||
async for thread_message in active_thread.history(
|
||||
limit=None,
|
||||
after=start_time,
|
||||
before=end_time,
|
||||
):
|
||||
# Skip messages that are not the default type
|
||||
if thread_message.type != MessageType.default:
|
||||
continue
|
||||
|
||||
sections = [
|
||||
TextSection(
|
||||
text=thread_message.content,
|
||||
link=thread_message.jump_url,
|
||||
)
|
||||
]
|
||||
|
||||
yield _convert_message_to_document(thread_message, sections)
|
||||
yield thread_message
|
||||
|
||||
async for archived_thread in channel.archived_threads(
|
||||
limit=None,
|
||||
@@ -149,20 +143,12 @@ async def _fetch_documents_from_channel(
|
||||
async for thread_message in archived_thread.history(
|
||||
limit=None,
|
||||
after=start_time,
|
||||
before=end_time,
|
||||
):
|
||||
# Skip messages that are not the default type
|
||||
if thread_message.type != MessageType.default:
|
||||
continue
|
||||
|
||||
sections = [
|
||||
TextSection(
|
||||
text=thread_message.content,
|
||||
link=thread_message.jump_url,
|
||||
)
|
||||
]
|
||||
|
||||
yield _convert_message_to_document(thread_message, sections)
|
||||
yield thread_message
|
||||
|
||||
|
||||
def _manage_async_retrieval(
|
||||
@@ -171,20 +157,23 @@ def _manage_async_retrieval(
|
||||
channel_names: list[str],
|
||||
server_ids: list[int],
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> Iterable[Document]:
|
||||
) -> Iterable[DiscordMessage]:
|
||||
"""Bridge the async Discord client into a synchronous iterator.
|
||||
|
||||
`start` is only used as a lower bound for the underlying fetch. Callers
|
||||
that need a narrower time window should apply their own filtering while
|
||||
iterating so the same full scan can also support deleted-file sync.
|
||||
"""
|
||||
# parse requested_start_date_string to datetime
|
||||
pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None
|
||||
|
||||
# Set start_time to the most recent of start and pull_date, or whichever is provided
|
||||
# Keep the configured start date as the full-scan lower bound.
|
||||
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
|
||||
|
||||
end_time: datetime | None = end
|
||||
proxy_url: str | None = os.environ.get("https_proxy") or os.environ.get("http_proxy")
|
||||
if proxy_url:
|
||||
logging.info(f"Using proxy for Discord: {proxy_url}")
|
||||
|
||||
async def _async_fetch() -> AsyncIterable[Document]:
|
||||
async def _async_fetch() -> AsyncIterable[DiscordMessage]:
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
async with Client(intents=intents, proxy=proxy_url) as cli:
|
||||
@@ -198,15 +187,13 @@ def _manage_async_retrieval(
|
||||
)
|
||||
|
||||
for channel in filtered_channels:
|
||||
async for doc in _fetch_documents_from_channel(
|
||||
async for message in _fetch_documents_from_channel(
|
||||
channel=channel,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
):
|
||||
print(doc)
|
||||
yield doc
|
||||
yield message
|
||||
|
||||
def run_and_yield() -> Iterable[Document]:
|
||||
def run_and_yield() -> Iterable[DiscordMessage]:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# Get the async generator
|
||||
@@ -228,7 +215,7 @@ def _manage_async_retrieval(
|
||||
return run_and_yield()
|
||||
|
||||
|
||||
class DiscordConnector(LoadConnector, PollConnector):
|
||||
class DiscordConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""Discord connector for accessing Discord messages and channels"""
|
||||
|
||||
def __init__(
|
||||
@@ -251,12 +238,28 @@ class DiscordConnector(LoadConnector, PollConnector):
|
||||
raise ConnectorMissingCredentialError("Discord")
|
||||
return self._discord_bot_token
|
||||
|
||||
def _manage_doc_batching(
|
||||
def _iter_merged_documents(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch = []
|
||||
"""Build merged Discord documents for the requested polling window."""
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
def _message_created_at(message: DiscordMessage) -> datetime:
|
||||
created_at = message.created_at
|
||||
if created_at.tzinfo is None:
|
||||
return created_at.replace(tzinfo=timezone.utc)
|
||||
return created_at.astimezone(timezone.utc)
|
||||
|
||||
def _is_in_window(message: DiscordMessage) -> bool:
|
||||
created_at = _message_created_at(message)
|
||||
if start is not None and created_at < start:
|
||||
return False
|
||||
if end is not None and created_at >= end:
|
||||
return False
|
||||
return True
|
||||
|
||||
def merge_batch():
|
||||
nonlocal doc_batch
|
||||
id = doc_batch[0].id
|
||||
@@ -280,14 +283,23 @@ class DiscordConnector(LoadConnector, PollConnector):
|
||||
size_bytes=size_bytes,
|
||||
)
|
||||
|
||||
for doc in _manage_async_retrieval(
|
||||
for message in _manage_async_retrieval(
|
||||
token=self.discord_bot_token,
|
||||
requested_start_date_string=self.requested_start_date_string,
|
||||
channel_names=self.channel_names,
|
||||
server_ids=self.server_ids,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
if not _is_in_window(message):
|
||||
continue
|
||||
|
||||
sections = [
|
||||
TextSection(
|
||||
text=message.content,
|
||||
link=message.jump_url,
|
||||
)
|
||||
]
|
||||
doc = _convert_message_to_document(message, sections)
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield [merge_batch()]
|
||||
@@ -296,6 +308,13 @@ class DiscordConnector(LoadConnector, PollConnector):
|
||||
if doc_batch:
|
||||
yield [merge_batch()]
|
||||
|
||||
def _manage_doc_batching(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
yield from self._iter_merged_documents(start=start, end=end)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._discord_bot_token = credentials["discord_bot_token"]
|
||||
return None
|
||||
@@ -316,6 +335,41 @@ class DiscordConnector(LoadConnector, PollConnector):
|
||||
"""Load messages from Discord state"""
|
||||
return self._manage_doc_batching(None, None)
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
callback: Any = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
del callback
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
full_scan_batch_size = 0
|
||||
full_scan_batch_first_id: str | None = None
|
||||
|
||||
for message in _manage_async_retrieval(
|
||||
token=self.discord_bot_token,
|
||||
requested_start_date_string=self.requested_start_date_string,
|
||||
channel_names=self.channel_names,
|
||||
server_ids=self.server_ids,
|
||||
start=None,
|
||||
):
|
||||
if full_scan_batch_first_id is None:
|
||||
full_scan_batch_first_id = f"{_DISCORD_DOC_ID_PREFIX}{message.id}"
|
||||
full_scan_batch_size += 1
|
||||
|
||||
if full_scan_batch_size >= self.batch_size:
|
||||
slim_doc_batch.append(SlimDocument(id=full_scan_batch_first_id))
|
||||
full_scan_batch_size = 0
|
||||
full_scan_batch_first_id = None
|
||||
|
||||
if len(slim_doc_batch) >= self.batch_size:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
|
||||
if full_scan_batch_first_id is not None:
|
||||
slim_doc_batch.append(SlimDocument(id=full_scan_batch_first_id))
|
||||
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
@@ -248,7 +248,20 @@ class SyncBase:
|
||||
prefix = self._get_source_prefix()
|
||||
prefix = f"{prefix} " if prefix else ""
|
||||
next_update_info = self._format_window_boundary(next_update)
|
||||
if file_list == []:
|
||||
expects_deleted_file_snapshot = (
|
||||
task.get("reindex") != "1"
|
||||
and task.get("poll_range_start")
|
||||
and self.conf.get("sync_deleted_files")
|
||||
)
|
||||
if expects_deleted_file_snapshot and file_list is None:
|
||||
logging.warning(
|
||||
"%s deleted-file snapshot retrieval failed "
|
||||
"(connector_id=%s, kb_id=%s)",
|
||||
self.SOURCE_NAME,
|
||||
task["connector_id"],
|
||||
task["kb_id"],
|
||||
)
|
||||
elif file_list == []:
|
||||
logging.warning(
|
||||
"%s deleted-file sync skipped because the snapshot was empty "
|
||||
"(connector_id=%s, kb_id=%s)",
|
||||
@@ -340,9 +353,7 @@ class _BlobLikeBase(SyncBase):
|
||||
_begin_info,
|
||||
)
|
||||
)
|
||||
if file_list is not None:
|
||||
return document_batch_generator, file_list
|
||||
return document_batch_generator
|
||||
return document_batch_generator, file_list
|
||||
|
||||
|
||||
class S3(_BlobLikeBase):
|
||||
@@ -508,9 +519,7 @@ class Notion(SyncBase):
|
||||
_begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(
|
||||
task["poll_range_start"])
|
||||
self.log_connection("Notion", f"root({self.conf['root_page_id']})", task)
|
||||
if file_list is not None:
|
||||
return document_generator, file_list
|
||||
return document_generator
|
||||
return document_generator, file_list
|
||||
|
||||
|
||||
class Discord(SyncBase):
|
||||
@@ -528,17 +537,26 @@ class Discord(SyncBase):
|
||||
batch_size=self.conf.get("batch_size", 1024),
|
||||
)
|
||||
self.connector.load_credentials(self.conf["credentials"])
|
||||
file_list = None
|
||||
document_generator = (
|
||||
self.connector.load_from_state()
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]
|
||||
else self.connector.poll_source(task["poll_range_start"].timestamp(),
|
||||
datetime.now(timezone.utc).timestamp())
|
||||
)
|
||||
if (
|
||||
task["reindex"] != "1"
|
||||
and task["poll_range_start"]
|
||||
and self.conf.get("sync_deleted_files")
|
||||
):
|
||||
file_list = []
|
||||
for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync():
|
||||
file_list.extend(slim_batch)
|
||||
|
||||
_begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(
|
||||
task["poll_range_start"])
|
||||
self.log_connection("Discord", f"servers({server_ids}), channel({channel_names})", task)
|
||||
return document_generator
|
||||
return document_generator, file_list
|
||||
|
||||
|
||||
class Gmail(SyncBase):
|
||||
@@ -847,9 +865,7 @@ class Jira(SyncBase):
|
||||
f"overlap_buffer_s={getattr(self.connector, 'time_buffer_seconds', connector_kwargs.get('time_buffer_seconds'))}"
|
||||
),
|
||||
)
|
||||
if file_list is not None:
|
||||
return document_batches(), file_list
|
||||
return document_batches()
|
||||
return document_batches(), file_list
|
||||
|
||||
@staticmethod
|
||||
def _normalize_list(values: Any) -> list[str] | None:
|
||||
@@ -979,9 +995,7 @@ class BOX(SyncBase):
|
||||
)
|
||||
_begin_info = f"from {poll_start}"
|
||||
self.log_connection("Box", f"folder_id({self.conf['folder_id']})", task)
|
||||
if file_list is not None:
|
||||
return document_generator, file_list
|
||||
return document_generator
|
||||
return document_generator, file_list
|
||||
|
||||
|
||||
class Airtable(SyncBase):
|
||||
@@ -1028,9 +1042,7 @@ class Airtable(SyncBase):
|
||||
task,
|
||||
)
|
||||
|
||||
if file_list is not None:
|
||||
return document_generator, file_list
|
||||
return document_generator
|
||||
return document_generator, file_list
|
||||
|
||||
class Asana(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.ASANA
|
||||
|
||||
@@ -54,7 +54,9 @@ type DataSourceFeatureVisibility = {
|
||||
|
||||
type DataSourceFormValues = Record<string, any>;
|
||||
|
||||
export const DataSourceFeatureVisibilityMap = {
|
||||
export const DataSourceFeatureVisibilityMap: Partial<
|
||||
Record<DataSourceKey, DataSourceFeatureVisibility>
|
||||
> = {
|
||||
[DataSourceKey.GITHUB]: {
|
||||
syncDeletedFiles: true,
|
||||
},
|
||||
@@ -91,6 +93,9 @@ export const DataSourceFeatureVisibilityMap = {
|
||||
[DataSourceKey.NOTION]: {
|
||||
syncDeletedFiles: true,
|
||||
},
|
||||
[DataSourceKey.DISCORD]: {
|
||||
syncDeletedFiles: true,
|
||||
},
|
||||
[DataSourceKey.JIRA]: {
|
||||
syncDeletedFiles: true,
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user