Files
ragflow/internal/harness/core/interrupt.go
Zhichang Yu 2362210caf refactor(log): unify Go logging to zap with rotation, strip per-package levels (#16261)
Refactor the Go agent port's logging so every log line — gin access,
agent canvas events, harness warnings, fatal boot errors — flows through
a single common.Logger (zap) backed by a rotated file, with structured
fields, level filtering, and configurable rotation.

---------

Co-authored-by: Claude <noreply@anthropic.com>
2026-06-23 16:21:46 +08:00

429 lines
13 KiB
Go

package core
import (
"bytes"
"context"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/gob"
"errors"
"fmt"
"os"
"ragflow/internal/common"
"ragflow/internal/harness/core/schema"
"go.uber.org/zap"
)
// ---- Resume types ----
type ResumeInfo struct {
EnableStreaming bool
*InterruptInfo
WasInterrupted bool
InterruptState any
IsResumeTarget bool
ResumeData any
}
type InterruptInfo struct {
Data any
InterruptContexts []*InterruptCtx
}
// ---- Address types ----
type Address = []AddressSegment
type AddressSegment struct {
Type AddressSegmentType
ID string
}
type AddressSegmentType string
const (
AddressSegmentAgent AddressSegmentType = "agent"
AddressSegmentTool AddressSegmentType = "tool"
)
type InterruptCtx struct {
ID string
Address Address
Info any
State any
}
type InterruptSignal struct {
ID string
Address Address
Info any
State any
Children []*InterruptSignal
}
// ---- Interrupt constructors ----
func Interrupt(ctx context.Context, info any) *AgentEvent {
return TypedInterrupt[*schema.Message](ctx, info)
}
func TypedInterrupt[M MessageType](ctx context.Context, info any) *TypedAgentEvent[M] {
return &TypedAgentEvent[M]{Action: &AgentAction{Interrupted: &InterruptInfo{Data: info}}}
}
func StatefulInterrupt(ctx context.Context, info, state any) *AgentEvent {
return TypedStatefulInterrupt[*schema.Message](ctx, info, state)
}
func TypedStatefulInterrupt[M MessageType](ctx context.Context, info, state any) *TypedAgentEvent[M] {
addr := captureAddress(ctx)
return &TypedAgentEvent[M]{Action: &AgentAction{
Interrupted: &InterruptInfo{Data: info},
internalInterrupted: &InterruptSignal{
Info: info, State: state, Address: addr,
},
}}
}
func CompositeInterrupt(ctx context.Context, info, state any, subs ...*InterruptSignal) *AgentEvent {
return TypedCompositeInterrupt[*schema.Message](ctx, info, state, subs...)
}
func TypedCompositeInterrupt[M MessageType](ctx context.Context, info, state any, subs ...*InterruptSignal) *TypedAgentEvent[M] {
addr := captureAddress(ctx)
children := make([]*InterruptSignal, len(subs))
for i, sub := range subs {
cp := *sub
children[i] = &cp
}
return &TypedAgentEvent[M]{Action: &AgentAction{
Interrupted: &InterruptInfo{Data: info},
internalInterrupted: &InterruptSignal{
Info: info, State: state, Address: addr, Children: children,
},
}}
}
// captureAddress copies the current address segments from context.
func captureAddress(ctx context.Context) Address {
segs := getAddressSegments(ctx)
if len(segs) == 0 {
return nil
}
addr := make(Address, len(segs))
copy(addr, segs)
return addr
}
type addrSegKey struct{}
func AppendAddressSegment(ctx context.Context, t AddressSegmentType, id string) context.Context {
parent, _ := ctx.Value(addrSegKey{}).([]AddressSegment)
seg := make([]AddressSegment, len(parent)+1)
copy(seg, parent)
seg[len(parent)] = AddressSegment{Type: t, ID: id}
return context.WithValue(ctx, addrSegKey{}, seg)
}
func getAddressSegments(ctx context.Context) []AddressSegment {
if v, ok := ctx.Value(addrSegKey{}).([]AddressSegment); ok {
return v
}
return nil
}
// FromInterruptContexts builds an InterruptSignal tree from a flat slice of
// InterruptCtx. Returns nil when ctxs is empty.
func FromInterruptContexts(ctxs []*InterruptCtx) *InterruptSignal {
if len(ctxs) == 0 {
return nil
}
root := &InterruptSignal{}
buildFromCtxs(ctxs, root)
return root
}
func buildFromCtxs(ctxs []*InterruptCtx, parent *InterruptSignal) {
for _, c := range ctxs {
sig := &InterruptSignal{
ID: c.ID, Address: make(Address, len(c.Address)),
Info: c.Info, State: c.State,
}
copy(sig.Address, c.Address)
parent.Children = append(parent.Children, sig)
}
}
// ---- Checkpoint store ----
type CheckPointStore interface {
Get(ctx context.Context, key string) ([]byte, bool, error)
Set(ctx context.Context, key string, data []byte) error
}
// InterruptState wraps the opaque interrupt state for checkpoint serialization.
// Callers MUST register the concrete type stored in State via schema.RegisterName
// or gob.Register before saving a checkpoint; otherwise gob.Encode/Decode will
// panic at runtime for unregistered interface types.
type InterruptState struct{ State any }
type checkpointPayload struct {
RunCtx *runContext
Info *InterruptInfo
EnableStreaming bool
InterruptID2Address map[string]Address
InterruptID2State map[string]InterruptState
TenantID string
}
func init() {
schema.RegisterType("agentcore_checkpoint", func() any { return &checkpointPayload{} })
schema.RegisterType("agentcore_interrupt_state", func() any { return &InterruptState{} })
}
// ---- Checkpoint tenant isolation ----
type checkpointTenantKey struct{}
const DefaultCheckpointTenantKey = "tenant_id"
// WithCheckpointTenant embeds a tenant ID in the context for checkpoint tenant isolation.
// loadCheckpoint will reject checkpoints whose TenantID does not match this value.
func WithCheckpointTenant(ctx context.Context, tenantID string) context.Context {
return context.WithValue(ctx, checkpointTenantKey{}, tenantID)
}
func extractCheckpointTenant(ctx context.Context) string {
if tid, ok := ctx.Value(checkpointTenantKey{}).(string); ok && tid != "" {
return tid
}
if rc := getRunCtx(ctx); rc != nil && rc.Session != nil {
if tid, ok := rc.Session.Values[DefaultCheckpointTenantKey].(string); ok {
return tid
}
}
return ""
}
// ---- Checkpoint integrity (HMAC) ----
const (
hmacLen = 32
envHMACKey = "CHECKPOINT_HMAC_KEY"
)
// checkpointHMACKey reads the HMAC key from the CHECKPOINT_HMAC_KEY env var
// (base64-encoded, 32 bytes). If unset, a random key is generated per startup
// with a log warning — this is safe for single-process in-memory usage but
// will BREAK checkpoint resume across process restarts. Production deployments
// MUST set CHECKPOINT_HMAC_KEY to a stable base64-encoded 32-byte secret.
var checkpointHMACKey = loadCheckpointHMACKey()
func loadCheckpointHMACKey() []byte {
if env := os.Getenv(envHMACKey); env != "" {
k, err := base64.StdEncoding.DecodeString(env)
if err != nil {
panic("checkpoint HMAC key: invalid base64 in " + envHMACKey + ": " + err.Error())
}
if len(k) != 32 {
panic("checkpoint HMAC key: " + envHMACKey + " must decode to exactly 32 bytes, got " + fmt.Sprintf("%d", len(k)))
}
return k
}
k := make([]byte, 32)
if _, err := rand.Read(k); err != nil {
panic("failed to generate checkpoint HMAC key: " + err.Error())
}
common.Warn("checkpoint HMAC env not set — using random per-process key; checkpoint resume across restarts will fail", zap.String("env", envHMACKey))
return k
}
func computeCheckpointHMAC(payload []byte) []byte {
mac := hmac.New(sha256.New, checkpointHMACKey)
mac.Write(payload)
return mac.Sum(nil)
}
func loadCheckpoint(store CheckPointStore, ctx context.Context, cid string) (context.Context, *runContext, *ResumeInfo, error) {
data, exist, err := store.Get(ctx, cid)
if err != nil {
return nil, nil, nil, fmt.Errorf("checkpoint get: %w", err)
}
if !exist {
return nil, nil, nil, fmt.Errorf("checkpoint %s not found", cid)
}
// Split: first 32 bytes = HMAC, rest = payload
if len(data) < hmacLen {
return nil, nil, nil, fmt.Errorf("checkpoint %s too short (%d bytes)", cid, len(data))
}
mac, payload := data[:hmacLen], data[hmacLen:]
// Verify HMAC
expected := computeCheckpointHMAC(payload)
if !hmac.Equal(mac, expected) {
return nil, nil, nil, fmt.Errorf("checkpoint %s integrity check failed", cid)
}
var p checkpointPayload
if err := gob.NewDecoder(bytes.NewReader(payload)).Decode(&p); err != nil {
return nil, nil, nil, fmt.Errorf("decode checkpoint: %w", err)
}
// Verify tenant isolation
// Policy: when EITHER side carries a TenantID, BOTH must be present and match.
// Empty-on-both-sides is allowed for backward compat (non-tenant deployments).
currentTenant := extractCheckpointTenant(ctx)
if p.TenantID != "" || currentTenant != "" {
if p.TenantID == "" {
return nil, nil, nil, fmt.Errorf("checkpoint %s tenant mismatch: stored is empty, current=%q", cid, currentTenant)
}
if currentTenant == "" {
return nil, nil, nil, fmt.Errorf("checkpoint %s tenant mismatch: stored=%q, current is empty", cid, p.TenantID)
}
if p.TenantID != currentTenant {
return nil, nil, nil, fmt.Errorf("checkpoint %s tenant mismatch: stored=%q current=%q", cid, p.TenantID, currentTenant)
}
}
// Rebuild InterruptContexts from checkpoint maps
ics := mapsToInterruptContexts(p.InterruptID2Address, p.InterruptID2State)
if p.Info != nil {
p.Info.InterruptContexts = ics
}
return ctx, p.RunCtx, &ResumeInfo{
EnableStreaming: p.EnableStreaming,
InterruptInfo: p.Info,
}, nil
}
func saveCheckpoint(store CheckPointStore, ctx context.Context, key string, enableStreaming bool, info *InterruptInfo, is *InterruptSignal) error {
if store == nil {
return nil
}
rc := getRunCtx(ctx)
id2addr, id2state := signalToMaps(is)
tenantID := extractCheckpointTenant(ctx)
// Encode payload with tenant ID
p := checkpointPayload{
RunCtx: rc, Info: info, EnableStreaming: enableStreaming,
InterruptID2Address: id2addr, InterruptID2State: id2state,
TenantID: tenantID,
}
var buf bytes.Buffer
if err := gob.NewEncoder(&buf).Encode(p); err != nil {
return fmt.Errorf("encode checkpoint: %w", err)
}
payload := buf.Bytes()
// Prepend HMAC for integrity verification
mac := computeCheckpointHMAC(payload)
stored := make([]byte, 0, hmacLen+len(payload))
stored = append(stored, mac...)
stored = append(stored, payload...)
return store.Set(ctx, key, stored)
}
// signalToMaps recursively walks the InterruptSignal tree (is.Children) to build
// flat ID-to-Address and ID-to-State maps for checkpoint serialization.
// Children are populated by buildFromCtxs (called from FromInterruptContexts) or
// by CompositeInterrupt/TypedCompositeInterrupt constructors.
func signalToMaps(is *InterruptSignal) (map[string]Address, map[string]InterruptState) {
a, s := make(map[string]Address), make(map[string]InterruptState)
if is == nil {
return a, s
}
if is.ID != "" {
a[is.ID] = is.Address
if is.State != nil {
s[is.ID] = InterruptState{State: is.State}
}
}
for _, c := range is.Children {
ca, cs := signalToMaps(c)
for k, v := range ca {
a[k] = v
}
for k, v := range cs {
s[k] = v
}
}
return a, s
}
// mapsToInterruptContexts reconstructs a slice of InterruptCtx from checkpoint maps.
func mapsToInterruptContexts(id2addr map[string]Address, id2state map[string]InterruptState) []*InterruptCtx {
if len(id2addr) == 0 {
return nil
}
ics := make([]*InterruptCtx, 0, len(id2addr))
for id, addr := range id2addr {
ic := &InterruptCtx{ID: id, Address: make(Address, len(addr))}
copy(ic.Address, addr)
if st, ok := id2state[id]; ok {
ic.State = st.State
}
ics = append(ics, ic)
}
return ics
}
// getNextResumeAgent returns the deepest (innermost) agent address segment for
// single-agent resume routing. It scans address segments from the end.
func getNextResumeAgent(ctx context.Context, info *ResumeInfo) (string, error) {
segs := getAddressSegments(ctx)
if len(segs) == 0 {
return "", errors.New("no address segments for resume")
}
// Find the deepest agent segment
for i := len(segs) - 1; i >= 0; i-- {
if segs[i].Type == AddressSegmentAgent {
return segs[i].ID, nil
}
}
return "", errors.New("no agent address segment found for resume")
}
// getNextResumeAgents returns ALL agent address segments for multi-agent resume
// routing (e.g., parallel branches). Returns all agent segments as a set.
func getNextResumeAgents(ctx context.Context, info *ResumeInfo) (map[string]bool, error) {
segs := getAddressSegments(ctx)
if len(segs) == 0 {
return nil, errors.New("no address segments for resume")
}
result := make(map[string]bool)
for _, s := range segs {
if s.Type == AddressSegmentAgent {
result[s.ID] = true
}
}
if len(result) == 0 {
return nil, errors.New("no agent address segments found for resume")
}
return result, nil
}
// buildResumeInfo copies all ResumeInfo fields into a new struct and appends
// the agent address segment. IsResumeTarget and ResumeData are always copied
// regardless of WasInterrupted — callers that set them for non-interrupted
// resumes (e.g., initial resume of a fresh run) should have them preserved.
func buildResumeInfo(ctx context.Context, nextID string, info *ResumeInfo) (context.Context, *ResumeInfo) {
ctx = AppendAddressSegment(ctx, AddressSegmentAgent, nextID)
ri := &ResumeInfo{
EnableStreaming: info.EnableStreaming,
InterruptInfo: info.InterruptInfo,
WasInterrupted: info.WasInterrupted,
IsResumeTarget: info.IsResumeTarget,
ResumeData: info.ResumeData,
}
ctx = updateRunPathOnly(ctx, nextID)
return ctx, ri
}