Files
ragflow/internal/harness/core/react_graph.go
Yingfeng 706e0d2d06 Refactor harness framework (#16271)
### What problem does this PR solve?

- Tools management
- Pregel engine wrapper for better usage
- UT race
- Coding style

### Type of change

- [x] Refactoring
2026-06-23 20:18:04 +08:00

508 lines
16 KiB
Go

// Package agentcore provides a graph-level ReAct loop using the project's own
// StateGraph engine, with built-in checkpoint/interrupt/resume support at each
// iteration boundary.
//
// The ReActGraph wraps a TypedChatModelAgent's loop into StateGraph nodes so that
// the graph engine's checkpointing (via graph.WithCheckpointer) and interrupt/resume
// (via graph.WithInterrupts) apply at each superstep automatically. This replaces
// the simple for-loop in chatmodel_react.go with the full Pregel execution engine.
//
// Key features:
// - Checkpoint at every model_generate and execute_tools node boundary
// - Interrupt before execute_tools for human-in-the-loop tool approval
// - Resume from interrupt via graph checkpoint restoration
// - Full middleware chain (BeforeAgent, BeforeModelRewrite, AfterModelRewrite, AfterAgent)
// - ToolsNode integration with ToolCallMiddlewares
// - Streaming events via pregel.StreamManager
// - Generic: supports both *schema.Message and *schema.AgenticMessage
package core
import (
"context"
"errors"
"fmt"
"ragflow/internal/harness/core/schema"
"ragflow/internal/harness/graph/channels"
"ragflow/internal/harness/graph/constants"
"ragflow/internal/harness/graph/graph"
"ragflow/internal/harness/graph/pregel"
"ragflow/internal/harness/graph/types"
)
func init() {
schema.RegisterType("_harness_react_graph_state", func() any { return &ReActGraphState{} })
}
// ReActGraphState is the shared state for the graph-level ReAct loop.
// It persists across supersteps, enabling checkpoint and interrupt/resume.
type ReActGraphState struct {
Messages []*schema.Message
ToolInfos []*schema.ToolInfo
IterationsLeft int
MaxIterations int
AgentName string
Instruction string
HasToolCall bool // signals whether the last model output had tool calls
// ToolExecutedCache caches completed tool call results for interrupt/resume.
// Key = tool call ID, value = result content string.
// After successful completion of all tools in a superstep, this is cleared.
// On interrupt, it persists via the Pregel checkpoint and allows skipping
// already-executed tools on resume (equivalent to Eino's ToolsInterruptAndRerunExtra).
ToolExecutedCache map[string]string
}
// ReActGraph wraps a ChatModelAgent's loop into a StateGraph with automatic
// checkpoint at each iteration and interrupt before tool execution.
type ReActGraph struct {
compiled *graph.CompiledGraph
config *ReActConfig[*schema.Message]
agent *ReActAgent[*schema.Message]
allInfos []*schema.ToolInfo // merged config + contributor tool infos
allTools []Tool // merged config + contributor tools
}
// ReActGraphConfig holds options for building a ReActGraph.
type ReActGraphConfig struct {
Checkpointer graph.Checkpointer
InterruptBefore []string // node names to interrupt before (default: "execute_tools")
RecursionLimit int
}
// NewReActGraph builds a StateGraph with nodes:
//
// prepare_input → model_generate → execute_tools → check_done
// ↘ [end]
//
// Interrupt is set at "execute_tools" by default. With a Checkpointer, each node
// transition automatically saves a checkpoint via the Pregel engine.
//
// The graph applies the full middleware chain:
// - prepare_input: BeforeAgent
// - model_generate: BeforeModelRewrite → model call → AfterModelRewrite
// - check_done (on exit): AfterAgent
func NewReActGraph(agent *ReActAgent[*schema.Message], cfg *ReActGraphConfig, allToolInfos []*schema.ToolInfo) (*ReActGraph, error) {
if cfg == nil {
cfg = &ReActGraphConfig{}
}
agentCfg := agent.config
// Fallback: when allToolInfos is nil, derive from the agent's tools
// so model wrappers always have tool metadata.
if allToolInfos == nil {
allToolInfos = toolsToInfosTyped[*schema.Message](agentCfg.Tools)
if agentCfg.ToolsConfig != nil {
allToolInfos = append(allToolInfos, toolsToInfosTyped[*schema.Message](agentCfg.ToolsConfig.Tools)...)
}
}
sg := graph.NewStateGraph(&ReActGraphState{})
// Register channels for state fields used by the graph engine.
sg.AddChannel("messages", channels.NewLastValue([]*schema.Message{}))
sg.AddChannel("iterations_left", channels.NewLastValue(0))
sg.AddChannel("has_tool_call", channels.NewLastValue(false))
sg.AddChannel("tool_cache", channels.NewLastValue(map[string]string{}))
// --- Node: prepare_input ---
// Runs once at the start. Applies BeforeAgent middleware.
// Build allTools and allRD from config + contributor tools.
allTools := make([]Tool, 0, len(agent.config.Tools))
allTools = append(allTools, agent.config.Tools...)
allRD := make(map[string]bool)
for k, v := range agent.config.ReturnDirectly {
allRD[k] = v
}
if ec := agent.exeCtx; ec != nil {
allTools = append(allTools, ec.contribTools...)
for k, v := range ec.contribReturnDirectly {
allRD[k] = v
}
}
sg.AddNode("prepare_input", func(ctx context.Context, state interface{}) (interface{}, error) {
s := state.(*ReActGraphState)
rc := &ReActAgentContext{
Instruction: s.Instruction,
Tools: allTools,
ReturnDirectly: allRD,
}
for _, mw := range agentCfg.Middlewares {
if mw == nil {
continue
}
var err error
ctx, rc, err = mw.BeforeAgent(ctx, rc)
if err != nil {
return nil, fmt.Errorf("BeforeAgent: %w", err)
}
}
s.Instruction = rc.Instruction
return s, nil
})
// --- Node: model_generate ---
// Calls the LLM with the current message history. Applies BeforeModelRewrite
// and AfterModelRewrite middleware chains.
// Clears the tool cache so each iteration starts fresh.
sg.AddNode("model_generate", func(ctx context.Context, state interface{}) (interface{}, error) {
s := state.(*ReActGraphState)
if s.IterationsLeft <= 0 {
return s, nil
}
s.IterationsLeft--
// Clear tool cache at start of each iteration.
s.ToolExecutedCache = nil
model := BuildModelWrapperChain(agentCfg.Model, nil, agentCfg, allToolInfos)
agentState := NewReActAgentState(
messageSliceToAny(s.Messages),
allToolInfos,
s.IterationsLeft+1,
)
typedState := (*TypedReActAgentState[*schema.Message])(agentState)
mc := &TypedModelContext[*schema.Message]{
Tools: allToolInfos,
ModelRetryConfig: agentCfg.RetryConfig,
ModelFailoverConfig: agentCfg.FailoverConfig,
}
// BeforeModelRewrite middleware chain.
for _, mw := range agentCfg.Middlewares {
if mw == nil {
continue
}
var err error
ctx, typedState, err = mw.BeforeModelRewrite(ctx, typedState, mc)
if err != nil {
return nil, fmt.Errorf("BeforeModelRewrite: %w", err)
}
}
s.Messages = typedState.Messages
// StateModifier hook (e.g., context window trimming).
if agentCfg.StateModifier != nil {
var err error
typedState, err = agentCfg.StateModifier(ctx, typedState)
if err != nil {
return nil, fmt.Errorf("StateModifier: %w", err)
}
s.Messages = typedState.Messages
}
// Build model input (via GenModelInput or default).
var modelMsgs []*schema.Message
if agentCfg.GenModelInput != nil {
var err error
modelMsgs, err = agentCfg.GenModelInput(ctx, s.Instruction,
&TypedAgentInput[*schema.Message]{Messages: s.Messages})
if err != nil {
return nil, fmt.Errorf("GenModelInput: %w", err)
}
} else {
modelMsgs = buildModelInputFromState(s.Messages, s.Instruction)
}
// Call model.
resp, err := model.Generate(ctx, modelMsgs)
if err != nil {
return nil, fmt.Errorf("model: %w", err)
}
s.Messages = append(s.Messages, resp)
// AfterModelRewrite middleware chain.
typedState.Messages = s.Messages
for _, mw := range agentCfg.Middlewares {
if mw == nil {
continue
}
var err error
ctx, typedState, err = mw.AfterModelRewrite(ctx, typedState, mc)
if err != nil {
return nil, fmt.Errorf("AfterModelRewrite: %w", err)
}
}
s.Messages = typedState.Messages
// Detect if the model produced tool calls.
toolCalls := extractToolCalls(resp)
s.HasToolCall = len(toolCalls) > 0
return s, nil
})
// --- Node: execute_tools ---
// Executes tool calls found in the last model response using ToolsNode.
// Supports interrupt/resume via ToolExecutedCache: on interrupt, completed
// tool results are saved to the cache (persisted via Pregel channel checkpoint).
// On resume, already-cached tools are skipped.
sg.AddNode("execute_tools", func(ctx context.Context, state interface{}) (interface{}, error) {
s := state.(*ReActGraphState)
if len(s.Messages) == 0 {
return s, nil
}
last := s.Messages[len(s.Messages)-1]
toolCalls := extractToolCalls(last)
if len(toolCalls) == 0 {
return s, nil
}
// Restore or initialize the tool execution cache.
cache := s.ToolExecutedCache
if cache == nil {
cache = make(map[string]string)
}
// Filter out already-cached (previously completed) tool calls.
var pendingCalls []schema.ToolCall
for _, tc := range toolCalls {
if _, done := cache[tc.ID]; !done {
pendingCalls = append(pendingCalls, tc)
}
}
if len(pendingCalls) == 0 {
return s, nil
}
agentState := NewReActAgentState(
messageSliceToAny(s.Messages),
s.ToolInfos,
s.IterationsLeft,
)
typedState := (*TypedReActAgentState[*schema.Message])(agentState)
// Build ToolsNode from config ToolsConfig (for ToolInvokeMiddlewares etc.)
// but overlay allTools (config + contributor).
tnCfg := &ToolsNodeConfig{}
if agentCfg.ToolsConfig != nil {
*tnCfg = *agentCfg.ToolsConfig
}
tnCfg.Tools = allTools
tnCfg.ReturnDirectly = allRD
tn := NewToolsNode[*schema.Message](tnCfg)
// Execute pending calls one at a time so we can track per-call results.
// Each call uses a single-tool-call message to keep tracking simple.
var firstErr error
var toolInterrupted bool
for _, tc := range pendingCalls {
// Build a fresh message containing only this tool call.
singleMsg := &schema.Message{
Role: schema.RoleAssistant,
Content: "",
ToolCalls: []schema.ToolCall{tc},
}
var action *AgentAction
var toolResults []*schema.Message
toolResults, action, firstErr = tn.Execute(ctx, singleMsg, typedState, nil)
if firstErr != nil {
// Check if this is a tool interrupt (not a real error).
var ir *interruptResult
if errors.As(firstErr, &ir) {
toolInterrupted = true
firstErr = nil
// Tool interrupted — still save its message to state.
for _, tr := range toolResults {
s.Messages = append(s.Messages, tr)
if tr != nil && tr.Content != "" {
cache[tc.ID] = tr.Content
}
}
break
}
// Real error — stop.
break
}
for _, tr := range toolResults {
s.Messages = append(s.Messages, tr)
if tr != nil && tr.Content != "" {
cache[tc.ID] = tr.Content
}
}
if action != nil && action.Exit {
s.IterationsLeft = 0
s.HasToolCall = false
break
}
}
if firstErr != nil {
s.ToolExecutedCache = cache
return s, fmt.Errorf("tools: %w", firstErr)
}
if toolInterrupted {
// Save cache and return cleanly so Pregel engine checkpoints the state.
s.ToolExecutedCache = cache
return s, nil
}
// All tools completed successfully — clear cache for next iteration.
s.ToolExecutedCache = nil
return s, nil
})
// --- Node: check_done ---
// Emits AfterAgent middleware and writes the final output.
sg.AddNode("check_done", func(ctx context.Context, state interface{}) (interface{}, error) {
s := state.(*ReActGraphState)
agentState := NewReActAgentState(
messageSliceToAny(s.Messages),
s.ToolInfos,
s.IterationsLeft,
)
typedState := (*TypedReActAgentState[*schema.Message])(agentState)
for _, mw := range agentCfg.Middlewares {
if mw == nil {
continue
}
var err error
ctx, err = mw.AfterAgent(ctx, typedState)
if err != nil {
return nil, fmt.Errorf("AfterAgent: %w", err)
}
}
// Store output in session if configured.
if agentCfg.OutputKey != "" && len(s.Messages) > 0 {
last := s.Messages[len(s.Messages)-1]
setOutputToSession(ctx, last, agentCfg.OutputKey)
}
return s, nil
})
// --- Edges ---
sg.AddEdge(constants.Start, "prepare_input")
sg.AddEdge("prepare_input", "model_generate")
// Conditional: if no tool calls → check_done (which goes to end), else → execute_tools
sg.AddConditionalEdges("model_generate", func(ctx context.Context, state interface{}) (interface{}, error) {
s := state.(*ReActGraphState)
if s.IterationsLeft <= 0 || !s.HasToolCall {
return "check_done", nil
}
return "execute_tools", nil
}, map[string]string{
"check_done": "check_done",
"execute_tools": "execute_tools",
})
sg.AddEdge("execute_tools", "model_generate") // loop back for next iteration
sg.AddEdge("check_done", constants.End) // terminal node
// --- Compile with checkpoint and interrupt ---
interrupts := cfg.InterruptBefore
if len(interrupts) == 0 {
interrupts = []string{"execute_tools"}
}
rl := cfg.RecursionLimit
if rl <= 0 {
rl = constants.DefaultRecursionLimit
}
compileOpts := []graph.CompileOption{
graph.WithRecursionLimit(rl),
}
if cfg.Checkpointer != nil {
compileOpts = append(compileOpts, graph.WithCheckpointer(cfg.Checkpointer))
}
for _, name := range interrupts {
compileOpts = append(compileOpts, graph.WithInterrupts(name))
}
compiled, err := sg.Compile(compileOpts...)
if err != nil {
return nil, fmt.Errorf("compile ReAct graph: %w", err)
}
return &ReActGraph{
compiled: compiled,
config: agentCfg,
agent: agent,
allInfos: allToolInfos,
allTools: allTools,
}, nil
}
// Invoke runs the graph-level ReAct loop synchronously via the Pregel engine.
// When input is nil (resume path), the graph restores state from the checkpoint;
// buildInitialState returns nil to let the engine handle it.
func (rg *ReActGraph) Invoke(ctx context.Context, input *AgentInput, config *types.RunnableConfig) (*ReActGraphState, error) {
var state interface{}
if input != nil {
state = rg.buildInitialState(input)
}
result, err := rg.compiled.Invoke(ctx, state, config)
if err != nil {
return nil, err
}
outState, ok := result.(*ReActGraphState)
if !ok {
return nil, fmt.Errorf("unexpected result type %T from graph", result)
}
return outState, nil
}
// Stream runs the graph-level ReAct loop with streaming events via Pregel.
// Returns (outputCh, errCh). The outputCh yields pregel.StreamEvent values
// including checkpoint, task start/end, values, and final state.
func (rg *ReActGraph) Stream(ctx context.Context, input *AgentInput, config *types.RunnableConfig, mode types.StreamMode) (<-chan interface{}, <-chan error) {
state := rg.buildInitialState(input)
return rg.compiled.Stream(ctx, state, mode, config)
}
// Resume resumes a previously interrupted graph execution from its checkpoint.
func (rg *ReActGraph) Resume(ctx context.Context, config *types.RunnableConfig) (*ReActGraphState, error) {
// Pass config so Pregel engine can restore the correct checkpoint.
result, err := rg.compiled.Invoke(ctx, nil, config)
if err != nil {
return nil, err
}
outState, ok := result.(*ReActGraphState)
if !ok {
return nil, fmt.Errorf("unexpected result type %T from resumed graph", result)
}
return outState, nil
}
// ResumeStream resumes a previously interrupted graph with streaming.
func (rg *ReActGraph) ResumeStream(ctx context.Context, config *types.RunnableConfig, mode types.StreamMode) (<-chan interface{}, <-chan error) {
return rg.compiled.Stream(ctx, nil, mode, config)
}
// Compile returns the underlying compiled graph for direct access.
func (rg *ReActGraph) Compile() *graph.CompiledGraph { return rg.compiled }
// ---- helpers ----
func (rg *ReActGraph) buildInitialState(input *AgentInput) *ReActGraphState {
maxIter := rg.config.MaxIterations
if maxIter <= 0 {
maxIter = 10
}
state := &ReActGraphState{
Messages: input.Messages,
IterationsLeft: maxIter,
MaxIterations: maxIter,
AgentName: rg.agent.name,
Instruction: rg.config.Instruction,
}
// Use merged tool infos (config + contributor).
state.ToolInfos = make([]*schema.ToolInfo, len(rg.allInfos))
copy(state.ToolInfos, rg.allInfos)
return state
}
func messageSliceToAny(msgs []*schema.Message) []Message {
r := make([]Message, len(msgs))
for i, m := range msgs {
r[i] = m
}
return r
}
// Ensure pregel is imported for side effects (engine registration).
var _ = pregel.Engine{}