diff --git a/api/apps/restful_apis/provider_api.py b/api/apps/restful_apis/provider_api.py index d86628a1ce..b8aedfafae 100644 --- a/api/apps/restful_apis/provider_api.py +++ b/api/apps/restful_apis/provider_api.py @@ -336,11 +336,11 @@ async def create_provider_instance(tenant_id: str = None, provider_id_or_name: s type: object """ data = await request.get_json() - if not data or "instance_name" not in data or "api_key" not in data: + if not data or "instance_name" not in data or ("api_key" not in data and provider_id_or_name != "VLLM"): return get_error_argument_result(message="instance_name and api_key are required") instance_name = data["instance_name"] - api_key = data["api_key"] + api_key = data.get("api_key", "") base_url = data.get("base_url", "") region = data.get("region", "") model_info = data.get("model_info", []) @@ -405,11 +405,11 @@ async def verify_provider_api_key(provider_id_or_name: str = None): type: object """ data = await request.get_json() - if not data or "api_key" not in data: + if not data or ("api_key" not in data and provider_id_or_name != "VLLM"): return get_error_argument_result(message="api_key is required") base_url = data.get("base_url", "") - api_key = data["api_key"] + api_key = data.get("api_key", "") region = data.get("region", "default") model_info = data.get("model_info", []) diff --git a/api/apps/services/provider_api_service.py b/api/apps/services/provider_api_service.py index 031cdca773..4016f7a862 100644 --- a/api/apps/services/provider_api_service.py +++ b/api/apps/services/provider_api_service.py @@ -51,6 +51,11 @@ def _normalize_provider_base_url(provider_name: str, base_url: str | None): return base_url +def _normalize_provider_api_key(provider_name: str, api_key: str | dict | None): + if provider_name == "VLLM" and not api_key: + return "x" + return api_key + def _factory_llm_name(llm: dict) -> str: return llm.get("name") or llm.get("llm_name", "") @@ -219,6 +224,7 @@ async def list_provider_models(provider_id_or_name: str, api_key: str = None, ba factory_info = [f for f in FACTORY_LLM_INFOS if f["name"] == provider_name] if not factory_info: return False, f"Provider '{provider_id_or_name}' not found" + api_key = _normalize_provider_api_key(provider_name, api_key) static_llms = [{ "name": _factory_llm_name(llm), "max_tokens": llm["max_tokens"], @@ -318,6 +324,7 @@ async def create_provider_instance(tenant_id: str, provider_id_or_name: str, ins provider_name = provider_obj.provider_name base_url = _normalize_provider_base_url(provider_name, base_url) + api_key = _normalize_provider_api_key(provider_name, api_key) if instance_name == "default": return False, "Instance name cannot be 'default'" @@ -411,6 +418,7 @@ async def verify_api_key(provider_id_or_name: str, api_key: str|dict, base_url: provider_name = provider_obj.provider_name if provider_obj else provider_id_or_name base_url = _normalize_provider_base_url(provider_name, base_url) + api_key = _normalize_provider_api_key(provider_name, api_key) if region and region == "intl" and provider_name.lower() == "siliconflow": target_factory_name = "siliconflow_intl"