diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 6fa9a8e3d8..b15d820c23 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -94,17 +94,21 @@ async def set_api_key(): assert factory in ChatModel, f"Chat model from {factory} is not supported yet." mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=base_url, **extra) try: - m, tc = await asyncio.wait_for( - mdl.async_chat( + async def check_streamly(): + async for chunk in mdl.async_chat_streamly( None, - [{"role": "user", "content": "Hello! How are you doing!"}], - {"temperature": 0.9, "max_tokens": 50}, - ), - timeout=timeout_seconds, - ) - if m.find("**ERROR**") >= 0: - raise Exception(m) - chat_passed = True + [{"role": "user", "content": "Hi"}], + {"temperature": 0.9}, + ): + if chunk and isinstance(chunk, str) and chunk.find("**ERROR**") < 0: + return True + return False + + result = await asyncio.wait_for(check_streamly(), timeout=timeout_seconds) + if result: + chat_passed = True + else: + raise Exception("No valid response received") except Exception as e: msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e) elif not rerank_passed and llm.model_type == LLMType.RERANK.value: @@ -127,7 +131,7 @@ async def set_api_key(): if req.get("verify", False): return get_json_result(data={"message": msg, "success": len(msg.strip())==0}) - + if msg: return get_data_error_result(message=msg) @@ -260,16 +264,19 @@ async def add_llm(): **extra, ) try: - m, tc = await asyncio.wait_for( - mdl.async_chat( + async def check_streamly(): + async for chunk in mdl.async_chat_streamly( None, - [{"role": "user", "content": "Hello! How are you doing!"}], + [{"role": "user", "content": "Hi"}], {"temperature": 0.9}, - ), - timeout=timeout_seconds, - ) - if not tc and m.find("**ERROR**:") >= 0: - raise Exception(m) + ): + if chunk and isinstance(chunk, str) and chunk.find("**ERROR**:") < 0: + return True + return False + + result = await asyncio.wait_for(check_streamly(), timeout=timeout_seconds) + if not result: + raise Exception("No valid response received") except Exception as e: msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) @@ -339,7 +346,7 @@ async def add_llm(): if req.get("verify", False): return get_json_result(data={"message": msg, "success": len(msg.strip()) == 0}) - + if msg: return get_data_error_result(message=msg) diff --git a/test/testcases/test_web_api/test_llm_app/test_llm_list_unit.py b/test/testcases/test_web_api/test_llm_app/test_llm_list_unit.py index 18179b8592..6253a5ade0 100644 --- a/test/testcases/test_web_api/test_llm_app/test_llm_list_unit.py +++ b/test/testcases/test_web_api/test_llm_app/test_llm_list_unit.py @@ -514,6 +514,10 @@ def test_add_llm_factory_specific_key_assembly_unit(monkeypatch): async def async_chat(self, *_args, **_kwargs): return "ok", 1 + async def async_chat_streamly(self, *_args, **_kwargs): + yield "ok" + yield 1 + class _TTSOK: def __init__(self, key, model_name, base_url="", **_kwargs): captured["tts"].append((key, model_name, base_url)) @@ -679,6 +683,10 @@ def test_add_llm_model_type_probe_and_persistence_matrix_unit(monkeypatch): async def async_chat(self, *_args, **_kwargs): return "**ERROR**: chat failed", 0 + async def async_chat_streamly(self, *_args, **_kwargs): + yield "**ERROR**: chat failed" + yield 0 + class _ChatPass: def __init__(self, *_args, **_kwargs): pass @@ -686,6 +694,10 @@ def test_add_llm_model_type_probe_and_persistence_matrix_unit(monkeypatch): async def async_chat(self, *_args, **_kwargs): return "ok", 1 + async def async_chat_streamly(self, *_args, **_kwargs): + yield "ok" + yield 1 + class _RerankFail: def __init__(self, *_args, **_kwargs): pass