diff --git a/api/db/joint_services/tenant_model_service.py b/api/db/joint_services/tenant_model_service.py index f53f83ab95..8e745d8e08 100644 --- a/api/db/joint_services/tenant_model_service.py +++ b/api/db/joint_services/tenant_model_service.py @@ -57,6 +57,14 @@ def get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_nam "api_base": embedding_cfg["base_url"], "model_type": LLMType.EMBEDDING.value, } + elif model_type_val == LLMType.CHAT.value: + # Retry as CHAT with pure_model_name first; then fall back to a multimodal model registered under IMAGE2TEXT. + model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, LLMType.CHAT.value) + if not model_config: + model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, LLMType.IMAGE2TEXT.value) + if not model_config: + raise LookupError(f"Tenant Model with name {model_name} and type {model_type_val} not found") + config_dict = model_config.to_dict() else: model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, model_type_val) if not model_config: @@ -67,7 +75,10 @@ def get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_nam config_dict = model_config.to_dict() config_model_type = config_dict.get("model_type") config_model_type = config_model_type.value if hasattr(config_model_type, "value") else config_model_type - if config_model_type != model_type_val: + if config_model_type != model_type_val and not ( + model_type_val == LLMType.CHAT.value + and config_model_type == LLMType.IMAGE2TEXT.value + ): raise LookupError( f"Tenant Model with name {model_name} has type {config_model_type}, expected {model_type_val}" )