Files
ragflow/internal/agent/component/categorize.go
Zhichang Yu 3f805a64f1 feat(agent): align Go agent behavior with Python (except retrieval component) (#16225)
## Summary

Aligns the **Go agent runtime/canvas/components/tools** behavior with
the **Python `agent/` implementation** so the same stored canvas DSL
produces the same execution result on either side. Every component,
tool, and runtime primitive in `internal/agent/` is now driven by the
same semantics as its Python counterpart — variable resolution, template
substitution, control flow, error reporting, retry/cancel, and stream
event shapes.

The **retrieval component is the one explicit exception** in this PR. It
is being reworked in a separate change and is excluded from this
alignment pass; the wrapper slot (`universe_a_wrappers.go →
newRetrievalComponent`) is preserved.

## Scope of alignment

### Components (all aligned with `agent/component/`)
`Begin` · `Message` · `LLM` (incl. ChatTemplateKwargs,
MessageHistoryWindowSize, VisualFiles, Cite, OutputStructure,
JSONOutput, TopP, MaxRetries, DelayAfterError, credentials) · `Agent`
(react + tool artifact capture + `Reset()` interface-assert) · `Switch`
(12/12 operators, Python-equivalent semantics) · `Categorize` · `Invoke`
· `Iteration` · `Loop` (macro-expansion through `workflowx.AddLoopNode`)
· `UserFillUp` (Python-equivalent interrupt/resume via eino
`compose.Interrupt`/`ResumeWithData`) · `FillUp` · `DataOperations` ·
`ListOperations` · `StringTransform` · `VariableAggregator` ·
`VariableAssigner` · `Browser` (full stagehand runtime parity) ·
`DocsGenerator` · `ExcelProcessor`.

### Tools (all aligned with `agent/tools/`)
`Retrieval` (wrapper slot only — logic out of scope) · `MCPToolAdapter`
(streamable-HTTP) · `CodeExec` (sandbox bridge with
`code_exec_contract.go` matching Python contract) · `AkShare` · `ArXiv`
· `Crawler` · `DeepL` · `DuckDuckGo` · `Email` · `ExeSQL` · `GitHub` ·
`Google` · `GoogleScholar` · `Jin10` · `PubMed` · `QWeather` · `SearXNG`
· `Tavily` · `Tushare` · `Wencai` · `Wikipedia` · `YahooFinance` —
uniform `eino tool.InvokableTool` interface, SSRF protection, shared
HTTP client.

### Canvas execution engine (`internal/agent/canvas/`)
Aligned with Python's `agent/canvas.py`:
- **Scheduler** (`scheduler.go`): state pre/post handlers, node lambdas,
per-component timeout resolver (4-level: per-class env → per-class table
→ uniform env → 600s fallback), `legacyNoOpNames`.
- **Loop subgraph** (`loop_subgraph.go`): Python-equivalent
`AddLoopNode` macro expansion + condition translation.
- **Multibranch** (`multibranch.go`): `Switch` / `Categorize` routing
via `compose.NewGraphMultiBranch` — same branch selection semantics as
Python.
- **Parallel subgraph** (`parallel_subgraph.go`): matches Python's
parallel fan-out contract.
- **Interrupt/Resume** (`interrupt_resume.go`): `UserFillUpNodeBody` /
`IsInterruptError` / `ExtractInterruptContexts` — replaces the
deprecated Python sentinel chain with eino's native interrupt API,
preserving the same external behavior.
- **Checkpoint** (`checkpoint_store.go`): `RedisCheckPointStore`
Get/Set/Delete, with business metadata (status / canvas_id /
parent_run_id) on a parallel Redis Hash.
- **RunTracker** (`run_tracker.go`): Start / MarkSucceeded / MarkFailed
/ MarkCancelled / AttachCheckpoint — same lifecycle as the Python run
record.
- **Cancel** (`cancel.go`): Redis pub/sub watch.
- **Stream** (`stream.go`): SSE channel with `messages` / `waiting` /
`errors` / `done` events, same shape as Python's `agent.canvas.RunEvent`
payload.

### DSL bridge (`internal/agent/dsl/`)
- `normalize.go`: v1↔v2 collapsed into a single wire format — Python and
Go consume the same stored JSON.
- `reset.go`: per-run state reset matches Python's `Canvas.reset()`
semantics.
- Testdata mirrors Python's `agent_msg.json` / `all.json` / etc.

### Runtime (`internal/agent/runtime/`)
- `CanvasState` / `NewCanvasState` / `GetVar` / `SetVar` / `ReadVars`:
same `{{cpn_id@param}}` resolution model.
- `ResolveTemplate` (regex fast path + gonja fallback) — Python
Jinja-style semantics.
- `selector.go`, `metrics.go`, `component.go`: shared runtime contracts.

## Out of scope (intentionally)

- **`Retrieval` component logic** — wrapped only; full parity lands in a
follow-up PR.
- **Frontend** — only minor dsl-bridge / canvas UX fixes ride along.
- **CLI / admin / model registry** — orthogonal to agent behavior.

## How alignment is verified

`internal/service/agent_run_e2e_test.go` exercises the **full production
chain** against real Python-shaped DSL fixtures:
```
loadCanvasForUser → versionDAO.GetLatest → decodeCanvasFromDSL →
canvas.Compile → cc.Workflow.Invoke → answer extraction
```
using in-memory SQLite + miniredis (no Docker). Covers:
- `TestRunAgent_RealCanvas_BeginMessage` — happy path, `{{sys.query}}`
resolution
- `TestRunAgent_RealCanvas_WaitForUserResume` — two-run resume cycle
(Python-equivalent)
- `TestRunAgent_RealCanvas_CompileFails` — unknown component name →
sanitized error (Python-equivalent)
- `TestRunAgent_RealCanvas_InvokeFails` — unresolvable template ref
(Python-equivalent)
- `TestRunAgent_RunTracker_AttachCheckpoint_CallSequence` —
Start→AttachCheckpoint→MarkSucceeded lifecycle

`internal/handler/agent_test.go` — SSE streaming parity (`Content-Type:
text/event-stream`, `data: {…}\n\n`, trailing `data: [DONE]\n\n`,
OpenAI-compatible non-stream `choices`).

`internal/agent/canvas/fixture_compile_test.go` + per-component tests
pin the Python-equivalent outputs.

```
go test -count=1 -v -run 'TestRunAgent_RealCanvas|TestRunAgent_RunTracker' ./internal/service/
```

## Design reference

`docs/develop/agent-go-port-design.md` (1329 lines, last cross-checked
2026-06-17) — module layout, per-component / per-tool inventory,
corner-case catalogue, and the actionable backlog (Section 14, including
the retrieval alignment follow-up).

---------

Co-authored-by: Claude <noreply@anthropic.com>
2026-06-22 11:58:29 +08:00

438 lines
13 KiB
Go

// Package component — Categorize (T3).
//
// LLM-based classifier. The component asks the model to pick exactly
// one of the configured categories, returns the chosen category name
// plus a uniform score map (1.0 for the chosen category, 0.0 for the
// rest). The MultiBranch wiring in canvas/multibranch.go consumes
// outputs["_next"] for runtime routing; the field is reserved for
// that consumer.
package component
import (
"context"
"fmt"
"sort"
"strings"
"github.com/cloudwego/eino/schema"
)
// CategorizeComponent is an LLM classifier.
type CategorizeComponent struct {
param CategorizeParam
}
// CategorizeParam captures the (resolved) DSL parameters for a Categorize node.
type CategorizeParam struct {
ModelID string
Items []string
Categories []string
CategoryRoutes map[string]string
SysPrompt string
DefaultCategory string
Driver string
APIKey string
BaseURL string
}
// CategorizeOutput mirrors the outputs map (per plan §2.11.3 row 6):
//
// "category" string — chosen category name (or default if
// model returned something not in list)
// "scores" map[string]float64
// "_next" []string — reserved for canvas/multibranch.go routing
type CategorizeOutput struct {
Category string
Scores map[string]float64
Next []string
}
// NewCategorizeComponent builds a CategorizeComponent from raw params.
func NewCategorizeComponent(p CategorizeParam) *CategorizeComponent {
return &CategorizeComponent{param: p}
}
// Name returns the registered component name.
func (c *CategorizeComponent) Name() string { return "Categorize" }
// Invoke calls the chat model, parses the response for a category, and
// returns the chosen category (or the default if the model returned
// something outside the configured set).
func (c *CategorizeComponent) Invoke(ctx context.Context, inputs map[string]any) (map[string]any, error) {
p := mergeCategorizeParam(c.param, inputs)
originalModelID := p.ModelID
if p.Driver == "" && p.ModelID != "" {
if modelID, driver, ok := splitCompositeLLMID(p.ModelID); ok {
p.Driver = driver
p.ModelID = modelID
}
}
p.APIKey, p.BaseURL = resolveTenantLLMConfig(ctx, p.Driver, p.ModelID, p.APIKey, p.BaseURL, originalModelID)
if p.ModelID == "" {
return nil, &ParamError{Field: "model_id", Reason: "required"}
}
if len(p.Categories) == 0 {
return nil, &ParamError{Field: "categories", Reason: "at least one category is required"}
}
if p.DefaultCategory == "" {
// Fall back to the first category so the run never fails purely
// because the user omitted the default.
p.DefaultCategory = p.Categories[0]
}
inv := getDefaultChatInvoker()
sysPrompt := p.SysPrompt
if sysPrompt == "" {
sysPrompt = "You are a strict classifier."
}
userPrompt := buildCategorizePrompt(p)
msgs := []schema.Message{
{Role: schema.System, Content: sysPrompt},
{Role: schema.User, Content: userPrompt},
}
resp, err := inv.Invoke(ctx, ChatInvokeRequest{
Driver: p.Driver,
ModelName: p.ModelID,
APIKey: p.APIKey,
BaseURL: p.BaseURL,
Messages: msgs,
})
if err != nil {
return nil, fmt.Errorf("component: Categorize.Invoke: %w", err)
}
chosen, score := pickCategory(resp.Content, p.Categories, p.DefaultCategory)
next := []string{}
if route := p.CategoryRoutes[chosen]; route != "" {
next = []string{route}
}
return map[string]any{
"category": chosen,
"category_name": chosen,
"scores": score,
"_next": next,
}, nil
}
// Stream mirrors Invoke as a single chunk.
func (c *CategorizeComponent) Stream(ctx context.Context, inputs map[string]any) (<-chan map[string]any, error) {
out := make(chan map[string]any, 1)
go func() {
defer close(out)
result, err := c.Invoke(ctx, inputs)
if err != nil {
out <- map[string]any{"error": err.Error()}
return
}
out <- result
}()
return out, nil
}
// Inputs returns parameter metadata for tooling.
func (c *CategorizeComponent) Inputs() map[string]string {
return map[string]string{
"model_id": "Provider-side model identifier",
"items": "Optional list of items to classify (added to the prompt as context)",
"categories": "List of allowed category names (response must match one)",
"sys_prompt": "Optional system prompt; defaults to a strict classifier instruction",
"default_category": "Category returned if the model's answer is not in `categories` (defaults to categories[0])",
"driver": "Provider driver name",
"api_key": "Override API key",
}
}
// Outputs returns output metadata.
func (c *CategorizeComponent) Outputs() map[string]string {
return map[string]string{
"category": "Chosen category name (one of the configured list, or the default)",
"category_name": "Alias of category for v1 canvas templates",
"scores": "Score map (1.0 for the chosen category, 0.0 for the rest)",
"_next": "Downstream route handle(s) selected from categorize item uuids",
}
}
// buildCategorizePrompt assembles a prompt that asks the model to pick a
// category. The categories are listed deterministically (sorted) so the
// prompt is stable across runs.
func buildCategorizePrompt(p CategorizeParam) string {
cats := append([]string(nil), p.Categories...)
sort.Strings(cats)
var b strings.Builder
b.WriteString("Classify the following item into exactly one of these categories:\n")
for _, c := range cats {
b.WriteString("- ")
b.WriteString(c)
b.WriteString("\n")
}
if len(p.Items) > 0 {
b.WriteString("\nItems:\n")
for _, it := range p.Items {
b.WriteString("- ")
b.WriteString(it)
b.WriteString("\n")
}
}
b.WriteString("\nRespond with ONLY the category name, no other text.")
return b.String()
}
// pickCategory extracts a category from the model's response. Strategy:
// 1. exact match (case-sensitive)
// 2. case-insensitive match
// 3. fall back to default
//
// Substring matching is intentionally avoided — it makes the picker too
// eager ("I have no idea" would match a category named "a"). If the model
// can't produce one of the categories verbatim, the default is used.
//
// Scores are 1.0 for the chosen category, 0.0 for the rest.
func pickCategory(response string, categories []string, def string) (string, map[string]float64) {
scores := make(map[string]float64, len(categories))
for _, c := range categories {
scores[c] = 0
}
resp := strings.TrimSpace(response)
resp = strings.Trim(resp, "\"'`\n\r\t ")
resp = strings.TrimPrefix(resp, "category:")
resp = strings.TrimPrefix(resp, "Category:")
resp = strings.TrimSpace(resp)
for _, c := range categories {
if resp == c {
scores[c] = 1
return c, scores
}
}
lower := strings.ToLower(resp)
for _, c := range categories {
if strings.ToLower(c) == lower {
scores[c] = 1
return c, scores
}
}
scores[def] = 1
return def, scores
}
// mergeCategorizeParam layers raw inputs over the receiver's default param set.
//
// v1 aliases accepted alongside the v2 names: "llm_id" → "model_id",
// "category_description" (a map[string]string) → "categories" (the keys
// of the map), and "base_url" → "BaseURL". v1 fixtures use the
// short / dict forms; without these aliases the v1→v2 conversion step
// would have to run before the factory builds the component.
func mergeCategorizeParam(base CategorizeParam, inputs map[string]any) CategorizeParam {
p := base
if v, ok := stringFrom(inputs, "model_id"); ok {
p.ModelID = v
} else if v, ok := stringFrom(inputs, "llm_id"); ok {
p.ModelID = v
}
if v, ok := sliceFrom(inputs, "items"); ok {
p.Items = v
}
if v, ok := sliceFrom(inputs, "categories"); ok {
p.Categories = v
} else if m, ok := stringMapFrom(inputs, "category_description"); ok && len(m) > 0 {
// v1 stores the categories as a map of {name: description}.
// We only need the keys to drive the picker.
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
p.Categories = keys
}
if routes, ok := categoryRoutesFrom(inputs, "category_description"); ok {
p.CategoryRoutes = routes
}
if v, ok := stringFrom(inputs, "sys_prompt"); ok {
p.SysPrompt = v
} else if v, ok := stringFrom(inputs, "system_prompt"); ok {
p.SysPrompt = v
}
if v, ok := stringFrom(inputs, "default_category"); ok {
p.DefaultCategory = v
}
if v, ok := stringFrom(inputs, "driver"); ok {
p.Driver = v
}
if v, ok := stringFrom(inputs, "api_key"); ok {
p.APIKey = v
}
if v, ok := stringFrom(inputs, "base_url"); ok {
p.BaseURL = v
}
return p
}
// stringMapFrom extracts map[string]string from inputs[name]. The v1
// "category_description" field is shaped this way (name → human
// description); we only consume the keys.
func stringMapFrom(inputs map[string]any, name string) (map[string]string, bool) {
v, ok := inputs[name]
if !ok {
return nil, false
}
raw, ok := v.(map[string]any)
if !ok {
return nil, false
}
out := make(map[string]string, len(raw))
for k, child := range raw {
if s, ok := child.(string); ok {
out[k] = s
continue
}
// Some encoders nest the description under a "description"
// key; handle that fallback defensively.
if nested, ok := child.(map[string]any); ok {
if s, ok := nested["description"].(string); ok {
out[k] = s
continue
}
}
out[k] = ""
}
return out, true
}
func categoryRoutesFrom(inputs map[string]any, name string) (map[string]string, bool) {
raw, ok := inputs[name]
if !ok {
return nil, false
}
src, ok := raw.(map[string]any)
if !ok || len(src) == 0 {
return nil, false
}
out := make(map[string]string, len(src))
for category, child := range src {
nested, ok := child.(map[string]any)
if !ok {
continue
}
if s, ok := firstRouteTarget(nested["to"]); ok {
out[category] = s
continue
}
if s, ok := nested["uuid"].(string); ok && s != "" {
out[category] = s
}
}
return out, len(out) > 0
}
// init registers CategorizeComponent with the orchestrator-owned registry.
func init() {
Register("Categorize", func(params map[string]any) (Component, error) {
var p CategorizeParam
if v, ok := stringFrom(params, "model_id"); ok {
p.ModelID = v
} else if v, ok := stringFrom(params, "llm_id"); ok {
p.ModelID = v
}
// Check the object-style []any of maps first. sliceFrom would
// otherwise match the same []any input and return (empty, true)
// for non-string elements, making the object branch unreachable.
if items, ok := params["items"].([]any); ok && len(items) > 0 {
names := make([]string, 0, len(items))
routes := make(map[string]string, len(items))
for _, item := range items {
m, ok := item.(map[string]any)
if !ok {
continue
}
name, _ := m["name"].(string)
if name == "" {
continue
}
names = append(names, name)
if route, ok := firstRouteTarget(m["to"]); ok {
routes[name] = route
} else if uuid, _ := m["uuid"].(string); uuid != "" {
routes[name] = uuid
}
if examples, ok := m["examples"].([]any); ok {
for _, example := range examples {
em, ok := example.(map[string]any)
if !ok {
continue
}
if v, _ := em["value"].(string); v != "" {
p.Items = append(p.Items, v)
}
}
}
}
if len(names) > 0 {
p.Categories = names
}
if len(routes) > 0 {
p.CategoryRoutes = routes
}
} else if v, ok := sliceFrom(params, "items"); ok {
p.Items = v
}
if v, ok := sliceFrom(params, "categories"); ok {
p.Categories = v
} else if m, ok := params["category_description"].(map[string]any); ok && len(m) > 0 {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
p.Categories = keys
routes := make(map[string]string, len(m))
for k, child := range m {
nested, ok := child.(map[string]any)
if !ok {
continue
}
if route, ok := firstRouteTarget(nested["to"]); ok {
routes[k] = route
} else if uuid, _ := nested["uuid"].(string); uuid != "" {
routes[k] = uuid
}
}
if len(routes) > 0 {
p.CategoryRoutes = routes
}
}
if v, ok := stringFrom(params, "sys_prompt"); ok {
p.SysPrompt = v
} else if v, ok := stringFrom(params, "system_prompt"); ok {
p.SysPrompt = v
}
if v, ok := stringFrom(params, "default_category"); ok {
p.DefaultCategory = v
}
if v, ok := stringFrom(params, "driver"); ok {
p.Driver = v
}
if v, ok := stringFrom(params, "api_key"); ok {
p.APIKey = v
}
if v, ok := stringFrom(params, "base_url"); ok {
p.BaseURL = v
}
return NewCategorizeComponent(p), nil
})
}
func firstRouteTarget(v any) (string, bool) {
if s, ok := v.(string); ok && s != "" {
return s, true
}
items, ok := v.([]any)
if !ok || len(items) == 0 {
return "", false
}
s, ok := items[0].(string)
if !ok || s == "" {
return "", false
}
return s, true
}