mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Fix retrieval API handling for omitted dataset IDs (#13990)
### What problem does this PR solve? This PR fixes a mismatch between the MCP retrieval contract and the backend retrieval API. `ragflow_retrieval` already describes `dataset_ids` as optional, but `/api/v1/retrieval` still rejected omitted or empty `dataset_ids` with `` `dataset_ids` is required. ``. That made MCP retrieval fail even though the tool schema promised that the request could search across all available datasets. This change updates `/api/v1/retrieval` to accept missing or empty `dataset_ids`, resolve all accessible datasets for the authenticated user, and keep the route schema aligned with the new runtime behavior. It also adds focused unit coverage for the fallback resolution path and the no-accessible-datasets case. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) Fixes: #13981 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved dataset resolution to reliably discover all accessible datasets through proper pagination, replacing the previous parsing method. * Enhanced error handling with clearer messaging when no datasets are available for retrieval. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -58,6 +58,7 @@ JSON_RESPONSE = True
|
||||
class RAGFlowConnector:
|
||||
_MAX_DATASET_CACHE = 32
|
||||
_CACHE_TTL = 300
|
||||
_DATASET_PAGE_SIZE = 1000
|
||||
|
||||
_dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts)
|
||||
_document_metadata_cache: OrderedDict[str, tuple[list[tuple[str, dict]], float | int]] = OrderedDict() # "dataset_id" -> ([(document_id, doc_metadata)], expiry_ts)
|
||||
@@ -127,35 +128,74 @@ class RAGFlowConnector:
|
||||
self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
|
||||
self._document_metadata_cache.move_to_end(dataset_id)
|
||||
|
||||
async def list_datasets(
|
||||
async def _fetch_datasets_page(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
page: int = 1,
|
||||
page_size: int = 1000,
|
||||
page: int,
|
||||
page_size: int,
|
||||
orderby: str = "create_time",
|
||||
desc: bool = True,
|
||||
id: str | None = None,
|
||||
name: str | None = None,
|
||||
):
|
||||
"""Fetch one structured page of accessible datasets from the backend API."""
|
||||
params = {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc}
|
||||
if id:
|
||||
params['id'] = id
|
||||
if name :
|
||||
params['name'] = name
|
||||
|
||||
params["id"] = id
|
||||
if name:
|
||||
params["name"] = name
|
||||
|
||||
res = await self._get("/datasets", params, api_key=api_key)
|
||||
if not res or res.status_code != 200:
|
||||
raise Exception([types.TextContent(type="text", text="Cannot process this operation.")])
|
||||
error_message = None
|
||||
if res is not None:
|
||||
try:
|
||||
error_message = res.json().get("message")
|
||||
except Exception:
|
||||
error_message = None
|
||||
raise Exception([types.TextContent(type="text", text=error_message or "Cannot process this operation.")])
|
||||
|
||||
res = res.json()
|
||||
if res.get("code") == 0:
|
||||
result_list = []
|
||||
for data in res["data"]:
|
||||
d = {"description": data["description"], "id": data["id"]}
|
||||
result_list.append(json.dumps(d, ensure_ascii=False))
|
||||
return "\n".join(result_list)
|
||||
return ""
|
||||
res_json = res.json()
|
||||
if res_json.get("code") != 0:
|
||||
raise Exception([types.TextContent(type="text", text=res_json.get("message", "Cannot process this operation."))])
|
||||
|
||||
return res_json
|
||||
|
||||
async def list_datasets(self, *, api_key: str, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
|
||||
"""Return accessible datasets as newline-delimited JSON for MCP tool descriptions."""
|
||||
res_json = await self._fetch_datasets_page(api_key=api_key, page=page, page_size=page_size, orderby=orderby, desc=desc, id=id, name=name)
|
||||
result_list = []
|
||||
for data in res_json["data"]:
|
||||
d = {"description": data["description"], "id": data["id"]}
|
||||
result_list.append(json.dumps(d, ensure_ascii=False))
|
||||
return "\n".join(result_list)
|
||||
|
||||
async def resolve_dataset_ids(self, *, api_key: str):
|
||||
"""Resolve all accessible dataset IDs for MCP retrieval fallback."""
|
||||
logging.info("Resolving accessible dataset IDs for MCP retrieval")
|
||||
dataset_ids = []
|
||||
page = 1
|
||||
|
||||
while True:
|
||||
logging.debug("resolve_dataset_ids fetching /datasets page=%s page_size=%s", page, self._DATASET_PAGE_SIZE)
|
||||
try:
|
||||
res_json = await self._fetch_datasets_page(api_key=api_key, page=page, page_size=self._DATASET_PAGE_SIZE)
|
||||
except Exception as exc:
|
||||
logging.warning("resolve_dataset_ids failed to fetch /datasets page=%s error=%s", page, exc)
|
||||
raise
|
||||
|
||||
datasets = res_json.get("data", [])
|
||||
logging.debug("resolve_dataset_ids received %s datasets from page=%s", len(datasets), page)
|
||||
dataset_ids.extend(data["id"] for data in datasets if data.get("id"))
|
||||
total = res_json.get("total", len(dataset_ids))
|
||||
if not datasets or len(dataset_ids) >= total:
|
||||
break
|
||||
page += 1
|
||||
|
||||
resolved = list(dict.fromkeys(dataset_ids))
|
||||
logging.info("resolve_dataset_ids resolved %s accessible dataset IDs", len(resolved))
|
||||
return resolved
|
||||
|
||||
async def retrieval(
|
||||
self,
|
||||
@@ -176,21 +216,12 @@ class RAGFlowConnector:
|
||||
if document_ids is None:
|
||||
document_ids = []
|
||||
|
||||
# If no dataset_ids provided or empty list, get all available dataset IDs
|
||||
if not dataset_ids:
|
||||
dataset_list_str = await self.list_datasets(api_key=api_key)
|
||||
dataset_ids = []
|
||||
|
||||
# Parse the dataset list to extract IDs
|
||||
if dataset_list_str:
|
||||
for line in dataset_list_str.strip().split("\n"):
|
||||
if line.strip():
|
||||
try:
|
||||
dataset_info = json.loads(line.strip())
|
||||
dataset_ids.append(dataset_info["id"])
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
# Skip malformed lines
|
||||
continue
|
||||
logging.info("MCP retrieval omitted dataset_ids; resolving accessible datasets")
|
||||
dataset_ids = await self.resolve_dataset_ids(api_key=api_key)
|
||||
if not dataset_ids:
|
||||
logging.info("MCP retrieval found no accessible datasets for current user")
|
||||
raise Exception([types.TextContent(type="text", text="No accessible datasets found.")])
|
||||
|
||||
data_json = {
|
||||
"page": page,
|
||||
@@ -522,22 +553,6 @@ async def call_tool(
|
||||
rerank_id = arguments.get("rerank_id")
|
||||
force_refresh = arguments.get("force_refresh", False)
|
||||
|
||||
# If no dataset_ids provided or empty list, get all available dataset IDs
|
||||
if not dataset_ids:
|
||||
dataset_list_str = await connector.list_datasets(api_key=api_key)
|
||||
dataset_ids = []
|
||||
|
||||
# Parse the dataset list to extract IDs
|
||||
if dataset_list_str:
|
||||
for line in dataset_list_str.strip().split("\n"):
|
||||
if line.strip():
|
||||
try:
|
||||
dataset_info = json.loads(line.strip())
|
||||
dataset_ids.append(dataset_info["id"])
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
# Skip malformed lines
|
||||
continue
|
||||
|
||||
return await connector.retrieval(
|
||||
api_key=api_key,
|
||||
dataset_ids=dataset_ids,
|
||||
|
||||
Reference in New Issue
Block a user