mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
fix(api): check kb ownership in /dify/retrieval (#15028)
POST /api/v1/dify/retrieval resolved the caller via @apikey_required (injecting tenant_id) but then fetched the requested knowledge_id with no tenant filter and ran the full retrieval pipeline against kb.tenant_id (the owner). Any valid Dify-compatible API key could retrieve chunks from any tenant whose KB UUID was known. Adds the missing ownership check. ## Root Cause api/apps/sdk/dify_retrieval.py line 253: KnowledgebaseService.get_by_id(kb_id) fetched the KB by id alone, then the handler used kb.tenant_id (the OWNER) to build the embedding model and call the retriever. The caller tenant_id was only used downstream at line 278 for retrieval_by_children, well after cross-tenant data was already retrieved. grep confirmed there was no KnowledgebaseService.accessible call anywhere in the handler. ## Fix Two-line guard immediately after the existing get_by_id lookup, mirroring the pattern PR #14749 lands for the sibling sdk/doc.py routes (download, parse, stop_parsing, retrieval_test): e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND) + if not KnowledgebaseService.accessible(kb_id, tenant_id): + return build_error_result(message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) if kb.tenant_embd_id: ... KnowledgebaseService.accessible already handles solo-tenant ownership, team membership via TenantService.get_joined_tenants_by_user_id, and the permission=ME distinction. No behavior change for legitimate callers; cross-tenant callers now receive RetCode.AUTHENTICATION_ERROR (109). ## Test Plan - [x] Regression test added: test/unit_test/api/apps/sdk/test_dify_retrieval.py - test_cross_tenant_request_is_rejected -- attacker tenant calling owner tenant KB gets 109; retriever is not invoked - test_same_tenant_request_succeeds -- owner tenant gets the records back - test_missing_knowledge_base_returns_not_found -- missing KB returns 404 BEFORE the access check fires (legit callers see the clearer message) - [x] All 3 tests pass after the fix - [x] Cross-tenant test FAILS on pre-fix main (KeyError on result[code] because handler leaks records dict instead of returning auth error) - [x] ruff check clean on both changed files - [x] No drive-by reformatting in dify_retrieval.py -- only the 2 added lines ### Post-fix output test_cross_tenant_request_is_rejected PASSED [ 33%] test_same_tenant_request_succeeds PASSED [ 66%] test_missing_knowledge_base_returns_not_found PASSED [100%] ============================== 3 passed in 0.04s =============================== Closes #15027
This commit is contained in:
@@ -253,6 +253,13 @@ async def retrieval(tenant_id):
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
|
||||
if not KnowledgebaseService.accessible(kb_id, tenant_id):
|
||||
logger.warning(
|
||||
"Rejected /dify/retrieval cross-tenant access: caller_tenant=%s knowledge_id=%s",
|
||||
tenant_id,
|
||||
kb_id,
|
||||
)
|
||||
return build_error_result(message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
if kb.tenant_embd_id:
|
||||
model_config = get_model_config_by_id(kb.tenant_embd_id)
|
||||
else:
|
||||
|
||||
@@ -273,6 +273,7 @@ def test_retrieval_success_with_metadata_and_kg(monkeypatch):
|
||||
monkeypatch.setattr(module, "jsonify", lambda payload: payload)
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [{"doc_id": "doc-1"}])
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda _kb_id, _tenant_id: True)
|
||||
monkeypatch.setattr(module, "convert_conditions", lambda cond: cond.get("conditions", []))
|
||||
monkeypatch.setattr(module, "meta_filter", lambda *_args, **_kwargs: [])
|
||||
|
||||
@@ -319,6 +320,7 @@ def test_retrieval_not_found_exception_mapping(monkeypatch):
|
||||
_set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"})
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [])
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda _kb_id, _tenant_id: True)
|
||||
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: [])
|
||||
|
||||
class _BrokenRetriever:
|
||||
@@ -338,6 +340,7 @@ def test_retrieval_generic_exception_mapping(monkeypatch):
|
||||
_set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"})
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [])
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda _kb_id, _tenant_id: True)
|
||||
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: [])
|
||||
|
||||
class _BrokenRetriever:
|
||||
|
||||
248
test/unit_test/api/apps/sdk/test_dify_retrieval.py
Normal file
248
test/unit_test/api/apps/sdk/test_dify_retrieval.py
Normal file
@@ -0,0 +1,248 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Regression tests for retrieval in api/apps/sdk/dify_retrieval.py.
|
||||
|
||||
Issue #15027: cross-tenant knowledge-base access via POST /api/v1/dify/retrieval.
|
||||
The handler authenticated the caller via @apikey_required (resolving
|
||||
tenant_id) but then fetched the requested knowledge_id with no tenant
|
||||
filter, allowing any valid API key to retrieve chunks from any other
|
||||
tenant's KB by id. The fix adds a KnowledgebaseService.accessible(...)
|
||||
check immediately after the lookup.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _PassthroughManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
return lambda func: func
|
||||
|
||||
|
||||
class _AwaitableValue:
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def __await__(self):
|
||||
async def _co():
|
||||
return self._value
|
||||
|
||||
return _co().__await__()
|
||||
|
||||
|
||||
def _stub(monkeypatch, name, **attrs):
|
||||
mod = ModuleType(name)
|
||||
for key, value in attrs.items():
|
||||
setattr(mod, key, value)
|
||||
monkeypatch.setitem(sys.modules, name, mod)
|
||||
# If `name` is a submodule, also overwrite the attribute on the parent
|
||||
# package. Otherwise `from <parent> import <child>` resolves to the
|
||||
# already-cached real submodule via attribute lookup, bypassing our
|
||||
# sys.modules entry and our stub.
|
||||
if "." in name:
|
||||
parent_name, _, child_name = name.rpartition(".")
|
||||
parent_mod = sys.modules.get(parent_name)
|
||||
if parent_mod is not None:
|
||||
monkeypatch.setattr(parent_mod, child_name, mod, raising=False)
|
||||
return mod
|
||||
|
||||
|
||||
class _FakeRetriever:
|
||||
def __init__(self, chunks=None):
|
||||
self._chunks = chunks if chunks is not None else []
|
||||
self.retrieval_calls = []
|
||||
|
||||
async def retrieval(self, question, embd_mdl, tenant_id, kb_ids, **kwargs):
|
||||
self.retrieval_calls.append({"question": question, "tenant_id": tenant_id, "kb_ids": list(kb_ids)})
|
||||
return {"chunks": list(self._chunks)}
|
||||
|
||||
def retrieval_by_children(self, chunks, _tenant_ids):
|
||||
return chunks
|
||||
|
||||
|
||||
class _FakeKGRetriever:
|
||||
async def retrieval(self, *_a, **_k):
|
||||
return {"content_with_weight": ""}
|
||||
|
||||
|
||||
def _load_dify_retrieval(monkeypatch, *, kb, accessible, request_body, chunks=None):
|
||||
"""Load dify_retrieval.py with minimum stubs to exercise the retrieval handler."""
|
||||
_stub(
|
||||
monkeypatch,
|
||||
"api.utils.api_utils",
|
||||
apikey_required=lambda func: func,
|
||||
build_error_result=lambda message="", code=0, data=False: {"code": code, "message": message, "data": data},
|
||||
get_request_json=lambda: _AwaitableValue(request_body),
|
||||
get_json_result=lambda code=0, message="", data=None: {"code": code, "message": message, "data": data},
|
||||
)
|
||||
|
||||
_stub(
|
||||
monkeypatch,
|
||||
"api.db.services.document_service",
|
||||
DocumentService=SimpleNamespace(get_by_id=lambda _id: (True, SimpleNamespace(meta_fields={}))),
|
||||
)
|
||||
_stub(
|
||||
monkeypatch,
|
||||
"api.db.services.doc_metadata_service",
|
||||
DocMetadataService=SimpleNamespace(get_flatted_meta_by_kbs=lambda _ids: {}),
|
||||
)
|
||||
|
||||
acc_fn = accessible if callable(accessible) else (lambda *_a, **_k: accessible)
|
||||
_stub(
|
||||
monkeypatch,
|
||||
"api.db.services.knowledgebase_service",
|
||||
KnowledgebaseService=SimpleNamespace(get_by_id=lambda _id: kb, accessible=acc_fn),
|
||||
)
|
||||
|
||||
_stub(monkeypatch, "api.db.services.llm_service", LLMBundle=lambda *_a, **_k: SimpleNamespace())
|
||||
|
||||
_stub(
|
||||
monkeypatch,
|
||||
"api.db.joint_services.tenant_model_service",
|
||||
get_model_config_by_id=lambda *_a, **_k: {},
|
||||
get_model_config_by_type_and_name=lambda *_a, **_k: {},
|
||||
get_tenant_default_model_by_type=lambda *_a, **_k: {},
|
||||
)
|
||||
|
||||
_stub(
|
||||
monkeypatch,
|
||||
"common.metadata_utils",
|
||||
meta_filter=lambda *_a, **_k: [],
|
||||
convert_conditions=lambda c: c,
|
||||
)
|
||||
|
||||
_stub(monkeypatch, "rag.app.tag", label_question=lambda *_a, **_k: {})
|
||||
|
||||
fake_retriever = _FakeRetriever(chunks=chunks)
|
||||
_stub(
|
||||
monkeypatch,
|
||||
"common.settings",
|
||||
retriever=fake_retriever,
|
||||
kg_retriever=_FakeKGRetriever(),
|
||||
)
|
||||
|
||||
quart_stub = ModuleType("quart")
|
||||
quart_stub.request = SimpleNamespace(method="POST", args={})
|
||||
quart_stub.jsonify = lambda payload: payload
|
||||
monkeypatch.setitem(sys.modules, "quart", quart_stub)
|
||||
|
||||
repo_root = Path(__file__).resolve().parents[5]
|
||||
module_path = repo_root / "api" / "apps" / "sdk" / "dify_retrieval.py"
|
||||
spec = importlib.util.spec_from_file_location("test_dify_retrieval_module", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _PassthroughManager()
|
||||
monkeypatch.setitem(sys.modules, "test_dify_retrieval_module", module)
|
||||
spec.loader.exec_module(module)
|
||||
module._fake_retriever = fake_retriever
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
class TestDifyRetrievalTenantCheck:
|
||||
"""Regression for #15027: cross-tenant KB exposure via /dify/retrieval."""
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_cross_tenant_request_is_rejected(self, monkeypatch, caplog):
|
||||
"""A caller whose tenant does NOT own the requested KB must be denied.
|
||||
|
||||
Also verifies that the denial is recorded via the module logger so
|
||||
operators can audit cross-tenant access attempts after the fact.
|
||||
"""
|
||||
import logging
|
||||
|
||||
owner_kb = SimpleNamespace(id="kb-victim", tenant_id="tenant-owner", tenant_embd_id="", embd_id="bge")
|
||||
request_body = {
|
||||
"knowledge_id": "kb-victim",
|
||||
"query": "VICTIM_SECRET",
|
||||
"retrieval_setting": {"top_k": 10, "score_threshold": 0.0},
|
||||
}
|
||||
|
||||
def _accessible_only_for_owner(kb_id, user_id):
|
||||
return user_id == "tenant-owner"
|
||||
|
||||
module = _load_dify_retrieval(
|
||||
monkeypatch,
|
||||
kb=(True, owner_kb),
|
||||
accessible=_accessible_only_for_owner,
|
||||
request_body=request_body,
|
||||
chunks=[{"doc_id": "d1", "content_with_weight": "VICTIM_SECRET ...", "similarity": 0.9, "docnm_kwd": "doc.txt"}],
|
||||
)
|
||||
|
||||
caplog.set_level(logging.WARNING, logger=module.__name__)
|
||||
result = asyncio.run(module.retrieval(tenant_id="tenant-attacker"))
|
||||
|
||||
assert result["code"] == 109, f"expected AUTHENTICATION_ERROR (109), got {result}"
|
||||
msg = result["message"].lower()
|
||||
assert "authorization" in msg or "authentication" in msg
|
||||
assert "records" not in result, "cross-tenant request leaked records"
|
||||
assert module._fake_retriever.retrieval_calls == [], "retriever invoked despite denial"
|
||||
|
||||
denial_logs = [r for r in caplog.records if r.levelno == logging.WARNING and "cross-tenant" in r.getMessage()]
|
||||
assert denial_logs, "denial branch must emit a WARNING audit log"
|
||||
rendered = denial_logs[0].getMessage()
|
||||
assert "tenant-attacker" in rendered, "caller tenant must appear in the audit log"
|
||||
assert "kb-victim" in rendered, "denied knowledge_id must appear in the audit log"
|
||||
assert "VICTIM_SECRET" not in rendered, "audit log must not leak request payload contents"
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_same_tenant_request_succeeds(self, monkeypatch):
|
||||
"""When the caller's tenant owns the KB, retrieval proceeds normally."""
|
||||
owner_kb = SimpleNamespace(id="kb-owner", tenant_id="tenant-owner", tenant_embd_id="", embd_id="bge")
|
||||
request_body = {
|
||||
"knowledge_id": "kb-owner",
|
||||
"query": "hello",
|
||||
"retrieval_setting": {"top_k": 5, "score_threshold": 0.0},
|
||||
}
|
||||
|
||||
module = _load_dify_retrieval(
|
||||
monkeypatch,
|
||||
kb=(True, owner_kb),
|
||||
accessible=lambda _id, _u: True,
|
||||
request_body=request_body,
|
||||
chunks=[{"doc_id": "d1", "content_with_weight": "hello world", "similarity": 0.8, "docnm_kwd": "doc.txt"}],
|
||||
)
|
||||
|
||||
result = asyncio.run(module.retrieval(tenant_id="tenant-owner"))
|
||||
|
||||
assert "records" in result
|
||||
assert len(result["records"]) == 1
|
||||
assert result["records"][0]["content"] == "hello world"
|
||||
assert module._fake_retriever.retrieval_calls, "retriever was not called on legitimate request"
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_missing_knowledge_base_returns_not_found(self, monkeypatch):
|
||||
"""KB id that does not exist returns 404 before the access check fires."""
|
||||
request_body = {"knowledge_id": "kb-does-not-exist", "query": "hello"}
|
||||
|
||||
def _accessible_should_not_be_called(*_a, **_k):
|
||||
raise AssertionError("accessible() must not be called for a missing KB")
|
||||
|
||||
module = _load_dify_retrieval(
|
||||
monkeypatch,
|
||||
kb=(False, None),
|
||||
accessible=_accessible_should_not_be_called,
|
||||
request_body=request_body,
|
||||
)
|
||||
|
||||
result = asyncio.run(module.retrieval(tenant_id="tenant-attacker"))
|
||||
|
||||
assert result["code"] == 404
|
||||
assert "not found" in result["message"].lower()
|
||||
Reference in New Issue
Block a user