mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
### What problem does this PR solve?
When setting the API key for the BaiduYiyan provider, all model
validations fail with the error "Fail to access model using this api
key. No valid response received".
**Root cause:**
1. `BaiduYiyanChat` in `rag/llm/chat_model.py` does not override
`async_chat_streamly()`. The `verify_api_key()` function uses
`mdl.async_chat_streamly()` to validate, but `BaiduYiyanChat` inherits
`Base.async_chat_streamly()` which uses the OpenAI client, not the Baidu
Qianfan SDK (qianfan). Since BaiduYiyan has no OpenAI-compatible
base_url, validation always fails.
2. `verify_api_key()` in `provider_api_service.py` does not format the
raw API key string into the JSON format (`{"yiyan_ak": "...",
"yiyan_sk": "..."}`) that `BaiduYiyanChat.__init__()` expects via
`json.loads(key)`.
**Fix:**
1. Add `async_chat_streamly()` method to `BaiduYiyanChat` using the
qianfan SDK, consistent with the existing `chat_streamly()` method.
2. Add BaiduYiyan API key formatting in `provider_api_service.py`
`verify_api_key()` to match the format expected by
`BaiduYiyanChat.__init__()`.
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- [ ] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):
2030 lines
83 KiB
Python
2030 lines
83 KiB
Python
#
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import re
|
|
import time
|
|
from abc import ABC
|
|
from copy import deepcopy
|
|
from urllib.parse import urljoin
|
|
|
|
import json_repair
|
|
from json.decoder import JSONDecodeError
|
|
import litellm
|
|
import openai
|
|
from openai import AsyncOpenAI, OpenAI
|
|
from enum import StrEnum
|
|
|
|
from common.misc_utils import thread_pool_exec
|
|
from common.token_utils import num_tokens_from_string, total_token_count_from_response
|
|
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
|
|
from rag.llm.tool_decorator import FunctionToolSession, is_tool
|
|
from rag.nlp import is_chinese, is_english
|
|
|
|
|
|
class LLMErrorCode(StrEnum):
|
|
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
|
|
ERROR_AUTHENTICATION = "AUTH_ERROR"
|
|
ERROR_INVALID_REQUEST = "INVALID_REQUEST"
|
|
ERROR_SERVER = "SERVER_ERROR"
|
|
ERROR_TIMEOUT = "TIMEOUT"
|
|
ERROR_CONNECTION = "CONNECTION_ERROR"
|
|
ERROR_MODEL = "MODEL_ERROR"
|
|
ERROR_MAX_ROUNDS = "ERROR_MAX_ROUNDS"
|
|
ERROR_CONTENT_FILTER = "CONTENT_FILTERED"
|
|
ERROR_QUOTA = "QUOTA_EXCEEDED"
|
|
ERROR_MAX_RETRIES = "MAX_RETRIES_EXCEEDED"
|
|
ERROR_GENERIC = "GENERIC_ERROR"
|
|
|
|
|
|
class ReActMode(StrEnum):
|
|
FUNCTION_CALL = "function_call"
|
|
REACT = "react"
|
|
|
|
|
|
ERROR_PREFIX = "**ERROR**"
|
|
LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小限制,回答已经被大模型截断。"
|
|
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
|
|
|
|
# Generation parameters that are safe to forward to the underlying completion
|
|
# call. `gen_conf` originates from a chat assistant's `llm_setting`, which can
|
|
# also carry RAGFlow-internal metadata (e.g. `model_type`). Anything outside
|
|
# this set is dropped so providers don't reject the request with errors like
|
|
# "Extra inputs are not permitted" / "Unknown parameter: 'model_type'" (#15427).
|
|
ALLOWED_GEN_CONF_KEYS = frozenset(
|
|
{
|
|
"temperature",
|
|
"max_completion_tokens",
|
|
"top_p",
|
|
"stream",
|
|
"stream_options",
|
|
"stop",
|
|
"n",
|
|
"presence_penalty",
|
|
"frequency_penalty",
|
|
"functions",
|
|
"function_call",
|
|
"logit_bias",
|
|
"user",
|
|
"response_format",
|
|
"seed",
|
|
"tools",
|
|
"tool_choice",
|
|
"logprobs",
|
|
"top_logprobs",
|
|
"extra_headers",
|
|
}
|
|
)
|
|
|
|
# LiteLLM additionally understands reasoning-control parameters that the
|
|
# model-family policies may inject into `gen_conf` (e.g. `thinking` for
|
|
# Anthropic / Kimi reasoning models, `reasoning_effort` for OpenAI o-series).
|
|
LITELLM_ALLOWED_GEN_CONF_KEYS = ALLOWED_GEN_CONF_KEYS | frozenset(
|
|
{
|
|
"thinking",
|
|
"reasoning_effort",
|
|
"extra_body",
|
|
}
|
|
)
|
|
|
|
|
|
def _apply_model_family_policies(
|
|
model_name: str,
|
|
*,
|
|
backend: str,
|
|
provider: SupportedLiteLLMProvider | str | None = None,
|
|
gen_conf: dict | None = None,
|
|
request_kwargs: dict | None = None,
|
|
):
|
|
model_name_lower = (model_name or "").lower()
|
|
sanitized_gen_conf = deepcopy(gen_conf) if gen_conf else {}
|
|
sanitized_kwargs = dict(request_kwargs) if request_kwargs else {}
|
|
|
|
# Qwen3 family disables thinking by extra_body on non-stream chat requests.
|
|
if "qwen3" in model_name_lower:
|
|
sanitized_kwargs["extra_body"] = {"enable_thinking": False}
|
|
|
|
if backend == "base":
|
|
return sanitized_gen_conf, sanitized_kwargs
|
|
|
|
if backend == "litellm":
|
|
if provider in {SupportedLiteLLMProvider.OpenAI, SupportedLiteLLMProvider.Azure_OpenAI} and "gpt-5" in model_name_lower:
|
|
for key in ("temperature", "top_p", "logprobs", "top_logprobs"):
|
|
sanitized_gen_conf.pop(key, None)
|
|
sanitized_kwargs.pop(key, None)
|
|
elif provider == SupportedLiteLLMProvider.Anthropic and model_name_lower in {"claude-opus-4-7", "claude-opus-4-8"}:
|
|
for key in ("temperature", "top_p", "top_k"):
|
|
sanitized_gen_conf.pop(key, None)
|
|
sanitized_kwargs.pop(key, None)
|
|
|
|
if provider == SupportedLiteLLMProvider.HunYuan:
|
|
for key in ("presence_penalty", "frequency_penalty"):
|
|
sanitized_gen_conf.pop(key, None)
|
|
elif "kimi-k2.5" in model_name_lower or "kimi-k2.6" in model_name_lower:
|
|
reasoning = sanitized_gen_conf.pop("reasoning", None)
|
|
thinking = {"type": "enabled"}
|
|
if reasoning is not None:
|
|
thinking = {"type": "enabled"} if reasoning else {"type": "disabled"}
|
|
elif not isinstance(thinking, dict) or thinking.get("type") not in {"enabled", "disabled"}:
|
|
thinking = {"type": "disabled"}
|
|
sanitized_gen_conf["thinking"] = thinking
|
|
|
|
thinking_enabled = thinking.get("type") == "enabled"
|
|
sanitized_gen_conf["temperature"] = 1.0 if thinking_enabled else 0.6
|
|
sanitized_gen_conf["top_p"] = 0.95
|
|
sanitized_gen_conf["n"] = 1
|
|
sanitized_gen_conf["presence_penalty"] = 0.0
|
|
sanitized_gen_conf["frequency_penalty"] = 0.0
|
|
|
|
return sanitized_gen_conf, sanitized_kwargs
|
|
|
|
return sanitized_gen_conf, sanitized_kwargs
|
|
|
|
|
|
class Base(ABC):
|
|
def __init__(self, key, model_name, base_url, **kwargs):
|
|
timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600))
|
|
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
|
|
self.async_client = AsyncOpenAI(api_key=key, base_url=base_url, timeout=timeout)
|
|
self.model_name = model_name
|
|
# Configure retry parameters
|
|
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
|
|
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
|
|
self.max_rounds = kwargs.get("max_rounds", 5)
|
|
self.is_tools = False
|
|
self.tools = []
|
|
self.toolcall_sessions = {}
|
|
|
|
def _get_delay(self):
|
|
return self.base_delay * random.uniform(10, 150)
|
|
|
|
def _classify_error(self, error):
|
|
error_str = str(error).lower()
|
|
|
|
keywords_mapping = [
|
|
(["quota", "capacity", "credit", "billing", "balance", "欠费"], LLMErrorCode.ERROR_QUOTA),
|
|
(["rate limit", "429", "tpm limit", "too many requests", "requests per minute"], LLMErrorCode.ERROR_RATE_LIMIT),
|
|
(["auth", "key", "apikey", "401", "forbidden", "permission"], LLMErrorCode.ERROR_AUTHENTICATION),
|
|
(["invalid", "bad request", "400", "format", "malformed", "parameter"], LLMErrorCode.ERROR_INVALID_REQUEST),
|
|
(["server", "503", "502", "504", "500", "unavailable"], LLMErrorCode.ERROR_SERVER),
|
|
(["timeout", "timed out"], LLMErrorCode.ERROR_TIMEOUT),
|
|
(["connect", "network", "unreachable", "dns"], LLMErrorCode.ERROR_CONNECTION),
|
|
(["filter", "content", "policy", "blocked", "safety", "inappropriate"], LLMErrorCode.ERROR_CONTENT_FILTER),
|
|
(["model", "not found", "does not exist", "not available"], LLMErrorCode.ERROR_MODEL),
|
|
(["max rounds"], LLMErrorCode.ERROR_MODEL),
|
|
]
|
|
for words, code in keywords_mapping:
|
|
if re.search("({})".format("|".join(words)), error_str):
|
|
return code
|
|
|
|
return LLMErrorCode.ERROR_GENERIC
|
|
|
|
def _clean_conf(self, gen_conf):
|
|
gen_conf, _ = _apply_model_family_policies(
|
|
self.model_name,
|
|
backend="base",
|
|
gen_conf=gen_conf,
|
|
)
|
|
|
|
if "max_tokens" in gen_conf:
|
|
del gen_conf["max_tokens"]
|
|
|
|
gen_conf = {k: v for k, v in gen_conf.items() if k in ALLOWED_GEN_CONF_KEYS}
|
|
return gen_conf
|
|
|
|
async def _async_chat_streamly(self, history, gen_conf, **kwargs):
|
|
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
|
reasoning_start = False
|
|
|
|
request_kwargs = {"model": self.model_name, "messages": history, "stream": True, **gen_conf}
|
|
stop = kwargs.get("stop")
|
|
if stop:
|
|
request_kwargs["stop"] = stop
|
|
|
|
response = await self.async_client.chat.completions.create(**request_kwargs)
|
|
async for resp in response:
|
|
if not resp.choices:
|
|
continue
|
|
if not resp.choices[0].delta.content:
|
|
resp.choices[0].delta.content = ""
|
|
_reasoning = getattr(resp.choices[0].delta, "reasoning_content", None) or getattr(resp.choices[0].delta, "reasoning", None)
|
|
if kwargs.get("with_reasoning", True) and _reasoning:
|
|
ans = ""
|
|
if not reasoning_start:
|
|
reasoning_start = True
|
|
ans = "<think>"
|
|
ans += _reasoning + "</think>"
|
|
else:
|
|
reasoning_start = False
|
|
ans = resp.choices[0].delta.content
|
|
tol = total_token_count_from_response(resp)
|
|
if not tol:
|
|
tol = num_tokens_from_string(resp.choices[0].delta.content)
|
|
|
|
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
|
|
if finish_reason == "length":
|
|
if is_chinese(ans):
|
|
ans += LENGTH_NOTIFICATION_CN
|
|
else:
|
|
ans += LENGTH_NOTIFICATION_EN
|
|
yield ans, tol
|
|
|
|
async def async_chat_streamly(self, system, history, gen_conf: dict | None = None, **kwargs):
|
|
gen_conf = dict(gen_conf or {})
|
|
if system and history and history[0].get("role") != "system":
|
|
history.insert(0, {"role": "system", "content": system})
|
|
gen_conf = self._clean_conf(gen_conf)
|
|
ans = ""
|
|
total_tokens = 0
|
|
|
|
for attempt in range(self.max_retries + 1):
|
|
try:
|
|
async for delta_ans, tol in self._async_chat_streamly(history, gen_conf, **kwargs):
|
|
ans = delta_ans
|
|
total_tokens += tol
|
|
yield ans
|
|
|
|
yield total_tokens
|
|
return
|
|
except Exception as e:
|
|
e = await self._exceptions_async(e, attempt)
|
|
if e:
|
|
yield e
|
|
yield total_tokens
|
|
return
|
|
|
|
def _length_stop(self, ans):
|
|
if is_chinese([ans]):
|
|
return ans + LENGTH_NOTIFICATION_CN
|
|
return ans + LENGTH_NOTIFICATION_EN
|
|
|
|
@property
|
|
def _retryable_errors(self) -> set[str]:
|
|
return {
|
|
LLMErrorCode.ERROR_RATE_LIMIT,
|
|
LLMErrorCode.ERROR_SERVER,
|
|
}
|
|
|
|
def _should_retry(self, error_code: str) -> bool:
|
|
return error_code in self._retryable_errors
|
|
|
|
def _exceptions(self, e, attempt) -> str | None:
|
|
logging.exception("OpenAI chat_with_tools")
|
|
# Classify the error
|
|
error_code = self._classify_error(e)
|
|
if attempt == self.max_retries:
|
|
error_code = LLMErrorCode.ERROR_MAX_RETRIES
|
|
|
|
if self._should_retry(error_code):
|
|
delay = self._get_delay()
|
|
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
|
time.sleep(delay)
|
|
return None
|
|
|
|
msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
|
logging.error(f"sync base giving up: {msg}")
|
|
return msg
|
|
|
|
async def _exceptions_async(self, e, attempt):
|
|
logging.exception("OpenAI async completion")
|
|
error_code = self._classify_error(e)
|
|
if attempt == self.max_retries:
|
|
error_code = LLMErrorCode.ERROR_MAX_RETRIES
|
|
|
|
if self._should_retry(error_code):
|
|
delay = self._get_delay()
|
|
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
|
await asyncio.sleep(delay)
|
|
return None
|
|
|
|
msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
|
logging.error(f"async base giving up: {msg}")
|
|
return msg
|
|
|
|
def _verbose_tool_use(self, name, args, res):
|
|
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
|
|
|
|
def _append_history(self, hist, tool_call, tool_res):
|
|
hist.append(
|
|
{
|
|
"role": "assistant",
|
|
"tool_calls": [
|
|
{
|
|
"index": getattr(tool_call, "index", None),
|
|
"id": tool_call.id,
|
|
"function": {
|
|
"name": tool_call.function.name,
|
|
"arguments": tool_call.function.arguments,
|
|
},
|
|
"type": "function",
|
|
},
|
|
],
|
|
}
|
|
)
|
|
try:
|
|
if isinstance(tool_res, dict):
|
|
tool_res = json.dumps(tool_res, ensure_ascii=False)
|
|
finally:
|
|
hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)})
|
|
return hist
|
|
|
|
def _append_history_batch(self, hist, results):
|
|
"""
|
|
Append a batch of tool calls to history following the OpenAI protocol:
|
|
one assistant message containing all tool_calls, followed by one tool message per call.
|
|
results: list of (tool_call, name, args, result, error)
|
|
"""
|
|
hist.append(
|
|
{
|
|
"role": "assistant",
|
|
"tool_calls": [
|
|
{
|
|
"index": getattr(tc, "index", None),
|
|
"id": tc.id,
|
|
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
|
|
"type": "function",
|
|
}
|
|
for tc, _, _, _, _ in results
|
|
],
|
|
}
|
|
)
|
|
for tc, _, _, result, err in results:
|
|
if err:
|
|
content = str(err)
|
|
elif isinstance(result, dict):
|
|
content = json.dumps(result, ensure_ascii=False)
|
|
else:
|
|
content = str(result)
|
|
hist.append({"role": "tool", "tool_call_id": tc.id, "content": content})
|
|
return hist
|
|
|
|
def bind_tools(self, toolcall_session=None, tools=None):
|
|
"""Register tools the LLM can call.
|
|
|
|
Two calling styles are accepted:
|
|
|
|
* Legacy: ``bind_tools(toolcall_session, tools_schemas)`` where
|
|
``toolcall_session`` implements :class:`ToolCallSession` and
|
|
``tools_schemas`` is a pre-built list of OpenAI function-schema
|
|
dicts (used by the agent/dialog layer).
|
|
* Decorator: ``bind_tools(tools=[fn1, fn2, ...])`` where each ``fn``
|
|
is decorated with :func:`rag.llm.tool_decorator.tool`. The session
|
|
and schemas are derived from the callables automatically.
|
|
"""
|
|
if tools is None and isinstance(toolcall_session, list):
|
|
tools, toolcall_session = toolcall_session, None
|
|
|
|
if tools and toolcall_session is None and all(is_tool(t) for t in tools):
|
|
session = FunctionToolSession(tools)
|
|
self.is_tools = True
|
|
self.toolcall_session = session
|
|
self.tools = session.schemas
|
|
return
|
|
|
|
if not (toolcall_session and tools):
|
|
return
|
|
self.is_tools = True
|
|
self.toolcall_session = toolcall_session
|
|
self.tools = tools
|
|
|
|
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict | None = None):
|
|
gen_conf = dict(gen_conf or {})
|
|
gen_conf = self._clean_conf(gen_conf)
|
|
if system and history and history[0].get("role") != "system":
|
|
history.insert(0, {"role": "system", "content": system})
|
|
|
|
ans = ""
|
|
tk_count = 0
|
|
hist = deepcopy(history)
|
|
for attempt in range(self.max_retries + 1):
|
|
history = deepcopy(hist)
|
|
try:
|
|
for _ in range(self.max_rounds + 1):
|
|
logging.info(f"{self.tools=}")
|
|
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
|
|
tk_count += total_token_count_from_response(response)
|
|
if not response.choices or not response.choices[0].message:
|
|
raise Exception(f"500 response structure error. Response: {response}")
|
|
|
|
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
|
|
_reasoning = getattr(response.choices[0].message, "reasoning_content", None) or getattr(response.choices[0].message, "reasoning", None)
|
|
if _reasoning:
|
|
ans += "<think>" + _reasoning + "</think>"
|
|
|
|
ans += response.choices[0].message.content
|
|
if response.choices[0].finish_reason == "length":
|
|
ans = self._length_stop(ans)
|
|
|
|
return ans, tk_count
|
|
|
|
async def _exec_tool(tc):
|
|
name = tc.function.name
|
|
try:
|
|
args = json_repair.loads(tc.function.arguments)
|
|
if not isinstance(args, dict):
|
|
raise TypeError(
|
|
f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}"
|
|
)
|
|
if hasattr(self.toolcall_session, "tool_call_async"):
|
|
result = await self.toolcall_session.tool_call_async(name, args)
|
|
else:
|
|
result = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
|
|
return tc, name, args, result, None
|
|
except Exception as e:
|
|
logging.exception(f"Tool call failed: {tc}")
|
|
return tc, name, {}, None, e
|
|
|
|
logging.info(f"Response tool_calls={response.choices[0].message.tool_calls}")
|
|
results = await asyncio.gather(*[_exec_tool(tc) for tc in response.choices[0].message.tool_calls])
|
|
history = self._append_history_batch(history, results)
|
|
for tc, name, args, result, err in results:
|
|
ans += self._verbose_tool_use(name, args, err if err else result)
|
|
|
|
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
|
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
|
response, token_count = await self._async_chat(history, gen_conf)
|
|
ans += response
|
|
tk_count += token_count
|
|
return ans, tk_count
|
|
except Exception as e:
|
|
e = await self._exceptions_async(e, attempt)
|
|
if e:
|
|
return e, tk_count
|
|
|
|
assert False, "Shouldn't be here."
|
|
|
|
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict | None = None):
|
|
gen_conf = dict(gen_conf or {})
|
|
gen_conf = self._clean_conf(gen_conf)
|
|
tools = self.tools
|
|
if system and history and history[0].get("role") != "system":
|
|
history.insert(0, {"role": "system", "content": system})
|
|
|
|
total_tokens = 0
|
|
hist = deepcopy(history)
|
|
|
|
for attempt in range(self.max_retries + 1):
|
|
history = deepcopy(hist)
|
|
try:
|
|
for _round in range(self.max_rounds + 1):
|
|
reasoning_start = False
|
|
logging.info(f"[ToolLoop] round={_round} model={self.model_name} tools={[t['function']['name'] for t in tools]}")
|
|
|
|
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
|
|
|
|
final_tool_calls = {}
|
|
answer = ""
|
|
|
|
async for resp in response:
|
|
if not hasattr(resp, "choices") or not resp.choices:
|
|
continue
|
|
|
|
delta = resp.choices[0].delta
|
|
|
|
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
|
for tool_call in delta.tool_calls:
|
|
index = tool_call.index
|
|
if index not in final_tool_calls:
|
|
if not tool_call.function.arguments:
|
|
tool_call.function.arguments = ""
|
|
final_tool_calls[index] = tool_call
|
|
else:
|
|
final_tool_calls[index].function.arguments += tool_call.function.arguments or ""
|
|
continue
|
|
|
|
if not hasattr(delta, "content") or delta.content is None:
|
|
delta.content = ""
|
|
|
|
_reasoning = getattr(delta, "reasoning_content", None) or getattr(delta, "reasoning", None)
|
|
if _reasoning:
|
|
ans = ""
|
|
if not reasoning_start:
|
|
reasoning_start = True
|
|
ans = "<think>"
|
|
ans += _reasoning + "</think>"
|
|
yield ans
|
|
else:
|
|
reasoning_start = False
|
|
answer += delta.content
|
|
yield delta.content
|
|
|
|
tol = total_token_count_from_response(resp)
|
|
if not tol:
|
|
total_tokens += num_tokens_from_string(delta.content)
|
|
else:
|
|
total_tokens = tol
|
|
|
|
finish_reason = getattr(resp.choices[0], "finish_reason", "")
|
|
if finish_reason == "length":
|
|
yield self._length_stop("")
|
|
|
|
if answer and not final_tool_calls:
|
|
logging.info(f"[ToolLoop] round={_round} completed with text response, exiting")
|
|
yield total_tokens
|
|
return
|
|
|
|
async def _exec_tool(tc):
|
|
name = tc.function.name
|
|
try:
|
|
args = json_repair.loads(tc.function.arguments)
|
|
if not isinstance(args, dict):
|
|
raise TypeError(
|
|
f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}"
|
|
)
|
|
if hasattr(self.toolcall_session, "tool_call_async"):
|
|
result = await self.toolcall_session.tool_call_async(name, args)
|
|
else:
|
|
result = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
|
|
return tc, name, args, result, None
|
|
except Exception as e:
|
|
logging.exception(f"Tool call failed: {tc}")
|
|
return tc, name, {}, None, e
|
|
|
|
tcs = list(final_tool_calls.values())
|
|
logging.info(f"[ToolLoop] round={_round} executing {len(tcs)} tool(s): {[tc.function.name for tc in tcs]}")
|
|
for tc in tcs:
|
|
try:
|
|
args = json_repair.loads(tc.function.arguments)
|
|
except Exception:
|
|
args = {}
|
|
yield self._verbose_tool_use(tc.function.name, args, "Begin to call...")
|
|
results = await asyncio.gather(*[_exec_tool(tc) for tc in tcs])
|
|
history = self._append_history_batch(history, results)
|
|
for tc, name, args, result, err in results:
|
|
yield self._verbose_tool_use(name, args, err if err else result)
|
|
|
|
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
|
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
|
|
|
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
|
|
|
|
async for resp in response:
|
|
if not hasattr(resp, "choices") or not resp.choices:
|
|
continue
|
|
delta = resp.choices[0].delta
|
|
if not hasattr(delta, "content") or delta.content is None:
|
|
continue
|
|
tol = total_token_count_from_response(resp)
|
|
if not tol:
|
|
total_tokens += num_tokens_from_string(delta.content)
|
|
else:
|
|
total_tokens = tol
|
|
yield delta.content
|
|
|
|
yield total_tokens
|
|
return
|
|
|
|
except Exception as e:
|
|
e = await self._exceptions_async(e, attempt)
|
|
if e:
|
|
logging.error(f"async_chat_streamly failed: {e}")
|
|
yield e
|
|
yield total_tokens
|
|
return
|
|
|
|
assert False, "Shouldn't be here."
|
|
|
|
async def _async_chat(self, history, gen_conf, **kwargs):
|
|
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
|
if self.model_name.lower().find("qwq") >= 0:
|
|
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly")
|
|
|
|
final_ans = ""
|
|
tol_token = 0
|
|
async for delta, tol in self._async_chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
|
|
if delta.startswith("<think>") or delta.endswith("</think>"):
|
|
continue
|
|
final_ans += delta
|
|
tol_token = tol
|
|
|
|
if len(final_ans.strip()) == 0:
|
|
final_ans = "**ERROR**: Empty response from reasoning model"
|
|
|
|
return final_ans.strip(), tol_token
|
|
|
|
_, kwargs = _apply_model_family_policies(
|
|
self.model_name,
|
|
backend="base",
|
|
request_kwargs=kwargs,
|
|
)
|
|
|
|
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
|
|
|
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
|
|
return "", 0
|
|
ans = response.choices[0].message.content.strip()
|
|
if response.choices[0].finish_reason == "length":
|
|
ans = self._length_stop(ans)
|
|
return ans, total_token_count_from_response(response)
|
|
|
|
async def async_chat(self, system, history, gen_conf=None, **kwargs):
|
|
gen_conf = dict(gen_conf or {})
|
|
if system and history and history[0].get("role") != "system":
|
|
history.insert(0, {"role": "system", "content": system})
|
|
gen_conf = self._clean_conf(gen_conf)
|
|
|
|
for attempt in range(self.max_retries + 1):
|
|
try:
|
|
return await self._async_chat(history, gen_conf, **kwargs)
|
|
except Exception as e:
|
|
e = await self._exceptions_async(e, attempt)
|
|
if e:
|
|
return e, 0
|
|
assert False, "Shouldn't be here."
|
|
|
|
|
|
class XinferenceChat(Base):
|
|
_FACTORY_NAME = "Xinference"
|
|
|
|
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
|
if not base_url:
|
|
raise ValueError("Local llm url cannot be None")
|
|
base_url = urljoin(base_url, "v1")
|
|
super().__init__(key, model_name, base_url, **kwargs)
|
|
|
|
|
|
class HuggingFaceChat(Base):
|
|
_FACTORY_NAME = "HuggingFace"
|
|
|
|
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
|
if not base_url:
|
|
raise ValueError("Local llm url cannot be None")
|
|
base_url = urljoin(base_url, "v1")
|
|
super().__init__(key, model_name.split("___")[0], base_url, **kwargs)
|
|
|
|
|
|
class ModelScopeChat(Base):
|
|
_FACTORY_NAME = "ModelScope"
|
|
|
|
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
|
if not base_url:
|
|
raise ValueError("Local llm url cannot be None")
|
|
base_url = urljoin(base_url, "v1")
|
|
super().__init__(key, model_name.split("___")[0], base_url, **kwargs)
|
|
|
|
|
|
class BaiChuanChat(Base):
|
|
_FACTORY_NAME = "BaiChuan"
|
|
|
|
def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1", **kwargs):
|
|
if not base_url:
|
|
base_url = "https://api.baichuan-ai.com/v1"
|
|
super().__init__(key, model_name, base_url, **kwargs)
|
|
|
|
@staticmethod
|
|
def _format_params(params):
|
|
return {
|
|
"temperature": params.get("temperature", 0.3),
|
|
"top_p": params.get("top_p", 0.85),
|
|
}
|
|
|
|
def _clean_conf(self, gen_conf):
|
|
return {
|
|
"temperature": gen_conf.get("temperature", 0.3),
|
|
"top_p": gen_conf.get("top_p", 0.85),
|
|
}
|
|
|
|
def _chat(self, history, gen_conf=None, **kwargs):
|
|
gen_conf = dict(gen_conf or {})
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=history,
|
|
extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
|
|
**gen_conf,
|
|
)
|
|
if not response.choices:
|
|
raise ValueError("LLM returned empty response") # pact: guard empty choices list
|
|
ans = response.choices[0].message.content.strip()
|
|
if response.choices[0].finish_reason == "length":
|
|
if is_chinese([ans]):
|
|
ans += LENGTH_NOTIFICATION_CN
|
|
else:
|
|
ans += LENGTH_NOTIFICATION_EN
|
|
return ans, total_token_count_from_response(response)
|
|
|
|
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
|
|
gen_conf = dict(gen_conf or {})
|
|
if system and history and history[0].get("role") != "system":
|
|
history.insert(0, {"role": "system", "content": system})
|
|
if "max_tokens" in gen_conf:
|
|
del gen_conf["max_tokens"]
|
|
ans = ""
|
|
total_tokens = 0
|
|
try:
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=history,
|
|
extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
|
|
stream=True,
|
|
**self._format_params(gen_conf),
|
|
)
|
|
for resp in response:
|
|
if not resp.choices:
|
|
continue
|
|
if not resp.choices[0].delta.content:
|
|
resp.choices[0].delta.content = ""
|
|
ans = resp.choices[0].delta.content
|
|
tol = total_token_count_from_response(resp)
|
|
if not tol:
|
|
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
|
else:
|
|
total_tokens = tol
|
|
if resp.choices[0].finish_reason == "length":
|
|
if is_chinese([ans]):
|
|
ans += LENGTH_NOTIFICATION_CN
|
|
else:
|
|
ans += LENGTH_NOTIFICATION_EN
|
|
yield ans
|
|
|
|
except Exception as e:
|
|
yield ans + "\n**ERROR**: " + str(e)
|
|
|
|
yield total_tokens
|
|
|
|
|
|
class LocalAIChat(Base):
|
|
_FACTORY_NAME = "LocalAI"
|
|
|
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
|
|
|
if not base_url:
|
|
raise ValueError("Local llm url cannot be None")
|
|
base_url = urljoin(base_url, "v1")
|
|
self.client = OpenAI(api_key="empty", base_url=base_url)
|
|
self.model_name = model_name.split("___")[0]
|
|
|
|
|
|
class LocalLLM(Base):
|
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
|
from jina import Client
|
|
|
|
self.client = Client(port=12345, protocol="grpc", asyncio=True)
|
|
|
|
def _prepare_prompt(self, system, history, gen_conf):
|
|
from rag.svr.jina_server import Prompt
|
|
|
|
if system and history and history[0].get("role") != "system":
|
|
history.insert(0, {"role": "system", "content": system})
|
|
return Prompt(message=history, gen_conf=gen_conf)
|
|
|
|
def _stream_response(self, endpoint, prompt):
|
|
from rag.svr.jina_server import Generation
|
|
|
|
answer = ""
|
|
loop = asyncio.new_event_loop()
|
|
try:
|
|
res = self.client.stream_doc(on=endpoint, inputs=prompt, return_type=Generation)
|
|
try:
|
|
while True:
|
|
answer = loop.run_until_complete(res.__anext__()).text
|
|
yield answer
|
|
except StopAsyncIteration:
|
|
pass
|
|
except Exception as e:
|
|
yield answer + "\n**ERROR**: " + str(e)
|
|
finally:
|
|
loop.close()
|
|
yield num_tokens_from_string(answer)
|
|
|
|
def chat(self, system, history, gen_conf=None, **kwargs):
|
|
gen_conf = dict(gen_conf or {})
|
|
if "max_tokens" in gen_conf:
|
|
del gen_conf["max_tokens"]
|
|
prompt = self._prepare_prompt(system, history, gen_conf)
|
|
chat_gen = self._stream_response("/chat", prompt)
|
|
ans = next(chat_gen)
|
|
total_tokens = next(chat_gen)
|
|
return ans, total_tokens
|
|
|
|
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
|
|
gen_conf = dict(gen_conf or {})
|
|
if "max_tokens" in gen_conf:
|
|
del gen_conf["max_tokens"]
|
|
prompt = self._prepare_prompt(system, history, gen_conf)
|
|
return self._stream_response("/stream", prompt)
|
|
|
|
|
|
class VolcEngineChat(Base):
|
|
_FACTORY_NAME = "VolcEngine"
|
|
|
|
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs):
|
|
"""
|
|
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
|
|
Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use
|
|
model_name is for display only
|
|
"""
|
|
base_url = base_url if base_url else "https://ark.cn-beijing.volces.com/api/v3"
|
|
try:
|
|
ark_api_key = json.loads(key).get("ark_api_key", "")
|
|
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
|
|
super().__init__(ark_api_key, model_name, base_url, **kwargs)
|
|
except JSONDecodeError:
|
|
super().__init__(key, model_name, base_url, **kwargs)
|
|
|
|
|
|
class MistralChat(Base):
|
|
_FACTORY_NAME = "Mistral"
|
|
|
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
|
|
|
from mistralai.client import MistralClient
|
|
|
|
self.client = MistralClient(api_key=key)
|
|
self.model_name = model_name
|
|
|
|
def _clean_conf(self, gen_conf):
|
|
for k in list(gen_conf.keys()):
|
|
if k not in ["temperature", "top_p", "max_tokens"]:
|
|
del gen_conf[k]
|
|
return gen_conf
|
|
|
|
def _chat(self, history, gen_conf=None, **kwargs):
|
|
gen_conf = dict(gen_conf or {})
|
|
gen_conf = self._clean_conf(gen_conf)
|
|
response = self.client.chat(model=self.model_name, messages=history, **gen_conf)
|
|
if not response.choices:
|
|
raise ValueError("LLM returned empty response") # pact: guard empty choices list
|
|
ans = response.choices[0].message.content
|
|
if response.choices[0].finish_reason == "length":
|
|
if is_chinese(ans):
|
|
ans += LENGTH_NOTIFICATION_CN
|
|
else:
|
|
ans += LENGTH_NOTIFICATION_EN
|
|
return ans, total_token_count_from_response(response)
|
|
|
|
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
|
|
gen_conf = dict(gen_conf or {})
|
|
if system and history and history[0].get("role") != "system":
|
|
history.insert(0, {"role": "system", "content": system})
|
|
gen_conf = self._clean_conf(gen_conf)
|
|
ans = ""
|
|
total_tokens = 0
|
|
try:
|
|
response = self.client.chat_stream(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
|
for resp in response:
|
|
if not resp.choices or not resp.choices[0].delta.content:
|
|
continue
|
|
ans = resp.choices[0].delta.content
|
|
total_tokens += 1
|
|
if resp.choices[0].finish_reason == "length":
|
|
if is_chinese(ans):
|
|
ans += LENGTH_NOTIFICATION_CN
|
|
else:
|
|
ans += LENGTH_NOTIFICATION_EN
|
|
yield ans
|
|
|
|
except openai.APIError as e:
|
|
yield ans + "\n**ERROR**: " + str(e)
|
|
|
|
yield total_tokens
|
|
|
|
|
|
class LmStudioChat(Base):
|
|
_FACTORY_NAME = "LM-Studio"
|
|
|
|
def __init__(self, key, model_name, base_url, **kwargs):
|
|
if not base_url:
|
|
raise ValueError("Local llm url cannot be None")
|
|
base_url = urljoin(base_url, "v1")
|
|
super().__init__(key, model_name, base_url, **kwargs)
|
|
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
|
|
self.model_name = model_name
|
|
|
|
|
|
class OpenAI_APIChat(Base):
|
|
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
|
|
|
|
def __init__(self, key, model_name, base_url, **kwargs):
|
|
if not base_url:
|
|
raise ValueError("url cannot be None")
|
|
model_name = model_name.split("___")[0]
|
|
super().__init__(key, model_name, base_url, **kwargs)
|
|
|
|
|
|
class LeptonAIChat(Base):
|
|
_FACTORY_NAME = "LeptonAI"
|
|
|
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
|
if not base_url:
|
|
base_url = urljoin("https://" + model_name + ".lepton.run", "api/v1")
|
|
super().__init__(key, model_name, base_url, **kwargs)
|
|
|
|
|
|
class ReplicateChat(Base):
|
|
_FACTORY_NAME = "Replicate"
|
|
|
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
|
|
|
from replicate.client import Client
|
|
|
|
self.model_name = model_name
|
|
self.client = Client(api_token=key)
|
|
|
|
def _chat(self, history, gen_conf=None, **kwargs):
|
|
gen_conf = dict(gen_conf or {})
|
|
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
|
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:] if item["role"] != "system"])
|
|
response = self.client.run(
|
|
self.model_name,
|
|
input={"system_prompt": system, "prompt": prompt, **gen_conf},
|
|
)
|
|
ans = "".join(response)
|
|
return ans, num_tokens_from_string(ans)
|
|
|
|
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
|
|
gen_conf = dict(gen_conf or {})
|
|
if "max_tokens" in gen_conf:
|
|
del gen_conf["max_tokens"]
|
|
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]])
|
|
ans = ""
|
|
try:
|
|
response = self.client.run(
|
|
self.model_name,
|
|
input={"system_prompt": system, "prompt": prompt, **gen_conf},
|
|
)
|
|
for resp in response:
|
|
ans = resp
|
|
yield ans
|
|
|
|
except Exception as e:
|
|
yield ans + "\n**ERROR**: " + str(e)
|
|
|
|
yield num_tokens_from_string(ans)
|
|
|
|
|
|
class SparkChat(Base):
|
|
_FACTORY_NAME = "XunFei Spark"
|
|
|
|
def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1", **kwargs):
|
|
if not base_url:
|
|
base_url = "https://spark-api-open.xf-yun.com/v1"
|
|
model2version = {
|
|
"Spark-Max": "generalv3.5",
|
|
"Spark-Max-32K": "max-32k",
|
|
"Spark-Lite": "lite",
|
|
"Spark-Pro": "generalv3",
|
|
"Spark-Pro-128K": "pro-128k",
|
|
"Spark-4.0-Ultra": "4.0Ultra",
|
|
}
|
|
version2model = {v: k for k, v in model2version.items()}
|
|
assert model_name in model2version or model_name in version2model, f"The given model name is not supported yet. Support: {list(model2version.keys())}"
|
|
if model_name in model2version:
|
|
model_version = model2version[model_name]
|
|
else:
|
|
model_version = model_name
|
|
super().__init__(key, model_version, base_url, **kwargs)
|
|
|
|
|
|
class BaiduYiyanChat(Base):
|
|
_FACTORY_NAME = "BaiduYiyan"
|
|
|
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
|
|
|
import qianfan
|
|
|
|
key = json.loads(key)
|
|
ak = key.get("yiyan_ak", "")
|
|
sk = key.get("yiyan_sk", "")
|
|
self.client = qianfan.ChatCompletion(ak=ak, sk=sk)
|
|
self.model_name = model_name.lower()
|
|
|
|
def _clean_conf(self, gen_conf):
|
|
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
|
|
if "max_tokens" in gen_conf:
|
|
del gen_conf["max_tokens"]
|
|
return gen_conf
|
|
|
|
def _chat(self, history, gen_conf):
|
|
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
|
response = self.client.do(model=self.model_name, messages=[h for h in history if h["role"] != "system"], system=system, **gen_conf).body
|
|
ans = response["result"]
|
|
return ans, total_token_count_from_response(response)
|
|
|
|
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
|
|
gen_conf = dict(gen_conf or {})
|
|
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
|
|
if "max_tokens" in gen_conf:
|
|
del gen_conf["max_tokens"]
|
|
ans = ""
|
|
total_tokens = 0
|
|
|
|
try:
|
|
response = self.client.do(model=self.model_name, messages=history, system=system, stream=True, **gen_conf)
|
|
for resp in response:
|
|
resp = resp.body
|
|
ans = resp["result"]
|
|
total_tokens = total_token_count_from_response(resp)
|
|
|
|
yield ans
|
|
|
|
except Exception as e:
|
|
return ans + "\n**ERROR**: " + str(e), 0
|
|
|
|
yield total_tokens
|
|
|
|
async def async_chat_streamly(self, system, history, gen_conf: dict | None = None, **kwargs):
|
|
gen_conf = dict(gen_conf or {})
|
|
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
|
|
if "max_tokens" in gen_conf:
|
|
del gen_conf["max_tokens"]
|
|
|
|
def _do_chat():
|
|
system_msg = history[0]["content"] if history and history[0].get("role") == "system" else ""
|
|
msgs = [h for h in history if h.get("role") != "system"]
|
|
try:
|
|
response = self.client.do(model=self.model_name, messages=msgs, system=system_msg, stream=True, **gen_conf)
|
|
result_text = ""
|
|
total_tokens = 0
|
|
for resp in response:
|
|
resp = resp.body
|
|
result_text = resp["result"]
|
|
total_tokens = total_token_count_from_response(resp)
|
|
return result_text, total_tokens, None
|
|
except Exception as e:
|
|
return "", 0, e
|
|
|
|
result_text, total_tokens, error = await asyncio.to_thread(_do_chat)
|
|
if error:
|
|
yield f"**ERROR**: {error}"
|
|
else:
|
|
yield result_text
|
|
yield total_tokens
|
|
|
|
|
|
class GoogleChat(Base):
|
|
_FACTORY_NAME = "Google Cloud"
|
|
|
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
|
|
|
import base64
|
|
|
|
from google.oauth2 import service_account
|
|
|
|
key = json.loads(key)
|
|
access_token = json.loads(base64.b64decode(key.get("google_service_account_key", "")))
|
|
project_id = key.get("google_project_id", "")
|
|
region = key.get("google_region", "")
|
|
|
|
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
|
|
self.model_name = model_name
|
|
|
|
if "claude" in self.model_name:
|
|
from anthropic import AnthropicVertex
|
|
from google.auth.transport.requests import Request
|
|
|
|
if access_token:
|
|
credits = service_account.Credentials.from_service_account_info(access_token, scopes=scopes)
|
|
request = Request()
|
|
credits.refresh(request)
|
|
token = credits.token
|
|
self.client = AnthropicVertex(region=region, project_id=project_id, access_token=token)
|
|
else:
|
|
self.client = AnthropicVertex(region=region, project_id=project_id)
|
|
else:
|
|
from google import genai
|
|
|
|
if access_token:
|
|
credits = service_account.Credentials.from_service_account_info(access_token, scopes=scopes)
|
|
self.client = genai.Client(vertexai=True, project=project_id, location=region, credentials=credits)
|
|
else:
|
|
self.client = genai.Client(vertexai=True, project=project_id, location=region)
|
|
|
|
def _clean_conf(self, gen_conf):
|
|
if "claude" in self.model_name:
|
|
if "max_tokens" in gen_conf:
|
|
del gen_conf["max_tokens"]
|
|
else:
|
|
if "max_tokens" in gen_conf:
|
|
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
|
|
del gen_conf["max_tokens"]
|
|
for k in list(gen_conf.keys()):
|
|
if k not in ["temperature", "top_p", "max_output_tokens"]:
|
|
del gen_conf[k]
|
|
return gen_conf
|
|
|
|
def _chat(self, history, gen_conf=None, **kwargs):
|
|
gen_conf = dict(gen_conf or {})
|
|
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
|
|
|
if "claude" in self.model_name:
|
|
gen_conf = self._clean_conf(gen_conf)
|
|
response = self.client.messages.create(
|
|
model=self.model_name,
|
|
messages=[h for h in history if h["role"] != "system"],
|
|
system=system,
|
|
stream=False,
|
|
**gen_conf,
|
|
).json()
|
|
ans = response["content"][0]["text"]
|
|
if response["stop_reason"] == "max_tokens":
|
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
|
return (
|
|
ans,
|
|
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
|
|
)
|
|
|
|
# Gemini models with google-genai SDK
|
|
# Set default thinking_budget=0 if not specified
|
|
if "thinking_budget" not in gen_conf:
|
|
gen_conf["thinking_budget"] = 0
|
|
|
|
thinking_budget = gen_conf.pop("thinking_budget", 0)
|
|
gen_conf = self._clean_conf(gen_conf)
|
|
|
|
# Build GenerateContentConfig
|
|
try:
|
|
from google.genai.types import Content, GenerateContentConfig, Part, ThinkingConfig
|
|
except ImportError as e:
|
|
logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0")
|
|
raise
|
|
|
|
config_dict = {}
|
|
if system:
|
|
config_dict["system_instruction"] = system
|
|
if "temperature" in gen_conf:
|
|
config_dict["temperature"] = gen_conf["temperature"]
|
|
if "top_p" in gen_conf:
|
|
config_dict["top_p"] = gen_conf["top_p"]
|
|
if "max_output_tokens" in gen_conf:
|
|
config_dict["max_output_tokens"] = gen_conf["max_output_tokens"]
|
|
|
|
# Add ThinkingConfig
|
|
config_dict["thinking_config"] = ThinkingConfig(thinking_budget=thinking_budget)
|
|
|
|
config = GenerateContentConfig(**config_dict)
|
|
|
|
# Convert history to google-genai Content format
|
|
contents = []
|
|
for item in history:
|
|
if item["role"] == "system":
|
|
continue
|
|
# google-genai uses 'model' instead of 'assistant'
|
|
role = "model" if item["role"] == "assistant" else item["role"]
|
|
content = Content(
|
|
role=role,
|
|
parts=[Part(text=item["content"])],
|
|
)
|
|
contents.append(content)
|
|
|
|
response = self.client.models.generate_content(
|
|
model=self.model_name,
|
|
contents=contents,
|
|
config=config,
|
|
)
|
|
|
|
ans = response.text
|
|
# Get token count from response
|
|
try:
|
|
total_tokens = response.usage_metadata.total_token_count
|
|
except Exception:
|
|
total_tokens = 0
|
|
|
|
return ans, total_tokens
|
|
|
|
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
|
|
gen_conf = dict(gen_conf or {})
|
|
if "claude" in self.model_name:
|
|
if "max_tokens" in gen_conf:
|
|
del gen_conf["max_tokens"]
|
|
ans = ""
|
|
total_tokens = 0
|
|
try:
|
|
response = self.client.messages.create(
|
|
model=self.model_name,
|
|
messages=history,
|
|
system=system,
|
|
stream=True,
|
|
**gen_conf,
|
|
)
|
|
for res in response.iter_lines():
|
|
res = res.decode("utf-8")
|
|
if "content_block_delta" in res and "data" in res:
|
|
text = json.loads(res[6:])["delta"]["text"]
|
|
ans = text
|
|
total_tokens += num_tokens_from_string(text)
|
|
except Exception as e:
|
|
yield ans + "\n**ERROR**: " + str(e)
|
|
|
|
yield total_tokens
|
|
else:
|
|
# Gemini models with google-genai SDK
|
|
ans = ""
|
|
total_tokens = 0
|
|
|
|
# Set default thinking_budget=0 if not specified
|
|
if "thinking_budget" not in gen_conf:
|
|
gen_conf["thinking_budget"] = 0
|
|
|
|
thinking_budget = gen_conf.pop("thinking_budget", 0)
|
|
gen_conf = self._clean_conf(gen_conf)
|
|
|
|
# Build GenerateContentConfig
|
|
try:
|
|
from google.genai.types import Content, GenerateContentConfig, Part, ThinkingConfig
|
|
except ImportError as e:
|
|
logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0")
|
|
raise
|
|
|
|
config_dict = {}
|
|
if system:
|
|
config_dict["system_instruction"] = system
|
|
if "temperature" in gen_conf:
|
|
config_dict["temperature"] = gen_conf["temperature"]
|
|
if "top_p" in gen_conf:
|
|
config_dict["top_p"] = gen_conf["top_p"]
|
|
if "max_output_tokens" in gen_conf:
|
|
config_dict["max_output_tokens"] = gen_conf["max_output_tokens"]
|
|
|
|
# Add ThinkingConfig
|
|
config_dict["thinking_config"] = ThinkingConfig(thinking_budget=thinking_budget)
|
|
|
|
config = GenerateContentConfig(**config_dict)
|
|
|
|
# Convert history to google-genai Content format
|
|
contents = []
|
|
for item in history:
|
|
# google-genai uses 'model' instead of 'assistant'
|
|
role = "model" if item["role"] == "assistant" else item["role"]
|
|
content = Content(
|
|
role=role,
|
|
parts=[Part(text=item["content"])],
|
|
)
|
|
contents.append(content)
|
|
|
|
try:
|
|
for chunk in self.client.models.generate_content_stream(
|
|
model=self.model_name,
|
|
contents=contents,
|
|
config=config,
|
|
):
|
|
text = chunk.text
|
|
ans = text
|
|
total_tokens += num_tokens_from_string(text)
|
|
yield ans
|
|
|
|
except Exception as e:
|
|
yield ans + "\n**ERROR**: " + str(e)
|
|
|
|
yield total_tokens
|
|
|
|
|
|
class TokenPonyChat(Base):
|
|
_FACTORY_NAME = "TokenPony"
|
|
|
|
def __init__(self, key, model_name, base_url="https://ragflow.vip-api.tokenpony.cn/v1", **kwargs):
|
|
if not base_url:
|
|
base_url = "https://ragflow.vip-api.tokenpony.cn/v1"
|
|
super().__init__(key, model_name, base_url, **kwargs)
|
|
|
|
|
|
class N1nChat(Base):
|
|
_FACTORY_NAME = "n1n"
|
|
|
|
def __init__(self, key, model_name, base_url="https://api.n1n.ai/v1", **kwargs):
|
|
if not base_url:
|
|
base_url = "https://api.n1n.ai/v1"
|
|
super().__init__(key, model_name, base_url, **kwargs)
|
|
|
|
|
|
class AvianChat(Base):
|
|
_FACTORY_NAME = "Avian"
|
|
|
|
def __init__(self, key, model_name, base_url="https://api.avian.io/v1", **kwargs):
|
|
if not base_url:
|
|
base_url = "https://api.avian.io/v1"
|
|
super().__init__(key, model_name, base_url, **kwargs)
|
|
|
|
|
|
class AstraflowChat(Base):
|
|
_FACTORY_NAME = "Astraflow"
|
|
|
|
def __init__(self, key, model_name, base_url="https://api-us-ca.umodelverse.ai/v1", **kwargs):
|
|
if not base_url:
|
|
base_url = "https://api-us-ca.umodelverse.ai/v1"
|
|
super().__init__(key, model_name, base_url, **kwargs)
|
|
|
|
|
|
class AstraflowCNChat(Base):
|
|
_FACTORY_NAME = "Astraflow-CN"
|
|
|
|
def __init__(self, key, model_name, base_url="https://api.modelverse.cn/v1", **kwargs):
|
|
if not base_url:
|
|
base_url = "https://api.modelverse.cn/v1"
|
|
super().__init__(key, model_name, base_url, **kwargs)
|
|
|
|
|
|
class FuturMixChat(Base):
|
|
_FACTORY_NAME = "FuturMix"
|
|
|
|
def __init__(self, key, model_name, base_url="https://futurmix.ai/v1", **kwargs):
|
|
if not base_url:
|
|
base_url = "https://futurmix.ai/v1"
|
|
super().__init__(key, model_name, base_url, **kwargs)
|
|
logging.info("[FuturMix] Chat initialized with model %s", model_name)
|
|
|
|
|
|
class LiteLLMBase(ABC):
|
|
_FACTORY_NAME = [
|
|
"Tongyi-Qianwen",
|
|
"Bedrock",
|
|
"Moonshot",
|
|
"xAI",
|
|
"DeepInfra",
|
|
"Groq",
|
|
"Cohere",
|
|
"Gemini",
|
|
"DeepSeek",
|
|
"NVIDIA",
|
|
"TogetherAI",
|
|
"Anthropic",
|
|
"Ollama",
|
|
"LongCat",
|
|
"CometAPI",
|
|
"SILICONFLOW",
|
|
"OpenRouter",
|
|
"StepFun",
|
|
"PPIO",
|
|
"PerfXCloud",
|
|
"Upstage",
|
|
"NovitaAI",
|
|
"01.AI",
|
|
"GiteeAI",
|
|
"302.AI",
|
|
"Jiekou.AI",
|
|
"ZHIPU-AI",
|
|
"MiniMax",
|
|
"DeerAPI",
|
|
"GPUStack",
|
|
"OpenAI",
|
|
"Azure-OpenAI",
|
|
"Tencent Hunyuan",
|
|
]
|
|
|
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
|
self.timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600))
|
|
self.provider = kwargs.get("provider", "")
|
|
self.prefix = LITELLM_PROVIDER_PREFIX.get(self.provider, "")
|
|
self.model_name = f"{self.prefix}{model_name}"
|
|
self.api_key = key
|
|
self.base_url = (base_url or FACTORY_DEFAULT_BASE_URL.get(self.provider, "")).rstrip("/")
|
|
# Configure retry parameters
|
|
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
|
|
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
|
|
self.max_rounds = kwargs.get("max_rounds", 5)
|
|
self.is_tools = False
|
|
self.tools = []
|
|
self.toolcall_sessions = {}
|
|
|
|
# Factory specific fields
|
|
if self.provider == SupportedLiteLLMProvider.OpenRouter:
|
|
try:
|
|
self.api_key = json.loads(key).get("api_key", "")
|
|
self.provider_order = json.loads(key).get("provider_order", "")
|
|
except JSONDecodeError:
|
|
self.api_key = key
|
|
self.provider_order = ""
|
|
elif self.provider == SupportedLiteLLMProvider.Azure_OpenAI:
|
|
self.api_key = json.loads(key).get("api_key", "")
|
|
self.api_version = json.loads(key).get("api_version", "2024-02-01")
|
|
elif self.provider == SupportedLiteLLMProvider.MiniMax:
|
|
# MiniMax requires GroupId as a query parameter for API authentication
|
|
try:
|
|
key_obj = json.loads(key) if isinstance(key, str) else key
|
|
self.api_key = key_obj.get("api_key", key) if isinstance(key_obj, dict) else key
|
|
self.group_id = key_obj.get("group_id", "") if isinstance(key_obj, dict) else ""
|
|
except (json.JSONDecodeError, TypeError):
|
|
self.api_key = key
|
|
self.group_id = ""
|
|
else:
|
|
self.group_id = ""
|
|
|
|
def _get_delay(self):
|
|
return self.base_delay * random.uniform(10, 150)
|
|
|
|
def _classify_error(self, error):
|
|
error_str = str(error).lower()
|
|
|
|
keywords_mapping = [
|
|
(["quota", "capacity", "credit", "billing", "balance", "欠费"], LLMErrorCode.ERROR_QUOTA),
|
|
(["rate limit", "429", "tpm limit", "too many requests", "requests per minute"], LLMErrorCode.ERROR_RATE_LIMIT),
|
|
(["auth", "key", "apikey", "401", "forbidden", "permission"], LLMErrorCode.ERROR_AUTHENTICATION),
|
|
(["invalid", "bad request", "400", "format", "malformed", "parameter"], LLMErrorCode.ERROR_INVALID_REQUEST),
|
|
(["server", "503", "502", "504", "500", "unavailable"], LLMErrorCode.ERROR_SERVER),
|
|
(["timeout", "timed out"], LLMErrorCode.ERROR_TIMEOUT),
|
|
(["connect", "network", "unreachable", "dns"], LLMErrorCode.ERROR_CONNECTION),
|
|
(["filter", "content", "policy", "blocked", "safety", "inappropriate"], LLMErrorCode.ERROR_CONTENT_FILTER),
|
|
(["model", "not found", "does not exist", "not available"], LLMErrorCode.ERROR_MODEL),
|
|
(["max rounds"], LLMErrorCode.ERROR_MODEL),
|
|
]
|
|
for words, code in keywords_mapping:
|
|
if re.search("({})".format("|".join(words)), error_str):
|
|
return code
|
|
|
|
return LLMErrorCode.ERROR_GENERIC
|
|
|
|
def _clean_conf(self, gen_conf):
|
|
gen_conf, _ = _apply_model_family_policies(
|
|
self.model_name,
|
|
backend="litellm",
|
|
provider=self.provider,
|
|
gen_conf=gen_conf,
|
|
)
|
|
|
|
gen_conf.pop("max_tokens", None)
|
|
gen_conf = {k: v for k, v in gen_conf.items() if k in LITELLM_ALLOWED_GEN_CONF_KEYS}
|
|
return gen_conf
|
|
|
|
def _need_reasoning_content_back(self) -> bool:
|
|
return self.provider == SupportedLiteLLMProvider.DeepSeek
|
|
|
|
async def async_chat(self, system, history, gen_conf, **kwargs):
|
|
hist = list(history) if history else []
|
|
if system:
|
|
if not hist or hist[0].get("role") != "system":
|
|
hist.insert(0, {"role": "system", "content": system})
|
|
|
|
logging.info("[HISTORY]" + json.dumps(hist, ensure_ascii=False, indent=2))
|
|
gen_conf = self._clean_conf(gen_conf)
|
|
_, kwargs = _apply_model_family_policies(
|
|
self.model_name,
|
|
backend="litellm",
|
|
provider=self.provider,
|
|
request_kwargs=kwargs,
|
|
)
|
|
|
|
completion_args = self._construct_completion_args(history=hist, stream=False, tools=False, **{**gen_conf, **kwargs})
|
|
|
|
for attempt in range(self.max_retries + 1):
|
|
try:
|
|
response = await litellm.acompletion(
|
|
**completion_args,
|
|
drop_params=True,
|
|
timeout=self.timeout,
|
|
)
|
|
|
|
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
|
|
return "", 0
|
|
ans = response.choices[0].message.content.strip()
|
|
if response.choices[0].finish_reason == "length":
|
|
ans = self._length_stop(ans)
|
|
|
|
return ans, total_token_count_from_response(response)
|
|
except Exception as e:
|
|
e = await self._exceptions_async(e, attempt)
|
|
if e:
|
|
return e, 0
|
|
|
|
assert False, "Shouldn't be here."
|
|
|
|
async def async_chat_streamly(self, system, history, gen_conf, **kwargs):
|
|
if system and history and history[0].get("role") != "system":
|
|
history.insert(0, {"role": "system", "content": system})
|
|
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
|
gen_conf = self._clean_conf(gen_conf)
|
|
reasoning_start = False
|
|
total_tokens = 0
|
|
|
|
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
|
|
stop = kwargs.get("stop")
|
|
if stop:
|
|
completion_args["stop"] = stop
|
|
|
|
for attempt in range(self.max_retries + 1):
|
|
try:
|
|
stream = await litellm.acompletion(
|
|
**completion_args,
|
|
drop_params=True,
|
|
timeout=self.timeout,
|
|
)
|
|
|
|
async for resp in stream:
|
|
if not hasattr(resp, "choices") or not resp.choices:
|
|
continue
|
|
|
|
delta = resp.choices[0].delta
|
|
if not hasattr(delta, "content") or delta.content is None:
|
|
delta.content = ""
|
|
|
|
_reasoning = getattr(delta, "reasoning_content", None) or getattr(delta, "reasoning", None)
|
|
if kwargs.get("with_reasoning", True) and _reasoning:
|
|
ans = ""
|
|
if not reasoning_start:
|
|
reasoning_start = True
|
|
ans = "<think>"
|
|
ans += _reasoning + "</think>"
|
|
else:
|
|
reasoning_start = False
|
|
ans = delta.content
|
|
|
|
tol = total_token_count_from_response(resp)
|
|
if not tol:
|
|
tol = num_tokens_from_string(delta.content)
|
|
total_tokens += tol
|
|
|
|
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
|
|
if finish_reason == "length":
|
|
if is_chinese(ans):
|
|
ans += LENGTH_NOTIFICATION_CN
|
|
else:
|
|
ans += LENGTH_NOTIFICATION_EN
|
|
|
|
yield ans
|
|
yield total_tokens
|
|
return
|
|
except Exception as e:
|
|
e = await self._exceptions_async(e, attempt)
|
|
if e:
|
|
yield e
|
|
yield total_tokens
|
|
return
|
|
|
|
def _length_stop(self, ans):
|
|
if is_chinese([ans]):
|
|
return ans + LENGTH_NOTIFICATION_CN
|
|
return ans + LENGTH_NOTIFICATION_EN
|
|
|
|
@property
|
|
def _retryable_errors(self) -> set[str]:
|
|
return {
|
|
LLMErrorCode.ERROR_RATE_LIMIT,
|
|
LLMErrorCode.ERROR_SERVER,
|
|
}
|
|
|
|
def _should_retry(self, error_code: str) -> bool:
|
|
return error_code in self._retryable_errors
|
|
|
|
async def _exceptions_async(self, e, attempt):
|
|
logging.exception("LiteLLMBase async completion")
|
|
error_code = self._classify_error(e)
|
|
if attempt == self.max_retries:
|
|
error_code = LLMErrorCode.ERROR_MAX_RETRIES
|
|
|
|
if self._should_retry(error_code):
|
|
delay = self._get_delay()
|
|
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
|
await asyncio.sleep(delay)
|
|
return None
|
|
msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
|
logging.error(f"async_chat_streamly giving up: {msg}")
|
|
return msg
|
|
|
|
def _verbose_tool_use(self, name, args, res):
|
|
return "<tool_call>" + json.dumps(
|
|
{"name": name, "args": args, "result": str(res) if isinstance(res, Exception) else res},
|
|
ensure_ascii=False,
|
|
indent=2,
|
|
) + "</tool_call>"
|
|
|
|
def _append_history(self, hist, tool_call, tool_res, reasoning_content=None):
|
|
assistant_msg = {
|
|
"role": "assistant",
|
|
"tool_calls": [
|
|
{
|
|
"index": getattr(tool_call, "index", None),
|
|
"id": tool_call.id,
|
|
"function": {
|
|
"name": tool_call.function.name,
|
|
"arguments": tool_call.function.arguments,
|
|
},
|
|
"type": "function",
|
|
},
|
|
],
|
|
}
|
|
if reasoning_content:
|
|
assistant_msg["reasoning_content"] = reasoning_content
|
|
hist.append(assistant_msg)
|
|
try:
|
|
if isinstance(tool_res, dict):
|
|
tool_res = json.dumps(tool_res, ensure_ascii=False)
|
|
finally:
|
|
hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)})
|
|
return hist
|
|
|
|
def _append_history_batch(self, hist, results, reasoning_content=None):
|
|
"""
|
|
Append a batch of tool calls to history following the OpenAI protocol:
|
|
one assistant message containing all tool_calls, followed by one tool message per call.
|
|
results: list of (tool_call, name, args, result, error)
|
|
"""
|
|
assistant_msg = {
|
|
"role": "assistant",
|
|
"tool_calls": [
|
|
{
|
|
"index": getattr(tc, "index", None),
|
|
"id": tc.id,
|
|
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
|
|
"type": "function",
|
|
}
|
|
for tc, _, _, _, _ in results
|
|
],
|
|
}
|
|
if reasoning_content:
|
|
assistant_msg["reasoning_content"] = reasoning_content
|
|
hist.append(assistant_msg)
|
|
for tc, _, _, result, err in results:
|
|
if err:
|
|
content = str(err)
|
|
elif isinstance(result, dict):
|
|
content = json.dumps(result, ensure_ascii=False)
|
|
else:
|
|
content = str(result)
|
|
hist.append({"role": "tool", "tool_call_id": tc.id, "content": content})
|
|
return hist
|
|
|
|
def bind_tools(self, toolcall_session=None, tools=None):
|
|
"""Register tools the LLM can call.
|
|
|
|
Two calling styles are accepted:
|
|
|
|
* Legacy: ``bind_tools(toolcall_session, tools_schemas)`` where
|
|
``toolcall_session`` implements :class:`ToolCallSession` and
|
|
``tools_schemas`` is a pre-built list of OpenAI function-schema
|
|
dicts (used by the agent/dialog layer).
|
|
* Decorator: ``bind_tools(tools=[fn1, fn2, ...])`` where each ``fn``
|
|
is decorated with :func:`rag.llm.tool_decorator.tool`. The session
|
|
and schemas are derived from the callables automatically.
|
|
"""
|
|
if tools is None and isinstance(toolcall_session, list):
|
|
tools, toolcall_session = toolcall_session, None
|
|
|
|
if tools and toolcall_session is None and all(is_tool(t) for t in tools):
|
|
session = FunctionToolSession(tools)
|
|
self.is_tools = True
|
|
self.toolcall_session = session
|
|
self.tools = session.schemas
|
|
return
|
|
|
|
if not (toolcall_session and tools):
|
|
return
|
|
self.is_tools = True
|
|
self.toolcall_session = toolcall_session
|
|
self.tools = tools
|
|
|
|
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict | None = None):
|
|
gen_conf = dict(gen_conf or {})
|
|
gen_conf = self._clean_conf(gen_conf)
|
|
if system and history and history[0].get("role") != "system":
|
|
history.insert(0, {"role": "system", "content": system})
|
|
|
|
ans = ""
|
|
tk_count = 0
|
|
hist = deepcopy(history)
|
|
for attempt in range(self.max_retries + 1):
|
|
history = deepcopy(hist)
|
|
try:
|
|
for _ in range(self.max_rounds + 1):
|
|
logging.info(f"{self.tools=}")
|
|
|
|
completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf)
|
|
response = await litellm.acompletion(
|
|
**completion_args,
|
|
drop_params=True,
|
|
timeout=self.timeout,
|
|
)
|
|
|
|
tk_count += total_token_count_from_response(response)
|
|
|
|
if not hasattr(response, "choices") or not response.choices or not response.choices[0].message:
|
|
raise Exception(f"500 response structure error. Response: {response}")
|
|
|
|
message = response.choices[0].message
|
|
reasoning_content = None
|
|
if self._need_reasoning_content_back():
|
|
reasoning_content = getattr(message, "reasoning_content", None) or getattr(message, "reasoning", None)
|
|
|
|
if not hasattr(message, "tool_calls") or not message.tool_calls:
|
|
if reasoning_content:
|
|
ans += f"<think>{reasoning_content}</think>"
|
|
ans += message.content or ""
|
|
if response.choices[0].finish_reason == "length":
|
|
ans = self._length_stop(ans)
|
|
return ans, tk_count
|
|
|
|
async def _exec_tool(tc):
|
|
name = tc.function.name
|
|
try:
|
|
args = json_repair.loads(tc.function.arguments)
|
|
if not isinstance(args, dict):
|
|
raise TypeError(f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}")
|
|
if hasattr(self.toolcall_session, "tool_call_async"):
|
|
result = await self.toolcall_session.tool_call_async(name, args)
|
|
else:
|
|
result = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
|
|
return tc, name, args, result, None
|
|
except Exception as e:
|
|
logging.exception(f"Tool call failed: {tc}")
|
|
return tc, name, {}, None, e
|
|
|
|
logging.info(f"Response tool_calls={message.tool_calls}")
|
|
results = await asyncio.gather(*[_exec_tool(tc) for tc in message.tool_calls])
|
|
history = self._append_history_batch(
|
|
history,
|
|
results,
|
|
reasoning_content=reasoning_content if self._need_reasoning_content_back() else None,
|
|
)
|
|
for tc, name, args, result, err in results:
|
|
ans += self._verbose_tool_use(name, args, err if err else result)
|
|
|
|
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
|
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
|
|
|
response, token_count = await self.async_chat("", history, gen_conf)
|
|
ans += response
|
|
tk_count += token_count
|
|
return ans, tk_count
|
|
|
|
except Exception as e:
|
|
e = await self._exceptions_async(e, attempt)
|
|
if e:
|
|
return e, tk_count
|
|
|
|
assert False, "Shouldn't be here."
|
|
|
|
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict | None = None):
|
|
gen_conf = dict(gen_conf or {})
|
|
gen_conf = self._clean_conf(gen_conf)
|
|
tools = self.tools
|
|
if system and history and history[0].get("role") != "system":
|
|
history.insert(0, {"role": "system", "content": system})
|
|
|
|
total_tokens = 0
|
|
hist = deepcopy(history)
|
|
|
|
for attempt in range(self.max_retries + 1):
|
|
history = deepcopy(hist)
|
|
try:
|
|
for _round in range(self.max_rounds + 1):
|
|
reasoning_start = False
|
|
reasoning_content = ""
|
|
logging.info(f"[ToolLoop] round={_round} model={self.model_name} tools={[t['function']['name'] for t in tools]}")
|
|
|
|
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
|
|
response = await litellm.acompletion(
|
|
**completion_args,
|
|
drop_params=True,
|
|
timeout=self.timeout,
|
|
)
|
|
|
|
final_tool_calls = {}
|
|
answer = ""
|
|
|
|
async for resp in response:
|
|
if not hasattr(resp, "choices") or not resp.choices:
|
|
continue
|
|
|
|
delta = resp.choices[0].delta
|
|
|
|
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
|
for tool_call in delta.tool_calls:
|
|
index = tool_call.index
|
|
if index not in final_tool_calls:
|
|
if not tool_call.function.arguments:
|
|
tool_call.function.arguments = ""
|
|
final_tool_calls[index] = tool_call
|
|
else:
|
|
final_tool_calls[index].function.arguments += tool_call.function.arguments or ""
|
|
continue
|
|
|
|
if not hasattr(delta, "content") or delta.content is None:
|
|
delta.content = ""
|
|
|
|
_reasoning = getattr(delta, "reasoning_content", None) or getattr(delta, "reasoning", None)
|
|
if _reasoning:
|
|
if self._need_reasoning_content_back():
|
|
reasoning_content += _reasoning
|
|
ans = ""
|
|
if not reasoning_start:
|
|
reasoning_start = True
|
|
ans = "<think>"
|
|
ans += _reasoning + "</think>"
|
|
yield ans
|
|
else:
|
|
reasoning_start = False
|
|
answer += delta.content
|
|
yield delta.content
|
|
|
|
tol = total_token_count_from_response(resp)
|
|
if not tol:
|
|
total_tokens += num_tokens_from_string(delta.content)
|
|
else:
|
|
total_tokens = tol
|
|
|
|
finish_reason = getattr(resp.choices[0], "finish_reason", "")
|
|
if finish_reason == "length":
|
|
yield self._length_stop("")
|
|
|
|
if answer and not final_tool_calls:
|
|
logging.info(f"[ToolLoop] round={_round} completed with text response, exiting")
|
|
yield total_tokens
|
|
return
|
|
|
|
async def _exec_tool(tc):
|
|
name = tc.function.name
|
|
try:
|
|
args = json_repair.loads(tc.function.arguments)
|
|
if not isinstance(args, dict):
|
|
raise TypeError(f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}")
|
|
if hasattr(self.toolcall_session, "tool_call_async"):
|
|
result = await self.toolcall_session.tool_call_async(name, args)
|
|
else:
|
|
result = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
|
|
return tc, name, args, result, None
|
|
except Exception as e:
|
|
logging.exception(f"Tool call failed: {tc}")
|
|
return tc, name, {}, None, e
|
|
|
|
tcs = list(final_tool_calls.values())
|
|
logging.info(f"[ToolLoop] round={_round} executing {len(tcs)} tool(s): {[tc.function.name for tc in tcs]}")
|
|
for tc in tcs:
|
|
try:
|
|
args = json_repair.loads(tc.function.arguments)
|
|
except Exception:
|
|
args = {}
|
|
yield self._verbose_tool_use(tc.function.name, args, "Begin to call...")
|
|
results = await asyncio.gather(*[_exec_tool(tc) for tc in tcs])
|
|
history = self._append_history_batch(
|
|
history,
|
|
results,
|
|
reasoning_content=reasoning_content if self._need_reasoning_content_back() else None,
|
|
)
|
|
for tc, name, args, result, err in results:
|
|
yield self._verbose_tool_use(name, args, err if err else result)
|
|
|
|
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
|
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
|
|
|
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
|
|
response = await litellm.acompletion(
|
|
**completion_args,
|
|
drop_params=True,
|
|
timeout=self.timeout,
|
|
)
|
|
|
|
async for resp in response:
|
|
if not hasattr(resp, "choices") or not resp.choices:
|
|
continue
|
|
delta = resp.choices[0].delta
|
|
if not hasattr(delta, "content") or delta.content is None:
|
|
continue
|
|
tol = total_token_count_from_response(resp)
|
|
if not tol:
|
|
total_tokens += num_tokens_from_string(delta.content)
|
|
else:
|
|
total_tokens = tol
|
|
yield delta.content
|
|
|
|
yield total_tokens
|
|
return
|
|
|
|
except Exception as e:
|
|
e = await self._exceptions_async(e, attempt)
|
|
if e:
|
|
yield e
|
|
yield total_tokens
|
|
return
|
|
|
|
assert False, "Shouldn't be here."
|
|
|
|
def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs):
|
|
completion_args = {
|
|
"model": self.model_name,
|
|
"messages": history,
|
|
"api_key": self.api_key,
|
|
"num_retries": self.max_retries,
|
|
**kwargs,
|
|
}
|
|
if stream:
|
|
completion_args.update(
|
|
{
|
|
"stream": stream,
|
|
}
|
|
)
|
|
if tools and self.tools:
|
|
completion_args.update(
|
|
{
|
|
"tools": self.tools,
|
|
"tool_choice": "auto",
|
|
}
|
|
)
|
|
if self.provider in FACTORY_DEFAULT_BASE_URL:
|
|
completion_args.update({"api_base": self.base_url})
|
|
elif self.provider == SupportedLiteLLMProvider.Bedrock:
|
|
import boto3
|
|
|
|
completion_args.pop("api_key", None)
|
|
completion_args.pop("api_base", None)
|
|
|
|
bedrock_key = json.loads(self.api_key)
|
|
mode = bedrock_key.get("auth_mode")
|
|
if not mode:
|
|
logging.error("Bedrock auth_mode is not provided in the key")
|
|
raise ValueError("Bedrock auth_mode must be provided in the key")
|
|
|
|
bedrock_region = bedrock_key.get("bedrock_region")
|
|
|
|
if mode == "access_key_secret":
|
|
completion_args.update({"aws_region_name": bedrock_region})
|
|
completion_args.update({"aws_access_key_id": bedrock_key.get("bedrock_ak")})
|
|
completion_args.update({"aws_secret_access_key": bedrock_key.get("bedrock_sk")})
|
|
elif mode == "iam_role":
|
|
aws_role_arn = bedrock_key.get("aws_role_arn")
|
|
sts_client = boto3.client("sts", region_name=bedrock_region)
|
|
resp = sts_client.assume_role(RoleArn=aws_role_arn, RoleSessionName="BedrockSession")
|
|
creds = resp["Credentials"]
|
|
completion_args.update({"aws_region_name": bedrock_region})
|
|
completion_args.update({"aws_access_key_id": creds["AccessKeyId"]})
|
|
completion_args.update({"aws_secret_access_key": creds["SecretAccessKey"]})
|
|
completion_args.update({"aws_session_token": creds["SessionToken"]})
|
|
else: # assume_role - use default credential chain (IRSA, instance profile, etc.)
|
|
completion_args.update({"aws_region_name": bedrock_region})
|
|
|
|
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
|
|
if self.provider_order:
|
|
|
|
def _to_order_list(x):
|
|
if x is None:
|
|
return []
|
|
if isinstance(x, str):
|
|
return [s.strip() for s in x.split(",") if s.strip()]
|
|
if isinstance(x, (list, tuple)):
|
|
return [str(s).strip() for s in x if str(s).strip()]
|
|
return []
|
|
|
|
extra_body = {}
|
|
provider_cfg = {}
|
|
provider_order = _to_order_list(self.provider_order)
|
|
provider_cfg["order"] = provider_order
|
|
provider_cfg["allow_fallbacks"] = False
|
|
extra_body["provider"] = provider_cfg
|
|
completion_args.update({"extra_body": extra_body})
|
|
elif self.provider == SupportedLiteLLMProvider.GPUStack:
|
|
completion_args.update(
|
|
{
|
|
"api_base": urljoin(self.base_url, "v1"),
|
|
}
|
|
)
|
|
elif self.provider == SupportedLiteLLMProvider.Azure_OpenAI:
|
|
completion_args.pop("api_key", None)
|
|
completion_args.pop("api_base", None)
|
|
completion_args.update(
|
|
{
|
|
"api_key": self.api_key,
|
|
"api_base": self.base_url,
|
|
"api_version": self.api_version,
|
|
}
|
|
)
|
|
|
|
# Ollama deployments commonly sit behind a reverse proxy that enforces
|
|
# Bearer auth. Ensure the Authorization header is set when an API key
|
|
# is provided, while respecting any user-supplied headers. #11350
|
|
extra_headers = deepcopy(completion_args.get("extra_headers") or {})
|
|
if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers:
|
|
extra_headers["Authorization"] = f"Bearer {self.api_key}"
|
|
# MiniMax requires GroupId as a query parameter for API authentication
|
|
if self.provider == SupportedLiteLLMProvider.MiniMax and hasattr(self, 'group_id') and self.group_id:
|
|
api_base = completion_args.get("api_base", self.base_url)
|
|
separator = "&" if "?" in api_base else "?"
|
|
completion_args["api_base"] = f"{api_base}{separator}GroupId={self.group_id}"
|
|
if extra_headers:
|
|
completion_args["extra_headers"] = extra_headers
|
|
return completion_args
|
|
|
|
|
|
class RAGconChat(Base):
|
|
"""
|
|
RAGcon Chat Provider - routes through LiteLLM proxy
|
|
|
|
All model types are handled through a unified LiteLLM endpoint.
|
|
Default Base URL: https://connect.ragcon.com/v1
|
|
"""
|
|
|
|
_FACTORY_NAME = "RAGcon"
|
|
|
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
|
if not base_url:
|
|
base_url = "https://connect.ragcon.com/v1"
|
|
|
|
super().__init__(key, model_name, base_url, **kwargs)
|