mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-01 00:05:43 +08:00
### What problem does this PR solve? - Tools management - Pregel engine wrapper for better usage - UT race - Coding style ### Type of change - [x] Refactoring
508 lines
16 KiB
Go
508 lines
16 KiB
Go
// Package agentcore provides a graph-level ReAct loop using the project's own
|
|
// StateGraph engine, with built-in checkpoint/interrupt/resume support at each
|
|
// iteration boundary.
|
|
//
|
|
// The ReActGraph wraps a TypedChatModelAgent's loop into StateGraph nodes so that
|
|
// the graph engine's checkpointing (via graph.WithCheckpointer) and interrupt/resume
|
|
// (via graph.WithInterrupts) apply at each superstep automatically. This replaces
|
|
// the simple for-loop in chatmodel_react.go with the full Pregel execution engine.
|
|
//
|
|
// Key features:
|
|
// - Checkpoint at every model_generate and execute_tools node boundary
|
|
// - Interrupt before execute_tools for human-in-the-loop tool approval
|
|
// - Resume from interrupt via graph checkpoint restoration
|
|
// - Full middleware chain (BeforeAgent, BeforeModelRewrite, AfterModelRewrite, AfterAgent)
|
|
// - ToolsNode integration with ToolCallMiddlewares
|
|
// - Streaming events via pregel.StreamManager
|
|
// - Generic: supports both *schema.Message and *schema.AgenticMessage
|
|
package core
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"ragflow/internal/harness/core/schema"
|
|
"ragflow/internal/harness/graph/channels"
|
|
"ragflow/internal/harness/graph/constants"
|
|
"ragflow/internal/harness/graph/graph"
|
|
"ragflow/internal/harness/graph/pregel"
|
|
"ragflow/internal/harness/graph/types"
|
|
)
|
|
|
|
func init() {
|
|
schema.RegisterType("_harness_react_graph_state", func() any { return &ReActGraphState{} })
|
|
}
|
|
|
|
// ReActGraphState is the shared state for the graph-level ReAct loop.
|
|
// It persists across supersteps, enabling checkpoint and interrupt/resume.
|
|
type ReActGraphState struct {
|
|
Messages []*schema.Message
|
|
ToolInfos []*schema.ToolInfo
|
|
IterationsLeft int
|
|
MaxIterations int
|
|
AgentName string
|
|
Instruction string
|
|
HasToolCall bool // signals whether the last model output had tool calls
|
|
|
|
// ToolExecutedCache caches completed tool call results for interrupt/resume.
|
|
// Key = tool call ID, value = result content string.
|
|
// After successful completion of all tools in a superstep, this is cleared.
|
|
// On interrupt, it persists via the Pregel checkpoint and allows skipping
|
|
// already-executed tools on resume (equivalent to Eino's ToolsInterruptAndRerunExtra).
|
|
ToolExecutedCache map[string]string
|
|
}
|
|
|
|
// ReActGraph wraps a ChatModelAgent's loop into a StateGraph with automatic
|
|
// checkpoint at each iteration and interrupt before tool execution.
|
|
type ReActGraph struct {
|
|
compiled *graph.CompiledGraph
|
|
config *ReActConfig[*schema.Message]
|
|
agent *ReActAgent[*schema.Message]
|
|
allInfos []*schema.ToolInfo // merged config + contributor tool infos
|
|
allTools []Tool // merged config + contributor tools
|
|
}
|
|
|
|
// ReActGraphConfig holds options for building a ReActGraph.
|
|
type ReActGraphConfig struct {
|
|
Checkpointer graph.Checkpointer
|
|
InterruptBefore []string // node names to interrupt before (default: "execute_tools")
|
|
RecursionLimit int
|
|
}
|
|
|
|
// NewReActGraph builds a StateGraph with nodes:
|
|
//
|
|
// prepare_input → model_generate → execute_tools → check_done
|
|
// ↘ [end]
|
|
//
|
|
// Interrupt is set at "execute_tools" by default. With a Checkpointer, each node
|
|
// transition automatically saves a checkpoint via the Pregel engine.
|
|
//
|
|
// The graph applies the full middleware chain:
|
|
// - prepare_input: BeforeAgent
|
|
// - model_generate: BeforeModelRewrite → model call → AfterModelRewrite
|
|
// - check_done (on exit): AfterAgent
|
|
func NewReActGraph(agent *ReActAgent[*schema.Message], cfg *ReActGraphConfig, allToolInfos []*schema.ToolInfo) (*ReActGraph, error) {
|
|
if cfg == nil {
|
|
cfg = &ReActGraphConfig{}
|
|
}
|
|
agentCfg := agent.config
|
|
// Fallback: when allToolInfos is nil, derive from the agent's tools
|
|
// so model wrappers always have tool metadata.
|
|
if allToolInfos == nil {
|
|
allToolInfos = toolsToInfosTyped[*schema.Message](agentCfg.Tools)
|
|
if agentCfg.ToolsConfig != nil {
|
|
allToolInfos = append(allToolInfos, toolsToInfosTyped[*schema.Message](agentCfg.ToolsConfig.Tools)...)
|
|
}
|
|
}
|
|
sg := graph.NewStateGraph(&ReActGraphState{})
|
|
|
|
// Register channels for state fields used by the graph engine.
|
|
sg.AddChannel("messages", channels.NewLastValue([]*schema.Message{}))
|
|
sg.AddChannel("iterations_left", channels.NewLastValue(0))
|
|
sg.AddChannel("has_tool_call", channels.NewLastValue(false))
|
|
sg.AddChannel("tool_cache", channels.NewLastValue(map[string]string{}))
|
|
|
|
// --- Node: prepare_input ---
|
|
// Runs once at the start. Applies BeforeAgent middleware.
|
|
// Build allTools and allRD from config + contributor tools.
|
|
allTools := make([]Tool, 0, len(agent.config.Tools))
|
|
allTools = append(allTools, agent.config.Tools...)
|
|
allRD := make(map[string]bool)
|
|
for k, v := range agent.config.ReturnDirectly {
|
|
allRD[k] = v
|
|
}
|
|
if ec := agent.exeCtx; ec != nil {
|
|
allTools = append(allTools, ec.contribTools...)
|
|
for k, v := range ec.contribReturnDirectly {
|
|
allRD[k] = v
|
|
}
|
|
}
|
|
|
|
sg.AddNode("prepare_input", func(ctx context.Context, state interface{}) (interface{}, error) {
|
|
s := state.(*ReActGraphState)
|
|
rc := &ReActAgentContext{
|
|
Instruction: s.Instruction,
|
|
Tools: allTools,
|
|
ReturnDirectly: allRD,
|
|
}
|
|
for _, mw := range agentCfg.Middlewares {
|
|
if mw == nil {
|
|
continue
|
|
}
|
|
var err error
|
|
ctx, rc, err = mw.BeforeAgent(ctx, rc)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("BeforeAgent: %w", err)
|
|
}
|
|
}
|
|
s.Instruction = rc.Instruction
|
|
return s, nil
|
|
})
|
|
|
|
// --- Node: model_generate ---
|
|
// Calls the LLM with the current message history. Applies BeforeModelRewrite
|
|
// and AfterModelRewrite middleware chains.
|
|
// Clears the tool cache so each iteration starts fresh.
|
|
sg.AddNode("model_generate", func(ctx context.Context, state interface{}) (interface{}, error) {
|
|
s := state.(*ReActGraphState)
|
|
if s.IterationsLeft <= 0 {
|
|
return s, nil
|
|
}
|
|
s.IterationsLeft--
|
|
// Clear tool cache at start of each iteration.
|
|
s.ToolExecutedCache = nil
|
|
|
|
model := BuildModelWrapperChain(agentCfg.Model, nil, agentCfg, allToolInfos)
|
|
|
|
agentState := NewReActAgentState(
|
|
messageSliceToAny(s.Messages),
|
|
allToolInfos,
|
|
s.IterationsLeft+1,
|
|
)
|
|
typedState := (*TypedReActAgentState[*schema.Message])(agentState)
|
|
mc := &TypedModelContext[*schema.Message]{
|
|
Tools: allToolInfos,
|
|
ModelRetryConfig: agentCfg.RetryConfig,
|
|
ModelFailoverConfig: agentCfg.FailoverConfig,
|
|
}
|
|
|
|
// BeforeModelRewrite middleware chain.
|
|
for _, mw := range agentCfg.Middlewares {
|
|
if mw == nil {
|
|
continue
|
|
}
|
|
var err error
|
|
ctx, typedState, err = mw.BeforeModelRewrite(ctx, typedState, mc)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("BeforeModelRewrite: %w", err)
|
|
}
|
|
}
|
|
s.Messages = typedState.Messages
|
|
|
|
// StateModifier hook (e.g., context window trimming).
|
|
if agentCfg.StateModifier != nil {
|
|
var err error
|
|
typedState, err = agentCfg.StateModifier(ctx, typedState)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("StateModifier: %w", err)
|
|
}
|
|
s.Messages = typedState.Messages
|
|
}
|
|
|
|
// Build model input (via GenModelInput or default).
|
|
var modelMsgs []*schema.Message
|
|
if agentCfg.GenModelInput != nil {
|
|
var err error
|
|
modelMsgs, err = agentCfg.GenModelInput(ctx, s.Instruction,
|
|
&TypedAgentInput[*schema.Message]{Messages: s.Messages})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("GenModelInput: %w", err)
|
|
}
|
|
} else {
|
|
modelMsgs = buildModelInputFromState(s.Messages, s.Instruction)
|
|
}
|
|
|
|
// Call model.
|
|
resp, err := model.Generate(ctx, modelMsgs)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("model: %w", err)
|
|
}
|
|
s.Messages = append(s.Messages, resp)
|
|
|
|
// AfterModelRewrite middleware chain.
|
|
typedState.Messages = s.Messages
|
|
for _, mw := range agentCfg.Middlewares {
|
|
if mw == nil {
|
|
continue
|
|
}
|
|
var err error
|
|
ctx, typedState, err = mw.AfterModelRewrite(ctx, typedState, mc)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("AfterModelRewrite: %w", err)
|
|
}
|
|
}
|
|
s.Messages = typedState.Messages
|
|
|
|
// Detect if the model produced tool calls.
|
|
toolCalls := extractToolCalls(resp)
|
|
s.HasToolCall = len(toolCalls) > 0
|
|
|
|
return s, nil
|
|
})
|
|
|
|
// --- Node: execute_tools ---
|
|
// Executes tool calls found in the last model response using ToolsNode.
|
|
// Supports interrupt/resume via ToolExecutedCache: on interrupt, completed
|
|
// tool results are saved to the cache (persisted via Pregel channel checkpoint).
|
|
// On resume, already-cached tools are skipped.
|
|
sg.AddNode("execute_tools", func(ctx context.Context, state interface{}) (interface{}, error) {
|
|
s := state.(*ReActGraphState)
|
|
if len(s.Messages) == 0 {
|
|
return s, nil
|
|
}
|
|
last := s.Messages[len(s.Messages)-1]
|
|
toolCalls := extractToolCalls(last)
|
|
if len(toolCalls) == 0 {
|
|
return s, nil
|
|
}
|
|
|
|
// Restore or initialize the tool execution cache.
|
|
cache := s.ToolExecutedCache
|
|
if cache == nil {
|
|
cache = make(map[string]string)
|
|
}
|
|
|
|
// Filter out already-cached (previously completed) tool calls.
|
|
var pendingCalls []schema.ToolCall
|
|
for _, tc := range toolCalls {
|
|
if _, done := cache[tc.ID]; !done {
|
|
pendingCalls = append(pendingCalls, tc)
|
|
}
|
|
}
|
|
if len(pendingCalls) == 0 {
|
|
return s, nil
|
|
}
|
|
|
|
agentState := NewReActAgentState(
|
|
messageSliceToAny(s.Messages),
|
|
s.ToolInfos,
|
|
s.IterationsLeft,
|
|
)
|
|
typedState := (*TypedReActAgentState[*schema.Message])(agentState)
|
|
|
|
// Build ToolsNode from config ToolsConfig (for ToolInvokeMiddlewares etc.)
|
|
// but overlay allTools (config + contributor).
|
|
tnCfg := &ToolsNodeConfig{}
|
|
if agentCfg.ToolsConfig != nil {
|
|
*tnCfg = *agentCfg.ToolsConfig
|
|
}
|
|
tnCfg.Tools = allTools
|
|
tnCfg.ReturnDirectly = allRD
|
|
tn := NewToolsNode[*schema.Message](tnCfg)
|
|
|
|
// Execute pending calls one at a time so we can track per-call results.
|
|
// Each call uses a single-tool-call message to keep tracking simple.
|
|
var firstErr error
|
|
var toolInterrupted bool
|
|
for _, tc := range pendingCalls {
|
|
// Build a fresh message containing only this tool call.
|
|
singleMsg := &schema.Message{
|
|
Role: schema.RoleAssistant,
|
|
Content: "",
|
|
ToolCalls: []schema.ToolCall{tc},
|
|
}
|
|
var action *AgentAction
|
|
var toolResults []*schema.Message
|
|
toolResults, action, firstErr = tn.Execute(ctx, singleMsg, typedState, nil)
|
|
if firstErr != nil {
|
|
// Check if this is a tool interrupt (not a real error).
|
|
var ir *interruptResult
|
|
if errors.As(firstErr, &ir) {
|
|
toolInterrupted = true
|
|
firstErr = nil
|
|
// Tool interrupted — still save its message to state.
|
|
for _, tr := range toolResults {
|
|
s.Messages = append(s.Messages, tr)
|
|
if tr != nil && tr.Content != "" {
|
|
cache[tc.ID] = tr.Content
|
|
}
|
|
}
|
|
break
|
|
}
|
|
// Real error — stop.
|
|
break
|
|
}
|
|
for _, tr := range toolResults {
|
|
s.Messages = append(s.Messages, tr)
|
|
if tr != nil && tr.Content != "" {
|
|
cache[tc.ID] = tr.Content
|
|
}
|
|
}
|
|
if action != nil && action.Exit {
|
|
s.IterationsLeft = 0
|
|
s.HasToolCall = false
|
|
break
|
|
}
|
|
}
|
|
|
|
if firstErr != nil {
|
|
s.ToolExecutedCache = cache
|
|
return s, fmt.Errorf("tools: %w", firstErr)
|
|
}
|
|
|
|
if toolInterrupted {
|
|
// Save cache and return cleanly so Pregel engine checkpoints the state.
|
|
s.ToolExecutedCache = cache
|
|
return s, nil
|
|
}
|
|
|
|
// All tools completed successfully — clear cache for next iteration.
|
|
s.ToolExecutedCache = nil
|
|
return s, nil
|
|
})
|
|
|
|
// --- Node: check_done ---
|
|
// Emits AfterAgent middleware and writes the final output.
|
|
sg.AddNode("check_done", func(ctx context.Context, state interface{}) (interface{}, error) {
|
|
s := state.(*ReActGraphState)
|
|
agentState := NewReActAgentState(
|
|
messageSliceToAny(s.Messages),
|
|
s.ToolInfos,
|
|
s.IterationsLeft,
|
|
)
|
|
typedState := (*TypedReActAgentState[*schema.Message])(agentState)
|
|
|
|
for _, mw := range agentCfg.Middlewares {
|
|
if mw == nil {
|
|
continue
|
|
}
|
|
var err error
|
|
ctx, err = mw.AfterAgent(ctx, typedState)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("AfterAgent: %w", err)
|
|
}
|
|
}
|
|
|
|
// Store output in session if configured.
|
|
if agentCfg.OutputKey != "" && len(s.Messages) > 0 {
|
|
last := s.Messages[len(s.Messages)-1]
|
|
setOutputToSession(ctx, last, agentCfg.OutputKey)
|
|
}
|
|
return s, nil
|
|
})
|
|
|
|
// --- Edges ---
|
|
sg.AddEdge(constants.Start, "prepare_input")
|
|
sg.AddEdge("prepare_input", "model_generate")
|
|
|
|
// Conditional: if no tool calls → check_done (which goes to end), else → execute_tools
|
|
sg.AddConditionalEdges("model_generate", func(ctx context.Context, state interface{}) (interface{}, error) {
|
|
s := state.(*ReActGraphState)
|
|
if s.IterationsLeft <= 0 || !s.HasToolCall {
|
|
return "check_done", nil
|
|
}
|
|
return "execute_tools", nil
|
|
}, map[string]string{
|
|
"check_done": "check_done",
|
|
"execute_tools": "execute_tools",
|
|
})
|
|
|
|
sg.AddEdge("execute_tools", "model_generate") // loop back for next iteration
|
|
sg.AddEdge("check_done", constants.End) // terminal node
|
|
|
|
// --- Compile with checkpoint and interrupt ---
|
|
interrupts := cfg.InterruptBefore
|
|
if len(interrupts) == 0 {
|
|
interrupts = []string{"execute_tools"}
|
|
}
|
|
rl := cfg.RecursionLimit
|
|
if rl <= 0 {
|
|
rl = constants.DefaultRecursionLimit
|
|
}
|
|
|
|
compileOpts := []graph.CompileOption{
|
|
graph.WithRecursionLimit(rl),
|
|
}
|
|
if cfg.Checkpointer != nil {
|
|
compileOpts = append(compileOpts, graph.WithCheckpointer(cfg.Checkpointer))
|
|
}
|
|
for _, name := range interrupts {
|
|
compileOpts = append(compileOpts, graph.WithInterrupts(name))
|
|
}
|
|
|
|
compiled, err := sg.Compile(compileOpts...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("compile ReAct graph: %w", err)
|
|
}
|
|
|
|
return &ReActGraph{
|
|
compiled: compiled,
|
|
config: agentCfg,
|
|
agent: agent,
|
|
allInfos: allToolInfos,
|
|
allTools: allTools,
|
|
}, nil
|
|
}
|
|
|
|
// Invoke runs the graph-level ReAct loop synchronously via the Pregel engine.
|
|
// When input is nil (resume path), the graph restores state from the checkpoint;
|
|
// buildInitialState returns nil to let the engine handle it.
|
|
func (rg *ReActGraph) Invoke(ctx context.Context, input *AgentInput, config *types.RunnableConfig) (*ReActGraphState, error) {
|
|
var state interface{}
|
|
if input != nil {
|
|
state = rg.buildInitialState(input)
|
|
}
|
|
|
|
result, err := rg.compiled.Invoke(ctx, state, config)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
outState, ok := result.(*ReActGraphState)
|
|
if !ok {
|
|
return nil, fmt.Errorf("unexpected result type %T from graph", result)
|
|
}
|
|
return outState, nil
|
|
}
|
|
|
|
// Stream runs the graph-level ReAct loop with streaming events via Pregel.
|
|
// Returns (outputCh, errCh). The outputCh yields pregel.StreamEvent values
|
|
// including checkpoint, task start/end, values, and final state.
|
|
func (rg *ReActGraph) Stream(ctx context.Context, input *AgentInput, config *types.RunnableConfig, mode types.StreamMode) (<-chan interface{}, <-chan error) {
|
|
state := rg.buildInitialState(input)
|
|
return rg.compiled.Stream(ctx, state, mode, config)
|
|
}
|
|
|
|
// Resume resumes a previously interrupted graph execution from its checkpoint.
|
|
func (rg *ReActGraph) Resume(ctx context.Context, config *types.RunnableConfig) (*ReActGraphState, error) {
|
|
// Pass config so Pregel engine can restore the correct checkpoint.
|
|
result, err := rg.compiled.Invoke(ctx, nil, config)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
outState, ok := result.(*ReActGraphState)
|
|
if !ok {
|
|
return nil, fmt.Errorf("unexpected result type %T from resumed graph", result)
|
|
}
|
|
return outState, nil
|
|
}
|
|
|
|
// ResumeStream resumes a previously interrupted graph with streaming.
|
|
func (rg *ReActGraph) ResumeStream(ctx context.Context, config *types.RunnableConfig, mode types.StreamMode) (<-chan interface{}, <-chan error) {
|
|
return rg.compiled.Stream(ctx, nil, mode, config)
|
|
}
|
|
|
|
// Compile returns the underlying compiled graph for direct access.
|
|
func (rg *ReActGraph) Compile() *graph.CompiledGraph { return rg.compiled }
|
|
|
|
// ---- helpers ----
|
|
|
|
func (rg *ReActGraph) buildInitialState(input *AgentInput) *ReActGraphState {
|
|
maxIter := rg.config.MaxIterations
|
|
if maxIter <= 0 {
|
|
maxIter = 10
|
|
}
|
|
state := &ReActGraphState{
|
|
Messages: input.Messages,
|
|
IterationsLeft: maxIter,
|
|
MaxIterations: maxIter,
|
|
AgentName: rg.agent.name,
|
|
Instruction: rg.config.Instruction,
|
|
}
|
|
// Use merged tool infos (config + contributor).
|
|
state.ToolInfos = make([]*schema.ToolInfo, len(rg.allInfos))
|
|
copy(state.ToolInfos, rg.allInfos)
|
|
return state
|
|
}
|
|
|
|
func messageSliceToAny(msgs []*schema.Message) []Message {
|
|
r := make([]Message, len(msgs))
|
|
for i, m := range msgs {
|
|
r[i] = m
|
|
}
|
|
return r
|
|
}
|
|
|
|
// Ensure pregel is imported for side effects (engine registration).
|
|
var _ = pregel.Engine{}
|