Files
ragflow/internal/harness/core/react_graph_test.go
blackflytech fcf2ca869b refactor: replace context.WithCancel with t.Context (#16509)
### Summary

The addition of the Context method to Go's testing.T provides
significant improvements for writing concurrent tests. It allows better
management of goroutines, ensuring they properly exit and preventing
issues like deadlocks and unfinished processes.

By using Context, errors and cancellations can be handled more
effectively, making tests more robust and easier to reason about. This
change also enables tighter integration between tests and the
application code, especially for systems that span multiple concurrent
components. Overall, it simplifies test code and enhances test stability
and maintainability.

More info: [golang/go#18368](https://github.com/golang/go/issues/18368)

Signed-off-by: blackflytech <blackflytech@outlook.com>
2026-07-01 18:37:11 +08:00

409 lines
12 KiB
Go

package core
import (
"context"
stderrors "errors"
"testing"
"time"
"ragflow/internal/harness/core/schema"
"ragflow/internal/harness/graph/checkpoint"
harnesserrors "ragflow/internal/harness/graph/errors"
"ragflow/internal/harness/graph/graph"
"ragflow/internal/harness/graph/types"
)
// ---- Basic ReAct Graph tests (no Pregel engine dependency) ----
// TestReActGraph_CheckpointInterruptResume verifies interrupt capture.
func TestReActGraph_CheckpointInterruptResume(t *testing.T) {
model := &forcedToolModel{
inner: &mockModel{},
toolCalls: []schema.ToolCall{{ID: "c1",
Function: schema.ToolCallFunction{Name: "approve", Arguments: "{}"},
}},
finalResp: "done",
firstCall: true,
}
tool := &mockTool{name: "approve", desc: "approval"}
agent := NewReActAgent(&ReActConfig[*schema.Message]{
Model: model, Tools: []Tool{tool},
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool}},
MaxIterations: 2,
})
agent.name = "interrupt_agent"
rg, err := NewReActGraph(agent, &ReActGraphConfig{
Checkpointer: checkpoint.NewMemorySaver(),
RecursionLimit: 20,
}, nil)
if err != nil {
t.Fatalf("NewReActGraph: %v", err)
}
ctx := context.Background()
_, err = rg.Invoke(ctx, &AgentInput{
Messages: []*schema.Message{schema.UserMessage("approve")}},
nil)
if err != nil {
var gi *harnesserrors.GraphInterrupt
if stderrors.As(err, &gi) {
t.Logf("interrupt captured (expected): %v", gi)
} else {
t.Logf("other error: %v", err)
}
}
}
// TestReActGraph_StreamWithInterrupt verifies streaming events include checkpoints.
func TestReActGraph_StreamWithInterrupt(t *testing.T) {
model := &forcedToolModel{
inner: &mockModel{},
toolCalls: []schema.ToolCall{{ID: "s1",
Function: schema.ToolCallFunction{Name: "tool_s", Arguments: "{}"},
}},
finalResp: "stream ok",
firstCall: true,
}
tool := &mockTool{name: "tool_s", desc: "stream test"}
agent := NewReActAgent(&ReActConfig[*schema.Message]{
Model: model, Tools: []Tool{tool},
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool}},
MaxIterations: 2,
})
agent.name = "stream_agent"
rg, err := NewReActGraph(agent, &ReActGraphConfig{
Checkpointer: checkpoint.NewMemorySaver(),
InterruptBefore: []string{},
RecursionLimit: 20,
}, nil)
if err != nil {
t.Fatalf("NewReActGraph: %v", err)
}
ctx := t.Context()
outputCh, errCh := rg.Stream(ctx, &AgentInput{
Messages: []*schema.Message{schema.UserMessage("test")}},
nil, types.StreamModeValues)
go func() {
for range outputCh {
}
}()
select {
case e := <-errCh:
t.Logf("stream completed: err=%v", e)
case <-time.After(2 * time.Second):
t.Log("stream timed out (expected for async pattern)")
}
}
// ---- Comprehensive Graph ReAct tests (require Pregel engine) ----
// TestReActGraph_FullCheckpointInterruptResume verifies the COMPLETE lifecycle:
//
// 1. Build graph with checkpoint + interrupt
// 2. Invoke → reaches tool call → pauses at execute_tools (interrupt)
// 3. Resume from checkpoint → executes tool → completes
// 4. Verify final state is correct
func TestReActGraph_FullCheckpointInterruptResume(t *testing.T) {
t.Skip("requires Pregel engine — run from harness root: go test ./...")
model := &forcedToolModel{
inner: &mockModel{},
toolCalls: []schema.ToolCall{{
ID: "full_cp_1",
Function: schema.ToolCallFunction{
Name: "calculator",
Arguments: "{\"x\":10,\"y\":20}",
},
}},
finalResp: "the result is 30",
firstCall: true,
}
tool := &mockTool{name: "calculator", desc: "math tool"}
agent := NewReActAgent(&ReActConfig[*schema.Message]{
Model: model,
Tools: []Tool{tool},
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool}},
MaxIterations: 3,
})
agent.name = "full_cycle_agent"
saver := checkpoint.NewMemorySaver()
rg, err := NewReActGraph(agent, &ReActGraphConfig{
Checkpointer: saver,
RecursionLimit: 20,
InterruptBefore: []string{"execute_tools"}, // pause before tool execution
}, nil)
if err != nil {
t.Fatalf("NewReActGraph: %v", err)
}
ctx := context.Background()
input := &AgentInput{
Messages: []*schema.Message{schema.UserMessage("what is 10+20?")},
}
config := &types.RunnableConfig{ThreadID: "full-cycle-001"}
// ---- Phase 1: First invocation - reaches interrupt ----
t.Log("=== Phase 1: First invocation ===")
_, err = rg.Invoke(ctx, input, config)
if err == nil {
t.Fatal("expected interrupt error, got nil")
}
t.Logf("interrupt captured: %v", err)
// ---- Phase 2: Human-in-the-loop review (simulated) ----
t.Log("=== Phase 2: Human review ===")
time.Sleep(5 * time.Millisecond) // simulate review time
// ---- Phase 3: Resume from checkpoint ----
t.Log("=== Phase 3: Resume ===")
state, err := rg.Invoke(ctx, nil, config)
if err != nil {
t.Fatalf("resume failed: %v", err)
}
if state == nil || len(state.Messages) == 0 {
t.Fatal("expected messages after resume")
}
last := state.Messages[len(state.Messages)-1]
if last.Content != "the result is 30" {
t.Errorf("expected 'the result is 30', got %q", last.Content)
}
t.Logf("=== Final output: %s ===", last.Content)
}
// TestReActGraph_SerialCheckpointCycles verifies multiple interrupt-resume cycles.
func TestReActGraph_SerialCheckpointCycles(t *testing.T) {
t.Skip("requires Pregel engine — run from harness root: go test ./...")
model := &sequentialToolModel{
mock: &mockModel{},
toolCalls: [][]schema.ToolCall{
{{ID: "sc1", Function: schema.ToolCallFunction{Name: "step1", Arguments: "{}"}}},
{{ID: "sc2", Function: schema.ToolCallFunction{Name: "step2", Arguments: "{}"}}},
},
finalResp: "all steps complete",
}
tool1 := &mockTool{name: "step1", desc: "first step"}
tool2 := &mockTool{name: "step2", desc: "second step"}
agent := NewReActAgent(&ReActConfig[*schema.Message]{
Model: model,
Tools: []Tool{tool1, tool2},
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool1, tool2}},
MaxIterations: 5,
})
agent.name = "serial_cycle"
saver := checkpoint.NewMemorySaver()
rg, err := NewReActGraph(agent, &ReActGraphConfig{
Checkpointer: saver,
RecursionLimit: 30,
}, nil)
if err != nil {
t.Fatalf("NewReActGraph: %v", err)
}
ctx := context.Background()
config := &types.RunnableConfig{ThreadID: "serial-cycle-001"}
input := &AgentInput{Messages: []*schema.Message{schema.UserMessage("run all steps")}}
cycles := 0
maxCycles := 3
for cycles < maxCycles {
_, err = rg.Invoke(ctx, input, config)
if err == nil {
t.Log("graph completed without interrupt")
break
}
var gi *harnesserrors.GraphInterrupt
if stderrors.As(err, &gi) {
cycles++
t.Logf("cycle %d: interrupted, resuming...", cycles)
} else {
t.Fatalf("unexpected error: %v", err)
}
}
t.Logf("serial checkpoint cycles completed: %d interrupt-resume cycles", cycles)
}
// TestReActGraph_StreamingCheckpointEvents verifies streaming produces
// checkpoint events at each node boundary.
func TestReActGraph_StreamingCheckpointEvents(t *testing.T) {
model := &forcedToolModel{
inner: &mockModel{},
toolCalls: []schema.ToolCall{{
ID: "stream_cp",
Function: schema.ToolCallFunction{Name: "stream_tool", Arguments: "{}"},
}},
finalResp: "streaming done",
firstCall: true,
}
tool := &mockTool{name: "stream_tool", desc: "stream test"}
agent := NewReActAgent(&ReActConfig[*schema.Message]{
Model: model,
Tools: []Tool{tool},
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool}},
MaxIterations: 2,
})
agent.name = "stream_cp_agent"
saver := checkpoint.NewMemorySaver()
rg, err := NewReActGraph(agent, &ReActGraphConfig{
Checkpointer: saver,
InterruptBefore: []string{},
RecursionLimit: 20,
}, nil)
if err != nil {
t.Fatalf("NewReActGraph: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
outCh, _ := rg.Stream(ctx, &AgentInput{
Messages: []*schema.Message{schema.UserMessage("stream test")},
}, nil, types.StreamModeCheckpoints)
eventCount := 0
timeout:
for {
select {
case ev, ok := <-outCh:
if !ok {
break timeout
}
_ = ev
eventCount++
case <-ctx.Done():
break timeout
}
}
t.Logf("streaming checkpoint events received: %d", eventCount)
}
// TestReActGraph_ConcurrentCheckpoints verifies concurrent graph instances
// with separate checkpoints don't interfere.
func TestReActGraph_ConcurrentCheckpoints(t *testing.T) {
t.Skip("requires Pregel engine — run from harness root: go test ./...")
const instances = 5
errs := make(chan error, instances)
for i := 0; i < instances; i++ {
go func(id int) {
m := &mockModel{}
m.addResp("concurrent result")
agent := NewReActAgent(&ReActConfig[*schema.Message]{
Model: m,
MaxIterations: 1,
}).WithName("concurrent_cp_agent")
rg, err := NewReActGraph(agent, &ReActGraphConfig{
Checkpointer: checkpoint.NewMemorySaver(),
InterruptBefore: []string{},
RecursionLimit: 10,
}, nil)
if err != nil {
errs <- err
return
}
ctx := context.Background()
_, err = rg.Invoke(ctx, &AgentInput{
Messages: []*schema.Message{schema.UserMessage("concurrent test")},
}, nil)
errs <- err
}(i)
}
for i := 0; i < instances; i++ {
if err := <-errs; err != nil {
t.Errorf("concurrent instance %d failed: %v", i, err)
}
}
t.Logf("concurrent checkpoints: %d instances completed", instances)
}
// ---- DAG mode test (standalone graph, no ReAct dependency) ----
// TestReActGraph_DAGMode verifies AllPredecessor trigger mode.
func TestReActGraph_DAGMode(t *testing.T) {
sg := graph.NewStateGraph(map[string]interface{}{"a": "", "b": "", "c": ""})
sg.AddNode("node_a", func(ctx context.Context, state interface{}) (interface{}, error) {
s := state.(map[string]interface{})
s["a"] = "done"
return s, nil
})
sg.AddNode("node_b", func(ctx context.Context, state interface{}) (interface{}, error) {
s := state.(map[string]interface{})
s["b"] = "done"
return s, nil
})
sg.AddNode("node_c", func(ctx context.Context, state interface{}) (interface{}, error) {
s := state.(map[string]interface{})
s["c"] = "merged"
return s, nil
})
sg.AddEdge("__start__", "node_a")
sg.AddEdge("node_a", "node_b")
sg.AddEdge("node_b", "node_c")
sg.AddEdge("node_c", "__end__")
cg, err := sg.Compile(
graph.WithNodeTriggerMode(types.NodeTriggerAllPredecessor),
graph.WithRecursionLimit(10),
)
if err != nil {
t.Fatalf("Compile: %v", err)
}
result, err := cg.Invoke(context.Background(), map[string]interface{}{})
if err != nil {
t.Fatalf("Invoke: %v", err)
}
m := result.(map[string]interface{})
if m["c"] != "merged" {
t.Errorf("expected 'merged', got %v", m["c"])
}
t.Logf("DAG result: a=%v b=%v c=%v", m["a"], m["b"], m["c"])
}
// ---- Helper models ----
// sequentialToolModel returns different tool calls on each Generate call,
// simulating a multi-step tool interaction.
type sequentialToolModel struct {
mock *mockModel
toolCalls [][]schema.ToolCall
finalResp string
callCount int
}
func (m *sequentialToolModel) Generate(ctx context.Context, msgs []*schema.Message, opts ...ModelOption) (*schema.Message, error) {
if m.callCount < len(m.toolCalls) {
tcs := m.toolCalls[m.callCount]
m.callCount++
msg := &schema.Message{Role: schema.RoleAssistant, Content: ""}
msg.ToolCalls = tcs
return msg, nil
}
return &schema.Message{Role: schema.RoleAssistant, Content: m.finalResp}, nil
}
func (m *sequentialToolModel) Stream(ctx context.Context, msgs []*schema.Message, opts ...ModelOption) (*schema.StreamReader[*schema.Message], error) {
r := schema.NewStreamReader[*schema.Message]()
msg, err := m.Generate(ctx, msgs, opts...)
if err != nil {
r.Close()
return r, err
}
r.Send(msg, nil)
r.Close()
return r, nil
}
func (m *sequentialToolModel) BindTools(tools []*schema.ToolInfo) error { return nil }