diff --git a/common/constants.py b/common/constants.py index 2742554010..b027908637 100644 --- a/common/constants.py +++ b/common/constants.py @@ -114,6 +114,7 @@ class ParserType(StrEnum): class FileSource(StrEnum): LOCAL = "" KNOWLEDGEBASE = "knowledgebase" + RSS = "rss" S3 = "s3" NOTION = "notion" DISCORD = "discord" diff --git a/common/data_source/__init__.py b/common/data_source/__init__.py index 022a107613..301103652c 100644 --- a/common/data_source/__init__.py +++ b/common/data_source/__init__.py @@ -24,6 +24,7 @@ SOFTWARE. """ from .blob_connector import BlobStorageConnector +from .rss_connector import RSSConnector from .slack_connector import SlackConnector from .gmail_connector import GmailConnector from .notion_connector import NotionConnector @@ -55,6 +56,7 @@ from .exceptions import ( __all__ = [ "BlobStorageConnector", + "RSSConnector", "SlackConnector", "GmailConnector", "NotionConnector", diff --git a/common/data_source/config.py b/common/data_source/config.py index 65338f34a6..c67d8b1a82 100644 --- a/common/data_source/config.py +++ b/common/data_source/config.py @@ -40,6 +40,7 @@ class BlobType(str, Enum): class DocumentSource(str, Enum): """Document sources""" + RSS = "rss" S3 = "s3" NOTION = "notion" R2 = "r2" diff --git a/common/data_source/rss_connector.py b/common/data_source/rss_connector.py new file mode 100644 index 0000000000..85471407ab --- /dev/null +++ b/common/data_source/rss_connector.py @@ -0,0 +1,208 @@ +import hashlib +import ipaddress +import socket +from datetime import datetime, timezone +from email.utils import parsedate_to_datetime +from time import struct_time +from typing import Any +from urllib.parse import urlparse + +import bs4 +import feedparser +import requests + +from common.data_source.config import INDEX_BATCH_SIZE, REQUEST_TIMEOUT_SECONDS, DocumentSource +from common.data_source.interfaces import LoadConnector, PollConnector +from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch + + +def _is_private_ip(ip: str) -> bool: + try: + ip_obj = ipaddress.ip_address(ip) + return ip_obj.is_private or ip_obj.is_link_local or ip_obj.is_loopback + except ValueError: + return False + + +def _validate_url_no_ssrf(url: str) -> None: + parsed = urlparse(url) + hostname = parsed.hostname + if not hostname: + raise ValueError("URL must have a valid hostname") + + try: + ip = socket.gethostbyname(hostname) + if _is_private_ip(ip): + raise ValueError(f"URL resolves to private/internal IP address: {ip}") + except socket.gaierror as e: + raise ValueError(f"Failed to resolve hostname: {hostname}") from e + + +class RSSConnector(LoadConnector, PollConnector): + def __init__(self, feed_url: str, batch_size: int = INDEX_BATCH_SIZE) -> None: + self.feed_url = feed_url.strip() + self.batch_size = batch_size + self.credentials: dict[str, Any] = {} + self._cached_feed: Any | None = None + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + self.credentials = credentials or {} + return None + + def validate_connector_settings(self) -> None: + self._validate_feed_url() + if self.batch_size < 1: + raise ValueError("batch_size must be greater than 0") + self._read_feed(require_entries=True) + + def load_from_state(self) -> GenerateDocumentsOutput: + yield from self._load_entries() + + def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput: + yield from self._load_entries(start=start, end=end) + + def _load_entries( + self, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> GenerateDocumentsOutput: + feed = self._read_feed(require_entries=False) + batch: list[Document] = [] + + for entry in feed.entries: + updated_at = self._resolve_entry_time(entry) + ts = updated_at.timestamp() + + if start is not None and ts <= start: + continue + if end is not None and ts > end: + continue + + batch.append(self._build_document(entry, updated_at)) + + if len(batch) >= self.batch_size: + yield batch + batch = [] + + if batch: + yield batch + + def _validate_feed_url(self) -> None: + if not self.feed_url: + raise ValueError("feed_url is required") + + parsed = urlparse(self.feed_url) + if parsed.scheme not in {"http", "https"} or not parsed.netloc: + raise ValueError("feed_url must be a valid http or https URL") + + _validate_url_no_ssrf(self.feed_url) + + def _read_feed(self, require_entries: bool) -> Any: + if self._cached_feed is not None: + if require_entries and not self._cached_feed.entries: + raise ValueError("RSS feed contains no entries") + return self._cached_feed + + self._validate_feed_url() + + response = requests.get(self.feed_url, timeout=REQUEST_TIMEOUT_SECONDS, allow_redirects=True) + response.raise_for_status() + + final_url = getattr(response, "url", self.feed_url) + if final_url != self.feed_url and urlparse(final_url).hostname: + _validate_url_no_ssrf(final_url) + + feed = feedparser.parse(response.content) + if getattr(feed, "bozo", False) and not feed.entries: + error = getattr(feed, "bozo_exception", None) + if error: + raise ValueError(f"Failed to parse RSS feed: {error}") from error + raise ValueError("Failed to parse RSS feed") + if require_entries and not feed.entries: + raise ValueError("RSS feed contains no entries") + + self._cached_feed = feed + return feed + + def _build_document(self, entry: Any, updated_at: datetime) -> Document: + link = (entry.get("link") or "").strip() + title = (entry.get("title") or "").strip() + stable_key = (entry.get("id") or link or title or self.feed_url).strip() + semantic_identifier = title or link or stable_key + content = self._build_content(entry, semantic_identifier) + blob = content.encode("utf-8") + + metadata: dict[str, Any] = {"feed_url": self.feed_url} + if link: + metadata["link"] = link + if entry.get("author"): + metadata["author"] = entry.get("author") + + categories = [] + for tag in entry.get("tags", []): + if not isinstance(tag, dict): + continue + term = tag.get("term") + if isinstance(term, str) and term: + categories.append(term) + if categories: + metadata["categories"] = categories + + return Document( + id=f"rss:{hashlib.md5(stable_key.encode('utf-8')).hexdigest()}", + source=DocumentSource.RSS, + semantic_identifier=semantic_identifier, + extension=".txt", + blob=blob, + doc_updated_at=updated_at, + size_bytes=len(blob), + metadata=metadata, + ) + + def _build_content(self, entry: Any, semantic_identifier: str) -> str: + parts = [semantic_identifier] + content_blocks = entry.get("content") or [] + + for block in content_blocks: + value = block.get("value") if isinstance(block, dict) else None + normalized = self._normalize_text(value) + if normalized: + parts.append(normalized) + + if len(parts) == 1: + fallback = entry.get("summary") or entry.get("description") or "" + normalized = self._normalize_text(fallback) + if normalized: + parts.append(normalized) + + return "\n\n".join(part for part in parts if part).strip() + + def _resolve_entry_time(self, entry: Any) -> datetime: + for field in ("updated_parsed", "published_parsed"): + value = entry.get(field) + if value: + return self._struct_time_to_utc(value) + + for field in ("updated", "published"): + value = entry.get(field) + if isinstance(value, str) and value.strip(): + try: + parsed = parsedate_to_datetime(value) + except (TypeError, ValueError, IndexError): + continue + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + return parsed.astimezone(timezone.utc) + + return datetime.now(timezone.utc) + + @staticmethod + def _normalize_text(value: Any) -> str: + if not isinstance(value, str): + return "" + return bs4.BeautifulSoup(value, "html.parser").get_text("\n", strip=True) + + @staticmethod + def _struct_time_to_utc(value: struct_time | tuple[Any, ...]) -> datetime: + dt = datetime(*value[:6], tzinfo=timezone.utc) + return dt.astimezone(timezone.utc) diff --git a/pyproject.toml b/pyproject.toml index 7003fb5e62..2b83466d17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "editdistance==0.8.1", "elasticsearch-dsl==8.12.0", "exceptiongroup>=1.3.0,<2.0.0", + "feedparser>=6.0.11,<7.0.0", "extract-msg>=0.39.0", "ffmpeg-python>=0.2.0", "flasgger>=0.9.7.1,<0.10.0", diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index 9fde44222d..1e95c4b166 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -43,6 +43,7 @@ from common import settings from common.config_utils import show_configs from common.data_source import ( BlobStorageConnector, + RSSConnector, NotionConnector, DiscordConnector, GoogleDriveConnector, @@ -243,6 +244,26 @@ class GOOGLE_CLOUD_STORAGE(_BlobLikeBase): DEFAULT_BUCKET_TYPE: str = "google_cloud_storage" +class RSS(SyncBase): + SOURCE_NAME: str = FileSource.RSS + + async def _generate(self, task: dict): + self.connector = RSSConnector( + feed_url=self.conf["feed_url"], + batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE), + ) + self.connector.load_credentials(self.conf.get("credentials", {})) + self.connector.validate_connector_settings() + + if task["reindex"] == "1" or not task["poll_range_start"]: + return self.connector.load_from_state() + + return self.connector.poll_source( + task["poll_range_start"].timestamp(), + datetime.now(timezone.utc).timestamp(), + ) + + class Confluence(SyncBase): SOURCE_NAME: str = FileSource.CONFLUENCE @@ -1347,6 +1368,7 @@ class PostgreSQL(SyncBase): func_factory = { + FileSource.RSS: RSS, FileSource.S3: S3, FileSource.R2: R2, FileSource.OCI_STORAGE: OCI_STORAGE, diff --git a/test/testcases/test_common_data_source/test_rss_connector_unit.py b/test/testcases/test_common_data_source/test_rss_connector_unit.py new file mode 100644 index 0000000000..39585bcbd1 --- /dev/null +++ b/test/testcases/test_common_data_source/test_rss_connector_unit.py @@ -0,0 +1,120 @@ +from datetime import datetime, timezone +import importlib +import sys +from pathlib import Path +from types import SimpleNamespace +from types import ModuleType + +import pytest + +import common + +repo_root = Path(__file__).resolve().parents[3] +data_source_pkg = ModuleType("common.data_source") +data_source_pkg.__path__ = [str(repo_root / "common" / "data_source")] +sys.modules["common.data_source"] = data_source_pkg +setattr(common, "data_source", data_source_pkg) + +DocumentSource = importlib.import_module("common.data_source.config").DocumentSource +RSSConnector = importlib.import_module("common.data_source.rss_connector").RSSConnector + + +class _FakeResponse: + def __init__(self, content: bytes = b"feed") -> None: + self.content = content + + def raise_for_status(self) -> None: + return None + + +def _mock_feed(*entries, bozo=False, bozo_exception=None): + return SimpleNamespace( + entries=list(entries), + bozo=bozo, + bozo_exception=bozo_exception, + ) + + +def test_validate_connector_settings_rejects_invalid_feed_url(): + connector = RSSConnector(feed_url="ftp://example.com/feed.xml") + + with pytest.raises(ValueError, match="valid http or https URL"): + connector.validate_connector_settings() + + +def test_validate_connector_settings_rejects_empty_feed(monkeypatch): + monkeypatch.setattr("common.data_source.rss_connector.requests.get", lambda *_args, **_kwargs: _FakeResponse()) + monkeypatch.setattr( + "common.data_source.rss_connector.feedparser.parse", + lambda _content: _mock_feed(), + ) + + connector = RSSConnector(feed_url="https://example.com/feed.xml") + + with pytest.raises(ValueError, match="contains no entries"): + connector.validate_connector_settings() + + +def test_load_from_state_builds_documents(monkeypatch): + monkeypatch.setattr("common.data_source.rss_connector.requests.get", lambda *_args, **_kwargs: _FakeResponse()) + monkeypatch.setattr( + "common.data_source.rss_connector.feedparser.parse", + lambda _content: _mock_feed( + { + "id": "entry-1", + "link": "https://example.com/posts/1", + "title": "Post One", + "content": [{"value": "

Hello world

"}], + "author": "Alice", + "tags": [{"term": "news"}, {"term": "product"}], + "updated": "Tue, 02 Jan 2024 15:04:05 GMT", + } + ), + ) + + connector = RSSConnector(feed_url="https://example.com/feed.xml") + batch = next(connector.load_from_state()) + + assert len(batch) == 1 + doc = batch[0] + assert doc.source == DocumentSource.RSS + assert doc.semantic_identifier == "Post One" + assert doc.extension == ".txt" + assert doc.metadata == { + "feed_url": "https://example.com/feed.xml", + "link": "https://example.com/posts/1", + "author": "Alice", + "categories": ["news", "product"], + } + assert "Hello" in doc.blob.decode("utf-8") + assert "world" in doc.blob.decode("utf-8") + + +def test_poll_source_filters_entries_by_timestamp(monkeypatch): + monkeypatch.setattr("common.data_source.rss_connector.requests.get", lambda *_args, **_kwargs: _FakeResponse()) + monkeypatch.setattr( + "common.data_source.rss_connector.feedparser.parse", + lambda _content: _mock_feed( + { + "id": "entry-1", + "title": "Older", + "summary": "older summary", + "updated": "Mon, 01 Jan 2024 00:00:00 GMT", + }, + { + "id": "entry-2", + "title": "Newer", + "summary": "new summary", + "updated": "Tue, 02 Jan 2024 00:00:00 GMT", + }, + ), + ) + + connector = RSSConnector(feed_url="https://example.com/feed.xml") + start = datetime(2024, 1, 1, tzinfo=timezone.utc).timestamp() + end = datetime(2024, 1, 2, tzinfo=timezone.utc).timestamp() + + batches = list(connector.poll_source(start, end)) + + assert len(batches) == 1 + assert [doc.semantic_identifier for doc in batches[0]] == ["Newer"] diff --git a/uv.lock b/uv.lock index 6fb93fe6c7..eaa5a6e0fb 100644 --- a/uv.lock +++ b/uv.lock @@ -6331,6 +6331,7 @@ dependencies = [ { name = "elasticsearch-dsl" }, { name = "exceptiongroup" }, { name = "extract-msg" }, + { name = "feedparser" }, { name = "ffmpeg-python" }, { name = "flasgger" }, { name = "flask-cors" }, @@ -6471,6 +6472,7 @@ requires-dist = [ { name = "elasticsearch-dsl", specifier = "==8.12.0" }, { name = "exceptiongroup", specifier = ">=1.3.0,<2.0.0" }, { name = "extract-msg", specifier = ">=0.39.0" }, + { name = "feedparser", specifier = ">=6.0.11,<7.0.0" }, { name = "ffmpeg-python", specifier = ">=0.2.0" }, { name = "flasgger", specifier = ">=0.9.7.1,<0.10.0" }, { name = "flask-cors", specifier = "==6.0.2" }, diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 49f288ce6d..e2a342b85c 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -1000,6 +1000,8 @@ Example: Virtual Hosted Style`, newDocs: 'New docs', timeStarted: 'Time started', log: 'Log', + rssDescription: + 'Connect to a public RSS or Atom feed and sync feed entries into your knowledge base.', confluenceDescription: 'Integrate your Confluence workspace to search documentation.', s3Description: diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 5d72eba45f..4f9893995c 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -895,6 +895,8 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 newDocs: '新文档', timeStarted: '开始时间', log: '日志', + rssDescription: + '连接公开的 RSS 或 Atom feed,并将 feed 条目同步到知识库。', confluenceDescription: '连接你的 Confluence 工作区以搜索文档内容。', s3Description: ' 连接你的 AWS S3 存储桶以导入和同步文件。', google_cloud_storageDescription: diff --git a/web/src/pages/user-setting/data-source/constant/index.tsx b/web/src/pages/user-setting/data-source/constant/index.tsx index b06d424b4c..01990ea4f3 100644 --- a/web/src/pages/user-setting/data-source/constant/index.tsx +++ b/web/src/pages/user-setting/data-source/constant/index.tsx @@ -2,7 +2,7 @@ import { FormFieldType } from '@/components/dynamic-form'; import { IconFontFill } from '@/components/icon-font'; import SvgIcon from '@/components/svg-icon'; import { t, TFunction } from 'i18next'; -import { Mail } from 'lucide-react'; +import { Mail, Rss } from 'lucide-react'; import { useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; import BoxTokenField from '../component/box-token-field'; @@ -15,6 +15,7 @@ import { S3Constant } from './s3-constant'; import { seafileConstant } from './seafile-constant'; export enum DataSourceKey { + RSS = 'rss', CONFLUENCE = 'confluence', S3 = 's3', NOTION = 'notion', @@ -47,6 +48,11 @@ export enum DataSourceKey { export const generateDataSourceInfo = (t: TFunction) => { return { + [DataSourceKey.RSS]: { + name: 'RSS', + description: t(`setting.${DataSourceKey.RSS}Description`), + icon: , + }, [DataSourceKey.GOOGLE_CLOUD_STORAGE]: { name: 'Google Cloud Storage', description: t( @@ -222,6 +228,25 @@ export const DataSourceFormBaseFields = [ }, ]; export const DataSourceFormFields = { + [DataSourceKey.RSS]: [ + { + label: 'Feed URL', + name: 'config.feed_url', + type: FormFieldType.Text, + required: true, + placeholder: 'https://example.com/feed.xml', + }, + { + label: 'Batch Size', + name: 'config.batch_size', + type: FormFieldType.Number, + required: false, + validation: { + min: 1, + message: 'Batch Size must be at least 1', + }, + }, + ], [DataSourceKey.GOOGLE_CLOUD_STORAGE]: [ { label: 'GCS Access Key ID', @@ -965,6 +990,14 @@ export const DataSourceFormFields = { }; export const DataSourceFormDefaultValues = { + [DataSourceKey.RSS]: { + name: '', + source: DataSourceKey.RSS, + config: { + feed_url: '', + batch_size: 2, + }, + }, [DataSourceKey.S3]: { name: '', source: DataSourceKey.S3,