mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
### What problem does this PR solve? core module for agent layer built on top of graph engine #16039 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
187 lines
5.7 KiB
Go
187 lines
5.7 KiB
Go
package core
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"ragflow/internal/harness/core/schema"
|
|
"ragflow/internal/harness/graph/constants"
|
|
"ragflow/internal/harness/graph/graph"
|
|
)
|
|
|
|
// TestSubAgentNode_Simple verifies a basic sub-agent node in a StateGraph.
|
|
func TestSubAgentNode_Simple(t *testing.T) {
|
|
m := &mockModel{}
|
|
m.addResp("sub-agent response")
|
|
agent := NewReActAgent(&ReActConfig[*schema.Message]{Model: m}).WithName("worker")
|
|
|
|
sg := graph.NewStateGraph(map[string]interface{}{"Messages": []interface{}{}})
|
|
node := NewSubAgentNode(agent)
|
|
sg.AddNode("worker", node)
|
|
sg.AddEdge(constants.Start, "worker")
|
|
sg.AddEdge("worker", constants.End)
|
|
|
|
cg, err := sg.Compile(graph.WithRecursionLimit(10))
|
|
if err != nil {
|
|
t.Fatalf("Compile: %v", err)
|
|
}
|
|
|
|
result, err := cg.Invoke(context.Background(), map[string]interface{}{
|
|
"Messages": []interface{}{schema.UserMessage("hello from sub-agent test")},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Invoke: %v", err)
|
|
}
|
|
_ = result
|
|
t.Logf("sub-agent node result: %T", result)
|
|
}
|
|
|
|
// TestSubAgentNode_SequentialChain verifies two sub-agent nodes in sequence.
|
|
func TestSubAgentNode_SequentialChain(t *testing.T) {
|
|
m1 := &mockModel{}
|
|
m1.addResp("agent one")
|
|
m2 := &mockModel{}
|
|
m2.addResp("agent two")
|
|
|
|
a1 := NewReActAgent(&ReActConfig[*schema.Message]{Model: m1}).WithName("agent_a")
|
|
a2 := NewReActAgent(&ReActConfig[*schema.Message]{Model: m2}).WithName("agent_b")
|
|
|
|
sg := graph.NewStateGraph(map[string]interface{}{"Messages": []interface{}{}})
|
|
sg.AddNode("agent_a", NewSubAgentNode(a1))
|
|
sg.AddNode("agent_b", NewSubAgentNode(a2))
|
|
sg.AddEdge(constants.Start, "agent_a")
|
|
sg.AddEdge("agent_a", "agent_b")
|
|
sg.AddEdge("agent_b", constants.End)
|
|
|
|
cg, err := sg.Compile(graph.WithRecursionLimit(10))
|
|
if err != nil {
|
|
t.Fatalf("Compile: %v", err)
|
|
}
|
|
|
|
_, err = cg.Invoke(context.Background(), map[string]interface{}{
|
|
"Messages": []interface{}{schema.UserMessage("chain test")},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Invoke: %v", err)
|
|
}
|
|
t.Log("sequential sub-agent chain completed")
|
|
}
|
|
|
|
// TestSubAgentNode_WithFieldMapping verifies field-level input/output projection.
|
|
func TestSubAgentNode_WithFieldMapping(t *testing.T) {
|
|
m := &mockModel{}
|
|
m.addResp("projected result")
|
|
agent := NewReActAgent(&ReActConfig[*schema.Message]{Model: m}).WithName("projector")
|
|
|
|
sg := graph.NewStateGraph(map[string]interface{}{"query": "", "response": "", "Messages": []interface{}{}})
|
|
node := NewSubAgentNode(agent,
|
|
WithSubAgentInput("query", "input"),
|
|
WithSubAgentOutput("response", "response"),
|
|
)
|
|
sg.AddNode("projector", node)
|
|
sg.AddEdge(constants.Start, "projector")
|
|
sg.AddEdge("projector", constants.End)
|
|
|
|
cg, err := sg.Compile(graph.WithRecursionLimit(10))
|
|
if err != nil {
|
|
t.Fatalf("Compile: %v", err)
|
|
}
|
|
|
|
result, err := cg.Invoke(context.Background(), map[string]interface{}{
|
|
"query": "what is go?",
|
|
"response": "",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Invoke: %v", err)
|
|
}
|
|
st, ok := result.(map[string]interface{})
|
|
if !ok {
|
|
t.Fatal("expected map result")
|
|
}
|
|
resp, ok := st["response"].(string)
|
|
if !ok || resp == "" {
|
|
t.Error("expected response field to be populated (OutputMapping should project agent output to state)")
|
|
}
|
|
t.Logf("sub-agent with field mapping: response=%q", resp)
|
|
}
|
|
|
|
// TestSubAgentNode_BuilderCompile verifies SubAgentGraphBuilder compilation
|
|
// with manual edge wiring.
|
|
func TestSubAgentNode_BuilderCompile(t *testing.T) {
|
|
m := &mockModel{}
|
|
m.addResp("builder test")
|
|
agent := NewReActAgent(&ReActConfig[*schema.Message]{Model: m}).WithName("builder_agent")
|
|
|
|
sg := graph.NewStateGraph(map[string]interface{}{"Messages": []interface{}{}})
|
|
sg.AddNode("node1", NewSubAgentNode(agent))
|
|
sg.AddEdge(constants.Start, "node1")
|
|
sg.AddEdge("node1", constants.End)
|
|
|
|
cg, err := sg.Compile(graph.WithRecursionLimit(10))
|
|
if err != nil {
|
|
t.Fatalf("Compile: %v", err)
|
|
}
|
|
if cg == nil {
|
|
t.Fatal("expected non-nil compiled graph")
|
|
}
|
|
t.Log("builder compile passed")
|
|
}
|
|
|
|
// TestSubAgentNode_WithSubAgentName verifies name override.
|
|
func TestSubAgentNode_WithSubAgentName(t *testing.T) {
|
|
m := &mockModel{}
|
|
m.addResp("named agent")
|
|
agent := NewReActAgent(&ReActConfig[*schema.Message]{Model: m}).WithName("original_name")
|
|
|
|
sg := graph.NewStateGraph(map[string]interface{}{"Messages": []interface{}{}})
|
|
node := NewSubAgentNode(agent, WithSubAgentName("custom_name"))
|
|
sg.AddNode("custom_name", node)
|
|
sg.AddEdge(constants.Start, "custom_name")
|
|
sg.AddEdge("custom_name", constants.End)
|
|
|
|
cg, err := sg.Compile(graph.WithRecursionLimit(10))
|
|
if err != nil {
|
|
t.Fatalf("Compile: %v", err)
|
|
}
|
|
|
|
_, err = cg.Invoke(context.Background(), map[string]interface{}{
|
|
"Messages": []interface{}{schema.UserMessage("name test")},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Invoke: %v", err)
|
|
}
|
|
t.Log("named sub-agent node completed")
|
|
}
|
|
|
|
// TestSubAgentNode_CustomExtractor verifies custom input extractor.
|
|
func TestSubAgentNode_CustomExtractor(t *testing.T) {
|
|
m := &mockModel{}
|
|
m.addResp("custom extractor ok")
|
|
agent := NewReActAgent(&ReActConfig[*schema.Message]{Model: m}).WithName("extractor_test")
|
|
|
|
sg := graph.NewStateGraph(map[string]interface{}{"data": "", "Messages": []interface{}{}})
|
|
node := NewSubAgentNode(agent,
|
|
WithSubAgentExtractor(func(ctx context.Context, state interface{}) (*AgentInput, error) {
|
|
return &AgentInput{
|
|
Messages: []*schema.Message{schema.UserMessage("custom input")},
|
|
}, nil
|
|
}),
|
|
)
|
|
sg.AddNode("extractor", node)
|
|
sg.AddEdge(constants.Start, "extractor")
|
|
sg.AddEdge("extractor", constants.End)
|
|
|
|
cg, err := sg.Compile(graph.WithRecursionLimit(10))
|
|
if err != nil {
|
|
t.Fatalf("Compile: %v", err)
|
|
}
|
|
|
|
_, err = cg.Invoke(context.Background(), map[string]interface{}{
|
|
"data": "some data",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Invoke: %v", err)
|
|
}
|
|
t.Log("custom extractor sub-agent completed")
|
|
}
|