mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Fix discord async issue (#15054)
### What problem does this PR solve? RuntimeError: Cannot run the event loop while another loop is running ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@@ -3,7 +3,9 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import suppress
|
||||
from datetime import datetime, timezone
|
||||
from queue import Queue
|
||||
from typing import Any, AsyncIterable, Iterable
|
||||
|
||||
from discord import Client, MessageType
|
||||
@@ -151,14 +153,14 @@ async def _fetch_documents_from_channel(
|
||||
yield thread_message
|
||||
|
||||
|
||||
def _manage_async_retrieval(
|
||||
async def _manage_async_retrieval(
|
||||
token: str,
|
||||
requested_start_date_string: str,
|
||||
channel_names: list[str],
|
||||
server_ids: list[int],
|
||||
start: datetime | None = None,
|
||||
) -> Iterable[DiscordMessage]:
|
||||
"""Bridge the async Discord client into a synchronous iterator.
|
||||
) -> AsyncIterable[DiscordMessage]:
|
||||
"""Fetch Discord messages with the async Discord client.
|
||||
|
||||
`start` is only used as a lower bound for the underlying fetch. Callers
|
||||
that need a narrower time window should apply their own filtering while
|
||||
@@ -173,11 +175,11 @@ def _manage_async_retrieval(
|
||||
if proxy_url:
|
||||
logging.info(f"Using proxy for Discord: {proxy_url}")
|
||||
|
||||
async def _async_fetch() -> AsyncIterable[DiscordMessage]:
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
async with Client(intents=intents, proxy=proxy_url) as cli:
|
||||
asyncio.create_task(coro=cli.start(token))
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
async with Client(intents=intents, proxy=proxy_url) as cli:
|
||||
client_task = asyncio.create_task(cli.start(token))
|
||||
try:
|
||||
await cli.wait_until_ready()
|
||||
|
||||
filtered_channels: list[TextChannel] = await _fetch_filtered_channels(
|
||||
@@ -192,27 +194,41 @@ def _manage_async_retrieval(
|
||||
start_time=start_time,
|
||||
):
|
||||
yield message
|
||||
|
||||
def run_and_yield() -> Iterable[DiscordMessage]:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _async_fetch()
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
# Run the coroutine to get the next document
|
||||
doc = loop.run_until_complete(next_coro)
|
||||
yield doc
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
loop.close()
|
||||
await cli.close()
|
||||
client_task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await client_task
|
||||
|
||||
return run_and_yield()
|
||||
|
||||
def _iterate_async_messages(async_messages: AsyncIterable[DiscordMessage]) -> Iterable[DiscordMessage]:
|
||||
"""Expose async Discord retrieval to the existing synchronous connector API."""
|
||||
item_queue: Queue[DiscordMessage | BaseException | None] = Queue()
|
||||
|
||||
async def consume_messages() -> None:
|
||||
async for message in async_messages:
|
||||
item_queue.put(message)
|
||||
|
||||
def run_consumer() -> None:
|
||||
try:
|
||||
asyncio.run(consume_messages())
|
||||
except BaseException as exc:
|
||||
item_queue.put(exc)
|
||||
finally:
|
||||
item_queue.put(None)
|
||||
|
||||
consumer_thread = Thread(target=run_consumer, name="discord-connector-retrieval", daemon=True)
|
||||
consumer_thread.start()
|
||||
|
||||
while True:
|
||||
item = item_queue.get()
|
||||
if item is None:
|
||||
break
|
||||
if isinstance(item, BaseException):
|
||||
raise item
|
||||
yield item
|
||||
|
||||
consumer_thread.join()
|
||||
|
||||
|
||||
class DiscordConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
@@ -283,12 +299,14 @@ class DiscordConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
size_bytes=size_bytes,
|
||||
)
|
||||
|
||||
for message in _manage_async_retrieval(
|
||||
token=self.discord_bot_token,
|
||||
requested_start_date_string=self.requested_start_date_string,
|
||||
channel_names=self.channel_names,
|
||||
server_ids=self.server_ids,
|
||||
start=start,
|
||||
for message in _iterate_async_messages(
|
||||
_manage_async_retrieval(
|
||||
token=self.discord_bot_token,
|
||||
requested_start_date_string=self.requested_start_date_string,
|
||||
channel_names=self.channel_names,
|
||||
server_ids=self.server_ids,
|
||||
start=start,
|
||||
)
|
||||
):
|
||||
if not _is_in_window(message):
|
||||
continue
|
||||
@@ -321,7 +339,7 @@ class DiscordConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Discord connector settings"""
|
||||
if not self.discord_client:
|
||||
if not self.discord_bot_token:
|
||||
raise ConnectorMissingCredentialError("Discord")
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||
@@ -344,12 +362,14 @@ class DiscordConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
full_scan_batch_size = 0
|
||||
full_scan_batch_first_id: str | None = None
|
||||
|
||||
for message in _manage_async_retrieval(
|
||||
token=self.discord_bot_token,
|
||||
requested_start_date_string=self.requested_start_date_string,
|
||||
channel_names=self.channel_names,
|
||||
server_ids=self.server_ids,
|
||||
start=None,
|
||||
for message in _iterate_async_messages(
|
||||
_manage_async_retrieval(
|
||||
token=self.discord_bot_token,
|
||||
requested_start_date_string=self.requested_start_date_string,
|
||||
channel_names=self.channel_names,
|
||||
server_ids=self.server_ids,
|
||||
start=None,
|
||||
)
|
||||
):
|
||||
if full_scan_batch_first_id is None:
|
||||
full_scan_batch_first_id = f"{_DISCORD_DOC_ID_PREFIX}{message.id}"
|
||||
|
||||
Reference in New Issue
Block a user