Files
ragflow/internal/harness/core/agent_loop_test.go
Yingfeng 706e0d2d06 Refactor harness framework (#16271)
### What problem does this PR solve?

- Tools management
- Pregel engine wrapper for better usage
- UT race
- Coding style

### Type of change

- [x] Refactoring
2026-06-23 20:18:04 +08:00

1672 lines
45 KiB
Go

package core
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"ragflow/internal/harness/core/schema"
)
// ======================== Helpers ========================
type turnLoopMockAgent struct {
name string
response string
GenerateFn func(ctx context.Context, msgs []Message) (Message, error)
captureCancel bool
canceled atomic.Bool
}
func (a *turnLoopMockAgent) Name(_ context.Context) string { return a.name }
func (a *turnLoopMockAgent) Description(_ context.Context) string { return "mock agent" }
func (a *turnLoopMockAgent) Run(ctx context.Context, input *AgentInput, opts ...RunOption) *AsyncIterator[*AgentEvent] {
m := &mockModel{}
response := a.response
if response == "" {
response = "mock"
}
m.addResp(response)
agent := NewReActAgent(&ReActConfig[*schema.Message]{Model: m})
agent.name = a.name
return agent.Run(ctx, input, opts...)
}
func (a *turnLoopMockAgent) GetType() string { return "ReActAgent" }
func (a *turnLoopMockAgent) new() Agent {
return &turnLoopMockAgent{
name: a.name,
response: a.response,
GenerateFn: a.GenerateFn,
captureCancel: a.captureCancel,
}
}
type turnLoopMockRunner struct {
responses []string
idx int
}
type turnCancellableModel struct {
inner Model[*schema.Message]
cancelDetected atomic.Bool
}
func (m *turnCancellableModel) Generate(ctx context.Context, msgs []Message, opts ...modelOption) (Message, error) {
select {
case <-ctx.Done():
m.cancelDetected.Store(true)
return nil, ctx.Err()
default:
}
return m.inner.Generate(ctx, msgs, opts...)
}
func (m *turnCancellableModel) Stream(ctx context.Context, msgs []Message, opts ...modelOption) (*schema.StreamReader[Message], error) {
return m.inner.Stream(ctx, msgs, opts...)
}
func (m *turnCancellableModel) BindTools(tools []*schema.ToolInfo) error {
return m.inner.BindTools(tools)
}
func newTurnCheckpointStore() *memStore { return &memStore{data: make(map[string][]byte)} }
// simpleTurnLoop creates a minimal AgentLoop for quick tests
func simpleTurnLoop(onEvents func(context.Context, *TurnContext[string], *AsyncIterator[*AgentEvent]) error) *AgentLoop[string] {
return NewAgentLoop(AgentLoopConfig[string]{
GenInput: func(ctx context.Context, loop *AgentLoop[string], items []string) (*GenInputResult[string], error) {
if len(items) == 0 {
return nil, nil
}
return &GenInputResult[string]{
Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}},
Consumed: items[:1], Remaining: items[1:],
}, nil
},
PrepareAgent: func(ctx context.Context, loop *AgentLoop[string], consumed []string) (Agent, error) {
m := &mockModel{}
m.addResp("Echo: " + consumed[0])
return NewReActAgent(&ReActConfig[*schema.Message]{Model: m}), nil
},
OnAgentEvents: onEvents,
})
}
// newAndRunTurnLoop creates and runs a AgentLoop in one call.
func newAndRunTurnLoop[T any](ctx context.Context, cfg AgentLoopConfig[T]) *AgentLoop[T] {
l := NewAgentLoop(cfg)
l.Run(ctx)
return l
}
// genInputConsumeAll consumes all items at once.
func genInputConsumeAll(_ context.Context, _ *AgentLoop[string], items []string) (*GenInputResult[string], error) {
if len(items) == 0 {
return nil, nil
}
return &GenInputResult[string]{Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, Remaining: nil}, nil
}
// genInputConsumeFirst consumes the first item, leaves rest for later.
func genInputConsumeFirst(_ context.Context, _ *AgentLoop[string], items []string) (*GenInputResult[string], error) {
if len(items) == 0 {
return nil, nil
}
return &GenInputResult[string]{
Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}},
Consumed: []string{items[0]},
Remaining: items[1:],
}, nil
}
// genInputConsumeAllWithMsg consumes all items and produces a user message.
func genInputConsumeAllWithMsg(_ context.Context, _ *AgentLoop[string], items []string) (*GenInputResult[string], error) {
if len(items) == 0 {
return nil, nil
}
return &GenInputResult[string]{
Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}},
Consumed: items,
}, nil
}
// prepareTestAgent returns a default mock agent.
var prepareTestAgent = func(_ context.Context, _ *AgentLoop[string], _ []string) (Agent, error) {
m := &mockModel{}
m.addResp("test")
agent := NewReActAgent(&ReActConfig[*schema.Message]{Model: m})
agent.name = "test"
return agent, nil
}
func prepareAgent(a Agent) func(context.Context, *AgentLoop[string], []string) (Agent, error) {
return func(_ context.Context, _ *AgentLoop[string], _ []string) (Agent, error) {
return a, nil
}
}
func waitOrFail(t *testing.T, ch <-chan struct{}, msg string) {
t.Helper()
select {
case <-ch:
case <-time.After(2 * time.Second):
t.Fatal(msg)
}
}
func newTestStore() *memStore {
return &memStore{data: make(map[string][]byte)}
}
// turnLoopCancellableMockAgent is a mock Agent that supports cancel observation.
type turnLoopCancellableMockAgent struct {
name string
runFunc func(ctx context.Context, input *AgentInput) (*AgentOutput, error)
onCancel func(cc *cancelContext)
cancel context.CancelFunc
mu sync.Mutex
}
func (a *turnLoopCancellableMockAgent) Name(_ context.Context) string { return a.name }
func (a *turnLoopCancellableMockAgent) Description(_ context.Context) string { return "mock agent" }
func (a *turnLoopCancellableMockAgent) Run(ctx context.Context, input *AgentInput, opts ...RunOption) *AsyncIterator[*AgentEvent] {
iter, gen := NewAsyncIteratorPair[*AgentEvent]()
o := getCommonOptions(nil, opts...)
cc := o.cancelCtx
a.mu.Lock()
var cancelCtx context.Context
cancelCtx, a.cancel = context.WithCancel(ctx)
a.mu.Unlock()
go func() {
defer gen.Close()
if cc != nil {
go func() {
<-cc.cancelChan
if a.onCancel != nil {
a.onCancel(cc)
}
a.mu.Lock()
if a.cancel != nil {
a.cancel()
}
a.mu.Unlock()
}()
}
output, err := a.runFunc(cancelCtx, input)
if err != nil {
gen.Send(&AgentEvent{Err: err})
return
}
gen.Send(&AgentEvent{Output: output})
}()
return iter
}
// turnLoopStopModeProbeAgent allows inspecting cancel mode.
type turnLoopStopModeProbeAgent struct {
ccCh chan *cancelContext
}
func (a *turnLoopStopModeProbeAgent) Name(_ context.Context) string { return "probe" }
func (a *turnLoopStopModeProbeAgent) Description(_ context.Context) string { return "probe" }
func (a *turnLoopStopModeProbeAgent) Run(_ context.Context, _ *AgentInput, opts ...RunOption) *AsyncIterator[*AgentEvent] {
iter, gen := NewAsyncIteratorPair[*AgentEvent]()
o := getCommonOptions(nil, opts...)
cc := o.cancelCtx
a.ccCh <- cc
go func() {
defer gen.Close()
<-cc.cancelChan
for {
if cc.getMode() == CancelImmediate {
gen.Send(&AgentEvent{Err: cc.createError()})
return
}
time.Sleep(1 * time.Millisecond)
}
}()
return iter
}
// turnLoopInterruptAgent is an agent that produces a business interrupt.
type turnLoopInterruptAgent struct {
interruptInfo any
}
func (a *turnLoopInterruptAgent) Name(_ context.Context) string { return "InterruptAgent" }
func (a *turnLoopInterruptAgent) Description(_ context.Context) string {
return "agent that interrupts"
}
func (a *turnLoopInterruptAgent) Run(ctx context.Context, _ *AgentInput, _ ...RunOption) *AsyncIterator[*AgentEvent] {
iter, gen := NewAsyncIteratorPair[*AgentEvent]()
go func() {
defer gen.Close()
gen.Send(Interrupt(ctx, a.interruptInfo))
}()
return iter
}
func containsString(s, substr string) bool {
return len(s) >= len(substr) && s[:len(substr)] == substr
}
// ======================== NewAgentLoop & Panic Tests ========================
func TestTurnLoop_NewPanicsWithNilGenInput(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic")
}
}()
NewAgentLoop[string](AgentLoopConfig[string]{PrepareAgent: func(_ context.Context, _ *AgentLoop[string], _ []string) (Agent, error) { return nil, nil }})
}
func TestTurnLoop_NewPanicsWithNilPrepareAgent(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic")
}
}()
NewAgentLoop[string](AgentLoopConfig[string]{GenInput: func(_ context.Context, _ *AgentLoop[string], _ []string) (*GenInputResult[string], error) {
return nil, nil
}})
}
// ======================== Basic Push-Stop-Run ========================
func TestTurnLoop_PushRunAndWait(t *testing.T) {
tl := simpleTurnLoop(nil)
tl.Push("a")
tl.Push("b")
ctx := context.Background()
tl.Stop()
tl.Run(ctx)
result := tl.Wait()
if result == nil {
t.Fatal("nil result")
}
t.Logf("basic: unhandled=%d", len(result.UnhandledItems))
}
func TestTurnLoop_StopCause(t *testing.T) {
tl := simpleTurnLoop(nil)
tl.Push("x")
tl.Stop(WithStopCause("max_tokens"))
tl.Run(context.Background())
result := tl.Wait()
if result.StopCause != "max_tokens" {
t.Errorf("StopCause = %q", result.StopCause)
}
}
func TestTurnLoop_OnAgentEventsCalled(t *testing.T) {
var called atomic.Bool
tl := simpleTurnLoop(func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error {
called.Store(true)
for {
ev, ok := events.Next()
if !ok {
break
}
_ = ev
}
return nil
})
ctx := context.Background()
tl.Run(ctx)
tl.Push("ev")
time.Sleep(50 * time.Millisecond)
tl.Stop()
tl.Wait()
if !called.Load() {
t.Error("OnAgentEvents not called")
}
}
func TestTurnLoop_OnAgentEventsReturnsError(t *testing.T) {
tl := simpleTurnLoop(func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error {
return errors.New("custom_events_error")
})
ctx := context.Background()
tl.Run(ctx)
tl.Push("fail")
time.Sleep(50 * time.Millisecond)
tl.Stop()
result := tl.Wait()
if result.ExitReason == nil || !containsString(result.ExitReason.Error(), "custom_events_error") {
t.Errorf("expected custom_events_error, got %v", result.ExitReason)
}
}
// ======================== GenInput / PrepareAgent Errors ========================
func TestTurnLoop_GenInputErrors(t *testing.T) {
tl := NewAgentLoop(AgentLoopConfig[string]{
GenInput: func(ctx context.Context, loop *AgentLoop[string], items []string) (*GenInputResult[string], error) {
if len(items) == 0 {
return nil, nil
}
return nil, errors.New("gen_input_err")
},
PrepareAgent: func(ctx context.Context, loop *AgentLoop[string], consumed []string) (Agent, error) {
return nil, nil
},
})
tl.Push("bad")
tl.Stop()
tl.Run(context.Background())
result := tl.Wait()
if result.ExitReason == nil {
t.Log("no exit error (may not reach GenInput before stop)")
}
}
func TestTurnLoop_PrepareAgentErrors(t *testing.T) {
tl := NewAgentLoop(AgentLoopConfig[string]{
GenInput: func(ctx context.Context, loop *AgentLoop[string], items []string) (*GenInputResult[string], error) {
return &GenInputResult[string]{Consumed: items, Remaining: nil}, nil
},
PrepareAgent: func(ctx context.Context, loop *AgentLoop[string], consumed []string) (Agent, error) {
return nil, errors.New("prepare_err")
},
})
tl.Push("bad")
tl.Stop()
tl.Run(context.Background())
result := tl.Wait()
if result.ExitReason == nil {
t.Log("no exit error (may not reach PrepareAgent)")
}
}
// ======================== Multiple Items ========================
func TestTurnLoop_MultipleItems(t *testing.T) {
tl := simpleTurnLoop(nil)
for i := 0; i < 10; i++ {
tl.Push(fmt.Sprintf("item-%d", i))
}
tl.Stop()
tl.Run(context.Background())
result := tl.Wait()
t.Logf("10 items: unhandled=%d interrupted=%d", len(result.UnhandledItems), len(result.InterruptedItems))
}
func TestTurnLoop_ConcurrentPush(t *testing.T) {
tl := simpleTurnLoop(nil)
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() { defer wg.Done(); tl.Push("c") }()
}
wg.Wait()
tl.Stop()
tl.Run(context.Background())
result := tl.Wait()
t.Logf("50 concurrent: unhandled=%d", len(result.UnhandledItems))
}
// ======================== Checkpoint ========================
func TestTurnLoop_WithCheckpoint(t *testing.T) {
store := newTurnCheckpointStore()
tl := NewAgentLoop(AgentLoopConfig[string]{
GenInput: func(ctx context.Context, loop *AgentLoop[string], items []string) (*GenInputResult[string], error) {
if len(items) == 0 {
return nil, nil
}
return &GenInputResult[string]{
Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}},
Consumed: items[:1], Remaining: items[1:],
}, nil
},
PrepareAgent: func(ctx context.Context, loop *AgentLoop[string], consumed []string) (Agent, error) {
m := &mockModel{}
m.addResp("cp:" + consumed[0])
return NewReActAgent(&ReActConfig[*schema.Message]{Model: m}), nil
},
Store: store,
})
tl.Push("cp1")
tl.Stop()
tl.Run(context.Background())
result := tl.Wait()
t.Logf("checkpoint: unhandled=%d", len(result.UnhandledItems))
}
// ======================== Stop Mode Tests ========================
func TestTurnLoop_ImmediateStop(t *testing.T) {
tl := simpleTurnLoop(nil)
tl.Push("urgent")
tl.Run(context.Background())
tl.Stop(WithImmediateStop(), WithSkipCheckpoint())
result := tl.Wait()
t.Logf("immediate: err=%v", result.ExitReason)
}
func TestTurnLoop_StopWithNoItems(t *testing.T) {
tl := simpleTurnLoop(nil)
tl.Stop(WithStopCause("empty"))
tl.Run(context.Background())
result := tl.Wait()
if result.StopCause != "empty" {
t.Errorf("StopCause = %q", result.StopCause)
}
}
func TestTurnLoop_StopMultipleTimes(t *testing.T) {
tl := simpleTurnLoop(nil)
tl.Push("x")
tl.Stop(WithStopCause("first"))
tl.Stop(WithStopCause("second"))
tl.Run(context.Background())
result := tl.Wait()
_ = result
}
// ======================== Context Cancel ========================
func TestTurnLoop_ContextCancel(t *testing.T) {
tl := simpleTurnLoop(nil)
tl.Push("task")
ctx, cancel := context.WithCancel(context.Background())
cancel()
tl.Stop()
tl.Run(ctx)
result := tl.Wait()
t.Logf("ctx cancel: err=%v", result.ExitReason)
}
// ======================== Items State ========================
func TestTurnLoop_PushAfterStop(t *testing.T) {
tl := simpleTurnLoop(nil)
tl.Push("a")
tl.Push("b")
tl.Stop()
tl.Push("c")
tl.Run(context.Background())
tl.Wait()
}
// ======================== AgentLoop with Tools ========================
func TestTurnLoop_WithToolAgent(t *testing.T) {
tool := &mockTool{name: "calc", desc: "calculator"}
tl := NewAgentLoop(AgentLoopConfig[string]{
GenInput: func(ctx context.Context, loop *AgentLoop[string], items []string) (*GenInputResult[string], error) {
if len(items) == 0 {
return nil, nil
}
return &GenInputResult[string]{
Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}},
Consumed: items[:1], Remaining: items[1:],
}, nil
},
PrepareAgent: func(ctx context.Context, loop *AgentLoop[string], consumed []string) (Agent, error) {
wrapperModel := &forcedToolModel{
toolCalls: []schema.ToolCall{{ID: "c1", Function: schema.ToolCallFunction{Name: "calc", Arguments: "{}"}}},
finalResp: "Tool done", firstCall: true,
}
return NewReActAgent(&ReActConfig[*schema.Message]{
Model: wrapperModel, Tools: []Tool{tool},
}), nil
},
})
tl.Push("use tool")
tl.Stop()
ctx := context.Background()
tl.Run(ctx)
result := tl.Wait()
t.Logf("tool agent: unhandled=%d", len(result.UnhandledItems))
}
// ======================== GenInput variants ========================
func TestTurnLoop_GenInputAllConsumed(t *testing.T) {
tl := NewAgentLoop(AgentLoopConfig[string]{
GenInput: func(ctx context.Context, loop *AgentLoop[string], items []string) (*GenInputResult[string], error) {
if len(items) == 0 {
return nil, nil
}
return &GenInputResult[string]{Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}}, Consumed: items, Remaining: nil}, nil
},
PrepareAgent: func(ctx context.Context, loop *AgentLoop[string], consumed []string) (Agent, error) {
m := &mockModel{}
m.addResp("all")
return NewReActAgent(&ReActConfig[*schema.Message]{Model: m}), nil
},
})
tl.Push("1")
tl.Push("2")
tl.Stop()
tl.Run(context.Background())
tl.Wait()
}
func TestTurnLoop_GenInputOneByOne(t *testing.T) {
tl := NewAgentLoop(AgentLoopConfig[string]{
GenInput: func(ctx context.Context, loop *AgentLoop[string], items []string) (*GenInputResult[string], error) {
if len(items) == 0 {
return nil, nil
}
return &GenInputResult[string]{
Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}},
Consumed: items[:1], Remaining: items[1:],
}, nil
},
PrepareAgent: func(ctx context.Context, loop *AgentLoop[string], consumed []string) (Agent, error) {
m := &mockModel{}
m.addResp("one:" + consumed[0])
return NewReActAgent(&ReActConfig[*schema.Message]{Model: m}), nil
},
})
tl.Push("x")
tl.Push("y")
tl.Push("z")
tl.Stop()
tl.Run(context.Background())
result := tl.Wait()
t.Logf("stream: unhandled=%d", len(result.UnhandledItems))
}
func TestTurnLoop_GenInputConsumedNone(t *testing.T) {
tl := NewAgentLoop(AgentLoopConfig[string]{
GenInput: func(ctx context.Context, loop *AgentLoop[string], items []string) (*GenInputResult[string], error) {
return &GenInputResult[string]{Consumed: nil, Remaining: items}, nil
},
PrepareAgent: func(ctx context.Context, loop *AgentLoop[string], consumed []string) (Agent, error) {
return nil, nil
},
})
tl.Push("x")
tl.Stop()
tl.Run(context.Background())
result := tl.Wait()
t.Logf("none consumed: unhandled=%d", len(result.UnhandledItems))
}
// ======================== OnStop / Intercepted Items ========================
func TestTurnLoop_InterceptedItems(t *testing.T) {
tl := NewAgentLoop(AgentLoopConfig[string]{
GenInput: func(ctx context.Context, loop *AgentLoop[string], items []string) (*GenInputResult[string], error) {
if len(items) == 0 {
return nil, nil
}
return &GenInputResult[string]{
Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}},
Consumed: items[:1], Remaining: items[1:],
}, nil
},
PrepareAgent: func(ctx context.Context, loop *AgentLoop[string], consumed []string) (Agent, error) {
m := &mockModel{}
m.addResp("intercepted")
return NewReActAgent(&ReActConfig[*schema.Message]{Model: m}), nil
},
})
tl.Push("a")
tl.Run(context.Background())
tl.Stop(WithImmediateStop(), WithSkipCheckpoint())
result := tl.Wait()
_ = result
}
// ======================== Edge Cases ========================
func TestTurnLoop_NoPushBeforeRun(t *testing.T) {
tl := simpleTurnLoop(nil)
tl.Stop()
tl.Run(context.Background())
result := tl.Wait()
if result == nil {
t.Fatal("nil result")
}
}
func TestTurnLoop_DoubleRunPanics(t *testing.T) {
tl := simpleTurnLoop(nil)
tl.Push("x")
tl.Stop()
tl.Run(context.Background())
tl.Run(context.Background()) // should be no-op
tl.Wait()
}
func TestTurnLoop_RunThenStopThenWait(t *testing.T) {
tl := simpleTurnLoop(nil)
tl.Push("x")
ctx := context.Background()
tl.Run(ctx)
tl.Stop()
result := tl.Wait()
if result == nil {
t.Fatal("nil result")
}
}
// ======================== edge-case tests ========================
// TestTurnLoop_StopIsIdempotent verifies multiple Stop() calls are safe.
func TestTurnLoop_StopIsIdempotent(t *testing.T) {
loop := newAndRunTurnLoop(context.Background(), AgentLoopConfig[string]{
GenInput: genInputConsumeAll,
PrepareAgent: prepareTestAgent,
})
loop.Stop()
loop.Stop()
loop.Stop()
result := loop.Wait()
if result.ExitReason != nil {
t.Errorf("expected nil exit reason, got %v", result.ExitReason)
}
}
// TestTurnLoop_WaitMultipleGoroutines verifies Wait() is safe for concurrent callers.
func TestTurnLoop_WaitMultipleGoroutines(t *testing.T) {
loop := newAndRunTurnLoop(context.Background(), AgentLoopConfig[string]{
GenInput: genInputConsumeAll,
PrepareAgent: prepareTestAgent,
})
loop.Stop()
var wg sync.WaitGroup
results := make([]*AgentLoopState[string], 3)
for i := 0; i < 3; i++ {
i := i
wg.Add(1)
go func() {
defer wg.Done()
results[i] = loop.Wait()
}()
}
wg.Wait()
// All should return the same pointer
for i := 1; i < 3; i++ {
if results[0] != results[i] {
t.Errorf("Wait returned different results for goroutines")
}
}
}
// TestTurnLoop_GetAgentError verifies PrepareAgent errors propagate.
func TestTurnLoop_GetAgentError(t *testing.T) {
agentErr := errors.New("get agent error")
loop := newAndRunTurnLoop(context.Background(), AgentLoopConfig[string]{
GenInput: genInputConsumeAll,
PrepareAgent: func(ctx context.Context, _ *AgentLoop[string], consumed []string) (Agent, error) {
return nil, agentErr
},
})
loop.Push("msg1")
result := loop.Wait()
if !errors.Is(result.ExitReason, agentErr) {
t.Errorf("expected agentErr, got %v", result.ExitReason)
}
}
// TestTurnLoop_BatchProcessing verifies GenInput receives batched items.
func TestTurnLoop_BatchProcessing(t *testing.T) {
var batches [][]string
var mu sync.Mutex
loop := newAndRunTurnLoop(context.Background(), AgentLoopConfig[string]{
GenInput: func(ctx context.Context, _ *AgentLoop[string], items []string) (*GenInputResult[string], error) {
mu.Lock()
batches = append(batches, items)
mu.Unlock()
return &GenInputResult[string]{
Input: &AgentInput{},
Consumed: items[:1],
Remaining: items[1:],
}, nil
},
PrepareAgent: prepareTestAgent,
})
loop.Push("msg1")
loop.Push("msg2")
loop.Push("msg3")
time.Sleep(200 * time.Millisecond)
loop.Stop()
loop.Wait()
mu.Lock()
defer mu.Unlock()
if len(batches) == 0 {
t.Error("should have processed at least one batch")
}
}
// TestTurnLoop_StopWithMode verifies Stop with WithGracefulStop works.
func TestTurnLoop_StopWithMode(t *testing.T) {
loop := newAndRunTurnLoop(context.Background(), AgentLoopConfig[string]{
GenInput: genInputConsumeAll,
PrepareAgent: prepareTestAgent,
})
loop.Stop(WithGracefulStop())
result := loop.Wait()
if result.ExitReason != nil {
t.Errorf("expected nil, got %v", result.ExitReason)
}
}
// ======================== Context Cancel Variants ========================
func TestTurnLoop_ContextDeadlineExceeded(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
loop := newAndRunTurnLoop(ctx, AgentLoopConfig[string]{
GenInput: func(ctx context.Context, _ *AgentLoop[string], items []string) (*GenInputResult[string], error) {
select {
case <-time.After(100 * time.Millisecond):
return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil
case <-ctx.Done():
return nil, ctx.Err()
}
},
PrepareAgent: prepareTestAgent,
})
loop.Push("msg1")
result := loop.Wait()
if !errors.Is(result.ExitReason, context.DeadlineExceeded) {
t.Logf("expected DeadlineExceeded, got %v (may be nil if loop stopped before timeout)", result.ExitReason)
}
}
func TestTurnLoop_ContextCancelBeforeReceive(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
loop := NewAgentLoop(AgentLoopConfig[string]{
GenInput: genInputConsumeAll,
PrepareAgent: prepareTestAgent,
})
loop.Push("msg1")
loop.Run(ctx)
result := loop.Wait()
if !errors.Is(result.ExitReason, context.Canceled) {
t.Logf("expected Canceled, got %v", result.ExitReason)
}
}
func TestTurnLoop_ContextCancelDuringBlockingReceive(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
loop := newAndRunTurnLoop(ctx, AgentLoopConfig[string]{
GenInput: genInputConsumeAll,
PrepareAgent: prepareTestAgent,
})
time.Sleep(50 * time.Millisecond)
cancel()
result := loop.Wait()
if !errors.Is(result.ExitReason, context.Canceled) {
t.Logf("expected Canceled, got %v", result.ExitReason)
}
}
func TestTurnLoop_ContextCancelAfterGenInput_RecoverItems(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
genInputCount := 0
loop := newAndRunTurnLoop(ctx, AgentLoopConfig[string]{
GenInput: func(ctx context.Context, _ *AgentLoop[string], items []string) (*GenInputResult[string], error) {
genInputCount++
if genInputCount == 1 {
cancel()
}
return &GenInputResult[string]{
Input: &AgentInput{},
Consumed: items[:1],
Remaining: items[1:],
}, nil
},
PrepareAgent: func(ctx context.Context, _ *AgentLoop[string], c []string) (Agent, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
return NewReActAgent(&ReActConfig[*schema.Message]{Model: &mockModel{}}), nil
},
})
loop.Push("msg1")
loop.Push("msg2")
result := loop.Wait()
if !errors.Is(result.ExitReason, context.Canceled) {
t.Logf("expected Canceled, got %v", result.ExitReason)
}
if len(result.UnhandledItems) == 0 {
t.Log("no unhandled items (may have been consumed before cancel)")
}
}
// ======================== OnAgentEvents Tests ========================
func TestTurnLoop_OnAgentEventsReceivesEvents(t *testing.T) {
var receivedEvents []*AgentEvent
var receivedConsumed []string
var mu sync.Mutex
loop := newAndRunTurnLoop(context.Background(), AgentLoopConfig[string]{
GenInput: genInputConsumeAllWithMsg,
PrepareAgent: prepareTestAgent,
OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error {
mu.Lock()
receivedConsumed = append(receivedConsumed, tc.Consumed...)
mu.Unlock()
for {
event, ok := events.Next()
if !ok {
break
}
mu.Lock()
receivedEvents = append(receivedEvents, event)
mu.Unlock()
}
return nil
},
})
loop.Push("msg1")
time.Sleep(100 * time.Millisecond)
loop.Stop()
result := loop.Wait()
mu.Lock()
defer mu.Unlock()
if result.ExitReason != nil {
t.Logf("exit reason: %v", result.ExitReason)
}
if len(receivedConsumed) == 0 {
t.Error("should have received consumed items")
}
}
// ======================== Stop with Checkpoint Cancel ========================
func TestTurnLoop_StopCheckPointIDInCancelError(t *testing.T) {
ctx := context.Background()
modelStarted := make(chan struct{}, 1)
checkpointID := "turn-loop-cancel-ckpt-1"
store := newTestStore()
slowModel := &cancelTestChatModel{
delayNs: int64(500 * time.Millisecond),
startedChan: modelStarted,
doneChan: make(chan struct{}, 1),
}
slowModel.addResp("Hello")
agent := NewReActAgent(&ReActConfig[*schema.Message]{
Instruction: "You are a test assistant",
Model: slowModel,
}).WithName("TestAgent").WithDescription("Test agent")
loop := newAndRunTurnLoop(ctx, AgentLoopConfig[string]{
Store: store,
CheckpointID: checkpointID,
GenInput: genInputConsumeAllWithMsg,
PrepareAgent: prepareAgent(agent),
})
loop.Push("msg1")
<-modelStarted
loop.Stop(WithImmediateStop())
result := loop.Wait()
t.Logf("exit reason: %v", result.ExitReason)
}
// ======================== CancelError Captured Independently ========================
func TestTurnLoop_CancelError_CapturedIndependentlyOfCallback(t *testing.T) {
ctx := context.Background()
modelStarted := make(chan struct{}, 1)
checkpointID := "cancel-capture-independent-1"
store := newTestStore()
slowModel := &cancelTestChatModel{
delayNs: int64(500 * time.Millisecond),
startedChan: modelStarted,
doneChan: make(chan struct{}, 1),
}
slowModel.addResp("Hello")
agent := NewReActAgent(&ReActConfig[*schema.Message]{
Instruction: "You are a test assistant",
Model: slowModel,
}).WithName("TestAgent").WithDescription("Test agent")
loop := newAndRunTurnLoop(ctx, AgentLoopConfig[string]{
Store: store,
CheckpointID: checkpointID,
GenInput: genInputConsumeAllWithMsg,
PrepareAgent: prepareAgent(agent),
OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error {
for {
_, ok := events.Next()
if !ok {
break
}
}
return nil // swallow everything
},
})
loop.Push("msg1")
<-modelStarted
loop.Stop(WithImmediateStop())
result := loop.Wait()
t.Logf("exit reason: %v", result.ExitReason)
}
// ======================== Stop Without CheckpointID ========================
func TestTurnLoop_StopWithoutCheckpointIDDoesNotPersist(t *testing.T) {
ctx := context.Background()
modelStarted := make(chan struct{}, 1)
store := newTestStore()
slowModel := &cancelTestChatModel{
delayNs: int64(500 * time.Millisecond),
startedChan: modelStarted,
doneChan: make(chan struct{}, 1),
}
slowModel.addResp("Hello")
agent := NewReActAgent(&ReActConfig[*schema.Message]{
Instruction: "You are a test assistant",
Model: slowModel,
}).WithName("TestAgent").WithDescription("Test agent")
loop := newAndRunTurnLoop(ctx, AgentLoopConfig[string]{
Store: store,
GenInput: genInputConsumeAllWithMsg,
PrepareAgent: prepareAgent(agent),
})
loop.Push("msg1")
<-modelStarted
loop.Stop(WithImmediateStop())
loop.Wait()
}
// ======================== Stop While Idle ========================
func TestTurnLoop_StopWhileIdle_SkipsCheckpoint(t *testing.T) {
ctx := context.Background()
store := newTestStore()
cpID := "idle-session"
loop := newAndRunTurnLoop(ctx, AgentLoopConfig[string]{
Store: store,
CheckpointID: cpID,
GenInput: genInputConsumeAll,
PrepareAgent: prepareTestAgent,
})
loop.Stop()
exit := loop.Wait()
if exit.ExitReason != nil {
t.Errorf("expected nil, got %v", exit.ExitReason)
}
}
// ======================== Stop Call From GenInput ========================
func TestTurnLoop_StopCallFromGenInput(t *testing.T) {
loop := newAndRunTurnLoop(context.Background(), AgentLoopConfig[string]{
GenInput: func(ctx context.Context, loop *AgentLoop[string], items []string) (*GenInputResult[string], error) {
loop.Stop()
return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil
},
PrepareAgent: prepareTestAgent,
})
loop.Push("msg1")
result := loop.Wait()
if result.ExitReason != nil {
t.Errorf("expected nil, got %v", result.ExitReason)
}
}
// ======================== Push From OnAgentEvents ========================
func TestTurnLoop_PushFromOnAgentEvents(t *testing.T) {
pushCount := int32(0)
loop := newAndRunTurnLoop(context.Background(), AgentLoopConfig[string]{
GenInput: genInputConsumeFirst,
PrepareAgent: prepareTestAgent,
OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error {
for {
_, ok := events.Next()
if !ok {
break
}
}
count := atomic.AddInt32(&pushCount, 1)
if count == 1 {
tc.Loop.Push("follow-up")
} else {
tc.Loop.Stop()
}
return nil
},
})
loop.Push("initial")
result := loop.Wait()
if result.ExitReason != nil {
t.Errorf("expected nil, got %v", result.ExitReason)
}
if atomic.LoadInt32(&pushCount) != 2 {
t.Errorf("expected 2 pushes, got %d", atomic.LoadInt32(&pushCount))
}
}
// ======================== NewAgentLoop: Push Before Run ========================
func TestNewTurnLoop_PushBeforeRun(t *testing.T) {
var processedItems []string
var mu sync.Mutex
loop := NewAgentLoop(AgentLoopConfig[string]{
GenInput: func(ctx context.Context, _ *AgentLoop[string], items []string) (*GenInputResult[string], error) {
mu.Lock()
processedItems = append(processedItems, items...)
mu.Unlock()
return &GenInputResult[string]{
Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}},
Consumed: items,
}, nil
},
PrepareAgent: prepareTestAgent,
})
ok, _ := loop.Push("msg1")
if !ok {
t.Error("Push returned false")
}
ok, _ = loop.Push("msg2")
if !ok {
t.Error("Push returned false")
}
loop.Run(context.Background())
time.Sleep(100 * time.Millisecond)
loop.Stop()
result := loop.Wait()
mu.Lock()
defer mu.Unlock()
if result.ExitReason != nil {
t.Errorf("expected nil, got %v", result.ExitReason)
}
if len(processedItems) == 0 {
t.Error("expected processed items")
}
}
// ======================== NewAgentLoop: Wait Before Run ========================
func TestNewTurnLoop_WaitBeforeRun(t *testing.T) {
loop := NewAgentLoop(AgentLoopConfig[string]{
GenInput: genInputConsumeAll,
PrepareAgent: prepareTestAgent,
})
waitDone := make(chan *AgentLoopState[string], 1)
go func() {
waitDone <- loop.Wait()
}()
select {
case <-waitDone:
t.Fatal("Wait returned before Run was called")
case <-time.After(50 * time.Millisecond):
}
loop.Push("msg1")
loop.Stop()
loop.Run(context.Background())
select {
case result := <-waitDone:
if result.ExitReason != nil {
t.Errorf("expected nil, got %v", result.ExitReason)
}
case <-time.After(1 * time.Second):
t.Fatal("Wait did not return after Run + Stop")
}
}
// ======================== NewAgentLoop: Run Is Idempotent ========================
func TestNewTurnLoop_RunIsIdempotent(t *testing.T) {
var genInputCalls int32
loop := NewAgentLoop(AgentLoopConfig[string]{
GenInput: func(ctx context.Context, _ *AgentLoop[string], items []string) (*GenInputResult[string], error) {
atomic.AddInt32(&genInputCalls, 1)
return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil
},
PrepareAgent: prepareTestAgent,
})
loop.Push("msg1")
loop.Run(context.Background())
loop.Run(context.Background())
loop.Run(context.Background())
time.Sleep(100 * time.Millisecond)
loop.Stop()
result := loop.Wait()
if result.ExitReason != nil {
t.Errorf("expected nil, got %v", result.ExitReason)
}
if atomic.LoadInt32(&genInputCalls) < 1 {
t.Error("expected at least 1 GenInput call")
}
}
// ======================== NewAgentLoop: Concurrent Push And Run ========================
func TestNewTurnLoop_ConcurrentPushAndRun(t *testing.T) {
for i := 0; i < 50; i++ {
var count int32
loop := NewAgentLoop(AgentLoopConfig[string]{
GenInput: func(ctx context.Context, _ *AgentLoop[string], items []string) (*GenInputResult[string], error) {
atomic.AddInt32(&count, int32(len(items)))
return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil
},
PrepareAgent: func(ctx context.Context, _ *AgentLoop[string], consumed []string) (Agent, error) {
return NewReActAgent(&ReActConfig[*schema.Message]{Model: &mockModel{}}), nil
},
})
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
loop.Push("item")
}()
go func() {
defer wg.Done()
loop.Run(context.Background())
}()
wg.Wait()
time.Sleep(50 * time.Millisecond)
loop.Stop()
result := loop.Wait()
processed := atomic.LoadInt32(&count)
unhandled := len(result.UnhandledItems)
if int(processed)+unhandled > 1 {
t.Errorf("total should not exceed pushed amount: processed=%d unhandled=%d", processed, unhandled)
}
}
}
// ======================== Context Propagation ========================
type turnCtxKey struct{}
// TestTurnLoop_CtxPropagation verifies the parent context is propagated to
// PrepareAgent, the agent run, and OnAgentEvents.
func TestTurnLoop_CtxPropagation(t *testing.T) {
const traceVal = "trace-123"
var prepareCtxVal, eventsCtxVal string
ctx := context.WithValue(context.Background(), turnCtxKey{}, traceVal)
cfg := AgentLoopConfig[string]{
GenInput: func(ctx context.Context, loop *AgentLoop[string], items []string) (*GenInputResult[string], error) {
return &GenInputResult[string]{
Input: &AgentInput{Messages: []Message{schema.UserMessage(items[0])}},
Consumed: items,
}, nil
},
PrepareAgent: func(ctx context.Context, loop *AgentLoop[string], consumed []string) (Agent, error) {
if v, ok := ctx.Value(turnCtxKey{}).(string); ok {
prepareCtxVal = v
}
return &turnLoopMockAgent{
name: "trace-agent",
GenerateFn: func(ctx context.Context, msgs []Message) (Message, error) {
return &schema.Message{Role: schema.RoleAssistant, Content: "done"}, nil
},
}, nil
},
OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error {
if v, ok := ctx.Value(turnCtxKey{}).(string); ok {
eventsCtxVal = v
}
for {
if _, ok := events.Next(); !ok {
break
}
}
tc.Loop.Stop()
return nil
},
}
loop := NewAgentLoop(cfg)
loop.Push("hello")
loop.Run(ctx)
result := loop.Wait()
if result.ExitReason != nil {
t.Errorf("expected nil, got %v", result.ExitReason)
}
if prepareCtxVal != traceVal {
t.Errorf("PrepareAgent should receive parent context: got %q", prepareCtxVal)
}
if eventsCtxVal != traceVal {
t.Errorf("OnAgentEvents should receive parent context: got %q", eventsCtxVal)
}
}
// ======================== TurnContext Stopped Channel ========================
func TestTurnLoop_TurnContext_StoppedChannel(t *testing.T) {
stoppedSeen := make(chan struct{})
agentStarted := make(chan struct{})
loop := newAndRunTurnLoop(context.Background(), AgentLoopConfig[string]{
GenInput: genInputConsumeAllWithMsg,
PrepareAgent: func(ctx context.Context, _ *AgentLoop[string], consumed []string) (Agent, error) {
return &turnLoopCancellableMockAgent{
name: "slow",
runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) {
<-ctx.Done()
return nil, ctx.Err()
},
}, nil
},
OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error {
close(agentStarted)
select {
case <-tc.Stopped:
close(stoppedSeen)
case <-time.After(5 * time.Second):
t.Error("timed out waiting for Stopped channel")
}
for {
if _, ok := events.Next(); !ok {
break
}
}
return nil
},
})
loop.Push("msg1")
<-agentStarted
loop.Stop(WithImmediateStop())
select {
case <-stoppedSeen:
// success
case <-time.After(5 * time.Second):
t.Fatal("stopped channel was never observed in OnAgentEvents")
}
loop.Wait()
}
// ======================== Stop With Skip Checkpoint ========================
func TestTurnLoop_StopWithSkipCheckpoint(t *testing.T) {
ctx := context.Background()
store := newTestStore()
cpID := "skip-cp-session"
loop := NewAgentLoop(AgentLoopConfig[string]{
Store: store,
CheckpointID: cpID,
GenInput: genInputConsumeAll,
PrepareAgent: prepareTestAgent,
})
loop.Push("a")
loop.Push("b")
loop.Stop(WithSkipCheckpoint())
loop.Run(ctx)
exit := loop.Wait()
if exit.ExitReason != nil {
t.Errorf("expected nil, got %v", exit.ExitReason)
}
}
// ======================== Stop With Stop Cause ========================
func TestTurnLoop_StopWithStopCause(t *testing.T) {
ctx := context.Background()
cause := "user session timeout"
loop := newAndRunTurnLoop(ctx, AgentLoopConfig[string]{
GenInput: genInputConsumeAll,
PrepareAgent: prepareTestAgent,
})
loop.Push("a")
loop.Stop(WithStopCause(cause))
exit := loop.Wait()
if exit.StopCause != cause {
t.Errorf("expected %q, got %q", cause, exit.StopCause)
}
}
func TestTurnLoop_StopCause_EmptyWhenNoStop(t *testing.T) {
ctx := context.Background()
loop := newAndRunTurnLoop(ctx, AgentLoopConfig[string]{
GenInput: genInputConsumeAll,
PrepareAgent: prepareTestAgent,
})
loop.Stop()
exit := loop.Wait()
if exit.StopCause != "" {
t.Errorf("expected empty, got %q", exit.StopCause)
}
}
func TestTurnLoop_StopCause_InTurnContext(t *testing.T) {
cause := "business shutdown"
gotCause := make(chan string, 1)
agentStarted := make(chan struct{})
loop := newAndRunTurnLoop(context.Background(), AgentLoopConfig[string]{
GenInput: genInputConsumeAllWithMsg,
PrepareAgent: func(ctx context.Context, _ *AgentLoop[string], consumed []string) (Agent, error) {
return &turnLoopCancellableMockAgent{
name: "slow",
runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) {
<-ctx.Done()
return nil, ctx.Err()
},
}, nil
},
OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error {
close(agentStarted)
select {
case <-tc.Stopped:
gotCause <- tc.StopCause()
case <-time.After(5 * time.Second):
t.Error("timed out waiting for Stopped channel")
}
for {
if _, ok := events.Next(); !ok {
break
}
}
return nil
},
})
loop.Push("msg1")
<-agentStarted
loop.Stop(WithImmediateStop(), WithStopCause(cause))
select {
case c := <-gotCause:
if c != cause {
t.Errorf("expected %q, got %q", cause, c)
}
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for StopCause in TurnContext")
}
exit := loop.Wait()
if exit.StopCause != cause {
t.Errorf("expected %q, got %q", cause, exit.StopCause)
}
}
func TestTurnLoop_StopCause_FirstNonEmptyWins(t *testing.T) {
agentStarted := make(chan struct{})
loop := newAndRunTurnLoop(context.Background(), AgentLoopConfig[string]{
GenInput: genInputConsumeAllWithMsg,
PrepareAgent: func(ctx context.Context, _ *AgentLoop[string], consumed []string) (Agent, error) {
return &turnLoopCancellableMockAgent{
name: "slow",
runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) {
<-ctx.Done()
return nil, ctx.Err()
},
}, nil
},
OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error {
close(agentStarted)
for {
if _, ok := events.Next(); !ok {
break
}
}
return nil
},
})
loop.Push("msg1")
<-agentStarted
loop.Stop(WithGracefulStop(), WithStopCause("first cause"))
loop.Stop(WithStopCause("second cause"))
exit := loop.Wait()
if exit.StopCause != "first cause" {
t.Errorf("expected 'first cause', got %q", exit.StopCause)
}
}
// ======================== Stop Before Run ========================
func TestTurnLoop_StopBeforeRun_PushThenStop(t *testing.T) {
loop := NewAgentLoop(AgentLoopConfig[string]{
GenInput: func(ctx context.Context, _ *AgentLoop[string], items []string) (*GenInputResult[string], error) {
t.Fatal("GenInput should not be called when Stop is called before Run")
return nil, nil
},
PrepareAgent: func(ctx context.Context, _ *AgentLoop[string], consumed []string) (Agent, error) {
t.Fatal("PrepareAgent should not be called when Stop is called before Run")
return nil, nil
},
})
ok, _ := loop.Push("item1")
if !ok {
t.Error("Push returned false")
}
ok, _ = loop.Push("item2")
if !ok {
t.Error("Push returned false")
}
loop.Stop()
loop.Run(context.Background())
result := loop.Wait()
if result.ExitReason != nil {
t.Errorf("expected nil, got %v", result.ExitReason)
}
}
// ======================== Skip Checkpoint Sticky ========================
func TestTurnLoop_SkipCheckpoint_Sticky(t *testing.T) {
agentStarted := make(chan struct{})
store := newTestStore()
cpID := "sticky-skip-session"
loop := newAndRunTurnLoop(context.Background(), AgentLoopConfig[string]{
Store: store,
CheckpointID: cpID,
GenInput: genInputConsumeAllWithMsg,
PrepareAgent: func(ctx context.Context, _ *AgentLoop[string], consumed []string) (Agent, error) {
return &turnLoopCancellableMockAgent{
name: "slow",
runFunc: func(ctx context.Context, input *AgentInput) (*AgentOutput, error) {
<-ctx.Done()
return nil, ctx.Err()
},
}, nil
},
OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error {
close(agentStarted)
for {
if _, ok := events.Next(); !ok {
break
}
}
return nil
},
})
loop.Push("msg1")
<-agentStarted
loop.Stop(WithGracefulStop(), WithSkipCheckpoint())
loop.Stop()
exit := loop.Wait()
_ = exit
t.Logf("skip checkpoint sticky: exit=%v", exit.ExitReason)
}
// ======================== GenInput Error Recovery ========================
func TestTurnLoop_GenInputError_RecoverItems(t *testing.T) {
genErr := errors.New("gen input error")
loop := newAndRunTurnLoop(context.Background(), AgentLoopConfig[string]{
GenInput: func(ctx context.Context, _ *AgentLoop[string], items []string) (*GenInputResult[string], error) {
return nil, genErr
},
PrepareAgent: prepareTestAgent,
})
loop.Push("msg1")
loop.Push("msg2")
result := loop.Wait()
if !errors.Is(result.ExitReason, genErr) {
t.Errorf("expected genErr, got %v", result.ExitReason)
}
}
// ======================== Checkpoint Not Found ========================
func TestTurnLoop_CheckpointNotFound_FreshStart(t *testing.T) {
ctx := context.Background()
store := newTestStore()
var genInputCalled bool
loop := NewAgentLoop(AgentLoopConfig[string]{
Store: store,
CheckpointID: "nonexistent-id",
GenInput: func(ctx context.Context, _ *AgentLoop[string], items []string) (*GenInputResult[string], error) {
genInputCalled = true
return &GenInputResult[string]{Input: &AgentInput{}, Consumed: items}, nil
},
PrepareAgent: prepareTestAgent,
OnAgentEvents: func(ctx context.Context, tc *TurnContext[string], events *AsyncIterator[*AgentEvent]) error {
for {
if _, ok := events.Next(); !ok {
break
}
}
tc.Loop.Stop()
return nil
},
})
loop.Push("a")
loop.Run(ctx)
exit := loop.Wait()
if exit.ExitReason != nil {
t.Errorf("expected nil, got %v", exit.ExitReason)
}
if !genInputCalled {
t.Error("GenInput should be called when checkpoint not found")
}
}
// ======================== TurnBuffer Tests ========================
func TestAttack_TurnBuffer_WakeupDoesNotLoseItems(t *testing.T) {
tb := newTurnBuffer[string]()
tb.TrySend("a")
tb.TrySend("b")
tb.Wakeup()
tb.TrySend("c")
var got []string
for i := 0; i < 3; i++ {
val, ok := tb.Receive()
if !ok {
t.Fatal("expected ok")
}
got = append(got, val)
}
if len(got) != 3 || got[0] != "a" || got[1] != "b" || got[2] != "c" {
t.Errorf("expected [a b c], got %v", got)
}
}
// ======================== AgentLoop Preempt During Planning ========================