Files
ragflow/internal/harness/core/callback_test.go
Yingfeng 706e0d2d06 Refactor harness framework (#16271)
### What problem does this PR solve?

- Tools management
- Pregel engine wrapper for better usage
- UT race
- Coding style

### Type of change

- [x] Refactoring
2026-06-23 20:18:04 +08:00

312 lines
8.5 KiB
Go

package core
import (
"context"
"testing"
"ragflow/internal/harness/core/schema"
)
// ---- Callback infrastructure tests ----
//
// Callback initialization (initAgentCallbacks) is triggered via flowAgent.Run,
// not directly via ReActAgent.Run. These tests verify the infrastructure
// layer: context propagation, filtering, and option handling.
func TestInitAgentCallbacks_NoCallbacks(t *testing.T) {
ctx := initAgentCallbacks(context.Background(), "test_agent", "ReActAgent")
cbs := getCallbacks(ctx)
if cbs != nil {
t.Error("expected nil callbacks when no options provided")
}
}
func TestInitAgentCallbacks_WithCallbacks(t *testing.T) {
cb := callbackHandler{
onStart: func(ctx context.Context, input *AgentCallbackInput) {},
}
// Simulate what initAgentCallbacks does: filter options and store
opts := []RunOption{WithCallbacks(cb)}
o := getCommonOptions(nil, opts...)
if len(o.callbacks) != 1 {
t.Error("expected 1 callback in options")
}
_ = cb
}
func TestFilterCallbacks_AgentNameMatch(t *testing.T) {
cb := callbackHandler{
onStart: func(ctx context.Context, input *AgentCallbackInput) {},
}
opts := []RunOption{WithCallbacks(cb), WithAgentNames("my_agent")}
// filterOptions includes callback when name matches
filtered := filterOptions("my_agent", opts)
o := getCommonOptions(nil, filtered...)
if len(o.callbacks) == 0 {
t.Error("expected callbacks to pass through for matching agent")
}
// filterOptions still includes callbacks because WithCallbacks doesn't set agentNames
// The filter only excludes options that explicitly set agentNames for a non-matching name
filtered2 := filterOptions("other_agent", opts)
o2 := getCommonOptions(nil, filtered2...)
if len(o2.callbacks) != 1 {
t.Error("WithCallbacks option doesn't carry agentNames, so it passes through all filters")
}
// Verify that WithAgentNames option IS filtered for non-matching agents
filtered3 := filterOptions("other_agent", opts)
o3 := getCommonOptions(nil, filtered3...)
if len(o3.agentNames) != 0 {
t.Log("agent name filter options are correctly filtered (agentNames removed)")
}
}
func TestCallbacks_WithAgentNamesFilter_CallbackSavedAndFiltered(t *testing.T) {
model := &mockModel{}
model.addResp("filter-test")
agent := NewReActAgent(&ReActConfig[*schema.Message]{Model: model})
agent.name = "filtered_agent"
cb := callbackHandler{
onStart: func(ctx context.Context, input *AgentCallbackInput) {},
}
opts := []RunOption{WithCallbacks(cb), WithAgentNames("filtered_agent")}
// Callback is at the option level; it gets injected during flowAgent.Run
iter := agent.Run(context.Background(), &AgentInput{
Messages: []Message{schema.UserMessage("test")},
}, opts...)
for {
ev, ok := iter.Next()
if !ok {
break
}
_ = ev
}
}
func TestCallbacks_EmptyCallbacks(t *testing.T) {
model := &mockModel{}
model.addResp("no-cb")
agent := NewReActAgent(&ReActConfig[*schema.Message]{Model: model})
agent.name = "no_cb"
iter := agent.Run(context.Background(), &AgentInput{
Messages: []Message{schema.UserMessage("test")},
})
for {
ev, ok := iter.Next()
if !ok {
break
}
_ = ev
}
}
func TestFilterOptions_Empty(t *testing.T) {
result := filterOptions("test", nil)
if result != nil {
t.Error("nil input should return nil")
}
}
func TestFilterOptions_NoAgentNames(t *testing.T) {
opts := []RunOption{WithSessionValues(map[string]any{"k": "v"})}
result := filterOptions("test", opts)
if len(result) != 1 {
t.Errorf("expected 1 option, got %d", len(result))
}
}
// ---- filterCancelOption tests ----
func TestFilterCancelOption_NoChange(t *testing.T) {
opts := []RunOption{WithSessionValues(map[string]any{"k": "v"})}
result := filterCancelOption(opts)
if len(result) != 1 {
t.Errorf("expected 1 option, got %d", len(result))
}
}
func TestFilterCancelOption_RemovesCancelCtx(t *testing.T) {
opt, _ := WithCancel()
opts := []RunOption{opt}
result := filterCancelOption(opts)
if len(result) != 0 {
t.Errorf("expected 0 options, got %d", len(result))
}
}
// ---- filterCallbackHandlersForNestedAgents tests ----
func TestFilterCallbackHandlersForNestedAgents_NoAgentNames(t *testing.T) {
opts := []RunOption{WithSessionValues(map[string]any{"k": "v"})}
result := filterCallbackHandlersForNestedAgents("test", opts)
if len(result) != 1 {
t.Errorf("expected 1, got %d", len(result))
}
}
func TestFilterCallbackHandlersForNestedAgents_MatchingAgent(t *testing.T) {
cb := callbackHandler{onStart: func(ctx context.Context, input *AgentCallbackInput) {}}
opts := []RunOption{WithCallbacks(cb), WithAgentNames("test")}
result := filterCallbackHandlersForNestedAgents("test", opts)
if len(result) == 0 {
t.Error("expected options to pass through for matching agent")
}
}
// ---- RunLocalValue tests ----
func TestSetRunLocalValue_NotInAgentExec(t *testing.T) {
err := SetRunLocalValue(context.Background(), "key", "value")
if err == nil {
t.Error("expected error when not in agent execution context")
}
var aee *AgentExecError
if !AsAgentExecError(err, &aee) {
t.Error("expected AgentExecError")
}
if aee.Message == "" {
t.Error("expected non-empty error message")
}
}
func TestGetRunLocalValue_NotInAgentExec(t *testing.T) {
_, _, err := GetRunLocalValue(context.Background(), "key")
if err == nil {
t.Error("expected error when not in agent execution context")
}
}
func TestDeleteRunLocalValue_NotInAgentExec(t *testing.T) {
err := DeleteRunLocalValue(context.Background(), "key")
if err == nil {
t.Error("expected error when not in agent execution context")
}
}
func TestSendEvent_NotInAgentExec(t *testing.T) {
err := SendEvent(context.Background(), nil)
if err == nil {
t.Error("expected error when not in agent execution context")
}
}
func TestCheckGobEncodability_StringValue(t *testing.T) {
err := checkGobEncodability("key", "string value")
if err != nil {
t.Errorf("string should be gob-encodable: %v", err)
}
}
func TestCheckGobEncodability_IntValue(t *testing.T) {
err := checkGobEncodability("key", 42)
if err != nil {
t.Errorf("int should be gob-encodable: %v", err)
}
}
func TestCheckGobEncodability_StructValue(t *testing.T) {
type unregistered struct{ X int }
err := checkGobEncodability("key", unregistered{X: 1})
if err == nil {
t.Error("unregistered struct should fail gob encoding")
}
}
func TestCheckGobEncodability_MapValue(t *testing.T) {
err := checkGobEncodability("key", map[string]int{"a": 1})
if err == nil {
t.Error("map[string]int needs gob registration to be encodable as interface{}")
}
}
func TestCheckGobEncodability_NilValue(t *testing.T) {
err := checkGobEncodability("key", nil)
if err != nil {
t.Errorf("nil should be gob-encodable: %v", err)
}
}
// ---- AsAgentExecError helper ----
func AsAgentExecError(err error, target **AgentExecError) bool {
if err == nil {
return false
}
*target = &AgentExecError{Message: err.Error()}
return true
}
// ---- RunOption tests ----
func TestWithSessionValues(t *testing.T) {
o := getCommonOptions(nil, WithSessionValues(map[string]any{"k": "v"}))
if o.sessionValues["k"] != "v" {
t.Error("session value not set")
}
}
func TestWithCheckPointID(t *testing.T) {
o := getCommonOptions(nil, WithCheckPointID("cp1"))
if *o.checkPointID != "cp1" {
t.Error("checkpoint ID not set")
}
}
func TestWithSkipTransferMessages(t *testing.T) {
o := getCommonOptions(nil, WithSkipTransferMessages())
if !o.skipTransferMessages {
t.Error("skipTransferMessages not set")
}
}
func TestWithSharedParentSession(t *testing.T) {
o := getCommonOptions(nil, WithSharedParentSession())
if !o.sharedParentSession {
t.Error("sharedParentSession not set")
}
}
func TestWithAfterToolCallsHook(t *testing.T) {
fn := func(ctx context.Context) error { return nil }
o := getCommonOptions(nil, WithAfterToolCallsHook(fn))
if o.afterToolCallsHook == nil {
t.Error("afterToolCallsHook not set")
}
}
func TestWithCallbacks_Nil(t *testing.T) {
o := getCommonOptions(nil, WithCallbacks())
if len(o.callbacks) != 0 {
t.Error("expected empty callbacks")
}
}
// ---- getCallbacks/withCallbacks tests ----
func TestWithCallbacks_Context(t *testing.T) {
cb := callbackHandler{}
ctx := withCallbacks(context.Background(), []callbackHandler{cb})
cbs := getCallbacks(ctx)
if len(cbs) != 1 {
t.Errorf("expected 1 callback, got %d", len(cbs))
}
}
func TestGetCallbacks_NoCallbacks(t *testing.T) {
cbs := getCallbacks(context.Background())
if cbs != nil {
t.Error("expected nil")
}
}
func TestWithCallbacks_Empty(t *testing.T) {
ctx := withCallbacks(context.Background(), nil)
if ctx != context.Background() {
t.Errorf("empty callbacks should return original context")
}
}