diff --git a/api/apps/restful_apis/agent_api.py b/api/apps/restful_apis/agent_api.py index 0eda78f107..1dc8c7a8aa 100644 --- a/api/apps/restful_apis/agent_api.py +++ b/api/apps/restful_apis/agent_api.py @@ -51,7 +51,9 @@ from api.db.services.user_service import TenantService, UserService from api.db.services.user_canvas_version import UserCanvasVersionService from api.utils.api_utils import ( add_tenant_id_to_kwargs, + check_duplicate_ids, get_data_error_result, + get_error_data_result, get_json_result, get_result, get_request_json, @@ -441,6 +443,61 @@ def delete_agent_session_item(agent_id, session_id, tenant_id): return get_json_result(data=API4ConversationService.delete_by_id(session_id)) +@manager.route("/agents//sessions", methods=["DELETE"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +@_require_canvas_access_async +async def delete_agent_session(tenant_id, agent_id): + errors = [] + success_count = 0 + req = await get_request_json() + cvs = await thread_pool_exec(UserCanvasService.query, user_id=tenant_id, id=agent_id) + if not cvs: + return get_error_data_result(f"You don't own the agent {agent_id}") + + if not req: + return get_result() + + ids = req.get("ids") + if not ids: + if req.get("delete_all") is True: + ids = [conv.id for conv in await thread_pool_exec(API4ConversationService.query, dialog_id=agent_id)] + if not ids: + return get_result() + else: + return get_result() + + conv_list = ids + + unique_conv_ids, duplicate_messages = check_duplicate_ids(conv_list, "session") + conv_list = unique_conv_ids + + for session_id in conv_list: + conv = await thread_pool_exec(API4ConversationService.query, id=session_id, dialog_id=agent_id) + if not conv: + errors.append(f"The agent doesn't own the session {session_id}") + continue + await thread_pool_exec(API4ConversationService.delete_by_id, session_id) + success_count += 1 + + if errors: + if success_count > 0: + return get_result(data={"success_count": success_count, "errors": errors}, + message=f"Partially deleted {success_count} sessions with {len(errors)} errors") + else: + return get_error_data_result(message="; ".join(errors)) + + if duplicate_messages: + if success_count > 0: + return get_result( + message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", + data={"success_count": success_count, "errors": duplicate_messages}) + else: + return get_error_data_result(message=";".join(duplicate_messages)) + + return get_result() + + @manager.route("/agents/download", methods=["GET"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs diff --git a/api/apps/restful_apis/bot_api.py b/api/apps/restful_apis/bot_api.py index c2d77a1060..ee4d0c68a0 100644 --- a/api/apps/restful_apis/bot_api.py +++ b/api/apps/restful_apis/bot_api.py @@ -26,7 +26,6 @@ from api.db.db_models import APIToken from api.db.services.api_service import API4ConversationService from api.db.services.canvas_service import UserCanvasService from api.db.services.canvas_service import completion as agent_completion -from api.db.services.user_canvas_version import UserCanvasVersionService from api.db.services.conversation_service import async_iframe_completion as iframe_completion from api.db.services.dialog_service import DialogService, async_ask, gen_mindmap from api.db.services.doc_metadata_service import DocMetadataService @@ -36,9 +35,9 @@ from common.metadata_utils import apply_meta_data_filter from api.db.services.search_service import SearchService from api.db.services.user_service import UserTenantService from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance -from common.misc_utils import get_uuid, thread_pool_exec -from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_json_result, \ - get_result, get_request_json, server_error_response, token_required, validate_request +from common.misc_utils import thread_pool_exec +from api.utils.api_utils import get_error_data_result, get_json_result, \ + get_result, get_request_json, server_error_response, validate_request from rag.app.tag import label_question from rag.prompts.template import load_prompt from rag.prompts.generator import cross_languages, keyword_extraction @@ -59,97 +58,6 @@ def _get_sdk_authorization_token(): return token[1] -@token_required -async def create_agent_session(tenant_id, agent_id): - req = await get_request_json() - user_id = req.get("user_id") or request.args.get("user_id", tenant_id) - release_mode = bool(req.get("release", request.args.get("release", False))) - - if not await thread_pool_exec(UserCanvasService.query, user_id=tenant_id, id=agent_id): - return get_error_data_result("You cannot access the agent.") - - try: - cvs, dsl = await thread_pool_exec(UserCanvasService.get_agent_dsl_with_release, agent_id, release_mode, tenant_id) - except LookupError: - return get_error_data_result("Agent not found.") - except PermissionError as e: - return get_error_data_result(str(e)) - - session_id = get_uuid() - canvas = Canvas(dsl, tenant_id, agent_id, canvas_id=cvs.id) - canvas.reset() - - cvs.dsl = json.loads(str(canvas)) - # Get the version title based on release_mode - version_title = await thread_pool_exec(UserCanvasVersionService.get_latest_version_title, cvs.id, release_mode=release_mode) - conv = { - "id": session_id, - "dialog_id": cvs.id, - "user_id": user_id, - "message": [{"role": "assistant", "content": canvas.get_prologue()}], - "source": "agent", - "dsl": cvs.dsl, - "version_title": version_title - } - await thread_pool_exec(API4ConversationService.save, **conv) - conv["agent_id"] = conv.pop("dialog_id") - return get_result(data=conv) - - -@manager.route("/agents//sessions", methods=["DELETE"]) # noqa: F821 -@token_required -async def delete_agent_session(tenant_id, agent_id): - errors = [] - success_count = 0 - req = await get_request_json() - cvs = await thread_pool_exec(UserCanvasService.query, user_id=tenant_id, id=agent_id) - if not cvs: - return get_error_data_result(f"You don't own the agent {agent_id}") - - if not req: - return get_result() - - ids = req.get("ids") - if not ids: - if req.get("delete_all") is True: - ids = [conv.id for conv in await thread_pool_exec(API4ConversationService.query, dialog_id=agent_id)] - if not ids: - return get_result() - else: - return get_result() - - conv_list = ids - - unique_conv_ids, duplicate_messages = check_duplicate_ids(conv_list, "session") - conv_list = unique_conv_ids - - for session_id in conv_list: - conv = await thread_pool_exec(API4ConversationService.query, id=session_id, dialog_id=agent_id) - if not conv: - errors.append(f"The agent doesn't own the session {session_id}") - continue - await thread_pool_exec(API4ConversationService.delete_by_id, session_id) - success_count += 1 - - if errors: - if success_count > 0: - return get_result(data={"success_count": success_count, "errors": errors}, - message=f"Partially deleted {success_count} sessions with {len(errors)} errors") - else: - return get_error_data_result(message="; ".join(errors)) - - if duplicate_messages: - if success_count > 0: - return get_result( - message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", - data={"success_count": success_count, "errors": duplicate_messages}) - else: - return get_error_data_result(message=";".join(duplicate_messages)) - - return get_result() - - - @manager.route("/chatbots//completions", methods=["POST"]) # noqa: F821 async def chatbot_completions(dialog_id): req = await get_request_json() diff --git a/api/apps/restful_apis/chunk_api.py b/api/apps/restful_apis/chunk_api.py index 59383e6f07..d2f803f7bc 100644 --- a/api/apps/restful_apis/chunk_api.py +++ b/api/apps/restful_apis/chunk_api.py @@ -44,7 +44,6 @@ from api.utils.api_utils import ( get_request_json, get_result, server_error_response, - token_required, ) from api.utils.pagination_utils import validate_rest_api_page_size from api.utils.image_utils import store_chunk_image @@ -157,7 +156,8 @@ def _enrich_chunks_with_document_metadata(chunks: list[dict], metadata_fields=No @manager.route("/datasets//chunks", methods=["POST"]) # noqa: F821 -@token_required +@login_required +@add_tenant_id_to_kwargs async def parse(tenant_id, dataset_id): if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") @@ -212,7 +212,8 @@ async def parse(tenant_id, dataset_id): @manager.route("/datasets//chunks", methods=["DELETE"]) # noqa: F821 -@token_required +@login_required +@add_tenant_id_to_kwargs async def stop_parsing(tenant_id, dataset_id): if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") @@ -252,7 +253,8 @@ async def stop_parsing(tenant_id, dataset_id): @manager.route("/retrieval", methods=["POST"]) # noqa: F821 -@token_required +@login_required +@add_tenant_id_to_kwargs async def retrieval_test(tenant_id): req = await get_request_json() if not req.get("dataset_ids"): diff --git a/test/testcases/restful_api/test_retrieval.py b/test/testcases/restful_api/test_retrieval.py index 350f277407..d75f331651 100644 --- a/test/testcases/restful_api/test_retrieval.py +++ b/test/testcases/restful_api/test_retrieval.py @@ -108,9 +108,8 @@ def test_retrieval_compatibility_requires_auth(rest_client_noauth): res = rest_client_noauth.post("/retrieval", json={"question": "test", "dataset_ids": ["x"]}) assert res.status_code == 401 payload = res.json() - # token_required preserves legacy payload code/message while returning HTTP 401. - assert payload["code"] == 0, payload - assert payload["message"] == "`Authorization` can't be empty", payload + assert payload["code"] == 401, payload + assert payload["message"] == "", payload @wait_for(20, 1, "Retrieval indexing timeout in RESTful batch 10 tests") @@ -151,14 +150,13 @@ def _retrieval_lacks_chunks(rest_client, dataset_id, question, chunk_ids): @pytest.mark.p2 -def test_retrieval_requires_auth_contract(ensure_parsed_document): - dataset_id, _ = ensure_parsed_document() +def test_retrieval_requires_auth_contract(): for scenario_name, token, expected_code, expected_message in ( - ("missing token", None, 0, "`Authorization` can't be empty"), - ("invalid token", INVALID_API_TOKEN, 109, "Authentication error: API key is invalid!"), + ("missing token", None, 401, ""), + ("invalid token", INVALID_API_TOKEN, 401, ""), ): client = RestClient(token=token) - res = client.post("/retrieval", json={"question": "chunk", "dataset_ids": [dataset_id]}) + res = client.post("/retrieval", json={"question": "chunk", "dataset_ids": ["x"]}) assert res.status_code == 401, (scenario_name, res.text) payload = res.json() assert payload["code"] == expected_code, (scenario_name, payload) diff --git a/test/unit_test/api/apps/restful_apis/test_get_agent_session.py b/test/unit_test/api/apps/restful_apis/test_get_agent_session.py index 7106140655..8832d9fc83 100644 --- a/test/unit_test/api/apps/restful_apis/test_get_agent_session.py +++ b/test/unit_test/api/apps/restful_apis/test_get_agent_session.py @@ -70,7 +70,9 @@ def _load_agent_api(monkeypatch, get_by_id_result): monkeypatch, "api.utils.api_utils", add_tenant_id_to_kwargs=lambda func: func, + check_duplicate_ids=lambda ids, _kind="item": (ids, []), get_data_error_result=lambda message="Sorry": {"code": 102, "message": message, "data": None}, + get_error_data_result=lambda message="Sorry": {"code": 102, "message": message, "data": None}, get_json_result=lambda code=0, message="", data=None: {"code": code, "message": message, "data": data}, get_result=lambda **kwargs: kwargs, get_request_json=lambda: {},