mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-03 09:11:59 +08:00
feat(agent): report accurate aggregated token usage and propagate session/user + input/output to Langfuse for agent runs (#16420)
### What problem does this PR solve?
_Briefly describe what this PR aims to solve. Include background context
that will help reviewers understand the purpose of the PR._
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [x] Other (please describe):
## Summary
Agent (Canvas) runs previously did not surface token usage in the SSE
stream, and RAGFlow's own Langfuse generations for agent runs were
missing the prompt/completion split and the session/user correlation.
This made it impossible for an external caller (or Langfuse) to
reconcile an agent turn's cost with the upstream provider (e.g.
OpenRouter), because a single turn can issue several distinct LLM calls
(query rewriting / cross-language translation, multi-round tool
reasoning, nested sub-agents, and the final answer).
This PR introduces a per-run token usage sink so that **every** LLM call
in a run is aggregated and reported once, and enriches Langfuse
generations with the prompt/completion split plus session/user
attributes.
## What changes
### 1. Per-run token usage sink (`common/token_utils.py`)
- Adds two `contextvars`: `token_usage_sink` (a mutable per-run
accumulator) and `langfuse_run_attrs` (session_id/user_id for the run).
- Adds `record_run_token_usage(...)` (thread-safe via a lock, because
`thread_pool_exec` copies the context into worker threads that share the
sink dict) and `usage_from_response(...)` which extracts a
`{prompt_tokens, completion_tokens, total_tokens}` split from
OpenAI/OpenRouter-style responses.
### 2. Provider layer captures the prompt/completion split
(`rag/llm/chat_model.py`)
- `LiteLLMBase` and `Base` now store `self.last_usage`
(prompt/completion/total) for the most recent chat call, in both the
plain and tool-calling paths.
- Streaming requests set `stream_options.include_usage = True` (LiteLLM
path) so the authoritative usage arrives on the final chunk; this is
read even on the usage-only chunk that carries no `choices`.
- Fixes a multi-round accounting bug in `*_with_tools`: token totals
were **overwritten** by each round (`total_tokens = tol`) instead of
accumulated, undercounting multi-round tool conversations. Each round is
now committed to a running aggregate.
### 3. LLMBundle reports usage once, per call
(`api/db/services/llm_service.py`)
- New `_report_usage(total_tokens)` records the call's usage into the
active run sink and returns the prompt/completion/total split for
Langfuse. The split is only used when it is consistent with the
authoritative total; otherwise only the total is reported.
- All three chat entry points (`async_chat`, `async_chat_streamly`,
`async_chat_streamly_delta`) now emit `usage_details` with
`input`/`output`/`total` instead of total-only.
- `_start_langfuse_observation` now applies `session_id`/`user_id` from
the per-run context (`langfuse_run_attrs`) so agent-run generations are
correctly grouped, even though agent LLMBundles are constructed without
those attributes.
### 4. Canvas installs the sink and emits the aggregate
(`agent/canvas.py`)
- `Canvas.run()` installs a fresh `token_usage_sink` and
`langfuse_run_attrs` (from `user_id`/`session_id`) at the start of every
turn.
- `message_end` now includes an aggregated `usage` object:
`{prompt_tokens, completion_tokens, total_tokens, calls}` covering all
LLM calls in the run.
### 5. Pass session id into the run
(`api/db/services/canvas_service.py`)
- `completion()` forwards `session_id` to `Canvas.run()` for Langfuse
session correlation.
## Why a context variable
LLM calls in an agent run originate from many places that each build
their own `LLMBundle` (e.g. `cross_languages`/`keyword_extraction`
helpers, the Agent component, and nested sub-agents invoked as tools). A
run-scoped context variable is the only non-invasive chokepoint that
captures all of them exactly once, including nested agents (which run in
the same async context) and thread-pool tools (the executor copies the
context).
## Behavior / compatibility
- No public API or wire-format removal: `message_end` gains an
additional optional `usage` field; existing consumers are unaffected.
- When a provider does not return authoritative usage, behavior falls
back to the previous token estimate (total only, no split).
- Non-agent flows (Dataflow `Pipeline`, sync `Graph.run`) are untouched.
## Testing
- [x] Simple agent answer: `message_end.usage.total_tokens` matches
provider usage.
- [x] Agent with cross-language retrieval: aggregate equals the sum of
both provider calls.
- [x] Tool-calling agent (multi-round): total accumulates across rounds.
- [x] Nested agent (agent-as-tool): sub-agent tokens included in the
parent run total.
- [x] Langfuse: agent generations show input/output split and are
grouped by session/user.
---------
Co-authored-by: yzc <yuzhichang@gmail.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
315
agent/canvas.py
315
agent/canvas.py
@@ -15,6 +15,7 @@
|
||||
#
|
||||
import asyncio
|
||||
import base64
|
||||
import contextvars
|
||||
import datetime
|
||||
import inspect
|
||||
import json
|
||||
@@ -36,52 +37,54 @@ from api.db.joint_services.tenant_model_service import get_tenant_default_model_
|
||||
from common.constants import LLMType
|
||||
from common.misc_utils import get_uuid, hash_str2int
|
||||
from common.exceptions import TaskCanceledException
|
||||
from common.token_utils import token_usage_sink, langfuse_run_attrs
|
||||
from rag.prompts.generator import chunks_format
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.tts_cache import synthesize_with_cache
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Graph:
|
||||
"""
|
||||
dsl = {
|
||||
"components": {
|
||||
"begin": {
|
||||
"obj":{
|
||||
"component_name": "Begin",
|
||||
"params": {},
|
||||
},
|
||||
"downstream": ["answer_0"],
|
||||
"upstream": [],
|
||||
dsl = {
|
||||
"components": {
|
||||
"begin": {
|
||||
"obj":{
|
||||
"component_name": "Begin",
|
||||
"params": {},
|
||||
},
|
||||
"retrieval_0": {
|
||||
"obj": {
|
||||
"component_name": "Retrieval",
|
||||
"params": {}
|
||||
},
|
||||
"downstream": ["generate_0"],
|
||||
"upstream": ["answer_0"],
|
||||
},
|
||||
"generate_0": {
|
||||
"obj": {
|
||||
"component_name": "Generate",
|
||||
"params": {}
|
||||
},
|
||||
"downstream": ["answer_0"],
|
||||
"upstream": ["retrieval_0"],
|
||||
}
|
||||
"downstream": ["answer_0"],
|
||||
"upstream": [],
|
||||
},
|
||||
"history": [],
|
||||
"path": ["begin"],
|
||||
"retrieval": {"chunks": [], "doc_aggs": []},
|
||||
"globals": {
|
||||
"sys.query": "",
|
||||
"sys.user_id": tenant_id,
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
"retrieval_0": {
|
||||
"obj": {
|
||||
"component_name": "Retrieval",
|
||||
"params": {}
|
||||
},
|
||||
"downstream": ["generate_0"],
|
||||
"upstream": ["answer_0"],
|
||||
},
|
||||
"generate_0": {
|
||||
"obj": {
|
||||
"component_name": "Generate",
|
||||
"params": {}
|
||||
},
|
||||
"downstream": ["answer_0"],
|
||||
"upstream": ["retrieval_0"],
|
||||
}
|
||||
},
|
||||
"history": [],
|
||||
"path": ["begin"],
|
||||
"retrieval": {"chunks": [], "doc_aggs": []},
|
||||
"globals": {
|
||||
"sys.query": "",
|
||||
"sys.user_id": tenant_id,
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
}
|
||||
"""
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, dsl: str, tenant_id=None, task_id=None, custom_header=None):
|
||||
self.path = []
|
||||
@@ -115,9 +118,7 @@ class Graph:
|
||||
def __str__(self):
|
||||
self.dsl["path"] = self.path
|
||||
self.dsl["task_id"] = self.task_id
|
||||
dsl = {
|
||||
"components": {}
|
||||
}
|
||||
dsl = {"components": {}}
|
||||
for k in self.dsl.keys():
|
||||
if k in ["components"]:
|
||||
continue
|
||||
@@ -139,11 +140,13 @@ class Graph:
|
||||
except Exception as e:
|
||||
logging.warning("Graph.__str__: deepcopy failed for component '%s' key '%s' (type=%s): %s. Using shallow reference.", k, c, type(cpn[c]).__name__, e)
|
||||
dsl["components"][k][c] = cpn[c]
|
||||
|
||||
def _serialize_default(obj):
|
||||
if callable(obj):
|
||||
return None
|
||||
logging.warning("Graph.__str__: JSON fallback via str() for type=%s", type(obj).__name__)
|
||||
return str(obj)
|
||||
|
||||
return json.dumps(dsl, ensure_ascii=False, default=_serialize_default)
|
||||
|
||||
def reset(self):
|
||||
@@ -180,13 +183,13 @@ class Graph:
|
||||
def get_tenant_id(self):
|
||||
return self._tenant_id
|
||||
|
||||
def get_value_with_variable(self,value: str) -> Any:
|
||||
def get_value_with_variable(self, value: str) -> Any:
|
||||
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*")
|
||||
out_parts = []
|
||||
last = 0
|
||||
|
||||
for m in pat.finditer(value):
|
||||
out_parts.append(value[last:m.start()])
|
||||
out_parts.append(value[last : m.start()])
|
||||
key = m.group(1)
|
||||
v = self.get_variable_value(key)
|
||||
if v is None:
|
||||
@@ -205,7 +208,7 @@ class Graph:
|
||||
last = m.end()
|
||||
|
||||
out_parts.append(value[last:])
|
||||
return("".join(out_parts))
|
||||
return "".join(out_parts)
|
||||
|
||||
def get_variable_value(self, exp: str) -> Any:
|
||||
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
|
||||
@@ -222,13 +225,13 @@ class Graph:
|
||||
|
||||
if not rest:
|
||||
return root_val
|
||||
return self.get_variable_param_value(root_val,rest)
|
||||
return self.get_variable_param_value(root_val, rest)
|
||||
|
||||
def get_variable_param_value(self, obj: Any, path: str) -> Any:
|
||||
cur = obj
|
||||
if not path:
|
||||
return cur
|
||||
for key in path.split('.'):
|
||||
for key in path.split("."):
|
||||
if cur is None:
|
||||
return None
|
||||
|
||||
@@ -253,7 +256,7 @@ class Graph:
|
||||
cur = getattr(cur, key, None)
|
||||
return cur
|
||||
|
||||
def set_variable_value(self, exp: str,value):
|
||||
def set_variable_value(self, exp: str, value):
|
||||
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
|
||||
if exp.find("@") < 0:
|
||||
self.globals[exp] = value
|
||||
@@ -271,11 +274,11 @@ class Graph:
|
||||
root_val = cpn["obj"].output(root_key)
|
||||
if not root_val:
|
||||
root_val = {}
|
||||
cpn["obj"].set_output(root_key, self.set_variable_param_value(root_val,rest,value))
|
||||
cpn["obj"].set_output(root_key, self.set_variable_param_value(root_val, rest, value))
|
||||
|
||||
def set_variable_param_value(self, obj: Any, path: str, value) -> Any:
|
||||
cur = obj
|
||||
keys = path.split('.')
|
||||
keys = path.split(".")
|
||||
if not path:
|
||||
return value
|
||||
for key in keys[:-1]:
|
||||
@@ -298,7 +301,6 @@ class Graph:
|
||||
|
||||
|
||||
class Canvas(Graph):
|
||||
|
||||
def __init__(self, dsl: str, tenant_id=None, task_id=None, canvas_id=None, custom_header=None):
|
||||
self.globals = {
|
||||
"sys.query": "",
|
||||
@@ -306,9 +308,14 @@ class Canvas(Graph):
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": [],
|
||||
"sys.history": [],
|
||||
"sys.date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
|
||||
"sys.date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
self.variables = {}
|
||||
# Aggregated provider token usage (prompt/completion/total) across every LLM
|
||||
# call in a single run — query rewriting, cross-language translation, tool
|
||||
# reasoning and the final answer. Populated via the token_usage_sink context
|
||||
# variable that each LLMBundle chat call writes to. Reset at run() start.
|
||||
self._run_token_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "calls": 0}
|
||||
super().__init__(dsl, tenant_id, task_id, custom_header=custom_header)
|
||||
self._id = canvas_id
|
||||
|
||||
@@ -323,13 +330,13 @@ class Canvas(Graph):
|
||||
self.globals["sys.date"] = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
|
||||
else:
|
||||
self.globals = {
|
||||
"sys.query": "",
|
||||
"sys.user_id": "",
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": [],
|
||||
"sys.history": [],
|
||||
"sys.date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
"sys.query": "",
|
||||
"sys.user_id": "",
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": [],
|
||||
"sys.history": [],
|
||||
"sys.date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
if "variables" in self.dsl:
|
||||
self.variables = self.dsl["variables"]
|
||||
else:
|
||||
@@ -393,6 +400,38 @@ class Canvas(Graph):
|
||||
self.globals[k] = ""
|
||||
|
||||
async def run(self, **kwargs):
|
||||
# Install a fresh per-run token usage sink and Langfuse correlation context,
|
||||
# and guarantee both are torn down when the run ends (even on early return or
|
||||
# exception) so later LLM calls in the same task never inherit a previous
|
||||
# run's sink or session/user attributes.
|
||||
self._run_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "calls": 0}
|
||||
_lf_attrs = {}
|
||||
_user_id = kwargs.get("user_id")
|
||||
if _user_id:
|
||||
_lf_attrs["user_id"] = str(_user_id)[:200]
|
||||
_session_id = kwargs.get("session_id") or self._id
|
||||
if _session_id:
|
||||
_lf_attrs["session_id"] = str(_session_id)[:200]
|
||||
sink_token = token_usage_sink.set(self._run_token_usage)
|
||||
attrs_token = langfuse_run_attrs.set(_lf_attrs)
|
||||
try:
|
||||
async for ev in self._run_impl(**kwargs):
|
||||
yield ev
|
||||
finally:
|
||||
# reset() can raise if the generator is closed from a different context
|
||||
# (e.g. client disconnect); fall back to clearing the values in that case.
|
||||
try:
|
||||
token_usage_sink.reset(sink_token)
|
||||
except ValueError:
|
||||
logging.debug("Failed to reset token usage ContextVar", exc_info=True)
|
||||
token_usage_sink.set(None)
|
||||
try:
|
||||
langfuse_run_attrs.reset(attrs_token)
|
||||
except ValueError:
|
||||
logging.debug("Failed to reset Langfuse run attributes ContextVar", exc_info=True)
|
||||
langfuse_run_attrs.set(None)
|
||||
|
||||
async def _run_impl(self, **kwargs):
|
||||
self.globals["sys.date"] = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
|
||||
st = time.perf_counter()
|
||||
self._loop = asyncio.get_running_loop()
|
||||
@@ -406,7 +445,7 @@ class Canvas(Graph):
|
||||
|
||||
if kwargs.get("webhook_payload"):
|
||||
for k, cpn in self.components.items():
|
||||
if self.components[k]["obj"].component_name.lower() == "begin" and self.components[k]["obj"]._param.mode == "Webhook":
|
||||
if self.components[k]["obj"].component_name.lower() == "begin" and self.components[k]["obj"]._param.mode == "Webhook":
|
||||
payload = kwargs.get("webhook_payload", {})
|
||||
if "input" in payload:
|
||||
self.components[k]["obj"].set_input_value("request", payload["input"])
|
||||
@@ -427,7 +466,7 @@ class Canvas(Graph):
|
||||
self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k], layout_recognize)
|
||||
else:
|
||||
self.globals[f"sys.{k}"] = kwargs[k]
|
||||
if not self.globals["sys.conversation_turns"] :
|
||||
if not self.globals["sys.conversation_turns"]:
|
||||
self.globals["sys.conversation_turns"] = 0
|
||||
self.globals["sys.conversation_turns"] += 1
|
||||
is_resume = bool(self.path) and self.path[0].lower().find("userfillup") >= 0
|
||||
@@ -436,11 +475,11 @@ class Canvas(Graph):
|
||||
nonlocal created_at
|
||||
return {
|
||||
"event": event,
|
||||
#"conversation_id": "f3cc152b-24b0-4258-a1a1-7d5e9fc8a115",
|
||||
# "conversation_id": "f3cc152b-24b0-4258-a1a1-7d5e9fc8a115",
|
||||
"message_id": self.message_id,
|
||||
"created_at": created_at,
|
||||
"task_id": self.task_id,
|
||||
"data": dt
|
||||
"data": dt,
|
||||
}
|
||||
|
||||
if not is_resume:
|
||||
@@ -476,7 +515,11 @@ class Canvas(Graph):
|
||||
if use_async:
|
||||
await cpn_obj.invoke_async(**(call_kwargs or {}))
|
||||
return
|
||||
await loop.run_in_executor(self._thread_pool, partial(sync_fn, **(call_kwargs or {})))
|
||||
# run_in_executor does not propagate context variables; copy the
|
||||
# current context so the token usage sink / Langfuse attributes set
|
||||
# by run() remain visible to LLMBundle calls inside sync components.
|
||||
ctx = contextvars.copy_context()
|
||||
await loop.run_in_executor(self._thread_pool, lambda: ctx.run(partial(sync_fn, **(call_kwargs or {}))))
|
||||
|
||||
i = f
|
||||
while i < t:
|
||||
@@ -525,16 +568,19 @@ class Canvas(Graph):
|
||||
json.dumps(outputs, ensure_ascii=False, default=str)[:500],
|
||||
cpn_obj.error(),
|
||||
)
|
||||
return decorate("node_finished",{
|
||||
"inputs": cpn_obj.get_input_values(),
|
||||
"outputs": outputs,
|
||||
"component_id": cpn_obj._id,
|
||||
"component_name": self.get_component_name(cpn_obj._id),
|
||||
"component_type": self.get_component_type(cpn_obj._id),
|
||||
"error": cpn_obj.error(),
|
||||
"elapsed_time": time.perf_counter() - cpn_obj.output("_created_time"),
|
||||
"created_at": cpn_obj.output("_created_time"),
|
||||
})
|
||||
return decorate(
|
||||
"node_finished",
|
||||
{
|
||||
"inputs": cpn_obj.get_input_values(),
|
||||
"outputs": outputs,
|
||||
"component_id": cpn_obj._id,
|
||||
"component_name": self.get_component_name(cpn_obj._id),
|
||||
"component_type": self.get_component_type(cpn_obj._id),
|
||||
"error": cpn_obj.error(),
|
||||
"elapsed_time": time.perf_counter() - cpn_obj.output("_created_time"),
|
||||
"created_at": cpn_obj.output("_created_time"),
|
||||
},
|
||||
)
|
||||
|
||||
self.error = ""
|
||||
idx = 0 if is_resume else len(self.path) - 1
|
||||
@@ -543,13 +589,17 @@ class Canvas(Graph):
|
||||
while idx < len(self.path):
|
||||
to = len(self.path)
|
||||
for i in range(idx, to):
|
||||
yield decorate("node_started", {
|
||||
"inputs": None, "created_at": int(time.time()),
|
||||
"component_id": self.path[i],
|
||||
"component_name": self.get_component_name(self.path[i]),
|
||||
"component_type": self.get_component_type(self.path[i]),
|
||||
"thoughts": self.get_component_thoughts(self.path[i])
|
||||
})
|
||||
yield decorate(
|
||||
"node_started",
|
||||
{
|
||||
"inputs": None,
|
||||
"created_at": int(time.time()),
|
||||
"component_id": self.path[i],
|
||||
"component_name": self.get_component_name(self.path[i]),
|
||||
"component_type": self.get_component_type(self.path[i]),
|
||||
"thoughts": self.get_component_thoughts(self.path[i]),
|
||||
},
|
||||
)
|
||||
await _run_batch(idx, to)
|
||||
to = len(self.path)
|
||||
# post-processing of components invocation
|
||||
@@ -564,6 +614,7 @@ class Canvas(Graph):
|
||||
_m = ""
|
||||
buff_m = ""
|
||||
stream = cpn_obj.output("content")()
|
||||
|
||||
async def _process_stream(m):
|
||||
nonlocal buff_m, _m, tts_mdl
|
||||
if not m:
|
||||
@@ -578,13 +629,7 @@ class Canvas(Graph):
|
||||
_m += m
|
||||
|
||||
if len(buff_m) > 16:
|
||||
ev = decorate(
|
||||
"message",
|
||||
{
|
||||
"content": m,
|
||||
"audio_binary": self.tts(tts_mdl, buff_m)
|
||||
}
|
||||
)
|
||||
ev = decorate("message", {"content": m, "audio_binary": self.tts(tts_mdl, buff_m)})
|
||||
buff_m = ""
|
||||
return ev
|
||||
|
||||
@@ -592,12 +637,12 @@ class Canvas(Graph):
|
||||
|
||||
if inspect.isasyncgen(stream):
|
||||
async for m in stream:
|
||||
ev= await _process_stream(m)
|
||||
ev = await _process_stream(m)
|
||||
if ev:
|
||||
yield ev
|
||||
else:
|
||||
for m in stream:
|
||||
ev= await _process_stream(m)
|
||||
ev = await _process_stream(m)
|
||||
if ev:
|
||||
yield ev
|
||||
if buff_m:
|
||||
@@ -629,7 +674,7 @@ class Canvas(Graph):
|
||||
else:
|
||||
self.error = cpn_obj.error()
|
||||
|
||||
if cpn_obj.component_name.lower() not in ("iteration","loop"):
|
||||
if cpn_obj.component_name.lower() not in ("iteration", "loop"):
|
||||
if isinstance(cpn_obj.output("content"), partial):
|
||||
if self.error:
|
||||
cpn_obj.set_output("content", None)
|
||||
@@ -654,7 +699,7 @@ class Canvas(Graph):
|
||||
for cpn_id in cpn_ids:
|
||||
_append_path(cpn_id)
|
||||
|
||||
if cpn_obj.component_name.lower() in ("iterationitem","loopitem") and cpn_obj.end():
|
||||
if cpn_obj.component_name.lower() in ("iterationitem", "loopitem") and cpn_obj.end():
|
||||
iter = cpn_obj.get_parent()
|
||||
yield _node_finished(iter)
|
||||
_extend_path(self.get_component(cpn["parent_id"])["downstream"])
|
||||
@@ -683,10 +728,7 @@ class Canvas(Graph):
|
||||
o = self.get_component_obj(c)
|
||||
if o.component_name.lower() == "userfillup":
|
||||
o.invoke()
|
||||
another_inputs.update({
|
||||
k: v for k, v in o.get_input_elements().items()
|
||||
if not self._is_input_field_satisfied(v)
|
||||
})
|
||||
another_inputs.update({k: v for k, v in o.get_input_elements().items() if not self._is_input_field_satisfied(v)})
|
||||
if o.get_param("enable_tips"):
|
||||
tips = o.output("tips")
|
||||
if not another_inputs:
|
||||
@@ -696,23 +738,30 @@ class Canvas(Graph):
|
||||
return
|
||||
self.path = self.path[:idx]
|
||||
if not self.error:
|
||||
yield decorate("workflow_finished",
|
||||
{
|
||||
"inputs": kwargs.get("inputs"),
|
||||
"outputs": self.get_component_obj(self.path[-1]).output(),
|
||||
"elapsed_time": time.perf_counter() - st,
|
||||
"created_at": st,
|
||||
})
|
||||
yield decorate(
|
||||
"workflow_finished",
|
||||
{
|
||||
"inputs": kwargs.get("inputs"),
|
||||
"outputs": self.get_component_obj(self.path[-1]).output(),
|
||||
"elapsed_time": time.perf_counter() - st,
|
||||
"created_at": st,
|
||||
# Run-level total of all LLM calls — emitted once here.
|
||||
"usage": self._run_usage_payload(),
|
||||
},
|
||||
)
|
||||
self.history.append(("assistant", self.get_component_obj(self.path[-1]).output()))
|
||||
self.globals["sys.history"].append(f"{self.history[-1][0]}: {self.history[-1][1]}")
|
||||
elif "Task has been canceled" in self.error:
|
||||
yield decorate("workflow_finished",
|
||||
{
|
||||
"inputs": kwargs.get("inputs"),
|
||||
"outputs": "Task has been canceled",
|
||||
"elapsed_time": time.perf_counter() - st,
|
||||
"created_at": st,
|
||||
})
|
||||
yield decorate(
|
||||
"workflow_finished",
|
||||
{
|
||||
"inputs": kwargs.get("inputs"),
|
||||
"outputs": "Task has been canceled",
|
||||
"elapsed_time": time.perf_counter() - st,
|
||||
"created_at": st,
|
||||
"usage": self._run_usage_payload(),
|
||||
},
|
||||
)
|
||||
|
||||
def is_reff(self, exp: str) -> bool:
|
||||
exp = exp.strip("{").strip("}")
|
||||
@@ -725,8 +774,7 @@ class Canvas(Graph):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def tts(self,tts_mdl, text):
|
||||
def tts(self, tts_mdl, text):
|
||||
def clean_tts_text(text: str) -> str:
|
||||
if not text:
|
||||
return ""
|
||||
@@ -736,15 +784,8 @@ class Canvas(Graph):
|
||||
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
|
||||
|
||||
emoji_pattern = re.compile(
|
||||
"[\U0001F600-\U0001F64F"
|
||||
"\U0001F300-\U0001F5FF"
|
||||
"\U0001F680-\U0001F6FF"
|
||||
"\U0001F1E0-\U0001F1FF"
|
||||
"\U00002700-\U000027BF"
|
||||
"\U0001F900-\U0001F9FF"
|
||||
"\U0001FA70-\U0001FAFF"
|
||||
"\U0001FAD0-\U0001FAFF]+",
|
||||
flags=re.UNICODE
|
||||
"[\U0001f600-\U0001f64f\U0001f300-\U0001f5ff\U0001f680-\U0001f6ff\U0001f1e0-\U0001f1ff\U00002700-\U000027bf\U0001f900-\U0001f9ff\U0001fa70-\U0001faff\U0001fad0-\U0001faff]+",
|
||||
flags=re.UNICODE,
|
||||
)
|
||||
text = emoji_pattern.sub("", text)
|
||||
|
||||
@@ -755,6 +796,7 @@ class Canvas(Graph):
|
||||
text = text[:MAX_LEN]
|
||||
|
||||
return text
|
||||
|
||||
if not tts_mdl or not text:
|
||||
return None
|
||||
text = clean_tts_text(text)
|
||||
@@ -766,7 +808,7 @@ class Canvas(Graph):
|
||||
convs = []
|
||||
if window_size <= 0:
|
||||
return convs
|
||||
for role, obj in self.history[window_size * -2:]:
|
||||
for role, obj in self.history[window_size * -2 :]:
|
||||
if isinstance(obj, dict):
|
||||
convs.append({"role": role, "content": obj.get("content", "")})
|
||||
else:
|
||||
@@ -815,17 +857,19 @@ class Canvas(Graph):
|
||||
|
||||
async def get_files_async(self, files: Union[None, list[dict]], layout_recognize: str = None) -> list[str]:
|
||||
if not files:
|
||||
return []
|
||||
return []
|
||||
|
||||
def image_to_base64(file):
|
||||
return "data:{};base64,{}".format(file["mime_type"],
|
||||
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
||||
return "data:{};base64,{}".format(file["mime_type"], base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
||||
|
||||
def parse_file(file):
|
||||
blob = FileService.get_blob(file["created_by"], file["id"])
|
||||
return FileService.parse(file["name"], blob, True, file["created_by"], layout_recognize)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
tasks = []
|
||||
for file in files:
|
||||
if file["mime_type"].find("image") >=0:
|
||||
if file["mime_type"].find("image") >= 0:
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file))
|
||||
continue
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, parse_file, file))
|
||||
@@ -844,7 +888,7 @@ class Canvas(Graph):
|
||||
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
|
||||
agent_ids = agent_id.split("-->")
|
||||
agent_name = self.get_component_name(agent_ids[0])
|
||||
path = agent_name if len(agent_ids) < 2 else agent_name+"-->"+"-->".join(agent_ids[1:])
|
||||
path = agent_name if len(agent_ids) < 2 else agent_name + "-->" + "-->".join(agent_ids[1:])
|
||||
try:
|
||||
bin = REDIS_CONN.get(f"{self.task_id}-{self.message_id}-logs")
|
||||
if bin:
|
||||
@@ -852,16 +896,10 @@ class Canvas(Graph):
|
||||
if obj[-1]["component_id"] == agent_ids[0]:
|
||||
obj[-1]["trace"].append({"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time})
|
||||
else:
|
||||
obj.append({
|
||||
"component_id": agent_ids[0],
|
||||
"trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}]
|
||||
})
|
||||
obj.append({"component_id": agent_ids[0], "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}]})
|
||||
else:
|
||||
obj = [{
|
||||
"component_id": agent_ids[0],
|
||||
"trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}]
|
||||
}]
|
||||
REDIS_CONN.set_obj(f"{self.task_id}-{self.message_id}-logs", obj, 60*10)
|
||||
obj = [{"component_id": agent_ids[0], "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}]}]
|
||||
REDIS_CONN.set_obj(f"{self.task_id}-{self.message_id}-logs", obj, 60 * 10)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
@@ -899,9 +937,22 @@ class Canvas(Graph):
|
||||
message_end["attachment"] = cpn_obj.output("attachment")
|
||||
if self._has_reference():
|
||||
message_end["reference"] = self.get_reference()
|
||||
# NOTE: aggregated run token usage is intentionally NOT attached here.
|
||||
# _build_message_end runs once per Message component, so a multi-Message graph
|
||||
# would emit cumulative usage repeatedly and double count. The run total is
|
||||
# emitted exactly once on the terminal workflow_finished event instead.
|
||||
return message_end
|
||||
|
||||
def add_memory(self, user:str, assist:str, summ: str):
|
||||
def _run_usage_payload(self) -> dict:
|
||||
usage = getattr(self, "_run_token_usage", None) or {}
|
||||
return {
|
||||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage.get("completion_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
"calls": usage.get("calls", 0),
|
||||
}
|
||||
|
||||
def add_memory(self, user: str, assist: str, summ: str):
|
||||
self.memory.append((user, assist, summ))
|
||||
|
||||
def get_memory(self) -> list[Tuple]:
|
||||
|
||||
@@ -88,29 +88,32 @@ def _canvas_json_default(obj):
|
||||
def _require_canvas_access_sync(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not UserCanvasService.accessible(kwargs.get('agent_id'), kwargs.get('tenant_id')):
|
||||
if not UserCanvasService.accessible(kwargs.get("agent_id"), kwargs.get("tenant_id")):
|
||||
return get_json_result(data=False, message="Make sure you have permission to access the agent.", code=RetCode.OPERATING_ERROR)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _require_canvas_access_async(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
agent_id = kwargs.get('agent_id')
|
||||
tenant_id = kwargs.get('tenant_id')
|
||||
agent_id = kwargs.get("agent_id")
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
if not await thread_pool_exec(UserCanvasService.accessible, agent_id, tenant_id):
|
||||
return get_json_result(data=False, message="Make sure you have permission to access the agent.", code=RetCode.OPERATING_ERROR)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _require_canvas_owner_sync(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not UserCanvasService.query(user_id=kwargs.get('tenant_id'), id=kwargs.get('agent_id')):
|
||||
if not UserCanvasService.query(user_id=kwargs.get("tenant_id"), id=kwargs.get("agent_id")):
|
||||
return get_json_result(data=False, message="Only the owner of the agent is authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -261,9 +264,7 @@ async def _run_workflow_session(
|
||||
if "chunks" in workflow_conv["reference"]:
|
||||
workflow_conv["reference"] = [workflow_conv["reference"]]
|
||||
else:
|
||||
workflow_conv["reference"] = [
|
||||
value for _, value in sorted(workflow_conv["reference"].items(), key=lambda item: int(item[0]))
|
||||
]
|
||||
workflow_conv["reference"] = [value for _, value in sorted(workflow_conv["reference"].items(), key=lambda item: int(item[0]))]
|
||||
elif not isinstance(workflow_conv.get("reference"), list):
|
||||
workflow_conv["reference"] = []
|
||||
workflow_conv["reference"] = [_normalize_agent_reference_entry(reference) for reference in workflow_conv["reference"]]
|
||||
@@ -344,22 +345,16 @@ async def _run_workflow_session(
|
||||
# bare [DONE] (fixes #15169).
|
||||
logging.info(
|
||||
"empty agent output - returning session_id (agent_id=%s session_id=%s stream=%s)",
|
||||
agent_id, session_id, True,
|
||||
)
|
||||
yield (
|
||||
"data:"
|
||||
+ json.dumps({"session_id": session_id, "data": {}}, ensure_ascii=False)
|
||||
+ "\n\n"
|
||||
agent_id,
|
||||
session_id,
|
||||
True,
|
||||
)
|
||||
yield ("data:" + json.dumps({"session_id": session_id, "data": {}}, ensure_ascii=False) + "\n\n")
|
||||
await persist_workflow_session()
|
||||
except Exception as exc:
|
||||
logging.exception(exc)
|
||||
canvas.cancel_task()
|
||||
yield (
|
||||
"data:"
|
||||
+ json.dumps({"code": 500, "message": str(exc), "data": False}, ensure_ascii=False)
|
||||
+ "\n\n"
|
||||
)
|
||||
yield ("data:" + json.dumps({"code": 500, "message": str(exc), "data": False}, ensure_ascii=False) + "\n\n")
|
||||
finally:
|
||||
if not done_sent:
|
||||
done_sent = True
|
||||
@@ -400,7 +395,9 @@ async def _run_workflow_session(
|
||||
# (fixes #15169).
|
||||
logging.info(
|
||||
"empty agent output - returning session_id (agent_id=%s session_id=%s stream=%s)",
|
||||
agent_id, session_id, False,
|
||||
agent_id,
|
||||
session_id,
|
||||
False,
|
||||
)
|
||||
await commit_runtime_replica()
|
||||
return get_result(data={"session_id": session_id})
|
||||
@@ -559,16 +556,13 @@ async def delete_agent_session(tenant_id, agent_id):
|
||||
|
||||
if errors:
|
||||
if success_count > 0:
|
||||
return get_result(data={"success_count": success_count, "errors": errors},
|
||||
message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
|
||||
return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
|
||||
else:
|
||||
return get_error_data_result(message="; ".join(errors))
|
||||
|
||||
if duplicate_messages:
|
||||
if success_count > 0:
|
||||
return get_result(
|
||||
message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors",
|
||||
data={"success_count": success_count, "errors": duplicate_messages})
|
||||
return get_result(message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages})
|
||||
else:
|
||||
return get_error_data_result(message=";".join(duplicate_messages))
|
||||
|
||||
@@ -611,8 +605,8 @@ async def _iter_session_completion_events(tenant_id, agent_id, req, return_trace
|
||||
yield ans
|
||||
continue
|
||||
|
||||
if event in ["message", "message_end", "user_inputs", "workflow_finished"]:
|
||||
if event in ["user_inputs", "workflow_finished"]:
|
||||
if event in ["message", "message_end", "user_inputs"]:
|
||||
if event == "user_inputs":
|
||||
logging.debug(
|
||||
"Forwarding session completion event: tenant_id=%s agent_id=%s event=%s",
|
||||
tenant_id,
|
||||
@@ -620,6 +614,22 @@ async def _iter_session_completion_events(tenant_id, agent_id, req, return_trace
|
||||
event,
|
||||
)
|
||||
yield ans
|
||||
continue
|
||||
|
||||
if event == "workflow_finished":
|
||||
# Forward only the run-level aggregated token usage, not the whole terminal
|
||||
# payload (inputs/outputs), so the session completion stream surface stays
|
||||
# limited to what the usage contract needs.
|
||||
logging.debug(
|
||||
"Forwarding session completion event: tenant_id=%s agent_id=%s event=%s",
|
||||
tenant_id,
|
||||
agent_id,
|
||||
event,
|
||||
)
|
||||
usage = ans.get("data", {}).get("usage")
|
||||
if usage is not None:
|
||||
yield {**ans, "data": {"usage": usage}}
|
||||
continue
|
||||
|
||||
|
||||
@manager.route("/agents/templates", methods=["GET"]) # noqa: F821
|
||||
@@ -760,7 +770,7 @@ async def update_agent_tags(tenant_id, canvas_id):
|
||||
@add_tenant_id_to_kwargs
|
||||
async def create_agent(tenant_id):
|
||||
req = {k: v for k, v in (await get_request_json()).items() if v is not None}
|
||||
req["canvas_type"] = req.get("canvas_type","")
|
||||
req["canvas_type"] = req.get("canvas_type", "")
|
||||
req["user_id"] = tenant_id
|
||||
req["canvas_category"] = req.get("canvas_category") or CanvasCategory.Agent
|
||||
req["release"] = bool(req.get("release", ""))
|
||||
@@ -837,13 +847,9 @@ async def upload_agent_file(agent_id, tenant_id):
|
||||
)
|
||||
try:
|
||||
if len(file_objs) == 1:
|
||||
uploaded = await thread_pool_exec(
|
||||
FileService.upload_info, tenant_id, file_objs[0], request.args.get("url")
|
||||
)
|
||||
uploaded = await thread_pool_exec(FileService.upload_info, tenant_id, file_objs[0], request.args.get("url"))
|
||||
return get_json_result(data=uploaded)
|
||||
results = await asyncio.gather(
|
||||
*(thread_pool_exec(FileService.upload_info, tenant_id, file_obj) for file_obj in file_objs)
|
||||
)
|
||||
results = await asyncio.gather(*(thread_pool_exec(FileService.upload_info, tenant_id, file_obj) for file_obj in file_objs))
|
||||
return get_json_result(data=results)
|
||||
except Exception as exc:
|
||||
logging.exception(
|
||||
@@ -1015,7 +1021,7 @@ def delete_agent(agent_id, tenant_id):
|
||||
@_require_canvas_access_async
|
||||
async def update_agent(agent_id, tenant_id):
|
||||
req = {k: v for k, v in (await get_request_json()).items() if v is not None}
|
||||
req["canvas_type"] = req.get("canvas_type","")
|
||||
req["canvas_type"] = req.get("canvas_type", "")
|
||||
req["release"] = bool(req.get("release", ""))
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
@@ -1038,10 +1044,7 @@ async def update_agent(agent_id, tenant_id):
|
||||
return get_data_error_result(message=f"{req['title']} already exists.")
|
||||
|
||||
agent_title_for_version = req.get("title") or (current_agent.title if current_agent else "")
|
||||
canvas_category = (
|
||||
req.get("canvas_category")
|
||||
or (current_agent.canvas_category if current_agent else CanvasCategory.Agent)
|
||||
)
|
||||
canvas_category = req.get("canvas_category") or (current_agent.canvas_category if current_agent else CanvasCategory.Agent)
|
||||
owner_nickname = _get_user_nickname(tenant_id)
|
||||
UserCanvasService.update_by_id(agent_id, req)
|
||||
|
||||
@@ -1153,13 +1156,19 @@ async def test_db_connection():
|
||||
except ValueError as exc:
|
||||
logging.warning(
|
||||
"Rejected test_db_connection: unsafe host %r (db_type=%s, user=%s): %s",
|
||||
req.get("host"), req.get("db_type"), current_user.id, exc,
|
||||
req.get("host"),
|
||||
req.get("db_type"),
|
||||
current_user.id,
|
||||
exc,
|
||||
)
|
||||
return get_data_error_result(message=str(exc))
|
||||
except OSError as exc:
|
||||
logging.warning(
|
||||
"Rejected test_db_connection: cannot resolve host %r (db_type=%s, user=%s): %s",
|
||||
req.get("host"), req.get("db_type"), current_user.id, exc,
|
||||
req.get("host"),
|
||||
req.get("db_type"),
|
||||
current_user.id,
|
||||
exc,
|
||||
)
|
||||
logging.debug("Full resolver exception for host %r", req.get("host"), exc_info=True)
|
||||
return get_data_error_result(message=f"Could not resolve host {req.get('host')!r}.")
|
||||
@@ -1198,13 +1207,7 @@ async def test_db_connection():
|
||||
elif req["db_type"] == "mssql":
|
||||
import pyodbc
|
||||
|
||||
connection_string = (
|
||||
f"DRIVER={{ODBC Driver 17 for SQL Server}};"
|
||||
f"SERVER={safe_host},{req['port']};"
|
||||
f"DATABASE={req['database']};"
|
||||
f"UID={req['username']};"
|
||||
f"PWD={req['password']};"
|
||||
)
|
||||
connection_string = f"DRIVER={{ODBC Driver 17 for SQL Server}};SERVER={safe_host},{req['port']};DATABASE={req['database']};UID={req['username']};PWD={req['password']};"
|
||||
db = pyodbc.connect(connection_string)
|
||||
try:
|
||||
cursor = db.cursor()
|
||||
@@ -1217,14 +1220,7 @@ async def test_db_connection():
|
||||
elif req["db_type"] == "IBM DB2":
|
||||
import ibm_db
|
||||
|
||||
conn_str = (
|
||||
f"DATABASE={req['database']};"
|
||||
f"HOSTNAME={safe_host};"
|
||||
f"PORT={req['port']};"
|
||||
f"PROTOCOL=TCPIP;"
|
||||
f"UID={req['username']};"
|
||||
f"PWD={req['password']};"
|
||||
)
|
||||
conn_str = f"DATABASE={req['database']};HOSTNAME={safe_host};PORT={req['port']};PROTOCOL=TCPIP;UID={req['username']};PWD={req['password']};"
|
||||
logging.info(
|
||||
"DATABASE=%s;HOSTNAME=%s;PORT=%s;PROTOCOL=TCPIP;UID=%s;PWD=****;",
|
||||
req["database"],
|
||||
@@ -1387,9 +1383,7 @@ async def agent_chat_completion(tenant_id, agent_id=None):
|
||||
if "chunks" in workflow_conv["reference"]:
|
||||
workflow_conv["reference"] = [workflow_conv["reference"]]
|
||||
else:
|
||||
workflow_conv["reference"] = [
|
||||
value for _, value in sorted(workflow_conv["reference"].items(), key=lambda item: int(item[0]))
|
||||
]
|
||||
workflow_conv["reference"] = [value for _, value in sorted(workflow_conv["reference"].items(), key=lambda item: int(item[0]))]
|
||||
elif not isinstance(workflow_conv.get("reference"), list):
|
||||
workflow_conv["reference"] = []
|
||||
workflow_conv["reference"] = [_normalize_agent_reference_entry(reference) for reference in workflow_conv["reference"]]
|
||||
@@ -1598,13 +1592,11 @@ async def agent_chat_completion(tenant_id, agent_id=None):
|
||||
# seeing only a bare [DONE] (fixes #15169).
|
||||
logging.info(
|
||||
"empty agent output - returning session_id (agent_id=%s session_id=%s stream=%s)",
|
||||
agent_id, session_id, True,
|
||||
)
|
||||
yield (
|
||||
"data:"
|
||||
+ json.dumps({"session_id": session_id, "data": {}}, ensure_ascii=False)
|
||||
+ "\n\n"
|
||||
agent_id,
|
||||
session_id,
|
||||
True,
|
||||
)
|
||||
yield ("data:" + json.dumps({"session_id": session_id, "data": {}}, ensure_ascii=False) + "\n\n")
|
||||
yield "data:[DONE]\n\n"
|
||||
|
||||
return _build_sse_response(generate())
|
||||
@@ -1614,6 +1606,7 @@ async def agent_chat_completion(tenant_id, agent_id=None):
|
||||
final_ans = {}
|
||||
trace_items = []
|
||||
structured_output = {}
|
||||
run_usage = None
|
||||
async for ans in _iter_session_completion_events(tenant_id, agent_id, req, return_trace):
|
||||
try:
|
||||
if ans["event"] == "message":
|
||||
@@ -1633,6 +1626,11 @@ async def agent_chat_completion(tenant_id, agent_id=None):
|
||||
"trace": [copy.deepcopy(data)],
|
||||
}
|
||||
)
|
||||
if ans.get("event") == "workflow_finished":
|
||||
# Capture the run-level usage but keep message_end/user_inputs as
|
||||
# final_ans so the non-stream response shape stays unchanged.
|
||||
run_usage = ans.get("data", {}).get("usage")
|
||||
continue
|
||||
if ans.get("event") == "message_end":
|
||||
final_ans = ans
|
||||
elif ans.get("event") == "user_inputs" and not final_ans:
|
||||
@@ -1647,7 +1645,9 @@ async def agent_chat_completion(tenant_id, agent_id=None):
|
||||
# (fixes #15169).
|
||||
logging.info(
|
||||
"empty agent output - returning session_id (agent_id=%s session_id=%s stream=%s)",
|
||||
agent_id, session_id, False,
|
||||
agent_id,
|
||||
session_id,
|
||||
False,
|
||||
)
|
||||
return get_result(data={"session_id": session_id})
|
||||
|
||||
@@ -1655,6 +1655,8 @@ async def agent_chat_completion(tenant_id, agent_id=None):
|
||||
final_ans["data"] = {}
|
||||
final_ans["data"]["content"] = full_content
|
||||
final_ans["data"]["reference"] = reference
|
||||
if run_usage:
|
||||
final_ans["data"]["usage"] = run_usage
|
||||
if structured_output:
|
||||
final_ans["data"]["structured"] = structured_output
|
||||
if return_trace and final_ans:
|
||||
@@ -1688,16 +1690,16 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
# 1. Fetch canvas by agent_id
|
||||
exists, cvs = UserCanvasService.get_by_id(agent_id)
|
||||
if not exists:
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Canvas not found."),RetCode.BAD_REQUEST
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST, message="Canvas not found."), RetCode.BAD_REQUEST
|
||||
|
||||
# 2. Check canvas category
|
||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Dataflow can not be triggered by webhook."),RetCode.BAD_REQUEST
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST, message="Dataflow can not be triggered by webhook."), RetCode.BAD_REQUEST
|
||||
|
||||
# 3. Load DSL from canvas
|
||||
dsl = getattr(cvs, "dsl", None)
|
||||
if not isinstance(dsl, dict):
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST, message="Invalid DSL format."), RetCode.BAD_REQUEST
|
||||
|
||||
# 4. Check webhook configuration in DSL
|
||||
webhook_cfg = {}
|
||||
@@ -1708,15 +1710,13 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
webhook_cfg = cpn_obj["params"]
|
||||
|
||||
if not webhook_cfg:
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Webhook not configured for this agent."),RetCode.BAD_REQUEST
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST, message="Webhook not configured for this agent."), RetCode.BAD_REQUEST
|
||||
|
||||
# 5. Validate request method against webhook_cfg.methods
|
||||
allowed_methods = webhook_cfg.get("methods", [])
|
||||
request_method = request.method.upper()
|
||||
if allowed_methods and request_method not in allowed_methods:
|
||||
return get_data_error_result(
|
||||
code=RetCode.BAD_REQUEST,message=f"HTTP method '{request_method}' not allowed for this webhook."
|
||||
),RetCode.BAD_REQUEST
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST, message=f"HTTP method '{request_method}' not allowed for this webhook."), RetCode.BAD_REQUEST
|
||||
|
||||
async def validate_webhook_security(security_cfg: dict):
|
||||
"""Validate webhook security rules based on security configuration."""
|
||||
@@ -1795,7 +1795,6 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
|
||||
client_ip = request.remote_addr
|
||||
|
||||
|
||||
for rule in whitelist:
|
||||
if "/" in rule:
|
||||
# CIDR notation
|
||||
@@ -1854,7 +1853,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
|
||||
def _validate_token_auth(security_cfg):
|
||||
"""Validate header-based token authentication."""
|
||||
token_cfg = security_cfg.get("token",{})
|
||||
token_cfg = security_cfg.get("token", {})
|
||||
header = token_cfg.get("token_header")
|
||||
token_value = token_cfg.get("token_value")
|
||||
|
||||
@@ -1883,7 +1882,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise Exception("Missing Bearer token")
|
||||
|
||||
token = auth_header[len("Bearer "):].strip()
|
||||
token = auth_header[len("Bearer ") :].strip()
|
||||
if not token:
|
||||
raise Exception("Empty Bearer token")
|
||||
|
||||
@@ -1922,10 +1921,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
else:
|
||||
required_claims = []
|
||||
|
||||
required_claims = [
|
||||
c for c in required_claims
|
||||
if isinstance(c, str) and c.strip()
|
||||
]
|
||||
required_claims = [c for c in required_claims if isinstance(c, str) and c.strip()]
|
||||
|
||||
RESERVED_CLAIMS = {"exp", "sub", "aud", "iss", "nbf", "iat"}
|
||||
for claim in required_claims:
|
||||
@@ -1939,10 +1935,10 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
return decoded
|
||||
|
||||
try:
|
||||
security_config=webhook_cfg.get("security", {})
|
||||
security_config = webhook_cfg.get("security", {})
|
||||
await validate_webhook_security(security_config)
|
||||
except Exception as e:
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST, message=str(e)), RetCode.BAD_REQUEST
|
||||
if not isinstance(cvs.dsl, str):
|
||||
dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
try:
|
||||
@@ -1950,7 +1946,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
|
||||
canvas = Canvas(dsl, cvs.user_id, agent_id, canvas_id=agent_id)
|
||||
except Exception as e:
|
||||
resp=get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e))
|
||||
resp = get_data_error_result(code=RetCode.BAD_REQUEST, message=str(e))
|
||||
resp.status_code = RetCode.BAD_REQUEST
|
||||
return resp
|
||||
|
||||
@@ -1967,9 +1963,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
# 3. Body
|
||||
ctype = request.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
if ctype and ctype != content_type:
|
||||
raise ValueError(
|
||||
f"Invalid Content-Type: expect '{content_type}', got '{ctype}'"
|
||||
)
|
||||
raise ValueError(f"Invalid Content-Type: expect '{content_type}', got '{ctype}'")
|
||||
|
||||
body_data: dict = {}
|
||||
|
||||
@@ -1991,11 +1985,11 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
raise Exception("Too many uploaded files")
|
||||
for key, file in files.items():
|
||||
desc = FileService.upload_info(
|
||||
cvs.user_id, # user
|
||||
file, # FileStorage
|
||||
None # url (None for webhook)
|
||||
cvs.user_id, # user
|
||||
file, # FileStorage
|
||||
None, # url (None for webhook)
|
||||
)
|
||||
file_parsed= await canvas.get_files_async([desc])
|
||||
file_parsed = await canvas.get_files_async([desc])
|
||||
body_data[key] = file_parsed
|
||||
|
||||
elif ctype == "application/x-www-form-urlencoded":
|
||||
@@ -2057,15 +2051,12 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
|
||||
# 4. Type validation
|
||||
if not validate_type(value, field_type):
|
||||
raise Exception(
|
||||
f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}"
|
||||
)
|
||||
raise Exception(f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}")
|
||||
|
||||
extracted[field] = value
|
||||
|
||||
return extracted
|
||||
|
||||
|
||||
def default_for_type(t):
|
||||
"""Return default value for the given schema type."""
|
||||
if t == "file":
|
||||
@@ -2145,7 +2136,6 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
# Default: do nothing
|
||||
return value
|
||||
|
||||
|
||||
def validate_type(value, t):
|
||||
"""Validate value type against schema type t."""
|
||||
if t == "file":
|
||||
@@ -2179,28 +2169,24 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
parsed = await parse_webhook_request(webhook_cfg.get("content_types"))
|
||||
SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}})
|
||||
|
||||
# Extract strictly by schema
|
||||
try:
|
||||
query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query")
|
||||
query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query")
|
||||
header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers")
|
||||
body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body")
|
||||
body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body")
|
||||
except Exception as e:
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST, message=str(e)), RetCode.BAD_REQUEST
|
||||
|
||||
clean_request = {
|
||||
"query": query_clean,
|
||||
"headers": header_clean,
|
||||
"body": body_clean,
|
||||
"input": parsed
|
||||
}
|
||||
clean_request = {"query": query_clean, "headers": header_clean, "body": body_clean, "input": parsed}
|
||||
|
||||
execution_mode = webhook_cfg.get("execution_mode", "Immediately")
|
||||
response_cfg = webhook_cfg.get("response", {})
|
||||
|
||||
def append_webhook_trace(agent_id: str, start_ts: float,event: dict, ttl=600):
|
||||
def append_webhook_trace(agent_id: str, start_ts: float, event: dict, ttl=600):
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
key = f"webhook-trace-{agent_id}-logs"
|
||||
@@ -2208,15 +2194,9 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
raw = REDIS_CONN.get(key)
|
||||
obj = json.loads(raw) if raw else {"webhooks": {}}
|
||||
|
||||
ws = obj["webhooks"].setdefault(
|
||||
str(start_ts),
|
||||
{"start_ts": start_ts, "events": []}
|
||||
)
|
||||
ws = obj["webhooks"].setdefault(str(start_ts), {"start_ts": start_ts, "events": []})
|
||||
|
||||
ws["events"].append({
|
||||
"ts": time.time(),
|
||||
**event
|
||||
})
|
||||
ws["events"].append({"ts": time.time(), **event})
|
||||
|
||||
REDIS_CONN.set_obj(key, obj, ttl)
|
||||
|
||||
@@ -2225,10 +2205,10 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
try:
|
||||
status = int(status)
|
||||
except (TypeError, ValueError):
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}")),RetCode.BAD_REQUEST
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST, message=str(f"Invalid response status code: {status}")), RetCode.BAD_REQUEST
|
||||
|
||||
if not (200 <= status <= 399):
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}, must be between 200 and 399")),RetCode.BAD_REQUEST
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST, message=str(f"Invalid response status code: {status}, must be between 200 and 399")), RetCode.BAD_REQUEST
|
||||
|
||||
body_tpl = response_cfg.get("body_template", "")
|
||||
|
||||
@@ -2242,7 +2222,6 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return body, "text/plain"
|
||||
|
||||
|
||||
body, content_type = parse_body(body_tpl)
|
||||
resp = Response(
|
||||
json.dumps(body, ensure_ascii=False) if content_type == "application/json" else body,
|
||||
@@ -2252,11 +2231,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
|
||||
async def background_run():
|
||||
try:
|
||||
async for ans in canvas.run(
|
||||
query="",
|
||||
user_id=cvs.user_id,
|
||||
webhook_payload=clean_request
|
||||
):
|
||||
async for ans in canvas.run(query="", user_id=cvs.user_id, webhook_payload=clean_request):
|
||||
if is_test:
|
||||
append_webhook_trace(agent_id, start_ts, ans)
|
||||
|
||||
@@ -2268,7 +2243,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
"event": "finished",
|
||||
"elapsed_time": time.time() - start_ts,
|
||||
"success": True,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
@@ -2285,7 +2260,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
"event": "error",
|
||||
"message": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
},
|
||||
)
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
@@ -2294,7 +2269,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
"event": "finished",
|
||||
"elapsed_time": time.time() - start_ts,
|
||||
"success": False,
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Failed to append webhook trace")
|
||||
@@ -2305,6 +2280,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
return resp
|
||||
else:
|
||||
|
||||
async def sse():
|
||||
nonlocal canvas
|
||||
contents: list[str] = []
|
||||
@@ -2326,11 +2302,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
if ans["event"] == "message_end":
|
||||
status = int(ans["data"].get("status", status))
|
||||
if is_test:
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
ans
|
||||
)
|
||||
append_webhook_trace(agent_id, start_ts, ans)
|
||||
if is_test:
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
@@ -2339,13 +2311,13 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
"event": "finished",
|
||||
"elapsed_time": time.time() - start_ts,
|
||||
"success": True,
|
||||
}
|
||||
},
|
||||
)
|
||||
final_content = "".join(contents)
|
||||
return {
|
||||
"message": final_content,
|
||||
"success": True,
|
||||
"code": status,
|
||||
"code": status,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -2357,7 +2329,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
"event": "error",
|
||||
"message": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
},
|
||||
)
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
@@ -2366,9 +2338,9 @@ async def _webhook_impl(agent_id: str, is_test: bool):
|
||||
"event": "finished",
|
||||
"elapsed_time": time.time() - start_ts,
|
||||
"success": False,
|
||||
}
|
||||
},
|
||||
)
|
||||
return {"code": 400, "message": str(e),"success":False}
|
||||
return {"code": 400, "message": str(e), "success": False}
|
||||
|
||||
result = await sse()
|
||||
return Response(
|
||||
@@ -2401,6 +2373,7 @@ async def webhook_trace(agent_id: str):
|
||||
if encode_webhook_id(ts) == enc_id:
|
||||
return ts
|
||||
return None
|
||||
|
||||
since_ts = request.args.get("since_ts", type=float)
|
||||
webhook_id = request.args.get("webhook_id")
|
||||
|
||||
@@ -2434,9 +2407,7 @@ async def webhook_trace(agent_id: str):
|
||||
webhooks = obj.get("webhooks", {})
|
||||
|
||||
if webhook_id is None:
|
||||
candidates = [
|
||||
float(k) for k in webhooks.keys() if float(k) > since_ts
|
||||
]
|
||||
candidates = [float(k) for k in webhooks.keys() if float(k) > since_ts]
|
||||
|
||||
if not candidates:
|
||||
return get_json_result(
|
||||
@@ -2492,6 +2463,7 @@ async def webhook_trace(agent_id: str):
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@manager.route("/agents/attachments/<attachment_id>/download", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
@add_tenant_id_to_kwargs
|
||||
|
||||
@@ -34,10 +34,12 @@ from peewee import fn
|
||||
class CanvasTemplateService(CommonService):
|
||||
model = CanvasTemplate
|
||||
|
||||
|
||||
class DataFlowTemplateService(CommonService):
|
||||
"""
|
||||
Alias of CanvasTemplateService
|
||||
"""
|
||||
|
||||
model = CanvasTemplate
|
||||
|
||||
|
||||
@@ -46,8 +48,7 @@ class UserCanvasService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_list(cls, tenant_id,
|
||||
page_number, items_per_page, orderby, desc, id, title, canvas_category=CanvasCategory.Agent):
|
||||
def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, title, canvas_category=CanvasCategory.Agent):
|
||||
agents = cls.model.select()
|
||||
if id:
|
||||
agents = agents.where(cls.model.id == id)
|
||||
@@ -68,20 +69,9 @@ class UserCanvasService(CommonService):
|
||||
@DB.connection_context()
|
||||
def get_all_agents_by_tenant_ids(cls, tenant_ids, user_id):
|
||||
# will get all permitted agents, be cautious
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.avatar,
|
||||
cls.model.title,
|
||||
cls.model.permission,
|
||||
cls.model.canvas_type,
|
||||
cls.model.canvas_category
|
||||
]
|
||||
fields = [cls.model.id, cls.model.avatar, cls.model.title, cls.model.permission, cls.model.canvas_type, cls.model.canvas_category]
|
||||
# find team agents and owned agents
|
||||
agents = cls.model.select(*fields).where(
|
||||
(cls.model.user_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (
|
||||
cls.model.user_id == user_id
|
||||
)
|
||||
)
|
||||
agents = cls.model.select(*fields).where((cls.model.user_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id))
|
||||
# sort by create_time, asc
|
||||
agents.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later
|
||||
@@ -100,7 +90,6 @@ class UserCanvasService(CommonService):
|
||||
@DB.connection_context()
|
||||
def get_by_canvas_id(cls, pid):
|
||||
try:
|
||||
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.avatar,
|
||||
@@ -115,11 +104,9 @@ class UserCanvasService(CommonService):
|
||||
cls.model.update_date,
|
||||
cls.model.canvas_category,
|
||||
User.nickname,
|
||||
User.avatar.alias('tenant_avatar'),
|
||||
User.avatar.alias("tenant_avatar"),
|
||||
]
|
||||
agents = cls.model.select(*fields) \
|
||||
.join(User, on=(cls.model.user_id == User.id)) \
|
||||
.where(cls.model.id == pid)
|
||||
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(cls.model.id == pid)
|
||||
# obj = cls.model.query(id=pid)[0]
|
||||
return True, agents.dicts()[0]
|
||||
except Exception as e:
|
||||
@@ -129,14 +116,7 @@ class UserCanvasService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_basic_info_by_canvas_ids(cls, canvas_id):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.avatar,
|
||||
cls.model.user_id,
|
||||
cls.model.title,
|
||||
cls.model.permission,
|
||||
cls.model.canvas_category
|
||||
]
|
||||
fields = [cls.model.id, cls.model.avatar, cls.model.user_id, cls.model.title, cls.model.permission, cls.model.canvas_category]
|
||||
return cls.model.select(*fields).where(cls.model.id.in_(canvas_id)).dicts()
|
||||
|
||||
@classmethod
|
||||
@@ -162,20 +142,26 @@ class UserCanvasService(CommonService):
|
||||
cls.model.permission,
|
||||
cls.model.user_id.alias("tenant_id"),
|
||||
User.nickname,
|
||||
User.avatar.alias('tenant_avatar'),
|
||||
User.avatar.alias("tenant_avatar"),
|
||||
cls.model.update_time,
|
||||
cls.model.canvas_type,
|
||||
cls.model.canvas_category,
|
||||
cls.model.tags,
|
||||
]
|
||||
if keywords:
|
||||
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
|
||||
(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)),
|
||||
(fn.LOWER(cls.model.title).contains(keywords.lower()))
|
||||
agents = (
|
||||
cls.model.select(*fields)
|
||||
.join(User, on=(cls.model.user_id == User.id))
|
||||
.where(
|
||||
(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)),
|
||||
(fn.LOWER(cls.model.title).contains(keywords.lower())),
|
||||
)
|
||||
)
|
||||
else:
|
||||
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
|
||||
(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id))
|
||||
agents = (
|
||||
cls.model.select(*fields)
|
||||
.join(User, on=(cls.model.user_id == User.id))
|
||||
.where((((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)))
|
||||
)
|
||||
if canvas_category:
|
||||
agents = agents.where(cls.model.canvas_category == canvas_category)
|
||||
@@ -201,7 +187,7 @@ class UserCanvasService(CommonService):
|
||||
|
||||
# Get latest release time for each canvas
|
||||
if agents_list:
|
||||
canvas_ids = [a['id'] for a in agents_list]
|
||||
canvas_ids = [a["id"] for a in agents_list]
|
||||
release_times = (
|
||||
UserCanvasVersion.select(UserCanvasVersion.user_canvas_id, fn.MAX(UserCanvasVersion.create_time).alias("release_time"))
|
||||
.where((UserCanvasVersion.user_canvas_id.in_(canvas_ids)) & (UserCanvasVersion.release))
|
||||
@@ -210,7 +196,7 @@ class UserCanvasService(CommonService):
|
||||
release_time_map = {r.user_canvas_id: r.release_time for r in release_times}
|
||||
|
||||
for agent in agents_list:
|
||||
agent['release_time'] = release_time_map.get(agent['id'])
|
||||
agent["release_time"] = release_time_map.get(agent["id"])
|
||||
|
||||
return agents_list, count
|
||||
|
||||
@@ -218,9 +204,7 @@ class UserCanvasService(CommonService):
|
||||
@DB.connection_context()
|
||||
def list_tags(cls, joined_tenant_ids, user_id, canvas_category=None):
|
||||
"""Return {tag: agent_count} aggregated across agents visible to the user."""
|
||||
query = cls.model.select(cls.model.tags).where(
|
||||
((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)
|
||||
)
|
||||
query = cls.model.select(cls.model.tags).where(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id))
|
||||
if canvas_category:
|
||||
query = query.where(cls.model.canvas_category == canvas_category)
|
||||
|
||||
@@ -281,6 +265,7 @@ class UserCanvasService(CommonService):
|
||||
@DB.connection_context()
|
||||
def accessible(cls, canvas_id, tenant_id):
|
||||
from api.db.services.user_service import UserTenantService
|
||||
|
||||
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return False
|
||||
@@ -345,18 +330,15 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
conv = API4Conversation(**conv)
|
||||
|
||||
message_id = str(uuid4())
|
||||
conv.message.append({
|
||||
"role": "user",
|
||||
"content": query,
|
||||
"id": message_id,
|
||||
"files": files
|
||||
})
|
||||
conv.message.append({"role": "user", "content": query, "id": message_id, "files": files})
|
||||
txt = ""
|
||||
run_kwargs = {
|
||||
"query": query,
|
||||
"files": files,
|
||||
"user_id": user_id,
|
||||
"inputs": inputs,
|
||||
# Used by Canvas.run to correlate RAGFlow's Langfuse generations by session.
|
||||
"session_id": session_id,
|
||||
}
|
||||
if chat_template_kwargs is not None:
|
||||
run_kwargs["chat_template_kwargs"] = chat_template_kwargs
|
||||
@@ -394,14 +376,7 @@ async def completion_openai(tenant_id, agent_id, question, session_id=None, stre
|
||||
if stream:
|
||||
completion_tokens = 0
|
||||
try:
|
||||
async for ans in completion(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
query=question,
|
||||
user_id=user_id,
|
||||
**kwargs
|
||||
):
|
||||
async for ans in completion(tenant_id=tenant_id, agent_id=agent_id, session_id=session_id, query=question, user_id=user_id, **kwargs):
|
||||
if isinstance(ans, str):
|
||||
try:
|
||||
ans = json.loads(ans[5:]) # remove "data:"
|
||||
@@ -417,14 +392,7 @@ async def completion_openai(tenant_id, agent_id, question, session_id=None, stre
|
||||
|
||||
completion_tokens += len(tiktoken_encoder.encode(content_piece))
|
||||
|
||||
openai_data = get_data_openai(
|
||||
id=session_id or str(uuid4()),
|
||||
model=agent_id,
|
||||
content=content_piece,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
stream=True
|
||||
)
|
||||
openai_data = get_data_openai(id=session_id or str(uuid4()), model=agent_id, content=content_piece, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, stream=True)
|
||||
|
||||
if ans.get("data", {}).get("reference", None):
|
||||
openai_data["choices"][0]["delta"]["reference"] = ans["data"]["reference"]
|
||||
@@ -435,32 +403,29 @@ async def completion_openai(tenant_id, agent_id, question, session_id=None, stre
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
yield "data: " + json.dumps(
|
||||
get_data_openai(
|
||||
id=session_id or str(uuid4()),
|
||||
model=agent_id,
|
||||
content=f"**ERROR**: {str(e)}",
|
||||
finish_reason="stop",
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=len(tiktoken_encoder.encode(f"**ERROR**: {str(e)}")),
|
||||
stream=True
|
||||
),
|
||||
ensure_ascii=False
|
||||
) + "\n\n"
|
||||
yield (
|
||||
"data: "
|
||||
+ json.dumps(
|
||||
get_data_openai(
|
||||
id=session_id or str(uuid4()),
|
||||
model=agent_id,
|
||||
content=f"**ERROR**: {str(e)}",
|
||||
finish_reason="stop",
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=len(tiktoken_encoder.encode(f"**ERROR**: {str(e)}")),
|
||||
stream=True,
|
||||
),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n\n"
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
else:
|
||||
try:
|
||||
all_content = ""
|
||||
reference = {}
|
||||
async for ans in completion(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
query=question,
|
||||
user_id=user_id,
|
||||
**kwargs
|
||||
):
|
||||
async for ans in completion(tenant_id=tenant_id, agent_id=agent_id, session_id=session_id, query=question, user_id=user_id, **kwargs):
|
||||
if isinstance(ans, str):
|
||||
ans = json.loads(ans[5:])
|
||||
if ans.get("event") not in ["message", "message_end"]:
|
||||
@@ -475,13 +440,7 @@ async def completion_openai(tenant_id, agent_id, question, session_id=None, stre
|
||||
completion_tokens = len(tiktoken_encoder.encode(all_content))
|
||||
|
||||
openai_data = get_data_openai(
|
||||
id=session_id or str(uuid4()),
|
||||
model=agent_id,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
content=all_content,
|
||||
finish_reason="stop",
|
||||
param=None
|
||||
id=session_id or str(uuid4()), model=agent_id, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, content=all_content, finish_reason="stop", param=None
|
||||
)
|
||||
|
||||
if reference:
|
||||
@@ -497,5 +456,5 @@ async def completion_openai(tenant_id, agent_id, question, session_id=None, stre
|
||||
completion_tokens=len(tiktoken_encoder.encode(f"**ERROR**: {str(e)}")),
|
||||
content=f"**ERROR**: {str(e)}",
|
||||
finish_reason="stop",
|
||||
param=None
|
||||
param=None,
|
||||
)
|
||||
|
||||
@@ -27,7 +27,7 @@ from langfuse import propagate_attributes
|
||||
from api.db.db_models import LLM
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.tenant_llm_service import LLM4Tenant
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from common.token_utils import num_tokens_from_string, record_run_token_usage, langfuse_run_attrs
|
||||
|
||||
|
||||
class LLMService(CommonService):
|
||||
@@ -39,11 +39,49 @@ class LLMBundle(LLM4Tenant):
|
||||
super().__init__(tenant_id, model_config, lang, **kwargs)
|
||||
|
||||
def _start_langfuse_observation(self, **kwargs):
|
||||
# Correlating attributes (session_id/user_id) let Langfuse group all of a
|
||||
# turn's generations. They may come from this bundle (chat/dialog path) or,
|
||||
# for agent runs whose bundles are created without them, from the per-run
|
||||
# context installed by Canvas.run.
|
||||
attrs = {}
|
||||
if self.langfuse_session_id:
|
||||
with propagate_attributes(session_id=self.langfuse_session_id):
|
||||
attrs["session_id"] = self.langfuse_session_id
|
||||
run_attrs = langfuse_run_attrs.get()
|
||||
if run_attrs:
|
||||
for k in ("session_id", "user_id"):
|
||||
if run_attrs.get(k) and k not in attrs:
|
||||
attrs[k] = run_attrs[k]
|
||||
if attrs:
|
||||
with propagate_attributes(**attrs):
|
||||
return self.langfuse.start_observation(**kwargs)
|
||||
return self.langfuse.start_observation(**kwargs)
|
||||
|
||||
def _reset_last_usage(self) -> None:
|
||||
"""Clear the model's per-call usage so a failed call that returns before
|
||||
updating it cannot leak the previous call's usage into this run."""
|
||||
if hasattr(self.mdl, "last_usage"):
|
||||
self.mdl.last_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||
|
||||
def _report_usage(self, total_tokens: int) -> dict:
|
||||
"""Record a chat call's usage to the active agent run and return the
|
||||
prompt/completion/total split for Langfuse.
|
||||
|
||||
``total_tokens`` is the authoritative total from the call. The prompt/completion
|
||||
split is taken from the provider response (``mdl.last_usage``) only when it is
|
||||
consistent with ``total_tokens`` (i.e. produced by this same call); otherwise the
|
||||
split is reported as 0 while the total still aggregates correctly.
|
||||
"""
|
||||
split = getattr(self.mdl, "last_usage", None) or {}
|
||||
prompt = int(split.get("prompt_tokens", 0) or 0)
|
||||
completion = int(split.get("completion_tokens", 0) or 0)
|
||||
if not total_tokens:
|
||||
total_tokens = int(split.get("total_tokens", 0) or 0)
|
||||
if (prompt + completion) != total_tokens:
|
||||
# Stale or inconsistent split — keep the total, drop the unreliable split.
|
||||
prompt, completion = 0, 0
|
||||
record_run_token_usage(prompt, completion, total_tokens)
|
||||
return {"input": prompt, "output": completion, "total": total_tokens}
|
||||
|
||||
def close(self):
|
||||
"""Release resources held by this LLMBundle instance."""
|
||||
super().close()
|
||||
@@ -139,7 +177,9 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
if self.langfuse:
|
||||
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts})
|
||||
generation = self._start_langfuse_observation(
|
||||
trace_context=self.trace_context, as_type="generation", name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts}
|
||||
)
|
||||
|
||||
sim, used_tokens = self.mdl.similarity(query, texts)
|
||||
logging.info("LLMBundle.similarity used_tokens: %d", used_tokens)
|
||||
@@ -165,7 +205,9 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def describe_with_prompt(self, image, prompt):
|
||||
if self.langfuse:
|
||||
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt})
|
||||
generation = self._start_langfuse_observation(
|
||||
trace_context=self.trace_context, as_type="generation", name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt}
|
||||
)
|
||||
|
||||
txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
|
||||
logging.info("LLMBundle.describe_with_prompt used_tokens: %d", used_tokens)
|
||||
@@ -194,7 +236,8 @@ class LLMBundle(LLM4Tenant):
|
||||
supports_stream = hasattr(mdl, "stream_transcription") and callable(getattr(mdl, "stream_transcription"))
|
||||
if supports_stream:
|
||||
if self.langfuse:
|
||||
generation = self._start_langfuse_observation(as_type="generation",
|
||||
generation = self._start_langfuse_observation(
|
||||
as_type="generation",
|
||||
trace_context=self.trace_context,
|
||||
name="stream_transcription",
|
||||
metadata={"model": self.model_config["llm_name"]},
|
||||
@@ -228,7 +271,8 @@ class LLMBundle(LLM4Tenant):
|
||||
return
|
||||
|
||||
if self.langfuse:
|
||||
generation = self._start_langfuse_observation(as_type="generation",
|
||||
generation = self._start_langfuse_observation(
|
||||
as_type="generation",
|
||||
trace_context=self.trace_context,
|
||||
name="stream_transcription",
|
||||
metadata={"model": self.model_config["llm_name"]},
|
||||
@@ -377,11 +421,14 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
generation = None
|
||||
if self.langfuse:
|
||||
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history})
|
||||
generation = self._start_langfuse_observation(
|
||||
trace_context=self.trace_context, as_type="generation", name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history}
|
||||
)
|
||||
|
||||
chat_partial = partial(base_fn, system, history, gen_conf)
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
|
||||
self._reset_last_usage()
|
||||
try:
|
||||
txt, used_tokens = await chat_partial(**use_kwargs)
|
||||
except Exception as e:
|
||||
@@ -397,8 +444,10 @@ class LLMBundle(LLM4Tenant):
|
||||
if used_tokens:
|
||||
logging.info("LLMBundle.async_chat used_tokens: %d", used_tokens)
|
||||
|
||||
usage_details = self._report_usage(used_tokens)
|
||||
|
||||
if generation:
|
||||
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
||||
generation.update(output={"output": txt}, usage_details=usage_details)
|
||||
generation.end()
|
||||
|
||||
return txt
|
||||
@@ -418,11 +467,14 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
generation = None
|
||||
if self.langfuse:
|
||||
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
|
||||
generation = self._start_langfuse_observation(
|
||||
trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history}
|
||||
)
|
||||
|
||||
if stream_fn:
|
||||
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
self._reset_last_usage()
|
||||
try:
|
||||
async for txt in chat_partial(**use_kwargs):
|
||||
if isinstance(txt, int):
|
||||
@@ -444,8 +496,9 @@ class LLMBundle(LLM4Tenant):
|
||||
raise
|
||||
if total_tokens:
|
||||
logging.info("LLMBundle.async_chat_streamly used_tokens: %d", total_tokens)
|
||||
usage_details = self._report_usage(total_tokens)
|
||||
if generation:
|
||||
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
||||
generation.update(output={"output": ans}, usage_details=usage_details)
|
||||
generation.end()
|
||||
return
|
||||
|
||||
@@ -461,11 +514,14 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
generation = None
|
||||
if self.langfuse:
|
||||
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
|
||||
generation = self._start_langfuse_observation(
|
||||
trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history}
|
||||
)
|
||||
|
||||
if stream_fn:
|
||||
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
self._reset_last_usage()
|
||||
try:
|
||||
async for txt in chat_partial(**use_kwargs):
|
||||
if isinstance(txt, int):
|
||||
@@ -487,7 +543,8 @@ class LLMBundle(LLM4Tenant):
|
||||
raise
|
||||
if total_tokens:
|
||||
logging.info("LLMBundle.async_chat_streamly_delta used_tokens: %d", total_tokens)
|
||||
usage_details = self._report_usage(total_tokens)
|
||||
if generation:
|
||||
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
||||
generation.update(output={"output": ans}, usage_details=usage_details)
|
||||
generation.end()
|
||||
return
|
||||
|
||||
@@ -14,9 +14,12 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import contextvars
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
import tiktoken
|
||||
|
||||
from common.file_utils import get_project_base_directory
|
||||
@@ -42,6 +45,84 @@ os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
|
||||
encoder = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
# Per-run token usage sink. An agent run (Canvas.run) installs a mutable dict here
|
||||
# at the start of each turn; every LLMBundle chat call adds its provider-reported
|
||||
# usage to it. This is the single chokepoint that aggregates token usage across all
|
||||
# LLM calls in a run (query rewriting, cross-language translation, tool reasoning,
|
||||
# and the final streamed answer) regardless of which component or helper issued the
|
||||
# call. Default None means "not inside a tracked run" and callers must no-op.
|
||||
token_usage_sink: contextvars.ContextVar = contextvars.ContextVar("ragflow_token_usage_sink", default=None)
|
||||
|
||||
# Per-run Langfuse correlating attributes (e.g. {"session_id": ..., "user_id": ...}).
|
||||
# Installed by Canvas.run so RAGFlow's own Langfuse generations can be grouped by
|
||||
# session and user even though the agent's LLMBundles are created without them.
|
||||
langfuse_run_attrs: contextvars.ContextVar = contextvars.ContextVar("ragflow_langfuse_run_attrs", default=None)
|
||||
|
||||
|
||||
# Guards sink mutations: concurrent tool calls (asyncio.gather + thread_pool_exec,
|
||||
# which copies the context so worker threads share the same sink dict) can otherwise
|
||||
# race on the read-modify-write of the counters.
|
||||
_sink_lock = threading.Lock()
|
||||
|
||||
|
||||
def record_run_token_usage(prompt_tokens: int = 0, completion_tokens: int = 0, total_tokens: int = 0) -> None:
|
||||
"""Add a single LLM call's token usage to the active run sink, if any.
|
||||
|
||||
Safe to call from anywhere: when no run sink is installed it does nothing.
|
||||
"""
|
||||
sink = token_usage_sink.get()
|
||||
if sink is None:
|
||||
return
|
||||
try:
|
||||
with _sink_lock:
|
||||
sink["prompt_tokens"] += int(prompt_tokens or 0)
|
||||
sink["completion_tokens"] += int(completion_tokens or 0)
|
||||
sink["total_tokens"] += int(total_tokens or 0)
|
||||
sink["calls"] += 1
|
||||
except Exception:
|
||||
# Never let usage bookkeeping break a request; log at debug so a malformed
|
||||
# sink or token value is still traceable without adding noise.
|
||||
logging.debug("Failed to record run token usage", exc_info=True)
|
||||
|
||||
|
||||
def usage_from_response(resp) -> dict:
|
||||
"""Extract a {prompt_tokens, completion_tokens, total_tokens} split from an LLM response.
|
||||
|
||||
Handles OpenAI/OpenRouter-style ``resp.usage`` objects and dict variants. Missing
|
||||
fields default to 0; ``total_tokens`` falls back to prompt+completion when absent.
|
||||
"""
|
||||
out = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||
if resp is None:
|
||||
return out
|
||||
|
||||
usage = None
|
||||
try:
|
||||
usage = getattr(resp, "usage", None)
|
||||
if usage is None and isinstance(resp, dict):
|
||||
usage = resp.get("usage")
|
||||
except Exception:
|
||||
usage = None
|
||||
if usage is None:
|
||||
return out
|
||||
|
||||
def _get(obj, *names):
|
||||
for n in names:
|
||||
try:
|
||||
v = obj.get(n) if isinstance(obj, dict) else getattr(obj, n, None)
|
||||
except Exception:
|
||||
v = None
|
||||
if v:
|
||||
return int(v)
|
||||
return 0
|
||||
|
||||
out["prompt_tokens"] = _get(usage, "prompt_tokens", "input_tokens")
|
||||
out["completion_tokens"] = _get(usage, "completion_tokens", "output_tokens")
|
||||
out["total_tokens"] = _get(usage, "total_tokens")
|
||||
if not out["total_tokens"]:
|
||||
out["total_tokens"] = out["prompt_tokens"] + out["completion_tokens"]
|
||||
return out
|
||||
|
||||
|
||||
def num_tokens_from_string(string: str) -> int:
|
||||
"""Returns the number of tokens in a text string."""
|
||||
try:
|
||||
@@ -50,6 +131,7 @@ def num_tokens_from_string(string: str) -> int:
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def total_token_count_from_response(resp):
|
||||
"""
|
||||
Extract token count from LLM response in various formats.
|
||||
@@ -78,19 +160,19 @@ def total_token_count_from_response(resp):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if isinstance(resp, dict) and 'usage' in resp and 'total_tokens' in resp['usage']:
|
||||
if isinstance(resp, dict) and "usage" in resp and "total_tokens" in resp["usage"]:
|
||||
try:
|
||||
return resp["usage"]["total_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if isinstance(resp, dict) and 'usage' in resp and 'input_tokens' in resp['usage'] and 'output_tokens' in resp['usage']:
|
||||
if isinstance(resp, dict) and "usage" in resp and "input_tokens" in resp["usage"] and "output_tokens" in resp["usage"]:
|
||||
try:
|
||||
return resp["usage"]["input_tokens"] + resp["usage"]["output_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if isinstance(resp, dict) and 'meta' in resp and 'tokens' in resp['meta'] and 'input_tokens' in resp['meta']['tokens'] and 'output_tokens' in resp['meta']['tokens']:
|
||||
if isinstance(resp, dict) and "meta" in resp and "tokens" in resp["meta"] and "input_tokens" in resp["meta"]["tokens"] and "output_tokens" in resp["meta"]["tokens"]:
|
||||
try:
|
||||
return resp["meta"]["tokens"]["input_tokens"] + resp["meta"]["tokens"]["output_tokens"]
|
||||
except Exception:
|
||||
|
||||
@@ -32,7 +32,7 @@ from openai import AsyncOpenAI, OpenAI
|
||||
from enum import StrEnum
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
from common.token_utils import num_tokens_from_string, total_token_count_from_response
|
||||
from common.token_utils import num_tokens_from_string, total_token_count_from_response, usage_from_response
|
||||
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
|
||||
from rag.llm.key_utils import _normalize_replicate_key
|
||||
from rag.llm.tool_decorator import FunctionToolSession, is_tool
|
||||
@@ -184,11 +184,7 @@ def _apply_model_family_policies(
|
||||
sanitized_gen_conf["n"] = 1
|
||||
sanitized_gen_conf["presence_penalty"] = 0.0
|
||||
sanitized_gen_conf["frequency_penalty"] = 0.0
|
||||
elif (
|
||||
provider == SupportedLiteLLMProvider.ZHIPU_AI
|
||||
and "glm" in model_name_lower
|
||||
and thinking_type
|
||||
):
|
||||
elif provider == SupportedLiteLLMProvider.ZHIPU_AI and "glm" in model_name_lower and thinking_type:
|
||||
_pop_thinking_controls()
|
||||
sanitized_gen_conf["thinking"] = {"type": thinking_type}
|
||||
|
||||
@@ -217,6 +213,7 @@ def _move_litellm_provider_body_fields(provider: SupportedLiteLLMProvider | str
|
||||
completion_args["extra_body"] = body
|
||||
return completion_args
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600))
|
||||
@@ -230,6 +227,9 @@ class Base(ABC):
|
||||
self.is_tools = False
|
||||
self.tools = []
|
||||
self.toolcall_sessions = {}
|
||||
# Token usage split (prompt/completion/total) of the most recent chat call.
|
||||
# Consumed by LLMBundle for accurate Langfuse reporting and run aggregation.
|
||||
self.last_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||
|
||||
def _get_delay(self):
|
||||
return self.base_delay * random.uniform(10, 150)
|
||||
@@ -313,6 +313,8 @@ class Base(ABC):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
# Reset so a stale split from a previous call can't leak into this one.
|
||||
self.last_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
@@ -478,6 +480,18 @@ class Base(ABC):
|
||||
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
# Aggregate prompt/completion/total across all tool-calling rounds.
|
||||
agg_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||
|
||||
def _add_round_usage(resp):
|
||||
nonlocal tk_count
|
||||
u = usage_from_response(resp)
|
||||
agg_usage["prompt_tokens"] += u["prompt_tokens"]
|
||||
agg_usage["completion_tokens"] += u["completion_tokens"]
|
||||
agg_usage["total_tokens"] += u["total_tokens"] or total_token_count_from_response(resp)
|
||||
tk_count = agg_usage["total_tokens"]
|
||||
self.last_usage = dict(agg_usage)
|
||||
|
||||
hist = deepcopy(history)
|
||||
for attempt in range(self.max_retries + 1):
|
||||
history = deepcopy(hist)
|
||||
@@ -485,7 +499,7 @@ class Base(ABC):
|
||||
for _ in range(self.max_rounds + 1):
|
||||
logging.info(f"{self.tools=}")
|
||||
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf, **extra_request_kwargs)
|
||||
tk_count += total_token_count_from_response(response)
|
||||
_add_round_usage(response)
|
||||
if not response.choices or not response.choices[0].message:
|
||||
raise Exception(f"500 response structure error. Response: {response}")
|
||||
|
||||
@@ -505,9 +519,7 @@ class Base(ABC):
|
||||
try:
|
||||
args = json_repair.loads(tc.function.arguments)
|
||||
if not isinstance(args, dict):
|
||||
raise TypeError(
|
||||
f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}"
|
||||
)
|
||||
raise TypeError(f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}")
|
||||
if hasattr(self.toolcall_session, "tool_call_async"):
|
||||
result = await self.toolcall_session.tool_call_async(name, args)
|
||||
else:
|
||||
@@ -527,7 +539,13 @@ class Base(ABC):
|
||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||
response, token_count = await self._async_chat(history, gen_conf)
|
||||
ans += response
|
||||
tk_count += token_count
|
||||
# _async_chat set self.last_usage to its own call; fold it into the aggregate.
|
||||
_fb = getattr(self, "last_usage", None) or {}
|
||||
agg_usage["prompt_tokens"] += int(_fb.get("prompt_tokens", 0) or 0)
|
||||
agg_usage["completion_tokens"] += int(_fb.get("completion_tokens", 0) or 0)
|
||||
agg_usage["total_tokens"] += int(_fb.get("total_tokens", 0) or token_count)
|
||||
tk_count = agg_usage["total_tokens"]
|
||||
self.last_usage = dict(agg_usage)
|
||||
return ans, tk_count
|
||||
except Exception as e:
|
||||
e = await self._exceptions_async(e, attempt)
|
||||
@@ -550,8 +568,23 @@ class Base(ABC):
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
|
||||
total_tokens = 0
|
||||
# Aggregate prompt/completion/total across all tool-calling rounds. The split is
|
||||
# captured opportunistically when the provider reports usage on a chunk; otherwise
|
||||
# only the (estimated) total accumulates.
|
||||
agg_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||
hist = deepcopy(history)
|
||||
|
||||
def _commit_round(round_usage, round_estimate):
|
||||
nonlocal total_tokens
|
||||
if round_usage and round_usage["total_tokens"]:
|
||||
agg_usage["prompt_tokens"] += round_usage["prompt_tokens"]
|
||||
agg_usage["completion_tokens"] += round_usage["completion_tokens"]
|
||||
agg_usage["total_tokens"] += round_usage["total_tokens"]
|
||||
else:
|
||||
agg_usage["total_tokens"] += round_estimate
|
||||
total_tokens = agg_usage["total_tokens"]
|
||||
self.last_usage = dict(agg_usage)
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
history = deepcopy(hist)
|
||||
try:
|
||||
@@ -559,12 +592,20 @@ class Base(ABC):
|
||||
reasoning_start = False
|
||||
logging.info(f"[ToolLoop] round={_round} model={self.model_name} tools={[t['function']['name'] for t in tools]}")
|
||||
|
||||
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf, **extra_request_kwargs)
|
||||
response = await self.async_client.chat.completions.create(
|
||||
model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf, **extra_request_kwargs
|
||||
)
|
||||
|
||||
final_tool_calls = {}
|
||||
answer = ""
|
||||
round_estimate = 0
|
||||
round_usage = None
|
||||
|
||||
async for resp in response:
|
||||
_u = usage_from_response(resp)
|
||||
if _u["total_tokens"]:
|
||||
round_usage = _u
|
||||
|
||||
if not hasattr(resp, "choices") or not resp.choices:
|
||||
continue
|
||||
|
||||
@@ -597,16 +638,17 @@ class Base(ABC):
|
||||
answer += delta.content
|
||||
yield delta.content
|
||||
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(delta.content)
|
||||
else:
|
||||
total_tokens = tol
|
||||
if not _u["total_tokens"]:
|
||||
round_estimate += num_tokens_from_string(delta.content)
|
||||
|
||||
finish_reason = getattr(resp.choices[0], "finish_reason", "")
|
||||
if finish_reason == "length":
|
||||
yield self._length_stop("")
|
||||
|
||||
# Commit this round's tokens (each round is a separate provider
|
||||
# request — accumulate, never overwrite).
|
||||
_commit_round(round_usage, round_estimate)
|
||||
|
||||
if answer and not final_tool_calls:
|
||||
logging.info(f"[ToolLoop] round={_round} completed with text response, exiting")
|
||||
yield total_tokens
|
||||
@@ -617,9 +659,7 @@ class Base(ABC):
|
||||
try:
|
||||
args = json_repair.loads(tc.function.arguments)
|
||||
if not isinstance(args, dict):
|
||||
raise TypeError(
|
||||
f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}"
|
||||
)
|
||||
raise TypeError(f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}")
|
||||
if hasattr(self.toolcall_session, "tool_call_async"):
|
||||
result = await self.toolcall_session.tool_call_async(name, args)
|
||||
else:
|
||||
@@ -655,19 +695,22 @@ class Base(ABC):
|
||||
**extra_request_kwargs,
|
||||
)
|
||||
|
||||
fb_estimate = 0
|
||||
fb_usage = None
|
||||
async for resp in response:
|
||||
_u = usage_from_response(resp)
|
||||
if _u["total_tokens"]:
|
||||
fb_usage = _u
|
||||
if not hasattr(resp, "choices") or not resp.choices:
|
||||
continue
|
||||
delta = resp.choices[0].delta
|
||||
if not hasattr(delta, "content") or delta.content is None:
|
||||
continue
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(delta.content)
|
||||
else:
|
||||
total_tokens = tol
|
||||
if not _u["total_tokens"]:
|
||||
fb_estimate += num_tokens_from_string(delta.content)
|
||||
yield delta.content
|
||||
|
||||
_commit_round(fb_usage, fb_estimate)
|
||||
yield total_tokens
|
||||
return
|
||||
|
||||
@@ -708,6 +751,8 @@ class Base(ABC):
|
||||
|
||||
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
||||
|
||||
# Capture prompt/completion split for accurate Langfuse + run aggregation.
|
||||
self.last_usage = usage_from_response(response)
|
||||
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
|
||||
return "", 0
|
||||
ans = response.choices[0].message.content.strip()
|
||||
@@ -1075,9 +1120,7 @@ class ReplicateChat(Base):
|
||||
msgs = [{"role": "system", "content": system}]
|
||||
|
||||
system_msg = msgs[0]["content"] if msgs and msgs[0].get("role") == "system" else ""
|
||||
prompt = "\n".join(
|
||||
[item["role"] + ":" + item["content"] for item in msgs[-5:] if item.get("role") != "system"]
|
||||
)
|
||||
prompt = "\n".join([item["role"] + ":" + item["content"] for item in msgs[-5:] if item.get("role") != "system"])
|
||||
try:
|
||||
response = self.client.run(
|
||||
self.model_name,
|
||||
@@ -1550,6 +1593,9 @@ class LiteLLMBase(ABC):
|
||||
self.is_tools = False
|
||||
self.tools = []
|
||||
self.toolcall_sessions = {}
|
||||
# Token usage split (prompt/completion/total) of the most recent chat call.
|
||||
# Consumed by LLMBundle for accurate Langfuse reporting and run aggregation.
|
||||
self.last_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||
|
||||
# Factory specific fields
|
||||
if self.provider == SupportedLiteLLMProvider.OpenRouter:
|
||||
@@ -1638,6 +1684,9 @@ class LiteLLMBase(ABC):
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
# Capture the prompt/completion split for accurate per-call usage
|
||||
# reporting (Langfuse + agent run aggregation).
|
||||
self.last_usage = usage_from_response(response)
|
||||
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
|
||||
return "", 0
|
||||
ans = response.choices[0].message.content.strip()
|
||||
@@ -1659,11 +1708,16 @@ class LiteLLMBase(ABC):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
reasoning_start = False
|
||||
total_tokens = 0
|
||||
# Reset so a stale split from a previous call can't leak into this one.
|
||||
self.last_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||
|
||||
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
|
||||
stop = kwargs.get("stop")
|
||||
if stop:
|
||||
completion_args["stop"] = stop
|
||||
# Ask the provider to include authoritative usage in the final streaming chunk.
|
||||
# drop_params=True ensures this is silently ignored by providers that don't support it.
|
||||
completion_args.setdefault("stream_options", {})["include_usage"] = True
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
@@ -1674,6 +1728,14 @@ class LiteLLMBase(ABC):
|
||||
)
|
||||
|
||||
async for resp in stream:
|
||||
# Authoritative usage may arrive on a usage-only final chunk that
|
||||
# carries no choices (OpenAI/OpenRouter with include_usage). Read it
|
||||
# before the choices guard so the prompt/completion split is captured.
|
||||
_usage = usage_from_response(resp)
|
||||
if _usage["total_tokens"]:
|
||||
total_tokens = _usage["total_tokens"]
|
||||
self.last_usage = _usage
|
||||
|
||||
if not hasattr(resp, "choices") or not resp.choices:
|
||||
continue
|
||||
|
||||
@@ -1692,10 +1754,9 @@ class LiteLLMBase(ABC):
|
||||
reasoning_start = False
|
||||
ans = delta.content
|
||||
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
tol = num_tokens_from_string(delta.content)
|
||||
total_tokens += tol
|
||||
if not _usage["total_tokens"]:
|
||||
# No authoritative usage yet: keep a running estimate as fallback.
|
||||
total_tokens += num_tokens_from_string(delta.content)
|
||||
|
||||
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
|
||||
if finish_reason == "length":
|
||||
@@ -1745,11 +1806,15 @@ class LiteLLMBase(ABC):
|
||||
return msg
|
||||
|
||||
def _verbose_tool_use(self, name, args, res):
|
||||
return "<tool_call>" + json.dumps(
|
||||
{"name": name, "args": args, "result": str(res) if isinstance(res, Exception) else res},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
) + "</tool_call>"
|
||||
return (
|
||||
"<tool_call>"
|
||||
+ json.dumps(
|
||||
{"name": name, "args": args, "result": str(res) if isinstance(res, Exception) else res},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
+ "</tool_call>"
|
||||
)
|
||||
|
||||
def _append_history(self, hist, tool_call, tool_res, reasoning_content=None):
|
||||
assistant_msg = {
|
||||
@@ -1844,6 +1909,19 @@ class LiteLLMBase(ABC):
|
||||
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
# Aggregate prompt/completion/total across every tool-calling round so the
|
||||
# whole multi-round exchange is reported once with a correct split.
|
||||
agg_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||
|
||||
def _add_usage(resp):
|
||||
nonlocal tk_count
|
||||
u = usage_from_response(resp)
|
||||
agg_usage["prompt_tokens"] += u["prompt_tokens"]
|
||||
agg_usage["completion_tokens"] += u["completion_tokens"]
|
||||
agg_usage["total_tokens"] += u["total_tokens"] or total_token_count_from_response(resp)
|
||||
tk_count = agg_usage["total_tokens"]
|
||||
self.last_usage = dict(agg_usage)
|
||||
|
||||
hist = deepcopy(history)
|
||||
for attempt in range(self.max_retries + 1):
|
||||
history = deepcopy(hist)
|
||||
@@ -1858,7 +1936,7 @@ class LiteLLMBase(ABC):
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
tk_count += total_token_count_from_response(response)
|
||||
_add_usage(response)
|
||||
|
||||
if not hasattr(response, "choices") or not response.choices or not response.choices[0].message:
|
||||
raise Exception(f"500 response structure error. Response: {response}")
|
||||
@@ -1906,7 +1984,13 @@ class LiteLLMBase(ABC):
|
||||
|
||||
response, token_count = await self.async_chat("", history, gen_conf)
|
||||
ans += response
|
||||
tk_count += token_count
|
||||
# self.async_chat set self.last_usage to its own call; fold it into the aggregate.
|
||||
_fb = getattr(self, "last_usage", None) or {}
|
||||
agg_usage["prompt_tokens"] += int(_fb.get("prompt_tokens", 0) or 0)
|
||||
agg_usage["completion_tokens"] += int(_fb.get("completion_tokens", 0) or 0)
|
||||
agg_usage["total_tokens"] += int(_fb.get("total_tokens", 0) or token_count)
|
||||
tk_count = agg_usage["total_tokens"]
|
||||
self.last_usage = dict(agg_usage)
|
||||
return ans, tk_count
|
||||
|
||||
except Exception as e:
|
||||
@@ -1924,8 +2008,23 @@ class LiteLLMBase(ABC):
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
|
||||
total_tokens = 0
|
||||
# Aggregate usage across every tool-calling round (each round is a separate
|
||||
# provider request). Committing per round avoids the previous bug where a later
|
||||
# round's total overwrote earlier rounds.
|
||||
agg_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||
hist = deepcopy(history)
|
||||
|
||||
def _commit_round(round_usage, round_estimate):
|
||||
nonlocal total_tokens
|
||||
if round_usage and round_usage["total_tokens"]:
|
||||
agg_usage["prompt_tokens"] += round_usage["prompt_tokens"]
|
||||
agg_usage["completion_tokens"] += round_usage["completion_tokens"]
|
||||
agg_usage["total_tokens"] += round_usage["total_tokens"]
|
||||
else:
|
||||
agg_usage["total_tokens"] += round_estimate
|
||||
total_tokens = agg_usage["total_tokens"]
|
||||
self.last_usage = dict(agg_usage)
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
history = deepcopy(hist)
|
||||
try:
|
||||
@@ -1935,6 +2034,8 @@ class LiteLLMBase(ABC):
|
||||
logging.info(f"[ToolLoop] round={_round} model={self.model_name} tools={[t['function']['name'] for t in tools]}")
|
||||
|
||||
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
|
||||
# Request authoritative usage on the final streaming chunk.
|
||||
completion_args.setdefault("stream_options", {})["include_usage"] = True
|
||||
response = await litellm.acompletion(
|
||||
**completion_args,
|
||||
drop_params=True,
|
||||
@@ -1943,8 +2044,15 @@ class LiteLLMBase(ABC):
|
||||
|
||||
final_tool_calls = {}
|
||||
answer = ""
|
||||
round_usage = None
|
||||
round_estimate = 0
|
||||
|
||||
async for resp in response:
|
||||
# Usage-only final chunk may carry no choices — read it first.
|
||||
_u = usage_from_response(resp)
|
||||
if _u["total_tokens"]:
|
||||
round_usage = _u
|
||||
|
||||
if not hasattr(resp, "choices") or not resp.choices:
|
||||
continue
|
||||
|
||||
@@ -1979,16 +2087,16 @@ class LiteLLMBase(ABC):
|
||||
answer += delta.content
|
||||
yield delta.content
|
||||
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(delta.content)
|
||||
else:
|
||||
total_tokens = tol
|
||||
if not _u["total_tokens"]:
|
||||
round_estimate += num_tokens_from_string(delta.content)
|
||||
|
||||
finish_reason = getattr(resp.choices[0], "finish_reason", "")
|
||||
if finish_reason == "length":
|
||||
yield self._length_stop("")
|
||||
|
||||
# Commit this round's tokens to the running aggregate.
|
||||
_commit_round(round_usage, round_estimate)
|
||||
|
||||
if answer and not final_tool_calls:
|
||||
logging.info(f"[ToolLoop] round={_round} completed with text response, exiting")
|
||||
yield total_tokens
|
||||
@@ -2030,25 +2138,29 @@ class LiteLLMBase(ABC):
|
||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||
|
||||
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
|
||||
completion_args.setdefault("stream_options", {})["include_usage"] = True
|
||||
response = await litellm.acompletion(
|
||||
**completion_args,
|
||||
drop_params=True,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
fb_usage = None
|
||||
fb_estimate = 0
|
||||
async for resp in response:
|
||||
_u = usage_from_response(resp)
|
||||
if _u["total_tokens"]:
|
||||
fb_usage = _u
|
||||
if not hasattr(resp, "choices") or not resp.choices:
|
||||
continue
|
||||
delta = resp.choices[0].delta
|
||||
if not hasattr(delta, "content") or delta.content is None:
|
||||
continue
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(delta.content)
|
||||
else:
|
||||
total_tokens = tol
|
||||
if not _u["total_tokens"]:
|
||||
fb_estimate += num_tokens_from_string(delta.content)
|
||||
yield delta.content
|
||||
|
||||
_commit_round(fb_usage, fb_estimate)
|
||||
yield total_tokens
|
||||
return
|
||||
|
||||
@@ -2157,7 +2269,7 @@ class LiteLLMBase(ABC):
|
||||
if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers:
|
||||
extra_headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
# MiniMax requires GroupId as a query parameter for API authentication
|
||||
if self.provider == SupportedLiteLLMProvider.MiniMax and hasattr(self, 'group_id') and self.group_id:
|
||||
if self.provider == SupportedLiteLLMProvider.MiniMax and hasattr(self, "group_id") and self.group_id:
|
||||
api_base = completion_args.get("api_base", self.base_url)
|
||||
separator = "&" if "?" in api_base else "?"
|
||||
completion_args["api_base"] = f"{api_base}{separator}GroupId={self.group_id}"
|
||||
|
||||
Reference in New Issue
Block a user