Files
ragflow/test/unit_test/api/apps/sdk/test_dify_retrieval.py
Wang Qi 4cbe597d7e Refactor: consolidate to use @login_required (#15652)
Refactor: consolidate to use @login_required
2026-06-05 11:35:00 +08:00

266 lines
9.9 KiB
Python

#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Regression tests for retrieval in api/apps/restful_apis/dify_retrieval_api.py.
Issue #15027: cross-tenant knowledge-base access via POST /api/v1/dify/retrieval.
The handler authenticated the caller and resolved tenant_id, but then fetched
the requested knowledge_id with no tenant filter, allowing any valid caller to
retrieve chunks from any other tenant's KB by id. The fix adds a
KnowledgebaseService.accessible(...) check immediately after the lookup.
"""
import asyncio
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _PassthroughManager:
def route(self, *_args, **_kwargs):
return lambda func: func
class _AwaitableValue:
def __init__(self, value):
self._value = value
def __await__(self):
async def _co():
return self._value
return _co().__await__()
def _stub(monkeypatch, name, **attrs):
mod = ModuleType(name)
for key, value in attrs.items():
setattr(mod, key, value)
monkeypatch.setitem(sys.modules, name, mod)
# If `name` is a submodule, also overwrite the attribute on the parent
# package. Otherwise `from <parent> import <child>` resolves to the
# already-cached real submodule via attribute lookup, bypassing our
# sys.modules entry and our stub.
if "." in name:
parent_name, _, child_name = name.rpartition(".")
parent_mod = sys.modules.get(parent_name)
if parent_mod is not None:
monkeypatch.setattr(parent_mod, child_name, mod, raising=False)
return mod
class _FakeRetriever:
def __init__(self, chunks=None):
self._chunks = chunks if chunks is not None else []
self.retrieval_calls = []
async def retrieval(self, question, embd_mdl, tenant_id, kb_ids, **kwargs):
self.retrieval_calls.append({"question": question, "tenant_id": tenant_id, "kb_ids": list(kb_ids)})
return {"chunks": list(self._chunks)}
def retrieval_by_children(self, chunks, _tenant_ids):
return chunks
class _FakeKGRetriever:
async def retrieval(self, *_a, **_k):
return {"content_with_weight": ""}
def _load_dify_retrieval(monkeypatch, *, kb, accessible, request_body, tenant_id, chunks=None):
"""Load dify_retrieval_api.py with minimum stubs to exercise the retrieval handler."""
def _add_tenant_id_to_kwargs(func):
async def wrapper(**kwargs):
kwargs["tenant_id"] = tenant_id
return await func(**kwargs)
return wrapper
_stub(
monkeypatch,
"api.apps",
current_user=SimpleNamespace(id=tenant_id),
login_required=lambda func: func,
)
_stub(
monkeypatch,
"api.utils.api_utils",
add_tenant_id_to_kwargs=_add_tenant_id_to_kwargs,
build_error_result=lambda message="", code=0, data=False: {"code": code, "message": message, "data": data},
get_request_json=lambda: _AwaitableValue(request_body),
get_json_result=lambda code=0, message="", data=None: {"code": code, "message": message, "data": data},
)
_stub(
monkeypatch,
"api.db.services.document_service",
DocumentService=SimpleNamespace(
get_by_id=lambda _id: (True, SimpleNamespace(id=_id, meta_fields={})),
get_by_ids=lambda ids, cols=None: [SimpleNamespace(id=doc_id, meta_fields={}) for doc_id in ids],
),
)
_stub(
monkeypatch,
"api.db.services.doc_metadata_service",
DocMetadataService=SimpleNamespace(get_flatted_meta_by_kbs=lambda _ids: {}),
)
acc_fn = accessible if callable(accessible) else (lambda *_a, **_k: accessible)
_stub(
monkeypatch,
"api.db.services.knowledgebase_service",
KnowledgebaseService=SimpleNamespace(get_by_id=lambda _id: kb, accessible=acc_fn),
)
_stub(monkeypatch, "api.db.services.llm_service", LLMBundle=lambda *_a, **_k: SimpleNamespace())
_stub(
monkeypatch,
"api.db.joint_services.tenant_model_service",
get_tenant_default_model_by_type=lambda *_a, **_k: {},
get_model_config_from_provider_instance=lambda *_a, **_k: {},
)
_stub(
monkeypatch,
"common.metadata_utils",
meta_filter=lambda *_a, **_k: [],
convert_conditions=lambda c: c,
)
_stub(monkeypatch, "rag.app.tag", label_question=lambda *_a, **_k: {})
fake_retriever = _FakeRetriever(chunks=chunks)
_stub(
monkeypatch,
"common.settings",
retriever=fake_retriever,
kg_retriever=_FakeKGRetriever(),
)
quart_stub = ModuleType("quart")
quart_stub.request = SimpleNamespace(method="POST", args={})
quart_stub.jsonify = lambda payload: payload
monkeypatch.setitem(sys.modules, "quart", quart_stub)
repo_root = Path(__file__).resolve().parents[5]
module_path = repo_root / "api" / "apps" / "restful_apis" / "dify_retrieval_api.py"
spec = importlib.util.spec_from_file_location("test_dify_retrieval_module", module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _PassthroughManager()
monkeypatch.setitem(sys.modules, "test_dify_retrieval_module", module)
spec.loader.exec_module(module)
module._fake_retriever = fake_retriever
return module
@pytest.mark.p1
class TestDifyRetrievalTenantCheck:
"""Regression for #15027: cross-tenant KB exposure via /dify/retrieval."""
@pytest.mark.p1
def test_cross_tenant_request_is_rejected(self, monkeypatch, caplog):
"""A caller whose tenant does NOT own the requested KB must be denied.
Also verifies that the denial is recorded via the module logger so
operators can audit cross-tenant access attempts after the fact.
"""
import logging
owner_kb = SimpleNamespace(id="kb-victim", tenant_id="tenant-owner", tenant_embd_id="", embd_id="bge")
request_body = {
"knowledge_id": "kb-victim",
"query": "VICTIM_SECRET",
"retrieval_setting": {"top_k": 10, "score_threshold": 0.0},
}
def _accessible_only_for_owner(kb_id, user_id):
return user_id == "tenant-owner"
module = _load_dify_retrieval(
monkeypatch,
kb=(True, owner_kb),
accessible=_accessible_only_for_owner,
request_body=request_body,
tenant_id="tenant-attacker",
chunks=[{"doc_id": "d1", "content_with_weight": "VICTIM_SECRET ...", "similarity": 0.9, "docnm_kwd": "doc.txt"}],
)
caplog.set_level(logging.WARNING, logger=module.__name__)
result = asyncio.run(module.retrieval())
assert result["code"] == 109, f"expected AUTHENTICATION_ERROR (109), got {result}"
msg = result["message"].lower()
assert "authorization" in msg or "authentication" in msg
assert "records" not in result, "cross-tenant request leaked records"
assert module._fake_retriever.retrieval_calls == [], "retriever invoked despite denial"
denial_logs = [r for r in caplog.records if r.levelno == logging.WARNING and "cross-tenant" in r.getMessage()]
assert denial_logs, "denial branch must emit a WARNING audit log"
rendered = denial_logs[0].getMessage()
assert "tenant-attacker" in rendered, "caller tenant must appear in the audit log"
assert "kb-victim" in rendered, "denied knowledge_id must appear in the audit log"
assert "VICTIM_SECRET" not in rendered, "audit log must not leak request payload contents"
@pytest.mark.p1
def test_same_tenant_request_succeeds(self, monkeypatch):
"""When the caller's tenant owns the KB, retrieval proceeds normally."""
owner_kb = SimpleNamespace(id="kb-owner", tenant_id="tenant-owner", tenant_embd_id="", embd_id="bge")
request_body = {
"knowledge_id": "kb-owner",
"query": "hello",
"retrieval_setting": {"top_k": 5, "score_threshold": 0.0},
}
module = _load_dify_retrieval(
monkeypatch,
kb=(True, owner_kb),
accessible=lambda _id, _u: True,
request_body=request_body,
tenant_id="tenant-owner",
chunks=[{"doc_id": "d1", "content_with_weight": "hello world", "similarity": 0.8, "docnm_kwd": "doc.txt"}],
)
result = asyncio.run(module.retrieval())
assert "records" in result
assert len(result["records"]) == 1
assert result["records"][0]["content"] == "hello world"
assert module._fake_retriever.retrieval_calls, "retriever was not called on legitimate request"
@pytest.mark.p1
def test_missing_knowledge_base_returns_not_found(self, monkeypatch):
"""KB id that does not exist returns 404 before the access check fires."""
request_body = {"knowledge_id": "kb-does-not-exist", "query": "hello"}
def _accessible_should_not_be_called(*_a, **_k):
raise AssertionError("accessible() must not be called for a missing KB")
module = _load_dify_retrieval(
monkeypatch,
kb=(False, None),
accessible=_accessible_should_not_be_called,
request_body=request_body,
tenant_id="tenant-attacker",
)
result = asyncio.run(module.retrieval())
assert result["code"] == 404
assert "not found" in result["message"].lower()