Refa: improve model verification ux (#13392)

### What problem does this PR solve?

Improve model verification UX. #13395 

### Type of change

- [x] Refactoring

---------

Co-authored-by: Liu An <asiro@qq.com>
This commit is contained in:
Yongteng Lei
2026-03-05 17:23:47 +08:00
committed by GitHub
parent 3124fa955e
commit f13a1fb007
2 changed files with 39 additions and 20 deletions

View File

@@ -94,17 +94,21 @@ async def set_api_key():
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=base_url, **extra)
try:
m, tc = await asyncio.wait_for(
mdl.async_chat(
async def check_streamly():
async for chunk in mdl.async_chat_streamly(
None,
[{"role": "user", "content": "Hello! How are you doing!"}],
{"temperature": 0.9, "max_tokens": 50},
),
timeout=timeout_seconds,
)
if m.find("**ERROR**") >= 0:
raise Exception(m)
chat_passed = True
[{"role": "user", "content": "Hi"}],
{"temperature": 0.9},
):
if chunk and isinstance(chunk, str) and chunk.find("**ERROR**") < 0:
return True
return False
result = await asyncio.wait_for(check_streamly(), timeout=timeout_seconds)
if result:
chat_passed = True
else:
raise Exception("No valid response received")
except Exception as e:
msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e)
elif not rerank_passed and llm.model_type == LLMType.RERANK.value:
@@ -127,7 +131,7 @@ async def set_api_key():
if req.get("verify", False):
return get_json_result(data={"message": msg, "success": len(msg.strip())==0})
if msg:
return get_data_error_result(message=msg)
@@ -260,16 +264,19 @@ async def add_llm():
**extra,
)
try:
m, tc = await asyncio.wait_for(
mdl.async_chat(
async def check_streamly():
async for chunk in mdl.async_chat_streamly(
None,
[{"role": "user", "content": "Hello! How are you doing!"}],
[{"role": "user", "content": "Hi"}],
{"temperature": 0.9},
),
timeout=timeout_seconds,
)
if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m)
):
if chunk and isinstance(chunk, str) and chunk.find("**ERROR**:") < 0:
return True
return False
result = await asyncio.wait_for(check_streamly(), timeout=timeout_seconds)
if not result:
raise Exception("No valid response received")
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
@@ -339,7 +346,7 @@ async def add_llm():
if req.get("verify", False):
return get_json_result(data={"message": msg, "success": len(msg.strip()) == 0})
if msg:
return get_data_error_result(message=msg)

View File

@@ -514,6 +514,10 @@ def test_add_llm_factory_specific_key_assembly_unit(monkeypatch):
async def async_chat(self, *_args, **_kwargs):
return "ok", 1
async def async_chat_streamly(self, *_args, **_kwargs):
yield "ok"
yield 1
class _TTSOK:
def __init__(self, key, model_name, base_url="", **_kwargs):
captured["tts"].append((key, model_name, base_url))
@@ -679,6 +683,10 @@ def test_add_llm_model_type_probe_and_persistence_matrix_unit(monkeypatch):
async def async_chat(self, *_args, **_kwargs):
return "**ERROR**: chat failed", 0
async def async_chat_streamly(self, *_args, **_kwargs):
yield "**ERROR**: chat failed"
yield 0
class _ChatPass:
def __init__(self, *_args, **_kwargs):
pass
@@ -686,6 +694,10 @@ def test_add_llm_model_type_probe_and_persistence_matrix_unit(monkeypatch):
async def async_chat(self, *_args, **_kwargs):
return "ok", 1
async def async_chat_streamly(self, *_args, **_kwargs):
yield "ok"
yield 1
class _RerankFail:
def __init__(self, *_args, **_kwargs):
pass