mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
feat: support rss datasource (#13721)
### What problem does this PR solve? Supporting public RSS/Atom feed URLs as data sources for RagFlow. link https://github.com/infiniflow/ragflow/issues/12313 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -114,6 +114,7 @@ class ParserType(StrEnum):
|
||||
class FileSource(StrEnum):
|
||||
LOCAL = ""
|
||||
KNOWLEDGEBASE = "knowledgebase"
|
||||
RSS = "rss"
|
||||
S3 = "s3"
|
||||
NOTION = "notion"
|
||||
DISCORD = "discord"
|
||||
|
||||
@@ -24,6 +24,7 @@ SOFTWARE.
|
||||
"""
|
||||
|
||||
from .blob_connector import BlobStorageConnector
|
||||
from .rss_connector import RSSConnector
|
||||
from .slack_connector import SlackConnector
|
||||
from .gmail_connector import GmailConnector
|
||||
from .notion_connector import NotionConnector
|
||||
@@ -55,6 +56,7 @@ from .exceptions import (
|
||||
|
||||
__all__ = [
|
||||
"BlobStorageConnector",
|
||||
"RSSConnector",
|
||||
"SlackConnector",
|
||||
"GmailConnector",
|
||||
"NotionConnector",
|
||||
|
||||
@@ -40,6 +40,7 @@ class BlobType(str, Enum):
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
"""Document sources"""
|
||||
RSS = "rss"
|
||||
S3 = "s3"
|
||||
NOTION = "notion"
|
||||
R2 = "r2"
|
||||
|
||||
208
common/data_source/rss_connector.py
Normal file
208
common/data_source/rss_connector.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import hashlib
|
||||
import ipaddress
|
||||
import socket
|
||||
from datetime import datetime, timezone
|
||||
from email.utils import parsedate_to_datetime
|
||||
from time import struct_time
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import bs4
|
||||
import feedparser
|
||||
import requests
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE, REQUEST_TIMEOUT_SECONDS, DocumentSource
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector
|
||||
from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch
|
||||
|
||||
|
||||
def _is_private_ip(ip: str) -> bool:
|
||||
try:
|
||||
ip_obj = ipaddress.ip_address(ip)
|
||||
return ip_obj.is_private or ip_obj.is_link_local or ip_obj.is_loopback
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _validate_url_no_ssrf(url: str) -> None:
|
||||
parsed = urlparse(url)
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
raise ValueError("URL must have a valid hostname")
|
||||
|
||||
try:
|
||||
ip = socket.gethostbyname(hostname)
|
||||
if _is_private_ip(ip):
|
||||
raise ValueError(f"URL resolves to private/internal IP address: {ip}")
|
||||
except socket.gaierror as e:
|
||||
raise ValueError(f"Failed to resolve hostname: {hostname}") from e
|
||||
|
||||
|
||||
class RSSConnector(LoadConnector, PollConnector):
|
||||
def __init__(self, feed_url: str, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.feed_url = feed_url.strip()
|
||||
self.batch_size = batch_size
|
||||
self.credentials: dict[str, Any] = {}
|
||||
self._cached_feed: Any | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.credentials = credentials or {}
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
self._validate_feed_url()
|
||||
if self.batch_size < 1:
|
||||
raise ValueError("batch_size must be greater than 0")
|
||||
self._read_feed(require_entries=True)
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
yield from self._load_entries()
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput:
|
||||
yield from self._load_entries(start=start, end=end)
|
||||
|
||||
def _load_entries(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
feed = self._read_feed(require_entries=False)
|
||||
batch: list[Document] = []
|
||||
|
||||
for entry in feed.entries:
|
||||
updated_at = self._resolve_entry_time(entry)
|
||||
ts = updated_at.timestamp()
|
||||
|
||||
if start is not None and ts <= start:
|
||||
continue
|
||||
if end is not None and ts > end:
|
||||
continue
|
||||
|
||||
batch.append(self._build_document(entry, updated_at))
|
||||
|
||||
if len(batch) >= self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
def _validate_feed_url(self) -> None:
|
||||
if not self.feed_url:
|
||||
raise ValueError("feed_url is required")
|
||||
|
||||
parsed = urlparse(self.feed_url)
|
||||
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
|
||||
raise ValueError("feed_url must be a valid http or https URL")
|
||||
|
||||
_validate_url_no_ssrf(self.feed_url)
|
||||
|
||||
def _read_feed(self, require_entries: bool) -> Any:
|
||||
if self._cached_feed is not None:
|
||||
if require_entries and not self._cached_feed.entries:
|
||||
raise ValueError("RSS feed contains no entries")
|
||||
return self._cached_feed
|
||||
|
||||
self._validate_feed_url()
|
||||
|
||||
response = requests.get(self.feed_url, timeout=REQUEST_TIMEOUT_SECONDS, allow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
||||
final_url = getattr(response, "url", self.feed_url)
|
||||
if final_url != self.feed_url and urlparse(final_url).hostname:
|
||||
_validate_url_no_ssrf(final_url)
|
||||
|
||||
feed = feedparser.parse(response.content)
|
||||
if getattr(feed, "bozo", False) and not feed.entries:
|
||||
error = getattr(feed, "bozo_exception", None)
|
||||
if error:
|
||||
raise ValueError(f"Failed to parse RSS feed: {error}") from error
|
||||
raise ValueError("Failed to parse RSS feed")
|
||||
if require_entries and not feed.entries:
|
||||
raise ValueError("RSS feed contains no entries")
|
||||
|
||||
self._cached_feed = feed
|
||||
return feed
|
||||
|
||||
def _build_document(self, entry: Any, updated_at: datetime) -> Document:
|
||||
link = (entry.get("link") or "").strip()
|
||||
title = (entry.get("title") or "").strip()
|
||||
stable_key = (entry.get("id") or link or title or self.feed_url).strip()
|
||||
semantic_identifier = title or link or stable_key
|
||||
content = self._build_content(entry, semantic_identifier)
|
||||
blob = content.encode("utf-8")
|
||||
|
||||
metadata: dict[str, Any] = {"feed_url": self.feed_url}
|
||||
if link:
|
||||
metadata["link"] = link
|
||||
if entry.get("author"):
|
||||
metadata["author"] = entry.get("author")
|
||||
|
||||
categories = []
|
||||
for tag in entry.get("tags", []):
|
||||
if not isinstance(tag, dict):
|
||||
continue
|
||||
term = tag.get("term")
|
||||
if isinstance(term, str) and term:
|
||||
categories.append(term)
|
||||
if categories:
|
||||
metadata["categories"] = categories
|
||||
|
||||
return Document(
|
||||
id=f"rss:{hashlib.md5(stable_key.encode('utf-8')).hexdigest()}",
|
||||
source=DocumentSource.RSS,
|
||||
semantic_identifier=semantic_identifier,
|
||||
extension=".txt",
|
||||
blob=blob,
|
||||
doc_updated_at=updated_at,
|
||||
size_bytes=len(blob),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _build_content(self, entry: Any, semantic_identifier: str) -> str:
|
||||
parts = [semantic_identifier]
|
||||
content_blocks = entry.get("content") or []
|
||||
|
||||
for block in content_blocks:
|
||||
value = block.get("value") if isinstance(block, dict) else None
|
||||
normalized = self._normalize_text(value)
|
||||
if normalized:
|
||||
parts.append(normalized)
|
||||
|
||||
if len(parts) == 1:
|
||||
fallback = entry.get("summary") or entry.get("description") or ""
|
||||
normalized = self._normalize_text(fallback)
|
||||
if normalized:
|
||||
parts.append(normalized)
|
||||
|
||||
return "\n\n".join(part for part in parts if part).strip()
|
||||
|
||||
def _resolve_entry_time(self, entry: Any) -> datetime:
|
||||
for field in ("updated_parsed", "published_parsed"):
|
||||
value = entry.get(field)
|
||||
if value:
|
||||
return self._struct_time_to_utc(value)
|
||||
|
||||
for field in ("updated", "published"):
|
||||
value = entry.get(field)
|
||||
if isinstance(value, str) and value.strip():
|
||||
try:
|
||||
parsed = parsedate_to_datetime(value)
|
||||
except (TypeError, ValueError, IndexError):
|
||||
continue
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed.astimezone(timezone.utc)
|
||||
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_text(value: Any) -> str:
|
||||
if not isinstance(value, str):
|
||||
return ""
|
||||
return bs4.BeautifulSoup(value, "html.parser").get_text("\n", strip=True)
|
||||
|
||||
@staticmethod
|
||||
def _struct_time_to_utc(value: struct_time | tuple[Any, ...]) -> datetime:
|
||||
dt = datetime(*value[:6], tzinfo=timezone.utc)
|
||||
return dt.astimezone(timezone.utc)
|
||||
@@ -31,6 +31,7 @@ dependencies = [
|
||||
"editdistance==0.8.1",
|
||||
"elasticsearch-dsl==8.12.0",
|
||||
"exceptiongroup>=1.3.0,<2.0.0",
|
||||
"feedparser>=6.0.11,<7.0.0",
|
||||
"extract-msg>=0.39.0",
|
||||
"ffmpeg-python>=0.2.0",
|
||||
"flasgger>=0.9.7.1,<0.10.0",
|
||||
|
||||
@@ -43,6 +43,7 @@ from common import settings
|
||||
from common.config_utils import show_configs
|
||||
from common.data_source import (
|
||||
BlobStorageConnector,
|
||||
RSSConnector,
|
||||
NotionConnector,
|
||||
DiscordConnector,
|
||||
GoogleDriveConnector,
|
||||
@@ -243,6 +244,26 @@ class GOOGLE_CLOUD_STORAGE(_BlobLikeBase):
|
||||
DEFAULT_BUCKET_TYPE: str = "google_cloud_storage"
|
||||
|
||||
|
||||
class RSS(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.RSS
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
self.connector = RSSConnector(
|
||||
feed_url=self.conf["feed_url"],
|
||||
batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE),
|
||||
)
|
||||
self.connector.load_credentials(self.conf.get("credentials", {}))
|
||||
self.connector.validate_connector_settings()
|
||||
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]:
|
||||
return self.connector.load_from_state()
|
||||
|
||||
return self.connector.poll_source(
|
||||
task["poll_range_start"].timestamp(),
|
||||
datetime.now(timezone.utc).timestamp(),
|
||||
)
|
||||
|
||||
|
||||
class Confluence(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.CONFLUENCE
|
||||
|
||||
@@ -1347,6 +1368,7 @@ class PostgreSQL(SyncBase):
|
||||
|
||||
|
||||
func_factory = {
|
||||
FileSource.RSS: RSS,
|
||||
FileSource.S3: S3,
|
||||
FileSource.R2: R2,
|
||||
FileSource.OCI_STORAGE: OCI_STORAGE,
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
from datetime import datetime, timezone
|
||||
import importlib
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from types import ModuleType
|
||||
|
||||
import pytest
|
||||
|
||||
import common
|
||||
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
data_source_pkg = ModuleType("common.data_source")
|
||||
data_source_pkg.__path__ = [str(repo_root / "common" / "data_source")]
|
||||
sys.modules["common.data_source"] = data_source_pkg
|
||||
setattr(common, "data_source", data_source_pkg)
|
||||
|
||||
DocumentSource = importlib.import_module("common.data_source.config").DocumentSource
|
||||
RSSConnector = importlib.import_module("common.data_source.rss_connector").RSSConnector
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, content: bytes = b"feed") -> None:
|
||||
self.content = content
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _mock_feed(*entries, bozo=False, bozo_exception=None):
|
||||
return SimpleNamespace(
|
||||
entries=list(entries),
|
||||
bozo=bozo,
|
||||
bozo_exception=bozo_exception,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_connector_settings_rejects_invalid_feed_url():
|
||||
connector = RSSConnector(feed_url="ftp://example.com/feed.xml")
|
||||
|
||||
with pytest.raises(ValueError, match="valid http or https URL"):
|
||||
connector.validate_connector_settings()
|
||||
|
||||
|
||||
def test_validate_connector_settings_rejects_empty_feed(monkeypatch):
|
||||
monkeypatch.setattr("common.data_source.rss_connector.requests.get", lambda *_args, **_kwargs: _FakeResponse())
|
||||
monkeypatch.setattr(
|
||||
"common.data_source.rss_connector.feedparser.parse",
|
||||
lambda _content: _mock_feed(),
|
||||
)
|
||||
|
||||
connector = RSSConnector(feed_url="https://example.com/feed.xml")
|
||||
|
||||
with pytest.raises(ValueError, match="contains no entries"):
|
||||
connector.validate_connector_settings()
|
||||
|
||||
|
||||
def test_load_from_state_builds_documents(monkeypatch):
|
||||
monkeypatch.setattr("common.data_source.rss_connector.requests.get", lambda *_args, **_kwargs: _FakeResponse())
|
||||
monkeypatch.setattr(
|
||||
"common.data_source.rss_connector.feedparser.parse",
|
||||
lambda _content: _mock_feed(
|
||||
{
|
||||
"id": "entry-1",
|
||||
"link": "https://example.com/posts/1",
|
||||
"title": "Post One",
|
||||
"content": [{"value": "<p>Hello <b>world</b></p>"}],
|
||||
"author": "Alice",
|
||||
"tags": [{"term": "news"}, {"term": "product"}],
|
||||
"updated": "Tue, 02 Jan 2024 15:04:05 GMT",
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
connector = RSSConnector(feed_url="https://example.com/feed.xml")
|
||||
batch = next(connector.load_from_state())
|
||||
|
||||
assert len(batch) == 1
|
||||
doc = batch[0]
|
||||
assert doc.source == DocumentSource.RSS
|
||||
assert doc.semantic_identifier == "Post One"
|
||||
assert doc.extension == ".txt"
|
||||
assert doc.metadata == {
|
||||
"feed_url": "https://example.com/feed.xml",
|
||||
"link": "https://example.com/posts/1",
|
||||
"author": "Alice",
|
||||
"categories": ["news", "product"],
|
||||
}
|
||||
assert "Hello" in doc.blob.decode("utf-8")
|
||||
assert "world" in doc.blob.decode("utf-8")
|
||||
|
||||
|
||||
def test_poll_source_filters_entries_by_timestamp(monkeypatch):
|
||||
monkeypatch.setattr("common.data_source.rss_connector.requests.get", lambda *_args, **_kwargs: _FakeResponse())
|
||||
monkeypatch.setattr(
|
||||
"common.data_source.rss_connector.feedparser.parse",
|
||||
lambda _content: _mock_feed(
|
||||
{
|
||||
"id": "entry-1",
|
||||
"title": "Older",
|
||||
"summary": "older summary",
|
||||
"updated": "Mon, 01 Jan 2024 00:00:00 GMT",
|
||||
},
|
||||
{
|
||||
"id": "entry-2",
|
||||
"title": "Newer",
|
||||
"summary": "new summary",
|
||||
"updated": "Tue, 02 Jan 2024 00:00:00 GMT",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
connector = RSSConnector(feed_url="https://example.com/feed.xml")
|
||||
start = datetime(2024, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime(2024, 1, 2, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
batches = list(connector.poll_source(start, end))
|
||||
|
||||
assert len(batches) == 1
|
||||
assert [doc.semantic_identifier for doc in batches[0]] == ["Newer"]
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -6331,6 +6331,7 @@ dependencies = [
|
||||
{ name = "elasticsearch-dsl" },
|
||||
{ name = "exceptiongroup" },
|
||||
{ name = "extract-msg" },
|
||||
{ name = "feedparser" },
|
||||
{ name = "ffmpeg-python" },
|
||||
{ name = "flasgger" },
|
||||
{ name = "flask-cors" },
|
||||
@@ -6471,6 +6472,7 @@ requires-dist = [
|
||||
{ name = "elasticsearch-dsl", specifier = "==8.12.0" },
|
||||
{ name = "exceptiongroup", specifier = ">=1.3.0,<2.0.0" },
|
||||
{ name = "extract-msg", specifier = ">=0.39.0" },
|
||||
{ name = "feedparser", specifier = ">=6.0.11,<7.0.0" },
|
||||
{ name = "ffmpeg-python", specifier = ">=0.2.0" },
|
||||
{ name = "flasgger", specifier = ">=0.9.7.1,<0.10.0" },
|
||||
{ name = "flask-cors", specifier = "==6.0.2" },
|
||||
|
||||
@@ -1000,6 +1000,8 @@ Example: Virtual Hosted Style`,
|
||||
newDocs: 'New docs',
|
||||
timeStarted: 'Time started',
|
||||
log: 'Log',
|
||||
rssDescription:
|
||||
'Connect to a public RSS or Atom feed and sync feed entries into your knowledge base.',
|
||||
confluenceDescription:
|
||||
'Integrate your Confluence workspace to search documentation.',
|
||||
s3Description:
|
||||
|
||||
@@ -895,6 +895,8 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于
|
||||
newDocs: '新文档',
|
||||
timeStarted: '开始时间',
|
||||
log: '日志',
|
||||
rssDescription:
|
||||
'连接公开的 RSS 或 Atom feed,并将 feed 条目同步到知识库。',
|
||||
confluenceDescription: '连接你的 Confluence 工作区以搜索文档内容。',
|
||||
s3Description: ' 连接你的 AWS S3 存储桶以导入和同步文件。',
|
||||
google_cloud_storageDescription:
|
||||
|
||||
@@ -2,7 +2,7 @@ import { FormFieldType } from '@/components/dynamic-form';
|
||||
import { IconFontFill } from '@/components/icon-font';
|
||||
import SvgIcon from '@/components/svg-icon';
|
||||
import { t, TFunction } from 'i18next';
|
||||
import { Mail } from 'lucide-react';
|
||||
import { Mail, Rss } from 'lucide-react';
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import BoxTokenField from '../component/box-token-field';
|
||||
@@ -15,6 +15,7 @@ import { S3Constant } from './s3-constant';
|
||||
import { seafileConstant } from './seafile-constant';
|
||||
|
||||
export enum DataSourceKey {
|
||||
RSS = 'rss',
|
||||
CONFLUENCE = 'confluence',
|
||||
S3 = 's3',
|
||||
NOTION = 'notion',
|
||||
@@ -47,6 +48,11 @@ export enum DataSourceKey {
|
||||
|
||||
export const generateDataSourceInfo = (t: TFunction) => {
|
||||
return {
|
||||
[DataSourceKey.RSS]: {
|
||||
name: 'RSS',
|
||||
description: t(`setting.${DataSourceKey.RSS}Description`),
|
||||
icon: <Rss className="text-text-primary" size={22} />,
|
||||
},
|
||||
[DataSourceKey.GOOGLE_CLOUD_STORAGE]: {
|
||||
name: 'Google Cloud Storage',
|
||||
description: t(
|
||||
@@ -222,6 +228,25 @@ export const DataSourceFormBaseFields = [
|
||||
},
|
||||
];
|
||||
export const DataSourceFormFields = {
|
||||
[DataSourceKey.RSS]: [
|
||||
{
|
||||
label: 'Feed URL',
|
||||
name: 'config.feed_url',
|
||||
type: FormFieldType.Text,
|
||||
required: true,
|
||||
placeholder: 'https://example.com/feed.xml',
|
||||
},
|
||||
{
|
||||
label: 'Batch Size',
|
||||
name: 'config.batch_size',
|
||||
type: FormFieldType.Number,
|
||||
required: false,
|
||||
validation: {
|
||||
min: 1,
|
||||
message: 'Batch Size must be at least 1',
|
||||
},
|
||||
},
|
||||
],
|
||||
[DataSourceKey.GOOGLE_CLOUD_STORAGE]: [
|
||||
{
|
||||
label: 'GCS Access Key ID',
|
||||
@@ -965,6 +990,14 @@ export const DataSourceFormFields = {
|
||||
};
|
||||
|
||||
export const DataSourceFormDefaultValues = {
|
||||
[DataSourceKey.RSS]: {
|
||||
name: '',
|
||||
source: DataSourceKey.RSS,
|
||||
config: {
|
||||
feed_url: '',
|
||||
batch_size: 2,
|
||||
},
|
||||
},
|
||||
[DataSourceKey.S3]: {
|
||||
name: '',
|
||||
source: DataSourceKey.S3,
|
||||
|
||||
Reference in New Issue
Block a user