diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py index 05885c380b..b4520685ca 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/sdk/dify_retrieval.py @@ -253,6 +253,13 @@ async def retrieval(tenant_id): e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND) + if not KnowledgebaseService.accessible(kb_id, tenant_id): + logger.warning( + "Rejected /dify/retrieval cross-tenant access: caller_tenant=%s knowledge_id=%s", + tenant_id, + kb_id, + ) + return build_error_result(message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) if kb.tenant_embd_id: model_config = get_model_config_by_id(kb.tenant_embd_id) else: diff --git a/test/testcases/restful_api/test_dify_retrieval_routes_unit.py b/test/testcases/restful_api/test_dify_retrieval_routes_unit.py index 3187846a7e..01b23e6107 100644 --- a/test/testcases/restful_api/test_dify_retrieval_routes_unit.py +++ b/test/testcases/restful_api/test_dify_retrieval_routes_unit.py @@ -273,6 +273,7 @@ def test_retrieval_success_with_metadata_and_kg(monkeypatch): monkeypatch.setattr(module, "jsonify", lambda payload: payload) monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [{"doc_id": "doc-1"}]) monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB())) + monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda _kb_id, _tenant_id: True) monkeypatch.setattr(module, "convert_conditions", lambda cond: cond.get("conditions", [])) monkeypatch.setattr(module, "meta_filter", lambda *_args, **_kwargs: []) @@ -319,6 +320,7 @@ def test_retrieval_not_found_exception_mapping(monkeypatch): _set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"}) monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: []) monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB())) + monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda _kb_id, _tenant_id: True) monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: []) class _BrokenRetriever: @@ -338,6 +340,7 @@ def test_retrieval_generic_exception_mapping(monkeypatch): _set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"}) monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: []) monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB())) + monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda _kb_id, _tenant_id: True) monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: []) class _BrokenRetriever: diff --git a/test/unit_test/api/apps/sdk/test_dify_retrieval.py b/test/unit_test/api/apps/sdk/test_dify_retrieval.py new file mode 100644 index 0000000000..113ff139f0 --- /dev/null +++ b/test/unit_test/api/apps/sdk/test_dify_retrieval.py @@ -0,0 +1,248 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Regression tests for retrieval in api/apps/sdk/dify_retrieval.py. + +Issue #15027: cross-tenant knowledge-base access via POST /api/v1/dify/retrieval. +The handler authenticated the caller via @apikey_required (resolving +tenant_id) but then fetched the requested knowledge_id with no tenant +filter, allowing any valid API key to retrieve chunks from any other +tenant's KB by id. The fix adds a KnowledgebaseService.accessible(...) +check immediately after the lookup. +""" + +import asyncio +import importlib.util +import sys +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest + + +class _PassthroughManager: + def route(self, *_args, **_kwargs): + return lambda func: func + + +class _AwaitableValue: + def __init__(self, value): + self._value = value + + def __await__(self): + async def _co(): + return self._value + + return _co().__await__() + + +def _stub(monkeypatch, name, **attrs): + mod = ModuleType(name) + for key, value in attrs.items(): + setattr(mod, key, value) + monkeypatch.setitem(sys.modules, name, mod) + # If `name` is a submodule, also overwrite the attribute on the parent + # package. Otherwise `from import ` resolves to the + # already-cached real submodule via attribute lookup, bypassing our + # sys.modules entry and our stub. + if "." in name: + parent_name, _, child_name = name.rpartition(".") + parent_mod = sys.modules.get(parent_name) + if parent_mod is not None: + monkeypatch.setattr(parent_mod, child_name, mod, raising=False) + return mod + + +class _FakeRetriever: + def __init__(self, chunks=None): + self._chunks = chunks if chunks is not None else [] + self.retrieval_calls = [] + + async def retrieval(self, question, embd_mdl, tenant_id, kb_ids, **kwargs): + self.retrieval_calls.append({"question": question, "tenant_id": tenant_id, "kb_ids": list(kb_ids)}) + return {"chunks": list(self._chunks)} + + def retrieval_by_children(self, chunks, _tenant_ids): + return chunks + + +class _FakeKGRetriever: + async def retrieval(self, *_a, **_k): + return {"content_with_weight": ""} + + +def _load_dify_retrieval(monkeypatch, *, kb, accessible, request_body, chunks=None): + """Load dify_retrieval.py with minimum stubs to exercise the retrieval handler.""" + _stub( + monkeypatch, + "api.utils.api_utils", + apikey_required=lambda func: func, + build_error_result=lambda message="", code=0, data=False: {"code": code, "message": message, "data": data}, + get_request_json=lambda: _AwaitableValue(request_body), + get_json_result=lambda code=0, message="", data=None: {"code": code, "message": message, "data": data}, + ) + + _stub( + monkeypatch, + "api.db.services.document_service", + DocumentService=SimpleNamespace(get_by_id=lambda _id: (True, SimpleNamespace(meta_fields={}))), + ) + _stub( + monkeypatch, + "api.db.services.doc_metadata_service", + DocMetadataService=SimpleNamespace(get_flatted_meta_by_kbs=lambda _ids: {}), + ) + + acc_fn = accessible if callable(accessible) else (lambda *_a, **_k: accessible) + _stub( + monkeypatch, + "api.db.services.knowledgebase_service", + KnowledgebaseService=SimpleNamespace(get_by_id=lambda _id: kb, accessible=acc_fn), + ) + + _stub(monkeypatch, "api.db.services.llm_service", LLMBundle=lambda *_a, **_k: SimpleNamespace()) + + _stub( + monkeypatch, + "api.db.joint_services.tenant_model_service", + get_model_config_by_id=lambda *_a, **_k: {}, + get_model_config_by_type_and_name=lambda *_a, **_k: {}, + get_tenant_default_model_by_type=lambda *_a, **_k: {}, + ) + + _stub( + monkeypatch, + "common.metadata_utils", + meta_filter=lambda *_a, **_k: [], + convert_conditions=lambda c: c, + ) + + _stub(monkeypatch, "rag.app.tag", label_question=lambda *_a, **_k: {}) + + fake_retriever = _FakeRetriever(chunks=chunks) + _stub( + monkeypatch, + "common.settings", + retriever=fake_retriever, + kg_retriever=_FakeKGRetriever(), + ) + + quart_stub = ModuleType("quart") + quart_stub.request = SimpleNamespace(method="POST", args={}) + quart_stub.jsonify = lambda payload: payload + monkeypatch.setitem(sys.modules, "quart", quart_stub) + + repo_root = Path(__file__).resolve().parents[5] + module_path = repo_root / "api" / "apps" / "sdk" / "dify_retrieval.py" + spec = importlib.util.spec_from_file_location("test_dify_retrieval_module", module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _PassthroughManager() + monkeypatch.setitem(sys.modules, "test_dify_retrieval_module", module) + spec.loader.exec_module(module) + module._fake_retriever = fake_retriever + return module + + +@pytest.mark.p1 +class TestDifyRetrievalTenantCheck: + """Regression for #15027: cross-tenant KB exposure via /dify/retrieval.""" + + @pytest.mark.p1 + def test_cross_tenant_request_is_rejected(self, monkeypatch, caplog): + """A caller whose tenant does NOT own the requested KB must be denied. + + Also verifies that the denial is recorded via the module logger so + operators can audit cross-tenant access attempts after the fact. + """ + import logging + + owner_kb = SimpleNamespace(id="kb-victim", tenant_id="tenant-owner", tenant_embd_id="", embd_id="bge") + request_body = { + "knowledge_id": "kb-victim", + "query": "VICTIM_SECRET", + "retrieval_setting": {"top_k": 10, "score_threshold": 0.0}, + } + + def _accessible_only_for_owner(kb_id, user_id): + return user_id == "tenant-owner" + + module = _load_dify_retrieval( + monkeypatch, + kb=(True, owner_kb), + accessible=_accessible_only_for_owner, + request_body=request_body, + chunks=[{"doc_id": "d1", "content_with_weight": "VICTIM_SECRET ...", "similarity": 0.9, "docnm_kwd": "doc.txt"}], + ) + + caplog.set_level(logging.WARNING, logger=module.__name__) + result = asyncio.run(module.retrieval(tenant_id="tenant-attacker")) + + assert result["code"] == 109, f"expected AUTHENTICATION_ERROR (109), got {result}" + msg = result["message"].lower() + assert "authorization" in msg or "authentication" in msg + assert "records" not in result, "cross-tenant request leaked records" + assert module._fake_retriever.retrieval_calls == [], "retriever invoked despite denial" + + denial_logs = [r for r in caplog.records if r.levelno == logging.WARNING and "cross-tenant" in r.getMessage()] + assert denial_logs, "denial branch must emit a WARNING audit log" + rendered = denial_logs[0].getMessage() + assert "tenant-attacker" in rendered, "caller tenant must appear in the audit log" + assert "kb-victim" in rendered, "denied knowledge_id must appear in the audit log" + assert "VICTIM_SECRET" not in rendered, "audit log must not leak request payload contents" + + @pytest.mark.p1 + def test_same_tenant_request_succeeds(self, monkeypatch): + """When the caller's tenant owns the KB, retrieval proceeds normally.""" + owner_kb = SimpleNamespace(id="kb-owner", tenant_id="tenant-owner", tenant_embd_id="", embd_id="bge") + request_body = { + "knowledge_id": "kb-owner", + "query": "hello", + "retrieval_setting": {"top_k": 5, "score_threshold": 0.0}, + } + + module = _load_dify_retrieval( + monkeypatch, + kb=(True, owner_kb), + accessible=lambda _id, _u: True, + request_body=request_body, + chunks=[{"doc_id": "d1", "content_with_weight": "hello world", "similarity": 0.8, "docnm_kwd": "doc.txt"}], + ) + + result = asyncio.run(module.retrieval(tenant_id="tenant-owner")) + + assert "records" in result + assert len(result["records"]) == 1 + assert result["records"][0]["content"] == "hello world" + assert module._fake_retriever.retrieval_calls, "retriever was not called on legitimate request" + + @pytest.mark.p1 + def test_missing_knowledge_base_returns_not_found(self, monkeypatch): + """KB id that does not exist returns 404 before the access check fires.""" + request_body = {"knowledge_id": "kb-does-not-exist", "query": "hello"} + + def _accessible_should_not_be_called(*_a, **_k): + raise AssertionError("accessible() must not be called for a missing KB") + + module = _load_dify_retrieval( + monkeypatch, + kb=(False, None), + accessible=_accessible_should_not_be_called, + request_body=request_body, + ) + + result = asyncio.run(module.retrieval(tenant_id="tenant-attacker")) + + assert result["code"] == 404 + assert "not found" in result["message"].lower()