diff --git a/api/apps/services/provider_api_service.py b/api/apps/services/provider_api_service.py index 0663eac7ec..110961ea7d 100644 --- a/api/apps/services/provider_api_service.py +++ b/api/apps/services/provider_api_service.py @@ -25,7 +25,7 @@ from api.db.joint_services.tenant_model_service import get_model_config_from_pro from api.db.services.tenant_model_provider_service import TenantModelProviderService from api.db.services.tenant_model_instance_service import TenantModelInstanceService from api.db.services.tenant_model_service import TenantModelService -from rag.llm import ChatModel, EmbeddingModel, ModelMeta, OcrModel, RerankModel +from rag.llm import ChatModel, EmbeddingModel, ModelMeta, OcrModel, RerankModel, TTSModel def _to_int(v, default=500): @@ -402,7 +402,7 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No return False, f"No valid models found for provider '{provider_name}'" # test if api key works - chat_passed, embd_passed, rerank_passed, ocr_passed = False, False, False, False + chat_passed, embd_passed, rerank_passed, ocr_passed, tts_passed = False, False, False, False, False timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10)) extra = {"provider": provider_name} msg = "" @@ -496,11 +496,31 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No llm["llm_name"], ) msg += f"\nFail to access model({provider_name}/{llm['llm_name']})." + str(e) - if any([embd_passed, chat_passed, rerank_passed, ocr_passed]): + elif not tts_passed and LLMType.TTS.value in model_types: + assert provider_name in TTSModel, f"TTS model from {provider_name} is not supported yet." + mdl = TTSModel[provider_name](key=api_key_str, model_name=llm["llm_name"], base_url=base_url) + try: + def drain_tts(): + for _ in mdl.tts("Hello~ RAGFlower!"): + pass + + await asyncio.wait_for( + asyncio.to_thread(drain_tts), + timeout=timeout_seconds, + ) + tts_passed = True + except Exception as e: + logging.exception( + "Fail to access TTS model for provider=%s model=%s", + provider_name, + llm["llm_name"], + ) + msg += f"\nFail to access model({provider_name}/{llm['llm_name']})." + str(e) + if any([embd_passed, chat_passed, rerank_passed, ocr_passed, tts_passed]): msg = "" break - success = any([embd_passed, chat_passed, rerank_passed, ocr_passed]) + success = any([embd_passed, chat_passed, rerank_passed, ocr_passed, tts_passed]) return success, "success" if success else msg