fix(api): check kb ownership in /dify/retrieval (#15028)

POST /api/v1/dify/retrieval resolved the caller via @apikey_required
(injecting tenant_id) but then fetched the requested knowledge_id with
no tenant filter and ran the full retrieval pipeline against
kb.tenant_id (the owner). Any valid Dify-compatible API key could
retrieve chunks from any tenant whose KB UUID was known. Adds the
missing ownership check.

## Root Cause
api/apps/sdk/dify_retrieval.py line 253:
KnowledgebaseService.get_by_id(kb_id) fetched the KB by id alone, then
the handler used kb.tenant_id (the OWNER) to build the embedding model
and call the retriever. The caller tenant_id was only used downstream at
line 278 for retrieval_by_children, well after cross-tenant data was
already retrieved.

grep confirmed there was no KnowledgebaseService.accessible call
anywhere in the handler.

## Fix
Two-line guard immediately after the existing get_by_id lookup,
mirroring the pattern PR #14749 lands for the sibling sdk/doc.py routes
(download, parse, stop_parsing, retrieval_test):

    e, kb = KnowledgebaseService.get_by_id(kb_id)
    if not e:
return build_error_result(message="Knowledgebase not found!",
code=RetCode.NOT_FOUND)
+   if not KnowledgebaseService.accessible(kb_id, tenant_id):
+ return build_error_result(message="No authorization.",
code=RetCode.AUTHENTICATION_ERROR)
    if kb.tenant_embd_id:
        ...

KnowledgebaseService.accessible already handles solo-tenant ownership,
team membership via TenantService.get_joined_tenants_by_user_id, and the
permission=ME distinction. No behavior change for legitimate callers;
cross-tenant callers now receive RetCode.AUTHENTICATION_ERROR (109).

## Test Plan
- [x] Regression test added:
test/unit_test/api/apps/sdk/test_dify_retrieval.py
- test_cross_tenant_request_is_rejected -- attacker tenant calling owner
tenant KB gets 109; retriever is not invoked
- test_same_tenant_request_succeeds -- owner tenant gets the records
back
- test_missing_knowledge_base_returns_not_found -- missing KB returns
404 BEFORE the access check fires (legit callers see the clearer
message)
- [x] All 3 tests pass after the fix
- [x] Cross-tenant test FAILS on pre-fix main (KeyError on result[code]
because handler leaks records dict instead of returning auth error)
- [x] ruff check clean on both changed files
- [x] No drive-by reformatting in dify_retrieval.py -- only the 2 added
lines

### Post-fix output

    test_cross_tenant_request_is_rejected           PASSED [ 33%]
    test_same_tenant_request_succeeds               PASSED [ 66%]
    test_missing_knowledge_base_returns_not_found   PASSED [100%]

============================== 3 passed in 0.04s
===============================

Closes #15027
This commit is contained in:
dripsmvcp
2026-05-21 14:29:00 +09:00
committed by GitHub
parent 0c93161a14
commit 440153c378
3 changed files with 258 additions and 0 deletions

View File

@@ -253,6 +253,13 @@ async def retrieval(tenant_id):
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
if not KnowledgebaseService.accessible(kb_id, tenant_id):
logger.warning(
"Rejected /dify/retrieval cross-tenant access: caller_tenant=%s knowledge_id=%s",
tenant_id,
kb_id,
)
return build_error_result(message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
if kb.tenant_embd_id:
model_config = get_model_config_by_id(kb.tenant_embd_id)
else:

View File

@@ -273,6 +273,7 @@ def test_retrieval_success_with_metadata_and_kg(monkeypatch):
monkeypatch.setattr(module, "jsonify", lambda payload: payload)
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [{"doc_id": "doc-1"}])
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda _kb_id, _tenant_id: True)
monkeypatch.setattr(module, "convert_conditions", lambda cond: cond.get("conditions", []))
monkeypatch.setattr(module, "meta_filter", lambda *_args, **_kwargs: [])
@@ -319,6 +320,7 @@ def test_retrieval_not_found_exception_mapping(monkeypatch):
_set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"})
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [])
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda _kb_id, _tenant_id: True)
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: [])
class _BrokenRetriever:
@@ -338,6 +340,7 @@ def test_retrieval_generic_exception_mapping(monkeypatch):
_set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"})
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [])
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda _kb_id, _tenant_id: True)
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: [])
class _BrokenRetriever:

View File

@@ -0,0 +1,248 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Regression tests for retrieval in api/apps/sdk/dify_retrieval.py.
Issue #15027: cross-tenant knowledge-base access via POST /api/v1/dify/retrieval.
The handler authenticated the caller via @apikey_required (resolving
tenant_id) but then fetched the requested knowledge_id with no tenant
filter, allowing any valid API key to retrieve chunks from any other
tenant's KB by id. The fix adds a KnowledgebaseService.accessible(...)
check immediately after the lookup.
"""
import asyncio
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _PassthroughManager:
def route(self, *_args, **_kwargs):
return lambda func: func
class _AwaitableValue:
def __init__(self, value):
self._value = value
def __await__(self):
async def _co():
return self._value
return _co().__await__()
def _stub(monkeypatch, name, **attrs):
mod = ModuleType(name)
for key, value in attrs.items():
setattr(mod, key, value)
monkeypatch.setitem(sys.modules, name, mod)
# If `name` is a submodule, also overwrite the attribute on the parent
# package. Otherwise `from <parent> import <child>` resolves to the
# already-cached real submodule via attribute lookup, bypassing our
# sys.modules entry and our stub.
if "." in name:
parent_name, _, child_name = name.rpartition(".")
parent_mod = sys.modules.get(parent_name)
if parent_mod is not None:
monkeypatch.setattr(parent_mod, child_name, mod, raising=False)
return mod
class _FakeRetriever:
def __init__(self, chunks=None):
self._chunks = chunks if chunks is not None else []
self.retrieval_calls = []
async def retrieval(self, question, embd_mdl, tenant_id, kb_ids, **kwargs):
self.retrieval_calls.append({"question": question, "tenant_id": tenant_id, "kb_ids": list(kb_ids)})
return {"chunks": list(self._chunks)}
def retrieval_by_children(self, chunks, _tenant_ids):
return chunks
class _FakeKGRetriever:
async def retrieval(self, *_a, **_k):
return {"content_with_weight": ""}
def _load_dify_retrieval(monkeypatch, *, kb, accessible, request_body, chunks=None):
"""Load dify_retrieval.py with minimum stubs to exercise the retrieval handler."""
_stub(
monkeypatch,
"api.utils.api_utils",
apikey_required=lambda func: func,
build_error_result=lambda message="", code=0, data=False: {"code": code, "message": message, "data": data},
get_request_json=lambda: _AwaitableValue(request_body),
get_json_result=lambda code=0, message="", data=None: {"code": code, "message": message, "data": data},
)
_stub(
monkeypatch,
"api.db.services.document_service",
DocumentService=SimpleNamespace(get_by_id=lambda _id: (True, SimpleNamespace(meta_fields={}))),
)
_stub(
monkeypatch,
"api.db.services.doc_metadata_service",
DocMetadataService=SimpleNamespace(get_flatted_meta_by_kbs=lambda _ids: {}),
)
acc_fn = accessible if callable(accessible) else (lambda *_a, **_k: accessible)
_stub(
monkeypatch,
"api.db.services.knowledgebase_service",
KnowledgebaseService=SimpleNamespace(get_by_id=lambda _id: kb, accessible=acc_fn),
)
_stub(monkeypatch, "api.db.services.llm_service", LLMBundle=lambda *_a, **_k: SimpleNamespace())
_stub(
monkeypatch,
"api.db.joint_services.tenant_model_service",
get_model_config_by_id=lambda *_a, **_k: {},
get_model_config_by_type_and_name=lambda *_a, **_k: {},
get_tenant_default_model_by_type=lambda *_a, **_k: {},
)
_stub(
monkeypatch,
"common.metadata_utils",
meta_filter=lambda *_a, **_k: [],
convert_conditions=lambda c: c,
)
_stub(monkeypatch, "rag.app.tag", label_question=lambda *_a, **_k: {})
fake_retriever = _FakeRetriever(chunks=chunks)
_stub(
monkeypatch,
"common.settings",
retriever=fake_retriever,
kg_retriever=_FakeKGRetriever(),
)
quart_stub = ModuleType("quart")
quart_stub.request = SimpleNamespace(method="POST", args={})
quart_stub.jsonify = lambda payload: payload
monkeypatch.setitem(sys.modules, "quart", quart_stub)
repo_root = Path(__file__).resolve().parents[5]
module_path = repo_root / "api" / "apps" / "sdk" / "dify_retrieval.py"
spec = importlib.util.spec_from_file_location("test_dify_retrieval_module", module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _PassthroughManager()
monkeypatch.setitem(sys.modules, "test_dify_retrieval_module", module)
spec.loader.exec_module(module)
module._fake_retriever = fake_retriever
return module
@pytest.mark.p1
class TestDifyRetrievalTenantCheck:
"""Regression for #15027: cross-tenant KB exposure via /dify/retrieval."""
@pytest.mark.p1
def test_cross_tenant_request_is_rejected(self, monkeypatch, caplog):
"""A caller whose tenant does NOT own the requested KB must be denied.
Also verifies that the denial is recorded via the module logger so
operators can audit cross-tenant access attempts after the fact.
"""
import logging
owner_kb = SimpleNamespace(id="kb-victim", tenant_id="tenant-owner", tenant_embd_id="", embd_id="bge")
request_body = {
"knowledge_id": "kb-victim",
"query": "VICTIM_SECRET",
"retrieval_setting": {"top_k": 10, "score_threshold": 0.0},
}
def _accessible_only_for_owner(kb_id, user_id):
return user_id == "tenant-owner"
module = _load_dify_retrieval(
monkeypatch,
kb=(True, owner_kb),
accessible=_accessible_only_for_owner,
request_body=request_body,
chunks=[{"doc_id": "d1", "content_with_weight": "VICTIM_SECRET ...", "similarity": 0.9, "docnm_kwd": "doc.txt"}],
)
caplog.set_level(logging.WARNING, logger=module.__name__)
result = asyncio.run(module.retrieval(tenant_id="tenant-attacker"))
assert result["code"] == 109, f"expected AUTHENTICATION_ERROR (109), got {result}"
msg = result["message"].lower()
assert "authorization" in msg or "authentication" in msg
assert "records" not in result, "cross-tenant request leaked records"
assert module._fake_retriever.retrieval_calls == [], "retriever invoked despite denial"
denial_logs = [r for r in caplog.records if r.levelno == logging.WARNING and "cross-tenant" in r.getMessage()]
assert denial_logs, "denial branch must emit a WARNING audit log"
rendered = denial_logs[0].getMessage()
assert "tenant-attacker" in rendered, "caller tenant must appear in the audit log"
assert "kb-victim" in rendered, "denied knowledge_id must appear in the audit log"
assert "VICTIM_SECRET" not in rendered, "audit log must not leak request payload contents"
@pytest.mark.p1
def test_same_tenant_request_succeeds(self, monkeypatch):
"""When the caller's tenant owns the KB, retrieval proceeds normally."""
owner_kb = SimpleNamespace(id="kb-owner", tenant_id="tenant-owner", tenant_embd_id="", embd_id="bge")
request_body = {
"knowledge_id": "kb-owner",
"query": "hello",
"retrieval_setting": {"top_k": 5, "score_threshold": 0.0},
}
module = _load_dify_retrieval(
monkeypatch,
kb=(True, owner_kb),
accessible=lambda _id, _u: True,
request_body=request_body,
chunks=[{"doc_id": "d1", "content_with_weight": "hello world", "similarity": 0.8, "docnm_kwd": "doc.txt"}],
)
result = asyncio.run(module.retrieval(tenant_id="tenant-owner"))
assert "records" in result
assert len(result["records"]) == 1
assert result["records"][0]["content"] == "hello world"
assert module._fake_retriever.retrieval_calls, "retriever was not called on legitimate request"
@pytest.mark.p1
def test_missing_knowledge_base_returns_not_found(self, monkeypatch):
"""KB id that does not exist returns 404 before the access check fires."""
request_body = {"knowledge_id": "kb-does-not-exist", "query": "hello"}
def _accessible_should_not_be_called(*_a, **_k):
raise AssertionError("accessible() must not be called for a missing KB")
module = _load_dify_retrieval(
monkeypatch,
kb=(False, None),
accessible=_accessible_should_not_be_called,
request_body=request_body,
)
result = asyncio.run(module.retrieval(tenant_id="tenant-attacker"))
assert result["code"] == 404
assert "not found" in result["message"].lower()