From dcbd0d260cd764dc173f53a1bd1c00e3478ec9a4 Mon Sep 17 00:00:00 2001 From: qinling0210 <88864212+qinling0210@users.noreply.github.com> Date: Thu, 2 Jul 2026 20:20:11 +0800 Subject: [PATCH] Port agent PRs to GO - 2 (#16565) ### Summary Port the following PRs to GO in this PR https://github.com/infiniflow/ragflow/pull/16420 https://github.com/infiniflow/ragflow/pull/13295 --- internal/agent/canvas/canvas.go | 61 +++++++++ internal/agent/tool/mcp.go | 13 ++ internal/entity/models/chat_tools.go | 103 ++++++++++++++- internal/entity/models/llm.go | 13 ++ internal/entity/models/openai.go | 80 +++++++++++- internal/entity/models/types.go | 20 +++ internal/service/agent.go | 61 ++++++++- internal/tokenizer/usage.go | 184 +++++++++++++++++++++++++++ 8 files changed, 524 insertions(+), 11 deletions(-) create mode 100644 internal/tokenizer/usage.go diff --git a/internal/agent/canvas/canvas.go b/internal/agent/canvas/canvas.go index 1348751a8a..839d0e6670 100644 --- a/internal/agent/canvas/canvas.go +++ b/internal/agent/canvas/canvas.go @@ -9,6 +9,9 @@ package canvas import ( "ragflow/internal/agent/runtime" + "ragflow/internal/common" + + "go.uber.org/zap" ) // legacyNoOpNames is the set of component names that the Go port @@ -69,6 +72,64 @@ type CanvasComponentObj struct { Params map[string]any `json:"params"` } +// Close releases resources held by components referenced in the canvas +// DSL. It walks every component's params map and calls Close() on any +// value that implements a Close() method (MCPToolAdapters, HTTP +// clients, etc.). Mirrors Python's Graph.close() in agent/canvas.py. +// +// In Go's architecture MCP sessions are per-invocation and auto-torn +// down; Close() is a best-effort hook that ensures idle HTTP +// connections are released even when adapters outlive a single call. +func (c *Canvas) Close() { + if c == nil { + return + } + seen := make(map[any]bool) + for _, comp := range c.Components { + for _, v := range comp.Obj.Params { + walkAndClose(v, seen) + } + } +} + +// walkAndClose recursively walks a value and calls Close() on any +// objects that implement a Close() method. Maps, slices, and pointers +// are recursed into; other types are skipped. Already-seen objects +// (by interface identity) are skipped to avoid double-close. +func walkAndClose(v any, seen map[any]bool) { + if v == nil { + return + } + if closer, ok := v.(interface{ Close() }); ok { + if !seen[closer] { + seen[closer] = true + safeClose(closer) + } + return + } + switch val := v.(type) { + case map[string]any: + for _, child := range val { + walkAndClose(child, seen) + } + case []any: + for _, child := range val { + walkAndClose(child, seen) + } + } +} + +// safeClose calls Close() on a closer value, swallowing panics so a +// misbehaving resource doesn't crash the canvas tear-down path. +func safeClose(closer interface{ Close() }) { + defer func() { + if rec := recover(); rec != nil { + common.Warn("canvas: Close() panicked", zap.Any("recover", rec)) + } + }() + closer.Close() +} + // Component is an alias for runtime.Component — the minimal runtime // surface BuildWorkflow needs at sub-graph build time. The canonical // definition (and the SetDefaultFactory / DefaultFactory plumbing) diff --git a/internal/agent/tool/mcp.go b/internal/agent/tool/mcp.go index b9c005933d..9c64d5ecf5 100644 --- a/internal/agent/tool/mcp.go +++ b/internal/agent/tool/mcp.go @@ -152,6 +152,19 @@ func (m *MCPToolAdapter) InvokableRun(ctx context.Context, argumentsInJSON strin return res.Text, nil } +// Close releases resources held by the adapter. In Go's architecture +// MCP sessions are per-invocation (created and torn down within each +// InvokableRun call), so there are no persistent connections to drain. +// The primary resource is the http.Client's idle-connection pool; +// calling Close explicitly drops those idle connections so they don't +// accumulate across many adapter instances over long-running processes. +// Mirrors Python's close_sync() in common/mcp_tool_call_conn.py. +func (m *MCPToolAdapter) Close() { + if m.httpClient != nil { + m.httpClient.CloseIdleConnections() + } +} + // BuildMCPToolAdapters wraps a slice of mcpclient.Tool descriptors as // eino InvokableTool. Returned slice is suitable for handing to // agenttool.NewRetrieverTool / NewMCPToolAdapter paths or directly to diff --git a/internal/entity/models/chat_tools.go b/internal/entity/models/chat_tools.go index 09308c71ee..b04c1f5748 100644 --- a/internal/entity/models/chat_tools.go +++ b/internal/entity/models/chat_tools.go @@ -25,6 +25,20 @@ import ( "ragflow/internal/tokenizer" ) +// recordUsageFromResponse records one chat call's token usage to both +// cm.LastUsage and the context-level run sink (if installed). Callers +// should invoke this after each ChatWithMessages / ChatStreamlyWithSender +// call so the canvas-layer aggregator and Langfuse both see the split. +func recordUsageFromResponse(ctx context.Context, cm *ChatModel) { + if cm == nil { + return + } + if cm.LastUsage == nil { + return + } + tokenizer.RecordRunTokenUsage(ctx, cm.LastUsage.PromptTokens, cm.LastUsage.CompletionTokens, cm.LastUsage.TotalTokens) +} + const ( defaultMaxRetries = 3 defaultMaxRounds = 5 @@ -77,7 +91,32 @@ func (cm *ChatModel) ChatWithTools(ctx context.Context, system string, history [ } func runToolLoop(ctx context.Context, cm *ChatModel, history []Message, toolsList interface{}, chatCfg *ChatConfig, maxRounds int) (string, int, error) { + // Aggregate prompt/completion/total across all tool-calling rounds. + // Mirrors Python PR #16420 fix: previously the total was overwritten + // each round; now we accumulate so multi-round tool conversations + // report the correct grand total. + // Reset stale per-call usage from a previous call so a response + // without usage doesn't leak the prior call's data. + cm.LastUsage = nil var totalTokens int + aggUsage := &ChatUsage{} + + addRoundUsage := func(resp *ChatResponse) { + u := resp.Usage + if u == nil { + return + } + aggUsage.PromptTokens += u.PromptTokens + aggUsage.CompletionTokens += u.CompletionTokens + aggUsage.TotalTokens += u.TotalTokens + totalTokens = aggUsage.TotalTokens + // Store per-round delta (not cumulative) so RecordRunTokenUsage + // records each round's contribution exactly once. + cm.LastUsage = &ChatUsage{ + PromptTokens: u.PromptTokens, CompletionTokens: u.CompletionTokens, TotalTokens: u.TotalTokens, + } + recordUsageFromResponse(ctx, cm) + } for round := 0; round <= maxRounds; round++ { select { @@ -97,6 +136,7 @@ func runToolLoop(ctx context.Context, cm *ChatModel, history []Message, toolsLis if resp == nil { return "", totalTokens, fmt.Errorf("round %d: nil response", round) } + addRoundUsage(resp) if len(resp.ToolCalls) == 0 { answer := "" @@ -106,7 +146,11 @@ func runToolLoop(ctx context.Context, cm *ChatModel, history []Message, toolsLis if resp.ReasonContent != nil && *resp.ReasonContent != "" { answer = "" + *resp.ReasonContent + "" + answer } - totalTokens += tokenizer.NumTokensFromString(answer) + // Fallback: if the provider didn't return usage info, + // estimate from the answer text using tiktoken. + if resp.Usage == nil || resp.Usage.TotalTokens == 0 { + totalTokens += tokenizer.NumTokensFromString(answer) + } return answer, totalTokens, nil } @@ -126,7 +170,11 @@ func runToolLoop(ctx context.Context, cm *ChatModel, history []Message, toolsLis if resp == nil || resp.Answer == nil { return "", totalTokens, fmt.Errorf("final call: no answer") } - totalTokens += tokenizer.NumTokensFromString(*resp.Answer) + addRoundUsage(resp) + // Fallback: use text-based estimation if no authoritative usage. + if resp.Usage == nil || resp.Usage.TotalTokens == 0 { + totalTokens += tokenizer.NumTokensFromString(*resp.Answer) + } return *resp.Answer, totalTokens, nil } @@ -177,7 +225,37 @@ func (cm *ChatModel) ChatStreamlyWithTools(ctx context.Context, system string, h } func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, toolsList interface{}, chatCfg *ChatConfig, maxRounds int, sender func(*string, *string) error) (int, error) { + // Aggregate token counts across every tool-calling round (each round is a + // separate provider request). Committing per round avoids the previous + // bug where a later round's total overwrote earlier rounds. + // Reset stale per-call usage from a previous call. + cm.LastUsage = nil var totalTokens int + aggUsage := &ChatUsage{} + + commitRound := func(cfg *ChatConfig, roundTokens int) { + // Prefer the authoritative usage from the API (extracted via + // stream_options.include_usage=true) over text-based token + // counting. Mirrors Python's usage_from_response accumulation + // in chat_model.py streaming handlers. + // Track per-round delta so RecordRunTokenUsage records each + // round's contribution exactly once (the sink does Add, not Set). + var deltaPrompt, deltaCompletion, deltaTotal int + if u := cfg.UsageResult; u != nil && u.TotalTokens > 0 { + deltaPrompt, deltaCompletion, deltaTotal = u.PromptTokens, u.CompletionTokens, u.TotalTokens + aggUsage.PromptTokens += deltaPrompt + aggUsage.CompletionTokens += deltaCompletion + aggUsage.TotalTokens += deltaTotal + } else { + deltaTotal = roundTokens + aggUsage.TotalTokens += deltaTotal + } + totalTokens = aggUsage.TotalTokens + cm.LastUsage = &ChatUsage{ + PromptTokens: deltaPrompt, CompletionTokens: deltaCompletion, TotalTokens: deltaTotal, + } + recordUsageFromResponse(ctx, cm) + } for round := 0; round <= maxRounds; round++ { select { @@ -192,10 +270,13 @@ func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, to cfg.Stream = boolPtr(true) var tcs []map[string]interface{} cfg.ToolCallsResult = &tcs + var roundUsage ChatUsage + cfg.UsageResult = &roundUsage reasoningStarted := false var answer string var pendingThinkClose bool + var roundTokens int err := cm.ModelDriver.ChatStreamlyWithSender(*cm.ModelName, history, cm.APIConfig, &cfg, func(delta *string, reason *string) error { if reason != nil && *reason != "" { @@ -207,6 +288,7 @@ func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, to } } pendingThinkClose = true + roundTokens += tokenizer.NumTokensFromString(*reason) return sender(reason, nil) } // Reasoning ended, close the think block if open @@ -221,7 +303,7 @@ func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, to if *delta == "[DONE]" { return nil } - totalTokens += tokenizer.NumTokensFromString(*delta) + roundTokens += tokenizer.NumTokensFromString(*delta) answer += *delta if e := sender(delta, nil); e != nil { return e @@ -237,6 +319,9 @@ func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, to return totalTokens, e } } + // Commit this round's token count to the running aggregate. + // Prefer authoritative API usage over text-based estimation. + commitRound(&cfg, roundTokens) if err != nil { return totalTokens, fmt.Errorf("round %d: %w", round, err) } @@ -263,7 +348,17 @@ func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, to }) cfg := *chatCfg cfg.Stream = boolPtr(true) - return totalTokens, cm.ModelDriver.ChatStreamlyWithSender(*cm.ModelName, history, cm.APIConfig, &cfg, sender) + var exceedUsage ChatUsage + cfg.UsageResult = &exceedUsage + var exceedTokens int + err := cm.ModelDriver.ChatStreamlyWithSender(*cm.ModelName, history, cm.APIConfig, &cfg, func(delta *string, reason *string) error { + if delta != nil && *delta != "" && *delta != "[DONE]" { + exceedTokens += tokenizer.NumTokensFromString(*delta) + } + return nil + }) + commitRound(&cfg, exceedTokens) + return totalTokens, err } // appendToolResults executes tool calls concurrently, appends the assistant diff --git a/internal/entity/models/llm.go b/internal/entity/models/llm.go index 5d2c687c00..c78ad02862 100644 --- a/internal/entity/models/llm.go +++ b/internal/entity/models/llm.go @@ -124,10 +124,23 @@ func (m *EinoChatModel) Generate(ctx context.Context, msgs []*schema.Message, op if err := ctx.Err(); err != nil { return nil, err } + // Reset stale per-call usage before the call so that a response + // without a usage block doesn't leak the previous call's data. + // Mirrors Python's LLMBundle._reset_last_usage(). + m.inner.LastUsage = nil resp, err := m.inner.ModelDriver.ChatWithMessages(*m.inner.ModelName, internal, m.inner.APIConfig, m.chatCfg) if err != nil { return nil, fmt.Errorf("models: EinoChatModel.Generate(%s): %w", *m.inner.ModelName, err) } + // Record the per-call token usage so the canvas-level aggregator (and + // Langfuse) can compute the run total. Mirrors Python's + // LLMBundle._report_usage() / self.mdl.last_usage pattern. + if resp != nil && resp.Usage != nil { + m.inner.LastUsage = &ChatUsage{ + PromptTokens: resp.Usage.PromptTokens, CompletionTokens: resp.Usage.CompletionTokens, TotalTokens: resp.Usage.TotalTokens, + } + recordUsageFromResponse(ctx, m.inner) + } return fromInternalResponse(resp), nil } diff --git a/internal/entity/models/openai.go b/internal/entity/models/openai.go index 74f7d725dc..c8d2e5f37d 100644 --- a/internal/entity/models/openai.go +++ b/internal/entity/models/openai.go @@ -211,6 +211,16 @@ func (o *OpenAIModel) ChatWithMessages(modelName string, messages []Message, api ToolCalls: toolCalls, } + // Extract usage split (prompt/completion/total) from the raw API + // response for accurate per-call token accounting. Non-OpenAI + // providers that implement the OpenAI-compat API surface (DeepSeek, + // Moonshot, etc.) also return a "usage" key with the same shape. + if pt, ct, tt := extractUsageFromMap(result); tt > 0 { + chatResponse.Usage = &ChatUsage{ + PromptTokens: pt, CompletionTokens: ct, TotalTokens: tt, + } + } + return chatResponse, nil } @@ -247,11 +257,17 @@ func (o *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Messag apiMessages[i] = apiMsg } - // Build request body with streaming on by default + // Build request body with streaming on by default. + // stream_options.include_usage asks the provider to attach a + // usage block to the final streaming chunk (mirrors Python's + // chat_model.py _stream_options / stream_options.include_usage). reqBody := map[string]interface{}{ "model": modelName, "messages": apiMessages, "stream": true, + "stream_options": map[string]interface{}{ + "include_usage": true, + }, } if chatModelConfig != nil { @@ -316,6 +332,12 @@ func (o *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Messag sawTerminal := false accumulatedToolCalls := make(map[int]map[string]interface{}) + // Capture the authoritative usage block from the final streaming + // chunk (when provider honours stream_options.include_usage=true). + // The last chunk in the stream carries the "usage" key alongside + // empty choices; we overwrite on every chunk so the final frame + // wins, matching Python's chat_model.py usage_from_response loop. + var streamUsage *ChatUsage scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { line := scanner.Text() @@ -340,6 +362,13 @@ func (o *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Messag continue } + // Extract usage from this chunk. When stream_options.include_usage + // is true, the final chunk carries the full usage breakdown at the + // top level of the event alongside (possibly empty) choices. + if pt, ct, tt := extractUsageFromMap(event); tt > 0 { + streamUsage = &ChatUsage{PromptTokens: pt, CompletionTokens: ct, TotalTokens: tt} + } + choices, ok := event["choices"].([]interface{}) if !ok || len(choices) == 0 { continue @@ -421,6 +450,11 @@ func (o *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Messag chatModelConfig.ToolCallsResult = &tcs } + // Populate UsageResult with the authoritative usage from the stream. + if streamUsage != nil && chatModelConfig != nil { + chatModelConfig.UsageResult = streamUsage + } + // Send the [DONE] marker for OpenAI compatibility endOfStream := "[DONE]" if err := sender(&endOfStream, nil); err != nil { @@ -1024,6 +1058,50 @@ func (o *OpenAIModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskRespon return nil, fmt.Errorf("%s, no such method", o.Name()) } +// extractUsageFromMap reads the "usage" key from an OpenAI-style API +// response and returns (prompt_tokens, completion_tokens, total_tokens). +// All return values are zero when the response carries no usage block. +func extractUsageFromMap(raw map[string]interface{}) (int, int, int) { + if raw == nil { + return 0, 0, 0 + } + ru, ok := raw["usage"] + if !ok { + return 0, 0, 0 + } + usage, ok := ru.(map[string]interface{}) + if !ok { + return 0, 0, 0 + } + get := func(keys ...string) int { + for _, k := range keys { + v, ok := usage[k] + if !ok { + continue + } + switch val := v.(type) { + case float64: + return int(val) + case int: + return val + case json.Number: + n, err := val.Int64() + if err == nil { + return int(n) + } + } + } + return 0 + } + pt := get("prompt_tokens", "input_tokens") + ct := get("completion_tokens", "output_tokens") + tt := get("total_tokens") + if tt == 0 { + tt = pt + ct + } + return pt, ct, tt +} + func cloneMap(m map[string]interface{}) map[string]interface{} { cp := make(map[string]interface{}, len(m)) for k, v := range m { diff --git a/internal/entity/models/types.go b/internal/entity/models/types.go index 36d2ae3610..f13ee7751e 100644 --- a/internal/entity/models/types.go +++ b/internal/entity/models/types.go @@ -60,6 +60,16 @@ type ChatResponse struct { Answer *string `json:"answer"` ReasonContent *string `json:"reason_content"` ToolCalls []map[string]interface{} `json:"tool_calls,omitempty"` + Usage *ChatUsage `json:"usage,omitempty"` +} + +// ChatUsage holds token usage split for one LLM call. Consumed by +// LLMBundle for accurate Langfuse reporting and run aggregation. +// Mirrors Python's common.token_utils.usage_from_response() split. +type ChatUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` } type EmbeddingData struct { @@ -154,6 +164,12 @@ type ChatConfig struct { Tools interface{} `json:"tools,omitempty"` ToolChoice *string `json:"tool_choice,omitempty"` ToolCallsResult *[]map[string]interface{} `json:"-"` + // UsageResult receives the token usage extracted from the final + // streaming chunk when stream_options.include_usage is true. + // The ChatStreamlyWithSender driver writes to this pointer (if + // non-nil) after the stream completes; callers read it the same + // way they read ToolCallsResult. + UsageResult *ChatUsage `json:"-"` } type APIConfig struct { @@ -238,6 +254,10 @@ type ChatModel struct { ModelName *string APIConfig *APIConfig ToolConfig *ToolConfig + // LastUsage holds the token usage (prompt/completion/total) of the most + // recent chat call. Consumed by callers for accurate Langfuse reporting + // and per-run token aggregation. Reset before each call. + LastUsage *ChatUsage } // NewChatModel creates a new ChatModel diff --git a/internal/service/agent.go b/internal/service/agent.go index e663c6692a..a8e41b6d8e 100644 --- a/internal/service/agent.go +++ b/internal/service/agent.go @@ -37,6 +37,7 @@ import ( "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/entity" + "ragflow/internal/tokenizer" dslpkg "ragflow/internal/agent/dsl" ) @@ -835,11 +836,26 @@ func (s *AgentService) buildRunFunc(canvasID string, versionRow *entity.UserCanv return nil, err } + // Install a per-run token usage sink so every LLM call inside + // this turn records its token usage (the sink is read at the end + // and emitted in workflow_finished). Mirrors Python's + // Canvas.run() installing token_usage_sink + langfuse_run_attrs. + ctx = tokenizer.WithRunUsage(ctx) + // Extract the event channel + metadata injected by Runner.Run. events, _ := root["__events__"].(chan canvas.RunEvent) messageID, _ := root["__message_id__"].(string) taskID, _ := root["__task_id__"].(string) sessionID, _ := root["__session_id__"].(string) + userID, _ := root["user_id"].(string) + + // Install per-run Langfuse correlation attrs so LLM calls inside + // this turn are grouped by session/user. Mirrors Python's + // Canvas.run() setting langfuse_run_attrs. + ctx = tokenizer.WithRunAttrs(ctx, &tokenizer.RunAttrs{ + SessionID: sessionID, + UserID: userID, + }) // Helper to build an SSE event with metadata. emit := func(typ, data string) { @@ -855,6 +871,22 @@ func (s *AgentService) buildRunFunc(canvasID string, versionRow *entity.UserCanv }) } + // usagePayload returns the aggregated per-run token usage as a + // JSON-serializable map, or nil when no sink was installed. + usagePayload := func() map[string]int { + sink := tokenizer.GetRunUsage(ctx) + if sink == nil { + return nil + } + pt, ct, tt, calls := sink.Snapshot() + return map[string]int{ + "prompt_tokens": pt, + "completion_tokens": ct, + "total_tokens": tt, + "calls": calls, + } + } + startedAt := float64(time.Now().UnixNano()) / 1e9 userInput := root["user_input"] @@ -883,7 +915,11 @@ func (s *AgentService) buildRunFunc(canvasID string, versionRow *entity.UserCanv meData, _ := json.Marshal(canvas.MessageEndEvent{}) emit("message", string(msgData)) emit("message_end", string(meData)) - wfData, _ := json.Marshal(map[string]any{"outputs": answer}) + wfPayload := map[string]any{"outputs": answer} + if u := usagePayload(); u != nil { + wfPayload["usage"] = u + } + wfData, _ := json.Marshal(wfPayload) emit("workflow_finished", string(wfData)) return state, nil } @@ -894,6 +930,10 @@ func (s *AgentService) buildRunFunc(canvasID string, versionRow *entity.UserCanv s.markRunFailed(ctx, runID, "decode: "+err.Error()) return nil, err } + // Close MCP tool adapters and any other closeable resources + // held by the canvas after execution completes. Mirrors + // Python's finally: canvas.close() in canvas_service.py. + defer c.Close() // Store events channel + run metadata on the context so the // per-node statePre/statePost wrappers (in scheduler.go) can @@ -1080,12 +1120,16 @@ func (s *AgentService) buildRunFunc(canvasID string, versionRow *entity.UserCanv }) emit("message_end", string(meData)) - wfData, _ := json.Marshal(map[string]interface{}{ + wfPayload := map[string]interface{}{ "inputs": map[string]any{"query": userInput}, "outputs": answer, "elapsed_time": now - startedAt, "created_at": now, - }) + } + if u := usagePayload(); u != nil { + wfPayload["usage"] = u + } + wfData, _ := json.Marshal(wfPayload) emit("workflow_finished", string(wfData)) s.markRunSucceeded(ctx2, runID) @@ -1107,13 +1151,18 @@ func (s *AgentService) buildRunFunc(canvasID string, versionRow *entity.UserCanv }) emit("message_end", string(meData)) - // Emit workflow_finished with the final outputs. - wfData, _ := json.Marshal(map[string]interface{}{ + // Emit workflow_finished with the final outputs and aggregated + // per-run token usage across all LLM calls in this turn. + wfPayload := map[string]interface{}{ "inputs": map[string]any{"query": userInput}, "outputs": answer, "elapsed_time": now - startedAt, "created_at": now, - }) + } + if u := usagePayload(); u != nil { + wfPayload["usage"] = u + } + wfData, _ := json.Marshal(wfPayload) emit("workflow_finished", string(wfData)) s.markRunSucceeded(ctx2, runID) diff --git a/internal/tokenizer/usage.go b/internal/tokenizer/usage.go new file mode 100644 index 0000000000..891c7fea0c --- /dev/null +++ b/internal/tokenizer/usage.go @@ -0,0 +1,184 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package tokenizer — per-run token usage tracking. +// +// An agent run installs a mutable token usage accumulator on the context +// (via WithRunUsage) at the start of each turn. Every LLM call inside +// that run adds its usage (prompt/completion/total tokens) to the sink +// via RecordRunTokenUsage. At the end of the run, the service layer +// reads the accumulated totals and emits them in the workflow_finished +// SSE event. +// +// This mirrors Python's common.token_utils: +// - token_usage_sink ContextVar → context.Context + runUsageKey +// - langfuse_run_attrs ContextVar → context.Context + runAttrsKey +// - record_run_token_usage() → RecordRunTokenUsage(ctx, ...) +// - usage_from_response() → UsageFromMap(raw) +package tokenizer + +import ( + "context" + "encoding/json" + "sync" +) + +// Context key types — unexported to prevent direct external access. +type runUsageKeyType struct{} +type runAttrsKeyType struct{} + +// RunUsage is the mutable per-run token usage accumulator installed on +// the context by the service layer at the start of a canvas turn. +// All fields are guarded by the embedded mutex because concurrent +// tool-calling goroutines (run_in_executor copies the context, so +// workers share the same sink) can race on read/modify/write. +type RunUsage struct { + mu sync.Mutex + PromptTokens int + CompletionTokens int + TotalTokens int + Calls int +} + +// Add atomically adds a single LLM call's token counts to the sink. +// Safe to call concurrently from multiple goroutines. +func (u *RunUsage) Add(prompt, completion, total int) { + if u == nil { + return + } + u.mu.Lock() + defer u.mu.Unlock() + if prompt > 0 { + u.PromptTokens += prompt + } + if completion > 0 { + u.CompletionTokens += completion + } + if total > 0 { + u.TotalTokens += total + } + u.Calls++ +} + +// Snapshot returns a copy of the current cumulative counts. +func (u *RunUsage) Snapshot() (prompt, completion, total, calls int) { + if u == nil { + return 0, 0, 0, 0 + } + u.mu.Lock() + defer u.mu.Unlock() + return u.PromptTokens, u.CompletionTokens, u.TotalTokens, u.Calls +} + +// RunAttrs holds per-run Langfuse correlating attributes (session_id, +// user_id) installed on the context by the service layer. +type RunAttrs struct { + SessionID string + UserID string +} + +// WithRunUsage installs a fresh RunUsage sink on ctx. Should be called +// once at the start of a canvas turn. +func WithRunUsage(ctx context.Context) context.Context { + return context.WithValue(ctx, runUsageKeyType{}, &RunUsage{}) +} + +// GetRunUsage retrieves the per-run token usage sink from ctx. +// Returns nil when no sink is installed (e.g. outside a canvas run). +func GetRunUsage(ctx context.Context) *RunUsage { + if v := ctx.Value(runUsageKeyType{}); v != nil { + if sink, ok := v.(*RunUsage); ok { + return sink + } + } + return nil +} + +// WithRunAttrs installs Langfuse correlation attributes on ctx. +func WithRunAttrs(ctx context.Context, attrs *RunAttrs) context.Context { + if attrs == nil { + return ctx + } + return context.WithValue(ctx, runAttrsKeyType{}, attrs) +} + +// GetRunAttrs retrieves the per-run Langfuse attributes from ctx. +func GetRunAttrs(ctx context.Context) *RunAttrs { + if v := ctx.Value(runAttrsKeyType{}); v != nil { + if attrs, ok := v.(*RunAttrs); ok { + return attrs + } + } + return nil +} + +// RecordRunTokenUsage adds a single LLM call's token usage to the +// active run sink on ctx. Safe to call from anywhere; when no run sink +// is installed it is a no-op. +func RecordRunTokenUsage(ctx context.Context, promptTokens, completionTokens, totalTokens int) { + sink := GetRunUsage(ctx) + if sink == nil { + return + } + sink.Add(promptTokens, completionTokens, totalTokens) +} + +// UsageFromMap extracts a token usage split from a raw API response map. +// Handles OpenAI/OpenRouter-style resp["usage"] dicts. Missing fields +// default to 0; total_tokens falls back to prompt+completion when absent. +// Returns nil when no usage found. +// Mirrors Python's common.token_utils.usage_from_response(). +func UsageFromMap(raw map[string]interface{}) (promptTokens, completionTokens, totalTokens int) { + if raw == nil { + return 0, 0, 0 + } + usageRaw, ok := raw["usage"] + if !ok { + return 0, 0, 0 + } + usage, ok := usageRaw.(map[string]interface{}) + if !ok { + return 0, 0, 0 + } + pt := getInt(usage, "prompt_tokens", "input_tokens") + ct := getInt(usage, "completion_tokens", "output_tokens") + tt := getInt(usage, "total_tokens") + if tt == 0 { + tt = pt + ct + } + return pt, ct, tt +} + +func getInt(m map[string]interface{}, keys ...string) int { + for _, k := range keys { + v, ok := m[k] + if !ok { + continue + } + switch val := v.(type) { + case float64: + return int(val) + case int: + return val + case json.Number: + n, err := val.Int64() + if err == nil { + return int(n) + } + } + } + return 0 +}