Feat: add new tests and tescases for restful api suite (#15347)

### What problem does this PR solve?

extend restful api suite

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Other (please describe): test
This commit is contained in:
Idriss Sbaaoui
2026-06-01 11:02:40 +08:00
committed by GitHub
parent 4972af4367
commit da1ed6f0e7
5 changed files with 1692 additions and 11 deletions

View File

@@ -16,6 +16,7 @@
import asyncio
import importlib.util
import logging
import sys
from enum import Enum
from pathlib import Path
@@ -24,6 +25,9 @@ from types import ModuleType, SimpleNamespace
import pytest
LOGGER = logging.getLogger(__name__)
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
@@ -370,12 +374,15 @@ class _DummyManager:
class _DummyFile:
def __init__(self, file_id, file_type, *, name="file.txt", location="loc", size=1):
def __init__(self, file_id, file_type, *, name="file.txt", location="loc", size=1, tenant_id="tenant1", parent_id="pf1", source_type="user"):
self.id = file_id
self.type = file_type
self.name = name
self.location = location
self.size = size
self.tenant_id = tenant_id
self.parent_id = parent_id
self.source_type = source_type
class _FalsyFile(_DummyFile):
@@ -630,3 +637,291 @@ def test_convert_branch_matrix_unit(monkeypatch):
res = _run(module.convert())
assert res["code"] == 500
assert "convert boom" in res["message"]
def _load_file_api_service(monkeypatch):
LOGGER.debug("_load_file_api_service: entry")
repo_root = Path(__file__).resolve().parents[3]
api_pkg = ModuleType("api")
api_pkg.__path__ = [str(repo_root / "api")]
monkeypatch.setitem(sys.modules, "api", api_pkg)
LOGGER.debug("_load_file_api_service: mocked api package")
common_pkg = ModuleType("api.common")
common_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "api.common", common_pkg)
permission_mod = ModuleType("api.common.check_team_permission")
permission_mod.check_file_team_permission = lambda *_args, **_kwargs: True
monkeypatch.setitem(sys.modules, "api.common.check_team_permission", permission_mod)
common_pkg.check_team_permission = permission_mod
db_pkg = ModuleType("api.db")
db_pkg.__path__ = []
class _ServiceFileType(Enum):
FOLDER = "folder"
VIRTUAL = "virtual"
DOC = "doc"
VISUAL = "visual"
db_pkg.FileType = _ServiceFileType
monkeypatch.setitem(sys.modules, "api.db", db_pkg)
api_pkg.db = db_pkg
services_pkg = ModuleType("api.db.services")
services_pkg.__path__ = []
services_pkg.duplicate_name = lambda _query, **kwargs: kwargs.get("name", "")
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
document_service_mod = ModuleType("api.db.services.document_service")
document_service_mod.DocumentService = SimpleNamespace(
get_doc_count=lambda _uid: 0,
get_by_id=lambda doc_id: (True, SimpleNamespace(id=doc_id)),
get_tenant_id=lambda _doc_id: "tenant1",
remove_document=lambda *_args, **_kwargs: True,
update_by_id=lambda *_args, **_kwargs: True,
)
monkeypatch.setitem(sys.modules, "api.db.services.document_service", document_service_mod)
services_pkg.document_service = document_service_mod
file2doc_mod = ModuleType("api.db.services.file2document_service")
file2doc_mod.File2DocumentService = SimpleNamespace(
get_by_file_id=lambda _file_id: [],
delete_by_file_id=lambda _file_id: None,
)
monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2doc_mod)
services_pkg.file2document_service = file2doc_mod
file_service_mod = ModuleType("api.db.services.file_service")
file_service_mod.FileService = SimpleNamespace(
get_root_folder=lambda _tenant_id: {"id": "root"},
get_by_id=lambda file_id: (True, _DummyFile(file_id, _ServiceFileType.DOC.value)),
get_id_list_by_id=lambda _pf_id, _names, _idx, ids: ids,
create_folder=lambda _file, parent_id, _names, _len_id: SimpleNamespace(id=parent_id, name=str(parent_id)),
query=lambda **_kwargs: [],
insert=lambda data: SimpleNamespace(to_json=lambda: data, **data),
is_parent_folder_exist=lambda _pf_id: True,
get_by_pf_id=lambda *_args, **_kwargs: ([], 0),
get_parent_folder=lambda _file_id: SimpleNamespace(to_json=lambda: {"id": "root"}),
get_all_parent_folders=lambda _file_id: [],
list_all_files_by_parent_id=lambda _parent_id: [],
delete=lambda _file: True,
delete_by_id=lambda _file_id: True,
update_by_id=lambda *_args, **_kwargs: True,
get_by_ids=lambda file_ids: [_DummyFile(file_id, _ServiceFileType.DOC.value) for file_id in file_ids],
)
monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service_mod)
services_pkg.file_service = file_service_mod
LOGGER.debug("_load_file_api_service: mocked api.db.services.file_service")
file_utils_mod = ModuleType("api.utils.file_utils")
file_utils_mod.filename_type = lambda _filename: _ServiceFileType.DOC.value
monkeypatch.setitem(sys.modules, "api.utils.file_utils", file_utils_mod)
common_root_mod = ModuleType("common")
common_root_mod.__path__ = [str(repo_root / "common")]
common_root_mod.settings = SimpleNamespace(
STORAGE_IMPL=SimpleNamespace(
obj_exist=lambda *_args, **_kwargs: False,
put=lambda *_args, **_kwargs: None,
rm=lambda *_args, **_kwargs: None,
move=lambda *_args, **_kwargs: None,
)
)
monkeypatch.setitem(sys.modules, "common", common_root_mod)
constants_mod = ModuleType("common.constants")
class _FileSource:
KNOWLEDGEBASE = "knowledgebase"
constants_mod.FileSource = _FileSource
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
misc_utils_mod = ModuleType("common.misc_utils")
misc_utils_mod.get_uuid = lambda: "uuid-1"
async def thread_pool_exec(func, *args, **kwargs):
return func(*args, **kwargs)
misc_utils_mod.thread_pool_exec = thread_pool_exec
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
module_path = repo_root / "api" / "apps" / "services" / "file_api_service.py"
spec = importlib.util.spec_from_file_location("api.apps.services.file_api_service", module_path)
module = importlib.util.module_from_spec(spec)
monkeypatch.setitem(sys.modules, "api.apps.services.file_api_service", module)
try:
spec.loader.exec_module(module)
except Exception:
LOGGER.exception(
"_load_file_api_service: spec.loader.exec_module(module) failed"
)
raise
LOGGER.debug("_load_file_api_service: spec.loader.exec_module(module) completed")
return module
@pytest.mark.p2
def test_upload_file_requires_existing_folder(monkeypatch):
module = _load_file_api_service(monkeypatch)
monkeypatch.setattr(module.FileService, "get_by_id", lambda _file_id: (False, None))
ok, message = _run(module.upload_file("tenant1", "pf1", [_DummyUploadFile("a.txt")]))
assert ok is False
assert message == "Can't find this folder!"
@pytest.mark.p2
def test_upload_file_respects_user_limit(monkeypatch):
module = _load_file_api_service(monkeypatch)
monkeypatch.setattr(module.FileService, "get_by_id", lambda _file_id: (True, SimpleNamespace(id="pf1", name="pf1")))
monkeypatch.setattr(module.DocumentService, "get_doc_count", lambda _uid: 1)
monkeypatch.setenv("MAX_FILE_NUM_PER_USER", "1")
ok, message = _run(module.upload_file("tenant1", "pf1", [_DummyUploadFile("a.txt")]))
assert ok is False
assert message == "Exceed the maximum file number of a free user!"
monkeypatch.delenv("MAX_FILE_NUM_PER_USER", raising=False)
@pytest.mark.p2
def test_upload_file_success_uses_new_service_layer(monkeypatch):
module = _load_file_api_service(monkeypatch)
storage_puts = []
monkeypatch.setattr(module.FileService, "get_by_id", lambda _file_id: (True, SimpleNamespace(id="pf1", name="pf1")))
monkeypatch.setattr(module.FileService, "get_id_list_by_id", lambda *_args, **_kwargs: ["pf1"])
monkeypatch.setattr(
module.FileService,
"create_folder",
lambda _file, parent_id, _names, _len_id, *_args: SimpleNamespace(id=parent_id),
)
monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(
obj_exist=lambda *_args, **_kwargs: False,
put=lambda bucket, location, blob: storage_puts.append((bucket, location, blob)),
rm=lambda *_args, **_kwargs: None,
move=lambda *_args, **_kwargs: None,
))
ok, data = _run(module.upload_file("tenant1", "pf1", [_DummyUploadFile("a.txt", b"hello")]))
assert ok is True
assert data[0]["name"] == "a.txt"
assert storage_puts == [("pf1", "a.txt", b"hello")]
@pytest.mark.p2
def test_create_folder_rejects_duplicate_name(monkeypatch):
module = _load_file_api_service(monkeypatch)
monkeypatch.setattr(module.FileService, "query", lambda **_kwargs: [SimpleNamespace(id="existing")])
ok, message = _run(module.create_folder("tenant1", "dup", "pf1", module.FileType.FOLDER.value))
assert ok is False
assert message == "Duplicated folder name in the same folder."
@pytest.mark.p2
def test_delete_files_checks_team_permission(monkeypatch):
module = _load_file_api_service(monkeypatch)
monkeypatch.setattr(
module.FileService,
"get_by_id",
lambda _file_id: (True, _DummyFile("file1", module.FileType.DOC.value)),
)
monkeypatch.setattr(module, "check_file_team_permission", lambda *_args, **_kwargs: False)
ok, message = _run(module.delete_files("tenant1", ["file1"]))
assert ok is False
assert message == {"success_count": 0, "errors": ["No authorization for file file1"]}
@pytest.mark.p2
def test_move_files_rejects_extension_change_in_new_name(monkeypatch):
module = _load_file_api_service(monkeypatch)
monkeypatch.setattr(
module.FileService,
"get_by_ids",
lambda _ids: [_DummyFile("file1", module.FileType.DOC.value, name="a.txt")],
)
ok, message = _run(module.move_files("tenant1", ["file1"], new_name="a.pdf"))
assert ok is False
assert message == "The extension of file can't be changed"
@pytest.mark.p2
def test_move_files_handles_dest_and_storage_move(monkeypatch):
module = _load_file_api_service(monkeypatch)
moved = []
updated = []
monkeypatch.setattr(
module.FileService,
"get_by_id",
lambda file_id: (False, None) if file_id == "missing" else (True, _DummyFile(file_id, module.FileType.FOLDER.value, name="dest")),
)
monkeypatch.setattr(
module.FileService,
"get_by_ids",
lambda _ids: [_DummyFile("file1", module.FileType.DOC.value, parent_id="src", location="old", name="a.txt")],
)
monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(
obj_exist=lambda *_args, **_kwargs: False,
put=lambda *_args, **_kwargs: None,
rm=lambda *_args, **_kwargs: None,
move=lambda old_bucket, old_loc, new_bucket, new_loc: moved.append((old_bucket, old_loc, new_bucket, new_loc)),
))
monkeypatch.setattr(module.FileService, "update_by_id", lambda file_id, data: updated.append((file_id, data)) or True)
ok, message = _run(module.move_files("tenant1", ["file1"], "missing"))
assert ok is False
assert message == "Parent folder not found!"
ok, data = _run(module.move_files("tenant1", ["file1"], "dest"))
assert ok is True
assert data is True
assert moved == [("src", "old", "dest", "a.txt")]
assert updated == [("file1", {"parent_id": "dest", "location": "a.txt"})]
@pytest.mark.p2
def test_move_files_renames_in_place_without_storage_move(monkeypatch):
module = _load_file_api_service(monkeypatch)
db_updates = []
doc_updates = []
monkeypatch.setattr(
module.FileService,
"get_by_ids",
lambda _ids: [_DummyFile("file1", module.FileType.DOC.value, parent_id="pf1", name="a.txt")],
)
monkeypatch.setattr(module.FileService, "update_by_id", lambda file_id, data: db_updates.append((file_id, data)) or True)
monkeypatch.setattr(
module.File2DocumentService,
"get_by_file_id",
lambda _file_id: [SimpleNamespace(document_id="doc1")],
)
monkeypatch.setattr(module.DocumentService, "update_by_id", lambda doc_id, data: doc_updates.append((doc_id, data)) or True)
ok, data = _run(module.move_files("tenant1", ["file1"], new_name="b.txt"))
assert ok is True
assert data is True
assert db_updates == [("file1", {"name": "b.txt"})]
assert doc_updates == [("doc1", {"name": "b.txt"})]
@pytest.mark.p2
def test_get_file_content_checks_permission(monkeypatch):
module = _load_file_api_service(monkeypatch)
monkeypatch.setattr(module, "check_file_team_permission", lambda *_args, **_kwargs: False)
ok, message = module.get_file_content("tenant1", "file1")
assert ok is False
assert message == "No authorization."
monkeypatch.setattr(module, "check_file_team_permission", lambda *_args, **_kwargs: True)
ok, file = module.get_file_content("tenant1", "file1")
assert ok is True
assert file.id == "file1"

View File

@@ -0,0 +1,742 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import json
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
import requests
from configs import HOST_ADDRESS, VERSION
from test.testcases.conftest import login
from test.testcases.libs.auth import RAGFlowWebApiAuth
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _ExprField:
def __init__(self, name):
self.name = name
def __eq__(self, other):
return (self.name, other)
class _StrEnum(str):
@property
def value(self):
return str(self)
class _DummyTenantLLMModel:
tenant_id = _ExprField("tenant_id")
llm_factory = _ExprField("llm_factory")
llm_name = _ExprField("llm_name")
def __init__(self, id=None, **kwargs):
self.id = id
self.api_key = None
self.status = None
for key, value in kwargs.items():
setattr(self, key, value)
class _TenantLLMRow:
def __init__(
self,
*,
id,
llm_name,
llm_factory,
model_type,
api_key="key",
status="1",
used_tokens=0,
api_base="",
max_tokens=8192,
):
self.id = id
self.llm_name = llm_name
self.llm_factory = llm_factory
self.model_type = model_type
self.api_key = api_key
self.status = status
self.used_tokens = used_tokens
self.api_base = api_base
self.max_tokens = max_tokens
def to_dict(self):
return {
"id": self.id,
"llm_name": self.llm_name,
"llm_factory": self.llm_factory,
"model_type": self.model_type,
"status": self.status,
"used_tokens": self.used_tokens,
"api_base": self.api_base,
"max_tokens": self.max_tokens,
}
class _LLMRow:
def __init__(self, *, llm_name, fid, model_type, status="1", max_tokens=2048):
self.llm_name = llm_name
self.fid = fid
self.model_type = model_type
self.status = status
self.max_tokens = max_tokens
def to_dict(self):
return {
"llm_name": self.llm_name,
"fid": self.fid,
"model_type": self.model_type,
"status": self.status,
"max_tokens": self.max_tokens,
}
def _run(coro):
return asyncio.run(coro)
def _set_request_json(monkeypatch, module, payload):
async def _get_request_json():
return dict(payload)
monkeypatch.setattr(module, "get_request_json", _get_request_json)
def _load_llm_app(monkeypatch):
repo_root = Path(__file__).resolve().parents[3]
quart_mod = ModuleType("quart")
quart_mod.request = SimpleNamespace(args={})
monkeypatch.setitem(sys.modules, "quart", quart_mod)
apps_mod = ModuleType("api.apps")
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
apps_mod.login_required = lambda fn: fn
apps_mod.current_user = SimpleNamespace(id="tenant-1")
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
tenant_llm_mod = ModuleType("api.db.services.tenant_llm_service")
class _StubLLMFactoriesService:
@staticmethod
def query(**_kwargs):
return []
class _StubTenantLLMService:
@staticmethod
def ensure_mineru_from_env(_tenant_id):
return None
@staticmethod
def ensure_opendataloader_from_env(_tenant_id):
return None
@staticmethod
def query(**_kwargs):
return []
@staticmethod
def get_my_llms(_tenant_id):
return []
@staticmethod
def save(**_kwargs):
return True
@staticmethod
def filter_delete(_filters):
return True
@staticmethod
def filter_update(_filters, _payload):
return True
tenant_llm_mod.LLMFactoriesService = _StubLLMFactoriesService
tenant_llm_mod.TenantLLMService = _StubTenantLLMService
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_mod)
llm_service_mod = ModuleType("api.db.services.llm_service")
class _StubLLMService:
@staticmethod
def get_all():
return []
@staticmethod
def query(**_kwargs):
return []
llm_service_mod.LLMService = _StubLLMService
monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod)
api_utils_mod = ModuleType("api.utils.api_utils")
api_utils_mod.get_allowed_llm_factories = lambda: []
api_utils_mod.get_data_error_result = lambda message="", code=400, data=None: {
"code": code,
"message": message,
"data": data,
}
api_utils_mod.get_json_result = lambda data=None, message="", code=0: {
"code": code,
"message": message,
"data": data,
}
async def _get_request_json():
return {}
api_utils_mod.get_request_json = _get_request_json
api_utils_mod.server_error_response = lambda exc: {"code": 500, "message": str(exc), "data": None}
api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn)
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
constants_mod = ModuleType("common.constants")
constants_mod.StatusEnum = SimpleNamespace(VALID=SimpleNamespace(value="1"), INVALID=SimpleNamespace(value="0"))
constants_mod.LLMType = SimpleNamespace(
CHAT=_StrEnum("chat"),
EMBEDDING=_StrEnum("embedding"),
SPEECH2TEXT=_StrEnum("speech2text"),
IMAGE2TEXT=_StrEnum("image2text"),
RERANK=_StrEnum("rerank"),
TTS=_StrEnum("tts"),
OCR=_StrEnum("ocr"),
)
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
db_models_mod = ModuleType("api.db.db_models")
db_models_mod.TenantLLM = _DummyTenantLLMModel
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
base64_mod = ModuleType("rag.utils.base64_image")
base64_mod.test_image = b"image-bytes"
monkeypatch.setitem(sys.modules, "rag.utils.base64_image", base64_mod)
rag_llm_mod = ModuleType("rag.llm")
rag_llm_mod.EmbeddingModel = {}
rag_llm_mod.ChatModel = {}
rag_llm_mod.RerankModel = {}
rag_llm_mod.CvModel = {}
rag_llm_mod.TTSModel = {}
rag_llm_mod.OcrModel = {}
rag_llm_mod.Seq2txtModel = {}
monkeypatch.setitem(sys.modules, "rag.llm", rag_llm_mod)
module_path = repo_root / "api" / "apps" / "llm_app.py"
spec = importlib.util.spec_from_file_location("test_llm_routes_unit_module", module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
spec.loader.exec_module(module)
return module
def _rag_llm_module():
return sys.modules["rag.llm"]
@pytest.mark.p2
def test_llm_list_app_grouping_availability_and_merge(monkeypatch):
module = _load_llm_app(monkeypatch)
ensure_calls = []
monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda tenant_id: ensure_calls.append(tenant_id))
tenant_rows = [
_TenantLLMRow(id=1, llm_name="fast-emb", llm_factory="FastEmbed", model_type="embedding", api_key="k1", status="1"),
_TenantLLMRow(id=2, llm_name="tenant-only", llm_factory="CustomFactory", model_type="chat", api_key="k2", status="1"),
_TenantLLMRow(id=3, llm_name="gpt-5.5", llm_factory="OpenAI", model_type="chat", api_key="k3", status="1"),
_TenantLLMRow(id=4, llm_name="gpt-5.4", llm_factory="OpenAI", model_type="chat", api_key="k4", status="1"),
]
monkeypatch.setattr(module.TenantLLMService, "query", lambda **_kwargs: tenant_rows)
all_llms = [
_LLMRow(llm_name="tei-embed", fid="Builtin", model_type="embedding", status="1"),
_LLMRow(llm_name="fast-emb", fid="FastEmbed", model_type="embedding", status="1"),
_LLMRow(llm_name="gpt-5.5", fid="OpenAI", model_type="chat", status="1"),
_LLMRow(llm_name="gpt-5.4", fid="OpenAI", model_type="chat", status="1"),
_LLMRow(llm_name="not-in-status", fid="Other", model_type="chat", status="1"),
]
monkeypatch.setattr(module.LLMService, "get_all", lambda: all_llms)
monkeypatch.setattr(module, "request", SimpleNamespace(args={}))
monkeypatch.setenv("COMPOSE_PROFILES", "tei-cpu")
monkeypatch.setenv("TEI_MODEL", "tei-embed")
res = _run(module.list_app())
assert res["code"] == 0, res["message"]
assert ensure_calls == ["tenant-1"]
data = res["data"]
assert {"Builtin", "FastEmbed", "CustomFactory", "OpenAI"}.issubset(set(data.keys()))
assert data["Builtin"][0]["llm_name"] == "tei-embed"
assert data["Builtin"][0]["available"] is True
assert data["FastEmbed"][0]["llm_name"] == "fast-emb"
assert data["FastEmbed"][0]["available"] is True
assert data["CustomFactory"][0]["llm_name"] == "tenant-only"
assert data["CustomFactory"][0]["available"] is True
openai_names = {item["llm_name"] for item in data["OpenAI"]}
assert {"gpt-5.5", "gpt-5.4"}.issubset(openai_names)
@pytest.mark.p2
def test_llm_list_app_model_type_filter(monkeypatch):
module = _load_llm_app(monkeypatch)
monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda _tenant_id: None)
monkeypatch.setattr(
module.TenantLLMService,
"query",
lambda **_kwargs: [
_TenantLLMRow(id=1, llm_name="fast-emb", llm_factory="FastEmbed", model_type="embedding", api_key="k1", status="1"),
_TenantLLMRow(id=2, llm_name="tenant-only", llm_factory="CustomFactory", model_type="chat", api_key="k2", status="1"),
],
)
monkeypatch.setattr(
module.LLMService,
"get_all",
lambda: [
_LLMRow(llm_name="tei-embed", fid="Builtin", model_type="embedding", status="1"),
_LLMRow(llm_name="fast-emb", fid="FastEmbed", model_type="embedding", status="1"),
],
)
monkeypatch.setattr(module, "request", SimpleNamespace(args={"model_type": "chat"}))
res = _run(module.list_app())
assert res["code"] == 0, res["message"]
assert list(res["data"].keys()) == ["CustomFactory"]
assert res["data"]["CustomFactory"][0]["model_type"] == "chat"
@pytest.mark.p2
def test_llm_list_app_exception_path(monkeypatch):
module = _load_llm_app(monkeypatch)
monkeypatch.setattr(module, "request", SimpleNamespace(args={}))
monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda _tenant_id: None)
monkeypatch.setattr(
module.TenantLLMService,
"query",
lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("query boom")),
)
res = _run(module.list_app())
assert res["code"] == 500
assert "query boom" in res["message"]
@pytest.mark.p2
def test_llm_factories_route_success_and_exception_unit(monkeypatch):
module = _load_llm_app(monkeypatch)
def _factory(name):
return SimpleNamespace(name=name, to_dict=lambda n=name: {"name": n})
monkeypatch.setattr(
module,
"get_allowed_llm_factories",
lambda: [
_factory("OpenAI"),
_factory("CustomFactory"),
_factory("FastEmbed"),
_factory("Builtin"),
],
)
monkeypatch.setattr(
module.LLMService,
"get_all",
lambda: [
_LLMRow(llm_name="m1", fid="OpenAI", model_type="chat", status="1"),
_LLMRow(llm_name="m2", fid="OpenAI", model_type="embedding", status="1"),
_LLMRow(llm_name="m3", fid="OpenAI", model_type="rerank", status="0"),
],
)
res = module.factories()
assert res["code"] == 0
names = [item["name"] for item in res["data"]]
assert "FastEmbed" not in names
assert "Builtin" not in names
assert {"OpenAI", "CustomFactory"} == set(names)
openai = next(item for item in res["data"] if item["name"] == "OpenAI")
assert {"chat", "embedding"} == set(openai["model_types"])
monkeypatch.setattr(module, "get_allowed_llm_factories", lambda: (_ for _ in ()).throw(RuntimeError("factories boom")))
res = module.factories()
assert res["code"] == 500
assert "factories boom" in res["message"]
@pytest.mark.p2
def test_add_llm_factory_specific_key_assembly_unit(monkeypatch):
module = _load_llm_app(monkeypatch)
async def _wait_for(coro, *_args, **_kwargs):
return await coro
async def _to_thread(fn, *args, **kwargs):
return fn(*args, **kwargs)
monkeypatch.setattr(module.asyncio, "wait_for", _wait_for)
monkeypatch.setattr(module.asyncio, "to_thread", _to_thread)
allowed = [
"VolcEngine",
"Tencent Cloud",
"Bedrock",
"LocalAI",
"HuggingFace",
"OpenAI-API-Compatible",
"VLLM",
"XunFei Spark",
"BaiduYiyan",
"Fish Audio",
"Google Cloud",
"Azure-OpenAI",
"OpenRouter",
"MinerU",
"PaddleOCR",
]
monkeypatch.setattr(module, "get_allowed_llm_factories", lambda: [SimpleNamespace(name=name) for name in allowed])
captured = {"filter_payloads": []}
class _ChatOK:
def __init__(self, *_args, **_kwargs):
pass
async def async_chat(self, *_args, **_kwargs):
return "ok", 1
async def async_chat_streamly(self, *_args, **_kwargs):
yield "ok"
yield 1
class _TTSOK:
def __init__(self, *_args, **_kwargs):
pass
def tts(self, _text):
yield b"ok"
monkeypatch.setattr(_rag_llm_module(), "ChatModel", {name: _ChatOK for name in allowed})
monkeypatch.setattr(_rag_llm_module(), "TTSModel", {"XunFei Spark": _TTSOK})
monkeypatch.setattr(module.TenantLLMService, "filter_update", lambda _filters, payload: captured["filter_payloads"].append(dict(payload)) or True)
reject_req = {"llm_factory": "NotAllowed", "llm_name": "x", "model_type": module.LLMType.CHAT.value}
_set_request_json(monkeypatch, module, reject_req)
res = _run(module.add_llm())
assert res["code"] == 400
assert "is not allowed" in res["message"]
def _run_case(factory, *, model_type=module.LLMType.CHAT.value, extra=None):
req = {"llm_factory": factory, "llm_name": "model", "model_type": model_type, "api_key": "k", "api_base": "http://api"}
if extra:
req.update(extra)
_set_request_json(monkeypatch, module, req)
out = _run(module.add_llm())
assert out["code"] == 0
assert out["data"] is True
return captured["filter_payloads"][-1]
volc = _run_case("VolcEngine", extra={"ark_api_key": "ak", "endpoint_id": "eid"})
assert json.loads(volc["api_key"]) == {"ark_api_key": "ak", "endpoint_id": "eid"}
bedrock = _run_case(
"Bedrock",
extra={"auth_mode": "iam", "bedrock_ak": "ak", "bedrock_sk": "sk", "bedrock_region": "r", "aws_role_arn": "arn"},
)
assert json.loads(bedrock["api_key"]) == {
"auth_mode": "iam",
"bedrock_ak": "ak",
"bedrock_sk": "sk",
"bedrock_region": "r",
"aws_role_arn": "arn",
}
assert _run_case("LocalAI")["llm_name"] == "model___LocalAI"
assert _run_case("HuggingFace")["llm_name"] == "model___HuggingFace"
assert _run_case("OpenAI-API-Compatible")["llm_name"] == "model___OpenAI-API"
assert _run_case("VLLM")["llm_name"] == "model___VLLM"
spark_chat = _run_case("XunFei Spark", extra={"spark_api_password": "spark-pass"})
assert spark_chat["api_key"] == "spark-pass"
spark_tts = _run_case(
"XunFei Spark",
model_type=module.LLMType.TTS.value,
extra={"spark_app_id": "app", "spark_api_secret": "secret", "spark_api_key": "key"},
)
assert json.loads(spark_tts["api_key"]) == {
"spark_app_id": "app",
"spark_api_secret": "secret",
"spark_api_key": "key",
}
assert json.loads(_run_case("BaiduYiyan", extra={"yiyan_ak": "ak", "yiyan_sk": "sk"})["api_key"]) == {"yiyan_ak": "ak", "yiyan_sk": "sk"}
assert json.loads(_run_case("Fish Audio", extra={"fish_audio_ak": "ak", "fish_audio_refid": "rid"})["api_key"]) == {"fish_audio_ak": "ak", "fish_audio_refid": "rid"}
assert json.loads(
_run_case("Google Cloud", extra={"google_project_id": "pid", "google_region": "us", "google_service_account_key": "sak"})["api_key"]
) == {
"google_project_id": "pid",
"google_region": "us",
"google_service_account_key": "sak",
}
assert json.loads(_run_case("Azure-OpenAI", extra={"api_key": "real-key", "api_version": "2024-01-01"})["api_key"]) == {
"api_key": "real-key",
"api_version": "2024-01-01",
}
assert json.loads(_run_case("OpenRouter", extra={"api_key": "or-key", "provider_order": "a,b"})["api_key"]) == {
"api_key": "or-key",
"provider_order": "a,b",
}
assert json.loads(_run_case("MinerU", extra={"api_key": "m-key", "provider_order": "p1"})["api_key"]) == {
"api_key": "m-key",
"provider_order": "p1",
}
assert json.loads(_run_case("PaddleOCR", extra={"api_key": "p-key", "provider_order": "p2"})["api_key"]) == {
"api_key": "p-key",
"provider_order": "p2",
}
tencent_req = {
"llm_factory": "Tencent Cloud",
"llm_name": "model",
"model_type": module.LLMType.CHAT.value,
"tencent_cloud_sid": "sid",
"tencent_cloud_sk": "sk",
}
async def _tencent_request_json():
return tencent_req
monkeypatch.setattr(module, "get_request_json", _tencent_request_json)
delegated = {}
async def _fake_set_api_key():
delegated["api_key"] = tencent_req.get("api_key")
return {"code": 0, "data": "delegated"}
monkeypatch.setattr(module, "set_api_key", _fake_set_api_key)
res = _run(module.add_llm())
assert res["code"] == 0
assert res["data"] == "delegated"
assert json.loads(delegated["api_key"]) == {"tencent_cloud_sid": "sid", "tencent_cloud_sk": "sk"}
@pytest.mark.p2
def test_add_llm_model_type_probe_and_persistence_matrix_unit(monkeypatch):
module = _load_llm_app(monkeypatch)
async def _wait_for(coro, *_args, **_kwargs):
return await coro
async def _to_thread(fn, *args, **kwargs):
return fn(*args, **kwargs)
monkeypatch.setattr(module.asyncio, "wait_for", _wait_for)
monkeypatch.setattr(module.asyncio, "to_thread", _to_thread)
monkeypatch.setattr(
module,
"get_allowed_llm_factories",
lambda: [
SimpleNamespace(name=name)
for name in [
"FEmbFail",
"FEmbPass",
"FChatFail",
"FChatPass",
"FRKey",
"FRFail",
"FImgFail",
"FTTSFail",
"FOcrFail",
"FSttFail",
"FUnknown",
]
],
)
class _EmbeddingFail:
def __init__(self, *_args, **_kwargs):
pass
def encode(self, _texts):
return [[]], 1
class _EmbeddingPass:
def __init__(self, *_args, **_kwargs):
pass
def encode(self, _texts):
return [[0.5]], 1
class _ChatFail:
def __init__(self, *_args, **_kwargs):
pass
async def async_chat(self, *_args, **_kwargs):
return "**ERROR**: chat failed", 0
async def async_chat_streamly(self, *_args, **_kwargs):
yield "**ERROR**: chat failed"
yield 0
class _ChatPass:
def __init__(self, *_args, **_kwargs):
pass
async def async_chat(self, *_args, **_kwargs):
return "ok", 1
async def async_chat_streamly(self, *_args, **_kwargs):
yield "ok"
yield 1
class _RerankFail:
def __init__(self, *_args, **_kwargs):
pass
def similarity(self, *_args, **_kwargs):
return [], 1
class _CvFail:
def __init__(self, *_args, **_kwargs):
pass
def describe(self, _image_data):
return "**ERROR**: image failed", 0
class _TTSFail:
def __init__(self, *_args, **_kwargs):
pass
def tts(self, _text):
raise RuntimeError("tts failed")
class _OcrFail:
def __init__(self, *_args, **_kwargs):
pass
def __call__(self, _img):
return None
class _SttFail:
def __init__(self, *_args, **_kwargs):
pass
def transcribe(self, _audio):
return "", 0
rag_llm_mod = _rag_llm_module()
monkeypatch.setattr(rag_llm_mod, "EmbeddingModel", {"FEmbFail": _EmbeddingFail, "FEmbPass": _EmbeddingPass})
monkeypatch.setattr(rag_llm_mod, "ChatModel", {"FChatFail": _ChatFail, "FChatPass": _ChatPass})
monkeypatch.setattr(rag_llm_mod, "RerankModel", {"FRFail": _RerankFail, "FRKey": _RerankFail})
monkeypatch.setattr(rag_llm_mod, "CvModel", {"FImgFail": _CvFail})
monkeypatch.setattr(rag_llm_mod, "TTSModel", {"FTTSFail": _TTSFail})
monkeypatch.setattr(rag_llm_mod, "OcrModel", {"FOcrFail": _OcrFail})
monkeypatch.setattr(rag_llm_mod, "Seq2txtModel", {"FSttFail": _SttFail})
saves = []
monkeypatch.setattr(module.TenantLLMService, "filter_update", lambda _filters, _payload: False)
monkeypatch.setattr(module.TenantLLMService, "save", lambda **kwargs: saves.append(kwargs) or True)
monkeypatch.setattr(
module.LLMService,
"query",
lambda **kwargs: [] if kwargs.get("llm_factory") == "FUnknown" else [
_LLMRow(llm_name="m", fid=kwargs.get("llm_factory"), model_type=kwargs.get("model_type", module.LLMType.CHAT.value), max_tokens=4096)
],
)
_set_request_json(monkeypatch, module, {"llm_factory": "FUnknown", "llm_name": "m", "model_type": "unknown"})
with pytest.raises(RuntimeError, match="Unknown model type: unknown"):
_run(module.add_llm())
cases = [
("FEmbFail", module.LLMType.EMBEDDING.value, 400, None, "embedding model"),
("FEmbPass", module.LLMType.EMBEDDING.value, 0, True, ""),
("FChatFail", module.LLMType.CHAT.value, 400, None, "No valid response received"),
("FChatPass", module.LLMType.CHAT.value, 0, True, ""),
("FRFail", module.LLMType.RERANK.value, 400, None, "Not known"),
("FImgFail", module.LLMType.IMAGE2TEXT.value, 400, None, "image failed"),
("FTTSFail", module.LLMType.TTS.value, 400, None, "tts"),
("FOcrFail", module.LLMType.OCR.value, 400, None, "ocr"),
("FSttFail", module.LLMType.SPEECH2TEXT.value, 0, True, ""),
]
for factory, model_type, expected_code, expected_data, expected_fragment in cases:
_set_request_json(monkeypatch, module, {"llm_factory": factory, "llm_name": "m", "model_type": model_type, "api_key": "key"})
res = _run(module.add_llm())
assert res["code"] == expected_code
if expected_data is not None:
assert res["data"] is expected_data
if expected_fragment:
assert expected_fragment.lower() in res["message"].lower()
assert any(item["llm_factory"] == "FEmbPass" for item in saves)
@pytest.mark.p2
def test_llm_factories_live_auth_contract():
llm_url = f"{HOST_ADDRESS}/{VERSION}/llm/factories"
invalid_auth_cases = [
(None, 401, "<Unauthorized '401: Unauthorized'>"),
(RAGFlowWebApiAuth("invalid-token"), 401, "<Unauthorized '401: Unauthorized'>"),
]
for auth_obj, expected_code, expected_message in invalid_auth_cases:
res = requests.get(llm_url, auth=auth_obj, timeout=30)
assert res.status_code == 401
payload = res.json()
assert payload["code"] == expected_code, payload
assert payload["message"] == expected_message, payload
ok_res = requests.get(llm_url, auth=RAGFlowWebApiAuth(login()), timeout=30)
assert ok_res.status_code == 200
ok_payload = ok_res.json()
assert ok_payload["code"] == 0, ok_payload
assert isinstance(ok_payload["data"], list), ok_payload
@pytest.mark.p2
def test_llm_list_live_auth_contract():
llm_url = f"{HOST_ADDRESS}/{VERSION}/llm/list"
invalid_auth_cases = [
(None, 401, "<Unauthorized '401: Unauthorized'>"),
(RAGFlowWebApiAuth("invalid-token"), 401, "<Unauthorized '401: Unauthorized'>"),
]
for auth_obj, expected_code, expected_message in invalid_auth_cases:
res = requests.get(llm_url, auth=auth_obj, timeout=30)
assert res.status_code == 401
payload = res.json()
assert payload["code"] == expected_code, payload
assert payload["message"] == expected_message, payload
ok_res = requests.get(llm_url, auth=RAGFlowWebApiAuth(login()), timeout=30)
assert ok_res.status_code == 200
ok_payload = ok_res.json()
assert ok_payload["code"] == 0, ok_payload
assert isinstance(ok_payload["data"], dict), ok_payload

View File

@@ -0,0 +1,151 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import inspect
import sys
from copy import deepcopy
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _AwaitableValue:
def __init__(self, value):
self._value = value
def __await__(self):
async def _co():
return self._value
return _co().__await__()
class _DummyArgs(dict):
def getlist(self, key):
value = self.get(key)
if value is None:
return []
if isinstance(value, list):
return value
return [value]
class _DummyMemoryApiService:
async def add_message(self, *_args, **_kwargs):
return True, "ok"
async def get_messages(self, *_args, **_kwargs):
return []
def _run(coro):
return asyncio.run(coro)
def _load_memory_routes_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[3]
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
apps_mod = ModuleType("api.apps")
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
apps_mod.current_user = SimpleNamespace(id="user-1")
apps_mod.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
services_mod = ModuleType("api.apps.services")
services_mod.memory_api_service = _DummyMemoryApiService()
monkeypatch.setitem(sys.modules, "api.apps.services", services_mod)
module_name = "test_message_routes_unit_module"
module_path = repo_root / "api" / "apps" / "restful_apis" / "memory_api.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, module_name, module)
spec.loader.exec_module(module)
return module
def _set_request_json(monkeypatch, module, payload):
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(deepcopy(payload)))
@pytest.mark.p2
def test_add_message_partial_failure_branch(monkeypatch):
module = _load_memory_routes_module(monkeypatch)
_set_request_json(
monkeypatch,
module,
{
"memory_id": ["memory-1"],
"agent_id": "agent-1",
"session_id": "session-1",
"user_input": "hello",
"agent_response": "world",
},
)
async def _add_message(_memory_ids, _message_dict):
return False, "cannot enqueue"
monkeypatch.setattr(module.memory_api_service, "add_message", _add_message)
res = _run(inspect.unwrap(module.add_message)())
assert res["code"] == module.RetCode.SERVER_ERROR, res
assert "Some messages failed to add" in res["message"], res
@pytest.mark.p2
def test_get_messages_csv_and_missing_memory_ids(monkeypatch):
module = _load_memory_routes_module(monkeypatch)
monkeypatch.setattr(module, "request", SimpleNamespace(args=_DummyArgs({})))
res = _run(inspect.unwrap(module.get_messages)())
assert res["code"] == module.RetCode.ARGUMENT_ERROR, res
assert "memory_ids is required." in res["message"], res
monkeypatch.setattr(
module,
"request",
SimpleNamespace(args=_DummyArgs({"memory_id": "m1,m2", "agent_id": "a1", "session_id": "s1", "limit": "5"})),
)
async def _get_messages(memory_ids, agent_id, session_id, limit):
assert memory_ids == ["m1", "m2"]
assert agent_id == "a1"
assert session_id == "s1"
assert limit == 5
return [{"message_id": 1}]
monkeypatch.setattr(module.memory_api_service, "get_messages", _get_messages)
res = _run(inspect.unwrap(module.get_messages)())
assert res["code"] == module.RetCode.SUCCESS, res
assert isinstance(res["data"], list), res

View File

@@ -16,8 +16,7 @@
from concurrent.futures import ThreadPoolExecutor
import pytest
import requests
from test.testcases.configs import HOST_ADDRESS, INVALID_API_TOKEN, VERSION
from test.testcases.configs import INVALID_API_TOKEN
from test.testcases.restful_api.helpers.client import RestClient
from test.testcases.utils import wait_for
@@ -378,12 +377,8 @@ def test_deleted_chunks_batch_not_in_retrieval_contract(rest_client, create_docu
@pytest.mark.p2
def test_related_questions_contract(auth, rest_client, rest_client_noauth):
tokens_res = requests.get(
f"{HOST_ADDRESS}/api/{VERSION}/system/tokens",
headers={"Authorization": auth},
timeout=30,
)
def test_related_questions_contract(rest_client, rest_client_noauth):
tokens_res = rest_client.get("/system/tokens")
assert tokens_res.status_code == 200, tokens_res.text
tokens_payload = tokens_res.json()
assert tokens_payload["code"] == 0, tokens_payload

View File

@@ -15,14 +15,36 @@
#
import json
from concurrent.futures import ThreadPoolExecutor
import pytest
from test.testcases.configs import INVALID_API_TOKEN, INVALID_ID_32, SESSION_WITH_CHAT_NAME_LIMIT
from test.testcases.restful_api.helpers.client import RestClient
from test.testcases.utils import is_sorted
def _sse_events(response_text: str) -> list[str]:
return [line[5:] for line in response_text.splitlines() if line.startswith("data:")]
def _session_names(payload):
return [session["name"] for session in payload["data"]]
def _seed_sessions(rest_client, create_chat, prefix, count=5):
chat_id = create_chat(f"{prefix}_chat")
sessions = []
for index in range(count):
name = f"{prefix}_{index}"
res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": name})
assert res.status_code == 200, (prefix, index, res.text)
payload = res.json()
assert payload["code"] == 0, (prefix, index, payload)
sessions.append(payload["data"])
return chat_id, sessions
@pytest.mark.p1
def test_session_crud_cycle(rest_client, create_chat):
chat_id = create_chat("restful_session_crud_chat")
@@ -100,6 +122,475 @@ def test_session_update_blocks_messages_and_reference(rest_client, create_chat):
assert "`reference` cannot be changed." in ref_payload["message"], ref_payload
@pytest.mark.p1
def test_session_create_requires_auth_and_invalid_chat_contract():
for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))):
res = client.post("/chats/chat_id/sessions", json={"name": "x"})
assert res.status_code == 401, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == 401, (scenario_name, payload)
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", (scenario_name, payload)
@pytest.mark.p2
def test_session_create_validation_and_deleted_chat_contract(rest_client, create_chat):
chat_id = create_chat("restful_session_create_contract")
empty_path_res = rest_client.post("/chats//sessions", json={"name": "valid_name"})
assert empty_path_res.status_code == 200
empty_path_payload = empty_path_res.json()
assert empty_path_payload["code"] == 100, empty_path_payload
assert empty_path_payload["message"] == "<MethodNotAllowed '405: Method Not Allowed'>", empty_path_payload
invalid_chat_res = rest_client.post("/chats/invalid_chat_assistant_id/sessions", json={"name": "valid_name"})
assert invalid_chat_res.status_code == 200
invalid_chat_payload = invalid_chat_res.json()
assert invalid_chat_payload["code"] == 109, invalid_chat_payload
assert invalid_chat_payload["message"] == "No authorization.", invalid_chat_payload
for scenario_name, payload in (
("valid", {"name": "valid_name"}),
("empty", {"name": ""}),
("space", {"name": " "}),
("numeric", {"name": 1}),
):
res = rest_client.post(f"/chats/{chat_id}/sessions", json=payload)
assert res.status_code == 200, (scenario_name, res.text)
body = res.json()
if scenario_name == "valid":
assert body["code"] == 0, (scenario_name, body)
assert body["data"]["name"] == "valid_name", (scenario_name, body)
assert body["data"]["chat_id"] == chat_id, (scenario_name, body)
else:
assert body["code"] == 102, (scenario_name, body)
assert body["message"] == "`name` can not be empty.", (scenario_name, body)
duplicate_first = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "duplicated_name"}).json()
duplicate_second = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "duplicated_name"}).json()
assert duplicate_first["code"] == 0, duplicate_first
assert duplicate_second["code"] == 0, duplicate_second
assert duplicate_second["data"]["name"] == "duplicated_name", duplicate_second
upper_case = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "CASE INSENSITIVE"}).json()
lower_case = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "case insensitive"}).json()
assert upper_case["code"] == 0, upper_case
assert lower_case["code"] == 0, lower_case
assert upper_case["data"]["name"] == "CASE INSENSITIVE", upper_case
assert lower_case["data"]["name"] == "case insensitive", lower_case
long_name = "a" * (SESSION_WITH_CHAT_NAME_LIMIT + 1)
long_name_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": long_name})
assert long_name_res.status_code == 200
long_name_payload = long_name_res.json()
assert long_name_payload["code"] == 0, long_name_payload
assert long_name_payload["data"]["name"] == long_name[:SESSION_WITH_CHAT_NAME_LIMIT], long_name_payload
delete_res = rest_client.delete("/chats", json={"ids": [chat_id]})
assert delete_res.status_code == 200
delete_payload = delete_res.json()
assert delete_payload["code"] == 0, delete_payload
create_after_delete = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "after_delete"})
assert create_after_delete.status_code == 200
create_after_delete_payload = create_after_delete.json()
assert create_after_delete_payload["code"] == 109, create_after_delete_payload
assert create_after_delete_payload["message"] == "No authorization.", create_after_delete_payload
@pytest.mark.p2
def test_session_create_concurrent_contract(rest_client, create_chat):
chat_id = create_chat("restful_session_create_concurrent")
def _create(index):
return rest_client.post(f"/chats/{chat_id}/sessions", json={"name": f"session create {index}"}).json()
with ThreadPoolExecutor(max_workers=5) as executor:
results = list(executor.map(_create, range(20)))
assert len(results) == 20, results
assert all(result["code"] == 0 for result in results), results
list_res = rest_client.get(f"/chats/{chat_id}/sessions", params={"page_size": 30})
assert list_res.status_code == 200
list_payload = list_res.json()
assert list_payload["code"] == 0, list_payload
assert len(list_payload["data"]) == 20, list_payload
@pytest.mark.p1
def test_session_delete_requires_auth_and_invalid_target_contract(rest_client, create_chat):
chat_id = create_chat("restful_session_delete_auth")
create_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_a"})
assert create_res.status_code == 200
session_id = create_res.json()["data"]["id"]
for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))):
res = client.delete(f"/chats/{chat_id}/sessions", json={"ids": [session_id]})
assert res.status_code == 401, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == 401, (scenario_name, payload)
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", (scenario_name, payload)
invalid_chat_res = rest_client.delete("/chats/invalid_chat_assistant_id/sessions", json={"ids": [session_id]})
assert invalid_chat_res.status_code == 200
invalid_chat_payload = invalid_chat_res.json()
assert invalid_chat_payload["code"] == 109, invalid_chat_payload
assert invalid_chat_payload["message"] == "No authorization.", invalid_chat_payload
@pytest.mark.p2
def test_session_delete_basic_scenarios(rest_client, create_chat):
cases = [
("none payload", None, 0, 5, {}),
("invalid only", {"ids": ["invalid_id"]}, 102, 5, "The chat doesn't own the session invalid_id"),
("not json", "not json", 100, 5, "<BadRequest '400: Bad Request'>"),
("single id", lambda sessions: {"ids": [sessions[0]["id"]]}, 0, 4, True),
("all ids", lambda sessions: {"ids": [session["id"] for session in sessions]}, 0, 0, True),
("delete all", {"delete_all": True}, 0, 0, True),
("empty ids", {"ids": []}, 0, 5, {}),
]
for scenario_name, payload, expected_code, expected_remaining, expected_data in cases:
chat_id, sessions = _seed_sessions(rest_client, create_chat, f"delete_basic_{scenario_name.replace(' ', '_')}")
if callable(payload):
payload = payload(sessions)
if scenario_name == "not json":
res = rest_client.delete(
f"/chats/{chat_id}/sessions",
headers={"Content-Type": "application/json"},
data=payload,
)
else:
res = rest_client.delete(f"/chats/{chat_id}/sessions", json=payload)
assert res.status_code == 200, (scenario_name, res.text)
body = res.json()
assert body["code"] == expected_code, (scenario_name, body)
if expected_code == 0:
assert body["data"] == expected_data, (scenario_name, body)
else:
assert body["message"] == expected_data, (scenario_name, body)
list_res = rest_client.get(f"/chats/{chat_id}/sessions", params={"page_size": 30})
assert list_res.status_code == 200, (scenario_name, list_res.text)
list_payload = list_res.json()
assert list_payload["code"] == 0, (scenario_name, list_payload)
assert len(list_payload["data"]) == expected_remaining, (scenario_name, list_payload)
@pytest.mark.p2
def test_session_delete_error_and_repeat_contract(rest_client, create_chat):
partial_cases = [
("invalid first", lambda sessions: {"ids": ["invalid_id"] + [session["id"] for session in sessions]}),
("invalid middle", lambda sessions: {"ids": [sessions[0]["id"], "invalid_id", *[session["id"] for session in sessions[1:]]]}),
("invalid last", lambda sessions: {"ids": [session["id"] for session in sessions] + ["invalid_id"]}),
]
for scenario_name, payload_builder in partial_cases:
chat_id, sessions = _seed_sessions(rest_client, create_chat, f"delete_partial_{scenario_name.replace(' ', '_')}")
res = rest_client.delete(f"/chats/{chat_id}/sessions", json=payload_builder(sessions))
assert res.status_code == 200, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == 0, (scenario_name, payload)
assert payload["data"]["success_count"] == len(sessions), (scenario_name, payload)
assert payload["data"]["errors"] == ["The chat doesn't own the session invalid_id"], (scenario_name, payload)
remaining = rest_client.get(f"/chats/{chat_id}/sessions", params={"page_size": 30}).json()
assert remaining["code"] == 0, (scenario_name, remaining)
assert remaining["data"] == [], (scenario_name, remaining)
duplicate_chat_id, duplicate_sessions = _seed_sessions(rest_client, create_chat, "delete_duplicate")
duplicate_id = duplicate_sessions[0]["id"]
duplicate_res = rest_client.delete(f"/chats/{duplicate_chat_id}/sessions", json={"ids": [duplicate_id, duplicate_id]})
assert duplicate_res.status_code == 200
duplicate_payload = duplicate_res.json()
assert duplicate_payload["code"] == 0, duplicate_payload
assert duplicate_payload["data"]["success_count"] == 1, duplicate_payload
assert duplicate_payload["data"]["errors"] == [f"Duplicate session ids: {duplicate_id}"], duplicate_payload
repeated_chat_id, repeated_sessions = _seed_sessions(rest_client, create_chat, "delete_repeated")
repeated_ids = [session["id"] for session in repeated_sessions]
first_res = rest_client.delete(f"/chats/{repeated_chat_id}/sessions", json={"ids": repeated_ids})
assert first_res.status_code == 200
first_payload = first_res.json()
assert first_payload["code"] == 0, first_payload
assert first_payload["data"] is True, first_payload
second_res = rest_client.delete(f"/chats/{repeated_chat_id}/sessions", json={"ids": repeated_ids})
assert second_res.status_code == 200
second_payload = second_res.json()
assert second_payload["code"] == 102, second_payload
for session_id in repeated_ids:
assert f"The chat doesn't own the session {session_id}" in second_payload["message"], second_payload
@pytest.mark.p2
def test_session_delete_concurrent_and_bulk_contract(rest_client, create_chat):
concurrent_chat_id, concurrent_sessions = _seed_sessions(rest_client, create_chat, "delete_concurrent", count=20)
def _delete(session):
return rest_client.delete(f"/chats/{concurrent_chat_id}/sessions", json={"ids": [session["id"]]}).json()
with ThreadPoolExecutor(max_workers=5) as executor:
results = list(executor.map(_delete, concurrent_sessions))
assert len(results) == 20, results
assert all(result["code"] == 0 for result in results), results
assert all(result["data"] is True for result in results), results
list_after_concurrent = rest_client.get(f"/chats/{concurrent_chat_id}/sessions", params={"page_size": 30}).json()
assert list_after_concurrent["code"] == 0, list_after_concurrent
assert list_after_concurrent["data"] == [], list_after_concurrent
bulk_chat_id, bulk_sessions = _seed_sessions(rest_client, create_chat, "delete_bulk", count=100)
bulk_res = rest_client.delete(
f"/chats/{bulk_chat_id}/sessions",
json={"ids": [session["id"] for session in bulk_sessions]},
)
assert bulk_res.status_code == 200
bulk_payload = bulk_res.json()
assert bulk_payload["code"] == 0, bulk_payload
assert bulk_payload["data"] is True, bulk_payload
@pytest.mark.p1
def test_session_list_requires_auth_and_invalid_target_contract():
for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))):
res = client.get("/chats/chat_id/sessions")
assert res.status_code == 401, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == 401, (scenario_name, payload)
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", (scenario_name, payload)
@pytest.mark.p2
def test_session_list_filter_and_deleted_chat_contract(rest_client, create_chat):
chat_id, sessions = _seed_sessions(rest_client, create_chat, "list_filter")
session_ids = [session["id"] for session in sessions]
session_names = [session["name"] for session in sessions]
default_res = rest_client.get(f"/chats/{chat_id}/sessions", params={"page_size": 30})
assert default_res.status_code == 200
default_payload = default_res.json()
assert default_payload["code"] == 0, default_payload
assert len(default_payload["data"]) == 5, default_payload
for scenario_name, params, expected_names in (
("id none", {"id": None, "page_size": 30}, session_names),
("id empty", {"id": "", "page_size": 30}, session_names),
("valid id", {"id": session_ids[0], "page_size": 30}, [session_names[0]]),
("unknown id", {"id": "unknown", "page_size": 30}, []),
("name none", {"name": None, "page_size": 30}, session_names),
("name empty", {"name": "", "page_size": 30}, session_names),
("name exact", {"name": session_names[1], "page_size": 30}, [session_names[1]]),
("name unknown", {"name": "unknown", "page_size": 30}, []),
("name and id match", {"id": session_ids[0], "name": session_names[0], "page_size": 30}, [session_names[0]]),
("name and id mismatch", {"id": session_ids[0], "name": "session_with_chat_assistant_100", "page_size": 30}, []),
("name and invalid id", {"id": "id", "name": session_names[0], "page_size": 30}, []),
("invalid params ignored", {"a": "b", "page_size": 30}, session_names),
):
res = rest_client.get(f"/chats/{chat_id}/sessions", params=params)
assert res.status_code == 200, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == 0, (scenario_name, payload)
assert set(_session_names(payload)) == set(expected_names), (scenario_name, payload)
invalid_chat_res = rest_client.get(f"/chats/{INVALID_ID_32}/sessions")
assert invalid_chat_res.status_code == 200
invalid_chat_payload = invalid_chat_res.json()
assert invalid_chat_payload["code"] == 109, invalid_chat_payload
assert invalid_chat_payload["message"] == "No authorization.", invalid_chat_payload
delete_chat_res = rest_client.delete("/chats", json={"ids": [chat_id]})
assert delete_chat_res.status_code == 200
delete_chat_payload = delete_chat_res.json()
assert delete_chat_payload["code"] == 0, delete_chat_payload
deleted_list_res = rest_client.get(f"/chats/{chat_id}/sessions", params={"page_size": 30})
assert deleted_list_res.status_code == 200
deleted_list_payload = deleted_list_res.json()
assert deleted_list_payload["code"] == 109, deleted_list_payload
assert deleted_list_payload["message"] == "No authorization.", deleted_list_payload
@pytest.mark.p2
def test_session_list_page_and_sort_contract(rest_client, create_chat):
chat_id, sessions = _seed_sessions(rest_client, create_chat, "list_page_sort")
created_names = [session["name"] for session in sessions]
descending_names = list(reversed(created_names))
page_cases = [
("page none", {"page": None, "page_size": 2}, 0, 2, ""),
("page zero", {"page": 0, "page_size": 2}, 0, 2, ""),
("page two", {"page": 2, "page_size": 2}, 0, 2, ""),
("page three", {"page": 3, "page_size": 2}, 0, 1, ""),
("page string", {"page": "3", "page_size": 2}, 0, 1, ""),
("page negative", {"page": -1, "page_size": 2}, 100, 0, "ProgrammingError(1064"),
("page alpha", {"page": "a", "page_size": 2}, 100, 0, "ValueError(\"invalid literal for int() with base 10: 'a'\")"),
("page_size none", {"page_size": None}, 0, 5, ""),
("page_size zero", {"page_size": 0}, 0, 0, ""),
("page_size one", {"page_size": 1}, 0, 1, ""),
("page_size six", {"page_size": 6}, 0, 5, ""),
("page_size negative", {"page_size": -1}, 0, 5, ""),
("page_size alpha", {"page_size": "a"}, 100, 0, "ValueError(\"invalid literal for int() with base 10: 'a'\")"),
]
for scenario_name, params, expected_code, expected_count, expected_message in page_cases:
res = rest_client.get(f"/chats/{chat_id}/sessions", params=params)
assert res.status_code == 200, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == expected_code, (scenario_name, payload)
if expected_code == 0:
assert len(payload["data"]) == expected_count, (scenario_name, payload)
else:
assert expected_message in payload["message"], (scenario_name, payload)
sort_cases = [
("orderby none", {"orderby": None, "page_size": 30}, "create_time", True, descending_names, ""),
("orderby create", {"orderby": "create_time", "page_size": 30}, "create_time", True, descending_names, ""),
("orderby update", {"orderby": "update_time", "page_size": 30}, "update_time", True, descending_names, ""),
("orderby name ascending", {"orderby": "name", "desc": "False", "page_size": 30}, "name", False, created_names, ""),
("orderby unknown", {"orderby": "unknown", "page_size": 30}, None, None, None, "AttributeError(\"type object 'Conversation' has no attribute 'unknown'\")"),
("desc none", {"desc": None, "page_size": 30}, "create_time", True, descending_names, ""),
("desc true", {"desc": "true", "page_size": 30}, "create_time", True, descending_names, ""),
("desc True", {"desc": "True", "page_size": 30}, "create_time", True, descending_names, ""),
("desc false", {"desc": "false", "page_size": 30}, "create_time", False, created_names, ""),
("desc False", {"desc": "False", "page_size": 30}, "create_time", False, created_names, ""),
("desc false update_time", {"desc": "False", "orderby": "update_time", "page_size": 30}, "update_time", False, created_names, ""),
("desc unknown", {"desc": "unknown", "page_size": 30}, "create_time", True, descending_names, ""),
]
for scenario_name, params, field, descending, expected_names, expected_message in sort_cases:
res = rest_client.get(f"/chats/{chat_id}/sessions", params=params)
assert res.status_code == 200, (scenario_name, res.text)
payload = res.json()
expected_code = 0 if expected_names is not None else 100
assert payload["code"] == expected_code, (scenario_name, payload)
if expected_code == 0:
assert is_sorted(payload["data"], field, descending), (scenario_name, payload)
assert _session_names(payload) == expected_names, (scenario_name, payload)
else:
assert expected_message in payload["message"], (scenario_name, payload)
@pytest.mark.p2
def test_session_list_concurrent_contract(rest_client, create_chat):
chat_id, _sessions = _seed_sessions(rest_client, create_chat, "list_concurrent")
def _list(_):
return rest_client.get(f"/chats/{chat_id}/sessions", params={"page_size": 30}).json()
with ThreadPoolExecutor(max_workers=5) as executor:
results = list(executor.map(_list, range(10)))
assert len(results) == 10, results
assert all(result["code"] == 0 for result in results), results
assert all(len(result["data"]) == 5 for result in results), results
@pytest.mark.p1
def test_session_update_requires_auth_and_invalid_target_contract(rest_client, create_chat):
chat_id = create_chat("restful_session_update_auth")
create_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_update_auth"})
assert create_res.status_code == 200
session_id = create_res.json()["data"]["id"]
for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))):
res = client.patch(f"/chats/{chat_id}/sessions/{session_id}", json={"name": "x"})
assert res.status_code == 401, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == 401, (scenario_name, payload)
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", (scenario_name, payload)
invalid_chat_res = rest_client.patch(f"/chats/{INVALID_ID_32}/sessions/{session_id}", json={"name": "x"})
assert invalid_chat_res.status_code == 200
invalid_chat_payload = invalid_chat_res.json()
assert invalid_chat_payload["code"] == 109, invalid_chat_payload
assert invalid_chat_payload["message"] == "No authorization.", invalid_chat_payload
empty_session_res = rest_client.patch(f"/chats/{chat_id}/sessions/", json={"name": "x"})
assert empty_session_res.status_code == 200
empty_session_payload = empty_session_res.json()
assert empty_session_payload["code"] == 100, empty_session_payload
assert empty_session_payload["message"] == "<MethodNotAllowed '405: Method Not Allowed'>", empty_session_payload
invalid_session_res = rest_client.patch(f"/chats/{chat_id}/sessions/invalid_session_id", json={"name": "x"})
assert invalid_session_res.status_code == 200
invalid_session_payload = invalid_session_res.json()
assert invalid_session_payload["code"] == 102, invalid_session_payload
assert invalid_session_payload["message"] == "Session not found!", invalid_session_payload
@pytest.mark.p2
def test_session_update_name_and_param_contract(rest_client, create_chat):
chat_id, sessions = _seed_sessions(rest_client, create_chat, "update_contract")
session_id = sessions[0]["id"]
for scenario_name, payload, expected_code, expected_name_or_message in (
("valid", {"name": "valid_name"}, 0, "valid_name"),
("empty", {"name": ""}, 102, "`name` can not be empty."),
("numeric", {"name": 1}, 102, "`name` can not be empty."),
("duplicate", {"name": "duplicated_name"}, 0, "duplicated_name"),
("case insensitive upper", {"name": "CASE INSENSITIVE UPDATE"}, 0, "CASE INSENSITIVE UPDATE"),
("case insensitive lower", {"name": "case insensitive update"}, 0, "case insensitive update"),
("long name", {"name": "a" * (SESSION_WITH_CHAT_NAME_LIMIT + 1)}, 0, "a" * SESSION_WITH_CHAT_NAME_LIMIT),
):
res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json=payload)
assert res.status_code == 200, (scenario_name, res.text)
body = res.json()
assert body["code"] == expected_code, (scenario_name, body)
if expected_code == 0:
assert body["data"]["name"] == expected_name_or_message, (scenario_name, body)
else:
assert body["message"] == expected_name_or_message, (scenario_name, body)
unknown_key_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json={"unknown_key": "unknown_value"})
assert unknown_key_res.status_code == 200
unknown_key_payload = unknown_key_res.json()
assert unknown_key_payload["code"] == 100, unknown_key_payload
assert 'Unrecognized field name: "unknown_key"' in unknown_key_payload["message"], unknown_key_payload
for scenario_name, payload in (("empty payload", {}), ("none payload", None)):
res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json=payload)
assert res.status_code == 200, (scenario_name, res.text)
body = res.json()
assert body["code"] == 0, (scenario_name, body)
assert body["data"]["id"] == session_id, (scenario_name, body)
delete_res = rest_client.delete("/chats", json={"ids": [chat_id]})
assert delete_res.status_code == 200
delete_payload = delete_res.json()
assert delete_payload["code"] == 0, delete_payload
update_after_delete_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json={"name": "after_delete"})
assert update_after_delete_res.status_code == 200
update_after_delete_payload = update_after_delete_res.json()
assert update_after_delete_payload["code"] == 109, update_after_delete_payload
assert update_after_delete_payload["message"] == "No authorization.", update_after_delete_payload
@pytest.mark.p2
def test_session_update_repeated_and_concurrent_contract(rest_client, create_chat):
chat_id, sessions = _seed_sessions(rest_client, create_chat, "update_repeated")
session_ids = [session["id"] for session in sessions]
first_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_ids[0]}", json={"name": "valid_name_1"})
assert first_res.status_code == 200
assert first_res.json()["code"] == 0, first_res.json()
second_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_ids[0]}", json={"name": "valid_name_2"})
assert second_res.status_code == 200
assert second_res.json()["code"] == 0, second_res.json()
def _update(index):
return rest_client.patch(
f"/chats/{chat_id}/sessions/{session_ids[index % len(session_ids)]}",
json={"name": f"update session test {index}"},
).json()
with ThreadPoolExecutor(max_workers=5) as executor:
results = list(executor.map(_update, range(20)))
assert len(results) == 20, results
assert all(result["code"] == 0 for result in results), results
@pytest.mark.p2
def test_chat_recommendation_requires_question(rest_client):
res = rest_client.post("/chat/recommendation", json={})
@@ -120,7 +611,11 @@ def test_related_questions_compatibility_requires_auth(rest_client_noauth):
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 102, payload
assert "Authorization is not valid!" in payload["message"], payload
assert payload["message"] in {
"Authorization is not valid!",
'Authentication error: API key is invalid!"',
"Authentication error: API key is invalid!",
}, payload
@pytest.mark.p2
@@ -226,7 +721,10 @@ def test_chat_completion_validation_errors(rest_client, create_chat):
assert missing_messages.status_code == 200
missing_messages_payload = missing_messages.json()
assert missing_messages_payload["code"] == 101, missing_messages_payload
assert "required argument are missing: messages" in missing_messages_payload["message"], missing_messages_payload
assert missing_messages_payload["message"] in {
"required argument are missing: messages",
"messages: is required",
}, missing_messages_payload
missing_chat_for_session = rest_client.post(
"/chat/completions",