From da1ed6f0e7f298f1708fe957f77ee009bdcabe7a Mon Sep 17 00:00:00 2001 From: Idriss Sbaaoui <112825897+6ba3i@users.noreply.github.com> Date: Mon, 1 Jun 2026 11:02:40 +0800 Subject: [PATCH] Feat: add new tests and tescases for restful api suite (#15347) ### What problem does this PR solve? extend restful api suite ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Other (please describe): test --- .../restful_api/test_file_routes_unit.py | 297 ++++++- .../restful_api/test_llm_routes_unit.py | 742 ++++++++++++++++++ .../restful_api/test_message_routes_unit.py | 151 ++++ test/testcases/restful_api/test_retrieval.py | 11 +- test/testcases/restful_api/test_sessions.py | 502 +++++++++++- 5 files changed, 1692 insertions(+), 11 deletions(-) create mode 100644 test/testcases/restful_api/test_llm_routes_unit.py create mode 100644 test/testcases/restful_api/test_message_routes_unit.py diff --git a/test/testcases/restful_api/test_file_routes_unit.py b/test/testcases/restful_api/test_file_routes_unit.py index 39246e97a0..579e37eb07 100644 --- a/test/testcases/restful_api/test_file_routes_unit.py +++ b/test/testcases/restful_api/test_file_routes_unit.py @@ -16,6 +16,7 @@ import asyncio import importlib.util +import logging import sys from enum import Enum from pathlib import Path @@ -24,6 +25,9 @@ from types import ModuleType, SimpleNamespace import pytest +LOGGER = logging.getLogger(__name__) + + class _DummyManager: def route(self, *_args, **_kwargs): def decorator(func): @@ -370,12 +374,15 @@ class _DummyManager: class _DummyFile: - def __init__(self, file_id, file_type, *, name="file.txt", location="loc", size=1): + def __init__(self, file_id, file_type, *, name="file.txt", location="loc", size=1, tenant_id="tenant1", parent_id="pf1", source_type="user"): self.id = file_id self.type = file_type self.name = name self.location = location self.size = size + self.tenant_id = tenant_id + self.parent_id = parent_id + self.source_type = source_type class _FalsyFile(_DummyFile): @@ -630,3 +637,291 @@ def test_convert_branch_matrix_unit(monkeypatch): res = _run(module.convert()) assert res["code"] == 500 assert "convert boom" in res["message"] + + +def _load_file_api_service(monkeypatch): + LOGGER.debug("_load_file_api_service: entry") + repo_root = Path(__file__).resolve().parents[3] + + api_pkg = ModuleType("api") + api_pkg.__path__ = [str(repo_root / "api")] + monkeypatch.setitem(sys.modules, "api", api_pkg) + LOGGER.debug("_load_file_api_service: mocked api package") + + common_pkg = ModuleType("api.common") + common_pkg.__path__ = [] + monkeypatch.setitem(sys.modules, "api.common", common_pkg) + + permission_mod = ModuleType("api.common.check_team_permission") + permission_mod.check_file_team_permission = lambda *_args, **_kwargs: True + monkeypatch.setitem(sys.modules, "api.common.check_team_permission", permission_mod) + common_pkg.check_team_permission = permission_mod + + db_pkg = ModuleType("api.db") + db_pkg.__path__ = [] + + class _ServiceFileType(Enum): + FOLDER = "folder" + VIRTUAL = "virtual" + DOC = "doc" + VISUAL = "visual" + + db_pkg.FileType = _ServiceFileType + monkeypatch.setitem(sys.modules, "api.db", db_pkg) + api_pkg.db = db_pkg + + services_pkg = ModuleType("api.db.services") + services_pkg.__path__ = [] + services_pkg.duplicate_name = lambda _query, **kwargs: kwargs.get("name", "") + monkeypatch.setitem(sys.modules, "api.db.services", services_pkg) + + document_service_mod = ModuleType("api.db.services.document_service") + document_service_mod.DocumentService = SimpleNamespace( + get_doc_count=lambda _uid: 0, + get_by_id=lambda doc_id: (True, SimpleNamespace(id=doc_id)), + get_tenant_id=lambda _doc_id: "tenant1", + remove_document=lambda *_args, **_kwargs: True, + update_by_id=lambda *_args, **_kwargs: True, + ) + monkeypatch.setitem(sys.modules, "api.db.services.document_service", document_service_mod) + services_pkg.document_service = document_service_mod + + file2doc_mod = ModuleType("api.db.services.file2document_service") + file2doc_mod.File2DocumentService = SimpleNamespace( + get_by_file_id=lambda _file_id: [], + delete_by_file_id=lambda _file_id: None, + ) + monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2doc_mod) + services_pkg.file2document_service = file2doc_mod + + file_service_mod = ModuleType("api.db.services.file_service") + file_service_mod.FileService = SimpleNamespace( + get_root_folder=lambda _tenant_id: {"id": "root"}, + get_by_id=lambda file_id: (True, _DummyFile(file_id, _ServiceFileType.DOC.value)), + get_id_list_by_id=lambda _pf_id, _names, _idx, ids: ids, + create_folder=lambda _file, parent_id, _names, _len_id: SimpleNamespace(id=parent_id, name=str(parent_id)), + query=lambda **_kwargs: [], + insert=lambda data: SimpleNamespace(to_json=lambda: data, **data), + is_parent_folder_exist=lambda _pf_id: True, + get_by_pf_id=lambda *_args, **_kwargs: ([], 0), + get_parent_folder=lambda _file_id: SimpleNamespace(to_json=lambda: {"id": "root"}), + get_all_parent_folders=lambda _file_id: [], + list_all_files_by_parent_id=lambda _parent_id: [], + delete=lambda _file: True, + delete_by_id=lambda _file_id: True, + update_by_id=lambda *_args, **_kwargs: True, + get_by_ids=lambda file_ids: [_DummyFile(file_id, _ServiceFileType.DOC.value) for file_id in file_ids], + ) + monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service_mod) + services_pkg.file_service = file_service_mod + LOGGER.debug("_load_file_api_service: mocked api.db.services.file_service") + + file_utils_mod = ModuleType("api.utils.file_utils") + file_utils_mod.filename_type = lambda _filename: _ServiceFileType.DOC.value + monkeypatch.setitem(sys.modules, "api.utils.file_utils", file_utils_mod) + + common_root_mod = ModuleType("common") + common_root_mod.__path__ = [str(repo_root / "common")] + common_root_mod.settings = SimpleNamespace( + STORAGE_IMPL=SimpleNamespace( + obj_exist=lambda *_args, **_kwargs: False, + put=lambda *_args, **_kwargs: None, + rm=lambda *_args, **_kwargs: None, + move=lambda *_args, **_kwargs: None, + ) + ) + monkeypatch.setitem(sys.modules, "common", common_root_mod) + + constants_mod = ModuleType("common.constants") + + class _FileSource: + KNOWLEDGEBASE = "knowledgebase" + + constants_mod.FileSource = _FileSource + monkeypatch.setitem(sys.modules, "common.constants", constants_mod) + + misc_utils_mod = ModuleType("common.misc_utils") + misc_utils_mod.get_uuid = lambda: "uuid-1" + + async def thread_pool_exec(func, *args, **kwargs): + return func(*args, **kwargs) + + misc_utils_mod.thread_pool_exec = thread_pool_exec + monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod) + + module_path = repo_root / "api" / "apps" / "services" / "file_api_service.py" + spec = importlib.util.spec_from_file_location("api.apps.services.file_api_service", module_path) + module = importlib.util.module_from_spec(spec) + monkeypatch.setitem(sys.modules, "api.apps.services.file_api_service", module) + try: + spec.loader.exec_module(module) + except Exception: + LOGGER.exception( + "_load_file_api_service: spec.loader.exec_module(module) failed" + ) + raise + LOGGER.debug("_load_file_api_service: spec.loader.exec_module(module) completed") + return module + + +@pytest.mark.p2 +def test_upload_file_requires_existing_folder(monkeypatch): + module = _load_file_api_service(monkeypatch) + monkeypatch.setattr(module.FileService, "get_by_id", lambda _file_id: (False, None)) + + ok, message = _run(module.upload_file("tenant1", "pf1", [_DummyUploadFile("a.txt")])) + assert ok is False + assert message == "Can't find this folder!" + + +@pytest.mark.p2 +def test_upload_file_respects_user_limit(monkeypatch): + module = _load_file_api_service(monkeypatch) + monkeypatch.setattr(module.FileService, "get_by_id", lambda _file_id: (True, SimpleNamespace(id="pf1", name="pf1"))) + monkeypatch.setattr(module.DocumentService, "get_doc_count", lambda _uid: 1) + monkeypatch.setenv("MAX_FILE_NUM_PER_USER", "1") + + ok, message = _run(module.upload_file("tenant1", "pf1", [_DummyUploadFile("a.txt")])) + assert ok is False + assert message == "Exceed the maximum file number of a free user!" + monkeypatch.delenv("MAX_FILE_NUM_PER_USER", raising=False) + + +@pytest.mark.p2 +def test_upload_file_success_uses_new_service_layer(monkeypatch): + module = _load_file_api_service(monkeypatch) + storage_puts = [] + + monkeypatch.setattr(module.FileService, "get_by_id", lambda _file_id: (True, SimpleNamespace(id="pf1", name="pf1"))) + monkeypatch.setattr(module.FileService, "get_id_list_by_id", lambda *_args, **_kwargs: ["pf1"]) + monkeypatch.setattr( + module.FileService, + "create_folder", + lambda _file, parent_id, _names, _len_id, *_args: SimpleNamespace(id=parent_id), + ) + monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace( + obj_exist=lambda *_args, **_kwargs: False, + put=lambda bucket, location, blob: storage_puts.append((bucket, location, blob)), + rm=lambda *_args, **_kwargs: None, + move=lambda *_args, **_kwargs: None, + )) + + ok, data = _run(module.upload_file("tenant1", "pf1", [_DummyUploadFile("a.txt", b"hello")])) + assert ok is True + assert data[0]["name"] == "a.txt" + assert storage_puts == [("pf1", "a.txt", b"hello")] + + +@pytest.mark.p2 +def test_create_folder_rejects_duplicate_name(monkeypatch): + module = _load_file_api_service(monkeypatch) + monkeypatch.setattr(module.FileService, "query", lambda **_kwargs: [SimpleNamespace(id="existing")]) + + ok, message = _run(module.create_folder("tenant1", "dup", "pf1", module.FileType.FOLDER.value)) + assert ok is False + assert message == "Duplicated folder name in the same folder." + + +@pytest.mark.p2 +def test_delete_files_checks_team_permission(monkeypatch): + module = _load_file_api_service(monkeypatch) + monkeypatch.setattr( + module.FileService, + "get_by_id", + lambda _file_id: (True, _DummyFile("file1", module.FileType.DOC.value)), + ) + monkeypatch.setattr(module, "check_file_team_permission", lambda *_args, **_kwargs: False) + + ok, message = _run(module.delete_files("tenant1", ["file1"])) + assert ok is False + assert message == {"success_count": 0, "errors": ["No authorization for file file1"]} + + +@pytest.mark.p2 +def test_move_files_rejects_extension_change_in_new_name(monkeypatch): + module = _load_file_api_service(monkeypatch) + monkeypatch.setattr( + module.FileService, + "get_by_ids", + lambda _ids: [_DummyFile("file1", module.FileType.DOC.value, name="a.txt")], + ) + + ok, message = _run(module.move_files("tenant1", ["file1"], new_name="a.pdf")) + assert ok is False + assert message == "The extension of file can't be changed" + + +@pytest.mark.p2 +def test_move_files_handles_dest_and_storage_move(monkeypatch): + module = _load_file_api_service(monkeypatch) + moved = [] + updated = [] + + monkeypatch.setattr( + module.FileService, + "get_by_id", + lambda file_id: (False, None) if file_id == "missing" else (True, _DummyFile(file_id, module.FileType.FOLDER.value, name="dest")), + ) + monkeypatch.setattr( + module.FileService, + "get_by_ids", + lambda _ids: [_DummyFile("file1", module.FileType.DOC.value, parent_id="src", location="old", name="a.txt")], + ) + monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace( + obj_exist=lambda *_args, **_kwargs: False, + put=lambda *_args, **_kwargs: None, + rm=lambda *_args, **_kwargs: None, + move=lambda old_bucket, old_loc, new_bucket, new_loc: moved.append((old_bucket, old_loc, new_bucket, new_loc)), + )) + monkeypatch.setattr(module.FileService, "update_by_id", lambda file_id, data: updated.append((file_id, data)) or True) + + ok, message = _run(module.move_files("tenant1", ["file1"], "missing")) + assert ok is False + assert message == "Parent folder not found!" + + ok, data = _run(module.move_files("tenant1", ["file1"], "dest")) + assert ok is True + assert data is True + assert moved == [("src", "old", "dest", "a.txt")] + assert updated == [("file1", {"parent_id": "dest", "location": "a.txt"})] + + +@pytest.mark.p2 +def test_move_files_renames_in_place_without_storage_move(monkeypatch): + module = _load_file_api_service(monkeypatch) + db_updates = [] + doc_updates = [] + + monkeypatch.setattr( + module.FileService, + "get_by_ids", + lambda _ids: [_DummyFile("file1", module.FileType.DOC.value, parent_id="pf1", name="a.txt")], + ) + monkeypatch.setattr(module.FileService, "update_by_id", lambda file_id, data: db_updates.append((file_id, data)) or True) + monkeypatch.setattr( + module.File2DocumentService, + "get_by_file_id", + lambda _file_id: [SimpleNamespace(document_id="doc1")], + ) + monkeypatch.setattr(module.DocumentService, "update_by_id", lambda doc_id, data: doc_updates.append((doc_id, data)) or True) + + ok, data = _run(module.move_files("tenant1", ["file1"], new_name="b.txt")) + assert ok is True + assert data is True + assert db_updates == [("file1", {"name": "b.txt"})] + assert doc_updates == [("doc1", {"name": "b.txt"})] + + +@pytest.mark.p2 +def test_get_file_content_checks_permission(monkeypatch): + module = _load_file_api_service(monkeypatch) + monkeypatch.setattr(module, "check_file_team_permission", lambda *_args, **_kwargs: False) + + ok, message = module.get_file_content("tenant1", "file1") + assert ok is False + assert message == "No authorization." + + monkeypatch.setattr(module, "check_file_team_permission", lambda *_args, **_kwargs: True) + ok, file = module.get_file_content("tenant1", "file1") + assert ok is True + assert file.id == "file1" diff --git a/test/testcases/restful_api/test_llm_routes_unit.py b/test/testcases/restful_api/test_llm_routes_unit.py new file mode 100644 index 0000000000..a43dbac2f8 --- /dev/null +++ b/test/testcases/restful_api/test_llm_routes_unit.py @@ -0,0 +1,742 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import importlib.util +import json +import sys +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest +import requests + +from configs import HOST_ADDRESS, VERSION +from test.testcases.conftest import login +from test.testcases.libs.auth import RAGFlowWebApiAuth + + +class _DummyManager: + def route(self, *_args, **_kwargs): + def decorator(func): + return func + + return decorator + + +class _ExprField: + def __init__(self, name): + self.name = name + + def __eq__(self, other): + return (self.name, other) + + +class _StrEnum(str): + @property + def value(self): + return str(self) + + +class _DummyTenantLLMModel: + tenant_id = _ExprField("tenant_id") + llm_factory = _ExprField("llm_factory") + llm_name = _ExprField("llm_name") + + def __init__(self, id=None, **kwargs): + self.id = id + self.api_key = None + self.status = None + for key, value in kwargs.items(): + setattr(self, key, value) + + +class _TenantLLMRow: + def __init__( + self, + *, + id, + llm_name, + llm_factory, + model_type, + api_key="key", + status="1", + used_tokens=0, + api_base="", + max_tokens=8192, + ): + self.id = id + self.llm_name = llm_name + self.llm_factory = llm_factory + self.model_type = model_type + self.api_key = api_key + self.status = status + self.used_tokens = used_tokens + self.api_base = api_base + self.max_tokens = max_tokens + + def to_dict(self): + return { + "id": self.id, + "llm_name": self.llm_name, + "llm_factory": self.llm_factory, + "model_type": self.model_type, + "status": self.status, + "used_tokens": self.used_tokens, + "api_base": self.api_base, + "max_tokens": self.max_tokens, + } + + +class _LLMRow: + def __init__(self, *, llm_name, fid, model_type, status="1", max_tokens=2048): + self.llm_name = llm_name + self.fid = fid + self.model_type = model_type + self.status = status + self.max_tokens = max_tokens + + def to_dict(self): + return { + "llm_name": self.llm_name, + "fid": self.fid, + "model_type": self.model_type, + "status": self.status, + "max_tokens": self.max_tokens, + } + + +def _run(coro): + return asyncio.run(coro) + + +def _set_request_json(monkeypatch, module, payload): + async def _get_request_json(): + return dict(payload) + + monkeypatch.setattr(module, "get_request_json", _get_request_json) + + +def _load_llm_app(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + + quart_mod = ModuleType("quart") + quart_mod.request = SimpleNamespace(args={}) + monkeypatch.setitem(sys.modules, "quart", quart_mod) + + apps_mod = ModuleType("api.apps") + apps_mod.__path__ = [str(repo_root / "api" / "apps")] + apps_mod.login_required = lambda fn: fn + apps_mod.current_user = SimpleNamespace(id="tenant-1") + monkeypatch.setitem(sys.modules, "api.apps", apps_mod) + + tenant_llm_mod = ModuleType("api.db.services.tenant_llm_service") + + class _StubLLMFactoriesService: + @staticmethod + def query(**_kwargs): + return [] + + class _StubTenantLLMService: + @staticmethod + def ensure_mineru_from_env(_tenant_id): + return None + + @staticmethod + def ensure_opendataloader_from_env(_tenant_id): + return None + + @staticmethod + def query(**_kwargs): + return [] + + @staticmethod + def get_my_llms(_tenant_id): + return [] + + @staticmethod + def save(**_kwargs): + return True + + @staticmethod + def filter_delete(_filters): + return True + + @staticmethod + def filter_update(_filters, _payload): + return True + + tenant_llm_mod.LLMFactoriesService = _StubLLMFactoriesService + tenant_llm_mod.TenantLLMService = _StubTenantLLMService + monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_mod) + + llm_service_mod = ModuleType("api.db.services.llm_service") + + class _StubLLMService: + @staticmethod + def get_all(): + return [] + + @staticmethod + def query(**_kwargs): + return [] + + llm_service_mod.LLMService = _StubLLMService + monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod) + + api_utils_mod = ModuleType("api.utils.api_utils") + api_utils_mod.get_allowed_llm_factories = lambda: [] + api_utils_mod.get_data_error_result = lambda message="", code=400, data=None: { + "code": code, + "message": message, + "data": data, + } + api_utils_mod.get_json_result = lambda data=None, message="", code=0: { + "code": code, + "message": message, + "data": data, + } + + async def _get_request_json(): + return {} + + api_utils_mod.get_request_json = _get_request_json + api_utils_mod.server_error_response = lambda exc: {"code": 500, "message": str(exc), "data": None} + api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn) + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + constants_mod = ModuleType("common.constants") + constants_mod.StatusEnum = SimpleNamespace(VALID=SimpleNamespace(value="1"), INVALID=SimpleNamespace(value="0")) + constants_mod.LLMType = SimpleNamespace( + CHAT=_StrEnum("chat"), + EMBEDDING=_StrEnum("embedding"), + SPEECH2TEXT=_StrEnum("speech2text"), + IMAGE2TEXT=_StrEnum("image2text"), + RERANK=_StrEnum("rerank"), + TTS=_StrEnum("tts"), + OCR=_StrEnum("ocr"), + ) + monkeypatch.setitem(sys.modules, "common.constants", constants_mod) + + db_models_mod = ModuleType("api.db.db_models") + db_models_mod.TenantLLM = _DummyTenantLLMModel + monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod) + + base64_mod = ModuleType("rag.utils.base64_image") + base64_mod.test_image = b"image-bytes" + monkeypatch.setitem(sys.modules, "rag.utils.base64_image", base64_mod) + + rag_llm_mod = ModuleType("rag.llm") + rag_llm_mod.EmbeddingModel = {} + rag_llm_mod.ChatModel = {} + rag_llm_mod.RerankModel = {} + rag_llm_mod.CvModel = {} + rag_llm_mod.TTSModel = {} + rag_llm_mod.OcrModel = {} + rag_llm_mod.Seq2txtModel = {} + monkeypatch.setitem(sys.modules, "rag.llm", rag_llm_mod) + + module_path = repo_root / "api" / "apps" / "llm_app.py" + spec = importlib.util.spec_from_file_location("test_llm_routes_unit_module", module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + spec.loader.exec_module(module) + return module + + +def _rag_llm_module(): + return sys.modules["rag.llm"] + + +@pytest.mark.p2 +def test_llm_list_app_grouping_availability_and_merge(monkeypatch): + module = _load_llm_app(monkeypatch) + + ensure_calls = [] + monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda tenant_id: ensure_calls.append(tenant_id)) + + tenant_rows = [ + _TenantLLMRow(id=1, llm_name="fast-emb", llm_factory="FastEmbed", model_type="embedding", api_key="k1", status="1"), + _TenantLLMRow(id=2, llm_name="tenant-only", llm_factory="CustomFactory", model_type="chat", api_key="k2", status="1"), + _TenantLLMRow(id=3, llm_name="gpt-5.5", llm_factory="OpenAI", model_type="chat", api_key="k3", status="1"), + _TenantLLMRow(id=4, llm_name="gpt-5.4", llm_factory="OpenAI", model_type="chat", api_key="k4", status="1"), + ] + monkeypatch.setattr(module.TenantLLMService, "query", lambda **_kwargs: tenant_rows) + + all_llms = [ + _LLMRow(llm_name="tei-embed", fid="Builtin", model_type="embedding", status="1"), + _LLMRow(llm_name="fast-emb", fid="FastEmbed", model_type="embedding", status="1"), + _LLMRow(llm_name="gpt-5.5", fid="OpenAI", model_type="chat", status="1"), + _LLMRow(llm_name="gpt-5.4", fid="OpenAI", model_type="chat", status="1"), + _LLMRow(llm_name="not-in-status", fid="Other", model_type="chat", status="1"), + ] + monkeypatch.setattr(module.LLMService, "get_all", lambda: all_llms) + + monkeypatch.setattr(module, "request", SimpleNamespace(args={})) + monkeypatch.setenv("COMPOSE_PROFILES", "tei-cpu") + monkeypatch.setenv("TEI_MODEL", "tei-embed") + + res = _run(module.list_app()) + assert res["code"] == 0, res["message"] + assert ensure_calls == ["tenant-1"] + + data = res["data"] + assert {"Builtin", "FastEmbed", "CustomFactory", "OpenAI"}.issubset(set(data.keys())) + assert data["Builtin"][0]["llm_name"] == "tei-embed" + assert data["Builtin"][0]["available"] is True + assert data["FastEmbed"][0]["llm_name"] == "fast-emb" + assert data["FastEmbed"][0]["available"] is True + assert data["CustomFactory"][0]["llm_name"] == "tenant-only" + assert data["CustomFactory"][0]["available"] is True + openai_names = {item["llm_name"] for item in data["OpenAI"]} + assert {"gpt-5.5", "gpt-5.4"}.issubset(openai_names) + + +@pytest.mark.p2 +def test_llm_list_app_model_type_filter(monkeypatch): + module = _load_llm_app(monkeypatch) + + monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda _tenant_id: None) + monkeypatch.setattr( + module.TenantLLMService, + "query", + lambda **_kwargs: [ + _TenantLLMRow(id=1, llm_name="fast-emb", llm_factory="FastEmbed", model_type="embedding", api_key="k1", status="1"), + _TenantLLMRow(id=2, llm_name="tenant-only", llm_factory="CustomFactory", model_type="chat", api_key="k2", status="1"), + ], + ) + monkeypatch.setattr( + module.LLMService, + "get_all", + lambda: [ + _LLMRow(llm_name="tei-embed", fid="Builtin", model_type="embedding", status="1"), + _LLMRow(llm_name="fast-emb", fid="FastEmbed", model_type="embedding", status="1"), + ], + ) + + monkeypatch.setattr(module, "request", SimpleNamespace(args={"model_type": "chat"})) + res = _run(module.list_app()) + assert res["code"] == 0, res["message"] + assert list(res["data"].keys()) == ["CustomFactory"] + assert res["data"]["CustomFactory"][0]["model_type"] == "chat" + + +@pytest.mark.p2 +def test_llm_list_app_exception_path(monkeypatch): + module = _load_llm_app(monkeypatch) + + monkeypatch.setattr(module, "request", SimpleNamespace(args={})) + monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda _tenant_id: None) + monkeypatch.setattr( + module.TenantLLMService, + "query", + lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("query boom")), + ) + + res = _run(module.list_app()) + assert res["code"] == 500 + assert "query boom" in res["message"] + + +@pytest.mark.p2 +def test_llm_factories_route_success_and_exception_unit(monkeypatch): + module = _load_llm_app(monkeypatch) + + def _factory(name): + return SimpleNamespace(name=name, to_dict=lambda n=name: {"name": n}) + + monkeypatch.setattr( + module, + "get_allowed_llm_factories", + lambda: [ + _factory("OpenAI"), + _factory("CustomFactory"), + _factory("FastEmbed"), + _factory("Builtin"), + ], + ) + monkeypatch.setattr( + module.LLMService, + "get_all", + lambda: [ + _LLMRow(llm_name="m1", fid="OpenAI", model_type="chat", status="1"), + _LLMRow(llm_name="m2", fid="OpenAI", model_type="embedding", status="1"), + _LLMRow(llm_name="m3", fid="OpenAI", model_type="rerank", status="0"), + ], + ) + res = module.factories() + assert res["code"] == 0 + names = [item["name"] for item in res["data"]] + assert "FastEmbed" not in names + assert "Builtin" not in names + assert {"OpenAI", "CustomFactory"} == set(names) + openai = next(item for item in res["data"] if item["name"] == "OpenAI") + assert {"chat", "embedding"} == set(openai["model_types"]) + + monkeypatch.setattr(module, "get_allowed_llm_factories", lambda: (_ for _ in ()).throw(RuntimeError("factories boom"))) + res = module.factories() + assert res["code"] == 500 + assert "factories boom" in res["message"] + + +@pytest.mark.p2 +def test_add_llm_factory_specific_key_assembly_unit(monkeypatch): + module = _load_llm_app(monkeypatch) + + async def _wait_for(coro, *_args, **_kwargs): + return await coro + + async def _to_thread(fn, *args, **kwargs): + return fn(*args, **kwargs) + + monkeypatch.setattr(module.asyncio, "wait_for", _wait_for) + monkeypatch.setattr(module.asyncio, "to_thread", _to_thread) + + allowed = [ + "VolcEngine", + "Tencent Cloud", + "Bedrock", + "LocalAI", + "HuggingFace", + "OpenAI-API-Compatible", + "VLLM", + "XunFei Spark", + "BaiduYiyan", + "Fish Audio", + "Google Cloud", + "Azure-OpenAI", + "OpenRouter", + "MinerU", + "PaddleOCR", + ] + monkeypatch.setattr(module, "get_allowed_llm_factories", lambda: [SimpleNamespace(name=name) for name in allowed]) + + captured = {"filter_payloads": []} + + class _ChatOK: + def __init__(self, *_args, **_kwargs): + pass + + async def async_chat(self, *_args, **_kwargs): + return "ok", 1 + + async def async_chat_streamly(self, *_args, **_kwargs): + yield "ok" + yield 1 + + class _TTSOK: + def __init__(self, *_args, **_kwargs): + pass + + def tts(self, _text): + yield b"ok" + + monkeypatch.setattr(_rag_llm_module(), "ChatModel", {name: _ChatOK for name in allowed}) + monkeypatch.setattr(_rag_llm_module(), "TTSModel", {"XunFei Spark": _TTSOK}) + monkeypatch.setattr(module.TenantLLMService, "filter_update", lambda _filters, payload: captured["filter_payloads"].append(dict(payload)) or True) + + reject_req = {"llm_factory": "NotAllowed", "llm_name": "x", "model_type": module.LLMType.CHAT.value} + _set_request_json(monkeypatch, module, reject_req) + res = _run(module.add_llm()) + assert res["code"] == 400 + assert "is not allowed" in res["message"] + + def _run_case(factory, *, model_type=module.LLMType.CHAT.value, extra=None): + req = {"llm_factory": factory, "llm_name": "model", "model_type": model_type, "api_key": "k", "api_base": "http://api"} + if extra: + req.update(extra) + _set_request_json(monkeypatch, module, req) + out = _run(module.add_llm()) + assert out["code"] == 0 + assert out["data"] is True + return captured["filter_payloads"][-1] + + volc = _run_case("VolcEngine", extra={"ark_api_key": "ak", "endpoint_id": "eid"}) + assert json.loads(volc["api_key"]) == {"ark_api_key": "ak", "endpoint_id": "eid"} + + bedrock = _run_case( + "Bedrock", + extra={"auth_mode": "iam", "bedrock_ak": "ak", "bedrock_sk": "sk", "bedrock_region": "r", "aws_role_arn": "arn"}, + ) + assert json.loads(bedrock["api_key"]) == { + "auth_mode": "iam", + "bedrock_ak": "ak", + "bedrock_sk": "sk", + "bedrock_region": "r", + "aws_role_arn": "arn", + } + + assert _run_case("LocalAI")["llm_name"] == "model___LocalAI" + assert _run_case("HuggingFace")["llm_name"] == "model___HuggingFace" + assert _run_case("OpenAI-API-Compatible")["llm_name"] == "model___OpenAI-API" + assert _run_case("VLLM")["llm_name"] == "model___VLLM" + + spark_chat = _run_case("XunFei Spark", extra={"spark_api_password": "spark-pass"}) + assert spark_chat["api_key"] == "spark-pass" + spark_tts = _run_case( + "XunFei Spark", + model_type=module.LLMType.TTS.value, + extra={"spark_app_id": "app", "spark_api_secret": "secret", "spark_api_key": "key"}, + ) + assert json.loads(spark_tts["api_key"]) == { + "spark_app_id": "app", + "spark_api_secret": "secret", + "spark_api_key": "key", + } + + assert json.loads(_run_case("BaiduYiyan", extra={"yiyan_ak": "ak", "yiyan_sk": "sk"})["api_key"]) == {"yiyan_ak": "ak", "yiyan_sk": "sk"} + assert json.loads(_run_case("Fish Audio", extra={"fish_audio_ak": "ak", "fish_audio_refid": "rid"})["api_key"]) == {"fish_audio_ak": "ak", "fish_audio_refid": "rid"} + assert json.loads( + _run_case("Google Cloud", extra={"google_project_id": "pid", "google_region": "us", "google_service_account_key": "sak"})["api_key"] + ) == { + "google_project_id": "pid", + "google_region": "us", + "google_service_account_key": "sak", + } + assert json.loads(_run_case("Azure-OpenAI", extra={"api_key": "real-key", "api_version": "2024-01-01"})["api_key"]) == { + "api_key": "real-key", + "api_version": "2024-01-01", + } + assert json.loads(_run_case("OpenRouter", extra={"api_key": "or-key", "provider_order": "a,b"})["api_key"]) == { + "api_key": "or-key", + "provider_order": "a,b", + } + assert json.loads(_run_case("MinerU", extra={"api_key": "m-key", "provider_order": "p1"})["api_key"]) == { + "api_key": "m-key", + "provider_order": "p1", + } + assert json.loads(_run_case("PaddleOCR", extra={"api_key": "p-key", "provider_order": "p2"})["api_key"]) == { + "api_key": "p-key", + "provider_order": "p2", + } + + tencent_req = { + "llm_factory": "Tencent Cloud", + "llm_name": "model", + "model_type": module.LLMType.CHAT.value, + "tencent_cloud_sid": "sid", + "tencent_cloud_sk": "sk", + } + + async def _tencent_request_json(): + return tencent_req + + monkeypatch.setattr(module, "get_request_json", _tencent_request_json) + delegated = {} + + async def _fake_set_api_key(): + delegated["api_key"] = tencent_req.get("api_key") + return {"code": 0, "data": "delegated"} + + monkeypatch.setattr(module, "set_api_key", _fake_set_api_key) + res = _run(module.add_llm()) + assert res["code"] == 0 + assert res["data"] == "delegated" + assert json.loads(delegated["api_key"]) == {"tencent_cloud_sid": "sid", "tencent_cloud_sk": "sk"} + + +@pytest.mark.p2 +def test_add_llm_model_type_probe_and_persistence_matrix_unit(monkeypatch): + module = _load_llm_app(monkeypatch) + + async def _wait_for(coro, *_args, **_kwargs): + return await coro + + async def _to_thread(fn, *args, **kwargs): + return fn(*args, **kwargs) + + monkeypatch.setattr(module.asyncio, "wait_for", _wait_for) + monkeypatch.setattr(module.asyncio, "to_thread", _to_thread) + monkeypatch.setattr( + module, + "get_allowed_llm_factories", + lambda: [ + SimpleNamespace(name=name) + for name in [ + "FEmbFail", + "FEmbPass", + "FChatFail", + "FChatPass", + "FRKey", + "FRFail", + "FImgFail", + "FTTSFail", + "FOcrFail", + "FSttFail", + "FUnknown", + ] + ], + ) + + class _EmbeddingFail: + def __init__(self, *_args, **_kwargs): + pass + + def encode(self, _texts): + return [[]], 1 + + class _EmbeddingPass: + def __init__(self, *_args, **_kwargs): + pass + + def encode(self, _texts): + return [[0.5]], 1 + + class _ChatFail: + def __init__(self, *_args, **_kwargs): + pass + + async def async_chat(self, *_args, **_kwargs): + return "**ERROR**: chat failed", 0 + + async def async_chat_streamly(self, *_args, **_kwargs): + yield "**ERROR**: chat failed" + yield 0 + + class _ChatPass: + def __init__(self, *_args, **_kwargs): + pass + + async def async_chat(self, *_args, **_kwargs): + return "ok", 1 + + async def async_chat_streamly(self, *_args, **_kwargs): + yield "ok" + yield 1 + + class _RerankFail: + def __init__(self, *_args, **_kwargs): + pass + + def similarity(self, *_args, **_kwargs): + return [], 1 + + class _CvFail: + def __init__(self, *_args, **_kwargs): + pass + + def describe(self, _image_data): + return "**ERROR**: image failed", 0 + + class _TTSFail: + def __init__(self, *_args, **_kwargs): + pass + + def tts(self, _text): + raise RuntimeError("tts failed") + + class _OcrFail: + def __init__(self, *_args, **_kwargs): + pass + + def __call__(self, _img): + return None + + class _SttFail: + def __init__(self, *_args, **_kwargs): + pass + + def transcribe(self, _audio): + return "", 0 + + rag_llm_mod = _rag_llm_module() + monkeypatch.setattr(rag_llm_mod, "EmbeddingModel", {"FEmbFail": _EmbeddingFail, "FEmbPass": _EmbeddingPass}) + monkeypatch.setattr(rag_llm_mod, "ChatModel", {"FChatFail": _ChatFail, "FChatPass": _ChatPass}) + monkeypatch.setattr(rag_llm_mod, "RerankModel", {"FRFail": _RerankFail, "FRKey": _RerankFail}) + monkeypatch.setattr(rag_llm_mod, "CvModel", {"FImgFail": _CvFail}) + monkeypatch.setattr(rag_llm_mod, "TTSModel", {"FTTSFail": _TTSFail}) + monkeypatch.setattr(rag_llm_mod, "OcrModel", {"FOcrFail": _OcrFail}) + monkeypatch.setattr(rag_llm_mod, "Seq2txtModel", {"FSttFail": _SttFail}) + + saves = [] + monkeypatch.setattr(module.TenantLLMService, "filter_update", lambda _filters, _payload: False) + monkeypatch.setattr(module.TenantLLMService, "save", lambda **kwargs: saves.append(kwargs) or True) + + monkeypatch.setattr( + module.LLMService, + "query", + lambda **kwargs: [] if kwargs.get("llm_factory") == "FUnknown" else [ + _LLMRow(llm_name="m", fid=kwargs.get("llm_factory"), model_type=kwargs.get("model_type", module.LLMType.CHAT.value), max_tokens=4096) + ], + ) + + _set_request_json(monkeypatch, module, {"llm_factory": "FUnknown", "llm_name": "m", "model_type": "unknown"}) + with pytest.raises(RuntimeError, match="Unknown model type: unknown"): + _run(module.add_llm()) + + cases = [ + ("FEmbFail", module.LLMType.EMBEDDING.value, 400, None, "embedding model"), + ("FEmbPass", module.LLMType.EMBEDDING.value, 0, True, ""), + ("FChatFail", module.LLMType.CHAT.value, 400, None, "No valid response received"), + ("FChatPass", module.LLMType.CHAT.value, 0, True, ""), + ("FRFail", module.LLMType.RERANK.value, 400, None, "Not known"), + ("FImgFail", module.LLMType.IMAGE2TEXT.value, 400, None, "image failed"), + ("FTTSFail", module.LLMType.TTS.value, 400, None, "tts"), + ("FOcrFail", module.LLMType.OCR.value, 400, None, "ocr"), + ("FSttFail", module.LLMType.SPEECH2TEXT.value, 0, True, ""), + ] + for factory, model_type, expected_code, expected_data, expected_fragment in cases: + _set_request_json(monkeypatch, module, {"llm_factory": factory, "llm_name": "m", "model_type": model_type, "api_key": "key"}) + res = _run(module.add_llm()) + assert res["code"] == expected_code + if expected_data is not None: + assert res["data"] is expected_data + if expected_fragment: + assert expected_fragment.lower() in res["message"].lower() + + assert any(item["llm_factory"] == "FEmbPass" for item in saves) + + +@pytest.mark.p2 +def test_llm_factories_live_auth_contract(): + llm_url = f"{HOST_ADDRESS}/{VERSION}/llm/factories" + invalid_auth_cases = [ + (None, 401, ""), + (RAGFlowWebApiAuth("invalid-token"), 401, ""), + ] + for auth_obj, expected_code, expected_message in invalid_auth_cases: + res = requests.get(llm_url, auth=auth_obj, timeout=30) + assert res.status_code == 401 + payload = res.json() + assert payload["code"] == expected_code, payload + assert payload["message"] == expected_message, payload + + ok_res = requests.get(llm_url, auth=RAGFlowWebApiAuth(login()), timeout=30) + assert ok_res.status_code == 200 + ok_payload = ok_res.json() + assert ok_payload["code"] == 0, ok_payload + assert isinstance(ok_payload["data"], list), ok_payload + + +@pytest.mark.p2 +def test_llm_list_live_auth_contract(): + llm_url = f"{HOST_ADDRESS}/{VERSION}/llm/list" + invalid_auth_cases = [ + (None, 401, ""), + (RAGFlowWebApiAuth("invalid-token"), 401, ""), + ] + for auth_obj, expected_code, expected_message in invalid_auth_cases: + res = requests.get(llm_url, auth=auth_obj, timeout=30) + assert res.status_code == 401 + payload = res.json() + assert payload["code"] == expected_code, payload + assert payload["message"] == expected_message, payload + + ok_res = requests.get(llm_url, auth=RAGFlowWebApiAuth(login()), timeout=30) + assert ok_res.status_code == 200 + ok_payload = ok_res.json() + assert ok_payload["code"] == 0, ok_payload + assert isinstance(ok_payload["data"], dict), ok_payload diff --git a/test/testcases/restful_api/test_message_routes_unit.py b/test/testcases/restful_api/test_message_routes_unit.py new file mode 100644 index 0000000000..b7890d31db --- /dev/null +++ b/test/testcases/restful_api/test_message_routes_unit.py @@ -0,0 +1,151 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import importlib.util +import inspect +import sys +from copy import deepcopy +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest + + +class _DummyManager: + def route(self, *_args, **_kwargs): + def decorator(func): + return func + + return decorator + + +class _AwaitableValue: + def __init__(self, value): + self._value = value + + def __await__(self): + async def _co(): + return self._value + + return _co().__await__() + + +class _DummyArgs(dict): + def getlist(self, key): + value = self.get(key) + if value is None: + return [] + if isinstance(value, list): + return value + return [value] + + +class _DummyMemoryApiService: + async def add_message(self, *_args, **_kwargs): + return True, "ok" + + async def get_messages(self, *_args, **_kwargs): + return [] + + +def _run(coro): + return asyncio.run(coro) + + +def _load_memory_routes_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + apps_mod = ModuleType("api.apps") + apps_mod.__path__ = [str(repo_root / "api" / "apps")] + apps_mod.current_user = SimpleNamespace(id="user-1") + apps_mod.login_required = lambda func: func + monkeypatch.setitem(sys.modules, "api.apps", apps_mod) + + services_mod = ModuleType("api.apps.services") + services_mod.memory_api_service = _DummyMemoryApiService() + monkeypatch.setitem(sys.modules, "api.apps.services", services_mod) + + module_name = "test_message_routes_unit_module" + module_path = repo_root / "api" / "apps" / "restful_apis" / "memory_api.py" + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + monkeypatch.setitem(sys.modules, module_name, module) + spec.loader.exec_module(module) + return module + + +def _set_request_json(monkeypatch, module, payload): + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(deepcopy(payload))) + + +@pytest.mark.p2 +def test_add_message_partial_failure_branch(monkeypatch): + module = _load_memory_routes_module(monkeypatch) + + _set_request_json( + monkeypatch, + module, + { + "memory_id": ["memory-1"], + "agent_id": "agent-1", + "session_id": "session-1", + "user_input": "hello", + "agent_response": "world", + }, + ) + + async def _add_message(_memory_ids, _message_dict): + return False, "cannot enqueue" + + monkeypatch.setattr(module.memory_api_service, "add_message", _add_message) + + res = _run(inspect.unwrap(module.add_message)()) + assert res["code"] == module.RetCode.SERVER_ERROR, res + assert "Some messages failed to add" in res["message"], res + + +@pytest.mark.p2 +def test_get_messages_csv_and_missing_memory_ids(monkeypatch): + module = _load_memory_routes_module(monkeypatch) + + monkeypatch.setattr(module, "request", SimpleNamespace(args=_DummyArgs({}))) + res = _run(inspect.unwrap(module.get_messages)()) + assert res["code"] == module.RetCode.ARGUMENT_ERROR, res + assert "memory_ids is required." in res["message"], res + + monkeypatch.setattr( + module, + "request", + SimpleNamespace(args=_DummyArgs({"memory_id": "m1,m2", "agent_id": "a1", "session_id": "s1", "limit": "5"})), + ) + + async def _get_messages(memory_ids, agent_id, session_id, limit): + assert memory_ids == ["m1", "m2"] + assert agent_id == "a1" + assert session_id == "s1" + assert limit == 5 + return [{"message_id": 1}] + + monkeypatch.setattr(module.memory_api_service, "get_messages", _get_messages) + res = _run(inspect.unwrap(module.get_messages)()) + assert res["code"] == module.RetCode.SUCCESS, res + assert isinstance(res["data"], list), res diff --git a/test/testcases/restful_api/test_retrieval.py b/test/testcases/restful_api/test_retrieval.py index 3c7eca7d10..350f277407 100644 --- a/test/testcases/restful_api/test_retrieval.py +++ b/test/testcases/restful_api/test_retrieval.py @@ -16,8 +16,7 @@ from concurrent.futures import ThreadPoolExecutor import pytest -import requests -from test.testcases.configs import HOST_ADDRESS, INVALID_API_TOKEN, VERSION +from test.testcases.configs import INVALID_API_TOKEN from test.testcases.restful_api.helpers.client import RestClient from test.testcases.utils import wait_for @@ -378,12 +377,8 @@ def test_deleted_chunks_batch_not_in_retrieval_contract(rest_client, create_docu @pytest.mark.p2 -def test_related_questions_contract(auth, rest_client, rest_client_noauth): - tokens_res = requests.get( - f"{HOST_ADDRESS}/api/{VERSION}/system/tokens", - headers={"Authorization": auth}, - timeout=30, - ) +def test_related_questions_contract(rest_client, rest_client_noauth): + tokens_res = rest_client.get("/system/tokens") assert tokens_res.status_code == 200, tokens_res.text tokens_payload = tokens_res.json() assert tokens_payload["code"] == 0, tokens_payload diff --git a/test/testcases/restful_api/test_sessions.py b/test/testcases/restful_api/test_sessions.py index ca1c8ea5c6..3d7d30a713 100644 --- a/test/testcases/restful_api/test_sessions.py +++ b/test/testcases/restful_api/test_sessions.py @@ -15,14 +15,36 @@ # import json +from concurrent.futures import ThreadPoolExecutor import pytest +from test.testcases.configs import INVALID_API_TOKEN, INVALID_ID_32, SESSION_WITH_CHAT_NAME_LIMIT +from test.testcases.restful_api.helpers.client import RestClient +from test.testcases.utils import is_sorted + def _sse_events(response_text: str) -> list[str]: return [line[5:] for line in response_text.splitlines() if line.startswith("data:")] +def _session_names(payload): + return [session["name"] for session in payload["data"]] + + +def _seed_sessions(rest_client, create_chat, prefix, count=5): + chat_id = create_chat(f"{prefix}_chat") + sessions = [] + for index in range(count): + name = f"{prefix}_{index}" + res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": name}) + assert res.status_code == 200, (prefix, index, res.text) + payload = res.json() + assert payload["code"] == 0, (prefix, index, payload) + sessions.append(payload["data"]) + return chat_id, sessions + + @pytest.mark.p1 def test_session_crud_cycle(rest_client, create_chat): chat_id = create_chat("restful_session_crud_chat") @@ -100,6 +122,475 @@ def test_session_update_blocks_messages_and_reference(rest_client, create_chat): assert "`reference` cannot be changed." in ref_payload["message"], ref_payload +@pytest.mark.p1 +def test_session_create_requires_auth_and_invalid_chat_contract(): + for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))): + res = client.post("/chats/chat_id/sessions", json={"name": "x"}) + assert res.status_code == 401, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 401, (scenario_name, payload) + assert payload["message"] == "", (scenario_name, payload) + + + +@pytest.mark.p2 +def test_session_create_validation_and_deleted_chat_contract(rest_client, create_chat): + chat_id = create_chat("restful_session_create_contract") + + empty_path_res = rest_client.post("/chats//sessions", json={"name": "valid_name"}) + assert empty_path_res.status_code == 200 + empty_path_payload = empty_path_res.json() + assert empty_path_payload["code"] == 100, empty_path_payload + assert empty_path_payload["message"] == "", empty_path_payload + + invalid_chat_res = rest_client.post("/chats/invalid_chat_assistant_id/sessions", json={"name": "valid_name"}) + assert invalid_chat_res.status_code == 200 + invalid_chat_payload = invalid_chat_res.json() + assert invalid_chat_payload["code"] == 109, invalid_chat_payload + assert invalid_chat_payload["message"] == "No authorization.", invalid_chat_payload + + for scenario_name, payload in ( + ("valid", {"name": "valid_name"}), + ("empty", {"name": ""}), + ("space", {"name": " "}), + ("numeric", {"name": 1}), + ): + res = rest_client.post(f"/chats/{chat_id}/sessions", json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + if scenario_name == "valid": + assert body["code"] == 0, (scenario_name, body) + assert body["data"]["name"] == "valid_name", (scenario_name, body) + assert body["data"]["chat_id"] == chat_id, (scenario_name, body) + else: + assert body["code"] == 102, (scenario_name, body) + assert body["message"] == "`name` can not be empty.", (scenario_name, body) + + duplicate_first = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "duplicated_name"}).json() + duplicate_second = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "duplicated_name"}).json() + assert duplicate_first["code"] == 0, duplicate_first + assert duplicate_second["code"] == 0, duplicate_second + assert duplicate_second["data"]["name"] == "duplicated_name", duplicate_second + + upper_case = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "CASE INSENSITIVE"}).json() + lower_case = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "case insensitive"}).json() + assert upper_case["code"] == 0, upper_case + assert lower_case["code"] == 0, lower_case + assert upper_case["data"]["name"] == "CASE INSENSITIVE", upper_case + assert lower_case["data"]["name"] == "case insensitive", lower_case + + long_name = "a" * (SESSION_WITH_CHAT_NAME_LIMIT + 1) + long_name_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": long_name}) + assert long_name_res.status_code == 200 + long_name_payload = long_name_res.json() + assert long_name_payload["code"] == 0, long_name_payload + assert long_name_payload["data"]["name"] == long_name[:SESSION_WITH_CHAT_NAME_LIMIT], long_name_payload + + delete_res = rest_client.delete("/chats", json={"ids": [chat_id]}) + assert delete_res.status_code == 200 + delete_payload = delete_res.json() + assert delete_payload["code"] == 0, delete_payload + + create_after_delete = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "after_delete"}) + assert create_after_delete.status_code == 200 + create_after_delete_payload = create_after_delete.json() + assert create_after_delete_payload["code"] == 109, create_after_delete_payload + assert create_after_delete_payload["message"] == "No authorization.", create_after_delete_payload + + +@pytest.mark.p2 +def test_session_create_concurrent_contract(rest_client, create_chat): + chat_id = create_chat("restful_session_create_concurrent") + + def _create(index): + return rest_client.post(f"/chats/{chat_id}/sessions", json={"name": f"session create {index}"}).json() + + with ThreadPoolExecutor(max_workers=5) as executor: + results = list(executor.map(_create, range(20))) + + assert len(results) == 20, results + assert all(result["code"] == 0 for result in results), results + + list_res = rest_client.get(f"/chats/{chat_id}/sessions", params={"page_size": 30}) + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert len(list_payload["data"]) == 20, list_payload + + +@pytest.mark.p1 +def test_session_delete_requires_auth_and_invalid_target_contract(rest_client, create_chat): + chat_id = create_chat("restful_session_delete_auth") + create_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_a"}) + assert create_res.status_code == 200 + session_id = create_res.json()["data"]["id"] + + for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))): + res = client.delete(f"/chats/{chat_id}/sessions", json={"ids": [session_id]}) + assert res.status_code == 401, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 401, (scenario_name, payload) + assert payload["message"] == "", (scenario_name, payload) + + invalid_chat_res = rest_client.delete("/chats/invalid_chat_assistant_id/sessions", json={"ids": [session_id]}) + assert invalid_chat_res.status_code == 200 + invalid_chat_payload = invalid_chat_res.json() + assert invalid_chat_payload["code"] == 109, invalid_chat_payload + assert invalid_chat_payload["message"] == "No authorization.", invalid_chat_payload + + +@pytest.mark.p2 +def test_session_delete_basic_scenarios(rest_client, create_chat): + cases = [ + ("none payload", None, 0, 5, {}), + ("invalid only", {"ids": ["invalid_id"]}, 102, 5, "The chat doesn't own the session invalid_id"), + ("not json", "not json", 100, 5, ""), + ("single id", lambda sessions: {"ids": [sessions[0]["id"]]}, 0, 4, True), + ("all ids", lambda sessions: {"ids": [session["id"] for session in sessions]}, 0, 0, True), + ("delete all", {"delete_all": True}, 0, 0, True), + ("empty ids", {"ids": []}, 0, 5, {}), + ] + + for scenario_name, payload, expected_code, expected_remaining, expected_data in cases: + chat_id, sessions = _seed_sessions(rest_client, create_chat, f"delete_basic_{scenario_name.replace(' ', '_')}") + if callable(payload): + payload = payload(sessions) + if scenario_name == "not json": + res = rest_client.delete( + f"/chats/{chat_id}/sessions", + headers={"Content-Type": "application/json"}, + data=payload, + ) + else: + res = rest_client.delete(f"/chats/{chat_id}/sessions", json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + if expected_code == 0: + assert body["data"] == expected_data, (scenario_name, body) + else: + assert body["message"] == expected_data, (scenario_name, body) + + list_res = rest_client.get(f"/chats/{chat_id}/sessions", params={"page_size": 30}) + assert list_res.status_code == 200, (scenario_name, list_res.text) + list_payload = list_res.json() + assert list_payload["code"] == 0, (scenario_name, list_payload) + assert len(list_payload["data"]) == expected_remaining, (scenario_name, list_payload) + + +@pytest.mark.p2 +def test_session_delete_error_and_repeat_contract(rest_client, create_chat): + partial_cases = [ + ("invalid first", lambda sessions: {"ids": ["invalid_id"] + [session["id"] for session in sessions]}), + ("invalid middle", lambda sessions: {"ids": [sessions[0]["id"], "invalid_id", *[session["id"] for session in sessions[1:]]]}), + ("invalid last", lambda sessions: {"ids": [session["id"] for session in sessions] + ["invalid_id"]}), + ] + for scenario_name, payload_builder in partial_cases: + chat_id, sessions = _seed_sessions(rest_client, create_chat, f"delete_partial_{scenario_name.replace(' ', '_')}") + res = rest_client.delete(f"/chats/{chat_id}/sessions", json=payload_builder(sessions)) + assert res.status_code == 200, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 0, (scenario_name, payload) + assert payload["data"]["success_count"] == len(sessions), (scenario_name, payload) + assert payload["data"]["errors"] == ["The chat doesn't own the session invalid_id"], (scenario_name, payload) + remaining = rest_client.get(f"/chats/{chat_id}/sessions", params={"page_size": 30}).json() + assert remaining["code"] == 0, (scenario_name, remaining) + assert remaining["data"] == [], (scenario_name, remaining) + + duplicate_chat_id, duplicate_sessions = _seed_sessions(rest_client, create_chat, "delete_duplicate") + duplicate_id = duplicate_sessions[0]["id"] + duplicate_res = rest_client.delete(f"/chats/{duplicate_chat_id}/sessions", json={"ids": [duplicate_id, duplicate_id]}) + assert duplicate_res.status_code == 200 + duplicate_payload = duplicate_res.json() + assert duplicate_payload["code"] == 0, duplicate_payload + assert duplicate_payload["data"]["success_count"] == 1, duplicate_payload + assert duplicate_payload["data"]["errors"] == [f"Duplicate session ids: {duplicate_id}"], duplicate_payload + + repeated_chat_id, repeated_sessions = _seed_sessions(rest_client, create_chat, "delete_repeated") + repeated_ids = [session["id"] for session in repeated_sessions] + first_res = rest_client.delete(f"/chats/{repeated_chat_id}/sessions", json={"ids": repeated_ids}) + assert first_res.status_code == 200 + first_payload = first_res.json() + assert first_payload["code"] == 0, first_payload + assert first_payload["data"] is True, first_payload + + second_res = rest_client.delete(f"/chats/{repeated_chat_id}/sessions", json={"ids": repeated_ids}) + assert second_res.status_code == 200 + second_payload = second_res.json() + assert second_payload["code"] == 102, second_payload + for session_id in repeated_ids: + assert f"The chat doesn't own the session {session_id}" in second_payload["message"], second_payload + + +@pytest.mark.p2 +def test_session_delete_concurrent_and_bulk_contract(rest_client, create_chat): + concurrent_chat_id, concurrent_sessions = _seed_sessions(rest_client, create_chat, "delete_concurrent", count=20) + + def _delete(session): + return rest_client.delete(f"/chats/{concurrent_chat_id}/sessions", json={"ids": [session["id"]]}).json() + + with ThreadPoolExecutor(max_workers=5) as executor: + results = list(executor.map(_delete, concurrent_sessions)) + + assert len(results) == 20, results + assert all(result["code"] == 0 for result in results), results + assert all(result["data"] is True for result in results), results + + list_after_concurrent = rest_client.get(f"/chats/{concurrent_chat_id}/sessions", params={"page_size": 30}).json() + assert list_after_concurrent["code"] == 0, list_after_concurrent + assert list_after_concurrent["data"] == [], list_after_concurrent + + bulk_chat_id, bulk_sessions = _seed_sessions(rest_client, create_chat, "delete_bulk", count=100) + bulk_res = rest_client.delete( + f"/chats/{bulk_chat_id}/sessions", + json={"ids": [session["id"] for session in bulk_sessions]}, + ) + assert bulk_res.status_code == 200 + bulk_payload = bulk_res.json() + assert bulk_payload["code"] == 0, bulk_payload + assert bulk_payload["data"] is True, bulk_payload + + +@pytest.mark.p1 +def test_session_list_requires_auth_and_invalid_target_contract(): + for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))): + res = client.get("/chats/chat_id/sessions") + assert res.status_code == 401, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 401, (scenario_name, payload) + assert payload["message"] == "", (scenario_name, payload) + + +@pytest.mark.p2 +def test_session_list_filter_and_deleted_chat_contract(rest_client, create_chat): + chat_id, sessions = _seed_sessions(rest_client, create_chat, "list_filter") + session_ids = [session["id"] for session in sessions] + session_names = [session["name"] for session in sessions] + + default_res = rest_client.get(f"/chats/{chat_id}/sessions", params={"page_size": 30}) + assert default_res.status_code == 200 + default_payload = default_res.json() + assert default_payload["code"] == 0, default_payload + assert len(default_payload["data"]) == 5, default_payload + + for scenario_name, params, expected_names in ( + ("id none", {"id": None, "page_size": 30}, session_names), + ("id empty", {"id": "", "page_size": 30}, session_names), + ("valid id", {"id": session_ids[0], "page_size": 30}, [session_names[0]]), + ("unknown id", {"id": "unknown", "page_size": 30}, []), + ("name none", {"name": None, "page_size": 30}, session_names), + ("name empty", {"name": "", "page_size": 30}, session_names), + ("name exact", {"name": session_names[1], "page_size": 30}, [session_names[1]]), + ("name unknown", {"name": "unknown", "page_size": 30}, []), + ("name and id match", {"id": session_ids[0], "name": session_names[0], "page_size": 30}, [session_names[0]]), + ("name and id mismatch", {"id": session_ids[0], "name": "session_with_chat_assistant_100", "page_size": 30}, []), + ("name and invalid id", {"id": "id", "name": session_names[0], "page_size": 30}, []), + ("invalid params ignored", {"a": "b", "page_size": 30}, session_names), + ): + res = rest_client.get(f"/chats/{chat_id}/sessions", params=params) + assert res.status_code == 200, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 0, (scenario_name, payload) + assert set(_session_names(payload)) == set(expected_names), (scenario_name, payload) + + invalid_chat_res = rest_client.get(f"/chats/{INVALID_ID_32}/sessions") + assert invalid_chat_res.status_code == 200 + invalid_chat_payload = invalid_chat_res.json() + assert invalid_chat_payload["code"] == 109, invalid_chat_payload + assert invalid_chat_payload["message"] == "No authorization.", invalid_chat_payload + + delete_chat_res = rest_client.delete("/chats", json={"ids": [chat_id]}) + assert delete_chat_res.status_code == 200 + delete_chat_payload = delete_chat_res.json() + assert delete_chat_payload["code"] == 0, delete_chat_payload + + deleted_list_res = rest_client.get(f"/chats/{chat_id}/sessions", params={"page_size": 30}) + assert deleted_list_res.status_code == 200 + deleted_list_payload = deleted_list_res.json() + assert deleted_list_payload["code"] == 109, deleted_list_payload + assert deleted_list_payload["message"] == "No authorization.", deleted_list_payload + + +@pytest.mark.p2 +def test_session_list_page_and_sort_contract(rest_client, create_chat): + chat_id, sessions = _seed_sessions(rest_client, create_chat, "list_page_sort") + created_names = [session["name"] for session in sessions] + descending_names = list(reversed(created_names)) + + page_cases = [ + ("page none", {"page": None, "page_size": 2}, 0, 2, ""), + ("page zero", {"page": 0, "page_size": 2}, 0, 2, ""), + ("page two", {"page": 2, "page_size": 2}, 0, 2, ""), + ("page three", {"page": 3, "page_size": 2}, 0, 1, ""), + ("page string", {"page": "3", "page_size": 2}, 0, 1, ""), + ("page negative", {"page": -1, "page_size": 2}, 100, 0, "ProgrammingError(1064"), + ("page alpha", {"page": "a", "page_size": 2}, 100, 0, "ValueError(\"invalid literal for int() with base 10: 'a'\")"), + ("page_size none", {"page_size": None}, 0, 5, ""), + ("page_size zero", {"page_size": 0}, 0, 0, ""), + ("page_size one", {"page_size": 1}, 0, 1, ""), + ("page_size six", {"page_size": 6}, 0, 5, ""), + ("page_size negative", {"page_size": -1}, 0, 5, ""), + ("page_size alpha", {"page_size": "a"}, 100, 0, "ValueError(\"invalid literal for int() with base 10: 'a'\")"), + ] + for scenario_name, params, expected_code, expected_count, expected_message in page_cases: + res = rest_client.get(f"/chats/{chat_id}/sessions", params=params) + assert res.status_code == 200, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == expected_code, (scenario_name, payload) + if expected_code == 0: + assert len(payload["data"]) == expected_count, (scenario_name, payload) + else: + assert expected_message in payload["message"], (scenario_name, payload) + + sort_cases = [ + ("orderby none", {"orderby": None, "page_size": 30}, "create_time", True, descending_names, ""), + ("orderby create", {"orderby": "create_time", "page_size": 30}, "create_time", True, descending_names, ""), + ("orderby update", {"orderby": "update_time", "page_size": 30}, "update_time", True, descending_names, ""), + ("orderby name ascending", {"orderby": "name", "desc": "False", "page_size": 30}, "name", False, created_names, ""), + ("orderby unknown", {"orderby": "unknown", "page_size": 30}, None, None, None, "AttributeError(\"type object 'Conversation' has no attribute 'unknown'\")"), + ("desc none", {"desc": None, "page_size": 30}, "create_time", True, descending_names, ""), + ("desc true", {"desc": "true", "page_size": 30}, "create_time", True, descending_names, ""), + ("desc True", {"desc": "True", "page_size": 30}, "create_time", True, descending_names, ""), + ("desc false", {"desc": "false", "page_size": 30}, "create_time", False, created_names, ""), + ("desc False", {"desc": "False", "page_size": 30}, "create_time", False, created_names, ""), + ("desc false update_time", {"desc": "False", "orderby": "update_time", "page_size": 30}, "update_time", False, created_names, ""), + ("desc unknown", {"desc": "unknown", "page_size": 30}, "create_time", True, descending_names, ""), + ] + for scenario_name, params, field, descending, expected_names, expected_message in sort_cases: + res = rest_client.get(f"/chats/{chat_id}/sessions", params=params) + assert res.status_code == 200, (scenario_name, res.text) + payload = res.json() + expected_code = 0 if expected_names is not None else 100 + assert payload["code"] == expected_code, (scenario_name, payload) + if expected_code == 0: + assert is_sorted(payload["data"], field, descending), (scenario_name, payload) + assert _session_names(payload) == expected_names, (scenario_name, payload) + else: + assert expected_message in payload["message"], (scenario_name, payload) + + +@pytest.mark.p2 +def test_session_list_concurrent_contract(rest_client, create_chat): + chat_id, _sessions = _seed_sessions(rest_client, create_chat, "list_concurrent") + + def _list(_): + return rest_client.get(f"/chats/{chat_id}/sessions", params={"page_size": 30}).json() + + with ThreadPoolExecutor(max_workers=5) as executor: + results = list(executor.map(_list, range(10))) + + assert len(results) == 10, results + assert all(result["code"] == 0 for result in results), results + assert all(len(result["data"]) == 5 for result in results), results + + +@pytest.mark.p1 +def test_session_update_requires_auth_and_invalid_target_contract(rest_client, create_chat): + chat_id = create_chat("restful_session_update_auth") + create_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_update_auth"}) + assert create_res.status_code == 200 + session_id = create_res.json()["data"]["id"] + + for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))): + res = client.patch(f"/chats/{chat_id}/sessions/{session_id}", json={"name": "x"}) + assert res.status_code == 401, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 401, (scenario_name, payload) + assert payload["message"] == "", (scenario_name, payload) + + invalid_chat_res = rest_client.patch(f"/chats/{INVALID_ID_32}/sessions/{session_id}", json={"name": "x"}) + assert invalid_chat_res.status_code == 200 + invalid_chat_payload = invalid_chat_res.json() + assert invalid_chat_payload["code"] == 109, invalid_chat_payload + assert invalid_chat_payload["message"] == "No authorization.", invalid_chat_payload + + empty_session_res = rest_client.patch(f"/chats/{chat_id}/sessions/", json={"name": "x"}) + assert empty_session_res.status_code == 200 + empty_session_payload = empty_session_res.json() + assert empty_session_payload["code"] == 100, empty_session_payload + assert empty_session_payload["message"] == "", empty_session_payload + + invalid_session_res = rest_client.patch(f"/chats/{chat_id}/sessions/invalid_session_id", json={"name": "x"}) + assert invalid_session_res.status_code == 200 + invalid_session_payload = invalid_session_res.json() + assert invalid_session_payload["code"] == 102, invalid_session_payload + assert invalid_session_payload["message"] == "Session not found!", invalid_session_payload + + +@pytest.mark.p2 +def test_session_update_name_and_param_contract(rest_client, create_chat): + chat_id, sessions = _seed_sessions(rest_client, create_chat, "update_contract") + session_id = sessions[0]["id"] + + for scenario_name, payload, expected_code, expected_name_or_message in ( + ("valid", {"name": "valid_name"}, 0, "valid_name"), + ("empty", {"name": ""}, 102, "`name` can not be empty."), + ("numeric", {"name": 1}, 102, "`name` can not be empty."), + ("duplicate", {"name": "duplicated_name"}, 0, "duplicated_name"), + ("case insensitive upper", {"name": "CASE INSENSITIVE UPDATE"}, 0, "CASE INSENSITIVE UPDATE"), + ("case insensitive lower", {"name": "case insensitive update"}, 0, "case insensitive update"), + ("long name", {"name": "a" * (SESSION_WITH_CHAT_NAME_LIMIT + 1)}, 0, "a" * SESSION_WITH_CHAT_NAME_LIMIT), + ): + res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + if expected_code == 0: + assert body["data"]["name"] == expected_name_or_message, (scenario_name, body) + else: + assert body["message"] == expected_name_or_message, (scenario_name, body) + + unknown_key_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json={"unknown_key": "unknown_value"}) + assert unknown_key_res.status_code == 200 + unknown_key_payload = unknown_key_res.json() + assert unknown_key_payload["code"] == 100, unknown_key_payload + assert 'Unrecognized field name: "unknown_key"' in unknown_key_payload["message"], unknown_key_payload + + for scenario_name, payload in (("empty payload", {}), ("none payload", None)): + res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == 0, (scenario_name, body) + assert body["data"]["id"] == session_id, (scenario_name, body) + + delete_res = rest_client.delete("/chats", json={"ids": [chat_id]}) + assert delete_res.status_code == 200 + delete_payload = delete_res.json() + assert delete_payload["code"] == 0, delete_payload + + update_after_delete_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json={"name": "after_delete"}) + assert update_after_delete_res.status_code == 200 + update_after_delete_payload = update_after_delete_res.json() + assert update_after_delete_payload["code"] == 109, update_after_delete_payload + assert update_after_delete_payload["message"] == "No authorization.", update_after_delete_payload + + +@pytest.mark.p2 +def test_session_update_repeated_and_concurrent_contract(rest_client, create_chat): + chat_id, sessions = _seed_sessions(rest_client, create_chat, "update_repeated") + session_ids = [session["id"] for session in sessions] + + first_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_ids[0]}", json={"name": "valid_name_1"}) + assert first_res.status_code == 200 + assert first_res.json()["code"] == 0, first_res.json() + + second_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_ids[0]}", json={"name": "valid_name_2"}) + assert second_res.status_code == 200 + assert second_res.json()["code"] == 0, second_res.json() + + def _update(index): + return rest_client.patch( + f"/chats/{chat_id}/sessions/{session_ids[index % len(session_ids)]}", + json={"name": f"update session test {index}"}, + ).json() + + with ThreadPoolExecutor(max_workers=5) as executor: + results = list(executor.map(_update, range(20))) + + assert len(results) == 20, results + assert all(result["code"] == 0 for result in results), results + + @pytest.mark.p2 def test_chat_recommendation_requires_question(rest_client): res = rest_client.post("/chat/recommendation", json={}) @@ -120,7 +611,11 @@ def test_related_questions_compatibility_requires_auth(rest_client_noauth): assert res.status_code == 200 payload = res.json() assert payload["code"] == 102, payload - assert "Authorization is not valid!" in payload["message"], payload + assert payload["message"] in { + "Authorization is not valid!", + 'Authentication error: API key is invalid!"', + "Authentication error: API key is invalid!", + }, payload @pytest.mark.p2 @@ -226,7 +721,10 @@ def test_chat_completion_validation_errors(rest_client, create_chat): assert missing_messages.status_code == 200 missing_messages_payload = missing_messages.json() assert missing_messages_payload["code"] == 101, missing_messages_payload - assert "required argument are missing: messages" in missing_messages_payload["message"], missing_messages_payload + assert missing_messages_payload["message"] in { + "required argument are missing: messages", + "messages: is required", + }, missing_messages_payload missing_chat_for_session = rest_client.post( "/chat/completions",