Files
ragflow/internal/harness/core/react_agent.go
Yingfeng 706e0d2d06 Refactor harness framework (#16271)
### What problem does this PR solve?

- Tools management
- Pregel engine wrapper for better usage
- UT race
- Coding style

### Type of change

- [x] Refactoring
2026-06-23 20:18:04 +08:00

696 lines
21 KiB
Go

package core
import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"
"ragflow/internal/harness/core/internal"
"ragflow/internal/harness/core/schema"
"ragflow/internal/harness/graph/graph"
)
// ReActConfig holds configuration for TypedReActAgent.
type ReActConfig[M MessageType] struct {
Model Model[M]
Tools []Tool
Instruction string
MaxIterations int
Middlewares []TypedReActMiddleware[M]
RetryConfig *TypedModelRetryConfig[M]
FailoverConfig *FailoverConfig[M]
ReturnDirectly map[string]bool
OutputKey string
GenModelInput TypedGenModelInput[M]
StateModifier StateModifier[M]
ToolsConfig *ToolsNodeConfig
EmitInternalEvents bool
// GraphReAct enables graph-based ReAct execution using the project's own
// StateGraph/Pregel engine. When true, each ReAct iteration runs as a graph
// node, providing automatic checkpoint, interrupt, and resume via the engine.
// Default: false (uses the simple for-loop in chatmodel_react.go).
GraphReAct bool
// GraphReActCheckpointer is the checkpointer used when GraphReAct is enabled.
// If nil, no checkpointing is performed (but interrupt is still available
// via WithInterrupts).
GraphReActCheckpointer graph.Checkpointer
// GraphReActInterruptBefore lists node names to interrupt before.
// Default: ["execute_tools"] (pause before tool execution for human approval).
GraphReActInterruptBefore []string
}
func DefaultReActConfig[M MessageType]() *ReActConfig[M] {
return &ReActConfig[M]{MaxIterations: 10, Instruction: internal.DefaultSystemPrompt}
}
// ReActAgentResumeData holds data provided during resume to modify agent behavior.
type ReActAgentResumeData struct {
HistoryModifier func(ctx context.Context, messages []Message) []Message
}
// ReActAgent implements the ReAct (Reasoning + Acting) pattern.
//
// Production features:
// - freeze-once: after first Run/Resume, configuration is frozen (atomic)
// - ToolsNode abstraction with middleware chain support
// - Enhanced Tool (4 endpoint types) support via handler interface
// - DeferredToolInfos for server-side tool search
// - EmitInternalEvents for AgentTool event forwarding
// - AfterToolCallsHook for AgentLoop integration
// - ResumeWithData / HistoryModifier for resume customization
// - gob encodability check on SetRunLocalValue
type ReActAgent[M MessageType] struct {
name string
desc string
config *ReActConfig[M]
once sync.Once
frozen uint32
run typedRunFunc[M]
exeCtx *execContext
}
var _ ResumableAgent = &ReActAgent[*schema.Message]{}
var _ TypedResumableAgent[*schema.AgenticMessage] = &ReActAgent[*schema.AgenticMessage]{}
type TypedGenModelInput[M MessageType] func(ctx context.Context, instruction string, input *TypedAgentInput[M]) ([]M, error)
// StateModifier allows transforming the agent state before model invocation.
type StateModifier[M MessageType] func(ctx context.Context, state *TypedReActAgentState[M]) (*TypedReActAgentState[M], error)
func defaultGenModelInput(ctx context.Context, instruction string, input *AgentInput) ([]Message, error) {
msgs := make([]Message, 0, len(input.Messages)+1)
if instruction != "" {
processed := resolveTemplate(instruction, ctx)
msgs = append(msgs, schema.SystemMessage(processed))
}
msgs = append(msgs, input.Messages...)
return msgs, nil
}
func resolveTemplate(tmpl string, ctx context.Context) string {
s := getSession(ctx)
if s == nil {
return tmpl
}
result := tmpl
for k, v := range s.Values {
repl := fmt.Sprintf("{%s}", k)
if sv, ok := v.(string); ok {
result = strings.ReplaceAll(result, repl, sv)
}
}
return result
}
func NewReActAgent[M MessageType](cfg *ReActConfig[M]) *ReActAgent[M] {
if cfg == nil {
cfg = DefaultReActConfig[M]()
}
a := &ReActAgent[M]{name: "react_agent", desc: "ReAct agent using a chat model", config: cfg}
if cfg.ToolsConfig == nil && len(cfg.Tools) > 0 {
cfg.ToolsConfig = &ToolsNodeConfig{Tools: cfg.Tools, ReturnDirectly: cfg.ReturnDirectly}
}
return a
}
func (a *ReActAgent[M]) WithName(n string) *ReActAgent[M] { a.name = n; return a }
func (a *ReActAgent[M]) WithDescription(d string) *ReActAgent[M] { a.desc = d; return a }
func (a *ReActAgent[M]) Name(_ context.Context) string { return a.name }
func (a *ReActAgent[M]) Description(_ context.Context) string { return a.desc }
func (a *ReActAgent[M]) GetType() string { return "ReActAgent" }
// ---- Freeze mechanism ----
func (a *ReActAgent[M]) IsFrozen() bool { return atomic.LoadUint32(&a.frozen) == 1 }
func (a *ReActAgent[M]) freeze() { atomic.StoreUint32(&a.frozen, 1) }
// ---- Run / Resume ----
func (a *ReActAgent[M]) Run(ctx context.Context, input *TypedAgentInput[M], opts ...RunOption) *AsyncIterator[*TypedAgentEvent[M]] {
it, gen := NewAsyncIteratorPair[*TypedAgentEvent[M]]()
go func() {
defer func() {
if r := recover(); r != nil {
gen.Send(&TypedAgentEvent[M]{Err: fmt.Errorf("panic: %v", r)})
}
gen.Close()
}()
runFunc := a.buildRunFunc(ctx)
runFunc(ctx, &typedRunParams[M]{input: input, generator: gen})
a.freeze()
}()
return it
}
func (a *ReActAgent[M]) Resume(ctx context.Context, info *ResumeInfo, opts ...RunOption) *AsyncIterator[*TypedAgentEvent[M]] {
it, gen := NewAsyncIteratorPair[*TypedAgentEvent[M]]()
go func() {
defer func() {
if r := recover(); r != nil {
gen.Send(&TypedAgentEvent[M]{Err: fmt.Errorf("panic: %v", r)})
}
gen.Close()
}()
if info.WasInterrupted {
if s, ok := info.InterruptState.(*TypedReActAgentState[M]); ok {
runFunc := a.buildRunFunc(ctx)
params := &typedRunParams[M]{input: &TypedAgentInput[M]{Messages: s.Messages, EnableStreaming: info.EnableStreaming}, generator: gen, interruptState: s, resumeInfo: info}
if info.ResumeData != nil {
if rd, ok := info.ResumeData.(*ReActAgentResumeData); ok && rd.HistoryModifier != nil {
params.historyModifier = rd.HistoryModifier
}
}
runFunc(ctx, params)
a.freeze()
return
}
}
gen.Send(&TypedAgentEvent[M]{Err: fmt.Errorf("resume called but agent was not interrupted or state is invalid")})
}()
return it
}
// ---- Internal types ----
type typedRunFunc[M MessageType] func(ctx context.Context, p *typedRunParams[M])
type typedRunParams[M MessageType] struct {
input *TypedAgentInput[M]
generator *AsyncGenerator[*TypedAgentEvent[M]]
interruptState *TypedReActAgentState[M]
resumeInfo *ResumeInfo
historyModifier func(context.Context, []Message) []Message
afterToolCallsHook func(ctx context.Context) error
}
// reActExecCtx carries per-execution state for event sending, cancellation,
// retry signal propagation, and after-tool-calls hooks.
type reActExecCtx struct {
generator *AsyncGenerator[*TypedAgentEvent[*schema.Message]]
cancelCtx *cancelContext
suppressEventSend bool
retrySignal *retrySignal
failoverLastModel Model[*schema.Message]
afterToolCallsHook func(ctx context.Context) error
}
func (ec *reActExecCtx) send(ev any) {
if ec != nil && ec.generator != nil {
if te, ok := ev.(*TypedAgentEvent[*schema.Message]); ok {
ec.generator.Send(te)
}
}
}
type execContext struct {
instruction string
returnDirectly map[string]bool
toolInfos []*schema.ToolInfo // from config.Tools + contributor ToolInfos
deferredToolInfos []*schema.ToolInfo
toolSearchTool *schema.ToolInfo
emitInternalEvents bool
// ToolContributor results (collected once in once.Do).
contribTools []Tool
contribToolInfos []*schema.ToolInfo
contribReturnDirectly map[string]bool
}
// ---- Run function builder ----
func (a *ReActAgent[M]) buildRunFunc(ctx context.Context) typedRunFunc[M] {
var onceRun typedRunFunc[M]
a.once.Do(func() {
ec, err := a.prepareExecContext(ctx)
if err != nil {
onceRun = func(_ context.Context, _ *typedRunParams[M]) {}
a.run = onceRun
return
}
a.exeCtx = ec
// Check for tools: config.Tools + ToolContributor tools
hasTools := len(a.config.Tools) > 0 ||
(a.config.ToolsConfig != nil && len(a.config.ToolsConfig.Tools) > 0) ||
len(ec.contribTools) > 0
if !hasTools {
onceRun = a.buildNoToolsRunFunc()
} else if a.config.GraphReAct {
onceRun = a.buildGraphReActRunFunc()
} else {
onceRun = a.buildReActRunFunc()
}
a.run = onceRun
})
return a.run
}
func (a *ReActAgent[M]) prepareExecContext(ctx context.Context) (*execContext, error) {
instruction := a.config.Instruction
if instruction == "" {
instruction = internal.DefaultSystemPrompt
}
rd := a.config.ReturnDirectly
if rd == nil {
rd = make(map[string]bool)
}
// Collect from ToolContributor middlewares.
contribTools := collectContributorTools(ctx, a.config.Middlewares)
contribInfos := collectContributorToolInfos(ctx, a.config.Middlewares)
contribRD := collectContributorReturnDirectly(ctx, a.config.Middlewares)
// Merge return-directly.
mergedRD := make(map[string]bool, len(rd)+len(contribRD))
for k, v := range rd {
mergedRD[k] = v
}
for k, v := range contribRD {
mergedRD[k] = v
}
// Merge tool infos from a single source to avoid duplicates:
// when ToolsConfig is nil, NewReActAgent populates ToolsConfig.Tools with a.config.Tools,
// so building from both sources would produce duplicate entries.
var baseInfos []*schema.ToolInfo
if a.config.ToolsConfig != nil && len(a.config.ToolsConfig.Tools) > 0 {
baseInfos = toolsToInfosTyped[M](a.config.ToolsConfig.Tools)
} else {
baseInfos = toolsToInfosTyped[M](a.config.Tools)
}
allInfos := make([]*schema.ToolInfo, 0, len(baseInfos)+len(contribInfos))
allInfos = append(allInfos, baseInfos...)
allInfos = append(allInfos, contribInfos...)
return &execContext{
instruction: instruction,
returnDirectly: mergedRD,
toolInfos: allInfos,
contribTools: contribTools,
contribToolInfos: contribInfos,
contribReturnDirectly: contribRD,
emitInternalEvents: a.config.EmitInternalEvents,
}, nil
}
// ---- No-tools run function ----
func (a *ReActAgent[M]) buildNoToolsRunFunc() typedRunFunc[M] {
return func(ctx context.Context, p *typedRunParams[M]) {
// Build allTools: config.Tools + contribTools
allTools := make([]Tool, 0, len(a.config.Tools)+len(a.exeCtx.contribTools))
allTools = append(allTools, a.config.Tools...)
allTools = append(allTools, a.exeCtx.contribTools...)
// BeforeAgent middleware
rc := &ReActAgentContext{Instruction: a.exeCtx.instruction, Tools: allTools, ReturnDirectly: a.exeCtx.returnDirectly}
if err := a.runBeforeAgent(&ctx, rc, p.generator); err != nil {
return
}
model := BuildModelWrapperChain(a.config.Model, nil, a.config, a.exeCtx.toolInfos)
state := NewReActAgentState(p.input.Messages, a.exeCtx.toolInfos, a.config.MaxIterations)
// BeforeModelRewrite middleware
mc := &TypedModelContext[M]{Tools: state.ToolInfos, ModelRetryConfig: a.config.RetryConfig, ModelFailoverConfig: a.config.FailoverConfig}
if err := a.runBeforeModelRewrite(&ctx, &state, mc, p.generator); err != nil {
return
}
if a.config.StateModifier != nil {
var err error
state, err = a.config.StateModifier(ctx, state)
if err != nil {
p.generator.Send(&TypedAgentEvent[M]{Err: fmt.Errorf("StateModifier: %w", err)})
return
}
}
modelMsgs := buildModelInputFromState[M](state.Messages, rc.Instruction)
resp, err := model.Generate(ctx, modelMsgs)
if err != nil {
p.generator.Send(&TypedAgentEvent[M]{Err: err})
return
}
p.generator.Send(typedModelOutputEvent(resp, nil))
state.Messages = append(state.Messages, resp)
// AfterModelRewrite middleware
if err := a.runAfterModelRewrite(&ctx, &state, mc, p.generator); err != nil {
return
}
if a.config.OutputKey != "" && !isNilMessage(resp) {
setOutputToSession(ctx, resp, a.config.OutputKey)
}
// AfterAgent middleware
a.runAfterAgent(&ctx, state, p.generator)
}
}
// runBeforeAgent executes the BeforeAgent middleware chain.
// Returns a non-nil error if any middleware signals termination.
func (a *ReActAgent[M]) runBeforeAgent(ctx *context.Context, rc *ReActAgentContext, gen *AsyncGenerator[*TypedAgentEvent[M]]) error {
for _, mw := range a.config.Middlewares {
if mw == nil {
continue
}
var err error
*ctx, rc, err = mw.BeforeAgent(*ctx, rc)
if err != nil {
gen.Send(&TypedAgentEvent[M]{Err: fmt.Errorf("BeforeAgent: %w", err)})
return err
}
}
return nil
}
// runBeforeModelRewrite executes the BeforeModelRewrite middleware chain.
func (a *ReActAgent[M]) runBeforeModelRewrite(ctx *context.Context, state **TypedReActAgentState[M], mc *TypedModelContext[M], gen *AsyncGenerator[*TypedAgentEvent[M]]) error {
for _, mw := range a.config.Middlewares {
if mw == nil {
continue
}
var err error
*ctx, *state, err = mw.BeforeModelRewrite(*ctx, *state, mc)
if err != nil {
gen.Send(&TypedAgentEvent[M]{Err: fmt.Errorf("BeforeModelRewrite: %w", err)})
return err
}
}
return nil
}
// runAfterModelRewrite executes the AfterModelRewrite middleware chain.
func (a *ReActAgent[M]) runAfterModelRewrite(ctx *context.Context, state **TypedReActAgentState[M], mc *TypedModelContext[M], gen *AsyncGenerator[*TypedAgentEvent[M]]) error {
for _, mw := range a.config.Middlewares {
if mw == nil {
continue
}
var err error
*ctx, *state, err = mw.AfterModelRewrite(*ctx, *state, mc)
if err != nil {
gen.Send(&TypedAgentEvent[M]{Err: fmt.Errorf("AfterModelRewrite: %w", err)})
return err
}
}
return nil
}
// runAfterAgent executes the AfterAgent middleware chain.
func (a *ReActAgent[M]) runAfterAgent(ctx *context.Context, state *TypedReActAgentState[M], gen *AsyncGenerator[*TypedAgentEvent[M]]) {
for _, mw := range a.config.Middlewares {
if mw == nil {
continue
}
var err error
*ctx, err = mw.AfterAgent(*ctx, state)
if err != nil {
gen.Send(&TypedAgentEvent[M]{Err: fmt.Errorf("AfterAgent: %w", err)})
return
}
}
}
// ---- Helpers ----
func buildModelInputFromState[M MessageType](messages []M, instruction string) []M {
var msgs []M
if instruction != "" {
msgs = append(msgs, any(schema.SystemMessage(instruction)).(M))
}
for _, m := range messages {
msgs = append(msgs, m)
}
return msgs
}
func setOutputToSession[M MessageType](ctx context.Context, msg M, key string) {
if !isNilMessage(msg) {
s := getSession(ctx)
if s != nil {
s.Values[key] = extractTextContent(msg)
}
}
}
func toolsToInfosTyped[M MessageType](tools []Tool) []*schema.ToolInfo {
infos := make([]*schema.ToolInfo, 0, len(tools))
for _, t := range tools {
if p, ok := t.(ToolInfoProvider); ok {
infos = append(infos, p.ToolInfo())
} else {
infos = append(infos, &schema.ToolInfo{Name: t.Name(), Description: t.Description()})
}
}
return infos
}
func extractTextContent[M MessageType](msg M) string {
switch v := any(msg).(type) {
case *schema.Message:
return v.Content
case *schema.AgenticMessage:
var texts []string
for _, b := range v.ContentBlocks {
if b.Type == "text" {
texts = append(texts, b.Text)
}
}
return strings.Join(texts, "\n")
default:
return ""
}
}
// findTool finds a tool by name from a list of tools.
func findTool(tools []Tool, name string) Tool {
for _, t := range tools {
if t.Name() == name {
return t
}
}
return nil
}
// extractToolCalls extracts tool calls from a model response message.
// It handles both *schema.Message (with ToolCalls field) and generic types.
func extractToolCalls[M MessageType](resp M) []schema.ToolCall {
switch v := any(resp).(type) {
case *schema.Message:
if len(v.ToolCalls) > 0 {
return v.ToolCalls
}
case *schema.AgenticMessage:
var tc []schema.ToolCall
for _, b := range v.ContentBlocks {
if b.Type == "tool_use" && b.ToolCall != nil && b.ToolCall.ID != "" && b.ToolCall.Name != "" {
tc = append(tc, schema.ToolCall{
ID: b.ToolCall.ID,
Function: schema.ToolCallFunction{Name: b.ToolCall.Name, Arguments: b.ToolCall.Arguments},
})
}
}
return tc
}
return nil
}
// streamWithCancel wraps a streaming model call with cancel detection.
func streamWithCancel[M MessageType](s *schema.StreamReader[M], cc *cancelContext) *schema.StreamReader[M] {
if cc == nil {
return s
}
select {
case <-cc.immediateChan:
s.Close()
r := schema.NewStreamReader[M]()
var zero M
r.Send(zero, ErrStreamCanceled)
r.Close()
return r
default:
}
r := schema.NewStreamReader[M]()
go func() {
defer r.Close()
defer s.Close()
ch := make(chan struct {
Data M
Err error
}, 64)
go func() {
defer close(ch)
for {
select {
case <-cc.immediateChan:
return
default:
}
d, e := s.Recv()
select {
case <-cc.immediateChan:
return
default:
}
select {
case ch <- struct {
Data M
Err error
}{d, e}:
case <-cc.immediateChan:
return
}
if e != nil {
return
}
}
}()
for {
select {
case <-cc.immediateChan:
var z M
r.Send(z, ErrStreamCanceled)
return
case v := <-ch:
if v.Err != nil {
return
}
r.Send(v.Data, nil)
}
}
}()
return r
}
// getChatModelExecCtx retrieves the chat model execution context from context.
func getChatModelExecCtx(ctx context.Context) *reActExecCtx {
rc := getRunCtx(ctx)
if rc == nil {
return nil
}
// The exec ctx is stored on the run session or passed via context value
if ec, ok := rc.Session.Values["__exec_ctx"].(*reActExecCtx); ok {
return ec
}
return nil
}
// getReActExecCtx retrieves the typed execution context from context.
func getReActExecCtx[M MessageType](ctx context.Context) *reActExecCtx {
return getChatModelExecCtx(ctx)
}
// CheckpointDataVersion is the version of checkpoint data format for forward compatibility.
type CheckpointDataVersion int
const CheckpointDataV1 CheckpointDataVersion = 1
// preprocessCheckpointData performs forward-compatible migration on resume data.
func preprocessCheckpointData(data any) any { return data }
// WithGraphReAct enables the graph-based ReAct execution engine for a ReActConfig.
// When enabled, each ReAct iteration runs as a StateGraph node with the Pregel engine,
// providing automatic checkpoint, interrupt before tool execution, and resume support.
//
// Usage:
//
// cfg := DefaultReActConfig[*schema.Message]()
// cfg.GraphReAct = true
// cfg.GraphReActCheckpointer = checkpoint.NewMemorySaver() // optional
func WithGraphReAct[M MessageType](cfg *ReActConfig[M], cptr graph.Checkpointer) {
cfg.GraphReAct = true
cfg.GraphReActCheckpointer = cptr
}
// WithGraphReActInterrupt sets which graph nodes to interrupt before.
// Default: ["execute_tools"]. Use this to customize interrupt behavior.
func WithGraphReActInterrupt[M MessageType](cfg *ReActConfig[M], interruptBefore ...string) {
cfg.GraphReActInterruptBefore = interruptBefore
}
// ---- Graph-based ReAct run function ----
//
// When GraphReAct is enabled, the ReAct loop runs as a StateGraph with the
// Pregel engine. Each iteration is a superstep, providing:
// - Automatic checkpoint at every node boundary (via graph.WithCheckpointer)
// - Interrupt before tool execution (via graph.WithInterrupts)
// - Resume from checkpoint on restart (via graph.Invoke with same config)
// - Streaming events via pregel.StreamManager
func (a *ReActAgent[M]) buildGraphReActRunFunc() typedRunFunc[M] {
return func(ctx context.Context, p *typedRunParams[M]) {
// Graph-based ReAct currently supports *schema.Message only.
// For AgenticMessage, fall back to the simple for-loop.
var zero M
_, isMessage := any(zero).(*schema.Message)
if !isMessage {
// Fallback: use standard for-loop for non-Message types.
a.buildReActRunFunc()(ctx, p)
return
}
// Build graph config from agent config.
graphCfg := &ReActGraphConfig{
Checkpointer: a.config.GraphReActCheckpointer,
InterruptBefore: a.config.GraphReActInterruptBefore,
RecursionLimit: a.config.MaxIterations * 2, // each iter = 2 nodes, so allow extra
}
// Type-assert the agent to *ReActAgent[*schema.Message].
msgAgent, ok := any(a).(*ReActAgent[*schema.Message])
if !ok {
p.generator.Send(&TypedAgentEvent[M]{Err: fmt.Errorf("graph ReAct: agent type assertion failed")})
return
}
rg, err := NewReActGraph(msgAgent, graphCfg, a.exeCtx.toolInfos)
if err != nil {
p.generator.Send(&TypedAgentEvent[M]{Err: fmt.Errorf("NewReActGraph: %w", err)})
return
}
// Build agent input.
input := &AgentInput{Messages: messageSliceToAny2(p.input.Messages)}
// Run the graph (synchronous invoke or streaming).
state, err := rg.Invoke(ctx, input, nil)
if err != nil {
p.generator.Send(&TypedAgentEvent[M]{Err: err})
return
}
// Emit the final model response as an event.
if len(state.Messages) > 0 {
last := state.Messages[len(state.Messages)-1]
if !isNilMessage(last) {
p.generator.Send(any(typedModelOutputEvent(last, nil)).(*TypedAgentEvent[M]))
}
}
// Emit afterToolCallsHook if configured.
if p.afterToolCallsHook != nil {
if err := p.afterToolCallsHook(ctx); err != nil {
p.generator.Send(&TypedAgentEvent[M]{Err: fmt.Errorf("after_tool_calls_hook: %w", err)})
}
}
}
}
// messageSliceToAny2 converts a []M (MessageType) to []*schema.Message for graph ReAct.
func messageSliceToAny2[M MessageType](msgs []M) []*schema.Message {
r := make([]*schema.Message, len(msgs))
for i, m := range msgs {
if msg, ok := any(m).(*schema.Message); ok {
r[i] = msg
} else {
// Fallback: skip non-Message items.
r[i] = nil
}
}
return r
}