mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Feat: add new tests and tescases for restful api suite (#15347)
### What problem does this PR solve? extend restful api suite ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Other (please describe): test
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import logging
|
||||
import sys
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
@@ -24,6 +25,9 @@ from types import ModuleType, SimpleNamespace
|
||||
import pytest
|
||||
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
@@ -370,12 +374,15 @@ class _DummyManager:
|
||||
|
||||
|
||||
class _DummyFile:
|
||||
def __init__(self, file_id, file_type, *, name="file.txt", location="loc", size=1):
|
||||
def __init__(self, file_id, file_type, *, name="file.txt", location="loc", size=1, tenant_id="tenant1", parent_id="pf1", source_type="user"):
|
||||
self.id = file_id
|
||||
self.type = file_type
|
||||
self.name = name
|
||||
self.location = location
|
||||
self.size = size
|
||||
self.tenant_id = tenant_id
|
||||
self.parent_id = parent_id
|
||||
self.source_type = source_type
|
||||
|
||||
|
||||
class _FalsyFile(_DummyFile):
|
||||
@@ -630,3 +637,291 @@ def test_convert_branch_matrix_unit(monkeypatch):
|
||||
res = _run(module.convert())
|
||||
assert res["code"] == 500
|
||||
assert "convert boom" in res["message"]
|
||||
|
||||
|
||||
def _load_file_api_service(monkeypatch):
|
||||
LOGGER.debug("_load_file_api_service: entry")
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
|
||||
api_pkg = ModuleType("api")
|
||||
api_pkg.__path__ = [str(repo_root / "api")]
|
||||
monkeypatch.setitem(sys.modules, "api", api_pkg)
|
||||
LOGGER.debug("_load_file_api_service: mocked api package")
|
||||
|
||||
common_pkg = ModuleType("api.common")
|
||||
common_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "api.common", common_pkg)
|
||||
|
||||
permission_mod = ModuleType("api.common.check_team_permission")
|
||||
permission_mod.check_file_team_permission = lambda *_args, **_kwargs: True
|
||||
monkeypatch.setitem(sys.modules, "api.common.check_team_permission", permission_mod)
|
||||
common_pkg.check_team_permission = permission_mod
|
||||
|
||||
db_pkg = ModuleType("api.db")
|
||||
db_pkg.__path__ = []
|
||||
|
||||
class _ServiceFileType(Enum):
|
||||
FOLDER = "folder"
|
||||
VIRTUAL = "virtual"
|
||||
DOC = "doc"
|
||||
VISUAL = "visual"
|
||||
|
||||
db_pkg.FileType = _ServiceFileType
|
||||
monkeypatch.setitem(sys.modules, "api.db", db_pkg)
|
||||
api_pkg.db = db_pkg
|
||||
|
||||
services_pkg = ModuleType("api.db.services")
|
||||
services_pkg.__path__ = []
|
||||
services_pkg.duplicate_name = lambda _query, **kwargs: kwargs.get("name", "")
|
||||
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
|
||||
|
||||
document_service_mod = ModuleType("api.db.services.document_service")
|
||||
document_service_mod.DocumentService = SimpleNamespace(
|
||||
get_doc_count=lambda _uid: 0,
|
||||
get_by_id=lambda doc_id: (True, SimpleNamespace(id=doc_id)),
|
||||
get_tenant_id=lambda _doc_id: "tenant1",
|
||||
remove_document=lambda *_args, **_kwargs: True,
|
||||
update_by_id=lambda *_args, **_kwargs: True,
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.document_service", document_service_mod)
|
||||
services_pkg.document_service = document_service_mod
|
||||
|
||||
file2doc_mod = ModuleType("api.db.services.file2document_service")
|
||||
file2doc_mod.File2DocumentService = SimpleNamespace(
|
||||
get_by_file_id=lambda _file_id: [],
|
||||
delete_by_file_id=lambda _file_id: None,
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2doc_mod)
|
||||
services_pkg.file2document_service = file2doc_mod
|
||||
|
||||
file_service_mod = ModuleType("api.db.services.file_service")
|
||||
file_service_mod.FileService = SimpleNamespace(
|
||||
get_root_folder=lambda _tenant_id: {"id": "root"},
|
||||
get_by_id=lambda file_id: (True, _DummyFile(file_id, _ServiceFileType.DOC.value)),
|
||||
get_id_list_by_id=lambda _pf_id, _names, _idx, ids: ids,
|
||||
create_folder=lambda _file, parent_id, _names, _len_id: SimpleNamespace(id=parent_id, name=str(parent_id)),
|
||||
query=lambda **_kwargs: [],
|
||||
insert=lambda data: SimpleNamespace(to_json=lambda: data, **data),
|
||||
is_parent_folder_exist=lambda _pf_id: True,
|
||||
get_by_pf_id=lambda *_args, **_kwargs: ([], 0),
|
||||
get_parent_folder=lambda _file_id: SimpleNamespace(to_json=lambda: {"id": "root"}),
|
||||
get_all_parent_folders=lambda _file_id: [],
|
||||
list_all_files_by_parent_id=lambda _parent_id: [],
|
||||
delete=lambda _file: True,
|
||||
delete_by_id=lambda _file_id: True,
|
||||
update_by_id=lambda *_args, **_kwargs: True,
|
||||
get_by_ids=lambda file_ids: [_DummyFile(file_id, _ServiceFileType.DOC.value) for file_id in file_ids],
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service_mod)
|
||||
services_pkg.file_service = file_service_mod
|
||||
LOGGER.debug("_load_file_api_service: mocked api.db.services.file_service")
|
||||
|
||||
file_utils_mod = ModuleType("api.utils.file_utils")
|
||||
file_utils_mod.filename_type = lambda _filename: _ServiceFileType.DOC.value
|
||||
monkeypatch.setitem(sys.modules, "api.utils.file_utils", file_utils_mod)
|
||||
|
||||
common_root_mod = ModuleType("common")
|
||||
common_root_mod.__path__ = [str(repo_root / "common")]
|
||||
common_root_mod.settings = SimpleNamespace(
|
||||
STORAGE_IMPL=SimpleNamespace(
|
||||
obj_exist=lambda *_args, **_kwargs: False,
|
||||
put=lambda *_args, **_kwargs: None,
|
||||
rm=lambda *_args, **_kwargs: None,
|
||||
move=lambda *_args, **_kwargs: None,
|
||||
)
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "common", common_root_mod)
|
||||
|
||||
constants_mod = ModuleType("common.constants")
|
||||
|
||||
class _FileSource:
|
||||
KNOWLEDGEBASE = "knowledgebase"
|
||||
|
||||
constants_mod.FileSource = _FileSource
|
||||
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
|
||||
|
||||
misc_utils_mod = ModuleType("common.misc_utils")
|
||||
misc_utils_mod.get_uuid = lambda: "uuid-1"
|
||||
|
||||
async def thread_pool_exec(func, *args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
misc_utils_mod.thread_pool_exec = thread_pool_exec
|
||||
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
|
||||
|
||||
module_path = repo_root / "api" / "apps" / "services" / "file_api_service.py"
|
||||
spec = importlib.util.spec_from_file_location("api.apps.services.file_api_service", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
monkeypatch.setitem(sys.modules, "api.apps.services.file_api_service", module)
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"_load_file_api_service: spec.loader.exec_module(module) failed"
|
||||
)
|
||||
raise
|
||||
LOGGER.debug("_load_file_api_service: spec.loader.exec_module(module) completed")
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_upload_file_requires_existing_folder(monkeypatch):
|
||||
module = _load_file_api_service(monkeypatch)
|
||||
monkeypatch.setattr(module.FileService, "get_by_id", lambda _file_id: (False, None))
|
||||
|
||||
ok, message = _run(module.upload_file("tenant1", "pf1", [_DummyUploadFile("a.txt")]))
|
||||
assert ok is False
|
||||
assert message == "Can't find this folder!"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_upload_file_respects_user_limit(monkeypatch):
|
||||
module = _load_file_api_service(monkeypatch)
|
||||
monkeypatch.setattr(module.FileService, "get_by_id", lambda _file_id: (True, SimpleNamespace(id="pf1", name="pf1")))
|
||||
monkeypatch.setattr(module.DocumentService, "get_doc_count", lambda _uid: 1)
|
||||
monkeypatch.setenv("MAX_FILE_NUM_PER_USER", "1")
|
||||
|
||||
ok, message = _run(module.upload_file("tenant1", "pf1", [_DummyUploadFile("a.txt")]))
|
||||
assert ok is False
|
||||
assert message == "Exceed the maximum file number of a free user!"
|
||||
monkeypatch.delenv("MAX_FILE_NUM_PER_USER", raising=False)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_upload_file_success_uses_new_service_layer(monkeypatch):
|
||||
module = _load_file_api_service(monkeypatch)
|
||||
storage_puts = []
|
||||
|
||||
monkeypatch.setattr(module.FileService, "get_by_id", lambda _file_id: (True, SimpleNamespace(id="pf1", name="pf1")))
|
||||
monkeypatch.setattr(module.FileService, "get_id_list_by_id", lambda *_args, **_kwargs: ["pf1"])
|
||||
monkeypatch.setattr(
|
||||
module.FileService,
|
||||
"create_folder",
|
||||
lambda _file, parent_id, _names, _len_id, *_args: SimpleNamespace(id=parent_id),
|
||||
)
|
||||
monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(
|
||||
obj_exist=lambda *_args, **_kwargs: False,
|
||||
put=lambda bucket, location, blob: storage_puts.append((bucket, location, blob)),
|
||||
rm=lambda *_args, **_kwargs: None,
|
||||
move=lambda *_args, **_kwargs: None,
|
||||
))
|
||||
|
||||
ok, data = _run(module.upload_file("tenant1", "pf1", [_DummyUploadFile("a.txt", b"hello")]))
|
||||
assert ok is True
|
||||
assert data[0]["name"] == "a.txt"
|
||||
assert storage_puts == [("pf1", "a.txt", b"hello")]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_create_folder_rejects_duplicate_name(monkeypatch):
|
||||
module = _load_file_api_service(monkeypatch)
|
||||
monkeypatch.setattr(module.FileService, "query", lambda **_kwargs: [SimpleNamespace(id="existing")])
|
||||
|
||||
ok, message = _run(module.create_folder("tenant1", "dup", "pf1", module.FileType.FOLDER.value))
|
||||
assert ok is False
|
||||
assert message == "Duplicated folder name in the same folder."
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_delete_files_checks_team_permission(monkeypatch):
|
||||
module = _load_file_api_service(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
module.FileService,
|
||||
"get_by_id",
|
||||
lambda _file_id: (True, _DummyFile("file1", module.FileType.DOC.value)),
|
||||
)
|
||||
monkeypatch.setattr(module, "check_file_team_permission", lambda *_args, **_kwargs: False)
|
||||
|
||||
ok, message = _run(module.delete_files("tenant1", ["file1"]))
|
||||
assert ok is False
|
||||
assert message == {"success_count": 0, "errors": ["No authorization for file file1"]}
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_move_files_rejects_extension_change_in_new_name(monkeypatch):
|
||||
module = _load_file_api_service(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
module.FileService,
|
||||
"get_by_ids",
|
||||
lambda _ids: [_DummyFile("file1", module.FileType.DOC.value, name="a.txt")],
|
||||
)
|
||||
|
||||
ok, message = _run(module.move_files("tenant1", ["file1"], new_name="a.pdf"))
|
||||
assert ok is False
|
||||
assert message == "The extension of file can't be changed"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_move_files_handles_dest_and_storage_move(monkeypatch):
|
||||
module = _load_file_api_service(monkeypatch)
|
||||
moved = []
|
||||
updated = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.FileService,
|
||||
"get_by_id",
|
||||
lambda file_id: (False, None) if file_id == "missing" else (True, _DummyFile(file_id, module.FileType.FOLDER.value, name="dest")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
module.FileService,
|
||||
"get_by_ids",
|
||||
lambda _ids: [_DummyFile("file1", module.FileType.DOC.value, parent_id="src", location="old", name="a.txt")],
|
||||
)
|
||||
monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(
|
||||
obj_exist=lambda *_args, **_kwargs: False,
|
||||
put=lambda *_args, **_kwargs: None,
|
||||
rm=lambda *_args, **_kwargs: None,
|
||||
move=lambda old_bucket, old_loc, new_bucket, new_loc: moved.append((old_bucket, old_loc, new_bucket, new_loc)),
|
||||
))
|
||||
monkeypatch.setattr(module.FileService, "update_by_id", lambda file_id, data: updated.append((file_id, data)) or True)
|
||||
|
||||
ok, message = _run(module.move_files("tenant1", ["file1"], "missing"))
|
||||
assert ok is False
|
||||
assert message == "Parent folder not found!"
|
||||
|
||||
ok, data = _run(module.move_files("tenant1", ["file1"], "dest"))
|
||||
assert ok is True
|
||||
assert data is True
|
||||
assert moved == [("src", "old", "dest", "a.txt")]
|
||||
assert updated == [("file1", {"parent_id": "dest", "location": "a.txt"})]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_move_files_renames_in_place_without_storage_move(monkeypatch):
|
||||
module = _load_file_api_service(monkeypatch)
|
||||
db_updates = []
|
||||
doc_updates = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.FileService,
|
||||
"get_by_ids",
|
||||
lambda _ids: [_DummyFile("file1", module.FileType.DOC.value, parent_id="pf1", name="a.txt")],
|
||||
)
|
||||
monkeypatch.setattr(module.FileService, "update_by_id", lambda file_id, data: db_updates.append((file_id, data)) or True)
|
||||
monkeypatch.setattr(
|
||||
module.File2DocumentService,
|
||||
"get_by_file_id",
|
||||
lambda _file_id: [SimpleNamespace(document_id="doc1")],
|
||||
)
|
||||
monkeypatch.setattr(module.DocumentService, "update_by_id", lambda doc_id, data: doc_updates.append((doc_id, data)) or True)
|
||||
|
||||
ok, data = _run(module.move_files("tenant1", ["file1"], new_name="b.txt"))
|
||||
assert ok is True
|
||||
assert data is True
|
||||
assert db_updates == [("file1", {"name": "b.txt"})]
|
||||
assert doc_updates == [("doc1", {"name": "b.txt"})]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_get_file_content_checks_permission(monkeypatch):
|
||||
module = _load_file_api_service(monkeypatch)
|
||||
monkeypatch.setattr(module, "check_file_team_permission", lambda *_args, **_kwargs: False)
|
||||
|
||||
ok, message = module.get_file_content("tenant1", "file1")
|
||||
assert ok is False
|
||||
assert message == "No authorization."
|
||||
|
||||
monkeypatch.setattr(module, "check_file_team_permission", lambda *_args, **_kwargs: True)
|
||||
ok, file = module.get_file_content("tenant1", "file1")
|
||||
assert ok is True
|
||||
assert file.id == "file1"
|
||||
|
||||
742
test/testcases/restful_api/test_llm_routes_unit.py
Normal file
742
test/testcases/restful_api/test_llm_routes_unit.py
Normal file
@@ -0,0 +1,742 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from configs import HOST_ADDRESS, VERSION
|
||||
from test.testcases.conftest import login
|
||||
from test.testcases.libs.auth import RAGFlowWebApiAuth
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _ExprField:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.name, other)
|
||||
|
||||
|
||||
class _StrEnum(str):
|
||||
@property
|
||||
def value(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
class _DummyTenantLLMModel:
|
||||
tenant_id = _ExprField("tenant_id")
|
||||
llm_factory = _ExprField("llm_factory")
|
||||
llm_name = _ExprField("llm_name")
|
||||
|
||||
def __init__(self, id=None, **kwargs):
|
||||
self.id = id
|
||||
self.api_key = None
|
||||
self.status = None
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class _TenantLLMRow:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id,
|
||||
llm_name,
|
||||
llm_factory,
|
||||
model_type,
|
||||
api_key="key",
|
||||
status="1",
|
||||
used_tokens=0,
|
||||
api_base="",
|
||||
max_tokens=8192,
|
||||
):
|
||||
self.id = id
|
||||
self.llm_name = llm_name
|
||||
self.llm_factory = llm_factory
|
||||
self.model_type = model_type
|
||||
self.api_key = api_key
|
||||
self.status = status
|
||||
self.used_tokens = used_tokens
|
||||
self.api_base = api_base
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"llm_name": self.llm_name,
|
||||
"llm_factory": self.llm_factory,
|
||||
"model_type": self.model_type,
|
||||
"status": self.status,
|
||||
"used_tokens": self.used_tokens,
|
||||
"api_base": self.api_base,
|
||||
"max_tokens": self.max_tokens,
|
||||
}
|
||||
|
||||
|
||||
class _LLMRow:
|
||||
def __init__(self, *, llm_name, fid, model_type, status="1", max_tokens=2048):
|
||||
self.llm_name = llm_name
|
||||
self.fid = fid
|
||||
self.model_type = model_type
|
||||
self.status = status
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"llm_name": self.llm_name,
|
||||
"fid": self.fid,
|
||||
"model_type": self.model_type,
|
||||
"status": self.status,
|
||||
"max_tokens": self.max_tokens,
|
||||
}
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _set_request_json(monkeypatch, module, payload):
|
||||
async def _get_request_json():
|
||||
return dict(payload)
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", _get_request_json)
|
||||
|
||||
|
||||
def _load_llm_app(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
|
||||
quart_mod = ModuleType("quart")
|
||||
quart_mod.request = SimpleNamespace(args={})
|
||||
monkeypatch.setitem(sys.modules, "quart", quart_mod)
|
||||
|
||||
apps_mod = ModuleType("api.apps")
|
||||
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
|
||||
apps_mod.login_required = lambda fn: fn
|
||||
apps_mod.current_user = SimpleNamespace(id="tenant-1")
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
|
||||
tenant_llm_mod = ModuleType("api.db.services.tenant_llm_service")
|
||||
|
||||
class _StubLLMFactoriesService:
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return []
|
||||
|
||||
class _StubTenantLLMService:
|
||||
@staticmethod
|
||||
def ensure_mineru_from_env(_tenant_id):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def ensure_opendataloader_from_env(_tenant_id):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_my_llms(_tenant_id):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def save(**_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def filter_delete(_filters):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def filter_update(_filters, _payload):
|
||||
return True
|
||||
|
||||
tenant_llm_mod.LLMFactoriesService = _StubLLMFactoriesService
|
||||
tenant_llm_mod.TenantLLMService = _StubTenantLLMService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_mod)
|
||||
|
||||
llm_service_mod = ModuleType("api.db.services.llm_service")
|
||||
|
||||
class _StubLLMService:
|
||||
@staticmethod
|
||||
def get_all():
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return []
|
||||
|
||||
llm_service_mod.LLMService = _StubLLMService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod)
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
api_utils_mod.get_allowed_llm_factories = lambda: []
|
||||
api_utils_mod.get_data_error_result = lambda message="", code=400, data=None: {
|
||||
"code": code,
|
||||
"message": message,
|
||||
"data": data,
|
||||
}
|
||||
api_utils_mod.get_json_result = lambda data=None, message="", code=0: {
|
||||
"code": code,
|
||||
"message": message,
|
||||
"data": data,
|
||||
}
|
||||
|
||||
async def _get_request_json():
|
||||
return {}
|
||||
|
||||
api_utils_mod.get_request_json = _get_request_json
|
||||
api_utils_mod.server_error_response = lambda exc: {"code": 500, "message": str(exc), "data": None}
|
||||
api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn)
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
|
||||
constants_mod = ModuleType("common.constants")
|
||||
constants_mod.StatusEnum = SimpleNamespace(VALID=SimpleNamespace(value="1"), INVALID=SimpleNamespace(value="0"))
|
||||
constants_mod.LLMType = SimpleNamespace(
|
||||
CHAT=_StrEnum("chat"),
|
||||
EMBEDDING=_StrEnum("embedding"),
|
||||
SPEECH2TEXT=_StrEnum("speech2text"),
|
||||
IMAGE2TEXT=_StrEnum("image2text"),
|
||||
RERANK=_StrEnum("rerank"),
|
||||
TTS=_StrEnum("tts"),
|
||||
OCR=_StrEnum("ocr"),
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
|
||||
|
||||
db_models_mod = ModuleType("api.db.db_models")
|
||||
db_models_mod.TenantLLM = _DummyTenantLLMModel
|
||||
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
|
||||
|
||||
base64_mod = ModuleType("rag.utils.base64_image")
|
||||
base64_mod.test_image = b"image-bytes"
|
||||
monkeypatch.setitem(sys.modules, "rag.utils.base64_image", base64_mod)
|
||||
|
||||
rag_llm_mod = ModuleType("rag.llm")
|
||||
rag_llm_mod.EmbeddingModel = {}
|
||||
rag_llm_mod.ChatModel = {}
|
||||
rag_llm_mod.RerankModel = {}
|
||||
rag_llm_mod.CvModel = {}
|
||||
rag_llm_mod.TTSModel = {}
|
||||
rag_llm_mod.OcrModel = {}
|
||||
rag_llm_mod.Seq2txtModel = {}
|
||||
monkeypatch.setitem(sys.modules, "rag.llm", rag_llm_mod)
|
||||
|
||||
module_path = repo_root / "api" / "apps" / "llm_app.py"
|
||||
spec = importlib.util.spec_from_file_location("test_llm_routes_unit_module", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _rag_llm_module():
|
||||
return sys.modules["rag.llm"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_llm_list_app_grouping_availability_and_merge(monkeypatch):
|
||||
module = _load_llm_app(monkeypatch)
|
||||
|
||||
ensure_calls = []
|
||||
monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda tenant_id: ensure_calls.append(tenant_id))
|
||||
|
||||
tenant_rows = [
|
||||
_TenantLLMRow(id=1, llm_name="fast-emb", llm_factory="FastEmbed", model_type="embedding", api_key="k1", status="1"),
|
||||
_TenantLLMRow(id=2, llm_name="tenant-only", llm_factory="CustomFactory", model_type="chat", api_key="k2", status="1"),
|
||||
_TenantLLMRow(id=3, llm_name="gpt-5.5", llm_factory="OpenAI", model_type="chat", api_key="k3", status="1"),
|
||||
_TenantLLMRow(id=4, llm_name="gpt-5.4", llm_factory="OpenAI", model_type="chat", api_key="k4", status="1"),
|
||||
]
|
||||
monkeypatch.setattr(module.TenantLLMService, "query", lambda **_kwargs: tenant_rows)
|
||||
|
||||
all_llms = [
|
||||
_LLMRow(llm_name="tei-embed", fid="Builtin", model_type="embedding", status="1"),
|
||||
_LLMRow(llm_name="fast-emb", fid="FastEmbed", model_type="embedding", status="1"),
|
||||
_LLMRow(llm_name="gpt-5.5", fid="OpenAI", model_type="chat", status="1"),
|
||||
_LLMRow(llm_name="gpt-5.4", fid="OpenAI", model_type="chat", status="1"),
|
||||
_LLMRow(llm_name="not-in-status", fid="Other", model_type="chat", status="1"),
|
||||
]
|
||||
monkeypatch.setattr(module.LLMService, "get_all", lambda: all_llms)
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args={}))
|
||||
monkeypatch.setenv("COMPOSE_PROFILES", "tei-cpu")
|
||||
monkeypatch.setenv("TEI_MODEL", "tei-embed")
|
||||
|
||||
res = _run(module.list_app())
|
||||
assert res["code"] == 0, res["message"]
|
||||
assert ensure_calls == ["tenant-1"]
|
||||
|
||||
data = res["data"]
|
||||
assert {"Builtin", "FastEmbed", "CustomFactory", "OpenAI"}.issubset(set(data.keys()))
|
||||
assert data["Builtin"][0]["llm_name"] == "tei-embed"
|
||||
assert data["Builtin"][0]["available"] is True
|
||||
assert data["FastEmbed"][0]["llm_name"] == "fast-emb"
|
||||
assert data["FastEmbed"][0]["available"] is True
|
||||
assert data["CustomFactory"][0]["llm_name"] == "tenant-only"
|
||||
assert data["CustomFactory"][0]["available"] is True
|
||||
openai_names = {item["llm_name"] for item in data["OpenAI"]}
|
||||
assert {"gpt-5.5", "gpt-5.4"}.issubset(openai_names)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_llm_list_app_model_type_filter(monkeypatch):
|
||||
module = _load_llm_app(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda _tenant_id: None)
|
||||
monkeypatch.setattr(
|
||||
module.TenantLLMService,
|
||||
"query",
|
||||
lambda **_kwargs: [
|
||||
_TenantLLMRow(id=1, llm_name="fast-emb", llm_factory="FastEmbed", model_type="embedding", api_key="k1", status="1"),
|
||||
_TenantLLMRow(id=2, llm_name="tenant-only", llm_factory="CustomFactory", model_type="chat", api_key="k2", status="1"),
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
module.LLMService,
|
||||
"get_all",
|
||||
lambda: [
|
||||
_LLMRow(llm_name="tei-embed", fid="Builtin", model_type="embedding", status="1"),
|
||||
_LLMRow(llm_name="fast-emb", fid="FastEmbed", model_type="embedding", status="1"),
|
||||
],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args={"model_type": "chat"}))
|
||||
res = _run(module.list_app())
|
||||
assert res["code"] == 0, res["message"]
|
||||
assert list(res["data"].keys()) == ["CustomFactory"]
|
||||
assert res["data"]["CustomFactory"][0]["model_type"] == "chat"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_llm_list_app_exception_path(monkeypatch):
|
||||
module = _load_llm_app(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args={}))
|
||||
monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda _tenant_id: None)
|
||||
monkeypatch.setattr(
|
||||
module.TenantLLMService,
|
||||
"query",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("query boom")),
|
||||
)
|
||||
|
||||
res = _run(module.list_app())
|
||||
assert res["code"] == 500
|
||||
assert "query boom" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_llm_factories_route_success_and_exception_unit(monkeypatch):
|
||||
module = _load_llm_app(monkeypatch)
|
||||
|
||||
def _factory(name):
|
||||
return SimpleNamespace(name=name, to_dict=lambda n=name: {"name": n})
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_allowed_llm_factories",
|
||||
lambda: [
|
||||
_factory("OpenAI"),
|
||||
_factory("CustomFactory"),
|
||||
_factory("FastEmbed"),
|
||||
_factory("Builtin"),
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
module.LLMService,
|
||||
"get_all",
|
||||
lambda: [
|
||||
_LLMRow(llm_name="m1", fid="OpenAI", model_type="chat", status="1"),
|
||||
_LLMRow(llm_name="m2", fid="OpenAI", model_type="embedding", status="1"),
|
||||
_LLMRow(llm_name="m3", fid="OpenAI", model_type="rerank", status="0"),
|
||||
],
|
||||
)
|
||||
res = module.factories()
|
||||
assert res["code"] == 0
|
||||
names = [item["name"] for item in res["data"]]
|
||||
assert "FastEmbed" not in names
|
||||
assert "Builtin" not in names
|
||||
assert {"OpenAI", "CustomFactory"} == set(names)
|
||||
openai = next(item for item in res["data"] if item["name"] == "OpenAI")
|
||||
assert {"chat", "embedding"} == set(openai["model_types"])
|
||||
|
||||
monkeypatch.setattr(module, "get_allowed_llm_factories", lambda: (_ for _ in ()).throw(RuntimeError("factories boom")))
|
||||
res = module.factories()
|
||||
assert res["code"] == 500
|
||||
assert "factories boom" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_add_llm_factory_specific_key_assembly_unit(monkeypatch):
|
||||
module = _load_llm_app(monkeypatch)
|
||||
|
||||
async def _wait_for(coro, *_args, **_kwargs):
|
||||
return await coro
|
||||
|
||||
async def _to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(module.asyncio, "wait_for", _wait_for)
|
||||
monkeypatch.setattr(module.asyncio, "to_thread", _to_thread)
|
||||
|
||||
allowed = [
|
||||
"VolcEngine",
|
||||
"Tencent Cloud",
|
||||
"Bedrock",
|
||||
"LocalAI",
|
||||
"HuggingFace",
|
||||
"OpenAI-API-Compatible",
|
||||
"VLLM",
|
||||
"XunFei Spark",
|
||||
"BaiduYiyan",
|
||||
"Fish Audio",
|
||||
"Google Cloud",
|
||||
"Azure-OpenAI",
|
||||
"OpenRouter",
|
||||
"MinerU",
|
||||
"PaddleOCR",
|
||||
]
|
||||
monkeypatch.setattr(module, "get_allowed_llm_factories", lambda: [SimpleNamespace(name=name) for name in allowed])
|
||||
|
||||
captured = {"filter_payloads": []}
|
||||
|
||||
class _ChatOK:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
async def async_chat(self, *_args, **_kwargs):
|
||||
return "ok", 1
|
||||
|
||||
async def async_chat_streamly(self, *_args, **_kwargs):
|
||||
yield "ok"
|
||||
yield 1
|
||||
|
||||
class _TTSOK:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def tts(self, _text):
|
||||
yield b"ok"
|
||||
|
||||
monkeypatch.setattr(_rag_llm_module(), "ChatModel", {name: _ChatOK for name in allowed})
|
||||
monkeypatch.setattr(_rag_llm_module(), "TTSModel", {"XunFei Spark": _TTSOK})
|
||||
monkeypatch.setattr(module.TenantLLMService, "filter_update", lambda _filters, payload: captured["filter_payloads"].append(dict(payload)) or True)
|
||||
|
||||
reject_req = {"llm_factory": "NotAllowed", "llm_name": "x", "model_type": module.LLMType.CHAT.value}
|
||||
_set_request_json(monkeypatch, module, reject_req)
|
||||
res = _run(module.add_llm())
|
||||
assert res["code"] == 400
|
||||
assert "is not allowed" in res["message"]
|
||||
|
||||
def _run_case(factory, *, model_type=module.LLMType.CHAT.value, extra=None):
|
||||
req = {"llm_factory": factory, "llm_name": "model", "model_type": model_type, "api_key": "k", "api_base": "http://api"}
|
||||
if extra:
|
||||
req.update(extra)
|
||||
_set_request_json(monkeypatch, module, req)
|
||||
out = _run(module.add_llm())
|
||||
assert out["code"] == 0
|
||||
assert out["data"] is True
|
||||
return captured["filter_payloads"][-1]
|
||||
|
||||
volc = _run_case("VolcEngine", extra={"ark_api_key": "ak", "endpoint_id": "eid"})
|
||||
assert json.loads(volc["api_key"]) == {"ark_api_key": "ak", "endpoint_id": "eid"}
|
||||
|
||||
bedrock = _run_case(
|
||||
"Bedrock",
|
||||
extra={"auth_mode": "iam", "bedrock_ak": "ak", "bedrock_sk": "sk", "bedrock_region": "r", "aws_role_arn": "arn"},
|
||||
)
|
||||
assert json.loads(bedrock["api_key"]) == {
|
||||
"auth_mode": "iam",
|
||||
"bedrock_ak": "ak",
|
||||
"bedrock_sk": "sk",
|
||||
"bedrock_region": "r",
|
||||
"aws_role_arn": "arn",
|
||||
}
|
||||
|
||||
assert _run_case("LocalAI")["llm_name"] == "model___LocalAI"
|
||||
assert _run_case("HuggingFace")["llm_name"] == "model___HuggingFace"
|
||||
assert _run_case("OpenAI-API-Compatible")["llm_name"] == "model___OpenAI-API"
|
||||
assert _run_case("VLLM")["llm_name"] == "model___VLLM"
|
||||
|
||||
spark_chat = _run_case("XunFei Spark", extra={"spark_api_password": "spark-pass"})
|
||||
assert spark_chat["api_key"] == "spark-pass"
|
||||
spark_tts = _run_case(
|
||||
"XunFei Spark",
|
||||
model_type=module.LLMType.TTS.value,
|
||||
extra={"spark_app_id": "app", "spark_api_secret": "secret", "spark_api_key": "key"},
|
||||
)
|
||||
assert json.loads(spark_tts["api_key"]) == {
|
||||
"spark_app_id": "app",
|
||||
"spark_api_secret": "secret",
|
||||
"spark_api_key": "key",
|
||||
}
|
||||
|
||||
assert json.loads(_run_case("BaiduYiyan", extra={"yiyan_ak": "ak", "yiyan_sk": "sk"})["api_key"]) == {"yiyan_ak": "ak", "yiyan_sk": "sk"}
|
||||
assert json.loads(_run_case("Fish Audio", extra={"fish_audio_ak": "ak", "fish_audio_refid": "rid"})["api_key"]) == {"fish_audio_ak": "ak", "fish_audio_refid": "rid"}
|
||||
assert json.loads(
|
||||
_run_case("Google Cloud", extra={"google_project_id": "pid", "google_region": "us", "google_service_account_key": "sak"})["api_key"]
|
||||
) == {
|
||||
"google_project_id": "pid",
|
||||
"google_region": "us",
|
||||
"google_service_account_key": "sak",
|
||||
}
|
||||
assert json.loads(_run_case("Azure-OpenAI", extra={"api_key": "real-key", "api_version": "2024-01-01"})["api_key"]) == {
|
||||
"api_key": "real-key",
|
||||
"api_version": "2024-01-01",
|
||||
}
|
||||
assert json.loads(_run_case("OpenRouter", extra={"api_key": "or-key", "provider_order": "a,b"})["api_key"]) == {
|
||||
"api_key": "or-key",
|
||||
"provider_order": "a,b",
|
||||
}
|
||||
assert json.loads(_run_case("MinerU", extra={"api_key": "m-key", "provider_order": "p1"})["api_key"]) == {
|
||||
"api_key": "m-key",
|
||||
"provider_order": "p1",
|
||||
}
|
||||
assert json.loads(_run_case("PaddleOCR", extra={"api_key": "p-key", "provider_order": "p2"})["api_key"]) == {
|
||||
"api_key": "p-key",
|
||||
"provider_order": "p2",
|
||||
}
|
||||
|
||||
tencent_req = {
|
||||
"llm_factory": "Tencent Cloud",
|
||||
"llm_name": "model",
|
||||
"model_type": module.LLMType.CHAT.value,
|
||||
"tencent_cloud_sid": "sid",
|
||||
"tencent_cloud_sk": "sk",
|
||||
}
|
||||
|
||||
async def _tencent_request_json():
|
||||
return tencent_req
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", _tencent_request_json)
|
||||
delegated = {}
|
||||
|
||||
async def _fake_set_api_key():
|
||||
delegated["api_key"] = tencent_req.get("api_key")
|
||||
return {"code": 0, "data": "delegated"}
|
||||
|
||||
monkeypatch.setattr(module, "set_api_key", _fake_set_api_key)
|
||||
res = _run(module.add_llm())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] == "delegated"
|
||||
assert json.loads(delegated["api_key"]) == {"tencent_cloud_sid": "sid", "tencent_cloud_sk": "sk"}
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_add_llm_model_type_probe_and_persistence_matrix_unit(monkeypatch):
|
||||
module = _load_llm_app(monkeypatch)
|
||||
|
||||
async def _wait_for(coro, *_args, **_kwargs):
|
||||
return await coro
|
||||
|
||||
async def _to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(module.asyncio, "wait_for", _wait_for)
|
||||
monkeypatch.setattr(module.asyncio, "to_thread", _to_thread)
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_allowed_llm_factories",
|
||||
lambda: [
|
||||
SimpleNamespace(name=name)
|
||||
for name in [
|
||||
"FEmbFail",
|
||||
"FEmbPass",
|
||||
"FChatFail",
|
||||
"FChatPass",
|
||||
"FRKey",
|
||||
"FRFail",
|
||||
"FImgFail",
|
||||
"FTTSFail",
|
||||
"FOcrFail",
|
||||
"FSttFail",
|
||||
"FUnknown",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
class _EmbeddingFail:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def encode(self, _texts):
|
||||
return [[]], 1
|
||||
|
||||
class _EmbeddingPass:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def encode(self, _texts):
|
||||
return [[0.5]], 1
|
||||
|
||||
class _ChatFail:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
async def async_chat(self, *_args, **_kwargs):
|
||||
return "**ERROR**: chat failed", 0
|
||||
|
||||
async def async_chat_streamly(self, *_args, **_kwargs):
|
||||
yield "**ERROR**: chat failed"
|
||||
yield 0
|
||||
|
||||
class _ChatPass:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
async def async_chat(self, *_args, **_kwargs):
|
||||
return "ok", 1
|
||||
|
||||
async def async_chat_streamly(self, *_args, **_kwargs):
|
||||
yield "ok"
|
||||
yield 1
|
||||
|
||||
class _RerankFail:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def similarity(self, *_args, **_kwargs):
|
||||
return [], 1
|
||||
|
||||
class _CvFail:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def describe(self, _image_data):
|
||||
return "**ERROR**: image failed", 0
|
||||
|
||||
class _TTSFail:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def tts(self, _text):
|
||||
raise RuntimeError("tts failed")
|
||||
|
||||
class _OcrFail:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, _img):
|
||||
return None
|
||||
|
||||
class _SttFail:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def transcribe(self, _audio):
|
||||
return "", 0
|
||||
|
||||
rag_llm_mod = _rag_llm_module()
|
||||
monkeypatch.setattr(rag_llm_mod, "EmbeddingModel", {"FEmbFail": _EmbeddingFail, "FEmbPass": _EmbeddingPass})
|
||||
monkeypatch.setattr(rag_llm_mod, "ChatModel", {"FChatFail": _ChatFail, "FChatPass": _ChatPass})
|
||||
monkeypatch.setattr(rag_llm_mod, "RerankModel", {"FRFail": _RerankFail, "FRKey": _RerankFail})
|
||||
monkeypatch.setattr(rag_llm_mod, "CvModel", {"FImgFail": _CvFail})
|
||||
monkeypatch.setattr(rag_llm_mod, "TTSModel", {"FTTSFail": _TTSFail})
|
||||
monkeypatch.setattr(rag_llm_mod, "OcrModel", {"FOcrFail": _OcrFail})
|
||||
monkeypatch.setattr(rag_llm_mod, "Seq2txtModel", {"FSttFail": _SttFail})
|
||||
|
||||
saves = []
|
||||
monkeypatch.setattr(module.TenantLLMService, "filter_update", lambda _filters, _payload: False)
|
||||
monkeypatch.setattr(module.TenantLLMService, "save", lambda **kwargs: saves.append(kwargs) or True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.LLMService,
|
||||
"query",
|
||||
lambda **kwargs: [] if kwargs.get("llm_factory") == "FUnknown" else [
|
||||
_LLMRow(llm_name="m", fid=kwargs.get("llm_factory"), model_type=kwargs.get("model_type", module.LLMType.CHAT.value), max_tokens=4096)
|
||||
],
|
||||
)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"llm_factory": "FUnknown", "llm_name": "m", "model_type": "unknown"})
|
||||
with pytest.raises(RuntimeError, match="Unknown model type: unknown"):
|
||||
_run(module.add_llm())
|
||||
|
||||
cases = [
|
||||
("FEmbFail", module.LLMType.EMBEDDING.value, 400, None, "embedding model"),
|
||||
("FEmbPass", module.LLMType.EMBEDDING.value, 0, True, ""),
|
||||
("FChatFail", module.LLMType.CHAT.value, 400, None, "No valid response received"),
|
||||
("FChatPass", module.LLMType.CHAT.value, 0, True, ""),
|
||||
("FRFail", module.LLMType.RERANK.value, 400, None, "Not known"),
|
||||
("FImgFail", module.LLMType.IMAGE2TEXT.value, 400, None, "image failed"),
|
||||
("FTTSFail", module.LLMType.TTS.value, 400, None, "tts"),
|
||||
("FOcrFail", module.LLMType.OCR.value, 400, None, "ocr"),
|
||||
("FSttFail", module.LLMType.SPEECH2TEXT.value, 0, True, ""),
|
||||
]
|
||||
for factory, model_type, expected_code, expected_data, expected_fragment in cases:
|
||||
_set_request_json(monkeypatch, module, {"llm_factory": factory, "llm_name": "m", "model_type": model_type, "api_key": "key"})
|
||||
res = _run(module.add_llm())
|
||||
assert res["code"] == expected_code
|
||||
if expected_data is not None:
|
||||
assert res["data"] is expected_data
|
||||
if expected_fragment:
|
||||
assert expected_fragment.lower() in res["message"].lower()
|
||||
|
||||
assert any(item["llm_factory"] == "FEmbPass" for item in saves)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_llm_factories_live_auth_contract():
|
||||
llm_url = f"{HOST_ADDRESS}/{VERSION}/llm/factories"
|
||||
invalid_auth_cases = [
|
||||
(None, 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
(RAGFlowWebApiAuth("invalid-token"), 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
]
|
||||
for auth_obj, expected_code, expected_message in invalid_auth_cases:
|
||||
res = requests.get(llm_url, auth=auth_obj, timeout=30)
|
||||
assert res.status_code == 401
|
||||
payload = res.json()
|
||||
assert payload["code"] == expected_code, payload
|
||||
assert payload["message"] == expected_message, payload
|
||||
|
||||
ok_res = requests.get(llm_url, auth=RAGFlowWebApiAuth(login()), timeout=30)
|
||||
assert ok_res.status_code == 200
|
||||
ok_payload = ok_res.json()
|
||||
assert ok_payload["code"] == 0, ok_payload
|
||||
assert isinstance(ok_payload["data"], list), ok_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_llm_list_live_auth_contract():
|
||||
llm_url = f"{HOST_ADDRESS}/{VERSION}/llm/list"
|
||||
invalid_auth_cases = [
|
||||
(None, 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
(RAGFlowWebApiAuth("invalid-token"), 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
]
|
||||
for auth_obj, expected_code, expected_message in invalid_auth_cases:
|
||||
res = requests.get(llm_url, auth=auth_obj, timeout=30)
|
||||
assert res.status_code == 401
|
||||
payload = res.json()
|
||||
assert payload["code"] == expected_code, payload
|
||||
assert payload["message"] == expected_message, payload
|
||||
|
||||
ok_res = requests.get(llm_url, auth=RAGFlowWebApiAuth(login()), timeout=30)
|
||||
assert ok_res.status_code == 200
|
||||
ok_payload = ok_res.json()
|
||||
assert ok_payload["code"] == 0, ok_payload
|
||||
assert isinstance(ok_payload["data"], dict), ok_payload
|
||||
151
test/testcases/restful_api/test_message_routes_unit.py
Normal file
151
test/testcases/restful_api/test_message_routes_unit.py
Normal file
@@ -0,0 +1,151 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import inspect
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _AwaitableValue:
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def __await__(self):
|
||||
async def _co():
|
||||
return self._value
|
||||
|
||||
return _co().__await__()
|
||||
|
||||
|
||||
class _DummyArgs(dict):
|
||||
def getlist(self, key):
|
||||
value = self.get(key)
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
return [value]
|
||||
|
||||
|
||||
class _DummyMemoryApiService:
|
||||
async def add_message(self, *_args, **_kwargs):
|
||||
return True, "ok"
|
||||
|
||||
async def get_messages(self, *_args, **_kwargs):
|
||||
return []
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _load_memory_routes_module(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
apps_mod = ModuleType("api.apps")
|
||||
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
|
||||
apps_mod.current_user = SimpleNamespace(id="user-1")
|
||||
apps_mod.login_required = lambda func: func
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
|
||||
services_mod = ModuleType("api.apps.services")
|
||||
services_mod.memory_api_service = _DummyMemoryApiService()
|
||||
monkeypatch.setitem(sys.modules, "api.apps.services", services_mod)
|
||||
|
||||
module_name = "test_message_routes_unit_module"
|
||||
module_path = repo_root / "api" / "apps" / "restful_apis" / "memory_api.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, module_name, module)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _set_request_json(monkeypatch, module, payload):
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(deepcopy(payload)))
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_add_message_partial_failure_branch(monkeypatch):
|
||||
module = _load_memory_routes_module(monkeypatch)
|
||||
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"memory_id": ["memory-1"],
|
||||
"agent_id": "agent-1",
|
||||
"session_id": "session-1",
|
||||
"user_input": "hello",
|
||||
"agent_response": "world",
|
||||
},
|
||||
)
|
||||
|
||||
async def _add_message(_memory_ids, _message_dict):
|
||||
return False, "cannot enqueue"
|
||||
|
||||
monkeypatch.setattr(module.memory_api_service, "add_message", _add_message)
|
||||
|
||||
res = _run(inspect.unwrap(module.add_message)())
|
||||
assert res["code"] == module.RetCode.SERVER_ERROR, res
|
||||
assert "Some messages failed to add" in res["message"], res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_get_messages_csv_and_missing_memory_ids(monkeypatch):
|
||||
module = _load_memory_routes_module(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args=_DummyArgs({})))
|
||||
res = _run(inspect.unwrap(module.get_messages)())
|
||||
assert res["code"] == module.RetCode.ARGUMENT_ERROR, res
|
||||
assert "memory_ids is required." in res["message"], res
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"request",
|
||||
SimpleNamespace(args=_DummyArgs({"memory_id": "m1,m2", "agent_id": "a1", "session_id": "s1", "limit": "5"})),
|
||||
)
|
||||
|
||||
async def _get_messages(memory_ids, agent_id, session_id, limit):
|
||||
assert memory_ids == ["m1", "m2"]
|
||||
assert agent_id == "a1"
|
||||
assert session_id == "s1"
|
||||
assert limit == 5
|
||||
return [{"message_id": 1}]
|
||||
|
||||
monkeypatch.setattr(module.memory_api_service, "get_messages", _get_messages)
|
||||
res = _run(inspect.unwrap(module.get_messages)())
|
||||
assert res["code"] == module.RetCode.SUCCESS, res
|
||||
assert isinstance(res["data"], list), res
|
||||
@@ -16,8 +16,7 @@
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import pytest
|
||||
import requests
|
||||
from test.testcases.configs import HOST_ADDRESS, INVALID_API_TOKEN, VERSION
|
||||
from test.testcases.configs import INVALID_API_TOKEN
|
||||
from test.testcases.restful_api.helpers.client import RestClient
|
||||
from test.testcases.utils import wait_for
|
||||
|
||||
@@ -378,12 +377,8 @@ def test_deleted_chunks_batch_not_in_retrieval_contract(rest_client, create_docu
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_related_questions_contract(auth, rest_client, rest_client_noauth):
|
||||
tokens_res = requests.get(
|
||||
f"{HOST_ADDRESS}/api/{VERSION}/system/tokens",
|
||||
headers={"Authorization": auth},
|
||||
timeout=30,
|
||||
)
|
||||
def test_related_questions_contract(rest_client, rest_client_noauth):
|
||||
tokens_res = rest_client.get("/system/tokens")
|
||||
assert tokens_res.status_code == 200, tokens_res.text
|
||||
tokens_payload = tokens_res.json()
|
||||
assert tokens_payload["code"] == 0, tokens_payload
|
||||
|
||||
@@ -15,14 +15,36 @@
|
||||
#
|
||||
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
|
||||
from test.testcases.configs import INVALID_API_TOKEN, INVALID_ID_32, SESSION_WITH_CHAT_NAME_LIMIT
|
||||
from test.testcases.restful_api.helpers.client import RestClient
|
||||
from test.testcases.utils import is_sorted
|
||||
|
||||
|
||||
def _sse_events(response_text: str) -> list[str]:
|
||||
return [line[5:] for line in response_text.splitlines() if line.startswith("data:")]
|
||||
|
||||
|
||||
def _session_names(payload):
|
||||
return [session["name"] for session in payload["data"]]
|
||||
|
||||
|
||||
def _seed_sessions(rest_client, create_chat, prefix, count=5):
|
||||
chat_id = create_chat(f"{prefix}_chat")
|
||||
sessions = []
|
||||
for index in range(count):
|
||||
name = f"{prefix}_{index}"
|
||||
res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": name})
|
||||
assert res.status_code == 200, (prefix, index, res.text)
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, (prefix, index, payload)
|
||||
sessions.append(payload["data"])
|
||||
return chat_id, sessions
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_session_crud_cycle(rest_client, create_chat):
|
||||
chat_id = create_chat("restful_session_crud_chat")
|
||||
@@ -100,6 +122,475 @@ def test_session_update_blocks_messages_and_reference(rest_client, create_chat):
|
||||
assert "`reference` cannot be changed." in ref_payload["message"], ref_payload
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_session_create_requires_auth_and_invalid_chat_contract():
|
||||
for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))):
|
||||
res = client.post("/chats/chat_id/sessions", json={"name": "x"})
|
||||
assert res.status_code == 401, (scenario_name, res.text)
|
||||
payload = res.json()
|
||||
assert payload["code"] == 401, (scenario_name, payload)
|
||||
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", (scenario_name, payload)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_session_create_validation_and_deleted_chat_contract(rest_client, create_chat):
|
||||
chat_id = create_chat("restful_session_create_contract")
|
||||
|
||||
empty_path_res = rest_client.post("/chats//sessions", json={"name": "valid_name"})
|
||||
assert empty_path_res.status_code == 200
|
||||
empty_path_payload = empty_path_res.json()
|
||||
assert empty_path_payload["code"] == 100, empty_path_payload
|
||||
assert empty_path_payload["message"] == "<MethodNotAllowed '405: Method Not Allowed'>", empty_path_payload
|
||||
|
||||
invalid_chat_res = rest_client.post("/chats/invalid_chat_assistant_id/sessions", json={"name": "valid_name"})
|
||||
assert invalid_chat_res.status_code == 200
|
||||
invalid_chat_payload = invalid_chat_res.json()
|
||||
assert invalid_chat_payload["code"] == 109, invalid_chat_payload
|
||||
assert invalid_chat_payload["message"] == "No authorization.", invalid_chat_payload
|
||||
|
||||
for scenario_name, payload in (
|
||||
("valid", {"name": "valid_name"}),
|
||||
("empty", {"name": ""}),
|
||||
("space", {"name": " "}),
|
||||
("numeric", {"name": 1}),
|
||||
):
|
||||
res = rest_client.post(f"/chats/{chat_id}/sessions", json=payload)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
body = res.json()
|
||||
if scenario_name == "valid":
|
||||
assert body["code"] == 0, (scenario_name, body)
|
||||
assert body["data"]["name"] == "valid_name", (scenario_name, body)
|
||||
assert body["data"]["chat_id"] == chat_id, (scenario_name, body)
|
||||
else:
|
||||
assert body["code"] == 102, (scenario_name, body)
|
||||
assert body["message"] == "`name` can not be empty.", (scenario_name, body)
|
||||
|
||||
duplicate_first = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "duplicated_name"}).json()
|
||||
duplicate_second = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "duplicated_name"}).json()
|
||||
assert duplicate_first["code"] == 0, duplicate_first
|
||||
assert duplicate_second["code"] == 0, duplicate_second
|
||||
assert duplicate_second["data"]["name"] == "duplicated_name", duplicate_second
|
||||
|
||||
upper_case = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "CASE INSENSITIVE"}).json()
|
||||
lower_case = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "case insensitive"}).json()
|
||||
assert upper_case["code"] == 0, upper_case
|
||||
assert lower_case["code"] == 0, lower_case
|
||||
assert upper_case["data"]["name"] == "CASE INSENSITIVE", upper_case
|
||||
assert lower_case["data"]["name"] == "case insensitive", lower_case
|
||||
|
||||
long_name = "a" * (SESSION_WITH_CHAT_NAME_LIMIT + 1)
|
||||
long_name_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": long_name})
|
||||
assert long_name_res.status_code == 200
|
||||
long_name_payload = long_name_res.json()
|
||||
assert long_name_payload["code"] == 0, long_name_payload
|
||||
assert long_name_payload["data"]["name"] == long_name[:SESSION_WITH_CHAT_NAME_LIMIT], long_name_payload
|
||||
|
||||
delete_res = rest_client.delete("/chats", json={"ids": [chat_id]})
|
||||
assert delete_res.status_code == 200
|
||||
delete_payload = delete_res.json()
|
||||
assert delete_payload["code"] == 0, delete_payload
|
||||
|
||||
create_after_delete = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "after_delete"})
|
||||
assert create_after_delete.status_code == 200
|
||||
create_after_delete_payload = create_after_delete.json()
|
||||
assert create_after_delete_payload["code"] == 109, create_after_delete_payload
|
||||
assert create_after_delete_payload["message"] == "No authorization.", create_after_delete_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_session_create_concurrent_contract(rest_client, create_chat):
|
||||
chat_id = create_chat("restful_session_create_concurrent")
|
||||
|
||||
def _create(index):
|
||||
return rest_client.post(f"/chats/{chat_id}/sessions", json={"name": f"session create {index}"}).json()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
results = list(executor.map(_create, range(20)))
|
||||
|
||||
assert len(results) == 20, results
|
||||
assert all(result["code"] == 0 for result in results), results
|
||||
|
||||
list_res = rest_client.get(f"/chats/{chat_id}/sessions", params={"page_size": 30})
|
||||
assert list_res.status_code == 200
|
||||
list_payload = list_res.json()
|
||||
assert list_payload["code"] == 0, list_payload
|
||||
assert len(list_payload["data"]) == 20, list_payload
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_session_delete_requires_auth_and_invalid_target_contract(rest_client, create_chat):
|
||||
chat_id = create_chat("restful_session_delete_auth")
|
||||
create_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_a"})
|
||||
assert create_res.status_code == 200
|
||||
session_id = create_res.json()["data"]["id"]
|
||||
|
||||
for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))):
|
||||
res = client.delete(f"/chats/{chat_id}/sessions", json={"ids": [session_id]})
|
||||
assert res.status_code == 401, (scenario_name, res.text)
|
||||
payload = res.json()
|
||||
assert payload["code"] == 401, (scenario_name, payload)
|
||||
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", (scenario_name, payload)
|
||||
|
||||
invalid_chat_res = rest_client.delete("/chats/invalid_chat_assistant_id/sessions", json={"ids": [session_id]})
|
||||
assert invalid_chat_res.status_code == 200
|
||||
invalid_chat_payload = invalid_chat_res.json()
|
||||
assert invalid_chat_payload["code"] == 109, invalid_chat_payload
|
||||
assert invalid_chat_payload["message"] == "No authorization.", invalid_chat_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_session_delete_basic_scenarios(rest_client, create_chat):
|
||||
cases = [
|
||||
("none payload", None, 0, 5, {}),
|
||||
("invalid only", {"ids": ["invalid_id"]}, 102, 5, "The chat doesn't own the session invalid_id"),
|
||||
("not json", "not json", 100, 5, "<BadRequest '400: Bad Request'>"),
|
||||
("single id", lambda sessions: {"ids": [sessions[0]["id"]]}, 0, 4, True),
|
||||
("all ids", lambda sessions: {"ids": [session["id"] for session in sessions]}, 0, 0, True),
|
||||
("delete all", {"delete_all": True}, 0, 0, True),
|
||||
("empty ids", {"ids": []}, 0, 5, {}),
|
||||
]
|
||||
|
||||
for scenario_name, payload, expected_code, expected_remaining, expected_data in cases:
|
||||
chat_id, sessions = _seed_sessions(rest_client, create_chat, f"delete_basic_{scenario_name.replace(' ', '_')}")
|
||||
if callable(payload):
|
||||
payload = payload(sessions)
|
||||
if scenario_name == "not json":
|
||||
res = rest_client.delete(
|
||||
f"/chats/{chat_id}/sessions",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=payload,
|
||||
)
|
||||
else:
|
||||
res = rest_client.delete(f"/chats/{chat_id}/sessions", json=payload)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
body = res.json()
|
||||
assert body["code"] == expected_code, (scenario_name, body)
|
||||
if expected_code == 0:
|
||||
assert body["data"] == expected_data, (scenario_name, body)
|
||||
else:
|
||||
assert body["message"] == expected_data, (scenario_name, body)
|
||||
|
||||
list_res = rest_client.get(f"/chats/{chat_id}/sessions", params={"page_size": 30})
|
||||
assert list_res.status_code == 200, (scenario_name, list_res.text)
|
||||
list_payload = list_res.json()
|
||||
assert list_payload["code"] == 0, (scenario_name, list_payload)
|
||||
assert len(list_payload["data"]) == expected_remaining, (scenario_name, list_payload)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_session_delete_error_and_repeat_contract(rest_client, create_chat):
|
||||
partial_cases = [
|
||||
("invalid first", lambda sessions: {"ids": ["invalid_id"] + [session["id"] for session in sessions]}),
|
||||
("invalid middle", lambda sessions: {"ids": [sessions[0]["id"], "invalid_id", *[session["id"] for session in sessions[1:]]]}),
|
||||
("invalid last", lambda sessions: {"ids": [session["id"] for session in sessions] + ["invalid_id"]}),
|
||||
]
|
||||
for scenario_name, payload_builder in partial_cases:
|
||||
chat_id, sessions = _seed_sessions(rest_client, create_chat, f"delete_partial_{scenario_name.replace(' ', '_')}")
|
||||
res = rest_client.delete(f"/chats/{chat_id}/sessions", json=payload_builder(sessions))
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, (scenario_name, payload)
|
||||
assert payload["data"]["success_count"] == len(sessions), (scenario_name, payload)
|
||||
assert payload["data"]["errors"] == ["The chat doesn't own the session invalid_id"], (scenario_name, payload)
|
||||
remaining = rest_client.get(f"/chats/{chat_id}/sessions", params={"page_size": 30}).json()
|
||||
assert remaining["code"] == 0, (scenario_name, remaining)
|
||||
assert remaining["data"] == [], (scenario_name, remaining)
|
||||
|
||||
duplicate_chat_id, duplicate_sessions = _seed_sessions(rest_client, create_chat, "delete_duplicate")
|
||||
duplicate_id = duplicate_sessions[0]["id"]
|
||||
duplicate_res = rest_client.delete(f"/chats/{duplicate_chat_id}/sessions", json={"ids": [duplicate_id, duplicate_id]})
|
||||
assert duplicate_res.status_code == 200
|
||||
duplicate_payload = duplicate_res.json()
|
||||
assert duplicate_payload["code"] == 0, duplicate_payload
|
||||
assert duplicate_payload["data"]["success_count"] == 1, duplicate_payload
|
||||
assert duplicate_payload["data"]["errors"] == [f"Duplicate session ids: {duplicate_id}"], duplicate_payload
|
||||
|
||||
repeated_chat_id, repeated_sessions = _seed_sessions(rest_client, create_chat, "delete_repeated")
|
||||
repeated_ids = [session["id"] for session in repeated_sessions]
|
||||
first_res = rest_client.delete(f"/chats/{repeated_chat_id}/sessions", json={"ids": repeated_ids})
|
||||
assert first_res.status_code == 200
|
||||
first_payload = first_res.json()
|
||||
assert first_payload["code"] == 0, first_payload
|
||||
assert first_payload["data"] is True, first_payload
|
||||
|
||||
second_res = rest_client.delete(f"/chats/{repeated_chat_id}/sessions", json={"ids": repeated_ids})
|
||||
assert second_res.status_code == 200
|
||||
second_payload = second_res.json()
|
||||
assert second_payload["code"] == 102, second_payload
|
||||
for session_id in repeated_ids:
|
||||
assert f"The chat doesn't own the session {session_id}" in second_payload["message"], second_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_session_delete_concurrent_and_bulk_contract(rest_client, create_chat):
|
||||
concurrent_chat_id, concurrent_sessions = _seed_sessions(rest_client, create_chat, "delete_concurrent", count=20)
|
||||
|
||||
def _delete(session):
|
||||
return rest_client.delete(f"/chats/{concurrent_chat_id}/sessions", json={"ids": [session["id"]]}).json()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
results = list(executor.map(_delete, concurrent_sessions))
|
||||
|
||||
assert len(results) == 20, results
|
||||
assert all(result["code"] == 0 for result in results), results
|
||||
assert all(result["data"] is True for result in results), results
|
||||
|
||||
list_after_concurrent = rest_client.get(f"/chats/{concurrent_chat_id}/sessions", params={"page_size": 30}).json()
|
||||
assert list_after_concurrent["code"] == 0, list_after_concurrent
|
||||
assert list_after_concurrent["data"] == [], list_after_concurrent
|
||||
|
||||
bulk_chat_id, bulk_sessions = _seed_sessions(rest_client, create_chat, "delete_bulk", count=100)
|
||||
bulk_res = rest_client.delete(
|
||||
f"/chats/{bulk_chat_id}/sessions",
|
||||
json={"ids": [session["id"] for session in bulk_sessions]},
|
||||
)
|
||||
assert bulk_res.status_code == 200
|
||||
bulk_payload = bulk_res.json()
|
||||
assert bulk_payload["code"] == 0, bulk_payload
|
||||
assert bulk_payload["data"] is True, bulk_payload
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_session_list_requires_auth_and_invalid_target_contract():
|
||||
for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))):
|
||||
res = client.get("/chats/chat_id/sessions")
|
||||
assert res.status_code == 401, (scenario_name, res.text)
|
||||
payload = res.json()
|
||||
assert payload["code"] == 401, (scenario_name, payload)
|
||||
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", (scenario_name, payload)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_session_list_filter_and_deleted_chat_contract(rest_client, create_chat):
|
||||
chat_id, sessions = _seed_sessions(rest_client, create_chat, "list_filter")
|
||||
session_ids = [session["id"] for session in sessions]
|
||||
session_names = [session["name"] for session in sessions]
|
||||
|
||||
default_res = rest_client.get(f"/chats/{chat_id}/sessions", params={"page_size": 30})
|
||||
assert default_res.status_code == 200
|
||||
default_payload = default_res.json()
|
||||
assert default_payload["code"] == 0, default_payload
|
||||
assert len(default_payload["data"]) == 5, default_payload
|
||||
|
||||
for scenario_name, params, expected_names in (
|
||||
("id none", {"id": None, "page_size": 30}, session_names),
|
||||
("id empty", {"id": "", "page_size": 30}, session_names),
|
||||
("valid id", {"id": session_ids[0], "page_size": 30}, [session_names[0]]),
|
||||
("unknown id", {"id": "unknown", "page_size": 30}, []),
|
||||
("name none", {"name": None, "page_size": 30}, session_names),
|
||||
("name empty", {"name": "", "page_size": 30}, session_names),
|
||||
("name exact", {"name": session_names[1], "page_size": 30}, [session_names[1]]),
|
||||
("name unknown", {"name": "unknown", "page_size": 30}, []),
|
||||
("name and id match", {"id": session_ids[0], "name": session_names[0], "page_size": 30}, [session_names[0]]),
|
||||
("name and id mismatch", {"id": session_ids[0], "name": "session_with_chat_assistant_100", "page_size": 30}, []),
|
||||
("name and invalid id", {"id": "id", "name": session_names[0], "page_size": 30}, []),
|
||||
("invalid params ignored", {"a": "b", "page_size": 30}, session_names),
|
||||
):
|
||||
res = rest_client.get(f"/chats/{chat_id}/sessions", params=params)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, (scenario_name, payload)
|
||||
assert set(_session_names(payload)) == set(expected_names), (scenario_name, payload)
|
||||
|
||||
invalid_chat_res = rest_client.get(f"/chats/{INVALID_ID_32}/sessions")
|
||||
assert invalid_chat_res.status_code == 200
|
||||
invalid_chat_payload = invalid_chat_res.json()
|
||||
assert invalid_chat_payload["code"] == 109, invalid_chat_payload
|
||||
assert invalid_chat_payload["message"] == "No authorization.", invalid_chat_payload
|
||||
|
||||
delete_chat_res = rest_client.delete("/chats", json={"ids": [chat_id]})
|
||||
assert delete_chat_res.status_code == 200
|
||||
delete_chat_payload = delete_chat_res.json()
|
||||
assert delete_chat_payload["code"] == 0, delete_chat_payload
|
||||
|
||||
deleted_list_res = rest_client.get(f"/chats/{chat_id}/sessions", params={"page_size": 30})
|
||||
assert deleted_list_res.status_code == 200
|
||||
deleted_list_payload = deleted_list_res.json()
|
||||
assert deleted_list_payload["code"] == 109, deleted_list_payload
|
||||
assert deleted_list_payload["message"] == "No authorization.", deleted_list_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_session_list_page_and_sort_contract(rest_client, create_chat):
|
||||
chat_id, sessions = _seed_sessions(rest_client, create_chat, "list_page_sort")
|
||||
created_names = [session["name"] for session in sessions]
|
||||
descending_names = list(reversed(created_names))
|
||||
|
||||
page_cases = [
|
||||
("page none", {"page": None, "page_size": 2}, 0, 2, ""),
|
||||
("page zero", {"page": 0, "page_size": 2}, 0, 2, ""),
|
||||
("page two", {"page": 2, "page_size": 2}, 0, 2, ""),
|
||||
("page three", {"page": 3, "page_size": 2}, 0, 1, ""),
|
||||
("page string", {"page": "3", "page_size": 2}, 0, 1, ""),
|
||||
("page negative", {"page": -1, "page_size": 2}, 100, 0, "ProgrammingError(1064"),
|
||||
("page alpha", {"page": "a", "page_size": 2}, 100, 0, "ValueError(\"invalid literal for int() with base 10: 'a'\")"),
|
||||
("page_size none", {"page_size": None}, 0, 5, ""),
|
||||
("page_size zero", {"page_size": 0}, 0, 0, ""),
|
||||
("page_size one", {"page_size": 1}, 0, 1, ""),
|
||||
("page_size six", {"page_size": 6}, 0, 5, ""),
|
||||
("page_size negative", {"page_size": -1}, 0, 5, ""),
|
||||
("page_size alpha", {"page_size": "a"}, 100, 0, "ValueError(\"invalid literal for int() with base 10: 'a'\")"),
|
||||
]
|
||||
for scenario_name, params, expected_code, expected_count, expected_message in page_cases:
|
||||
res = rest_client.get(f"/chats/{chat_id}/sessions", params=params)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
payload = res.json()
|
||||
assert payload["code"] == expected_code, (scenario_name, payload)
|
||||
if expected_code == 0:
|
||||
assert len(payload["data"]) == expected_count, (scenario_name, payload)
|
||||
else:
|
||||
assert expected_message in payload["message"], (scenario_name, payload)
|
||||
|
||||
sort_cases = [
|
||||
("orderby none", {"orderby": None, "page_size": 30}, "create_time", True, descending_names, ""),
|
||||
("orderby create", {"orderby": "create_time", "page_size": 30}, "create_time", True, descending_names, ""),
|
||||
("orderby update", {"orderby": "update_time", "page_size": 30}, "update_time", True, descending_names, ""),
|
||||
("orderby name ascending", {"orderby": "name", "desc": "False", "page_size": 30}, "name", False, created_names, ""),
|
||||
("orderby unknown", {"orderby": "unknown", "page_size": 30}, None, None, None, "AttributeError(\"type object 'Conversation' has no attribute 'unknown'\")"),
|
||||
("desc none", {"desc": None, "page_size": 30}, "create_time", True, descending_names, ""),
|
||||
("desc true", {"desc": "true", "page_size": 30}, "create_time", True, descending_names, ""),
|
||||
("desc True", {"desc": "True", "page_size": 30}, "create_time", True, descending_names, ""),
|
||||
("desc false", {"desc": "false", "page_size": 30}, "create_time", False, created_names, ""),
|
||||
("desc False", {"desc": "False", "page_size": 30}, "create_time", False, created_names, ""),
|
||||
("desc false update_time", {"desc": "False", "orderby": "update_time", "page_size": 30}, "update_time", False, created_names, ""),
|
||||
("desc unknown", {"desc": "unknown", "page_size": 30}, "create_time", True, descending_names, ""),
|
||||
]
|
||||
for scenario_name, params, field, descending, expected_names, expected_message in sort_cases:
|
||||
res = rest_client.get(f"/chats/{chat_id}/sessions", params=params)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
payload = res.json()
|
||||
expected_code = 0 if expected_names is not None else 100
|
||||
assert payload["code"] == expected_code, (scenario_name, payload)
|
||||
if expected_code == 0:
|
||||
assert is_sorted(payload["data"], field, descending), (scenario_name, payload)
|
||||
assert _session_names(payload) == expected_names, (scenario_name, payload)
|
||||
else:
|
||||
assert expected_message in payload["message"], (scenario_name, payload)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_session_list_concurrent_contract(rest_client, create_chat):
|
||||
chat_id, _sessions = _seed_sessions(rest_client, create_chat, "list_concurrent")
|
||||
|
||||
def _list(_):
|
||||
return rest_client.get(f"/chats/{chat_id}/sessions", params={"page_size": 30}).json()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
results = list(executor.map(_list, range(10)))
|
||||
|
||||
assert len(results) == 10, results
|
||||
assert all(result["code"] == 0 for result in results), results
|
||||
assert all(len(result["data"]) == 5 for result in results), results
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_session_update_requires_auth_and_invalid_target_contract(rest_client, create_chat):
|
||||
chat_id = create_chat("restful_session_update_auth")
|
||||
create_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_update_auth"})
|
||||
assert create_res.status_code == 200
|
||||
session_id = create_res.json()["data"]["id"]
|
||||
|
||||
for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))):
|
||||
res = client.patch(f"/chats/{chat_id}/sessions/{session_id}", json={"name": "x"})
|
||||
assert res.status_code == 401, (scenario_name, res.text)
|
||||
payload = res.json()
|
||||
assert payload["code"] == 401, (scenario_name, payload)
|
||||
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", (scenario_name, payload)
|
||||
|
||||
invalid_chat_res = rest_client.patch(f"/chats/{INVALID_ID_32}/sessions/{session_id}", json={"name": "x"})
|
||||
assert invalid_chat_res.status_code == 200
|
||||
invalid_chat_payload = invalid_chat_res.json()
|
||||
assert invalid_chat_payload["code"] == 109, invalid_chat_payload
|
||||
assert invalid_chat_payload["message"] == "No authorization.", invalid_chat_payload
|
||||
|
||||
empty_session_res = rest_client.patch(f"/chats/{chat_id}/sessions/", json={"name": "x"})
|
||||
assert empty_session_res.status_code == 200
|
||||
empty_session_payload = empty_session_res.json()
|
||||
assert empty_session_payload["code"] == 100, empty_session_payload
|
||||
assert empty_session_payload["message"] == "<MethodNotAllowed '405: Method Not Allowed'>", empty_session_payload
|
||||
|
||||
invalid_session_res = rest_client.patch(f"/chats/{chat_id}/sessions/invalid_session_id", json={"name": "x"})
|
||||
assert invalid_session_res.status_code == 200
|
||||
invalid_session_payload = invalid_session_res.json()
|
||||
assert invalid_session_payload["code"] == 102, invalid_session_payload
|
||||
assert invalid_session_payload["message"] == "Session not found!", invalid_session_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_session_update_name_and_param_contract(rest_client, create_chat):
|
||||
chat_id, sessions = _seed_sessions(rest_client, create_chat, "update_contract")
|
||||
session_id = sessions[0]["id"]
|
||||
|
||||
for scenario_name, payload, expected_code, expected_name_or_message in (
|
||||
("valid", {"name": "valid_name"}, 0, "valid_name"),
|
||||
("empty", {"name": ""}, 102, "`name` can not be empty."),
|
||||
("numeric", {"name": 1}, 102, "`name` can not be empty."),
|
||||
("duplicate", {"name": "duplicated_name"}, 0, "duplicated_name"),
|
||||
("case insensitive upper", {"name": "CASE INSENSITIVE UPDATE"}, 0, "CASE INSENSITIVE UPDATE"),
|
||||
("case insensitive lower", {"name": "case insensitive update"}, 0, "case insensitive update"),
|
||||
("long name", {"name": "a" * (SESSION_WITH_CHAT_NAME_LIMIT + 1)}, 0, "a" * SESSION_WITH_CHAT_NAME_LIMIT),
|
||||
):
|
||||
res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json=payload)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
body = res.json()
|
||||
assert body["code"] == expected_code, (scenario_name, body)
|
||||
if expected_code == 0:
|
||||
assert body["data"]["name"] == expected_name_or_message, (scenario_name, body)
|
||||
else:
|
||||
assert body["message"] == expected_name_or_message, (scenario_name, body)
|
||||
|
||||
unknown_key_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json={"unknown_key": "unknown_value"})
|
||||
assert unknown_key_res.status_code == 200
|
||||
unknown_key_payload = unknown_key_res.json()
|
||||
assert unknown_key_payload["code"] == 100, unknown_key_payload
|
||||
assert 'Unrecognized field name: "unknown_key"' in unknown_key_payload["message"], unknown_key_payload
|
||||
|
||||
for scenario_name, payload in (("empty payload", {}), ("none payload", None)):
|
||||
res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json=payload)
|
||||
assert res.status_code == 200, (scenario_name, res.text)
|
||||
body = res.json()
|
||||
assert body["code"] == 0, (scenario_name, body)
|
||||
assert body["data"]["id"] == session_id, (scenario_name, body)
|
||||
|
||||
delete_res = rest_client.delete("/chats", json={"ids": [chat_id]})
|
||||
assert delete_res.status_code == 200
|
||||
delete_payload = delete_res.json()
|
||||
assert delete_payload["code"] == 0, delete_payload
|
||||
|
||||
update_after_delete_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json={"name": "after_delete"})
|
||||
assert update_after_delete_res.status_code == 200
|
||||
update_after_delete_payload = update_after_delete_res.json()
|
||||
assert update_after_delete_payload["code"] == 109, update_after_delete_payload
|
||||
assert update_after_delete_payload["message"] == "No authorization.", update_after_delete_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_session_update_repeated_and_concurrent_contract(rest_client, create_chat):
|
||||
chat_id, sessions = _seed_sessions(rest_client, create_chat, "update_repeated")
|
||||
session_ids = [session["id"] for session in sessions]
|
||||
|
||||
first_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_ids[0]}", json={"name": "valid_name_1"})
|
||||
assert first_res.status_code == 200
|
||||
assert first_res.json()["code"] == 0, first_res.json()
|
||||
|
||||
second_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_ids[0]}", json={"name": "valid_name_2"})
|
||||
assert second_res.status_code == 200
|
||||
assert second_res.json()["code"] == 0, second_res.json()
|
||||
|
||||
def _update(index):
|
||||
return rest_client.patch(
|
||||
f"/chats/{chat_id}/sessions/{session_ids[index % len(session_ids)]}",
|
||||
json={"name": f"update session test {index}"},
|
||||
).json()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
results = list(executor.map(_update, range(20)))
|
||||
|
||||
assert len(results) == 20, results
|
||||
assert all(result["code"] == 0 for result in results), results
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chat_recommendation_requires_question(rest_client):
|
||||
res = rest_client.post("/chat/recommendation", json={})
|
||||
@@ -120,7 +611,11 @@ def test_related_questions_compatibility_requires_auth(rest_client_noauth):
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 102, payload
|
||||
assert "Authorization is not valid!" in payload["message"], payload
|
||||
assert payload["message"] in {
|
||||
"Authorization is not valid!",
|
||||
'Authentication error: API key is invalid!"',
|
||||
"Authentication error: API key is invalid!",
|
||||
}, payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
@@ -226,7 +721,10 @@ def test_chat_completion_validation_errors(rest_client, create_chat):
|
||||
assert missing_messages.status_code == 200
|
||||
missing_messages_payload = missing_messages.json()
|
||||
assert missing_messages_payload["code"] == 101, missing_messages_payload
|
||||
assert "required argument are missing: messages" in missing_messages_payload["message"], missing_messages_payload
|
||||
assert missing_messages_payload["message"] in {
|
||||
"required argument are missing: messages",
|
||||
"messages: is required",
|
||||
}, missing_messages_payload
|
||||
|
||||
missing_chat_for_session = rest_client.post(
|
||||
"/chat/completions",
|
||||
|
||||
Reference in New Issue
Block a user