diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index 3037e0c4a3..6896ce7a43 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -265,12 +265,15 @@ async def cross_languages(tenant_id, llm_id, query, languages=[]): from common.constants import LLMType from api.db.services.llm_service import LLMBundle from api.db.services.tenant_llm_service import TenantLLMService - from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name + from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name, get_tenant_default_model_by_type if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text": chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.IMAGE2TEXT, llm_id) else: - chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, llm_id) + if not llm_id: + chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) + else: + chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, llm_id) chat_mdl = LLMBundle(tenant_id, chat_model_config) rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render() rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query,