From f97d6396b47140d2f884fa3bc8fcf646cbd7aba3 Mon Sep 17 00:00:00 2001 From: euvre <93761161+euvre@users.noreply.github.com> Date: Tue, 9 Jun 2026 04:05:58 -0700 Subject: [PATCH] fix: BaiduYiyan API key validation fails in set_api_key (#15828) ### What problem does this PR solve? When setting the API key for the BaiduYiyan provider, all model validations fail with the error "Fail to access model using this api key. No valid response received". **Root cause:** 1. `BaiduYiyanChat` in `rag/llm/chat_model.py` does not override `async_chat_streamly()`. The `verify_api_key()` function uses `mdl.async_chat_streamly()` to validate, but `BaiduYiyanChat` inherits `Base.async_chat_streamly()` which uses the OpenAI client, not the Baidu Qianfan SDK (qianfan). Since BaiduYiyan has no OpenAI-compatible base_url, validation always fails. 2. `verify_api_key()` in `provider_api_service.py` does not format the raw API key string into the JSON format (`{"yiyan_ak": "...", "yiyan_sk": "..."}`) that `BaiduYiyanChat.__init__()` expects via `json.loads(key)`. **Fix:** 1. Add `async_chat_streamly()` method to `BaiduYiyanChat` using the qianfan SDK, consistent with the existing `chat_streamly()` method. 2. Add BaiduYiyan API key formatting in `provider_api_service.py` `verify_api_key()` to match the format expected by `BaiduYiyanChat.__init__()`. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [ ] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe): --- api/apps/services/provider_api_service.py | 6 +++++ rag/llm/chat_model.py | 28 +++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/api/apps/services/provider_api_service.py b/api/apps/services/provider_api_service.py index 909517bbc1..69df40fd67 100644 --- a/api/apps/services/provider_api_service.py +++ b/api/apps/services/provider_api_service.py @@ -391,6 +391,12 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10)) extra = {"provider": provider_name} msg = "" + if provider_name == "BaiduYiyan": + if isinstance(api_key, str): + try: + json.loads(api_key) + except (json.JSONDecodeError, TypeError): + api_key = {"yiyan_ak": api_key, "yiyan_sk": ""} api_key_str = api_key if isinstance(api_key, str) else json.dumps(api_key) for llm in factory_llms: if not embd_passed and llm["model_type"] == LLMType.EMBEDDING.value: diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 20d5a12dcd..0396d4a267 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1043,6 +1043,34 @@ class BaiduYiyanChat(Base): yield total_tokens + async def async_chat_streamly(self, system, history, gen_conf: dict | None = None, **kwargs): + gen_conf = dict(gen_conf or {}) + gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1 + if "max_tokens" in gen_conf: + del gen_conf["max_tokens"] + + def _do_chat(): + system_msg = history[0]["content"] if history and history[0].get("role") == "system" else "" + msgs = [h for h in history if h.get("role") != "system"] + try: + response = self.client.do(model=self.model_name, messages=msgs, system=system_msg, stream=True, **gen_conf) + result_text = "" + total_tokens = 0 + for resp in response: + resp = resp.body + result_text = resp["result"] + total_tokens = total_token_count_from_response(resp) + return result_text, total_tokens, None + except Exception as e: + return "", 0, e + + result_text, total_tokens, error = await asyncio.to_thread(_do_chat) + if error: + yield f"**ERROR**: {error}" + else: + yield result_text + yield total_tokens + class GoogleChat(Base): _FACTORY_NAME = "Google Cloud"