Refactor: consolidate to use @login_required (#15652)

Refactor: consolidate to use @login_required
This commit is contained in:
Wang Qi
2026-06-05 11:35:00 +08:00
committed by GitHub
parent 9f3e289b78
commit 4cbe597d7e
8 changed files with 55 additions and 126 deletions

View File

@@ -16,11 +16,10 @@
"""Regression tests for retrieval in api/apps/restful_apis/dify_retrieval_api.py.
Issue #15027: cross-tenant knowledge-base access via POST /api/v1/dify/retrieval.
The handler authenticated the caller via @apikey_required (resolving
tenant_id) but then fetched the requested knowledge_id with no tenant
filter, allowing any valid API key to retrieve chunks from any other
tenant's KB by id. The fix adds a KnowledgebaseService.accessible(...)
check immediately after the lookup.
The handler authenticated the caller and resolved tenant_id, but then fetched
the requested knowledge_id with no tenant filter, allowing any valid caller to
retrieve chunks from any other tenant's KB by id. The fix adds a
KnowledgebaseService.accessible(...) check immediately after the lookup.
"""
import asyncio
@@ -83,12 +82,25 @@ class _FakeKGRetriever:
return {"content_with_weight": ""}
def _load_dify_retrieval(monkeypatch, *, kb, accessible, request_body, chunks=None):
def _load_dify_retrieval(monkeypatch, *, kb, accessible, request_body, tenant_id, chunks=None):
"""Load dify_retrieval_api.py with minimum stubs to exercise the retrieval handler."""
def _add_tenant_id_to_kwargs(func):
async def wrapper(**kwargs):
kwargs["tenant_id"] = tenant_id
return await func(**kwargs)
return wrapper
_stub(
monkeypatch,
"api.apps",
current_user=SimpleNamespace(id=tenant_id),
login_required=lambda func: func,
)
_stub(
monkeypatch,
"api.utils.api_utils",
apikey_required=lambda func: func,
add_tenant_id_to_kwargs=_add_tenant_id_to_kwargs,
build_error_result=lambda message="", code=0, data=False: {"code": code, "message": message, "data": data},
get_request_json=lambda: _AwaitableValue(request_body),
get_json_result=lambda code=0, message="", data=None: {"code": code, "message": message, "data": data},
@@ -185,11 +197,12 @@ class TestDifyRetrievalTenantCheck:
kb=(True, owner_kb),
accessible=_accessible_only_for_owner,
request_body=request_body,
tenant_id="tenant-attacker",
chunks=[{"doc_id": "d1", "content_with_weight": "VICTIM_SECRET ...", "similarity": 0.9, "docnm_kwd": "doc.txt"}],
)
caplog.set_level(logging.WARNING, logger=module.__name__)
result = asyncio.run(module.retrieval(tenant_id="tenant-attacker"))
result = asyncio.run(module.retrieval())
assert result["code"] == 109, f"expected AUTHENTICATION_ERROR (109), got {result}"
msg = result["message"].lower()
@@ -219,10 +232,11 @@ class TestDifyRetrievalTenantCheck:
kb=(True, owner_kb),
accessible=lambda _id, _u: True,
request_body=request_body,
tenant_id="tenant-owner",
chunks=[{"doc_id": "d1", "content_with_weight": "hello world", "similarity": 0.8, "docnm_kwd": "doc.txt"}],
)
result = asyncio.run(module.retrieval(tenant_id="tenant-owner"))
result = asyncio.run(module.retrieval())
assert "records" in result
assert len(result["records"]) == 1
@@ -242,9 +256,10 @@ class TestDifyRetrievalTenantCheck:
kb=(False, None),
accessible=_accessible_should_not_be_called,
request_body=request_body,
tenant_id="tenant-attacker",
)
result = asyncio.run(module.retrieval(tenant_id="tenant-attacker"))
result = asyncio.run(module.retrieval())
assert result["code"] == 404
assert "not found" in result["message"].lower()