mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
@@ -25,7 +25,7 @@ from api.db.joint_services.tenant_model_service import get_model_config_from_pro
|
||||
from api.db.services.tenant_model_provider_service import TenantModelProviderService
|
||||
from api.db.services.tenant_model_instance_service import TenantModelInstanceService
|
||||
from api.db.services.tenant_model_service import TenantModelService
|
||||
from rag.llm import ChatModel, EmbeddingModel, ModelMeta, OcrModel, RerankModel
|
||||
from rag.llm import ChatModel, EmbeddingModel, ModelMeta, OcrModel, RerankModel, TTSModel
|
||||
|
||||
|
||||
def _to_int(v, default=500):
|
||||
@@ -402,7 +402,7 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
|
||||
return False, f"No valid models found for provider '{provider_name}'"
|
||||
|
||||
# test if api key works
|
||||
chat_passed, embd_passed, rerank_passed, ocr_passed = False, False, False, False
|
||||
chat_passed, embd_passed, rerank_passed, ocr_passed, tts_passed = False, False, False, False, False
|
||||
timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10))
|
||||
extra = {"provider": provider_name}
|
||||
msg = ""
|
||||
@@ -496,11 +496,31 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
|
||||
llm["llm_name"],
|
||||
)
|
||||
msg += f"\nFail to access model({provider_name}/{llm['llm_name']})." + str(e)
|
||||
if any([embd_passed, chat_passed, rerank_passed, ocr_passed]):
|
||||
elif not tts_passed and LLMType.TTS.value in model_types:
|
||||
assert provider_name in TTSModel, f"TTS model from {provider_name} is not supported yet."
|
||||
mdl = TTSModel[provider_name](key=api_key_str, model_name=llm["llm_name"], base_url=base_url)
|
||||
try:
|
||||
def drain_tts():
|
||||
for _ in mdl.tts("Hello~ RAGFlower!"):
|
||||
pass
|
||||
|
||||
await asyncio.wait_for(
|
||||
asyncio.to_thread(drain_tts),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
tts_passed = True
|
||||
except Exception as e:
|
||||
logging.exception(
|
||||
"Fail to access TTS model for provider=%s model=%s",
|
||||
provider_name,
|
||||
llm["llm_name"],
|
||||
)
|
||||
msg += f"\nFail to access model({provider_name}/{llm['llm_name']})." + str(e)
|
||||
if any([embd_passed, chat_passed, rerank_passed, ocr_passed, tts_passed]):
|
||||
msg = ""
|
||||
break
|
||||
|
||||
success = any([embd_passed, chat_passed, rerank_passed, ocr_passed])
|
||||
success = any([embd_passed, chat_passed, rerank_passed, ocr_passed, tts_passed])
|
||||
return success, "success" if success else msg
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user