From 4cbe597d7ee98fc34bfc29bf44f9c1f036c92bef Mon Sep 17 00:00:00 2001 From: Wang Qi Date: Fri, 5 Jun 2026 11:35:00 +0800 Subject: [PATCH] Refactor: consolidate to use @login_required (#15652) Refactor: consolidate to use @login_required --- api/apps/restful_apis/dify_retrieval_api.py | 6 +- api/utils/api_utils.py | 89 +------------------ .../test_dify_retrieval_routes_unit.py | 10 +++ .../test_dify_retrieval_routes_unit.py | 17 +++- .../test_doc_sdk_routes_unit.py | 9 -- .../test_session_sdk_routes_unit.py | 1 - .../test_dataset_sdk_routes_unit.py | 14 --- .../api/apps/sdk/test_dify_retrieval.py | 35 +++++--- 8 files changed, 55 insertions(+), 126 deletions(-) diff --git a/api/apps/restful_apis/dify_retrieval_api.py b/api/apps/restful_apis/dify_retrieval_api.py index 132b5869a5..d7c2968606 100644 --- a/api/apps/restful_apis/dify_retrieval_api.py +++ b/api/apps/restful_apis/dify_retrieval_api.py @@ -29,7 +29,8 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance from common.metadata_utils import meta_filter, convert_conditions -from api.utils.api_utils import apikey_required, build_error_result, get_request_json, get_json_result +from api.apps import login_required +from api.utils.api_utils import add_tenant_id_to_kwargs, build_error_result, get_request_json, get_json_result from rag.app.tag import label_question from common.constants import RetCode, LLMType from common import settings @@ -108,7 +109,8 @@ def _parse_retrieval_options(retrieval_setting): @manager.route('/dify/retrieval', methods=['POST', 'GET']) # noqa: F821 -@apikey_required +@login_required +@add_tenant_id_to_kwargs async def retrieval(tenant_id): """ Dify-compatible retrieval API diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 74d95a514f..21809aacbf 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -19,7 +19,6 @@ import functools import inspect import json import logging -import os import sys import time from copy import deepcopy @@ -32,7 +31,7 @@ from quart import ( request, has_app_context, ) -from werkzeug.exceptions import BadRequest as WerkzeugBadRequest, Unauthorized as WerkzeugUnauthorized +from werkzeug.exceptions import BadRequest as WerkzeugBadRequest try: from quart.exceptions import BadRequest as QuartBadRequest @@ -42,7 +41,6 @@ except ImportError: # pragma: no cover - optional dependency from peewee import OperationalError from common.constants import ActiveEnum, LLMType -from api.db.db_models import APIToken from api.utils.json_encode import CustomJSONEncoder from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions from api.db.services.tenant_llm_service import LLMFactoriesService @@ -252,28 +250,6 @@ def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=Non return _safe_jsonify(response) -def apikey_required(func): - @wraps(func) - async def decorated_function(*args, **kwargs): - authorization = request.headers.get("Authorization") - if not authorization: - return build_error_result(message="Authorization header is missing!", code=RetCode.FORBIDDEN) - parts = authorization.split() - if len(parts) < 2: - return build_error_result(message="Please check your authorization format.", code=RetCode.FORBIDDEN) - token = parts[1] - objs = APIToken.query(token=token) - if not objs: - return build_error_result(message="API-KEY is invalid!", code=RetCode.FORBIDDEN) - kwargs["tenant_id"] = objs[0].tenant_id - if inspect.iscoroutinefunction(func): - return await func(*args, **kwargs) - - return func(*args, **kwargs) - - return decorated_function - - def build_error_result(code=RetCode.FORBIDDEN, message="success"): response = {"code": code, "message": message} response = _safe_jsonify(response) @@ -288,69 +264,6 @@ def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", da return _safe_jsonify({"code": code, "message": message, "data": data}) -def token_required(func): - @wraps(func) - async def wrapper(*args, **kwargs): - # Validate the token (API Key) - if os.environ.get("DISABLE_SDK"): - err = WerkzeugUnauthorized(description="`Authorization` can't be empty") - err.code = RetCode.SUCCESS - raise err - - authorization_str = request.headers.get("Authorization") - if not authorization_str: - err = WerkzeugUnauthorized(description="`Authorization` can't be empty") - err.code = RetCode.SUCCESS - raise err - - authorization_list = authorization_str.split() - if len(authorization_list) < 2: - err = WerkzeugUnauthorized(description="Please check your authorization format.") - err.code = RetCode.AUTHENTICATION_ERROR - raise err - - token = authorization_list[1] - - # First try API token (explicit API token authentication) - objs = APIToken.query(token=token) - if objs: - # On success, inject tenant_id into the route function's kwargs - kwargs["tenant_id"] = objs[0].tenant_id - result = func(*args, **kwargs) - if inspect.iscoroutine(result): - return await result - return result - - # Fallback: try login token (for clients that use login token as API token) - # Login tokens are JWT-encoded (URLSafeTimedSerializer), need to decode to get raw access_token - from api.db.services.user_service import UserService - from common.constants import StatusEnum - from common import settings - from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer - try: - jwt = Serializer(secret_key=settings.get_secret_key()) - raw_token = str(jwt.loads(token)) - user = UserService.query(access_token=raw_token, status=StatusEnum.VALID.value) - if user: - # On success, inject tenant_id from user's tenant - from api.db.services.user_service import UserTenantService - tenants = UserTenantService.query(user_id=user[0].id) - if tenants: - kwargs["tenant_id"] = tenants[0].tenant_id - result = func(*args, **kwargs) - if inspect.iscoroutine(result): - return await result - return result - except Exception: - pass - - err = WerkzeugUnauthorized(description="Authentication error: API key is invalid!") - err.code = RetCode.AUTHENTICATION_ERROR - raise err - - return wrapper - - def get_result(code=RetCode.SUCCESS, message="", data=None, total=None): """ Standard API response format: diff --git a/test/testcases/restful_api/test_dify_retrieval_routes_unit.py b/test/testcases/restful_api/test_dify_retrieval_routes_unit.py index 90ea5b781d..d7b00b6afa 100644 --- a/test/testcases/restful_api/test_dify_retrieval_routes_unit.py +++ b/test/testcases/restful_api/test_dify_retrieval_routes_unit.py @@ -67,6 +67,11 @@ def _run(coro): return asyncio.run(coro) +@pytest.fixture(scope="session", autouse=True) +def set_tenant_info(): + return None + + def _load_dify_retrieval_module(monkeypatch): repo_root = Path(__file__).resolve().parents[3] @@ -74,6 +79,11 @@ def _load_dify_retrieval_module(monkeypatch): common_pkg.__path__ = [str(repo_root / "common")] monkeypatch.setitem(sys.modules, "common", common_pkg) + api_apps_mod = ModuleType("api.apps") + api_apps_mod.current_user = SimpleNamespace(id="tenant-1") + api_apps_mod.login_required = lambda func: func + monkeypatch.setitem(sys.modules, "api.apps", api_apps_mod) + deepdoc_pkg = ModuleType("deepdoc") deepdoc_parser_pkg = ModuleType("deepdoc.parser") deepdoc_parser_pkg.__path__ = [] diff --git a/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py b/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py index c08dd482d0..c51880c801 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py +++ b/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py @@ -67,6 +67,11 @@ def _run(coro): return asyncio.run(coro) +@pytest.fixture(scope="session", autouse=True) +def set_tenant_info(): + return None + + def _load_dify_retrieval_module(monkeypatch): repo_root = Path(__file__).resolve().parents[4] @@ -74,6 +79,11 @@ def _load_dify_retrieval_module(monkeypatch): common_pkg.__path__ = [str(repo_root / "common")] monkeypatch.setitem(sys.modules, "common", common_pkg) + api_apps_mod = ModuleType("api.apps") + api_apps_mod.current_user = SimpleNamespace(id="tenant-1") + api_apps_mod.login_required = lambda func: func + monkeypatch.setitem(sys.modules, "api.apps", api_apps_mod) + deepdoc_pkg = ModuleType("deepdoc") deepdoc_parser_pkg = ModuleType("deepdoc.parser") deepdoc_parser_pkg.__path__ = [] @@ -284,6 +294,7 @@ def test_retrieval_success_with_metadata_and_kg(monkeypatch): monkeypatch.setattr(module, "jsonify", lambda payload: payload) monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [{"doc_id": "doc-1"}]) monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB())) + monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda _kb_id, _tenant_id: True) monkeypatch.setattr(module, "convert_conditions", lambda cond: cond.get("conditions", [])) monkeypatch.setattr(module, "meta_filter", lambda *_args, **_kwargs: []) @@ -302,8 +313,8 @@ def test_retrieval_success_with_metadata_and_kg(monkeypatch): monkeypatch.setattr(module.settings, "kg_retriever", _DummyKgRetriever()) monkeypatch.setattr( module.DocumentService, - "get_by_id", - lambda doc_id: (True, SimpleNamespace(meta_fields={"origin": f"meta-{doc_id}"})), + "get_by_ids", + lambda doc_ids, cols=None: [SimpleNamespace(id=doc_id, meta_fields={"origin": f"meta-{doc_id}"}) for doc_id in doc_ids], ) monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: []) @@ -334,6 +345,7 @@ def test_retrieval_not_found_exception_mapping(monkeypatch): _set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"}) monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: []) monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB())) + monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda _kb_id, _tenant_id: True) monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: []) class _BrokenRetriever: @@ -353,6 +365,7 @@ def test_retrieval_generic_exception_mapping(monkeypatch): _set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"}) monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: []) monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB())) + monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda _kb_id, _tenant_id: True) monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: []) class _BrokenRetriever: diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py index 66302fada8..7e6bd4128d 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py @@ -17,7 +17,6 @@ import asyncio import inspect import importlib.util import sys -from functools import wraps from pathlib import Path from types import ModuleType, SimpleNamespace @@ -249,14 +248,6 @@ def _load_doc_module(monkeypatch, module_basename="chunk_api"): if value is not None } api_utils_mod.server_error_response = lambda e: {"code": 500, "message": str(e)} - def _token_required(func): - @wraps(func) - async def wrapper(*args, **kwargs): - return await func(*args, **kwargs) - - return wrapper - - api_utils_mod.token_required = _token_required monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) image_utils_mod = ModuleType("api.utils.image_utils") diff --git a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py index 6c1aa4b961..0287fc75aa 100644 --- a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py @@ -275,7 +275,6 @@ def _load_session_module(monkeypatch): } api_utils_mod.get_request_json = lambda: _AwaitableValue({}) api_utils_mod.server_error_response = lambda e: {"code": _StubRetCode.SERVER_ERROR, "message": str(e)} - api_utils_mod.token_required = lambda func: func api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda func: func) monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) diff --git a/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py b/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py index 2311eb22dc..4a2c8a47a0 100644 --- a/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py +++ b/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py @@ -15,7 +15,6 @@ # import asyncio -import functools import importlib.util import inspect import json @@ -390,17 +389,6 @@ def _load_dataset_module(monkeypatch): def _get_error_permission_result(message=""): return _get_result(code=_RetCode.AUTHENTICATION_ERROR, message=message) - def _token_required(func): - @functools.wraps(func) - async def _async_wrapper(*args, **kwargs): - return await func(*args, **kwargs) - - @functools.wraps(func) - def _sync_wrapper(*args, **kwargs): - return func(*args, **kwargs) - - return _async_wrapper if asyncio.iscoroutinefunction(func) else _sync_wrapper - api_utils_mod.deep_merge = _deep_merge api_utils_mod.get_error_argument_result = _get_error_argument_result api_utils_mod.get_error_data_result = _get_error_data_result @@ -408,7 +396,6 @@ def _load_dataset_module(monkeypatch): api_utils_mod.get_parser_config = lambda _chunk_method, _unused: {"auto": True} api_utils_mod.get_result = _get_result api_utils_mod.remap_dictionary_keys = lambda data: data - api_utils_mod.token_required = _token_required api_utils_mod.add_tenant_id_to_kwargs = lambda func: func api_utils_mod.verify_embedding_availability = lambda _embd_id, _tenant_id: (True, None) monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) @@ -854,4 +841,3 @@ def test_delete_index_wipe_flag_unit(monkeypatch): assert res["code"] == module.RetCode.SUCCESS, res assert len(deleted) == 1, f"default wipe must call docStore.delete once: {deleted}" assert cleared_phase_markers == ["kb-1"], cleared_phase_markers - diff --git a/test/unit_test/api/apps/sdk/test_dify_retrieval.py b/test/unit_test/api/apps/sdk/test_dify_retrieval.py index 0b880308bd..fbf5b5225f 100644 --- a/test/unit_test/api/apps/sdk/test_dify_retrieval.py +++ b/test/unit_test/api/apps/sdk/test_dify_retrieval.py @@ -16,11 +16,10 @@ """Regression tests for retrieval in api/apps/restful_apis/dify_retrieval_api.py. Issue #15027: cross-tenant knowledge-base access via POST /api/v1/dify/retrieval. -The handler authenticated the caller via @apikey_required (resolving -tenant_id) but then fetched the requested knowledge_id with no tenant -filter, allowing any valid API key to retrieve chunks from any other -tenant's KB by id. The fix adds a KnowledgebaseService.accessible(...) -check immediately after the lookup. +The handler authenticated the caller and resolved tenant_id, but then fetched +the requested knowledge_id with no tenant filter, allowing any valid caller to +retrieve chunks from any other tenant's KB by id. The fix adds a +KnowledgebaseService.accessible(...) check immediately after the lookup. """ import asyncio @@ -83,12 +82,25 @@ class _FakeKGRetriever: return {"content_with_weight": ""} -def _load_dify_retrieval(monkeypatch, *, kb, accessible, request_body, chunks=None): +def _load_dify_retrieval(monkeypatch, *, kb, accessible, request_body, tenant_id, chunks=None): """Load dify_retrieval_api.py with minimum stubs to exercise the retrieval handler.""" + def _add_tenant_id_to_kwargs(func): + async def wrapper(**kwargs): + kwargs["tenant_id"] = tenant_id + return await func(**kwargs) + + return wrapper + + _stub( + monkeypatch, + "api.apps", + current_user=SimpleNamespace(id=tenant_id), + login_required=lambda func: func, + ) _stub( monkeypatch, "api.utils.api_utils", - apikey_required=lambda func: func, + add_tenant_id_to_kwargs=_add_tenant_id_to_kwargs, build_error_result=lambda message="", code=0, data=False: {"code": code, "message": message, "data": data}, get_request_json=lambda: _AwaitableValue(request_body), get_json_result=lambda code=0, message="", data=None: {"code": code, "message": message, "data": data}, @@ -185,11 +197,12 @@ class TestDifyRetrievalTenantCheck: kb=(True, owner_kb), accessible=_accessible_only_for_owner, request_body=request_body, + tenant_id="tenant-attacker", chunks=[{"doc_id": "d1", "content_with_weight": "VICTIM_SECRET ...", "similarity": 0.9, "docnm_kwd": "doc.txt"}], ) caplog.set_level(logging.WARNING, logger=module.__name__) - result = asyncio.run(module.retrieval(tenant_id="tenant-attacker")) + result = asyncio.run(module.retrieval()) assert result["code"] == 109, f"expected AUTHENTICATION_ERROR (109), got {result}" msg = result["message"].lower() @@ -219,10 +232,11 @@ class TestDifyRetrievalTenantCheck: kb=(True, owner_kb), accessible=lambda _id, _u: True, request_body=request_body, + tenant_id="tenant-owner", chunks=[{"doc_id": "d1", "content_with_weight": "hello world", "similarity": 0.8, "docnm_kwd": "doc.txt"}], ) - result = asyncio.run(module.retrieval(tenant_id="tenant-owner")) + result = asyncio.run(module.retrieval()) assert "records" in result assert len(result["records"]) == 1 @@ -242,9 +256,10 @@ class TestDifyRetrievalTenantCheck: kb=(False, None), accessible=_accessible_should_not_be_called, request_body=request_body, + tenant_id="tenant-attacker", ) - result = asyncio.run(module.retrieval(tenant_id="tenant-attacker")) + result = asyncio.run(module.retrieval()) assert result["code"] == 404 assert "not found" in result["message"].lower()