mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-30 07:51:10 +08:00
### What problem does this PR solve? core module for agent layer built on top of graph engine #16039 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
246 lines
5.6 KiB
Go
246 lines
5.6 KiB
Go
package core
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"runtime/debug"
|
|
"time"
|
|
|
|
"ragflow/internal/harness/core/schema"
|
|
)
|
|
|
|
// ---- AgentLoop agent execution and event handling ----
|
|
|
|
func (l *AgentLoop[T]) setupBridgeStore(spec *turnRunSpec[T], runOpts []RunOption) ([]RunOption, *bridgeStore, error) {
|
|
store := l.config.Store
|
|
if store == nil && spec.isResume {
|
|
return nil, nil, fmt.Errorf("failed to resume agent: checkpoint store is nil")
|
|
}
|
|
if store == nil {
|
|
return runOpts, nil, nil
|
|
}
|
|
runOpts = append(runOpts, WithCheckPointID(bridgeCheckpointID))
|
|
if spec.isResume {
|
|
if len(spec.resumeBytes) == 0 {
|
|
return nil, nil, fmt.Errorf("resume checkpoint is empty")
|
|
}
|
|
return runOpts, newResumeBridgeStore(bridgeCheckpointID, spec.resumeBytes), nil
|
|
}
|
|
return runOpts, newBridgeStore(), nil
|
|
}
|
|
|
|
// watchPreempt runs for the lifetime of a single active turn.
|
|
func (l *AgentLoop[T]) watchPreempt(done <-chan struct{}, agentCancelFunc AgentCancelFunc, preemptDone chan struct{}) {
|
|
preemptDoneClosed := false
|
|
for {
|
|
select {
|
|
case <-done:
|
|
return
|
|
case <-l.preemptCtrl.notify:
|
|
req, ok := l.preemptCtrl.receivePreempt()
|
|
if !ok {
|
|
continue
|
|
}
|
|
_, contributed := agentCancelFunc(req.cancelOptions(time.Now())...)
|
|
if contributed && !preemptDoneClosed {
|
|
close(preemptDone)
|
|
preemptDoneClosed = true
|
|
}
|
|
req.ack()
|
|
}
|
|
}
|
|
}
|
|
|
|
// watchStop runs for the lifetime of a single active turn.
|
|
func (l *AgentLoop[T]) watchStop(done <-chan struct{}, agentCancelFunc AgentCancelFunc, stoppedDone chan struct{}) {
|
|
stoppedClosed := false
|
|
|
|
submit := func(req *stopCancelRequest) {
|
|
_, contributed := agentCancelFunc(req.cancelOptions(time.Now())...)
|
|
if contributed && !stoppedClosed {
|
|
close(stoppedDone)
|
|
stoppedClosed = true
|
|
}
|
|
}
|
|
|
|
for {
|
|
if req, ok := l.stopCtrl.receiveCancel(); ok {
|
|
submit(req)
|
|
}
|
|
|
|
select {
|
|
case <-done:
|
|
return
|
|
case <-l.stopCtrl.notify:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (l *AgentLoop[T]) runAgentAndHandleEvents(
|
|
ctx context.Context,
|
|
agent Agent,
|
|
spec *turnRunSpec[T],
|
|
) error {
|
|
l.interruptContexts = nil
|
|
l.capturedCancelErr = nil
|
|
l.checkPointRunnerBytes = nil
|
|
|
|
var iter *AsyncIterator[*AgentEvent]
|
|
|
|
runOpts, ms, err := l.setupBridgeStore(spec, spec.runOpts)
|
|
if err != nil {
|
|
l.preemptCtrl.abortPlanningTurn().ack()
|
|
return err
|
|
}
|
|
store := l.config.Store
|
|
cancelOpt, agentCancelFunc := WithCancel()
|
|
runOpts = append(runOpts, cancelOpt)
|
|
|
|
enableStreaming := false
|
|
if spec.input != nil {
|
|
enableStreaming = spec.input.EnableStreaming
|
|
}
|
|
runner := NewRunner(ctx, RunnerConfig[*schema.Message]{
|
|
EnableStreaming: enableStreaming,
|
|
Agent: agent,
|
|
CheckPointStore: ms,
|
|
})
|
|
|
|
preemptDone := make(chan struct{})
|
|
stoppedDone := make(chan struct{})
|
|
|
|
tc := &TurnContext[T]{
|
|
Loop: l,
|
|
Consumed: spec.consumed,
|
|
Preempted: preemptDone,
|
|
Stopped: stoppedDone,
|
|
StopCause: l.stopCtrl.cause,
|
|
}
|
|
l.preemptCtrl.beginActiveTurn(ctx, tc)
|
|
l.stopCtrl.beginActiveTurn()
|
|
defer func() {
|
|
l.stopCtrl.endActiveTurn()
|
|
l.preemptCtrl.endActiveTurn().ack()
|
|
}()
|
|
|
|
if spec.isResume {
|
|
var err error
|
|
if spec.resumeParams != nil {
|
|
iter, err = runner.ResumeWithParams(ctx, bridgeCheckpointID, spec.resumeParams, runOpts...)
|
|
} else {
|
|
iter, err = runner.Resume(ctx, bridgeCheckpointID, runOpts...)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("failed to resume agent: %w", err)
|
|
}
|
|
} else {
|
|
iter = runner.Run(ctx, spec.input.Messages, runOpts...)
|
|
}
|
|
|
|
// Wrap iterator to capture framework-level signals (CancelError, InterruptContexts)
|
|
srcIter := iter
|
|
proxyIter, proxyGen := NewAsyncIteratorPair[*AgentEvent]()
|
|
srcIterDone := make(chan struct{})
|
|
go func() {
|
|
defer func() {
|
|
proxyGen.Close()
|
|
close(srcIterDone)
|
|
}()
|
|
for {
|
|
event, ok := srcIter.Next()
|
|
if !ok {
|
|
break
|
|
}
|
|
if event != nil {
|
|
if event.Err != nil {
|
|
var cancelErr *CancelError
|
|
if errors.As(event.Err, &cancelErr) {
|
|
l.capturedCancelErr = cancelErr
|
|
}
|
|
}
|
|
if event.Action != nil && event.Action.Interrupted != nil {
|
|
l.interruptContexts = event.Action.Interrupted.InterruptContexts
|
|
}
|
|
}
|
|
proxyGen.Send(event)
|
|
}
|
|
}()
|
|
iter = proxyIter
|
|
|
|
handleEvents := func() error {
|
|
return l.onAgentEvents(ctx, tc, iter)
|
|
}
|
|
|
|
done := make(chan struct{})
|
|
var handleErr error
|
|
|
|
go func() {
|
|
defer func() {
|
|
panicErr := recover()
|
|
if panicErr != nil {
|
|
handleErr = fmt.Errorf("panic in OnAgentEvents: %v\n%s", panicErr, debug.Stack())
|
|
}
|
|
close(done)
|
|
}()
|
|
handleErr = handleEvents()
|
|
}()
|
|
go l.watchPreempt(done, agentCancelFunc, preemptDone)
|
|
go l.watchStop(done, agentCancelFunc, stoppedDone)
|
|
|
|
finalizeCheckpoint := func() error {
|
|
if store != nil && ms != nil {
|
|
data, ok, err := ms.Get(ctx, bridgeCheckpointID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read runner checkpoint: %w", err)
|
|
}
|
|
if ok {
|
|
l.checkPointRunnerBytes = append([]byte{}, data...)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
select {
|
|
case <-done:
|
|
select {
|
|
case <-preemptDone:
|
|
return nil
|
|
default:
|
|
}
|
|
if err := finalizeCheckpoint(); err != nil {
|
|
if handleErr != nil {
|
|
handleErr = fmt.Errorf("%w; checkpoint error: %v", handleErr, err)
|
|
} else {
|
|
handleErr = err
|
|
}
|
|
}
|
|
return l.applyFrameworkCapturedError(handleErr)
|
|
case <-preemptDone:
|
|
srcIter.Close()
|
|
<-srcIterDone
|
|
<-done
|
|
return nil
|
|
case <-stoppedDone:
|
|
<-done
|
|
if err := finalizeCheckpoint(); err != nil {
|
|
if handleErr != nil {
|
|
handleErr = fmt.Errorf("%w; checkpoint error: %v", handleErr, err)
|
|
} else {
|
|
handleErr = err
|
|
}
|
|
}
|
|
return l.applyFrameworkCapturedError(handleErr)
|
|
}
|
|
}
|
|
|
|
func (l *AgentLoop[T]) applyFrameworkCapturedError(handleErr error) error {
|
|
if handleErr != nil {
|
|
return handleErr
|
|
}
|
|
if l.capturedCancelErr != nil {
|
|
return l.capturedCancelErr
|
|
}
|
|
return nil
|
|
}
|