Files
ragflow/internal/harness/core/react_graph_test.go
Yingfeng 706e0d2d06 Refactor harness framework (#16271)
### What problem does this PR solve?

- Tools management
- Pregel engine wrapper for better usage
- UT race
- Coding style

### Type of change

- [x] Refactoring
2026-06-23 20:18:04 +08:00

410 lines
12 KiB
Go

package core
import (
"context"
stderrors "errors"
"testing"
"time"
"ragflow/internal/harness/core/schema"
"ragflow/internal/harness/graph/checkpoint"
harnesserrors "ragflow/internal/harness/graph/errors"
"ragflow/internal/harness/graph/graph"
"ragflow/internal/harness/graph/types"
)
// ---- Basic ReAct Graph tests (no Pregel engine dependency) ----
// TestReActGraph_CheckpointInterruptResume verifies interrupt capture.
func TestReActGraph_CheckpointInterruptResume(t *testing.T) {
model := &forcedToolModel{
inner: &mockModel{},
toolCalls: []schema.ToolCall{{ID: "c1",
Function: schema.ToolCallFunction{Name: "approve", Arguments: "{}"},
}},
finalResp: "done",
firstCall: true,
}
tool := &mockTool{name: "approve", desc: "approval"}
agent := NewReActAgent(&ReActConfig[*schema.Message]{
Model: model, Tools: []Tool{tool},
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool}},
MaxIterations: 2,
})
agent.name = "interrupt_agent"
rg, err := NewReActGraph(agent, &ReActGraphConfig{
Checkpointer: checkpoint.NewMemorySaver(),
RecursionLimit: 20,
}, nil)
if err != nil {
t.Fatalf("NewReActGraph: %v", err)
}
ctx := context.Background()
_, err = rg.Invoke(ctx, &AgentInput{
Messages: []*schema.Message{schema.UserMessage("approve")}},
nil)
if err != nil {
var gi *harnesserrors.GraphInterrupt
if stderrors.As(err, &gi) {
t.Logf("interrupt captured (expected): %v", gi)
} else {
t.Logf("other error: %v", err)
}
}
}
// TestReActGraph_StreamWithInterrupt verifies streaming events include checkpoints.
func TestReActGraph_StreamWithInterrupt(t *testing.T) {
model := &forcedToolModel{
inner: &mockModel{},
toolCalls: []schema.ToolCall{{ID: "s1",
Function: schema.ToolCallFunction{Name: "tool_s", Arguments: "{}"},
}},
finalResp: "stream ok",
firstCall: true,
}
tool := &mockTool{name: "tool_s", desc: "stream test"}
agent := NewReActAgent(&ReActConfig[*schema.Message]{
Model: model, Tools: []Tool{tool},
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool}},
MaxIterations: 2,
})
agent.name = "stream_agent"
rg, err := NewReActGraph(agent, &ReActGraphConfig{
Checkpointer: checkpoint.NewMemorySaver(),
InterruptBefore: []string{},
RecursionLimit: 20,
}, nil)
if err != nil {
t.Fatalf("NewReActGraph: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
outputCh, errCh := rg.Stream(ctx, &AgentInput{
Messages: []*schema.Message{schema.UserMessage("test")}},
nil, types.StreamModeValues)
go func() {
for range outputCh {
}
}()
select {
case e := <-errCh:
t.Logf("stream completed: err=%v", e)
case <-time.After(2 * time.Second):
t.Log("stream timed out (expected for async pattern)")
}
}
// ---- Comprehensive Graph ReAct tests (require Pregel engine) ----
// TestReActGraph_FullCheckpointInterruptResume verifies the COMPLETE lifecycle:
//
// 1. Build graph with checkpoint + interrupt
// 2. Invoke → reaches tool call → pauses at execute_tools (interrupt)
// 3. Resume from checkpoint → executes tool → completes
// 4. Verify final state is correct
func TestReActGraph_FullCheckpointInterruptResume(t *testing.T) {
t.Skip("requires Pregel engine — run from harness root: go test ./...")
model := &forcedToolModel{
inner: &mockModel{},
toolCalls: []schema.ToolCall{{
ID: "full_cp_1",
Function: schema.ToolCallFunction{
Name: "calculator",
Arguments: "{\"x\":10,\"y\":20}",
},
}},
finalResp: "the result is 30",
firstCall: true,
}
tool := &mockTool{name: "calculator", desc: "math tool"}
agent := NewReActAgent(&ReActConfig[*schema.Message]{
Model: model,
Tools: []Tool{tool},
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool}},
MaxIterations: 3,
})
agent.name = "full_cycle_agent"
saver := checkpoint.NewMemorySaver()
rg, err := NewReActGraph(agent, &ReActGraphConfig{
Checkpointer: saver,
RecursionLimit: 20,
InterruptBefore: []string{"execute_tools"}, // pause before tool execution
}, nil)
if err != nil {
t.Fatalf("NewReActGraph: %v", err)
}
ctx := context.Background()
input := &AgentInput{
Messages: []*schema.Message{schema.UserMessage("what is 10+20?")},
}
config := &types.RunnableConfig{ThreadID: "full-cycle-001"}
// ---- Phase 1: First invocation - reaches interrupt ----
t.Log("=== Phase 1: First invocation ===")
_, err = rg.Invoke(ctx, input, config)
if err == nil {
t.Fatal("expected interrupt error, got nil")
}
t.Logf("interrupt captured: %v", err)
// ---- Phase 2: Human-in-the-loop review (simulated) ----
t.Log("=== Phase 2: Human review ===")
time.Sleep(5 * time.Millisecond) // simulate review time
// ---- Phase 3: Resume from checkpoint ----
t.Log("=== Phase 3: Resume ===")
state, err := rg.Invoke(ctx, nil, config)
if err != nil {
t.Fatalf("resume failed: %v", err)
}
if state == nil || len(state.Messages) == 0 {
t.Fatal("expected messages after resume")
}
last := state.Messages[len(state.Messages)-1]
if last.Content != "the result is 30" {
t.Errorf("expected 'the result is 30', got %q", last.Content)
}
t.Logf("=== Final output: %s ===", last.Content)
}
// TestReActGraph_SerialCheckpointCycles verifies multiple interrupt-resume cycles.
func TestReActGraph_SerialCheckpointCycles(t *testing.T) {
t.Skip("requires Pregel engine — run from harness root: go test ./...")
model := &sequentialToolModel{
mock: &mockModel{},
toolCalls: [][]schema.ToolCall{
{{ID: "sc1", Function: schema.ToolCallFunction{Name: "step1", Arguments: "{}"}}},
{{ID: "sc2", Function: schema.ToolCallFunction{Name: "step2", Arguments: "{}"}}},
},
finalResp: "all steps complete",
}
tool1 := &mockTool{name: "step1", desc: "first step"}
tool2 := &mockTool{name: "step2", desc: "second step"}
agent := NewReActAgent(&ReActConfig[*schema.Message]{
Model: model,
Tools: []Tool{tool1, tool2},
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool1, tool2}},
MaxIterations: 5,
})
agent.name = "serial_cycle"
saver := checkpoint.NewMemorySaver()
rg, err := NewReActGraph(agent, &ReActGraphConfig{
Checkpointer: saver,
RecursionLimit: 30,
}, nil)
if err != nil {
t.Fatalf("NewReActGraph: %v", err)
}
ctx := context.Background()
config := &types.RunnableConfig{ThreadID: "serial-cycle-001"}
input := &AgentInput{Messages: []*schema.Message{schema.UserMessage("run all steps")}}
cycles := 0
maxCycles := 3
for cycles < maxCycles {
_, err = rg.Invoke(ctx, input, config)
if err == nil {
t.Log("graph completed without interrupt")
break
}
var gi *harnesserrors.GraphInterrupt
if stderrors.As(err, &gi) {
cycles++
t.Logf("cycle %d: interrupted, resuming...", cycles)
} else {
t.Fatalf("unexpected error: %v", err)
}
}
t.Logf("serial checkpoint cycles completed: %d interrupt-resume cycles", cycles)
}
// TestReActGraph_StreamingCheckpointEvents verifies streaming produces
// checkpoint events at each node boundary.
func TestReActGraph_StreamingCheckpointEvents(t *testing.T) {
model := &forcedToolModel{
inner: &mockModel{},
toolCalls: []schema.ToolCall{{
ID: "stream_cp",
Function: schema.ToolCallFunction{Name: "stream_tool", Arguments: "{}"},
}},
finalResp: "streaming done",
firstCall: true,
}
tool := &mockTool{name: "stream_tool", desc: "stream test"}
agent := NewReActAgent(&ReActConfig[*schema.Message]{
Model: model,
Tools: []Tool{tool},
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool}},
MaxIterations: 2,
})
agent.name = "stream_cp_agent"
saver := checkpoint.NewMemorySaver()
rg, err := NewReActGraph(agent, &ReActGraphConfig{
Checkpointer: saver,
InterruptBefore: []string{},
RecursionLimit: 20,
}, nil)
if err != nil {
t.Fatalf("NewReActGraph: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
outCh, _ := rg.Stream(ctx, &AgentInput{
Messages: []*schema.Message{schema.UserMessage("stream test")},
}, nil, types.StreamModeCheckpoints)
eventCount := 0
timeout:
for {
select {
case ev, ok := <-outCh:
if !ok {
break timeout
}
_ = ev
eventCount++
case <-ctx.Done():
break timeout
}
}
t.Logf("streaming checkpoint events received: %d", eventCount)
}
// TestReActGraph_ConcurrentCheckpoints verifies concurrent graph instances
// with separate checkpoints don't interfere.
func TestReActGraph_ConcurrentCheckpoints(t *testing.T) {
t.Skip("requires Pregel engine — run from harness root: go test ./...")
const instances = 5
errs := make(chan error, instances)
for i := 0; i < instances; i++ {
go func(id int) {
m := &mockModel{}
m.addResp("concurrent result")
agent := NewReActAgent(&ReActConfig[*schema.Message]{
Model: m,
MaxIterations: 1,
}).WithName("concurrent_cp_agent")
rg, err := NewReActGraph(agent, &ReActGraphConfig{
Checkpointer: checkpoint.NewMemorySaver(),
InterruptBefore: []string{},
RecursionLimit: 10,
}, nil)
if err != nil {
errs <- err
return
}
ctx := context.Background()
_, err = rg.Invoke(ctx, &AgentInput{
Messages: []*schema.Message{schema.UserMessage("concurrent test")},
}, nil)
errs <- err
}(i)
}
for i := 0; i < instances; i++ {
if err := <-errs; err != nil {
t.Errorf("concurrent instance %d failed: %v", i, err)
}
}
t.Logf("concurrent checkpoints: %d instances completed", instances)
}
// ---- DAG mode test (standalone graph, no ReAct dependency) ----
// TestReActGraph_DAGMode verifies AllPredecessor trigger mode.
func TestReActGraph_DAGMode(t *testing.T) {
sg := graph.NewStateGraph(map[string]interface{}{"a": "", "b": "", "c": ""})
sg.AddNode("node_a", func(ctx context.Context, state interface{}) (interface{}, error) {
s := state.(map[string]interface{})
s["a"] = "done"
return s, nil
})
sg.AddNode("node_b", func(ctx context.Context, state interface{}) (interface{}, error) {
s := state.(map[string]interface{})
s["b"] = "done"
return s, nil
})
sg.AddNode("node_c", func(ctx context.Context, state interface{}) (interface{}, error) {
s := state.(map[string]interface{})
s["c"] = "merged"
return s, nil
})
sg.AddEdge("__start__", "node_a")
sg.AddEdge("node_a", "node_b")
sg.AddEdge("node_b", "node_c")
sg.AddEdge("node_c", "__end__")
cg, err := sg.Compile(
graph.WithNodeTriggerMode(types.NodeTriggerAllPredecessor),
graph.WithRecursionLimit(10),
)
if err != nil {
t.Fatalf("Compile: %v", err)
}
result, err := cg.Invoke(context.Background(), map[string]interface{}{})
if err != nil {
t.Fatalf("Invoke: %v", err)
}
m := result.(map[string]interface{})
if m["c"] != "merged" {
t.Errorf("expected 'merged', got %v", m["c"])
}
t.Logf("DAG result: a=%v b=%v c=%v", m["a"], m["b"], m["c"])
}
// ---- Helper models ----
// sequentialToolModel returns different tool calls on each Generate call,
// simulating a multi-step tool interaction.
type sequentialToolModel struct {
mock *mockModel
toolCalls [][]schema.ToolCall
finalResp string
callCount int
}
func (m *sequentialToolModel) Generate(ctx context.Context, msgs []*schema.Message, opts ...ModelOption) (*schema.Message, error) {
if m.callCount < len(m.toolCalls) {
tcs := m.toolCalls[m.callCount]
m.callCount++
msg := &schema.Message{Role: schema.RoleAssistant, Content: ""}
msg.ToolCalls = tcs
return msg, nil
}
return &schema.Message{Role: schema.RoleAssistant, Content: m.finalResp}, nil
}
func (m *sequentialToolModel) Stream(ctx context.Context, msgs []*schema.Message, opts ...ModelOption) (*schema.StreamReader[*schema.Message], error) {
r := schema.NewStreamReader[*schema.Message]()
msg, err := m.Generate(ctx, msgs, opts...)
if err != nil {
r.Close()
return r, err
}
r.Send(msg, nil)
r.Close()
return r, nil
}
func (m *sequentialToolModel) BindTools(tools []*schema.ToolInfo) error { return nil }