Fix retrieval API handling for omitted dataset IDs (#13990)

### What problem does this PR solve?

This PR fixes a mismatch between the MCP retrieval contract and the
backend retrieval API.

`ragflow_retrieval` already describes `dataset_ids` as optional, but
`/api/v1/retrieval` still rejected omitted or empty `dataset_ids` with
`` `dataset_ids` is required. ``. That made MCP retrieval fail even
though the tool schema promised that the request could search across all
available datasets.

This change updates `/api/v1/retrieval` to accept missing or empty
`dataset_ids`, resolve all accessible datasets for the authenticated
user, and keep the route schema aligned with the new runtime behavior.
It also adds focused unit coverage for the fallback resolution path and
the no-accessible-datasets case.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Fixes: #13981

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Improved dataset resolution to reliably discover all accessible
datasets through proper pagination, replacing the previous parsing
method.
* Enhanced error handling with clearer messaging when no datasets are
available for retrieval.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
NeedmeFordev
2026-04-09 20:34:15 -07:00
committed by GitHub
parent 27329b40ed
commit 7315d25cbc

View File

@@ -58,6 +58,7 @@ JSON_RESPONSE = True
class RAGFlowConnector:
_MAX_DATASET_CACHE = 32
_CACHE_TTL = 300
_DATASET_PAGE_SIZE = 1000
_dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts)
_document_metadata_cache: OrderedDict[str, tuple[list[tuple[str, dict]], float | int]] = OrderedDict() # "dataset_id" -> ([(document_id, doc_metadata)], expiry_ts)
@@ -127,35 +128,74 @@ class RAGFlowConnector:
self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
self._document_metadata_cache.move_to_end(dataset_id)
async def list_datasets(
async def _fetch_datasets_page(
self,
*,
api_key: str,
page: int = 1,
page_size: int = 1000,
page: int,
page_size: int,
orderby: str = "create_time",
desc: bool = True,
id: str | None = None,
name: str | None = None,
):
"""Fetch one structured page of accessible datasets from the backend API."""
params = {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc}
if id:
params['id'] = id
if name :
params['name'] = name
params["id"] = id
if name:
params["name"] = name
res = await self._get("/datasets", params, api_key=api_key)
if not res or res.status_code != 200:
raise Exception([types.TextContent(type="text", text="Cannot process this operation.")])
error_message = None
if res is not None:
try:
error_message = res.json().get("message")
except Exception:
error_message = None
raise Exception([types.TextContent(type="text", text=error_message or "Cannot process this operation.")])
res = res.json()
if res.get("code") == 0:
result_list = []
for data in res["data"]:
d = {"description": data["description"], "id": data["id"]}
result_list.append(json.dumps(d, ensure_ascii=False))
return "\n".join(result_list)
return ""
res_json = res.json()
if res_json.get("code") != 0:
raise Exception([types.TextContent(type="text", text=res_json.get("message", "Cannot process this operation."))])
return res_json
async def list_datasets(self, *, api_key: str, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
"""Return accessible datasets as newline-delimited JSON for MCP tool descriptions."""
res_json = await self._fetch_datasets_page(api_key=api_key, page=page, page_size=page_size, orderby=orderby, desc=desc, id=id, name=name)
result_list = []
for data in res_json["data"]:
d = {"description": data["description"], "id": data["id"]}
result_list.append(json.dumps(d, ensure_ascii=False))
return "\n".join(result_list)
async def resolve_dataset_ids(self, *, api_key: str):
"""Resolve all accessible dataset IDs for MCP retrieval fallback."""
logging.info("Resolving accessible dataset IDs for MCP retrieval")
dataset_ids = []
page = 1
while True:
logging.debug("resolve_dataset_ids fetching /datasets page=%s page_size=%s", page, self._DATASET_PAGE_SIZE)
try:
res_json = await self._fetch_datasets_page(api_key=api_key, page=page, page_size=self._DATASET_PAGE_SIZE)
except Exception as exc:
logging.warning("resolve_dataset_ids failed to fetch /datasets page=%s error=%s", page, exc)
raise
datasets = res_json.get("data", [])
logging.debug("resolve_dataset_ids received %s datasets from page=%s", len(datasets), page)
dataset_ids.extend(data["id"] for data in datasets if data.get("id"))
total = res_json.get("total", len(dataset_ids))
if not datasets or len(dataset_ids) >= total:
break
page += 1
resolved = list(dict.fromkeys(dataset_ids))
logging.info("resolve_dataset_ids resolved %s accessible dataset IDs", len(resolved))
return resolved
async def retrieval(
self,
@@ -176,21 +216,12 @@ class RAGFlowConnector:
if document_ids is None:
document_ids = []
# If no dataset_ids provided or empty list, get all available dataset IDs
if not dataset_ids:
dataset_list_str = await self.list_datasets(api_key=api_key)
dataset_ids = []
# Parse the dataset list to extract IDs
if dataset_list_str:
for line in dataset_list_str.strip().split("\n"):
if line.strip():
try:
dataset_info = json.loads(line.strip())
dataset_ids.append(dataset_info["id"])
except (json.JSONDecodeError, KeyError):
# Skip malformed lines
continue
logging.info("MCP retrieval omitted dataset_ids; resolving accessible datasets")
dataset_ids = await self.resolve_dataset_ids(api_key=api_key)
if not dataset_ids:
logging.info("MCP retrieval found no accessible datasets for current user")
raise Exception([types.TextContent(type="text", text="No accessible datasets found.")])
data_json = {
"page": page,
@@ -522,22 +553,6 @@ async def call_tool(
rerank_id = arguments.get("rerank_id")
force_refresh = arguments.get("force_refresh", False)
# If no dataset_ids provided or empty list, get all available dataset IDs
if not dataset_ids:
dataset_list_str = await connector.list_datasets(api_key=api_key)
dataset_ids = []
# Parse the dataset list to extract IDs
if dataset_list_str:
for line in dataset_list_str.strip().split("\n"):
if line.strip():
try:
dataset_info = json.loads(line.strip())
dataset_ids.append(dataset_info["id"])
except (json.JSONDecodeError, KeyError):
# Skip malformed lines
continue
return await connector.retrieval(
api_key=api_key,
dataset_ids=dataset_ids,