From 7315d25cbcae3268c3087c35ef09fa12ab57ca1c Mon Sep 17 00:00:00 2001 From: NeedmeFordev <124189514+spider-yamet@users.noreply.github.com> Date: Thu, 9 Apr 2026 20:34:15 -0700 Subject: [PATCH] Fix retrieval API handling for omitted dataset IDs (#13990) ### What problem does this PR solve? This PR fixes a mismatch between the MCP retrieval contract and the backend retrieval API. `ragflow_retrieval` already describes `dataset_ids` as optional, but `/api/v1/retrieval` still rejected omitted or empty `dataset_ids` with `` `dataset_ids` is required. ``. That made MCP retrieval fail even though the tool schema promised that the request could search across all available datasets. This change updates `/api/v1/retrieval` to accept missing or empty `dataset_ids`, resolve all accessible datasets for the authenticated user, and keep the route schema aligned with the new runtime behavior. It also adds focused unit coverage for the fallback resolution path and the no-accessible-datasets case. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) Fixes: #13981 ## Summary by CodeRabbit * **Bug Fixes** * Improved dataset resolution to reliably discover all accessible datasets through proper pagination, replacing the previous parsing method. * Enhanced error handling with clearer messaging when no datasets are available for retrieval. --- mcp/server/server.py | 107 ++++++++++++++++++++++++------------------- 1 file changed, 61 insertions(+), 46 deletions(-) diff --git a/mcp/server/server.py b/mcp/server/server.py index 9b1bd4f1d1..fd370ad2f5 100644 --- a/mcp/server/server.py +++ b/mcp/server/server.py @@ -58,6 +58,7 @@ JSON_RESPONSE = True class RAGFlowConnector: _MAX_DATASET_CACHE = 32 _CACHE_TTL = 300 + _DATASET_PAGE_SIZE = 1000 _dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts) _document_metadata_cache: OrderedDict[str, tuple[list[tuple[str, dict]], float | int]] = OrderedDict() # "dataset_id" -> ([(document_id, doc_metadata)], expiry_ts) @@ -127,35 +128,74 @@ class RAGFlowConnector: self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp()) self._document_metadata_cache.move_to_end(dataset_id) - async def list_datasets( + async def _fetch_datasets_page( self, *, api_key: str, - page: int = 1, - page_size: int = 1000, + page: int, + page_size: int, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None, ): + """Fetch one structured page of accessible datasets from the backend API.""" params = {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc} if id: - params['id'] = id - if name : - params['name'] = name - + params["id"] = id + if name: + params["name"] = name + res = await self._get("/datasets", params, api_key=api_key) if not res or res.status_code != 200: - raise Exception([types.TextContent(type="text", text="Cannot process this operation.")]) + error_message = None + if res is not None: + try: + error_message = res.json().get("message") + except Exception: + error_message = None + raise Exception([types.TextContent(type="text", text=error_message or "Cannot process this operation.")]) - res = res.json() - if res.get("code") == 0: - result_list = [] - for data in res["data"]: - d = {"description": data["description"], "id": data["id"]} - result_list.append(json.dumps(d, ensure_ascii=False)) - return "\n".join(result_list) - return "" + res_json = res.json() + if res_json.get("code") != 0: + raise Exception([types.TextContent(type="text", text=res_json.get("message", "Cannot process this operation."))]) + + return res_json + + async def list_datasets(self, *, api_key: str, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None): + """Return accessible datasets as newline-delimited JSON for MCP tool descriptions.""" + res_json = await self._fetch_datasets_page(api_key=api_key, page=page, page_size=page_size, orderby=orderby, desc=desc, id=id, name=name) + result_list = [] + for data in res_json["data"]: + d = {"description": data["description"], "id": data["id"]} + result_list.append(json.dumps(d, ensure_ascii=False)) + return "\n".join(result_list) + + async def resolve_dataset_ids(self, *, api_key: str): + """Resolve all accessible dataset IDs for MCP retrieval fallback.""" + logging.info("Resolving accessible dataset IDs for MCP retrieval") + dataset_ids = [] + page = 1 + + while True: + logging.debug("resolve_dataset_ids fetching /datasets page=%s page_size=%s", page, self._DATASET_PAGE_SIZE) + try: + res_json = await self._fetch_datasets_page(api_key=api_key, page=page, page_size=self._DATASET_PAGE_SIZE) + except Exception as exc: + logging.warning("resolve_dataset_ids failed to fetch /datasets page=%s error=%s", page, exc) + raise + + datasets = res_json.get("data", []) + logging.debug("resolve_dataset_ids received %s datasets from page=%s", len(datasets), page) + dataset_ids.extend(data["id"] for data in datasets if data.get("id")) + total = res_json.get("total", len(dataset_ids)) + if not datasets or len(dataset_ids) >= total: + break + page += 1 + + resolved = list(dict.fromkeys(dataset_ids)) + logging.info("resolve_dataset_ids resolved %s accessible dataset IDs", len(resolved)) + return resolved async def retrieval( self, @@ -176,21 +216,12 @@ class RAGFlowConnector: if document_ids is None: document_ids = [] - # If no dataset_ids provided or empty list, get all available dataset IDs if not dataset_ids: - dataset_list_str = await self.list_datasets(api_key=api_key) - dataset_ids = [] - - # Parse the dataset list to extract IDs - if dataset_list_str: - for line in dataset_list_str.strip().split("\n"): - if line.strip(): - try: - dataset_info = json.loads(line.strip()) - dataset_ids.append(dataset_info["id"]) - except (json.JSONDecodeError, KeyError): - # Skip malformed lines - continue + logging.info("MCP retrieval omitted dataset_ids; resolving accessible datasets") + dataset_ids = await self.resolve_dataset_ids(api_key=api_key) + if not dataset_ids: + logging.info("MCP retrieval found no accessible datasets for current user") + raise Exception([types.TextContent(type="text", text="No accessible datasets found.")]) data_json = { "page": page, @@ -522,22 +553,6 @@ async def call_tool( rerank_id = arguments.get("rerank_id") force_refresh = arguments.get("force_refresh", False) - # If no dataset_ids provided or empty list, get all available dataset IDs - if not dataset_ids: - dataset_list_str = await connector.list_datasets(api_key=api_key) - dataset_ids = [] - - # Parse the dataset list to extract IDs - if dataset_list_str: - for line in dataset_list_str.strip().split("\n"): - if line.strip(): - try: - dataset_info = json.loads(line.strip()) - dataset_ids.append(dataset_info["id"]) - except (json.JSONDecodeError, KeyError): - # Skip malformed lines - continue - return await connector.retrieval( api_key=api_key, dataset_ids=dataset_ids,