mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-03 09:11:59 +08:00
### Summary The addition of the Context method to Go's testing.T provides significant improvements for writing concurrent tests. It allows better management of goroutines, ensuring they properly exit and preventing issues like deadlocks and unfinished processes. By using Context, errors and cancellations can be handled more effectively, making tests more robust and easier to reason about. This change also enables tighter integration between tests and the application code, especially for systems that span multiple concurrent components. Overall, it simplifies test code and enhances test stability and maintainability. More info: [golang/go#18368](https://github.com/golang/go/issues/18368) Signed-off-by: blackflytech <blackflytech@outlook.com>
409 lines
12 KiB
Go
409 lines
12 KiB
Go
package core
|
|
|
|
import (
|
|
"context"
|
|
stderrors "errors"
|
|
"testing"
|
|
"time"
|
|
|
|
"ragflow/internal/harness/core/schema"
|
|
"ragflow/internal/harness/graph/checkpoint"
|
|
harnesserrors "ragflow/internal/harness/graph/errors"
|
|
"ragflow/internal/harness/graph/graph"
|
|
"ragflow/internal/harness/graph/types"
|
|
)
|
|
|
|
// ---- Basic ReAct Graph tests (no Pregel engine dependency) ----
|
|
|
|
// TestReActGraph_CheckpointInterruptResume verifies interrupt capture.
|
|
func TestReActGraph_CheckpointInterruptResume(t *testing.T) {
|
|
model := &forcedToolModel{
|
|
inner: &mockModel{},
|
|
toolCalls: []schema.ToolCall{{ID: "c1",
|
|
Function: schema.ToolCallFunction{Name: "approve", Arguments: "{}"},
|
|
}},
|
|
finalResp: "done",
|
|
firstCall: true,
|
|
}
|
|
tool := &mockTool{name: "approve", desc: "approval"}
|
|
agent := NewReActAgent(&ReActConfig[*schema.Message]{
|
|
Model: model, Tools: []Tool{tool},
|
|
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool}},
|
|
MaxIterations: 2,
|
|
})
|
|
agent.name = "interrupt_agent"
|
|
|
|
rg, err := NewReActGraph(agent, &ReActGraphConfig{
|
|
Checkpointer: checkpoint.NewMemorySaver(),
|
|
RecursionLimit: 20,
|
|
}, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewReActGraph: %v", err)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
_, err = rg.Invoke(ctx, &AgentInput{
|
|
Messages: []*schema.Message{schema.UserMessage("approve")}},
|
|
nil)
|
|
if err != nil {
|
|
var gi *harnesserrors.GraphInterrupt
|
|
if stderrors.As(err, &gi) {
|
|
t.Logf("interrupt captured (expected): %v", gi)
|
|
} else {
|
|
t.Logf("other error: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestReActGraph_StreamWithInterrupt verifies streaming events include checkpoints.
|
|
func TestReActGraph_StreamWithInterrupt(t *testing.T) {
|
|
model := &forcedToolModel{
|
|
inner: &mockModel{},
|
|
toolCalls: []schema.ToolCall{{ID: "s1",
|
|
Function: schema.ToolCallFunction{Name: "tool_s", Arguments: "{}"},
|
|
}},
|
|
finalResp: "stream ok",
|
|
firstCall: true,
|
|
}
|
|
tool := &mockTool{name: "tool_s", desc: "stream test"}
|
|
agent := NewReActAgent(&ReActConfig[*schema.Message]{
|
|
Model: model, Tools: []Tool{tool},
|
|
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool}},
|
|
MaxIterations: 2,
|
|
})
|
|
agent.name = "stream_agent"
|
|
|
|
rg, err := NewReActGraph(agent, &ReActGraphConfig{
|
|
Checkpointer: checkpoint.NewMemorySaver(),
|
|
InterruptBefore: []string{},
|
|
RecursionLimit: 20,
|
|
}, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewReActGraph: %v", err)
|
|
}
|
|
|
|
ctx := t.Context()
|
|
outputCh, errCh := rg.Stream(ctx, &AgentInput{
|
|
Messages: []*schema.Message{schema.UserMessage("test")}},
|
|
nil, types.StreamModeValues)
|
|
go func() {
|
|
for range outputCh {
|
|
}
|
|
}()
|
|
select {
|
|
case e := <-errCh:
|
|
t.Logf("stream completed: err=%v", e)
|
|
case <-time.After(2 * time.Second):
|
|
t.Log("stream timed out (expected for async pattern)")
|
|
}
|
|
}
|
|
|
|
// ---- Comprehensive Graph ReAct tests (require Pregel engine) ----
|
|
|
|
// TestReActGraph_FullCheckpointInterruptResume verifies the COMPLETE lifecycle:
|
|
//
|
|
// 1. Build graph with checkpoint + interrupt
|
|
// 2. Invoke → reaches tool call → pauses at execute_tools (interrupt)
|
|
// 3. Resume from checkpoint → executes tool → completes
|
|
// 4. Verify final state is correct
|
|
func TestReActGraph_FullCheckpointInterruptResume(t *testing.T) {
|
|
t.Skip("requires Pregel engine — run from harness root: go test ./...")
|
|
|
|
model := &forcedToolModel{
|
|
inner: &mockModel{},
|
|
toolCalls: []schema.ToolCall{{
|
|
ID: "full_cp_1",
|
|
Function: schema.ToolCallFunction{
|
|
Name: "calculator",
|
|
Arguments: "{\"x\":10,\"y\":20}",
|
|
},
|
|
}},
|
|
finalResp: "the result is 30",
|
|
firstCall: true,
|
|
}
|
|
tool := &mockTool{name: "calculator", desc: "math tool"}
|
|
agent := NewReActAgent(&ReActConfig[*schema.Message]{
|
|
Model: model,
|
|
Tools: []Tool{tool},
|
|
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool}},
|
|
MaxIterations: 3,
|
|
})
|
|
agent.name = "full_cycle_agent"
|
|
|
|
saver := checkpoint.NewMemorySaver()
|
|
rg, err := NewReActGraph(agent, &ReActGraphConfig{
|
|
Checkpointer: saver,
|
|
RecursionLimit: 20,
|
|
InterruptBefore: []string{"execute_tools"}, // pause before tool execution
|
|
}, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewReActGraph: %v", err)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
input := &AgentInput{
|
|
Messages: []*schema.Message{schema.UserMessage("what is 10+20?")},
|
|
}
|
|
config := &types.RunnableConfig{ThreadID: "full-cycle-001"}
|
|
|
|
// ---- Phase 1: First invocation - reaches interrupt ----
|
|
t.Log("=== Phase 1: First invocation ===")
|
|
_, err = rg.Invoke(ctx, input, config)
|
|
if err == nil {
|
|
t.Fatal("expected interrupt error, got nil")
|
|
}
|
|
t.Logf("interrupt captured: %v", err)
|
|
|
|
// ---- Phase 2: Human-in-the-loop review (simulated) ----
|
|
t.Log("=== Phase 2: Human review ===")
|
|
time.Sleep(5 * time.Millisecond) // simulate review time
|
|
|
|
// ---- Phase 3: Resume from checkpoint ----
|
|
t.Log("=== Phase 3: Resume ===")
|
|
state, err := rg.Invoke(ctx, nil, config)
|
|
if err != nil {
|
|
t.Fatalf("resume failed: %v", err)
|
|
}
|
|
if state == nil || len(state.Messages) == 0 {
|
|
t.Fatal("expected messages after resume")
|
|
}
|
|
last := state.Messages[len(state.Messages)-1]
|
|
if last.Content != "the result is 30" {
|
|
t.Errorf("expected 'the result is 30', got %q", last.Content)
|
|
}
|
|
t.Logf("=== Final output: %s ===", last.Content)
|
|
}
|
|
|
|
// TestReActGraph_SerialCheckpointCycles verifies multiple interrupt-resume cycles.
|
|
func TestReActGraph_SerialCheckpointCycles(t *testing.T) {
|
|
t.Skip("requires Pregel engine — run from harness root: go test ./...")
|
|
|
|
model := &sequentialToolModel{
|
|
mock: &mockModel{},
|
|
toolCalls: [][]schema.ToolCall{
|
|
{{ID: "sc1", Function: schema.ToolCallFunction{Name: "step1", Arguments: "{}"}}},
|
|
{{ID: "sc2", Function: schema.ToolCallFunction{Name: "step2", Arguments: "{}"}}},
|
|
},
|
|
finalResp: "all steps complete",
|
|
}
|
|
tool1 := &mockTool{name: "step1", desc: "first step"}
|
|
tool2 := &mockTool{name: "step2", desc: "second step"}
|
|
|
|
agent := NewReActAgent(&ReActConfig[*schema.Message]{
|
|
Model: model,
|
|
Tools: []Tool{tool1, tool2},
|
|
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool1, tool2}},
|
|
MaxIterations: 5,
|
|
})
|
|
agent.name = "serial_cycle"
|
|
|
|
saver := checkpoint.NewMemorySaver()
|
|
rg, err := NewReActGraph(agent, &ReActGraphConfig{
|
|
Checkpointer: saver,
|
|
RecursionLimit: 30,
|
|
}, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewReActGraph: %v", err)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
config := &types.RunnableConfig{ThreadID: "serial-cycle-001"}
|
|
input := &AgentInput{Messages: []*schema.Message{schema.UserMessage("run all steps")}}
|
|
|
|
cycles := 0
|
|
maxCycles := 3
|
|
for cycles < maxCycles {
|
|
_, err = rg.Invoke(ctx, input, config)
|
|
if err == nil {
|
|
t.Log("graph completed without interrupt")
|
|
break
|
|
}
|
|
var gi *harnesserrors.GraphInterrupt
|
|
if stderrors.As(err, &gi) {
|
|
cycles++
|
|
t.Logf("cycle %d: interrupted, resuming...", cycles)
|
|
} else {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
t.Logf("serial checkpoint cycles completed: %d interrupt-resume cycles", cycles)
|
|
}
|
|
|
|
// TestReActGraph_StreamingCheckpointEvents verifies streaming produces
|
|
// checkpoint events at each node boundary.
|
|
func TestReActGraph_StreamingCheckpointEvents(t *testing.T) {
|
|
model := &forcedToolModel{
|
|
inner: &mockModel{},
|
|
toolCalls: []schema.ToolCall{{
|
|
ID: "stream_cp",
|
|
Function: schema.ToolCallFunction{Name: "stream_tool", Arguments: "{}"},
|
|
}},
|
|
finalResp: "streaming done",
|
|
firstCall: true,
|
|
}
|
|
tool := &mockTool{name: "stream_tool", desc: "stream test"}
|
|
agent := NewReActAgent(&ReActConfig[*schema.Message]{
|
|
Model: model,
|
|
Tools: []Tool{tool},
|
|
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool}},
|
|
MaxIterations: 2,
|
|
})
|
|
agent.name = "stream_cp_agent"
|
|
|
|
saver := checkpoint.NewMemorySaver()
|
|
rg, err := NewReActGraph(agent, &ReActGraphConfig{
|
|
Checkpointer: saver,
|
|
InterruptBefore: []string{},
|
|
RecursionLimit: 20,
|
|
}, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewReActGraph: %v", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
|
|
outCh, _ := rg.Stream(ctx, &AgentInput{
|
|
Messages: []*schema.Message{schema.UserMessage("stream test")},
|
|
}, nil, types.StreamModeCheckpoints)
|
|
|
|
eventCount := 0
|
|
timeout:
|
|
for {
|
|
select {
|
|
case ev, ok := <-outCh:
|
|
if !ok {
|
|
break timeout
|
|
}
|
|
_ = ev
|
|
eventCount++
|
|
case <-ctx.Done():
|
|
break timeout
|
|
}
|
|
}
|
|
t.Logf("streaming checkpoint events received: %d", eventCount)
|
|
}
|
|
|
|
// TestReActGraph_ConcurrentCheckpoints verifies concurrent graph instances
|
|
// with separate checkpoints don't interfere.
|
|
func TestReActGraph_ConcurrentCheckpoints(t *testing.T) {
|
|
t.Skip("requires Pregel engine — run from harness root: go test ./...")
|
|
|
|
const instances = 5
|
|
errs := make(chan error, instances)
|
|
|
|
for i := 0; i < instances; i++ {
|
|
go func(id int) {
|
|
m := &mockModel{}
|
|
m.addResp("concurrent result")
|
|
agent := NewReActAgent(&ReActConfig[*schema.Message]{
|
|
Model: m,
|
|
MaxIterations: 1,
|
|
}).WithName("concurrent_cp_agent")
|
|
|
|
rg, err := NewReActGraph(agent, &ReActGraphConfig{
|
|
Checkpointer: checkpoint.NewMemorySaver(),
|
|
InterruptBefore: []string{},
|
|
RecursionLimit: 10,
|
|
}, nil)
|
|
if err != nil {
|
|
errs <- err
|
|
return
|
|
}
|
|
|
|
ctx := context.Background()
|
|
_, err = rg.Invoke(ctx, &AgentInput{
|
|
Messages: []*schema.Message{schema.UserMessage("concurrent test")},
|
|
}, nil)
|
|
errs <- err
|
|
}(i)
|
|
}
|
|
|
|
for i := 0; i < instances; i++ {
|
|
if err := <-errs; err != nil {
|
|
t.Errorf("concurrent instance %d failed: %v", i, err)
|
|
}
|
|
}
|
|
t.Logf("concurrent checkpoints: %d instances completed", instances)
|
|
}
|
|
|
|
// ---- DAG mode test (standalone graph, no ReAct dependency) ----
|
|
|
|
// TestReActGraph_DAGMode verifies AllPredecessor trigger mode.
|
|
func TestReActGraph_DAGMode(t *testing.T) {
|
|
sg := graph.NewStateGraph(map[string]interface{}{"a": "", "b": "", "c": ""})
|
|
sg.AddNode("node_a", func(ctx context.Context, state interface{}) (interface{}, error) {
|
|
s := state.(map[string]interface{})
|
|
s["a"] = "done"
|
|
return s, nil
|
|
})
|
|
sg.AddNode("node_b", func(ctx context.Context, state interface{}) (interface{}, error) {
|
|
s := state.(map[string]interface{})
|
|
s["b"] = "done"
|
|
return s, nil
|
|
})
|
|
sg.AddNode("node_c", func(ctx context.Context, state interface{}) (interface{}, error) {
|
|
s := state.(map[string]interface{})
|
|
s["c"] = "merged"
|
|
return s, nil
|
|
})
|
|
sg.AddEdge("__start__", "node_a")
|
|
sg.AddEdge("node_a", "node_b")
|
|
sg.AddEdge("node_b", "node_c")
|
|
sg.AddEdge("node_c", "__end__")
|
|
|
|
cg, err := sg.Compile(
|
|
graph.WithNodeTriggerMode(types.NodeTriggerAllPredecessor),
|
|
graph.WithRecursionLimit(10),
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("Compile: %v", err)
|
|
}
|
|
|
|
result, err := cg.Invoke(context.Background(), map[string]interface{}{})
|
|
if err != nil {
|
|
t.Fatalf("Invoke: %v", err)
|
|
}
|
|
m := result.(map[string]interface{})
|
|
if m["c"] != "merged" {
|
|
t.Errorf("expected 'merged', got %v", m["c"])
|
|
}
|
|
t.Logf("DAG result: a=%v b=%v c=%v", m["a"], m["b"], m["c"])
|
|
}
|
|
|
|
// ---- Helper models ----
|
|
|
|
// sequentialToolModel returns different tool calls on each Generate call,
|
|
// simulating a multi-step tool interaction.
|
|
type sequentialToolModel struct {
|
|
mock *mockModel
|
|
toolCalls [][]schema.ToolCall
|
|
finalResp string
|
|
callCount int
|
|
}
|
|
|
|
func (m *sequentialToolModel) Generate(ctx context.Context, msgs []*schema.Message, opts ...ModelOption) (*schema.Message, error) {
|
|
if m.callCount < len(m.toolCalls) {
|
|
tcs := m.toolCalls[m.callCount]
|
|
m.callCount++
|
|
msg := &schema.Message{Role: schema.RoleAssistant, Content: ""}
|
|
msg.ToolCalls = tcs
|
|
return msg, nil
|
|
}
|
|
return &schema.Message{Role: schema.RoleAssistant, Content: m.finalResp}, nil
|
|
}
|
|
|
|
func (m *sequentialToolModel) Stream(ctx context.Context, msgs []*schema.Message, opts ...ModelOption) (*schema.StreamReader[*schema.Message], error) {
|
|
r := schema.NewStreamReader[*schema.Message]()
|
|
msg, err := m.Generate(ctx, msgs, opts...)
|
|
if err != nil {
|
|
r.Close()
|
|
return r, err
|
|
}
|
|
r.Send(msg, nil)
|
|
r.Close()
|
|
return r, nil
|
|
}
|
|
|
|
func (m *sequentialToolModel) BindTools(tools []*schema.ToolInfo) error { return nil }
|