Port agent PRs to GO - 2 (#16565)

### Summary

Port the following PRs to GO in this PR

https://github.com/infiniflow/ragflow/pull/16420
https://github.com/infiniflow/ragflow/pull/13295
This commit is contained in:
qinling0210
2026-07-02 20:20:11 +08:00
committed by GitHub
parent 24118ac0d1
commit dcbd0d260c
8 changed files with 524 additions and 11 deletions

View File

@@ -9,6 +9,9 @@ package canvas
import (
"ragflow/internal/agent/runtime"
"ragflow/internal/common"
"go.uber.org/zap"
)
// legacyNoOpNames is the set of component names that the Go port
@@ -69,6 +72,64 @@ type CanvasComponentObj struct {
Params map[string]any `json:"params"`
}
// Close releases resources held by components referenced in the canvas
// DSL. It walks every component's params map and calls Close() on any
// value that implements a Close() method (MCPToolAdapters, HTTP
// clients, etc.). Mirrors Python's Graph.close() in agent/canvas.py.
//
// In Go's architecture MCP sessions are per-invocation and auto-torn
// down; Close() is a best-effort hook that ensures idle HTTP
// connections are released even when adapters outlive a single call.
func (c *Canvas) Close() {
if c == nil {
return
}
seen := make(map[any]bool)
for _, comp := range c.Components {
for _, v := range comp.Obj.Params {
walkAndClose(v, seen)
}
}
}
// walkAndClose recursively walks a value and calls Close() on any
// objects that implement a Close() method. Maps, slices, and pointers
// are recursed into; other types are skipped. Already-seen objects
// (by interface identity) are skipped to avoid double-close.
func walkAndClose(v any, seen map[any]bool) {
if v == nil {
return
}
if closer, ok := v.(interface{ Close() }); ok {
if !seen[closer] {
seen[closer] = true
safeClose(closer)
}
return
}
switch val := v.(type) {
case map[string]any:
for _, child := range val {
walkAndClose(child, seen)
}
case []any:
for _, child := range val {
walkAndClose(child, seen)
}
}
}
// safeClose calls Close() on a closer value, swallowing panics so a
// misbehaving resource doesn't crash the canvas tear-down path.
func safeClose(closer interface{ Close() }) {
defer func() {
if rec := recover(); rec != nil {
common.Warn("canvas: Close() panicked", zap.Any("recover", rec))
}
}()
closer.Close()
}
// Component is an alias for runtime.Component — the minimal runtime
// surface BuildWorkflow needs at sub-graph build time. The canonical
// definition (and the SetDefaultFactory / DefaultFactory plumbing)

View File

@@ -152,6 +152,19 @@ func (m *MCPToolAdapter) InvokableRun(ctx context.Context, argumentsInJSON strin
return res.Text, nil
}
// Close releases resources held by the adapter. In Go's architecture
// MCP sessions are per-invocation (created and torn down within each
// InvokableRun call), so there are no persistent connections to drain.
// The primary resource is the http.Client's idle-connection pool;
// calling Close explicitly drops those idle connections so they don't
// accumulate across many adapter instances over long-running processes.
// Mirrors Python's close_sync() in common/mcp_tool_call_conn.py.
func (m *MCPToolAdapter) Close() {
if m.httpClient != nil {
m.httpClient.CloseIdleConnections()
}
}
// BuildMCPToolAdapters wraps a slice of mcpclient.Tool descriptors as
// eino InvokableTool. Returned slice is suitable for handing to
// agenttool.NewRetrieverTool / NewMCPToolAdapter paths or directly to

View File

@@ -25,6 +25,20 @@ import (
"ragflow/internal/tokenizer"
)
// recordUsageFromResponse records one chat call's token usage to both
// cm.LastUsage and the context-level run sink (if installed). Callers
// should invoke this after each ChatWithMessages / ChatStreamlyWithSender
// call so the canvas-layer aggregator and Langfuse both see the split.
func recordUsageFromResponse(ctx context.Context, cm *ChatModel) {
if cm == nil {
return
}
if cm.LastUsage == nil {
return
}
tokenizer.RecordRunTokenUsage(ctx, cm.LastUsage.PromptTokens, cm.LastUsage.CompletionTokens, cm.LastUsage.TotalTokens)
}
const (
defaultMaxRetries = 3
defaultMaxRounds = 5
@@ -77,7 +91,32 @@ func (cm *ChatModel) ChatWithTools(ctx context.Context, system string, history [
}
func runToolLoop(ctx context.Context, cm *ChatModel, history []Message, toolsList interface{}, chatCfg *ChatConfig, maxRounds int) (string, int, error) {
// Aggregate prompt/completion/total across all tool-calling rounds.
// Mirrors Python PR #16420 fix: previously the total was overwritten
// each round; now we accumulate so multi-round tool conversations
// report the correct grand total.
// Reset stale per-call usage from a previous call so a response
// without usage doesn't leak the prior call's data.
cm.LastUsage = nil
var totalTokens int
aggUsage := &ChatUsage{}
addRoundUsage := func(resp *ChatResponse) {
u := resp.Usage
if u == nil {
return
}
aggUsage.PromptTokens += u.PromptTokens
aggUsage.CompletionTokens += u.CompletionTokens
aggUsage.TotalTokens += u.TotalTokens
totalTokens = aggUsage.TotalTokens
// Store per-round delta (not cumulative) so RecordRunTokenUsage
// records each round's contribution exactly once.
cm.LastUsage = &ChatUsage{
PromptTokens: u.PromptTokens, CompletionTokens: u.CompletionTokens, TotalTokens: u.TotalTokens,
}
recordUsageFromResponse(ctx, cm)
}
for round := 0; round <= maxRounds; round++ {
select {
@@ -97,6 +136,7 @@ func runToolLoop(ctx context.Context, cm *ChatModel, history []Message, toolsLis
if resp == nil {
return "", totalTokens, fmt.Errorf("round %d: nil response", round)
}
addRoundUsage(resp)
if len(resp.ToolCalls) == 0 {
answer := ""
@@ -106,7 +146,11 @@ func runToolLoop(ctx context.Context, cm *ChatModel, history []Message, toolsLis
if resp.ReasonContent != nil && *resp.ReasonContent != "" {
answer = "<think>" + *resp.ReasonContent + "</think>" + answer
}
totalTokens += tokenizer.NumTokensFromString(answer)
// Fallback: if the provider didn't return usage info,
// estimate from the answer text using tiktoken.
if resp.Usage == nil || resp.Usage.TotalTokens == 0 {
totalTokens += tokenizer.NumTokensFromString(answer)
}
return answer, totalTokens, nil
}
@@ -126,7 +170,11 @@ func runToolLoop(ctx context.Context, cm *ChatModel, history []Message, toolsLis
if resp == nil || resp.Answer == nil {
return "", totalTokens, fmt.Errorf("final call: no answer")
}
totalTokens += tokenizer.NumTokensFromString(*resp.Answer)
addRoundUsage(resp)
// Fallback: use text-based estimation if no authoritative usage.
if resp.Usage == nil || resp.Usage.TotalTokens == 0 {
totalTokens += tokenizer.NumTokensFromString(*resp.Answer)
}
return *resp.Answer, totalTokens, nil
}
@@ -177,7 +225,37 @@ func (cm *ChatModel) ChatStreamlyWithTools(ctx context.Context, system string, h
}
func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, toolsList interface{}, chatCfg *ChatConfig, maxRounds int, sender func(*string, *string) error) (int, error) {
// Aggregate token counts across every tool-calling round (each round is a
// separate provider request). Committing per round avoids the previous
// bug where a later round's total overwrote earlier rounds.
// Reset stale per-call usage from a previous call.
cm.LastUsage = nil
var totalTokens int
aggUsage := &ChatUsage{}
commitRound := func(cfg *ChatConfig, roundTokens int) {
// Prefer the authoritative usage from the API (extracted via
// stream_options.include_usage=true) over text-based token
// counting. Mirrors Python's usage_from_response accumulation
// in chat_model.py streaming handlers.
// Track per-round delta so RecordRunTokenUsage records each
// round's contribution exactly once (the sink does Add, not Set).
var deltaPrompt, deltaCompletion, deltaTotal int
if u := cfg.UsageResult; u != nil && u.TotalTokens > 0 {
deltaPrompt, deltaCompletion, deltaTotal = u.PromptTokens, u.CompletionTokens, u.TotalTokens
aggUsage.PromptTokens += deltaPrompt
aggUsage.CompletionTokens += deltaCompletion
aggUsage.TotalTokens += deltaTotal
} else {
deltaTotal = roundTokens
aggUsage.TotalTokens += deltaTotal
}
totalTokens = aggUsage.TotalTokens
cm.LastUsage = &ChatUsage{
PromptTokens: deltaPrompt, CompletionTokens: deltaCompletion, TotalTokens: deltaTotal,
}
recordUsageFromResponse(ctx, cm)
}
for round := 0; round <= maxRounds; round++ {
select {
@@ -192,10 +270,13 @@ func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, to
cfg.Stream = boolPtr(true)
var tcs []map[string]interface{}
cfg.ToolCallsResult = &tcs
var roundUsage ChatUsage
cfg.UsageResult = &roundUsage
reasoningStarted := false
var answer string
var pendingThinkClose bool
var roundTokens int
err := cm.ModelDriver.ChatStreamlyWithSender(*cm.ModelName, history, cm.APIConfig, &cfg, func(delta *string, reason *string) error {
if reason != nil && *reason != "" {
@@ -207,6 +288,7 @@ func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, to
}
}
pendingThinkClose = true
roundTokens += tokenizer.NumTokensFromString(*reason)
return sender(reason, nil)
}
// Reasoning ended, close the think block if open
@@ -221,7 +303,7 @@ func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, to
if *delta == "[DONE]" {
return nil
}
totalTokens += tokenizer.NumTokensFromString(*delta)
roundTokens += tokenizer.NumTokensFromString(*delta)
answer += *delta
if e := sender(delta, nil); e != nil {
return e
@@ -237,6 +319,9 @@ func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, to
return totalTokens, e
}
}
// Commit this round's token count to the running aggregate.
// Prefer authoritative API usage over text-based estimation.
commitRound(&cfg, roundTokens)
if err != nil {
return totalTokens, fmt.Errorf("round %d: %w", round, err)
}
@@ -263,7 +348,17 @@ func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, to
})
cfg := *chatCfg
cfg.Stream = boolPtr(true)
return totalTokens, cm.ModelDriver.ChatStreamlyWithSender(*cm.ModelName, history, cm.APIConfig, &cfg, sender)
var exceedUsage ChatUsage
cfg.UsageResult = &exceedUsage
var exceedTokens int
err := cm.ModelDriver.ChatStreamlyWithSender(*cm.ModelName, history, cm.APIConfig, &cfg, func(delta *string, reason *string) error {
if delta != nil && *delta != "" && *delta != "[DONE]" {
exceedTokens += tokenizer.NumTokensFromString(*delta)
}
return nil
})
commitRound(&cfg, exceedTokens)
return totalTokens, err
}
// appendToolResults executes tool calls concurrently, appends the assistant

View File

@@ -124,10 +124,23 @@ func (m *EinoChatModel) Generate(ctx context.Context, msgs []*schema.Message, op
if err := ctx.Err(); err != nil {
return nil, err
}
// Reset stale per-call usage before the call so that a response
// without a usage block doesn't leak the previous call's data.
// Mirrors Python's LLMBundle._reset_last_usage().
m.inner.LastUsage = nil
resp, err := m.inner.ModelDriver.ChatWithMessages(*m.inner.ModelName, internal, m.inner.APIConfig, m.chatCfg)
if err != nil {
return nil, fmt.Errorf("models: EinoChatModel.Generate(%s): %w", *m.inner.ModelName, err)
}
// Record the per-call token usage so the canvas-level aggregator (and
// Langfuse) can compute the run total. Mirrors Python's
// LLMBundle._report_usage() / self.mdl.last_usage pattern.
if resp != nil && resp.Usage != nil {
m.inner.LastUsage = &ChatUsage{
PromptTokens: resp.Usage.PromptTokens, CompletionTokens: resp.Usage.CompletionTokens, TotalTokens: resp.Usage.TotalTokens,
}
recordUsageFromResponse(ctx, m.inner)
}
return fromInternalResponse(resp), nil
}

View File

@@ -211,6 +211,16 @@ func (o *OpenAIModel) ChatWithMessages(modelName string, messages []Message, api
ToolCalls: toolCalls,
}
// Extract usage split (prompt/completion/total) from the raw API
// response for accurate per-call token accounting. Non-OpenAI
// providers that implement the OpenAI-compat API surface (DeepSeek,
// Moonshot, etc.) also return a "usage" key with the same shape.
if pt, ct, tt := extractUsageFromMap(result); tt > 0 {
chatResponse.Usage = &ChatUsage{
PromptTokens: pt, CompletionTokens: ct, TotalTokens: tt,
}
}
return chatResponse, nil
}
@@ -247,11 +257,17 @@ func (o *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Messag
apiMessages[i] = apiMsg
}
// Build request body with streaming on by default
// Build request body with streaming on by default.
// stream_options.include_usage asks the provider to attach a
// usage block to the final streaming chunk (mirrors Python's
// chat_model.py _stream_options / stream_options.include_usage).
reqBody := map[string]interface{}{
"model": modelName,
"messages": apiMessages,
"stream": true,
"stream_options": map[string]interface{}{
"include_usage": true,
},
}
if chatModelConfig != nil {
@@ -316,6 +332,12 @@ func (o *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Messag
sawTerminal := false
accumulatedToolCalls := make(map[int]map[string]interface{})
// Capture the authoritative usage block from the final streaming
// chunk (when provider honours stream_options.include_usage=true).
// The last chunk in the stream carries the "usage" key alongside
// empty choices; we overwrite on every chunk so the final frame
// wins, matching Python's chat_model.py usage_from_response loop.
var streamUsage *ChatUsage
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
@@ -340,6 +362,13 @@ func (o *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Messag
continue
}
// Extract usage from this chunk. When stream_options.include_usage
// is true, the final chunk carries the full usage breakdown at the
// top level of the event alongside (possibly empty) choices.
if pt, ct, tt := extractUsageFromMap(event); tt > 0 {
streamUsage = &ChatUsage{PromptTokens: pt, CompletionTokens: ct, TotalTokens: tt}
}
choices, ok := event["choices"].([]interface{})
if !ok || len(choices) == 0 {
continue
@@ -421,6 +450,11 @@ func (o *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Messag
chatModelConfig.ToolCallsResult = &tcs
}
// Populate UsageResult with the authoritative usage from the stream.
if streamUsage != nil && chatModelConfig != nil {
chatModelConfig.UsageResult = streamUsage
}
// Send the [DONE] marker for OpenAI compatibility
endOfStream := "[DONE]"
if err := sender(&endOfStream, nil); err != nil {
@@ -1024,6 +1058,50 @@ func (o *OpenAIModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskRespon
return nil, fmt.Errorf("%s, no such method", o.Name())
}
// extractUsageFromMap reads the "usage" key from an OpenAI-style API
// response and returns (prompt_tokens, completion_tokens, total_tokens).
// All return values are zero when the response carries no usage block.
func extractUsageFromMap(raw map[string]interface{}) (int, int, int) {
if raw == nil {
return 0, 0, 0
}
ru, ok := raw["usage"]
if !ok {
return 0, 0, 0
}
usage, ok := ru.(map[string]interface{})
if !ok {
return 0, 0, 0
}
get := func(keys ...string) int {
for _, k := range keys {
v, ok := usage[k]
if !ok {
continue
}
switch val := v.(type) {
case float64:
return int(val)
case int:
return val
case json.Number:
n, err := val.Int64()
if err == nil {
return int(n)
}
}
}
return 0
}
pt := get("prompt_tokens", "input_tokens")
ct := get("completion_tokens", "output_tokens")
tt := get("total_tokens")
if tt == 0 {
tt = pt + ct
}
return pt, ct, tt
}
func cloneMap(m map[string]interface{}) map[string]interface{} {
cp := make(map[string]interface{}, len(m))
for k, v := range m {

View File

@@ -60,6 +60,16 @@ type ChatResponse struct {
Answer *string `json:"answer"`
ReasonContent *string `json:"reason_content"`
ToolCalls []map[string]interface{} `json:"tool_calls,omitempty"`
Usage *ChatUsage `json:"usage,omitempty"`
}
// ChatUsage holds token usage split for one LLM call. Consumed by
// LLMBundle for accurate Langfuse reporting and run aggregation.
// Mirrors Python's common.token_utils.usage_from_response() split.
type ChatUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type EmbeddingData struct {
@@ -154,6 +164,12 @@ type ChatConfig struct {
Tools interface{} `json:"tools,omitempty"`
ToolChoice *string `json:"tool_choice,omitempty"`
ToolCallsResult *[]map[string]interface{} `json:"-"`
// UsageResult receives the token usage extracted from the final
// streaming chunk when stream_options.include_usage is true.
// The ChatStreamlyWithSender driver writes to this pointer (if
// non-nil) after the stream completes; callers read it the same
// way they read ToolCallsResult.
UsageResult *ChatUsage `json:"-"`
}
type APIConfig struct {
@@ -238,6 +254,10 @@ type ChatModel struct {
ModelName *string
APIConfig *APIConfig
ToolConfig *ToolConfig
// LastUsage holds the token usage (prompt/completion/total) of the most
// recent chat call. Consumed by callers for accurate Langfuse reporting
// and per-run token aggregation. Reset before each call.
LastUsage *ChatUsage
}
// NewChatModel creates a new ChatModel

View File

@@ -37,6 +37,7 @@ import (
"ragflow/internal/common"
"ragflow/internal/dao"
"ragflow/internal/entity"
"ragflow/internal/tokenizer"
dslpkg "ragflow/internal/agent/dsl"
)
@@ -835,11 +836,26 @@ func (s *AgentService) buildRunFunc(canvasID string, versionRow *entity.UserCanv
return nil, err
}
// Install a per-run token usage sink so every LLM call inside
// this turn records its token usage (the sink is read at the end
// and emitted in workflow_finished). Mirrors Python's
// Canvas.run() installing token_usage_sink + langfuse_run_attrs.
ctx = tokenizer.WithRunUsage(ctx)
// Extract the event channel + metadata injected by Runner.Run.
events, _ := root["__events__"].(chan canvas.RunEvent)
messageID, _ := root["__message_id__"].(string)
taskID, _ := root["__task_id__"].(string)
sessionID, _ := root["__session_id__"].(string)
userID, _ := root["user_id"].(string)
// Install per-run Langfuse correlation attrs so LLM calls inside
// this turn are grouped by session/user. Mirrors Python's
// Canvas.run() setting langfuse_run_attrs.
ctx = tokenizer.WithRunAttrs(ctx, &tokenizer.RunAttrs{
SessionID: sessionID,
UserID: userID,
})
// Helper to build an SSE event with metadata.
emit := func(typ, data string) {
@@ -855,6 +871,22 @@ func (s *AgentService) buildRunFunc(canvasID string, versionRow *entity.UserCanv
})
}
// usagePayload returns the aggregated per-run token usage as a
// JSON-serializable map, or nil when no sink was installed.
usagePayload := func() map[string]int {
sink := tokenizer.GetRunUsage(ctx)
if sink == nil {
return nil
}
pt, ct, tt, calls := sink.Snapshot()
return map[string]int{
"prompt_tokens": pt,
"completion_tokens": ct,
"total_tokens": tt,
"calls": calls,
}
}
startedAt := float64(time.Now().UnixNano()) / 1e9
userInput := root["user_input"]
@@ -883,7 +915,11 @@ func (s *AgentService) buildRunFunc(canvasID string, versionRow *entity.UserCanv
meData, _ := json.Marshal(canvas.MessageEndEvent{})
emit("message", string(msgData))
emit("message_end", string(meData))
wfData, _ := json.Marshal(map[string]any{"outputs": answer})
wfPayload := map[string]any{"outputs": answer}
if u := usagePayload(); u != nil {
wfPayload["usage"] = u
}
wfData, _ := json.Marshal(wfPayload)
emit("workflow_finished", string(wfData))
return state, nil
}
@@ -894,6 +930,10 @@ func (s *AgentService) buildRunFunc(canvasID string, versionRow *entity.UserCanv
s.markRunFailed(ctx, runID, "decode: "+err.Error())
return nil, err
}
// Close MCP tool adapters and any other closeable resources
// held by the canvas after execution completes. Mirrors
// Python's finally: canvas.close() in canvas_service.py.
defer c.Close()
// Store events channel + run metadata on the context so the
// per-node statePre/statePost wrappers (in scheduler.go) can
@@ -1080,12 +1120,16 @@ func (s *AgentService) buildRunFunc(canvasID string, versionRow *entity.UserCanv
})
emit("message_end", string(meData))
wfData, _ := json.Marshal(map[string]interface{}{
wfPayload := map[string]interface{}{
"inputs": map[string]any{"query": userInput},
"outputs": answer,
"elapsed_time": now - startedAt,
"created_at": now,
})
}
if u := usagePayload(); u != nil {
wfPayload["usage"] = u
}
wfData, _ := json.Marshal(wfPayload)
emit("workflow_finished", string(wfData))
s.markRunSucceeded(ctx2, runID)
@@ -1107,13 +1151,18 @@ func (s *AgentService) buildRunFunc(canvasID string, versionRow *entity.UserCanv
})
emit("message_end", string(meData))
// Emit workflow_finished with the final outputs.
wfData, _ := json.Marshal(map[string]interface{}{
// Emit workflow_finished with the final outputs and aggregated
// per-run token usage across all LLM calls in this turn.
wfPayload := map[string]interface{}{
"inputs": map[string]any{"query": userInput},
"outputs": answer,
"elapsed_time": now - startedAt,
"created_at": now,
})
}
if u := usagePayload(); u != nil {
wfPayload["usage"] = u
}
wfData, _ := json.Marshal(wfPayload)
emit("workflow_finished", string(wfData))
s.markRunSucceeded(ctx2, runID)

184
internal/tokenizer/usage.go Normal file
View File

@@ -0,0 +1,184 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Package tokenizer — per-run token usage tracking.
//
// An agent run installs a mutable token usage accumulator on the context
// (via WithRunUsage) at the start of each turn. Every LLM call inside
// that run adds its usage (prompt/completion/total tokens) to the sink
// via RecordRunTokenUsage. At the end of the run, the service layer
// reads the accumulated totals and emits them in the workflow_finished
// SSE event.
//
// This mirrors Python's common.token_utils:
// - token_usage_sink ContextVar → context.Context + runUsageKey
// - langfuse_run_attrs ContextVar → context.Context + runAttrsKey
// - record_run_token_usage() → RecordRunTokenUsage(ctx, ...)
// - usage_from_response() → UsageFromMap(raw)
package tokenizer
import (
"context"
"encoding/json"
"sync"
)
// Context key types — unexported to prevent direct external access.
type runUsageKeyType struct{}
type runAttrsKeyType struct{}
// RunUsage is the mutable per-run token usage accumulator installed on
// the context by the service layer at the start of a canvas turn.
// All fields are guarded by the embedded mutex because concurrent
// tool-calling goroutines (run_in_executor copies the context, so
// workers share the same sink) can race on read/modify/write.
type RunUsage struct {
mu sync.Mutex
PromptTokens int
CompletionTokens int
TotalTokens int
Calls int
}
// Add atomically adds a single LLM call's token counts to the sink.
// Safe to call concurrently from multiple goroutines.
func (u *RunUsage) Add(prompt, completion, total int) {
if u == nil {
return
}
u.mu.Lock()
defer u.mu.Unlock()
if prompt > 0 {
u.PromptTokens += prompt
}
if completion > 0 {
u.CompletionTokens += completion
}
if total > 0 {
u.TotalTokens += total
}
u.Calls++
}
// Snapshot returns a copy of the current cumulative counts.
func (u *RunUsage) Snapshot() (prompt, completion, total, calls int) {
if u == nil {
return 0, 0, 0, 0
}
u.mu.Lock()
defer u.mu.Unlock()
return u.PromptTokens, u.CompletionTokens, u.TotalTokens, u.Calls
}
// RunAttrs holds per-run Langfuse correlating attributes (session_id,
// user_id) installed on the context by the service layer.
type RunAttrs struct {
SessionID string
UserID string
}
// WithRunUsage installs a fresh RunUsage sink on ctx. Should be called
// once at the start of a canvas turn.
func WithRunUsage(ctx context.Context) context.Context {
return context.WithValue(ctx, runUsageKeyType{}, &RunUsage{})
}
// GetRunUsage retrieves the per-run token usage sink from ctx.
// Returns nil when no sink is installed (e.g. outside a canvas run).
func GetRunUsage(ctx context.Context) *RunUsage {
if v := ctx.Value(runUsageKeyType{}); v != nil {
if sink, ok := v.(*RunUsage); ok {
return sink
}
}
return nil
}
// WithRunAttrs installs Langfuse correlation attributes on ctx.
func WithRunAttrs(ctx context.Context, attrs *RunAttrs) context.Context {
if attrs == nil {
return ctx
}
return context.WithValue(ctx, runAttrsKeyType{}, attrs)
}
// GetRunAttrs retrieves the per-run Langfuse attributes from ctx.
func GetRunAttrs(ctx context.Context) *RunAttrs {
if v := ctx.Value(runAttrsKeyType{}); v != nil {
if attrs, ok := v.(*RunAttrs); ok {
return attrs
}
}
return nil
}
// RecordRunTokenUsage adds a single LLM call's token usage to the
// active run sink on ctx. Safe to call from anywhere; when no run sink
// is installed it is a no-op.
func RecordRunTokenUsage(ctx context.Context, promptTokens, completionTokens, totalTokens int) {
sink := GetRunUsage(ctx)
if sink == nil {
return
}
sink.Add(promptTokens, completionTokens, totalTokens)
}
// UsageFromMap extracts a token usage split from a raw API response map.
// Handles OpenAI/OpenRouter-style resp["usage"] dicts. Missing fields
// default to 0; total_tokens falls back to prompt+completion when absent.
// Returns nil when no usage found.
// Mirrors Python's common.token_utils.usage_from_response().
func UsageFromMap(raw map[string]interface{}) (promptTokens, completionTokens, totalTokens int) {
if raw == nil {
return 0, 0, 0
}
usageRaw, ok := raw["usage"]
if !ok {
return 0, 0, 0
}
usage, ok := usageRaw.(map[string]interface{})
if !ok {
return 0, 0, 0
}
pt := getInt(usage, "prompt_tokens", "input_tokens")
ct := getInt(usage, "completion_tokens", "output_tokens")
tt := getInt(usage, "total_tokens")
if tt == 0 {
tt = pt + ct
}
return pt, ct, tt
}
func getInt(m map[string]interface{}, keys ...string) int {
for _, k := range keys {
v, ok := m[k]
if !ok {
continue
}
switch val := v.(type) {
case float64:
return int(val)
case int:
return val
case json.Number:
n, err := val.Int64()
if err == nil {
return int(n)
}
}
}
return 0
}