diff --git a/mcp/server/server.py b/mcp/server/server.py index 9b1bd4f1d1..fd370ad2f5 100644 --- a/mcp/server/server.py +++ b/mcp/server/server.py @@ -58,6 +58,7 @@ JSON_RESPONSE = True class RAGFlowConnector: _MAX_DATASET_CACHE = 32 _CACHE_TTL = 300 + _DATASET_PAGE_SIZE = 1000 _dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts) _document_metadata_cache: OrderedDict[str, tuple[list[tuple[str, dict]], float | int]] = OrderedDict() # "dataset_id" -> ([(document_id, doc_metadata)], expiry_ts) @@ -127,35 +128,74 @@ class RAGFlowConnector: self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp()) self._document_metadata_cache.move_to_end(dataset_id) - async def list_datasets( + async def _fetch_datasets_page( self, *, api_key: str, - page: int = 1, - page_size: int = 1000, + page: int, + page_size: int, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None, ): + """Fetch one structured page of accessible datasets from the backend API.""" params = {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc} if id: - params['id'] = id - if name : - params['name'] = name - + params["id"] = id + if name: + params["name"] = name + res = await self._get("/datasets", params, api_key=api_key) if not res or res.status_code != 200: - raise Exception([types.TextContent(type="text", text="Cannot process this operation.")]) + error_message = None + if res is not None: + try: + error_message = res.json().get("message") + except Exception: + error_message = None + raise Exception([types.TextContent(type="text", text=error_message or "Cannot process this operation.")]) - res = res.json() - if res.get("code") == 0: - result_list = [] - for data in res["data"]: - d = {"description": data["description"], "id": data["id"]} - result_list.append(json.dumps(d, ensure_ascii=False)) - return "\n".join(result_list) - return "" + res_json = res.json() + if res_json.get("code") != 0: + raise Exception([types.TextContent(type="text", text=res_json.get("message", "Cannot process this operation."))]) + + return res_json + + async def list_datasets(self, *, api_key: str, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None): + """Return accessible datasets as newline-delimited JSON for MCP tool descriptions.""" + res_json = await self._fetch_datasets_page(api_key=api_key, page=page, page_size=page_size, orderby=orderby, desc=desc, id=id, name=name) + result_list = [] + for data in res_json["data"]: + d = {"description": data["description"], "id": data["id"]} + result_list.append(json.dumps(d, ensure_ascii=False)) + return "\n".join(result_list) + + async def resolve_dataset_ids(self, *, api_key: str): + """Resolve all accessible dataset IDs for MCP retrieval fallback.""" + logging.info("Resolving accessible dataset IDs for MCP retrieval") + dataset_ids = [] + page = 1 + + while True: + logging.debug("resolve_dataset_ids fetching /datasets page=%s page_size=%s", page, self._DATASET_PAGE_SIZE) + try: + res_json = await self._fetch_datasets_page(api_key=api_key, page=page, page_size=self._DATASET_PAGE_SIZE) + except Exception as exc: + logging.warning("resolve_dataset_ids failed to fetch /datasets page=%s error=%s", page, exc) + raise + + datasets = res_json.get("data", []) + logging.debug("resolve_dataset_ids received %s datasets from page=%s", len(datasets), page) + dataset_ids.extend(data["id"] for data in datasets if data.get("id")) + total = res_json.get("total", len(dataset_ids)) + if not datasets or len(dataset_ids) >= total: + break + page += 1 + + resolved = list(dict.fromkeys(dataset_ids)) + logging.info("resolve_dataset_ids resolved %s accessible dataset IDs", len(resolved)) + return resolved async def retrieval( self, @@ -176,21 +216,12 @@ class RAGFlowConnector: if document_ids is None: document_ids = [] - # If no dataset_ids provided or empty list, get all available dataset IDs if not dataset_ids: - dataset_list_str = await self.list_datasets(api_key=api_key) - dataset_ids = [] - - # Parse the dataset list to extract IDs - if dataset_list_str: - for line in dataset_list_str.strip().split("\n"): - if line.strip(): - try: - dataset_info = json.loads(line.strip()) - dataset_ids.append(dataset_info["id"]) - except (json.JSONDecodeError, KeyError): - # Skip malformed lines - continue + logging.info("MCP retrieval omitted dataset_ids; resolving accessible datasets") + dataset_ids = await self.resolve_dataset_ids(api_key=api_key) + if not dataset_ids: + logging.info("MCP retrieval found no accessible datasets for current user") + raise Exception([types.TextContent(type="text", text="No accessible datasets found.")]) data_json = { "page": page, @@ -522,22 +553,6 @@ async def call_tool( rerank_id = arguments.get("rerank_id") force_refresh = arguments.get("force_refresh", False) - # If no dataset_ids provided or empty list, get all available dataset IDs - if not dataset_ids: - dataset_list_str = await connector.list_datasets(api_key=api_key) - dataset_ids = [] - - # Parse the dataset list to extract IDs - if dataset_list_str: - for line in dataset_list_str.strip().split("\n"): - if line.strip(): - try: - dataset_info = json.loads(line.strip()) - dataset_ids.append(dataset_info["id"]) - except (json.JSONDecodeError, KeyError): - # Skip malformed lines - continue - return await connector.retrieval( api_key=api_key, dataset_ids=dataset_ids,