From 742188c3bb37af091c21d4d2a0b6bacecf66bf61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96ndery?= Date: Thu, 2 Jul 2026 04:35:28 +0300 Subject: [PATCH] feat(agent): report accurate aggregated token usage and propagate session/user + input/output to Langfuse for agent runs (#16420) ### What problem does this PR solve? _Briefly describe what this PR aims to solve. Include background context that will help reviewers understand the purpose of the PR._ ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) - [x] Other (please describe): ## Summary Agent (Canvas) runs previously did not surface token usage in the SSE stream, and RAGFlow's own Langfuse generations for agent runs were missing the prompt/completion split and the session/user correlation. This made it impossible for an external caller (or Langfuse) to reconcile an agent turn's cost with the upstream provider (e.g. OpenRouter), because a single turn can issue several distinct LLM calls (query rewriting / cross-language translation, multi-round tool reasoning, nested sub-agents, and the final answer). This PR introduces a per-run token usage sink so that **every** LLM call in a run is aggregated and reported once, and enriches Langfuse generations with the prompt/completion split plus session/user attributes. ## What changes ### 1. Per-run token usage sink (`common/token_utils.py`) - Adds two `contextvars`: `token_usage_sink` (a mutable per-run accumulator) and `langfuse_run_attrs` (session_id/user_id for the run). - Adds `record_run_token_usage(...)` (thread-safe via a lock, because `thread_pool_exec` copies the context into worker threads that share the sink dict) and `usage_from_response(...)` which extracts a `{prompt_tokens, completion_tokens, total_tokens}` split from OpenAI/OpenRouter-style responses. ### 2. Provider layer captures the prompt/completion split (`rag/llm/chat_model.py`) - `LiteLLMBase` and `Base` now store `self.last_usage` (prompt/completion/total) for the most recent chat call, in both the plain and tool-calling paths. - Streaming requests set `stream_options.include_usage = True` (LiteLLM path) so the authoritative usage arrives on the final chunk; this is read even on the usage-only chunk that carries no `choices`. - Fixes a multi-round accounting bug in `*_with_tools`: token totals were **overwritten** by each round (`total_tokens = tol`) instead of accumulated, undercounting multi-round tool conversations. Each round is now committed to a running aggregate. ### 3. LLMBundle reports usage once, per call (`api/db/services/llm_service.py`) - New `_report_usage(total_tokens)` records the call's usage into the active run sink and returns the prompt/completion/total split for Langfuse. The split is only used when it is consistent with the authoritative total; otherwise only the total is reported. - All three chat entry points (`async_chat`, `async_chat_streamly`, `async_chat_streamly_delta`) now emit `usage_details` with `input`/`output`/`total` instead of total-only. - `_start_langfuse_observation` now applies `session_id`/`user_id` from the per-run context (`langfuse_run_attrs`) so agent-run generations are correctly grouped, even though agent LLMBundles are constructed without those attributes. ### 4. Canvas installs the sink and emits the aggregate (`agent/canvas.py`) - `Canvas.run()` installs a fresh `token_usage_sink` and `langfuse_run_attrs` (from `user_id`/`session_id`) at the start of every turn. - `message_end` now includes an aggregated `usage` object: `{prompt_tokens, completion_tokens, total_tokens, calls}` covering all LLM calls in the run. ### 5. Pass session id into the run (`api/db/services/canvas_service.py`) - `completion()` forwards `session_id` to `Canvas.run()` for Langfuse session correlation. ## Why a context variable LLM calls in an agent run originate from many places that each build their own `LLMBundle` (e.g. `cross_languages`/`keyword_extraction` helpers, the Agent component, and nested sub-agents invoked as tools). A run-scoped context variable is the only non-invasive chokepoint that captures all of them exactly once, including nested agents (which run in the same async context) and thread-pool tools (the executor copies the context). ## Behavior / compatibility - No public API or wire-format removal: `message_end` gains an additional optional `usage` field; existing consumers are unaffected. - When a provider does not return authoritative usage, behavior falls back to the previous token estimate (total only, no split). - Non-agent flows (Dataflow `Pipeline`, sync `Graph.run`) are untouched. ## Testing - [x] Simple agent answer: `message_end.usage.total_tokens` matches provider usage. - [x] Agent with cross-language retrieval: aggregate equals the sum of both provider calls. - [x] Tool-calling agent (multi-round): total accumulates across rounds. - [x] Nested agent (agent-as-tool): sub-agent tokens included in the parent run total. - [x] Langfuse: agent generations show input/output split and are grouped by session/user. --------- Co-authored-by: yzc Co-authored-by: Cursor --- agent/canvas.py | 315 +++++++++++++++++------------ api/apps/restful_apis/agent_api.py | 244 ++++++++++------------ api/db/services/canvas_service.py | 137 +++++-------- api/db/services/llm_service.py | 81 ++++++-- common/token_utils.py | 88 +++++++- rag/llm/chat_model.py | 212 ++++++++++++++----- 6 files changed, 655 insertions(+), 422 deletions(-) diff --git a/agent/canvas.py b/agent/canvas.py index 15f5cf0449..0a9d8cd486 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -15,6 +15,7 @@ # import asyncio import base64 +import contextvars import datetime import inspect import json @@ -36,52 +37,54 @@ from api.db.joint_services.tenant_model_service import get_tenant_default_model_ from common.constants import LLMType from common.misc_utils import get_uuid, hash_str2int from common.exceptions import TaskCanceledException +from common.token_utils import token_usage_sink, langfuse_run_attrs from rag.prompts.generator import chunks_format from rag.utils.redis_conn import REDIS_CONN from rag.utils.tts_cache import synthesize_with_cache _logger = logging.getLogger(__name__) + class Graph: """ - dsl = { - "components": { - "begin": { - "obj":{ - "component_name": "Begin", - "params": {}, - }, - "downstream": ["answer_0"], - "upstream": [], + dsl = { + "components": { + "begin": { + "obj":{ + "component_name": "Begin", + "params": {}, }, - "retrieval_0": { - "obj": { - "component_name": "Retrieval", - "params": {} - }, - "downstream": ["generate_0"], - "upstream": ["answer_0"], - }, - "generate_0": { - "obj": { - "component_name": "Generate", - "params": {} - }, - "downstream": ["answer_0"], - "upstream": ["retrieval_0"], - } + "downstream": ["answer_0"], + "upstream": [], }, - "history": [], - "path": ["begin"], - "retrieval": {"chunks": [], "doc_aggs": []}, - "globals": { - "sys.query": "", - "sys.user_id": tenant_id, - "sys.conversation_turns": 0, - "sys.files": [] + "retrieval_0": { + "obj": { + "component_name": "Retrieval", + "params": {} + }, + "downstream": ["generate_0"], + "upstream": ["answer_0"], + }, + "generate_0": { + "obj": { + "component_name": "Generate", + "params": {} + }, + "downstream": ["answer_0"], + "upstream": ["retrieval_0"], } + }, + "history": [], + "path": ["begin"], + "retrieval": {"chunks": [], "doc_aggs": []}, + "globals": { + "sys.query": "", + "sys.user_id": tenant_id, + "sys.conversation_turns": 0, + "sys.files": [] } - """ + } + """ def __init__(self, dsl: str, tenant_id=None, task_id=None, custom_header=None): self.path = [] @@ -115,9 +118,7 @@ class Graph: def __str__(self): self.dsl["path"] = self.path self.dsl["task_id"] = self.task_id - dsl = { - "components": {} - } + dsl = {"components": {}} for k in self.dsl.keys(): if k in ["components"]: continue @@ -139,11 +140,13 @@ class Graph: except Exception as e: logging.warning("Graph.__str__: deepcopy failed for component '%s' key '%s' (type=%s): %s. Using shallow reference.", k, c, type(cpn[c]).__name__, e) dsl["components"][k][c] = cpn[c] + def _serialize_default(obj): if callable(obj): return None logging.warning("Graph.__str__: JSON fallback via str() for type=%s", type(obj).__name__) return str(obj) + return json.dumps(dsl, ensure_ascii=False, default=_serialize_default) def reset(self): @@ -180,13 +183,13 @@ class Graph: def get_tenant_id(self): return self._tenant_id - def get_value_with_variable(self,value: str) -> Any: + def get_value_with_variable(self, value: str) -> Any: pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*") out_parts = [] last = 0 for m in pat.finditer(value): - out_parts.append(value[last:m.start()]) + out_parts.append(value[last : m.start()]) key = m.group(1) v = self.get_variable_value(key) if v is None: @@ -205,7 +208,7 @@ class Graph: last = m.end() out_parts.append(value[last:]) - return("".join(out_parts)) + return "".join(out_parts) def get_variable_value(self, exp: str) -> Any: exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}") @@ -222,13 +225,13 @@ class Graph: if not rest: return root_val - return self.get_variable_param_value(root_val,rest) + return self.get_variable_param_value(root_val, rest) def get_variable_param_value(self, obj: Any, path: str) -> Any: cur = obj if not path: return cur - for key in path.split('.'): + for key in path.split("."): if cur is None: return None @@ -253,7 +256,7 @@ class Graph: cur = getattr(cur, key, None) return cur - def set_variable_value(self, exp: str,value): + def set_variable_value(self, exp: str, value): exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}") if exp.find("@") < 0: self.globals[exp] = value @@ -271,11 +274,11 @@ class Graph: root_val = cpn["obj"].output(root_key) if not root_val: root_val = {} - cpn["obj"].set_output(root_key, self.set_variable_param_value(root_val,rest,value)) + cpn["obj"].set_output(root_key, self.set_variable_param_value(root_val, rest, value)) def set_variable_param_value(self, obj: Any, path: str, value) -> Any: cur = obj - keys = path.split('.') + keys = path.split(".") if not path: return value for key in keys[:-1]: @@ -298,7 +301,6 @@ class Graph: class Canvas(Graph): - def __init__(self, dsl: str, tenant_id=None, task_id=None, canvas_id=None, custom_header=None): self.globals = { "sys.query": "", @@ -306,9 +308,14 @@ class Canvas(Graph): "sys.conversation_turns": 0, "sys.files": [], "sys.history": [], - "sys.date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S") + "sys.date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), } self.variables = {} + # Aggregated provider token usage (prompt/completion/total) across every LLM + # call in a single run — query rewriting, cross-language translation, tool + # reasoning and the final answer. Populated via the token_usage_sink context + # variable that each LLMBundle chat call writes to. Reset at run() start. + self._run_token_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "calls": 0} super().__init__(dsl, tenant_id, task_id, custom_header=custom_header) self._id = canvas_id @@ -323,13 +330,13 @@ class Canvas(Graph): self.globals["sys.date"] = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S") else: self.globals = { - "sys.query": "", - "sys.user_id": "", - "sys.conversation_turns": 0, - "sys.files": [], - "sys.history": [], - "sys.date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S") - } + "sys.query": "", + "sys.user_id": "", + "sys.conversation_turns": 0, + "sys.files": [], + "sys.history": [], + "sys.date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), + } if "variables" in self.dsl: self.variables = self.dsl["variables"] else: @@ -393,6 +400,38 @@ class Canvas(Graph): self.globals[k] = "" async def run(self, **kwargs): + # Install a fresh per-run token usage sink and Langfuse correlation context, + # and guarantee both are torn down when the run ends (even on early return or + # exception) so later LLM calls in the same task never inherit a previous + # run's sink or session/user attributes. + self._run_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "calls": 0} + _lf_attrs = {} + _user_id = kwargs.get("user_id") + if _user_id: + _lf_attrs["user_id"] = str(_user_id)[:200] + _session_id = kwargs.get("session_id") or self._id + if _session_id: + _lf_attrs["session_id"] = str(_session_id)[:200] + sink_token = token_usage_sink.set(self._run_token_usage) + attrs_token = langfuse_run_attrs.set(_lf_attrs) + try: + async for ev in self._run_impl(**kwargs): + yield ev + finally: + # reset() can raise if the generator is closed from a different context + # (e.g. client disconnect); fall back to clearing the values in that case. + try: + token_usage_sink.reset(sink_token) + except ValueError: + logging.debug("Failed to reset token usage ContextVar", exc_info=True) + token_usage_sink.set(None) + try: + langfuse_run_attrs.reset(attrs_token) + except ValueError: + logging.debug("Failed to reset Langfuse run attributes ContextVar", exc_info=True) + langfuse_run_attrs.set(None) + + async def _run_impl(self, **kwargs): self.globals["sys.date"] = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S") st = time.perf_counter() self._loop = asyncio.get_running_loop() @@ -406,7 +445,7 @@ class Canvas(Graph): if kwargs.get("webhook_payload"): for k, cpn in self.components.items(): - if self.components[k]["obj"].component_name.lower() == "begin" and self.components[k]["obj"]._param.mode == "Webhook": + if self.components[k]["obj"].component_name.lower() == "begin" and self.components[k]["obj"]._param.mode == "Webhook": payload = kwargs.get("webhook_payload", {}) if "input" in payload: self.components[k]["obj"].set_input_value("request", payload["input"]) @@ -427,7 +466,7 @@ class Canvas(Graph): self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k], layout_recognize) else: self.globals[f"sys.{k}"] = kwargs[k] - if not self.globals["sys.conversation_turns"] : + if not self.globals["sys.conversation_turns"]: self.globals["sys.conversation_turns"] = 0 self.globals["sys.conversation_turns"] += 1 is_resume = bool(self.path) and self.path[0].lower().find("userfillup") >= 0 @@ -436,11 +475,11 @@ class Canvas(Graph): nonlocal created_at return { "event": event, - #"conversation_id": "f3cc152b-24b0-4258-a1a1-7d5e9fc8a115", + # "conversation_id": "f3cc152b-24b0-4258-a1a1-7d5e9fc8a115", "message_id": self.message_id, "created_at": created_at, "task_id": self.task_id, - "data": dt + "data": dt, } if not is_resume: @@ -476,7 +515,11 @@ class Canvas(Graph): if use_async: await cpn_obj.invoke_async(**(call_kwargs or {})) return - await loop.run_in_executor(self._thread_pool, partial(sync_fn, **(call_kwargs or {}))) + # run_in_executor does not propagate context variables; copy the + # current context so the token usage sink / Langfuse attributes set + # by run() remain visible to LLMBundle calls inside sync components. + ctx = contextvars.copy_context() + await loop.run_in_executor(self._thread_pool, lambda: ctx.run(partial(sync_fn, **(call_kwargs or {})))) i = f while i < t: @@ -525,16 +568,19 @@ class Canvas(Graph): json.dumps(outputs, ensure_ascii=False, default=str)[:500], cpn_obj.error(), ) - return decorate("node_finished",{ - "inputs": cpn_obj.get_input_values(), - "outputs": outputs, - "component_id": cpn_obj._id, - "component_name": self.get_component_name(cpn_obj._id), - "component_type": self.get_component_type(cpn_obj._id), - "error": cpn_obj.error(), - "elapsed_time": time.perf_counter() - cpn_obj.output("_created_time"), - "created_at": cpn_obj.output("_created_time"), - }) + return decorate( + "node_finished", + { + "inputs": cpn_obj.get_input_values(), + "outputs": outputs, + "component_id": cpn_obj._id, + "component_name": self.get_component_name(cpn_obj._id), + "component_type": self.get_component_type(cpn_obj._id), + "error": cpn_obj.error(), + "elapsed_time": time.perf_counter() - cpn_obj.output("_created_time"), + "created_at": cpn_obj.output("_created_time"), + }, + ) self.error = "" idx = 0 if is_resume else len(self.path) - 1 @@ -543,13 +589,17 @@ class Canvas(Graph): while idx < len(self.path): to = len(self.path) for i in range(idx, to): - yield decorate("node_started", { - "inputs": None, "created_at": int(time.time()), - "component_id": self.path[i], - "component_name": self.get_component_name(self.path[i]), - "component_type": self.get_component_type(self.path[i]), - "thoughts": self.get_component_thoughts(self.path[i]) - }) + yield decorate( + "node_started", + { + "inputs": None, + "created_at": int(time.time()), + "component_id": self.path[i], + "component_name": self.get_component_name(self.path[i]), + "component_type": self.get_component_type(self.path[i]), + "thoughts": self.get_component_thoughts(self.path[i]), + }, + ) await _run_batch(idx, to) to = len(self.path) # post-processing of components invocation @@ -564,6 +614,7 @@ class Canvas(Graph): _m = "" buff_m = "" stream = cpn_obj.output("content")() + async def _process_stream(m): nonlocal buff_m, _m, tts_mdl if not m: @@ -578,13 +629,7 @@ class Canvas(Graph): _m += m if len(buff_m) > 16: - ev = decorate( - "message", - { - "content": m, - "audio_binary": self.tts(tts_mdl, buff_m) - } - ) + ev = decorate("message", {"content": m, "audio_binary": self.tts(tts_mdl, buff_m)}) buff_m = "" return ev @@ -592,12 +637,12 @@ class Canvas(Graph): if inspect.isasyncgen(stream): async for m in stream: - ev= await _process_stream(m) + ev = await _process_stream(m) if ev: yield ev else: for m in stream: - ev= await _process_stream(m) + ev = await _process_stream(m) if ev: yield ev if buff_m: @@ -629,7 +674,7 @@ class Canvas(Graph): else: self.error = cpn_obj.error() - if cpn_obj.component_name.lower() not in ("iteration","loop"): + if cpn_obj.component_name.lower() not in ("iteration", "loop"): if isinstance(cpn_obj.output("content"), partial): if self.error: cpn_obj.set_output("content", None) @@ -654,7 +699,7 @@ class Canvas(Graph): for cpn_id in cpn_ids: _append_path(cpn_id) - if cpn_obj.component_name.lower() in ("iterationitem","loopitem") and cpn_obj.end(): + if cpn_obj.component_name.lower() in ("iterationitem", "loopitem") and cpn_obj.end(): iter = cpn_obj.get_parent() yield _node_finished(iter) _extend_path(self.get_component(cpn["parent_id"])["downstream"]) @@ -683,10 +728,7 @@ class Canvas(Graph): o = self.get_component_obj(c) if o.component_name.lower() == "userfillup": o.invoke() - another_inputs.update({ - k: v for k, v in o.get_input_elements().items() - if not self._is_input_field_satisfied(v) - }) + another_inputs.update({k: v for k, v in o.get_input_elements().items() if not self._is_input_field_satisfied(v)}) if o.get_param("enable_tips"): tips = o.output("tips") if not another_inputs: @@ -696,23 +738,30 @@ class Canvas(Graph): return self.path = self.path[:idx] if not self.error: - yield decorate("workflow_finished", - { - "inputs": kwargs.get("inputs"), - "outputs": self.get_component_obj(self.path[-1]).output(), - "elapsed_time": time.perf_counter() - st, - "created_at": st, - }) + yield decorate( + "workflow_finished", + { + "inputs": kwargs.get("inputs"), + "outputs": self.get_component_obj(self.path[-1]).output(), + "elapsed_time": time.perf_counter() - st, + "created_at": st, + # Run-level total of all LLM calls — emitted once here. + "usage": self._run_usage_payload(), + }, + ) self.history.append(("assistant", self.get_component_obj(self.path[-1]).output())) self.globals["sys.history"].append(f"{self.history[-1][0]}: {self.history[-1][1]}") elif "Task has been canceled" in self.error: - yield decorate("workflow_finished", - { - "inputs": kwargs.get("inputs"), - "outputs": "Task has been canceled", - "elapsed_time": time.perf_counter() - st, - "created_at": st, - }) + yield decorate( + "workflow_finished", + { + "inputs": kwargs.get("inputs"), + "outputs": "Task has been canceled", + "elapsed_time": time.perf_counter() - st, + "created_at": st, + "usage": self._run_usage_payload(), + }, + ) def is_reff(self, exp: str) -> bool: exp = exp.strip("{").strip("}") @@ -725,8 +774,7 @@ class Canvas(Graph): return False return True - - def tts(self,tts_mdl, text): + def tts(self, tts_mdl, text): def clean_tts_text(text: str) -> str: if not text: return "" @@ -736,15 +784,8 @@ class Canvas(Graph): text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text) emoji_pattern = re.compile( - "[\U0001F600-\U0001F64F" - "\U0001F300-\U0001F5FF" - "\U0001F680-\U0001F6FF" - "\U0001F1E0-\U0001F1FF" - "\U00002700-\U000027BF" - "\U0001F900-\U0001F9FF" - "\U0001FA70-\U0001FAFF" - "\U0001FAD0-\U0001FAFF]+", - flags=re.UNICODE + "[\U0001f600-\U0001f64f\U0001f300-\U0001f5ff\U0001f680-\U0001f6ff\U0001f1e0-\U0001f1ff\U00002700-\U000027bf\U0001f900-\U0001f9ff\U0001fa70-\U0001faff\U0001fad0-\U0001faff]+", + flags=re.UNICODE, ) text = emoji_pattern.sub("", text) @@ -755,6 +796,7 @@ class Canvas(Graph): text = text[:MAX_LEN] return text + if not tts_mdl or not text: return None text = clean_tts_text(text) @@ -766,7 +808,7 @@ class Canvas(Graph): convs = [] if window_size <= 0: return convs - for role, obj in self.history[window_size * -2:]: + for role, obj in self.history[window_size * -2 :]: if isinstance(obj, dict): convs.append({"role": role, "content": obj.get("content", "")}) else: @@ -815,17 +857,19 @@ class Canvas(Graph): async def get_files_async(self, files: Union[None, list[dict]], layout_recognize: str = None) -> list[str]: if not files: - return [] + return [] + def image_to_base64(file): - return "data:{};base64,{}".format(file["mime_type"], - base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8")) + return "data:{};base64,{}".format(file["mime_type"], base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8")) + def parse_file(file): blob = FileService.get_blob(file["created_by"], file["id"]) return FileService.parse(file["name"], blob, True, file["created_by"], layout_recognize) + loop = asyncio.get_running_loop() tasks = [] for file in files: - if file["mime_type"].find("image") >=0: + if file["mime_type"].find("image") >= 0: tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file)) continue tasks.append(loop.run_in_executor(self._thread_pool, parse_file, file)) @@ -844,7 +888,7 @@ class Canvas(Graph): def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None): agent_ids = agent_id.split("-->") agent_name = self.get_component_name(agent_ids[0]) - path = agent_name if len(agent_ids) < 2 else agent_name+"-->"+"-->".join(agent_ids[1:]) + path = agent_name if len(agent_ids) < 2 else agent_name + "-->" + "-->".join(agent_ids[1:]) try: bin = REDIS_CONN.get(f"{self.task_id}-{self.message_id}-logs") if bin: @@ -852,16 +896,10 @@ class Canvas(Graph): if obj[-1]["component_id"] == agent_ids[0]: obj[-1]["trace"].append({"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}) else: - obj.append({ - "component_id": agent_ids[0], - "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}] - }) + obj.append({"component_id": agent_ids[0], "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}]}) else: - obj = [{ - "component_id": agent_ids[0], - "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}] - }] - REDIS_CONN.set_obj(f"{self.task_id}-{self.message_id}-logs", obj, 60*10) + obj = [{"component_id": agent_ids[0], "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}]}] + REDIS_CONN.set_obj(f"{self.task_id}-{self.message_id}-logs", obj, 60 * 10) except Exception as e: logging.exception(e) @@ -899,9 +937,22 @@ class Canvas(Graph): message_end["attachment"] = cpn_obj.output("attachment") if self._has_reference(): message_end["reference"] = self.get_reference() + # NOTE: aggregated run token usage is intentionally NOT attached here. + # _build_message_end runs once per Message component, so a multi-Message graph + # would emit cumulative usage repeatedly and double count. The run total is + # emitted exactly once on the terminal workflow_finished event instead. return message_end - def add_memory(self, user:str, assist:str, summ: str): + def _run_usage_payload(self) -> dict: + usage = getattr(self, "_run_token_usage", None) or {} + return { + "prompt_tokens": usage.get("prompt_tokens", 0), + "completion_tokens": usage.get("completion_tokens", 0), + "total_tokens": usage.get("total_tokens", 0), + "calls": usage.get("calls", 0), + } + + def add_memory(self, user: str, assist: str, summ: str): self.memory.append((user, assist, summ)) def get_memory(self) -> list[Tuple]: diff --git a/api/apps/restful_apis/agent_api.py b/api/apps/restful_apis/agent_api.py index 8ae6723709..ae28cd0414 100644 --- a/api/apps/restful_apis/agent_api.py +++ b/api/apps/restful_apis/agent_api.py @@ -88,29 +88,32 @@ def _canvas_json_default(obj): def _require_canvas_access_sync(func): @wraps(func) def wrapper(*args, **kwargs): - if not UserCanvasService.accessible(kwargs.get('agent_id'), kwargs.get('tenant_id')): + if not UserCanvasService.accessible(kwargs.get("agent_id"), kwargs.get("tenant_id")): return get_json_result(data=False, message="Make sure you have permission to access the agent.", code=RetCode.OPERATING_ERROR) return func(*args, **kwargs) + return wrapper def _require_canvas_access_async(func): @wraps(func) async def wrapper(*args, **kwargs): - agent_id = kwargs.get('agent_id') - tenant_id = kwargs.get('tenant_id') + agent_id = kwargs.get("agent_id") + tenant_id = kwargs.get("tenant_id") if not await thread_pool_exec(UserCanvasService.accessible, agent_id, tenant_id): return get_json_result(data=False, message="Make sure you have permission to access the agent.", code=RetCode.OPERATING_ERROR) return await func(*args, **kwargs) + return wrapper def _require_canvas_owner_sync(func): @wraps(func) def wrapper(*args, **kwargs): - if not UserCanvasService.query(user_id=kwargs.get('tenant_id'), id=kwargs.get('agent_id')): + if not UserCanvasService.query(user_id=kwargs.get("tenant_id"), id=kwargs.get("agent_id")): return get_json_result(data=False, message="Only the owner of the agent is authorized for this operation.", code=RetCode.OPERATING_ERROR) return func(*args, **kwargs) + return wrapper @@ -261,9 +264,7 @@ async def _run_workflow_session( if "chunks" in workflow_conv["reference"]: workflow_conv["reference"] = [workflow_conv["reference"]] else: - workflow_conv["reference"] = [ - value for _, value in sorted(workflow_conv["reference"].items(), key=lambda item: int(item[0])) - ] + workflow_conv["reference"] = [value for _, value in sorted(workflow_conv["reference"].items(), key=lambda item: int(item[0]))] elif not isinstance(workflow_conv.get("reference"), list): workflow_conv["reference"] = [] workflow_conv["reference"] = [_normalize_agent_reference_entry(reference) for reference in workflow_conv["reference"]] @@ -344,22 +345,16 @@ async def _run_workflow_session( # bare [DONE] (fixes #15169). logging.info( "empty agent output - returning session_id (agent_id=%s session_id=%s stream=%s)", - agent_id, session_id, True, - ) - yield ( - "data:" - + json.dumps({"session_id": session_id, "data": {}}, ensure_ascii=False) - + "\n\n" + agent_id, + session_id, + True, ) + yield ("data:" + json.dumps({"session_id": session_id, "data": {}}, ensure_ascii=False) + "\n\n") await persist_workflow_session() except Exception as exc: logging.exception(exc) canvas.cancel_task() - yield ( - "data:" - + json.dumps({"code": 500, "message": str(exc), "data": False}, ensure_ascii=False) - + "\n\n" - ) + yield ("data:" + json.dumps({"code": 500, "message": str(exc), "data": False}, ensure_ascii=False) + "\n\n") finally: if not done_sent: done_sent = True @@ -400,7 +395,9 @@ async def _run_workflow_session( # (fixes #15169). logging.info( "empty agent output - returning session_id (agent_id=%s session_id=%s stream=%s)", - agent_id, session_id, False, + agent_id, + session_id, + False, ) await commit_runtime_replica() return get_result(data={"session_id": session_id}) @@ -559,16 +556,13 @@ async def delete_agent_session(tenant_id, agent_id): if errors: if success_count > 0: - return get_result(data={"success_count": success_count, "errors": errors}, - message=f"Partially deleted {success_count} sessions with {len(errors)} errors") + return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} sessions with {len(errors)} errors") else: return get_error_data_result(message="; ".join(errors)) if duplicate_messages: if success_count > 0: - return get_result( - message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", - data={"success_count": success_count, "errors": duplicate_messages}) + return get_result(message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages}) else: return get_error_data_result(message=";".join(duplicate_messages)) @@ -611,8 +605,8 @@ async def _iter_session_completion_events(tenant_id, agent_id, req, return_trace yield ans continue - if event in ["message", "message_end", "user_inputs", "workflow_finished"]: - if event in ["user_inputs", "workflow_finished"]: + if event in ["message", "message_end", "user_inputs"]: + if event == "user_inputs": logging.debug( "Forwarding session completion event: tenant_id=%s agent_id=%s event=%s", tenant_id, @@ -620,6 +614,22 @@ async def _iter_session_completion_events(tenant_id, agent_id, req, return_trace event, ) yield ans + continue + + if event == "workflow_finished": + # Forward only the run-level aggregated token usage, not the whole terminal + # payload (inputs/outputs), so the session completion stream surface stays + # limited to what the usage contract needs. + logging.debug( + "Forwarding session completion event: tenant_id=%s agent_id=%s event=%s", + tenant_id, + agent_id, + event, + ) + usage = ans.get("data", {}).get("usage") + if usage is not None: + yield {**ans, "data": {"usage": usage}} + continue @manager.route("/agents/templates", methods=["GET"]) # noqa: F821 @@ -760,7 +770,7 @@ async def update_agent_tags(tenant_id, canvas_id): @add_tenant_id_to_kwargs async def create_agent(tenant_id): req = {k: v for k, v in (await get_request_json()).items() if v is not None} - req["canvas_type"] = req.get("canvas_type","") + req["canvas_type"] = req.get("canvas_type", "") req["user_id"] = tenant_id req["canvas_category"] = req.get("canvas_category") or CanvasCategory.Agent req["release"] = bool(req.get("release", "")) @@ -837,13 +847,9 @@ async def upload_agent_file(agent_id, tenant_id): ) try: if len(file_objs) == 1: - uploaded = await thread_pool_exec( - FileService.upload_info, tenant_id, file_objs[0], request.args.get("url") - ) + uploaded = await thread_pool_exec(FileService.upload_info, tenant_id, file_objs[0], request.args.get("url")) return get_json_result(data=uploaded) - results = await asyncio.gather( - *(thread_pool_exec(FileService.upload_info, tenant_id, file_obj) for file_obj in file_objs) - ) + results = await asyncio.gather(*(thread_pool_exec(FileService.upload_info, tenant_id, file_obj) for file_obj in file_objs)) return get_json_result(data=results) except Exception as exc: logging.exception( @@ -1015,7 +1021,7 @@ def delete_agent(agent_id, tenant_id): @_require_canvas_access_async async def update_agent(agent_id, tenant_id): req = {k: v for k, v in (await get_request_json()).items() if v is not None} - req["canvas_type"] = req.get("canvas_type","") + req["canvas_type"] = req.get("canvas_type", "") req["release"] = bool(req.get("release", "")) if req.get("dsl") is not None: @@ -1038,10 +1044,7 @@ async def update_agent(agent_id, tenant_id): return get_data_error_result(message=f"{req['title']} already exists.") agent_title_for_version = req.get("title") or (current_agent.title if current_agent else "") - canvas_category = ( - req.get("canvas_category") - or (current_agent.canvas_category if current_agent else CanvasCategory.Agent) - ) + canvas_category = req.get("canvas_category") or (current_agent.canvas_category if current_agent else CanvasCategory.Agent) owner_nickname = _get_user_nickname(tenant_id) UserCanvasService.update_by_id(agent_id, req) @@ -1153,13 +1156,19 @@ async def test_db_connection(): except ValueError as exc: logging.warning( "Rejected test_db_connection: unsafe host %r (db_type=%s, user=%s): %s", - req.get("host"), req.get("db_type"), current_user.id, exc, + req.get("host"), + req.get("db_type"), + current_user.id, + exc, ) return get_data_error_result(message=str(exc)) except OSError as exc: logging.warning( "Rejected test_db_connection: cannot resolve host %r (db_type=%s, user=%s): %s", - req.get("host"), req.get("db_type"), current_user.id, exc, + req.get("host"), + req.get("db_type"), + current_user.id, + exc, ) logging.debug("Full resolver exception for host %r", req.get("host"), exc_info=True) return get_data_error_result(message=f"Could not resolve host {req.get('host')!r}.") @@ -1198,13 +1207,7 @@ async def test_db_connection(): elif req["db_type"] == "mssql": import pyodbc - connection_string = ( - f"DRIVER={{ODBC Driver 17 for SQL Server}};" - f"SERVER={safe_host},{req['port']};" - f"DATABASE={req['database']};" - f"UID={req['username']};" - f"PWD={req['password']};" - ) + connection_string = f"DRIVER={{ODBC Driver 17 for SQL Server}};SERVER={safe_host},{req['port']};DATABASE={req['database']};UID={req['username']};PWD={req['password']};" db = pyodbc.connect(connection_string) try: cursor = db.cursor() @@ -1217,14 +1220,7 @@ async def test_db_connection(): elif req["db_type"] == "IBM DB2": import ibm_db - conn_str = ( - f"DATABASE={req['database']};" - f"HOSTNAME={safe_host};" - f"PORT={req['port']};" - f"PROTOCOL=TCPIP;" - f"UID={req['username']};" - f"PWD={req['password']};" - ) + conn_str = f"DATABASE={req['database']};HOSTNAME={safe_host};PORT={req['port']};PROTOCOL=TCPIP;UID={req['username']};PWD={req['password']};" logging.info( "DATABASE=%s;HOSTNAME=%s;PORT=%s;PROTOCOL=TCPIP;UID=%s;PWD=****;", req["database"], @@ -1387,9 +1383,7 @@ async def agent_chat_completion(tenant_id, agent_id=None): if "chunks" in workflow_conv["reference"]: workflow_conv["reference"] = [workflow_conv["reference"]] else: - workflow_conv["reference"] = [ - value for _, value in sorted(workflow_conv["reference"].items(), key=lambda item: int(item[0])) - ] + workflow_conv["reference"] = [value for _, value in sorted(workflow_conv["reference"].items(), key=lambda item: int(item[0]))] elif not isinstance(workflow_conv.get("reference"), list): workflow_conv["reference"] = [] workflow_conv["reference"] = [_normalize_agent_reference_entry(reference) for reference in workflow_conv["reference"]] @@ -1598,13 +1592,11 @@ async def agent_chat_completion(tenant_id, agent_id=None): # seeing only a bare [DONE] (fixes #15169). logging.info( "empty agent output - returning session_id (agent_id=%s session_id=%s stream=%s)", - agent_id, session_id, True, - ) - yield ( - "data:" - + json.dumps({"session_id": session_id, "data": {}}, ensure_ascii=False) - + "\n\n" + agent_id, + session_id, + True, ) + yield ("data:" + json.dumps({"session_id": session_id, "data": {}}, ensure_ascii=False) + "\n\n") yield "data:[DONE]\n\n" return _build_sse_response(generate()) @@ -1614,6 +1606,7 @@ async def agent_chat_completion(tenant_id, agent_id=None): final_ans = {} trace_items = [] structured_output = {} + run_usage = None async for ans in _iter_session_completion_events(tenant_id, agent_id, req, return_trace): try: if ans["event"] == "message": @@ -1633,6 +1626,11 @@ async def agent_chat_completion(tenant_id, agent_id=None): "trace": [copy.deepcopy(data)], } ) + if ans.get("event") == "workflow_finished": + # Capture the run-level usage but keep message_end/user_inputs as + # final_ans so the non-stream response shape stays unchanged. + run_usage = ans.get("data", {}).get("usage") + continue if ans.get("event") == "message_end": final_ans = ans elif ans.get("event") == "user_inputs" and not final_ans: @@ -1647,7 +1645,9 @@ async def agent_chat_completion(tenant_id, agent_id=None): # (fixes #15169). logging.info( "empty agent output - returning session_id (agent_id=%s session_id=%s stream=%s)", - agent_id, session_id, False, + agent_id, + session_id, + False, ) return get_result(data={"session_id": session_id}) @@ -1655,6 +1655,8 @@ async def agent_chat_completion(tenant_id, agent_id=None): final_ans["data"] = {} final_ans["data"]["content"] = full_content final_ans["data"]["reference"] = reference + if run_usage: + final_ans["data"]["usage"] = run_usage if structured_output: final_ans["data"]["structured"] = structured_output if return_trace and final_ans: @@ -1688,16 +1690,16 @@ async def _webhook_impl(agent_id: str, is_test: bool): # 1. Fetch canvas by agent_id exists, cvs = UserCanvasService.get_by_id(agent_id) if not exists: - return get_data_error_result(code=RetCode.BAD_REQUEST,message="Canvas not found."),RetCode.BAD_REQUEST + return get_data_error_result(code=RetCode.BAD_REQUEST, message="Canvas not found."), RetCode.BAD_REQUEST # 2. Check canvas category if cvs.canvas_category == CanvasCategory.DataFlow: - return get_data_error_result(code=RetCode.BAD_REQUEST,message="Dataflow can not be triggered by webhook."),RetCode.BAD_REQUEST + return get_data_error_result(code=RetCode.BAD_REQUEST, message="Dataflow can not be triggered by webhook."), RetCode.BAD_REQUEST # 3. Load DSL from canvas dsl = getattr(cvs, "dsl", None) if not isinstance(dsl, dict): - return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST + return get_data_error_result(code=RetCode.BAD_REQUEST, message="Invalid DSL format."), RetCode.BAD_REQUEST # 4. Check webhook configuration in DSL webhook_cfg = {} @@ -1708,15 +1710,13 @@ async def _webhook_impl(agent_id: str, is_test: bool): webhook_cfg = cpn_obj["params"] if not webhook_cfg: - return get_data_error_result(code=RetCode.BAD_REQUEST,message="Webhook not configured for this agent."),RetCode.BAD_REQUEST + return get_data_error_result(code=RetCode.BAD_REQUEST, message="Webhook not configured for this agent."), RetCode.BAD_REQUEST # 5. Validate request method against webhook_cfg.methods allowed_methods = webhook_cfg.get("methods", []) request_method = request.method.upper() if allowed_methods and request_method not in allowed_methods: - return get_data_error_result( - code=RetCode.BAD_REQUEST,message=f"HTTP method '{request_method}' not allowed for this webhook." - ),RetCode.BAD_REQUEST + return get_data_error_result(code=RetCode.BAD_REQUEST, message=f"HTTP method '{request_method}' not allowed for this webhook."), RetCode.BAD_REQUEST async def validate_webhook_security(security_cfg: dict): """Validate webhook security rules based on security configuration.""" @@ -1795,7 +1795,6 @@ async def _webhook_impl(agent_id: str, is_test: bool): client_ip = request.remote_addr - for rule in whitelist: if "/" in rule: # CIDR notation @@ -1854,7 +1853,7 @@ async def _webhook_impl(agent_id: str, is_test: bool): def _validate_token_auth(security_cfg): """Validate header-based token authentication.""" - token_cfg = security_cfg.get("token",{}) + token_cfg = security_cfg.get("token", {}) header = token_cfg.get("token_header") token_value = token_cfg.get("token_value") @@ -1883,7 +1882,7 @@ async def _webhook_impl(agent_id: str, is_test: bool): if not auth_header.startswith("Bearer "): raise Exception("Missing Bearer token") - token = auth_header[len("Bearer "):].strip() + token = auth_header[len("Bearer ") :].strip() if not token: raise Exception("Empty Bearer token") @@ -1922,10 +1921,7 @@ async def _webhook_impl(agent_id: str, is_test: bool): else: required_claims = [] - required_claims = [ - c for c in required_claims - if isinstance(c, str) and c.strip() - ] + required_claims = [c for c in required_claims if isinstance(c, str) and c.strip()] RESERVED_CLAIMS = {"exp", "sub", "aud", "iss", "nbf", "iat"} for claim in required_claims: @@ -1939,10 +1935,10 @@ async def _webhook_impl(agent_id: str, is_test: bool): return decoded try: - security_config=webhook_cfg.get("security", {}) + security_config = webhook_cfg.get("security", {}) await validate_webhook_security(security_config) except Exception as e: - return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST + return get_data_error_result(code=RetCode.BAD_REQUEST, message=str(e)), RetCode.BAD_REQUEST if not isinstance(cvs.dsl, str): dsl = json.dumps(cvs.dsl, ensure_ascii=False) try: @@ -1950,7 +1946,7 @@ async def _webhook_impl(agent_id: str, is_test: bool): canvas = Canvas(dsl, cvs.user_id, agent_id, canvas_id=agent_id) except Exception as e: - resp=get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)) + resp = get_data_error_result(code=RetCode.BAD_REQUEST, message=str(e)) resp.status_code = RetCode.BAD_REQUEST return resp @@ -1967,9 +1963,7 @@ async def _webhook_impl(agent_id: str, is_test: bool): # 3. Body ctype = request.headers.get("Content-Type", "").split(";")[0].strip() if ctype and ctype != content_type: - raise ValueError( - f"Invalid Content-Type: expect '{content_type}', got '{ctype}'" - ) + raise ValueError(f"Invalid Content-Type: expect '{content_type}', got '{ctype}'") body_data: dict = {} @@ -1991,11 +1985,11 @@ async def _webhook_impl(agent_id: str, is_test: bool): raise Exception("Too many uploaded files") for key, file in files.items(): desc = FileService.upload_info( - cvs.user_id, # user - file, # FileStorage - None # url (None for webhook) + cvs.user_id, # user + file, # FileStorage + None, # url (None for webhook) ) - file_parsed= await canvas.get_files_async([desc]) + file_parsed = await canvas.get_files_async([desc]) body_data[key] = file_parsed elif ctype == "application/x-www-form-urlencoded": @@ -2057,15 +2051,12 @@ async def _webhook_impl(agent_id: str, is_test: bool): # 4. Type validation if not validate_type(value, field_type): - raise Exception( - f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}" - ) + raise Exception(f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}") extracted[field] = value return extracted - def default_for_type(t): """Return default value for the given schema type.""" if t == "file": @@ -2145,7 +2136,6 @@ async def _webhook_impl(agent_id: str, is_test: bool): # Default: do nothing return value - def validate_type(value, t): """Validate value type against schema type t.""" if t == "file": @@ -2179,28 +2169,24 @@ async def _webhook_impl(agent_id: str, is_test: bool): return True return True + parsed = await parse_webhook_request(webhook_cfg.get("content_types")) SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}}) # Extract strictly by schema try: - query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query") + query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query") header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers") - body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body") + body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body") except Exception as e: - return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST + return get_data_error_result(code=RetCode.BAD_REQUEST, message=str(e)), RetCode.BAD_REQUEST - clean_request = { - "query": query_clean, - "headers": header_clean, - "body": body_clean, - "input": parsed - } + clean_request = {"query": query_clean, "headers": header_clean, "body": body_clean, "input": parsed} execution_mode = webhook_cfg.get("execution_mode", "Immediately") response_cfg = webhook_cfg.get("response", {}) - def append_webhook_trace(agent_id: str, start_ts: float,event: dict, ttl=600): + def append_webhook_trace(agent_id: str, start_ts: float, event: dict, ttl=600): from rag.utils.redis_conn import REDIS_CONN key = f"webhook-trace-{agent_id}-logs" @@ -2208,15 +2194,9 @@ async def _webhook_impl(agent_id: str, is_test: bool): raw = REDIS_CONN.get(key) obj = json.loads(raw) if raw else {"webhooks": {}} - ws = obj["webhooks"].setdefault( - str(start_ts), - {"start_ts": start_ts, "events": []} - ) + ws = obj["webhooks"].setdefault(str(start_ts), {"start_ts": start_ts, "events": []}) - ws["events"].append({ - "ts": time.time(), - **event - }) + ws["events"].append({"ts": time.time(), **event}) REDIS_CONN.set_obj(key, obj, ttl) @@ -2225,10 +2205,10 @@ async def _webhook_impl(agent_id: str, is_test: bool): try: status = int(status) except (TypeError, ValueError): - return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}")),RetCode.BAD_REQUEST + return get_data_error_result(code=RetCode.BAD_REQUEST, message=str(f"Invalid response status code: {status}")), RetCode.BAD_REQUEST if not (200 <= status <= 399): - return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}, must be between 200 and 399")),RetCode.BAD_REQUEST + return get_data_error_result(code=RetCode.BAD_REQUEST, message=str(f"Invalid response status code: {status}, must be between 200 and 399")), RetCode.BAD_REQUEST body_tpl = response_cfg.get("body_template", "") @@ -2242,7 +2222,6 @@ async def _webhook_impl(agent_id: str, is_test: bool): except (json.JSONDecodeError, TypeError): return body, "text/plain" - body, content_type = parse_body(body_tpl) resp = Response( json.dumps(body, ensure_ascii=False) if content_type == "application/json" else body, @@ -2252,11 +2231,7 @@ async def _webhook_impl(agent_id: str, is_test: bool): async def background_run(): try: - async for ans in canvas.run( - query="", - user_id=cvs.user_id, - webhook_payload=clean_request - ): + async for ans in canvas.run(query="", user_id=cvs.user_id, webhook_payload=clean_request): if is_test: append_webhook_trace(agent_id, start_ts, ans) @@ -2268,7 +2243,7 @@ async def _webhook_impl(agent_id: str, is_test: bool): "event": "finished", "elapsed_time": time.time() - start_ts, "success": True, - } + }, ) cvs.dsl = json.loads(str(canvas)) @@ -2285,7 +2260,7 @@ async def _webhook_impl(agent_id: str, is_test: bool): "event": "error", "message": str(e), "error_type": type(e).__name__, - } + }, ) append_webhook_trace( agent_id, @@ -2294,7 +2269,7 @@ async def _webhook_impl(agent_id: str, is_test: bool): "event": "finished", "elapsed_time": time.time() - start_ts, "success": False, - } + }, ) except Exception: logging.exception("Failed to append webhook trace") @@ -2305,6 +2280,7 @@ async def _webhook_impl(agent_id: str, is_test: bool): task.add_done_callback(_background_tasks.discard) return resp else: + async def sse(): nonlocal canvas contents: list[str] = [] @@ -2326,11 +2302,7 @@ async def _webhook_impl(agent_id: str, is_test: bool): if ans["event"] == "message_end": status = int(ans["data"].get("status", status)) if is_test: - append_webhook_trace( - agent_id, - start_ts, - ans - ) + append_webhook_trace(agent_id, start_ts, ans) if is_test: append_webhook_trace( agent_id, @@ -2339,13 +2311,13 @@ async def _webhook_impl(agent_id: str, is_test: bool): "event": "finished", "elapsed_time": time.time() - start_ts, "success": True, - } + }, ) final_content = "".join(contents) return { "message": final_content, "success": True, - "code": status, + "code": status, } except Exception as e: @@ -2357,7 +2329,7 @@ async def _webhook_impl(agent_id: str, is_test: bool): "event": "error", "message": str(e), "error_type": type(e).__name__, - } + }, ) append_webhook_trace( agent_id, @@ -2366,9 +2338,9 @@ async def _webhook_impl(agent_id: str, is_test: bool): "event": "finished", "elapsed_time": time.time() - start_ts, "success": False, - } + }, ) - return {"code": 400, "message": str(e),"success":False} + return {"code": 400, "message": str(e), "success": False} result = await sse() return Response( @@ -2401,6 +2373,7 @@ async def webhook_trace(agent_id: str): if encode_webhook_id(ts) == enc_id: return ts return None + since_ts = request.args.get("since_ts", type=float) webhook_id = request.args.get("webhook_id") @@ -2434,9 +2407,7 @@ async def webhook_trace(agent_id: str): webhooks = obj.get("webhooks", {}) if webhook_id is None: - candidates = [ - float(k) for k in webhooks.keys() if float(k) > since_ts - ] + candidates = [float(k) for k in webhooks.keys() if float(k) > since_ts] if not candidates: return get_json_result( @@ -2492,6 +2463,7 @@ async def webhook_trace(agent_id: str): } ) + @manager.route("/agents/attachments//download", methods=["GET"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py index 561535259c..5e3498be00 100644 --- a/api/db/services/canvas_service.py +++ b/api/db/services/canvas_service.py @@ -34,10 +34,12 @@ from peewee import fn class CanvasTemplateService(CommonService): model = CanvasTemplate + class DataFlowTemplateService(CommonService): """ Alias of CanvasTemplateService """ + model = CanvasTemplate @@ -46,8 +48,7 @@ class UserCanvasService(CommonService): @classmethod @DB.connection_context() - def get_list(cls, tenant_id, - page_number, items_per_page, orderby, desc, id, title, canvas_category=CanvasCategory.Agent): + def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, title, canvas_category=CanvasCategory.Agent): agents = cls.model.select() if id: agents = agents.where(cls.model.id == id) @@ -68,20 +69,9 @@ class UserCanvasService(CommonService): @DB.connection_context() def get_all_agents_by_tenant_ids(cls, tenant_ids, user_id): # will get all permitted agents, be cautious - fields = [ - cls.model.id, - cls.model.avatar, - cls.model.title, - cls.model.permission, - cls.model.canvas_type, - cls.model.canvas_category - ] + fields = [cls.model.id, cls.model.avatar, cls.model.title, cls.model.permission, cls.model.canvas_type, cls.model.canvas_category] # find team agents and owned agents - agents = cls.model.select(*fields).where( - (cls.model.user_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | ( - cls.model.user_id == user_id - ) - ) + agents = cls.model.select(*fields).where((cls.model.user_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)) # sort by create_time, asc agents.order_by(cls.model.create_time.asc()) # maybe cause slow query by deep paginate, optimize later @@ -100,7 +90,6 @@ class UserCanvasService(CommonService): @DB.connection_context() def get_by_canvas_id(cls, pid): try: - fields = [ cls.model.id, cls.model.avatar, @@ -115,11 +104,9 @@ class UserCanvasService(CommonService): cls.model.update_date, cls.model.canvas_category, User.nickname, - User.avatar.alias('tenant_avatar'), + User.avatar.alias("tenant_avatar"), ] - agents = cls.model.select(*fields) \ - .join(User, on=(cls.model.user_id == User.id)) \ - .where(cls.model.id == pid) + agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(cls.model.id == pid) # obj = cls.model.query(id=pid)[0] return True, agents.dicts()[0] except Exception as e: @@ -129,14 +116,7 @@ class UserCanvasService(CommonService): @classmethod @DB.connection_context() def get_basic_info_by_canvas_ids(cls, canvas_id): - fields = [ - cls.model.id, - cls.model.avatar, - cls.model.user_id, - cls.model.title, - cls.model.permission, - cls.model.canvas_category - ] + fields = [cls.model.id, cls.model.avatar, cls.model.user_id, cls.model.title, cls.model.permission, cls.model.canvas_category] return cls.model.select(*fields).where(cls.model.id.in_(canvas_id)).dicts() @classmethod @@ -162,20 +142,26 @@ class UserCanvasService(CommonService): cls.model.permission, cls.model.user_id.alias("tenant_id"), User.nickname, - User.avatar.alias('tenant_avatar'), + User.avatar.alias("tenant_avatar"), cls.model.update_time, cls.model.canvas_type, cls.model.canvas_category, cls.model.tags, ] if keywords: - agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where( - (((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)), - (fn.LOWER(cls.model.title).contains(keywords.lower())) + agents = ( + cls.model.select(*fields) + .join(User, on=(cls.model.user_id == User.id)) + .where( + (((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)), + (fn.LOWER(cls.model.title).contains(keywords.lower())), + ) ) else: - agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where( - (((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)) + agents = ( + cls.model.select(*fields) + .join(User, on=(cls.model.user_id == User.id)) + .where((((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id))) ) if canvas_category: agents = agents.where(cls.model.canvas_category == canvas_category) @@ -201,7 +187,7 @@ class UserCanvasService(CommonService): # Get latest release time for each canvas if agents_list: - canvas_ids = [a['id'] for a in agents_list] + canvas_ids = [a["id"] for a in agents_list] release_times = ( UserCanvasVersion.select(UserCanvasVersion.user_canvas_id, fn.MAX(UserCanvasVersion.create_time).alias("release_time")) .where((UserCanvasVersion.user_canvas_id.in_(canvas_ids)) & (UserCanvasVersion.release)) @@ -210,7 +196,7 @@ class UserCanvasService(CommonService): release_time_map = {r.user_canvas_id: r.release_time for r in release_times} for agent in agents_list: - agent['release_time'] = release_time_map.get(agent['id']) + agent["release_time"] = release_time_map.get(agent["id"]) return agents_list, count @@ -218,9 +204,7 @@ class UserCanvasService(CommonService): @DB.connection_context() def list_tags(cls, joined_tenant_ids, user_id, canvas_category=None): """Return {tag: agent_count} aggregated across agents visible to the user.""" - query = cls.model.select(cls.model.tags).where( - ((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id) - ) + query = cls.model.select(cls.model.tags).where(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)) if canvas_category: query = query.where(cls.model.canvas_category == canvas_category) @@ -281,6 +265,7 @@ class UserCanvasService(CommonService): @DB.connection_context() def accessible(cls, canvas_id, tenant_id): from api.db.services.user_service import UserTenantService + e, c = UserCanvasService.get_by_canvas_id(canvas_id) if not e: return False @@ -345,18 +330,15 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs): conv = API4Conversation(**conv) message_id = str(uuid4()) - conv.message.append({ - "role": "user", - "content": query, - "id": message_id, - "files": files - }) + conv.message.append({"role": "user", "content": query, "id": message_id, "files": files}) txt = "" run_kwargs = { "query": query, "files": files, "user_id": user_id, "inputs": inputs, + # Used by Canvas.run to correlate RAGFlow's Langfuse generations by session. + "session_id": session_id, } if chat_template_kwargs is not None: run_kwargs["chat_template_kwargs"] = chat_template_kwargs @@ -394,14 +376,7 @@ async def completion_openai(tenant_id, agent_id, question, session_id=None, stre if stream: completion_tokens = 0 try: - async for ans in completion( - tenant_id=tenant_id, - agent_id=agent_id, - session_id=session_id, - query=question, - user_id=user_id, - **kwargs - ): + async for ans in completion(tenant_id=tenant_id, agent_id=agent_id, session_id=session_id, query=question, user_id=user_id, **kwargs): if isinstance(ans, str): try: ans = json.loads(ans[5:]) # remove "data:" @@ -417,14 +392,7 @@ async def completion_openai(tenant_id, agent_id, question, session_id=None, stre completion_tokens += len(tiktoken_encoder.encode(content_piece)) - openai_data = get_data_openai( - id=session_id or str(uuid4()), - model=agent_id, - content=content_piece, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - stream=True - ) + openai_data = get_data_openai(id=session_id or str(uuid4()), model=agent_id, content=content_piece, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, stream=True) if ans.get("data", {}).get("reference", None): openai_data["choices"][0]["delta"]["reference"] = ans["data"]["reference"] @@ -435,32 +403,29 @@ async def completion_openai(tenant_id, agent_id, question, session_id=None, stre except Exception as e: logging.exception(e) - yield "data: " + json.dumps( - get_data_openai( - id=session_id or str(uuid4()), - model=agent_id, - content=f"**ERROR**: {str(e)}", - finish_reason="stop", - prompt_tokens=prompt_tokens, - completion_tokens=len(tiktoken_encoder.encode(f"**ERROR**: {str(e)}")), - stream=True - ), - ensure_ascii=False - ) + "\n\n" + yield ( + "data: " + + json.dumps( + get_data_openai( + id=session_id or str(uuid4()), + model=agent_id, + content=f"**ERROR**: {str(e)}", + finish_reason="stop", + prompt_tokens=prompt_tokens, + completion_tokens=len(tiktoken_encoder.encode(f"**ERROR**: {str(e)}")), + stream=True, + ), + ensure_ascii=False, + ) + + "\n\n" + ) yield "data: [DONE]\n\n" else: try: all_content = "" reference = {} - async for ans in completion( - tenant_id=tenant_id, - agent_id=agent_id, - session_id=session_id, - query=question, - user_id=user_id, - **kwargs - ): + async for ans in completion(tenant_id=tenant_id, agent_id=agent_id, session_id=session_id, query=question, user_id=user_id, **kwargs): if isinstance(ans, str): ans = json.loads(ans[5:]) if ans.get("event") not in ["message", "message_end"]: @@ -475,13 +440,7 @@ async def completion_openai(tenant_id, agent_id, question, session_id=None, stre completion_tokens = len(tiktoken_encoder.encode(all_content)) openai_data = get_data_openai( - id=session_id or str(uuid4()), - model=agent_id, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - content=all_content, - finish_reason="stop", - param=None + id=session_id or str(uuid4()), model=agent_id, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, content=all_content, finish_reason="stop", param=None ) if reference: @@ -497,5 +456,5 @@ async def completion_openai(tenant_id, agent_id, question, session_id=None, stre completion_tokens=len(tiktoken_encoder.encode(f"**ERROR**: {str(e)}")), content=f"**ERROR**: {str(e)}", finish_reason="stop", - param=None + param=None, ) diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 6ff5b7ce9f..c1fe749bee 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -27,7 +27,7 @@ from langfuse import propagate_attributes from api.db.db_models import LLM from api.db.services.common_service import CommonService from api.db.services.tenant_llm_service import LLM4Tenant -from common.token_utils import num_tokens_from_string +from common.token_utils import num_tokens_from_string, record_run_token_usage, langfuse_run_attrs class LLMService(CommonService): @@ -39,11 +39,49 @@ class LLMBundle(LLM4Tenant): super().__init__(tenant_id, model_config, lang, **kwargs) def _start_langfuse_observation(self, **kwargs): + # Correlating attributes (session_id/user_id) let Langfuse group all of a + # turn's generations. They may come from this bundle (chat/dialog path) or, + # for agent runs whose bundles are created without them, from the per-run + # context installed by Canvas.run. + attrs = {} if self.langfuse_session_id: - with propagate_attributes(session_id=self.langfuse_session_id): + attrs["session_id"] = self.langfuse_session_id + run_attrs = langfuse_run_attrs.get() + if run_attrs: + for k in ("session_id", "user_id"): + if run_attrs.get(k) and k not in attrs: + attrs[k] = run_attrs[k] + if attrs: + with propagate_attributes(**attrs): return self.langfuse.start_observation(**kwargs) return self.langfuse.start_observation(**kwargs) + def _reset_last_usage(self) -> None: + """Clear the model's per-call usage so a failed call that returns before + updating it cannot leak the previous call's usage into this run.""" + if hasattr(self.mdl, "last_usage"): + self.mdl.last_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + + def _report_usage(self, total_tokens: int) -> dict: + """Record a chat call's usage to the active agent run and return the + prompt/completion/total split for Langfuse. + + ``total_tokens`` is the authoritative total from the call. The prompt/completion + split is taken from the provider response (``mdl.last_usage``) only when it is + consistent with ``total_tokens`` (i.e. produced by this same call); otherwise the + split is reported as 0 while the total still aggregates correctly. + """ + split = getattr(self.mdl, "last_usage", None) or {} + prompt = int(split.get("prompt_tokens", 0) or 0) + completion = int(split.get("completion_tokens", 0) or 0) + if not total_tokens: + total_tokens = int(split.get("total_tokens", 0) or 0) + if (prompt + completion) != total_tokens: + # Stale or inconsistent split — keep the total, drop the unreliable split. + prompt, completion = 0, 0 + record_run_token_usage(prompt, completion, total_tokens) + return {"input": prompt, "output": completion, "total": total_tokens} + def close(self): """Release resources held by this LLMBundle instance.""" super().close() @@ -139,7 +177,9 @@ class LLMBundle(LLM4Tenant): def similarity(self, query: str, texts: list): if self.langfuse: - generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts}) + generation = self._start_langfuse_observation( + trace_context=self.trace_context, as_type="generation", name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts} + ) sim, used_tokens = self.mdl.similarity(query, texts) logging.info("LLMBundle.similarity used_tokens: %d", used_tokens) @@ -165,7 +205,9 @@ class LLMBundle(LLM4Tenant): def describe_with_prompt(self, image, prompt): if self.langfuse: - generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt}) + generation = self._start_langfuse_observation( + trace_context=self.trace_context, as_type="generation", name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt} + ) txt, used_tokens = self.mdl.describe_with_prompt(image, prompt) logging.info("LLMBundle.describe_with_prompt used_tokens: %d", used_tokens) @@ -194,7 +236,8 @@ class LLMBundle(LLM4Tenant): supports_stream = hasattr(mdl, "stream_transcription") and callable(getattr(mdl, "stream_transcription")) if supports_stream: if self.langfuse: - generation = self._start_langfuse_observation(as_type="generation", + generation = self._start_langfuse_observation( + as_type="generation", trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.model_config["llm_name"]}, @@ -228,7 +271,8 @@ class LLMBundle(LLM4Tenant): return if self.langfuse: - generation = self._start_langfuse_observation(as_type="generation", + generation = self._start_langfuse_observation( + as_type="generation", trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.model_config["llm_name"]}, @@ -377,11 +421,14 @@ class LLMBundle(LLM4Tenant): generation = None if self.langfuse: - generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history}) + generation = self._start_langfuse_observation( + trace_context=self.trace_context, as_type="generation", name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history} + ) chat_partial = partial(base_fn, system, history, gen_conf) use_kwargs = self._clean_param(chat_partial, **kwargs) + self._reset_last_usage() try: txt, used_tokens = await chat_partial(**use_kwargs) except Exception as e: @@ -397,8 +444,10 @@ class LLMBundle(LLM4Tenant): if used_tokens: logging.info("LLMBundle.async_chat used_tokens: %d", used_tokens) + usage_details = self._report_usage(used_tokens) + if generation: - generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens}) + generation.update(output={"output": txt}, usage_details=usage_details) generation.end() return txt @@ -418,11 +467,14 @@ class LLMBundle(LLM4Tenant): generation = None if self.langfuse: - generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history}) + generation = self._start_langfuse_observation( + trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history} + ) if stream_fn: chat_partial = partial(stream_fn, system, history, gen_conf) use_kwargs = self._clean_param(chat_partial, **kwargs) + self._reset_last_usage() try: async for txt in chat_partial(**use_kwargs): if isinstance(txt, int): @@ -444,8 +496,9 @@ class LLMBundle(LLM4Tenant): raise if total_tokens: logging.info("LLMBundle.async_chat_streamly used_tokens: %d", total_tokens) + usage_details = self._report_usage(total_tokens) if generation: - generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens}) + generation.update(output={"output": ans}, usage_details=usage_details) generation.end() return @@ -461,11 +514,14 @@ class LLMBundle(LLM4Tenant): generation = None if self.langfuse: - generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history}) + generation = self._start_langfuse_observation( + trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history} + ) if stream_fn: chat_partial = partial(stream_fn, system, history, gen_conf) use_kwargs = self._clean_param(chat_partial, **kwargs) + self._reset_last_usage() try: async for txt in chat_partial(**use_kwargs): if isinstance(txt, int): @@ -487,7 +543,8 @@ class LLMBundle(LLM4Tenant): raise if total_tokens: logging.info("LLMBundle.async_chat_streamly_delta used_tokens: %d", total_tokens) + usage_details = self._report_usage(total_tokens) if generation: - generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens}) + generation.update(output={"output": ans}, usage_details=usage_details) generation.end() return diff --git a/common/token_utils.py b/common/token_utils.py index 3a24e6cdc5..75027737e6 100644 --- a/common/token_utils.py +++ b/common/token_utils.py @@ -14,9 +14,12 @@ # limitations under the License. # +import contextvars import hashlib +import logging import os import shutil +import threading import tiktoken from common.file_utils import get_project_base_directory @@ -42,6 +45,84 @@ os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir encoder = tiktoken.get_encoding("cl100k_base") +# Per-run token usage sink. An agent run (Canvas.run) installs a mutable dict here +# at the start of each turn; every LLMBundle chat call adds its provider-reported +# usage to it. This is the single chokepoint that aggregates token usage across all +# LLM calls in a run (query rewriting, cross-language translation, tool reasoning, +# and the final streamed answer) regardless of which component or helper issued the +# call. Default None means "not inside a tracked run" and callers must no-op. +token_usage_sink: contextvars.ContextVar = contextvars.ContextVar("ragflow_token_usage_sink", default=None) + +# Per-run Langfuse correlating attributes (e.g. {"session_id": ..., "user_id": ...}). +# Installed by Canvas.run so RAGFlow's own Langfuse generations can be grouped by +# session and user even though the agent's LLMBundles are created without them. +langfuse_run_attrs: contextvars.ContextVar = contextvars.ContextVar("ragflow_langfuse_run_attrs", default=None) + + +# Guards sink mutations: concurrent tool calls (asyncio.gather + thread_pool_exec, +# which copies the context so worker threads share the same sink dict) can otherwise +# race on the read-modify-write of the counters. +_sink_lock = threading.Lock() + + +def record_run_token_usage(prompt_tokens: int = 0, completion_tokens: int = 0, total_tokens: int = 0) -> None: + """Add a single LLM call's token usage to the active run sink, if any. + + Safe to call from anywhere: when no run sink is installed it does nothing. + """ + sink = token_usage_sink.get() + if sink is None: + return + try: + with _sink_lock: + sink["prompt_tokens"] += int(prompt_tokens or 0) + sink["completion_tokens"] += int(completion_tokens or 0) + sink["total_tokens"] += int(total_tokens or 0) + sink["calls"] += 1 + except Exception: + # Never let usage bookkeeping break a request; log at debug so a malformed + # sink or token value is still traceable without adding noise. + logging.debug("Failed to record run token usage", exc_info=True) + + +def usage_from_response(resp) -> dict: + """Extract a {prompt_tokens, completion_tokens, total_tokens} split from an LLM response. + + Handles OpenAI/OpenRouter-style ``resp.usage`` objects and dict variants. Missing + fields default to 0; ``total_tokens`` falls back to prompt+completion when absent. + """ + out = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + if resp is None: + return out + + usage = None + try: + usage = getattr(resp, "usage", None) + if usage is None and isinstance(resp, dict): + usage = resp.get("usage") + except Exception: + usage = None + if usage is None: + return out + + def _get(obj, *names): + for n in names: + try: + v = obj.get(n) if isinstance(obj, dict) else getattr(obj, n, None) + except Exception: + v = None + if v: + return int(v) + return 0 + + out["prompt_tokens"] = _get(usage, "prompt_tokens", "input_tokens") + out["completion_tokens"] = _get(usage, "completion_tokens", "output_tokens") + out["total_tokens"] = _get(usage, "total_tokens") + if not out["total_tokens"]: + out["total_tokens"] = out["prompt_tokens"] + out["completion_tokens"] + return out + + def num_tokens_from_string(string: str) -> int: """Returns the number of tokens in a text string.""" try: @@ -50,6 +131,7 @@ def num_tokens_from_string(string: str) -> int: except Exception: return 0 + def total_token_count_from_response(resp): """ Extract token count from LLM response in various formats. @@ -78,19 +160,19 @@ def total_token_count_from_response(resp): except Exception: pass - if isinstance(resp, dict) and 'usage' in resp and 'total_tokens' in resp['usage']: + if isinstance(resp, dict) and "usage" in resp and "total_tokens" in resp["usage"]: try: return resp["usage"]["total_tokens"] except Exception: pass - if isinstance(resp, dict) and 'usage' in resp and 'input_tokens' in resp['usage'] and 'output_tokens' in resp['usage']: + if isinstance(resp, dict) and "usage" in resp and "input_tokens" in resp["usage"] and "output_tokens" in resp["usage"]: try: return resp["usage"]["input_tokens"] + resp["usage"]["output_tokens"] except Exception: pass - if isinstance(resp, dict) and 'meta' in resp and 'tokens' in resp['meta'] and 'input_tokens' in resp['meta']['tokens'] and 'output_tokens' in resp['meta']['tokens']: + if isinstance(resp, dict) and "meta" in resp and "tokens" in resp["meta"] and "input_tokens" in resp["meta"]["tokens"] and "output_tokens" in resp["meta"]["tokens"]: try: return resp["meta"]["tokens"]["input_tokens"] + resp["meta"]["tokens"]["output_tokens"] except Exception: diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 169f37beda..c7ee06dea9 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -32,7 +32,7 @@ from openai import AsyncOpenAI, OpenAI from enum import StrEnum from common.misc_utils import thread_pool_exec -from common.token_utils import num_tokens_from_string, total_token_count_from_response +from common.token_utils import num_tokens_from_string, total_token_count_from_response, usage_from_response from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider from rag.llm.key_utils import _normalize_replicate_key from rag.llm.tool_decorator import FunctionToolSession, is_tool @@ -184,11 +184,7 @@ def _apply_model_family_policies( sanitized_gen_conf["n"] = 1 sanitized_gen_conf["presence_penalty"] = 0.0 sanitized_gen_conf["frequency_penalty"] = 0.0 - elif ( - provider == SupportedLiteLLMProvider.ZHIPU_AI - and "glm" in model_name_lower - and thinking_type - ): + elif provider == SupportedLiteLLMProvider.ZHIPU_AI and "glm" in model_name_lower and thinking_type: _pop_thinking_controls() sanitized_gen_conf["thinking"] = {"type": thinking_type} @@ -217,6 +213,7 @@ def _move_litellm_provider_body_fields(provider: SupportedLiteLLMProvider | str completion_args["extra_body"] = body return completion_args + class Base(ABC): def __init__(self, key, model_name, base_url, **kwargs): timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600)) @@ -230,6 +227,9 @@ class Base(ABC): self.is_tools = False self.tools = [] self.toolcall_sessions = {} + # Token usage split (prompt/completion/total) of the most recent chat call. + # Consumed by LLMBundle for accurate Langfuse reporting and run aggregation. + self.last_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} def _get_delay(self): return self.base_delay * random.uniform(10, 150) @@ -313,6 +313,8 @@ class Base(ABC): gen_conf = self._clean_conf(gen_conf) ans = "" total_tokens = 0 + # Reset so a stale split from a previous call can't leak into this one. + self.last_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} for attempt in range(self.max_retries + 1): try: @@ -478,6 +480,18 @@ class Base(ABC): ans = "" tk_count = 0 + # Aggregate prompt/completion/total across all tool-calling rounds. + agg_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + + def _add_round_usage(resp): + nonlocal tk_count + u = usage_from_response(resp) + agg_usage["prompt_tokens"] += u["prompt_tokens"] + agg_usage["completion_tokens"] += u["completion_tokens"] + agg_usage["total_tokens"] += u["total_tokens"] or total_token_count_from_response(resp) + tk_count = agg_usage["total_tokens"] + self.last_usage = dict(agg_usage) + hist = deepcopy(history) for attempt in range(self.max_retries + 1): history = deepcopy(hist) @@ -485,7 +499,7 @@ class Base(ABC): for _ in range(self.max_rounds + 1): logging.info(f"{self.tools=}") response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf, **extra_request_kwargs) - tk_count += total_token_count_from_response(response) + _add_round_usage(response) if not response.choices or not response.choices[0].message: raise Exception(f"500 response structure error. Response: {response}") @@ -505,9 +519,7 @@ class Base(ABC): try: args = json_repair.loads(tc.function.arguments) if not isinstance(args, dict): - raise TypeError( - f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}" - ) + raise TypeError(f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}") if hasattr(self.toolcall_session, "tool_call_async"): result = await self.toolcall_session.tool_call_async(name, args) else: @@ -527,7 +539,13 @@ class Base(ABC): history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) response, token_count = await self._async_chat(history, gen_conf) ans += response - tk_count += token_count + # _async_chat set self.last_usage to its own call; fold it into the aggregate. + _fb = getattr(self, "last_usage", None) or {} + agg_usage["prompt_tokens"] += int(_fb.get("prompt_tokens", 0) or 0) + agg_usage["completion_tokens"] += int(_fb.get("completion_tokens", 0) or 0) + agg_usage["total_tokens"] += int(_fb.get("total_tokens", 0) or token_count) + tk_count = agg_usage["total_tokens"] + self.last_usage = dict(agg_usage) return ans, tk_count except Exception as e: e = await self._exceptions_async(e, attempt) @@ -550,8 +568,23 @@ class Base(ABC): history.insert(0, {"role": "system", "content": system}) total_tokens = 0 + # Aggregate prompt/completion/total across all tool-calling rounds. The split is + # captured opportunistically when the provider reports usage on a chunk; otherwise + # only the (estimated) total accumulates. + agg_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} hist = deepcopy(history) + def _commit_round(round_usage, round_estimate): + nonlocal total_tokens + if round_usage and round_usage["total_tokens"]: + agg_usage["prompt_tokens"] += round_usage["prompt_tokens"] + agg_usage["completion_tokens"] += round_usage["completion_tokens"] + agg_usage["total_tokens"] += round_usage["total_tokens"] + else: + agg_usage["total_tokens"] += round_estimate + total_tokens = agg_usage["total_tokens"] + self.last_usage = dict(agg_usage) + for attempt in range(self.max_retries + 1): history = deepcopy(hist) try: @@ -559,12 +592,20 @@ class Base(ABC): reasoning_start = False logging.info(f"[ToolLoop] round={_round} model={self.model_name} tools={[t['function']['name'] for t in tools]}") - response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf, **extra_request_kwargs) + response = await self.async_client.chat.completions.create( + model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf, **extra_request_kwargs + ) final_tool_calls = {} answer = "" + round_estimate = 0 + round_usage = None async for resp in response: + _u = usage_from_response(resp) + if _u["total_tokens"]: + round_usage = _u + if not hasattr(resp, "choices") or not resp.choices: continue @@ -597,16 +638,17 @@ class Base(ABC): answer += delta.content yield delta.content - tol = total_token_count_from_response(resp) - if not tol: - total_tokens += num_tokens_from_string(delta.content) - else: - total_tokens = tol + if not _u["total_tokens"]: + round_estimate += num_tokens_from_string(delta.content) finish_reason = getattr(resp.choices[0], "finish_reason", "") if finish_reason == "length": yield self._length_stop("") + # Commit this round's tokens (each round is a separate provider + # request — accumulate, never overwrite). + _commit_round(round_usage, round_estimate) + if answer and not final_tool_calls: logging.info(f"[ToolLoop] round={_round} completed with text response, exiting") yield total_tokens @@ -617,9 +659,7 @@ class Base(ABC): try: args = json_repair.loads(tc.function.arguments) if not isinstance(args, dict): - raise TypeError( - f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}" - ) + raise TypeError(f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}") if hasattr(self.toolcall_session, "tool_call_async"): result = await self.toolcall_session.tool_call_async(name, args) else: @@ -655,19 +695,22 @@ class Base(ABC): **extra_request_kwargs, ) + fb_estimate = 0 + fb_usage = None async for resp in response: + _u = usage_from_response(resp) + if _u["total_tokens"]: + fb_usage = _u if not hasattr(resp, "choices") or not resp.choices: continue delta = resp.choices[0].delta if not hasattr(delta, "content") or delta.content is None: continue - tol = total_token_count_from_response(resp) - if not tol: - total_tokens += num_tokens_from_string(delta.content) - else: - total_tokens = tol + if not _u["total_tokens"]: + fb_estimate += num_tokens_from_string(delta.content) yield delta.content + _commit_round(fb_usage, fb_estimate) yield total_tokens return @@ -708,6 +751,8 @@ class Base(ABC): response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs) + # Capture prompt/completion split for accurate Langfuse + run aggregation. + self.last_usage = usage_from_response(response) if not response.choices or not response.choices[0].message or not response.choices[0].message.content: return "", 0 ans = response.choices[0].message.content.strip() @@ -1075,9 +1120,7 @@ class ReplicateChat(Base): msgs = [{"role": "system", "content": system}] system_msg = msgs[0]["content"] if msgs and msgs[0].get("role") == "system" else "" - prompt = "\n".join( - [item["role"] + ":" + item["content"] for item in msgs[-5:] if item.get("role") != "system"] - ) + prompt = "\n".join([item["role"] + ":" + item["content"] for item in msgs[-5:] if item.get("role") != "system"]) try: response = self.client.run( self.model_name, @@ -1550,6 +1593,9 @@ class LiteLLMBase(ABC): self.is_tools = False self.tools = [] self.toolcall_sessions = {} + # Token usage split (prompt/completion/total) of the most recent chat call. + # Consumed by LLMBundle for accurate Langfuse reporting and run aggregation. + self.last_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} # Factory specific fields if self.provider == SupportedLiteLLMProvider.OpenRouter: @@ -1638,6 +1684,9 @@ class LiteLLMBase(ABC): timeout=self.timeout, ) + # Capture the prompt/completion split for accurate per-call usage + # reporting (Langfuse + agent run aggregation). + self.last_usage = usage_from_response(response) if not response.choices or not response.choices[0].message or not response.choices[0].message.content: return "", 0 ans = response.choices[0].message.content.strip() @@ -1659,11 +1708,16 @@ class LiteLLMBase(ABC): gen_conf = self._clean_conf(gen_conf) reasoning_start = False total_tokens = 0 + # Reset so a stale split from a previous call can't leak into this one. + self.last_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf) stop = kwargs.get("stop") if stop: completion_args["stop"] = stop + # Ask the provider to include authoritative usage in the final streaming chunk. + # drop_params=True ensures this is silently ignored by providers that don't support it. + completion_args.setdefault("stream_options", {})["include_usage"] = True for attempt in range(self.max_retries + 1): try: @@ -1674,6 +1728,14 @@ class LiteLLMBase(ABC): ) async for resp in stream: + # Authoritative usage may arrive on a usage-only final chunk that + # carries no choices (OpenAI/OpenRouter with include_usage). Read it + # before the choices guard so the prompt/completion split is captured. + _usage = usage_from_response(resp) + if _usage["total_tokens"]: + total_tokens = _usage["total_tokens"] + self.last_usage = _usage + if not hasattr(resp, "choices") or not resp.choices: continue @@ -1692,10 +1754,9 @@ class LiteLLMBase(ABC): reasoning_start = False ans = delta.content - tol = total_token_count_from_response(resp) - if not tol: - tol = num_tokens_from_string(delta.content) - total_tokens += tol + if not _usage["total_tokens"]: + # No authoritative usage yet: keep a running estimate as fallback. + total_tokens += num_tokens_from_string(delta.content) finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else "" if finish_reason == "length": @@ -1745,11 +1806,15 @@ class LiteLLMBase(ABC): return msg def _verbose_tool_use(self, name, args, res): - return "" + json.dumps( - {"name": name, "args": args, "result": str(res) if isinstance(res, Exception) else res}, - ensure_ascii=False, - indent=2, - ) + "" + return ( + "" + + json.dumps( + {"name": name, "args": args, "result": str(res) if isinstance(res, Exception) else res}, + ensure_ascii=False, + indent=2, + ) + + "" + ) def _append_history(self, hist, tool_call, tool_res, reasoning_content=None): assistant_msg = { @@ -1844,6 +1909,19 @@ class LiteLLMBase(ABC): ans = "" tk_count = 0 + # Aggregate prompt/completion/total across every tool-calling round so the + # whole multi-round exchange is reported once with a correct split. + agg_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + + def _add_usage(resp): + nonlocal tk_count + u = usage_from_response(resp) + agg_usage["prompt_tokens"] += u["prompt_tokens"] + agg_usage["completion_tokens"] += u["completion_tokens"] + agg_usage["total_tokens"] += u["total_tokens"] or total_token_count_from_response(resp) + tk_count = agg_usage["total_tokens"] + self.last_usage = dict(agg_usage) + hist = deepcopy(history) for attempt in range(self.max_retries + 1): history = deepcopy(hist) @@ -1858,7 +1936,7 @@ class LiteLLMBase(ABC): timeout=self.timeout, ) - tk_count += total_token_count_from_response(response) + _add_usage(response) if not hasattr(response, "choices") or not response.choices or not response.choices[0].message: raise Exception(f"500 response structure error. Response: {response}") @@ -1906,7 +1984,13 @@ class LiteLLMBase(ABC): response, token_count = await self.async_chat("", history, gen_conf) ans += response - tk_count += token_count + # self.async_chat set self.last_usage to its own call; fold it into the aggregate. + _fb = getattr(self, "last_usage", None) or {} + agg_usage["prompt_tokens"] += int(_fb.get("prompt_tokens", 0) or 0) + agg_usage["completion_tokens"] += int(_fb.get("completion_tokens", 0) or 0) + agg_usage["total_tokens"] += int(_fb.get("total_tokens", 0) or token_count) + tk_count = agg_usage["total_tokens"] + self.last_usage = dict(agg_usage) return ans, tk_count except Exception as e: @@ -1924,8 +2008,23 @@ class LiteLLMBase(ABC): history.insert(0, {"role": "system", "content": system}) total_tokens = 0 + # Aggregate usage across every tool-calling round (each round is a separate + # provider request). Committing per round avoids the previous bug where a later + # round's total overwrote earlier rounds. + agg_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} hist = deepcopy(history) + def _commit_round(round_usage, round_estimate): + nonlocal total_tokens + if round_usage and round_usage["total_tokens"]: + agg_usage["prompt_tokens"] += round_usage["prompt_tokens"] + agg_usage["completion_tokens"] += round_usage["completion_tokens"] + agg_usage["total_tokens"] += round_usage["total_tokens"] + else: + agg_usage["total_tokens"] += round_estimate + total_tokens = agg_usage["total_tokens"] + self.last_usage = dict(agg_usage) + for attempt in range(self.max_retries + 1): history = deepcopy(hist) try: @@ -1935,6 +2034,8 @@ class LiteLLMBase(ABC): logging.info(f"[ToolLoop] round={_round} model={self.model_name} tools={[t['function']['name'] for t in tools]}") completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf) + # Request authoritative usage on the final streaming chunk. + completion_args.setdefault("stream_options", {})["include_usage"] = True response = await litellm.acompletion( **completion_args, drop_params=True, @@ -1943,8 +2044,15 @@ class LiteLLMBase(ABC): final_tool_calls = {} answer = "" + round_usage = None + round_estimate = 0 async for resp in response: + # Usage-only final chunk may carry no choices — read it first. + _u = usage_from_response(resp) + if _u["total_tokens"]: + round_usage = _u + if not hasattr(resp, "choices") or not resp.choices: continue @@ -1979,16 +2087,16 @@ class LiteLLMBase(ABC): answer += delta.content yield delta.content - tol = total_token_count_from_response(resp) - if not tol: - total_tokens += num_tokens_from_string(delta.content) - else: - total_tokens = tol + if not _u["total_tokens"]: + round_estimate += num_tokens_from_string(delta.content) finish_reason = getattr(resp.choices[0], "finish_reason", "") if finish_reason == "length": yield self._length_stop("") + # Commit this round's tokens to the running aggregate. + _commit_round(round_usage, round_estimate) + if answer and not final_tool_calls: logging.info(f"[ToolLoop] round={_round} completed with text response, exiting") yield total_tokens @@ -2030,25 +2138,29 @@ class LiteLLMBase(ABC): history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf) + completion_args.setdefault("stream_options", {})["include_usage"] = True response = await litellm.acompletion( **completion_args, drop_params=True, timeout=self.timeout, ) + fb_usage = None + fb_estimate = 0 async for resp in response: + _u = usage_from_response(resp) + if _u["total_tokens"]: + fb_usage = _u if not hasattr(resp, "choices") or not resp.choices: continue delta = resp.choices[0].delta if not hasattr(delta, "content") or delta.content is None: continue - tol = total_token_count_from_response(resp) - if not tol: - total_tokens += num_tokens_from_string(delta.content) - else: - total_tokens = tol + if not _u["total_tokens"]: + fb_estimate += num_tokens_from_string(delta.content) yield delta.content + _commit_round(fb_usage, fb_estimate) yield total_tokens return @@ -2157,7 +2269,7 @@ class LiteLLMBase(ABC): if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers: extra_headers["Authorization"] = f"Bearer {self.api_key}" # MiniMax requires GroupId as a query parameter for API authentication - if self.provider == SupportedLiteLLMProvider.MiniMax and hasattr(self, 'group_id') and self.group_id: + if self.provider == SupportedLiteLLMProvider.MiniMax and hasattr(self, "group_id") and self.group_id: api_base = completion_args.get("api_base", self.base_url) separator = "&" if "?" in api_base else "?" completion_args["api_base"] = f"{api_base}{separator}GroupId={self.group_id}"