diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index b539272c0d..11960dcf65 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -42,13 +42,15 @@ from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_ from rag.app.tag import label_question from rag.prompts.template import load_prompt from rag.prompts.generator import cross_languages, keyword_extraction -from common.constants import RetCode, LLMType +from common.constants import RetCode, LLMType, StatusEnum from common import settings from api.utils.reference_metadata_utils import ( enrich_chunks_with_document_metadata, resolve_reference_metadata_preferences, ) +logger = logging.getLogger(__name__) + @token_required async def create_agent_session(tenant_id, agent_id): @@ -152,20 +154,69 @@ async def chatbot_completions(dialog_id): objs = APIToken.query(beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') + tenant_id = objs[0].tenant_id + exists, dialog = DialogService.get_by_id(dialog_id) + if (not exists + or getattr(dialog, "tenant_id", None) != tenant_id + or str(getattr(dialog, "status", "")) != StatusEnum.VALID.value): + logger.warning( + "Denied chatbot access: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s", + "no access to this chatbot", + tenant_id, + dialog_id, + req.get("user_id"), + req.get("session_id"), + ) + return get_error_data_result(message="Authentication error: no access to this chatbot!") if "quote" not in req: req["quote"] = False + def _validate_iframe_access(): + if req.get("session_id"): + exists, conv = API4ConversationService.get_by_id(req.get("session_id")) + if not exists: + raise AssertionError("Session not found!") + if conv.dialog_id != dialog_id: + raise AssertionError("Session does not belong to this dialog") + if tenant_id and conv.user_id and conv.user_id != tenant_id: + raise AssertionError("Session does not belong to this tenant") + if req.get("stream", True): - resp = Response(iframe_completion(dialog_id, **req), mimetype="text/event-stream") + try: + _validate_iframe_access() + except AssertionError: + logger.warning( + "Denied chatbot completion stream: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s", + "no access to this chatbot", + tenant_id, + dialog_id, + req.get("user_id"), + req.get("session_id"), + ) + return get_error_data_result(message="Authentication error: no access to this chatbot!") + + resp = Response(iframe_completion(dialog_id, tenant_id=tenant_id, **req), mimetype="text/event-stream") resp.headers.add_header("Cache-control", "no-cache") resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") return resp - async for answer in iframe_completion(dialog_id, **req): - return get_result(data=answer) + try: + _validate_iframe_access() + async for answer in iframe_completion(dialog_id, tenant_id=tenant_id, **req): + return get_result(data=answer) + except AssertionError: + logger.warning( + "Denied chatbot completion: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s", + "no access to this chatbot", + tenant_id, + dialog_id, + req.get("user_id"), + req.get("session_id"), + ) + return get_error_data_result(message="Authentication error: no access to this chatbot!") return None @@ -178,11 +229,23 @@ async def chatbots_inputs(dialog_id): objs = APIToken.query(beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') - - e, dialog = DialogService.get_by_id(dialog_id) - if not e: - return get_error_data_result(f"Can't find dialog by ID: {dialog_id}") - + tenant_id = objs[0].tenant_id + exists, dialog = DialogService.get_by_id(dialog_id) + if (not exists + or getattr(dialog, "tenant_id", None) != tenant_id + or str(getattr(dialog, "status", "")) != StatusEnum.VALID.value): + request_args = getattr(request, "args", {}) or {} + request_user_id = request_args.get("user_id") if hasattr(request_args, "get") else None + request_session_id = request_args.get("session_id") if hasattr(request_args, "get") else None + logger.warning( + "Denied chatbot access: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s", + "no access to this chatbot", + tenant_id, + dialog_id, + request_user_id, + request_session_id, + ) + return get_error_data_result(message="Authentication error: no access to this chatbot!") return get_result( data={ "title": dialog.name, diff --git a/api/db/services/conversation_service.py b/api/db/services/conversation_service.py index 5a205b1421..2603676e98 100644 --- a/api/db/services/conversation_service.py +++ b/api/db/services/conversation_service.py @@ -14,6 +14,7 @@ # limitations under the License. # import time +import logging from uuid import uuid4 from common.constants import StatusEnum from api.db.db_models import Conversation, DB @@ -26,6 +27,9 @@ import json from rag.prompts.generator import chunks_format +logger = logging.getLogger(__name__) + + class ConversationService(CommonService): model = Conversation @@ -201,9 +205,23 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses break yield answer -async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs): - e, dia = DialogService.get_by_id(dialog_id) - assert e, "Dialog not found" +async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, tenant_id=None, **kwargs): + if tenant_id: + exists, dia = DialogService.get_by_id(dialog_id) + if (not exists + or getattr(dia, "tenant_id", None) != tenant_id + or str(getattr(dia, "status", "")) != StatusEnum.VALID.value): + logger.warning( + "Dialog lookup failed for tenant-scoped iframe completion: " + "tenant_id=%s dialog_id=%s required_status=%s", + tenant_id, + dialog_id, + StatusEnum.VALID.value, + ) + raise AssertionError("Dialog not found") + else: + e, dia = DialogService.get_by_id(dialog_id) + assert e, "Dialog not found" if not session_id: session_id = get_uuid() conv = { @@ -228,6 +246,7 @@ async def async_iframe_completion(dialog_id, question, session_id=None, stream=T session_id = session_id e, conv = API4ConversationService.get_by_id(session_id) assert e, "Session not found!" + assert conv.dialog_id == dialog_id, "Session does not belong to this dialog" if not conv.message: conv.message = [] diff --git a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py index de548361c4..77ac86232b 100644 --- a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py @@ -1292,6 +1292,13 @@ def test_chatbot_routes_auth_stream_nonstream_unit(monkeypatch): res = _run(inspect.unwrap(module.chatbot_completions)("dialog-1")) assert "API key is invalid" in res["message"] + monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer ok"})) + monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")]) + monkeypatch.setattr(module.DialogService, "get_by_id", lambda _dialog_id: (False, None)) + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"stream": False})) + res = _run(inspect.unwrap(module.chatbot_completions)("dialog-unauthorized")) + assert res["message"] == "Authentication error: no access to this chatbot!" + stream_calls = [] async def _iframe_stream(dialog_id, **req): @@ -1301,6 +1308,11 @@ def test_chatbot_routes_auth_stream_nonstream_unit(monkeypatch): monkeypatch.setattr(module, "iframe_completion", _iframe_stream) monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer ok"})) monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")]) + monkeypatch.setattr( + module.DialogService, + "get_by_id", + lambda _dialog_id: (True, SimpleNamespace(id="dialog-1", tenant_id="tenant-1", status="1")), + ) monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"stream": True})) resp = _run(inspect.unwrap(module.chatbot_completions)("dialog-1")) assert isinstance(resp, _StubResponse) @@ -1308,11 +1320,17 @@ def test_chatbot_routes_auth_stream_nonstream_unit(monkeypatch): _run(_collect_stream(resp.body)) assert stream_calls[-1][0] == "dialog-1" assert stream_calls[-1][1]["quote"] is False + assert stream_calls[-1][1]["tenant_id"] == "tenant-1" async def _iframe_nonstream(_dialog_id, **_req): yield {"answer": "non-stream"} monkeypatch.setattr(module, "iframe_completion", _iframe_nonstream) + monkeypatch.setattr( + module.DialogService, + "get_by_id", + lambda _dialog_id: (True, SimpleNamespace(id="dialog-1", tenant_id="tenant-1", status="1")), + ) monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"stream": False, "quote": True})) res = _run(inspect.unwrap(module.chatbot_completions)("dialog-1")) assert res["data"]["answer"] == "non-stream" @@ -1329,8 +1347,27 @@ def test_chatbot_routes_auth_stream_nonstream_unit(monkeypatch): monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer ok"})) monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")]) monkeypatch.setattr(module.DialogService, "get_by_id", lambda _dialog_id: (False, None)) + res = _run(inspect.unwrap(module.chatbots_inputs)("dialog-404")) - assert res["message"] == "Can't find dialog by ID: dialog-404" + assert res["message"] == "Authentication error: no access to this chatbot!" + + # Happy path: valid token + owned dialog -> correct payload + stub_dialog = SimpleNamespace( + name="My Bot", + icon="avatar.png", + tenant_id="tenant-1", + status="1", + prompt_config={"prologue": "Hello!", "tavily_api_key": "key123"}, + ) + monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer ok"})) + monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")]) + monkeypatch.setattr(module.DialogService, "get_by_id", lambda _dialog_id: (True, stub_dialog)) + res = _run(inspect.unwrap(module.chatbots_inputs)("dialog-404")) + assert res["code"] == 0 + assert res["data"]["title"] == "My Bot" + assert res["data"]["avatar"] == "avatar.png" + assert res["data"]["prologue"] == "Hello!" + assert res["data"]["has_tavily_key"] is True @pytest.mark.p2