mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-30 07:51:10 +08:00
## Summary
Aligns the **Go agent runtime/canvas/components/tools** behavior with
the **Python `agent/` implementation** so the same stored canvas DSL
produces the same execution result on either side. Every component,
tool, and runtime primitive in `internal/agent/` is now driven by the
same semantics as its Python counterpart — variable resolution, template
substitution, control flow, error reporting, retry/cancel, and stream
event shapes.
The **retrieval component is the one explicit exception** in this PR. It
is being reworked in a separate change and is excluded from this
alignment pass; the wrapper slot (`universe_a_wrappers.go →
newRetrievalComponent`) is preserved.
## Scope of alignment
### Components (all aligned with `agent/component/`)
`Begin` · `Message` · `LLM` (incl. ChatTemplateKwargs,
MessageHistoryWindowSize, VisualFiles, Cite, OutputStructure,
JSONOutput, TopP, MaxRetries, DelayAfterError, credentials) · `Agent`
(react + tool artifact capture + `Reset()` interface-assert) · `Switch`
(12/12 operators, Python-equivalent semantics) · `Categorize` · `Invoke`
· `Iteration` · `Loop` (macro-expansion through `workflowx.AddLoopNode`)
· `UserFillUp` (Python-equivalent interrupt/resume via eino
`compose.Interrupt`/`ResumeWithData`) · `FillUp` · `DataOperations` ·
`ListOperations` · `StringTransform` · `VariableAggregator` ·
`VariableAssigner` · `Browser` (full stagehand runtime parity) ·
`DocsGenerator` · `ExcelProcessor`.
### Tools (all aligned with `agent/tools/`)
`Retrieval` (wrapper slot only — logic out of scope) · `MCPToolAdapter`
(streamable-HTTP) · `CodeExec` (sandbox bridge with
`code_exec_contract.go` matching Python contract) · `AkShare` · `ArXiv`
· `Crawler` · `DeepL` · `DuckDuckGo` · `Email` · `ExeSQL` · `GitHub` ·
`Google` · `GoogleScholar` · `Jin10` · `PubMed` · `QWeather` · `SearXNG`
· `Tavily` · `Tushare` · `Wencai` · `Wikipedia` · `YahooFinance` —
uniform `eino tool.InvokableTool` interface, SSRF protection, shared
HTTP client.
### Canvas execution engine (`internal/agent/canvas/`)
Aligned with Python's `agent/canvas.py`:
- **Scheduler** (`scheduler.go`): state pre/post handlers, node lambdas,
per-component timeout resolver (4-level: per-class env → per-class table
→ uniform env → 600s fallback), `legacyNoOpNames`.
- **Loop subgraph** (`loop_subgraph.go`): Python-equivalent
`AddLoopNode` macro expansion + condition translation.
- **Multibranch** (`multibranch.go`): `Switch` / `Categorize` routing
via `compose.NewGraphMultiBranch` — same branch selection semantics as
Python.
- **Parallel subgraph** (`parallel_subgraph.go`): matches Python's
parallel fan-out contract.
- **Interrupt/Resume** (`interrupt_resume.go`): `UserFillUpNodeBody` /
`IsInterruptError` / `ExtractInterruptContexts` — replaces the
deprecated Python sentinel chain with eino's native interrupt API,
preserving the same external behavior.
- **Checkpoint** (`checkpoint_store.go`): `RedisCheckPointStore`
Get/Set/Delete, with business metadata (status / canvas_id /
parent_run_id) on a parallel Redis Hash.
- **RunTracker** (`run_tracker.go`): Start / MarkSucceeded / MarkFailed
/ MarkCancelled / AttachCheckpoint — same lifecycle as the Python run
record.
- **Cancel** (`cancel.go`): Redis pub/sub watch.
- **Stream** (`stream.go`): SSE channel with `messages` / `waiting` /
`errors` / `done` events, same shape as Python's `agent.canvas.RunEvent`
payload.
### DSL bridge (`internal/agent/dsl/`)
- `normalize.go`: v1↔v2 collapsed into a single wire format — Python and
Go consume the same stored JSON.
- `reset.go`: per-run state reset matches Python's `Canvas.reset()`
semantics.
- Testdata mirrors Python's `agent_msg.json` / `all.json` / etc.
### Runtime (`internal/agent/runtime/`)
- `CanvasState` / `NewCanvasState` / `GetVar` / `SetVar` / `ReadVars`:
same `{{cpn_id@param}}` resolution model.
- `ResolveTemplate` (regex fast path + gonja fallback) — Python
Jinja-style semantics.
- `selector.go`, `metrics.go`, `component.go`: shared runtime contracts.
## Out of scope (intentionally)
- **`Retrieval` component logic** — wrapped only; full parity lands in a
follow-up PR.
- **Frontend** — only minor dsl-bridge / canvas UX fixes ride along.
- **CLI / admin / model registry** — orthogonal to agent behavior.
## How alignment is verified
`internal/service/agent_run_e2e_test.go` exercises the **full production
chain** against real Python-shaped DSL fixtures:
```
loadCanvasForUser → versionDAO.GetLatest → decodeCanvasFromDSL →
canvas.Compile → cc.Workflow.Invoke → answer extraction
```
using in-memory SQLite + miniredis (no Docker). Covers:
- `TestRunAgent_RealCanvas_BeginMessage` — happy path, `{{sys.query}}`
resolution
- `TestRunAgent_RealCanvas_WaitForUserResume` — two-run resume cycle
(Python-equivalent)
- `TestRunAgent_RealCanvas_CompileFails` — unknown component name →
sanitized error (Python-equivalent)
- `TestRunAgent_RealCanvas_InvokeFails` — unresolvable template ref
(Python-equivalent)
- `TestRunAgent_RunTracker_AttachCheckpoint_CallSequence` —
Start→AttachCheckpoint→MarkSucceeded lifecycle
`internal/handler/agent_test.go` — SSE streaming parity (`Content-Type:
text/event-stream`, `data: {…}\n\n`, trailing `data: [DONE]\n\n`,
OpenAI-compatible non-stream `choices`).
`internal/agent/canvas/fixture_compile_test.go` + per-component tests
pin the Python-equivalent outputs.
```
go test -count=1 -v -run 'TestRunAgent_RealCanvas|TestRunAgent_RunTracker' ./internal/service/
```
## Design reference
`docs/develop/agent-go-port-design.md` (1329 lines, last cross-checked
2026-06-17) — module layout, per-component / per-tool inventory,
corner-case catalogue, and the actionable backlog (Section 14, including
the retrieval alignment follow-up).
---------
Co-authored-by: Claude <noreply@anthropic.com>
813 lines
28 KiB
Go
813 lines
28 KiB
Go
//
|
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
|
|
// Package workflowx parallel extension.
|
|
//
|
|
// AddParallelNode is a zero-intrusion helper that runs a sub-workflow
|
|
// once per input item, with bounded concurrency, and supports
|
|
// per-item interrupt / resume. The shape mirrors AddLoopNode:
|
|
// the outer workflow sees a single node; the fan-out is entirely
|
|
// inside the lambda body.
|
|
//
|
|
// The first release is invoke-only on the outer lambda; the inner
|
|
// per-item sub-workflow is invoked via runner.Invoke.
|
|
//
|
|
// See .claude/plans/eino-workflow-parallel.md (and
|
|
// .omc/autopilot/spec.md) for the design rationale.
|
|
package workflowx
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strconv"
|
|
"sync"
|
|
|
|
"github.com/cloudwego/eino/compose"
|
|
"github.com/cloudwego/eino/schema"
|
|
)
|
|
|
|
// ParallelAddressSegment is the per-item address segment used when
|
|
// addressing interrupts. It mirrors the batch node's
|
|
// AddressSegmentBatchProcess.
|
|
const ParallelAddressSegment compose.AddressSegmentType = "workflowx-parallel"
|
|
|
|
// Sentinel errors for the parallel extension. Tests use errors.Is
|
|
// to assert these.
|
|
var (
|
|
// ErrParallelCompileFailed wraps a compile-time failure of the
|
|
// inner sub-workflow. The original error from sub.Compile is
|
|
// reachable via errors.Unwrap.
|
|
ErrParallelCompileFailed = errors.New("workflowx: parallel sub-workflow compile failed")
|
|
|
|
// ErrParallelResumeStateInvalid is returned when a resume is
|
|
// requested but the persisted state is missing, malformed, or
|
|
// has an empty Inputs slice.
|
|
ErrParallelResumeStateInvalid = errors.New("workflowx: parallel resume state invalid")
|
|
)
|
|
|
|
// ParallelOption configures AddParallelNode. Follows the
|
|
// functional-options pattern.
|
|
type ParallelOption func(*parallelOptions)
|
|
|
|
type parallelOptions struct {
|
|
maxConcurrency int
|
|
compileOpts []compose.GraphCompileOption
|
|
runOpts []compose.Option
|
|
checkpointBuilder func(nodeKey string, index int) string
|
|
enableSubCheckpoint bool
|
|
contextBuilder func(ctx context.Context, item any, index int) context.Context
|
|
}
|
|
|
|
// WithParallelMaxConcurrency caps the number of per-item sub-workflow
|
|
// invocations that run concurrently.
|
|
//
|
|
// n <= 1 — sequential execution on the calling goroutine (no
|
|
// goroutines are spawned for any input length).
|
|
// n > 1 — bounded fan-out using a semaphore of size n; the first
|
|
// item still runs on the main goroutine.
|
|
//
|
|
// The default is 0 (sequential).
|
|
func WithParallelMaxConcurrency(n int) ParallelOption {
|
|
return func(o *parallelOptions) {
|
|
if n >= 0 {
|
|
o.maxConcurrency = n
|
|
}
|
|
}
|
|
}
|
|
|
|
// WithParallelCompileOptions appends compile options to the inner
|
|
// sub-workflow's Compile call. Useful for wiring a Serializer or a
|
|
// caller-managed CheckPointStore on the inner sub-graph.
|
|
//
|
|
// Note: the parallel extension always passes its own bridge
|
|
// CheckPointStore first (when sub-checkpoint is enabled), so any
|
|
// store set via this option will not collide with the bridge store.
|
|
func WithParallelCompileOptions(opts ...compose.GraphCompileOption) ParallelOption {
|
|
return func(o *parallelOptions) {
|
|
o.compileOpts = append(o.compileOpts, opts...)
|
|
}
|
|
}
|
|
|
|
// WithParallelRunOptions appends run options to every per-item
|
|
// sub-workflow Invoke call. Use this to forward run-level options
|
|
// such as per-item callbacks.
|
|
func WithParallelRunOptions(opts ...compose.Option) ParallelOption {
|
|
return func(o *parallelOptions) {
|
|
o.runOpts = append(o.runOpts, opts...)
|
|
}
|
|
}
|
|
|
|
// WithParallelCheckpointIDBuilder supplies a deterministic checkpoint
|
|
// ID for each per-item sub-workflow invocation. eino does not expose
|
|
// the active outer checkpoint ID through ctx, so the extension
|
|
// cannot derive child IDs by itself.
|
|
//
|
|
// The default is a reserved-namespace builder
|
|
// (workflowx-parallel:<nodeKey>:<index>) which is deterministic and
|
|
// stable across resumes. Callers that need a shared prefix can
|
|
// capture it in the closure.
|
|
//
|
|
// The builder is invoked on the first run AND on resume, with stable
|
|
// (nodeKey, index) arguments. An empty return is treated as "skip
|
|
// the per-item WithCheckPointID" so the inner task does not get a
|
|
// bad namespace.
|
|
func WithParallelCheckpointIDBuilder(b func(nodeKey string, index int) string) ParallelOption {
|
|
return func(o *parallelOptions) {
|
|
if b != nil {
|
|
o.checkpointBuilder = b
|
|
}
|
|
}
|
|
}
|
|
|
|
// WithParallelEnableSubCheckpoint opts the parallel node into
|
|
// passing compose.WithCheckPointID(...) and an internal bridge
|
|
// store to the sub-workflow on every per-item Invoke call.
|
|
//
|
|
// The default is true. Disabling it is only useful when the caller
|
|
// explicitly wants the smaller/no-sub-checkpoint behavior.
|
|
func WithParallelEnableSubCheckpoint(enable bool) ParallelOption {
|
|
return func(o *parallelOptions) {
|
|
o.enableSubCheckpoint = enable
|
|
}
|
|
}
|
|
|
|
// WithParallelContextBuilder decorates the per-item sub-workflow
|
|
// context before Invoke. This lets callers attach item-scoped runtime
|
|
// state without changing the outer []I -> []O parallel API.
|
|
func WithParallelContextBuilder(
|
|
b func(ctx context.Context, item any, index int) context.Context,
|
|
) ParallelOption {
|
|
return func(o *parallelOptions) {
|
|
if b != nil {
|
|
o.contextBuilder = b
|
|
}
|
|
}
|
|
}
|
|
|
|
// defaultParallelCheckpointBuilder returns a deterministic per-item
|
|
// checkpoint ID. Unlike the loop extension, the parallel extension
|
|
// does not need a UUID in the default because the same item index
|
|
// is naturally re-derived from the persisted InterruptedIndices on
|
|
// resume — so the same ID is reused.
|
|
func defaultParallelCheckpointBuilder(nodeKey string, index int) string {
|
|
return fmt.Sprintf("workflowx-parallel:%s:%d", nodeKey, index)
|
|
}
|
|
|
|
func getParallelOptions(opts []ParallelOption) *parallelOptions {
|
|
o := ¶llelOptions{
|
|
checkpointBuilder: defaultParallelCheckpointBuilder,
|
|
enableSubCheckpoint: true,
|
|
}
|
|
for _, opt := range opts {
|
|
opt(o)
|
|
}
|
|
return o
|
|
}
|
|
|
|
// ParallelInterruptState is the parallel-local checkpoint payload.
|
|
// It is persisted as the state argument of
|
|
// compose.CompositeInterrupt so a resumed run can continue from the
|
|
// interrupted items rather than restart.
|
|
//
|
|
// The struct mirrors the reference batch node's NodeInterruptState,
|
|
// with one important adaptation: OriginalInputs is stored as a
|
|
// JSON byte slice (not []any) so the parallel extension can
|
|
// re-decode it with the original Go types on resume. JSON's
|
|
// default behaviour of decoding numbers into float64 would
|
|
// otherwise break integer and other typed inputs.
|
|
type ParallelInterruptState struct {
|
|
// OriginalInputsJSON is the JSON encoding of the input slice
|
|
// as seen by the parallel lambda on first run. On resume
|
|
// the lambda input is replaced by a zero value by eino's
|
|
// rerun mechanism; this byte slice is the source of truth.
|
|
OriginalInputsJSON []byte `json:"original_inputs_json"`
|
|
|
|
// CompletedResults carries every index that already produced
|
|
// a value on a previous (interrupted) run.
|
|
CompletedResults map[int]any `json:"completed_results"`
|
|
|
|
// InterruptedIndices is the list of indices that were not
|
|
// durably confirmed completed at the interrupt boundary.
|
|
// In the common case this equals "the items whose sub-workflow
|
|
// Invoke returned an interrupt". Under concurrent execution,
|
|
// however, any item that is not present in CompletedResults is
|
|
// treated conservatively as needing replay / resume, because its
|
|
// precise execution state may be unknown when the outer node
|
|
// returns a CompositeInterrupt.
|
|
InterruptedIndices []int `json:"interrupted_indices"`
|
|
|
|
// TotalCount is the size of the input slice. It is the source
|
|
// of truth for the output slice length on resume.
|
|
TotalCount int `json:"total_count"`
|
|
|
|
// ItemCheckpoints is the per-item bridge-store payload captured
|
|
// at interrupt time. Keys are the per-item child checkpoint
|
|
// IDs (whatever the configured builder produced).
|
|
ItemCheckpoints map[string][]byte `json:"item_checkpoints,omitempty"`
|
|
}
|
|
|
|
// Compilable is the input type accepted by AddParallelNode. Both
|
|
// *compose.Graph[I, O] and *compose.Workflow[I, O] satisfy it.
|
|
type Compilable[I, O any] interface {
|
|
Compile(ctx context.Context, opts ...compose.GraphCompileOption) (compose.Runnable[I, O], error)
|
|
}
|
|
|
|
// AddParallelNode appends a parallel-fanout node to the outer
|
|
// workflow. The fan-out is inside the lambda body; the outer graph
|
|
// sees one node.
|
|
//
|
|
// The lambda is invoke-only in v1; its Stream handler returns a
|
|
// documented error. Callers that need outer-stream parallelism
|
|
// should treat that as a future v2 plan.
|
|
//
|
|
// AddParallelNode compiles the sub-workflow immediately. Compile-
|
|
// time failures are returned as an error and the outer workflow
|
|
// is not modified.
|
|
func AddParallelNode[I, O any](
|
|
ctx context.Context,
|
|
wf *compose.Workflow[[]I, []O],
|
|
key string,
|
|
sub Compilable[I, O],
|
|
opts ...ParallelOption,
|
|
) (*compose.WorkflowNode, error) {
|
|
if wf == nil {
|
|
return nil, errors.New("workflowx: outer workflow is nil")
|
|
}
|
|
if sub == nil {
|
|
return nil, errors.New("workflowx: sub workflow is nil")
|
|
}
|
|
options := getParallelOptions(opts)
|
|
|
|
// Build a fresh per-node bridge store. It is captured in the
|
|
// lambda's closure and rehydrated from ItemCheckpoints on
|
|
// resume.
|
|
bridgeState := newParallelBridgeState(nil)
|
|
|
|
compileOpts := append([]compose.GraphCompileOption{}, options.compileOpts...)
|
|
if options.enableSubCheckpoint {
|
|
compileOpts = append(compileOpts, compose.WithCheckPointStore(bridgeState.store()))
|
|
}
|
|
compiled, err := sub.Compile(ctx, compileOpts...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: %s: %v", ErrParallelCompileFailed, key, err)
|
|
}
|
|
|
|
lambda, err := compose.AnyLambda[[]I, []O, struct{}](
|
|
func(ctx context.Context, items []I, _ ...struct{}) ([]O, error) {
|
|
return runParallelInvoke(ctx, key, compiled, items, options, bridgeState)
|
|
},
|
|
func(ctx context.Context, items []I, _ ...struct{}) (*schema.StreamReader[[]O], error) {
|
|
return nil, errParallelOuterStreamUnsupported
|
|
},
|
|
nil,
|
|
nil,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("workflowx: build parallel lambda: %w", err)
|
|
}
|
|
|
|
return wf.AddLambdaNode(key, lambda), nil
|
|
}
|
|
|
|
// errParallelOuterStreamUnsupported is the documented v1 error
|
|
// returned from the outer Stream handler. Surfaced as a sentinel
|
|
// for tests to assert against via errors.Is.
|
|
var errParallelOuterStreamUnsupported = errors.New("workflowx: parallel node does not support outer stream in v1")
|
|
|
|
// ErrParallelOuterStreamUnsupported is exported so external tests
|
|
// can assert on it. The lambda's Stream handler returns this
|
|
// (wrapped) error.
|
|
var ErrParallelOuterStreamUnsupported = errParallelOuterStreamUnsupported
|
|
|
|
// runParallelInvoke is the body of the parallel lambda's Invoke
|
|
// handler. It implements the documented state machine:
|
|
//
|
|
// - On a fresh run: process every item 0..len(items)-1 with an
|
|
// empty CompletedResults map.
|
|
// - On a resume: process exactly prev.InterruptedIndices, with
|
|
// prev.CompletedResults pre-populated into the output slice.
|
|
// On resume the lambda's items input is replaced by a zero
|
|
// value by eino's rerun mechanism; the canonical inputs come
|
|
// from prev.OriginalInputs.
|
|
// - If any items interrupt, return a single CompositeInterrupt
|
|
// carrying all per-item interrupt errors and a state that lets
|
|
// a resumed run re-enter deterministically.
|
|
// - On a non-interrupt error, return the first one (wrapped per
|
|
// "item %d: %w") and discard the other items' results.
|
|
func runParallelInvoke[I, O any](
|
|
ctx context.Context,
|
|
nodeKey string,
|
|
sub compose.Runnable[I, O],
|
|
items []I,
|
|
options *parallelOptions,
|
|
defaultBridge *parallelBridgeState,
|
|
) ([]O, error) {
|
|
prev, isResume, resumeErr := loadParallelSnapshot(ctx)
|
|
if resumeErr != nil {
|
|
return nil, resumeErr
|
|
}
|
|
// On a resume, eino's rerun mechanism passes a zero-value
|
|
// items slice to the lambda. The canonical inputs come from
|
|
// the persisted state. On a fresh run, items is the user's
|
|
// input and the persisted state is empty.
|
|
effectiveItems := items
|
|
if isResume && prev != nil {
|
|
var restored []I
|
|
if rErr := json.Unmarshal(prev.OriginalInputsJSON, &restored); rErr != nil {
|
|
return nil, fmt.Errorf("%w: decode original_inputs_json: %v", ErrParallelResumeStateInvalid, rErr)
|
|
}
|
|
effectiveItems = restored
|
|
}
|
|
if len(effectiveItems) == 0 {
|
|
return []O{}, nil
|
|
}
|
|
|
|
// Allocate output slice. On resume, the total count is the
|
|
// persisted value; on first run, it is the input length.
|
|
totalCount := len(effectiveItems)
|
|
indicesToProcess := make([]int, len(effectiveItems))
|
|
for i := range effectiveItems {
|
|
indicesToProcess[i] = i
|
|
}
|
|
|
|
bridgeState := defaultBridge
|
|
outputs := make([]O, totalCount)
|
|
|
|
if isResume && prev != nil {
|
|
totalCount = prev.TotalCount
|
|
if totalCount < 0 {
|
|
return nil, fmt.Errorf("%w: negative total_count", ErrParallelResumeStateInvalid)
|
|
}
|
|
outputs = make([]O, totalCount)
|
|
// Replay completed results into the correct output slots.
|
|
// The persisted value came through a JSON round-trip so
|
|
// numeric types are float64; we coerce to O via an
|
|
// intermediate any round-trip.
|
|
for idx, v := range prev.CompletedResults {
|
|
if idx < 0 || idx >= totalCount {
|
|
return nil, fmt.Errorf("%w: completed index %d out of range", ErrParallelResumeStateInvalid, idx)
|
|
}
|
|
typed, ok := coerceAnyToO[O](v)
|
|
if !ok {
|
|
return nil, fmt.Errorf("%w: cannot coerce completed result at index %d (type %T) to target type", ErrParallelResumeStateInvalid, idx, v)
|
|
}
|
|
outputs[idx] = typed
|
|
}
|
|
// Only re-invoke the previously-interrupted indices.
|
|
indicesToProcess = append([]int(nil), prev.InterruptedIndices...)
|
|
// Rehydrate the bridge store from persisted ItemCheckpoints.
|
|
bridgeState = newParallelBridgeState(prev.ItemCheckpoints)
|
|
}
|
|
|
|
// Run all items. The sequential / semaphore-bounded fan-out
|
|
// is delegated to runParallelFanout.
|
|
results := runParallelFanout(ctx, nodeKey, sub, effectiveItems, indicesToProcess, options, bridgeState)
|
|
|
|
// Drain the result channel, categorising each entry.
|
|
var normalErr error
|
|
var interruptErrs []error
|
|
completedResults := make(map[int]any)
|
|
for r := range results {
|
|
if r.err == nil {
|
|
if r.index >= 0 && r.index < len(outputs) {
|
|
if typed, ok := r.output.(O); ok {
|
|
outputs[r.index] = typed
|
|
}
|
|
}
|
|
completedResults[r.index] = r.output
|
|
continue
|
|
}
|
|
if isInterruptError(r.err) {
|
|
interruptErrs = append(interruptErrs, r.err)
|
|
continue
|
|
}
|
|
// First non-interrupt error wins; we keep draining so
|
|
// goroutines do not leak, but the caller will see this
|
|
// normalErr and discard the rest.
|
|
if normalErr == nil {
|
|
normalErr = fmt.Errorf("item %d: %w", r.index, r.err)
|
|
}
|
|
}
|
|
|
|
// Non-interrupt error: discard every other result, return the
|
|
// first one (wrapped). No state is persisted.
|
|
if normalErr != nil {
|
|
return nil, normalErr
|
|
}
|
|
|
|
// Interrupt case: persist state and rethrow via CompositeInterrupt.
|
|
// We store every non-completed index, not only the indices that
|
|
// explicitly surfaced an interrupt. This preserves correctness if
|
|
// a future implementation changes the fan-out to short-circuit or
|
|
// cancel in-flight work at the first interrupt boundary.
|
|
if len(interruptErrs) > 0 {
|
|
inputsJSON, jErr := json.Marshal(effectiveItems)
|
|
if jErr != nil {
|
|
return nil, fmt.Errorf("workflowx: marshal original inputs: %w", jErr)
|
|
}
|
|
interruptedIndices := buildPendingIndices(totalCount, completedResults)
|
|
state, sErr := encodeParallelState(ParallelInterruptState{
|
|
OriginalInputsJSON: inputsJSON,
|
|
CompletedResults: completedResults,
|
|
InterruptedIndices: interruptedIndices,
|
|
TotalCount: totalCount,
|
|
ItemCheckpoints: bridgeState.snapshot(),
|
|
})
|
|
if sErr != nil {
|
|
return nil, fmt.Errorf("workflowx: encode parallel interrupt state: %w", sErr)
|
|
}
|
|
return nil, compose.CompositeInterrupt(ctx, nil, state, interruptErrs...)
|
|
}
|
|
|
|
return outputs, nil
|
|
}
|
|
|
|
// coerceAnyToO adapts a JSON-roundtripped any to the typed
|
|
// output O. JSON decoding maps numeric types to float64 by
|
|
// default, so a value that originated as int, int64, float32,
|
|
// etc. comes back as float64. This helper covers the common
|
|
// numeric conversions; for non-numeric O, the direct assertion
|
|
// is used.
|
|
func coerceAnyToO[O any](v any) (O, bool) {
|
|
var zero O
|
|
if v == nil {
|
|
return zero, false
|
|
}
|
|
if typed, ok := v.(O); ok {
|
|
return typed, true
|
|
}
|
|
// JSON-decode coercion: float64 -> O when O is one of the
|
|
// common numeric types. We use reflect-free type switches.
|
|
switch any(zero).(type) {
|
|
case int:
|
|
if f, ok := v.(float64); ok {
|
|
return any(int(f)).(O), true
|
|
}
|
|
case int64:
|
|
if f, ok := v.(float64); ok {
|
|
return any(int64(f)).(O), true
|
|
}
|
|
case int32:
|
|
if f, ok := v.(float64); ok {
|
|
return any(int32(f)).(O), true
|
|
}
|
|
case float32:
|
|
if f, ok := v.(float64); ok {
|
|
return any(float32(f)).(O), true
|
|
}
|
|
case float64:
|
|
if f, ok := v.(float64); ok {
|
|
return any(f).(O), true
|
|
}
|
|
case uint:
|
|
if f, ok := v.(float64); ok {
|
|
return any(uint(f)).(O), true
|
|
}
|
|
case uint64:
|
|
if f, ok := v.(float64); ok {
|
|
return any(uint64(f)).(O), true
|
|
}
|
|
case uint32:
|
|
if f, ok := v.(float64); ok {
|
|
return any(uint32(f)).(O), true
|
|
}
|
|
}
|
|
return zero, false
|
|
}
|
|
|
|
// parallelResumeBackdoorKey is a context key used by unit tests
|
|
// to drive the resume path without going through eino's
|
|
// framework-managed checkpoint store. The production resume
|
|
// path uses compose.GetInterruptState; this is a test-only
|
|
// backdoor. Set via context.WithValue(ctx,
|
|
// parallelResumeBackdoorKey{}, payload) where payload is the
|
|
// JSON-encoded ParallelInterruptState.
|
|
type parallelResumeBackdoorKey struct{}
|
|
|
|
// loadParallelSnapshot reads the persisted parallel state from
|
|
// ctx if the current run is a resume. On a fresh run it returns
|
|
// (nil, false, nil).
|
|
//
|
|
// The loader first checks for a test-injected payload via
|
|
// parallelResumeBackdoorKey (so unit tests can drive the resume
|
|
// path directly), then falls back to eino's
|
|
// compose.GetInterruptState. The production resume path always
|
|
// goes through the second branch.
|
|
func loadParallelSnapshot(ctx context.Context) (*ParallelInterruptState, bool, error) {
|
|
// Test backdoor: a hand-injected payload takes priority so
|
|
// unit tests can drive resume without a real checkpoint
|
|
// store. Production code never sets this key.
|
|
if raw, ok := ctx.Value(parallelResumeBackdoorKey{}).([]byte); ok && len(raw) > 0 {
|
|
var st ParallelInterruptState
|
|
if err := json.Unmarshal(raw, &st); err != nil {
|
|
return nil, false, fmt.Errorf("%w: decode state: %v", ErrParallelResumeStateInvalid, err)
|
|
}
|
|
if st.TotalCount < 0 {
|
|
return nil, false, fmt.Errorf("%w: negative total_count", ErrParallelResumeStateInvalid)
|
|
}
|
|
if err := validateParallelSnapshot(&st); err != nil {
|
|
return nil, false, err
|
|
}
|
|
return &st, true, nil
|
|
}
|
|
wasInterrupted, hasState, payload := compose.GetInterruptState[[]byte](ctx)
|
|
if !wasInterrupted || !hasState {
|
|
return nil, false, nil
|
|
}
|
|
var st ParallelInterruptState
|
|
if err := json.Unmarshal(payload, &st); err != nil {
|
|
return nil, false, fmt.Errorf("%w: decode state: %v", ErrParallelResumeStateInvalid, err)
|
|
}
|
|
if st.TotalCount < 0 {
|
|
return nil, false, fmt.Errorf("%w: negative total_count", ErrParallelResumeStateInvalid)
|
|
}
|
|
if err := validateParallelSnapshot(&st); err != nil {
|
|
return nil, false, err
|
|
}
|
|
return &st, true, nil
|
|
}
|
|
|
|
// encodeParallelState marshals the parallel state to the
|
|
// persistable form. Go's encoding/json natively encodes
|
|
// map[int]any with integer keys as JSON object string keys, so
|
|
// the round-trip preserves the type.
|
|
func encodeParallelState(s ParallelInterruptState) ([]byte, error) {
|
|
return json.Marshal(s)
|
|
}
|
|
|
|
// buildPendingIndices returns the resume set for an interrupted run:
|
|
// every index in [0,totalCount) that is not durably present in
|
|
// CompletedResults. The returned slice is intentionally the full
|
|
// non-completed complement for safety: under concurrent execution,
|
|
// an item whose goroutine was still in-flight at the interrupt
|
|
// boundary is treated as needing replay.
|
|
func buildPendingIndices(totalCount int, completedResults map[int]any) []int {
|
|
if totalCount <= 0 {
|
|
return nil
|
|
}
|
|
pending := make([]int, 0, totalCount-len(completedResults))
|
|
for idx := 0; idx < totalCount; idx++ {
|
|
if _, ok := completedResults[idx]; ok {
|
|
continue
|
|
}
|
|
pending = append(pending, idx)
|
|
}
|
|
return pending
|
|
}
|
|
|
|
// validateParallelSnapshot enforces the resume invariant:
|
|
// CompletedResults and InterruptedIndices must form a partition of
|
|
// [0,totalCount). Any hole means the resumed run cannot know whether
|
|
// the missing item never started, partially ran, or already caused
|
|
// side effects, so the state is rejected as invalid.
|
|
func validateParallelSnapshot(st *ParallelInterruptState) error {
|
|
if st == nil {
|
|
return nil
|
|
}
|
|
covered := make([]bool, st.TotalCount)
|
|
for idx := range st.CompletedResults {
|
|
if idx < 0 || idx >= st.TotalCount {
|
|
return fmt.Errorf("%w: completed index %d out of range", ErrParallelResumeStateInvalid, idx)
|
|
}
|
|
if covered[idx] {
|
|
return fmt.Errorf("%w: duplicate index %d across completed/interrupted sets", ErrParallelResumeStateInvalid, idx)
|
|
}
|
|
covered[idx] = true
|
|
}
|
|
for _, idx := range st.InterruptedIndices {
|
|
if idx < 0 || idx >= st.TotalCount {
|
|
return fmt.Errorf("%w: interrupted index %d out of range", ErrParallelResumeStateInvalid, idx)
|
|
}
|
|
if covered[idx] {
|
|
return fmt.Errorf("%w: duplicate index %d across completed/interrupted sets", ErrParallelResumeStateInvalid, idx)
|
|
}
|
|
covered[idx] = true
|
|
}
|
|
for idx, ok := range covered {
|
|
if !ok {
|
|
return fmt.Errorf("%w: missing index %d from completed/interrupted partition", ErrParallelResumeStateInvalid, idx)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// parallelTaskResult is the per-item outcome that the fan-out
|
|
// goroutines send back to the main loop. `output` is any so the
|
|
// fan-out helper can be shared by runParallelInvoke callers of
|
|
// arbitrary I, O; the consumer type-asserts back to O when filling
|
|
// the output slice.
|
|
type parallelTaskResult struct {
|
|
index int
|
|
output any
|
|
err error
|
|
}
|
|
|
|
// runParallelFanout executes the per-item sub-workflow calls
|
|
// according to the configured concurrency policy and returns a
|
|
// channel of results. The channel is closed once every item has
|
|
// reported (success, interrupt, or error).
|
|
//
|
|
// Concurrency policy:
|
|
// - maxConcurrency <= 1: strictly sequential, no goroutines
|
|
// spawned (matches plan §"Concurrency policy" and the P0
|
|
// acceptance criterion "no goroutine spawns for 0 or 1").
|
|
// - maxConcurrency > 1: bounded fan-out via a buffered channel
|
|
// semaphore of size maxConcurrency. The first item runs on
|
|
// the main goroutine; subsequent items run in worker
|
|
// goroutines that acquire the semaphore before invoking.
|
|
//
|
|
// Per-item panics are recovered and surfaced as a normal error
|
|
// wrapped with "item %d:" so the outer lambda never crashes.
|
|
func runParallelFanout[I, O any](
|
|
ctx context.Context,
|
|
nodeKey string,
|
|
sub compose.Runnable[I, O],
|
|
items []I,
|
|
indices []int,
|
|
options *parallelOptions,
|
|
bridgeState *parallelBridgeState,
|
|
) <-chan parallelTaskResult {
|
|
resultCh := make(chan parallelTaskResult, len(indices))
|
|
if len(indices) == 0 {
|
|
close(resultCh)
|
|
return resultCh
|
|
}
|
|
|
|
runOne := func(idx int) {
|
|
// Derive the per-item checkpoint ID. The builder is
|
|
// invoked on the first run AND on resume. An empty
|
|
// return is treated as "no per-item id"; the option
|
|
// is skipped.
|
|
var cpID string
|
|
if options.enableSubCheckpoint {
|
|
cpID = options.checkpointBuilder(nodeKey, idx)
|
|
}
|
|
|
|
// Per-item address segment.
|
|
subCtx := compose.AppendAddressSegment(ctx, ParallelAddressSegment, strconv.Itoa(idx))
|
|
|
|
// Bridge store wiring for this item.
|
|
subCtx = withParallelBridgeState(subCtx, bridgeState)
|
|
if options.contextBuilder != nil {
|
|
subCtx = options.contextBuilder(subCtx, items[idx], idx)
|
|
}
|
|
|
|
invokeOpts := make([]compose.Option, 0, len(options.runOpts)+1)
|
|
if options.enableSubCheckpoint && cpID != "" {
|
|
invokeOpts = append(invokeOpts, compose.WithCheckPointID(cpID))
|
|
}
|
|
invokeOpts = append(invokeOpts, options.runOpts...)
|
|
|
|
var out O
|
|
var err error
|
|
func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
err = fmt.Errorf("item %d panic: %v", idx, r)
|
|
}
|
|
}()
|
|
out, err = sub.Invoke(subCtx, items[idx], invokeOpts...)
|
|
}()
|
|
resultCh <- parallelTaskResult{index: idx, output: out, err: err}
|
|
}
|
|
|
|
// Strictly sequential path: no goroutines, regardless of
|
|
// input length.
|
|
if options.maxConcurrency <= 1 {
|
|
for _, idx := range indices {
|
|
runOne(idx)
|
|
}
|
|
close(resultCh)
|
|
return resultCh
|
|
}
|
|
|
|
// Concurrent path. Use a buffered channel semaphore.
|
|
sem := make(chan struct{}, options.maxConcurrency)
|
|
var wg sync.WaitGroup
|
|
for i, idx := range indices {
|
|
wg.Add(1)
|
|
idx := idx
|
|
if i == 0 {
|
|
// First task runs on the main goroutine.
|
|
runOne(idx)
|
|
wg.Done()
|
|
continue
|
|
}
|
|
go func() {
|
|
defer wg.Done()
|
|
sem <- struct{}{}
|
|
defer func() { <-sem }()
|
|
runOne(idx)
|
|
}()
|
|
}
|
|
go func() {
|
|
wg.Wait()
|
|
close(resultCh)
|
|
}()
|
|
return resultCh
|
|
}
|
|
|
|
// parallelBridgeStoreKey is the context key for the per-run
|
|
// parallel bridge state.
|
|
type parallelBridgeStoreKey struct{}
|
|
|
|
// parallelBridgeState is the in-memory map backing the per-item
|
|
// child checkpoints. It is owned by AddParallelNode and passed
|
|
// through ctx so the parallelBridgeStore can find it.
|
|
type parallelBridgeState struct {
|
|
mu sync.RWMutex
|
|
data map[string][]byte
|
|
}
|
|
|
|
func newParallelBridgeState(data map[string][]byte) *parallelBridgeState {
|
|
cloned := cloneCheckpointMap(data)
|
|
if cloned == nil {
|
|
cloned = make(map[string][]byte)
|
|
}
|
|
return ¶llelBridgeState{data: cloned}
|
|
}
|
|
|
|
func (s *parallelBridgeState) get(id string) ([]byte, bool) {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
v, ok := s.data[id]
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
buf := make([]byte, len(v))
|
|
copy(buf, v)
|
|
return buf, true
|
|
}
|
|
|
|
func (s *parallelBridgeState) set(id string, payload []byte) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.data == nil {
|
|
s.data = make(map[string][]byte)
|
|
}
|
|
buf := make([]byte, len(payload))
|
|
copy(buf, payload)
|
|
s.data[id] = buf
|
|
}
|
|
|
|
func (s *parallelBridgeState) snapshot() map[string][]byte {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return cloneCheckpointMap(s.data)
|
|
}
|
|
|
|
// parallelBridgeStore is the CheckPointStore implementation that
|
|
// reads/writes the parallel bridge state from ctx. It is registered
|
|
// on the inner sub-workflow's Compile call (when
|
|
// WithParallelEnableSubCheckpoint(true) is in effect).
|
|
type parallelBridgeStore struct{}
|
|
|
|
func newParallelBridgeStore() *parallelBridgeStore {
|
|
return ¶llelBridgeStore{}
|
|
}
|
|
|
|
func (s *parallelBridgeStore) Get(ctx context.Context, checkPointID string) ([]byte, bool, error) {
|
|
state, ok := ctx.Value(parallelBridgeStoreKey{}).(*parallelBridgeState)
|
|
if !ok || state == nil {
|
|
return nil, false, nil
|
|
}
|
|
payload, found := state.get(checkPointID)
|
|
return payload, found, nil
|
|
}
|
|
|
|
func (s *parallelBridgeStore) Set(ctx context.Context, checkPointID string, checkPoint []byte) error {
|
|
state, ok := ctx.Value(parallelBridgeStoreKey{}).(*parallelBridgeState)
|
|
if !ok || state == nil {
|
|
return nil
|
|
}
|
|
state.set(checkPointID, checkPoint)
|
|
return nil
|
|
}
|
|
|
|
// withParallelBridgeState wires the per-run bridge state into ctx.
|
|
func withParallelBridgeState(ctx context.Context, state *parallelBridgeState) context.Context {
|
|
return context.WithValue(ctx, parallelBridgeStoreKey{}, state)
|
|
}
|
|
|
|
// store returns the per-run CheckPointStore that the inner
|
|
// sub-workflow should use. It is captured by closure when the
|
|
// inner sub-workflow is compiled.
|
|
func (s *parallelBridgeState) store() *parallelBridgeStore {
|
|
return newParallelBridgeStore()
|
|
}
|