mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
### What problem does this PR solve? core module for agent layer built on top of graph engine #16039 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
101 lines
2.5 KiB
Go
101 lines
2.5 KiB
Go
package core
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/gob"
|
|
"errors"
|
|
"fmt"
|
|
)
|
|
|
|
// ---- AgentLoop checkpoint serialization and lifecycle ----
|
|
|
|
type CheckPointDeleter interface {
|
|
Delete(ctx context.Context, key string) error
|
|
}
|
|
|
|
func marshalTurnLoopCheckpoint[T any](c *agentLoopCheckpoint[T]) ([]byte, error) {
|
|
buf := new(bytes.Buffer)
|
|
if err := gob.NewEncoder(buf).Encode(c); err != nil {
|
|
return nil, err
|
|
}
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
func unmarshalTurnLoopCheckpoint[T any](data []byte) (*agentLoopCheckpoint[T], error) {
|
|
var c agentLoopCheckpoint[T]
|
|
if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&c); err != nil {
|
|
return nil, err
|
|
}
|
|
return &c, nil
|
|
}
|
|
|
|
func (l *AgentLoop[T]) saveTurnLoopCheckpoint(ctx context.Context, checkPointID string, c *agentLoopCheckpoint[T]) error {
|
|
if l.config.Store == nil {
|
|
return errors.New("checkpoint store is nil")
|
|
}
|
|
data, err := marshalTurnLoopCheckpoint(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return l.config.Store.Set(ctx, checkPointID, data)
|
|
}
|
|
|
|
func (l *AgentLoop[T]) deleteTurnLoopCheckpoint(ctx context.Context, checkPointID string) error {
|
|
if l.config.Store == nil {
|
|
return nil
|
|
}
|
|
if deleter, ok := l.config.Store.(CheckPointDeleter); ok {
|
|
return deleter.Delete(ctx, checkPointID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (l *AgentLoop[T]) tryLoadCheckpoint(ctx context.Context) error {
|
|
checkPointID := l.config.CheckpointID
|
|
if checkPointID == "" || l.config.Store == nil {
|
|
return nil
|
|
}
|
|
|
|
l.loadCheckpointID = checkPointID
|
|
|
|
data, existed, err := l.config.Store.Get(ctx, checkPointID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load checkpoint[%s]: %w", checkPointID, err)
|
|
}
|
|
if !existed {
|
|
return nil
|
|
}
|
|
|
|
var cp *agentLoopCheckpoint[T]
|
|
if len(data) == 0 {
|
|
return nil
|
|
}
|
|
cp, err = unmarshalTurnLoopCheckpoint[T](data)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to unmarshal checkpoint[%s]: %w", checkPointID, err)
|
|
}
|
|
|
|
newItems := l.buffer.TakeAll()
|
|
|
|
if cp.HasRunnerState {
|
|
if len(cp.RunnerCheckpoint) == 0 {
|
|
l.buffer.PushFront(newItems)
|
|
return fmt.Errorf("checkpoint[%s] has runner state but bytes are empty", checkPointID)
|
|
}
|
|
l.pendingResume = &agentLoopPendingResume[T]{
|
|
interrupted: append([]T{}, cp.CanceledItems...),
|
|
unhandled: append([]T{}, cp.UnhandledItems...),
|
|
newItems: append([]T{}, newItems...),
|
|
resumeBytes: append([]byte{}, cp.RunnerCheckpoint...),
|
|
}
|
|
} else {
|
|
items := make([]T, 0, len(cp.UnhandledItems)+len(newItems))
|
|
items = append(items, cp.UnhandledItems...)
|
|
items = append(items, newItems...)
|
|
l.buffer.PushFront(items)
|
|
}
|
|
|
|
return nil
|
|
}
|