mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
fix: BaiduYiyan API key validation fails in set_api_key (#15828)
### What problem does this PR solve?
When setting the API key for the BaiduYiyan provider, all model
validations fail with the error "Fail to access model using this api
key. No valid response received".
**Root cause:**
1. `BaiduYiyanChat` in `rag/llm/chat_model.py` does not override
`async_chat_streamly()`. The `verify_api_key()` function uses
`mdl.async_chat_streamly()` to validate, but `BaiduYiyanChat` inherits
`Base.async_chat_streamly()` which uses the OpenAI client, not the Baidu
Qianfan SDK (qianfan). Since BaiduYiyan has no OpenAI-compatible
base_url, validation always fails.
2. `verify_api_key()` in `provider_api_service.py` does not format the
raw API key string into the JSON format (`{"yiyan_ak": "...",
"yiyan_sk": "..."}`) that `BaiduYiyanChat.__init__()` expects via
`json.loads(key)`.
**Fix:**
1. Add `async_chat_streamly()` method to `BaiduYiyanChat` using the
qianfan SDK, consistent with the existing `chat_streamly()` method.
2. Add BaiduYiyan API key formatting in `provider_api_service.py`
`verify_api_key()` to match the format expected by
`BaiduYiyanChat.__init__()`.
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- [ ] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):
This commit is contained in:
@@ -391,6 +391,12 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No
|
||||
timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10))
|
||||
extra = {"provider": provider_name}
|
||||
msg = ""
|
||||
if provider_name == "BaiduYiyan":
|
||||
if isinstance(api_key, str):
|
||||
try:
|
||||
json.loads(api_key)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
api_key = {"yiyan_ak": api_key, "yiyan_sk": ""}
|
||||
api_key_str = api_key if isinstance(api_key, str) else json.dumps(api_key)
|
||||
for llm in factory_llms:
|
||||
if not embd_passed and llm["model_type"] == LLMType.EMBEDDING.value:
|
||||
|
||||
@@ -1043,6 +1043,34 @@ class BaiduYiyanChat(Base):
|
||||
|
||||
yield total_tokens
|
||||
|
||||
async def async_chat_streamly(self, system, history, gen_conf: dict | None = None, **kwargs):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
|
||||
def _do_chat():
|
||||
system_msg = history[0]["content"] if history and history[0].get("role") == "system" else ""
|
||||
msgs = [h for h in history if h.get("role") != "system"]
|
||||
try:
|
||||
response = self.client.do(model=self.model_name, messages=msgs, system=system_msg, stream=True, **gen_conf)
|
||||
result_text = ""
|
||||
total_tokens = 0
|
||||
for resp in response:
|
||||
resp = resp.body
|
||||
result_text = resp["result"]
|
||||
total_tokens = total_token_count_from_response(resp)
|
||||
return result_text, total_tokens, None
|
||||
except Exception as e:
|
||||
return "", 0, e
|
||||
|
||||
result_text, total_tokens, error = await asyncio.to_thread(_do_chat)
|
||||
if error:
|
||||
yield f"**ERROR**: {error}"
|
||||
else:
|
||||
yield result_text
|
||||
yield total_tokens
|
||||
|
||||
|
||||
class GoogleChat(Base):
|
||||
_FACTORY_NAME = "Google Cloud"
|
||||
|
||||
Reference in New Issue
Block a user