Files
ragflow/internal/harness/core/tools_node.go
Yingfeng 706e0d2d06 Refactor harness framework (#16271)
### What problem does this PR solve?

- Tools management
- Pregel engine wrapper for better usage
- UT race
- Coding style

### Type of change

- [x] Refactoring
2026-06-23 20:18:04 +08:00

461 lines
14 KiB
Go

package core
import (
"context"
"crypto/md5"
"encoding/json"
"fmt"
"sync"
"ragflow/internal/harness/core/schema"
)
// ToolsNodeConfig configures the tools node for a ReActAgent.
type ToolsNodeConfig struct {
// Tools is the list of tools available for execution.
Tools []Tool
// Registry provides centralized tool management with aliases, categories,
// and filtering. When set, tools are loaded from the registry first,
// then any tools in the Tools slice are added on top.
Registry *ToolRegistry
// ReturnDirectly specifies tool names that cause the agent to return immediately.
ReturnDirectly map[string]bool
// ToolInvokeMiddlewares are middleware wrappers using ToolInvocationContext.
// Applied before tool execution in a chain (outermost first).
ToolInvokeMiddlewares []ToolInvokeMiddleware
// EmitInternalEvents enables forwarding internal events from AgentTool children.
EmitInternalEvents bool
// LoopGuard prevents infinite loops by detecting repeated tool calls
// with identical arguments or consecutive failures. If nil, no guard is applied.
LoopGuard *LoopGuard
// UnknownToolHandler handles tool calls for tools that are not registered.
// If nil, an error message is returned to the model when a tool is not found.
// The function receives the tool name and arguments JSON string.
UnknownToolHandler func(ctx context.Context, name, arguments string) (string, error)
// ArgumentsAliases maps tool names to their argument field aliases.
// Key = canonical tool name, value = map[canonicalArgumentKey][]alias.
// When a tool call's JSON contains an alias key instead of the canonical key,
// it is remapped before execution.
// Example: {"get_weather": {"query": ["q", "search_term"]}}
ArgumentsAliases map[string]map[string][]string
}
// ToolsNode handles tool extraction from model output, dispatching to tools,
// collecting results, and applying middleware chains.
type ToolsNode[M MessageType] struct {
config *ToolsNodeConfig
toolMap map[string]Tool
}
// NewToolsNode creates a new ToolsNode with the given configuration.
// Tools are loaded from the Registry first (if set), then any Tools slice
// entries are added on top (taking precedence on name conflict).
func NewToolsNode[M MessageType](cfg *ToolsNodeConfig) *ToolsNode[M] {
tn := &ToolsNode[M]{config: cfg}
capacity := len(cfg.Tools)
if cfg.Registry != nil {
capacity = max(capacity, len(cfg.Registry.tools))
}
tn.toolMap = make(map[string]Tool, capacity)
if cfg.Registry != nil {
for _, t := range cfg.Registry.AllTools() {
tn.toolMap[t.Name()] = t
}
}
for _, t := range cfg.Tools {
tn.toolMap[t.Name()] = t
}
return tn
}
// Execute processes all tool calls found in the model response.
// It returns the list of tool result messages to append to state,
// and any agent action (e.g., return-directly) that should be handled.
//
// When multiple independent tool calls are present, Execute runs them concurrently
// using a bounded goroutine pool (default max concurrency = 10), reducing total
// latency from O(sum) to O(max). For a single tool call, no goroutine is spawned.
func (tn *ToolsNode[M]) Execute(ctx context.Context, resp M, state *TypedReActAgentState[M], _ interface{}) ([]M, *AgentAction, error) {
toolCalls := extractToolCalls(resp)
if len(toolCalls) == 0 {
return nil, nil, nil
}
if len(toolCalls) == 1 {
// Fast path: single tool call, no goroutine overhead.
tc := toolCalls[0]
var action *AgentAction
if tn.config.ReturnDirectly != nil && tn.config.ReturnDirectly[tc.Function.Name] {
action = NewExitAction()
}
toolMsg, err := tn.executeOne(ctx, tc)
if err != nil {
return nil, action, fmt.Errorf("tool '%s': %w", tc.Function.Name, err)
}
return []M{toolMsg}, action, nil
}
// Multi-tool path: plan execution batches by capability, then execute.
batches := tn.planBatches(toolCalls)
var action *AgentAction
var mu sync.Mutex
var firstErr error
var results []M
for _, batch := range batches {
if batch.mode == batchParallel {
// Execute parallel-safe tools concurrently.
const maxConcurrency = 10
sem := make(chan struct{}, maxConcurrency)
parResults := make([]M, len(batch.calls))
var wg sync.WaitGroup
for i, tc := range batch.calls {
if tn.config.ReturnDirectly != nil && tn.config.ReturnDirectly[tc.Function.Name] {
mu.Lock()
action = NewExitAction()
mu.Unlock()
}
wg.Add(1)
go func(idx int, call schema.ToolCall) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
msg, err := tn.executeOne(ctx, call)
mu.Lock()
defer mu.Unlock()
if err != nil && firstErr == nil {
firstErr = fmt.Errorf("tool '%s': %w", call.Function.Name, err)
return
}
parResults[idx] = msg
}(i, tc)
}
wg.Wait()
for _, r := range parResults {
if !isNilMessage(r) {
results = append(results, r)
}
}
} else {
// Execute serial tools one by one.
for _, tc := range batch.calls {
if tn.config.ReturnDirectly != nil && tn.config.ReturnDirectly[tc.Function.Name] {
action = NewExitAction()
}
msg, err := tn.executeOne(ctx, tc)
if err != nil && firstErr == nil {
firstErr = fmt.Errorf("tool '%s': %w", tc.Function.Name, err)
}
if !isNilMessage(msg) {
results = append(results, msg)
}
}
}
}
if firstErr != nil {
return nil, action, firstErr
}
return results, action, nil
}
func (tn *ToolsNode[M]) executeOne(ctx context.Context, tc schema.ToolCall) (msg M, err error) {
// Panic recovery: tool.Invoke may panic, catch and convert to tool result message.
defer func() {
if r := recover(); r != nil {
msg = tn.makeToolMsg(fmt.Sprintf("Error: tool '%s' panicked: %v", tc.Function.Name, r), tc.ID)
err = nil // do not propagate Go error; captured as tool result text
}
}()
// LoopGuard: detect repeated calls with identical arguments.
if lg := tn.getLoopGuard(); lg != nil {
if err := lg.CheckSameArgs(tc.Function.Name, tc.Function.Arguments); err != nil {
return tn.makeToolMsg(fmt.Sprintf("Error: %v", err), tc.ID), nil
}
}
tool, ok := tn.toolMap[tc.Function.Name]
if !ok {
if tn.config.UnknownToolHandler != nil {
result, err := tn.config.UnknownToolHandler(ctx, tc.Function.Name, tc.Function.Arguments)
if err != nil {
return tn.makeToolMsg(fmt.Sprintf("Error: %v", err), tc.ID), nil
}
return tn.makeToolMsg(result, tc.ID), nil
}
errMsg := fmt.Sprintf("tool '%s' not found", tc.Function.Name)
return tn.makeToolMsg(errMsg, tc.ID), nil
}
return tn.executeWithNewChain(ctx, tc, tool)
}
func (tn *ToolsNode[M]) executeWithNewChain(ctx context.Context, tc schema.ToolCall, tool Tool) (M, error) {
// Remap argument aliases if configured.
argsJSON := tc.Function.Arguments
if aliases, ok := tn.config.ArgumentsAliases[tc.Function.Name]; ok && len(aliases) > 0 {
argsJSON = remapToolArgs(tc.Function.Arguments, aliases)
}
args := &schema.ToolArgument{
Name: tc.Function.Name,
Arguments: argsJSON,
CallID: tc.ID,
}
ictx := &ToolInvocationContext{
Name: tc.Function.Name,
CallID: tc.ID,
Arguments: args,
}
var invokeFn InvokeTool
if et, ok := tool.(EnhancedTool); ok {
invokeFn = EnhancedToolToInvokeFn(et)
} else {
invokeFn = ToolToInvokeFn(tool)
}
chained := ToolWrapperChain(invokeFn, tn.config.ToolInvokeMiddlewares...)
result, err := chained(ctx, ictx)
if err != nil {
// Detect tool-level interrupt: save state to context for resume.
if tie, ok := IsToolInterrupt(err); ok {
ctx = setToolInterruptState(ctx, tie)
ctx = AppendAddressSegment(ctx, AddressSegmentTool, tc.ID)
addr := getAddressSegments(ctx)
addrCopy := make(Address, len(addr))
copy(addrCopy, addr)
return tn.makeToolMsg(fmt.Sprintf("[interrupted: %v]", tie.Info), tc.ID),
&interruptResult{tie: tie, toolAddress: addrCopy}
}
return tn.makeToolMsg(fmt.Sprintf("Error: %v", err), tc.ID), nil
}
content := result.Content
if result.Error != "" && content == "" {
content = fmt.Sprintf("Error: %s", result.Error)
}
return tn.makeToolMsg(content, tc.ID), nil
}
// interruptResult wraps a tool interrupt for propagation up the call chain.
type interruptResult struct {
tie *ToolInterruptError
toolAddress Address // address segments at time of interrupt; preserved for resume routing
}
func (e *interruptResult) Error() string { return fmt.Sprintf("interrupt: %v", e.tie.Info) }
// remapToolArgs replaces alias keys in JSON arguments with canonical keys.
func remapToolArgs(argsJSON string, aliases map[string][]string) string {
if len(aliases) == 0 || argsJSON == "" {
return argsJSON
}
var raw map[string]json.RawMessage
if err := json.Unmarshal([]byte(argsJSON), &raw); err != nil {
return argsJSON
}
changed := false
for canonical, aliasList := range aliases {
for _, alias := range aliasList {
if v, ok := raw[alias]; ok {
if _, exists := raw[canonical]; !exists {
raw[canonical] = v
delete(raw, alias)
changed = true
}
}
}
}
if !changed {
return argsJSON
}
b, _ := json.Marshal(raw)
return string(b)
}
func (tn *ToolsNode[M]) makeToolMsg(content, callID string) M {
var zero M
switch any(zero).(type) {
case *schema.AgenticMessage:
return any(&schema.AgenticMessage{
Role: schema.AgenticRoleUser,
Content: content,
ContentBlocks: []schema.ContentBlock{
{Type: "tool_result", ToolResult: &schema.ToolResult{
ToolCallID: callID, Content: content,
}},
},
}).(M)
default:
return any(schema.ToolMessage(content, callID)).(M)
}
}
// ---- Helper: convert tool results for event emission ----
func toolResultToEvent[M MessageType](msg M, roleName string) *TypedAgentEvent[M] {
if m, ok := any(msg).(*schema.Message); ok {
return any(typedEventFromMessage(m, nil, schema.RoleTool, roleName)).(*TypedAgentEvent[M])
}
return nil
}
// ---- JSON helpers ----
func parseToolArgs(argsJSON string, target any) error {
if err := json.Unmarshal([]byte(argsJSON), target); err != nil {
return fmt.Errorf("invalid tool arguments JSON: %w", err)
}
return nil
}
// ---- LoopGuard: detect repeated tool calls with same args ----
// LoopGuard prevents infinite loops where the model repeatedly calls a tool
// with identical parameters. It tracks consecutive same-args calls per tool.
type LoopGuard struct {
mu sync.Mutex
sameArgs map[string]int // key = toolName+"|"+argsHash
failures map[string]int // key = toolName
maxSame int
maxFails int
}
// NewLoopGuard creates a LoopGuard with the given thresholds.
func NewLoopGuard(maxSame, maxFails int) *LoopGuard {
return &LoopGuard{
sameArgs: make(map[string]int),
failures: make(map[string]int),
maxSame: maxSame,
maxFails: maxFails,
}
}
// CheckSameArgs returns an error if the same tool+args pair is called too many times.
func (g *LoopGuard) CheckSameArgs(toolName, argsJSON string) error {
if g == nil || g.maxSame <= 0 {
return nil
}
g.mu.Lock()
defer g.mu.Unlock()
hash := fmt.Sprintf("%s|%x", toolName, md5.Sum([]byte(argsJSON)))
g.sameArgs[hash]++
if g.sameArgs[hash] >= g.maxSame {
return fmt.Errorf("loop guard: tool '%s' called %d times with identical arguments", toolName, g.sameArgs[hash])
}
return nil
}
// RecordFailure tracks consecutive failures for a tool.
// Returns an error if the failure limit is exceeded.
func (g *LoopGuard) RecordFailure(toolName string) error {
if g == nil || g.maxFails <= 0 {
return nil
}
g.mu.Lock()
defer g.mu.Unlock()
g.failures[toolName]++
cnt := g.failures[toolName]
if cnt >= g.maxFails {
return fmt.Errorf("loop guard: tool '%s' failed %d consecutive times", toolName, cnt)
}
return nil
}
// Reset clears tracking for a tool (called on success or different args).
func (g *LoopGuard) Reset(toolName string) {
if g == nil {
return
}
g.mu.Lock()
defer g.mu.Unlock()
// Remove all same-args entries for this tool
for k := range g.sameArgs {
if len(k) > len(toolName) && k[:len(toolName)] == toolName {
delete(g.sameArgs, k)
}
}
delete(g.failures, toolName)
}
// ---- Tool capability and batch planning ----
// toolCapFromTool returns the capability of a tool.
func toolCapFromTool(t Tool) ToolCapability {
if ct, ok := t.(CapableTool); ok {
return ct.Capability()
}
return ToolCapWritesFiles // default: conservative serial
}
// executionBatch represents a group of tool calls to execute together.
type executionBatch struct {
mode batchMode
calls []schema.ToolCall
}
type batchMode int
const (
batchParallel batchMode = iota
batchSerial
)
// planBatches groups tool calls into parallel/serial batches based on capability.
// Read-only tools are grouped for parallel execution; others run serially.
func (tn *ToolsNode[M]) planBatches(tcs []schema.ToolCall) []executionBatch {
var batches []executionBatch
var currentParallel []schema.ToolCall
flushParallel := func() {
if len(currentParallel) > 0 {
batches = append(batches, executionBatch{mode: batchParallel, calls: currentParallel})
currentParallel = nil
}
}
for _, tc := range tcs {
tool, ok := tn.toolMap[tc.Function.Name]
if !ok {
// Unknown tool - treat as serial to be safe.
flushParallel()
batches = append(batches, executionBatch{mode: batchSerial, calls: []schema.ToolCall{tc}})
continue
}
cap := toolCapFromTool(tool)
if cap == ToolCapReadOnly {
currentParallel = append(currentParallel, tc)
} else {
flushParallel()
batches = append(batches, executionBatch{mode: batchSerial, calls: []schema.ToolCall{tc}})
}
}
flushParallel()
return batches
}
// ---- LoopGuard integration in executeOne ----
// getLoopGuard returns the LoopGuard from the ToolsNode if configured.
// It is stored on the ToolsNode to share state across invocation cycles.
func (tn *ToolsNode[M]) getLoopGuard() *LoopGuard {
if tn.config != nil {
return tn.config.LoopGuard
}
return nil
}
// clearLoopGuard writes the LoopGuard config back (no-op, config is shared by pointer).
func (tn *ToolsNode[M]) clearLoopGuard() {}