From a9ddcae0b314268d77fac7f44d02abfbec04fd6c Mon Sep 17 00:00:00 2001 From: xu haiLong <42533614+xuhaiL@users.noreply.github.com> Date: Thu, 18 Jun 2026 09:39:37 +0800 Subject: [PATCH] =?UTF-8?q?Fix:=20MCP=20dataset=20discovery=20fails=20due?= =?UTF-8?q?=20to=20REST=20API=20max=20page=20size=20limit=20=E2=80=A6=20(#?= =?UTF-8?q?16148)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix #16146 --- mcp/server/server.py | 77 ++++++++++---- test/unit_test/mcp/test_server_datasets.py | 117 +++++++++++++++++++++ 2 files changed, 172 insertions(+), 22 deletions(-) create mode 100644 test/unit_test/mcp/test_server_datasets.py diff --git a/mcp/server/server.py b/mcp/server/server.py index 81c8da3f07..760b7121f3 100644 --- a/mcp/server/server.py +++ b/mcp/server/server.py @@ -58,7 +58,8 @@ JSON_RESPONSE = True class RAGFlowConnector: _MAX_DATASET_CACHE = 32 _CACHE_TTL = 300 - _DATASET_PAGE_SIZE = 1000 + # Keep in sync with api.utils.pagination_utils.REST_API_MAX_PAGE_SIZE. + _REST_API_MAX_PAGE_SIZE = 100 _dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts) _document_metadata_cache: OrderedDict[str, tuple[list[tuple[str, dict]], float | int]] = OrderedDict() # "dataset_id" -> ([(document_id, doc_metadata)], expiry_ts) @@ -162,11 +163,55 @@ class RAGFlowConnector: return res_json - async def list_datasets(self, *, api_key: str, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None): + async def _fetch_all_datasets( + self, + *, + api_key: str, + orderby: str = "create_time", + desc: bool = True, + id: str | None = None, + name: str | None = None, + ): + """Fetch all accessible datasets without exceeding the REST API page-size limit.""" + datasets = [] + page = 1 + + while True: + logging.debug("fetching all /datasets page=%s page_size=%s", page, self._REST_API_MAX_PAGE_SIZE) + res_json = await self._fetch_datasets_page( + api_key=api_key, + page=page, + page_size=self._REST_API_MAX_PAGE_SIZE, + orderby=orderby, + desc=desc, + id=id, + name=name, + ) + page_datasets = res_json.get("data", []) + logging.debug("received %s datasets from page=%s", len(page_datasets), page) + if not page_datasets: + break + + datasets.extend(page_datasets) + total = res_json.get("total") + if total is not None and len(datasets) >= total: + break + + page += 1 + + return datasets + + async def list_datasets(self, *, api_key: str, page: int = 1, page_size: int = -1, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None): """Return accessible datasets as newline-delimited JSON for MCP tool descriptions.""" - res_json = await self._fetch_datasets_page(api_key=api_key, page=page, page_size=page_size, orderby=orderby, desc=desc, id=id, name=name) + if page_size == -1: + datasets = await self._fetch_all_datasets(api_key=api_key, orderby=orderby, desc=desc, id=id, name=name) + else: + page_size = min(page_size, self._REST_API_MAX_PAGE_SIZE) + res_json = await self._fetch_datasets_page(api_key=api_key, page=page, page_size=page_size, orderby=orderby, desc=desc, id=id, name=name) + datasets = res_json["data"] + result_list = [] - for data in res_json["data"]: + for data in datasets: d = {"description": data["description"], "id": data["id"]} result_list.append(json.dumps(d, ensure_ascii=False)) return "\n".join(result_list) @@ -174,25 +219,13 @@ class RAGFlowConnector: async def resolve_dataset_ids(self, *, api_key: str): """Resolve all accessible dataset IDs for MCP retrieval fallback.""" logging.info("Resolving accessible dataset IDs for MCP retrieval") - dataset_ids = [] - page = 1 - - while True: - logging.debug("resolve_dataset_ids fetching /datasets page=%s page_size=%s", page, self._DATASET_PAGE_SIZE) - try: - res_json = await self._fetch_datasets_page(api_key=api_key, page=page, page_size=self._DATASET_PAGE_SIZE) - except Exception as exc: - logging.warning("resolve_dataset_ids failed to fetch /datasets page=%s error=%s", page, exc) - raise - - datasets = res_json.get("data", []) - logging.debug("resolve_dataset_ids received %s datasets from page=%s", len(datasets), page) - dataset_ids.extend(data["id"] for data in datasets if data.get("id")) - total = res_json.get("total", len(dataset_ids)) - if not datasets or len(dataset_ids) >= total: - break - page += 1 + try: + datasets = await self._fetch_all_datasets(api_key=api_key) + except Exception as exc: + logging.warning("resolve_dataset_ids failed to fetch /datasets error=%s", exc) + raise + dataset_ids = [data["id"] for data in datasets if data.get("id")] resolved = list(dict.fromkeys(dataset_ids)) logging.info("resolve_dataset_ids resolved %s accessible dataset IDs", len(resolved)) return resolved diff --git a/test/unit_test/mcp/test_server_datasets.py b/test/unit_test/mcp/test_server_datasets.py new file mode 100644 index 0000000000..bd75e157cb --- /dev/null +++ b/test/unit_test/mcp/test_server_datasets.py @@ -0,0 +1,117 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib.util +import json +from pathlib import Path + +import pytest + + +def _load_mcp_server(): + server_path = Path(__file__).resolve().parents[3] / "mcp" / "server" / "server.py" + spec = importlib.util.spec_from_file_location("ragflow_mcp_server_unit", server_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +class _FakeResponse: + status_code = 200 + + def __init__(self, payload): + self._payload = payload + + def json(self): + return self._payload + + +def _datasets(count): + return [{"id": f"dataset-{idx}", "description": f"description-{idx}"} for idx in range(count)] + + +@pytest.fixture() +def mcp_server(): + return _load_mcp_server() + + +def _stub_dataset_pages(monkeypatch, connector, datasets): + requests = [] + + async def _get(path, params=None, api_key=""): + requests.append({"path": path, "params": dict(params), "api_key": api_key}) + page = params["page"] + page_size = params["page_size"] + start = (page - 1) * page_size + end = start + page_size + return _FakeResponse({"code": 0, "data": datasets[start:end], "total": len(datasets)}) + + monkeypatch.setattr(connector, "_get", _get) + return requests + + +@pytest.mark.asyncio +async def test_list_datasets_default_fetches_all_with_rest_page_size_limit(monkeypatch, mcp_server): + connector = mcp_server.RAGFlowConnector(base_url=mcp_server.BASE_URL) + requests = _stub_dataset_pages(monkeypatch, connector, _datasets(250)) + + result = await connector.list_datasets(api_key="unit-key") + + rows = [json.loads(line) for line in result.splitlines()] + assert [row["id"] for row in rows] == [f"dataset-{idx}" for idx in range(250)] + assert [request["params"]["page"] for request in requests] == [1, 2, 3] + assert all(request["path"] == "/datasets" for request in requests) + assert all(request["params"]["page_size"] == 100 for request in requests) + + +@pytest.mark.asyncio +async def test_resolve_dataset_ids_fetches_all_pages_and_deduplicates(monkeypatch, mcp_server): + connector = mcp_server.RAGFlowConnector(base_url=mcp_server.BASE_URL) + datasets = _datasets(101) + [{"id": "dataset-100", "description": "duplicate"}] + requests = _stub_dataset_pages(monkeypatch, connector, datasets) + + result = await connector.resolve_dataset_ids(api_key="unit-key") + + assert result == [f"dataset-{idx}" for idx in range(101)] + assert [request["params"]["page"] for request in requests] == [1, 2] + assert all(request["params"]["page_size"] == 100 for request in requests) + + +@pytest.mark.asyncio +async def test_list_datasets_clamps_explicit_page_size_to_rest_limit_and_preserves_filters(monkeypatch, mcp_server): + connector = mcp_server.RAGFlowConnector(base_url=mcp_server.BASE_URL) + requests = _stub_dataset_pages(monkeypatch, connector, _datasets(150)) + + result = await connector.list_datasets( + api_key="unit-key", + page=1, + page_size=1000, + orderby="name", + desc=False, + id="dataset-1", + name="target", + ) + + assert len(result.splitlines()) == 100 + assert len(requests) == 1 + assert requests[0]["params"] == { + "page": 1, + "page_size": 100, + "orderby": "name", + "desc": False, + "id": "dataset-1", + "name": "target", + }