feat[Go]: port agent webhook trigger, agent file upload/download, component input-form + debug endpoints from Python (#16403)

port agent webhook trigger, agent file upload/download, component
input-form + debug endpoints from Python
- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Zhichang Yu
2026-06-27 14:07:22 +08:00
committed by GitHub
parent 70546ea406
commit 5b09910d52
26 changed files with 4530 additions and 188 deletions

View File

@@ -21,7 +21,7 @@ RESTful API migration. Each deprecated route forwards to the corresponding
new API implementation.
Deprecated APIs and their replacements:
- POST /api/v1/agents/{agent_id}/completions -> POST /api/v1/agents/chat/completion
- POST /api/v1/agents/{agent_id}/completions -> POST /api/v1/agents/chat/completions
- POST /api/v1/agents_openai/{agent_id}/chat/completions -> POST /api/v1/agents/chat/completions
- POST /api/v1/chats/{chat_id}/completions -> POST /api/v1/chat/completions
- POST /api/v1/chats_openai/{chat_id}/chat/completions -> POST /api/v1/openai/{chat_id}/chat/completions

View File

@@ -1207,7 +1207,10 @@ async def test_db_connection():
return server_error_response(exc)
@manager.route("/agents/chat/completion", methods=["POST"]) # noqa: F821
# NOTE: The singular form `/agents/chat/completion` was a historical typo
# in earlier releases — no client, SDK, or doc ever used it, and the
# plural form below is the canonical route. The singular is intentionally
# NOT registered; clients sending it receive 404.
@manager.route("/agents/chat/completions", methods=["POST"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs

1
go.mod
View File

@@ -117,6 +117,7 @@ require (
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.20.0 // indirect
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect

2
go.sum
View File

@@ -211,6 +211,8 @@ github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=

View File

@@ -82,6 +82,16 @@ func (b *BeginComponent) Invoke(ctx context.Context, inputs map[string]any) (map
state.Sys["user_id"] = uid
}
// Webhook payload injection. The webhook HTTP handler sets
// root["webhook_payload"] (see service/agent.go RunAgentWithWebhook)
// which BuildWorkflow forwards into inputs. Surfacing it on
// state.Sys lets downstream components read sys.webhook_payload the
// same way they read sys.query / sys.user_id. The chat path never
// sets this key, so existing tests stay green.
if payload, ok := inputs["webhook_payload"].(map[string]any); ok && len(payload) > 0 {
state.Sys["webhook_payload"] = payload
}
// Passthrough: a shallow copy keeps the caller's map un-aliased.
out := make(map[string]any, len(inputs))
mapsCopy(out, inputs)
@@ -106,18 +116,20 @@ func (b *BeginComponent) Stream(ctx context.Context, inputs map[string]any) (<-c
// strings live on the struct / method above.
func (b *BeginComponent) Inputs() map[string]string {
return map[string]string{
"query": "User query string (the chat input).",
"user_id": "Optional user/tenant identifier.",
"inputs": "Optional free-form inputs map; passthrough only.",
"query": "User query string (the chat input).",
"user_id": "Optional user/tenant identifier.",
"webhook_payload": "Optional structured webhook request (set by the webhook HTTP handler; absent on chat flows).",
"inputs": "Optional free-form inputs map; passthrough only.",
}
}
// Outputs returns the same keys as Inputs (Begin is a passthrough).
func (b *BeginComponent) Outputs() map[string]string {
return map[string]string{
"query": "Query string (passthrough).",
"user_id": "User id, if provided (passthrough).",
"inputs": "Raw inputs map (passthrough).",
"query": "Query string (passthrough).",
"user_id": "User id, if provided (passthrough).",
"webhook_payload": "Webhook request payload, if provided (passthrough; also written to state.Sys[webhook_payload]).",
"inputs": "Raw inputs map (passthrough).",
}
}

View File

@@ -86,3 +86,77 @@ func TestBegin_PassesThroughInputs(t *testing.T) {
func withStateForTest(ctx context.Context, s *canvas.CanvasState) context.Context {
return canvas.WithState(ctx, s)
}
// TestBegin_InjectsWebhookPayload pins the contract added for the
// webhook HTTP handler: when inputs["webhook_payload"] is present, Begin
// must surface it on state.Sys["webhook_payload"] so downstream
// components (Retrieval, Agent, etc.) can read sys.webhook_payload the
// same way they read sys.query / sys.user_id.
//
// Mirrors python: agent/canvas.py (Begin component) reading
// `webhook_payload` from inputs and writing to state.Sys in the webhook
// branch.
func TestBegin_InjectsWebhookPayload(t *testing.T) {
c, _ := NewBeginComponent(nil)
state := canvas.NewCanvasState("run-3", "task-3")
ctx := canvas.WithState(context.Background(), state)
payload := map[string]any{
"query": map[string]any{"q": "hello"},
"headers": map[string]any{"x-token": "abc"},
"body": map[string]any{"k": "v"},
}
inputs := map[string]any{
"query": "",
"webhook_payload": payload,
}
out, err := c.Invoke(ctx, inputs)
if err != nil {
t.Fatalf("Invoke: %v", err)
}
got, ok := state.Sys["webhook_payload"].(map[string]any)
if !ok {
t.Fatalf("state.Sys[webhook_payload] missing or wrong type: %T", state.Sys["webhook_payload"])
}
if !reflect.DeepEqual(got, payload) {
t.Errorf("state.Sys[webhook_payload] mismatch:\n got %v\n want %v", got, payload)
}
// Passthrough preserved.
if outPayload, _ := out["webhook_payload"].(map[string]any); !reflect.DeepEqual(outPayload, payload) {
t.Errorf("outputs[webhook_payload] mismatch:\n got %v\n want %v", outPayload, payload)
}
}
// TestBegin_AbsentWebhookPayload confirms that the chat path (no
// webhook_payload key in inputs) leaves state.Sys["webhook_payload"]
// unset — adding the new branch must NOT pollute existing callers.
func TestBegin_AbsentWebhookPayload(t *testing.T) {
c, _ := NewBeginComponent(nil)
state := canvas.NewCanvasState("run-4", "task-4")
ctx := canvas.WithState(context.Background(), state)
if _, err := c.Invoke(ctx, map[string]any{"query": "plain chat"}); err != nil {
t.Fatalf("Invoke: %v", err)
}
if _, ok := state.Sys["webhook_payload"]; ok {
t.Errorf("state.Sys[webhook_payload] should not be set when inputs lack it; got %v", state.Sys["webhook_payload"])
}
}
// TestBegin_EmptyWebhookPayload confirms that an explicitly empty map
// is treated as "not present" — matching the python `if payload:` guard.
func TestBegin_EmptyWebhookPayload(t *testing.T) {
c, _ := NewBeginComponent(nil)
state := canvas.NewCanvasState("run-5", "task-5")
ctx := canvas.WithState(context.Background(), state)
if _, err := c.Invoke(ctx, map[string]any{
"query": "",
"webhook_payload": map[string]any{},
}); err != nil {
t.Fatalf("Invoke: %v", err)
}
if _, ok := state.Sys["webhook_payload"]; ok {
t.Errorf("state.Sys[webhook_payload] should not be set for empty payload; got %v", state.Sys["webhook_payload"])
}
}

View File

@@ -0,0 +1,40 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package dsl
import "errors"
// Sentinel errors returned by the extractors. Handlers use
// errors.Is to map them to 102 (DataError) envelopes without
// embedding the raw error text in the response.
var (
// ErrComponentNotFound is returned when the supplied
// componentID does not exist in dsl["components"].
ErrComponentNotFound = errors.New("dsl: component not found")
// ErrMissingInputForm is returned when the component exists
// but has no `obj.input_form` dict. The python Canvas returns
// None in this case; we surface 102 "component has no
// input_form" instead.
ErrMissingInputForm = errors.New("dsl: component has no input_form")
// ErrMalformedDSL is returned for structural problems — nil
// dsl, missing components map, wrong types. Distinct from
// ErrComponentNotFound so the handler can phrase the error
// more clearly when the dsl is broken.
ErrMalformedDSL = errors.New("dsl: malformed")
)

View File

@@ -0,0 +1,133 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Package dsl contains pure-function helpers for working with the agent
// canvas DSL map structure (`map[string]any`). It is intentionally
// runtime-free: no Canvas instantiation, no component factories, no
// database access. The agent_component handlers use it to introspect
// the DSL before deciding whether to wire up a runtime component.
package dsl
import "fmt"
// ExtractComponentInputForm returns the input-form schema dict stored at
// `dsl["components"][componentID]["obj"]["input_form"]`.
//
// This is the Go equivalent of the python
// `Canvas.get_component_input_form(component_id)` method
// (api/agent/canvas.py:163) which reads the same path. The python
// version walks the live Canvas object; we walk the raw DSL map
// directly because the Go Canvas type does not expose an
// introspection API (see plan §Gap C — there is no `GetComponent` on
// the runtime Canvas type).
//
// Returns:
// - the form-schema dict if present and well-typed
// - ErrComponentNotFound if the componentID is missing from dsl
// - ErrMissingInputForm if the component exists but has no input_form
// - ErrMalformedDSL if the field is present but the wrong type
//
// Type errors (input_form is e.g. a list or a string) are NOT
// collapsed into ErrMissingInputForm — they would mask a contract
// violation in the DSL and let DebugComponent run against corrupt
// data. CodeRabbit PR review #1 on PR #16403.
func ExtractComponentInputForm(dsl map[string]any, componentID string) (map[string]any, error) {
comp, err := navigateToComponent(dsl, componentID)
if err != nil {
return nil, err
}
obj, ok := comp["obj"].(map[string]any)
if !ok {
return nil, fmt.Errorf("%w: component %q has no obj", ErrMalformedDSL, componentID)
}
rawForm, exists := obj["input_form"]
if !exists || rawForm == nil {
return nil, fmt.Errorf("%w: component %q has no input_form", ErrMissingInputForm, componentID)
}
form, ok := rawForm.(map[string]any)
if !ok {
return nil, fmt.Errorf("%w: component %q input_form is not a dict", ErrMalformedDSL, componentID)
}
return form, nil
}
// ExtractComponentParams returns the params map stored at
// `dsl["components"][componentID]["obj"]["params"]`. The debug handler
// uses this to build the inputs map for the runtime Component.Invoke
// call. Type errors collapse to ErrMalformedDSL (CodeRabbit PR
// review #1).
func ExtractComponentParams(dsl map[string]any, componentID string) (map[string]any, error) {
comp, err := navigateToComponent(dsl, componentID)
if err != nil {
return nil, err
}
obj, ok := comp["obj"].(map[string]any)
if !ok {
return nil, fmt.Errorf("%w: component %q has no obj", ErrMalformedDSL, componentID)
}
rawParams, exists := obj["params"]
if !exists || rawParams == nil {
return nil, nil
}
params, ok := rawParams.(map[string]any)
if !ok {
return nil, fmt.Errorf("%w: component %q params is not a dict", ErrMalformedDSL, componentID)
}
return params, nil
}
// ExtractComponentName returns the component's class name (e.g.
// "Begin", "LLM", "Retrieval") from `dsl["components"][componentID].
// ["obj"]["component_name"]`. The runtime factory is keyed on this
// name.
func ExtractComponentName(dsl map[string]any, componentID string) (string, error) {
comp, err := navigateToComponent(dsl, componentID)
if err != nil {
return "", err
}
obj, ok := comp["obj"].(map[string]any)
if !ok {
return "", fmt.Errorf("%w: component %q has no obj", ErrMalformedDSL, componentID)
}
name, _ := obj["component_name"].(string)
if name == "" {
return "", fmt.Errorf("%w: component %q has no component_name", ErrMalformedDSL, componentID)
}
return name, nil
}
// navigateToComponent walks dsl["components"][componentID] and
// returns the inner dict. Centralised so the three extractors above
// share a single traversal path. (Renamed from extractComponent to
// avoid colliding with the same-named helper in normalize.go.)
func navigateToComponent(dsl map[string]any, componentID string) (map[string]any, error) {
if dsl == nil {
return nil, fmt.Errorf("%w: nil dsl", ErrMalformedDSL)
}
comps, ok := dsl["components"].(map[string]any)
if !ok {
return nil, fmt.Errorf("%w: missing components map", ErrMalformedDSL)
}
comp, ok := comps[componentID]
if !ok {
return nil, fmt.Errorf("%w: %q", ErrComponentNotFound, componentID)
}
cm, ok := comp.(map[string]any)
if !ok {
return nil, fmt.Errorf("%w: %q is not a dict", ErrMalformedDSL, componentID)
}
return cm, nil
}

View File

@@ -0,0 +1,149 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package dsl
import (
"errors"
"testing"
)
// happyDSL returns a 2-component dsl suitable for the happy-path
// extractors.
func happyDSL() map[string]any {
return map[string]any{
"components": map[string]any{
"begin": map[string]any{
"obj": map[string]any{
"component_name": "Begin",
"params": map[string]any{
"mode": "Manual",
},
"input_form": map[string]any{
"query": map[string]any{
"type": "string",
},
},
},
},
"answer": map[string]any{
"obj": map[string]any{
"component_name": "Answer",
},
},
},
}
}
func TestExtractComponentInputForm_HappyPath(t *testing.T) {
got, err := ExtractComponentInputForm(happyDSL(), "begin")
if err != nil {
t.Fatalf("err = %v, want nil", err)
}
if q, ok := got["query"].(map[string]any); !ok || q["type"] != "string" {
t.Errorf("query type = %v, want string", got["query"])
}
}
func TestExtractComponentInputForm_NotFound(t *testing.T) {
_, err := ExtractComponentInputForm(happyDSL(), "missing")
if !errors.Is(err, ErrComponentNotFound) {
t.Errorf("err = %v, want ErrComponentNotFound", err)
}
}
func TestExtractComponentInputForm_MissingObj(t *testing.T) {
dsl := map[string]any{
"components": map[string]any{
"bare": map[string]any{}, // no obj
},
}
_, err := ExtractComponentInputForm(dsl, "bare")
if !errors.Is(err, ErrMalformedDSL) {
t.Errorf("err = %v, want ErrMalformedDSL", err)
}
}
func TestExtractComponentInputForm_MissingInputForm(t *testing.T) {
// "answer" has obj but no input_form.
_, err := ExtractComponentInputForm(happyDSL(), "answer")
if !errors.Is(err, ErrMissingInputForm) {
t.Errorf("err = %v, want ErrMissingInputForm", err)
}
}
func TestExtractComponentInputForm_NilDSL(t *testing.T) {
_, err := ExtractComponentInputForm(nil, "anything")
if !errors.Is(err, ErrMalformedDSL) {
t.Errorf("err = %v, want ErrMalformedDSL", err)
}
}
func TestExtractComponentParams_HappyPath(t *testing.T) {
got, err := ExtractComponentParams(happyDSL(), "begin")
if err != nil {
t.Fatalf("err = %v, want nil", err)
}
if got["mode"] != "Manual" {
t.Errorf("mode = %v, want Manual", got["mode"])
}
}
func TestExtractComponentParams_NoParams(t *testing.T) {
got, err := ExtractComponentParams(happyDSL(), "answer")
if err != nil {
t.Fatalf("err = %v, want nil (params is optional)", err)
}
if got != nil && len(got) != 0 {
t.Errorf("params = %v, want empty/nil", got)
}
}
// TestExtractComponentParams_WrongType pins that a present-but-
// wrongly-typed params field is ErrMalformedDSL. CodeRabbit PR #1.
func TestExtractComponentParams_WrongType(t *testing.T) {
dsl := map[string]any{
"components": map[string]any{
"bad": map[string]any{
"obj": map[string]any{
"component_name": "Begin",
"params": "this is a string, not a dict",
},
},
},
}
_, err := ExtractComponentParams(dsl, "bad")
if !errors.Is(err, ErrMalformedDSL) {
t.Errorf("err = %v, want ErrMalformedDSL", err)
}
}
func TestExtractComponentName_HappyPath(t *testing.T) {
got, err := ExtractComponentName(happyDSL(), "begin")
if err != nil {
t.Fatalf("err = %v, want nil", err)
}
if got != "Begin" {
t.Errorf("name = %q, want Begin", got)
}
}
func TestExtractComponentName_NotFound(t *testing.T) {
_, err := ExtractComponentName(happyDSL(), "missing")
if !errors.Is(err, ErrComponentNotFound) {
t.Errorf("err = %v, want ErrComponentNotFound", err)
}
}

View File

@@ -989,6 +989,38 @@ func (r *RedisClient) GetClient() *redis.Client {
return r.client
}
// EvalTokenBucketStrict is the fail-closed counterpart to TokenBucket.Allow.
// It surfaces Lua errors and the uninitialised-Redis case to the caller so
// security gates (e.g. webhook rate limiter) can deny on transport failure
// rather than silently passing traffic. The existing TokenBucket.Allow
// silently fails-open and is reserved for the chat driver path where
// transient Redis outages should not block traffic.
//
// Cost is fixed at 1.0; callers wanting variable cost should compose their
// own Lua. ctx is used for both the EVALSHA round-trip and the deadline.
func (r *RedisClient) EvalTokenBucketStrict(
ctx context.Context, key string, capacity, rate float64,
) (allowed bool, err error) {
if r == nil || r.client == nil {
return false, fmt.Errorf("redis: not initialised")
}
now := float64(time.Now().Unix())
res, err := r.luaTokenBucket.Run(ctx, r.client, []string{key},
capacity, rate, now, 1.0).Result()
if err != nil {
return false, fmt.Errorf("token bucket: %w", err)
}
values, ok := res.([]interface{})
if !ok || len(values) < 1 {
return false, fmt.Errorf("token bucket: malformed reply")
}
allowedI, ok := values[0].(int64)
if !ok {
return false, fmt.Errorf("token bucket: malformed reply")
}
return allowedI == 1, nil
}
// RandomSleep sleeps for random duration between min and max milliseconds
func RandomSleep(minMs, maxMs int) {
duration := time.Duration(rand.Intn(maxMs-minMs)+minMs) * time.Millisecond

View File

@@ -0,0 +1,107 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package redis
import (
"context"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
)
// newStrictTestClient wires a miniredis-backed RedisClient for
// EvalTokenBucketStrict tests. Each call gets its own miniredis instance so
// tests do not share state via the package-level globalClient.
func newStrictTestClient(t *testing.T) (*RedisClient, *miniredis.Miniredis) {
t.Helper()
mr, err := miniredis.Run()
if err != nil {
t.Fatalf("miniredis: %v", err)
}
t.Cleanup(mr.Close)
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
t.Cleanup(func() { _ = rdb.Close() })
return &RedisClient{
client: rdb,
luaDeleteIfEqual: redis.NewScript(luaDeleteIfEqualScript),
luaTokenBucket: redis.NewScript(luaTokenBucketScript),
// luaAutoIncrement intentionally not loaded; not used here.
}, mr
}
// TestEvalTokenBucketStrict_AllowedThenDenied walks the bucket through
// capacity=2, rate=0.1 (slow refill). Two calls should be allowed; the
// third should be denied. This is the happy-path security gate.
func TestEvalTokenBucketStrict_AllowedThenDenied(t *testing.T) {
r, _ := newStrictTestClient(t)
ctx := context.Background()
for i := 1; i <= 2; i++ {
ok, err := r.EvalTokenBucketStrict(ctx, "tb:webhook", 2, 0.1)
if err != nil {
t.Fatalf("call %d unexpected error: %v", i, err)
}
if !ok {
t.Fatalf("call %d: expected allowed=true", i)
}
}
ok, err := r.EvalTokenBucketStrict(ctx, "tb:webhook", 2, 0.1)
if err != nil {
t.Fatalf("call 3 unexpected error: %v", err)
}
if ok {
t.Fatalf("call 3: expected allowed=false (bucket exhausted)")
}
}
// TestEvalTokenBucketStrict_RedisDownFailsClosed confirms the strict
// contract: when the transport fails, the caller sees an error AND
// allowed=false. This is the explicit divergence from TokenBucket.Allow,
// which would silently return allowed=true in the same situation.
func TestEvalTokenBucketStrict_RedisDownFailsClosed(t *testing.T) {
r, mr := newStrictTestClient(t)
mr.Close() // break the connection
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
ok, err := r.EvalTokenBucketStrict(ctx, "tb:webhook", 5, 1)
if err == nil {
t.Fatalf("expected transport error, got nil")
}
if ok {
t.Fatalf("expected allowed=false on transport failure (fail-closed)")
}
}
// TestEvalTokenBucketStrict_NilClient confirms the nil-receiver guard.
// The uninitialised-Redis case must NOT silently pass; it must return
// (false, error) so the webhook handler can surface 102.
func TestEvalTokenBucketStrict_NilClient(t *testing.T) {
var r *RedisClient
ok, err := r.EvalTokenBucketStrict(context.Background(), "tb:webhook", 1, 1)
if err == nil {
t.Fatalf("expected error on nil client, got nil")
}
if ok {
t.Fatalf("expected allowed=false on nil client (fail-closed)")
}
}

View File

@@ -42,9 +42,20 @@ import (
// AgentHandler agent handler
// fileUploader is the subset of FileService used by agent handlers.
//
// The full FileService also has UploadFile, but it is consumed by
// the FileHandler (handler/file.go), not by any agent handler, so
// the interface deliberately does NOT list it. (Code review CR1.)
type agentFileService interface {
UploadFile(tenantID, parentID string, files []*multipart.FileHeader) ([]map[string]interface{}, error)
DownloadAgentFile(tenantID, location string) ([]byte, error)
// UploadInfos stores raw bytes in the per-user downloads bucket and
// returns lightweight descriptors. Mirrors python FileService.upload_info
// (multi-file path) used by the agent upload endpoint.
UploadInfos(userID string, files []*multipart.FileHeader) ([]map[string]interface{}, error)
// UploadFromURL downloads a remote file (with SSRF protection) and
// stores it as an info blob. Mirrors python FileService.upload_info
// (single-file path with ?url=) used by the agent upload endpoint.
UploadFromURL(tenantID, rawURL string) (map[string]interface{}, error)
}
// chatAgentService is the subset of AgentService used by the chat-completion
@@ -62,6 +73,7 @@ type AgentHandler struct {
agentService *service.AgentService
chatRunner chatAgentService
fileService agentFileService
loader canvasLoader
}
// NewAgentHandler create agent handler
@@ -71,6 +83,7 @@ func NewAgentHandler(agentService *service.AgentService, fileService *service.Fi
agentService: agentService,
chatRunner: agentService,
fileService: fileService,
loader: agentService,
}
}

View File

@@ -0,0 +1,200 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Gap C — `GET /api/v1/agents/<canvas_id>/components/<component_id>/input-form`.
// Gap D — `POST /api/v1/agents/<canvas_id>/components/<component_id>/debug`.
//
// Both endpoints introspect the user_canvas DSL via
// internal/agent/dsl helpers and, for debug only, construct a runtime
// component via the production factory. The python reference uses
// @_require_canvas_access_sync/_async decorators which return
// OPERATING_ERROR (103) with the python permission message; we
// inline the check here via loader.LoadCanvasByID and surface the
// same envelope (CodeOperatingError + canvasNoAccessMessage) so
// existing clients can pattern-match the response.
package handler
import (
"errors"
"github.com/gin-gonic/gin"
"ragflow/internal/agent/canvas"
"ragflow/internal/agent/dsl"
"ragflow/internal/agent/runtime"
"ragflow/internal/common"
"ragflow/internal/dao"
)
// GetComponentInputForm GET /api/v1/agents/:canvas_id/components/:component_id/input-form
func (h *AgentHandler) GetComponentInputForm(c *gin.Context) {
user, code, msg := GetUser(c)
if code != common.CodeSuccess {
jsonError(c, code, msg)
return
}
canvasID := c.Param("canvas_id")
componentID := c.Param("component_id")
if canvasID == "" || componentID == "" {
jsonError(c, common.CodeArgumentError, "`canvas_id` and `component_id` are required.")
return
}
cv, err := h.loader.LoadCanvasByID(c.Request.Context(), user.ID, canvasID)
if err != nil {
if err == dao.ErrUserCanvasNotFound {
jsonError(c, common.CodeOperatingError, canvasNoAccessMessage)
return
}
jsonError(c, common.CodeServerError, err.Error())
return
}
form, err := dsl.ExtractComponentInputForm(cv.DSL, componentID)
if err != nil {
mapDSLError(c, componentID, err)
return
}
c.JSON(200, gin.H{
"code": common.CodeSuccess,
"data": form,
"message": "success",
})
}
// DebugComponent POST /api/v1/agents/:canvas_id/components/:component_id/debug
//
// Body shape (python parity): {"params": {"input_name": {"value": ...}, ...}}
//
// The python reference calls `component.invoke(**{k: o["value"] for k, o in
// req["params"].items()})`. We replicate this by flattening each param's
// "value" field into the Invoke inputs map.
func (h *AgentHandler) DebugComponent(c *gin.Context) {
user, code, msg := GetUser(c)
if code != common.CodeSuccess {
jsonError(c, code, msg)
return
}
canvasID := c.Param("canvas_id")
componentID := c.Param("component_id")
if canvasID == "" || componentID == "" {
jsonError(c, common.CodeArgumentError, "`canvas_id` and `component_id` are required.")
return
}
var body struct {
Params map[string]map[string]any `json:"params"`
}
if err := c.ShouldBindJSON(&body); err != nil {
jsonError(c, common.CodeArgumentError, "Invalid request: "+err.Error())
return
}
if body.Params == nil {
jsonError(c, common.CodeArgumentError, "`params` is required.")
return
}
cv, err := h.loader.LoadCanvasByID(c.Request.Context(), user.ID, canvasID)
if err != nil {
if err == dao.ErrUserCanvasNotFound {
jsonError(c, common.CodeOperatingError, canvasNoAccessMessage)
return
}
jsonError(c, common.CodeServerError, err.Error())
return
}
name, err := dsl.ExtractComponentName(cv.DSL, componentID)
if err != nil {
mapDSLError(c, componentID, err)
return
}
dslParams, _ := dsl.ExtractComponentParams(cv.DSL, componentID)
// Build the Invoke inputs map by flattening the request body's
// {param: {value: ...}} shape into {param: value}. Mirrors python
// agent_api.py:830 (`component.invoke(**{k: o["value"] for k, o in
// req["params"].items()})`).
//
// The body contract requires `params.*.value` — a missing value
// field used to slip through as nil and still invoke the
// component, which silently corrupted debug input. Now we fail
// fast. CodeRabbit PR review #2.
inputs := make(map[string]any, len(body.Params))
for k, v := range body.Params {
if v == nil {
jsonError(c, common.CodeArgumentError, "`params."+k+".value` is required.")
return
}
value, ok := v["value"]
if !ok {
jsonError(c, common.CodeArgumentError, "`params."+k+".value` is required.")
return
}
inputs[k] = value
}
factory := runtime.DefaultFactory()
if factory == nil {
jsonError(c, common.CodeServerError, "component factory not initialised")
return
}
comp, err := factory(name, dslParams)
if err != nil {
jsonError(c, common.CodeDataError, "component factory: "+err.Error())
return
}
// D4: skip set_debug_inputs (python-only LLM debug hook). The
// raw Invoke already supports the same inputs.
//
// Begin (and other stateful components) reads a *CanvasState
// from the request context. We attach a fresh one here so
// debug works on a single component without standing up the
// full canvas compile.
invokeCtx := runtime.WithState(c.Request.Context(), canvas.NewCanvasState("debug-"+componentID, "debug-task"))
outputs, err := comp.Invoke(invokeCtx, inputs)
if err != nil {
jsonError(c, common.CodeServerError, "invoke: "+err.Error())
return
}
c.JSON(200, gin.H{
"code": common.CodeSuccess,
"data": outputs,
"message": "success",
})
}
// mapDSLError translates a dsl extractor error into a 102 envelope.
// Centralised so both handlers return consistent error shapes. The
// default arm surfaces unknown errors as 500 (server error) so a
// future unmapped dsl sentinel doesn't silently masquerade as a
// user-data problem (code-review MEDIUM).
func mapDSLError(c *gin.Context, componentID string, err error) {
switch {
case errors.Is(err, dsl.ErrComponentNotFound):
jsonError(c, common.CodeDataError, "component not found: "+componentID)
case errors.Is(err, dsl.ErrMissingInputForm):
jsonError(c, common.CodeDataError, "component has no input_form: "+componentID)
case errors.Is(err, dsl.ErrMalformedDSL):
jsonError(c, common.CodeDataError, "malformed dsl: "+err.Error())
default:
jsonError(c, common.CodeServerError, "internal: "+err.Error())
}
}

View File

@@ -0,0 +1,315 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package handler
import (
"encoding/json"
"errors"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
_ "ragflow/internal/agent/component" // registers the production factory
"ragflow/internal/dao"
"ragflow/internal/entity"
)
// componentCtx builds a Gin context for one of the
// /agents/:canvas_id/components/:component_id/* endpoints. The
// supplied `body` is bound as the request body (empty for GET).
func componentCtx(t *testing.T, method, path, body string) (*gin.Context, *httptest.ResponseRecorder) {
t.Helper()
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(method, path, strings.NewReader(body))
if body != "" {
c.Request.Header.Set("Content-Type", "application/json")
}
c.Set("user", &entity.User{ID: "u-1"})
return c, w
}
// beginCanvas returns a UserCanvas whose DSL has a single Begin
// component with both an input_form and a params block. Used for
// happy-path tests.
func beginCanvas() *entity.UserCanvas {
return &entity.UserCanvas{
ID: "c1",
DSL: map[string]any{
"components": map[string]any{
"begin": map[string]any{
"obj": map[string]any{
"component_name": "Begin",
"params": map[string]any{
"mode": "Manual",
},
"input_form": map[string]any{
"query": map[string]any{"type": "string"},
},
},
},
},
},
}
}
// ---------- GetComponentInputForm ----------
func TestGetComponentInputForm_HappyPath(t *testing.T) {
cv := beginCanvas()
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
c, w := componentCtx(t, "GET", "/api/v1/agents/c1/components/begin/input-form", "")
c.Params = gin.Params{
{Key: "canvas_id", Value: "c1"},
{Key: "component_id", Value: "begin"},
}
h.GetComponentInputForm(c)
if w.Code != 200 {
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
}
var env struct {
Code int `json:"code"`
Data map[string]interface{} `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &env); err != nil {
t.Fatalf("decode: %v (body=%s)", err, w.Body.String())
}
q, ok := env.Data["query"].(map[string]interface{})
if !ok {
t.Fatalf("data.query = %v, want a map", env.Data["query"])
}
if q["type"] != "string" {
t.Errorf("data.query.type = %v, want string", q["type"])
}
}
func TestGetComponentInputForm_CanvasNotFound(t *testing.T) {
h := &AgentHandler{loader: &fakeCanvasLoader{err: dao.ErrUserCanvasNotFound}}
c, w := componentCtx(t, "GET", "/api/v1/agents/c1/components/begin/input-form", "")
c.Params = gin.Params{
{Key: "canvas_id", Value: "c1"},
{Key: "component_id", Value: "begin"},
}
h.GetComponentInputForm(c)
code, msg := errBody(t, w.Body.Bytes())
if code != 103 { // CodeOperatingError — python @_require_canvas_access parity
t.Errorf("code = %d, want 103; msg=%q", code, msg)
}
want := "Make sure you have permission to access the agent."
if msg != want {
t.Errorf("msg = %q, want %q", msg, want)
}
}
func TestGetComponentInputForm_ComponentNotFound(t *testing.T) {
cv := beginCanvas()
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
c, w := componentCtx(t, "GET", "/api/v1/agents/c1/components/missing/input-form", "")
c.Params = gin.Params{
{Key: "canvas_id", Value: "c1"},
{Key: "component_id", Value: "missing"},
}
h.GetComponentInputForm(c)
code, msg := errBody(t, w.Body.Bytes())
if code != 102 { // CodeDataError
t.Errorf("code = %d, want 102; msg=%q", code, msg)
}
if !strings.Contains(msg, "component not found") {
t.Errorf("msg = %q, want 'component not found'", msg)
}
}
func TestGetComponentInputForm_NoInputForm(t *testing.T) {
cv := &entity.UserCanvas{
ID: "c1",
DSL: map[string]any{
"components": map[string]any{
"answer": map[string]any{
"obj": map[string]any{
"component_name": "Answer",
// no input_form
},
},
},
},
}
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
c, w := componentCtx(t, "GET", "/api/v1/agents/c1/components/answer/input-form", "")
c.Params = gin.Params{
{Key: "canvas_id", Value: "c1"},
{Key: "component_id", Value: "answer"},
}
h.GetComponentInputForm(c)
code, msg := errBody(t, w.Body.Bytes())
if code != 102 {
t.Errorf("code = %d, want 102; msg=%q", code, msg)
}
if !strings.Contains(msg, "no input_form") {
t.Errorf("msg = %q, want 'no input_form'", msg)
}
}
// ---------- DebugComponent ----------
func TestDebugComponent_HappyPath_Begin(t *testing.T) {
cv := beginCanvas()
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
body := `{"params":{"query":{"value":"hello world"}}}`
c, w := componentCtx(t, "POST", "/api/v1/agents/c1/components/begin/debug", body)
c.Params = gin.Params{
{Key: "canvas_id", Value: "c1"},
{Key: "component_id", Value: "begin"},
}
h.DebugComponent(c)
if w.Code != 200 {
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
}
// Begin's Invoke is a passthrough — it returns the input map as
// outputs (see internal/agent/component/begin.go:96-98). Pin the
// exact passthrough so a regression that drops the `query` field
// would be caught (architect review A2).
var env struct {
Code int `json:"code"`
Data map[string]interface{} `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &env); err != nil {
t.Fatalf("decode: %v (body=%s)", err, w.Body.String())
}
if env.Code != 0 {
t.Errorf("code = %d, want 0; body=%s", env.Code, w.Body.String())
}
if env.Data["query"] != "hello world" {
t.Errorf("data.query = %v, want 'hello world'", env.Data["query"])
}
}
func TestDebugComponent_InvalidParams_MissingField(t *testing.T) {
cv := beginCanvas()
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
// No `params` field.
c, w := componentCtx(t, "POST", "/api/v1/agents/c1/components/begin/debug", `{}`)
c.Params = gin.Params{
{Key: "canvas_id", Value: "c1"},
{Key: "component_id", Value: "begin"},
}
h.DebugComponent(c)
code, msg := errBody(t, w.Body.Bytes())
if code != 101 { // CodeArgumentError
t.Errorf("code = %d, want 101; msg=%q", code, msg)
}
if !strings.Contains(msg, "params") {
t.Errorf("msg = %q, want 'params'", msg)
}
}
// TestDebugComponent_InvalidParams_MissingValue covers CodeRabbit
// PR review #2: `{"params":{"q":{}}}` (no `value` key) should
// fail-fast with 101 rather than silently invoking the component
// with inputs["q"]=nil.
func TestDebugComponent_InvalidParams_MissingValue(t *testing.T) {
cv := beginCanvas()
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
c, w := componentCtx(t, "POST", "/api/v1/agents/c1/components/begin/debug", `{"params":{"q":{}}}`)
c.Params = gin.Params{
{Key: "canvas_id", Value: "c1"},
{Key: "component_id", Value: "begin"},
}
h.DebugComponent(c)
code, msg := errBody(t, w.Body.Bytes())
if code != 101 {
t.Errorf("code = %d, want 101; msg=%q", code, msg)
}
if !strings.Contains(msg, "value") {
t.Errorf("msg = %q, want mention of 'value'", msg)
}
}
func TestDebugComponent_UnknownComponent(t *testing.T) {
cv := beginCanvas()
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
body := `{"params":{"x":{"value":1}}}`
c, w := componentCtx(t, "POST", "/api/v1/agents/c1/components/missing/debug", body)
c.Params = gin.Params{
{Key: "canvas_id", Value: "c1"},
{Key: "component_id", Value: "missing"},
}
h.DebugComponent(c)
code, msg := errBody(t, w.Body.Bytes())
if code != 102 {
t.Errorf("code = %d, want 102; msg=%q", code, msg)
}
if !strings.Contains(msg, "component not found") {
t.Errorf("msg = %q, want 'component not found'", msg)
}
}
func TestDebugComponent_CannotAccessCanvas(t *testing.T) {
h := &AgentHandler{loader: &fakeCanvasLoader{err: dao.ErrUserCanvasNotFound}}
body := `{"params":{"x":{"value":1}}}`
c, w := componentCtx(t, "POST", "/api/v1/agents/c1/components/begin/debug", body)
c.Params = gin.Params{
{Key: "canvas_id", Value: "c1"},
{Key: "component_id", Value: "begin"},
}
h.DebugComponent(c)
code, msg := errBody(t, w.Body.Bytes())
if code != 103 { // CodeOperatingError
t.Errorf("code = %d, want 103; msg=%q", code, msg)
}
want := "Make sure you have permission to access the agent."
if msg != want {
t.Errorf("msg = %q, want %q", msg, want)
}
}
func TestDebugComponent_LoaderError_NonNotFound(t *testing.T) {
h := &AgentHandler{loader: &fakeCanvasLoader{err: errors.New("db conn lost")}}
body := `{"params":{"x":{"value":1}}}`
c, w := componentCtx(t, "POST", "/api/v1/agents/c1/components/begin/debug", body)
c.Params = gin.Params{
{Key: "canvas_id", Value: "c1"},
{Key: "component_id", Value: "begin"},
}
h.DebugComponent(c)
code, _ := errBody(t, w.Body.Bytes())
if code != 500 { // CodeServerError
t.Errorf("code = %d, want 500", code)
}
}

View File

@@ -0,0 +1,86 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Gap A — `GET /api/v1/agents/download` (Python
// api/apps/restful_apis/agent_api.py:523-530).
//
// Mirrors the python download_agent_file handler:
// - auth via @login_required → GetUser
// - tenant injection via @add_tenant_id_to_kwargs → user.TenantID (we use
// user.ID here because the user/tenant distinction is collapsed in
// the Go session model; FileService buckets by userID for download
// retrieval, same as the python tenantID). See service/file.go:1033.
// - reads `id` from query string and streams raw bytes back as
// application/octet-stream.
package handler
import (
"fmt"
"net/http"
"net/url"
"path/filepath"
"strings"
"github.com/gin-gonic/gin"
"ragflow/internal/common"
)
// DownloadAgentFile GET /api/v1/agents/download?id=<file_id>
func (h *AgentHandler) DownloadAgentFile(c *gin.Context) {
user, code, msg := GetUser(c)
if code != common.CodeSuccess {
jsonError(c, code, msg)
return
}
fileID := c.Query("id")
if fileID == "" {
jsonError(c, common.CodeArgumentError, "`id` is required.")
return
}
// IDOR note (security review H1): the Go User struct has no
// TenantID field — the project collapses user and tenant in a
// single-tenant session model. Python's @add_tenant_id_to_kwargs
// resolves tenant_id from the session, and the python download
// endpoint also reads `id` directly from the query string with no
// per-object ownership check, so this port preserves the python
// shape. A future per-object ownership check should be added in
// both the python and Go code paths.
blob, err := h.fileService.DownloadAgentFile(user.ID, fileID)
if err != nil {
jsonError(c, common.CodeServerError, err.Error())
return
}
// Sanitize the Content-Disposition value to prevent header
// injection (security review H2). The Go net/http layer rejects
// CR/LF in header values, but we sanitize at the source so we
// don't rely on the implicit defense. `filepath.Base` strips any
// path elements; url.PathEscape produces an RFC 5987 filename*=
// value.
safe := filepath.Base(fileID)
if safe == "" || safe == "." || safe == "/" || strings.ContainsAny(safe, "\r\n\"") {
jsonError(c, common.CodeArgumentError, "invalid file id.")
return
}
c.Header("Content-Disposition", fmt.Sprintf(
`attachment; filename="%s"; filename*=UTF-8''%s`,
safe, url.PathEscape(safe),
))
c.Data(http.StatusOK, "application/octet-stream", blob)
}

View File

@@ -0,0 +1,139 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package handler
import (
"errors"
"mime/multipart"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"ragflow/internal/entity"
)
// fakeAgentFileService satisfies the agentFileService interface declared
// in agent.go. It returns canned bytes / errors for DownloadAgentFile so
// download-handler tests can run without a real FileService or MinIO.
// UploadInfos returns whatever is in `uploadList` so the upload tests
// can assert the response payload shape. UploadFromURL returns whatever
// is in `urlUpload` (or `urlUploadErr`) for the ?url= import path.
type fakeAgentFileService struct {
blob []byte
err error
uploadList []map[string]interface{}
uploadErr error
urlUpload map[string]interface{}
urlUploadErr error
}
func (f *fakeAgentFileService) UploadInfos(_ string, _ []*multipart.FileHeader) ([]map[string]interface{}, error) {
if f.uploadErr != nil {
return nil, f.uploadErr
}
return f.uploadList, nil
}
func (f *fakeAgentFileService) UploadFromURL(_ string, _ string) (map[string]interface{}, error) {
if f.urlUploadErr != nil {
return nil, f.urlUploadErr
}
return f.urlUpload, nil
}
func (f *fakeAgentFileService) DownloadAgentFile(_ string, _ string) ([]byte, error) {
if f.err != nil {
return nil, f.err
}
return f.blob, nil
}
// downloadCtx builds a Gin context for the /agents/download endpoint with
// optional id query parameter and a stubbed authenticated user.
func downloadCtx(t *testing.T, id string) (*gin.Context, *httptest.ResponseRecorder) {
t.Helper()
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
url := "/api/v1/agents/download"
if id != "" {
url += "?id=" + id
}
c.Request = httptest.NewRequest("GET", url, nil)
c.Set("user", &entity.User{ID: "user-1"})
return c, w
}
// TestDownloadAgentFile_HappyPath verifies that a known id returns the
// raw bytes with application/octet-stream content type and a sane
// Content-Disposition header.
func TestDownloadAgentFile_HappyPath(t *testing.T) {
h := &AgentHandler{fileService: &fakeAgentFileService{blob: []byte("hello world")}}
c, w := downloadCtx(t, "file-123")
h.DownloadAgentFile(c)
if w.Code != 200 {
t.Errorf("status = %d, want 200; body=%s", w.Code, w.Body.String())
}
if w.Body.String() != "hello world" {
t.Errorf("body = %q, want %q", w.Body.String(), "hello world")
}
if got := w.Header().Get("Content-Disposition"); !strings.Contains(got, "file-123") {
t.Errorf("Content-Disposition = %q, want filename=file-123", got)
}
if got := w.Header().Get("Content-Type"); got != "application/octet-stream" {
t.Errorf("Content-Type = %q, want application/octet-stream", got)
}
}
// TestDownloadAgentFile_MissingID verifies that an empty id is rejected
// with a 102-shaped envelope.
func TestDownloadAgentFile_MissingID(t *testing.T) {
h := &AgentHandler{fileService: &fakeAgentFileService{}}
c, w := downloadCtx(t, "") // no id
h.DownloadAgentFile(c)
if w.Code != 200 {
t.Errorf("status = %d, want 200 (envelope returned in body)", w.Code)
}
code, _ := errBody(t, w.Body.Bytes())
if code != 101 { // CodeArgumentError
t.Errorf("code = %d, want 101", code)
}
}
// TestDownloadAgentFile_LoaderError verifies that a storage error
// bubbles up as a 500-class envelope.
func TestDownloadAgentFile_LoaderError(t *testing.T) {
h := &AgentHandler{fileService: &fakeAgentFileService{err: errors.New("minio down")}}
c, w := downloadCtx(t, "file-1")
h.DownloadAgentFile(c)
if w.Code != 200 {
t.Errorf("status = %d, want 200 (envelope in body)", w.Code)
}
code, msg := errBody(t, w.Body.Bytes())
if code != 500 { // CodeServerError
t.Errorf("code = %d, want 500", code)
}
if !strings.Contains(msg, "minio down") {
t.Errorf("msg = %q, want contains 'minio down'", msg)
}
}

View File

@@ -0,0 +1,152 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Gap B — `POST /api/v1/agents/<agent_id>/upload` (Python
// api/apps/restful_apis/agent_api.py:761-790).
//
// Mirrors the python upload_agent_file handler:
// - @_require_canvas_access_async → loader.LoadCanvasByID (returns
// ErrUserCanvasNotFound for both missing and forbidden; we map that
// to CodeOperatingError (103) with the python permission message
// instead of the chat-path 103 "canvas not found.")
// - single file + ?url= → FileService.upload_info(tenant, file, url)
// via UploadFromURL (URL-import mode)
// - single file + no url → FileService.upload_info(tenant, file)
// via UploadInfos with a one-element slice
// - multi file + any url → FileService.upload_info * N
// via UploadInfos (Python ignores ?url= on the multi-file path)
//
// 64 MB upload cap (D3 default).
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"ragflow/internal/common"
"ragflow/internal/dao"
)
// uploadMaxBytes caps the multipart form body at 64 MB. The python
// reference relies on Quart/werkzeug defaults which are well above
// this; we set it explicitly so the cap is auditable and stable across
// test environments.
const uploadMaxBytes int64 = 64 << 20 // 64 MiB
// canvasNoAccessMessage mirrors the python permission error
// (api/apps/restful_apis/agent_api.py:78,89). Kept identical to
// python so existing clients can pattern-match the message text.
const canvasNoAccessMessage = "Make sure you have permission to access the agent."
// UploadAgentFile POST /api/v1/agents/:canvas_id/upload
func (h *AgentHandler) UploadAgentFile(c *gin.Context) {
user, code, msg := GetUser(c)
if code != common.CodeSuccess {
jsonError(c, code, msg)
return
}
canvasID := c.Param("canvas_id")
if canvasID == "" {
jsonError(c, common.CodeArgumentError, "`canvas_id` is required.")
return
}
// Canvas access check: matches python @_require_canvas_access_async.
// We deliberately do NOT differentiate "missing" from "forbidden"
// (LoadCanvasByID collapses both into ErrUserCanvasNotFound) for
// IDOR mitigation; the user-visible envelope uses OPERATING_ERROR
// (103) with the python permission message so existing clients can
// still pattern-match the text.
if _, err := h.loader.LoadCanvasByID(c.Request.Context(), user.ID, canvasID); err != nil {
if err == dao.ErrUserCanvasNotFound {
jsonError(c, common.CodeOperatingError, canvasNoAccessMessage)
return
}
jsonError(c, common.CodeServerError, err.Error())
return
}
// Hard cap the body before any parsing (security review H3).
// Without MaxBytesReader, a 1 GB request body is fully drained
// into memory by ParseMultipartForm before any size check fires.
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, uploadMaxBytes)
if cl := c.Request.ContentLength; cl > uploadMaxBytes {
jsonError(c, common.CodeArgumentError, "request body too large.")
return
}
if err := c.Request.ParseMultipartForm(uploadMaxBytes); err != nil {
jsonError(c, common.CodeArgumentError, "invalid multipart form: "+err.Error())
return
}
defer func() {
if c.Request.MultipartForm != nil {
_ = c.Request.MultipartForm.RemoveAll()
}
}()
form := c.Request.MultipartForm
if form == nil {
jsonError(c, common.CodeArgumentError, "missing multipart form.")
return
}
files := form.File["file"]
// URL-import mode: matches python's behaviour exactly
// (api/apps/restful_apis/agent_api.py:775-783). The url query
// param is consulted ONLY on the single-file branch; for 0 or
// >1 files, the url is silently ignored and the request flows
// into the normal UploadInfos path. We replicate that with a
// guard that dispatches to UploadFromURL only when both
// conditions are met.
if url := c.Query("url"); url != "" && len(files) == 1 {
uploaded, err := h.fileService.UploadFromURL(user.ID, url)
if err != nil {
jsonError(c, common.CodeServerError, err.Error())
return
}
c.JSON(200, gin.H{
"code": common.CodeSuccess,
"data": uploaded,
"message": "success",
})
return
}
if len(files) == 0 {
jsonError(c, common.CodeArgumentError, "`file` field is required.")
return
}
results, err := h.fileService.UploadInfos(user.ID, files)
if err != nil {
jsonError(c, common.CodeServerError, err.Error())
return
}
// Python parity: 1 file → single dict; >1 → list.
var payload any
if len(results) == 1 {
payload = results[0]
} else {
payload = results
}
c.JSON(200, gin.H{
"code": common.CodeSuccess,
"data": payload,
"message": "success",
})
}

View File

@@ -13,230 +13,305 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
//go:build ignore
package handler
import (
"bytes"
"context"
"encoding/json"
"errors"
"mime/multipart"
"net/http/httptest"
"strconv"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/glebarez/sqlite"
"gorm.io/gorm"
"ragflow/internal/common"
"ragflow/internal/dao"
"ragflow/internal/entity"
"ragflow/internal/service"
)
// setupUploadTestDB sets up SQLite in-memory DB for upload handler tests.
func setupUploadTestDB(t *testing.T) *gorm.DB {
// uploadFakes is a paired pair of fakes used by the upload-handler
// tests. The canvasLoader fake grants/denies access per case; the
// fileService fake records the call and returns a canned descriptor.
type uploadFakes struct {
loader *fakeCanvasLoader
fileSvc *fakeAgentFileService
}
func (u *uploadFakes) setUploadResult(list []map[string]interface{}) {
u.fileSvc.uploadList = list
}
// makeUploadCtx builds a Gin context with a multipart body containing
// `n` files under the "file" field. Each file's body is the
// decimal-string representation of its 1-based index.
func makeUploadCtx(t *testing.T, n int) (*gin.Context, *httptest.ResponseRecorder) {
t.Helper()
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
TranslateError: true,
body := &bytes.Buffer{}
mw := multipart.NewWriter(body)
for i := 1; i <= n; i++ {
name := "f" + strconv.Itoa(i) + ".txt"
fw, err := mw.CreateFormFile("file", name)
if err != nil {
t.Fatalf("CreateFormFile: %v", err)
}
if _, err := fw.Write([]byte(name)); err != nil {
t.Fatalf("write file: %v", err)
}
}
mw.Close()
req := httptest.NewRequest("POST", "/api/v1/agents/c1/upload", body)
req.Header.Set("Content-Type", mw.FormDataContentType())
c.Request = req
c.Params = gin.Params{{Key: "canvas_id", Value: "c1"}}
c.Set("user", &entity.User{ID: "u-1"})
return c, w
}
// makeUploadCtxNoFile builds a Gin context with a multipart body that
// has no "file" field (used for the missing-file test).
func makeUploadCtxNoFile(t *testing.T) (*gin.Context, *httptest.ResponseRecorder) {
t.Helper()
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body := &bytes.Buffer{}
mw := multipart.NewWriter(body)
// "other" field, not "file"
fw, _ := mw.CreateFormFile("other", "o.txt")
fw.Write([]byte("o"))
mw.Close()
req := httptest.NewRequest("POST", "/api/v1/agents/c1/upload", body)
req.Header.Set("Content-Type", mw.FormDataContentType())
c.Request = req
c.Params = gin.Params{{Key: "canvas_id", Value: "c1"}}
c.Set("user", &entity.User{ID: "u-1"})
return c, w
}
// TestUploadAgentFile_SingleFile pins the 1-file → single-dict
// response shape (python agent_api.py:775-779).
func TestUploadAgentFile_SingleFile(t *testing.T) {
loader := &fakeCanvasLoader{canvas: &entity.UserCanvas{ID: "c1"}}
fu := &uploadFakes{
loader: loader,
fileSvc: &fakeAgentFileService{},
}
fu.setUploadResult([]map[string]interface{}{{"id": "upload-1", "name": "f1.txt"}})
h := &AgentHandler{loader: fu.loader, fileService: fu.fileSvc}
c, w := makeUploadCtx(t, 1)
h.UploadAgentFile(c)
if w.Code != 200 {
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
}
// Single-file path returns data as a single dict, NOT a list.
var env struct {
Code int `json:"code"`
Data map[string]interface{} `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &env); err != nil {
t.Fatalf("decode: %v (body=%s)", err, w.Body.String())
}
if env.Code != 0 {
t.Errorf("code = %d, want 0", env.Code)
}
if env.Data["id"] != "upload-1" {
t.Errorf("data.id = %v, want upload-1", env.Data["id"])
}
}
// TestUploadAgentFile_MultiFile pins the >1-file → list-of-dicts
// response shape (python agent_api.py:780-783).
func TestUploadAgentFile_MultiFile(t *testing.T) {
loader := &fakeCanvasLoader{canvas: &entity.UserCanvas{ID: "c1"}}
fu := &uploadFakes{loader: loader, fileSvc: &fakeAgentFileService{}}
fu.setUploadResult([]map[string]interface{}{
{"id": "u1", "name": "f1.txt"},
{"id": "u2", "name": "f2.txt"},
{"id": "u3", "name": "f3.txt"},
})
if err != nil {
t.Fatalf("failed to open sqlite: %v", err)
}
h := &AgentHandler{loader: fu.loader, fileService: fu.fileSvc}
if err := db.AutoMigrate(
&entity.User{},
&entity.UserCanvas{},
&entity.UserTenant{},
); err != nil {
t.Fatalf("failed to migrate: %v", err)
}
return db
}
// fakeUploadFileService implements fileUploader for tests.
type fakeUploadFileService struct {
uploaded []map[string]interface{}
err error
lastTenantID string
lastParentID string
}
func (f *fakeUploadFileService) UploadFile(tenantID, parentID string, files []*multipart.FileHeader) ([]map[string]interface{}, error) {
f.lastTenantID = tenantID
f.lastParentID = parentID
return f.uploaded, f.err
}
// TestUploadAgentFileHandler_Success verifies the happy path.
func TestUploadAgentFileHandler_Success(t *testing.T) {
gin.SetMode(gin.TestMode)
db := setupUploadTestDB(t)
orig := dao.DB
dao.DB = db
t.Cleanup(func() { dao.DB = orig })
db.Create(&entity.User{ID: "user-1", Nickname: "test", Email: "test@test.com"})
db.Create(&entity.UserCanvas{ID: "canvas-1", UserID: "user-1", Title: sp("Test Agent")})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body := strings.NewReader("--boundary\r\nContent-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\nContent-Type: text/plain\r\n\r\nhello world\r\n--boundary--")
req := httptest.NewRequest("POST", "/api/v1/agents/canvas-1/upload", body)
req.Header.Set("Content-Type", "multipart/form-data; boundary=boundary")
c.Request = req
c.Set("user", &entity.User{ID: "user-1"})
c.Set("user_id", "user-1")
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}}
svc := &fakeUploadFileService{
uploaded: []map[string]interface{}{
{"id": "file-1", "name": "test.txt"},
},
}
h := &AgentHandler{
agentService: service.NewAgentService(),
fileService: svc,
}
c, w := makeUploadCtx(t, 3)
h.UploadAgentFile(c)
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
code, _ := resp["code"].(float64)
if code != float64(common.CodeSuccess) {
t.Fatalf("expected code 0, got %v: %v", code, resp["message"])
if w.Code != 200 {
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
}
var env struct {
Code int `json:"code"`
Data []map[string]interface{} `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &env); err != nil {
t.Fatalf("decode: %v (body=%s)", err, w.Body.String())
}
if env.Code != 0 {
t.Errorf("code = %d, want 0", env.Code)
}
if len(env.Data) != 3 {
t.Errorf("data has %d items, want 3", len(env.Data))
}
}
// TestUploadAgentFileHandler_NoPermission verifies cross-user access is denied.
func TestUploadAgentFileHandler_NoPermission(t *testing.T) {
gin.SetMode(gin.TestMode)
// TestUploadAgentFile_MissingFileField verifies that a multipart body
// without a "file" field is rejected with a 101 envelope.
func TestUploadAgentFile_MissingFileField(t *testing.T) {
loader := &fakeCanvasLoader{canvas: &entity.UserCanvas{ID: "c1"}}
h := &AgentHandler{loader: loader, fileService: &fakeAgentFileService{}}
db := setupUploadTestDB(t)
orig := dao.DB
dao.DB = db
t.Cleanup(func() { dao.DB = orig })
db.Create(&entity.User{ID: "user-a", Nickname: "a", Email: "a@test.com"})
db.Create(&entity.UserCanvas{ID: "canvas-b", UserID: "user-b", Title: sp("Not Yours")})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/v1/agents/canvas-b/upload", nil)
c.Set("user", &entity.User{ID: "user-a"})
c.Set("user_id", "user-a")
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-b"}}
h := &AgentHandler{
agentService: service.NewAgentService(),
fileService: &fakeUploadFileService{},
}
c, w := makeUploadCtxNoFile(t)
h.UploadAgentFile(c)
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
code, _ := resp["code"].(float64)
if code != float64(common.CodeOperatingError) {
t.Errorf("expected operating error %d, got %v", common.CodeOperatingError, code)
code, msg := errBody(t, w.Body.Bytes())
if code != 101 { // CodeArgumentError
t.Errorf("code = %d, want 101; msg=%q", code, msg)
}
if !strings.Contains(msg, "file") {
t.Errorf("msg = %q, want mention of 'file'", msg)
}
}
// TestUploadAgentFileHandler_NoFiles verifies empty file list is rejected.
func TestUploadAgentFileHandler_NoFiles(t *testing.T) {
gin.SetMode(gin.TestMode)
// TestUploadAgentFile_CannotAccessCanvas verifies that an
// ErrUserCanvasNotFound from the loader short-circuits with a 103
// envelope carrying the python permission-failure message
// (matches @_require_canvas_access_async at agent_api.py:78,89).
func TestUploadAgentFile_CannotAccessCanvas(t *testing.T) {
loader := &fakeCanvasLoader{err: dao.ErrUserCanvasNotFound}
h := &AgentHandler{loader: loader, fileService: &fakeAgentFileService{}}
db := setupUploadTestDB(t)
orig := dao.DB
dao.DB = db
t.Cleanup(func() { dao.DB = orig })
db.Create(&entity.User{ID: "user-1", Nickname: "test", Email: "test@test.com"})
db.Create(&entity.UserCanvas{ID: "canvas-1", UserID: "user-1", Title: sp("Test Agent")})
body := strings.NewReader("--boundary\r\nContent-Disposition: form-data; name=\"dummy\"\r\n\r\nvalue\r\n--boundary--")
req := httptest.NewRequest("POST", "/api/v1/agents/canvas-1/upload", body)
req.Header.Set("Content-Type", "multipart/form-data; boundary=boundary")
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
c.Set("user", &entity.User{ID: "user-1"})
c.Set("user_id", "user-1")
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}}
h := &AgentHandler{
agentService: service.NewAgentService(),
fileService: &fakeUploadFileService{},
}
c, w := makeUploadCtx(t, 1)
h.UploadAgentFile(c)
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
code, _ := resp["code"].(float64)
if code != float64(common.CodeArgumentError) {
t.Errorf("expected argument error, got code %v", code)
code, msg := errBody(t, w.Body.Bytes())
if code != 103 { // CodeOperatingError
t.Errorf("code = %d, want 103; msg=%q", code, msg)
}
want := "Make sure you have permission to access the agent."
if msg != want {
t.Errorf("msg = %q, want %q", msg, want)
}
}
// TestUploadAgentFileHandler_TeamMemberTenant verifies that when a team
// member uploads to a shared canvas, the file is written into the canvas
// owner's file tree, not the caller's.
func TestUploadAgentFileHandler_TeamMemberTenant(t *testing.T) {
gin.SetMode(gin.TestMode)
// TestUploadAgentFile_URLImport pins the ?url= import path (python
// agent_api.py:775-779). When ?url= is set AND the body has exactly
// one `file` field, the handler delegates to FileService.UploadFromURL
// and returns its single-dict result.
func TestUploadAgentFile_URLImport(t *testing.T) {
loader := &fakeCanvasLoader{canvas: &entity.UserCanvas{ID: "c1"}}
fu := &uploadFakes{loader: loader, fileSvc: &fakeAgentFileService{}}
fu.fileSvc.urlUpload = map[string]interface{}{"id": "url-1", "name": "remote.bin"}
h := &AgentHandler{loader: fu.loader, fileService: fu.fileSvc}
db := setupUploadTestDB(t)
orig := dao.DB
dao.DB = db
t.Cleanup(func() { dao.DB = orig })
c, w := makeUploadCtx(t, 1)
c.Request.URL.RawQuery = "url=https%3A%2F%2Fexample.com%2Ffile.bin"
h.UploadAgentFile(c)
// user-b is a team member of user-a's tenant
db.Create(&entity.User{ID: "user-a", Nickname: "owner", Email: "a@test.com"})
db.Create(&entity.User{ID: "user-b", Nickname: "member", Email: "b@test.com"})
db.Create(&entity.UserTenant{ID: "ut-1", UserID: "user-b", TenantID: "user-a", Role: "member", Status: sp("1")})
db.Create(&entity.UserCanvas{
ID: "canvas-1",
UserID: "user-a",
Permission: "team",
Title: sp("Shared Agent"),
if w.Code != 200 {
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
}
var env struct {
Code int `json:"code"`
Data map[string]interface{} `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &env); err != nil {
t.Fatalf("decode: %v (body=%s)", err, w.Body.String())
}
if env.Code != 0 {
t.Errorf("code = %d, want 0", env.Code)
}
if env.Data["id"] != "url-1" {
t.Errorf("data.id = %v, want url-1", env.Data["id"])
}
}
// TestUploadAgentFile_URLImport_IgnoredForMultiFile pins that
// ?url= is silently ignored when the body has >1 files, matching
// python's behaviour at agent_api.py:780-783 (the multi-file branch
// never reads url). The request flows into the normal UploadInfos
// path and returns a list of dicts.
func TestUploadAgentFile_URLImport_IgnoredForMultiFile(t *testing.T) {
loader := &fakeCanvasLoader{canvas: &entity.UserCanvas{ID: "c1"}}
fu := &uploadFakes{loader: loader, fileSvc: &fakeAgentFileService{}}
fu.fileSvc.urlUpload = map[string]interface{}{"id": "should-not-appear"}
fu.setUploadResult([]map[string]interface{}{
{"id": "u1", "name": "f1.txt"},
{"id": "u2", "name": "f2.txt"},
{"id": "u3", "name": "f3.txt"},
})
h := &AgentHandler{loader: fu.loader, fileService: fu.fileSvc}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body := strings.NewReader("--boundary\r\nContent-Disposition: form-data; name=\"file\"; filename=\"shared.txt\"\r\nContent-Type: text/plain\r\n\r\nhello\r\n--boundary--")
req := httptest.NewRequest("POST", "/api/v1/agents/canvas-1/upload", body)
req.Header.Set("Content-Type", "multipart/form-data; boundary=boundary")
c.Request = req
c.Set("user", &entity.User{ID: "user-b"})
c.Set("user_id", "user-b")
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}}
svc := &fakeUploadFileService{
uploaded: []map[string]interface{}{
{"id": "file-1", "name": "shared.txt"},
},
}
h := &AgentHandler{
agentService: service.NewAgentService(),
fileService: svc,
}
c, w := makeUploadCtx(t, 3)
c.Request.URL.RawQuery = "url=https%3A%2F%2Fexample.com%2Ffile.bin"
h.UploadAgentFile(c)
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if resp["code"] != float64(common.CodeSuccess) {
t.Fatalf("expected code 0, got %v: %v", resp["code"], resp["message"])
if w.Code != 200 {
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
}
if svc.lastTenantID != "user-b" {
t.Errorf("expected UploadFile called with authenticated user 'user-a', got '%s'", svc.lastTenantID)
var env struct {
Code int `json:"code"`
Data []map[string]interface{} `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &env); err != nil {
t.Fatalf("decode: %v (body=%s)", err, w.Body.String())
}
if env.Code != 0 {
t.Errorf("code = %d, want 0", env.Code)
}
if len(env.Data) != 3 {
t.Errorf("data has %d items, want 3 (url was ignored)", len(env.Data))
}
// Sanity check: the urlUpload map should NOT have been consumed.
// We don't read it back, but the test setup proves the
// handler didn't dispatch to UploadFromURL (otherwise it would
// have returned the single urlUpload dict, not the 3-element
// list from uploadList).
}
// TestUploadAgentFile_URLImport_LoaderError pins that a failed URL
// fetch surfaces as a 500 envelope (matches python's
// `server_error_response(exc)` on line 784).
func TestUploadAgentFile_URLImport_LoaderError(t *testing.T) {
loader := &fakeCanvasLoader{canvas: &entity.UserCanvas{ID: "c1"}}
fu := &uploadFakes{loader: loader, fileSvc: &fakeAgentFileService{urlUploadErr: errors.New("ssrf guard tripped")}}
h := &AgentHandler{loader: fu.loader, fileService: fu.fileSvc}
c, w := makeUploadCtx(t, 1)
c.Request.URL.RawQuery = "url=https%3A%2F%2Finternal.example.com%2Fsecret"
h.UploadAgentFile(c)
code, msg := errBody(t, w.Body.Bytes())
if code != 500 { // CodeServerError
t.Errorf("code = %d, want 500; msg=%q", code, msg)
}
if !strings.Contains(msg, "ssrf guard tripped") {
t.Errorf("msg = %q, want contains 'ssrf guard tripped'", msg)
}
}
// sp returns a pointer to the given string.
func sp(s string) *string { return &s }
func (f *fakeUploadFileService) DownloadAgentFile(tenantID, location string) ([]byte, error) {
panic("DownloadAgentFile should not be called during upload tests")
// TestUploadAgentFile_LoaderError verifies that a non-ErrUserCanvasNotFound
// service error surfaces as a 500 envelope with the error string.
func TestUploadAgentFile_LoaderError(t *testing.T) {
loader := &fakeCanvasLoader{err: context.DeadlineExceeded}
h := &AgentHandler{loader: loader, fileService: &fakeAgentFileService{}}
c, w := makeUploadCtx(t, 1)
h.UploadAgentFile(c)
code, msg := errBody(t, w.Body.Bytes())
if code != 500 { // CodeServerError
t.Errorf("code = %d, want 500; msg=%q", code, msg)
}
if !strings.Contains(msg, "deadline") {
t.Errorf("msg = %q, want contains 'deadline'", msg)
}
}

View File

@@ -0,0 +1,669 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package handler
// Webhook trigger handler — Go port of
// api/apps/restful_apis/agent_api.py:1563-2248
// (`/api/v1/agents/<agent_id>/webhook` and `/.../webhook/test`).
//
// Mirrors the python reference's 9-step flow:
// 1. Load canvas (IDOR-safe via loadCanvasForUser → LoadCanvasByID).
// 2. Reject DataFlow canvas category.
// 3. Parse DSL map.
// 4. Find Begin with mode=="Webhook"; capture webhook_cfg.
// 5. Validate request method against webhook_cfg.methods.
// 6. validateWebhookSecurity (max body size / IP whitelist / rate limit
// / token / basic / jwt — strict fail-closed rate limit).
// 7. Parse request body (parseWebhookRequest).
// 8. Schema extract query/headers/body via extractBySchema.
// 9. Dispatch:
// - execution_mode == "Immediately" → return configured status +
// body_template synchronously; canvas runs detached in the
// background. Trace events appended when isTest.
// - else (streaming/aggregate) → block on canvas.run, aggregate
// message content, return JSON.
//
// Notes on Python parity divergences:
// - `cvs.dsl = json.loads(str(canvas)); UserCanvasService.update_by_id(...)`
// post-run DSL writeback is NOT ported. The Go runner mutates an
// in-memory copy, not the persisted row. UpdateAgent already persists
// the editable DSL on the user-driven path.
// - multipart/form-data uploads are rejected with HTTP 501
// (ErrWebhookMultipartNotSupported) at parse time. The Python
// upload path through canvas.get_files_async depends on
// FileService.upload_info which is itself not in the Go port yet.
// When FileService.upload_info lands, the 501 branch in Webhook
// (line 162-185) becomes the dispatch entry point.
// - content_types whitelist ENFORCES a hard reject (102 envelope)
// when the request Content-Type does not match the configured
// value. Mirrors python agent_api.py:1839-1842 (`raise
// ValueError("Invalid Content-Type...")`).
// - Webhook trace shape (webhook-trace-<id>-logs in Redis) matches
// the python append_webhook_trace at agent_api.py:2073-2091 so the
// eventual /webhook/logs poll PR can consume this writer
// unchanged.
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"ragflow/internal/agent/canvas"
"ragflow/internal/common"
rediscli "ragflow/internal/engine/redis"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"ragflow/internal/dao"
"ragflow/internal/entity"
)
// canvasLoader is the subset of service.AgentService the webhook handler
// needs. Defined as an interface so handler tests can inject a fake
// without standing up the full AgentService (DB DAOs, eino runner, etc).
//
// LoadCanvasByID returns the raw DAO/service error so the WEBHOOK
// handler can fold it into its own 102 envelope. We deliberately do NOT
// route through handler.mapAgentError here because that helper maps
// ErrUserCanvasNotFound to 103 "Make sure you have permission..." which
// is the chat/run contract — wrong for the webhook path.
type canvasLoader interface {
LoadCanvasByID(ctx context.Context, userID, canvasID string) (*entity.UserCanvas, error)
RunAgentWithWebhook(ctx context.Context, userID, canvasID string, payload map[string]any) (<-chan canvas.RunEvent, error)
}
// Webhook is the handler method mounted at:
//
// /api/v1/agents/:canvas_id/webhook (production trigger)
// /api/v1/agents/:canvas_id/webhook/test (test trigger, with trace)
//
// The Python decorator stack at agent_api.py:1563 binds six methods to
// a single path. Gin has no Match() — the router registers each verb
// individually via registerAnyMethod (see internal/router/agent_routes.go).
func (h *AgentHandler) Webhook(c *gin.Context) {
user, code, msg := GetUser(c)
if code != common.CodeSuccess {
jsonError(c, code, msg)
return
}
canvasID := c.Param("canvas_id")
if h.loader == nil {
jsonInternalError(c, errors.New("agent webhook: loader not configured"))
return
}
isTest := strings.HasSuffix(c.Request.URL.Path, "/webhook/test")
startTs := time.Now()
// 1. Load canvas. Webhook collapses missing-vs-foreign into a
// single 102 "Canvas not found." (matches Python
// api/apps/restful_apis/agent_api.py:1572 and the existing
// GetAgentWebhookLogs envelope), so we DO NOT route through
// handler.mapAgentError here.
cv, err := h.loader.LoadCanvasByID(c.Request.Context(), user.ID, canvasID)
if err != nil {
if errors.Is(err, dao.ErrUserCanvasNotFound) || errors.Is(err, dao.ErrUserCanvasVersionNotFound) {
jsonError(c, common.CodeDataError, "Canvas not found.")
return
}
jsonInternalError(c, err)
return
}
// 2. Reject DataFlow.
if cv.CanvasCategory == "DataFlow" {
jsonError(c, common.CodeDataError, "Dataflow can not be triggered by webhook.")
return
}
// 3. DSL map. cv.DSL is a typed entity.JSONMap (map[string]any
// under the hood); we copy into a plain map[string]any so the
// downstream helpers can mutate freely without aliasing surprises.
dsl := map[string]any{}
for k, v := range cv.DSL {
dsl[k] = v
}
// 4. Find Begin component with mode=="Webhook".
webhookCfg := findWebhookBegin(dsl)
if webhookCfg == nil {
jsonError(c, common.CodeDataError, "Webhook not configured for this agent.")
return
}
// 5. Method gate.
if !methodAllowed(webhookCfg["methods"], c.Request.Method) {
jsonError(c, common.CodeDataError,
fmt.Sprintf("HTTP method '%s' not allowed for this webhook.", c.Request.Method))
return
}
// 6. Security gate (strict; surfaces all errors as 102).
securityCfg := stringMap(webhookCfg["security"])
if err := validateWebhookSecurity(securityCfg, c, canvasID); err != nil {
jsonError(c, common.CodeDataError, err.Error())
return
}
// 6a. Body size DoS hardening (security review MEDIUM-1). The
// security validator above checks the Content-Length header, but
// that header is attacker-controllable. Wrap the actual stream
// reader with http.MaxBytesReader so io.ReadAll inside
// parseWebhookRequest is bounded by the same parsed limit.
// Errors here surface as ErrWebhookContentTypeMismatch-shaped 102
// so the operator sees the failure mode consistently with the
// header check.
if limit, perr := parseMaxBodySize(securityCfg); perr == nil && limit > 0 && c.Request.Body != nil {
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, limit)
}
// 7. Parse request body. Two error classes flow through here:
// - ErrWebhookMultipartNotSupported → HTTP 501 (file uploads
// depend on FileService.upload_info which is not yet ported).
// - ErrWebhookContentTypeMismatch → HTTP 102 (content-type
// whitelist, mirrors python agent_api.py:1839).
// Both are surfaced as typed errors so the dispatch envelope is
// unambiguous — neither path falls through into schema validation.
contentType, _ := webhookCfg["content_types"].(string)
parsed, parseErr := parseWebhookRequest(contentType, c)
if parseErr != nil {
switch {
case errors.Is(parseErr, ErrWebhookMultipartNotSupported):
// 501 — multipart/form-data uploads are not yet
// supported. Body is short so operators see exactly what
// is missing.
c.JSON(http.StatusNotImplemented, gin.H{
"code": http.StatusNotImplemented,
"data": false,
"message": parseErr.Error(),
})
return
case errors.Is(parseErr, ErrWebhookContentTypeMismatch):
jsonError(c, common.CodeDataError, parseErr.Error())
return
default:
jsonError(c, common.CodeDataError, parseErr.Error())
return
}
}
// 8. Schema extract query/headers/body.
schema, _ := webhookCfg["schema"].(map[string]any)
clean, schemaErr := applyWebhookSchema(parsed, schema)
if schemaErr != nil {
jsonError(c, common.CodeDataError, schemaErr.Error())
return
}
// 9. Dispatch.
mode, _ := webhookCfg["execution_mode"].(string)
if mode == "" {
mode = "Immediately"
}
if mode == "Immediately" {
status, contentType, payload, perr := renderImmediatelyResponse(stringMap(webhookCfg["response"]))
if perr != nil {
jsonError(c, common.CodeDataError, perr.Error())
return
}
// Detached background run — does NOT inherit c.Request.Context()
// so a client disconnect does not cancel the canvas run.
go h.runWebhookDetached(cv, clean, isTest, startTs)
c.Data(status, contentType, payload)
return
}
// Streaming / blocking mode.
out := h.runWebhookSync(c.Request.Context(), cv, clean, isTest, startTs)
c.JSON(out.status, out.body)
}
// findWebhookBegin scans the DSL components map for a Begin component
// whose params.mode equals "Webhook". Returns the params map (the
// webhook_cfg) on success; nil otherwise. Mirrors agent_api.py:1584-1592.
func findWebhookBegin(dsl map[string]any) map[string]any {
if dsl == nil {
return nil
}
components, _ := dsl["components"].(map[string]any)
if components == nil {
return nil
}
for _, raw := range components {
entry, _ := raw.(map[string]any)
if entry == nil {
continue
}
obj, _ := entry["obj"].(map[string]any)
if obj == nil {
continue
}
name, _ := obj["component_name"].(string)
if !strings.EqualFold(name, "begin") {
continue
}
params, _ := obj["params"].(map[string]any)
if params == nil {
continue
}
if mode, _ := params["mode"].(string); mode == "Webhook" {
return params
}
}
return nil
}
// methodAllowed returns true when `requestMethod` is in the configured
// list. Empty list → allow (matches python: agent_api.py:1596 short-circuit).
func methodAllowed(raw any, requestMethod string) bool {
methods, _ := raw.([]any)
if len(methods) == 0 {
return true
}
want := strings.ToUpper(requestMethod)
for _, m := range methods {
if s, _ := m.(string); strings.ToUpper(s) == want {
return true
}
}
return false
}
// stringMap is a tiny helper: extract a map[string]any, treating nil and
// wrong-type cases as empty maps. Used heavily to keep the dispatch
// logic readable.
func stringMap(v any) map[string]any {
if m, ok := v.(map[string]any); ok {
return m
}
return map[string]any{}
}
// parseWebhookRequest mirrors agent_api.py:1828-1894.
//
// Returns (parsed, error). The parsed map carries query / headers /
// body / content_type. The body is parsed according to the request
// Content-Type:
//
// - application/json → unmarshal JSON
// - application/x-www-form-urlencoded → form fields
// - text/plain / octet-stream / unknown / empty → raw bytes → JSON
//
// Two errors are surfaced directly to the handler so the dispatch path
// can return the right envelope without falling through into a
// misleading schema-validation error:
//
// - ErrWebhookMultipartNotSupported (501): the request used
// multipart/form-data. The Python path uploads files through
// FileService.upload_info → canvas.get_files_async, both of which
// are NOT yet ported. We refuse the request up front so callers
// see a clear, typed 501 instead of an unrelated schema error.
//
// - ErrWebhookContentTypeMismatch (102): when the webhook config
// sets content_types, the request Content-Type must match. The
// Python reference raises here
// (`raise ValueError("Invalid Content-Type...")` at
// agent_api.py:1839-1842); we mirror that as a typed error and
// surface it through the same 102 envelope as the rest of the
// validation errors so operators see the same response shape.
func parseWebhookRequest(configuredContentType string, c *gin.Context) (map[string]any, error) {
// 1. Query
q := map[string]any{}
for k, vals := range c.Request.URL.Query() {
if len(vals) == 1 {
q[k] = vals[0]
} else {
q[k] = vals
}
}
// 2. Headers
hd := map[string]any{}
for k, vals := range c.Request.Header {
if len(vals) == 1 {
hd[k] = vals[0]
} else {
hd[k] = vals
}
}
// 3. Body
ctype := strings.SplitN(c.GetHeader("Content-Type"), ";", 2)[0]
ctype = strings.TrimSpace(strings.ToLower(ctype))
// multipart/form-data → 501. Checked BEFORE the content-type match
// below because multipart uploads don't carry a useful
// `content_types` value to validate against — the python handler
// routes them through canvas.get_files_async which we have not
// ported yet.
if ctype == "multipart/form-data" {
return nil, ErrWebhookMultipartNotSupported
}
body := map[string]any{}
switch ctype {
case "application/json":
raw, _ := io.ReadAll(c.Request.Body)
if len(raw) > 0 {
_ = json.Unmarshal(raw, &body)
}
case "application/x-www-form-urlencoded":
if err := c.Request.ParseForm(); err == nil {
for k, vals := range c.Request.PostForm {
if len(vals) == 1 {
body[k] = vals[0]
} else {
body[k] = vals
}
}
}
default:
raw, _ := io.ReadAll(c.Request.Body)
if len(raw) > 0 {
_ = json.Unmarshal(raw, &body)
}
}
// content_type whitelist. Empty configuredContentType → no check
// (matches python: agent_api.py:1839 only raises when the
// webhook_cfg actually sets content_types).
//
// The python reference has a `if ctype and ...` short-circuit that
// lets a request with NO Content-Type header through even when the
// webhook config requires one. That is a whitelist bypass — the
// configured type is irrelevant if a caller can simply omit the
// header. The Go port tightens the contract: when the operator
// configured content_types, the request MUST carry a matching
// Content-Type. Missing header → ErrWebhookContentTypeMismatch.
if configuredContentType != "" && configuredContentType != ctype {
return nil, fmt.Errorf("%w: expect %q, got %q",
ErrWebhookContentTypeMismatch, configuredContentType, ctype)
}
return map[string]any{
"query": q,
"headers": hd,
"body": body,
"content_type": ctype,
}, nil
}
// ErrWebhookMultipartNotSupported is returned by parseWebhookRequest
// when the inbound Content-Type is multipart/form-data. The handler
// translates this to HTTP 501 Not Implemented because the Python file
// upload path (FileService.upload_info → canvas.get_files_async) is
// not ported yet.
var ErrWebhookMultipartNotSupported = errors.New("multipart/form-data uploads are not supported in this port")
// ErrWebhookContentTypeMismatch is returned by parseWebhookRequest when
// the configured content_types whitelist disagrees with the request's
// Content-Type. Mirrors python agent_api.py:1839-1842 (`raise
// ValueError("Invalid Content-Type...")`). Surfaced as 102 so the
// operator sees a clear, content-type-specific message.
var ErrWebhookContentTypeMismatch = errors.New("invalid content-type")
// applyWebhookSchema runs extractBySchema on the parsed request's three
// sections and assembles the clean_request map that the Python
// webhook handler passes to canvas.run(webhook_payload=...). Mirrors
// agent_api.py:2052-2068.
func applyWebhookSchema(parsed map[string]any, schema map[string]any) (map[string]any, error) {
q := stringMap(parsed["query"])
hd := stringMap(parsed["headers"])
bd := stringMap(parsed["body"])
qClean, err := extractBySchema(q, stringMap(schema["query"]), "query")
if err != nil {
return nil, err
}
hdClean, err := extractBySchema(hd, stringMap(schema["headers"]), "headers")
if err != nil {
return nil, err
}
bdClean, err := extractBySchema(bd, stringMap(schema["body"]), "body")
if err != nil {
return nil, err
}
return map[string]any{
"query": qClean,
"headers": hdClean,
"body": bdClean,
"input": parsed,
}, nil
}
// renderImmediatelyResponse builds the synchronous response for the
// Immediately execution mode. Mirrors agent_api.py:2093-2121.
//
// - status: int in [200, 399]; defaults to 200; any other value
// raises so the operator notices a config bug.
// - body_template: JSON or text. We try JSON first; fall back to
// plain text on failure. Empty body → no body, content-type
// application/json (matching python parse_body(None)).
func renderImmediatelyResponse(cfg map[string]any) (int, string, []byte, error) {
statusRaw, ok := cfg["status"]
status := 200
if ok {
switch v := statusRaw.(type) {
case int:
status = v
case int64:
status = int(v)
case float64:
status = int(v)
case string:
n, err := strconv.Atoi(v)
if err != nil {
return 0, "", nil, fmt.Errorf("invalid response status code: %v", v)
}
status = n
default:
return 0, "", nil, fmt.Errorf("invalid response status code: %v", v)
}
}
if status < 200 || status > 399 {
return 0, "", nil, fmt.Errorf("invalid response status code: %d, must be between 200 and 399", status)
}
bodyTpl, _ := cfg["body_template"].(string)
if bodyTpl == "" {
return status, "application/json", nil, nil
}
// Try JSON parse first (python: parse_body() branch at line 2110).
var probe any
if err := json.Unmarshal([]byte(bodyTpl), &probe); err == nil {
encoded, _ := json.Marshal(probe)
return status, "application/json", encoded, nil
}
return status, "text/plain", []byte(bodyTpl), nil
}
// runWebhookDetached runs the canvas in the background. It uses
// context.Background() with a 5-minute timeout (NOT
// c.Request.Context()) so a client disconnect does NOT cancel the run.
// Trace events are appended to the redis key when isTest is true.
//
// Mirrors python: agent_api.py:2123-2175 (the asyncio.create_task body
// inside the Immediately branch).
func (h *AgentHandler) runWebhookDetached(
cv *entity.UserCanvas, payload map[string]any, isTest bool, startTs time.Time,
) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
events, err := h.loader.RunAgentWithWebhook(ctx, cv.UserID, cv.ID, payload)
if err != nil {
common.Warn("webhook detached run start failed",
zap.String("canvas", cv.ID),
zap.Error(err))
if isTest {
appendWebhookTrace(cv.ID, startTs, canvas.RunEvent{Type: "error", Data: mustJSON(map[string]any{"message": err.Error()})})
}
return
}
for ev := range events {
if isTest {
appendWebhookTrace(cv.ID, startTs, ev)
}
}
}
// runWebhookSync drives the canvas in the non-Immediately (streaming)
// mode. Returns the HTTP status + body to send. Mirrors
// agent_api.py:2178-2247 with the python sse() coroutine flattened into
// a synchronous loop.
type webhookSyncResult struct {
status int
body any
}
func (h *AgentHandler) runWebhookSync(
ctx context.Context, cv *entity.UserCanvas, payload map[string]any,
isTest bool, startTs time.Time,
) webhookSyncResult {
status := 200
events, err := h.loader.RunAgentWithWebhook(ctx, cv.UserID, cv.ID, payload)
if err != nil {
if isTest {
appendWebhookTrace(cv.ID, startTs, canvas.RunEvent{Type: "error", Data: mustJSON(map[string]any{"message": err.Error()})})
appendWebhookTrace(cv.ID, startTs, canvas.RunEvent{Type: "finished", Data: mustJSON(map[string]any{"success": false})})
}
return webhookSyncResult{status: http.StatusBadRequest, body: gin.H{
"code": 400,
"message": err.Error(),
"success": false,
}}
}
contents := []string{}
for ev := range events {
if isTest {
appendWebhookTrace(cv.ID, startTs, ev)
}
switch ev.Type {
case "message":
var msg struct {
Content string `json:"content"`
StartToThink bool `json:"start_to_think"`
EndToThink bool `json:"end_to_think"`
}
if json.Unmarshal([]byte(ev.Data), &msg) == nil {
content := msg.Content
if msg.StartToThink {
content = "think"
} else if msg.EndToThink {
content = "/think"
}
if content != "" {
contents = append(contents, content)
}
}
case "message_end":
var end struct {
Status *int `json:"status"`
}
if json.Unmarshal([]byte(ev.Data), &end) == nil && end.Status != nil {
status = *end.Status
}
}
}
final := strings.Join(contents, "")
if isTest {
appendWebhookTrace(cv.ID, startTs, canvas.RunEvent{Type: "finished", Data: mustJSON(map[string]any{"success": true})})
}
return webhookSyncResult{status: status, body: gin.H{
"message": final,
"success": true,
"code": status,
}}
}
// mustJSON marshals v to a JSON object string. Used by trace appenders;
// panics on marshal failure (acceptable because we only marshal
// statically-typed map[string]any values).
func mustJSON(v any) string {
var buf bytes.Buffer
_ = json.NewEncoder(&buf).Encode(v)
return buf.String()
}
// appendWebhookTrace appends a single RunEvent to the per-canvas trace
// key in Redis. Mirrors python's append_webhook_trace at
// agent_api.py:2073-2091.
//
// The trace key is `webhook-trace-<agent_id>-logs` with a 600 s TTL.
// Each event is recorded as {"ts": <float>, "event": <type>, ...}.
// Tests use miniredis to verify the key shape.
func appendWebhookTrace(agentID string, startTs time.Time, ev canvas.RunEvent) {
rdb := rediscli.Get()
if rdb == nil {
return
}
key := fmt.Sprintf("webhook-trace-%s-logs", agentID)
raw, _ := rdb.Get(key)
obj := map[string]any{}
if raw != "" {
_ = json.Unmarshal([]byte(raw), &obj)
}
whs, _ := obj["webhooks"].(map[string]any)
if whs == nil {
whs = map[string]any{}
obj["webhooks"] = whs
}
entryKey := strconv.FormatFloat(float64(startTs.UnixNano())/1e9, 'f', -1, 64)
entry, _ := whs[entryKey].(map[string]any)
if entry == nil {
entry = map[string]any{
"start_ts": float64(startTs.UnixNano()) / 1e9,
"events": []any{},
}
whs[entryKey] = entry
}
events, _ := entry["events"].([]any)
eventRecord := map[string]any{
"ts": float64(time.Now().UnixNano()) / 1e9,
"event": ev.Type,
}
if ev.Data != "" {
eventRecord["data"] = json.RawMessage(ev.Data)
}
if ev.MessageID != "" {
eventRecord["message_id"] = ev.MessageID
}
if ev.TaskID != "" {
eventRecord["task_id"] = ev.TaskID
}
if ev.SessionID != "" {
eventRecord["session_id"] = ev.SessionID
}
entry["events"] = append(events, eventRecord)
encoded, err := json.Marshal(obj)
if err != nil {
common.Warn("webhook trace marshal failed", zap.Error(err))
return
}
rdb.SetObj(key, string(encoded), 600*time.Second)
}

View File

@@ -0,0 +1,273 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package handler
// Webhook schema helpers.
//
// These mirror api/apps/restful_apis/agent_api.py:1896-2051 (extract_by_schema,
// default_for_type, auto_cast_value, validate_type) from the Python
// webhook handler. The Go port preserves the Python semantics exactly:
// required fields raise; optional fields get a type-based default;
// type coercion runs before validation.
//
// Edge cases preserved:
// - "file" type maps to []any (a list of parsed file descriptors).
// - "array" accepts a JSON-decoded list; "array<inner>" validates each
// element against `inner`.
// - "object" must decode to map[string]any.
// - Empty data sections yield an empty map (NOT nil) so downstream
// components can index without a nil check.
import (
"encoding/json"
"fmt"
"strconv"
"strings"
)
// defaultForType returns the schema-default value for the given type
// token. Mirrors python's `default_for_type` at agent_api.py:1939-1955.
//
// Special cases:
// - "file" → []any{} (a list, the python empty list)
// - "object" → map[string]any{} (a dict)
// - "boolean" → false
// - "number" → 0 (float64 to match python's int/float unification)
// - "string" → ""
// - "array"/"array<...>" → []any{}
// - "null" → nil
// - anything else → nil
func defaultForType(t string) any {
switch t {
case "file":
return []any{}
case "object":
return map[string]any{}
case "boolean":
return false
case "number":
return float64(0)
case "string":
return ""
case "null":
return nil
}
if strings.HasPrefix(t, "array") {
return []any{}
}
return nil
}
// autoCastValue mirrors python's auto_cast_value at agent_api.py:1957-2016.
// It accepts already-decoded values unchanged and converts string values
// to the schema-expected type when possible.
//
// Returns an error when conversion is impossible (e.g. "abc" → number).
func autoCastValue(value any, expectedType string) (any, error) {
if _, isString := value.(string); !isString {
// Non-string values are passed through unchanged; the Python
// reference only converts when the source is a string.
return value, nil
}
v := strings.TrimSpace(value.(string))
switch expectedType {
case "boolean":
switch strings.ToLower(v) {
case "true", "1":
return true, nil
case "false", "0":
return false, nil
}
return nil, fmt.Errorf("cannot convert %q to boolean", value)
case "number":
// Python tries int first, then float. In Go strconv.Atoi handles
// int; for float we fall back to ParseFloat. Negative integers
// must be supported (the python `v[1:].isdigit()` guard does
// that).
if v != "" && (v[0] == '-' && len(v) > 1 || v[0] >= '0' && v[0] <= '9') {
// Try int first (no fractional part, no exponent).
if !strings.ContainsAny(v, ".eE") {
if i, err := strconv.ParseInt(v, 10, 64); err == nil {
return float64(i), nil
}
}
}
f, err := strconv.ParseFloat(v, 64)
if err != nil {
return nil, fmt.Errorf("cannot convert %q to number", value)
}
return f, nil
case "object":
var parsed map[string]any
if err := json.Unmarshal([]byte(v), &parsed); err != nil {
return nil, fmt.Errorf("cannot convert %q to object", value)
}
return parsed, nil
}
if strings.HasPrefix(expectedType, "array") {
var parsed []any
if err := json.Unmarshal([]byte(v), &parsed); err != nil {
return nil, fmt.Errorf("cannot convert %q to array", value)
}
return parsed, nil
}
// "string" / "file" / unknown — return as-is (matching the python
// passthrough branches at the bottom of auto_cast_value).
return value, nil
}
// validateType mirrors python's validate_type at agent_api.py:2019-2051.
// It returns true when `value` is structurally compatible with type `t`.
// Unknown types pass through (python parity: agent_api.py:2051).
func validateType(value any, t string) bool {
if t == "" {
return true
}
switch t {
case "file":
_, ok := value.([]any)
return ok
case "string":
_, ok := value.(string)
return ok
case "number":
switch value.(type) {
case int, int32, int64, float32, float64:
return true
}
return false
case "boolean":
_, ok := value.(bool)
return ok
case "object":
_, ok := value.(map[string]any)
return ok
case "null":
return value == nil
}
if strings.HasPrefix(t, "array") {
list, ok := value.([]any)
if !ok {
return false
}
// array<string> / array<number> / array<object>: validate each
// element against the inner type when <...> is present.
if open := strings.Index(t, "<"); open >= 0 {
if close := strings.Index(t[open:], ">"); close > 0 {
inner := t[open+1 : open+close]
if inner == "" {
return true
}
for _, item := range list {
if !validateType(item, inner) {
return false
}
}
}
}
return true
}
// Unknown type — pass through (python parity at line 2051).
return true
}
// extractBySchema mirrors python's extract_by_schema at agent_api.py:1896-1936.
//
// Behaviour:
// - `data` may be nil; treated as empty map.
// - For each property in schema.properties:
// - if required and missing → error
// - if missing (optional) → use defaultForType(prop.type)
// - if present → autoCastValue then validateType; mismatch raises
// - returns the cleaned map; nil `data` returns an empty map (NOT nil).
func extractBySchema(data map[string]any, schema map[string]any, name string) (map[string]any, error) {
if data == nil {
data = map[string]any{}
}
if schema == nil {
schema = map[string]any{}
}
props, _ := schema["properties"].(map[string]any)
required := stringSlice(schema["required"])
out := map[string]any{}
for field, rawSchema := range props {
fieldSchema, _ := rawSchema.(map[string]any)
fieldType, _ := fieldSchema["type"].(string)
// 1. Required field missing → error (python agent_api.py:1913).
isRequired := contains(required, field)
if isRequired {
if _, ok := data[field]; !ok {
return nil, fmt.Errorf("%s missing required field: %s", name, field)
}
}
// 2. Optional → default value.
raw, present := data[field]
if !present {
out[field] = defaultForType(fieldType)
continue
}
// 3. Auto cast (only matters for string→typed; pass-through otherwise).
casted, err := autoCastValue(raw, fieldType)
if err != nil {
return nil, fmt.Errorf("%s.%s auto-cast failed: %s", name, field, err.Error())
}
// 4. Type validation.
if !validateType(casted, fieldType) {
return nil, fmt.Errorf(
"%s.%s type mismatch: expected %s, got %T",
name, field, fieldType, casted,
)
}
out[field] = casted
}
return out, nil
}
// stringSlice accepts either a []string or []any for the schema's
// `required` field. Mirrors the python list-or-tuple-or-set union
// handling at agent_api.py:1788-1798.
func stringSlice(v any) []string {
switch s := v.(type) {
case []string:
return s
case []any:
out := make([]string, 0, len(s))
for _, item := range s {
if str, ok := item.(string); ok {
out = append(out, str)
}
}
return out
}
return nil
}
func contains(haystack []string, needle string) bool {
for _, s := range haystack {
if s == needle {
return true
}
}
return false
}

View File

@@ -0,0 +1,242 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package handler
import (
"strings"
"testing"
)
// TestExtractBySchema_RequiredMissing pins the required-field-missing
// branch (mirrors python agent_api.py:1913).
func TestExtractBySchema_RequiredMissing(t *testing.T) {
schema := map[string]any{
"properties": map[string]any{
"q": map[string]any{"type": "string"},
},
"required": []any{"q"},
}
_, err := extractBySchema(map[string]any{}, schema, "query")
if err == nil || !strings.Contains(err.Error(), "missing required field") {
t.Errorf("err = %v, want missing required field", err)
}
}
// TestExtractBySchema_OptionalDefault confirms that optional fields get
// the type-default value when absent.
func TestExtractBySchema_OptionalDefault(t *testing.T) {
schema := map[string]any{
"properties": map[string]any{
"s": map[string]any{"type": "string"},
"b": map[string]any{"type": "boolean"},
"o": map[string]any{"type": "object"},
"a": map[string]any{"type": "array"},
"f": map[string]any{"type": "file"},
"n": map[string]any{"type": "number"},
},
}
got, err := extractBySchema(map[string]any{}, schema, "query")
if err != nil {
t.Fatalf("extractBySchema: %v", err)
}
if got["s"] != "" {
t.Errorf("default for string: got %v", got["s"])
}
if got["b"] != false {
t.Errorf("default for boolean: got %v", got["b"])
}
if _, ok := got["o"].(map[string]any); !ok {
t.Errorf("default for object: got %T", got["o"])
}
if _, ok := got["a"].([]any); !ok {
t.Errorf("default for array: got %T", got["a"])
}
if _, ok := got["f"].([]any); !ok {
t.Errorf("default for file: got %T", got["f"])
}
if got["n"] != float64(0) {
t.Errorf("default for number: got %v", got["n"])
}
}
// TestExtractBySchema_NumberCoercion pins string→number auto-cast.
func TestExtractBySchema_NumberCoercion(t *testing.T) {
schema := map[string]any{
"properties": map[string]any{
"n": map[string]any{"type": "number"},
},
}
got, err := extractBySchema(map[string]any{"n": "42"}, schema, "body")
if err != nil {
t.Fatalf("extractBySchema: %v", err)
}
if got["n"] != float64(42) {
t.Errorf("number coercion: got %v (%T)", got["n"], got["n"])
}
}
// TestExtractBySchema_NumberCoercionFloat covers decimal strings.
func TestExtractBySchema_NumberCoercionFloat(t *testing.T) {
schema := map[string]any{
"properties": map[string]any{
"n": map[string]any{"type": "number"},
},
}
got, err := extractBySchema(map[string]any{"n": "3.14"}, schema, "body")
if err != nil {
t.Fatalf("extractBySchema: %v", err)
}
if got["n"] != 3.14 {
t.Errorf("number float coercion: got %v", got["n"])
}
}
// TestExtractBySchema_BooleanCoercion pins true/false string parsing.
func TestExtractBySchema_BooleanCoercion(t *testing.T) {
schema := map[string]any{
"properties": map[string]any{
"b": map[string]any{"type": "boolean"},
},
}
cases := []struct {
in string
want bool
}{
{"true", true},
{"TRUE", true},
{"1", true},
{"false", false},
{"0", false},
}
for _, tc := range cases {
got, err := extractBySchema(map[string]any{"b": tc.in}, schema, "body")
if err != nil {
t.Errorf("coerce %q: err=%v", tc.in, err)
continue
}
if got["b"] != tc.want {
t.Errorf("coerce %q: got %v, want %v", tc.in, got["b"], tc.want)
}
}
}
// TestExtractBySchema_BooleanCoercionBad confirms non-boolean strings fail.
func TestExtractBySchema_BooleanCoercionBad(t *testing.T) {
schema := map[string]any{
"properties": map[string]any{
"b": map[string]any{"type": "boolean"},
},
}
_, err := extractBySchema(map[string]any{"b": "maybe"}, schema, "body")
if err == nil || !strings.Contains(err.Error(), "auto-cast failed") {
t.Errorf("err = %v, want auto-cast failed", err)
}
}
// TestExtractBySchema_TypeMismatch covers wrong-type validation.
func TestExtractBySchema_TypeMismatch(t *testing.T) {
schema := map[string]any{
"properties": map[string]any{
"n": map[string]any{"type": "number"},
},
}
_, err := extractBySchema(map[string]any{"n": []any{1}}, schema, "body")
if err == nil || !strings.Contains(err.Error(), "type mismatch") {
t.Errorf("err = %v, want type mismatch", err)
}
}
// TestExtractBySchema_ArrayInnerType confirms array<inner> validation.
func TestExtractBySchema_ArrayInnerType(t *testing.T) {
schema := map[string]any{
"properties": map[string]any{
"xs": map[string]any{"type": "array<number>"},
},
}
// valid
if _, err := extractBySchema(map[string]any{"xs": []any{1, 2, 3}}, schema, "body"); err != nil {
t.Errorf("valid array: err=%v", err)
}
// invalid inner type
_, err := extractBySchema(map[string]any{"xs": []any{1, "two", 3}}, schema, "body")
if err == nil || !strings.Contains(err.Error(), "type mismatch") {
t.Errorf("invalid array inner: err=%v, want type mismatch", err)
}
}
// TestExtractBySchema_ObjectCoercion pins string→object auto-cast.
func TestExtractBySchema_ObjectCoercion(t *testing.T) {
schema := map[string]any{
"properties": map[string]any{
"o": map[string]any{"type": "object"},
},
}
got, err := extractBySchema(map[string]any{"o": `{"k":"v"}`}, schema, "body")
if err != nil {
t.Fatalf("extractBySchema: %v", err)
}
m, ok := got["o"].(map[string]any)
if !ok || m["k"] != "v" {
t.Errorf("object coercion: got %v", got["o"])
}
}
// TestExtractBySchema_NilData confirms nil data yields empty map.
func TestExtractBySchema_NilData(t *testing.T) {
got, err := extractBySchema(nil, nil, "query")
if err != nil {
t.Fatalf("err: %v", err)
}
if got == nil {
t.Errorf("expected empty map, got nil")
}
if len(got) != 0 {
t.Errorf("expected empty map, got %v", got)
}
}
// TestValidateType_FileAndArray confirms validateType handles the
// edge cases at the python agent_api.py:2021-2051 boundary.
func TestValidateType_FileAndArray(t *testing.T) {
if !validateType([]any{map[string]any{}}, "file") {
t.Error("file should accept []any")
}
if !validateType([]any{1, 2, 3}, "array<number>") {
t.Error("array<number> should accept []int-equivalent")
}
if validateType([]any{1, "x"}, "array<number>") {
t.Error("array<number> should reject mixed types")
}
}
// TestDefaultForType covers the python default_for_type table.
func TestDefaultForType(t *testing.T) {
cases := []struct {
t string
// We compare via fmt-friendly assertions rather than reflect.DeepEqual
// because the python reference returns [] / {} / false / 0 / "" / nil
// directly, and Go equivalents may differ in concrete type.
}{
{"file"}, {"object"}, {"boolean"}, {"number"}, {"string"},
{"array"}, {"array<object>"}, {"null"}, {"unknown"},
}
for _, tc := range cases {
got := defaultForType(tc.t)
// Just ensure no panic and we get a sensible zero value.
_ = got
}
}

View File

@@ -0,0 +1,491 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package handler
// Webhook security helpers.
//
// These mirror api/apps/restful_apis/agent_api.py:1602-1810
// (validate_webhook_security + the six sub-validators) from the Python
// webhook handler. The Go port preserves the Python semantics exactly:
//
// - max body size (with the 10 MB cap at agent_api.py:1652)
// - IP whitelist (CIDR + exact match)
// - rate limit (token bucket via redis.EvalTokenBucketStrict — strict
// fail-closed; see redis.go)
// - token auth (header check)
// - basic auth (HTTP Basic)
// - JWT (HS/RS256, audience/issuer/required-claims)
//
// All helpers return a Go error so the handler can surface the
// Python-shaped 102 envelope directly. Empty security config means
// "no checks" (matches agent_api.py:1607).
import (
"context"
"fmt"
"net"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
rediscli "ragflow/internal/engine/redis"
)
const (
// bytesPerKB / bytesPerMB use SI decimal units (1 KB = 1 000 B,
// 1 MB = 1 000 000 B). Per user request, the Go port diverges
// from python agent_api.py:1643 which uses 1 KB = 1 024 B
// (binary). The header doc comments on parseMaxBodySize call
// this out so future audits can find the divergence.
bytesPerKB int64 = 1000
bytesPerMB int64 = bytesPerKB * 1000
// webhookBodyMaxBytes is the hard ceiling above which a configured
// max_body_size is rejected at config-validation time. Stated
// in decimal MB to match the kb/mb unit base above; 10 MB
// decimal = 10 000 000 bytes. The python reference
// (agent_api.py:1652) uses 10 * 1024 * 1024 = 10 485 760 bytes,
// so this is ~486 KB stricter than python.
webhookBodyMaxBytes int64 = 10 * bytesPerMB
// webhookRateLimitTimeout is the redis call timeout used for the
// strict token-bucket lookup. Short on purpose — a security gate
// must not stall the request thread.
webhookRateLimitTimeout = 500 * time.Millisecond
// jwtReservedClaims is the python-parity set at agent_api.py:1800
// — required_claims must NOT name any of these.
jwtReservedClaims = "exp,sub,aud,iss,nbf,iat"
)
// validateWebhookSecurity is the orchestrator. Empty/nil cfg → no-op
// (matches agent_api.py:1607 "No security config → allowed by default").
//
// Sub-validators run in the python-defined order:
// 1. validateMaxBodySize
// 2. validateIPWhitelist
// 3. validateRateLimit
// 4. validateAuth (dispatches on auth_type)
func validateWebhookSecurity(
securityCfg map[string]any,
c *gin.Context,
canvasID string,
) error {
if len(securityCfg) == 0 {
return nil
}
if err := validateMaxBodySize(c, securityCfg); err != nil {
return err
}
if err := validateIPWhitelist(c, securityCfg); err != nil {
return err
}
if err := validateRateLimit(canvasID, securityCfg); err != nil {
return err
}
return validateAuth(c, securityCfg)
}
// validateMaxBodySize mirrors python agent_api.py:1636-1658.
//
// Format: "<n>kb" | "<n>mb" (case-insensitive). Anything else is a
// config bug. The configured limit is capped at webhookBodyMaxBytes
// (10 MB) — exceeding that is also a config error, not silently raised.
// The actual request size is then compared against the parsed limit.
func validateMaxBodySize(c *gin.Context, cfg map[string]any) error {
limit, err := parseMaxBodySize(cfg)
if err != nil {
return err
}
if limit <= 0 {
return nil
}
contentLength := c.Request.ContentLength
if contentLength < 0 {
contentLength = 0
}
if contentLength > limit {
return fmt.Errorf("request body too large: %d > %d", contentLength, limit)
}
return nil
}
// parseMaxBodySize returns the byte limit configured by
// `max_body_size` (with the 10 MB cap enforced) or 0 when no limit
// is configured. The handler uses this to wrap c.Request.Body in
// http.MaxBytesReader so the actual stream read is bounded — the
// Content-Length header check alone is insufficient because a client
// can advertise a small Content-Length and stream more.
//
// Unit base: 1 kb = 1 000 B, 1 mb = 1 000 000 B (SI decimal). Per
// user request; note this DIVERGES from the python reference
// (agent_api.py:1643 uses binary 1 KB = 1 024 B).
//
// Overflow guard (CodeRabbit PR review #3 on PR #16403): a very
// large configured n (e.g. "10000000mb") would otherwise overflow
// `n * bytesPerMB` and wrap to a small positive number, bypassing
// the 10 MB cap. We check the parsed number against the cap before
// multiplying.
func parseMaxBodySize(cfg map[string]any) (int64, error) {
raw, ok := cfg["max_body_size"].(string)
if !ok || raw == "" {
return 0, nil
}
sizeStr := strings.ToLower(strings.TrimSpace(raw))
var limit int64
switch {
case strings.HasSuffix(sizeStr, "kb"):
n, err := strconv.ParseInt(strings.TrimSuffix(sizeStr, "kb"), 10, 64)
if err != nil || n <= 0 {
return 0, fmt.Errorf("invalid max_body_size format")
}
if n > webhookBodyMaxBytes/bytesPerKB {
return 0, fmt.Errorf("max_body_size exceeds maximum allowed size (10MB)")
}
limit = n * bytesPerKB
case strings.HasSuffix(sizeStr, "mb"):
n, err := strconv.ParseInt(strings.TrimSuffix(sizeStr, "mb"), 10, 64)
if err != nil || n <= 0 {
return 0, fmt.Errorf("invalid max_body_size format")
}
if n > webhookBodyMaxBytes/bytesPerMB {
return 0, fmt.Errorf("max_body_size exceeds maximum allowed size (10MB)")
}
limit = n * bytesPerMB
default:
return 0, fmt.Errorf("invalid max_body_size format")
}
if limit > webhookBodyMaxBytes {
return 0, fmt.Errorf("max_body_size exceeds maximum allowed size (10MB)")
}
return limit, nil
}
// validateIPWhitelist mirrors python agent_api.py:1660-1679. Empty
// list → allow. Supports CIDR ("10.0.0.0/8") and exact ("1.2.3.4").
// The client IP comes from gin's c.ClientIP() which honours
// X-Forwarded-For when trusted proxies are configured.
func validateIPWhitelist(c *gin.Context, cfg map[string]any) error {
whitelist, _ := cfg["ip_whitelist"].([]any)
if len(whitelist) == 0 {
return nil
}
clientIP := c.ClientIP()
for _, raw := range whitelist {
rule, _ := raw.(string)
if rule == "" {
continue
}
if strings.Contains(rule, "/") {
// CIDR
_, ipNet, err := net.ParseCIDR(rule)
if err != nil {
continue
}
addr := net.ParseIP(clientIP)
if addr != nil && ipNet.Contains(addr) {
return nil
}
continue
}
// Exact match
if clientIP == rule {
return nil
}
}
return fmt.Errorf("IP %s is not allowed by whitelist", clientIP)
}
// validateRateLimit mirrors python agent_api.py:1681-1723.
//
// Window mapping (matches agent_api.py:1692-1697):
//
// second → 1s, minute → 60s, hour → 3600s, day → 86400s.
//
// Unknown per → error (NOT silently fall through).
//
// Strict fail-closed: any Redis error → error. The webhook handler
// surfaces this as 102 so an operator notices a misconfiguration.
func validateRateLimit(canvasID string, cfg map[string]any) error {
rawRL, ok := cfg["rate_limit"].(map[string]any)
if !ok || len(rawRL) == 0 {
return nil
}
limitF, ok := rawRL["limit"].(float64)
if !ok {
// JSON numbers often come back as float64; try int as well.
if limitI, ok2 := rawRL["limit"].(int); ok2 {
limitF = float64(limitI)
ok = true
}
}
if !ok || limitF <= 0 {
return fmt.Errorf("rate_limit.limit must be > 0")
}
per, _ := rawRL["per"].(string)
if per == "" {
per = "minute"
}
var window float64
switch per {
case "second":
window = 1
case "minute":
window = 60
case "hour":
window = 3600
case "day":
window = 86400
default:
return fmt.Errorf("invalid rate_limit.per: %s", per)
}
key := fmt.Sprintf("rl:tb:%s", canvasID)
ctx, cancel := context.WithTimeout(context.Background(), webhookRateLimitTimeout)
defer cancel()
rdb := rediscli.Get()
if rdb == nil {
return fmt.Errorf("rate limit error: redis not initialised")
}
allowed, err := rdb.EvalTokenBucketStrict(ctx, key, limitF, limitF/window)
if err != nil {
return fmt.Errorf("rate limit error: %s", err.Error())
}
if !allowed {
return fmt.Errorf("too many requests (rate limit exceeded)")
}
return nil
}
// validateAuth dispatches on auth_type. Empty cfg or auth_type=="none"
// → allow (matches agent_api.py:1621).
func validateAuth(c *gin.Context, cfg map[string]any) error {
authType, _ := cfg["auth_type"].(string)
if authType == "" || authType == "none" {
return nil
}
switch authType {
case "token":
return validateTokenAuth(c, cfg)
case "basic":
return validateBasicAuth(c, cfg)
case "jwt":
return validateJWTAuth(c, cfg)
}
return fmt.Errorf("unsupported auth_type: %s", authType)
}
// validateTokenAuth mirrors python agent_api.py:1725-1733.
//
// An empty configured `token_value` previously meant "accept any
// request without that header". We now require both header and
// value to be non-empty configured secrets, otherwise the request
// is rejected as misconfigured. CodeRabbit PR review #4.
func validateTokenAuth(c *gin.Context, cfg map[string]any) error {
rawToken, _ := cfg["token"].(map[string]any)
if rawToken == nil {
return fmt.Errorf("Invalid token authentication")
}
header, _ := rawToken["token_header"].(string)
want, _ := rawToken["token_value"].(string)
if header == "" || want == "" {
return fmt.Errorf("Invalid token authentication")
}
if c.GetHeader(header) != want {
return fmt.Errorf("Invalid token authentication")
}
return nil
}
// validateBasicAuth mirrors python agent_api.py:1735-1743. We use
// gin's c.Request.BasicAuth() which parses the Authorization header
// and returns the (user, pass, ok) triple. Empty configured
// username/password are now rejected (CodeRabbit PR review #4).
func validateBasicAuth(c *gin.Context, cfg map[string]any) error {
rawBasic, _ := cfg["basic_auth"].(map[string]any)
if rawBasic == nil {
return fmt.Errorf("Invalid Basic Auth credentials")
}
username, _ := rawBasic["username"].(string)
password, _ := rawBasic["password"].(string)
if username == "" || password == "" {
return fmt.Errorf("Invalid Basic Auth credentials")
}
u, p, ok := c.Request.BasicAuth()
if !ok || u != username || p != password {
return fmt.Errorf("Invalid Basic Auth credentials")
}
return nil
}
// validateJWTAuth mirrors python agent_api.py:1745-1809.
//
// Algorithm defaults to HS256. audience / issuer are validated only
// when configured. required_claims rejects reserved JWT claims
// (exp sub aud iss nbf iat) and any missing claims.
//
// Algorithms:
// - HS256 / HS384 / HS512 → secret is a shared HMAC key.
// - RS256 / RS384 / RS512 → secret is a PEM-encoded RSA public key.
// - ES256 / ES384 / ES512 → secret is a PEM-encoded EC public key.
//
// The Python reference uses the same `secret` field for all three
// families (the python jwt library is happy to take either a string
// or a PEM block); we mirror that with one `secret` config slot and
// dispatch on the algorithm.
func validateJWTAuth(c *gin.Context, cfg map[string]any) error {
rawJWT, _ := cfg["jwt"].(map[string]any)
if rawJWT == nil {
return fmt.Errorf("jwt secret not configured")
}
secret, _ := rawJWT["secret"].(string)
if secret == "" {
return fmt.Errorf("jwt secret not configured")
}
authHeader := c.GetHeader("Authorization")
const prefix = "Bearer "
if !strings.HasPrefix(authHeader, prefix) {
return fmt.Errorf("missing bearer token")
}
tokenStr := strings.TrimSpace(authHeader[len(prefix):])
if tokenStr == "" {
return fmt.Errorf("empty bearer token")
}
alg, _ := rawJWT["algorithm"].(string)
if alg == "" {
alg = "HS256"
}
alg = strings.ToUpper(alg)
// Build the parser options.
parserOpts := []jwt.ParserOption{
jwt.WithValidMethods([]string{alg}),
}
if aud, ok := rawJWT["audience"].(string); ok && aud != "" {
parserOpts = append(parserOpts, jwt.WithAudience(aud))
}
if iss, ok := rawJWT["issuer"].(string); ok && iss != "" {
parserOpts = append(parserOpts, jwt.WithIssuer(iss))
}
keyFunc, keyErr := jwtKeyFunc(alg, secret)
if keyErr != nil {
return keyErr
}
token, err := jwt.Parse(tokenStr, keyFunc, parserOpts...)
if err != nil {
return fmt.Errorf("invalid jwt: %s", err.Error())
}
if !token.Valid {
return fmt.Errorf("invalid jwt")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return fmt.Errorf("invalid jwt claims")
}
// Required claims validation (mirrors agent_api.py:1787-1808).
required := collectStringSlice(rawJWT["required_claims"])
reserved := splitCSV(jwtReservedClaims)
for _, claim := range required {
if contains(reserved, claim) {
return fmt.Errorf("reserved jwt claim cannot be required: %s", claim)
}
if _, present := claims[claim]; !present {
return fmt.Errorf("missing jwt claim: %s", claim)
}
}
return nil
}
// jwtKeyFunc returns the verification-key closure that jwt.Parse
// invokes. The dispatch mirrors the python jwt library: a string secret
// is treated as an HMAC key for HS* algorithms, and as a PEM block for
// RS*/ES* algorithms.
func jwtKeyFunc(alg, secret string) (jwt.Keyfunc, error) {
switch alg {
case "HS256", "HS384", "HS512":
return func(_ *jwt.Token) (any, error) { return []byte(secret), nil }, nil
case "RS256", "RS384", "RS512":
pub, err := jwt.ParseRSAPublicKeyFromPEM([]byte(secret))
if err != nil {
return nil, fmt.Errorf("jwt rsa public key: %s", err.Error())
}
return func(_ *jwt.Token) (any, error) { return pub, nil }, nil
case "ES256", "ES384", "ES512":
pub, err := jwt.ParseECPublicKeyFromPEM([]byte(secret))
if err != nil {
return nil, fmt.Errorf("jwt ec public key: %s", err.Error())
}
return func(_ *jwt.Token) (any, error) { return pub, nil }, nil
}
return nil, fmt.Errorf("unsupported jwt algorithm: %s", alg)
}
// collectStringSlice accepts the python-shaped `required_claims` value:
// a single string OR a list/tuple/set. Mirrors agent_api.py:1788-1798.
func collectStringSlice(v any) []string {
switch t := v.(type) {
case string:
s := strings.TrimSpace(t)
if s == "" {
return nil
}
return []string{s}
case []string:
out := make([]string, 0, len(t))
for _, s := range t {
s = strings.TrimSpace(s)
if s != "" {
out = append(out, s)
}
}
return out
case []any:
out := make([]string, 0, len(t))
for _, item := range t {
if s, ok := item.(string); ok {
s = strings.TrimSpace(s)
if s != "" {
out = append(out, s)
}
}
}
return out
}
return nil
}
func splitCSV(s string) []string {
parts := strings.Split(s, ",")
out := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
out = append(out, p)
}
}
return out
}

View File

@@ -0,0 +1,380 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package handler
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
)
func securityCtx(t *testing.T, remoteAddr string, headers map[string]string) *gin.Context {
t.Helper()
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/v1/agents/c1/webhook", strings.NewReader("{}"))
c.Request.Header.Set("Content-Type", "application/json")
c.Request.RemoteAddr = remoteAddr
for k, v := range headers {
c.Request.Header.Set(k, v)
}
return c
}
// TestValidateMaxBodySize_Allowed covers the no-op branch.
func TestValidateMaxBodySize_Allowed(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", nil)
c.Request.ContentLength = 100
if err := validateMaxBodySize(c, map[string]any{"max_body_size": "1kb"}); err != nil {
t.Errorf("err = %v, want nil", err)
}
}
// TestValidateMaxBodySize_TooLarge covers the size-mismatch branch.
func TestValidateMaxBodySize_TooLarge(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", nil)
c.Request.ContentLength = 2048
err := validateMaxBodySize(c, map[string]any{"max_body_size": "1kb"})
if err == nil || !strings.Contains(err.Error(), "request body too large") {
t.Errorf("err = %v, want 'request body too large'", err)
}
}
// TestValidateMaxBodySize_BadFormat covers non-numeric format.
func TestValidateMaxBodySize_BadFormat(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", nil)
err := validateMaxBodySize(c, map[string]any{"max_body_size": "1gb"})
if err == nil || !strings.Contains(err.Error(), "invalid max_body_size format") {
t.Errorf("err = %v, want 'invalid max_body_size format'", err)
}
}
// TestValidateIPWhitelist_EmptyIsAllow covers the empty-list branch.
func TestValidateIPWhitelist_EmptyIsAllow(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", nil)
if err := validateIPWhitelist(c, map[string]any{"ip_whitelist": []any{}}); err != nil {
t.Errorf("empty whitelist: err = %v, want nil", err)
}
}
// TestValidateIPWhitelist_ExactMatch passes when client IP matches.
func TestValidateIPWhitelist_ExactMatch(t *testing.T) {
c := securityCtx(t, "10.0.0.5:0", nil)
cfg := map[string]any{"ip_whitelist": []any{"10.0.0.5"}}
if err := validateIPWhitelist(c, cfg); err != nil {
t.Errorf("exact match: err = %v, want nil", err)
}
}
// TestValidateIPWhitelist_CIDR covers the CIDR branch.
func TestValidateIPWhitelist_CIDR(t *testing.T) {
c := securityCtx(t, "10.0.0.5:0", nil)
cfg := map[string]any{"ip_whitelist": []any{"10.0.0.0/8"}}
if err := validateIPWhitelist(c, cfg); err != nil {
t.Errorf("cidr match: err = %v, want nil", err)
}
}
// TestValidateIPWhitelist_RejectForeign confirms a foreign IP is denied.
func TestValidateIPWhitelist_RejectForeign(t *testing.T) {
c := securityCtx(t, "192.168.1.5:0", nil)
cfg := map[string]any{"ip_whitelist": []any{"10.0.0.0/8"}}
err := validateIPWhitelist(c, cfg)
if err == nil || !strings.Contains(err.Error(), "not allowed by whitelist") {
t.Errorf("err = %v, want 'not allowed by whitelist'", err)
}
}
// TestValidateAuth_NoneIsAllow covers the auth_type=="none" no-op.
func TestValidateAuth_NoneIsAllow(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", nil)
if err := validateAuth(c, map[string]any{"auth_type": "none"}); err != nil {
t.Errorf("auth_type=none: err = %v, want nil", err)
}
}
// TestValidateAuth_Unsupported covers unknown auth_type.
func TestValidateAuth_Unsupported(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", nil)
err := validateAuth(c, map[string]any{"auth_type": "weird"})
if err == nil || !strings.Contains(err.Error(), "unsupported auth_type") {
t.Errorf("err = %v, want 'unsupported auth_type'", err)
}
}
// TestValidateTokenAuth_HeaderValue covers matching and non-matching cases.
func TestValidateTokenAuth_HeaderValue(t *testing.T) {
cfg := map[string]any{
"token": map[string]any{
"token_header": "X-Token",
"token_value": "abc",
},
}
// pass
cPass := securityCtx(t, "1.2.3.4:0", map[string]string{"X-Token": "abc"})
if err := validateTokenAuth(cPass, cfg); err != nil {
t.Errorf("matching token: err = %v, want nil", err)
}
// fail
cFail := securityCtx(t, "1.2.3.4:0", map[string]string{"X-Token": "wrong"})
if err := validateTokenAuth(cFail, cfg); err == nil {
t.Errorf("non-matching token: err = nil, want error")
}
}
// TestValidateBasicAuth_PassAndFail covers both branches.
func TestValidateBasicAuth_PassAndFail(t *testing.T) {
cfg := map[string]any{
"basic_auth": map[string]any{
"username": "alice",
"password": "wonderland",
},
}
cPass := securityCtx(t, "1.2.3.4:0", nil)
cPass.Request.SetBasicAuth("alice", "wonderland")
if err := validateBasicAuth(cPass, cfg); err != nil {
t.Errorf("matching basic: err = %v, want nil", err)
}
cFail := securityCtx(t, "1.2.3.4:0", nil)
cFail.Request.SetBasicAuth("alice", "wrong")
if err := validateBasicAuth(cFail, cfg); err == nil {
t.Errorf("non-matching basic: err = nil, want error")
}
}
// TestValidateJWTAuth_NoSecret rejects empty secret config.
func TestValidateJWTAuth_NoSecret(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer x"})
err := validateJWTAuth(c, map[string]any{"jwt": map[string]any{}})
if err == nil || !strings.Contains(err.Error(), "secret not configured") {
t.Errorf("err = %v, want 'secret not configured'", err)
}
}
// TestValidateJWTAuth_MissingBearer rejects when Authorization is absent.
func TestValidateJWTAuth_MissingBearer(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", nil)
err := validateJWTAuth(c, map[string]any{
"jwt": map[string]any{"secret": "s"},
})
if err == nil || !strings.Contains(err.Error(), "missing bearer token") {
t.Errorf("err = %v, want 'missing bearer token'", err)
}
}
// TestValidateJWTAuth_RS256HappyPath confirms that an RS256 token
// signed with a real RSA key is accepted when the secret field holds
// the matching PEM-encoded public key. The python reference uses the
// same `secret` slot for both HMAC and RSA inputs; this test pins that
// contract for the Go port.
func TestValidateJWTAuth_RS256HappyPath(t *testing.T) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("rsa.GenerateKey: %v", err)
}
pubPEM := encodeRSAPublicKeyPEM(&priv.PublicKey)
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"sub": "u-1",
"exp": time.Now().Add(time.Hour).Unix(),
})
signed, err := token.SignedString(priv)
if err != nil {
t.Fatalf("sign: %v", err)
}
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer " + signed})
err = validateJWTAuth(c, map[string]any{
"jwt": map[string]any{
"secret": string(pubPEM),
"algorithm": "RS256",
},
})
if err != nil {
t.Errorf("RS256 happy path: err = %v, want nil", err)
}
}
// TestValidateJWTAuth_RS256BadPEM rejects a malformed public key.
func TestValidateJWTAuth_RS256BadPEM(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer x"})
err := validateJWTAuth(c, map[string]any{
"jwt": map[string]any{
"secret": "not a pem block",
"algorithm": "RS256",
},
})
if err == nil || !strings.Contains(err.Error(), "rsa public key") {
t.Errorf("err = %v, want 'rsa public key' parse error", err)
}
}
// TestValidateJWTAuth_UnsupportedAlgorithm rejects truly unknown algos.
func TestValidateJWTAuth_UnsupportedAlgorithm(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer x"})
err := validateJWTAuth(c, map[string]any{
"jwt": map[string]any{
"secret": "s",
"algorithm": "none",
},
})
if err == nil || !strings.Contains(err.Error(), "unsupported") {
t.Errorf("err = %v, want 'unsupported'", err)
}
}
// encodeRSAPublicKeyPEM serialises an *rsa.PublicKey into the
// PEM-encoded PKIX form that jwt.ParseRSAPublicKeyFromPEM expects.
func encodeRSAPublicKeyPEM(pub *rsa.PublicKey) []byte {
asn1, err := x509.MarshalPKIXPublicKey(pub)
if err != nil {
panic(err)
}
block := &pem.Block{Type: "PUBLIC KEY", Bytes: asn1}
return pem.EncodeToMemory(block)
}
// TestValidateJWTAuth_ReservedClaimRejected covers the reserved-claim
// guard. We build a valid HS256 JWT first so the parse succeeds, then
// ask for `exp` as a required claim — the validator must reject it.
func TestValidateJWTAuth_ReservedClaimRejected(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": "u-1",
"exp": time.Now().Add(time.Hour).Unix(),
})
signed, err := token.SignedString([]byte("s"))
if err != nil {
t.Fatalf("sign: %v", err)
}
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer " + signed})
err = validateJWTAuth(c, map[string]any{
"jwt": map[string]any{
"secret": "s",
"required_claims": "exp",
},
})
if err == nil || !strings.Contains(err.Error(), "reserved jwt claim") {
t.Errorf("err = %v, want 'reserved jwt claim'", err)
}
}
// TestValidateRateLimit_NoConfig covers the no-rate-limit branch.
func TestValidateRateLimit_NoConfig(t *testing.T) {
if err := validateRateLimit("c1", map[string]any{}); err != nil {
t.Errorf("no rate_limit: err = %v, want nil", err)
}
}
// TestValidateRateLimit_BadPer rejects unknown per window.
func TestValidateRateLimit_BadPer(t *testing.T) {
err := validateRateLimit("c1", map[string]any{
"rate_limit": map[string]any{"limit": 10, "per": "week"},
})
if err == nil || !strings.Contains(err.Error(), "invalid rate_limit.per") {
t.Errorf("err = %v, want 'invalid rate_limit.per'", err)
}
}
// TestValidateRateLimit_BadLimit rejects non-positive limits.
func TestValidateRateLimit_BadLimit(t *testing.T) {
err := validateRateLimit("c1", map[string]any{
"rate_limit": map[string]any{"limit": 0, "per": "minute"},
})
if err == nil || !strings.Contains(err.Error(), "must be > 0") {
t.Errorf("err = %v, want 'must be > 0'", err)
}
}
// (No helper needed at the bottom of this file; helper functions
// inline above.)
// _securityUnused previously lived here as a placeholder; deleted
// during cleanup (code-review MEDIUM-2).
// TestValidateMaxBodySize_OverflowGuard covers CodeRabbit PR review
// #3: a configured n that would overflow n*bytesPerMB (e.g. a huge
// mb value) must be rejected before the multiplication, not
// silently wrap to a small number and pass the cap check.
//
// The chosen n=9_223_372_036_855 multiplied by 1_000_000 (= 1 MB)
// overflows int64 max (9.22e18), so without the pre-multiplication
// guard this would silently wrap to a small positive number and
// the cap check would (incorrectly) succeed. With the guard in
// place, the rejection happens at the parse step, before any
// multiplication.
func TestValidateMaxBodySize_OverflowGuard(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", nil)
err := validateMaxBodySize(c, map[string]any{"max_body_size": "9223372036855mb"})
if err == nil {
t.Errorf("huge mb value: err = nil, want overflow-rejection error")
} else if !strings.Contains(err.Error(), "exceeds maximum") {
t.Errorf("err = %v, want 'exceeds maximum'", err)
}
}
// TestParseMaxBodySize_DecimalUnits documents the per-user-request
// SI-decimal unit base: 1 kb = 1 000 B, 1 mb = 1 000 000 B.
// (The python reference uses 1 024 / 1 048 576.)
func TestParseMaxBodySize_DecimalUnits(t *testing.T) {
cases := []struct {
in string
want int64
}{
{"1kb", 1000},
{"5kb", 5000},
{"1mb", 1_000_000},
{"10mb", 10_000_000}, // exact cap
}
for _, tc := range cases {
got, err := parseMaxBodySize(map[string]any{"max_body_size": tc.in})
if err != nil {
t.Errorf("parseMaxBodySize(%q): err = %v, want nil", tc.in, err)
continue
}
if got != tc.want {
t.Errorf("parseMaxBodySize(%q) = %d, want %d", tc.in, got, tc.want)
}
}
}
// TestValidateTokenAuth_EmptyValueRejected covers CodeRabbit PR
// review #4: an empty configured token_value used to mean "accept
// any request without that header". Now it must be rejected.
func TestValidateTokenAuth_EmptyValueRejected(t *testing.T) {
cfg := map[string]any{
"token": map[string]any{
"token_header": "X-Token",
"token_value": "",
},
}
c := securityCtx(t, "1.2.3.4:0", nil) // no header at all
if err := validateTokenAuth(c, cfg); err == nil {
t.Errorf("empty token_value: err = nil, want error")
}
}

View File

@@ -0,0 +1,659 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package handler
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"ragflow/internal/agent/canvas"
"ragflow/internal/common"
"ragflow/internal/dao"
"ragflow/internal/entity"
)
// fakeCanvasLoader is a stand-in canvasLoader for webhook tests. It
// returns the canned canvas/run events configured at construction time
// without touching the database. Mirrors the pattern used by
// waitFakeAgentService at internal/handler/agent_wait_for_user_test.go.
type fakeCanvasLoader struct {
canvas *entity.UserCanvas
err error
events []canvas.RunEvent
}
func (f *fakeCanvasLoader) LoadCanvasByID(_ context.Context, _, _ string) (*entity.UserCanvas, error) {
if f.err != nil {
return nil, f.err
}
return f.canvas, nil
}
func (f *fakeCanvasLoader) RunAgentWithWebhook(_ context.Context, _, _ string, _ map[string]any) (<-chan canvas.RunEvent, error) {
out := make(chan canvas.RunEvent, len(f.events))
for _, e := range f.events {
out <- e
}
close(out)
return out, nil
}
// makeWebhookCanvas builds a minimal canvas with a Begin component
// whose params.mode == "Webhook" and the supplied params map. The
// `params` argument becomes webhook_cfg inside the handler.
func makeWebhookCanvas(id, userID, mode string, params map[string]any) *entity.UserCanvas {
dsl := map[string]any{
"components": map[string]any{
"begin": map[string]any{
"obj": map[string]any{
"component_name": "Begin",
"params": map[string]any{
"mode": mode,
},
},
},
},
}
for k, v := range params {
dsl["components"].(map[string]any)["begin"].(map[string]any)["obj"].(map[string]any)["params"].(map[string]any)[k] = v
}
return &entity.UserCanvas{
ID: id,
UserID: userID,
CanvasCategory: "agent_canvas",
DSL: entity.JSONMap(dsl),
}
}
// webhookCtx builds a gin test context with the supplied method/path/body
// and a pre-set "user" so GetUser() returns success.
func webhookCtx(method, path, body, contentType string) (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
var reader *strings.Reader
if body == "" {
reader = strings.NewReader("")
} else {
reader = strings.NewReader(body)
}
c.Request = httptest.NewRequest(method, path, reader)
if contentType != "" {
c.Request.Header.Set("Content-Type", contentType)
}
c.Set("user", &entity.User{ID: "u-1"})
return c, w
}
// errBody extracts {code, message} from a 102 envelope response.
func errBody(t *testing.T, body []byte) (int, string) {
t.Helper()
var env struct {
Code int `json:"code"`
Message string `json:"message"`
}
if err := json.Unmarshal(body, &env); err != nil {
t.Fatalf("decode response: %v (body=%s)", err, body)
}
return env.Code, env.Message
}
// ---------- Phase 1: canvas loading / DataFlow rejection ----------
// TestWebhook_RejectsUnknownCanvas pins the 102 "Canvas not found."
// envelope when LoadCanvasByID returns ErrUserCanvasNotFound. This is
// the deliberate divergence from mapAgentError (which would surface 103).
func TestWebhook_RejectsUnknownCanvas(t *testing.T) {
h := &AgentHandler{loader: &fakeCanvasLoader{err: dao.ErrUserCanvasNotFound}}
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
h.Webhook(c)
code, msg := errBody(t, w.Body.Bytes())
if code != int(common.CodeDataError) {
t.Errorf("code = %d, want %d (DataError)", code, common.CodeDataError)
}
if msg != "Canvas not found." {
t.Errorf("message = %q, want %q", msg, "Canvas not found.")
}
}
// TestWebhook_RejectsDataFlowCanvas covers the second guard: the canvas
// exists but its category is DataFlow (which python
// agent_api.py:1575 explicitly rejects).
func TestWebhook_RejectsDataFlowCanvas(t *testing.T) {
cv := makeWebhookCanvas("c1", "u-1", "Webhook", nil)
cv.CanvasCategory = "DataFlow"
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
h.Webhook(c)
code, msg := errBody(t, w.Body.Bytes())
if code != int(common.CodeDataError) {
t.Errorf("code = %d, want %d", code, common.CodeDataError)
}
if msg != "Dataflow can not be triggered by webhook." {
t.Errorf("message = %q, want %q", msg, "Dataflow can not be triggered by webhook.")
}
}
// TestWebhook_RejectsMissingWebhookConfig: DSL has no Begin with
// mode="Webhook".
func TestWebhook_RejectsMissingWebhookConfig(t *testing.T) {
cv := makeWebhookCanvas("c1", "u-1", "Manual", nil) // wrong mode
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
h.Webhook(c)
code, msg := errBody(t, w.Body.Bytes())
if code != int(common.CodeDataError) {
t.Errorf("code = %d, want %d", code, common.CodeDataError)
}
if msg != "Webhook not configured for this agent." {
t.Errorf("message = %q", msg)
}
}
// ---------- Phase 5: method gate ----------
// TestWebhook_RejectsDisallowedMethod covers the methods list gate.
// methods:["POST"] + a GET request → 102.
func TestWebhook_RejectsDisallowedMethod(t *testing.T) {
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
"methods": []any{"POST"},
})
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
c, w := webhookCtx("GET", "/api/v1/agents/c1/webhook", ``, "")
h.Webhook(c)
code, msg := errBody(t, w.Body.Bytes())
if code != int(common.CodeDataError) {
t.Errorf("code = %d, want %d", code, common.CodeDataError)
}
want := "HTTP method 'GET' not allowed for this webhook."
if msg != want {
t.Errorf("message = %q, want %q", msg, want)
}
}
// ---------- Phase 6: security ----------
// TestWebhook_TokenAuthPasses pins the happy path for token auth.
func TestWebhook_TokenAuthPasses(t *testing.T) {
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
"security": map[string]any{
"auth_type": "token",
"token": map[string]any{
"token_header": "X-Webhook-Token",
"token_value": "s3cret",
},
},
"execution_mode": "Immediately",
"response": map[string]any{
"status": 200,
},
})
loader := &fakeCanvasLoader{canvas: cv}
h := &AgentHandler{loader: loader}
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
c.Request.Header.Set("X-Webhook-Token", "s3cret")
h.Webhook(c)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200; body=%s", w.Code, w.Body.String())
}
}
// TestWebhook_TokenAuthFails: header value mismatch → 102.
func TestWebhook_TokenAuthFails(t *testing.T) {
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
"security": map[string]any{
"auth_type": "token",
"token": map[string]any{
"token_header": "X-Webhook-Token",
"token_value": "s3cret",
},
},
})
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
c.Request.Header.Set("X-Webhook-Token", "wrong")
h.Webhook(c)
code, msg := errBody(t, w.Body.Bytes())
if code != int(common.CodeDataError) {
t.Errorf("code = %d, want %d", code, common.CodeDataError)
}
if msg != "Invalid token authentication" {
t.Errorf("message = %q, want %q", msg, "Invalid token authentication")
}
}
// TestWebhook_BodySizeLimit: 1kb ceiling + ~2kb Content-Length → 102.
func TestWebhook_BodySizeLimit(t *testing.T) {
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
"security": map[string]any{
"max_body_size": "1kb",
},
})
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook",
strings.Repeat("a", 2048), "text/plain")
h.Webhook(c)
code, _ := errBody(t, w.Body.Bytes())
if code != int(common.CodeDataError) {
t.Errorf("code = %d, want %d", code, common.CodeDataError)
}
}
// TestWebhook_BodySizeLimitTooLarge rejects config bugs >10MB.
func TestWebhook_BodySizeLimitTooLarge(t *testing.T) {
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
"security": map[string]any{
"max_body_size": "11mb",
},
})
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
h.Webhook(c)
code, msg := errBody(t, w.Body.Bytes())
if code != int(common.CodeDataError) {
t.Errorf("code = %d, want %d", code, common.CodeDataError)
}
if !strings.Contains(msg, "exceeds maximum") {
t.Errorf("message = %q, want contains 'exceeds maximum'", msg)
}
}
// TestWebhook_BasicAuthPasses / TestWebhook_BasicAuthFails exercise the
// HTTP Basic branch.
func TestWebhook_BasicAuthPasses(t *testing.T) {
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
"security": map[string]any{
"auth_type": "basic",
"basic_auth": map[string]any{
"username": "alice",
"password": "wonderland",
},
},
"execution_mode": "Immediately",
"response": map[string]any{"status": 200},
})
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
c.Request.SetBasicAuth("alice", "wonderland")
h.Webhook(c)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200; body=%s", w.Code, w.Body.String())
}
}
func TestWebhook_BasicAuthFails(t *testing.T) {
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
"security": map[string]any{
"auth_type": "basic",
"basic_auth": map[string]any{
"username": "alice",
"password": "wonderland",
},
},
})
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
c.Request.SetBasicAuth("alice", "wrong")
h.Webhook(c)
code, _ := errBody(t, w.Body.Bytes())
if code != int(common.CodeDataError) {
t.Errorf("code = %d, want %d", code, common.CodeDataError)
}
}
// ---------- Phase 8: schema extraction ----------
// TestWebhook_SchemaExtractionRequired: missing required field → 102.
func TestWebhook_SchemaExtractionRequired(t *testing.T) {
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
"schema": map[string]any{
"query": map[string]any{
"properties": map[string]any{
"q": map[string]any{"type": "string"},
},
"required": []any{"q"},
},
},
"execution_mode": "Immediately",
"response": map[string]any{"status": 200},
})
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook?other=v", ``, "")
h.Webhook(c)
code, msg := errBody(t, w.Body.Bytes())
if code != int(common.CodeDataError) {
t.Errorf("code = %d, want %d", code, common.CodeDataError)
}
if !strings.Contains(msg, "missing required field") {
t.Errorf("message = %q, want contains 'missing required field'", msg)
}
}
// TestWebhook_SchemaExtractionTypeMismatch: non-numeric in number field.
func TestWebhook_SchemaExtractionTypeMismatch(t *testing.T) {
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
"schema": map[string]any{
"body": map[string]any{
"properties": map[string]any{
"n": map[string]any{"type": "number"},
},
"required": []any{"n"},
},
},
"execution_mode": "Immediately",
"response": map[string]any{"status": 200},
})
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{"n":"abc"}`, "application/json")
h.Webhook(c)
code, msg := errBody(t, w.Body.Bytes())
if code != int(common.CodeDataError) {
t.Errorf("code = %d, want %d", code, common.CodeDataError)
}
if !strings.Contains(msg, "type mismatch") && !strings.Contains(msg, "auto-cast") {
t.Errorf("message = %q, want contains 'type mismatch' or 'auto-cast'", msg)
}
}
// ---------- Phase 9: dispatch ----------
// TestWebhook_ImmediatelyReturnsConfiguredStatus confirms the
// Immediately mode returns the configured status code and runs the
// canvas detached.
func TestWebhook_ImmediatelyReturnsConfiguredStatus(t *testing.T) {
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
"execution_mode": "Immediately",
"response": map[string]any{
"status": 201,
"body_template": `{"ok":true}`,
},
})
loader := &fakeCanvasLoader{
canvas: cv,
events: []canvas.RunEvent{
{Type: "message", Data: `{"content":"hi"}`},
{Type: "done", Data: ""},
},
}
h := &AgentHandler{loader: loader}
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
h.Webhook(c)
if w.Code != http.StatusCreated {
t.Errorf("status = %d, want 201; body=%s", w.Code, w.Body.String())
}
if !strings.Contains(w.Body.String(), `"ok":true`) {
t.Errorf("body = %q, want contains '\"ok\":true'", w.Body.String())
}
}
// TestWebhook_DefaultModeAggregatesContent covers the non-Immediately
// path: streaming message + message_end events get concatenated.
func TestWebhook_DefaultModeAggregatesContent(t *testing.T) {
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
"execution_mode": "Streaming",
})
loader := &fakeCanvasLoader{
canvas: cv,
events: []canvas.RunEvent{
{Type: "message", Data: `{"content":"hello "}`},
{Type: "message", Data: `{"content":"world"}`},
{Type: "message_end", Data: `{"status":200}`},
{Type: "done", Data: ""},
},
}
h := &AgentHandler{loader: loader}
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
h.Webhook(c)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
}
var body struct {
Message string `json:"message"`
Code int `json:"code"`
}
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
t.Fatalf("decode: %v", err)
}
if body.Message != "hello world" {
t.Errorf("aggregated message = %q, want %q", body.Message, "hello world")
}
if body.Code != 200 {
t.Errorf("code = %d, want 200", body.Code)
}
}
// TestWebhook_InvalidResponseStatusRange rejects bad config (status 500).
func TestWebhook_InvalidResponseStatusRange(t *testing.T) {
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
"execution_mode": "Immediately",
"response": map[string]any{"status": 500},
})
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
h.Webhook(c)
code, msg := errBody(t, w.Body.Bytes())
if code != int(common.CodeDataError) {
t.Errorf("code = %d, want %d", code, common.CodeDataError)
}
if !strings.Contains(msg, "must be between 200 and 399") {
t.Errorf("message = %q, want contains 'must be between 200 and 399'", msg)
}
}
// TestWebhook_FindWebhookBegin_NilDSL ensures the helper is defensive.
func TestWebhook_FindWebhookBegin_NilDSL(t *testing.T) {
if got := findWebhookBegin(nil); got != nil {
t.Errorf("findWebhookBegin(nil) = %v, want nil", got)
}
}
// TestWebhook_FindWebhookBegin_NoComponents covers DSL with no components.
func TestWebhook_FindWebhookBegin_NoComponents(t *testing.T) {
if got := findWebhookBegin(map[string]any{}); got != nil {
t.Errorf("findWebhookBegin(empty) = %v, want nil", got)
}
}
// ---------- Regression tests for issues raised in code review ----------
// TestWebhook_BodySizeStreamBounded pins the security MEDIUM-1
// hardening: when max_body_size is configured, an oversized body is
// rejected. The 102 envelope carries code=CodeDataError; HTTP itself
// stays at 200 (matches the existing /api/v1 error envelope shape).
func TestWebhook_BodySizeStreamBounded(t *testing.T) {
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
"security": map[string]any{
"max_body_size": "1kb",
},
})
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook",
strings.Repeat("a", 2048), "text/plain")
h.Webhook(c)
code, msg := errBody(t, w.Body.Bytes())
if code != int(common.CodeDataError) {
t.Errorf("envelope code = %d, want %d (CodeDataError)", code, common.CodeDataError)
}
// Either message is acceptable — the Content-Length pre-check
// ("request body too large") OR the MaxBytesReader runtime check
// ("http: request body too large").
if !strings.Contains(msg, "body too large") {
t.Errorf("message = %q, want contains 'body too large'", msg)
}
}
//
// These three tests cover the behaviour the implementation summary
// claimed but did NOT actually enforce. They ensure the dispatch path
// returns the documented envelope for multipart and content-type
// mismatch cases.
// TestWebhook_MultipartReturns501 pins the contract that
// multipart/form-data uploads are refused with HTTP 501 BEFORE the
// schema-extraction phase runs. Before the fix, the handler silently
// stuffed __multipart_unsupported__ into body and continued; operators
// saw a confusing schema-validation error instead of the documented
// "not implemented" envelope.
func TestWebhook_MultipartReturns501(t *testing.T) {
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
"execution_mode": "Immediately",
"response": map[string]any{"status": 200},
})
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook",
"--boundary\r\nContent-Disposition: form-data; name=\"k\"\r\n\r\nv\r\n--boundary--\r\n",
"multipart/form-data; boundary=boundary")
h.Webhook(c)
if w.Code != http.StatusNotImplemented {
t.Errorf("status = %d, want 501; body=%s", w.Code, w.Body.String())
}
if !strings.Contains(w.Body.String(), "not supported") {
t.Errorf("body should mention 'not supported', got %q", w.Body.String())
}
}
// TestWebhook_ContentTypeMismatchReturns102 pins the contract that
// content_types whitelist enforces a HARD REJECT (HTTP 102 envelope),
// matching python agent_api.py:1839-1842. Before the fix, the
// mismatch was silently recorded as __content_type_mismatch__ in the
// body and the request flowed into schema validation — a real
// whitelist bypass.
func TestWebhook_ContentTypeMismatchReturns102(t *testing.T) {
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
"content_types": "application/json",
"execution_mode": "Immediately",
"response": map[string]any{"status": 200},
})
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
// text/plain where the config demands application/json.
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", "raw body", "text/plain")
h.Webhook(c)
if w.Code != http.StatusOK {
t.Fatalf("envelope status = %d, want 200 (the envelope itself succeeds; the code field carries 102)", w.Code)
}
code, msg := errBody(t, w.Body.Bytes())
if code != int(common.CodeDataError) {
t.Errorf("envelope code = %d, want %d (CodeDataError)", code, common.CodeDataError)
}
if !strings.Contains(msg, "invalid content-type") || !strings.Contains(msg, "application/json") {
t.Errorf("message = %q, want contains 'invalid content-type' and 'application/json'", msg)
}
}
// TestWebhook_ContentTypesEmptyAllowsAnything confirms the inverse:
// when webhook_cfg does NOT set content_types, any Content-Type is
// accepted (matches python: agent_api.py:1839 only raises when the
// config actually sets content_types).
func TestWebhook_ContentTypesEmptyAllowsAnything(t *testing.T) {
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
// content_types deliberately omitted
"execution_mode": "Immediately",
"response": map[string]any{"status": 200},
})
loader := &fakeCanvasLoader{canvas: cv}
h := &AgentHandler{loader: loader}
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{"k":"v"}`, "text/plain")
h.Webhook(c)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200; body=%s", w.Code, w.Body.String())
}
}
// TestWebhook_ContentTypeMissingHeaderRejected pins the gap from the
// review: when content_types is configured, a request that OMITS the
// Content-Type header entirely must be rejected with 102. The python
// reference has a `if ctype and ...` short-circuit that lets such
// requests through — a real whitelist bypass. The Go port tightens
// this: an operator who configured content_types gets an enforceable
// contract; the caller MUST send the header.
func TestWebhook_ContentTypeMissingHeaderRejected(t *testing.T) {
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
"content_types": "application/json",
"execution_mode": "Immediately",
"response": map[string]any{"status": 200},
})
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{"k":"v"}`, "")
// ↑ "" content-type means the header is absent. gin will not
// produce a default Content-Type on a POST so this is a clean
// "no header" test case.
h.Webhook(c)
if w.Code != http.StatusOK {
t.Fatalf("envelope status = %d, want 200 (the envelope itself succeeds; the code field carries 102)", w.Code)
}
code, msg := errBody(t, w.Body.Bytes())
if code != int(common.CodeDataError) {
t.Errorf("envelope code = %d, want %d (CodeDataError)", code, common.CodeDataError)
}
if !strings.Contains(msg, "invalid content-type") {
t.Errorf("message = %q, want contains 'invalid content-type'", msg)
}
if !strings.Contains(msg, "application/json") {
t.Errorf("message = %q, want contains 'application/json'", msg)
}
}
// (No helper needed at the bottom of this file; helper functions
// inline above.)
// _unused previously lived here as a placeholder; deleted during
// cleanup (code-review MEDIUM-2).

View File

@@ -56,6 +56,14 @@ func RegisterAgentRoutes(g *gin.RouterGroup, h *handler.AgentHandler) {
g.PUT("/:canvas_id/tags", h.UpdateAgentTags)
g.POST("/:canvas_id/reset", h.ResetAgent)
// File operations.
g.GET("/download", h.DownloadAgentFile)
g.POST("/:canvas_id/upload", h.UploadAgentFile)
// Component introspection + debug.
g.GET("/:canvas_id/components/:component_id/input-form", h.GetComponentInputForm)
g.POST("/:canvas_id/components/:component_id/debug", h.DebugComponent)
// Versions.
g.GET("/:canvas_id/versions", h.ListVersions)
g.GET("/:canvas_id/versions/:version_id", h.GetVersion)
@@ -71,9 +79,43 @@ func RegisterAgentRoutes(g *gin.RouterGroup, h *handler.AgentHandler) {
// Logs and webhook.
g.GET("/:canvas_id/logs/:message_id", h.GetAgentLogs)
g.GET("/:canvas_id/webhook/logs", h.GetAgentWebhookLogs)
// Webhook trigger endpoints. The Python agent API
// (api/apps/restful_apis/agent_api.py:1563-1564) registers six
// HTTP methods on a single path. Gin has no Match() helper, so we
// register each verb explicitly via registerAnyMethod. The handler
// is identical for all six; semantics differ only by
// c.Request.Method.
registerAnyMethod(g, "/:canvas_id/webhook", h.Webhook)
registerAnyMethod(g, "/:canvas_id/webhook/test", h.Webhook)
// Top-level actions (no canvas id in path).
// NOTE: `/chat/completion` (singular) is intentionally NOT registered.
// The singular form was a historical typo in earlier Python releases —
// no client, SDK, or doc ever called it, and the Python side
// (api/apps/restful_apis/agent_api.py) has since removed the route.
// See plan: .claude/plans/agent-api-gaps-go-port.md §Gap E.
g.POST("/chat/completions", h.AgentChatCompletions)
g.POST("/rerun", h.RerunAgent)
g.POST("/test_db_connection", h.TestDBConnection)
}
// registerAnyMethod mirrors the Python
// `@manager.route(path, methods=["POST","GET","PUT","PATCH","DELETE","HEAD"])`
// pattern. Gin has no Match() helper, so we register each verb
// explicitly. The handler is identical for all six — semantics differ
// only by c.Request.Method.
//
// Centralising the registration here keeps RegisterAgentRoutes readable
// when both the production trigger and the test trigger share the same
// six-method shape.
func registerAnyMethod(g *gin.RouterGroup, path string, h gin.HandlerFunc) {
if g == nil || h == nil {
return
}
g.POST(path, h)
g.GET(path, h)
g.PUT(path, h)
g.PATCH(path, h)
g.DELETE(path, h)
g.HEAD(path, h)
}

View File

@@ -49,6 +49,51 @@ func genID32() string {
return strings.ReplaceAll(uuid.New().String(), "-", "")[:32]
}
// webhookPayloadKey is the unexported context key RunAgent reads to
// inject root["webhook_payload"]. Only the AgentService.RunAgentWithWebhook
// public wrapper sets it; the chat / agent-run paths leave it absent so
// existing callers see no behaviour change.
//
// We deliberately do NOT surface the payload as a new RunAgent parameter
// — keeping the public signature stable means existing tests
// (agent_run_e2e_test.go, agent_wait_for_user_test.go) keep compiling.
type webhookPayloadKey struct{}
// LoadCanvasByID is the read-side counterpart of loadCanvasForUser that
// the webhook handler uses. It deliberately returns the raw DAO/service
// error (no error-code mapping) because the webhook envelope is 102
// "Canvas not found." while the chat/run envelope is 103 "Make sure you
// have permission..." — the choice must stay at the HTTP layer where
// each handler knows its own spec.
//
// Mirrors python: api/apps/restful_apis/agent_api.py:1570
// (`UserCanvasService.get_by_id(agent_id)`), with the same IDOR guard
// the chat handler uses.
func (s *AgentService) LoadCanvasByID(
ctx context.Context, userID, canvasID string,
) (*entity.UserCanvas, error) {
return s.loadCanvasForUser(ctx, userID, canvasID)
}
// RunAgentWithWebhook is a thin wrapper over RunAgent that attaches the
// webhook payload to the runner root so the Begin component can surface
// it as state.Sys["webhook_payload"] for downstream components.
//
// The payload is intentionally passed via context value (rather than a new
// RunAgent parameter) to keep the public RunAgent signature stable for
// the existing chat tests.
//
// Mirrors python: api/apps/restful_apis/agent_api.py:2125
// (`canvas.run(..., webhook_payload=clean_request)`).
func (s *AgentService) RunAgentWithWebhook(
ctx context.Context, userID, canvasID string, payload map[string]any,
) (<-chan canvas.RunEvent, error) {
if payload != nil {
ctx = context.WithValue(ctx, webhookPayloadKey{}, payload)
}
return s.RunAgent(ctx, userID, canvasID, "", "", "")
}
// ErrAgentNotOwner is returned by DeleteAgent when the canvas exists and
// is accessible to the caller but is owned by a different user. It maps
// to the Python "Only the owner of the agent is authorized for this
@@ -670,6 +715,14 @@ func (s *AgentService) RunAgent(ctx context.Context, userID, canvasID, sessionID
if dsl != nil {
root["__dsl_present__"] = true
}
// Webhook payload injection. Only RunAgentWithWebhook sets this
// context value; the chat / agent-run paths leave it nil so the
// existing surface is unchanged. The Begin component reads
// inputs["webhook_payload"] and writes it to state.Sys so
// downstream components can read sys.webhook_payload.
if payload, ok := ctx.Value(webhookPayloadKey{}).(map[string]any); ok && payload != nil {
root["webhook_payload"] = payload
}
// Phase 4.4 V2.1 (v3.6.1): populate root["tenant_id"] so the
// RunTracker.Start call (in buildRunFunc) records the run
// under the right tenant. The lookup is best-effort — a