Files
ragflow/test/testcases/restful_api/test_chats.py

2020 lines
95 KiB
Python

#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import sys
from copy import deepcopy
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from functools import wraps
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
from test.testcases.configs import CHAT_ASSISTANT_NAME_LIMIT, INVALID_API_TOKEN
from test.testcases.restful_api.helpers.client import RestClient
from test.testcases.utils import encode_avatar
from test.testcases.utils.file_utils import create_image_file
DEFAULT_CHAT_EMPTY_RESPONSE = "Sorry! No relevant content was found in the knowledge base!"
DEFAULT_CHAT_PROLOGUE = "Hi! I'm your assistant. What can I do for you?"
DEFAULT_CHAT_SYSTEM_PROMPT = (
'You are an intelligent assistant. Please summarize the content of the dataset to answer the question. '
'Please list the data in the dataset and answer in detail. When all dataset content is irrelevant to the '
'question, your answer must include the sentence "The answer you are looking for is not found in the dataset!" '
"Answers need to consider chat history.\n"
" Here is the knowledge base:\n"
" {knowledge}\n"
" The above is the knowledge base."
)
def _get_nested(data, path):
current = data
for key in path:
current = current[key]
return current
def _chat_names(payload):
return [chat["name"] for chat in payload["data"]["chats"]]
def _reset_chat_batch(rest_client, prefix, count=5):
cleanup_res = rest_client.delete("/chats", json={"ids": None, "delete_all": True})
assert cleanup_res.status_code == 200, cleanup_res.text
cleanup_payload = cleanup_res.json()
assert cleanup_payload["code"] in (0, 102), cleanup_payload
ids = []
for index in range(count):
res = rest_client.post("/chats", json={"name": f"{prefix}_{index}", "dataset_ids": []})
assert res.status_code == 200, (prefix, index, res.text)
payload = res.json()
assert payload["code"] == 0, (prefix, index, payload)
ids.append(payload["data"]["id"])
return ids
@pytest.mark.p1
class TestChatsAuthorization:
def test_create_requires_auth(self, rest_client_noauth):
res = rest_client_noauth.post("/chats", json={"name": "chat_auth", "dataset_ids": []})
assert res.status_code == 401
@pytest.mark.p1
def test_chat_crud_cycle(rest_client, clear_chats):
create_res = rest_client.post("/chats", json={"name": "restful_chat_crud", "dataset_ids": []})
assert create_res.status_code == 200
create_payload = create_res.json()
assert create_payload["code"] == 0, create_payload
chat_id = create_payload["data"]["id"]
list_res = rest_client.get("/chats", params={"id": chat_id})
assert list_res.status_code == 200
list_payload = list_res.json()
assert list_payload["code"] == 0, list_payload
assert len(list_payload["data"]["chats"]) == 1, list_payload
assert list_payload["data"]["chats"][0]["id"] == chat_id, list_payload
get_res = rest_client.get(f"/chats/{chat_id}")
assert get_res.status_code == 200
get_payload = get_res.json()
assert get_payload["code"] == 0, get_payload
assert get_payload["data"]["id"] == chat_id, get_payload
update_res = rest_client.put(f"/chats/{chat_id}", json={"name": "restful_chat_crud_updated", "dataset_ids": []})
assert update_res.status_code == 200
update_payload = update_res.json()
assert update_payload["code"] == 0, update_payload
assert update_payload["data"]["name"] == "restful_chat_crud_updated", update_payload
patch_res = rest_client.patch(f"/chats/{chat_id}", json={"name": "restful_chat_crud_patched"})
assert patch_res.status_code == 200
patch_payload = patch_res.json()
assert patch_payload["code"] == 0, patch_payload
assert patch_payload["data"]["name"] == "restful_chat_crud_patched", patch_payload
delete_res = rest_client.delete("/chats", json={"ids": [chat_id]})
assert delete_res.status_code == 200
delete_payload = delete_res.json()
assert delete_payload["code"] == 0, delete_payload
assert delete_payload["data"]["success_count"] == 1, delete_payload
list_after_delete = rest_client.get("/chats", params={"id": chat_id})
assert list_after_delete.status_code == 200
list_after_delete_payload = list_after_delete.json()
assert list_after_delete_payload["code"] == 0, list_after_delete_payload
assert list_after_delete_payload["data"]["chats"] == [], list_after_delete_payload
@pytest.mark.p2
@pytest.mark.parametrize(
"name, expected_fragment",
[
("", "`name` is required."),
(" ", "`name` is required."),
],
)
def test_chat_create_name_validation(rest_client, clear_chats, name, expected_fragment):
res = rest_client.post("/chats", json={"name": name, "dataset_ids": []})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 102, payload
assert expected_fragment in payload["message"], payload
@pytest.mark.p2
def test_chat_duplicate_name_validation(rest_client, clear_chats):
first = rest_client.post("/chats", json={"name": "duplicate_chat_name", "dataset_ids": []})
assert first.status_code == 200
first_payload = first.json()
assert first_payload["code"] == 0, first_payload
second = rest_client.post("/chats", json={"name": "duplicate_chat_name", "dataset_ids": []})
assert second.status_code == 200
second_payload = second.json()
assert second_payload["code"] == 102, second_payload
assert "Duplicated chat name" in second_payload["message"], second_payload
@pytest.mark.p2
def test_chat_list_pagination(rest_client, clear_chats):
for i in range(3):
res = rest_client.post("/chats", json={"name": f"chat_page_{i}", "dataset_ids": []})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
page_res = rest_client.get("/chats", params={"page": 1, "page_size": 2, "orderby": "create_time", "desc": "true"})
assert page_res.status_code == 200
page_payload = page_res.json()
assert page_payload["code"] == 0, page_payload
assert len(page_payload["data"]["chats"]) == 2, page_payload
assert page_payload["data"]["total"] >= 3, page_payload
@pytest.mark.p1
def test_chat_delete_requires_auth():
for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))):
res = client.delete("/chats", json={"ids": []})
assert res.status_code == 401, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == 401, (scenario_name, payload)
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", (scenario_name, payload)
@pytest.mark.p2
def test_chat_delete_basic_scenarios(rest_client, clear_chats):
existing_ids = _reset_chat_batch(rest_client, "delete_basic")
existing_res = rest_client.delete("/chats", json={"ids": existing_ids})
assert existing_res.status_code == 200
existing_payload = existing_res.json()
assert existing_payload["code"] == 0, existing_payload
assert existing_payload["data"]["success_count"] == len(existing_ids), existing_payload
list_after_existing = rest_client.get("/chats").json()
assert list_after_existing["code"] == 0, list_after_existing
assert list_after_existing["data"]["chats"] == [], list_after_existing
empty_res = rest_client.delete("/chats", json={"ids": []})
assert empty_res.status_code == 200
empty_payload = empty_res.json()
assert empty_payload["code"] == 0, empty_payload
assert empty_payload["message"] == "success", empty_payload
delete_all_ids = _reset_chat_batch(rest_client, "delete_all")
delete_all_res = rest_client.delete("/chats", json={"ids": None, "delete_all": True})
assert delete_all_res.status_code == 200
delete_all_payload = delete_all_res.json()
assert delete_all_payload["code"] == 0, delete_all_payload
assert delete_all_payload["data"]["success_count"] == len(delete_all_ids), delete_all_payload
list_after_delete_all = rest_client.get("/chats").json()
assert list_after_delete_all["code"] == 0, list_after_delete_all
assert list_after_delete_all["data"]["chats"] == [], list_after_delete_all
@pytest.mark.p2
def test_chat_delete_error_and_repeat_contract(rest_client, clear_chats):
partial_cases = [
("partial invalid id", lambda ids: {"ids": ids + ["invalid_id"]}),
("partial invalid punctuation id", lambda ids: {"ids": ids + ["!@#$%^&*()"]}),
]
for scenario_name, payload in partial_cases:
ids = _reset_chat_batch(rest_client, f"delete_partial_{scenario_name.replace(' ', '_')}")
res = rest_client.delete("/chats", json=payload(ids))
assert res.status_code == 200, (scenario_name, res.text)
body = res.json()
assert body["code"] == 0, (scenario_name, body)
assert len(body["data"]["errors"]) == 1, (scenario_name, body)
assert body["data"]["success_count"] == 5, (scenario_name, body)
list_payload = rest_client.get("/chats").json()
assert list_payload["code"] == 0, (scenario_name, list_payload)
assert list_payload["data"]["chats"] == [], (scenario_name, list_payload)
duplicate_ids = _reset_chat_batch(rest_client, "delete_duplicate_all")
duplicate_all_res = rest_client.delete("/chats", json={"ids": duplicate_ids + duplicate_ids})
assert duplicate_all_res.status_code == 200
duplicate_all_payload = duplicate_all_res.json()
assert duplicate_all_payload["code"] == 0, duplicate_all_payload
assert duplicate_all_payload["data"]["success_count"] == 5, duplicate_all_payload
assert len(duplicate_all_payload["data"]["errors"]) == 5, duplicate_all_payload
assert all(error.startswith("Duplicate chat ids: ") for error in duplicate_all_payload["data"]["errors"]), duplicate_all_payload
duplicate_one_ids = _reset_chat_batch(rest_client, "delete_duplicate_one")
duplicate_one_res = rest_client.delete("/chats", json={"ids": [duplicate_one_ids[0], duplicate_one_ids[0]]})
assert duplicate_one_res.status_code == 200
duplicate_one_payload = duplicate_one_res.json()
assert duplicate_one_payload["code"] == 0, duplicate_one_payload
assert duplicate_one_payload["data"]["success_count"] == 1, duplicate_one_payload
assert duplicate_one_payload["data"]["errors"] == [f"Duplicate chat ids: {duplicate_one_ids[0]}"], duplicate_one_payload
all_missing_res = rest_client.delete("/chats", json={"ids": ["missing-1", "missing-2"]})
assert all_missing_res.status_code == 200
all_missing_payload = all_missing_res.json()
assert all_missing_payload["code"] == 102, all_missing_payload
assert "Chat(missing-1) not found." in all_missing_payload["message"], all_missing_payload
assert "Chat(missing-2) not found." in all_missing_payload["message"], all_missing_payload
repeated_ids = _reset_chat_batch(rest_client, "delete_repeated")
first_res = rest_client.delete("/chats", json={"ids": repeated_ids})
assert first_res.status_code == 200
first_payload = first_res.json()
assert first_payload["code"] == 0, first_payload
assert first_payload["data"]["success_count"] == 5, first_payload
second_res = rest_client.delete("/chats", json={"ids": repeated_ids})
assert second_res.status_code == 200
second_payload = second_res.json()
assert second_payload["code"] == 102, second_payload
for chat_id in repeated_ids:
assert f"Chat({chat_id}) not found." in second_payload["message"], second_payload
@pytest.mark.p2
def test_chat_delete_concurrent_and_bulk_contract(rest_client, clear_chats):
concurrent_ids = _reset_chat_batch(rest_client, "delete_concurrent", count=20)
with ThreadPoolExecutor(max_workers=5) as executor:
results = list(executor.map(lambda chat_id: rest_client.delete("/chats", json={"ids": [chat_id]}).json(), concurrent_ids))
assert len(results) == 20, results
assert all(result["code"] == 0 for result in results), results
assert all(result["data"]["success_count"] == 1 for result in results), results
list_after_concurrent = rest_client.get("/chats").json()
assert list_after_concurrent["code"] == 0, list_after_concurrent
assert list_after_concurrent["data"]["chats"] == [], list_after_concurrent
bulk_ids = _reset_chat_batch(rest_client, "delete_bulk", count=100)
bulk_res = rest_client.delete("/chats", json={"ids": bulk_ids})
assert bulk_res.status_code == 200
bulk_payload = bulk_res.json()
assert bulk_payload["code"] == 0, bulk_payload
assert bulk_payload["data"]["success_count"] == len(bulk_ids), bulk_payload
@pytest.mark.p1
def test_chat_list_requires_auth():
for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))):
res = client.get("/chats")
assert res.status_code == 401, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == 401, (scenario_name, payload)
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", (scenario_name, payload)
@pytest.mark.p1
def test_chat_list_default_get_and_separate_lookup_contract(rest_client, clear_chats):
ids = _reset_chat_batch(rest_client, "list_default")
default_res = rest_client.get("/chats")
assert default_res.status_code == 200
default_payload = default_res.json()
assert default_payload["code"] == 0, default_payload
assert len(default_payload["data"]["chats"]) == 5, default_payload
assert default_payload["data"]["total"] == 5, default_payload
valid_get_res = rest_client.get(f"/chats/{ids[0]}")
assert valid_get_res.status_code == 200
valid_get_payload = valid_get_res.json()
assert valid_get_payload["code"] == 0, valid_get_payload
assert valid_get_payload["data"]["id"] == ids[0], valid_get_payload
invalid_get_res = rest_client.get("/chats/unknown")
assert invalid_get_res.status_code == 200
invalid_get_payload = invalid_get_res.json()
assert invalid_get_payload["code"] == 109, invalid_get_payload
assert invalid_get_payload["message"] == "No authorization.", invalid_get_payload
for chat_id, keywords, expected_count in ((ids[0], "list_default_0", 1), (ids[0], "list_default_1", 1), (ids[0], "unknown", 0)):
get_res = rest_client.get(f"/chats/{chat_id}")
list_res = rest_client.get("/chats", params={"keywords": keywords})
assert get_res.status_code == 200, (keywords, get_res.text)
assert list_res.status_code == 200, (keywords, list_res.text)
get_payload = get_res.json()
list_payload = list_res.json()
assert get_payload["code"] == 0, (keywords, get_payload)
assert list_payload["code"] == 0, (keywords, list_payload)
assert len(list_payload["data"]["chats"]) == expected_count, (keywords, list_payload)
@pytest.mark.p2
def test_chat_list_keyword_and_invalid_param_contract(rest_client, clear_chats):
_reset_chat_batch(rest_client, "list_keyword")
cases = [
("keywords none", {"keywords": None}, 5, None),
("keywords empty", {"keywords": ""}, 5, None),
("keywords exact", {"keywords": "list_keyword_1"}, 1, "list_keyword_1"),
("keywords unknown", {"keywords": "unknown"}, 0, None),
("invalid params ignored", {"a": "b"}, 5, None),
]
for scenario_name, params, expected_count, expected_name in cases:
res = rest_client.get("/chats", params=params)
assert res.status_code == 200, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == 0, (scenario_name, payload)
assert len(payload["data"]["chats"]) == expected_count, (scenario_name, payload)
if expected_name is not None:
assert payload["data"]["chats"][0]["name"] == expected_name, (scenario_name, payload)
@pytest.mark.p2
def test_chat_list_page_and_page_size_contract(rest_client, clear_chats):
cases = [
("page none", {"page": None, "page_size": 2}, 0, lambda total: total, ""),
("page zero", {"page": 0, "page_size": 2}, 0, lambda total: total, ""),
("page two", {"page": 2, "page_size": 2}, 0, lambda total: min(max(total - 2, 0), 2), ""),
("page three", {"page": 3, "page_size": 2}, 0, lambda total: min(max(total - 4, 0), 2), ""),
("page string", {"page": "3", "page_size": 2}, 0, lambda total: min(max(total - 4, 0), 2), ""),
("page negative", {"page": -1, "page_size": 2}, 100, None, "ProgrammingError(1064"),
("page alpha", {"page": "a", "page_size": 2}, 100, None, "ValueError(\"invalid literal for int() with base 10: 'a'\")"),
("page_size none", {"page_size": None}, 0, lambda total: total, ""),
("page_size zero", {"page_size": 0}, 0, lambda total: total, ""),
("page_size one", {"page_size": 1}, 0, lambda total: total, ""),
("page_size six", {"page_size": 6}, 0, lambda total: total, ""),
("page_size string", {"page_size": "1"}, 0, lambda total: total, ""),
("page_size negative", {"page_size": -1}, 0, lambda total: total, ""),
("page_size alpha", {"page_size": "a"}, 100, None, "ValueError(\"invalid literal for int() with base 10: 'a'\")"),
]
for scenario_name, params, expected_code, expected_count_fn, expected_message in cases:
_reset_chat_batch(rest_client, f"list_page_{scenario_name.replace(' ', '_')}")
baseline_payload = rest_client.get("/chats").json()
assert baseline_payload["code"] == 0, (scenario_name, baseline_payload)
baseline_total = baseline_payload["data"]["total"]
res = rest_client.get("/chats", params=params)
assert res.status_code == 200, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == expected_code, (scenario_name, payload)
if expected_code == 0:
assert len(payload["data"]["chats"]) == expected_count_fn(baseline_total), (scenario_name, payload)
assert payload["data"]["total"] == baseline_total, (scenario_name, payload)
else:
assert expected_message in payload["message"], (scenario_name, payload)
@pytest.mark.p2
def test_chat_list_sorting_contract(rest_client, clear_chats):
_reset_chat_batch(rest_client, "list_sort")
ascending_names = [f"list_sort_{i}" for i in range(5)]
descending_names = list(reversed(ascending_names))
cases = [
("orderby none", {"orderby": None}, 0, descending_names, ""),
("orderby create", {"orderby": "create_time"}, 0, descending_names, ""),
("orderby update", {"orderby": "update_time"}, 0, descending_names, ""),
("orderby name ascending", {"orderby": "name", "desc": "False"}, 0, ascending_names, ""),
("orderby unknown", {"orderby": "unknown"}, 100, None, "AttributeError(\"type object 'Dialog' has no attribute 'unknown'\")"),
("desc none", {"desc": None}, 0, descending_names, ""),
("desc true", {"desc": "true"}, 0, descending_names, ""),
("desc True", {"desc": "True"}, 0, descending_names, ""),
("desc bool true", {"desc": True}, 0, descending_names, ""),
("desc false", {"desc": "false"}, 0, ascending_names, ""),
("desc False", {"desc": "False"}, 0, ascending_names, ""),
("desc bool false", {"desc": False}, 0, ascending_names, ""),
("desc False update_time", {"desc": "False", "orderby": "update_time"}, 0, ascending_names, ""),
("desc unknown", {"desc": "unknown"}, 0, descending_names, ""),
]
for scenario_name, params, expected_code, expected_names, expected_message in cases:
res = rest_client.get("/chats", params=params)
assert res.status_code == 200, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == expected_code, (scenario_name, payload)
if expected_code == 0:
assert _chat_names(payload) == expected_names, (scenario_name, payload)
else:
assert expected_message in payload["message"], (scenario_name, payload)
@pytest.mark.p2
def test_chat_list_concurrent_and_dataset_delete_contract(rest_client, clear_chats, ensure_parsed_document):
_reset_chat_batch(rest_client, "list_concurrent")
with ThreadPoolExecutor(max_workers=5) as executor:
results = list(executor.map(lambda _idx: rest_client.get("/chats").json(), range(10)))
assert len(results) == 10, results
assert all(result["code"] == 0 for result in results), results
assert all(len(result["data"]["chats"]) == 5 for result in results), results
dataset_id, _ = ensure_parsed_document()
create_res = rest_client.post("/chats", json={"name": "list_after_dataset_delete", "dataset_ids": [dataset_id]})
assert create_res.status_code == 200
create_payload = create_res.json()
assert create_payload["code"] == 0, create_payload
delete_dataset_res = rest_client.delete("/datasets", json={"ids": [dataset_id]})
assert delete_dataset_res.status_code == 200
delete_dataset_payload = delete_dataset_res.json()
assert delete_dataset_payload["code"] == 0, delete_dataset_payload
list_res = rest_client.get("/chats", params={"keywords": "list_after_dataset_delete"})
assert list_res.status_code == 200
list_payload = list_res.json()
assert list_payload["code"] == 0, list_payload
assert len(list_payload["data"]["chats"]) == 1, list_payload
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _AwaitableValue:
def __init__(self, value):
self._value = value
def __await__(self):
async def _co():
return self._value
return _co().__await__()
class _DummyArgs(dict):
def get(self, key, default=None):
return super().get(key, default)
def getlist(self, key):
value = self.get(key, [])
if value is None:
return []
if isinstance(value, list):
return value
return [value]
class _StubHeaders:
def __init__(self):
self._items = []
def add_header(self, key, value):
self._items.append((key, value))
def get(self, key, default=None):
for existing_key, value in reversed(self._items):
if existing_key == key:
return value
return default
class _StubResponse:
def __init__(self, body=None, mimetype=None, content_type=None):
self.body = body
self.mimetype = mimetype
self.content_type = content_type
self.headers = _StubHeaders()
class _DummyUploadFile:
def __init__(self, filename):
self.filename = filename
self.saved_path = None
async def save(self, path):
self.saved_path = path
def _passthrough_login_required(func):
@wraps(func)
async def _wrapper(*args, **kwargs):
return await func(*args, **kwargs)
return _wrapper
class _DummyKB:
def __init__(self, kid="kb-1", embd_id="embd@factory", chunk_num=1, name="Dataset A", status="1"):
self.id = kid
self.embd_id = embd_id
self.chunk_num = chunk_num
self.name = name
self.status = status
class _DummyDialogRecord:
def __init__(self, data=None):
self._data = data or {
"id": "chat-1",
"name": "chat-name",
"description": "desc",
"icon": "icon.png",
"kb_ids": ["kb-1"],
"llm_id": "glm-4",
"llm_setting": {"temperature": 0.1},
"prompt_config": {
"system": "Answer with {knowledge}",
"parameters": [{"key": "knowledge", "optional": False}],
"prologue": "hello",
"quote": True,
},
"similarity_threshold": 0.2,
"vector_similarity_weight": 0.3,
"top_n": 6,
"top_k": 1024,
"rerank_id": "",
"meta_data_filter": {},
"tenant_id": "tenant-1",
}
def to_dict(self):
return deepcopy(self._data)
def _run(coro):
return asyncio.run(coro)
async def _collect_stream(body):
items = []
if hasattr(body, "__aiter__"):
async for item in body:
if isinstance(item, bytes):
item = item.decode("utf-8")
items.append(item)
else:
for item in body:
if isinstance(item, bytes):
item = item.decode("utf-8")
items.append(item)
return items
def _load_chat_routes_unit_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[3]
module_name = "test_chat_restful_routes_unit_module"
module_path = repo_root / "api" / "apps" / "restful_apis" / "chat_api.py"
quart_mod = ModuleType("quart")
quart_mod.request = SimpleNamespace(args=_DummyArgs())
quart_mod.Response = _StubResponse
monkeypatch.setitem(sys.modules, "quart", quart_mod)
api_pkg = ModuleType("api")
api_pkg.__path__ = [str(repo_root / "api")]
monkeypatch.setitem(sys.modules, "api", api_pkg)
apps_pkg = ModuleType("api.apps")
apps_pkg.__path__ = [str(repo_root / "api" / "apps")]
apps_pkg.current_user = SimpleNamespace(id="tenant-1")
apps_pkg.login_required = _passthrough_login_required
monkeypatch.setitem(sys.modules, "api.apps", apps_pkg)
api_pkg.apps = apps_pkg
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
common_constants_mod = ModuleType("common.constants")
class _StubLLMType(str, Enum):
CHAT = "chat"
IMAGE2TEXT = "image2text"
RERANK = "rerank"
SPEECH2TEXT = "speech2text"
TTS = "tts"
class _StubRetCode(int, Enum):
SUCCESS = 0
DATA_ERROR = 102
OPERATING_ERROR = 103
AUTHENTICATION_ERROR = 109
class _StubStatusEnum(str, Enum):
VALID = "1"
INVALID = "0"
common_constants_mod.LLMType = _StubLLMType
common_constants_mod.RetCode = _StubRetCode
common_constants_mod.StatusEnum = _StubStatusEnum
from common.constants import MAXIMUM_PAGE_NUMBER as _MPN, MAXIMUM_TASK_PAGE_NUMBER as _MTPN
common_constants_mod.MAXIMUM_PAGE_NUMBER = _MPN
common_constants_mod.MAXIMUM_TASK_PAGE_NUMBER = _MTPN
monkeypatch.setitem(sys.modules, "common.constants", common_constants_mod)
misc_utils_mod = ModuleType("common.misc_utils")
misc_utils_mod.get_uuid = lambda: "generated-chat-id"
async def _thread_pool_exec(func, *args, **kwargs):
return func(*args, **kwargs)
misc_utils_mod.thread_pool_exec = _thread_pool_exec
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
settings_mod = ModuleType("common.settings")
settings_mod.STORAGE_IMPL = type("_StorageImpl", (), {"rm": staticmethod(lambda *_args, **_kwargs: None)})()
monkeypatch.setitem(sys.modules, "common.settings", settings_mod)
dialog_service_mod = ModuleType("api.db.services.dialog_service")
class _StubDialogService:
model = SimpleNamespace(
_meta=SimpleNamespace(
fields={
"id": None,
"tenant_id": None,
"name": None,
"description": None,
"icon": None,
"kb_ids": None,
"llm_id": None,
"llm_setting": None,
"prompt_config": None,
"similarity_threshold": None,
"vector_similarity_weight": None,
"top_n": None,
"top_k": None,
"rerank_id": None,
"meta_data_filter": None,
"created_by": None,
"create_time": None,
"create_date": None,
"update_time": None,
"update_date": None,
"status": None,
}
)
)
@staticmethod
def query(**_kwargs):
return []
@staticmethod
def save(**_kwargs):
return True
@staticmethod
def get_by_id(_chat_id):
return False, None
@staticmethod
def update_by_id(_chat_id, _payload):
return True
@staticmethod
def get_by_tenant_ids(*_args, **_kwargs):
return [], 0
dialog_service_mod.DialogService = _StubDialogService
dialog_service_mod.async_ask = lambda *_args, **_kwargs: None
dialog_service_mod.async_chat = lambda *_args, **_kwargs: None
dialog_service_mod.gen_mindmap = lambda *_args, **_kwargs: None
monkeypatch.setitem(sys.modules, "api.db.services.dialog_service", dialog_service_mod)
conversation_service_mod = ModuleType("api.db.services.conversation_service")
class _StubConversationService:
@staticmethod
def query(**_kwargs):
return []
@staticmethod
def get_list(*_args, **_kwargs):
return []
@staticmethod
def get_by_id(_session_id):
return False, None
@staticmethod
def update_by_id(_session_id, _payload):
return True
@staticmethod
def delete_by_id(_session_id):
return True
@staticmethod
def save(**_kwargs):
return True
conversation_service_mod.ConversationService = _StubConversationService
conversation_service_mod.structure_answer = lambda *_args, **_kwargs: {}
monkeypatch.setitem(sys.modules, "api.db.services.conversation_service", conversation_service_mod)
kb_service_mod = ModuleType("api.db.services.knowledgebase_service")
class _StubKnowledgebaseService:
@staticmethod
def accessible(**_kwargs):
return []
@staticmethod
def query(**_kwargs):
return []
@staticmethod
def get_by_id(_kb_id):
return False, None
kb_service_mod.KnowledgebaseService = _StubKnowledgebaseService
monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod)
llm_service_mod = ModuleType("api.db.services.llm_service")
llm_service_mod.LLMBundle = lambda *_args, **_kwargs: None
monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod)
search_service_mod = ModuleType("api.db.services.search_service")
search_service_mod.SearchService = SimpleNamespace()
monkeypatch.setitem(sys.modules, "api.db.services.search_service", search_service_mod)
tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service")
tenant_model_service_mod.get_model_config_from_provider_instance = lambda *_args, **_kwargs: {}
tenant_model_service_mod.get_tenant_default_model_by_type = lambda *_args, **_kwargs: {}
tenant_model_service_mod.get_api_key = lambda *_args, **_kwargs: SimpleNamespace(id=1)
tenant_model_service_mod.split_model_name = lambda model: (model.split("@")[0],"default", "factory")
monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod)
user_service_mod = ModuleType("api.db.services.user_service")
class _StubTenantService:
@staticmethod
def get_by_id(_tenant_id):
return True, SimpleNamespace(llm_id="glm-4")
@staticmethod
def get_joined_tenants_by_user_id(_user_id):
return [{"tenant_id": "tenant-1"}, {"tenant_id": "team-tenant-2"}]
class _StubUserTenantService:
@staticmethod
def query(**_kwargs):
return []
user_service_mod.UserService = type("UserService", (), {})
user_service_mod.TenantService = _StubTenantService
user_service_mod.UserTenantService = _StubUserTenantService
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
chunk_feedback_service_mod = ModuleType("api.db.services.chunk_feedback_service")
chunk_feedback_service_mod.ChunkFeedbackService = type(
"ChunkFeedbackService",
(),
{"apply_feedback": staticmethod(lambda **_kwargs: {"success_count": 0, "fail_count": 0, "chunk_ids": []})},
)
monkeypatch.setitem(sys.modules, "api.db.services.chunk_feedback_service", chunk_feedback_service_mod)
api_utils_mod = ModuleType("api.utils.api_utils")
def _check_duplicate_ids(ids, label):
counts = {}
for item in ids or []:
counts[item] = counts.get(item, 0) + 1
duplicate_messages = [f"Duplicate {label} ids: {item}" for item, count in counts.items() if count > 1]
return list(dict.fromkeys(ids or [])), duplicate_messages
api_utils_mod.check_duplicate_ids = _check_duplicate_ids
api_utils_mod.get_data_error_result = lambda message="": {"code": 102, "data": None, "message": message}
api_utils_mod.get_json_result = lambda data=None, message="", code=0: {"code": code, "data": data, "message": message}
api_utils_mod.get_request_json = lambda: _AwaitableValue({})
api_utils_mod.server_error_response = lambda ex: {"code": 500, "data": None, "message": str(ex)}
api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda func: func)
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
rag_pkg = ModuleType("rag")
rag_pkg.__path__ = [str(repo_root / "rag")]
monkeypatch.setitem(sys.modules, "rag", rag_pkg)
rag_prompts_pkg = ModuleType("rag.prompts")
rag_prompts_pkg.__path__ = [str(repo_root / "rag" / "prompts")]
monkeypatch.setitem(sys.modules, "rag.prompts", rag_prompts_pkg)
rag_prompts_generator_mod = ModuleType("rag.prompts.generator")
rag_prompts_generator_mod.chunks_format = lambda reference: reference.get("chunks", []) if isinstance(reference, dict) else []
monkeypatch.setitem(sys.modules, "rag.prompts.generator", rag_prompts_generator_mod)
rag_prompts_template_mod = ModuleType("rag.prompts.template")
rag_prompts_template_mod.load_prompt = lambda *_args, **_kwargs: ""
monkeypatch.setitem(sys.modules, "rag.prompts.template", rag_prompts_template_mod)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, module_name, module)
spec.loader.exec_module(module)
return module
def _set_route_unit_request_json(monkeypatch, module, payload):
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(deepcopy(payload)))
@pytest.mark.p2
def test_chat_session_create_and_update_guard_matrix_unit(monkeypatch):
module = _load_chat_routes_unit_module(monkeypatch)
_set_route_unit_request_json(monkeypatch, module, {"name": "session"})
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [])
res = _run(module.create_session.__wrapped__("chat-1"))
assert res["message"] == "No authorization."
dia = SimpleNamespace(prompt_config={"prologue": "hello"})
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [dia])
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, dia))
monkeypatch.setattr(module.ConversationService, "save", lambda **_kwargs: None)
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None))
res = _run(module.create_session.__wrapped__("chat-1"))
assert "Fail to create a session" in res["message"]
_set_route_unit_request_json(monkeypatch, module, {})
monkeypatch.setattr(module.ConversationService, "query", lambda **_kwargs: [])
res = _run(module.update_session.__wrapped__("chat-1", "session-1"))
assert res["message"] == "Session not found!"
monkeypatch.setattr(module.ConversationService, "query", lambda **_kwargs: [SimpleNamespace(id="session-1")])
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [])
res = _run(module.update_session.__wrapped__("chat-1", "session-1"))
assert res["message"] == "No authorization."
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")])
_set_route_unit_request_json(monkeypatch, module, {"message": []})
res = _run(module.update_session.__wrapped__("chat-1", "session-1"))
assert "`messages` cannot be changed." in res["message"]
_set_route_unit_request_json(monkeypatch, module, {"reference": []})
res = _run(module.update_session.__wrapped__("chat-1", "session-1"))
assert "`reference` cannot be changed." in res["message"]
_set_route_unit_request_json(monkeypatch, module, {"name": ""})
res = _run(module.update_session.__wrapped__("chat-1", "session-1"))
assert "`name` can not be empty." in res["message"]
_set_route_unit_request_json(monkeypatch, module, {"name": "renamed"})
monkeypatch.setattr(module.ConversationService, "update_by_id", lambda *_args, **_kwargs: False)
res = _run(module.update_session.__wrapped__("chat-1", "session-1"))
assert res["message"] == "Session not found!"
@pytest.mark.p2
def test_chat_session_list_projection_unit(monkeypatch):
module = _load_chat_routes_unit_module(monkeypatch)
monkeypatch.setattr(
module,
"request",
SimpleNamespace(
args=SimpleNamespace(
get=lambda key, default=None: {
"page": 1,
"page_size": 30,
"orderby": "create_time",
"desc": "true",
"id": None,
"name": None,
"user_id": None,
}.get(key, default)
)
),
)
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")])
monkeypatch.setattr(
module.ConversationService,
"get_list",
lambda *_args, **_kwargs: [
{
"id": "session-1",
"dialog_id": "chat-1",
"message": [{"role": "assistant", "content": "hello"}],
"reference": [],
}
],
)
res = _run(module.list_sessions.__wrapped__("chat-1"))
assert res["data"][0]["chat_id"] == "chat-1"
assert res["data"][0]["messages"][0]["content"] == "hello"
monkeypatch.setattr(
module,
"request",
SimpleNamespace(
args=SimpleNamespace(
get=lambda key, default=None: {
"page": 1,
"page_size": 0,
"orderby": "create_time",
"desc": "true",
"id": None,
"name": None,
"user_id": None,
}.get(key, default)
)
),
)
res = _run(module.list_sessions.__wrapped__("chat-1"))
assert res["data"] == []
@pytest.mark.p2
def test_chat_session_delete_routes_partial_duplicate_unit(monkeypatch):
module = _load_chat_routes_unit_module(monkeypatch)
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")])
_set_route_unit_request_json(monkeypatch, module, {})
res = _run(module.delete_sessions.__wrapped__("chat-1"))
assert res["code"] == 0
monkeypatch.setattr(module.ConversationService, "delete_by_id", lambda *_args, **_kwargs: True)
def _conversation_query(**kwargs):
if "dialog_id" in kwargs and "id" not in kwargs:
return [SimpleNamespace(id="seed")]
if kwargs.get("id") == "ok":
return [SimpleNamespace(id="ok")]
return []
monkeypatch.setattr(module.ConversationService, "query", _conversation_query)
_set_route_unit_request_json(monkeypatch, module, {"ids": ["ok", "bad"]})
monkeypatch.setattr(module, "check_duplicate_ids", lambda ids, _kind: (ids, []))
res = _run(module.delete_sessions.__wrapped__("chat-1"))
assert res["code"] == 0
assert res["data"]["success_count"] == 1
assert res["data"]["errors"] == ["The chat doesn't own the session bad"]
_set_route_unit_request_json(monkeypatch, module, {"ids": ["bad"]})
monkeypatch.setattr(module, "check_duplicate_ids", lambda ids, _kind: (ids, []))
res = _run(module.delete_sessions.__wrapped__("chat-1"))
assert res["message"] == "The chat doesn't own the session bad"
_set_route_unit_request_json(monkeypatch, module, {"ids": ["ok", "ok"]})
monkeypatch.setattr(module, "check_duplicate_ids", lambda ids, _kind: (["ok"], ["Duplicate session ids: ok"]))
res = _run(module.delete_sessions.__wrapped__("chat-1"))
assert res["code"] == 0
assert res["data"]["success_count"] == 1
assert res["data"]["errors"] == ["Duplicate session ids: ok"]
@pytest.mark.p2
def test_chat_audio_transcription_routes_unit(monkeypatch):
module = _load_chat_routes_unit_module(monkeypatch)
monkeypatch.setattr(module, "Response", _StubResponse)
monkeypatch.setattr(module.tempfile, "mkstemp", lambda suffix: (11, f"/tmp/audio{suffix}"))
monkeypatch.setattr(module.os, "close", lambda _fd: None)
def _set_request(form, files):
monkeypatch.setattr(module, "request", SimpleNamespace(form=_AwaitableValue(form), files=_AwaitableValue(files)))
_set_request({"stream": "false"}, {})
res = _run(module.transcription.__wrapped__())
assert "Missing 'file' in multipart form-data" in res["message"]
_set_request({"stream": "false"}, {"file": _DummyUploadFile("bad.txt")})
res = _run(module.transcription.__wrapped__())
assert "Unsupported audio format: .txt" in res["message"]
_set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")})
monkeypatch.setattr(
module,
"get_tenant_default_model_by_type",
lambda *_args, **_kwargs: (_ for _ in ()).throw(LookupError("Tenant not found!")),
)
res = _run(module.transcription.__wrapped__())
assert res["message"] == "Tenant not found!"
_set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")})
monkeypatch.setattr(
module,
"get_tenant_default_model_by_type",
lambda *_args, **_kwargs: (_ for _ in ()).throw(Exception("No default ASR model is set")),
)
res = _run(module.transcription.__wrapped__())
assert res["message"] == "No default ASR model is set"
class _SyncASR:
def transcription(self, _path):
return "transcribed text"
def stream_transcription(self, _path):
return []
_set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")})
monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: {"llm_name": "asr-x"})
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _SyncASR())
monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("cleanup fail")))
res = _run(module.transcription.__wrapped__())
assert res["code"] == 0
assert res["data"]["text"] == "transcribed text"
class _StreamASR:
def transcription(self, _path):
return ""
def stream_transcription(self, _path):
yield {"event": "partial", "text": "hello"}
_set_request({"stream": "true"}, {"file": _DummyUploadFile("audio.wav")})
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _StreamASR())
monkeypatch.setattr(module.os, "remove", lambda _path: None)
resp = _run(module.transcription.__wrapped__())
assert isinstance(resp, _StubResponse)
assert resp.content_type == "text/event-stream"
chunks = _run(_collect_stream(resp.body))
assert any('"event": "partial"' in chunk for chunk in chunks)
class _ErrorASR:
def transcription(self, _path):
return ""
def stream_transcription(self, _path):
raise RuntimeError("stream asr boom")
_set_request({"stream": "true"}, {"file": _DummyUploadFile("audio.wav")})
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _ErrorASR())
monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("cleanup boom")))
resp = _run(module.transcription.__wrapped__())
chunks = _run(_collect_stream(resp.body))
assert any("stream asr boom" in chunk for chunk in chunks)
@pytest.mark.p2
def test_chat_audio_speech_routes_unit(monkeypatch):
module = _load_chat_routes_unit_module(monkeypatch)
monkeypatch.setattr(module, "Response", _StubResponse)
_set_route_unit_request_json(monkeypatch, module, {"text": "A。B"})
monkeypatch.setattr(
module,
"get_tenant_default_model_by_type",
lambda *_args, **_kwargs: (_ for _ in ()).throw(LookupError("Tenant not found!")),
)
res = _run(module.tts.__wrapped__())
assert res["message"] == "Tenant not found!"
monkeypatch.setattr(
module,
"get_tenant_default_model_by_type",
lambda *_args, **_kwargs: (_ for _ in ()).throw(Exception("No default TTS model is set")),
)
res = _run(module.tts.__wrapped__())
assert res["message"] == "No default TTS model is set"
class _TTSOk:
def tts(self, txt):
if not txt:
return []
yield f"chunk-{txt}".encode("utf-8")
monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: {"llm_name": "tts-x"})
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSOk())
resp = _run(module.tts.__wrapped__())
assert resp.mimetype == "audio/mpeg"
assert resp.headers.get("Cache-Control") == "no-cache"
assert resp.headers.get("Connection") == "keep-alive"
assert resp.headers.get("X-Accel-Buffering") == "no"
chunks = _run(_collect_stream(resp.body))
assert any("chunk-A" in chunk for chunk in chunks)
assert any("chunk-B" in chunk for chunk in chunks)
class _TTSErr:
def tts(self, _txt):
raise RuntimeError("tts boom")
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSErr())
resp = _run(module.tts.__wrapped__())
chunks = _run(_collect_stream(resp.body))
assert any('"code": 500' in chunk and "**ERROR**: tts boom" in chunk for chunk in chunks)
@pytest.mark.p1
def test_chat_create_accepts_provider_scoped_rerank_id_unit(monkeypatch):
module = _load_chat_routes_unit_module(monkeypatch)
saved = {}
query_calls = []
_set_route_unit_request_json(
monkeypatch,
module,
{
"name": "chat-a",
"icon": "icon.png",
"dataset_ids": ["kb-1"],
"llm_id": "glm-4@@CI@ZHIPU-AI",
"llm_setting": {"temperature": 0.8},
"prompt_config": {
"system": "Answer with {knowledge}",
"parameters": [{"key": "knowledge", "optional": False}],
"prologue": "Hi",
},
"rerank_id": "custom-reranker@OpenAI",
"vector_similarity_weight": 0.25,
},
)
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4@CI@ZHIPU-AI")))
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [])
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: [SimpleNamespace(id="kb-1")])
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [_DummyKB()])
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB()))
def _split_model_name_and_factory(model_name):
return {
"glm-4@ZHIPU-AI": ("glm-4", "default", "ZHIPU-AI"),
"glm-4@CI@ZHIPU-AI": ("glm-4", "CI", "ZHIPU-AI"),
"custom-reranker@OpenAI": ("custom-reranker", "default", "OpenAI")
}.get(model_name, (model_name, None))
monkeypatch.setattr(module, "split_model_name", _split_model_name_and_factory)
def _get_model_config_from_provider_instance(**kwargs):
query_calls.append(kwargs)
return {}
monkeypatch.setattr(module, "get_model_config_from_provider_instance", _get_model_config_from_provider_instance)
def _save(**kwargs):
saved.update(kwargs)
return True
monkeypatch.setattr(module.DialogService, "save", _save)
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(saved)))
res = _run(module.create.__wrapped__())
assert res["code"] == 0
assert saved["rerank_id"] == "custom-reranker@OpenAI"
assert {
"tenant_id": "tenant-1",
"model_name": "custom-reranker@OpenAI",
"model_type": "rerank",
} in query_calls
@pytest.mark.p1
def test_chat_create_allows_default_knowledge_placeholder_without_sources_unit(monkeypatch):
module = _load_chat_routes_unit_module(monkeypatch)
saved = {}
_set_route_unit_request_json(monkeypatch, module, {"name": "chat-a"})
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4")))
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [])
monkeypatch.setattr(module, "get_api_key", lambda *_args, **_kwargs: SimpleNamespace(id=1))
def _save(**kwargs):
saved.update(kwargs)
return True
monkeypatch.setattr(module.DialogService, "save", _save)
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(saved)))
res = _run(module.create.__wrapped__())
assert res["code"] == 0
assert saved["kb_ids"] == []
assert saved["prompt_config"]["system"].find("{knowledge}") >= 0
assert saved["prompt_config"]["parameters"] == [{"key": "knowledge", "optional": False}]
@pytest.mark.p2
def test_chat_create_uses_direct_chat_fields_unit(monkeypatch):
module = _load_chat_routes_unit_module(monkeypatch)
saved = {}
_set_route_unit_request_json(
monkeypatch,
module,
{
"name": "chat-a",
"icon": "icon.png",
"dataset_ids": ["kb-1"],
"llm_id": "glm-4",
"llm_setting": {"temperature": 0.8},
"prompt_config": {
"system": "Answer with {knowledge}",
"parameters": [{"key": "knowledge", "optional": False}],
"prologue": "Hi",
},
"vector_similarity_weight": 0.25,
},
)
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4")))
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [])
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: [SimpleNamespace(id="kb-1")])
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [_DummyKB()])
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB()))
monkeypatch.setattr(module, "split_model_name", lambda model: (model.split("@")[0],"default", "factory"))
def _save(**kwargs):
saved.update(kwargs)
return True
monkeypatch.setattr(module.DialogService, "save", _save)
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(saved)))
res = _run(module.create.__wrapped__())
assert res["code"] == 0
assert saved["kb_ids"] == ["kb-1"]
assert saved["prompt_config"]["prologue"] == "Hi"
assert saved["llm_id"] == "glm-4"
assert saved["llm_setting"]["temperature"] == 0.8
assert res["data"]["dataset_ids"] == ["kb-1"]
assert res["data"]["kb_names"] == ["Dataset A"]
assert "kb_ids" not in res["data"]
assert "prompt" not in res["data"]
assert "llm" not in res["data"]
assert "avatar" not in res["data"]
@pytest.mark.p2
def test_list_chats_passes_empty_owner_ids_when_omitted_unit(monkeypatch):
module = _load_chat_routes_unit_module(monkeypatch)
captured = {}
monkeypatch.setattr(
module,
"request",
SimpleNamespace(
args=SimpleNamespace(
get=lambda key, default=None: {
"keywords": "",
"page": "1",
"page_size": "10",
"orderby": "create_time",
"desc": "true",
"id": None,
"name": None,
}.get(key, default),
getlist=lambda _key: [],
)
),
)
def _get_by_tenant_ids(owner_ids, *_args, **_kwargs):
captured["owner_ids"] = owner_ids
return ([], 0)
monkeypatch.setattr(module.DialogService, "get_by_tenant_ids", _get_by_tenant_ids)
res = _run(module.list_chats.__wrapped__())
assert res["code"] == 0
assert captured["owner_ids"] == []
@pytest.mark.p2
def test_list_chats_filters_by_requested_owner_ids_unit(monkeypatch):
module = _load_chat_routes_unit_module(monkeypatch)
captured = {}
monkeypatch.setattr(
module,
"request",
SimpleNamespace(
args=SimpleNamespace(
get=lambda key, default=None: {
"keywords": "",
"page": "1",
"page_size": "10",
"orderby": "create_time",
"desc": "true",
"id": None,
"name": None,
}.get(key, default),
getlist=lambda key: ["team-tenant-2"] if key == "owner_ids" else [],
)
),
)
def _get_by_tenant_ids(owner_ids, *_args, **_kwargs):
captured["owner_ids"] = owner_ids
team_chat = _DummyDialogRecord({"id": "team-chat", "tenant_id": "team-tenant-2", "name": "team"}).to_dict()
own_chat = _DummyDialogRecord({"id": "own-chat", "tenant_id": "tenant-1", "name": "own"}).to_dict()
return ([team_chat, own_chat], 2)
monkeypatch.setattr(module.DialogService, "get_by_tenant_ids", _get_by_tenant_ids)
res = _run(module.list_chats.__wrapped__())
assert res["code"] == 0
assert captured["owner_ids"] == ["team-tenant-2"]
assert [chat["id"] for chat in res["data"]["chats"]] == ["team-chat"]
assert res["data"]["total"] == 1
@pytest.mark.p2
def test_list_chats_returns_old_business_fields_unit(monkeypatch):
module = _load_chat_routes_unit_module(monkeypatch)
monkeypatch.setattr(
module,
"request",
SimpleNamespace(
args=SimpleNamespace(
get=lambda key, default=None: {
"keywords": "",
"page": 1,
"page_size": 20,
"orderby": "create_time",
"desc": "true",
}.get(key, default),
getlist=lambda _key: [],
)
),
)
monkeypatch.setattr(module.DialogService, "get_by_tenant_ids", lambda *_args, **_kwargs: ([_DummyDialogRecord().to_dict()], 1))
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB()))
res = _run(module.list_chats.__wrapped__())
assert res["code"] == 0
chat = res["data"]["chats"][0]
assert chat["icon"] == "icon.png"
assert chat["dataset_ids"] == ["kb-1"]
assert chat["kb_names"] == ["Dataset A"]
assert "kb_ids" not in chat
assert chat["prompt_config"]["prologue"] == "hello"
assert "dataset_names" not in chat
assert "prompt" not in chat
assert "llm" not in chat
@pytest.mark.p2
def test_patch_chat_drops_response_only_fields_before_update_unit(monkeypatch):
module = _load_chat_routes_unit_module(monkeypatch)
updated = {}
existing = _DummyDialogRecord().to_dict()
payload = {
"name": "renamed-chat",
"description": existing["description"],
"icon": existing["icon"],
"dataset_ids": existing["kb_ids"],
"kb_names": ["Dataset A"],
"llm_id": existing["llm_id"],
"llm_setting": existing["llm_setting"],
"prompt_config": existing["prompt_config"],
"similarity_threshold": existing["similarity_threshold"],
"vector_similarity_weight": existing["vector_similarity_weight"],
"top_n": existing["top_n"],
"top_k": existing["top_k"],
"rerank_id": existing["rerank_id"],
}
_set_route_unit_request_json(monkeypatch, module, payload)
monkeypatch.setattr(module.DialogService, "query", lambda **kwargs: [] if "name" in kwargs else [SimpleNamespace(id="chat-1")])
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(existing)))
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4")))
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: [SimpleNamespace(id="kb-1")])
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [_DummyKB()])
monkeypatch.setattr(module, "split_model_name", lambda model: (model.split("@")[0],"default", "factory"))
monkeypatch.setattr(module, "get_api_key", lambda *args, **kwargs: SimpleNamespace(id=1))
def _update(_chat_id, req):
updated.update(req)
return True
monkeypatch.setattr(module.DialogService, "update_by_id", _update)
res = _run(module.patch_chat.__wrapped__("chat-1"))
assert res["code"] == 0
assert updated["name"] == "renamed-chat"
assert "kb_names" not in updated
@pytest.mark.p2
def test_patch_chat_merges_prompt_and_llm_settings_unit(monkeypatch):
module = _load_chat_routes_unit_module(monkeypatch)
updated = {}
existing = _DummyDialogRecord().to_dict()
_set_route_unit_request_json(
monkeypatch,
module,
{"prompt_config": {"prologue": "updated opener"}, "llm_setting": {"temperature": 0.9}},
)
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")])
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(existing)))
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4")))
def _update(_chat_id, payload):
updated.update(payload)
return True
monkeypatch.setattr(module.DialogService, "update_by_id", _update)
res = _run(module.patch_chat.__wrapped__("chat-1"))
assert res["code"] == 0
assert updated["prompt_config"]["system"] == "Answer with {knowledge}"
assert updated["prompt_config"]["prologue"] == "updated opener"
assert updated["llm_setting"]["temperature"] == 0.9
@pytest.mark.p2
def test_update_chat_allows_knowledge_placeholder_without_sources_unit(monkeypatch):
module = _load_chat_routes_unit_module(monkeypatch)
existing = _DummyDialogRecord().to_dict()
_set_route_unit_request_json(
monkeypatch,
module,
{
"name": "chat-name",
"description": "desc",
"icon": "icon.png",
"dataset_ids": [],
"llm_id": "glm-4",
"llm_setting": {"temperature": 0.1},
"prompt_config": {
"system": "Answer with {knowledge}",
"parameters": [{"key": "knowledge", "optional": False}],
"prologue": "hello",
"quote": True,
},
"similarity_threshold": 0.2,
"vector_similarity_weight": 0.3,
"top_n": 6,
"top_k": 1024,
"rerank_id": "",
},
)
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")])
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(existing)))
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4")))
monkeypatch.setattr(module, "split_model_name", lambda model: (model.split("@")[0], "default", "factory"))
updated = {}
def _update(_chat_id, payload):
updated.update(payload)
return True
monkeypatch.setattr(module.DialogService, "update_by_id", _update)
res = _run(module.update_chat.__wrapped__("chat-1"))
assert res["code"] == 0
assert updated["prompt_config"]["system"] == "Answer with {knowledge}"
@pytest.mark.p1
def test_chat_create_dataset_ids_contract(rest_client, clear_chats, ensure_parsed_document):
dataset_id, _ = ensure_parsed_document()
cases = [
("empty dataset_ids", [], 0, "", []),
("owned parsed dataset", [dataset_id], 0, "", [dataset_id]),
("invalid dataset id", ["invalid_dataset_id"], 102, "You don't own the dataset invalid_dataset_id", None),
("dataset_ids wrong type", "invalid_dataset_id", 102, "`dataset_ids` should be a list.", None),
]
for index, (scenario_name, dataset_ids, expected_code, expected_message, expected_dataset_ids) in enumerate(cases, start=1):
res = rest_client.post(
"/chats",
json={"name": f"restful_chat_dataset_ids_{index}", "dataset_ids": dataset_ids},
)
assert res.status_code == 200, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == expected_code, (scenario_name, payload)
if expected_code == 0:
assert payload["data"]["dataset_ids"] == expected_dataset_ids, (scenario_name, payload)
else:
assert payload["message"] == expected_message, (scenario_name, payload)
@pytest.mark.p2
def test_chat_create_avatar_contract(rest_client, clear_chats, tmp_path):
image_path = create_image_file(tmp_path / "restful_chat_avatar.png")
encoded_avatar = encode_avatar(image_path)
res = rest_client.post(
"/chats",
json={"name": "restful_chat_avatar", "dataset_ids": [], "icon": encoded_avatar},
)
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
assert payload["data"]["icon"] == encoded_avatar, payload
@pytest.mark.p2
def test_chat_create_llm_contract(rest_client, clear_chats, ensure_parsed_document):
dataset_id, _ = ensure_parsed_document()
cases = [
("default llm", {}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {}),
("explicit llm_id", {"llm_id": "glm-4"}, 102, "`llm_id` glm-4 doesn't exist", None, None),
("unknown llm_id", {"llm_id": "unknown"}, 102, "`llm_id` unknown doesn't exist", None, None),
("temperature zero", {"llm_setting": {"temperature": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": 0}),
("temperature one", {"llm_setting": {"temperature": 1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": 1}),
("temperature negative one", {"llm_setting": {"temperature": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": -1}),
("temperature ten", {"llm_setting": {"temperature": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": 10}),
("temperature string", {"llm_setting": {"temperature": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": "a"}),
("top_p zero", {"llm_setting": {"top_p": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": 0}),
("top_p one", {"llm_setting": {"top_p": 1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": 1}),
("top_p negative one", {"llm_setting": {"top_p": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": -1}),
("top_p ten", {"llm_setting": {"top_p": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": 10}),
("top_p string", {"llm_setting": {"top_p": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": "a"}),
("presence_penalty zero", {"llm_setting": {"presence_penalty": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": 0}),
("presence_penalty one", {"llm_setting": {"presence_penalty": 1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": 1}),
("presence_penalty negative one", {"llm_setting": {"presence_penalty": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": -1}),
("presence_penalty ten", {"llm_setting": {"presence_penalty": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": 10}),
("presence_penalty string", {"llm_setting": {"presence_penalty": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": "a"}),
("frequency_penalty zero", {"llm_setting": {"frequency_penalty": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": 0}),
("frequency_penalty one", {"llm_setting": {"frequency_penalty": 1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": 1}),
("frequency_penalty negative one", {"llm_setting": {"frequency_penalty": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": -1}),
("frequency_penalty ten", {"llm_setting": {"frequency_penalty": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": 10}),
("frequency_penalty string", {"llm_setting": {"frequency_penalty": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": "a"}),
("max_token zero", {"llm_setting": {"max_token": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": 0}),
("max_token 1024", {"llm_setting": {"max_token": 1024}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": 1024}),
("max_token negative one", {"llm_setting": {"max_token": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": -1}),
("max_token ten", {"llm_setting": {"max_token": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": 10}),
("max_token string", {"llm_setting": {"max_token": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": "a"}),
("unknown llm setting key", {"llm_setting": {"unknown": "unknown"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"unknown": "unknown"}),
]
for index, (scenario_name, extra_payload, expected_code, expected_message, expected_llm_id, expected_llm_setting) in enumerate(cases, start=1):
payload = {
"name": f"restful_chat_llm_{index}",
"dataset_ids": [dataset_id],
}
payload.update(extra_payload)
res = rest_client.post("/chats", json=payload)
assert res.status_code == 200, (scenario_name, res.text)
body = res.json()
assert body["code"] == expected_code, (scenario_name, body)
if expected_code == 0:
assert body["data"]["llm_id"] == expected_llm_id, (scenario_name, body)
assert body["data"]["llm_setting"] == expected_llm_setting, (scenario_name, body)
else:
assert body["message"] == expected_message, (scenario_name, body)
@pytest.mark.p2
def test_chat_create_prompt_contract(rest_client, clear_chats):
cases = [
(
"default prompt config",
{},
{
("similarity_threshold",): 0.1,
("vector_similarity_weight",): 0.3,
("top_n",): 6,
("rerank_id",): "",
("prompt_config", "parameters"): [{"key": "knowledge", "optional": False}],
("prompt_config", "empty_response"): DEFAULT_CHAT_EMPTY_RESPONSE,
("prompt_config", "prologue"): DEFAULT_CHAT_PROLOGUE,
("prompt_config", "quote"): True,
("prompt_config", "system"): DEFAULT_CHAT_SYSTEM_PROMPT,
},
),
("similarity_threshold zero", {"similarity_threshold": 0}, {("similarity_threshold",): 0}),
("similarity_threshold one", {"similarity_threshold": 1}, {("similarity_threshold",): 1}),
("similarity_threshold negative one", {"similarity_threshold": -1}, {("similarity_threshold",): -1.0}),
("similarity_threshold ten", {"similarity_threshold": 10}, {("similarity_threshold",): 10.0}),
("similarity_threshold string", {"similarity_threshold": "a"}, {("similarity_threshold",): 0.0}),
("vector_similarity_weight one", {"vector_similarity_weight": 1}, {("vector_similarity_weight",): 1}),
("vector_similarity_weight zero", {"vector_similarity_weight": 0}, {("vector_similarity_weight",): 0}),
("vector_similarity_weight two", {"vector_similarity_weight": 2}, {("vector_similarity_weight",): 2.0}),
("vector_similarity_weight negative nine", {"vector_similarity_weight": -9}, {("vector_similarity_weight",): -9.0}),
("vector_similarity_weight string", {"vector_similarity_weight": "a"}, {("vector_similarity_weight",): 0.0}),
("empty prompt parameters", {"prompt_config": {"parameters": []}}, {("prompt_config", "parameters"): []}),
("top_n zero", {"top_n": 0}, {("top_n",): 0}),
("top_n one", {"top_n": 1}, {("top_n",): 1}),
("top_n negative one", {"top_n": -1}, {("top_n",): -1}),
("top_n ten", {"top_n": 10}, {("top_n",): 10}),
("top_n string", {"top_n": "a"}, {("top_n",): 0}),
("empty_response plain text", {"prompt_config": {"empty_response": "Hello World"}}, {("prompt_config", "empty_response"): "Hello World"}),
("empty_response empty string", {"prompt_config": {"empty_response": ""}}, {("prompt_config", "empty_response"): ""}),
("empty_response punctuation", {"prompt_config": {"empty_response": "!@#$%^&*()"}}, {("prompt_config", "empty_response"): "!@#$%^&*()"}),
("empty_response chinese text", {"prompt_config": {"empty_response": "中文测试"}}, {("prompt_config", "empty_response"): "中文测试"}),
("empty_response integer", {"prompt_config": {"empty_response": 123}}, {("prompt_config", "empty_response"): 123}),
("empty_response boolean", {"prompt_config": {"empty_response": True}}, {("prompt_config", "empty_response"): True}),
("empty_response space", {"prompt_config": {"empty_response": " "}}, {("prompt_config", "empty_response"): " "}),
("prologue plain text", {"prompt_config": {"prologue": "Hello World"}}, {("prompt_config", "prologue"): "Hello World"}),
("prologue empty string", {"prompt_config": {"prologue": ""}}, {("prompt_config", "prologue"): ""}),
("prologue punctuation", {"prompt_config": {"prologue": "!@#$%^&*()"}}, {("prompt_config", "prologue"): "!@#$%^&*()"}),
("prologue chinese text", {"prompt_config": {"prologue": "中文测试"}}, {("prompt_config", "prologue"): "中文测试"}),
("prologue integer", {"prompt_config": {"prologue": 123}}, {("prompt_config", "prologue"): 123}),
("prologue boolean", {"prompt_config": {"prologue": True}}, {("prompt_config", "prologue"): True}),
("prologue space", {"prompt_config": {"prologue": " "}}, {("prompt_config", "prologue"): " "}),
("quote true", {"prompt_config": {"quote": True}}, {("prompt_config", "quote"): True}),
("quote false", {"prompt_config": {"quote": False}}, {("prompt_config", "quote"): False}),
("system prompt with knowledge prefix", {"prompt_config": {"system": "Hello World {knowledge}"}}, {("prompt_config", "system"): "Hello World {knowledge}"}),
("system prompt only knowledge", {"prompt_config": {"system": "{knowledge}"}}, {("prompt_config", "system"): "{knowledge}"}),
("system prompt punctuation", {"prompt_config": {"system": "!@#$%^&*() {knowledge}"}}, {("prompt_config", "system"): "!@#$%^&*() {knowledge}"}),
("system prompt chinese text", {"prompt_config": {"system": "中文测试 {knowledge}"}}, {("prompt_config", "system"): "中文测试 {knowledge}"}),
("system prompt plain text", {"prompt_config": {"system": "Hello World"}}, {("prompt_config", "system"): "Hello World"}),
(
"system prompt with explicit empty parameters",
{"prompt_config": {"system": "Hello World", "parameters": []}},
{("prompt_config", "system"): "Hello World", ("prompt_config", "parameters"): []},
),
("system prompt integer", {"prompt_config": {"system": 123}}, {("prompt_config", "system"): 123}),
("system prompt boolean", {"prompt_config": {"system": True}}, {("prompt_config", "system"): True}),
("unknown prompt_config key", {"prompt_config": {"unknown": "unknown"}}, {("prompt_config", "unknown"): "unknown"}),
]
for index, (scenario_name, extra_payload, expected_values) in enumerate(cases, start=1):
res = rest_client.post(
"/chats",
json={"name": f"restful_chat_prompt_{index}", "dataset_ids": [], **extra_payload},
)
assert res.status_code == 200, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == 0, (scenario_name, payload)
for path, expected_value in expected_values.items():
assert _get_nested(payload["data"], path) == expected_value, (scenario_name, path, payload)
@pytest.mark.p2
def test_chat_create_additional_guards_contract(rest_client, clear_chats):
cases = [
("reject tenant_id override", {"tenant_id": "tenant-should-not-pass"}, "`tenant_id` must not be provided."),
("reject unknown rerank_id", {"rerank_id": "unknown-rerank-model"}, "`rerank_id` unknown-rerank-model doesn't exist"),
]
for index, (scenario_name, extra_payload, expected_message) in enumerate(cases, start=1):
res = rest_client.post(
"/chats",
json={"name": f"restful_chat_guard_{index}", "dataset_ids": [], **extra_payload},
)
assert res.status_code == 200, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == 102, (scenario_name, payload)
assert expected_message in payload["message"], (scenario_name, payload)
@pytest.mark.p2
def test_chat_create_rejects_unparsed_document(rest_client, clear_chats, create_document):
dataset_id, _ = create_document()
res = rest_client.post(
"/chats",
json={"name": "restful_chat_unparsed_document", "dataset_ids": [dataset_id]},
)
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 102, payload
assert "doesn't own parsed file" in payload["message"], payload
@pytest.mark.p2
def test_chat_update_name_contract(rest_client, clear_chats):
duplicate_res = rest_client.post("/chats", json={"name": "restful_chat_update_duplicate", "dataset_ids": []})
assert duplicate_res.status_code == 200
duplicate_payload = duplicate_res.json()
assert duplicate_payload["code"] == 0, duplicate_payload
target_res = rest_client.post("/chats", json={"name": "restful_chat_update_name_target", "dataset_ids": []})
assert target_res.status_code == 200
target_payload = target_res.json()
assert target_payload["code"] == 0, target_payload
chat_id = target_payload["data"]["id"]
cases = [
("valid name", {"name": "valid_name"}, 0, "", "valid_name"),
(
"name too long",
{"name": "a" * (CHAT_ASSISTANT_NAME_LIMIT + 1)},
102,
f"Chat name length is {CHAT_ASSISTANT_NAME_LIMIT + 1} which is larger than {CHAT_ASSISTANT_NAME_LIMIT}.",
None,
),
("name wrong type", {"name": 1}, 102, "Chat name must be a string.", None),
("name empty", {"name": ""}, 102, "`name` cannot be empty.", None),
("duplicate lowercase", {"name": "restful_chat_update_duplicate"}, 102, "Duplicated chat name.", None),
("duplicate uppercase", {"name": "RESTFUL_CHAT_UPDATE_DUPLICATE"}, 102, "Duplicated chat name.", None),
]
for scenario_name, patch_payload, expected_code, expected_message, expected_name in cases:
res = rest_client.patch(f"/chats/{chat_id}", json=patch_payload)
assert res.status_code == 200, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == expected_code, (scenario_name, payload)
if expected_code == 0:
get_res = rest_client.get(f"/chats/{chat_id}")
assert get_res.status_code == 200, (scenario_name, get_res.text)
get_payload = get_res.json()
assert get_payload["code"] == 0, (scenario_name, get_payload)
assert get_payload["data"]["name"] == expected_name, (scenario_name, get_payload)
else:
assert payload["message"] == expected_message, (scenario_name, payload)
@pytest.mark.p2
def test_chat_update_dataset_ids_contract(rest_client, clear_chats, ensure_parsed_document):
dataset_id, _ = ensure_parsed_document()
target_res = rest_client.post("/chats", json={"name": "restful_chat_update_dataset_target", "dataset_ids": []})
assert target_res.status_code == 200
target_payload = target_res.json()
assert target_payload["code"] == 0, target_payload
chat_id = target_payload["data"]["id"]
cases = [
("empty dataset_ids", [], 0, "", []),
("owned parsed dataset", [dataset_id], 0, "", [dataset_id]),
("invalid dataset id", ["invalid_dataset_id"], 102, "You don't own the dataset invalid_dataset_id", None),
("dataset_ids wrong type", "invalid_dataset_id", 102, "`dataset_ids` should be a list.", None),
]
for scenario_name, dataset_ids, expected_code, expected_message, expected_dataset_ids in cases:
res = rest_client.put(
f"/chats/{chat_id}",
json={"name": "ragflow test", "dataset_ids": dataset_ids},
)
assert res.status_code == 200, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == expected_code, (scenario_name, payload)
if expected_code == 0:
get_res = rest_client.get(f"/chats/{chat_id}")
assert get_res.status_code == 200, (scenario_name, get_res.text)
get_payload = get_res.json()
assert get_payload["code"] == 0, (scenario_name, get_payload)
assert get_payload["data"]["name"] == "ragflow test", (scenario_name, get_payload)
assert get_payload["data"]["dataset_ids"] == expected_dataset_ids, (scenario_name, get_payload)
else:
assert payload["message"] == expected_message, (scenario_name, payload)
@pytest.mark.p2
def test_chat_update_avatar_contract(rest_client, clear_chats, ensure_parsed_document, tmp_path):
dataset_id, _ = ensure_parsed_document()
create_res = rest_client.post("/chats", json={"name": "restful_chat_update_avatar_target", "dataset_ids": []})
assert create_res.status_code == 200
create_payload = create_res.json()
assert create_payload["code"] == 0, create_payload
chat_id = create_payload["data"]["id"]
image_path = create_image_file(tmp_path / "restful_chat_update_avatar.png")
encoded_avatar = encode_avatar(image_path)
res = rest_client.put(
f"/chats/{chat_id}",
json={"name": "avatar_test", "icon": encoded_avatar, "dataset_ids": [dataset_id]},
)
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
get_res = rest_client.get(f"/chats/{chat_id}")
assert get_res.status_code == 200
get_payload = get_res.json()
assert get_payload["code"] == 0, get_payload
assert get_payload["data"]["name"] == "avatar_test", get_payload
assert get_payload["data"]["icon"] == encoded_avatar, get_payload
assert get_payload["data"]["dataset_ids"] == [dataset_id], get_payload
@pytest.mark.p2
def test_chat_update_llm_contract(rest_client, clear_chats, ensure_parsed_document):
dataset_id, _ = ensure_parsed_document()
cases = [
("default llm", {}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {}),
("explicit llm_id", {"llm_id": "glm-4"}, 102, "`llm_id` glm-4 doesn't exist", None, None),
("unknown llm_id", {"llm_id": "unknown"}, 102, "`llm_id` unknown doesn't exist", None, None),
("temperature zero", {"llm_setting": {"temperature": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": 0}),
("temperature one", {"llm_setting": {"temperature": 1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": 1}),
("temperature negative one", {"llm_setting": {"temperature": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": -1}),
("temperature ten", {"llm_setting": {"temperature": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": 10}),
("temperature string", {"llm_setting": {"temperature": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"temperature": "a"}),
("top_p zero", {"llm_setting": {"top_p": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": 0}),
("top_p one", {"llm_setting": {"top_p": 1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": 1}),
("top_p negative one", {"llm_setting": {"top_p": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": -1}),
("top_p ten", {"llm_setting": {"top_p": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": 10}),
("top_p string", {"llm_setting": {"top_p": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"top_p": "a"}),
("presence_penalty zero", {"llm_setting": {"presence_penalty": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": 0}),
("presence_penalty one", {"llm_setting": {"presence_penalty": 1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": 1}),
("presence_penalty negative one", {"llm_setting": {"presence_penalty": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": -1}),
("presence_penalty ten", {"llm_setting": {"presence_penalty": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": 10}),
("presence_penalty string", {"llm_setting": {"presence_penalty": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"presence_penalty": "a"}),
("frequency_penalty zero", {"llm_setting": {"frequency_penalty": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": 0}),
("frequency_penalty one", {"llm_setting": {"frequency_penalty": 1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": 1}),
("frequency_penalty negative one", {"llm_setting": {"frequency_penalty": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": -1}),
("frequency_penalty ten", {"llm_setting": {"frequency_penalty": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": 10}),
("frequency_penalty string", {"llm_setting": {"frequency_penalty": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"frequency_penalty": "a"}),
("max_token zero", {"llm_setting": {"max_token": 0}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": 0}),
("max_token 1024", {"llm_setting": {"max_token": 1024}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": 1024}),
("max_token negative one", {"llm_setting": {"max_token": -1}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": -1}),
("max_token ten", {"llm_setting": {"max_token": 10}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": 10}),
("max_token string", {"llm_setting": {"max_token": "a"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"max_token": "a"}),
("unknown llm setting key", {"llm_setting": {"unknown": "unknown"}}, 0, "", "glm-4-flash@CI@ZHIPU-AI", {"unknown": "unknown"}),
]
for index, (scenario_name, extra_payload, expected_code, expected_message, expected_llm_id, expected_llm_setting) in enumerate(cases, start=1):
create_res = rest_client.post(
"/chats",
json={"name": f"restful_chat_update_llm_target_{index}", "dataset_ids": [dataset_id]},
)
assert create_res.status_code == 200, (scenario_name, create_res.text)
create_payload = create_res.json()
assert create_payload["code"] == 0, (scenario_name, create_payload)
chat_id = create_payload["data"]["id"]
updated_name = f"llm_test_{index}"
payload = {"name": updated_name, "dataset_ids": [dataset_id]}
payload.update(extra_payload)
res = rest_client.put(f"/chats/{chat_id}", json=payload)
assert res.status_code == 200, (scenario_name, res.text)
body = res.json()
assert body["code"] == expected_code, (scenario_name, body)
if expected_code == 0:
get_res = rest_client.get(f"/chats/{chat_id}")
assert get_res.status_code == 200, (scenario_name, get_res.text)
get_payload = get_res.json()
assert get_payload["code"] == 0, (scenario_name, get_payload)
assert get_payload["data"]["name"] == updated_name, (scenario_name, get_payload)
assert get_payload["data"]["llm_id"] == expected_llm_id, (scenario_name, get_payload)
assert get_payload["data"]["llm_setting"] == expected_llm_setting, (scenario_name, get_payload)
else:
assert body["message"] == expected_message, (scenario_name, body)
@pytest.mark.p2
def test_chat_update_prompt_contract(rest_client, clear_chats, ensure_parsed_document):
dataset_id, _ = ensure_parsed_document()
cases = [
(
"default prompt config",
{},
{
("similarity_threshold",): 0.1,
("vector_similarity_weight",): 0.3,
("top_n",): 6,
("prompt_config", "parameters"): [{"key": "knowledge", "optional": False}],
("prompt_config", "empty_response"): DEFAULT_CHAT_EMPTY_RESPONSE,
("prompt_config", "prologue"): DEFAULT_CHAT_PROLOGUE,
("prompt_config", "quote"): True,
("prompt_config", "system"): DEFAULT_CHAT_SYSTEM_PROMPT,
},
),
("similarity_threshold zero", {"similarity_threshold": 0}, {("similarity_threshold",): 0}),
("similarity_threshold one", {"similarity_threshold": 1}, {("similarity_threshold",): 1}),
("similarity_threshold negative one", {"similarity_threshold": -1}, {("similarity_threshold",): -1.0}),
("similarity_threshold ten", {"similarity_threshold": 10}, {("similarity_threshold",): 10.0}),
("similarity_threshold string", {"similarity_threshold": "a"}, {("similarity_threshold",): 0.0}),
("vector_similarity_weight zero", {"vector_similarity_weight": 0}, {("vector_similarity_weight",): 0}),
("vector_similarity_weight one", {"vector_similarity_weight": 1}, {("vector_similarity_weight",): 1}),
("vector_similarity_weight negative one", {"vector_similarity_weight": -1}, {("vector_similarity_weight",): -1.0}),
("vector_similarity_weight ten", {"vector_similarity_weight": 10}, {("vector_similarity_weight",): 10.0}),
("vector_similarity_weight string", {"vector_similarity_weight": "a"}, {("vector_similarity_weight",): 0.0}),
("empty prompt parameters", {"prompt_config": {"parameters": []}}, {("prompt_config", "parameters"): []}),
("top_n zero", {"top_n": 0}, {("top_n",): 0}),
("top_n one", {"top_n": 1}, {("top_n",): 1}),
("top_n negative one", {"top_n": -1}, {("top_n",): -1}),
("top_n ten", {"top_n": 10}, {("top_n",): 10}),
("top_n string", {"top_n": "a"}, {("top_n",): 0}),
("empty_response plain text", {"prompt_config": {"empty_response": "Hello World"}}, {("prompt_config", "empty_response"): "Hello World"}),
("empty_response empty string", {"prompt_config": {"empty_response": ""}}, {("prompt_config", "empty_response"): ""}),
("empty_response punctuation", {"prompt_config": {"empty_response": "!@#$%^&*()"}}, {("prompt_config", "empty_response"): "!@#$%^&*()"}),
("empty_response chinese text", {"prompt_config": {"empty_response": "中文测试"}}, {("prompt_config", "empty_response"): "中文测试"}),
("empty_response integer", {"prompt_config": {"empty_response": 123}}, {("prompt_config", "empty_response"): 123}),
("empty_response boolean", {"prompt_config": {"empty_response": True}}, {("prompt_config", "empty_response"): True}),
("empty_response space", {"prompt_config": {"empty_response": " "}}, {("prompt_config", "empty_response"): " "}),
("prologue plain text", {"prompt_config": {"prologue": "Hello World"}}, {("prompt_config", "prologue"): "Hello World"}),
("prologue empty string", {"prompt_config": {"prologue": ""}}, {("prompt_config", "prologue"): ""}),
("prologue punctuation", {"prompt_config": {"prologue": "!@#$%^&*()"}}, {("prompt_config", "prologue"): "!@#$%^&*()"}),
("prologue chinese text", {"prompt_config": {"prologue": "中文测试"}}, {("prompt_config", "prologue"): "中文测试"}),
("prologue integer", {"prompt_config": {"prologue": 123}}, {("prompt_config", "prologue"): 123}),
("prologue boolean", {"prompt_config": {"prologue": True}}, {("prompt_config", "prologue"): True}),
("prologue space", {"prompt_config": {"prologue": " "}}, {("prompt_config", "prologue"): " "}),
("quote true", {"prompt_config": {"quote": True}}, {("prompt_config", "quote"): True}),
("quote false", {"prompt_config": {"quote": False}}, {("prompt_config", "quote"): False}),
("system prompt with knowledge prefix", {"prompt_config": {"system": "Hello World {knowledge}"}}, {("prompt_config", "system"): "Hello World {knowledge}"}),
("system prompt only knowledge", {"prompt_config": {"system": "{knowledge}"}}, {("prompt_config", "system"): "{knowledge}"}),
("system prompt punctuation", {"prompt_config": {"system": "!@#$%^&*() {knowledge}"}}, {("prompt_config", "system"): "!@#$%^&*() {knowledge}"}),
("system prompt chinese text", {"prompt_config": {"system": "中文测试 {knowledge}"}}, {("prompt_config", "system"): "中文测试 {knowledge}"}),
("system prompt plain text", {"prompt_config": {"system": "Hello World"}}, {("prompt_config", "system"): "Hello World"}),
(
"system prompt with explicit empty parameters",
{"prompt_config": {"system": "Hello World", "parameters": []}},
{("prompt_config", "system"): "Hello World", ("prompt_config", "parameters"): []},
),
("system prompt integer", {"prompt_config": {"system": 123}}, {("prompt_config", "system"): 123}),
("system prompt boolean", {"prompt_config": {"system": True}}, {("prompt_config", "system"): True}),
("unknown prompt key", {"unknown": "unknown"}, {}),
]
for index, (scenario_name, extra_payload, expected_values) in enumerate(cases, start=1):
create_res = rest_client.post(
"/chats",
json={"name": f"restful_chat_update_prompt_target_{index}", "dataset_ids": [dataset_id]},
)
assert create_res.status_code == 200, (scenario_name, create_res.text)
create_payload = create_res.json()
assert create_payload["code"] == 0, (scenario_name, create_payload)
chat_id = create_payload["data"]["id"]
updated_name = f"prompt_test_{index}"
res = rest_client.put(
f"/chats/{chat_id}",
json={"name": updated_name, "dataset_ids": [dataset_id], **extra_payload},
)
assert res.status_code == 200, (scenario_name, res.text)
payload = res.json()
assert payload["code"] == 0, (scenario_name, payload)
get_res = rest_client.get(f"/chats/{chat_id}")
assert get_res.status_code == 200, (scenario_name, get_res.text)
get_payload = get_res.json()
assert get_payload["code"] == 0, (scenario_name, get_payload)
assert get_payload["data"]["name"] == updated_name, (scenario_name, get_payload)
assert get_payload["data"]["dataset_ids"] == [dataset_id], (scenario_name, get_payload)
for path, expected_value in expected_values.items():
assert _get_nested(get_payload["data"], path) == expected_value, (scenario_name, path, get_payload)
@pytest.mark.p2
def test_chat_update_mapping_and_validation_branches_p2(rest_client, clear_chats):
duplicate_res = rest_client.post("/chats", json={"name": "restful_chat_update_mapping_duplicate", "dataset_ids": []})
assert duplicate_res.status_code == 200
duplicate_payload = duplicate_res.json()
assert duplicate_payload["code"] == 0, duplicate_payload
target_res = rest_client.post("/chats", json={"name": "restful_chat_update_mapping_target", "dataset_ids": []})
assert target_res.status_code == 200
target_payload = target_res.json()
assert target_payload["code"] == 0, target_payload
chat_id = target_payload["data"]["id"]
unauthorized = rest_client.patch("/chats/invalid-chat-id", json={"name": "anything"})
assert unauthorized.status_code == 200
unauthorized_payload = unauthorized.json()
assert unauthorized_payload["code"] == 109, unauthorized_payload
assert unauthorized_payload["message"] == "No authorization.", unauthorized_payload
quote_res = rest_client.patch(f"/chats/{chat_id}", json={"prompt_config": {"quote": False}})
assert quote_res.status_code == 200
quote_payload = quote_res.json()
assert quote_payload["code"] == 0, quote_payload
assert quote_payload["data"]["prompt_config"]["quote"] is False, quote_payload
invalid_llm_res = rest_client.patch(
f"/chats/{chat_id}",
json={"llm_id": "unknown-llm-model", "llm_setting": {"model_type": "chat"}},
)
assert invalid_llm_res.status_code == 200
invalid_llm_payload = invalid_llm_res.json()
assert invalid_llm_payload["code"] == 102, invalid_llm_payload
assert "`llm_id` unknown-llm-model doesn't exist" in invalid_llm_payload["message"], invalid_llm_payload
invalid_rerank_res = rest_client.patch(f"/chats/{chat_id}", json={"rerank_id": "unknown-rerank-model"})
assert invalid_rerank_res.status_code == 200
invalid_rerank_payload = invalid_rerank_res.json()
assert invalid_rerank_payload["code"] == 102, invalid_rerank_payload
assert "`rerank_id` unknown-rerank-model doesn't exist" in invalid_rerank_payload["message"], invalid_rerank_payload
empty_name_res = rest_client.patch(f"/chats/{chat_id}", json={"name": ""})
assert empty_name_res.status_code == 200
empty_name_payload = empty_name_res.json()
assert empty_name_payload["code"] == 102, empty_name_payload
assert empty_name_payload["message"] == "`name` cannot be empty.", empty_name_payload
duplicate_name_res = rest_client.patch(f"/chats/{chat_id}", json={"name": "restful_chat_update_mapping_duplicate"})
assert duplicate_name_res.status_code == 200
duplicate_name_payload = duplicate_name_res.json()
assert duplicate_name_payload["code"] == 102, duplicate_name_payload
assert duplicate_name_payload["message"] == "Duplicated chat name.", duplicate_name_payload
prompt_without_placeholder_res = rest_client.patch(
f"/chats/{chat_id}",
json={"prompt_config": {"system": "No required placeholder", "parameters": [{"key": "knowledge", "optional": False}]}},
)
assert prompt_without_placeholder_res.status_code == 200
prompt_without_placeholder_payload = prompt_without_placeholder_res.json()
assert prompt_without_placeholder_payload["code"] == 0, prompt_without_placeholder_payload
icon_res = rest_client.patch(f"/chats/{chat_id}", json={"icon": "raw-avatar-value"})
assert icon_res.status_code == 200
icon_payload = icon_res.json()
assert icon_payload["code"] == 0, icon_payload
get_res = rest_client.get(f"/chats/{chat_id}")
assert get_res.status_code == 200
get_payload = get_res.json()
assert get_payload["code"] == 0, get_payload
assert get_payload["data"]["prompt_config"]["system"] == "No required placeholder", get_payload
assert get_payload["data"]["prompt_config"]["parameters"] == [{"key": "knowledge", "optional": False}], get_payload
assert get_payload["data"]["icon"] == "raw-avatar-value", get_payload
@pytest.mark.p2
def test_chat_update_rejects_unparsed_document(rest_client, clear_chats, create_document):
dataset_id, _ = create_document()
create_res = rest_client.post("/chats", json={"name": "restful_chat_update_unparsed_target", "dataset_ids": []})
assert create_res.status_code == 200, create_res.text
create_payload = create_res.json()
assert create_payload["code"] == 0, create_payload
chat_id = create_payload["data"]["id"]
res = rest_client.patch(f"/chats/{chat_id}", json={"dataset_ids": [dataset_id]})
assert res.status_code == 200, res.text
payload = res.json()
assert payload["code"] == 102, payload
assert "doesn't own parsed file" in payload["message"], payload