mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
### What problem does this PR solve? - Tools management - Pregel engine wrapper for better usage - UT race - Coding style ### Type of change - [x] Refactoring
461 lines
14 KiB
Go
461 lines
14 KiB
Go
package core
|
|
|
|
import (
|
|
"context"
|
|
"crypto/md5"
|
|
"encoding/json"
|
|
"fmt"
|
|
"sync"
|
|
|
|
"ragflow/internal/harness/core/schema"
|
|
)
|
|
|
|
// ToolsNodeConfig configures the tools node for a ReActAgent.
|
|
type ToolsNodeConfig struct {
|
|
// Tools is the list of tools available for execution.
|
|
Tools []Tool
|
|
|
|
// Registry provides centralized tool management with aliases, categories,
|
|
// and filtering. When set, tools are loaded from the registry first,
|
|
// then any tools in the Tools slice are added on top.
|
|
Registry *ToolRegistry
|
|
|
|
// ReturnDirectly specifies tool names that cause the agent to return immediately.
|
|
ReturnDirectly map[string]bool
|
|
|
|
// ToolInvokeMiddlewares are middleware wrappers using ToolInvocationContext.
|
|
// Applied before tool execution in a chain (outermost first).
|
|
ToolInvokeMiddlewares []ToolInvokeMiddleware
|
|
|
|
// EmitInternalEvents enables forwarding internal events from AgentTool children.
|
|
EmitInternalEvents bool
|
|
|
|
// LoopGuard prevents infinite loops by detecting repeated tool calls
|
|
// with identical arguments or consecutive failures. If nil, no guard is applied.
|
|
LoopGuard *LoopGuard
|
|
|
|
// UnknownToolHandler handles tool calls for tools that are not registered.
|
|
// If nil, an error message is returned to the model when a tool is not found.
|
|
// The function receives the tool name and arguments JSON string.
|
|
UnknownToolHandler func(ctx context.Context, name, arguments string) (string, error)
|
|
|
|
// ArgumentsAliases maps tool names to their argument field aliases.
|
|
// Key = canonical tool name, value = map[canonicalArgumentKey][]alias.
|
|
// When a tool call's JSON contains an alias key instead of the canonical key,
|
|
// it is remapped before execution.
|
|
// Example: {"get_weather": {"query": ["q", "search_term"]}}
|
|
ArgumentsAliases map[string]map[string][]string
|
|
}
|
|
|
|
// ToolsNode handles tool extraction from model output, dispatching to tools,
|
|
// collecting results, and applying middleware chains.
|
|
type ToolsNode[M MessageType] struct {
|
|
config *ToolsNodeConfig
|
|
toolMap map[string]Tool
|
|
}
|
|
|
|
// NewToolsNode creates a new ToolsNode with the given configuration.
|
|
// Tools are loaded from the Registry first (if set), then any Tools slice
|
|
// entries are added on top (taking precedence on name conflict).
|
|
func NewToolsNode[M MessageType](cfg *ToolsNodeConfig) *ToolsNode[M] {
|
|
tn := &ToolsNode[M]{config: cfg}
|
|
capacity := len(cfg.Tools)
|
|
if cfg.Registry != nil {
|
|
capacity = max(capacity, len(cfg.Registry.tools))
|
|
}
|
|
tn.toolMap = make(map[string]Tool, capacity)
|
|
if cfg.Registry != nil {
|
|
for _, t := range cfg.Registry.AllTools() {
|
|
tn.toolMap[t.Name()] = t
|
|
}
|
|
}
|
|
for _, t := range cfg.Tools {
|
|
tn.toolMap[t.Name()] = t
|
|
}
|
|
return tn
|
|
}
|
|
|
|
// Execute processes all tool calls found in the model response.
|
|
// It returns the list of tool result messages to append to state,
|
|
// and any agent action (e.g., return-directly) that should be handled.
|
|
//
|
|
// When multiple independent tool calls are present, Execute runs them concurrently
|
|
// using a bounded goroutine pool (default max concurrency = 10), reducing total
|
|
// latency from O(sum) to O(max). For a single tool call, no goroutine is spawned.
|
|
func (tn *ToolsNode[M]) Execute(ctx context.Context, resp M, state *TypedReActAgentState[M], _ interface{}) ([]M, *AgentAction, error) {
|
|
toolCalls := extractToolCalls(resp)
|
|
if len(toolCalls) == 0 {
|
|
return nil, nil, nil
|
|
}
|
|
|
|
if len(toolCalls) == 1 {
|
|
// Fast path: single tool call, no goroutine overhead.
|
|
tc := toolCalls[0]
|
|
var action *AgentAction
|
|
if tn.config.ReturnDirectly != nil && tn.config.ReturnDirectly[tc.Function.Name] {
|
|
action = NewExitAction()
|
|
}
|
|
toolMsg, err := tn.executeOne(ctx, tc)
|
|
if err != nil {
|
|
return nil, action, fmt.Errorf("tool '%s': %w", tc.Function.Name, err)
|
|
}
|
|
return []M{toolMsg}, action, nil
|
|
}
|
|
|
|
// Multi-tool path: plan execution batches by capability, then execute.
|
|
batches := tn.planBatches(toolCalls)
|
|
var action *AgentAction
|
|
var mu sync.Mutex
|
|
var firstErr error
|
|
var results []M
|
|
|
|
for _, batch := range batches {
|
|
if batch.mode == batchParallel {
|
|
// Execute parallel-safe tools concurrently.
|
|
const maxConcurrency = 10
|
|
sem := make(chan struct{}, maxConcurrency)
|
|
parResults := make([]M, len(batch.calls))
|
|
var wg sync.WaitGroup
|
|
|
|
for i, tc := range batch.calls {
|
|
if tn.config.ReturnDirectly != nil && tn.config.ReturnDirectly[tc.Function.Name] {
|
|
mu.Lock()
|
|
action = NewExitAction()
|
|
mu.Unlock()
|
|
}
|
|
wg.Add(1)
|
|
go func(idx int, call schema.ToolCall) {
|
|
defer wg.Done()
|
|
sem <- struct{}{}
|
|
defer func() { <-sem }()
|
|
msg, err := tn.executeOne(ctx, call)
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
if err != nil && firstErr == nil {
|
|
firstErr = fmt.Errorf("tool '%s': %w", call.Function.Name, err)
|
|
return
|
|
}
|
|
parResults[idx] = msg
|
|
}(i, tc)
|
|
}
|
|
wg.Wait()
|
|
for _, r := range parResults {
|
|
if !isNilMessage(r) {
|
|
results = append(results, r)
|
|
}
|
|
}
|
|
} else {
|
|
// Execute serial tools one by one.
|
|
for _, tc := range batch.calls {
|
|
if tn.config.ReturnDirectly != nil && tn.config.ReturnDirectly[tc.Function.Name] {
|
|
action = NewExitAction()
|
|
}
|
|
msg, err := tn.executeOne(ctx, tc)
|
|
if err != nil && firstErr == nil {
|
|
firstErr = fmt.Errorf("tool '%s': %w", tc.Function.Name, err)
|
|
}
|
|
if !isNilMessage(msg) {
|
|
results = append(results, msg)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if firstErr != nil {
|
|
return nil, action, firstErr
|
|
}
|
|
return results, action, nil
|
|
}
|
|
|
|
func (tn *ToolsNode[M]) executeOne(ctx context.Context, tc schema.ToolCall) (msg M, err error) {
|
|
// Panic recovery: tool.Invoke may panic, catch and convert to tool result message.
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
msg = tn.makeToolMsg(fmt.Sprintf("Error: tool '%s' panicked: %v", tc.Function.Name, r), tc.ID)
|
|
err = nil // do not propagate Go error; captured as tool result text
|
|
}
|
|
}()
|
|
|
|
// LoopGuard: detect repeated calls with identical arguments.
|
|
if lg := tn.getLoopGuard(); lg != nil {
|
|
if err := lg.CheckSameArgs(tc.Function.Name, tc.Function.Arguments); err != nil {
|
|
return tn.makeToolMsg(fmt.Sprintf("Error: %v", err), tc.ID), nil
|
|
}
|
|
}
|
|
|
|
tool, ok := tn.toolMap[tc.Function.Name]
|
|
if !ok {
|
|
if tn.config.UnknownToolHandler != nil {
|
|
result, err := tn.config.UnknownToolHandler(ctx, tc.Function.Name, tc.Function.Arguments)
|
|
if err != nil {
|
|
return tn.makeToolMsg(fmt.Sprintf("Error: %v", err), tc.ID), nil
|
|
}
|
|
return tn.makeToolMsg(result, tc.ID), nil
|
|
}
|
|
errMsg := fmt.Sprintf("tool '%s' not found", tc.Function.Name)
|
|
return tn.makeToolMsg(errMsg, tc.ID), nil
|
|
}
|
|
|
|
return tn.executeWithNewChain(ctx, tc, tool)
|
|
}
|
|
|
|
func (tn *ToolsNode[M]) executeWithNewChain(ctx context.Context, tc schema.ToolCall, tool Tool) (M, error) {
|
|
// Remap argument aliases if configured.
|
|
argsJSON := tc.Function.Arguments
|
|
if aliases, ok := tn.config.ArgumentsAliases[tc.Function.Name]; ok && len(aliases) > 0 {
|
|
argsJSON = remapToolArgs(tc.Function.Arguments, aliases)
|
|
}
|
|
|
|
args := &schema.ToolArgument{
|
|
Name: tc.Function.Name,
|
|
Arguments: argsJSON,
|
|
CallID: tc.ID,
|
|
}
|
|
|
|
ictx := &ToolInvocationContext{
|
|
Name: tc.Function.Name,
|
|
CallID: tc.ID,
|
|
Arguments: args,
|
|
}
|
|
|
|
var invokeFn InvokeTool
|
|
if et, ok := tool.(EnhancedTool); ok {
|
|
invokeFn = EnhancedToolToInvokeFn(et)
|
|
} else {
|
|
invokeFn = ToolToInvokeFn(tool)
|
|
}
|
|
|
|
chained := ToolWrapperChain(invokeFn, tn.config.ToolInvokeMiddlewares...)
|
|
result, err := chained(ctx, ictx)
|
|
if err != nil {
|
|
// Detect tool-level interrupt: save state to context for resume.
|
|
if tie, ok := IsToolInterrupt(err); ok {
|
|
ctx = setToolInterruptState(ctx, tie)
|
|
ctx = AppendAddressSegment(ctx, AddressSegmentTool, tc.ID)
|
|
addr := getAddressSegments(ctx)
|
|
addrCopy := make(Address, len(addr))
|
|
copy(addrCopy, addr)
|
|
return tn.makeToolMsg(fmt.Sprintf("[interrupted: %v]", tie.Info), tc.ID),
|
|
&interruptResult{tie: tie, toolAddress: addrCopy}
|
|
}
|
|
return tn.makeToolMsg(fmt.Sprintf("Error: %v", err), tc.ID), nil
|
|
}
|
|
|
|
content := result.Content
|
|
if result.Error != "" && content == "" {
|
|
content = fmt.Sprintf("Error: %s", result.Error)
|
|
}
|
|
return tn.makeToolMsg(content, tc.ID), nil
|
|
}
|
|
|
|
// interruptResult wraps a tool interrupt for propagation up the call chain.
|
|
type interruptResult struct {
|
|
tie *ToolInterruptError
|
|
toolAddress Address // address segments at time of interrupt; preserved for resume routing
|
|
}
|
|
|
|
func (e *interruptResult) Error() string { return fmt.Sprintf("interrupt: %v", e.tie.Info) }
|
|
|
|
// remapToolArgs replaces alias keys in JSON arguments with canonical keys.
|
|
func remapToolArgs(argsJSON string, aliases map[string][]string) string {
|
|
if len(aliases) == 0 || argsJSON == "" {
|
|
return argsJSON
|
|
}
|
|
var raw map[string]json.RawMessage
|
|
if err := json.Unmarshal([]byte(argsJSON), &raw); err != nil {
|
|
return argsJSON
|
|
}
|
|
changed := false
|
|
for canonical, aliasList := range aliases {
|
|
for _, alias := range aliasList {
|
|
if v, ok := raw[alias]; ok {
|
|
if _, exists := raw[canonical]; !exists {
|
|
raw[canonical] = v
|
|
delete(raw, alias)
|
|
changed = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if !changed {
|
|
return argsJSON
|
|
}
|
|
b, _ := json.Marshal(raw)
|
|
return string(b)
|
|
}
|
|
|
|
func (tn *ToolsNode[M]) makeToolMsg(content, callID string) M {
|
|
var zero M
|
|
switch any(zero).(type) {
|
|
case *schema.AgenticMessage:
|
|
return any(&schema.AgenticMessage{
|
|
Role: schema.AgenticRoleUser,
|
|
Content: content,
|
|
ContentBlocks: []schema.ContentBlock{
|
|
{Type: "tool_result", ToolResult: &schema.ToolResult{
|
|
ToolCallID: callID, Content: content,
|
|
}},
|
|
},
|
|
}).(M)
|
|
default:
|
|
return any(schema.ToolMessage(content, callID)).(M)
|
|
}
|
|
}
|
|
|
|
// ---- Helper: convert tool results for event emission ----
|
|
|
|
func toolResultToEvent[M MessageType](msg M, roleName string) *TypedAgentEvent[M] {
|
|
if m, ok := any(msg).(*schema.Message); ok {
|
|
return any(typedEventFromMessage(m, nil, schema.RoleTool, roleName)).(*TypedAgentEvent[M])
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ---- JSON helpers ----
|
|
|
|
func parseToolArgs(argsJSON string, target any) error {
|
|
if err := json.Unmarshal([]byte(argsJSON), target); err != nil {
|
|
return fmt.Errorf("invalid tool arguments JSON: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ---- LoopGuard: detect repeated tool calls with same args ----
|
|
|
|
// LoopGuard prevents infinite loops where the model repeatedly calls a tool
|
|
// with identical parameters. It tracks consecutive same-args calls per tool.
|
|
type LoopGuard struct {
|
|
mu sync.Mutex
|
|
sameArgs map[string]int // key = toolName+"|"+argsHash
|
|
failures map[string]int // key = toolName
|
|
maxSame int
|
|
maxFails int
|
|
}
|
|
|
|
// NewLoopGuard creates a LoopGuard with the given thresholds.
|
|
func NewLoopGuard(maxSame, maxFails int) *LoopGuard {
|
|
return &LoopGuard{
|
|
sameArgs: make(map[string]int),
|
|
failures: make(map[string]int),
|
|
maxSame: maxSame,
|
|
maxFails: maxFails,
|
|
}
|
|
}
|
|
|
|
// CheckSameArgs returns an error if the same tool+args pair is called too many times.
|
|
func (g *LoopGuard) CheckSameArgs(toolName, argsJSON string) error {
|
|
if g == nil || g.maxSame <= 0 {
|
|
return nil
|
|
}
|
|
g.mu.Lock()
|
|
defer g.mu.Unlock()
|
|
hash := fmt.Sprintf("%s|%x", toolName, md5.Sum([]byte(argsJSON)))
|
|
g.sameArgs[hash]++
|
|
if g.sameArgs[hash] >= g.maxSame {
|
|
return fmt.Errorf("loop guard: tool '%s' called %d times with identical arguments", toolName, g.sameArgs[hash])
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RecordFailure tracks consecutive failures for a tool.
|
|
// Returns an error if the failure limit is exceeded.
|
|
func (g *LoopGuard) RecordFailure(toolName string) error {
|
|
if g == nil || g.maxFails <= 0 {
|
|
return nil
|
|
}
|
|
g.mu.Lock()
|
|
defer g.mu.Unlock()
|
|
g.failures[toolName]++
|
|
cnt := g.failures[toolName]
|
|
if cnt >= g.maxFails {
|
|
return fmt.Errorf("loop guard: tool '%s' failed %d consecutive times", toolName, cnt)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Reset clears tracking for a tool (called on success or different args).
|
|
func (g *LoopGuard) Reset(toolName string) {
|
|
if g == nil {
|
|
return
|
|
}
|
|
g.mu.Lock()
|
|
defer g.mu.Unlock()
|
|
// Remove all same-args entries for this tool
|
|
for k := range g.sameArgs {
|
|
if len(k) > len(toolName) && k[:len(toolName)] == toolName {
|
|
delete(g.sameArgs, k)
|
|
}
|
|
}
|
|
delete(g.failures, toolName)
|
|
}
|
|
|
|
// ---- Tool capability and batch planning ----
|
|
|
|
// toolCapFromTool returns the capability of a tool.
|
|
func toolCapFromTool(t Tool) ToolCapability {
|
|
if ct, ok := t.(CapableTool); ok {
|
|
return ct.Capability()
|
|
}
|
|
return ToolCapWritesFiles // default: conservative serial
|
|
}
|
|
|
|
// executionBatch represents a group of tool calls to execute together.
|
|
type executionBatch struct {
|
|
mode batchMode
|
|
calls []schema.ToolCall
|
|
}
|
|
|
|
type batchMode int
|
|
|
|
const (
|
|
batchParallel batchMode = iota
|
|
batchSerial
|
|
)
|
|
|
|
// planBatches groups tool calls into parallel/serial batches based on capability.
|
|
// Read-only tools are grouped for parallel execution; others run serially.
|
|
func (tn *ToolsNode[M]) planBatches(tcs []schema.ToolCall) []executionBatch {
|
|
var batches []executionBatch
|
|
var currentParallel []schema.ToolCall
|
|
|
|
flushParallel := func() {
|
|
if len(currentParallel) > 0 {
|
|
batches = append(batches, executionBatch{mode: batchParallel, calls: currentParallel})
|
|
currentParallel = nil
|
|
}
|
|
}
|
|
|
|
for _, tc := range tcs {
|
|
tool, ok := tn.toolMap[tc.Function.Name]
|
|
if !ok {
|
|
// Unknown tool - treat as serial to be safe.
|
|
flushParallel()
|
|
batches = append(batches, executionBatch{mode: batchSerial, calls: []schema.ToolCall{tc}})
|
|
continue
|
|
}
|
|
cap := toolCapFromTool(tool)
|
|
if cap == ToolCapReadOnly {
|
|
currentParallel = append(currentParallel, tc)
|
|
} else {
|
|
flushParallel()
|
|
batches = append(batches, executionBatch{mode: batchSerial, calls: []schema.ToolCall{tc}})
|
|
}
|
|
}
|
|
flushParallel()
|
|
return batches
|
|
}
|
|
|
|
// ---- LoopGuard integration in executeOne ----
|
|
|
|
// getLoopGuard returns the LoopGuard from the ToolsNode if configured.
|
|
// It is stored on the ToolsNode to share state across invocation cycles.
|
|
func (tn *ToolsNode[M]) getLoopGuard() *LoopGuard {
|
|
if tn.config != nil {
|
|
return tn.config.LoopGuard
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// clearLoopGuard writes the LoopGuard config back (no-op, config is shared by pointer).
|
|
func (tn *ToolsNode[M]) clearLoopGuard() {}
|