mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
### What problem does this PR solve? - Tools management - Pregel engine wrapper for better usage - UT race - Coding style ### Type of change - [x] Refactoring
410 lines
12 KiB
Go
410 lines
12 KiB
Go
package core
|
|
|
|
import (
|
|
"context"
|
|
stderrors "errors"
|
|
"testing"
|
|
"time"
|
|
|
|
"ragflow/internal/harness/core/schema"
|
|
"ragflow/internal/harness/graph/checkpoint"
|
|
harnesserrors "ragflow/internal/harness/graph/errors"
|
|
"ragflow/internal/harness/graph/graph"
|
|
"ragflow/internal/harness/graph/types"
|
|
)
|
|
|
|
// ---- Basic ReAct Graph tests (no Pregel engine dependency) ----
|
|
|
|
// TestReActGraph_CheckpointInterruptResume verifies interrupt capture.
|
|
func TestReActGraph_CheckpointInterruptResume(t *testing.T) {
|
|
model := &forcedToolModel{
|
|
inner: &mockModel{},
|
|
toolCalls: []schema.ToolCall{{ID: "c1",
|
|
Function: schema.ToolCallFunction{Name: "approve", Arguments: "{}"},
|
|
}},
|
|
finalResp: "done",
|
|
firstCall: true,
|
|
}
|
|
tool := &mockTool{name: "approve", desc: "approval"}
|
|
agent := NewReActAgent(&ReActConfig[*schema.Message]{
|
|
Model: model, Tools: []Tool{tool},
|
|
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool}},
|
|
MaxIterations: 2,
|
|
})
|
|
agent.name = "interrupt_agent"
|
|
|
|
rg, err := NewReActGraph(agent, &ReActGraphConfig{
|
|
Checkpointer: checkpoint.NewMemorySaver(),
|
|
RecursionLimit: 20,
|
|
}, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewReActGraph: %v", err)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
_, err = rg.Invoke(ctx, &AgentInput{
|
|
Messages: []*schema.Message{schema.UserMessage("approve")}},
|
|
nil)
|
|
if err != nil {
|
|
var gi *harnesserrors.GraphInterrupt
|
|
if stderrors.As(err, &gi) {
|
|
t.Logf("interrupt captured (expected): %v", gi)
|
|
} else {
|
|
t.Logf("other error: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestReActGraph_StreamWithInterrupt verifies streaming events include checkpoints.
|
|
func TestReActGraph_StreamWithInterrupt(t *testing.T) {
|
|
model := &forcedToolModel{
|
|
inner: &mockModel{},
|
|
toolCalls: []schema.ToolCall{{ID: "s1",
|
|
Function: schema.ToolCallFunction{Name: "tool_s", Arguments: "{}"},
|
|
}},
|
|
finalResp: "stream ok",
|
|
firstCall: true,
|
|
}
|
|
tool := &mockTool{name: "tool_s", desc: "stream test"}
|
|
agent := NewReActAgent(&ReActConfig[*schema.Message]{
|
|
Model: model, Tools: []Tool{tool},
|
|
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool}},
|
|
MaxIterations: 2,
|
|
})
|
|
agent.name = "stream_agent"
|
|
|
|
rg, err := NewReActGraph(agent, &ReActGraphConfig{
|
|
Checkpointer: checkpoint.NewMemorySaver(),
|
|
InterruptBefore: []string{},
|
|
RecursionLimit: 20,
|
|
}, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewReActGraph: %v", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
outputCh, errCh := rg.Stream(ctx, &AgentInput{
|
|
Messages: []*schema.Message{schema.UserMessage("test")}},
|
|
nil, types.StreamModeValues)
|
|
go func() {
|
|
for range outputCh {
|
|
}
|
|
}()
|
|
select {
|
|
case e := <-errCh:
|
|
t.Logf("stream completed: err=%v", e)
|
|
case <-time.After(2 * time.Second):
|
|
t.Log("stream timed out (expected for async pattern)")
|
|
}
|
|
}
|
|
|
|
// ---- Comprehensive Graph ReAct tests (require Pregel engine) ----
|
|
|
|
// TestReActGraph_FullCheckpointInterruptResume verifies the COMPLETE lifecycle:
|
|
//
|
|
// 1. Build graph with checkpoint + interrupt
|
|
// 2. Invoke → reaches tool call → pauses at execute_tools (interrupt)
|
|
// 3. Resume from checkpoint → executes tool → completes
|
|
// 4. Verify final state is correct
|
|
func TestReActGraph_FullCheckpointInterruptResume(t *testing.T) {
|
|
t.Skip("requires Pregel engine — run from harness root: go test ./...")
|
|
|
|
model := &forcedToolModel{
|
|
inner: &mockModel{},
|
|
toolCalls: []schema.ToolCall{{
|
|
ID: "full_cp_1",
|
|
Function: schema.ToolCallFunction{
|
|
Name: "calculator",
|
|
Arguments: "{\"x\":10,\"y\":20}",
|
|
},
|
|
}},
|
|
finalResp: "the result is 30",
|
|
firstCall: true,
|
|
}
|
|
tool := &mockTool{name: "calculator", desc: "math tool"}
|
|
agent := NewReActAgent(&ReActConfig[*schema.Message]{
|
|
Model: model,
|
|
Tools: []Tool{tool},
|
|
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool}},
|
|
MaxIterations: 3,
|
|
})
|
|
agent.name = "full_cycle_agent"
|
|
|
|
saver := checkpoint.NewMemorySaver()
|
|
rg, err := NewReActGraph(agent, &ReActGraphConfig{
|
|
Checkpointer: saver,
|
|
RecursionLimit: 20,
|
|
InterruptBefore: []string{"execute_tools"}, // pause before tool execution
|
|
}, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewReActGraph: %v", err)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
input := &AgentInput{
|
|
Messages: []*schema.Message{schema.UserMessage("what is 10+20?")},
|
|
}
|
|
config := &types.RunnableConfig{ThreadID: "full-cycle-001"}
|
|
|
|
// ---- Phase 1: First invocation - reaches interrupt ----
|
|
t.Log("=== Phase 1: First invocation ===")
|
|
_, err = rg.Invoke(ctx, input, config)
|
|
if err == nil {
|
|
t.Fatal("expected interrupt error, got nil")
|
|
}
|
|
t.Logf("interrupt captured: %v", err)
|
|
|
|
// ---- Phase 2: Human-in-the-loop review (simulated) ----
|
|
t.Log("=== Phase 2: Human review ===")
|
|
time.Sleep(5 * time.Millisecond) // simulate review time
|
|
|
|
// ---- Phase 3: Resume from checkpoint ----
|
|
t.Log("=== Phase 3: Resume ===")
|
|
state, err := rg.Invoke(ctx, nil, config)
|
|
if err != nil {
|
|
t.Fatalf("resume failed: %v", err)
|
|
}
|
|
if state == nil || len(state.Messages) == 0 {
|
|
t.Fatal("expected messages after resume")
|
|
}
|
|
last := state.Messages[len(state.Messages)-1]
|
|
if last.Content != "the result is 30" {
|
|
t.Errorf("expected 'the result is 30', got %q", last.Content)
|
|
}
|
|
t.Logf("=== Final output: %s ===", last.Content)
|
|
}
|
|
|
|
// TestReActGraph_SerialCheckpointCycles verifies multiple interrupt-resume cycles.
|
|
func TestReActGraph_SerialCheckpointCycles(t *testing.T) {
|
|
t.Skip("requires Pregel engine — run from harness root: go test ./...")
|
|
|
|
model := &sequentialToolModel{
|
|
mock: &mockModel{},
|
|
toolCalls: [][]schema.ToolCall{
|
|
{{ID: "sc1", Function: schema.ToolCallFunction{Name: "step1", Arguments: "{}"}}},
|
|
{{ID: "sc2", Function: schema.ToolCallFunction{Name: "step2", Arguments: "{}"}}},
|
|
},
|
|
finalResp: "all steps complete",
|
|
}
|
|
tool1 := &mockTool{name: "step1", desc: "first step"}
|
|
tool2 := &mockTool{name: "step2", desc: "second step"}
|
|
|
|
agent := NewReActAgent(&ReActConfig[*schema.Message]{
|
|
Model: model,
|
|
Tools: []Tool{tool1, tool2},
|
|
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool1, tool2}},
|
|
MaxIterations: 5,
|
|
})
|
|
agent.name = "serial_cycle"
|
|
|
|
saver := checkpoint.NewMemorySaver()
|
|
rg, err := NewReActGraph(agent, &ReActGraphConfig{
|
|
Checkpointer: saver,
|
|
RecursionLimit: 30,
|
|
}, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewReActGraph: %v", err)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
config := &types.RunnableConfig{ThreadID: "serial-cycle-001"}
|
|
input := &AgentInput{Messages: []*schema.Message{schema.UserMessage("run all steps")}}
|
|
|
|
cycles := 0
|
|
maxCycles := 3
|
|
for cycles < maxCycles {
|
|
_, err = rg.Invoke(ctx, input, config)
|
|
if err == nil {
|
|
t.Log("graph completed without interrupt")
|
|
break
|
|
}
|
|
var gi *harnesserrors.GraphInterrupt
|
|
if stderrors.As(err, &gi) {
|
|
cycles++
|
|
t.Logf("cycle %d: interrupted, resuming...", cycles)
|
|
} else {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
t.Logf("serial checkpoint cycles completed: %d interrupt-resume cycles", cycles)
|
|
}
|
|
|
|
// TestReActGraph_StreamingCheckpointEvents verifies streaming produces
|
|
// checkpoint events at each node boundary.
|
|
func TestReActGraph_StreamingCheckpointEvents(t *testing.T) {
|
|
model := &forcedToolModel{
|
|
inner: &mockModel{},
|
|
toolCalls: []schema.ToolCall{{
|
|
ID: "stream_cp",
|
|
Function: schema.ToolCallFunction{Name: "stream_tool", Arguments: "{}"},
|
|
}},
|
|
finalResp: "streaming done",
|
|
firstCall: true,
|
|
}
|
|
tool := &mockTool{name: "stream_tool", desc: "stream test"}
|
|
agent := NewReActAgent(&ReActConfig[*schema.Message]{
|
|
Model: model,
|
|
Tools: []Tool{tool},
|
|
ToolsConfig: &ToolsNodeConfig{Tools: []Tool{tool}},
|
|
MaxIterations: 2,
|
|
})
|
|
agent.name = "stream_cp_agent"
|
|
|
|
saver := checkpoint.NewMemorySaver()
|
|
rg, err := NewReActGraph(agent, &ReActGraphConfig{
|
|
Checkpointer: saver,
|
|
InterruptBefore: []string{},
|
|
RecursionLimit: 20,
|
|
}, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewReActGraph: %v", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
|
|
outCh, _ := rg.Stream(ctx, &AgentInput{
|
|
Messages: []*schema.Message{schema.UserMessage("stream test")},
|
|
}, nil, types.StreamModeCheckpoints)
|
|
|
|
eventCount := 0
|
|
timeout:
|
|
for {
|
|
select {
|
|
case ev, ok := <-outCh:
|
|
if !ok {
|
|
break timeout
|
|
}
|
|
_ = ev
|
|
eventCount++
|
|
case <-ctx.Done():
|
|
break timeout
|
|
}
|
|
}
|
|
t.Logf("streaming checkpoint events received: %d", eventCount)
|
|
}
|
|
|
|
// TestReActGraph_ConcurrentCheckpoints verifies concurrent graph instances
|
|
// with separate checkpoints don't interfere.
|
|
func TestReActGraph_ConcurrentCheckpoints(t *testing.T) {
|
|
t.Skip("requires Pregel engine — run from harness root: go test ./...")
|
|
|
|
const instances = 5
|
|
errs := make(chan error, instances)
|
|
|
|
for i := 0; i < instances; i++ {
|
|
go func(id int) {
|
|
m := &mockModel{}
|
|
m.addResp("concurrent result")
|
|
agent := NewReActAgent(&ReActConfig[*schema.Message]{
|
|
Model: m,
|
|
MaxIterations: 1,
|
|
}).WithName("concurrent_cp_agent")
|
|
|
|
rg, err := NewReActGraph(agent, &ReActGraphConfig{
|
|
Checkpointer: checkpoint.NewMemorySaver(),
|
|
InterruptBefore: []string{},
|
|
RecursionLimit: 10,
|
|
}, nil)
|
|
if err != nil {
|
|
errs <- err
|
|
return
|
|
}
|
|
|
|
ctx := context.Background()
|
|
_, err = rg.Invoke(ctx, &AgentInput{
|
|
Messages: []*schema.Message{schema.UserMessage("concurrent test")},
|
|
}, nil)
|
|
errs <- err
|
|
}(i)
|
|
}
|
|
|
|
for i := 0; i < instances; i++ {
|
|
if err := <-errs; err != nil {
|
|
t.Errorf("concurrent instance %d failed: %v", i, err)
|
|
}
|
|
}
|
|
t.Logf("concurrent checkpoints: %d instances completed", instances)
|
|
}
|
|
|
|
// ---- DAG mode test (standalone graph, no ReAct dependency) ----
|
|
|
|
// TestReActGraph_DAGMode verifies AllPredecessor trigger mode.
|
|
func TestReActGraph_DAGMode(t *testing.T) {
|
|
sg := graph.NewStateGraph(map[string]interface{}{"a": "", "b": "", "c": ""})
|
|
sg.AddNode("node_a", func(ctx context.Context, state interface{}) (interface{}, error) {
|
|
s := state.(map[string]interface{})
|
|
s["a"] = "done"
|
|
return s, nil
|
|
})
|
|
sg.AddNode("node_b", func(ctx context.Context, state interface{}) (interface{}, error) {
|
|
s := state.(map[string]interface{})
|
|
s["b"] = "done"
|
|
return s, nil
|
|
})
|
|
sg.AddNode("node_c", func(ctx context.Context, state interface{}) (interface{}, error) {
|
|
s := state.(map[string]interface{})
|
|
s["c"] = "merged"
|
|
return s, nil
|
|
})
|
|
sg.AddEdge("__start__", "node_a")
|
|
sg.AddEdge("node_a", "node_b")
|
|
sg.AddEdge("node_b", "node_c")
|
|
sg.AddEdge("node_c", "__end__")
|
|
|
|
cg, err := sg.Compile(
|
|
graph.WithNodeTriggerMode(types.NodeTriggerAllPredecessor),
|
|
graph.WithRecursionLimit(10),
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("Compile: %v", err)
|
|
}
|
|
|
|
result, err := cg.Invoke(context.Background(), map[string]interface{}{})
|
|
if err != nil {
|
|
t.Fatalf("Invoke: %v", err)
|
|
}
|
|
m := result.(map[string]interface{})
|
|
if m["c"] != "merged" {
|
|
t.Errorf("expected 'merged', got %v", m["c"])
|
|
}
|
|
t.Logf("DAG result: a=%v b=%v c=%v", m["a"], m["b"], m["c"])
|
|
}
|
|
|
|
// ---- Helper models ----
|
|
|
|
// sequentialToolModel returns different tool calls on each Generate call,
|
|
// simulating a multi-step tool interaction.
|
|
type sequentialToolModel struct {
|
|
mock *mockModel
|
|
toolCalls [][]schema.ToolCall
|
|
finalResp string
|
|
callCount int
|
|
}
|
|
|
|
func (m *sequentialToolModel) Generate(ctx context.Context, msgs []*schema.Message, opts ...ModelOption) (*schema.Message, error) {
|
|
if m.callCount < len(m.toolCalls) {
|
|
tcs := m.toolCalls[m.callCount]
|
|
m.callCount++
|
|
msg := &schema.Message{Role: schema.RoleAssistant, Content: ""}
|
|
msg.ToolCalls = tcs
|
|
return msg, nil
|
|
}
|
|
return &schema.Message{Role: schema.RoleAssistant, Content: m.finalResp}, nil
|
|
}
|
|
|
|
func (m *sequentialToolModel) Stream(ctx context.Context, msgs []*schema.Message, opts ...ModelOption) (*schema.StreamReader[*schema.Message], error) {
|
|
r := schema.NewStreamReader[*schema.Message]()
|
|
msg, err := m.Generate(ctx, msgs, opts...)
|
|
if err != nil {
|
|
r.Close()
|
|
return r, err
|
|
}
|
|
r.Send(msg, nil)
|
|
r.Close()
|
|
return r, nil
|
|
}
|
|
|
|
func (m *sequentialToolModel) BindTools(tools []*schema.ToolInfo) error { return nil }
|