From 6ce76e6799521bce60d6e4c825718bc3783495b3 Mon Sep 17 00:00:00 2001 From: Wang Qi Date: Wed, 20 May 2026 19:21:19 +0800 Subject: [PATCH] Fix discord async issue (#15054) ### What problem does this PR solve? RuntimeError: Cannot run the event loop while another loop is running ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- common/data_source/discord_connector.py | 100 ++++++++++++++---------- 1 file changed, 60 insertions(+), 40 deletions(-) diff --git a/common/data_source/discord_connector.py b/common/data_source/discord_connector.py index 83b2b562f0..e047148f33 100644 --- a/common/data_source/discord_connector.py +++ b/common/data_source/discord_connector.py @@ -3,7 +3,9 @@ import asyncio import logging import os +from contextlib import suppress from datetime import datetime, timezone +from queue import Queue from typing import Any, AsyncIterable, Iterable from discord import Client, MessageType @@ -151,14 +153,14 @@ async def _fetch_documents_from_channel( yield thread_message -def _manage_async_retrieval( +async def _manage_async_retrieval( token: str, requested_start_date_string: str, channel_names: list[str], server_ids: list[int], start: datetime | None = None, -) -> Iterable[DiscordMessage]: - """Bridge the async Discord client into a synchronous iterator. +) -> AsyncIterable[DiscordMessage]: + """Fetch Discord messages with the async Discord client. `start` is only used as a lower bound for the underlying fetch. Callers that need a narrower time window should apply their own filtering while @@ -173,11 +175,11 @@ def _manage_async_retrieval( if proxy_url: logging.info(f"Using proxy for Discord: {proxy_url}") - async def _async_fetch() -> AsyncIterable[DiscordMessage]: - intents = Intents.default() - intents.message_content = True - async with Client(intents=intents, proxy=proxy_url) as cli: - asyncio.create_task(coro=cli.start(token)) + intents = Intents.default() + intents.message_content = True + async with Client(intents=intents, proxy=proxy_url) as cli: + client_task = asyncio.create_task(cli.start(token)) + try: await cli.wait_until_ready() filtered_channels: list[TextChannel] = await _fetch_filtered_channels( @@ -192,27 +194,41 @@ def _manage_async_retrieval( start_time=start_time, ): yield message - - def run_and_yield() -> Iterable[DiscordMessage]: - loop = asyncio.new_event_loop() - try: - # Get the async generator - async_gen = _async_fetch() - # Convert to AsyncIterator - async_iter = async_gen.__aiter__() - while True: - try: - # Create a coroutine by calling anext with the async iterator - next_coro = anext(async_iter) - # Run the coroutine to get the next document - doc = loop.run_until_complete(next_coro) - yield doc - except StopAsyncIteration: - break finally: - loop.close() + await cli.close() + client_task.cancel() + with suppress(asyncio.CancelledError): + await client_task - return run_and_yield() + +def _iterate_async_messages(async_messages: AsyncIterable[DiscordMessage]) -> Iterable[DiscordMessage]: + """Expose async Discord retrieval to the existing synchronous connector API.""" + item_queue: Queue[DiscordMessage | BaseException | None] = Queue() + + async def consume_messages() -> None: + async for message in async_messages: + item_queue.put(message) + + def run_consumer() -> None: + try: + asyncio.run(consume_messages()) + except BaseException as exc: + item_queue.put(exc) + finally: + item_queue.put(None) + + consumer_thread = Thread(target=run_consumer, name="discord-connector-retrieval", daemon=True) + consumer_thread.start() + + while True: + item = item_queue.get() + if item is None: + break + if isinstance(item, BaseException): + raise item + yield item + + consumer_thread.join() class DiscordConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): @@ -283,12 +299,14 @@ class DiscordConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): size_bytes=size_bytes, ) - for message in _manage_async_retrieval( - token=self.discord_bot_token, - requested_start_date_string=self.requested_start_date_string, - channel_names=self.channel_names, - server_ids=self.server_ids, - start=start, + for message in _iterate_async_messages( + _manage_async_retrieval( + token=self.discord_bot_token, + requested_start_date_string=self.requested_start_date_string, + channel_names=self.channel_names, + server_ids=self.server_ids, + start=start, + ) ): if not _is_in_window(message): continue @@ -321,7 +339,7 @@ class DiscordConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): def validate_connector_settings(self) -> None: """Validate Discord connector settings""" - if not self.discord_client: + if not self.discord_bot_token: raise ConnectorMissingCredentialError("Discord") def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any: @@ -344,12 +362,14 @@ class DiscordConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): full_scan_batch_size = 0 full_scan_batch_first_id: str | None = None - for message in _manage_async_retrieval( - token=self.discord_bot_token, - requested_start_date_string=self.requested_start_date_string, - channel_names=self.channel_names, - server_ids=self.server_ids, - start=None, + for message in _iterate_async_messages( + _manage_async_retrieval( + token=self.discord_bot_token, + requested_start_date_string=self.requested_start_date_string, + channel_names=self.channel_names, + server_ids=self.server_ids, + start=None, + ) ): if full_scan_batch_first_id is None: full_scan_batch_first_id = f"{_DISCORD_DOC_ID_PREFIX}{message.id}"