diff --git a/api/apps/services/provider_api_service.py b/api/apps/services/provider_api_service.py index 909517bbc1..69df40fd67 100644 --- a/api/apps/services/provider_api_service.py +++ b/api/apps/services/provider_api_service.py @@ -391,6 +391,12 @@ async def verify_api_key(provider_name: str, api_key: str|dict, base_url: str=No timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10)) extra = {"provider": provider_name} msg = "" + if provider_name == "BaiduYiyan": + if isinstance(api_key, str): + try: + json.loads(api_key) + except (json.JSONDecodeError, TypeError): + api_key = {"yiyan_ak": api_key, "yiyan_sk": ""} api_key_str = api_key if isinstance(api_key, str) else json.dumps(api_key) for llm in factory_llms: if not embd_passed and llm["model_type"] == LLMType.EMBEDDING.value: diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 20d5a12dcd..0396d4a267 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1043,6 +1043,34 @@ class BaiduYiyanChat(Base): yield total_tokens + async def async_chat_streamly(self, system, history, gen_conf: dict | None = None, **kwargs): + gen_conf = dict(gen_conf or {}) + gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1 + if "max_tokens" in gen_conf: + del gen_conf["max_tokens"] + + def _do_chat(): + system_msg = history[0]["content"] if history and history[0].get("role") == "system" else "" + msgs = [h for h in history if h.get("role") != "system"] + try: + response = self.client.do(model=self.model_name, messages=msgs, system=system_msg, stream=True, **gen_conf) + result_text = "" + total_tokens = 0 + for resp in response: + resp = resp.body + result_text = resp["result"] + total_tokens = total_token_count_from_response(resp) + return result_text, total_tokens, None + except Exception as e: + return "", 0, e + + result_text, total_tokens, error = await asyncio.to_thread(_do_chat) + if error: + yield f"**ERROR**: {error}" + else: + yield result_text + yield total_tokens + class GoogleChat(Base): _FACTORY_NAME = "Google Cloud"