Files

1832 lines
50 KiB
Go

// Package pregel provides the Pregel execution algorithm for graph processing.
package pregel
import (
"context"
"encoding/json"
"fmt"
"reflect"
"sort"
"strings"
"sync"
"github.com/google/uuid"
"go.uber.org/zap"
"ragflow/internal/common"
"ragflow/internal/harness/graph/channels"
"ragflow/internal/harness/graph/checkpoint"
"ragflow/internal/harness/graph/constants"
"ragflow/internal/harness/graph/errors"
"ragflow/internal/harness/graph/interrupt"
"ragflow/internal/harness/graph/types"
)
// Engine implements the Pregel (bulk-synchronous parallel) execution model
// for StateGraph. It manages channel-based state communication, concurrent
// task execution via AsyncPipeline, streaming event emission, and checkpoint
// persistence.
//
// Create an Engine via NewEngine with option functions:
//
// engine := NewEngine(graph,
// WithCheckpointer(cp),
// WithRecursionLimit(50),
// )
type Engine struct {
graph types.StateGraph
checkpointer checkpoint.BaseCheckpointer
interrupts map[string]bool
interruptsAfter map[string]bool
recursionLimit int
debug bool
config *types.RunnableConfig
maxConcurrency int
retryPolicy *types.RetryPolicy
currentCheckpoint *checkpoint.Checkpoint
channelVersions map[string]int
versionsSeen map[string]map[string]int
cache Cache
backgroundExec *BackgroundExecutor
deferredCheckpoints []deferredCheckpoint // for DurabilityExit mode
}
// deferredCheckpoint stores checkpoint data for deferred saving (DurabilityExit mode)
type deferredCheckpoint struct {
ThreadID string
CheckpointID string
Step int
Checkpoint map[string]any
}
// NewEngine creates a new Pregel engine bound to a StateGraph.
// Options configure checkpointer, recursion limit, concurrency, retry, cache, etc.
//
// The engine is reusable across multiple Run calls. Each call creates its own
// background executor for isolation.
func NewEngine(g types.StateGraph, opts ...EngineOption) *Engine {
eng := &Engine{
graph: g,
interrupts: make(map[string]bool),
interruptsAfter: make(map[string]bool),
recursionLimit: 25,
debug: false,
config: types.NewRunnableConfig(),
maxConcurrency: 10,
retryPolicy: nil,
channelVersions: make(map[string]int),
versionsSeen: make(map[string]map[string]int),
cache: &NoopCache{},
}
for _, opt := range opts {
opt(eng)
}
// Initialize background executor if not already set
if eng.backgroundExec == nil {
eng.backgroundExec = NewBackgroundExecutor(eng.maxConcurrency, 100)
}
return eng
}
// EngineOption is an option for configuring the Pregel engine.
// Available options: WithCheckpointer, WithInterrupts, WithRecursionLimit,
// WithDebug, WithConfig, WithMaxConcurrency, WithRetryPolicy, WithCache,
// WithBackgroundExecutor.
type EngineOption func(*Engine)
// WithCheckpointer sets the checkpointer.
func WithCheckpointer(cp checkpoint.BaseCheckpointer) EngineOption {
return func(e *Engine) {
e.checkpointer = cp
}
}
// WithInterrupts sets the interrupt nodes.
func WithInterrupts(nodes ...string) EngineOption {
return func(e *Engine) {
for _, node := range nodes {
e.interrupts[node] = true
}
}
}
// WithInterruptsAfter sets the after-execution interrupt nodes.
func WithInterruptsAfter(nodes ...string) EngineOption {
return func(e *Engine) {
for _, node := range nodes {
e.interruptsAfter[node] = true
}
}
}
// WithRecursionLimit sets the recursion limit.
func WithRecursionLimit(limit int) EngineOption {
return func(e *Engine) {
e.recursionLimit = limit
}
}
// WithDebug enables debug mode.
func WithDebug(debug bool) EngineOption {
return func(e *Engine) {
e.debug = debug
}
}
// WithConfig sets the runnable config.
func WithConfig(cfg *types.RunnableConfig) EngineOption {
return func(e *Engine) {
e.config = cfg
}
}
// WithMaxConcurrency sets the maximum concurrency for node execution.
func WithMaxConcurrency(max int) EngineOption {
return func(e *Engine) {
if max > 0 {
e.maxConcurrency = max
}
}
}
// WithRetryPolicy sets the retry policy for node execution.
func WithRetryPolicy(policy *types.RetryPolicy) EngineOption {
return func(e *Engine) {
e.retryPolicy = policy
}
}
// WithCache sets the cache for the engine.
func WithCache(cache Cache) EngineOption {
return func(e *Engine) {
e.cache = cache
}
}
// WithBackgroundExecutor sets the background executor for the engine.
func WithBackgroundExecutor(exec *BackgroundExecutor) EngineOption {
return func(e *Engine) {
e.backgroundExec = exec
}
}
// ExecuteResult represents the result of graph execution.
type ExecuteResult struct {
// Final state of the graph.
State any
// Checkpoint ID for this execution.
CheckpointID string
// Metadata about the execution.
Metadata map[string]any
}
// Run executes the graph using the Pregel algorithm and returns streaming events.
// outputCh yields StreamEvent values (checkpoints, task start/end, state updates,
// and a final event with the complete state). errCh receives a single error on failure
// or nil on clean completion.
//
// The caller MUST read from outputCh until it is closed to prevent goroutine leaks.
// For synchronous execution, use RunSync instead.
func (e *Engine) Run(ctx context.Context, input any, mode types.StreamMode) (<-chan any, <-chan error) {
outputCh := make(chan any, 100)
errCh := make(chan error, 1)
go func() {
defer close(errCh)
// Create stream manager for event streaming
streamManager := NewStreamManager(mode, 100)
// WaitGroup ensures the forward goroutine exits before we close outputCh,
// preventing a data race between close(outputCh) and outputCh <- event.
var fwWg sync.WaitGroup
// Forward stream events to output channel
fwWg.Go(func() {
for event := range streamManager.Events() {
select {
case outputCh <- event:
case <-ctx.Done():
return
}
}
})
// Deferred cleanup: close streamManager first (unblocks forward goroutine),
// then wait for forward goroutine to exit, then close outputCh.
defer func() {
streamManager.Close()
fwWg.Wait()
close(outputCh)
}()
// Create async pipeline for concurrent task execution
retryPolicy := e.retryPolicy
if retryPolicy == nil {
defaultPolicy := types.DefaultRetryPolicy()
retryPolicy = &defaultPolicy
}
asyncPipeline := NewAsyncPipeline(e.maxConcurrency, retryPolicy)
pipelineCtx := asyncPipeline.Start(ctx)
defer asyncPipeline.Stop()
// Reset per-execution engine state.
// Without this, reusing the same Engine across multiple RunSync calls
// causes checkpoint maps and channel versions to accumulate indefinitely,
// leading to unbounded memory growth (soak tests exposed this).
e.currentCheckpoint = nil
e.channelVersions = make(map[string]int)
e.versionsSeen = make(map[string]map[string]int)
e.deferredCheckpoints = nil
// Initialize channels
channelRegistry := channels.NewRegistry()
graphChannels := e.getGraphChannels()
for name, ch := range graphChannels {
channelRegistry.Register(name, ch.Copy())
}
// Apply input to channels
if err := e.applyInput(channelRegistry, input); err != nil {
errCh <- fmt.Errorf("failed to apply input: %w", err)
return
}
// Get thread ID for checkpointing
threadID := e.getThreadID()
// Load checkpoint when one exists for this thread_id, even when
// input is non-nil (resume from a previous run). The canvas
// always passes a non-nil input ({"query": ...}) on resume, so
// a strict input==nil guard would prevent checkpoint recovery.
// We only skip checkpoint loading if the checkpointer reports
// no data (fresh start).
// When a checkpoint IS loaded, do NOT apply input — the
// channel values from the checkpoint already contain the
// state at the point of interruption.
var (
didLoadCheckpoint bool
cpCompletedTasks map[string]bool
cpLastCompletedNode string
cpData map[string]any
)
if e.checkpointer != nil {
var cpErr error
cpConfig := map[string]any{
constants.ConfigKeyThreadID: threadID,
}
// Support loading a specific checkpoint_id for replay/fork.
var requestedCPID string
if e.config != nil && e.config.Configurable != nil {
if cpid, ok := e.config.Configurable[constants.ConfigKeyCheckpointID]; ok {
if cpidStr, ok := cpid.(string); ok && cpidStr != "" {
cpConfig[constants.ConfigKeyCheckpointID] = cpidStr
requestedCPID = cpidStr
}
}
}
cpData, cpErr = e.checkpointer.Get(ctx, cpConfig)
// When a specific checkpoint_id was requested, fail on missing data.
if requestedCPID != "" && (cpErr != nil || cpData == nil) {
cpErrMsg := "checkpoint not found"
if cpErr != nil {
cpErrMsg = cpErr.Error()
}
errCh <- fmt.Errorf("requested checkpoint_id %s: %s", requestedCPID, cpErrMsg)
return
}
if cpErr == nil && cpData != nil {
didLoadCheckpoint = true
common.Debug("LOOP_CHECK: loaded checkpoint",
zap.String("thread", threadID),
zap.Bool("has_sub", cpData["__sub_state__"] != nil))
// Restore sub-state (e.g. Loop iteration, currentInput)
// and inject into interrupt context so Loop node can
// read it via loadLoopSnapshot on resume.
if raw, ok := cpData["__sub_state__"]; ok {
switch v := raw.(type) {
case []byte:
pipelineCtx = context.WithValue(pipelineCtx, interrupt.SubGraphStateCtxKey, v)
case string:
pipelineCtx = context.WithValue(pipelineCtx, interrupt.SubGraphStateCtxKey, []byte(v))
}
}
// Restore completed task tracking.
if raw, ok := cpData["__completed_tasks__"]; ok {
if str, ok := raw.(string); ok {
cpCompletedTasks = deserializeStringSet(str)
}
}
if raw, ok := cpData["__last_completed_node__"]; ok {
if str, ok := raw.(string); ok {
cpLastCompletedNode = str
}
}
// Only restore keys that correspond to registered channels.
filtered := make(map[string]any)
for key, val := range cpData {
if _, ok := channelRegistry.Get(key); ok {
filtered[key] = val
}
}
if len(filtered) > 0 {
if err := channelRegistry.RestoreFromCheckpoint(filtered); err != nil {
errCh <- fmt.Errorf("failed to restore from checkpoint: %w", err)
return
}
}
if cp, err := checkpoint.FromMap(cpData); err == nil {
e.currentCheckpoint = cp
}
}
}
// Apply input only when no checkpoint was loaded.
if !didLoadCheckpoint {
if err := e.applyInput(channelRegistry, input); err != nil {
errCh <- fmt.Errorf("failed to apply input: %w", err)
return
}
}
// Initialize new checkpoint if none exists
if e.currentCheckpoint == nil {
e.currentCheckpoint = checkpoint.NewCheckpoint(threadID, 0)
}
// Create per-run background executor (not shared, so concurrent calls are safe)
backgroundExec := NewBackgroundExecutor(e.maxConcurrency, 100)
backgroundExec.Start(ctx)
defer backgroundExec.Stop()
// Replace engine-level backgroundExec reference for use by async pipeline
e.backgroundExec = backgroundExec
// Execute Pregel loop
step := 0
var completedTasks map[string]bool
lastCompletedNode := cpLastCompletedNode
if didLoadCheckpoint && cpCompletedTasks != nil {
completedTasks = cpCompletedTasks
} else {
completedTasks = make(map[string]bool)
}
var lastState any
if didLoadCheckpoint {
if raw, ok := cpData["__last_state__"]; ok {
var jsonBytes []byte
switch val := raw.(type) {
case string:
jsonBytes = []byte(val)
case []byte:
jsonBytes = val
}
if jsonBytes != nil {
var decoded map[string]any
if json.Unmarshal(jsonBytes, &decoded) == nil {
lastState = decoded
}
}
}
} else {
lastState = input
}
for {
// Check context cancellation at each superstep.
select {
case <-ctx.Done():
errCh <- ctx.Err()
return
default:
}
// Check recursion limit
if step >= e.recursionLimit {
errCh <- &errors.GraphRecursionError{Limit: e.recursionLimit}
return
}
// Emit checkpoint event via stream manager
streamManager.EmitCheckpoint(step, channelRegistry.CreateCheckpoint())
// Determine next tasks
tasks, triggers, err := e.prepareNextTasks(ctx, channelRegistry, completedTasks, lastCompletedNode, lastState)
if err != nil {
errCh <- fmt.Errorf("failed to prepare next tasks: %w", err)
return
}
// Emit task start events
for _, task := range tasks {
streamManager.EmitTaskStart(step, task.Name, task.ID)
}
// If no tasks, we're done
if len(tasks) == 0 {
break
}
// Check for interrupts
interruptedTasks := e.shouldInterrupt(channelRegistry, tasks, triggers)
if len(interruptedTasks) > 0 {
// Save checkpoint
if e.checkpointer != nil {
checkpoint := channelRegistry.CreateCheckpoint()
if err := e.checkpointer.Put(ctx, map[string]any{
constants.ConfigKeyThreadID: threadID,
}, checkpoint); err != nil {
errCh <- fmt.Errorf("failed to save checkpoint: %w", err)
return
}
}
// Emit interrupt event
interruptNames := make([]string, len(interruptedTasks))
for i, task := range interruptedTasks {
interruptNames[i] = task.Name
}
streamManager.EmitInterrupt(step, interruptNames)
errCh <- &errors.GraphInterrupt{}
return
}
// Execute tasks using async pipeline
results, err := e.executeTasksAsync(pipelineCtx, tasks, channelRegistry, asyncPipeline, streamManager, step)
if err != nil {
errCh <- fmt.Errorf("failed to execute tasks: %w", err)
return
}
// Mark tasks as completed and track last state
allFailed := len(results) > 0
var interruptTaskNames []string
for _, result := range results {
if errors.IsGraphInterrupt(result.Err) {
interruptTaskNames = append(interruptTaskNames, result.Name)
continue
}
if result.Err == nil {
allFailed = false
completedTasks[result.Name] = true
lastCompletedNode = result.Name
// Merge result into lastState
lastState = e.mergeStates(lastState, result.Output)
}
}
// If any task was interrupted, handle the interrupt.
if len(interruptTaskNames) > 0 {
common.Debug("engine interrupt path",
zap.Int("step", step),
zap.Strings("tasks", interruptTaskNames),
zap.Bool("allFailed", allFailed))
// Save checkpoint with completed_tasks and sub_state.
if e.checkpointer != nil {
checkpointData := channelRegistry.CreateCheckpoint()
cpPayload := make(map[string]any, len(checkpointData)+4)
for key, val := range checkpointData {
cpPayload[key] = val
}
cpPayload["__completed_tasks__"] = serializeStringSet(completedTasks)
cpPayload["__last_completed_node__"] = lastCompletedNode
cpPayload["__step__"] = float64(step)
// Persist lastState as string (not []byte) to avoid
// JSON double-base64-encoding when the checkpointer
// adapter serializes the whole payload.
if lastState != nil {
if ls, err := json.Marshal(lastState); err == nil {
cpPayload["__last_state__"] = string(ls)
}
}
// Extract sub-state from GraphInterrupt value.
for _, r := range results {
if gi, ok := r.Err.(*errors.GraphInterrupt); ok && len(gi.Interrupts) > 0 {
if intr, ok := gi.Interrupts[0].(*types.Interrupt); ok && intr.Value != nil {
if b, e := json.Marshal(intr.Value); e == nil {
cpPayload["__sub_state__"] = b
}
}
break
}
}
if err := e.checkpointer.Put(ctx, map[string]any{
constants.ConfigKeyThreadID: threadID,
}, cpPayload); err != nil {
errCh <- fmt.Errorf("failed to save checkpoint on interrupt: %w", err)
return
}
}
streamManager.EmitInterrupt(step, interruptTaskNames)
// Preserve the first interrupted task's GraphInterrupt value
// (with Interrupts populated) instead of creating a bare one,
// so MustExtractInterruptContexts can extract the original
// UserFillUp spec / tips / cpn_id from it.
for _, r := range results {
if gi, ok := r.Err.(*errors.GraphInterrupt); ok && len(gi.Interrupts) > 0 {
errCh <- gi
return
}
}
errCh <- &errors.GraphInterrupt{}
return
}
// If every task in this step failed, the graph cannot make progress.
// Terminate immediately rather than infinitely re-scheduling the
// same failing nodes (e.g. a panicking node caught by recover()).
if allFailed {
var why string
for _, r := range results {
why += fmt.Sprintf(" %s=%T(%v)", r.Name, r.Err, r.Err)
}
common.Debug("allFailed",
zap.Int("step", step),
zap.String("results", why))
errCh <- fmt.Errorf("all %d tasks failed in step %d: %s", len(results), step, why)
return
}
// Apply writes to channels
_, err = e.applyWrites(channelRegistry, results, triggers)
if err != nil {
errCh <- fmt.Errorf("failed to apply writes: %w", err)
return
}
// Emit values event
if values, err := channelRegistry.GetValues(); err == nil {
streamManager.EmitValues(step, values)
}
// Save checkpoint based on durability mode
if e.checkpointer != nil {
checkpoint := channelRegistry.CreateCheckpoint()
checkpointID := uuid.New().String()
switch e.config.Durability {
case types.DurabilitySync:
// Synchronous save - block until complete
if err := e.saveCheckpoint(ctx, threadID, checkpointID, step, checkpoint); err != nil {
errCh <- fmt.Errorf("failed to save checkpoint: %w", err)
return
}
case types.DurabilityAsync:
// Asynchronous save - don't block next step
go func(cp map[string]any, cpID string, s int) {
if err := e.saveCheckpoint(context.Background(), threadID, cpID, s, cp); err != nil {
// Log async error but don't fail execution
common.Error("async checkpoint save failed", err, zap.String("thread_id", threadID), zap.String("checkpoint_id", cpID), zap.Int("step", s))
}
}(checkpoint, checkpointID, step)
case types.DurabilityExit:
// Defer save until exit - accumulate checkpoints in memory
// Will be saved in final state
e.deferCheckpoint(threadID, checkpointID, step, checkpoint)
default:
// Default to sync behavior
if err := e.saveCheckpoint(ctx, threadID, checkpointID, step, checkpoint); err != nil {
errCh <- fmt.Errorf("failed to save checkpoint: %w", err)
return
}
}
}
// Check for after-node interrupts. The checkpoint above already
// captures this step's output.
if e.shouldInterruptAfter(results) {
errCh <- &errors.GraphInterrupt{}
return
}
step++
}
// Get final state
finalState, err := e.buildOutput(channelRegistry, lastState)
if err != nil {
errCh <- fmt.Errorf("failed to build output: %w", err)
return
}
// Save deferred checkpoints for DurabilityExit mode
if e.config.Durability == types.DurabilityExit {
if err := e.saveDeferredCheckpoints(ctx); err != nil {
errCh <- fmt.Errorf("failed to save deferred checkpoints: %w", err)
return
}
}
// Emit final event
streamManager.EmitFinal(step, finalState)
}()
return outputCh, errCh
}
// prepareNextTasks determines which tasks to execute next.
// This is the standard version that prepares tasks for execution.
func (e *Engine) prepareNextTasks(
ctx context.Context,
registry *channels.Registry,
completedTasks map[string]bool,
lastCompletedNode string,
currentState any,
) ([]*Task, map[string]struct{}, error) {
return e.prepareNextTasksWithMode(ctx, registry, completedTasks, lastCompletedNode, currentState, true)
}
// prepareNextTasksWithMode determines which tasks to execute next with for_execution mode.
// When forExecution is true, tasks are prepared for actual execution.
// When forExecution is false, only task information is prepared (for inspection/planning).
//
// In AllPredecessor (DAG) mode, a node is triggered only when ALL of its incoming edges'
// source nodes have completed. In AnyPredecessor (Pregel/BSP) mode (default), a node is
// triggered when any predecessor completes. AllPredecessor does not support cycles.
func (e *Engine) prepareNextTasksWithMode(
ctx context.Context,
registry *channels.Registry,
completedTasks map[string]bool,
lastCompletedNode string,
currentState any,
forExecution bool,
) ([]*Task, map[string]struct{}, error) {
tasks := make([]*Task, 0)
triggerToNodes := make(map[string]struct{})
// If this is the first step
if len(completedTasks) == 0 {
entryPoint := e.getEntryPoint()
if entryPoint == "" {
return nil, nil, fmt.Errorf("no entry point set")
}
// Handle direct edge Start → End (empty/trivial graph)
if entryPoint == constants.End {
return tasks, triggerToNodes, nil
}
node := e.getNode(entryPoint)
if node == nil {
return nil, nil, &errors.NodeNotFoundError{NodeName: entryPoint}
}
// Pass node Triggers as task Channels so the first task reads from
// registered channels rather than receiving a nil state.
// When the entry point has no explicit triggers, use all registered
// channel names so it can read the initial input values. This is
// needed by systems (e.g. canvas) that route data via context
// rather than channels but still register input channels for
// the engine's input validation.
triggers := e.getTriggers(node)
if len(triggers) == 0 {
triggers = registry.Names()
}
task := e.createTask(node, currentState, triggers, []string{})
tasks = append(tasks, task)
triggerToNodes["__start__"] = struct{}{}
return tasks, triggerToNodes, nil
}
// AllPredecessor (DAG) mode: scan all uncompleted nodes and check if
// ALL of their incoming-edge source nodes have completed.
if e.graph.GetNodeTriggerMode() == types.NodeTriggerAllPredecessor {
return e.prepareNextTasksDAG(completedTasks, currentState, forExecution)
}
// AnyPredecessor (Pregel/BSP) mode: determine next nodes from the
// last completed node's outgoing edges.
nextNodes := e.getNextNodes(ctx, lastCompletedNode, currentState)
for nodeName := range nextNodes {
node := e.getNode(nodeName)
if node == nil {
continue
}
// Determine triggers for this node
triggers := e.getTriggers(node)
if len(triggers) == 0 {
triggers = registry.Names()
}
// BSP mode: always schedule, even if previously completed (supports loops).
var task *Task
if forExecution {
task = e.createTask(node, currentState, triggers, []string{})
} else {
task = e.createTaskInfo(node, currentState, triggers, []string{})
}
tasks = append(tasks, task)
// Build trigger to nodes mapping
for _, trigger := range triggers {
triggerToNodes[trigger] = struct{}{}
}
}
return tasks, triggerToNodes, nil
}
// prepareNextTasksDAG prepares tasks in DAG (AllPredecessor) mode.
// It scans all nodes and schedules those whose incoming-edge sources
// have all completed. This is O(n) per call but correct for fan-in patterns.
func (e *Engine) prepareNextTasksDAG(
completedTasks map[string]bool,
currentState any,
forExecution bool,
) ([]*Task, map[string]struct{}, error) {
tasks := make([]*Task, 0)
triggerToNodes := make(map[string]struct{})
// Build reverse adjacency: for each node, which nodes have edges TO it.
incomingEdges := e.buildIncomingEdges()
for _, node := range e.graph.GetNodes() {
n := e.getNode(node.Name)
if n == nil {
continue
}
if completedTasks[node.Name] {
continue
}
// Check if all incoming-edge sources have completed.
predecessors := incomingEdges[node.Name]
allDone := true
for _, pred := range predecessors {
// constants.Start and constants.End are always considered completed.
if pred == constants.Start || pred == constants.End {
continue
}
if !completedTasks[pred] {
allDone = false
break
}
}
// Nodes with no incoming edges (beyond start) can run.
if !allDone {
continue
}
triggers := e.getTriggers(n)
if len(triggers) == 0 {
chMap := e.graph.GetChannels()
triggers = make([]string, 0, len(chMap))
for name := range chMap {
triggers = append(triggers, name)
}
}
var task *Task
if forExecution {
task = e.createTask(n, currentState, triggers, []string{})
} else {
task = e.createTaskInfo(n, currentState, triggers, []string{})
}
tasks = append(tasks, task)
for _, trigger := range triggers {
triggerToNodes[trigger] = struct{}{}
}
}
// No tasks means all reachable nodes are done.
return tasks, triggerToNodes, nil
}
// buildIncomingEdges builds a reverse-adjacency map: node → list of nodes with edges TO it.
func (e *Engine) buildIncomingEdges() map[string][]string {
adj := make(map[string][]string)
for _, edge := range e.graph.GetEdges() {
adj[edge.To] = append(adj[edge.To], edge.From)
}
return adj
}
// shouldInterrupt checks if graph should be interrupted.
func (e *Engine) shouldInterrupt(
registry *channels.Registry,
tasks []*Task,
triggerToNodes map[string]struct{},
) []*Task {
interrupted := make([]*Task, 0)
if len(e.interrupts) == 0 {
return interrupted
}
interruptAll := e.interrupts[types.All]
for _, task := range tasks {
if interruptAll || e.interrupts[task.Name] {
interrupted = append(interrupted, task)
}
}
return interrupted
}
// shouldInterruptAfter checks if any SUCCESSFULLY completed task's node name
// is in interruptsAfter. Called AFTER execution and checkpoint save so the
// checkpoint already captures the node's output.
func (e *Engine) shouldInterruptAfter(results []*TaskResult) bool {
if len(e.interruptsAfter) == 0 {
return false
}
interruptAll := e.interruptsAfter[types.All]
for _, r := range results {
if r.Err != nil {
continue
}
if interruptAll || e.interruptsAfter[r.Name] {
return true
}
}
return false
}
// applyWrites applies task outputs to channels with version management and write merging.
func (e *Engine) applyWrites(
registry *channels.Registry,
results []*TaskResult,
triggerToNodes map[string]struct{},
) (map[string]struct{}, error) {
updatedChannels := make(map[string]struct{})
// Sort results for deterministic order
sort.Slice(results, func(i, j int) bool {
return results[i].Name < results[j].Name
})
// Group writes by channel with write merging
writesByChannel := make(map[string][]any)
pendingWrites := make(map[string]*checkpoint.PendingWrite)
for _, result := range results {
if result.Err != nil {
continue
}
// Skip nil outputs (node returned nil, nil — no state update)
if result.Output == nil {
continue
}
// Convert output to map of writes
outputMap, err := toMap(result.Output)
if err != nil {
return nil, fmt.Errorf("failed to convert output to map: %w", err)
}
// Apply FieldMapping if the node has field-level routing configured.
if node := e.getNode(result.Name); node != nil && len(node.FieldMapping) > 0 {
outputMap = applyFieldMapping(outputMap, node.FieldMapping)
}
for key, value := range outputMap {
// Skip nil values
if value == nil {
continue
}
// Check for Overwrite wrapper
overwrite := false
if ow, ok := value.(*types.Overwrite); ok {
value = ow.Value
overwrite = true
}
// Add to writes
writesByChannel[key] = append(writesByChannel[key], value)
// Track pending write
pendingWrites[key] = &checkpoint.PendingWrite{
Channel: key,
Value: value,
Overwrite: overwrite,
Node: result.Name,
}
}
}
// Apply writes to channels with version management
for channelName, values := range writesByChannel {
ch, ok := registry.Get(channelName)
if !ok {
// Auto-create a LastValue channel for map-based schemas where no
// channels were pre-configured (e.g. map[string]any{} schema).
newCh := channels.NewLastValue(nil)
registry.Register(channelName, newCh)
ch = newCh
}
// Filter out nil values
filtered := make([]any, 0, len(values))
for _, val := range values {
if val != nil {
filtered = append(filtered, val)
}
}
// When multiple values target a LastValue channel in the same step
// (star-topology pattern), keep only the last value to avoid channel
// conflict errors. BinaryOperatorAggregate and ReducerChannel handle
// multiple writes via their accumulator logic.
if len(filtered) > 1 {
_, isBO := ch.(*channels.BinaryOperatorAggregate)
_, isRC := ch.(*channels.ReducerChannel)
if !isBO && !isRC {
last := filtered[len(filtered)-1]
filtered = filtered[:1]
filtered[0] = last
}
}
// Update channel
updated, err := ch.Update(filtered)
if err != nil {
return nil, fmt.Errorf("failed to update channel %s: %w", channelName, err)
}
if updated && ch.IsAvailable() {
updatedChannels[channelName] = struct{}{}
// Increment channel version (engine-level tracking).
e.channelVersions[channelName]++
// Also bump the version on the channel itself for ChannelChangedTrigger.
if vc, ok := ch.(interface{ SetVersion(int) }); ok {
vc.SetVersion(e.channelVersions[channelName])
}
// Update checkpoint if available
if e.currentCheckpoint != nil {
e.currentCheckpoint.IncrementChannel(channelName)
}
}
}
// Store pending writes to checkpoint
if e.currentCheckpoint != nil {
for _, pw := range pendingWrites {
e.currentCheckpoint.AddPendingWrite(pw.Channel, pw.Value, pw.Overwrite, pw.Node)
}
}
// Mark channels as seen by nodes
for resultName := range writesByChannel {
if _, ok := triggerToNodes[resultName]; ok {
for channelName := range updatedChannels {
e.markSeen(resultName, channelName)
}
}
}
return updatedChannels, nil
}
// markSeen marks that a node has seen a channel's version.
func (e *Engine) markSeen(node, channel string) {
if e.versionsSeen[node] == nil {
e.versionsSeen[node] = make(map[string]int)
}
e.versionsSeen[node][channel] = e.channelVersions[channel]
if e.currentCheckpoint != nil {
e.currentCheckpoint.MarkSeen(node, channel)
}
}
// hasSeen checks if a node has seen a channel's current version.
func (e *Engine) hasSeen(node, channel string) bool {
if versions, ok := e.versionsSeen[node]; ok {
if version, ok := versions[channel]; ok {
return version == e.channelVersions[channel]
}
}
return false
}
// executeTasks executes the given tasks concurrently.
func (e *Engine) executeTasks(
ctx context.Context,
tasks []*Task,
registry *channels.Registry,
) ([]*TaskResult, error) {
results := make([]*TaskResult, len(tasks))
var wg sync.WaitGroup
var mu sync.Mutex
for i, task := range tasks {
wg.Add(1)
go func(idx int, t *Task) {
defer wg.Done()
result := e.executeTask(ctx, t, registry)
mu.Lock()
results[idx] = result
mu.Unlock()
}(i, task)
}
wg.Wait()
return results, nil
}
// executeTasksAsync executes tasks using async pipeline with streaming.
func (e *Engine) executeTasksAsync(
ctx context.Context,
tasks []*Task,
registry *channels.Registry,
asyncPipeline *AsyncPipeline,
streamManager *StreamManager,
step int,
) ([]*TaskResult, error) {
results := make([]*TaskResult, len(tasks))
var wg sync.WaitGroup
var mu sync.Mutex
for i, task := range tasks {
wg.Add(1)
go func(idx int, t *Task) {
defer wg.Done()
// Read input for this task
input, err := e.readTaskInput(registry, t)
if err != nil {
mu.Lock()
results[idx] = &TaskResult{
Name: t.Name,
Err: fmt.Errorf("failed to read task input: %w", err),
}
mu.Unlock()
return
}
// Convert map input to struct type if state schema is a struct
convertedInput := e.mapToStateSchema(input)
// Define the function to execute
executeFn := func(ctx context.Context) (any, error) {
return t.Func(ctx, convertedInput)
}
// Use task's retry policy or default
retryPolicy := t.RetryPolicy
if retryPolicy == nil {
defaultPolicy := types.DefaultRetryPolicy()
retryPolicy = &defaultPolicy
}
// Execute with async pipeline
resultCh := asyncPipeline.ExecuteNode(ctx, t.Name, executeFn, &RetryConfig{Policy: retryPolicy})
// Wait for result
select {
case <-ctx.Done():
mu.Lock()
results[idx] = &TaskResult{
Name: t.Name,
Err: ctx.Err(),
}
mu.Unlock()
case asyncResult, ok := <-resultCh:
if !ok {
mu.Lock()
results[idx] = &TaskResult{
Name: t.Name,
Err: fmt.Errorf("async result channel closed unexpectedly"),
}
mu.Unlock()
return
}
// Convert async result to task result
taskResult := &TaskResult{
Name: t.Name,
Output: asyncResult.Output,
Err: asyncResult.Err,
}
// Emit task end event
streamManager.EmitTaskEnd(step, t.Name, t.ID, asyncResult.Output, asyncResult.Duration, asyncResult.Err)
// Emit update event if successful
if asyncResult.Err == nil {
streamManager.EmitUpdate(step, t.Name, asyncResult.Output)
} else {
// Emit error event
streamManager.EmitError(step, asyncResult.Err, t.Name)
}
mu.Lock()
results[idx] = taskResult
mu.Unlock()
}
}(i, task)
}
wg.Wait()
return results, nil
}
// executeTask executes a single task with retry logic.
func (e *Engine) executeTask(
ctx context.Context,
task *Task,
registry *channels.Registry,
) *TaskResult {
// Read input for this task
input, err := e.readTaskInput(registry, task)
if err != nil {
return &TaskResult{
Name: task.Name,
Err: fmt.Errorf("failed to read task input: %w", err),
}
}
// Convert map input to struct type if the state schema is a struct
input = e.mapToStateSchema(input)
// Use RetryExecutor for retry logic
retryPolicy := task.RetryPolicy
if retryPolicy == nil {
defaultPolicy := types.DefaultRetryPolicy()
retryPolicy = &defaultPolicy
}
retryExecutor := NewRetryExecutor(retryPolicy)
// Define the function to execute
executeFn := func(ctx context.Context) (any, error) {
return task.Func(ctx, input)
}
// Execute with retry
output, err := retryExecutor.Execute(ctx, task.Name, executeFn)
if err != nil {
// Check if it's a retry exhausted error
if IsRetryExhausted(err) {
return &TaskResult{
Name: task.Name,
Err: fmt.Errorf("max retries exceeded: %w", err),
}
}
// Check for interrupt
if errors.IsGraphInterrupt(err) {
return &TaskResult{
Name: task.Name,
Err: err,
}
}
// Other errors
return &TaskResult{
Name: task.Name,
Err: err,
}
}
// Success
return &TaskResult{
Name: task.Name,
Output: output,
Err: nil,
}
}
// readTaskInput reads the input for a task from channels.
// mapToStateSchema converts a map[string]any state to the graph's state schema
// type if it is a struct (or pointer to struct). If the schema is a map or
// nil, the map input is returned as-is.
func (e *Engine) mapToStateSchema(input any) any {
if input == nil {
return nil
}
inputMap, ok := input.(map[string]any)
if !ok {
return input
}
schema := e.graph.GetStateSchema()
if schema == nil {
return inputMap
}
rv := reflect.ValueOf(schema)
for rv.Kind() == reflect.Ptr {
rv = rv.Elem()
}
if rv.Kind() != reflect.Struct {
return inputMap
}
// State schema is a struct (possibly wrapped in pointer): create a new
// instance and populate fields from the input map.
// Preserve whether the original schema was a pointer or value.
schemaVal := reflect.ValueOf(schema)
isPtr := schemaVal.Kind() == reflect.Ptr
structType := rv.Type() // underlying struct type
structPtr := reflect.New(structType)
structVal := structPtr.Elem()
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
if field.PkgPath != "" {
continue
}
if val, exists := inputMap[field.Name]; exists {
fv := structVal.Field(i)
if fv.CanSet() {
rvVal := reflect.ValueOf(val)
if rvVal.Type().AssignableTo(fv.Type()) {
fv.Set(rvVal)
} else if rvVal.Type().ConvertibleTo(fv.Type()) {
fv.Set(rvVal.Convert(fv.Type()))
}
}
}
}
if isPtr {
return structPtr.Interface() // *StructType
}
return structVal.Interface() // StructType (value)
}
func (e *Engine) readTaskInput(registry *channels.Registry, task *Task) (any, error) {
if len(task.Channels) == 0 {
// Return empty map instead of nil so that node functions expecting
// map[string]any receive a usable zero value rather than nil.
return map[string]any{}, nil
}
// Read values from specified channels
values := make(map[string]any)
for _, channelName := range task.Channels {
if ch, ok := registry.Get(channelName); ok {
value, err := ch.Get()
if err != nil {
if _, isEmpty := err.(*errors.EmptyChannelError); !isEmpty {
return nil, err
}
// Empty channels are OK
continue
}
values[channelName] = value
}
}
return values, nil
}
// Task represents a task to execute.
type Task struct {
ID string
Name string
Func types.NodeFunc
Channels []string
Path []string
Triggers map[string]struct{}
RetryPolicy *types.RetryPolicy
}
// TaskResult represents the result of executing a task.
type TaskResult struct {
Name string
Output any
Err error
Path []string // Task path for deterministic ordering (like Python's task_path)
}
// TaskPathStr generates a deterministic string representation of the task path.
// This corresponds to Python's task_path_str function in _algo.py
func TaskPathStr(path []string) string {
if len(path) == 0 {
return ""
}
// Join path components with separator for deterministic ordering
return strings.Join(path, "/")
}
// ParseTaskPath parses a task path string back into a path array.
func ParseTaskPath(pathStr string) []string {
if pathStr == "" {
return []string{}
}
return strings.Split(pathStr, "/")
}
// BuildTaskPath builds a task path from components.
// Supports nested paths like Python's tuple-based paths.
func BuildTaskPath(components ...any) []string {
path := make([]string, 0, len(components))
for _, comp := range components {
switch val := comp.(type) {
case string:
path = append(path, val)
case int:
path = append(path, fmt.Sprintf("%d", val))
case []string:
path = append(path, val...)
default:
if stringer, ok := val.(fmt.Stringer); ok {
path = append(path, stringer.String())
} else {
path = append(path, fmt.Sprintf("%v", val))
}
}
}
return path
}
// Helper methods that access the StateGraph
func (e *Engine) getGraphChannels() map[string]channels.Channel {
raw := e.graph.GetChannels()
result := make(map[string]channels.Channel, len(raw))
for k, v := range raw {
if ch, ok := v.(channels.Channel); ok {
result[k] = ch
}
}
return result
}
func (e *Engine) getEntryPoint() string {
return e.graph.GetEntryPoint()
}
func (e *Engine) getNode(name string) *types.Node {
node, _ := e.graph.GetNode(name)
return node
}
func (e *Engine) getNextNodes(ctx context.Context, node string, state any) map[string]bool {
common.Debug("getNextNodes",
zap.String("node", node),
zap.Any("state", state))
nextNodes := make(map[string]bool)
// (1) Check conditional edges. When a node has conditional edges,
// ONLY the matched target(s) are scheduled — the regular-edge
// fallback is skipped entirely so branchable nodes (Switch,
// Categorize) route exclusively via the _next value.
hasConditional := false
for _, condEdge := range e.graph.GetConditionalEdges() {
if condEdge.From != node {
continue
}
hasConditional = true
conditionResult, err := condEdge.Condition(ctx, state)
if err != nil {
common.Debug("conditional edge failed", zap.String("from", node), zap.Error(err))
}
conditionKey := fmt.Sprintf("%v", conditionResult)
targetNode, ok := condEdge.Mapping[conditionKey]
if !ok {
continue
}
if targetNode == constants.End {
return nextNodes
}
nextNodes[targetNode] = true
}
// (2) Regular edges: ONLY when this node has no conditional edges.
if !hasConditional && len(nextNodes) == 0 {
for _, edge := range e.graph.GetEdges() {
if edge.From == node {
if edge.To == constants.End {
return nextNodes
}
nextNodes[edge.To] = true
}
}
}
// (3) Resume fallback: when the last completed node has no outgoing
// edges but the graph state contains _next (persisted from a
// Switch/Categorize branch), route directly from _next. This
// happens on checkpoint resume because the conditional edge is
// registered on the Switch node, not on __loop_init__.
if len(nextNodes) == 0 {
if st, ok := state.(map[string]any); ok {
if raw, has := st["_next"]; has && raw != nil {
switch tv := raw.(type) {
case string:
if _, exists := e.graph.GetNode(tv); exists {
nextNodes[tv] = true
}
case []any:
if len(tv) > 0 {
if str, ok := tv[0].(string); ok {
if _, exists := e.graph.GetNode(str); exists {
nextNodes[str] = true
}
}
}
}
}
}
}
// (4) Branches: always included on top of whatever was scheduled.
for _, branch := range e.graph.GetBranches() {
if branch.From == node {
branchResult, err := branch.Condition(ctx, state)
if err != nil {
continue
}
targets := branch.Then(branchResult)
for _, target := range targets {
if target == constants.End {
continue
}
nextNodes[target] = true
}
}
}
return nextNodes
}
func (e *Engine) getTriggers(node *types.Node) []string {
if node == nil {
return []string{}
}
return node.Triggers
}
func (e *Engine) createTask(node *types.Node, state any, channels []string, triggers []string) *Task {
task := &Task{
ID: uuid.New().String(),
Name: node.Name,
Channels: channels,
Triggers: make(map[string]struct{}),
}
if node.Function != nil {
task.Func = node.Function
}
for _, trigger := range triggers {
task.Triggers[trigger] = struct{}{}
}
return task
}
// createTaskInfo creates a task info object for inspection/planning (for_execution=false mode).
// This is similar to Python's prepare_next_tasks with for_execution=False.
func (e *Engine) createTaskInfo(node *types.Node, state any, channels []string, triggers []string) *Task {
task := &Task{
ID: uuid.New().String(),
Name: node.Name,
Channels: channels,
Triggers: make(map[string]struct{}),
Func: nil,
}
for _, trigger := range triggers {
task.Triggers[trigger] = struct{}{}
}
return task
}
// PrepareNextTasksForInspection prepares tasks for inspection/planning only (for_execution=false).
// This corresponds to Python's prepare_next_tasks with for_execution=False.
func (e *Engine) PrepareNextTasksForInspection(
ctx context.Context,
registry *channels.Registry,
completedTasks map[string]bool,
lastCompletedNode string,
currentState any,
) ([]*Task, map[string]struct{}, error) {
return e.prepareNextTasksWithMode(ctx, registry, completedTasks, lastCompletedNode, currentState, false)
}
func (e *Engine) applyInput(registry *channels.Registry, input any) error {
inputMap, err := toMap(input)
if err != nil {
return err
}
// Auto-create channels for any input keys not yet registered, then write.
for key, value := range inputMap {
if _, ok := registry.Get(key); ok {
continue
}
guessed := caseFoldKey(registry, key)
if guessed != "" {
delete(inputMap, key)
inputMap[guessed] = value
} else {
registry.Register(key, channels.NewLastValue(value))
}
}
writes := make(map[string][]any, len(inputMap))
for key, value := range inputMap {
writes[key] = []any{value}
}
if len(writes) > 0 {
return registry.UpdateChannels(writes)
}
return nil
}
// caseFoldKey attempts to locate a registered channel whose name differs from
// key only by the case of the first character (e.g. struct field "Counter" vs
// input map key "counter"). Returns the matched channel name, or "".
func caseFoldKey(registry *channels.Registry, key string) string {
if len(key) == 0 {
return ""
}
// Try uppercase first (e.g. "counter" → "Counter")
bs := []byte(key)
if bs[0] >= 'a' && bs[0] <= 'z' {
bs[0] -= 32
candidate := string(bs)
if _, ok := registry.Get(candidate); ok {
return candidate
}
}
// Try lowercase first (e.g. "Counter" → "counter")
bs[0] = key[0]
if bs[0] >= 'A' && bs[0] <= 'Z' {
bs[0] += 32
candidate := string(bs)
if _, ok := registry.Get(candidate); ok {
return candidate
}
}
return ""
}
func (e *Engine) getThreadID() string {
if e.config != nil && e.config.Configurable != nil {
if tid, ok := e.config.Configurable["thread_id"].(string); ok {
return tid
}
}
return uuid.New().String()
}
func (e *Engine) buildOutput(registry *channels.Registry, lastState any) (any, error) {
values, err := registry.GetValues()
if err != nil {
return lastState, nil
}
if len(values) > 0 {
return values, nil
}
return lastState, nil
}
func (e *Engine) mergeStates(existing, next any) any {
if existing == nil {
return next
}
if next == nil {
return existing
}
// Try to merge maps
existingMap, ok1 := existing.(map[string]any)
nextMap, ok2 := next.(map[string]any)
if ok1 && ok2 {
result := make(map[string]any)
for key, val := range existingMap {
result[key] = val
}
for key, val := range nextMap {
result[key] = val
}
return result
}
return next
}
// toMap converts a struct or map to a map[string]any.
func toMap(val any) (map[string]any, error) {
if val == nil {
return nil, fmt.Errorf("nil value")
}
// If it's already a map
if m, ok := val.(map[string]any); ok {
return m, nil
}
// Use reflection to convert struct to map
rv := reflect.ValueOf(val)
if rv.Kind() == reflect.Ptr {
rv = rv.Elem()
}
if rv.Kind() != reflect.Struct && rv.Kind() != reflect.Map {
return map[string]any{"__root__": val}, nil
}
result := make(map[string]any)
if rv.Kind() == reflect.Map {
for _, key := range rv.MapKeys() {
result[fmt.Sprintf("%v", key.Interface())] = rv.MapIndex(key).Interface()
}
return result, nil
}
// Struct
rt := rv.Type()
for i := 0; i < rv.NumField(); i++ {
field := rt.Field(i)
// Skip unexported fields
if field.PkgPath != "" {
continue
}
val := rv.Field(i).Interface()
// Use original field name to match channel registration
// (configureChannelsFromSchema registers channels with field.Name).
result[field.Name] = val
}
return result, nil
}
// saveCheckpoint saves a checkpoint to the checkpointer.
func (e *Engine) saveCheckpoint(ctx context.Context, threadID, checkpointID string, step int, checkpoint map[string]any) error {
if e.checkpointer == nil {
return nil
}
return e.checkpointer.Put(ctx, map[string]any{
constants.ConfigKeyThreadID: threadID,
constants.ConfigKeyCheckpointID: checkpointID,
"step": step,
}, checkpoint)
}
// deferCheckpoint defers a checkpoint save for DurabilityExit mode.
func (e *Engine) deferCheckpoint(threadID, checkpointID string, step int, checkpoint map[string]any) {
e.deferredCheckpoints = append(e.deferredCheckpoints, deferredCheckpoint{
ThreadID: threadID,
CheckpointID: checkpointID,
Step: step,
Checkpoint: checkpoint,
})
}
// saveDeferredCheckpoints saves all deferred checkpoints (called at exit for DurabilityExit mode).
func (e *Engine) saveDeferredCheckpoints(ctx context.Context) error {
if e.checkpointer == nil || len(e.deferredCheckpoints) == 0 {
return nil
}
var lastErr error
for _, dc := range e.deferredCheckpoints {
if err := e.saveCheckpoint(ctx, dc.ThreadID, dc.CheckpointID, dc.Step, dc.Checkpoint); err != nil {
lastErr = err
// Continue saving other checkpoints even if one fails
}
}
// Clear deferred checkpoints after attempting to save
e.deferredCheckpoints = nil
return lastErr
}
// RunSync executes the graph synchronously and returns the final state.
// This is a convenience wrapper around Run() for callers that want a blocking API.
//
// RunSync first drains all events from outputCh (reading until it is closed),
// then checks errCh for any execution error. This ordering avoids a race
// between the EventTypeFinal arriving on outputCh and errCh being closed
// (the defer calling close(errCh) runs AFTER close(outputCh)).
func (e *Engine) RunSync(ctx context.Context, input any) (any, error) {
outputCh, errCh := e.Run(ctx, input, types.StreamModeValues)
var finalState any
// Drain outputCh to capture the final state event.
// Must read until closed to avoid leaking the forward goroutine.
for result := range outputCh {
if se, ok := result.(*StreamEvent); ok && se.Type == EventTypeFinal {
if data, ok := se.Data.(map[string]any); ok {
if state, ok := data["state"]; ok {
finalState = state
}
}
}
}
// Check for execution errors (non-blocking; errCh is closed after outputCh).
select {
case err := <-errCh:
if err != nil {
return nil, err
}
default:
}
return finalState, nil
}
// applyFieldMapping filters and remaps an output map according to FieldMapping rules.
// If no mappings are specified, the entire output map is passed through unchanged.
// Each mapping specifies a source field path (From) and a target field path (To).
func applyFieldMapping(output map[string]any, mappings []types.FieldMapping) map[string]any {
if len(mappings) == 0 {
return output
}
result := make(map[string]any, len(mappings))
for _, mapping := range mappings {
val := getNestedField(output, mapping.From)
if val != nil {
setNestedField(result, mapping.To, val)
}
}
return result
}
// getNestedField retrieves a value from a nested map using a dot-separated path.
func getNestedField(m map[string]any, path string) any {
if path == "" {
return m // return entire map
}
parts := strings.Split(path, ".")
var cur any = m
for _, part := range parts {
cm, ok := cur.(map[string]any)
if !ok {
return nil
}
cur = cm[part]
if cur == nil {
return nil
}
}
return cur
}
// setNestedField sets a value in a nested map using a dot-separated path.
func setNestedField(m map[string]any, path string, val any) {
if path == "" {
for k, v := range val.(map[string]any) {
m[k] = v
}
return
}
parts := strings.Split(path, ".")
for i := 0; i < len(parts)-1; i++ {
sub, ok := m[parts[i]]
if !ok {
sub = make(map[string]any)
m[parts[i]] = sub
}
var ok2 bool
m, ok2 = sub.(map[string]any)
if !ok2 {
nm := make(map[string]any)
m[parts[i]] = nm
m = nm
}
}
m[parts[len(parts)-1]] = val
}
// serializeStringSet encodes a map[string]bool to a NUL-separated string
// for storage in the checkpoint payload.
func serializeStringSet(set map[string]bool) string {
if len(set) == 0 {
return ""
}
keys := make([]string, 0, len(set))
for key := range set {
keys = append(keys, key)
}
sort.Strings(keys)
out := make([]byte, 0, 256)
for i, key := range keys {
if i > 0 {
out = append(out, 0)
}
out = append(out, key...)
}
return string(out)
}
// deserializeStringSet decodes a NUL-separated string back to a
// map[string]bool.
func deserializeStringSet(encoded string) map[string]bool {
if encoded == "" {
return nil
}
parts := strings.Split(encoded, "\x00")
out := make(map[string]bool, len(parts))
for _, part := range parts {
out[part] = true
}
return out
}