mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Feat: add new tests and tescases for restful api suite (#14996)
### What problem does this PR solve? extend restful api suite ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Other (please describe): test
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -70,6 +70,14 @@ def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _passthrough_login_required(func):
|
||||
async def _wrapper(*args, **kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
_wrapper.__wrapped__ = func
|
||||
return _wrapper
|
||||
|
||||
|
||||
def _set_request_json(monkeypatch, module, payload):
|
||||
async def _request_json():
|
||||
return payload
|
||||
@@ -1380,3 +1388,241 @@ def test_forget_reset_password_matrix_unit(monkeypatch):
|
||||
assert res["code"] == module.RetCode.SUCCESS, res
|
||||
assert res["auth"] == user.get_id(), res
|
||||
assert module.REDIS_CONN.get(v_key) is None, module.REDIS_CONN.store
|
||||
|
||||
|
||||
def _load_chat_routes_unit_module(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
module_name = "test_chat_restful_routes_unit_module_for_tenant"
|
||||
module_path = repo_root / "api" / "apps" / "restful_apis" / "chat_api.py"
|
||||
|
||||
quart_mod = ModuleType("quart")
|
||||
quart_mod.request = SimpleNamespace(args=SimpleNamespace(get=lambda _key, default=None: default, getlist=lambda _key: []))
|
||||
quart_mod.Response = type("_StubResponse", (), {})
|
||||
monkeypatch.setitem(sys.modules, "quart", quart_mod)
|
||||
|
||||
api_pkg = ModuleType("api")
|
||||
api_pkg.__path__ = [str(repo_root / "api")]
|
||||
monkeypatch.setitem(sys.modules, "api", api_pkg)
|
||||
|
||||
apps_pkg = ModuleType("api.apps")
|
||||
apps_pkg.__path__ = [str(repo_root / "api" / "apps")]
|
||||
apps_pkg.current_user = SimpleNamespace(id="tenant-1")
|
||||
apps_pkg.login_required = _passthrough_login_required
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_pkg)
|
||||
api_pkg.apps = apps_pkg
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
settings_mod = ModuleType("common.settings")
|
||||
settings_mod.STORAGE_IMPL = type("_StorageImpl", (), {"rm": staticmethod(lambda *_args, **_kwargs: None)})()
|
||||
monkeypatch.setitem(sys.modules, "common.settings", settings_mod)
|
||||
|
||||
constants_mod = ModuleType("common.constants")
|
||||
constants_mod.LLMType = SimpleNamespace(CHAT="chat", IMAGE2TEXT="image2text", RERANK="rerank", SPEECH2TEXT="speech2text", TTS="tts")
|
||||
constants_mod.RetCode = SimpleNamespace(SUCCESS=0, DATA_ERROR=102, OPERATING_ERROR=103, AUTHENTICATION_ERROR=109)
|
||||
constants_mod.StatusEnum = SimpleNamespace(VALID=SimpleNamespace(value="1"), INVALID=SimpleNamespace(value="0"))
|
||||
from common.constants import MAXIMUM_PAGE_NUMBER as _MPN, MAXIMUM_TASK_PAGE_NUMBER as _MTPN
|
||||
constants_mod.MAXIMUM_PAGE_NUMBER = _MPN
|
||||
constants_mod.MAXIMUM_TASK_PAGE_NUMBER = _MTPN
|
||||
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
|
||||
|
||||
misc_utils_mod = ModuleType("common.misc_utils")
|
||||
misc_utils_mod.get_uuid = lambda: "generated-chat-id"
|
||||
async def _thread_pool_exec(func, *args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
misc_utils_mod.thread_pool_exec = _thread_pool_exec
|
||||
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
|
||||
|
||||
dialog_service_mod = ModuleType("api.db.services.dialog_service")
|
||||
class _DialogService:
|
||||
model = SimpleNamespace(_meta=SimpleNamespace(fields={
|
||||
"id": None,
|
||||
"tenant_id": None,
|
||||
"name": None,
|
||||
"description": None,
|
||||
"icon": None,
|
||||
"kb_ids": None,
|
||||
"llm_id": None,
|
||||
"llm_setting": None,
|
||||
"prompt_config": None,
|
||||
"similarity_threshold": None,
|
||||
"vector_similarity_weight": None,
|
||||
"top_n": None,
|
||||
"top_k": None,
|
||||
"rerank_id": None,
|
||||
"meta_data_filter": None,
|
||||
"created_by": None,
|
||||
"create_time": None,
|
||||
"create_date": None,
|
||||
"update_time": None,
|
||||
"update_date": None,
|
||||
"status": None,
|
||||
}))
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return []
|
||||
@staticmethod
|
||||
def save(**_kwargs):
|
||||
return True
|
||||
@staticmethod
|
||||
def get_by_id(_chat_id):
|
||||
return False, None
|
||||
@staticmethod
|
||||
def get_by_tenant_ids(*_args, **_kwargs):
|
||||
return [], 0
|
||||
dialog_service_mod.DialogService = _DialogService
|
||||
dialog_service_mod.async_ask = lambda *_args, **_kwargs: None
|
||||
dialog_service_mod.async_chat = lambda *_args, **_kwargs: None
|
||||
dialog_service_mod.gen_mindmap = lambda *_args, **_kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.dialog_service", dialog_service_mod)
|
||||
|
||||
conversation_service_mod = ModuleType("api.db.services.conversation_service")
|
||||
conversation_service_mod.ConversationService = type("ConversationService", (), {})
|
||||
conversation_service_mod.structure_answer = lambda *_args, **_kwargs: {}
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.conversation_service", conversation_service_mod)
|
||||
|
||||
kb_service_mod = ModuleType("api.db.services.knowledgebase_service")
|
||||
class _KB:
|
||||
def __init__(self):
|
||||
self.id = "kb-1"
|
||||
self.embd_id = "embd@factory"
|
||||
self.chunk_num = 1
|
||||
self.name = "Dataset A"
|
||||
self.status = "1"
|
||||
kb_service_mod.KnowledgebaseService = type('KnowledgebaseService', (), {
|
||||
'accessible': staticmethod(lambda **_kwargs: [SimpleNamespace(id='kb-1')]),
|
||||
'query': staticmethod(lambda **_kwargs: [_KB()]),
|
||||
'get_by_id': staticmethod(lambda _id: (True, _KB())),
|
||||
})
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod)
|
||||
|
||||
tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service")
|
||||
tenant_llm_service_mod.TenantLLMService = type('TenantLLMService', (), {
|
||||
'split_model_name_and_factory': staticmethod(lambda model: (model.split('@', 1)[0], model.split('@', 1)[1] if '@' in model else None)),
|
||||
'query': staticmethod(lambda **_kwargs: [SimpleNamespace(id='llm-1')]),
|
||||
'get_api_key': staticmethod(lambda *_args, **_kwargs: SimpleNamespace(id=1)),
|
||||
})
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)
|
||||
|
||||
llm_service_mod = ModuleType("api.db.services.llm_service")
|
||||
llm_service_mod.LLMBundle = lambda *_args, **_kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod)
|
||||
|
||||
search_service_mod = ModuleType("api.db.services.search_service")
|
||||
search_service_mod.SearchService = SimpleNamespace()
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.search_service", search_service_mod)
|
||||
|
||||
tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service")
|
||||
tenant_model_service_mod.get_model_config_by_type_and_name = lambda *_args, **_kwargs: {}
|
||||
tenant_model_service_mod.get_tenant_default_model_by_type = lambda *_args, **_kwargs: {}
|
||||
monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod)
|
||||
|
||||
user_service_mod = ModuleType("api.db.services.user_service")
|
||||
user_service_mod.UserService = type('UserService', (), {})
|
||||
user_service_mod.TenantService = type('TenantService', (), {
|
||||
'get_by_id': staticmethod(lambda _tenant_id: (True, SimpleNamespace(llm_id='glm-4'))),
|
||||
'get_joined_tenants_by_user_id': staticmethod(lambda _user_id: [{'tenant_id': 'tenant-1'}, {'tenant_id': 'team-tenant-2'}]),
|
||||
})
|
||||
user_service_mod.UserTenantService = type('UserTenantService', (), {'query': staticmethod(lambda **_kwargs: [])})
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
|
||||
|
||||
chunk_feedback_service_mod = ModuleType("api.db.services.chunk_feedback_service")
|
||||
chunk_feedback_service_mod.ChunkFeedbackService = type('ChunkFeedbackService', (), {'apply_feedback': staticmethod(lambda **_kwargs: {'success_count': 0, 'fail_count': 0, 'chunk_ids': []})})
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.chunk_feedback_service", chunk_feedback_service_mod)
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
api_utils_mod.check_duplicate_ids = lambda ids, _label: (list(dict.fromkeys(ids or [])), [])
|
||||
api_utils_mod.get_data_error_result = lambda message='': {'code': 102, 'data': None, 'message': message}
|
||||
api_utils_mod.get_json_result = lambda data=None, message='', code=0: {'code': code, 'data': data, 'message': message}
|
||||
api_utils_mod.server_error_response = lambda ex: {'code': 500, 'data': None, 'message': str(ex)}
|
||||
api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda func: func)
|
||||
api_utils_mod.get_request_json = lambda: _AwaitableValue({})
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
|
||||
tenant_utils_mod = ModuleType("api.utils.tenant_utils")
|
||||
tenant_utils_mod.ensure_tenant_model_id_for_params = lambda _tenant_id, req: req
|
||||
monkeypatch.setitem(sys.modules, "api.utils.tenant_utils", tenant_utils_mod)
|
||||
|
||||
rag_pkg = ModuleType("rag")
|
||||
rag_pkg.__path__ = [str(repo_root / 'rag')]
|
||||
monkeypatch.setitem(sys.modules, 'rag', rag_pkg)
|
||||
rag_prompts_pkg = ModuleType('rag.prompts')
|
||||
rag_prompts_pkg.__path__ = [str(repo_root / 'rag' / 'prompts')]
|
||||
monkeypatch.setitem(sys.modules, 'rag.prompts', rag_prompts_pkg)
|
||||
rag_prompts_generator_mod = ModuleType('rag.prompts.generator')
|
||||
rag_prompts_generator_mod.chunks_format = lambda reference: reference.get('chunks', []) if isinstance(reference, dict) else []
|
||||
monkeypatch.setitem(sys.modules, 'rag.prompts.generator', rag_prompts_generator_mod)
|
||||
rag_prompts_template_mod = ModuleType('rag.prompts.template')
|
||||
rag_prompts_template_mod.load_prompt = lambda *_args, **_kwargs: ''
|
||||
monkeypatch.setitem(sys.modules, 'rag.prompts.template', rag_prompts_template_mod)
|
||||
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, module_name, module)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_create_chat_uses_tenant_default_llm_when_llm_id_is_null_unit(monkeypatch):
|
||||
module = _load_chat_routes_unit_module(monkeypatch)
|
||||
saved = {}
|
||||
|
||||
async def _request_json():
|
||||
return {
|
||||
'name': 'chat-a',
|
||||
'dataset_ids': ['kb-1'],
|
||||
'llm_id': None,
|
||||
'llm_setting': {'temperature': 0.8},
|
||||
'prompt_config': {'system': 'Answer with {knowledge}', 'parameters': [{'key': 'knowledge', 'optional': False}]},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(module, 'get_request_json', _request_json)
|
||||
monkeypatch.setattr(module.DialogService, 'query', lambda **_kwargs: [])
|
||||
|
||||
def _save(**kwargs):
|
||||
saved.update(kwargs)
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(module.DialogService, 'save', _save)
|
||||
monkeypatch.setattr(module.DialogService, 'get_by_id', lambda _id: (True, SimpleNamespace(to_dict=lambda: saved)))
|
||||
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert res['code'] == 0
|
||||
assert saved['llm_id'] == 'glm-4'
|
||||
assert saved['llm_setting']['temperature'] == 0.8
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_list_chats_authorized_multi_tenant_unit(monkeypatch):
|
||||
module = _load_chat_routes_unit_module(monkeypatch)
|
||||
captured = {}
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
'request',
|
||||
SimpleNamespace(
|
||||
args=SimpleNamespace(
|
||||
get=lambda key, default=None: {
|
||||
'keywords': '', 'page': '1', 'page_size': '10', 'orderby': 'create_time', 'desc': 'true', 'id': None, 'name': None,
|
||||
}.get(key, default),
|
||||
getlist=lambda key: ['tenant-1', 'team-tenant-2'] if key == 'owner_ids' else [],
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def _get_by_tenant_ids(owner_ids, user_id, *args, **kwargs):
|
||||
captured['owner_ids'] = owner_ids
|
||||
captured['user_id'] = user_id
|
||||
return ([{'id': 'c1', 'tenant_id': 'tenant-1'}, {'id': 'c2', 'tenant_id': 'team-tenant-2'}], 2)
|
||||
|
||||
monkeypatch.setattr(module.DialogService, 'get_by_tenant_ids', _get_by_tenant_ids)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, 'get_by_id', lambda _id: (False, None))
|
||||
res = _run(module.list_chats.__wrapped__())
|
||||
assert res['code'] == 0
|
||||
assert res['data']['total'] == 2
|
||||
assert {c['id'] for c in res['data']['chats']} == {'c1', 'c2'}
|
||||
assert set(captured['owner_ids']) == {'tenant-1', 'team-tenant-2'}
|
||||
assert captured['user_id'] == 'tenant-1'
|
||||
|
||||
Reference in New Issue
Block a user