Refactor: consolidate to use @login_required (#15652)

Refactor: consolidate to use @login_required
This commit is contained in:
Wang Qi
2026-06-05 11:35:00 +08:00
committed by GitHub
parent 9f3e289b78
commit 4cbe597d7e
8 changed files with 55 additions and 126 deletions

View File

@@ -67,6 +67,11 @@ def _run(coro):
return asyncio.run(coro)
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
def _load_dify_retrieval_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[3]
@@ -74,6 +79,11 @@ def _load_dify_retrieval_module(monkeypatch):
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
api_apps_mod = ModuleType("api.apps")
api_apps_mod.current_user = SimpleNamespace(id="tenant-1")
api_apps_mod.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", api_apps_mod)
deepdoc_pkg = ModuleType("deepdoc")
deepdoc_parser_pkg = ModuleType("deepdoc.parser")
deepdoc_parser_pkg.__path__ = []

View File

@@ -67,6 +67,11 @@ def _run(coro):
return asyncio.run(coro)
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
def _load_dify_retrieval_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
@@ -74,6 +79,11 @@ def _load_dify_retrieval_module(monkeypatch):
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
api_apps_mod = ModuleType("api.apps")
api_apps_mod.current_user = SimpleNamespace(id="tenant-1")
api_apps_mod.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", api_apps_mod)
deepdoc_pkg = ModuleType("deepdoc")
deepdoc_parser_pkg = ModuleType("deepdoc.parser")
deepdoc_parser_pkg.__path__ = []
@@ -284,6 +294,7 @@ def test_retrieval_success_with_metadata_and_kg(monkeypatch):
monkeypatch.setattr(module, "jsonify", lambda payload: payload)
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [{"doc_id": "doc-1"}])
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda _kb_id, _tenant_id: True)
monkeypatch.setattr(module, "convert_conditions", lambda cond: cond.get("conditions", []))
monkeypatch.setattr(module, "meta_filter", lambda *_args, **_kwargs: [])
@@ -302,8 +313,8 @@ def test_retrieval_success_with_metadata_and_kg(monkeypatch):
monkeypatch.setattr(module.settings, "kg_retriever", _DummyKgRetriever())
monkeypatch.setattr(
module.DocumentService,
"get_by_id",
lambda doc_id: (True, SimpleNamespace(meta_fields={"origin": f"meta-{doc_id}"})),
"get_by_ids",
lambda doc_ids, cols=None: [SimpleNamespace(id=doc_id, meta_fields={"origin": f"meta-{doc_id}"}) for doc_id in doc_ids],
)
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: [])
@@ -334,6 +345,7 @@ def test_retrieval_not_found_exception_mapping(monkeypatch):
_set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"})
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [])
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda _kb_id, _tenant_id: True)
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: [])
class _BrokenRetriever:
@@ -353,6 +365,7 @@ def test_retrieval_generic_exception_mapping(monkeypatch):
_set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"})
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [])
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda _kb_id, _tenant_id: True)
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: [])
class _BrokenRetriever:

View File

@@ -17,7 +17,6 @@ import asyncio
import inspect
import importlib.util
import sys
from functools import wraps
from pathlib import Path
from types import ModuleType, SimpleNamespace
@@ -249,14 +248,6 @@ def _load_doc_module(monkeypatch, module_basename="chunk_api"):
if value is not None
}
api_utils_mod.server_error_response = lambda e: {"code": 500, "message": str(e)}
def _token_required(func):
@wraps(func)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)
return wrapper
api_utils_mod.token_required = _token_required
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
image_utils_mod = ModuleType("api.utils.image_utils")

View File

@@ -275,7 +275,6 @@ def _load_session_module(monkeypatch):
}
api_utils_mod.get_request_json = lambda: _AwaitableValue({})
api_utils_mod.server_error_response = lambda e: {"code": _StubRetCode.SERVER_ERROR, "message": str(e)}
api_utils_mod.token_required = lambda func: func
api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda func: func)
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)

View File

@@ -15,7 +15,6 @@
#
import asyncio
import functools
import importlib.util
import inspect
import json
@@ -390,17 +389,6 @@ def _load_dataset_module(monkeypatch):
def _get_error_permission_result(message=""):
return _get_result(code=_RetCode.AUTHENTICATION_ERROR, message=message)
def _token_required(func):
@functools.wraps(func)
async def _async_wrapper(*args, **kwargs):
return await func(*args, **kwargs)
@functools.wraps(func)
def _sync_wrapper(*args, **kwargs):
return func(*args, **kwargs)
return _async_wrapper if asyncio.iscoroutinefunction(func) else _sync_wrapper
api_utils_mod.deep_merge = _deep_merge
api_utils_mod.get_error_argument_result = _get_error_argument_result
api_utils_mod.get_error_data_result = _get_error_data_result
@@ -408,7 +396,6 @@ def _load_dataset_module(monkeypatch):
api_utils_mod.get_parser_config = lambda _chunk_method, _unused: {"auto": True}
api_utils_mod.get_result = _get_result
api_utils_mod.remap_dictionary_keys = lambda data: data
api_utils_mod.token_required = _token_required
api_utils_mod.add_tenant_id_to_kwargs = lambda func: func
api_utils_mod.verify_embedding_availability = lambda _embd_id, _tenant_id: (True, None)
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
@@ -854,4 +841,3 @@ def test_delete_index_wipe_flag_unit(monkeypatch):
assert res["code"] == module.RetCode.SUCCESS, res
assert len(deleted) == 1, f"default wipe must call docStore.delete once: {deleted}"
assert cleared_phase_markers == ["kb-1"], cleared_phase_markers

View File

@@ -16,11 +16,10 @@
"""Regression tests for retrieval in api/apps/restful_apis/dify_retrieval_api.py.
Issue #15027: cross-tenant knowledge-base access via POST /api/v1/dify/retrieval.
The handler authenticated the caller via @apikey_required (resolving
tenant_id) but then fetched the requested knowledge_id with no tenant
filter, allowing any valid API key to retrieve chunks from any other
tenant's KB by id. The fix adds a KnowledgebaseService.accessible(...)
check immediately after the lookup.
The handler authenticated the caller and resolved tenant_id, but then fetched
the requested knowledge_id with no tenant filter, allowing any valid caller to
retrieve chunks from any other tenant's KB by id. The fix adds a
KnowledgebaseService.accessible(...) check immediately after the lookup.
"""
import asyncio
@@ -83,12 +82,25 @@ class _FakeKGRetriever:
return {"content_with_weight": ""}
def _load_dify_retrieval(monkeypatch, *, kb, accessible, request_body, chunks=None):
def _load_dify_retrieval(monkeypatch, *, kb, accessible, request_body, tenant_id, chunks=None):
"""Load dify_retrieval_api.py with minimum stubs to exercise the retrieval handler."""
def _add_tenant_id_to_kwargs(func):
async def wrapper(**kwargs):
kwargs["tenant_id"] = tenant_id
return await func(**kwargs)
return wrapper
_stub(
monkeypatch,
"api.apps",
current_user=SimpleNamespace(id=tenant_id),
login_required=lambda func: func,
)
_stub(
monkeypatch,
"api.utils.api_utils",
apikey_required=lambda func: func,
add_tenant_id_to_kwargs=_add_tenant_id_to_kwargs,
build_error_result=lambda message="", code=0, data=False: {"code": code, "message": message, "data": data},
get_request_json=lambda: _AwaitableValue(request_body),
get_json_result=lambda code=0, message="", data=None: {"code": code, "message": message, "data": data},
@@ -185,11 +197,12 @@ class TestDifyRetrievalTenantCheck:
kb=(True, owner_kb),
accessible=_accessible_only_for_owner,
request_body=request_body,
tenant_id="tenant-attacker",
chunks=[{"doc_id": "d1", "content_with_weight": "VICTIM_SECRET ...", "similarity": 0.9, "docnm_kwd": "doc.txt"}],
)
caplog.set_level(logging.WARNING, logger=module.__name__)
result = asyncio.run(module.retrieval(tenant_id="tenant-attacker"))
result = asyncio.run(module.retrieval())
assert result["code"] == 109, f"expected AUTHENTICATION_ERROR (109), got {result}"
msg = result["message"].lower()
@@ -219,10 +232,11 @@ class TestDifyRetrievalTenantCheck:
kb=(True, owner_kb),
accessible=lambda _id, _u: True,
request_body=request_body,
tenant_id="tenant-owner",
chunks=[{"doc_id": "d1", "content_with_weight": "hello world", "similarity": 0.8, "docnm_kwd": "doc.txt"}],
)
result = asyncio.run(module.retrieval(tenant_id="tenant-owner"))
result = asyncio.run(module.retrieval())
assert "records" in result
assert len(result["records"]) == 1
@@ -242,9 +256,10 @@ class TestDifyRetrievalTenantCheck:
kb=(False, None),
accessible=_accessible_should_not_be_called,
request_body=request_body,
tenant_id="tenant-attacker",
)
result = asyncio.run(module.retrieval(tenant_id="tenant-attacker"))
result = asyncio.run(module.retrieval())
assert result["code"] == 404
assert "not found" in result["message"].lower()