Fix discord async issue (#15054)

### What problem does this PR solve?

RuntimeError: Cannot run the event loop while another loop is running

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Wang Qi
2026-05-20 19:21:19 +08:00
committed by GitHub
parent b28e134944
commit 6ce76e6799

View File

@@ -3,7 +3,9 @@
import asyncio
import logging
import os
from contextlib import suppress
from datetime import datetime, timezone
from queue import Queue
from typing import Any, AsyncIterable, Iterable
from discord import Client, MessageType
@@ -151,14 +153,14 @@ async def _fetch_documents_from_channel(
yield thread_message
def _manage_async_retrieval(
async def _manage_async_retrieval(
token: str,
requested_start_date_string: str,
channel_names: list[str],
server_ids: list[int],
start: datetime | None = None,
) -> Iterable[DiscordMessage]:
"""Bridge the async Discord client into a synchronous iterator.
) -> AsyncIterable[DiscordMessage]:
"""Fetch Discord messages with the async Discord client.
`start` is only used as a lower bound for the underlying fetch. Callers
that need a narrower time window should apply their own filtering while
@@ -173,11 +175,11 @@ def _manage_async_retrieval(
if proxy_url:
logging.info(f"Using proxy for Discord: {proxy_url}")
async def _async_fetch() -> AsyncIterable[DiscordMessage]:
intents = Intents.default()
intents.message_content = True
async with Client(intents=intents, proxy=proxy_url) as cli:
asyncio.create_task(coro=cli.start(token))
intents = Intents.default()
intents.message_content = True
async with Client(intents=intents, proxy=proxy_url) as cli:
client_task = asyncio.create_task(cli.start(token))
try:
await cli.wait_until_ready()
filtered_channels: list[TextChannel] = await _fetch_filtered_channels(
@@ -192,27 +194,41 @@ def _manage_async_retrieval(
start_time=start_time,
):
yield message
def run_and_yield() -> Iterable[DiscordMessage]:
loop = asyncio.new_event_loop()
try:
# Get the async generator
async_gen = _async_fetch()
# Convert to AsyncIterator
async_iter = async_gen.__aiter__()
while True:
try:
# Create a coroutine by calling anext with the async iterator
next_coro = anext(async_iter)
# Run the coroutine to get the next document
doc = loop.run_until_complete(next_coro)
yield doc
except StopAsyncIteration:
break
finally:
loop.close()
await cli.close()
client_task.cancel()
with suppress(asyncio.CancelledError):
await client_task
return run_and_yield()
def _iterate_async_messages(async_messages: AsyncIterable[DiscordMessage]) -> Iterable[DiscordMessage]:
"""Expose async Discord retrieval to the existing synchronous connector API."""
item_queue: Queue[DiscordMessage | BaseException | None] = Queue()
async def consume_messages() -> None:
async for message in async_messages:
item_queue.put(message)
def run_consumer() -> None:
try:
asyncio.run(consume_messages())
except BaseException as exc:
item_queue.put(exc)
finally:
item_queue.put(None)
consumer_thread = Thread(target=run_consumer, name="discord-connector-retrieval", daemon=True)
consumer_thread.start()
while True:
item = item_queue.get()
if item is None:
break
if isinstance(item, BaseException):
raise item
yield item
consumer_thread.join()
class DiscordConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
@@ -283,12 +299,14 @@ class DiscordConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
size_bytes=size_bytes,
)
for message in _manage_async_retrieval(
token=self.discord_bot_token,
requested_start_date_string=self.requested_start_date_string,
channel_names=self.channel_names,
server_ids=self.server_ids,
start=start,
for message in _iterate_async_messages(
_manage_async_retrieval(
token=self.discord_bot_token,
requested_start_date_string=self.requested_start_date_string,
channel_names=self.channel_names,
server_ids=self.server_ids,
start=start,
)
):
if not _is_in_window(message):
continue
@@ -321,7 +339,7 @@ class DiscordConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
def validate_connector_settings(self) -> None:
"""Validate Discord connector settings"""
if not self.discord_client:
if not self.discord_bot_token:
raise ConnectorMissingCredentialError("Discord")
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
@@ -344,12 +362,14 @@ class DiscordConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
full_scan_batch_size = 0
full_scan_batch_first_id: str | None = None
for message in _manage_async_retrieval(
token=self.discord_bot_token,
requested_start_date_string=self.requested_start_date_string,
channel_names=self.channel_names,
server_ids=self.server_ids,
start=None,
for message in _iterate_async_messages(
_manage_async_retrieval(
token=self.discord_bot_token,
requested_start_date_string=self.requested_start_date_string,
channel_names=self.channel_names,
server_ids=self.server_ids,
start=None,
)
):
if full_scan_batch_first_id is None:
full_scan_batch_first_id = f"{_DISCORD_DOC_ID_PREFIX}{message.id}"