mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-03 01:01:56 +08:00
Port agent PRs to GO - 2 (#16565)
### Summary Port the following PRs to GO in this PR https://github.com/infiniflow/ragflow/pull/16420 https://github.com/infiniflow/ragflow/pull/13295
This commit is contained in:
@@ -9,6 +9,9 @@ package canvas
|
||||
|
||||
import (
|
||||
"ragflow/internal/agent/runtime"
|
||||
"ragflow/internal/common"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// legacyNoOpNames is the set of component names that the Go port
|
||||
@@ -69,6 +72,64 @@ type CanvasComponentObj struct {
|
||||
Params map[string]any `json:"params"`
|
||||
}
|
||||
|
||||
// Close releases resources held by components referenced in the canvas
|
||||
// DSL. It walks every component's params map and calls Close() on any
|
||||
// value that implements a Close() method (MCPToolAdapters, HTTP
|
||||
// clients, etc.). Mirrors Python's Graph.close() in agent/canvas.py.
|
||||
//
|
||||
// In Go's architecture MCP sessions are per-invocation and auto-torn
|
||||
// down; Close() is a best-effort hook that ensures idle HTTP
|
||||
// connections are released even when adapters outlive a single call.
|
||||
func (c *Canvas) Close() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
seen := make(map[any]bool)
|
||||
for _, comp := range c.Components {
|
||||
for _, v := range comp.Obj.Params {
|
||||
walkAndClose(v, seen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// walkAndClose recursively walks a value and calls Close() on any
|
||||
// objects that implement a Close() method. Maps, slices, and pointers
|
||||
// are recursed into; other types are skipped. Already-seen objects
|
||||
// (by interface identity) are skipped to avoid double-close.
|
||||
func walkAndClose(v any, seen map[any]bool) {
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
if closer, ok := v.(interface{ Close() }); ok {
|
||||
if !seen[closer] {
|
||||
seen[closer] = true
|
||||
safeClose(closer)
|
||||
}
|
||||
return
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case map[string]any:
|
||||
for _, child := range val {
|
||||
walkAndClose(child, seen)
|
||||
}
|
||||
case []any:
|
||||
for _, child := range val {
|
||||
walkAndClose(child, seen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// safeClose calls Close() on a closer value, swallowing panics so a
|
||||
// misbehaving resource doesn't crash the canvas tear-down path.
|
||||
func safeClose(closer interface{ Close() }) {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
common.Warn("canvas: Close() panicked", zap.Any("recover", rec))
|
||||
}
|
||||
}()
|
||||
closer.Close()
|
||||
}
|
||||
|
||||
// Component is an alias for runtime.Component — the minimal runtime
|
||||
// surface BuildWorkflow needs at sub-graph build time. The canonical
|
||||
// definition (and the SetDefaultFactory / DefaultFactory plumbing)
|
||||
|
||||
@@ -152,6 +152,19 @@ func (m *MCPToolAdapter) InvokableRun(ctx context.Context, argumentsInJSON strin
|
||||
return res.Text, nil
|
||||
}
|
||||
|
||||
// Close releases resources held by the adapter. In Go's architecture
|
||||
// MCP sessions are per-invocation (created and torn down within each
|
||||
// InvokableRun call), so there are no persistent connections to drain.
|
||||
// The primary resource is the http.Client's idle-connection pool;
|
||||
// calling Close explicitly drops those idle connections so they don't
|
||||
// accumulate across many adapter instances over long-running processes.
|
||||
// Mirrors Python's close_sync() in common/mcp_tool_call_conn.py.
|
||||
func (m *MCPToolAdapter) Close() {
|
||||
if m.httpClient != nil {
|
||||
m.httpClient.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
|
||||
// BuildMCPToolAdapters wraps a slice of mcpclient.Tool descriptors as
|
||||
// eino InvokableTool. Returned slice is suitable for handing to
|
||||
// agenttool.NewRetrieverTool / NewMCPToolAdapter paths or directly to
|
||||
|
||||
@@ -25,6 +25,20 @@ import (
|
||||
"ragflow/internal/tokenizer"
|
||||
)
|
||||
|
||||
// recordUsageFromResponse records one chat call's token usage to both
|
||||
// cm.LastUsage and the context-level run sink (if installed). Callers
|
||||
// should invoke this after each ChatWithMessages / ChatStreamlyWithSender
|
||||
// call so the canvas-layer aggregator and Langfuse both see the split.
|
||||
func recordUsageFromResponse(ctx context.Context, cm *ChatModel) {
|
||||
if cm == nil {
|
||||
return
|
||||
}
|
||||
if cm.LastUsage == nil {
|
||||
return
|
||||
}
|
||||
tokenizer.RecordRunTokenUsage(ctx, cm.LastUsage.PromptTokens, cm.LastUsage.CompletionTokens, cm.LastUsage.TotalTokens)
|
||||
}
|
||||
|
||||
const (
|
||||
defaultMaxRetries = 3
|
||||
defaultMaxRounds = 5
|
||||
@@ -77,7 +91,32 @@ func (cm *ChatModel) ChatWithTools(ctx context.Context, system string, history [
|
||||
}
|
||||
|
||||
func runToolLoop(ctx context.Context, cm *ChatModel, history []Message, toolsList interface{}, chatCfg *ChatConfig, maxRounds int) (string, int, error) {
|
||||
// Aggregate prompt/completion/total across all tool-calling rounds.
|
||||
// Mirrors Python PR #16420 fix: previously the total was overwritten
|
||||
// each round; now we accumulate so multi-round tool conversations
|
||||
// report the correct grand total.
|
||||
// Reset stale per-call usage from a previous call so a response
|
||||
// without usage doesn't leak the prior call's data.
|
||||
cm.LastUsage = nil
|
||||
var totalTokens int
|
||||
aggUsage := &ChatUsage{}
|
||||
|
||||
addRoundUsage := func(resp *ChatResponse) {
|
||||
u := resp.Usage
|
||||
if u == nil {
|
||||
return
|
||||
}
|
||||
aggUsage.PromptTokens += u.PromptTokens
|
||||
aggUsage.CompletionTokens += u.CompletionTokens
|
||||
aggUsage.TotalTokens += u.TotalTokens
|
||||
totalTokens = aggUsage.TotalTokens
|
||||
// Store per-round delta (not cumulative) so RecordRunTokenUsage
|
||||
// records each round's contribution exactly once.
|
||||
cm.LastUsage = &ChatUsage{
|
||||
PromptTokens: u.PromptTokens, CompletionTokens: u.CompletionTokens, TotalTokens: u.TotalTokens,
|
||||
}
|
||||
recordUsageFromResponse(ctx, cm)
|
||||
}
|
||||
|
||||
for round := 0; round <= maxRounds; round++ {
|
||||
select {
|
||||
@@ -97,6 +136,7 @@ func runToolLoop(ctx context.Context, cm *ChatModel, history []Message, toolsLis
|
||||
if resp == nil {
|
||||
return "", totalTokens, fmt.Errorf("round %d: nil response", round)
|
||||
}
|
||||
addRoundUsage(resp)
|
||||
|
||||
if len(resp.ToolCalls) == 0 {
|
||||
answer := ""
|
||||
@@ -106,7 +146,11 @@ func runToolLoop(ctx context.Context, cm *ChatModel, history []Message, toolsLis
|
||||
if resp.ReasonContent != nil && *resp.ReasonContent != "" {
|
||||
answer = "<think>" + *resp.ReasonContent + "</think>" + answer
|
||||
}
|
||||
totalTokens += tokenizer.NumTokensFromString(answer)
|
||||
// Fallback: if the provider didn't return usage info,
|
||||
// estimate from the answer text using tiktoken.
|
||||
if resp.Usage == nil || resp.Usage.TotalTokens == 0 {
|
||||
totalTokens += tokenizer.NumTokensFromString(answer)
|
||||
}
|
||||
return answer, totalTokens, nil
|
||||
}
|
||||
|
||||
@@ -126,7 +170,11 @@ func runToolLoop(ctx context.Context, cm *ChatModel, history []Message, toolsLis
|
||||
if resp == nil || resp.Answer == nil {
|
||||
return "", totalTokens, fmt.Errorf("final call: no answer")
|
||||
}
|
||||
totalTokens += tokenizer.NumTokensFromString(*resp.Answer)
|
||||
addRoundUsage(resp)
|
||||
// Fallback: use text-based estimation if no authoritative usage.
|
||||
if resp.Usage == nil || resp.Usage.TotalTokens == 0 {
|
||||
totalTokens += tokenizer.NumTokensFromString(*resp.Answer)
|
||||
}
|
||||
return *resp.Answer, totalTokens, nil
|
||||
}
|
||||
|
||||
@@ -177,7 +225,37 @@ func (cm *ChatModel) ChatStreamlyWithTools(ctx context.Context, system string, h
|
||||
}
|
||||
|
||||
func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, toolsList interface{}, chatCfg *ChatConfig, maxRounds int, sender func(*string, *string) error) (int, error) {
|
||||
// Aggregate token counts across every tool-calling round (each round is a
|
||||
// separate provider request). Committing per round avoids the previous
|
||||
// bug where a later round's total overwrote earlier rounds.
|
||||
// Reset stale per-call usage from a previous call.
|
||||
cm.LastUsage = nil
|
||||
var totalTokens int
|
||||
aggUsage := &ChatUsage{}
|
||||
|
||||
commitRound := func(cfg *ChatConfig, roundTokens int) {
|
||||
// Prefer the authoritative usage from the API (extracted via
|
||||
// stream_options.include_usage=true) over text-based token
|
||||
// counting. Mirrors Python's usage_from_response accumulation
|
||||
// in chat_model.py streaming handlers.
|
||||
// Track per-round delta so RecordRunTokenUsage records each
|
||||
// round's contribution exactly once (the sink does Add, not Set).
|
||||
var deltaPrompt, deltaCompletion, deltaTotal int
|
||||
if u := cfg.UsageResult; u != nil && u.TotalTokens > 0 {
|
||||
deltaPrompt, deltaCompletion, deltaTotal = u.PromptTokens, u.CompletionTokens, u.TotalTokens
|
||||
aggUsage.PromptTokens += deltaPrompt
|
||||
aggUsage.CompletionTokens += deltaCompletion
|
||||
aggUsage.TotalTokens += deltaTotal
|
||||
} else {
|
||||
deltaTotal = roundTokens
|
||||
aggUsage.TotalTokens += deltaTotal
|
||||
}
|
||||
totalTokens = aggUsage.TotalTokens
|
||||
cm.LastUsage = &ChatUsage{
|
||||
PromptTokens: deltaPrompt, CompletionTokens: deltaCompletion, TotalTokens: deltaTotal,
|
||||
}
|
||||
recordUsageFromResponse(ctx, cm)
|
||||
}
|
||||
|
||||
for round := 0; round <= maxRounds; round++ {
|
||||
select {
|
||||
@@ -192,10 +270,13 @@ func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, to
|
||||
cfg.Stream = boolPtr(true)
|
||||
var tcs []map[string]interface{}
|
||||
cfg.ToolCallsResult = &tcs
|
||||
var roundUsage ChatUsage
|
||||
cfg.UsageResult = &roundUsage
|
||||
|
||||
reasoningStarted := false
|
||||
var answer string
|
||||
var pendingThinkClose bool
|
||||
var roundTokens int
|
||||
|
||||
err := cm.ModelDriver.ChatStreamlyWithSender(*cm.ModelName, history, cm.APIConfig, &cfg, func(delta *string, reason *string) error {
|
||||
if reason != nil && *reason != "" {
|
||||
@@ -207,6 +288,7 @@ func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, to
|
||||
}
|
||||
}
|
||||
pendingThinkClose = true
|
||||
roundTokens += tokenizer.NumTokensFromString(*reason)
|
||||
return sender(reason, nil)
|
||||
}
|
||||
// Reasoning ended, close the think block if open
|
||||
@@ -221,7 +303,7 @@ func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, to
|
||||
if *delta == "[DONE]" {
|
||||
return nil
|
||||
}
|
||||
totalTokens += tokenizer.NumTokensFromString(*delta)
|
||||
roundTokens += tokenizer.NumTokensFromString(*delta)
|
||||
answer += *delta
|
||||
if e := sender(delta, nil); e != nil {
|
||||
return e
|
||||
@@ -237,6 +319,9 @@ func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, to
|
||||
return totalTokens, e
|
||||
}
|
||||
}
|
||||
// Commit this round's token count to the running aggregate.
|
||||
// Prefer authoritative API usage over text-based estimation.
|
||||
commitRound(&cfg, roundTokens)
|
||||
if err != nil {
|
||||
return totalTokens, fmt.Errorf("round %d: %w", round, err)
|
||||
}
|
||||
@@ -263,7 +348,17 @@ func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, to
|
||||
})
|
||||
cfg := *chatCfg
|
||||
cfg.Stream = boolPtr(true)
|
||||
return totalTokens, cm.ModelDriver.ChatStreamlyWithSender(*cm.ModelName, history, cm.APIConfig, &cfg, sender)
|
||||
var exceedUsage ChatUsage
|
||||
cfg.UsageResult = &exceedUsage
|
||||
var exceedTokens int
|
||||
err := cm.ModelDriver.ChatStreamlyWithSender(*cm.ModelName, history, cm.APIConfig, &cfg, func(delta *string, reason *string) error {
|
||||
if delta != nil && *delta != "" && *delta != "[DONE]" {
|
||||
exceedTokens += tokenizer.NumTokensFromString(*delta)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
commitRound(&cfg, exceedTokens)
|
||||
return totalTokens, err
|
||||
}
|
||||
|
||||
// appendToolResults executes tool calls concurrently, appends the assistant
|
||||
|
||||
@@ -124,10 +124,23 @@ func (m *EinoChatModel) Generate(ctx context.Context, msgs []*schema.Message, op
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Reset stale per-call usage before the call so that a response
|
||||
// without a usage block doesn't leak the previous call's data.
|
||||
// Mirrors Python's LLMBundle._reset_last_usage().
|
||||
m.inner.LastUsage = nil
|
||||
resp, err := m.inner.ModelDriver.ChatWithMessages(*m.inner.ModelName, internal, m.inner.APIConfig, m.chatCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("models: EinoChatModel.Generate(%s): %w", *m.inner.ModelName, err)
|
||||
}
|
||||
// Record the per-call token usage so the canvas-level aggregator (and
|
||||
// Langfuse) can compute the run total. Mirrors Python's
|
||||
// LLMBundle._report_usage() / self.mdl.last_usage pattern.
|
||||
if resp != nil && resp.Usage != nil {
|
||||
m.inner.LastUsage = &ChatUsage{
|
||||
PromptTokens: resp.Usage.PromptTokens, CompletionTokens: resp.Usage.CompletionTokens, TotalTokens: resp.Usage.TotalTokens,
|
||||
}
|
||||
recordUsageFromResponse(ctx, m.inner)
|
||||
}
|
||||
return fromInternalResponse(resp), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -211,6 +211,16 @@ func (o *OpenAIModel) ChatWithMessages(modelName string, messages []Message, api
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
|
||||
// Extract usage split (prompt/completion/total) from the raw API
|
||||
// response for accurate per-call token accounting. Non-OpenAI
|
||||
// providers that implement the OpenAI-compat API surface (DeepSeek,
|
||||
// Moonshot, etc.) also return a "usage" key with the same shape.
|
||||
if pt, ct, tt := extractUsageFromMap(result); tt > 0 {
|
||||
chatResponse.Usage = &ChatUsage{
|
||||
PromptTokens: pt, CompletionTokens: ct, TotalTokens: tt,
|
||||
}
|
||||
}
|
||||
|
||||
return chatResponse, nil
|
||||
}
|
||||
|
||||
@@ -247,11 +257,17 @@ func (o *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Messag
|
||||
apiMessages[i] = apiMsg
|
||||
}
|
||||
|
||||
// Build request body with streaming on by default
|
||||
// Build request body with streaming on by default.
|
||||
// stream_options.include_usage asks the provider to attach a
|
||||
// usage block to the final streaming chunk (mirrors Python's
|
||||
// chat_model.py _stream_options / stream_options.include_usage).
|
||||
reqBody := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"messages": apiMessages,
|
||||
"stream": true,
|
||||
"stream_options": map[string]interface{}{
|
||||
"include_usage": true,
|
||||
},
|
||||
}
|
||||
|
||||
if chatModelConfig != nil {
|
||||
@@ -316,6 +332,12 @@ func (o *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Messag
|
||||
|
||||
sawTerminal := false
|
||||
accumulatedToolCalls := make(map[int]map[string]interface{})
|
||||
// Capture the authoritative usage block from the final streaming
|
||||
// chunk (when provider honours stream_options.include_usage=true).
|
||||
// The last chunk in the stream carries the "usage" key alongside
|
||||
// empty choices; we overwrite on every chunk so the final frame
|
||||
// wins, matching Python's chat_model.py usage_from_response loop.
|
||||
var streamUsage *ChatUsage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
@@ -340,6 +362,13 @@ func (o *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Messag
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract usage from this chunk. When stream_options.include_usage
|
||||
// is true, the final chunk carries the full usage breakdown at the
|
||||
// top level of the event alongside (possibly empty) choices.
|
||||
if pt, ct, tt := extractUsageFromMap(event); tt > 0 {
|
||||
streamUsage = &ChatUsage{PromptTokens: pt, CompletionTokens: ct, TotalTokens: tt}
|
||||
}
|
||||
|
||||
choices, ok := event["choices"].([]interface{})
|
||||
if !ok || len(choices) == 0 {
|
||||
continue
|
||||
@@ -421,6 +450,11 @@ func (o *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Messag
|
||||
chatModelConfig.ToolCallsResult = &tcs
|
||||
}
|
||||
|
||||
// Populate UsageResult with the authoritative usage from the stream.
|
||||
if streamUsage != nil && chatModelConfig != nil {
|
||||
chatModelConfig.UsageResult = streamUsage
|
||||
}
|
||||
|
||||
// Send the [DONE] marker for OpenAI compatibility
|
||||
endOfStream := "[DONE]"
|
||||
if err := sender(&endOfStream, nil); err != nil {
|
||||
@@ -1024,6 +1058,50 @@ func (o *OpenAIModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskRespon
|
||||
return nil, fmt.Errorf("%s, no such method", o.Name())
|
||||
}
|
||||
|
||||
// extractUsageFromMap reads the "usage" key from an OpenAI-style API
|
||||
// response and returns (prompt_tokens, completion_tokens, total_tokens).
|
||||
// All return values are zero when the response carries no usage block.
|
||||
func extractUsageFromMap(raw map[string]interface{}) (int, int, int) {
|
||||
if raw == nil {
|
||||
return 0, 0, 0
|
||||
}
|
||||
ru, ok := raw["usage"]
|
||||
if !ok {
|
||||
return 0, 0, 0
|
||||
}
|
||||
usage, ok := ru.(map[string]interface{})
|
||||
if !ok {
|
||||
return 0, 0, 0
|
||||
}
|
||||
get := func(keys ...string) int {
|
||||
for _, k := range keys {
|
||||
v, ok := usage[k]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return int(val)
|
||||
case int:
|
||||
return val
|
||||
case json.Number:
|
||||
n, err := val.Int64()
|
||||
if err == nil {
|
||||
return int(n)
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
pt := get("prompt_tokens", "input_tokens")
|
||||
ct := get("completion_tokens", "output_tokens")
|
||||
tt := get("total_tokens")
|
||||
if tt == 0 {
|
||||
tt = pt + ct
|
||||
}
|
||||
return pt, ct, tt
|
||||
}
|
||||
|
||||
func cloneMap(m map[string]interface{}) map[string]interface{} {
|
||||
cp := make(map[string]interface{}, len(m))
|
||||
for k, v := range m {
|
||||
|
||||
@@ -60,6 +60,16 @@ type ChatResponse struct {
|
||||
Answer *string `json:"answer"`
|
||||
ReasonContent *string `json:"reason_content"`
|
||||
ToolCalls []map[string]interface{} `json:"tool_calls,omitempty"`
|
||||
Usage *ChatUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// ChatUsage holds token usage split for one LLM call. Consumed by
|
||||
// LLMBundle for accurate Langfuse reporting and run aggregation.
|
||||
// Mirrors Python's common.token_utils.usage_from_response() split.
|
||||
type ChatUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type EmbeddingData struct {
|
||||
@@ -154,6 +164,12 @@ type ChatConfig struct {
|
||||
Tools interface{} `json:"tools,omitempty"`
|
||||
ToolChoice *string `json:"tool_choice,omitempty"`
|
||||
ToolCallsResult *[]map[string]interface{} `json:"-"`
|
||||
// UsageResult receives the token usage extracted from the final
|
||||
// streaming chunk when stream_options.include_usage is true.
|
||||
// The ChatStreamlyWithSender driver writes to this pointer (if
|
||||
// non-nil) after the stream completes; callers read it the same
|
||||
// way they read ToolCallsResult.
|
||||
UsageResult *ChatUsage `json:"-"`
|
||||
}
|
||||
|
||||
type APIConfig struct {
|
||||
@@ -238,6 +254,10 @@ type ChatModel struct {
|
||||
ModelName *string
|
||||
APIConfig *APIConfig
|
||||
ToolConfig *ToolConfig
|
||||
// LastUsage holds the token usage (prompt/completion/total) of the most
|
||||
// recent chat call. Consumed by callers for accurate Langfuse reporting
|
||||
// and per-run token aggregation. Reset before each call.
|
||||
LastUsage *ChatUsage
|
||||
}
|
||||
|
||||
// NewChatModel creates a new ChatModel
|
||||
|
||||
@@ -37,6 +37,7 @@ import (
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/tokenizer"
|
||||
|
||||
dslpkg "ragflow/internal/agent/dsl"
|
||||
)
|
||||
@@ -835,11 +836,26 @@ func (s *AgentService) buildRunFunc(canvasID string, versionRow *entity.UserCanv
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Install a per-run token usage sink so every LLM call inside
|
||||
// this turn records its token usage (the sink is read at the end
|
||||
// and emitted in workflow_finished). Mirrors Python's
|
||||
// Canvas.run() installing token_usage_sink + langfuse_run_attrs.
|
||||
ctx = tokenizer.WithRunUsage(ctx)
|
||||
|
||||
// Extract the event channel + metadata injected by Runner.Run.
|
||||
events, _ := root["__events__"].(chan canvas.RunEvent)
|
||||
messageID, _ := root["__message_id__"].(string)
|
||||
taskID, _ := root["__task_id__"].(string)
|
||||
sessionID, _ := root["__session_id__"].(string)
|
||||
userID, _ := root["user_id"].(string)
|
||||
|
||||
// Install per-run Langfuse correlation attrs so LLM calls inside
|
||||
// this turn are grouped by session/user. Mirrors Python's
|
||||
// Canvas.run() setting langfuse_run_attrs.
|
||||
ctx = tokenizer.WithRunAttrs(ctx, &tokenizer.RunAttrs{
|
||||
SessionID: sessionID,
|
||||
UserID: userID,
|
||||
})
|
||||
|
||||
// Helper to build an SSE event with metadata.
|
||||
emit := func(typ, data string) {
|
||||
@@ -855,6 +871,22 @@ func (s *AgentService) buildRunFunc(canvasID string, versionRow *entity.UserCanv
|
||||
})
|
||||
}
|
||||
|
||||
// usagePayload returns the aggregated per-run token usage as a
|
||||
// JSON-serializable map, or nil when no sink was installed.
|
||||
usagePayload := func() map[string]int {
|
||||
sink := tokenizer.GetRunUsage(ctx)
|
||||
if sink == nil {
|
||||
return nil
|
||||
}
|
||||
pt, ct, tt, calls := sink.Snapshot()
|
||||
return map[string]int{
|
||||
"prompt_tokens": pt,
|
||||
"completion_tokens": ct,
|
||||
"total_tokens": tt,
|
||||
"calls": calls,
|
||||
}
|
||||
}
|
||||
|
||||
startedAt := float64(time.Now().UnixNano()) / 1e9
|
||||
|
||||
userInput := root["user_input"]
|
||||
@@ -883,7 +915,11 @@ func (s *AgentService) buildRunFunc(canvasID string, versionRow *entity.UserCanv
|
||||
meData, _ := json.Marshal(canvas.MessageEndEvent{})
|
||||
emit("message", string(msgData))
|
||||
emit("message_end", string(meData))
|
||||
wfData, _ := json.Marshal(map[string]any{"outputs": answer})
|
||||
wfPayload := map[string]any{"outputs": answer}
|
||||
if u := usagePayload(); u != nil {
|
||||
wfPayload["usage"] = u
|
||||
}
|
||||
wfData, _ := json.Marshal(wfPayload)
|
||||
emit("workflow_finished", string(wfData))
|
||||
return state, nil
|
||||
}
|
||||
@@ -894,6 +930,10 @@ func (s *AgentService) buildRunFunc(canvasID string, versionRow *entity.UserCanv
|
||||
s.markRunFailed(ctx, runID, "decode: "+err.Error())
|
||||
return nil, err
|
||||
}
|
||||
// Close MCP tool adapters and any other closeable resources
|
||||
// held by the canvas after execution completes. Mirrors
|
||||
// Python's finally: canvas.close() in canvas_service.py.
|
||||
defer c.Close()
|
||||
|
||||
// Store events channel + run metadata on the context so the
|
||||
// per-node statePre/statePost wrappers (in scheduler.go) can
|
||||
@@ -1080,12 +1120,16 @@ func (s *AgentService) buildRunFunc(canvasID string, versionRow *entity.UserCanv
|
||||
})
|
||||
emit("message_end", string(meData))
|
||||
|
||||
wfData, _ := json.Marshal(map[string]interface{}{
|
||||
wfPayload := map[string]interface{}{
|
||||
"inputs": map[string]any{"query": userInput},
|
||||
"outputs": answer,
|
||||
"elapsed_time": now - startedAt,
|
||||
"created_at": now,
|
||||
})
|
||||
}
|
||||
if u := usagePayload(); u != nil {
|
||||
wfPayload["usage"] = u
|
||||
}
|
||||
wfData, _ := json.Marshal(wfPayload)
|
||||
emit("workflow_finished", string(wfData))
|
||||
|
||||
s.markRunSucceeded(ctx2, runID)
|
||||
@@ -1107,13 +1151,18 @@ func (s *AgentService) buildRunFunc(canvasID string, versionRow *entity.UserCanv
|
||||
})
|
||||
emit("message_end", string(meData))
|
||||
|
||||
// Emit workflow_finished with the final outputs.
|
||||
wfData, _ := json.Marshal(map[string]interface{}{
|
||||
// Emit workflow_finished with the final outputs and aggregated
|
||||
// per-run token usage across all LLM calls in this turn.
|
||||
wfPayload := map[string]interface{}{
|
||||
"inputs": map[string]any{"query": userInput},
|
||||
"outputs": answer,
|
||||
"elapsed_time": now - startedAt,
|
||||
"created_at": now,
|
||||
})
|
||||
}
|
||||
if u := usagePayload(); u != nil {
|
||||
wfPayload["usage"] = u
|
||||
}
|
||||
wfData, _ := json.Marshal(wfPayload)
|
||||
emit("workflow_finished", string(wfData))
|
||||
|
||||
s.markRunSucceeded(ctx2, runID)
|
||||
|
||||
184
internal/tokenizer/usage.go
Normal file
184
internal/tokenizer/usage.go
Normal file
@@ -0,0 +1,184 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
// Package tokenizer — per-run token usage tracking.
|
||||
//
|
||||
// An agent run installs a mutable token usage accumulator on the context
|
||||
// (via WithRunUsage) at the start of each turn. Every LLM call inside
|
||||
// that run adds its usage (prompt/completion/total tokens) to the sink
|
||||
// via RecordRunTokenUsage. At the end of the run, the service layer
|
||||
// reads the accumulated totals and emits them in the workflow_finished
|
||||
// SSE event.
|
||||
//
|
||||
// This mirrors Python's common.token_utils:
|
||||
// - token_usage_sink ContextVar → context.Context + runUsageKey
|
||||
// - langfuse_run_attrs ContextVar → context.Context + runAttrsKey
|
||||
// - record_run_token_usage() → RecordRunTokenUsage(ctx, ...)
|
||||
// - usage_from_response() → UsageFromMap(raw)
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Context key types — unexported to prevent direct external access.
|
||||
type runUsageKeyType struct{}
|
||||
type runAttrsKeyType struct{}
|
||||
|
||||
// RunUsage is the mutable per-run token usage accumulator installed on
|
||||
// the context by the service layer at the start of a canvas turn.
|
||||
// All fields are guarded by the embedded mutex because concurrent
|
||||
// tool-calling goroutines (run_in_executor copies the context, so
|
||||
// workers share the same sink) can race on read/modify/write.
|
||||
type RunUsage struct {
|
||||
mu sync.Mutex
|
||||
PromptTokens int
|
||||
CompletionTokens int
|
||||
TotalTokens int
|
||||
Calls int
|
||||
}
|
||||
|
||||
// Add atomically adds a single LLM call's token counts to the sink.
|
||||
// Safe to call concurrently from multiple goroutines.
|
||||
func (u *RunUsage) Add(prompt, completion, total int) {
|
||||
if u == nil {
|
||||
return
|
||||
}
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
if prompt > 0 {
|
||||
u.PromptTokens += prompt
|
||||
}
|
||||
if completion > 0 {
|
||||
u.CompletionTokens += completion
|
||||
}
|
||||
if total > 0 {
|
||||
u.TotalTokens += total
|
||||
}
|
||||
u.Calls++
|
||||
}
|
||||
|
||||
// Snapshot returns a copy of the current cumulative counts.
|
||||
func (u *RunUsage) Snapshot() (prompt, completion, total, calls int) {
|
||||
if u == nil {
|
||||
return 0, 0, 0, 0
|
||||
}
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
return u.PromptTokens, u.CompletionTokens, u.TotalTokens, u.Calls
|
||||
}
|
||||
|
||||
// RunAttrs holds per-run Langfuse correlating attributes (session_id,
|
||||
// user_id) installed on the context by the service layer.
|
||||
type RunAttrs struct {
|
||||
SessionID string
|
||||
UserID string
|
||||
}
|
||||
|
||||
// WithRunUsage installs a fresh RunUsage sink on ctx. Should be called
|
||||
// once at the start of a canvas turn.
|
||||
func WithRunUsage(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, runUsageKeyType{}, &RunUsage{})
|
||||
}
|
||||
|
||||
// GetRunUsage retrieves the per-run token usage sink from ctx.
|
||||
// Returns nil when no sink is installed (e.g. outside a canvas run).
|
||||
func GetRunUsage(ctx context.Context) *RunUsage {
|
||||
if v := ctx.Value(runUsageKeyType{}); v != nil {
|
||||
if sink, ok := v.(*RunUsage); ok {
|
||||
return sink
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WithRunAttrs installs Langfuse correlation attributes on ctx.
|
||||
func WithRunAttrs(ctx context.Context, attrs *RunAttrs) context.Context {
|
||||
if attrs == nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, runAttrsKeyType{}, attrs)
|
||||
}
|
||||
|
||||
// GetRunAttrs retrieves the per-run Langfuse attributes from ctx.
|
||||
func GetRunAttrs(ctx context.Context) *RunAttrs {
|
||||
if v := ctx.Value(runAttrsKeyType{}); v != nil {
|
||||
if attrs, ok := v.(*RunAttrs); ok {
|
||||
return attrs
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecordRunTokenUsage adds a single LLM call's token usage to the
|
||||
// active run sink on ctx. Safe to call from anywhere; when no run sink
|
||||
// is installed it is a no-op.
|
||||
func RecordRunTokenUsage(ctx context.Context, promptTokens, completionTokens, totalTokens int) {
|
||||
sink := GetRunUsage(ctx)
|
||||
if sink == nil {
|
||||
return
|
||||
}
|
||||
sink.Add(promptTokens, completionTokens, totalTokens)
|
||||
}
|
||||
|
||||
// UsageFromMap extracts a token usage split from a raw API response map.
|
||||
// Handles OpenAI/OpenRouter-style resp["usage"] dicts. Missing fields
|
||||
// default to 0; total_tokens falls back to prompt+completion when absent.
|
||||
// Returns nil when no usage found.
|
||||
// Mirrors Python's common.token_utils.usage_from_response().
|
||||
func UsageFromMap(raw map[string]interface{}) (promptTokens, completionTokens, totalTokens int) {
|
||||
if raw == nil {
|
||||
return 0, 0, 0
|
||||
}
|
||||
usageRaw, ok := raw["usage"]
|
||||
if !ok {
|
||||
return 0, 0, 0
|
||||
}
|
||||
usage, ok := usageRaw.(map[string]interface{})
|
||||
if !ok {
|
||||
return 0, 0, 0
|
||||
}
|
||||
pt := getInt(usage, "prompt_tokens", "input_tokens")
|
||||
ct := getInt(usage, "completion_tokens", "output_tokens")
|
||||
tt := getInt(usage, "total_tokens")
|
||||
if tt == 0 {
|
||||
tt = pt + ct
|
||||
}
|
||||
return pt, ct, tt
|
||||
}
|
||||
|
||||
func getInt(m map[string]interface{}, keys ...string) int {
|
||||
for _, k := range keys {
|
||||
v, ok := m[k]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return int(val)
|
||||
case int:
|
||||
return val
|
||||
case json.Number:
|
||||
n, err := val.Int64()
|
||||
if err == nil {
|
||||
return int(n)
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
Reference in New Issue
Block a user