Fix: no more @token_required (#15562)

Fix: no more @token_required
This commit is contained in:
Wang Qi
2026-06-03 16:24:08 +08:00
committed by GitHub
parent a678ed7b1f
commit d6fc50a469
5 changed files with 74 additions and 107 deletions

View File

@@ -51,7 +51,9 @@ from api.db.services.user_service import TenantService, UserService
from api.db.services.user_canvas_version import UserCanvasVersionService
from api.utils.api_utils import (
add_tenant_id_to_kwargs,
check_duplicate_ids,
get_data_error_result,
get_error_data_result,
get_json_result,
get_result,
get_request_json,
@@ -441,6 +443,61 @@ def delete_agent_session_item(agent_id, session_id, tenant_id):
return get_json_result(data=API4ConversationService.delete_by_id(session_id))
@manager.route("/agents/<agent_id>/sessions", methods=["DELETE"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
@_require_canvas_access_async
async def delete_agent_session(tenant_id, agent_id):
errors = []
success_count = 0
req = await get_request_json()
cvs = await thread_pool_exec(UserCanvasService.query, user_id=tenant_id, id=agent_id)
if not cvs:
return get_error_data_result(f"You don't own the agent {agent_id}")
if not req:
return get_result()
ids = req.get("ids")
if not ids:
if req.get("delete_all") is True:
ids = [conv.id for conv in await thread_pool_exec(API4ConversationService.query, dialog_id=agent_id)]
if not ids:
return get_result()
else:
return get_result()
conv_list = ids
unique_conv_ids, duplicate_messages = check_duplicate_ids(conv_list, "session")
conv_list = unique_conv_ids
for session_id in conv_list:
conv = await thread_pool_exec(API4ConversationService.query, id=session_id, dialog_id=agent_id)
if not conv:
errors.append(f"The agent doesn't own the session {session_id}")
continue
await thread_pool_exec(API4ConversationService.delete_by_id, session_id)
success_count += 1
if errors:
if success_count > 0:
return get_result(data={"success_count": success_count, "errors": errors},
message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
else:
return get_error_data_result(message="; ".join(errors))
if duplicate_messages:
if success_count > 0:
return get_result(
message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors",
data={"success_count": success_count, "errors": duplicate_messages})
else:
return get_error_data_result(message=";".join(duplicate_messages))
return get_result()
@manager.route("/agents/download", methods=["GET"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs

View File

@@ -26,7 +26,6 @@ from api.db.db_models import APIToken
from api.db.services.api_service import API4ConversationService
from api.db.services.canvas_service import UserCanvasService
from api.db.services.canvas_service import completion as agent_completion
from api.db.services.user_canvas_version import UserCanvasVersionService
from api.db.services.conversation_service import async_iframe_completion as iframe_completion
from api.db.services.dialog_service import DialogService, async_ask, gen_mindmap
from api.db.services.doc_metadata_service import DocMetadataService
@@ -36,9 +35,9 @@ from common.metadata_utils import apply_meta_data_filter
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance
from common.misc_utils import get_uuid, thread_pool_exec
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_json_result, \
get_result, get_request_json, server_error_response, token_required, validate_request
from common.misc_utils import thread_pool_exec
from api.utils.api_utils import get_error_data_result, get_json_result, \
get_result, get_request_json, server_error_response, validate_request
from rag.app.tag import label_question
from rag.prompts.template import load_prompt
from rag.prompts.generator import cross_languages, keyword_extraction
@@ -59,97 +58,6 @@ def _get_sdk_authorization_token():
return token[1]
@token_required
async def create_agent_session(tenant_id, agent_id):
req = await get_request_json()
user_id = req.get("user_id") or request.args.get("user_id", tenant_id)
release_mode = bool(req.get("release", request.args.get("release", False)))
if not await thread_pool_exec(UserCanvasService.query, user_id=tenant_id, id=agent_id):
return get_error_data_result("You cannot access the agent.")
try:
cvs, dsl = await thread_pool_exec(UserCanvasService.get_agent_dsl_with_release, agent_id, release_mode, tenant_id)
except LookupError:
return get_error_data_result("Agent not found.")
except PermissionError as e:
return get_error_data_result(str(e))
session_id = get_uuid()
canvas = Canvas(dsl, tenant_id, agent_id, canvas_id=cvs.id)
canvas.reset()
cvs.dsl = json.loads(str(canvas))
# Get the version title based on release_mode
version_title = await thread_pool_exec(UserCanvasVersionService.get_latest_version_title, cvs.id, release_mode=release_mode)
conv = {
"id": session_id,
"dialog_id": cvs.id,
"user_id": user_id,
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
"source": "agent",
"dsl": cvs.dsl,
"version_title": version_title
}
await thread_pool_exec(API4ConversationService.save, **conv)
conv["agent_id"] = conv.pop("dialog_id")
return get_result(data=conv)
@manager.route("/agents/<agent_id>/sessions", methods=["DELETE"]) # noqa: F821
@token_required
async def delete_agent_session(tenant_id, agent_id):
errors = []
success_count = 0
req = await get_request_json()
cvs = await thread_pool_exec(UserCanvasService.query, user_id=tenant_id, id=agent_id)
if not cvs:
return get_error_data_result(f"You don't own the agent {agent_id}")
if not req:
return get_result()
ids = req.get("ids")
if not ids:
if req.get("delete_all") is True:
ids = [conv.id for conv in await thread_pool_exec(API4ConversationService.query, dialog_id=agent_id)]
if not ids:
return get_result()
else:
return get_result()
conv_list = ids
unique_conv_ids, duplicate_messages = check_duplicate_ids(conv_list, "session")
conv_list = unique_conv_ids
for session_id in conv_list:
conv = await thread_pool_exec(API4ConversationService.query, id=session_id, dialog_id=agent_id)
if not conv:
errors.append(f"The agent doesn't own the session {session_id}")
continue
await thread_pool_exec(API4ConversationService.delete_by_id, session_id)
success_count += 1
if errors:
if success_count > 0:
return get_result(data={"success_count": success_count, "errors": errors},
message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
else:
return get_error_data_result(message="; ".join(errors))
if duplicate_messages:
if success_count > 0:
return get_result(
message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors",
data={"success_count": success_count, "errors": duplicate_messages})
else:
return get_error_data_result(message=";".join(duplicate_messages))
return get_result()
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
async def chatbot_completions(dialog_id):
req = await get_request_json()

View File

@@ -44,7 +44,6 @@ from api.utils.api_utils import (
get_request_json,
get_result,
server_error_response,
token_required,
)
from api.utils.pagination_utils import validate_rest_api_page_size
from api.utils.image_utils import store_chunk_image
@@ -157,7 +156,8 @@ def _enrich_chunks_with_document_metadata(chunks: list[dict], metadata_fields=No
@manager.route("/datasets/<dataset_id>/chunks", methods=["POST"]) # noqa: F821
@token_required
@login_required
@add_tenant_id_to_kwargs
async def parse(tenant_id, dataset_id):
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
@@ -212,7 +212,8 @@ async def parse(tenant_id, dataset_id):
@manager.route("/datasets/<dataset_id>/chunks", methods=["DELETE"]) # noqa: F821
@token_required
@login_required
@add_tenant_id_to_kwargs
async def stop_parsing(tenant_id, dataset_id):
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
@@ -252,7 +253,8 @@ async def stop_parsing(tenant_id, dataset_id):
@manager.route("/retrieval", methods=["POST"]) # noqa: F821
@token_required
@login_required
@add_tenant_id_to_kwargs
async def retrieval_test(tenant_id):
req = await get_request_json()
if not req.get("dataset_ids"):

View File

@@ -108,9 +108,8 @@ def test_retrieval_compatibility_requires_auth(rest_client_noauth):
res = rest_client_noauth.post("/retrieval", json={"question": "test", "dataset_ids": ["x"]})
assert res.status_code == 401
payload = res.json()
# token_required preserves legacy payload code/message while returning HTTP 401.
assert payload["code"] == 0, payload
assert payload["message"] == "`Authorization` can't be empty", payload
assert payload["code"] == 401, payload
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", payload
@wait_for(20, 1, "Retrieval indexing timeout in RESTful batch 10 tests")
@@ -151,14 +150,13 @@ def _retrieval_lacks_chunks(rest_client, dataset_id, question, chunk_ids):
@pytest.mark.p2
def test_retrieval_requires_auth_contract(ensure_parsed_document):
dataset_id, _ = ensure_parsed_document()
def test_retrieval_requires_auth_contract():
for scenario_name, token, expected_code, expected_message in (
("missing token", None, 0, "`Authorization` can't be empty"),
("invalid token", INVALID_API_TOKEN, 109, "Authentication error: API key is invalid!"),
("missing token", None, 401, "<Unauthorized '401: Unauthorized'>"),
("invalid token", INVALID_API_TOKEN, 401, "<Unauthorized '401: Unauthorized'>"),
):
client = RestClient(token=token)
res = client.post("/retrieval", json={"question": "chunk", "dataset_ids": [dataset_id]})
res = client.post("/retrieval", json={"question": "chunk", "dataset_ids": ["x"]})
assert res.status_code == 401, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == expected_code, (scenario_name, payload)

View File

@@ -70,7 +70,9 @@ def _load_agent_api(monkeypatch, get_by_id_result):
monkeypatch,
"api.utils.api_utils",
add_tenant_id_to_kwargs=lambda func: func,
check_duplicate_ids=lambda ids, _kind="item": (ids, []),
get_data_error_result=lambda message="Sorry": {"code": 102, "message": message, "data": None},
get_error_data_result=lambda message="Sorry": {"code": 102, "message": message, "data": None},
get_json_result=lambda code=0, message="", data=None: {"code": code, "message": message, "data": data},
get_result=lambda **kwargs: kwargs,
get_request_json=lambda: {},