//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package models
import (
"context"
"encoding/json"
"fmt"
"sync"
"ragflow/internal/tokenizer"
)
// recordUsageFromResponse records one chat call's token usage to both
// cm.LastUsage and the context-level run sink (if installed). Callers
// should invoke this after each ChatWithMessages / ChatStreamlyWithSender
// call so the canvas-layer aggregator and Langfuse both see the split.
func recordUsageFromResponse(ctx context.Context, cm *ChatModel) {
if cm == nil {
return
}
if cm.LastUsage == nil {
return
}
tokenizer.RecordRunTokenUsage(ctx, cm.LastUsage.PromptTokens, cm.LastUsage.CompletionTokens, cm.LastUsage.TotalTokens)
}
const (
defaultMaxRetries = 3
defaultMaxRounds = 5
)
// ChatWithTools runs the non-streaming tool-calling loop.
func (cm *ChatModel) ChatWithTools(ctx context.Context, system string, history []Message, chatCfg *ChatConfig) (string, int, error) {
tc := cm.ToolConfig
if tc == nil {
return "", 0, fmt.Errorf("ChatWithTools called without bound tools")
}
var toolsList interface{}
if err := json.Unmarshal([]byte(tc.Tools), &toolsList); err != nil {
return "", 0, fmt.Errorf("failed to parse tools JSON: %w", err)
}
maxRounds := tc.MaxRounds
if maxRounds <= 0 {
maxRounds = defaultMaxRounds
}
maxRetries := tc.MaxRetries
if maxRetries <= 0 {
maxRetries = defaultMaxRetries
}
if system != "" && len(history) > 0 && history[0].Role != "system" {
history = append([]Message{{Role: "system", Content: system}}, history...)
}
baseHistory := make([]Message, len(history))
copy(baseHistory, history)
for attempt := 0; attempt < maxRetries; attempt++ {
select {
case <-ctx.Done():
return "", 0, ctx.Err()
default:
}
h := make([]Message, len(baseHistory))
copy(h, baseHistory)
answer, tokens, err := runToolLoop(ctx, cm, h, toolsList, chatCfg, maxRounds)
if err == nil {
return answer, tokens, nil
}
}
return "", 0, fmt.Errorf("ChatWithTools failed after %d retries", maxRetries)
}
func runToolLoop(ctx context.Context, cm *ChatModel, history []Message, toolsList interface{}, chatCfg *ChatConfig, maxRounds int) (string, int, error) {
// Aggregate prompt/completion/total across all tool-calling rounds.
// Mirrors Python PR #16420 fix: previously the total was overwritten
// each round; now we accumulate so multi-round tool conversations
// report the correct grand total.
// Reset stale per-call usage from a previous call so a response
// without usage doesn't leak the prior call's data.
cm.LastUsage = nil
var totalTokens int
aggUsage := &ChatUsage{}
addRoundUsage := func(resp *ChatResponse) {
u := resp.Usage
if u == nil {
return
}
aggUsage.PromptTokens += u.PromptTokens
aggUsage.CompletionTokens += u.CompletionTokens
aggUsage.TotalTokens += u.TotalTokens
totalTokens = aggUsage.TotalTokens
// Store per-round delta (not cumulative) so RecordRunTokenUsage
// records each round's contribution exactly once.
cm.LastUsage = &ChatUsage{
PromptTokens: u.PromptTokens, CompletionTokens: u.CompletionTokens, TotalTokens: u.TotalTokens,
}
recordUsageFromResponse(ctx, cm)
}
for round := 0; round <= maxRounds; round++ {
select {
case <-ctx.Done():
return "", totalTokens, ctx.Err()
default:
}
cfg := *chatCfg
cfg.Tools = toolsList
tcChoice := "auto"
cfg.ToolChoice = &tcChoice
resp, err := cm.ModelDriver.ChatWithMessages(*cm.ModelName, history, cm.APIConfig, &cfg)
if err != nil {
return "", totalTokens, fmt.Errorf("round %d: %w", round, err)
}
if resp == nil {
return "", totalTokens, fmt.Errorf("round %d: nil response", round)
}
addRoundUsage(resp)
if len(resp.ToolCalls) == 0 {
answer := ""
if resp.Answer != nil {
answer = *resp.Answer
}
if resp.ReasonContent != nil && *resp.ReasonContent != "" {
answer = "" + *resp.ReasonContent + "" + answer
}
// Fallback: if the provider didn't return usage info,
// estimate from the answer text using tiktoken.
if resp.Usage == nil || resp.Usage.TotalTokens == 0 {
totalTokens += tokenizer.NumTokensFromString(answer)
}
return answer, totalTokens, nil
}
history = appendToolResults(history, resp.ToolCalls, cm.ToolConfig.ToolCallSession)
}
// Exceeded max rounds
history = append(history, Message{
Role: "user",
Content: fmt.Sprintf("Exceed max rounds: %d", maxRounds),
})
cfg := *chatCfg
resp, err := cm.ModelDriver.ChatWithMessages(*cm.ModelName, history, cm.APIConfig, &cfg)
if err != nil {
return "", totalTokens, fmt.Errorf("final call: %w", err)
}
if resp == nil || resp.Answer == nil {
return "", totalTokens, fmt.Errorf("final call: no answer")
}
addRoundUsage(resp)
// Fallback: use text-based estimation if no authoritative usage.
if resp.Usage == nil || resp.Usage.TotalTokens == 0 {
totalTokens += tokenizer.NumTokensFromString(*resp.Answer)
}
return *resp.Answer, totalTokens, nil
}
// ChatStreamlyWithTools runs the streaming tool-calling loop.
func (cm *ChatModel) ChatStreamlyWithTools(ctx context.Context, system string, history []Message, chatCfg *ChatConfig, sender func(*string, *string) error) (int, error) {
tc := cm.ToolConfig
if tc == nil {
return 0, fmt.Errorf("ChatStreamlyWithTools called without bound tools")
}
var toolsList interface{}
if err := json.Unmarshal([]byte(tc.Tools), &toolsList); err != nil {
return 0, fmt.Errorf("failed to parse tools JSON: %w", err)
}
maxRounds := tc.MaxRounds
if maxRounds <= 0 {
maxRounds = defaultMaxRounds
}
maxRetries := tc.MaxRetries
if maxRetries <= 0 {
maxRetries = defaultMaxRetries
}
if system != "" && len(history) > 0 && history[0].Role != "system" {
history = append([]Message{{Role: "system", Content: system}}, history...)
}
baseHistory := make([]Message, len(history))
copy(baseHistory, history)
for attempt := 0; attempt < maxRetries; attempt++ {
select {
case <-ctx.Done():
return 0, ctx.Err()
default:
}
h := make([]Message, len(baseHistory))
copy(h, baseHistory)
totalTokens, err := runStreamToolLoop(ctx, cm, h, toolsList, chatCfg, maxRounds, sender)
if err == nil {
return totalTokens, nil
}
}
return 0, fmt.Errorf("ChatStreamlyWithTools failed after %d retries", maxRetries)
}
func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, toolsList interface{}, chatCfg *ChatConfig, maxRounds int, sender func(*string, *string) error) (int, error) {
// Aggregate token counts across every tool-calling round (each round is a
// separate provider request). Committing per round avoids the previous
// bug where a later round's total overwrote earlier rounds.
// Reset stale per-call usage from a previous call.
cm.LastUsage = nil
var totalTokens int
aggUsage := &ChatUsage{}
commitRound := func(cfg *ChatConfig, roundTokens int) {
// Prefer the authoritative usage from the API (extracted via
// stream_options.include_usage=true) over text-based token
// counting. Mirrors Python's usage_from_response accumulation
// in chat_model.py streaming handlers.
// Track per-round delta so RecordRunTokenUsage records each
// round's contribution exactly once (the sink does Add, not Set).
var deltaPrompt, deltaCompletion, deltaTotal int
if u := cfg.UsageResult; u != nil && u.TotalTokens > 0 {
deltaPrompt, deltaCompletion, deltaTotal = u.PromptTokens, u.CompletionTokens, u.TotalTokens
aggUsage.PromptTokens += deltaPrompt
aggUsage.CompletionTokens += deltaCompletion
aggUsage.TotalTokens += deltaTotal
} else {
deltaTotal = roundTokens
aggUsage.TotalTokens += deltaTotal
}
totalTokens = aggUsage.TotalTokens
cm.LastUsage = &ChatUsage{
PromptTokens: deltaPrompt, CompletionTokens: deltaCompletion, TotalTokens: deltaTotal,
}
recordUsageFromResponse(ctx, cm)
}
for round := 0; round <= maxRounds; round++ {
select {
case <-ctx.Done():
return totalTokens, ctx.Err()
default:
}
cfg := *chatCfg
cfg.Tools = toolsList
tcChoice := "auto"
cfg.ToolChoice = &tcChoice
cfg.Stream = boolPtr(true)
var tcs []map[string]interface{}
cfg.ToolCallsResult = &tcs
var roundUsage ChatUsage
cfg.UsageResult = &roundUsage
reasoningStarted := false
var answer string
var pendingThinkClose bool
var roundTokens int
err := cm.ModelDriver.ChatStreamlyWithSender(*cm.ModelName, history, cm.APIConfig, &cfg, func(delta *string, reason *string) error {
if reason != nil && *reason != "" {
if !reasoningStarted {
reasoningStarted = true
thinkOpen := ""
if e := sender(&thinkOpen, nil); e != nil {
return e
}
}
pendingThinkClose = true
roundTokens += tokenizer.NumTokensFromString(*reason)
return sender(reason, nil)
}
// Reasoning ended, close the think block if open
if pendingThinkClose {
pendingThinkClose = false
thinkClose := ""
if e := sender(&thinkClose, nil); e != nil {
return e
}
}
if delta != nil && *delta != "" {
if *delta == "[DONE]" {
return nil
}
roundTokens += tokenizer.NumTokensFromString(*delta)
answer += *delta
if e := sender(delta, nil); e != nil {
return e
}
}
return nil
})
// Close any unclosed think block after stream completes
if pendingThinkClose {
pendingThinkClose = false
thinkClose := ""
if e := sender(&thinkClose, nil); e != nil {
return totalTokens, e
}
}
// Commit this round's token count to the running aggregate.
// Prefer authoritative API usage over text-based estimation.
commitRound(&cfg, roundTokens)
if err != nil {
return totalTokens, fmt.Errorf("round %d: %w", round, err)
}
var toolCalls []map[string]interface{}
if cfg.ToolCallsResult != nil {
toolCalls = *cfg.ToolCallsResult
}
if answer != "" && len(toolCalls) == 0 {
return totalTokens, nil
}
if len(toolCalls) == 0 {
return totalTokens, fmt.Errorf("round %d: no content and no tool_calls", round)
}
history = appendToolResults(history, toolCalls, cm.ToolConfig.ToolCallSession)
}
// Exceeded max rounds
history = append(history, Message{
Role: "user",
Content: fmt.Sprintf("Exceed max rounds: %d", maxRounds),
})
cfg := *chatCfg
cfg.Stream = boolPtr(true)
var exceedUsage ChatUsage
cfg.UsageResult = &exceedUsage
var exceedTokens int
err := cm.ModelDriver.ChatStreamlyWithSender(*cm.ModelName, history, cm.APIConfig, &cfg, func(delta *string, reason *string) error {
if delta != nil && *delta != "" && *delta != "[DONE]" {
exceedTokens += tokenizer.NumTokensFromString(*delta)
}
return nil
})
commitRound(&cfg, exceedTokens)
return totalTokens, err
}
// appendToolResults executes tool calls concurrently, appends the assistant
// message with tool_calls and individual tool result messages to history.
func appendToolResults(history []Message, toolCalls []map[string]interface{}, session ToolCallSession) []Message {
if session == nil {
history = append(history, Message{
Role: "assistant",
Content: nil,
ToolCalls: toolCalls,
})
for _, tc := range toolCalls {
tcID, _ := tc["id"].(string)
history = append(history, Message{
Role: "tool",
Content: "Error: no tool session configured",
ToolCallID: tcID,
})
}
return history
}
var mu sync.Mutex
var wg sync.WaitGroup
type toolResult struct {
index int
tcID string
content string
}
results := make([]toolResult, len(toolCalls))
for i, tc := range toolCalls {
wg.Add(1)
go func(idx int, tcMap map[string]interface{}) {
defer wg.Done()
var result toolResult
result.index = idx
fn, ok := tcMap["function"].(map[string]interface{})
if !ok {
mu.Lock()
results[idx] = result
mu.Unlock()
return
}
name, _ := fn["name"].(string)
argsStr, _ := fn["arguments"].(string)
result.tcID, _ = tcMap["id"].(string)
var args map[string]interface{}
if err := json.Unmarshal([]byte(argsStr), &args); err != nil {
args = map[string]interface{}{"raw_arguments": argsStr}
}
res, err := session.ToolCall(name, args)
if err != nil {
result.content = fmt.Sprintf("Error: %s", err.Error())
} else {
result.content = res
}
mu.Lock()
results[idx] = result
mu.Unlock()
}(i, tc)
}
wg.Wait()
history = append(history, Message{
Role: "assistant",
Content: nil,
ToolCalls: toolCalls,
})
for _, r := range results {
history = append(history, Message{
Role: "tool",
Content: r.content,
ToolCallID: r.tcID,
})
}
return history
}
func boolPtr(b bool) *bool {
return &b
}