diff --git a/api/apps/restful_apis/connector_api.py b/api/apps/restful_apis/connector_api.py index 99a5893021..5e428d1a22 100644 --- a/api/apps/restful_apis/connector_api.py +++ b/api/apps/restful_apis/connector_api.py @@ -133,6 +133,54 @@ def rm_connector(connector_id): return get_json_result(data=True) +@manager.route("/connectors//test", methods=["POST"]) # noqa: F821 +@login_required +async def test_connector(connector_id): + """Validate connector configuration without persisting changes or triggering sync. + + For the REST API connector, this uses `RestAPIConnector.validate_config` + against the existing saved configuration. + """ + from common.data_source.rest_api_connector import RestAPIConnector + from common.data_source.exceptions import ConnectorMissingCredentialError, ConnectorValidationError + + ok, conn = ConnectorService.get_by_id(connector_id) + if not ok: + return get_data_error_result(message="Can't find this Connector!") + + if conn.source != DocumentSource.REST_API: + return get_json_result( + code=RetCode.ARGUMENT_ERROR, + message="Test endpoint currently supports only REST API connectors.", + data=False, + ) + + config = conn.config or {} + credentials = config.get("credentials") or {} + + try: + await asyncio.to_thread( + RestAPIConnector.validate_config, + config=config, + credentials=credentials, + ) + except (ConnectorValidationError, ConnectorMissingCredentialError) as exc: + return get_json_result( + code=RetCode.DATA_ERROR, + message=str(exc), + data=False, + ) + except Exception as exc: + logging.exception("REST API connector validation failed: %s", exc) + return get_json_result( + code=RetCode.SERVER_ERROR, + message="REST API connector validation failed, please check logs.", + data=False, + ) + + return get_json_result(data=True) + + WEB_FLOW_TTL_SECS = 15 * 60 diff --git a/common/constants.py b/common/constants.py index 5ab9acaa50..1e83a77041 100644 --- a/common/constants.py +++ b/common/constants.py @@ -117,6 +117,7 @@ class FileSource(StrEnum): RSS = "rss" S3 = "s3" NOTION = "notion" + REST_API = "rest_api" DISCORD = "discord" CONFLUENCE = "confluence" GMAIL = "gmail" diff --git a/common/data_source/__init__.py b/common/data_source/__init__.py index 301103652c..34bb467d9f 100644 --- a/common/data_source/__init__.py +++ b/common/data_source/__init__.py @@ -44,6 +44,7 @@ from .zendesk_connector import ZendeskConnector from .seafile_connector import SeaFileConnector from .rdbms_connector import RDBMSConnector from .webdav_connector import WebDAVConnector +from .rest_api_connector import RestAPIConnector from .config import BlobType, DocumentSource from .models import Document, TextSection, ImageSection, BasicExpertInfo from .exceptions import ( @@ -87,4 +88,5 @@ __all__ = [ "RDBMSConnector", "WebDAVConnector", "DingTalkAITableConnector", + "RestAPIConnector", ] diff --git a/common/data_source/config.py b/common/data_source/config.py index 2b512d4ce2..08d0209e03 100644 --- a/common/data_source/config.py +++ b/common/data_source/config.py @@ -43,6 +43,7 @@ class DocumentSource(str, Enum): RSS = "rss" S3 = "s3" NOTION = "notion" + REST_API = "rest_api" R2 = "r2" GOOGLE_CLOUD_STORAGE = "google_cloud_storage" OCI_STORAGE = "oci_storage" diff --git a/common/data_source/rest_api_connector.py b/common/data_source/rest_api_connector.py new file mode 100644 index 0000000000..8616be2730 --- /dev/null +++ b/common/data_source/rest_api_connector.py @@ -0,0 +1,1012 @@ +"""Generic, configuration-driven REST API data source connector. + +Connect any REST API as a RAGFlow data source without code changes. +All behaviour — URL, auth, pagination, field mapping — is controlled +via the ``RestAPIConnectorConfig`` schema exposed by the UI. +""" + +from __future__ import annotations + +import json +import logging +import re +import time +from datetime import datetime, timezone +from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional +from urllib.parse import parse_qs, urlparse, urlunparse + +import ipaddress +import socket +import requests +from pydantic import BaseModel, ConfigDict, Field, HttpUrl, ValidationError, field_validator + +logger = logging.getLogger(__name__) + +from api.utils.common import hash128 +from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource +from common.data_source.exceptions import ( + ConnectorMissingCredentialError, + ConnectorValidationError, +) +from common.data_source.interfaces import ( + LoadConnector, + PollConnector, + SecondsSinceUnixEpoch, +) +from common.data_source.models import Document +from common.data_source.utils import rl_requests, retry_builder + +try: + from jsonpath import jsonpath as _jsonpath # type: ignore[import] +except Exception: # pragma: no cover + _jsonpath = None + +_FIELD_SEGMENT_RE = re.compile(r'^(?P[^\[\]]+)(\[(?P\d+|\*)\])?$') +_DEFAULT_MAX_PAGES = 1000 + + +class AuthType: + NONE = "none" + API_KEY_HEADER = "api_key_header" + BEARER = "bearer" + BASIC = "basic" + + +class PaginationType: + NONE = "none" + PAGE = "page" + OFFSET = "offset" + CURSOR = "cursor" + + +def _text_to_dict(v: Any) -> Dict[str, str]: + """Parse a dict, JSON string, or ``key=value`` text (one per line) into a dict. + + This is module-level because Pydantic ``@field_validator`` classmethods + on ``RestAPIConnectorConfig`` need to call it before any instance exists. + """ + if v is None or v == "": + return {} + if isinstance(v, dict): + return {str(k): str(vv) for k, vv in v.items()} + if isinstance(v, str): + try: + parsed = json.loads(v) + if isinstance(parsed, dict): + return {str(k): str(vv) for k, vv in parsed.items()} + except Exception: + pass + result: Dict[str, str] = {} + for line in v.splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + if "=" in line: + k, _, val = line.partition("=") + result[k.strip()] = val.strip() + return result + return {} + + +class RestAPIConnectorConfig(BaseModel): + """Validated schema for the REST API connector configuration.""" + + model_config = ConfigDict(extra="ignore") + + url: HttpUrl + method: str = "GET" + headers: Dict[str, str] = Field(default_factory=dict) + query_params: Dict[str, str] = Field(default_factory=dict) + + auth_type: str = AuthType.NONE + auth_config: Dict[str, Any] = Field(default_factory=dict) + + items_path: Optional[str] = None + id_field: Optional[str] = None + content_fields: List[str] = Field(default_factory=list) + metadata_fields: List[str] = Field(default_factory=list) + + pagination_type: str = PaginationType.NONE + pagination_config: Dict[str, Any] = Field(default_factory=dict) + + poll_timestamp_field: Optional[str] = None + request_body: Optional[Dict[str, Any]] = None + + field_type_hints: Dict[str, str] = Field(default_factory=dict) + field_default_values: Dict[str, Any] = Field(default_factory=dict) + content_template: Optional[str] = None + + batch_size: int = INDEX_BATCH_SIZE + max_pages: int = _DEFAULT_MAX_PAGES + request_delay: float = 0.5 + + @field_validator("headers", mode="before") + @classmethod + def _coerce_headers(cls, v: Any) -> Dict[str, str]: + return _text_to_dict(v) + + @field_validator("query_params", mode="before") + @classmethod + def _coerce_query_params(cls, v: Any) -> Dict[str, str]: + return _text_to_dict(v) + + @field_validator("content_fields", "metadata_fields", mode="before") + @classmethod + def _coerce_field_list(cls, v: Any) -> List[str]: + if v is None or v == "": + return [] + if isinstance(v, str): + return [p.strip() for p in v.split(",") if p.strip()] + if isinstance(v, list): + return [str(p).strip() for p in v if str(p).strip()] + return [] + + def normalized_method(self) -> str: + m = (self.method or "GET").upper() + if m not in {"GET", "POST"}: + raise ConnectorValidationError(f"Unsupported HTTP method '{m}'.") + return m + + def normalized_auth_type(self) -> str: + if self.auth_type not in {AuthType.NONE, AuthType.API_KEY_HEADER, AuthType.BEARER, AuthType.BASIC}: + raise ConnectorValidationError(f"Unsupported auth_type '{self.auth_type}'.") + return self.auth_type + + def normalized_pagination_type(self) -> str: + if self.pagination_type not in {PaginationType.NONE, PaginationType.PAGE, PaginationType.OFFSET, PaginationType.CURSOR}: + raise ConnectorValidationError(f"Unsupported pagination_type '{self.pagination_type}'.") + return self.pagination_type + + def ensure_required_fields(self) -> None: + if not self.content_fields: + raise ConnectorValidationError("At least one content field must be configured (content_fields).") + + +class RestAPIConnector(LoadConnector, PollConnector): + """Configuration-driven REST API connector. + + Implements ``LoadConnector`` and ``PollConnector`` to fetch documents + from any REST API using user-provided configuration (URL, auth, + pagination, field mapping). + """ + + @staticmethod + def _validate_url_for_ssrf(url: str) -> None: + """Validate that the URL does not point to localhost or private/internal networks. + + Raises: + ConnectorValidationError: If the URL is considered unsafe. + """ + parsed = urlparse(str(url)) + + if parsed.scheme not in ("http", "https"): + msg = f"Unsupported URL scheme for REST API connector: {parsed.scheme!r}. Only http/https are allowed." + logger.warning(msg) + raise ConnectorValidationError(msg) + + hostname = parsed.hostname + if not hostname: + msg = "REST API connector URL must include a hostname." + logger.warning(msg) + raise ConnectorValidationError(msg) + + # Quick checks for obvious localhost-style hostnames. + lower_host = hostname.lower() + if lower_host in ("localhost",): + msg = f"REST API connector URL hostname {hostname!r} is not allowed (localhost is blocked)." + logger.warning(msg) + raise ConnectorValidationError(msg) + + try: + addrinfo_list = socket.getaddrinfo(hostname, None) + except OSError as exc: + # If resolution fails, log and let higher-level validation (if any) decide. + # We do not treat this as an SSRF condition by itself. + logger.info("DNS resolution failed for REST API connector URL %r: %s", url, exc) + return + + for family, _, _, _, sockaddr in addrinfo_list: + ip_str = sockaddr[0] + try: + ip_obj = ipaddress.ip_address(ip_str) + except ValueError: + # Not an IP address we understand; skip. + logger.debug("Skipping non-IP address resolved from %r: %r", hostname, ip_str) + continue + + if ( + ip_obj.is_loopback + or ip_obj.is_private + or ip_obj.is_link_local + or ip_obj.is_reserved + or ip_obj.is_multicast + ): + msg = ( + f"REST API connector URL {url!r} resolves to disallowed address {ip_str} " + "(localhost, private, link-local, reserved, or multicast addresses are blocked)." + ) + logger.warning(msg) + raise ConnectorValidationError(msg) + + logger.debug("REST API connector URL %r passed SSRF safety validation.", url) + + def __init__( + self, + url: str, + method: str = "GET", + headers: Optional[Dict[str, str]] = None, + query_params: Optional[Dict[str, str]] = None, + auth_type: str = AuthType.NONE, + auth_config: Optional[Dict[str, Any]] = None, + items_path: Optional[str] = None, + id_field: Optional[str] = None, + content_fields: Optional[List[str]] = None, + metadata_fields: Optional[List[str]] = None, + pagination_type: str = PaginationType.NONE, + pagination_config: Optional[Dict[str, Any]] = None, + poll_timestamp_field: Optional[str] = None, + batch_size: int = INDEX_BATCH_SIZE, + max_pages: int = _DEFAULT_MAX_PAGES, + request_delay: float = 0.5, + request_body: Optional[Dict[str, Any]] = None, + field_type_hints: Optional[Dict[str, str]] = None, + field_default_values: Optional[Dict[str, Any]] = None, + content_template: Optional[str] = None, + ) -> None: + # Validate URL against SSRF-style targets (localhost, private/internal ranges, etc.) + self._validate_url_for_ssrf(url) + + parsed = urlparse(str(url)) + self._base_url = urlunparse((parsed.scheme, parsed.netloc, parsed.path, "", "", "")) + self._url_params: Dict[str, str] = {} + if parsed.query: + for k, v_list in parse_qs(parsed.query, keep_blank_values=True).items(): + self._url_params[k] = v_list[-1] + + self._explicit_query_params: Dict[str, str] = ( + _text_to_dict(query_params) if isinstance(query_params, str) else (query_params or {}) + ) + self.url = self._base_url + self.method = (method or "GET").upper() + self._base_headers: Dict[str, str] = ( + _text_to_dict(headers) if isinstance(headers, str) else (headers or {}) + ) + self.auth_type = auth_type or AuthType.NONE + self.auth_config: Dict[str, Any] = auth_config or {} + self.items_path = items_path + self.id_field = id_field + self.content_fields: List[str] = content_fields or [] + self.metadata_fields: List[str] = metadata_fields or [] + self.pagination_type = pagination_type or PaginationType.NONE + self.pagination_config: Dict[str, Any] = pagination_config or {} + self._static_request_body: Dict[str, Any] = ( + request_body if request_body is not None + else self.pagination_config.get("request_body") or {} + ) + self.poll_timestamp_field = poll_timestamp_field + self.batch_size = batch_size + self.max_pages = max_pages + self.request_delay = max(request_delay, 0.0) + self.field_type_hints: Dict[str, str] = field_type_hints or {} + self.field_default_values: Dict[str, Any] = field_default_values or {} + self.content_template = content_template + + self._credentials: Dict[str, Any] = {} + self._auth_headers: Dict[str, str] = {} + self._basic_auth: Optional[requests.auth.HTTPBasicAuth] = None + + # -- Credentials -------------------------------------------------------- + + def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None: + """Apply authentication credentials (no network call). + + Use ``validate_config()`` to perform a live connectivity check. + """ + self._credentials = credentials or {} + self._build_auth() + return None + + def _build_auth(self) -> None: + """Derive auth headers / basic-auth object from credentials.""" + self._auth_headers = {} + self._basic_auth = None + + if self.auth_type == AuthType.NONE: + logging.info("REST API auth_type=none, no authentication configured.") + return + + if self.auth_type == AuthType.API_KEY_HEADER: + header_name = self.auth_config.get("header_name") + api_key = ( + self._credentials.get("api_key") + or self.auth_config.get("api_key_value") + or self.auth_config.get("api_key") + ) + if not header_name or not api_key: + logging.warning( + "REST API auth setup failed: header_name=%s, api_key present=%s, " + "credentials keys=%s, auth_config keys=%s", + header_name, bool(api_key), + list(self._credentials.keys()), list(self.auth_config.keys()), + ) + raise ConnectorMissingCredentialError( + "REST API (api_key_header) requires 'header_name' in auth_config and 'api_key' in credentials" + ) + self._auth_headers[header_name] = str(api_key) + logging.info("REST API auth configured: header '%s' set.", header_name) + return + + if self.auth_type == AuthType.BEARER: + token = self._credentials.get("token") or self.auth_config.get("token") + if not token: + raise ConnectorMissingCredentialError("REST API (bearer) requires 'token' in credentials") + self._auth_headers["Authorization"] = f"Bearer {token}" + logging.info("REST API auth configured: Bearer token set.") + return + + if self.auth_type == AuthType.BASIC: + username = self._credentials.get("username") or self.auth_config.get("username") + password = self._credentials.get("password") or self.auth_config.get("password") + if not username or password is None: + raise ConnectorMissingCredentialError("REST API (basic) requires 'username' and 'password'") + self._basic_auth = requests.auth.HTTPBasicAuth(str(username), str(password)) + logging.info("REST API auth configured: Basic auth for user '%s'.", username) + return + + raise ConnectorValidationError(f"Unsupported auth_type: {self.auth_type}") + + # -- Config validation (test connection) -------------------------------- + + @classmethod + def parse_storage_config(cls, raw: Dict[str, Any]) -> RestAPIConnectorConfig: + """Parse connector config as stored on the connector row (no network I/O). + + ``credentials`` live under ``raw`` but are excluded from the schema and + must be applied via ``load_credentials`` separately. + """ + body = {k: v for k, v in raw.items() if k != "credentials"} + try: + cfg = RestAPIConnectorConfig(**body) + except ValidationError as exc: + raise ConnectorValidationError(f"Invalid REST API config: {exc}") from exc + cfg.normalized_method() + cfg.normalized_auth_type() + cfg.normalized_pagination_type() + cfg.ensure_required_fields() + return cfg + + @classmethod + def from_parsed_config( + cls, + cfg: RestAPIConnectorConfig, + *, + max_pages: Optional[int] = None, + ) -> RestAPIConnector: + """Build a connector from validated config (``__init__`` runs SSRF validation).""" + return cls( + url=str(cfg.url), + method=cfg.normalized_method(), + headers=cfg.headers, + query_params=cfg.query_params, + auth_type=cfg.normalized_auth_type(), + auth_config=cfg.auth_config, + items_path=cfg.items_path, + id_field=cfg.id_field, + content_fields=cfg.content_fields, + metadata_fields=cfg.metadata_fields, + pagination_type=cfg.normalized_pagination_type(), + pagination_config=cfg.pagination_config, + poll_timestamp_field=cfg.poll_timestamp_field, + batch_size=cfg.batch_size, + max_pages=max_pages if max_pages is not None else cfg.max_pages, + request_delay=cfg.request_delay, + request_body=cfg.request_body, + field_type_hints=cfg.field_type_hints, + field_default_values=cfg.field_default_values, + content_template=cfg.content_template, + ) + + @classmethod + def validate_config( + cls, + config: Dict[str, Any], + credentials: Optional[Dict[str, Any]] = None, + ) -> RestAPIConnectorConfig: + """Validate config schema and optionally perform a live API call. + + Args: + config: Raw config dict from the UI / database. + credentials: Optional credentials dict; when provided a live + connectivity check is performed. + + Returns: + The validated ``RestAPIConnectorConfig`` instance. + + Raises: + ConnectorValidationError: On schema or connectivity failure. + """ + cfg = cls.parse_storage_config(config) + validation_cap = min(cfg.max_pages, 10) + connector = cls.from_parsed_config(cfg, max_pages=validation_cap) + + if credentials is None and cfg.auth_type != AuthType.NONE: + return cfg + + if credentials is not None: + connector.load_credentials(credentials) + else: + connector._credentials = {} + connector._build_auth() + + try: + logging.info("Validating REST API connector by fetching first page") + _ = next(connector._page_iter_for_validation()) + except StopIteration: + pass + + return cfg + + # -- LoadConnector / PollConnector interface ----------------------------- + + def load_from_state(self) -> Generator[List[Document], None, None]: + """Full fetch with pagination.""" + return self._yield_documents(time_window=None) + + def poll_source( + self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch + ) -> Generator[List[Document], None, None]: + """Incremental fetch; filters by ``poll_timestamp_field`` if configured.""" + if not self.poll_timestamp_field: + logging.warning( + "poll_source called without poll_timestamp_field; " + "falling back to full fetch with in-memory filtering." + ) + return self._yield_documents( + time_window=( + datetime.fromtimestamp(start, tz=timezone.utc), + datetime.fromtimestamp(end, tz=timezone.utc), + ) + ) + + # -- Document generation ------------------------------------------------ + + def _yield_documents( + self, + time_window: tuple[datetime, datetime] | None, + ) -> Generator[List[Document], None, None]: + batch: List[Document] = [] + for item in self._iter_items(): + try: + doc = self._item_to_document(item) + except Exception as exc: + logging.warning("Failed to convert REST API item to Document: %s", exc) + continue + + if time_window is not None and not self._doc_in_time_window(doc, *time_window): + continue + + batch.append(doc) + if len(batch) >= self.batch_size: + yield batch + batch = [] + + if batch: + yield batch + + # -- Pagination & page fetching ----------------------------------------- + + def _iter_items(self) -> Iterable[Mapping[str, Any]]: + """Iterate over raw items across all pages.""" + page_count = 0 + + page = int(self.pagination_config.get("start_page", 1)) + per_page = self._resolve_page_size() + + offset = int(self.pagination_config.get("start_offset", 0)) + limit = int(self.pagination_config.get("limit", per_page)) + if limit <= 0: + limit = per_page + + cursor: Optional[str] = self.pagination_config.get("initial_cursor") + + while True: + if page_count >= self.max_pages: + logging.warning("REST API connector reached max_pages=%d, stopping.", self.max_pages) + break + + params: Dict[str, Any] = {} + if self.pagination_type == PaginationType.PAGE: + self._apply_page_pagination(params, page, per_page) + elif self.pagination_type == PaginationType.OFFSET: + self._apply_offset_pagination(params, offset, limit) + elif self.pagination_type == PaginationType.CURSOR and cursor is not None: + self._apply_cursor_pagination(params, cursor) + + if page_count > 0 and self.request_delay > 0: + time.sleep(self.request_delay) + + try: + response_json = self._fetch_page(params) + except (ConnectorValidationError, ConnectorMissingCredentialError): + raise + except Exception as exc: + raise ConnectorValidationError(f"REST API page fetch failed: {exc}") from exc + + items = self._extract_items(response_json) + if not items: + break + + for item in items: + if isinstance(item, Mapping): + yield item + + page_count += 1 + + if self.pagination_type == PaginationType.NONE: + break + elif self.pagination_type == PaginationType.PAGE: + if len(items) < per_page: + break + page += 1 + elif self.pagination_type == PaginationType.OFFSET: + if len(items) < limit: + break + offset += limit + elif self.pagination_type == PaginationType.CURSOR: + next_cursor = self._extract_next_cursor(response_json) + if not next_cursor: + break + cursor = next_cursor + + def _page_iter_for_validation(self) -> Iterable[Mapping[str, Any]]: + """Single-page iterator used for connectivity checks.""" + params: Dict[str, Any] = {} + if self.pagination_type == PaginationType.PAGE: + page = int(self.pagination_config.get("start_page", 1)) + per_page = self._resolve_page_size() + self._apply_page_pagination(params, page, per_page) + elif self.pagination_type == PaginationType.OFFSET: + per_page = self._resolve_page_size() + offset = int(self.pagination_config.get("start_offset", 0)) + limit = int(self.pagination_config.get("limit", per_page)) + if limit <= 0: + limit = per_page + self._apply_offset_pagination(params, offset, limit) + elif self.pagination_type == PaginationType.CURSOR: + cursor = self.pagination_config.get("initial_cursor") + if cursor is not None: + self._apply_cursor_pagination(params, cursor) + + response_json = self._fetch_page(params=params) + for item in self._extract_items(response_json): + yield item + + @retry_builder( + tries=5, delay=1, max_delay=30, backoff=2, + exceptions=(requests.ConnectionError, requests.Timeout, requests.HTTPError), + ) + def _fetch_page(self, params: Dict[str, Any]) -> Any: + """Fetch a single page with retry and exponential backoff.""" + headers = {**self._base_headers, **self._auth_headers} + + merged: Dict[str, Any] = {**self._url_params} + merged.update(self._explicit_query_params) + merged.update(params) + + url, query_params = self._build_url_with_templates(merged) + + sensitive = {"authorization", "apikey", "api-key", "x-api-key"} + logging.debug( + "REST API request: %s %s | params=%s | headers=%s", + self.method, url, + {k: ("***" if k.lower() in sensitive else v) for k, v in query_params.items()}, + {k: ("***" if k.lower() in sensitive else v) for k, v in headers.items()}, + ) + + if self.method == "GET": + resp = rl_requests.get(url, headers=headers, params=query_params, auth=self._basic_auth, timeout=60) + elif self.method == "POST": + resp = rl_requests.post( + url, headers=headers, params=query_params, + json=self._static_request_body or {}, auth=self._basic_auth, timeout=60, + ) + else: + raise ConnectorValidationError(f"Unsupported HTTP method: {self.method}") + + try: + resp.raise_for_status() + except requests.HTTPError as exc: + status = exc.response.status_code if exc.response is not None else None + if status in (401, 403): + sensitive = {"authorization", "apikey", "api-key", "x-api-key"} + logging.warning( + "REST API %d for %s %s | auth_type=%s | " + "request header keys=%s | auth_header keys=%s", + status, self.method, resp.url, + self.auth_type, + [k for k in headers], + [k for k in self._auth_headers], + ) + raise ConnectorMissingCredentialError( + f"REST API authentication failed with status {status}" + ) from exc + if status is not None and 400 <= status < 500 and status != 429: + logging.warning( + "REST API client error %d for %s %s; not retrying.", + status, + self.method, + resp.url, + ) + raise ConnectorValidationError( + f"REST API request failed with non-retriable client error status {status}" + ) from exc + raise + + try: + return resp.json() + except ValueError as exc: + raise ConnectorValidationError("REST API response is not valid JSON") from exc + + def _build_url_with_templates(self, params: Dict[str, Any]) -> tuple[str, Dict[str, Any]]: + """Substitute ``{key}`` placeholders in the URL; return remaining query params.""" + url = self.url + query_params = dict(params) + used_keys: List[str] = [] + for key, value in list(query_params.items()): + placeholder = "{" + key + "}" + if placeholder in url: + url = url.replace(placeholder, str(value)) + used_keys.append(key) + for key in used_keys: + query_params.pop(key, None) + return url, query_params + + # -- Pagination helpers ------------------------------------------------- + + def _resolve_page_size(self) -> int: + """Determine per-page size from config, query params, or batch_size fallback. + + Priority: explicit ``page_size`` in pagination_config > value already + present in user query params for the same param name > batch_size. + """ + explicit = self.pagination_config.get("page_size") + if explicit is not None: + val = int(explicit) + if val > 0: + return val + + size_param = self.pagination_config.get("page_size_param") or self.pagination_config.get("limit_param") + if size_param: + for source in (self._explicit_query_params, self._url_params): + if size_param in source: + try: + val = int(source[size_param]) + if val > 0: + return val + except (ValueError, TypeError): + pass + + return self.batch_size + + def _apply_page_pagination(self, params: Dict[str, Any], page: int, per_page: int) -> None: + params[self.pagination_config.get("page_param", "page")] = page + size_param = self.pagination_config.get("page_size_param") + if size_param: + params[size_param] = per_page + + def _apply_offset_pagination(self, params: Dict[str, Any], offset: int, limit: int) -> None: + params[self.pagination_config.get("offset_param", "offset")] = offset + limit_param = self.pagination_config.get("limit_param") + if limit_param: + params[limit_param] = limit + + def _apply_cursor_pagination(self, params: Dict[str, Any], cursor: str) -> None: + params[self.pagination_config.get("cursor_param", "cursor")] = cursor + + # -- JSON extraction ---------------------------------------------------- + + def _extract_items(self, response_json: Any) -> List[Mapping[str, Any]]: + """Extract the items array from a JSON response.""" + if self.items_path and _jsonpath is not None: + try: + matches = _jsonpath(response_json, self.items_path) + except Exception as exc: + raise ConnectorValidationError( + f"Failed to apply items JSONPath '{self.items_path}': {exc}" + ) from exc + if not matches: + return [] + if len(matches) == 1 and isinstance(matches[0], list): + items = matches[0] + else: + items = matches + elif isinstance(response_json, list): + items = response_json + elif isinstance(response_json, dict): + items = [] + for key in ("items", "results", "data", "records"): + if key in response_json and isinstance(response_json[key], list): + items = response_json[key] + break + else: + for value in response_json.values(): + if isinstance(value, list): + items = value + break + else: + items = [] + + return [it for it in items if isinstance(it, Mapping)] + + def _extract_next_cursor(self, response_json: Any) -> Optional[str]: + """Extract cursor value for cursor-based pagination.""" + cursor_path = self.pagination_config.get("next_cursor_path") + if not cursor_path: + field = self.pagination_config.get("next_cursor_field") + if field and isinstance(response_json, Mapping): + value = response_json.get(field) + return str(value) if value is not None else None + return None + + if _jsonpath is None: + return None + + try: + matches = _jsonpath(response_json, cursor_path) + except Exception: + return None + + if not matches: + return None + return str(matches[0]) if matches[0] is not None else None + + # -- Item → Document mapping -------------------------------------------- + + def _item_to_document(self, item: Mapping[str, Any]) -> Document: + """Map a single API item to a ``Document``.""" + raw_id = self._get_typed_field_value(self.id_field, item) if self.id_field else None + if raw_id is None: + raw_id = hash128(f"rest_api_item:{repr(item)}") + doc_id = hash128(f"rest_api:{raw_id}") + + if self.content_template: + content_text = self._render_content_template(item) + else: + parts = [] + for field in self.content_fields: + val = self._get_typed_field_value(field, item) + if val is not None: + text = self._strip_html(self._coerce_to_text(val)) + if text: + parts.append(text) + content_text = "\n".join(parts) + blob = content_text.encode("utf-8") + + metadata: Dict[str, Any] = {} + for field in self.metadata_fields: + value = self._get_typed_field_value(field, item) + if value is not None: + metadata[field] = self._serialize_metadata_value(value) + + doc_updated_at = self._extract_timestamp(item) or datetime.now(timezone.utc) + + sem = str(self._extract_field(item, self.content_fields[0]) if self.content_fields else raw_id) + sem = self._strip_html(sem).replace("\n", " ").replace("\r", " ").strip()[:100] or str(doc_id) + + return Document( + id=doc_id, + source=DocumentSource.REST_API, + semantic_identifier=sem, + extension=".txt", + blob=blob, + doc_updated_at=doc_updated_at, + size_bytes=len(blob), + metadata=metadata or None, + ) + + # -- Field extraction --------------------------------------------------- + + def _extract_field(self, item: Mapping[str, Any], path: str) -> Any: + """Extract a value using dot-notation with optional array indexing. + + Examples: ``country.name``, ``tags[0].label``, ``tags[*].label`` + """ + values = self._extract_field_values(item, path) + if not values: + return None + return values[0] if len(values) == 1 else values + + def _extract_field_values(self, item: Mapping[str, Any], path: str) -> List[Any]: + """Return all raw values for a dot-notation field path with wildcards.""" + if not path: + return [] + + current_values: List[Any] = [item] + for segment in path.split("."): + if not segment: + return [] + + match = _FIELD_SEGMENT_RE.match(segment) + key = segment + index: Optional[str] = None + if match: + key = match.group("key") + index = match.group("index") + + next_values: List[Any] = [] + for value in current_values: + if not isinstance(value, Mapping): + continue + child = value.get(key) + if child is None: + continue + if index is None: + next_values.append(child) + elif not isinstance(child, list): + continue + elif index == "*": + next_values.extend(child) + else: + try: + idx = int(index) + except ValueError: + continue + if 0 <= idx < len(child): + next_values.append(child[idx]) + + current_values = next_values + if not current_values: + break + + return current_values + + def _get_typed_field_value(self, path: str, item: Mapping[str, Any]) -> Any: + """Extract a field value, applying type hints, defaults, and array joining.""" + values = self._extract_field_values(item, path) + if not values: + return self.field_default_values.get(path) + + hint = self.field_type_hints.get(path) + + def _convert(v: Any) -> Any: + if hint == "string": + return "" if v is None else str(v) + if hint == "number": + if v is None: + return None + try: + num = float(v) + return int(num) if num.is_integer() else num + except Exception: + return None + if hint == "date": + if isinstance(v, datetime): + return v.isoformat() + dt = self._parse_datetime(v) + if dt is not None: + return dt.isoformat() + return str(v) if v is not None else None + return v + + converted = [_convert(v) for v in values] + non_null = [v for v in converted if v is not None] + if not non_null: + return None + if len(non_null) == 1: + return non_null[0] + return ", ".join(self._coerce_to_text(v) for v in non_null) + + # -- Timestamp parsing -------------------------------------------------- + + def _extract_timestamp(self, item: Mapping[str, Any]) -> Optional[datetime]: + """Extract and normalise a timestamp from ``poll_timestamp_field``.""" + if not self.poll_timestamp_field: + return None + + value = self._extract_field(item, self.poll_timestamp_field) + if isinstance(value, list) and value: + value = value[0] + return self._parse_datetime(value) + + @staticmethod + def _parse_datetime(value: Any) -> Optional[datetime]: + """Parse a raw value into a UTC datetime, or return None.""" + if value is None: + return None + + if isinstance(value, datetime): + return (value if value.tzinfo else value.replace(tzinfo=timezone.utc)).astimezone(timezone.utc) + + if isinstance(value, (int, float)): + try: + return datetime.fromtimestamp(float(value), tz=timezone.utc) + except Exception: + return None + + if isinstance(value, str): + ts = value.strip() + for fmt in ("%Y-%m-%dT%H:%M:%S.%fZ", "%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%d %H:%M:%S", "%Y-%m-%d"): + try: + return datetime.strptime(ts, fmt).replace(tzinfo=timezone.utc) + except Exception: + continue + try: + dt = datetime.fromisoformat(ts.replace("Z", "+00:00").replace(" ", "T")) + return (dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc)).astimezone(timezone.utc) + except Exception: + return None + + return None + + # -- Content template rendering ----------------------------------------- + + class _SafeDict(dict): + """Dict subclass that returns empty string for missing keys in format_map.""" + def __missing__(self, key: str) -> str: + return "" + + def _render_content_template(self, item: Mapping[str, Any]) -> str: + """Render content using a user-provided template with ``{field}`` placeholders.""" + template = self.content_template or "" + values: Dict[str, str] = {} + for field_path in set(self.content_fields + self.metadata_fields): + val = self._get_typed_field_value(field_path, item) + if val is None: + continue + name = re.sub(r"\[\d+\]|\[\*\]", "", field_path).replace(".", "_") + values[name] = self._coerce_to_text(val) + + try: + rendered = template.format_map(self._SafeDict(values)) + except Exception as exc: + logging.warning("Failed to render content template: %s", exc) + parts = [self._coerce_to_text(self._get_typed_field_value(f, item)) for f in self.content_fields] + rendered = "\n".join(p for p in parts if p) + + return self._strip_html(rendered) + + # -- Static helpers ----------------------------------------------------- + + @staticmethod + def _strip_html(text: str) -> str: + """Remove basic HTML tags and normalise whitespace.""" + if "<" not in text or ">" not in text: + return text + cleaned = re.sub(r"<[^>]+>", " ", text) + return re.sub(r"\s+", " ", cleaned).strip() + + @staticmethod + def _coerce_to_text(value: Any) -> str: + """Convert any value to a plain-text string.""" + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, (int, float, bool)): + return str(value) + try: + return json.dumps(value, ensure_ascii=False) + except Exception: + return str(value) + + @staticmethod + def _serialize_metadata_value(value: Any) -> Any: + """Serialise a metadata value for storage.""" + if isinstance(value, datetime): + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + return value.isoformat() + if isinstance(value, (int, float, bool, str)): + return value + try: + return json.dumps(value, ensure_ascii=False) + except Exception: + return str(value) + + @staticmethod + def _doc_in_time_window(doc: Document, start: datetime, end: datetime) -> bool: + if not doc.doc_updated_at: + return False + dt = doc.doc_updated_at + dt = (dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc)).astimezone(timezone.utc) + return start <= dt < end diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index 92ab86b023..7d8ccdadc2 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -41,7 +41,9 @@ from api.db.services.connector_service import ConnectorService, SyncLogsService from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from common import settings +from common.constants import FileSource, TaskStatus from common.config_utils import show_configs +from common.data_source.config import INDEX_BATCH_SIZE from common.data_source import ( BlobStorageConnector, RSSConnector, @@ -58,9 +60,8 @@ from common.data_source import ( SeaFileConnector, RDBMSConnector, DingTalkAITableConnector, + RestAPIConnector, ) -from common.constants import FileSource, TaskStatus -from common.data_source.config import INDEX_BATCH_SIZE from common.data_source.models import ConnectorFailure, SeafileSyncScope from common.data_source.webdav_connector import WebDAVConnector from common.data_source.confluence_connector import ConfluenceConnector @@ -70,6 +71,7 @@ from common.data_source.github.connector import GithubConnector from common.data_source.gitlab_connector import GitlabConnector from common.data_source.bitbucket.connector import BitbucketConnector from common.data_source.interfaces import CheckpointOutputWrapper +from common.data_source.exceptions import ConnectorValidationError from common.log_utils import init_root_logger from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc from common.versions import get_ragflow_version @@ -1819,6 +1821,33 @@ class PostgreSQL(_RDBMSBase): DEFAULT_PORT: int = 5432 +class REST_API(SyncBase): + SOURCE_NAME: str = FileSource.REST_API + + async def _generate(self, task: dict): + try: + cfg = RestAPIConnector.parse_storage_config(self.conf) + except ConnectorValidationError as exc: + raise ValueError(str(exc)) from exc + + self.connector = RestAPIConnector.from_parsed_config(cfg) + self.connector.load_credentials(self.conf.get("credentials") or {}) + + poll_start = task.get("poll_range_start") + if task.get("reindex") == "1" or poll_start is None: + document_generator = self.connector.load_from_state() + begin_info = "totally" + else: + document_generator = self.connector.poll_source( + poll_start.timestamp(), + datetime.now(timezone.utc).timestamp(), + ) + begin_info = f"from {poll_start}" + + logging.info("Connect to REST API: %s %s %s", self.conf.get("method", "GET"), self.conf.get("url"), begin_info) + return document_generator + + func_factory = { FileSource.RSS: RSS, FileSource.S3: S3, @@ -1849,6 +1878,7 @@ func_factory = { FileSource.MYSQL: MySQL, FileSource.POSTGRESQL: PostgreSQL, FileSource.DINGTALK_AI_TABLE: DingTalkAITable, + FileSource.REST_API: REST_API, } diff --git a/test/unit_test/data_source/__init__.py b/test/unit_test/data_source/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/unit_test/data_source/conftest.py b/test/unit_test/data_source/conftest.py new file mode 100644 index 0000000000..0ba0f86071 --- /dev/null +++ b/test/unit_test/data_source/conftest.py @@ -0,0 +1,36 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Pre-register the ``common.data_source`` package namespace so that +importing individual sub-modules (config, exceptions, rest_api_connector, …) +does **not** trigger ``common/data_source/__init__.py``, which pulls in every +connector and their heavy transitive dependencies (numpy, xgboost, etc.). + +This file is executed by pytest before any test module in this directory is +collected, so the lightweight namespace is always in place. +""" + +import os +import sys +import types + +import common # lightweight top-level package + +if "common.data_source" not in sys.modules: + _pkg = types.ModuleType("common.data_source") + _pkg.__path__ = [os.path.join(p, "data_source") for p in common.__path__] + _pkg.__package__ = "common.data_source" + sys.modules["common.data_source"] = _pkg diff --git a/test/unit_test/data_source/test_rest_api_connector.py b/test/unit_test/data_source/test_rest_api_connector.py new file mode 100644 index 0000000000..1e7d737eaa --- /dev/null +++ b/test/unit_test/data_source/test_rest_api_connector.py @@ -0,0 +1,607 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from contextlib import contextmanager +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from common.data_source import utils as _ds_utils +from common.data_source.exceptions import ( + ConnectorMissingCredentialError, + ConnectorValidationError, +) +from common.data_source.rest_api_connector import ( + AuthType, + PaginationType, + RestAPIConnector, + RestAPIConnectorConfig, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +VALID_URL = "https://api.example.com/v1/items" + +_MOCK_DNS_ADDRINFO = [(2, 1, 6, "", ("93.184.216.34", 0))] + + +@contextmanager +def _mocked_rest_api_requests_and_dns(): + """Block real DNS/TCP: mock SSRF getaddrinfo and HTTP at the class layer. + + `RestAPIConnector` calls `rl_requests.get` / `.post` on + `utils._RateLimitedRequest`. Replacing only module-level `rl_requests` is not + reliable everywhere (import/rebind quirks), so we patch the class methods + that wrap `requests.get` / `requests.post` and avoid retry backoff delays. + """ + mock_rl = MagicMock() + with patch( + "common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=_MOCK_DNS_ADDRINFO, + ), patch.object(_ds_utils._RateLimitedRequest, "get", mock_rl.get), patch.object( + _ds_utils._RateLimitedRequest, + "post", + mock_rl.post, + ): + yield mock_rl + + +def _make_paged_connector(**overrides) -> RestAPIConnector: + defaults = dict( + url=VALID_URL, + content_fields=["title"], + pagination_type=PaginationType.PAGE, + pagination_config={"page_param": "page"}, + max_pages=100, + request_delay=0, + ) + defaults.update(overrides) + return RestAPIConnector(**defaults) + + +def _make_connector(**overrides) -> RestAPIConnector: + """Build a RestAPIConnector with sensible defaults, applying *overrides*.""" + defaults = dict( + url=VALID_URL, + content_fields=["title", "body"], + ) + defaults.update(overrides) + return RestAPIConnector(**defaults) + + +def _mock_response(json_data, status_code=200): + """Return a ``requests.Response``-like mock.""" + resp = MagicMock(spec=requests.Response) + resp.status_code = status_code + resp.url = VALID_URL + resp.json.return_value = json_data + + if status_code >= 400: + http_error = requests.HTTPError(response=resp) + resp.raise_for_status.side_effect = http_error + resp.status_code = status_code + else: + resp.raise_for_status.return_value = None + + return resp + + +# ===================================================================== # +# 1. Config schema validation # +# ===================================================================== # + +class TestRestAPIConfig: + """Test Pydantic RestAPIConnectorConfig schema validation.""" + + def test_missing_url_raises_validation_error(self): + """Missing url should fail Pydantic validation.""" + with pytest.raises(Exception): + RestAPIConnectorConfig(content_fields=["title"]) + + def test_missing_content_fields_detected(self): + """An empty content_fields list should be caught by ensure_required_fields.""" + cfg = RestAPIConnectorConfig(url=VALID_URL, content_fields=[]) + with pytest.raises(ConnectorValidationError): + cfg.ensure_required_fields() + + def test_valid_minimal_config(self): + """Minimal valid config: url + content_fields.""" + cfg = RestAPIConnectorConfig(url=VALID_URL, content_fields=["title"]) + assert str(cfg.url).startswith("https://api.example.com") + assert cfg.content_fields == ["title"] + + def test_auth_type_defaults_to_none(self): + """auth_type should default to 'none'.""" + cfg = RestAPIConnectorConfig(url=VALID_URL, content_fields=["t"]) + assert cfg.auth_type == AuthType.NONE + + def test_pagination_type_defaults_to_none(self): + """pagination_type should default to 'none'.""" + cfg = RestAPIConnectorConfig(url=VALID_URL, content_fields=["t"]) + assert cfg.pagination_type == PaginationType.NONE + + def test_string_to_dict_coercion_for_headers(self): + """A key=value string should be coerced to a dict.""" + cfg = RestAPIConnectorConfig( + url=VALID_URL, content_fields=["t"], headers="X-Custom=hello" + ) + assert cfg.headers == {"X-Custom": "hello"} + + def test_string_to_list_coercion_for_content_fields(self): + """A comma-separated string should be coerced to a list.""" + cfg = RestAPIConnectorConfig(url=VALID_URL, content_fields="title,content") + assert cfg.content_fields == ["title", "content"] + + +# ===================================================================== # +# 2. SSRF URL validation # +# ===================================================================== # + +class TestSSRFValidation: + """Test that unsafe URLs are blocked before any HTTP request is made.""" + + def test_localhost_blocked(self): + """localhost should be rejected.""" + with pytest.raises(ConnectorValidationError, match="localhost"): + _make_connector(url="http://localhost/api") + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo") + def test_loopback_ip_blocked(self, mock_dns): + """127.0.0.1 should be rejected.""" + mock_dns.return_value = [(2, 1, 6, "", ("127.0.0.1", 0))] + with pytest.raises(ConnectorValidationError, match="disallowed"): + _make_connector(url="http://127.0.0.1/api") + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo") + def test_cloud_metadata_ip_blocked(self, mock_dns): + """169.254.169.254 (cloud metadata endpoint) should be rejected.""" + mock_dns.return_value = [(2, 1, 6, "", ("169.254.169.254", 0))] + with pytest.raises(ConnectorValidationError, match="disallowed"): + _make_connector(url="http://169.254.169.254/latest/meta-data/") + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo") + def test_private_ip_192_blocked(self, mock_dns): + """192.168.x.x should be rejected.""" + mock_dns.return_value = [(2, 1, 6, "", ("192.168.1.1", 0))] + with pytest.raises(ConnectorValidationError, match="disallowed"): + _make_connector(url="http://192.168.1.1/api") + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo") + def test_private_ip_10_blocked(self, mock_dns): + """10.x.x.x should be rejected.""" + mock_dns.return_value = [(2, 1, 6, "", ("10.0.0.1", 0))] + with pytest.raises(ConnectorValidationError, match="disallowed"): + _make_connector(url="http://10.0.0.1/api") + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo") + def test_public_url_passes(self, mock_dns): + """A public IP should pass validation.""" + mock_dns.return_value = [(2, 1, 6, "", ("93.184.216.34", 0))] + c = _make_connector(url="https://example.com/api") + assert c.url.startswith("https://") + + def test_ftp_scheme_blocked(self): + """ftp:// should be rejected.""" + with pytest.raises(ConnectorValidationError, match="scheme"): + _make_connector(url="ftp://example.com/file") + + def test_file_scheme_blocked(self): + """file:// should be rejected.""" + with pytest.raises(ConnectorValidationError, match="scheme"): + _make_connector(url="file:///etc/passwd") + + +# ===================================================================== # +# 3. Authentication setup # +# ===================================================================== # + +class TestAuthSetup: + """Test _build_auth produces the correct headers / auth objects.""" + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + def test_auth_none(self, _dns): + """auth_type=none should produce no auth headers.""" + c = _make_connector(auth_type=AuthType.NONE) + c.load_credentials({}) + assert c._auth_headers == {} + assert c._basic_auth is None + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + def test_api_key_header(self, _dns): + """api_key_header should set the specified header.""" + c = _make_connector( + auth_type=AuthType.API_KEY_HEADER, + auth_config={"header_name": "X-API-Key"}, + ) + c.load_credentials({"api_key": "secret123"}) + assert c._auth_headers == {"X-API-Key": "secret123"} + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + def test_bearer_token(self, _dns): + """bearer should set Authorization: Bearer .""" + c = _make_connector(auth_type=AuthType.BEARER) + c.load_credentials({"token": "tok_abc"}) + assert c._auth_headers == {"Authorization": "Bearer tok_abc"} + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + def test_basic_auth(self, _dns): + """basic should produce an HTTPBasicAuth object.""" + c = _make_connector(auth_type=AuthType.BASIC) + c.load_credentials({"username": "user", "password": "pass"}) + assert c._basic_auth is not None + assert c._basic_auth.username == "user" + assert c._basic_auth.password == "pass" + + +# ===================================================================== # +# 4. Field extraction # +# ===================================================================== # + +class TestFieldExtraction: + """Test _extract_field / _extract_field_values dot-notation paths.""" + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + def setup_method(self, method, _dns=None): + with patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]): + self.connector = _make_connector() + + def test_simple_field(self): + """Top-level field extraction.""" + assert self.connector._extract_field({"title": "Hello"}, "title") == "Hello" + + def test_dot_notation_nested(self): + """Dot-notation nested field.""" + item = {"country": {"name": "Kuwait"}} + assert self.connector._extract_field(item, "country.name") == "Kuwait" + + def test_array_wildcard(self): + """Wildcard [*] returns all array elements.""" + item = {"tags": [{"name": "A"}, {"name": "B"}]} + result = self.connector._extract_field(item, "tags[*].name") + assert result == ["A", "B"] + + def test_missing_field_returns_none(self): + """Missing field returns None.""" + assert self.connector._extract_field({"a": 1}, "nonexistent") is None + + def test_missing_field_with_default(self): + """Missing field returns configured default value.""" + with patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]): + c = _make_connector(field_default_values={"missing": "fallback"}) + result = c._get_typed_field_value("missing", {"other": 1}) + assert result == "fallback" + + def test_deeply_nested_path(self): + """Multi-level dot-notation path.""" + item = {"a": {"b": {"c": {"d": 42}}}} + assert self.connector._extract_field(item, "a.b.c.d") == 42 + + +# ===================================================================== # +# 5. Items array detection # +# ===================================================================== # + +class TestItemsArrayDetection: + """Test _extract_items auto-detection of the items array.""" + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + def setup_method(self, method, _dns=None): + with patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]): + self.connector = _make_connector() + + def test_items_key(self): + """Detect 'items' key.""" + resp = {"items": [{"id": 1}]} + assert self.connector._extract_items(resp) == [{"id": 1}] + + def test_results_key(self): + """Detect 'results' key.""" + resp = {"results": [{"id": 2}]} + assert self.connector._extract_items(resp) == [{"id": 2}] + + def test_data_key(self): + """Detect 'data' key.""" + resp = {"data": [{"id": 3}]} + assert self.connector._extract_items(resp) == [{"id": 3}] + + def test_records_key(self): + """Detect 'records' key.""" + resp = {"records": [{"id": 4}]} + assert self.connector._extract_items(resp) == [{"id": 4}] + + def test_custom_key_fallback(self): + """Fall back to the first list value in the dict.""" + resp = {"totalCount": 5, "stories": [{"id": 5}]} + assert self.connector._extract_items(resp) == [{"id": 5}] + + def test_response_is_list(self): + """Response that is directly a list.""" + resp = [{"id": 6}, {"id": 7}] + assert self.connector._extract_items(resp) == [{"id": 6}, {"id": 7}] + + def test_empty_response(self): + """Empty dict returns empty list.""" + assert self.connector._extract_items({}) == [] + + def test_no_list_in_response(self): + """Dict with no list values returns empty list.""" + assert self.connector._extract_items({"count": 0}) == [] + + +# ===================================================================== # +# 6. HTML stripping # +# ===================================================================== # + +class TestHTMLStripping: + """Test the _strip_html static method.""" + + def test_basic_tag_removal(self): + """Remove simple HTML tags.""" + assert RestAPIConnector._strip_html("

Hello

") == "Hello" + + def test_whitespace_collapsing(self): + """Multiple whitespace chars collapse to single space.""" + assert RestAPIConnector._strip_html("

Hello

World

") == "Hello World" + + def test_empty_string(self): + """Empty input returns empty output.""" + assert RestAPIConnector._strip_html("") == "" + + def test_plain_text_passthrough(self): + """Text without HTML passes through unchanged.""" + assert RestAPIConnector._strip_html("Hello World") == "Hello World" + + def test_nested_tags(self): + """Nested HTML tags are all stripped.""" + result = RestAPIConnector._strip_html("

Bold text

") + assert result == "Bold text" + + def test_html_with_attributes(self): + """Tags with attributes are stripped.""" + result = RestAPIConnector._strip_html('Link') + assert result == "Link" + + +# ===================================================================== # +# 7. Document creation # +# ===================================================================== # + +class TestDocumentCreation: + """Test _item_to_document mapping.""" + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + def setup_method(self, method, _dns=None): + with patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]): + self.connector = _make_connector( + id_field="id", + content_fields=["title", "body"], + metadata_fields=["author"], + ) + + def test_document_id_from_configured_field(self): + """Document ID uses the configured id_field.""" + item = {"id": "abc", "title": "T", "body": "B", "author": "A"} + doc = self.connector._item_to_document(item) + assert doc.id is not None and len(doc.id) > 0 + + def test_semantic_identifier_from_first_content_field(self): + """semantic_identifier comes from the first content field.""" + item = {"id": "1", "title": "My Title", "body": "Body", "author": "A"} + doc = self.connector._item_to_document(item) + assert "My Title" in doc.semantic_identifier + + def test_content_blob_contains_all_fields(self): + """Blob should contain both content fields.""" + item = {"id": "1", "title": "Title", "body": "Body text", "author": "A"} + doc = self.connector._item_to_document(item) + content = doc.blob.decode("utf-8") + assert "Title" in content + assert "Body text" in content + + def test_metadata_populated(self): + """Metadata dict is populated from configured metadata_fields.""" + item = {"id": "1", "title": "T", "body": "B", "author": "Jane"} + doc = self.connector._item_to_document(item) + assert doc.metadata is not None + assert doc.metadata["author"] == "Jane" + + def test_html_stripped_from_content(self): + """HTML tags are removed from content fields.""" + item = {"id": "1", "title": "T", "body": "

Clean

", "author": "A"} + doc = self.connector._item_to_document(item) + content = doc.blob.decode("utf-8") + assert "

" not in content + assert "Clean" in content + + def test_extension_is_txt(self): + """Document extension should be .txt.""" + item = {"id": "1", "title": "T", "body": "B", "author": "A"} + doc = self.connector._item_to_document(item) + assert doc.extension == ".txt" + + def test_missing_content_fields_graceful(self): + """Missing content fields produce an empty blob gracefully.""" + item = {"id": "1", "author": "A"} + doc = self.connector._item_to_document(item) + assert doc.blob == b"" + + +# ===================================================================== # +# 8. Pagination behaviour # +# ===================================================================== # + +class TestPaginationBehavior: + """Test pagination iteration with mocked HTTP responses.""" + + def test_page_pagination_increments(self): + """Page-based pagination should increment the page param.""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + page1 = _mock_response({"items": [{"title": "A"}, {"title": "B"}]}) + page2 = _mock_response({"items": []}) + mock_rl.get.side_effect = [page1, page2] + + c = _make_paged_connector() + items = list(c._iter_items()) + assert len(items) == 2 + assert mock_rl.get.call_count == 2 + + def test_offset_pagination_increments(self): + """Offset-based pagination should increment offset by limit.""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + page1 = _mock_response({"items": [{"title": "A"}]}) + page2 = _mock_response({"items": []}) + mock_rl.get.side_effect = [page1, page2] + + c = _make_connector( + pagination_type=PaginationType.OFFSET, + pagination_config={ + "offset_param": "offset", + "limit_param": "limit", + "limit": 10, + }, + request_delay=0, + ) + items = list(c._iter_items()) + assert len(items) == 1 + + def test_stops_on_empty_results(self): + """Pagination stops when empty items are returned.""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + mock_rl.get.return_value = _mock_response({"items": []}) + + c = _make_paged_connector() + items = list(c._iter_items()) + assert items == [] + assert mock_rl.get.call_count == 1 + + def test_stops_when_fewer_items_than_page_size(self): + """Pagination stops when fewer items than page_size are returned.""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + page1 = _mock_response({"items": [{"title": "A"}]}) + mock_rl.get.return_value = page1 + + c = _make_paged_connector( + pagination_config={"page_param": "page", "page_size": 10}, + ) + items = list(c._iter_items()) + assert len(items) == 1 + assert mock_rl.get.call_count == 1 + + def test_max_pages_cap(self): + """Pagination respects the max_pages safety cap.""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + mock_rl.get.return_value = _mock_response( + {"items": [{"title": "A"}, {"title": "B"}]} + ) + + c = _make_paged_connector( + max_pages=3, + pagination_config={"page_param": "page", "page_size": 2}, + ) + list(c._iter_items()) + assert mock_rl.get.call_count == 3 + + def test_request_delay_applied(self): + """request_delay should cause a sleep between pages.""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + with patch("common.data_source.rest_api_connector.time.sleep") as mock_sleep: + page1 = _mock_response({"items": [{"title": "A"}, {"title": "B"}]}) + page2 = _mock_response({"items": []}) + mock_rl.get.side_effect = [page1, page2] + + c = _make_paged_connector( + pagination_config={"page_param": "page", "page_size": 2}, + ) + c.request_delay = 1.5 + list(c._iter_items()) + mock_sleep.assert_called_once_with(1.5) + + +# ===================================================================== # +# 9. Non-retriable HTTP errors # +# ===================================================================== # + +class TestNonRetriableErrors: + """Test that HTTP errors are classified correctly in _fetch_page.""" + + def test_401_raises_credential_error(self): + """401 should raise ConnectorMissingCredentialError immediately.""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + mock_rl.get.return_value = _mock_response({}, status_code=401) + c = _make_connector(request_delay=0) + c.load_credentials({}) + with pytest.raises(ConnectorMissingCredentialError): + c._fetch_page({}) + + def test_403_raises_credential_error(self): + """403 should raise ConnectorMissingCredentialError immediately.""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + mock_rl.get.return_value = _mock_response({}, status_code=403) + c = _make_connector(request_delay=0) + c.load_credentials({}) + with pytest.raises(ConnectorMissingCredentialError): + c._fetch_page({}) + + def test_404_raises_validation_error(self): + """404 should raise ConnectorValidationError (no retry).""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + mock_rl.get.return_value = _mock_response({}, status_code=404) + c = _make_connector(request_delay=0) + c.load_credentials({}) + with pytest.raises(ConnectorValidationError, match="non-retriable"): + c._fetch_page({}) + + def test_400_raises_validation_error(self): + """400 should raise ConnectorValidationError (no retry).""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + mock_rl.get.return_value = _mock_response({}, status_code=400) + c = _make_connector(request_delay=0) + c.load_credentials({}) + with pytest.raises(ConnectorValidationError, match="non-retriable"): + c._fetch_page({}) + + def test_500_triggers_retry(self): + """500 should raise HTTPError (which the retry decorator catches).""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + mock_rl.get.return_value = _mock_response({}, status_code=500) + c = _make_connector(request_delay=0) + c.load_credentials({}) + with pytest.raises(requests.HTTPError): + c._fetch_page({}) + + def test_429_triggers_retry(self): + """429 should raise HTTPError (retriable, not ConnectorValidationError).""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + mock_rl.get.return_value = _mock_response({}, status_code=429) + c = _make_connector(request_delay=0) + c.load_credentials({}) + with pytest.raises(requests.HTTPError): + c._fetch_page({}) diff --git a/test/unit_test/rag/conftest.py b/test/unit_test/rag/conftest.py new file mode 100644 index 0000000000..3ca5e289e9 --- /dev/null +++ b/test/unit_test/rag/conftest.py @@ -0,0 +1,58 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Restore the real ``common.data_source`` package before importing rag unit tests. + +``test/unit_test/data_source/conftest.py`` registers a lightweight +``sys.modules["common.data_source"]`` stub so submodule imports skip the heavy +package ``__init__.py``. Pytest collection order visits ``data_source/`` before +``rag/``, so without this hook ``rag.svr.sync_data_source`` fails on +``from common.data_source import BlobStorageConnector``. +""" + +from __future__ import annotations + +import importlib +import sys +import types + + +def _restore_common_data_source_package() -> None: + mod = sys.modules.get("common.data_source") + if mod is None: + return + # Stub is a bare types.ModuleType with __path__ and no __file__; real package has __init__.py. + if getattr(mod, "__file__", None) is not None: + return + if not isinstance(mod, types.ModuleType) or not getattr(mod, "__path__", None): + return + keys = [ + key + for key in sys.modules + if key == "common.data_source" or key.startswith("common.data_source.") + ] + for key in keys: + del sys.modules[key] + importlib.invalidate_caches() + try: + importlib.import_module("common.data_source") + except Exception as exc: # pragma: no cover + raise ImportError( + "conftest: failed to restore real common.data_source package" + ) from exc + + +_restore_common_data_source_package() diff --git a/web/src/assets/svg/data-source/rest-api.svg b/web/src/assets/svg/data-source/rest-api.svg new file mode 100644 index 0000000000..f7d3e6d213 --- /dev/null +++ b/web/src/assets/svg/data-source/rest-api.svg @@ -0,0 +1,13 @@ + + + + + + + + + + + + + diff --git a/web/src/components/document-preview/ppt-preview.tsx b/web/src/components/document-preview/ppt-preview.tsx index 4895ad9fe3..d95e05c46d 100644 --- a/web/src/components/document-preview/ppt-preview.tsx +++ b/web/src/components/document-preview/ppt-preview.tsx @@ -41,7 +41,7 @@ export const PptPreviewer: React.FC = ({ }); pptxPrviewer.preview(arrayBuffer); } - } catch (err) { + } catch { message.error('ppt parse failed'); } }; diff --git a/web/src/components/dynamic-form.tsx b/web/src/components/dynamic-form.tsx index 5c9fff5eaf..0ef13df1c1 100644 --- a/web/src/components/dynamic-form.tsx +++ b/web/src/components/dynamic-form.tsx @@ -347,7 +347,6 @@ export const RenderField = ({ field: FormFieldConfig; labelClassName?: string; }) => { - const form = useFormContext(); if (field.render) { if (field.type === FormFieldType.Custom && field.hideLabel) { return

{field.render({})}
; diff --git a/web/src/components/floating-chat-widget-markdown.tsx b/web/src/components/floating-chat-widget-markdown.tsx index 3a4e4942c6..4237fa1758 100644 --- a/web/src/components/floating-chat-widget-markdown.tsx +++ b/web/src/components/floating-chat-widget-markdown.tsx @@ -296,9 +296,11 @@ const FloatingChatWidgetMarkdown = ({ className="text-sm leading-relaxed space-y-2 prose-sm max-w-full" components={ { - p: ({ children, node, ...props }: any) => ( -

{children}

- ), + p: (props: any) => { + const { children, node, ...rest } = props; + void node; + return

{children}

; + }, 'custom-typography': ({ children }: { children: string }) => renderReference(children), code(props: any) { diff --git a/web/src/components/ui/audio-button.tsx b/web/src/components/ui/audio-button.tsx index 313ef38739..15ac5b8c63 100644 --- a/web/src/components/ui/audio-button.tsx +++ b/web/src/components/ui/audio-button.tsx @@ -8,7 +8,6 @@ import { getAuthorization } from '@/utils/authorization-util'; import { chain, sum } from 'lodash'; import { Loader2, Mic, Square } from 'lucide-react'; import { useCallback, useEffect, useRef, useState } from 'react'; -import { useIsDarkTheme } from '../theme-provider'; import { Input } from './input'; import { Popover, PopoverContent, PopoverTrigger } from './popover'; @@ -18,7 +17,6 @@ const VoiceVisualizer = ({ isRecording }: { isRecording: boolean }) => { const analyserRef = useRef(null); const animationFrameRef = useRef(0); const streamRef = useRef(null); - const isDark = useIsDarkTheme(); const draw = useCallback(() => { const canvas = canvasRef.current; @@ -273,11 +271,9 @@ export const AudioButton = ({ // throw new Error('ReadableStream not supported in this browser'); // } - console.log('Response:', response); const { data, code } = await response.json(); if (code === 0 && data && data.text) { setTranscript(data.text); - console.log('Transcript:', data.text); onOk?.(data.text); } setPopoverOpen(false); @@ -341,7 +337,7 @@ export const AudioButton = ({ { - setPopoverOpen(true); + setPopoverOpen(open); }} > diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 0c43e37fdd..20d5fde23d 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -1365,6 +1365,42 @@ Example: Virtual Hosted Style`, 'Column to use as unique document ID. If not specified, a hash of the content will be used.', postgresqlTimestampColumnTip: 'Datetime/timestamp column for incremental sync. Only rows modified after the last sync will be fetched.', + rest_apiDescription: + 'Connect any REST API endpoint as a data source using a flexible, configuration-driven connector.', + restApiQueryParamsTip: + 'Key=value pairs (one per line) sent as URL query parameters. Use this instead of embedding params in the URL.', + restApiHeadersTip: + 'Optional JSON object of additional HTTP headers to send with every request.', + restApiItemsPathTip: + 'Field name or JSONPath to the array of items in the response. Leave empty to auto-detect (tries "items", "results", "data", etc.).', + restApiIdFieldTip: + 'Field path within each item used to build a stable document ID. Leave empty to auto-generate from content hash.', + restApiContentFieldsTip: + 'Comma-separated list of item fields to concatenate into the document content.', + restApiMetadataFieldsTip: + 'Comma-separated list of item fields to store as metadata.', + restApiNextCursorPathTip: + 'JSONPath expression that resolves to the next-page cursor in the API response.', + restApiPollTimestampFieldTip: + 'Field path in each item that represents the last updated time, used for incremental sync.', + restApiRequestBodyTip: + 'Optional JSON body to send for POST requests. Used together with query params and pagination.', + restApiRequestDelayTip: + 'Delay in seconds between consecutive page requests. Helps avoid rate limiting from the API. Set to 0 to disable.', + restApiValidationApiKeyRequired: + 'API key is required when Auth Type is API Key (Header).', + restApiValidationApiKeyHeaderNameRequired: + 'API key header name is required when Auth Type is API Key (Header).', + restApiValidationBearerTokenRequired: + 'Bearer token is required when Auth Type is Bearer Token.', + restApiValidationBasicUsernameRequired: + 'Username is required when Auth Type is Basic Auth.', + restApiValidationBasicPasswordRequired: + 'Password is required when Auth Type is Basic Auth.', + restApiTestConnection: 'Test connection', + restApiTestSuccess: 'REST API connector validated successfully.', + restApiTestFailed: + 'REST API connector validation failed. Please check your configuration and logs.', availableSourcesDescription: 'Select a data source to add', availableSources: 'Available sources', datasourceDescription: 'Manage your data source and connections', diff --git a/web/src/locales/tr.ts b/web/src/locales/tr.ts index 2ae7cf3a23..3c0ba20c7b 100644 --- a/web/src/locales/tr.ts +++ b/web/src/locales/tr.ts @@ -341,7 +341,7 @@ Prosedürel Bellek: Öğrenilen beceriler, alışkanlıklar ve otomatik prosedü config: { descriptionPlaceholder: 'Belleğinizi açıklayın', memorySizeTooltip: `Her mesajın içeriği + embedding vektörü için geçerlidir (≈ İçerik + Boyutlar × 8 Bayt). -Örnek: 1024 boyutlu embedding ile 1 KB\'lık bir mesaj ~9 KB kullanır. 5 MB varsayılan sınır ~500 mesaj tutar.`, +Örnek: 1024 boyutlu embedding ile 1 KB'lık bir mesaj ~9 KB kullanır. 5 MB varsayılan sınır ~500 mesaj tutar.`, avatar: 'Avatar', description: 'Açıklama', memorySize: 'Bellek boyutu', @@ -980,9 +980,9 @@ Bu otomatik etiketleme özelliği, mevcut datasete alanına özgü bilgi katman systemTip: 'LLM için istemleriniz veya talimatlarınız; rol, yanıtların uzunluğu, tonu ve dili dahil ancak bunlarla sınırlı değildir. Modeliniz doğal olarak akıl yürütmeyi destekliyorsa, akıl yürütmeyi durdurmak için isteme //no_thinking ekleyebilirsiniz.', topN: 'İlk N', - topNTip: `Benzerlik eşiğinin üzerindeki tüm parçalar LLM\'ye gönderilmeyecek. Bu, alınanlardan 'İlk N' parçayı seçer.`, + topNTip: `Benzerlik eşiğinin üzerindeki tüm parçalar LLM'ye gönderilmeyecek. Bu, alınanlardan 'İlk N' parçayı seçer.`, variable: 'Değişken', - variableTip: `RAGFlow\'nun sohbet asistanı yönetim API\'leri ile birlikte kullanılır.`, + variableTip: `RAGFlow'nun sohbet asistanı yönetim API'leri ile birlikte kullanılır.`, add: 'Ekle', key: 'Anahtar', optional: 'İsteğe bağlı', @@ -1154,7 +1154,7 @@ Bu otomatik etiketleme özelliği, mevcut datasete alanına özgü bilgi katman "Confluence örneğinizin temel URL'si (örn. https://your-domain.atlassian.net/wiki)", confluenceSpaceKeyTip: 'İsteğe bağlı: Belirli bir alanla senkronizasyonu sınırlamak için alan anahtarı belirtin.', - s3PrefixTip: `S3 bucket\'ınızdaki dosyaları almak için klasör yolunu belirtin.`, + s3PrefixTip: `S3 bucket'ınızdaki dosyaları almak için klasör yolunu belirtin.`, S3CompatibleEndpointUrlTip: `S3 uyumlu Depolama Kutusu için zorunludur.`, S3CompatibleAddressingStyleTip: `S3 uyumlu Depolama Kutusu için zorunludur.`, addDataSourceModalTitle: '{{name}} bağlayıcınızı oluşturun', @@ -1509,7 +1509,7 @@ Bu otomatik etiketleme özelliği, mevcut datasete alanına özgü bilgi katman 'Lütfen Google Cloud Hizmet Hesabı Anahtarını base64 formatında girin', addGoogleRegion: 'Google Cloud Bölgesi', GoogleRegionMessage: 'Lütfen Google Cloud Bölgesi girin', - modelProvidersWarn: `Lütfen önce Ayarlar > Model sağlayıcıları bölümünde hem embedding modelini hem de LLM\'yi ekleyin.`, + modelProvidersWarn: `Lütfen önce Ayarlar > Model sağlayıcıları bölümünde hem embedding modelini hem de LLM'yi ekleyin.`, apiVersion: 'API Sürümü', apiVersionMessage: 'Lütfen API sürümünü girin', add: 'Ekle', @@ -1794,7 +1794,7 @@ En uygun olduğu durumlar: Anlatı bütünlüğünün bitişik paragrafları bir beginDescription: 'Akışın başladığı yer.', answerDescription: `İnsan ve bot arasındaki arayüz olarak hizmet eden bir bileşen.`, retrievalDescription: `Belirtilen datasets içinden bilgi alan bir bileşen.`, - generateDescription: `LLM\'yi yanıt üretmeye yönlendiren bir bileşen.`, + generateDescription: `LLM'yi yanıt üretmeye yönlendiren bir bileşen.`, categorizeDescription: `Kullanıcı girişlerini önceden tanımlanmış kategorilere sınıflandırmak için LLM kullanan bir bileşen.`, relevantDescription: `Yukarı akış çıktısının kullanıcının son sorgusuna uygun olup olmadığını değerlendirmek için LLM kullanan bir bileşen.`, rewriteQuestionDescription: `Önceki diyalogların bağlamına dayanarak Etkileşim bileşeninden bir kullanıcı sorgusunu yeniden yazan bir bileşen.`, @@ -1802,7 +1802,7 @@ En uygun olduğu durumlar: Anlatı bütünlüğünün bitişik paragrafları bir 'Bu bileşen, önceden tanımlanmış mesaj içeriğiyle birlikte iş akışının nihai veri çıktısını döndürür.', keywordDescription: `Kullanıcının girdisinden en fazla N arama sonucunu alan bir bileşen.`, switchDescription: `Önceki bileşenlerin çıktısına göre koşulları değerlendiren ve yürütme akışını yönlendiren bir bileşen.`, - wikipediaDescription: `wikipedia.org\'dan arama yapan bir bileşen.`, + wikipediaDescription: `wikipedia.org'dan arama yapan bir bileşen.`, promptText: `Lütfen aşağıdaki paragrafları özetleyin. Sayılara dikkat edin, uydurma yapmayın. Paragraflar aşağıdaki gibidir: {input} Yukarısı özetlemeniz gereken içeriktir.`, @@ -1825,7 +1825,7 @@ En uygun olduğu durumlar: Anlatı bütünlüğünün bitişik paragrafları bir keywordExtract: 'Anahtar kelime', keywordExtractDescription: `Bir kullanıcı sorgusundan anahtar kelimeler çıkaran bir bileşen.`, baidu: 'Baidu', - baiduDescription: `baidu.com\'dan arama yapan bir bileşen.`, + baiduDescription: `baidu.com'dan arama yapan bir bileşen.`, duckDuckGo: 'DuckDuckGo', duckDuckGoDescription: "duckduckgo.com'dan arama yapan bir bileşen.", searXNG: 'SearXNG', @@ -2693,7 +2693,7 @@ Temel Talimatlar: changeStepModalContent: `

Şu anda bu aşamanın sonuçlarını düzenliyorsunuz.

Daha sonraki bir aşamaya geçerseniz değişiklikleriniz kaybolacak.

-

Korumak için lütfen Yeniden Çalıştır\'a tıklayın.

`, +

Korumak için lütfen Yeniden Çalıştır'a tıklayın.

`, changeStepModalConfirmText: 'Yine de Geç', changeStepModalCancelText: 'İptal', unlinkPipelineModalTitle: 'Alım hattı bağlantısını kes', diff --git a/web/src/pages/agent/form/components/prompt-editor/variable-picker-plugin.tsx b/web/src/pages/agent/form/components/prompt-editor/variable-picker-plugin.tsx index a68bfd4fd9..31e092ca3e 100644 --- a/web/src/pages/agent/form/components/prompt-editor/variable-picker-plugin.tsx +++ b/web/src/pages/agent/form/components/prompt-editor/variable-picker-plugin.tsx @@ -369,7 +369,7 @@ export default function VariablePickerMenuPlugin({ const filterStructuredOutput = useGetStructuredOutputByValue(); const testTriggerFn = React.useCallback((text: string) => { - const triggerRegex = /(^|\s|\()([/]((?:[^/\s\()])*))$/; + const triggerRegex = /(^|\s|\()([/]((?:[^/\s()])*))$/; const match = triggerRegex.exec(text); if (match !== null) { diff --git a/web/src/pages/agent/hooks/use-add-node.ts b/web/src/pages/agent/hooks/use-add-node.ts index 45f9179493..15b57c5c6e 100644 --- a/web/src/pages/agent/hooks/use-add-node.ts +++ b/web/src/pages/agent/hooks/use-add-node.ts @@ -75,7 +75,7 @@ const GroupStartNodeMap = { name: Operator.IterationStart, form: initialIterationStartValues, }, - extent: 'parent' as 'parent', + extent: 'parent' as const, }, [Operator.Loop]: { id: `${Operator.LoopStart}:${humanId()}`, @@ -86,7 +86,7 @@ const GroupStartNodeMap = { name: Operator.LoopStart, form: {}, }, - extent: 'parent' as 'parent', + extent: 'parent' as const, }, }; diff --git a/web/src/pages/dataflow-result/components/time-line/index.tsx b/web/src/pages/dataflow-result/components/time-line/index.tsx index 1e96eb216e..3c0a3ca678 100644 --- a/web/src/pages/dataflow-result/components/time-line/index.tsx +++ b/web/src/pages/dataflow-result/components/time-line/index.tsx @@ -55,7 +55,6 @@ export interface TimelineDataFlowProps { const TimelineDataFlow = ({ activeFunc, activeId, - data, timelineNodes, }: TimelineDataFlowProps) => { // const [timelineNodeArr,setTimelineNodeArr] = useState() diff --git a/web/src/pages/next-chats/chat/app-settings/chat-basic-settings.tsx b/web/src/pages/next-chats/chat/app-settings/chat-basic-settings.tsx index c0715c7846..2be5fc995e 100644 --- a/web/src/pages/next-chats/chat/app-settings/chat-basic-settings.tsx +++ b/web/src/pages/next-chats/chat/app-settings/chat-basic-settings.tsx @@ -27,13 +27,20 @@ export default function ChatBasicSetting() { const form = useFormContext(); const emptyResponseValue = form.watch('prompt_config.empty_response'); const prologueValue = form.watch('prompt_config.prologue'); - const kbIds = (useWatch({ control: form.control, name: 'dataset_ids' }) || - []) as string[]; + const rawDatasetIds = useWatch({ + control: form.control, + name: 'dataset_ids', + }); + const kbIds = useMemo( + () => (rawDatasetIds || []) as string[], + [rawDatasetIds], + ); const metadataInclude = useWatch({ control: form.control, name: 'prompt_config.reference_metadata.include', }); - const { data: metadataKeys } = useFetchKnowledgeMetadataKeys(kbIds); + const { data: metadataKeys, loading: metadataKeysLoading } = + useFetchKnowledgeMetadataKeys(kbIds); const metadataFieldOptions = useMemo(() => { return (metadataKeys || []).map((key) => ({ label: key, @@ -60,7 +67,7 @@ export default function ChatBasicSetting() { } else if (!metadataInclude) { form.setValue('prompt_config.reference_metadata.fields', undefined); } - }, [kbIds, metadataKeys, metadataInclude, form]); + }, [kbIds, metadataKeys, metadataKeysLoading, metadataInclude, form]); return (
@@ -176,4 +183,4 @@ export default function ChatBasicSetting() { )}
); -} +} \ No newline at end of file diff --git a/web/src/pages/next-search/markdown-content/index.tsx b/web/src/pages/next-search/markdown-content/index.tsx index a5522b7c7a..132f34d5df 100644 --- a/web/src/pages/next-search/markdown-content/index.tsx +++ b/web/src/pages/next-search/markdown-content/index.tsx @@ -90,10 +90,13 @@ const MarkdownContent = ({ chunk: IReferenceChunk, isPdf: boolean = false, documentUrl?: string, - ) => - () => { + ) => { + void isPdf; + void documentUrl; + return () => { clickDocumentButton?.(documentId, chunk); - }, + }; + }, [clickDocumentButton], ); @@ -250,9 +253,7 @@ const MarkdownContent = ({ remarkPlugins={[remarkGfm, remarkMath]} components={ { - p: ({ children, node, ...props }: any) => ( -

{children}

- ), + p: ({ children, ...props }: any) =>

{children}

, 'custom-typography': ({ children }: { children: string }) => renderReference(children), code(props: any) { diff --git a/web/src/pages/next-search/search-setting.tsx b/web/src/pages/next-search/search-setting.tsx index c3c812306d..697e7369e8 100644 --- a/web/src/pages/next-search/search-setting.tsx +++ b/web/src/pages/next-search/search-setting.tsx @@ -217,9 +217,8 @@ const SearchSetting: React.FC = ({ control: formMethods.control, name: 'search_config.reference_metadata.include', }); - const { data: metadataKeys } = useFetchKnowledgeMetadataKeys( - selectedKbIds || [], - ); + const { data: metadataKeys, loading: metadataKeysLoading } = + useFetchKnowledgeMetadataKeys(selectedKbIds || []); const metadataFieldOptions = useMemo(() => { return (metadataKeys || []).map((key) => ({ label: key, @@ -252,7 +251,13 @@ const SearchSetting: React.FC = ({ undefined, ); } - }, [selectedKbIds, metadataKeys, referenceMetadataEnabled, formMethods]); + }, [ + selectedKbIds, + metadataKeys, + metadataKeysLoading, + referenceMetadataEnabled, + formMethods, + ]); // Reset top_k to 1024 only when user actively disables rerank (from true to false) const prevRerankEnabled = useRef(undefined); diff --git a/web/src/pages/skills/components/markdown-viewer.tsx b/web/src/pages/skills/components/markdown-viewer.tsx index 12937ed32c..e50f56289f 100644 --- a/web/src/pages/skills/components/markdown-viewer.tsx +++ b/web/src/pages/skills/components/markdown-viewer.tsx @@ -61,20 +61,20 @@ const MarkdownViewer: React.FC = ({ content }) => { const language = match ? match[1] : ''; if (language) { - return ( - - {String(children).replace(/\n$/, '')} - - ); + return ( + + {String(children).replace(/\n$/, '')} + + ); } return ( diff --git a/web/src/pages/skills/components/skill-detail.tsx b/web/src/pages/skills/components/skill-detail.tsx index c378a0cb0f..1b915cb488 100644 --- a/web/src/pages/skills/components/skill-detail.tsx +++ b/web/src/pages/skills/components/skill-detail.tsx @@ -103,9 +103,11 @@ const SkillDetail: React.FC = ({ const [versionFiles, setVersionFiles] = useState([]); const [versionLoading, setVersionLoading] = useState(false); - // Check if skill has multiple versions - const hasVersions = skill?.versions && skill.versions.length > 0; - const availableVersions = skill?.versions || []; + const availableVersions = useMemo( + () => skill?.versions ?? [], + [skill?.versions], + ); + const hasVersions = availableVersions.length > 0; // Reset state when skill changes or drawer opens/closes useEffect(() => { @@ -128,20 +130,14 @@ const SkillDetail: React.FC = ({ setSelectedFile(null); setFileContent(''); } - }, [ - open, - skill?.id, - hasVersions, - skill?.metadata?.version, - availableVersions, - ]); + }, [open, skill, hasVersions, availableVersions]); const resolvedVersion = useMemo(() => { if (!skill) return ''; return ( selectedVersion || skill.metadata?.version || skill.versions?.[0] || '' ); - }, [selectedVersion, skill?.id, skill?.metadata?.version, skill?.versions]); + }, [selectedVersion, skill]); // Load files when version or skill changes useEffect(() => { @@ -204,16 +200,7 @@ const SkillDetail: React.FC = ({ return () => { isActive = false; }; - }, [ - skill?.id, - skill?.source_type, - skill?.metadata?.version, - skill?.versions, - (skill as any)?._folderId, - skill?.files, - resolvedVersion, - getVersionFiles, - ]); + }, [skill, resolvedVersion, getVersionFiles]); // Use version files if available, otherwise use skill.files const currentFiles = useMemo(() => { @@ -248,7 +235,7 @@ const SkillDetail: React.FC = ({ ); setFileContent(content || ''); } catch (error) { - console.error('Failed to load file content'); + console.error('Failed to load file content', error); } finally { setLoading(false); } @@ -274,7 +261,7 @@ const SkillDetail: React.FC = ({ handleSelect({ id: targetFile.path } as TreeDataItem); } } - }, [open, skill?.id, currentFiles.length]); + }, [open, skill, currentFiles, handleSelect, selectedFile]); const renderFileContent = () => { if (!selectedFile) { diff --git a/web/src/pages/user-setting/data-source/component/box-token-field.tsx b/web/src/pages/user-setting/data-source/component/box-token-field.tsx index eccdf2dda5..4ae6bb2cf7 100644 --- a/web/src/pages/user-setting/data-source/component/box-token-field.tsx +++ b/web/src/pages/user-setting/data-source/component/box-token-field.tsx @@ -127,7 +127,10 @@ const BoxTokenField = ({ value, onChange }: BoxTokenFieldProps) => { string, any >; - const { user_id: _userId, code, ...rest } = credentials; + const code = credentials.code; + const rest = { ...credentials }; + delete rest.user_id; + delete rest.code; const finalValue: Record = { ...rest, @@ -173,7 +176,7 @@ const BoxTokenField = ({ value, onChange }: BoxTokenFieldProps) => { setWebStatus('error'); setWebStatusMessage(errorMessage); clearWebState(); - } catch (_error) { + } catch { message.error('Unable to retrieve authorization result.'); setWebStatus('error'); setWebStatusMessage('Unable to retrieve authorization result.'); @@ -304,7 +307,7 @@ const BoxTokenField = ({ value, onChange }: BoxTokenFieldProps) => { } else { message.error(data.message || 'Failed to start Box authorization.'); } - } catch (_error) { + } catch { message.error('Failed to start Box authorization.'); } finally { setSubmitLoading(false); diff --git a/web/src/pages/user-setting/data-source/component/gmail-token-field.tsx b/web/src/pages/user-setting/data-source/component/gmail-token-field.tsx index 186281d918..c347ed649d 100644 --- a/web/src/pages/user-setting/data-source/component/gmail-token-field.tsx +++ b/web/src/pages/user-setting/data-source/component/gmail-token-field.tsx @@ -95,11 +95,7 @@ const withRedirectUri = (credentials: string, redirectUri: string): string => { }); }; -const GmailTokenField = ({ - value, - onChange, - placeholder, -}: GmailTokenFieldProps) => { +const GmailTokenField = ({ value, onChange }: GmailTokenFieldProps) => { const [files, setFiles] = useState([]); const [pendingCredentials, setPendingCredentials] = useState(''); const [redirectUri, setRedirectUri] = useState(''); @@ -195,7 +191,7 @@ const GmailTokenField = ({ } message.error(data.message || 'Authorization failed.'); clearWebState(); - } catch (err) { + } catch { message.error('Unable to retrieve authorization result.'); clearWebState(); } @@ -315,7 +311,7 @@ const GmailTokenField = ({ } else { message.error(data.message || 'Failed to start browser authorization.'); } - } catch (err) { + } catch { message.error('Failed to start browser authorization.'); } finally { setWebAuthLoading(false); diff --git a/web/src/pages/user-setting/data-source/component/google-drive-token-field.tsx b/web/src/pages/user-setting/data-source/component/google-drive-token-field.tsx index 8d182fdbba..53ef6a60bf 100644 --- a/web/src/pages/user-setting/data-source/component/google-drive-token-field.tsx +++ b/web/src/pages/user-setting/data-source/component/google-drive-token-field.tsx @@ -192,7 +192,7 @@ const GoogleDriveTokenField = ({ } message.error(data.message || 'Authorization failed.'); clearWebState(); - } catch (err) { + } catch { message.error('Unable to retrieve authorization result.'); clearWebState(); } @@ -312,7 +312,7 @@ const GoogleDriveTokenField = ({ } else { message.error(data.message || 'Failed to start browser authorization.'); } - } catch (err) { + } catch { message.error('Failed to start browser authorization.'); } finally { setWebAuthLoading(false); diff --git a/web/src/pages/user-setting/data-source/constant/index.tsx b/web/src/pages/user-setting/data-source/constant/index.tsx index 0a5eb8c429..d00db49e16 100644 --- a/web/src/pages/user-setting/data-source/constant/index.tsx +++ b/web/src/pages/user-setting/data-source/constant/index.tsx @@ -41,6 +41,7 @@ export enum DataSourceKey { SEAFILE = 'seafile', MYSQL = 'mysql', POSTGRESQL = 'postgresql', + REST_API = 'rest_api', RSS = 'rss', // SHAREPOINT = 'sharepoint', @@ -202,6 +203,11 @@ export const generateDataSourceInfo = (t: TFunction) => { description: t(`setting.${DataSourceKey.GMAIL}Description`), icon: , }, + [DataSourceKey.REST_API]: { + name: 'REST API', + description: t(`setting.${DataSourceKey.REST_API}Description`), + icon: , + }, [DataSourceKey.MOODLE]: { name: 'Moodle', description: t(`setting.${DataSourceKey.MOODLE}Description`), @@ -373,47 +379,6 @@ export const getCommonExtraDefaultValues = () => ({ }, }); -export const getDataSourceFieldsWithExtras = ( - source?: DataSourceKey, -): FormFieldConfig[] => { - if (!source) { - return []; - } - - const sourceFields = - DataSourceFormFields[source as keyof typeof DataSourceFormFields] || []; - const extraFields = getCommonExtraFields(source); - - if (source !== DataSourceKey.JIRA) { - return [...sourceFields, ...extraFields]; - } - - const modeFieldIndex = sourceFields.findIndex( - (field) => field.name === 'config.is_cloud', - ); - if (modeFieldIndex < 0) { - return [...sourceFields, ...extraFields]; - } - - const sharedFields = sourceFields.slice(0, modeFieldIndex); - const modeFields = sourceFields.slice(modeFieldIndex); - - const sharedCheckboxFieldIndex = sharedFields.findIndex( - (field) => field.type === FormFieldType.Checkbox, - ); - - if (sharedCheckboxFieldIndex < 0) { - return [...sharedFields, ...extraFields, ...modeFields]; - } - - return [ - ...sharedFields.slice(0, sharedCheckboxFieldIndex), - ...sharedFields.slice(sharedCheckboxFieldIndex), - ...extraFields, - ...modeFields, - ]; -}; - export const DataSourceFormFields = { [DataSourceKey.RSS]: [ { @@ -1123,6 +1088,286 @@ export const DataSourceFormFields = { tooltip: t('setting.postgresqlTimestampColumnTip'), }, ], + [DataSourceKey.REST_API]: [ + // ── Essential fields ────────────────────────────────────────────── + { + label: 'Base URL', + name: 'config.url', + type: FormFieldType.Text, + required: true, + placeholder: 'https://api.example.com/v1/resources', + }, + { + label: 'HTTP Method', + name: 'config.method', + type: FormFieldType.Select, + required: true, + options: [ + { label: 'GET', value: 'GET' }, + { label: 'POST', value: 'POST' }, + ], + defaultValue: 'GET', + }, + { + label: 'Query Parameters', + name: 'config.query_params', + type: FormFieldType.Textarea, + required: false, + placeholder: `key=value\none_per_line=true`, + tooltip: t('setting.restApiQueryParamsTip'), + }, + { + label: 'Items Path', + name: 'config.items_path', + type: FormFieldType.Text, + required: false, + placeholder: '$.items', + tooltip: t('setting.restApiItemsPathTip'), + }, + { + label: 'ID Field', + name: 'config.id_field', + type: FormFieldType.Text, + required: false, + placeholder: 'id', + tooltip: t('setting.restApiIdFieldTip'), + }, + { + label: 'Auth Type', + name: 'config.auth_type', + type: FormFieldType.Select, + required: true, + options: [ + { label: 'None', value: 'none' }, + { label: 'API Key (Header)', value: 'api_key_header' }, + { label: 'Bearer Token', value: 'bearer' }, + { label: 'Basic Auth', value: 'basic' }, + ], + defaultValue: 'none', + }, + { + label: 'API Key Header Name', + name: 'config.auth_config.header_name', + type: FormFieldType.Text, + required: false, + placeholder: 'X-API-Key', + shouldRender: (values: any) => + values?.config?.auth_type === 'api_key_header', + customValidate: (val: string, values: any) => { + if ( + values?.config?.auth_type === 'api_key_header' && + !(val != null && String(val).trim()) + ) { + return t('setting.restApiValidationApiKeyHeaderNameRequired'); + } + return true; + }, + }, + { + label: 'API Key Value', + name: 'config.credentials.api_key', + type: FormFieldType.Password, + required: false, + shouldRender: (values: any) => + values?.config?.auth_type === 'api_key_header', + customValidate: (val: string, values: any) => { + if (values?.config?.auth_type === 'api_key_header' && !val) { + return t('setting.restApiValidationApiKeyRequired'); + } + return true; + }, + }, + { + label: 'Bearer Token', + name: 'config.credentials.token', + type: FormFieldType.Password, + required: false, + shouldRender: (values: any) => values?.config?.auth_type === 'bearer', + customValidate: (val: string, values: any) => { + if (values?.config?.auth_type === 'bearer' && !val) { + return t('setting.restApiValidationBearerTokenRequired'); + } + return true; + }, + }, + { + label: 'Username', + name: 'config.credentials.username', + type: FormFieldType.Text, + required: false, + shouldRender: (values: any) => values?.config?.auth_type === 'basic', + customValidate: (val: string, values: any) => { + if ( + values?.config?.auth_type === 'basic' && + !(val != null && String(val).trim()) + ) { + return t('setting.restApiValidationBasicUsernameRequired'); + } + return true; + }, + }, + { + label: 'Password', + name: 'config.credentials.password', + type: FormFieldType.Password, + required: false, + shouldRender: (values: any) => values?.config?.auth_type === 'basic', + customValidate: (val: string, values: any) => { + if (values?.config?.auth_type === 'basic' && !val) { + return t('setting.restApiValidationBasicPasswordRequired'); + } + return true; + }, + }, + { + label: 'Content Fields', + name: 'config.content_fields', + type: FormFieldType.Text, + required: true, + placeholder: 'title,body', + tooltip: t('setting.restApiContentFieldsTip'), + }, + { + label: 'Metadata Fields', + name: 'config.metadata_fields', + type: FormFieldType.Text, + required: false, + placeholder: 'author,category', + tooltip: t('setting.restApiMetadataFieldsTip'), + }, + { + label: 'Pagination Type', + name: 'config.pagination_type', + type: FormFieldType.Select, + required: true, + options: [ + { label: 'None', value: 'none' }, + { label: 'Page', value: 'page' }, + { label: 'Offset', value: 'offset' }, + { label: 'Cursor', value: 'cursor' }, + ], + defaultValue: 'none', + }, + { + label: 'Start Page', + name: 'config.pagination_config.start_page', + type: FormFieldType.Number, + required: false, + defaultValue: 1, + shouldRender: (values: any) => values?.config?.pagination_type === 'page', + }, + { + label: 'Offset Param', + name: 'config.pagination_config.offset_param', + type: FormFieldType.Text, + required: false, + defaultValue: 'offset', + shouldRender: (values: any) => + values?.config?.pagination_type === 'offset', + }, + { + label: 'Start Offset', + name: 'config.pagination_config.start_offset', + type: FormFieldType.Number, + required: false, + defaultValue: 0, + shouldRender: (values: any) => + values?.config?.pagination_type === 'offset', + }, + { + label: 'Cursor Param', + name: 'config.pagination_config.cursor_param', + type: FormFieldType.Text, + required: false, + defaultValue: 'cursor', + shouldRender: (values: any) => + values?.config?.pagination_type === 'cursor', + }, + { + label: 'Next Cursor JSONPath', + name: 'config.pagination_config.next_cursor_path', + type: FormFieldType.Text, + required: false, + placeholder: '$.next_cursor', + shouldRender: (values: any) => + values?.config?.pagination_type === 'cursor', + tooltip: t('setting.restApiNextCursorPathTip'), + }, + // ── Advanced settings toggle ────────────────────────────────────── + { + label: 'Advanced Settings', + name: 'config.show_advanced', + type: FormFieldType.Switch, + required: false, + defaultValue: false, + }, + // ── Advanced fields (hidden until toggled) ──────────────────────── + { + label: 'Custom Headers (JSON)', + name: 'config.headers', + type: FormFieldType.Textarea, + required: false, + placeholder: `{"X-Custom-Header": "value"}`, + tooltip: t('setting.restApiHeadersTip'), + shouldRender: (values: any) => !!values?.config?.show_advanced, + }, + { + label: 'Limit Param', + name: 'config.pagination_config.limit_param', + type: FormFieldType.Text, + required: false, + placeholder: 'limit (leave empty if already in Query Parameters)', + shouldRender: (values: any) => + !!values?.config?.show_advanced && + values?.config?.pagination_type === 'offset', + }, + { + label: 'Initial Cursor', + name: 'config.pagination_config.initial_cursor', + type: FormFieldType.Text, + required: false, + shouldRender: (values: any) => + !!values?.config?.show_advanced && + values?.config?.pagination_type === 'cursor', + }, + { + label: 'Max Pages', + name: 'config.max_pages', + type: FormFieldType.Number, + required: false, + defaultValue: 1000, + shouldRender: (values: any) => !!values?.config?.show_advanced, + }, + { + label: 'Request Delay (seconds)', + name: 'config.request_delay', + type: FormFieldType.Number, + required: false, + defaultValue: 0.5, + placeholder: '0.5', + tooltip: t('setting.restApiRequestDelayTip'), + shouldRender: (values: any) => !!values?.config?.show_advanced, + }, + { + label: 'Poll Timestamp Field', + name: 'config.poll_timestamp_field', + type: FormFieldType.Text, + required: false, + placeholder: 'updated_at', + tooltip: t('setting.restApiPollTimestampFieldTip'), + shouldRender: (values: any) => !!values?.config?.show_advanced, + }, + { + label: 'Request Body (POST) JSON', + name: 'config.request_body', + type: FormFieldType.Textarea, + required: false, + placeholder: `{"status": "published"}`, + tooltip: t('setting.restApiRequestBodyTip'), + shouldRender: (values: any) => + !!values?.config?.show_advanced && values?.config?.method === 'POST', + }, + ], }; export const DataSourceFormDefaultValues = { @@ -1477,4 +1722,74 @@ export const DataSourceFormDefaultValues = { }, }, }, + [DataSourceKey.REST_API]: { + name: '', + source: DataSourceKey.REST_API, + config: { + url: '', + method: 'GET', + query_params: '', + headers: '', + auth_type: 'none', + auth_config: {}, + items_path: '', + id_field: '', + content_fields: '', + metadata_fields: '', + pagination_type: 'none', + pagination_config: {}, + poll_timestamp_field: '', + request_body: '', + max_pages: 1000, + request_delay: 0.5, + show_advanced: false, + credentials: { + api_key: '', + token: '', + username: '', + password: '', + }, + }, + }, +}; + +export const getDataSourceFieldsWithExtras = ( + source?: DataSourceKey, +): FormFieldConfig[] => { + if (!source) { + return []; + } + + const sourceFields = + DataSourceFormFields[source as keyof typeof DataSourceFormFields] || []; + const extraFields = getCommonExtraFields(source); + + if (source !== DataSourceKey.JIRA) { + return [...sourceFields, ...extraFields]; + } + + const modeFieldIndex = sourceFields.findIndex( + (field) => field.name === 'config.is_cloud', + ); + if (modeFieldIndex < 0) { + return [...sourceFields, ...extraFields]; + } + + const sharedFields = sourceFields.slice(0, modeFieldIndex); + const modeFields = sourceFields.slice(modeFieldIndex); + + const sharedCheckboxFieldIndex = sharedFields.findIndex( + (field) => field.type === FormFieldType.Checkbox, + ); + + if (sharedCheckboxFieldIndex < 0) { + return [...sharedFields, ...extraFields, ...modeFields]; + } + + return [ + ...sharedFields.slice(0, sharedCheckboxFieldIndex), + ...sharedFields.slice(sharedCheckboxFieldIndex), + ...extraFields, + ...modeFields, + ]; }; diff --git a/web/src/pages/user-setting/data-source/data-source-detail-page/index.tsx b/web/src/pages/user-setting/data-source/data-source-detail-page/index.tsx index 1a4554abeb..dfeb7e0830 100644 --- a/web/src/pages/user-setting/data-source/data-source-detail-page/index.tsx +++ b/web/src/pages/user-setting/data-source/data-source-detail-page/index.tsx @@ -17,6 +17,7 @@ import { FieldValues } from 'react-hook-form'; import { DataSourceFormBaseFields, DataSourceFormDefaultValues, + DataSourceKey, getCommonExtraDefaultValues, getDataSourceFieldsWithExtras, mergeDataSourceFormValues, @@ -26,6 +27,7 @@ import { useAddDataSource, useDataSourceResume, useFetchDataSourceDetail, + useTestDataSource, } from '../hooks'; import { DataSourceLogsTable } from './log-table'; @@ -144,6 +146,7 @@ const SourceDetailPage = () => { }, [detail, runSchedule]); const { addLoading, handleAddOk } = useAddDataSource({ isEdit: true }); + const { loading: testLoading, handleTest } = useTestDataSource(); const onSubmit = useCallback(() => { formRef?.current?.submit(); @@ -187,7 +190,6 @@ const SourceDetailPage = () => { detail as FieldValues, ), }; - console.log('defaultValue', defaultValueTemp); setDefaultValues(defaultValueTemp); } }, [detail, customFields, onSubmit]); @@ -213,7 +215,18 @@ const SourceDetailPage = () => { defaultValues={defaultValues} /> -
+
+ {detail?.source === DataSourceKey.REST_API && ( + + )}