feat(agent): report accurate aggregated token usage and propagate session/user + input/output to Langfuse for agent runs (#16420)

### What problem does this PR solve?

_Briefly describe what this PR aims to solve. Include background context
that will help reviewers understand the purpose of the PR._

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [x] Other (please describe):

## Summary

Agent (Canvas) runs previously did not surface token usage in the SSE
stream, and RAGFlow's own Langfuse generations for agent runs were
missing the prompt/completion split and the session/user correlation.
This made it impossible for an external caller (or Langfuse) to
reconcile an agent turn's cost with the upstream provider (e.g.
OpenRouter), because a single turn can issue several distinct LLM calls
(query rewriting / cross-language translation, multi-round tool
reasoning, nested sub-agents, and the final answer).

This PR introduces a per-run token usage sink so that **every** LLM call
in a run is aggregated and reported once, and enriches Langfuse
generations with the prompt/completion split plus session/user
attributes.

## What changes

### 1. Per-run token usage sink (`common/token_utils.py`)

- Adds two `contextvars`: `token_usage_sink` (a mutable per-run
accumulator) and `langfuse_run_attrs` (session_id/user_id for the run).
- Adds `record_run_token_usage(...)` (thread-safe via a lock, because
`thread_pool_exec` copies the context into worker threads that share the
sink dict) and `usage_from_response(...)` which extracts a
`{prompt_tokens, completion_tokens, total_tokens}` split from
OpenAI/OpenRouter-style responses.

### 2. Provider layer captures the prompt/completion split
(`rag/llm/chat_model.py`)

- `LiteLLMBase` and `Base` now store `self.last_usage`
(prompt/completion/total) for the most recent chat call, in both the
plain and tool-calling paths.
- Streaming requests set `stream_options.include_usage = True` (LiteLLM
path) so the authoritative usage arrives on the final chunk; this is
read even on the usage-only chunk that carries no `choices`.
- Fixes a multi-round accounting bug in `*_with_tools`: token totals
were **overwritten** by each round (`total_tokens = tol`) instead of
accumulated, undercounting multi-round tool conversations. Each round is
now committed to a running aggregate.

### 3. LLMBundle reports usage once, per call
(`api/db/services/llm_service.py`)

- New `_report_usage(total_tokens)` records the call's usage into the
active run sink and returns the prompt/completion/total split for
Langfuse. The split is only used when it is consistent with the
authoritative total; otherwise only the total is reported.
- All three chat entry points (`async_chat`, `async_chat_streamly`,
`async_chat_streamly_delta`) now emit `usage_details` with
`input`/`output`/`total` instead of total-only.
- `_start_langfuse_observation` now applies `session_id`/`user_id` from
the per-run context (`langfuse_run_attrs`) so agent-run generations are
correctly grouped, even though agent LLMBundles are constructed without
those attributes.

### 4. Canvas installs the sink and emits the aggregate
(`agent/canvas.py`)

- `Canvas.run()` installs a fresh `token_usage_sink` and
`langfuse_run_attrs` (from `user_id`/`session_id`) at the start of every
turn.
- `message_end` now includes an aggregated `usage` object:
`{prompt_tokens, completion_tokens, total_tokens, calls}` covering all
LLM calls in the run.

### 5. Pass session id into the run
(`api/db/services/canvas_service.py`)

- `completion()` forwards `session_id` to `Canvas.run()` for Langfuse
session correlation.

## Why a context variable

LLM calls in an agent run originate from many places that each build
their own `LLMBundle` (e.g. `cross_languages`/`keyword_extraction`
helpers, the Agent component, and nested sub-agents invoked as tools). A
run-scoped context variable is the only non-invasive chokepoint that
captures all of them exactly once, including nested agents (which run in
the same async context) and thread-pool tools (the executor copies the
context).

## Behavior / compatibility

- No public API or wire-format removal: `message_end` gains an
additional optional `usage` field; existing consumers are unaffected.
- When a provider does not return authoritative usage, behavior falls
back to the previous token estimate (total only, no split).
- Non-agent flows (Dataflow `Pipeline`, sync `Graph.run`) are untouched.

## Testing
- [x] Simple agent answer: `message_end.usage.total_tokens` matches
provider usage.
- [x] Agent with cross-language retrieval: aggregate equals the sum of
both provider calls.
- [x] Tool-calling agent (multi-round): total accumulates across rounds.
- [x] Nested agent (agent-as-tool): sub-agent tokens included in the
parent run total.
- [x] Langfuse: agent generations show input/output split and are
grouped by session/user.

---------

Co-authored-by: yzc <yuzhichang@gmail.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Öndery
2026-07-02 04:35:28 +03:00
committed by GitHub
parent 42a0faad18
commit 742188c3bb
6 changed files with 655 additions and 422 deletions

View File

@@ -15,6 +15,7 @@
#
import asyncio
import base64
import contextvars
import datetime
import inspect
import json
@@ -36,52 +37,54 @@ from api.db.joint_services.tenant_model_service import get_tenant_default_model_
from common.constants import LLMType
from common.misc_utils import get_uuid, hash_str2int
from common.exceptions import TaskCanceledException
from common.token_utils import token_usage_sink, langfuse_run_attrs
from rag.prompts.generator import chunks_format
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.tts_cache import synthesize_with_cache
_logger = logging.getLogger(__name__)
class Graph:
"""
dsl = {
"components": {
"begin": {
"obj":{
"component_name": "Begin",
"params": {},
},
"downstream": ["answer_0"],
"upstream": [],
dsl = {
"components": {
"begin": {
"obj":{
"component_name": "Begin",
"params": {},
},
"retrieval_0": {
"obj": {
"component_name": "Retrieval",
"params": {}
},
"downstream": ["generate_0"],
"upstream": ["answer_0"],
},
"generate_0": {
"obj": {
"component_name": "Generate",
"params": {}
},
"downstream": ["answer_0"],
"upstream": ["retrieval_0"],
}
"downstream": ["answer_0"],
"upstream": [],
},
"history": [],
"path": ["begin"],
"retrieval": {"chunks": [], "doc_aggs": []},
"globals": {
"sys.query": "",
"sys.user_id": tenant_id,
"sys.conversation_turns": 0,
"sys.files": []
"retrieval_0": {
"obj": {
"component_name": "Retrieval",
"params": {}
},
"downstream": ["generate_0"],
"upstream": ["answer_0"],
},
"generate_0": {
"obj": {
"component_name": "Generate",
"params": {}
},
"downstream": ["answer_0"],
"upstream": ["retrieval_0"],
}
},
"history": [],
"path": ["begin"],
"retrieval": {"chunks": [], "doc_aggs": []},
"globals": {
"sys.query": "",
"sys.user_id": tenant_id,
"sys.conversation_turns": 0,
"sys.files": []
}
"""
}
"""
def __init__(self, dsl: str, tenant_id=None, task_id=None, custom_header=None):
self.path = []
@@ -115,9 +118,7 @@ class Graph:
def __str__(self):
self.dsl["path"] = self.path
self.dsl["task_id"] = self.task_id
dsl = {
"components": {}
}
dsl = {"components": {}}
for k in self.dsl.keys():
if k in ["components"]:
continue
@@ -139,11 +140,13 @@ class Graph:
except Exception as e:
logging.warning("Graph.__str__: deepcopy failed for component '%s' key '%s' (type=%s): %s. Using shallow reference.", k, c, type(cpn[c]).__name__, e)
dsl["components"][k][c] = cpn[c]
def _serialize_default(obj):
if callable(obj):
return None
logging.warning("Graph.__str__: JSON fallback via str() for type=%s", type(obj).__name__)
return str(obj)
return json.dumps(dsl, ensure_ascii=False, default=_serialize_default)
def reset(self):
@@ -180,13 +183,13 @@ class Graph:
def get_tenant_id(self):
return self._tenant_id
def get_value_with_variable(self,value: str) -> Any:
def get_value_with_variable(self, value: str) -> Any:
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*")
out_parts = []
last = 0
for m in pat.finditer(value):
out_parts.append(value[last:m.start()])
out_parts.append(value[last : m.start()])
key = m.group(1)
v = self.get_variable_value(key)
if v is None:
@@ -205,7 +208,7 @@ class Graph:
last = m.end()
out_parts.append(value[last:])
return("".join(out_parts))
return "".join(out_parts)
def get_variable_value(self, exp: str) -> Any:
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
@@ -222,13 +225,13 @@ class Graph:
if not rest:
return root_val
return self.get_variable_param_value(root_val,rest)
return self.get_variable_param_value(root_val, rest)
def get_variable_param_value(self, obj: Any, path: str) -> Any:
cur = obj
if not path:
return cur
for key in path.split('.'):
for key in path.split("."):
if cur is None:
return None
@@ -253,7 +256,7 @@ class Graph:
cur = getattr(cur, key, None)
return cur
def set_variable_value(self, exp: str,value):
def set_variable_value(self, exp: str, value):
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
if exp.find("@") < 0:
self.globals[exp] = value
@@ -271,11 +274,11 @@ class Graph:
root_val = cpn["obj"].output(root_key)
if not root_val:
root_val = {}
cpn["obj"].set_output(root_key, self.set_variable_param_value(root_val,rest,value))
cpn["obj"].set_output(root_key, self.set_variable_param_value(root_val, rest, value))
def set_variable_param_value(self, obj: Any, path: str, value) -> Any:
cur = obj
keys = path.split('.')
keys = path.split(".")
if not path:
return value
for key in keys[:-1]:
@@ -298,7 +301,6 @@ class Graph:
class Canvas(Graph):
def __init__(self, dsl: str, tenant_id=None, task_id=None, canvas_id=None, custom_header=None):
self.globals = {
"sys.query": "",
@@ -306,9 +308,14 @@ class Canvas(Graph):
"sys.conversation_turns": 0,
"sys.files": [],
"sys.history": [],
"sys.date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
"sys.date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}
self.variables = {}
# Aggregated provider token usage (prompt/completion/total) across every LLM
# call in a single run — query rewriting, cross-language translation, tool
# reasoning and the final answer. Populated via the token_usage_sink context
# variable that each LLMBundle chat call writes to. Reset at run() start.
self._run_token_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "calls": 0}
super().__init__(dsl, tenant_id, task_id, custom_header=custom_header)
self._id = canvas_id
@@ -323,13 +330,13 @@ class Canvas(Graph):
self.globals["sys.date"] = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
else:
self.globals = {
"sys.query": "",
"sys.user_id": "",
"sys.conversation_turns": 0,
"sys.files": [],
"sys.history": [],
"sys.date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
}
"sys.query": "",
"sys.user_id": "",
"sys.conversation_turns": 0,
"sys.files": [],
"sys.history": [],
"sys.date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}
if "variables" in self.dsl:
self.variables = self.dsl["variables"]
else:
@@ -393,6 +400,38 @@ class Canvas(Graph):
self.globals[k] = ""
async def run(self, **kwargs):
# Install a fresh per-run token usage sink and Langfuse correlation context,
# and guarantee both are torn down when the run ends (even on early return or
# exception) so later LLM calls in the same task never inherit a previous
# run's sink or session/user attributes.
self._run_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "calls": 0}
_lf_attrs = {}
_user_id = kwargs.get("user_id")
if _user_id:
_lf_attrs["user_id"] = str(_user_id)[:200]
_session_id = kwargs.get("session_id") or self._id
if _session_id:
_lf_attrs["session_id"] = str(_session_id)[:200]
sink_token = token_usage_sink.set(self._run_token_usage)
attrs_token = langfuse_run_attrs.set(_lf_attrs)
try:
async for ev in self._run_impl(**kwargs):
yield ev
finally:
# reset() can raise if the generator is closed from a different context
# (e.g. client disconnect); fall back to clearing the values in that case.
try:
token_usage_sink.reset(sink_token)
except ValueError:
logging.debug("Failed to reset token usage ContextVar", exc_info=True)
token_usage_sink.set(None)
try:
langfuse_run_attrs.reset(attrs_token)
except ValueError:
logging.debug("Failed to reset Langfuse run attributes ContextVar", exc_info=True)
langfuse_run_attrs.set(None)
async def _run_impl(self, **kwargs):
self.globals["sys.date"] = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
st = time.perf_counter()
self._loop = asyncio.get_running_loop()
@@ -406,7 +445,7 @@ class Canvas(Graph):
if kwargs.get("webhook_payload"):
for k, cpn in self.components.items():
if self.components[k]["obj"].component_name.lower() == "begin" and self.components[k]["obj"]._param.mode == "Webhook":
if self.components[k]["obj"].component_name.lower() == "begin" and self.components[k]["obj"]._param.mode == "Webhook":
payload = kwargs.get("webhook_payload", {})
if "input" in payload:
self.components[k]["obj"].set_input_value("request", payload["input"])
@@ -427,7 +466,7 @@ class Canvas(Graph):
self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k], layout_recognize)
else:
self.globals[f"sys.{k}"] = kwargs[k]
if not self.globals["sys.conversation_turns"] :
if not self.globals["sys.conversation_turns"]:
self.globals["sys.conversation_turns"] = 0
self.globals["sys.conversation_turns"] += 1
is_resume = bool(self.path) and self.path[0].lower().find("userfillup") >= 0
@@ -436,11 +475,11 @@ class Canvas(Graph):
nonlocal created_at
return {
"event": event,
#"conversation_id": "f3cc152b-24b0-4258-a1a1-7d5e9fc8a115",
# "conversation_id": "f3cc152b-24b0-4258-a1a1-7d5e9fc8a115",
"message_id": self.message_id,
"created_at": created_at,
"task_id": self.task_id,
"data": dt
"data": dt,
}
if not is_resume:
@@ -476,7 +515,11 @@ class Canvas(Graph):
if use_async:
await cpn_obj.invoke_async(**(call_kwargs or {}))
return
await loop.run_in_executor(self._thread_pool, partial(sync_fn, **(call_kwargs or {})))
# run_in_executor does not propagate context variables; copy the
# current context so the token usage sink / Langfuse attributes set
# by run() remain visible to LLMBundle calls inside sync components.
ctx = contextvars.copy_context()
await loop.run_in_executor(self._thread_pool, lambda: ctx.run(partial(sync_fn, **(call_kwargs or {}))))
i = f
while i < t:
@@ -525,16 +568,19 @@ class Canvas(Graph):
json.dumps(outputs, ensure_ascii=False, default=str)[:500],
cpn_obj.error(),
)
return decorate("node_finished",{
"inputs": cpn_obj.get_input_values(),
"outputs": outputs,
"component_id": cpn_obj._id,
"component_name": self.get_component_name(cpn_obj._id),
"component_type": self.get_component_type(cpn_obj._id),
"error": cpn_obj.error(),
"elapsed_time": time.perf_counter() - cpn_obj.output("_created_time"),
"created_at": cpn_obj.output("_created_time"),
})
return decorate(
"node_finished",
{
"inputs": cpn_obj.get_input_values(),
"outputs": outputs,
"component_id": cpn_obj._id,
"component_name": self.get_component_name(cpn_obj._id),
"component_type": self.get_component_type(cpn_obj._id),
"error": cpn_obj.error(),
"elapsed_time": time.perf_counter() - cpn_obj.output("_created_time"),
"created_at": cpn_obj.output("_created_time"),
},
)
self.error = ""
idx = 0 if is_resume else len(self.path) - 1
@@ -543,13 +589,17 @@ class Canvas(Graph):
while idx < len(self.path):
to = len(self.path)
for i in range(idx, to):
yield decorate("node_started", {
"inputs": None, "created_at": int(time.time()),
"component_id": self.path[i],
"component_name": self.get_component_name(self.path[i]),
"component_type": self.get_component_type(self.path[i]),
"thoughts": self.get_component_thoughts(self.path[i])
})
yield decorate(
"node_started",
{
"inputs": None,
"created_at": int(time.time()),
"component_id": self.path[i],
"component_name": self.get_component_name(self.path[i]),
"component_type": self.get_component_type(self.path[i]),
"thoughts": self.get_component_thoughts(self.path[i]),
},
)
await _run_batch(idx, to)
to = len(self.path)
# post-processing of components invocation
@@ -564,6 +614,7 @@ class Canvas(Graph):
_m = ""
buff_m = ""
stream = cpn_obj.output("content")()
async def _process_stream(m):
nonlocal buff_m, _m, tts_mdl
if not m:
@@ -578,13 +629,7 @@ class Canvas(Graph):
_m += m
if len(buff_m) > 16:
ev = decorate(
"message",
{
"content": m,
"audio_binary": self.tts(tts_mdl, buff_m)
}
)
ev = decorate("message", {"content": m, "audio_binary": self.tts(tts_mdl, buff_m)})
buff_m = ""
return ev
@@ -592,12 +637,12 @@ class Canvas(Graph):
if inspect.isasyncgen(stream):
async for m in stream:
ev= await _process_stream(m)
ev = await _process_stream(m)
if ev:
yield ev
else:
for m in stream:
ev= await _process_stream(m)
ev = await _process_stream(m)
if ev:
yield ev
if buff_m:
@@ -629,7 +674,7 @@ class Canvas(Graph):
else:
self.error = cpn_obj.error()
if cpn_obj.component_name.lower() not in ("iteration","loop"):
if cpn_obj.component_name.lower() not in ("iteration", "loop"):
if isinstance(cpn_obj.output("content"), partial):
if self.error:
cpn_obj.set_output("content", None)
@@ -654,7 +699,7 @@ class Canvas(Graph):
for cpn_id in cpn_ids:
_append_path(cpn_id)
if cpn_obj.component_name.lower() in ("iterationitem","loopitem") and cpn_obj.end():
if cpn_obj.component_name.lower() in ("iterationitem", "loopitem") and cpn_obj.end():
iter = cpn_obj.get_parent()
yield _node_finished(iter)
_extend_path(self.get_component(cpn["parent_id"])["downstream"])
@@ -683,10 +728,7 @@ class Canvas(Graph):
o = self.get_component_obj(c)
if o.component_name.lower() == "userfillup":
o.invoke()
another_inputs.update({
k: v for k, v in o.get_input_elements().items()
if not self._is_input_field_satisfied(v)
})
another_inputs.update({k: v for k, v in o.get_input_elements().items() if not self._is_input_field_satisfied(v)})
if o.get_param("enable_tips"):
tips = o.output("tips")
if not another_inputs:
@@ -696,23 +738,30 @@ class Canvas(Graph):
return
self.path = self.path[:idx]
if not self.error:
yield decorate("workflow_finished",
{
"inputs": kwargs.get("inputs"),
"outputs": self.get_component_obj(self.path[-1]).output(),
"elapsed_time": time.perf_counter() - st,
"created_at": st,
})
yield decorate(
"workflow_finished",
{
"inputs": kwargs.get("inputs"),
"outputs": self.get_component_obj(self.path[-1]).output(),
"elapsed_time": time.perf_counter() - st,
"created_at": st,
# Run-level total of all LLM calls — emitted once here.
"usage": self._run_usage_payload(),
},
)
self.history.append(("assistant", self.get_component_obj(self.path[-1]).output()))
self.globals["sys.history"].append(f"{self.history[-1][0]}: {self.history[-1][1]}")
elif "Task has been canceled" in self.error:
yield decorate("workflow_finished",
{
"inputs": kwargs.get("inputs"),
"outputs": "Task has been canceled",
"elapsed_time": time.perf_counter() - st,
"created_at": st,
})
yield decorate(
"workflow_finished",
{
"inputs": kwargs.get("inputs"),
"outputs": "Task has been canceled",
"elapsed_time": time.perf_counter() - st,
"created_at": st,
"usage": self._run_usage_payload(),
},
)
def is_reff(self, exp: str) -> bool:
exp = exp.strip("{").strip("}")
@@ -725,8 +774,7 @@ class Canvas(Graph):
return False
return True
def tts(self,tts_mdl, text):
def tts(self, tts_mdl, text):
def clean_tts_text(text: str) -> str:
if not text:
return ""
@@ -736,15 +784,8 @@ class Canvas(Graph):
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
emoji_pattern = re.compile(
"[\U0001F600-\U0001F64F"
"\U0001F300-\U0001F5FF"
"\U0001F680-\U0001F6FF"
"\U0001F1E0-\U0001F1FF"
"\U00002700-\U000027BF"
"\U0001F900-\U0001F9FF"
"\U0001FA70-\U0001FAFF"
"\U0001FAD0-\U0001FAFF]+",
flags=re.UNICODE
"[\U0001f600-\U0001f64f\U0001f300-\U0001f5ff\U0001f680-\U0001f6ff\U0001f1e0-\U0001f1ff\U00002700-\U000027bf\U0001f900-\U0001f9ff\U0001fa70-\U0001faff\U0001fad0-\U0001faff]+",
flags=re.UNICODE,
)
text = emoji_pattern.sub("", text)
@@ -755,6 +796,7 @@ class Canvas(Graph):
text = text[:MAX_LEN]
return text
if not tts_mdl or not text:
return None
text = clean_tts_text(text)
@@ -766,7 +808,7 @@ class Canvas(Graph):
convs = []
if window_size <= 0:
return convs
for role, obj in self.history[window_size * -2:]:
for role, obj in self.history[window_size * -2 :]:
if isinstance(obj, dict):
convs.append({"role": role, "content": obj.get("content", "")})
else:
@@ -815,17 +857,19 @@ class Canvas(Graph):
async def get_files_async(self, files: Union[None, list[dict]], layout_recognize: str = None) -> list[str]:
if not files:
return []
return []
def image_to_base64(file):
return "data:{};base64,{}".format(file["mime_type"],
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
return "data:{};base64,{}".format(file["mime_type"], base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
def parse_file(file):
blob = FileService.get_blob(file["created_by"], file["id"])
return FileService.parse(file["name"], blob, True, file["created_by"], layout_recognize)
loop = asyncio.get_running_loop()
tasks = []
for file in files:
if file["mime_type"].find("image") >=0:
if file["mime_type"].find("image") >= 0:
tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file))
continue
tasks.append(loop.run_in_executor(self._thread_pool, parse_file, file))
@@ -844,7 +888,7 @@ class Canvas(Graph):
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
agent_ids = agent_id.split("-->")
agent_name = self.get_component_name(agent_ids[0])
path = agent_name if len(agent_ids) < 2 else agent_name+"-->"+"-->".join(agent_ids[1:])
path = agent_name if len(agent_ids) < 2 else agent_name + "-->" + "-->".join(agent_ids[1:])
try:
bin = REDIS_CONN.get(f"{self.task_id}-{self.message_id}-logs")
if bin:
@@ -852,16 +896,10 @@ class Canvas(Graph):
if obj[-1]["component_id"] == agent_ids[0]:
obj[-1]["trace"].append({"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time})
else:
obj.append({
"component_id": agent_ids[0],
"trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}]
})
obj.append({"component_id": agent_ids[0], "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}]})
else:
obj = [{
"component_id": agent_ids[0],
"trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}]
}]
REDIS_CONN.set_obj(f"{self.task_id}-{self.message_id}-logs", obj, 60*10)
obj = [{"component_id": agent_ids[0], "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}]}]
REDIS_CONN.set_obj(f"{self.task_id}-{self.message_id}-logs", obj, 60 * 10)
except Exception as e:
logging.exception(e)
@@ -899,9 +937,22 @@ class Canvas(Graph):
message_end["attachment"] = cpn_obj.output("attachment")
if self._has_reference():
message_end["reference"] = self.get_reference()
# NOTE: aggregated run token usage is intentionally NOT attached here.
# _build_message_end runs once per Message component, so a multi-Message graph
# would emit cumulative usage repeatedly and double count. The run total is
# emitted exactly once on the terminal workflow_finished event instead.
return message_end
def add_memory(self, user:str, assist:str, summ: str):
def _run_usage_payload(self) -> dict:
usage = getattr(self, "_run_token_usage", None) or {}
return {
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
"calls": usage.get("calls", 0),
}
def add_memory(self, user: str, assist: str, summ: str):
self.memory.append((user, assist, summ))
def get_memory(self) -> list[Tuple]:

View File

@@ -88,29 +88,32 @@ def _canvas_json_default(obj):
def _require_canvas_access_sync(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not UserCanvasService.accessible(kwargs.get('agent_id'), kwargs.get('tenant_id')):
if not UserCanvasService.accessible(kwargs.get("agent_id"), kwargs.get("tenant_id")):
return get_json_result(data=False, message="Make sure you have permission to access the agent.", code=RetCode.OPERATING_ERROR)
return func(*args, **kwargs)
return wrapper
def _require_canvas_access_async(func):
@wraps(func)
async def wrapper(*args, **kwargs):
agent_id = kwargs.get('agent_id')
tenant_id = kwargs.get('tenant_id')
agent_id = kwargs.get("agent_id")
tenant_id = kwargs.get("tenant_id")
if not await thread_pool_exec(UserCanvasService.accessible, agent_id, tenant_id):
return get_json_result(data=False, message="Make sure you have permission to access the agent.", code=RetCode.OPERATING_ERROR)
return await func(*args, **kwargs)
return wrapper
def _require_canvas_owner_sync(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not UserCanvasService.query(user_id=kwargs.get('tenant_id'), id=kwargs.get('agent_id')):
if not UserCanvasService.query(user_id=kwargs.get("tenant_id"), id=kwargs.get("agent_id")):
return get_json_result(data=False, message="Only the owner of the agent is authorized for this operation.", code=RetCode.OPERATING_ERROR)
return func(*args, **kwargs)
return wrapper
@@ -261,9 +264,7 @@ async def _run_workflow_session(
if "chunks" in workflow_conv["reference"]:
workflow_conv["reference"] = [workflow_conv["reference"]]
else:
workflow_conv["reference"] = [
value for _, value in sorted(workflow_conv["reference"].items(), key=lambda item: int(item[0]))
]
workflow_conv["reference"] = [value for _, value in sorted(workflow_conv["reference"].items(), key=lambda item: int(item[0]))]
elif not isinstance(workflow_conv.get("reference"), list):
workflow_conv["reference"] = []
workflow_conv["reference"] = [_normalize_agent_reference_entry(reference) for reference in workflow_conv["reference"]]
@@ -344,22 +345,16 @@ async def _run_workflow_session(
# bare [DONE] (fixes #15169).
logging.info(
"empty agent output - returning session_id (agent_id=%s session_id=%s stream=%s)",
agent_id, session_id, True,
)
yield (
"data:"
+ json.dumps({"session_id": session_id, "data": {}}, ensure_ascii=False)
+ "\n\n"
agent_id,
session_id,
True,
)
yield ("data:" + json.dumps({"session_id": session_id, "data": {}}, ensure_ascii=False) + "\n\n")
await persist_workflow_session()
except Exception as exc:
logging.exception(exc)
canvas.cancel_task()
yield (
"data:"
+ json.dumps({"code": 500, "message": str(exc), "data": False}, ensure_ascii=False)
+ "\n\n"
)
yield ("data:" + json.dumps({"code": 500, "message": str(exc), "data": False}, ensure_ascii=False) + "\n\n")
finally:
if not done_sent:
done_sent = True
@@ -400,7 +395,9 @@ async def _run_workflow_session(
# (fixes #15169).
logging.info(
"empty agent output - returning session_id (agent_id=%s session_id=%s stream=%s)",
agent_id, session_id, False,
agent_id,
session_id,
False,
)
await commit_runtime_replica()
return get_result(data={"session_id": session_id})
@@ -559,16 +556,13 @@ async def delete_agent_session(tenant_id, agent_id):
if errors:
if success_count > 0:
return get_result(data={"success_count": success_count, "errors": errors},
message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
else:
return get_error_data_result(message="; ".join(errors))
if duplicate_messages:
if success_count > 0:
return get_result(
message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors",
data={"success_count": success_count, "errors": duplicate_messages})
return get_result(message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages})
else:
return get_error_data_result(message=";".join(duplicate_messages))
@@ -611,8 +605,8 @@ async def _iter_session_completion_events(tenant_id, agent_id, req, return_trace
yield ans
continue
if event in ["message", "message_end", "user_inputs", "workflow_finished"]:
if event in ["user_inputs", "workflow_finished"]:
if event in ["message", "message_end", "user_inputs"]:
if event == "user_inputs":
logging.debug(
"Forwarding session completion event: tenant_id=%s agent_id=%s event=%s",
tenant_id,
@@ -620,6 +614,22 @@ async def _iter_session_completion_events(tenant_id, agent_id, req, return_trace
event,
)
yield ans
continue
if event == "workflow_finished":
# Forward only the run-level aggregated token usage, not the whole terminal
# payload (inputs/outputs), so the session completion stream surface stays
# limited to what the usage contract needs.
logging.debug(
"Forwarding session completion event: tenant_id=%s agent_id=%s event=%s",
tenant_id,
agent_id,
event,
)
usage = ans.get("data", {}).get("usage")
if usage is not None:
yield {**ans, "data": {"usage": usage}}
continue
@manager.route("/agents/templates", methods=["GET"]) # noqa: F821
@@ -760,7 +770,7 @@ async def update_agent_tags(tenant_id, canvas_id):
@add_tenant_id_to_kwargs
async def create_agent(tenant_id):
req = {k: v for k, v in (await get_request_json()).items() if v is not None}
req["canvas_type"] = req.get("canvas_type","")
req["canvas_type"] = req.get("canvas_type", "")
req["user_id"] = tenant_id
req["canvas_category"] = req.get("canvas_category") or CanvasCategory.Agent
req["release"] = bool(req.get("release", ""))
@@ -837,13 +847,9 @@ async def upload_agent_file(agent_id, tenant_id):
)
try:
if len(file_objs) == 1:
uploaded = await thread_pool_exec(
FileService.upload_info, tenant_id, file_objs[0], request.args.get("url")
)
uploaded = await thread_pool_exec(FileService.upload_info, tenant_id, file_objs[0], request.args.get("url"))
return get_json_result(data=uploaded)
results = await asyncio.gather(
*(thread_pool_exec(FileService.upload_info, tenant_id, file_obj) for file_obj in file_objs)
)
results = await asyncio.gather(*(thread_pool_exec(FileService.upload_info, tenant_id, file_obj) for file_obj in file_objs))
return get_json_result(data=results)
except Exception as exc:
logging.exception(
@@ -1015,7 +1021,7 @@ def delete_agent(agent_id, tenant_id):
@_require_canvas_access_async
async def update_agent(agent_id, tenant_id):
req = {k: v for k, v in (await get_request_json()).items() if v is not None}
req["canvas_type"] = req.get("canvas_type","")
req["canvas_type"] = req.get("canvas_type", "")
req["release"] = bool(req.get("release", ""))
if req.get("dsl") is not None:
@@ -1038,10 +1044,7 @@ async def update_agent(agent_id, tenant_id):
return get_data_error_result(message=f"{req['title']} already exists.")
agent_title_for_version = req.get("title") or (current_agent.title if current_agent else "")
canvas_category = (
req.get("canvas_category")
or (current_agent.canvas_category if current_agent else CanvasCategory.Agent)
)
canvas_category = req.get("canvas_category") or (current_agent.canvas_category if current_agent else CanvasCategory.Agent)
owner_nickname = _get_user_nickname(tenant_id)
UserCanvasService.update_by_id(agent_id, req)
@@ -1153,13 +1156,19 @@ async def test_db_connection():
except ValueError as exc:
logging.warning(
"Rejected test_db_connection: unsafe host %r (db_type=%s, user=%s): %s",
req.get("host"), req.get("db_type"), current_user.id, exc,
req.get("host"),
req.get("db_type"),
current_user.id,
exc,
)
return get_data_error_result(message=str(exc))
except OSError as exc:
logging.warning(
"Rejected test_db_connection: cannot resolve host %r (db_type=%s, user=%s): %s",
req.get("host"), req.get("db_type"), current_user.id, exc,
req.get("host"),
req.get("db_type"),
current_user.id,
exc,
)
logging.debug("Full resolver exception for host %r", req.get("host"), exc_info=True)
return get_data_error_result(message=f"Could not resolve host {req.get('host')!r}.")
@@ -1198,13 +1207,7 @@ async def test_db_connection():
elif req["db_type"] == "mssql":
import pyodbc
connection_string = (
f"DRIVER={{ODBC Driver 17 for SQL Server}};"
f"SERVER={safe_host},{req['port']};"
f"DATABASE={req['database']};"
f"UID={req['username']};"
f"PWD={req['password']};"
)
connection_string = f"DRIVER={{ODBC Driver 17 for SQL Server}};SERVER={safe_host},{req['port']};DATABASE={req['database']};UID={req['username']};PWD={req['password']};"
db = pyodbc.connect(connection_string)
try:
cursor = db.cursor()
@@ -1217,14 +1220,7 @@ async def test_db_connection():
elif req["db_type"] == "IBM DB2":
import ibm_db
conn_str = (
f"DATABASE={req['database']};"
f"HOSTNAME={safe_host};"
f"PORT={req['port']};"
f"PROTOCOL=TCPIP;"
f"UID={req['username']};"
f"PWD={req['password']};"
)
conn_str = f"DATABASE={req['database']};HOSTNAME={safe_host};PORT={req['port']};PROTOCOL=TCPIP;UID={req['username']};PWD={req['password']};"
logging.info(
"DATABASE=%s;HOSTNAME=%s;PORT=%s;PROTOCOL=TCPIP;UID=%s;PWD=****;",
req["database"],
@@ -1387,9 +1383,7 @@ async def agent_chat_completion(tenant_id, agent_id=None):
if "chunks" in workflow_conv["reference"]:
workflow_conv["reference"] = [workflow_conv["reference"]]
else:
workflow_conv["reference"] = [
value for _, value in sorted(workflow_conv["reference"].items(), key=lambda item: int(item[0]))
]
workflow_conv["reference"] = [value for _, value in sorted(workflow_conv["reference"].items(), key=lambda item: int(item[0]))]
elif not isinstance(workflow_conv.get("reference"), list):
workflow_conv["reference"] = []
workflow_conv["reference"] = [_normalize_agent_reference_entry(reference) for reference in workflow_conv["reference"]]
@@ -1598,13 +1592,11 @@ async def agent_chat_completion(tenant_id, agent_id=None):
# seeing only a bare [DONE] (fixes #15169).
logging.info(
"empty agent output - returning session_id (agent_id=%s session_id=%s stream=%s)",
agent_id, session_id, True,
)
yield (
"data:"
+ json.dumps({"session_id": session_id, "data": {}}, ensure_ascii=False)
+ "\n\n"
agent_id,
session_id,
True,
)
yield ("data:" + json.dumps({"session_id": session_id, "data": {}}, ensure_ascii=False) + "\n\n")
yield "data:[DONE]\n\n"
return _build_sse_response(generate())
@@ -1614,6 +1606,7 @@ async def agent_chat_completion(tenant_id, agent_id=None):
final_ans = {}
trace_items = []
structured_output = {}
run_usage = None
async for ans in _iter_session_completion_events(tenant_id, agent_id, req, return_trace):
try:
if ans["event"] == "message":
@@ -1633,6 +1626,11 @@ async def agent_chat_completion(tenant_id, agent_id=None):
"trace": [copy.deepcopy(data)],
}
)
if ans.get("event") == "workflow_finished":
# Capture the run-level usage but keep message_end/user_inputs as
# final_ans so the non-stream response shape stays unchanged.
run_usage = ans.get("data", {}).get("usage")
continue
if ans.get("event") == "message_end":
final_ans = ans
elif ans.get("event") == "user_inputs" and not final_ans:
@@ -1647,7 +1645,9 @@ async def agent_chat_completion(tenant_id, agent_id=None):
# (fixes #15169).
logging.info(
"empty agent output - returning session_id (agent_id=%s session_id=%s stream=%s)",
agent_id, session_id, False,
agent_id,
session_id,
False,
)
return get_result(data={"session_id": session_id})
@@ -1655,6 +1655,8 @@ async def agent_chat_completion(tenant_id, agent_id=None):
final_ans["data"] = {}
final_ans["data"]["content"] = full_content
final_ans["data"]["reference"] = reference
if run_usage:
final_ans["data"]["usage"] = run_usage
if structured_output:
final_ans["data"]["structured"] = structured_output
if return_trace and final_ans:
@@ -1688,16 +1690,16 @@ async def _webhook_impl(agent_id: str, is_test: bool):
# 1. Fetch canvas by agent_id
exists, cvs = UserCanvasService.get_by_id(agent_id)
if not exists:
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Canvas not found."),RetCode.BAD_REQUEST
return get_data_error_result(code=RetCode.BAD_REQUEST, message="Canvas not found."), RetCode.BAD_REQUEST
# 2. Check canvas category
if cvs.canvas_category == CanvasCategory.DataFlow:
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Dataflow can not be triggered by webhook."),RetCode.BAD_REQUEST
return get_data_error_result(code=RetCode.BAD_REQUEST, message="Dataflow can not be triggered by webhook."), RetCode.BAD_REQUEST
# 3. Load DSL from canvas
dsl = getattr(cvs, "dsl", None)
if not isinstance(dsl, dict):
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST
return get_data_error_result(code=RetCode.BAD_REQUEST, message="Invalid DSL format."), RetCode.BAD_REQUEST
# 4. Check webhook configuration in DSL
webhook_cfg = {}
@@ -1708,15 +1710,13 @@ async def _webhook_impl(agent_id: str, is_test: bool):
webhook_cfg = cpn_obj["params"]
if not webhook_cfg:
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Webhook not configured for this agent."),RetCode.BAD_REQUEST
return get_data_error_result(code=RetCode.BAD_REQUEST, message="Webhook not configured for this agent."), RetCode.BAD_REQUEST
# 5. Validate request method against webhook_cfg.methods
allowed_methods = webhook_cfg.get("methods", [])
request_method = request.method.upper()
if allowed_methods and request_method not in allowed_methods:
return get_data_error_result(
code=RetCode.BAD_REQUEST,message=f"HTTP method '{request_method}' not allowed for this webhook."
),RetCode.BAD_REQUEST
return get_data_error_result(code=RetCode.BAD_REQUEST, message=f"HTTP method '{request_method}' not allowed for this webhook."), RetCode.BAD_REQUEST
async def validate_webhook_security(security_cfg: dict):
"""Validate webhook security rules based on security configuration."""
@@ -1795,7 +1795,6 @@ async def _webhook_impl(agent_id: str, is_test: bool):
client_ip = request.remote_addr
for rule in whitelist:
if "/" in rule:
# CIDR notation
@@ -1854,7 +1853,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
def _validate_token_auth(security_cfg):
"""Validate header-based token authentication."""
token_cfg = security_cfg.get("token",{})
token_cfg = security_cfg.get("token", {})
header = token_cfg.get("token_header")
token_value = token_cfg.get("token_value")
@@ -1883,7 +1882,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
if not auth_header.startswith("Bearer "):
raise Exception("Missing Bearer token")
token = auth_header[len("Bearer "):].strip()
token = auth_header[len("Bearer ") :].strip()
if not token:
raise Exception("Empty Bearer token")
@@ -1922,10 +1921,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
else:
required_claims = []
required_claims = [
c for c in required_claims
if isinstance(c, str) and c.strip()
]
required_claims = [c for c in required_claims if isinstance(c, str) and c.strip()]
RESERVED_CLAIMS = {"exp", "sub", "aud", "iss", "nbf", "iat"}
for claim in required_claims:
@@ -1939,10 +1935,10 @@ async def _webhook_impl(agent_id: str, is_test: bool):
return decoded
try:
security_config=webhook_cfg.get("security", {})
security_config = webhook_cfg.get("security", {})
await validate_webhook_security(security_config)
except Exception as e:
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
return get_data_error_result(code=RetCode.BAD_REQUEST, message=str(e)), RetCode.BAD_REQUEST
if not isinstance(cvs.dsl, str):
dsl = json.dumps(cvs.dsl, ensure_ascii=False)
try:
@@ -1950,7 +1946,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
canvas = Canvas(dsl, cvs.user_id, agent_id, canvas_id=agent_id)
except Exception as e:
resp=get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e))
resp = get_data_error_result(code=RetCode.BAD_REQUEST, message=str(e))
resp.status_code = RetCode.BAD_REQUEST
return resp
@@ -1967,9 +1963,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
# 3. Body
ctype = request.headers.get("Content-Type", "").split(";")[0].strip()
if ctype and ctype != content_type:
raise ValueError(
f"Invalid Content-Type: expect '{content_type}', got '{ctype}'"
)
raise ValueError(f"Invalid Content-Type: expect '{content_type}', got '{ctype}'")
body_data: dict = {}
@@ -1991,11 +1985,11 @@ async def _webhook_impl(agent_id: str, is_test: bool):
raise Exception("Too many uploaded files")
for key, file in files.items():
desc = FileService.upload_info(
cvs.user_id, # user
file, # FileStorage
None # url (None for webhook)
cvs.user_id, # user
file, # FileStorage
None, # url (None for webhook)
)
file_parsed= await canvas.get_files_async([desc])
file_parsed = await canvas.get_files_async([desc])
body_data[key] = file_parsed
elif ctype == "application/x-www-form-urlencoded":
@@ -2057,15 +2051,12 @@ async def _webhook_impl(agent_id: str, is_test: bool):
# 4. Type validation
if not validate_type(value, field_type):
raise Exception(
f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}"
)
raise Exception(f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}")
extracted[field] = value
return extracted
def default_for_type(t):
"""Return default value for the given schema type."""
if t == "file":
@@ -2145,7 +2136,6 @@ async def _webhook_impl(agent_id: str, is_test: bool):
# Default: do nothing
return value
def validate_type(value, t):
"""Validate value type against schema type t."""
if t == "file":
@@ -2179,28 +2169,24 @@ async def _webhook_impl(agent_id: str, is_test: bool):
return True
return True
parsed = await parse_webhook_request(webhook_cfg.get("content_types"))
SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}})
# Extract strictly by schema
try:
query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query")
query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query")
header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers")
body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body")
body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body")
except Exception as e:
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
return get_data_error_result(code=RetCode.BAD_REQUEST, message=str(e)), RetCode.BAD_REQUEST
clean_request = {
"query": query_clean,
"headers": header_clean,
"body": body_clean,
"input": parsed
}
clean_request = {"query": query_clean, "headers": header_clean, "body": body_clean, "input": parsed}
execution_mode = webhook_cfg.get("execution_mode", "Immediately")
response_cfg = webhook_cfg.get("response", {})
def append_webhook_trace(agent_id: str, start_ts: float,event: dict, ttl=600):
def append_webhook_trace(agent_id: str, start_ts: float, event: dict, ttl=600):
from rag.utils.redis_conn import REDIS_CONN
key = f"webhook-trace-{agent_id}-logs"
@@ -2208,15 +2194,9 @@ async def _webhook_impl(agent_id: str, is_test: bool):
raw = REDIS_CONN.get(key)
obj = json.loads(raw) if raw else {"webhooks": {}}
ws = obj["webhooks"].setdefault(
str(start_ts),
{"start_ts": start_ts, "events": []}
)
ws = obj["webhooks"].setdefault(str(start_ts), {"start_ts": start_ts, "events": []})
ws["events"].append({
"ts": time.time(),
**event
})
ws["events"].append({"ts": time.time(), **event})
REDIS_CONN.set_obj(key, obj, ttl)
@@ -2225,10 +2205,10 @@ async def _webhook_impl(agent_id: str, is_test: bool):
try:
status = int(status)
except (TypeError, ValueError):
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}")),RetCode.BAD_REQUEST
return get_data_error_result(code=RetCode.BAD_REQUEST, message=str(f"Invalid response status code: {status}")), RetCode.BAD_REQUEST
if not (200 <= status <= 399):
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}, must be between 200 and 399")),RetCode.BAD_REQUEST
return get_data_error_result(code=RetCode.BAD_REQUEST, message=str(f"Invalid response status code: {status}, must be between 200 and 399")), RetCode.BAD_REQUEST
body_tpl = response_cfg.get("body_template", "")
@@ -2242,7 +2222,6 @@ async def _webhook_impl(agent_id: str, is_test: bool):
except (json.JSONDecodeError, TypeError):
return body, "text/plain"
body, content_type = parse_body(body_tpl)
resp = Response(
json.dumps(body, ensure_ascii=False) if content_type == "application/json" else body,
@@ -2252,11 +2231,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
async def background_run():
try:
async for ans in canvas.run(
query="",
user_id=cvs.user_id,
webhook_payload=clean_request
):
async for ans in canvas.run(query="", user_id=cvs.user_id, webhook_payload=clean_request):
if is_test:
append_webhook_trace(agent_id, start_ts, ans)
@@ -2268,7 +2243,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
"event": "finished",
"elapsed_time": time.time() - start_ts,
"success": True,
}
},
)
cvs.dsl = json.loads(str(canvas))
@@ -2285,7 +2260,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
"event": "error",
"message": str(e),
"error_type": type(e).__name__,
}
},
)
append_webhook_trace(
agent_id,
@@ -2294,7 +2269,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
"event": "finished",
"elapsed_time": time.time() - start_ts,
"success": False,
}
},
)
except Exception:
logging.exception("Failed to append webhook trace")
@@ -2305,6 +2280,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
task.add_done_callback(_background_tasks.discard)
return resp
else:
async def sse():
nonlocal canvas
contents: list[str] = []
@@ -2326,11 +2302,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
if ans["event"] == "message_end":
status = int(ans["data"].get("status", status))
if is_test:
append_webhook_trace(
agent_id,
start_ts,
ans
)
append_webhook_trace(agent_id, start_ts, ans)
if is_test:
append_webhook_trace(
agent_id,
@@ -2339,13 +2311,13 @@ async def _webhook_impl(agent_id: str, is_test: bool):
"event": "finished",
"elapsed_time": time.time() - start_ts,
"success": True,
}
},
)
final_content = "".join(contents)
return {
"message": final_content,
"success": True,
"code": status,
"code": status,
}
except Exception as e:
@@ -2357,7 +2329,7 @@ async def _webhook_impl(agent_id: str, is_test: bool):
"event": "error",
"message": str(e),
"error_type": type(e).__name__,
}
},
)
append_webhook_trace(
agent_id,
@@ -2366,9 +2338,9 @@ async def _webhook_impl(agent_id: str, is_test: bool):
"event": "finished",
"elapsed_time": time.time() - start_ts,
"success": False,
}
},
)
return {"code": 400, "message": str(e),"success":False}
return {"code": 400, "message": str(e), "success": False}
result = await sse()
return Response(
@@ -2401,6 +2373,7 @@ async def webhook_trace(agent_id: str):
if encode_webhook_id(ts) == enc_id:
return ts
return None
since_ts = request.args.get("since_ts", type=float)
webhook_id = request.args.get("webhook_id")
@@ -2434,9 +2407,7 @@ async def webhook_trace(agent_id: str):
webhooks = obj.get("webhooks", {})
if webhook_id is None:
candidates = [
float(k) for k in webhooks.keys() if float(k) > since_ts
]
candidates = [float(k) for k in webhooks.keys() if float(k) > since_ts]
if not candidates:
return get_json_result(
@@ -2492,6 +2463,7 @@ async def webhook_trace(agent_id: str):
}
)
@manager.route("/agents/attachments/<attachment_id>/download", methods=["GET"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs

View File

@@ -34,10 +34,12 @@ from peewee import fn
class CanvasTemplateService(CommonService):
model = CanvasTemplate
class DataFlowTemplateService(CommonService):
"""
Alias of CanvasTemplateService
"""
model = CanvasTemplate
@@ -46,8 +48,7 @@ class UserCanvasService(CommonService):
@classmethod
@DB.connection_context()
def get_list(cls, tenant_id,
page_number, items_per_page, orderby, desc, id, title, canvas_category=CanvasCategory.Agent):
def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, title, canvas_category=CanvasCategory.Agent):
agents = cls.model.select()
if id:
agents = agents.where(cls.model.id == id)
@@ -68,20 +69,9 @@ class UserCanvasService(CommonService):
@DB.connection_context()
def get_all_agents_by_tenant_ids(cls, tenant_ids, user_id):
# will get all permitted agents, be cautious
fields = [
cls.model.id,
cls.model.avatar,
cls.model.title,
cls.model.permission,
cls.model.canvas_type,
cls.model.canvas_category
]
fields = [cls.model.id, cls.model.avatar, cls.model.title, cls.model.permission, cls.model.canvas_type, cls.model.canvas_category]
# find team agents and owned agents
agents = cls.model.select(*fields).where(
(cls.model.user_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (
cls.model.user_id == user_id
)
)
agents = cls.model.select(*fields).where((cls.model.user_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id))
# sort by create_time, asc
agents.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later
@@ -100,7 +90,6 @@ class UserCanvasService(CommonService):
@DB.connection_context()
def get_by_canvas_id(cls, pid):
try:
fields = [
cls.model.id,
cls.model.avatar,
@@ -115,11 +104,9 @@ class UserCanvasService(CommonService):
cls.model.update_date,
cls.model.canvas_category,
User.nickname,
User.avatar.alias('tenant_avatar'),
User.avatar.alias("tenant_avatar"),
]
agents = cls.model.select(*fields) \
.join(User, on=(cls.model.user_id == User.id)) \
.where(cls.model.id == pid)
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(cls.model.id == pid)
# obj = cls.model.query(id=pid)[0]
return True, agents.dicts()[0]
except Exception as e:
@@ -129,14 +116,7 @@ class UserCanvasService(CommonService):
@classmethod
@DB.connection_context()
def get_basic_info_by_canvas_ids(cls, canvas_id):
fields = [
cls.model.id,
cls.model.avatar,
cls.model.user_id,
cls.model.title,
cls.model.permission,
cls.model.canvas_category
]
fields = [cls.model.id, cls.model.avatar, cls.model.user_id, cls.model.title, cls.model.permission, cls.model.canvas_category]
return cls.model.select(*fields).where(cls.model.id.in_(canvas_id)).dicts()
@classmethod
@@ -162,20 +142,26 @@ class UserCanvasService(CommonService):
cls.model.permission,
cls.model.user_id.alias("tenant_id"),
User.nickname,
User.avatar.alias('tenant_avatar'),
User.avatar.alias("tenant_avatar"),
cls.model.update_time,
cls.model.canvas_type,
cls.model.canvas_category,
cls.model.tags,
]
if keywords:
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)),
(fn.LOWER(cls.model.title).contains(keywords.lower()))
agents = (
cls.model.select(*fields)
.join(User, on=(cls.model.user_id == User.id))
.where(
(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)),
(fn.LOWER(cls.model.title).contains(keywords.lower())),
)
)
else:
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id))
agents = (
cls.model.select(*fields)
.join(User, on=(cls.model.user_id == User.id))
.where((((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)))
)
if canvas_category:
agents = agents.where(cls.model.canvas_category == canvas_category)
@@ -201,7 +187,7 @@ class UserCanvasService(CommonService):
# Get latest release time for each canvas
if agents_list:
canvas_ids = [a['id'] for a in agents_list]
canvas_ids = [a["id"] for a in agents_list]
release_times = (
UserCanvasVersion.select(UserCanvasVersion.user_canvas_id, fn.MAX(UserCanvasVersion.create_time).alias("release_time"))
.where((UserCanvasVersion.user_canvas_id.in_(canvas_ids)) & (UserCanvasVersion.release))
@@ -210,7 +196,7 @@ class UserCanvasService(CommonService):
release_time_map = {r.user_canvas_id: r.release_time for r in release_times}
for agent in agents_list:
agent['release_time'] = release_time_map.get(agent['id'])
agent["release_time"] = release_time_map.get(agent["id"])
return agents_list, count
@@ -218,9 +204,7 @@ class UserCanvasService(CommonService):
@DB.connection_context()
def list_tags(cls, joined_tenant_ids, user_id, canvas_category=None):
"""Return {tag: agent_count} aggregated across agents visible to the user."""
query = cls.model.select(cls.model.tags).where(
((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)
)
query = cls.model.select(cls.model.tags).where(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id))
if canvas_category:
query = query.where(cls.model.canvas_category == canvas_category)
@@ -281,6 +265,7 @@ class UserCanvasService(CommonService):
@DB.connection_context()
def accessible(cls, canvas_id, tenant_id):
from api.db.services.user_service import UserTenantService
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
if not e:
return False
@@ -345,18 +330,15 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
conv = API4Conversation(**conv)
message_id = str(uuid4())
conv.message.append({
"role": "user",
"content": query,
"id": message_id,
"files": files
})
conv.message.append({"role": "user", "content": query, "id": message_id, "files": files})
txt = ""
run_kwargs = {
"query": query,
"files": files,
"user_id": user_id,
"inputs": inputs,
# Used by Canvas.run to correlate RAGFlow's Langfuse generations by session.
"session_id": session_id,
}
if chat_template_kwargs is not None:
run_kwargs["chat_template_kwargs"] = chat_template_kwargs
@@ -394,14 +376,7 @@ async def completion_openai(tenant_id, agent_id, question, session_id=None, stre
if stream:
completion_tokens = 0
try:
async for ans in completion(
tenant_id=tenant_id,
agent_id=agent_id,
session_id=session_id,
query=question,
user_id=user_id,
**kwargs
):
async for ans in completion(tenant_id=tenant_id, agent_id=agent_id, session_id=session_id, query=question, user_id=user_id, **kwargs):
if isinstance(ans, str):
try:
ans = json.loads(ans[5:]) # remove "data:"
@@ -417,14 +392,7 @@ async def completion_openai(tenant_id, agent_id, question, session_id=None, stre
completion_tokens += len(tiktoken_encoder.encode(content_piece))
openai_data = get_data_openai(
id=session_id or str(uuid4()),
model=agent_id,
content=content_piece,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
stream=True
)
openai_data = get_data_openai(id=session_id or str(uuid4()), model=agent_id, content=content_piece, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, stream=True)
if ans.get("data", {}).get("reference", None):
openai_data["choices"][0]["delta"]["reference"] = ans["data"]["reference"]
@@ -435,32 +403,29 @@ async def completion_openai(tenant_id, agent_id, question, session_id=None, stre
except Exception as e:
logging.exception(e)
yield "data: " + json.dumps(
get_data_openai(
id=session_id or str(uuid4()),
model=agent_id,
content=f"**ERROR**: {str(e)}",
finish_reason="stop",
prompt_tokens=prompt_tokens,
completion_tokens=len(tiktoken_encoder.encode(f"**ERROR**: {str(e)}")),
stream=True
),
ensure_ascii=False
) + "\n\n"
yield (
"data: "
+ json.dumps(
get_data_openai(
id=session_id or str(uuid4()),
model=agent_id,
content=f"**ERROR**: {str(e)}",
finish_reason="stop",
prompt_tokens=prompt_tokens,
completion_tokens=len(tiktoken_encoder.encode(f"**ERROR**: {str(e)}")),
stream=True,
),
ensure_ascii=False,
)
+ "\n\n"
)
yield "data: [DONE]\n\n"
else:
try:
all_content = ""
reference = {}
async for ans in completion(
tenant_id=tenant_id,
agent_id=agent_id,
session_id=session_id,
query=question,
user_id=user_id,
**kwargs
):
async for ans in completion(tenant_id=tenant_id, agent_id=agent_id, session_id=session_id, query=question, user_id=user_id, **kwargs):
if isinstance(ans, str):
ans = json.loads(ans[5:])
if ans.get("event") not in ["message", "message_end"]:
@@ -475,13 +440,7 @@ async def completion_openai(tenant_id, agent_id, question, session_id=None, stre
completion_tokens = len(tiktoken_encoder.encode(all_content))
openai_data = get_data_openai(
id=session_id or str(uuid4()),
model=agent_id,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
content=all_content,
finish_reason="stop",
param=None
id=session_id or str(uuid4()), model=agent_id, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, content=all_content, finish_reason="stop", param=None
)
if reference:
@@ -497,5 +456,5 @@ async def completion_openai(tenant_id, agent_id, question, session_id=None, stre
completion_tokens=len(tiktoken_encoder.encode(f"**ERROR**: {str(e)}")),
content=f"**ERROR**: {str(e)}",
finish_reason="stop",
param=None
param=None,
)

View File

@@ -27,7 +27,7 @@ from langfuse import propagate_attributes
from api.db.db_models import LLM
from api.db.services.common_service import CommonService
from api.db.services.tenant_llm_service import LLM4Tenant
from common.token_utils import num_tokens_from_string
from common.token_utils import num_tokens_from_string, record_run_token_usage, langfuse_run_attrs
class LLMService(CommonService):
@@ -39,11 +39,49 @@ class LLMBundle(LLM4Tenant):
super().__init__(tenant_id, model_config, lang, **kwargs)
def _start_langfuse_observation(self, **kwargs):
# Correlating attributes (session_id/user_id) let Langfuse group all of a
# turn's generations. They may come from this bundle (chat/dialog path) or,
# for agent runs whose bundles are created without them, from the per-run
# context installed by Canvas.run.
attrs = {}
if self.langfuse_session_id:
with propagate_attributes(session_id=self.langfuse_session_id):
attrs["session_id"] = self.langfuse_session_id
run_attrs = langfuse_run_attrs.get()
if run_attrs:
for k in ("session_id", "user_id"):
if run_attrs.get(k) and k not in attrs:
attrs[k] = run_attrs[k]
if attrs:
with propagate_attributes(**attrs):
return self.langfuse.start_observation(**kwargs)
return self.langfuse.start_observation(**kwargs)
def _reset_last_usage(self) -> None:
"""Clear the model's per-call usage so a failed call that returns before
updating it cannot leak the previous call's usage into this run."""
if hasattr(self.mdl, "last_usage"):
self.mdl.last_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
def _report_usage(self, total_tokens: int) -> dict:
"""Record a chat call's usage to the active agent run and return the
prompt/completion/total split for Langfuse.
``total_tokens`` is the authoritative total from the call. The prompt/completion
split is taken from the provider response (``mdl.last_usage``) only when it is
consistent with ``total_tokens`` (i.e. produced by this same call); otherwise the
split is reported as 0 while the total still aggregates correctly.
"""
split = getattr(self.mdl, "last_usage", None) or {}
prompt = int(split.get("prompt_tokens", 0) or 0)
completion = int(split.get("completion_tokens", 0) or 0)
if not total_tokens:
total_tokens = int(split.get("total_tokens", 0) or 0)
if (prompt + completion) != total_tokens:
# Stale or inconsistent split — keep the total, drop the unreliable split.
prompt, completion = 0, 0
record_run_token_usage(prompt, completion, total_tokens)
return {"input": prompt, "output": completion, "total": total_tokens}
def close(self):
"""Release resources held by this LLMBundle instance."""
super().close()
@@ -139,7 +177,9 @@ class LLMBundle(LLM4Tenant):
def similarity(self, query: str, texts: list):
if self.langfuse:
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts})
generation = self._start_langfuse_observation(
trace_context=self.trace_context, as_type="generation", name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts}
)
sim, used_tokens = self.mdl.similarity(query, texts)
logging.info("LLMBundle.similarity used_tokens: %d", used_tokens)
@@ -165,7 +205,9 @@ class LLMBundle(LLM4Tenant):
def describe_with_prompt(self, image, prompt):
if self.langfuse:
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt})
generation = self._start_langfuse_observation(
trace_context=self.trace_context, as_type="generation", name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt}
)
txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
logging.info("LLMBundle.describe_with_prompt used_tokens: %d", used_tokens)
@@ -194,7 +236,8 @@ class LLMBundle(LLM4Tenant):
supports_stream = hasattr(mdl, "stream_transcription") and callable(getattr(mdl, "stream_transcription"))
if supports_stream:
if self.langfuse:
generation = self._start_langfuse_observation(as_type="generation",
generation = self._start_langfuse_observation(
as_type="generation",
trace_context=self.trace_context,
name="stream_transcription",
metadata={"model": self.model_config["llm_name"]},
@@ -228,7 +271,8 @@ class LLMBundle(LLM4Tenant):
return
if self.langfuse:
generation = self._start_langfuse_observation(as_type="generation",
generation = self._start_langfuse_observation(
as_type="generation",
trace_context=self.trace_context,
name="stream_transcription",
metadata={"model": self.model_config["llm_name"]},
@@ -377,11 +421,14 @@ class LLMBundle(LLM4Tenant):
generation = None
if self.langfuse:
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history})
generation = self._start_langfuse_observation(
trace_context=self.trace_context, as_type="generation", name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history}
)
chat_partial = partial(base_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
self._reset_last_usage()
try:
txt, used_tokens = await chat_partial(**use_kwargs)
except Exception as e:
@@ -397,8 +444,10 @@ class LLMBundle(LLM4Tenant):
if used_tokens:
logging.info("LLMBundle.async_chat used_tokens: %d", used_tokens)
usage_details = self._report_usage(used_tokens)
if generation:
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
generation.update(output={"output": txt}, usage_details=usage_details)
generation.end()
return txt
@@ -418,11 +467,14 @@ class LLMBundle(LLM4Tenant):
generation = None
if self.langfuse:
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
generation = self._start_langfuse_observation(
trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history}
)
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
self._reset_last_usage()
try:
async for txt in chat_partial(**use_kwargs):
if isinstance(txt, int):
@@ -444,8 +496,9 @@ class LLMBundle(LLM4Tenant):
raise
if total_tokens:
logging.info("LLMBundle.async_chat_streamly used_tokens: %d", total_tokens)
usage_details = self._report_usage(total_tokens)
if generation:
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
generation.update(output={"output": ans}, usage_details=usage_details)
generation.end()
return
@@ -461,11 +514,14 @@ class LLMBundle(LLM4Tenant):
generation = None
if self.langfuse:
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
generation = self._start_langfuse_observation(
trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history}
)
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
self._reset_last_usage()
try:
async for txt in chat_partial(**use_kwargs):
if isinstance(txt, int):
@@ -487,7 +543,8 @@ class LLMBundle(LLM4Tenant):
raise
if total_tokens:
logging.info("LLMBundle.async_chat_streamly_delta used_tokens: %d", total_tokens)
usage_details = self._report_usage(total_tokens)
if generation:
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
generation.update(output={"output": ans}, usage_details=usage_details)
generation.end()
return

View File

@@ -14,9 +14,12 @@
# limitations under the License.
#
import contextvars
import hashlib
import logging
import os
import shutil
import threading
import tiktoken
from common.file_utils import get_project_base_directory
@@ -42,6 +45,84 @@ os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
encoder = tiktoken.get_encoding("cl100k_base")
# Per-run token usage sink. An agent run (Canvas.run) installs a mutable dict here
# at the start of each turn; every LLMBundle chat call adds its provider-reported
# usage to it. This is the single chokepoint that aggregates token usage across all
# LLM calls in a run (query rewriting, cross-language translation, tool reasoning,
# and the final streamed answer) regardless of which component or helper issued the
# call. Default None means "not inside a tracked run" and callers must no-op.
token_usage_sink: contextvars.ContextVar = contextvars.ContextVar("ragflow_token_usage_sink", default=None)
# Per-run Langfuse correlating attributes (e.g. {"session_id": ..., "user_id": ...}).
# Installed by Canvas.run so RAGFlow's own Langfuse generations can be grouped by
# session and user even though the agent's LLMBundles are created without them.
langfuse_run_attrs: contextvars.ContextVar = contextvars.ContextVar("ragflow_langfuse_run_attrs", default=None)
# Guards sink mutations: concurrent tool calls (asyncio.gather + thread_pool_exec,
# which copies the context so worker threads share the same sink dict) can otherwise
# race on the read-modify-write of the counters.
_sink_lock = threading.Lock()
def record_run_token_usage(prompt_tokens: int = 0, completion_tokens: int = 0, total_tokens: int = 0) -> None:
"""Add a single LLM call's token usage to the active run sink, if any.
Safe to call from anywhere: when no run sink is installed it does nothing.
"""
sink = token_usage_sink.get()
if sink is None:
return
try:
with _sink_lock:
sink["prompt_tokens"] += int(prompt_tokens or 0)
sink["completion_tokens"] += int(completion_tokens or 0)
sink["total_tokens"] += int(total_tokens or 0)
sink["calls"] += 1
except Exception:
# Never let usage bookkeeping break a request; log at debug so a malformed
# sink or token value is still traceable without adding noise.
logging.debug("Failed to record run token usage", exc_info=True)
def usage_from_response(resp) -> dict:
"""Extract a {prompt_tokens, completion_tokens, total_tokens} split from an LLM response.
Handles OpenAI/OpenRouter-style ``resp.usage`` objects and dict variants. Missing
fields default to 0; ``total_tokens`` falls back to prompt+completion when absent.
"""
out = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
if resp is None:
return out
usage = None
try:
usage = getattr(resp, "usage", None)
if usage is None and isinstance(resp, dict):
usage = resp.get("usage")
except Exception:
usage = None
if usage is None:
return out
def _get(obj, *names):
for n in names:
try:
v = obj.get(n) if isinstance(obj, dict) else getattr(obj, n, None)
except Exception:
v = None
if v:
return int(v)
return 0
out["prompt_tokens"] = _get(usage, "prompt_tokens", "input_tokens")
out["completion_tokens"] = _get(usage, "completion_tokens", "output_tokens")
out["total_tokens"] = _get(usage, "total_tokens")
if not out["total_tokens"]:
out["total_tokens"] = out["prompt_tokens"] + out["completion_tokens"]
return out
def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string."""
try:
@@ -50,6 +131,7 @@ def num_tokens_from_string(string: str) -> int:
except Exception:
return 0
def total_token_count_from_response(resp):
"""
Extract token count from LLM response in various formats.
@@ -78,19 +160,19 @@ def total_token_count_from_response(resp):
except Exception:
pass
if isinstance(resp, dict) and 'usage' in resp and 'total_tokens' in resp['usage']:
if isinstance(resp, dict) and "usage" in resp and "total_tokens" in resp["usage"]:
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
if isinstance(resp, dict) and 'usage' in resp and 'input_tokens' in resp['usage'] and 'output_tokens' in resp['usage']:
if isinstance(resp, dict) and "usage" in resp and "input_tokens" in resp["usage"] and "output_tokens" in resp["usage"]:
try:
return resp["usage"]["input_tokens"] + resp["usage"]["output_tokens"]
except Exception:
pass
if isinstance(resp, dict) and 'meta' in resp and 'tokens' in resp['meta'] and 'input_tokens' in resp['meta']['tokens'] and 'output_tokens' in resp['meta']['tokens']:
if isinstance(resp, dict) and "meta" in resp and "tokens" in resp["meta"] and "input_tokens" in resp["meta"]["tokens"] and "output_tokens" in resp["meta"]["tokens"]:
try:
return resp["meta"]["tokens"]["input_tokens"] + resp["meta"]["tokens"]["output_tokens"]
except Exception:

View File

@@ -32,7 +32,7 @@ from openai import AsyncOpenAI, OpenAI
from enum import StrEnum
from common.misc_utils import thread_pool_exec
from common.token_utils import num_tokens_from_string, total_token_count_from_response
from common.token_utils import num_tokens_from_string, total_token_count_from_response, usage_from_response
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
from rag.llm.key_utils import _normalize_replicate_key
from rag.llm.tool_decorator import FunctionToolSession, is_tool
@@ -184,11 +184,7 @@ def _apply_model_family_policies(
sanitized_gen_conf["n"] = 1
sanitized_gen_conf["presence_penalty"] = 0.0
sanitized_gen_conf["frequency_penalty"] = 0.0
elif (
provider == SupportedLiteLLMProvider.ZHIPU_AI
and "glm" in model_name_lower
and thinking_type
):
elif provider == SupportedLiteLLMProvider.ZHIPU_AI and "glm" in model_name_lower and thinking_type:
_pop_thinking_controls()
sanitized_gen_conf["thinking"] = {"type": thinking_type}
@@ -217,6 +213,7 @@ def _move_litellm_provider_body_fields(provider: SupportedLiteLLMProvider | str
completion_args["extra_body"] = body
return completion_args
class Base(ABC):
def __init__(self, key, model_name, base_url, **kwargs):
timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600))
@@ -230,6 +227,9 @@ class Base(ABC):
self.is_tools = False
self.tools = []
self.toolcall_sessions = {}
# Token usage split (prompt/completion/total) of the most recent chat call.
# Consumed by LLMBundle for accurate Langfuse reporting and run aggregation.
self.last_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
def _get_delay(self):
return self.base_delay * random.uniform(10, 150)
@@ -313,6 +313,8 @@ class Base(ABC):
gen_conf = self._clean_conf(gen_conf)
ans = ""
total_tokens = 0
# Reset so a stale split from a previous call can't leak into this one.
self.last_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
for attempt in range(self.max_retries + 1):
try:
@@ -478,6 +480,18 @@ class Base(ABC):
ans = ""
tk_count = 0
# Aggregate prompt/completion/total across all tool-calling rounds.
agg_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
def _add_round_usage(resp):
nonlocal tk_count
u = usage_from_response(resp)
agg_usage["prompt_tokens"] += u["prompt_tokens"]
agg_usage["completion_tokens"] += u["completion_tokens"]
agg_usage["total_tokens"] += u["total_tokens"] or total_token_count_from_response(resp)
tk_count = agg_usage["total_tokens"]
self.last_usage = dict(agg_usage)
hist = deepcopy(history)
for attempt in range(self.max_retries + 1):
history = deepcopy(hist)
@@ -485,7 +499,7 @@ class Base(ABC):
for _ in range(self.max_rounds + 1):
logging.info(f"{self.tools=}")
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf, **extra_request_kwargs)
tk_count += total_token_count_from_response(response)
_add_round_usage(response)
if not response.choices or not response.choices[0].message:
raise Exception(f"500 response structure error. Response: {response}")
@@ -505,9 +519,7 @@ class Base(ABC):
try:
args = json_repair.loads(tc.function.arguments)
if not isinstance(args, dict):
raise TypeError(
f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}"
)
raise TypeError(f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}")
if hasattr(self.toolcall_session, "tool_call_async"):
result = await self.toolcall_session.tool_call_async(name, args)
else:
@@ -527,7 +539,13 @@ class Base(ABC):
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response, token_count = await self._async_chat(history, gen_conf)
ans += response
tk_count += token_count
# _async_chat set self.last_usage to its own call; fold it into the aggregate.
_fb = getattr(self, "last_usage", None) or {}
agg_usage["prompt_tokens"] += int(_fb.get("prompt_tokens", 0) or 0)
agg_usage["completion_tokens"] += int(_fb.get("completion_tokens", 0) or 0)
agg_usage["total_tokens"] += int(_fb.get("total_tokens", 0) or token_count)
tk_count = agg_usage["total_tokens"]
self.last_usage = dict(agg_usage)
return ans, tk_count
except Exception as e:
e = await self._exceptions_async(e, attempt)
@@ -550,8 +568,23 @@ class Base(ABC):
history.insert(0, {"role": "system", "content": system})
total_tokens = 0
# Aggregate prompt/completion/total across all tool-calling rounds. The split is
# captured opportunistically when the provider reports usage on a chunk; otherwise
# only the (estimated) total accumulates.
agg_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
hist = deepcopy(history)
def _commit_round(round_usage, round_estimate):
nonlocal total_tokens
if round_usage and round_usage["total_tokens"]:
agg_usage["prompt_tokens"] += round_usage["prompt_tokens"]
agg_usage["completion_tokens"] += round_usage["completion_tokens"]
agg_usage["total_tokens"] += round_usage["total_tokens"]
else:
agg_usage["total_tokens"] += round_estimate
total_tokens = agg_usage["total_tokens"]
self.last_usage = dict(agg_usage)
for attempt in range(self.max_retries + 1):
history = deepcopy(hist)
try:
@@ -559,12 +592,20 @@ class Base(ABC):
reasoning_start = False
logging.info(f"[ToolLoop] round={_round} model={self.model_name} tools={[t['function']['name'] for t in tools]}")
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf, **extra_request_kwargs)
response = await self.async_client.chat.completions.create(
model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf, **extra_request_kwargs
)
final_tool_calls = {}
answer = ""
round_estimate = 0
round_usage = None
async for resp in response:
_u = usage_from_response(resp)
if _u["total_tokens"]:
round_usage = _u
if not hasattr(resp, "choices") or not resp.choices:
continue
@@ -597,16 +638,17 @@ class Base(ABC):
answer += delta.content
yield delta.content
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(delta.content)
else:
total_tokens = tol
if not _u["total_tokens"]:
round_estimate += num_tokens_from_string(delta.content)
finish_reason = getattr(resp.choices[0], "finish_reason", "")
if finish_reason == "length":
yield self._length_stop("")
# Commit this round's tokens (each round is a separate provider
# request — accumulate, never overwrite).
_commit_round(round_usage, round_estimate)
if answer and not final_tool_calls:
logging.info(f"[ToolLoop] round={_round} completed with text response, exiting")
yield total_tokens
@@ -617,9 +659,7 @@ class Base(ABC):
try:
args = json_repair.loads(tc.function.arguments)
if not isinstance(args, dict):
raise TypeError(
f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}"
)
raise TypeError(f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}")
if hasattr(self.toolcall_session, "tool_call_async"):
result = await self.toolcall_session.tool_call_async(name, args)
else:
@@ -655,19 +695,22 @@ class Base(ABC):
**extra_request_kwargs,
)
fb_estimate = 0
fb_usage = None
async for resp in response:
_u = usage_from_response(resp)
if _u["total_tokens"]:
fb_usage = _u
if not hasattr(resp, "choices") or not resp.choices:
continue
delta = resp.choices[0].delta
if not hasattr(delta, "content") or delta.content is None:
continue
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(delta.content)
else:
total_tokens = tol
if not _u["total_tokens"]:
fb_estimate += num_tokens_from_string(delta.content)
yield delta.content
_commit_round(fb_usage, fb_estimate)
yield total_tokens
return
@@ -708,6 +751,8 @@ class Base(ABC):
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
# Capture prompt/completion split for accurate Langfuse + run aggregation.
self.last_usage = usage_from_response(response)
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
return "", 0
ans = response.choices[0].message.content.strip()
@@ -1075,9 +1120,7 @@ class ReplicateChat(Base):
msgs = [{"role": "system", "content": system}]
system_msg = msgs[0]["content"] if msgs and msgs[0].get("role") == "system" else ""
prompt = "\n".join(
[item["role"] + ":" + item["content"] for item in msgs[-5:] if item.get("role") != "system"]
)
prompt = "\n".join([item["role"] + ":" + item["content"] for item in msgs[-5:] if item.get("role") != "system"])
try:
response = self.client.run(
self.model_name,
@@ -1550,6 +1593,9 @@ class LiteLLMBase(ABC):
self.is_tools = False
self.tools = []
self.toolcall_sessions = {}
# Token usage split (prompt/completion/total) of the most recent chat call.
# Consumed by LLMBundle for accurate Langfuse reporting and run aggregation.
self.last_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
# Factory specific fields
if self.provider == SupportedLiteLLMProvider.OpenRouter:
@@ -1638,6 +1684,9 @@ class LiteLLMBase(ABC):
timeout=self.timeout,
)
# Capture the prompt/completion split for accurate per-call usage
# reporting (Langfuse + agent run aggregation).
self.last_usage = usage_from_response(response)
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
return "", 0
ans = response.choices[0].message.content.strip()
@@ -1659,11 +1708,16 @@ class LiteLLMBase(ABC):
gen_conf = self._clean_conf(gen_conf)
reasoning_start = False
total_tokens = 0
# Reset so a stale split from a previous call can't leak into this one.
self.last_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
stop = kwargs.get("stop")
if stop:
completion_args["stop"] = stop
# Ask the provider to include authoritative usage in the final streaming chunk.
# drop_params=True ensures this is silently ignored by providers that don't support it.
completion_args.setdefault("stream_options", {})["include_usage"] = True
for attempt in range(self.max_retries + 1):
try:
@@ -1674,6 +1728,14 @@ class LiteLLMBase(ABC):
)
async for resp in stream:
# Authoritative usage may arrive on a usage-only final chunk that
# carries no choices (OpenAI/OpenRouter with include_usage). Read it
# before the choices guard so the prompt/completion split is captured.
_usage = usage_from_response(resp)
if _usage["total_tokens"]:
total_tokens = _usage["total_tokens"]
self.last_usage = _usage
if not hasattr(resp, "choices") or not resp.choices:
continue
@@ -1692,10 +1754,9 @@ class LiteLLMBase(ABC):
reasoning_start = False
ans = delta.content
tol = total_token_count_from_response(resp)
if not tol:
tol = num_tokens_from_string(delta.content)
total_tokens += tol
if not _usage["total_tokens"]:
# No authoritative usage yet: keep a running estimate as fallback.
total_tokens += num_tokens_from_string(delta.content)
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
if finish_reason == "length":
@@ -1745,11 +1806,15 @@ class LiteLLMBase(ABC):
return msg
def _verbose_tool_use(self, name, args, res):
return "<tool_call>" + json.dumps(
{"name": name, "args": args, "result": str(res) if isinstance(res, Exception) else res},
ensure_ascii=False,
indent=2,
) + "</tool_call>"
return (
"<tool_call>"
+ json.dumps(
{"name": name, "args": args, "result": str(res) if isinstance(res, Exception) else res},
ensure_ascii=False,
indent=2,
)
+ "</tool_call>"
)
def _append_history(self, hist, tool_call, tool_res, reasoning_content=None):
assistant_msg = {
@@ -1844,6 +1909,19 @@ class LiteLLMBase(ABC):
ans = ""
tk_count = 0
# Aggregate prompt/completion/total across every tool-calling round so the
# whole multi-round exchange is reported once with a correct split.
agg_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
def _add_usage(resp):
nonlocal tk_count
u = usage_from_response(resp)
agg_usage["prompt_tokens"] += u["prompt_tokens"]
agg_usage["completion_tokens"] += u["completion_tokens"]
agg_usage["total_tokens"] += u["total_tokens"] or total_token_count_from_response(resp)
tk_count = agg_usage["total_tokens"]
self.last_usage = dict(agg_usage)
hist = deepcopy(history)
for attempt in range(self.max_retries + 1):
history = deepcopy(hist)
@@ -1858,7 +1936,7 @@ class LiteLLMBase(ABC):
timeout=self.timeout,
)
tk_count += total_token_count_from_response(response)
_add_usage(response)
if not hasattr(response, "choices") or not response.choices or not response.choices[0].message:
raise Exception(f"500 response structure error. Response: {response}")
@@ -1906,7 +1984,13 @@ class LiteLLMBase(ABC):
response, token_count = await self.async_chat("", history, gen_conf)
ans += response
tk_count += token_count
# self.async_chat set self.last_usage to its own call; fold it into the aggregate.
_fb = getattr(self, "last_usage", None) or {}
agg_usage["prompt_tokens"] += int(_fb.get("prompt_tokens", 0) or 0)
agg_usage["completion_tokens"] += int(_fb.get("completion_tokens", 0) or 0)
agg_usage["total_tokens"] += int(_fb.get("total_tokens", 0) or token_count)
tk_count = agg_usage["total_tokens"]
self.last_usage = dict(agg_usage)
return ans, tk_count
except Exception as e:
@@ -1924,8 +2008,23 @@ class LiteLLMBase(ABC):
history.insert(0, {"role": "system", "content": system})
total_tokens = 0
# Aggregate usage across every tool-calling round (each round is a separate
# provider request). Committing per round avoids the previous bug where a later
# round's total overwrote earlier rounds.
agg_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
hist = deepcopy(history)
def _commit_round(round_usage, round_estimate):
nonlocal total_tokens
if round_usage and round_usage["total_tokens"]:
agg_usage["prompt_tokens"] += round_usage["prompt_tokens"]
agg_usage["completion_tokens"] += round_usage["completion_tokens"]
agg_usage["total_tokens"] += round_usage["total_tokens"]
else:
agg_usage["total_tokens"] += round_estimate
total_tokens = agg_usage["total_tokens"]
self.last_usage = dict(agg_usage)
for attempt in range(self.max_retries + 1):
history = deepcopy(hist)
try:
@@ -1935,6 +2034,8 @@ class LiteLLMBase(ABC):
logging.info(f"[ToolLoop] round={_round} model={self.model_name} tools={[t['function']['name'] for t in tools]}")
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
# Request authoritative usage on the final streaming chunk.
completion_args.setdefault("stream_options", {})["include_usage"] = True
response = await litellm.acompletion(
**completion_args,
drop_params=True,
@@ -1943,8 +2044,15 @@ class LiteLLMBase(ABC):
final_tool_calls = {}
answer = ""
round_usage = None
round_estimate = 0
async for resp in response:
# Usage-only final chunk may carry no choices — read it first.
_u = usage_from_response(resp)
if _u["total_tokens"]:
round_usage = _u
if not hasattr(resp, "choices") or not resp.choices:
continue
@@ -1979,16 +2087,16 @@ class LiteLLMBase(ABC):
answer += delta.content
yield delta.content
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(delta.content)
else:
total_tokens = tol
if not _u["total_tokens"]:
round_estimate += num_tokens_from_string(delta.content)
finish_reason = getattr(resp.choices[0], "finish_reason", "")
if finish_reason == "length":
yield self._length_stop("")
# Commit this round's tokens to the running aggregate.
_commit_round(round_usage, round_estimate)
if answer and not final_tool_calls:
logging.info(f"[ToolLoop] round={_round} completed with text response, exiting")
yield total_tokens
@@ -2030,25 +2138,29 @@ class LiteLLMBase(ABC):
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
completion_args.setdefault("stream_options", {})["include_usage"] = True
response = await litellm.acompletion(
**completion_args,
drop_params=True,
timeout=self.timeout,
)
fb_usage = None
fb_estimate = 0
async for resp in response:
_u = usage_from_response(resp)
if _u["total_tokens"]:
fb_usage = _u
if not hasattr(resp, "choices") or not resp.choices:
continue
delta = resp.choices[0].delta
if not hasattr(delta, "content") or delta.content is None:
continue
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(delta.content)
else:
total_tokens = tol
if not _u["total_tokens"]:
fb_estimate += num_tokens_from_string(delta.content)
yield delta.content
_commit_round(fb_usage, fb_estimate)
yield total_tokens
return
@@ -2157,7 +2269,7 @@ class LiteLLMBase(ABC):
if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers:
extra_headers["Authorization"] = f"Bearer {self.api_key}"
# MiniMax requires GroupId as a query parameter for API authentication
if self.provider == SupportedLiteLLMProvider.MiniMax and hasattr(self, 'group_id') and self.group_id:
if self.provider == SupportedLiteLLMProvider.MiniMax and hasattr(self, "group_id") and self.group_id:
api_base = completion_args.get("api_base", self.base_url)
separator = "&" if "?" in api_base else "?"
completion_args["api_base"] = f"{api_base}{separator}GroupId={self.group_id}"