diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index ef5f9f00b3..360a9ea6a5 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -287,24 +287,27 @@ class DialogService(CommonService): async def async_chat_solo(dialog, messages, stream=True, session_id=None): - llm_types = get_model_type_by_name(dialog.tenant_id, dialog.llm_id) attachments = "" image_attachments = [] image_files = [] - if "files" in messages[-1]: - if "chat" in llm_types: - text_attachments, image_attachments = split_file_attachments(messages[-1]["files"]) - else: - text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True) - attachments = "\n\n".join(text_attachments) if dialog.llm_id: - model_config = get_model_config_from_provider_instance(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) + llm_types = get_model_type_by_name(dialog.tenant_id, dialog.llm_id) + if "chat" in llm_types: + model_config = get_model_config_from_provider_instance(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) + else: + model_config = get_model_config_from_provider_instance(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) else: model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(dialog.tenant_id, model_config, langfuse_session_id=session_id) factory = model_config.get("llm_factory", "") if model_config else "" + if "files" in messages[-1]: + if model_config["model_type"] == "chat": + text_attachments, image_attachments = split_file_attachments(messages[-1]["files"]) + else: + text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True) + attachments = "\n\n".join(text_attachments) prompt_config = dialog.prompt_config tts_mdl = None @@ -314,10 +317,10 @@ async def async_chat_solo(dialog, messages, stream=True, session_id=None): msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"] if attachments and msg: msg[-1]["content"] += attachments - if "chat" in llm_types and image_attachments: + if model_config["model_type"] == "chat" and image_attachments: convert_last_user_msg_to_multimodal(msg, image_attachments, factory) if stream: - if "chat" in llm_types: + if model_config["model_type"] == "chat": stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting) else: stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting, images=image_files) @@ -328,7 +331,7 @@ async def async_chat_solo(dialog, messages, stream=True, session_id=None): continue yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "prompt": "", "created_at": time.time(), "final": False} else: - if "chat" in llm_types: + if model_config["model_type"] == "chat": answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting) else: answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting, images=image_files) @@ -552,10 +555,10 @@ async def async_chat(dialog, messages, stream=True, **kwargs): chat_start_ts = timer() if dialog.llm_id: llm_types = get_model_type_by_name(dialog.tenant_id, dialog.llm_id) - if "image2text" in llm_types: - llm_model_config = get_model_config_from_provider_instance(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) - else: + if "chat" in llm_types: llm_model_config = get_model_config_from_provider_instance(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) + else: + llm_model_config = get_model_config_from_provider_instance(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) else: llm_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)