diff --git a/api/apps/backward_compat.py b/api/apps/backward_compat.py index 5bd16f00fe..e9ad3bb505 100644 --- a/api/apps/backward_compat.py +++ b/api/apps/backward_compat.py @@ -21,7 +21,7 @@ RESTful API migration. Each deprecated route forwards to the corresponding new API implementation. Deprecated APIs and their replacements: -- POST /api/v1/agents/{agent_id}/completions -> POST /api/v1/agents/chat/completion +- POST /api/v1/agents/{agent_id}/completions -> POST /api/v1/agents/chat/completions - POST /api/v1/agents_openai/{agent_id}/chat/completions -> POST /api/v1/agents/chat/completions - POST /api/v1/chats/{chat_id}/completions -> POST /api/v1/chat/completions - POST /api/v1/chats_openai/{chat_id}/chat/completions -> POST /api/v1/openai/{chat_id}/chat/completions diff --git a/api/apps/restful_apis/agent_api.py b/api/apps/restful_apis/agent_api.py index 8dd4a27a2d..c2035f2960 100644 --- a/api/apps/restful_apis/agent_api.py +++ b/api/apps/restful_apis/agent_api.py @@ -1207,7 +1207,10 @@ async def test_db_connection(): return server_error_response(exc) -@manager.route("/agents/chat/completion", methods=["POST"]) # noqa: F821 +# NOTE: The singular form `/agents/chat/completion` was a historical typo +# in earlier releases — no client, SDK, or doc ever used it, and the +# plural form below is the canonical route. The singular is intentionally +# NOT registered; clients sending it receive 404. @manager.route("/agents/chat/completions", methods=["POST"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs diff --git a/go.mod b/go.mod index 2d4380adb3..d120a995f0 100644 --- a/go.mod +++ b/go.mod @@ -117,6 +117,7 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.20.0 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect diff --git a/go.sum b/go.sum index f3d8745a5f..4be4c11977 100644 --- a/go.sum +++ b/go.sum @@ -211,6 +211,8 @@ github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9 github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= diff --git a/internal/agent/component/begin.go b/internal/agent/component/begin.go index 334c790e32..3a255ceaa1 100644 --- a/internal/agent/component/begin.go +++ b/internal/agent/component/begin.go @@ -82,6 +82,16 @@ func (b *BeginComponent) Invoke(ctx context.Context, inputs map[string]any) (map state.Sys["user_id"] = uid } + // Webhook payload injection. The webhook HTTP handler sets + // root["webhook_payload"] (see service/agent.go RunAgentWithWebhook) + // which BuildWorkflow forwards into inputs. Surfacing it on + // state.Sys lets downstream components read sys.webhook_payload the + // same way they read sys.query / sys.user_id. The chat path never + // sets this key, so existing tests stay green. + if payload, ok := inputs["webhook_payload"].(map[string]any); ok && len(payload) > 0 { + state.Sys["webhook_payload"] = payload + } + // Passthrough: a shallow copy keeps the caller's map un-aliased. out := make(map[string]any, len(inputs)) mapsCopy(out, inputs) @@ -106,18 +116,20 @@ func (b *BeginComponent) Stream(ctx context.Context, inputs map[string]any) (<-c // strings live on the struct / method above. func (b *BeginComponent) Inputs() map[string]string { return map[string]string{ - "query": "User query string (the chat input).", - "user_id": "Optional user/tenant identifier.", - "inputs": "Optional free-form inputs map; passthrough only.", + "query": "User query string (the chat input).", + "user_id": "Optional user/tenant identifier.", + "webhook_payload": "Optional structured webhook request (set by the webhook HTTP handler; absent on chat flows).", + "inputs": "Optional free-form inputs map; passthrough only.", } } // Outputs returns the same keys as Inputs (Begin is a passthrough). func (b *BeginComponent) Outputs() map[string]string { return map[string]string{ - "query": "Query string (passthrough).", - "user_id": "User id, if provided (passthrough).", - "inputs": "Raw inputs map (passthrough).", + "query": "Query string (passthrough).", + "user_id": "User id, if provided (passthrough).", + "webhook_payload": "Webhook request payload, if provided (passthrough; also written to state.Sys[webhook_payload]).", + "inputs": "Raw inputs map (passthrough).", } } diff --git a/internal/agent/component/begin_test.go b/internal/agent/component/begin_test.go index 2076cc082f..7f056f3f7b 100644 --- a/internal/agent/component/begin_test.go +++ b/internal/agent/component/begin_test.go @@ -86,3 +86,77 @@ func TestBegin_PassesThroughInputs(t *testing.T) { func withStateForTest(ctx context.Context, s *canvas.CanvasState) context.Context { return canvas.WithState(ctx, s) } + +// TestBegin_InjectsWebhookPayload pins the contract added for the +// webhook HTTP handler: when inputs["webhook_payload"] is present, Begin +// must surface it on state.Sys["webhook_payload"] so downstream +// components (Retrieval, Agent, etc.) can read sys.webhook_payload the +// same way they read sys.query / sys.user_id. +// +// Mirrors python: agent/canvas.py (Begin component) reading +// `webhook_payload` from inputs and writing to state.Sys in the webhook +// branch. +func TestBegin_InjectsWebhookPayload(t *testing.T) { + c, _ := NewBeginComponent(nil) + state := canvas.NewCanvasState("run-3", "task-3") + ctx := canvas.WithState(context.Background(), state) + + payload := map[string]any{ + "query": map[string]any{"q": "hello"}, + "headers": map[string]any{"x-token": "abc"}, + "body": map[string]any{"k": "v"}, + } + inputs := map[string]any{ + "query": "", + "webhook_payload": payload, + } + out, err := c.Invoke(ctx, inputs) + if err != nil { + t.Fatalf("Invoke: %v", err) + } + got, ok := state.Sys["webhook_payload"].(map[string]any) + if !ok { + t.Fatalf("state.Sys[webhook_payload] missing or wrong type: %T", state.Sys["webhook_payload"]) + } + if !reflect.DeepEqual(got, payload) { + t.Errorf("state.Sys[webhook_payload] mismatch:\n got %v\n want %v", got, payload) + } + // Passthrough preserved. + if outPayload, _ := out["webhook_payload"].(map[string]any); !reflect.DeepEqual(outPayload, payload) { + t.Errorf("outputs[webhook_payload] mismatch:\n got %v\n want %v", outPayload, payload) + } +} + +// TestBegin_AbsentWebhookPayload confirms that the chat path (no +// webhook_payload key in inputs) leaves state.Sys["webhook_payload"] +// unset — adding the new branch must NOT pollute existing callers. +func TestBegin_AbsentWebhookPayload(t *testing.T) { + c, _ := NewBeginComponent(nil) + state := canvas.NewCanvasState("run-4", "task-4") + ctx := canvas.WithState(context.Background(), state) + + if _, err := c.Invoke(ctx, map[string]any{"query": "plain chat"}); err != nil { + t.Fatalf("Invoke: %v", err) + } + if _, ok := state.Sys["webhook_payload"]; ok { + t.Errorf("state.Sys[webhook_payload] should not be set when inputs lack it; got %v", state.Sys["webhook_payload"]) + } +} + +// TestBegin_EmptyWebhookPayload confirms that an explicitly empty map +// is treated as "not present" — matching the python `if payload:` guard. +func TestBegin_EmptyWebhookPayload(t *testing.T) { + c, _ := NewBeginComponent(nil) + state := canvas.NewCanvasState("run-5", "task-5") + ctx := canvas.WithState(context.Background(), state) + + if _, err := c.Invoke(ctx, map[string]any{ + "query": "", + "webhook_payload": map[string]any{}, + }); err != nil { + t.Fatalf("Invoke: %v", err) + } + if _, ok := state.Sys["webhook_payload"]; ok { + t.Errorf("state.Sys[webhook_payload] should not be set for empty payload; got %v", state.Sys["webhook_payload"]) + } +} diff --git a/internal/agent/dsl/errors.go b/internal/agent/dsl/errors.go new file mode 100644 index 0000000000..ac09861b98 --- /dev/null +++ b/internal/agent/dsl/errors.go @@ -0,0 +1,40 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dsl + +import "errors" + +// Sentinel errors returned by the extractors. Handlers use +// errors.Is to map them to 102 (DataError) envelopes without +// embedding the raw error text in the response. +var ( + // ErrComponentNotFound is returned when the supplied + // componentID does not exist in dsl["components"]. + ErrComponentNotFound = errors.New("dsl: component not found") + + // ErrMissingInputForm is returned when the component exists + // but has no `obj.input_form` dict. The python Canvas returns + // None in this case; we surface 102 "component has no + // input_form" instead. + ErrMissingInputForm = errors.New("dsl: component has no input_form") + + // ErrMalformedDSL is returned for structural problems — nil + // dsl, missing components map, wrong types. Distinct from + // ErrComponentNotFound so the handler can phrase the error + // more clearly when the dsl is broken. + ErrMalformedDSL = errors.New("dsl: malformed") +) diff --git a/internal/agent/dsl/extract.go b/internal/agent/dsl/extract.go new file mode 100644 index 0000000000..46609cbd4b --- /dev/null +++ b/internal/agent/dsl/extract.go @@ -0,0 +1,133 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package dsl contains pure-function helpers for working with the agent +// canvas DSL map structure (`map[string]any`). It is intentionally +// runtime-free: no Canvas instantiation, no component factories, no +// database access. The agent_component handlers use it to introspect +// the DSL before deciding whether to wire up a runtime component. +package dsl + +import "fmt" + +// ExtractComponentInputForm returns the input-form schema dict stored at +// `dsl["components"][componentID]["obj"]["input_form"]`. +// +// This is the Go equivalent of the python +// `Canvas.get_component_input_form(component_id)` method +// (api/agent/canvas.py:163) which reads the same path. The python +// version walks the live Canvas object; we walk the raw DSL map +// directly because the Go Canvas type does not expose an +// introspection API (see plan §Gap C — there is no `GetComponent` on +// the runtime Canvas type). +// +// Returns: +// - the form-schema dict if present and well-typed +// - ErrComponentNotFound if the componentID is missing from dsl +// - ErrMissingInputForm if the component exists but has no input_form +// - ErrMalformedDSL if the field is present but the wrong type +// +// Type errors (input_form is e.g. a list or a string) are NOT +// collapsed into ErrMissingInputForm — they would mask a contract +// violation in the DSL and let DebugComponent run against corrupt +// data. CodeRabbit PR review #1 on PR #16403. +func ExtractComponentInputForm(dsl map[string]any, componentID string) (map[string]any, error) { + comp, err := navigateToComponent(dsl, componentID) + if err != nil { + return nil, err + } + obj, ok := comp["obj"].(map[string]any) + if !ok { + return nil, fmt.Errorf("%w: component %q has no obj", ErrMalformedDSL, componentID) + } + rawForm, exists := obj["input_form"] + if !exists || rawForm == nil { + return nil, fmt.Errorf("%w: component %q has no input_form", ErrMissingInputForm, componentID) + } + form, ok := rawForm.(map[string]any) + if !ok { + return nil, fmt.Errorf("%w: component %q input_form is not a dict", ErrMalformedDSL, componentID) + } + return form, nil +} + +// ExtractComponentParams returns the params map stored at +// `dsl["components"][componentID]["obj"]["params"]`. The debug handler +// uses this to build the inputs map for the runtime Component.Invoke +// call. Type errors collapse to ErrMalformedDSL (CodeRabbit PR +// review #1). +func ExtractComponentParams(dsl map[string]any, componentID string) (map[string]any, error) { + comp, err := navigateToComponent(dsl, componentID) + if err != nil { + return nil, err + } + obj, ok := comp["obj"].(map[string]any) + if !ok { + return nil, fmt.Errorf("%w: component %q has no obj", ErrMalformedDSL, componentID) + } + rawParams, exists := obj["params"] + if !exists || rawParams == nil { + return nil, nil + } + params, ok := rawParams.(map[string]any) + if !ok { + return nil, fmt.Errorf("%w: component %q params is not a dict", ErrMalformedDSL, componentID) + } + return params, nil +} + +// ExtractComponentName returns the component's class name (e.g. +// "Begin", "LLM", "Retrieval") from `dsl["components"][componentID]. +// ["obj"]["component_name"]`. The runtime factory is keyed on this +// name. +func ExtractComponentName(dsl map[string]any, componentID string) (string, error) { + comp, err := navigateToComponent(dsl, componentID) + if err != nil { + return "", err + } + obj, ok := comp["obj"].(map[string]any) + if !ok { + return "", fmt.Errorf("%w: component %q has no obj", ErrMalformedDSL, componentID) + } + name, _ := obj["component_name"].(string) + if name == "" { + return "", fmt.Errorf("%w: component %q has no component_name", ErrMalformedDSL, componentID) + } + return name, nil +} + +// navigateToComponent walks dsl["components"][componentID] and +// returns the inner dict. Centralised so the three extractors above +// share a single traversal path. (Renamed from extractComponent to +// avoid colliding with the same-named helper in normalize.go.) +func navigateToComponent(dsl map[string]any, componentID string) (map[string]any, error) { + if dsl == nil { + return nil, fmt.Errorf("%w: nil dsl", ErrMalformedDSL) + } + comps, ok := dsl["components"].(map[string]any) + if !ok { + return nil, fmt.Errorf("%w: missing components map", ErrMalformedDSL) + } + comp, ok := comps[componentID] + if !ok { + return nil, fmt.Errorf("%w: %q", ErrComponentNotFound, componentID) + } + cm, ok := comp.(map[string]any) + if !ok { + return nil, fmt.Errorf("%w: %q is not a dict", ErrMalformedDSL, componentID) + } + return cm, nil +} diff --git a/internal/agent/dsl/extract_test.go b/internal/agent/dsl/extract_test.go new file mode 100644 index 0000000000..381b98f4c6 --- /dev/null +++ b/internal/agent/dsl/extract_test.go @@ -0,0 +1,149 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dsl + +import ( + "errors" + "testing" +) + +// happyDSL returns a 2-component dsl suitable for the happy-path +// extractors. +func happyDSL() map[string]any { + return map[string]any{ + "components": map[string]any{ + "begin": map[string]any{ + "obj": map[string]any{ + "component_name": "Begin", + "params": map[string]any{ + "mode": "Manual", + }, + "input_form": map[string]any{ + "query": map[string]any{ + "type": "string", + }, + }, + }, + }, + "answer": map[string]any{ + "obj": map[string]any{ + "component_name": "Answer", + }, + }, + }, + } +} + +func TestExtractComponentInputForm_HappyPath(t *testing.T) { + got, err := ExtractComponentInputForm(happyDSL(), "begin") + if err != nil { + t.Fatalf("err = %v, want nil", err) + } + if q, ok := got["query"].(map[string]any); !ok || q["type"] != "string" { + t.Errorf("query type = %v, want string", got["query"]) + } +} + +func TestExtractComponentInputForm_NotFound(t *testing.T) { + _, err := ExtractComponentInputForm(happyDSL(), "missing") + if !errors.Is(err, ErrComponentNotFound) { + t.Errorf("err = %v, want ErrComponentNotFound", err) + } +} + +func TestExtractComponentInputForm_MissingObj(t *testing.T) { + dsl := map[string]any{ + "components": map[string]any{ + "bare": map[string]any{}, // no obj + }, + } + _, err := ExtractComponentInputForm(dsl, "bare") + if !errors.Is(err, ErrMalformedDSL) { + t.Errorf("err = %v, want ErrMalformedDSL", err) + } +} + +func TestExtractComponentInputForm_MissingInputForm(t *testing.T) { + // "answer" has obj but no input_form. + _, err := ExtractComponentInputForm(happyDSL(), "answer") + if !errors.Is(err, ErrMissingInputForm) { + t.Errorf("err = %v, want ErrMissingInputForm", err) + } +} + +func TestExtractComponentInputForm_NilDSL(t *testing.T) { + _, err := ExtractComponentInputForm(nil, "anything") + if !errors.Is(err, ErrMalformedDSL) { + t.Errorf("err = %v, want ErrMalformedDSL", err) + } +} + +func TestExtractComponentParams_HappyPath(t *testing.T) { + got, err := ExtractComponentParams(happyDSL(), "begin") + if err != nil { + t.Fatalf("err = %v, want nil", err) + } + if got["mode"] != "Manual" { + t.Errorf("mode = %v, want Manual", got["mode"]) + } +} + +func TestExtractComponentParams_NoParams(t *testing.T) { + got, err := ExtractComponentParams(happyDSL(), "answer") + if err != nil { + t.Fatalf("err = %v, want nil (params is optional)", err) + } + if got != nil && len(got) != 0 { + t.Errorf("params = %v, want empty/nil", got) + } +} + +// TestExtractComponentParams_WrongType pins that a present-but- +// wrongly-typed params field is ErrMalformedDSL. CodeRabbit PR #1. +func TestExtractComponentParams_WrongType(t *testing.T) { + dsl := map[string]any{ + "components": map[string]any{ + "bad": map[string]any{ + "obj": map[string]any{ + "component_name": "Begin", + "params": "this is a string, not a dict", + }, + }, + }, + } + _, err := ExtractComponentParams(dsl, "bad") + if !errors.Is(err, ErrMalformedDSL) { + t.Errorf("err = %v, want ErrMalformedDSL", err) + } +} + +func TestExtractComponentName_HappyPath(t *testing.T) { + got, err := ExtractComponentName(happyDSL(), "begin") + if err != nil { + t.Fatalf("err = %v, want nil", err) + } + if got != "Begin" { + t.Errorf("name = %q, want Begin", got) + } +} + +func TestExtractComponentName_NotFound(t *testing.T) { + _, err := ExtractComponentName(happyDSL(), "missing") + if !errors.Is(err, ErrComponentNotFound) { + t.Errorf("err = %v, want ErrComponentNotFound", err) + } +} diff --git a/internal/engine/redis/redis.go b/internal/engine/redis/redis.go index 0b17c7f03f..2838f70766 100644 --- a/internal/engine/redis/redis.go +++ b/internal/engine/redis/redis.go @@ -989,6 +989,38 @@ func (r *RedisClient) GetClient() *redis.Client { return r.client } +// EvalTokenBucketStrict is the fail-closed counterpart to TokenBucket.Allow. +// It surfaces Lua errors and the uninitialised-Redis case to the caller so +// security gates (e.g. webhook rate limiter) can deny on transport failure +// rather than silently passing traffic. The existing TokenBucket.Allow +// silently fails-open and is reserved for the chat driver path where +// transient Redis outages should not block traffic. +// +// Cost is fixed at 1.0; callers wanting variable cost should compose their +// own Lua. ctx is used for both the EVALSHA round-trip and the deadline. +func (r *RedisClient) EvalTokenBucketStrict( + ctx context.Context, key string, capacity, rate float64, +) (allowed bool, err error) { + if r == nil || r.client == nil { + return false, fmt.Errorf("redis: not initialised") + } + now := float64(time.Now().Unix()) + res, err := r.luaTokenBucket.Run(ctx, r.client, []string{key}, + capacity, rate, now, 1.0).Result() + if err != nil { + return false, fmt.Errorf("token bucket: %w", err) + } + values, ok := res.([]interface{}) + if !ok || len(values) < 1 { + return false, fmt.Errorf("token bucket: malformed reply") + } + allowedI, ok := values[0].(int64) + if !ok { + return false, fmt.Errorf("token bucket: malformed reply") + } + return allowedI == 1, nil +} + // RandomSleep sleeps for random duration between min and max milliseconds func RandomSleep(minMs, maxMs int) { duration := time.Duration(rand.Intn(maxMs-minMs)+minMs) * time.Millisecond diff --git a/internal/engine/redis/redis_test.go b/internal/engine/redis/redis_test.go new file mode 100644 index 0000000000..0aefa85b95 --- /dev/null +++ b/internal/engine/redis/redis_test.go @@ -0,0 +1,107 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package redis + +import ( + "context" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" +) + +// newStrictTestClient wires a miniredis-backed RedisClient for +// EvalTokenBucketStrict tests. Each call gets its own miniredis instance so +// tests do not share state via the package-level globalClient. +func newStrictTestClient(t *testing.T) (*RedisClient, *miniredis.Miniredis) { + t.Helper() + mr, err := miniredis.Run() + if err != nil { + t.Fatalf("miniredis: %v", err) + } + t.Cleanup(mr.Close) + + rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + t.Cleanup(func() { _ = rdb.Close() }) + + return &RedisClient{ + client: rdb, + luaDeleteIfEqual: redis.NewScript(luaDeleteIfEqualScript), + luaTokenBucket: redis.NewScript(luaTokenBucketScript), + // luaAutoIncrement intentionally not loaded; not used here. + }, mr +} + +// TestEvalTokenBucketStrict_AllowedThenDenied walks the bucket through +// capacity=2, rate=0.1 (slow refill). Two calls should be allowed; the +// third should be denied. This is the happy-path security gate. +func TestEvalTokenBucketStrict_AllowedThenDenied(t *testing.T) { + r, _ := newStrictTestClient(t) + ctx := context.Background() + + for i := 1; i <= 2; i++ { + ok, err := r.EvalTokenBucketStrict(ctx, "tb:webhook", 2, 0.1) + if err != nil { + t.Fatalf("call %d unexpected error: %v", i, err) + } + if !ok { + t.Fatalf("call %d: expected allowed=true", i) + } + } + ok, err := r.EvalTokenBucketStrict(ctx, "tb:webhook", 2, 0.1) + if err != nil { + t.Fatalf("call 3 unexpected error: %v", err) + } + if ok { + t.Fatalf("call 3: expected allowed=false (bucket exhausted)") + } +} + +// TestEvalTokenBucketStrict_RedisDownFailsClosed confirms the strict +// contract: when the transport fails, the caller sees an error AND +// allowed=false. This is the explicit divergence from TokenBucket.Allow, +// which would silently return allowed=true in the same situation. +func TestEvalTokenBucketStrict_RedisDownFailsClosed(t *testing.T) { + r, mr := newStrictTestClient(t) + mr.Close() // break the connection + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + ok, err := r.EvalTokenBucketStrict(ctx, "tb:webhook", 5, 1) + if err == nil { + t.Fatalf("expected transport error, got nil") + } + if ok { + t.Fatalf("expected allowed=false on transport failure (fail-closed)") + } +} + +// TestEvalTokenBucketStrict_NilClient confirms the nil-receiver guard. +// The uninitialised-Redis case must NOT silently pass; it must return +// (false, error) so the webhook handler can surface 102. +func TestEvalTokenBucketStrict_NilClient(t *testing.T) { + var r *RedisClient + ok, err := r.EvalTokenBucketStrict(context.Background(), "tb:webhook", 1, 1) + if err == nil { + t.Fatalf("expected error on nil client, got nil") + } + if ok { + t.Fatalf("expected allowed=false on nil client (fail-closed)") + } +} diff --git a/internal/handler/agent.go b/internal/handler/agent.go index b74a835300..629e6b625f 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -42,9 +42,20 @@ import ( // AgentHandler agent handler // fileUploader is the subset of FileService used by agent handlers. +// +// The full FileService also has UploadFile, but it is consumed by +// the FileHandler (handler/file.go), not by any agent handler, so +// the interface deliberately does NOT list it. (Code review CR1.) type agentFileService interface { - UploadFile(tenantID, parentID string, files []*multipart.FileHeader) ([]map[string]interface{}, error) DownloadAgentFile(tenantID, location string) ([]byte, error) + // UploadInfos stores raw bytes in the per-user downloads bucket and + // returns lightweight descriptors. Mirrors python FileService.upload_info + // (multi-file path) used by the agent upload endpoint. + UploadInfos(userID string, files []*multipart.FileHeader) ([]map[string]interface{}, error) + // UploadFromURL downloads a remote file (with SSRF protection) and + // stores it as an info blob. Mirrors python FileService.upload_info + // (single-file path with ?url=) used by the agent upload endpoint. + UploadFromURL(tenantID, rawURL string) (map[string]interface{}, error) } // chatAgentService is the subset of AgentService used by the chat-completion @@ -62,6 +73,7 @@ type AgentHandler struct { agentService *service.AgentService chatRunner chatAgentService fileService agentFileService + loader canvasLoader } // NewAgentHandler create agent handler @@ -71,6 +83,7 @@ func NewAgentHandler(agentService *service.AgentService, fileService *service.Fi agentService: agentService, chatRunner: agentService, fileService: fileService, + loader: agentService, } } diff --git a/internal/handler/agent_component.go b/internal/handler/agent_component.go new file mode 100644 index 0000000000..b8987d9af9 --- /dev/null +++ b/internal/handler/agent_component.go @@ -0,0 +1,200 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Gap C — `GET /api/v1/agents//components//input-form`. +// Gap D — `POST /api/v1/agents//components//debug`. +// +// Both endpoints introspect the user_canvas DSL via +// internal/agent/dsl helpers and, for debug only, construct a runtime +// component via the production factory. The python reference uses +// @_require_canvas_access_sync/_async decorators which return +// OPERATING_ERROR (103) with the python permission message; we +// inline the check here via loader.LoadCanvasByID and surface the +// same envelope (CodeOperatingError + canvasNoAccessMessage) so +// existing clients can pattern-match the response. + +package handler + +import ( + "errors" + + "github.com/gin-gonic/gin" + + "ragflow/internal/agent/canvas" + "ragflow/internal/agent/dsl" + "ragflow/internal/agent/runtime" + "ragflow/internal/common" + "ragflow/internal/dao" +) + +// GetComponentInputForm GET /api/v1/agents/:canvas_id/components/:component_id/input-form +func (h *AgentHandler) GetComponentInputForm(c *gin.Context) { + user, code, msg := GetUser(c) + if code != common.CodeSuccess { + jsonError(c, code, msg) + return + } + canvasID := c.Param("canvas_id") + componentID := c.Param("component_id") + if canvasID == "" || componentID == "" { + jsonError(c, common.CodeArgumentError, "`canvas_id` and `component_id` are required.") + return + } + + cv, err := h.loader.LoadCanvasByID(c.Request.Context(), user.ID, canvasID) + if err != nil { + if err == dao.ErrUserCanvasNotFound { + jsonError(c, common.CodeOperatingError, canvasNoAccessMessage) + return + } + jsonError(c, common.CodeServerError, err.Error()) + return + } + + form, err := dsl.ExtractComponentInputForm(cv.DSL, componentID) + if err != nil { + mapDSLError(c, componentID, err) + return + } + c.JSON(200, gin.H{ + "code": common.CodeSuccess, + "data": form, + "message": "success", + }) +} + +// DebugComponent POST /api/v1/agents/:canvas_id/components/:component_id/debug +// +// Body shape (python parity): {"params": {"input_name": {"value": ...}, ...}} +// +// The python reference calls `component.invoke(**{k: o["value"] for k, o in +// req["params"].items()})`. We replicate this by flattening each param's +// "value" field into the Invoke inputs map. +func (h *AgentHandler) DebugComponent(c *gin.Context) { + user, code, msg := GetUser(c) + if code != common.CodeSuccess { + jsonError(c, code, msg) + return + } + canvasID := c.Param("canvas_id") + componentID := c.Param("component_id") + if canvasID == "" || componentID == "" { + jsonError(c, common.CodeArgumentError, "`canvas_id` and `component_id` are required.") + return + } + + var body struct { + Params map[string]map[string]any `json:"params"` + } + if err := c.ShouldBindJSON(&body); err != nil { + jsonError(c, common.CodeArgumentError, "Invalid request: "+err.Error()) + return + } + if body.Params == nil { + jsonError(c, common.CodeArgumentError, "`params` is required.") + return + } + + cv, err := h.loader.LoadCanvasByID(c.Request.Context(), user.ID, canvasID) + if err != nil { + if err == dao.ErrUserCanvasNotFound { + jsonError(c, common.CodeOperatingError, canvasNoAccessMessage) + return + } + jsonError(c, common.CodeServerError, err.Error()) + return + } + + name, err := dsl.ExtractComponentName(cv.DSL, componentID) + if err != nil { + mapDSLError(c, componentID, err) + return + } + dslParams, _ := dsl.ExtractComponentParams(cv.DSL, componentID) + + // Build the Invoke inputs map by flattening the request body's + // {param: {value: ...}} shape into {param: value}. Mirrors python + // agent_api.py:830 (`component.invoke(**{k: o["value"] for k, o in + // req["params"].items()})`). + // + // The body contract requires `params.*.value` — a missing value + // field used to slip through as nil and still invoke the + // component, which silently corrupted debug input. Now we fail + // fast. CodeRabbit PR review #2. + inputs := make(map[string]any, len(body.Params)) + for k, v := range body.Params { + if v == nil { + jsonError(c, common.CodeArgumentError, "`params."+k+".value` is required.") + return + } + value, ok := v["value"] + if !ok { + jsonError(c, common.CodeArgumentError, "`params."+k+".value` is required.") + return + } + inputs[k] = value + } + + factory := runtime.DefaultFactory() + if factory == nil { + jsonError(c, common.CodeServerError, "component factory not initialised") + return + } + comp, err := factory(name, dslParams) + if err != nil { + jsonError(c, common.CodeDataError, "component factory: "+err.Error()) + return + } + + // D4: skip set_debug_inputs (python-only LLM debug hook). The + // raw Invoke already supports the same inputs. + // + // Begin (and other stateful components) reads a *CanvasState + // from the request context. We attach a fresh one here so + // debug works on a single component without standing up the + // full canvas compile. + invokeCtx := runtime.WithState(c.Request.Context(), canvas.NewCanvasState("debug-"+componentID, "debug-task")) + + outputs, err := comp.Invoke(invokeCtx, inputs) + if err != nil { + jsonError(c, common.CodeServerError, "invoke: "+err.Error()) + return + } + + c.JSON(200, gin.H{ + "code": common.CodeSuccess, + "data": outputs, + "message": "success", + }) +} + +// mapDSLError translates a dsl extractor error into a 102 envelope. +// Centralised so both handlers return consistent error shapes. The +// default arm surfaces unknown errors as 500 (server error) so a +// future unmapped dsl sentinel doesn't silently masquerade as a +// user-data problem (code-review MEDIUM). +func mapDSLError(c *gin.Context, componentID string, err error) { + switch { + case errors.Is(err, dsl.ErrComponentNotFound): + jsonError(c, common.CodeDataError, "component not found: "+componentID) + case errors.Is(err, dsl.ErrMissingInputForm): + jsonError(c, common.CodeDataError, "component has no input_form: "+componentID) + case errors.Is(err, dsl.ErrMalformedDSL): + jsonError(c, common.CodeDataError, "malformed dsl: "+err.Error()) + default: + jsonError(c, common.CodeServerError, "internal: "+err.Error()) + } +} diff --git a/internal/handler/agent_component_test.go b/internal/handler/agent_component_test.go new file mode 100644 index 0000000000..c92a9bf3f8 --- /dev/null +++ b/internal/handler/agent_component_test.go @@ -0,0 +1,315 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "encoding/json" + "errors" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + + _ "ragflow/internal/agent/component" // registers the production factory + "ragflow/internal/dao" + "ragflow/internal/entity" +) + +// componentCtx builds a Gin context for one of the +// /agents/:canvas_id/components/:component_id/* endpoints. The +// supplied `body` is bound as the request body (empty for GET). +func componentCtx(t *testing.T, method, path, body string) (*gin.Context, *httptest.ResponseRecorder) { + t.Helper() + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(method, path, strings.NewReader(body)) + if body != "" { + c.Request.Header.Set("Content-Type", "application/json") + } + c.Set("user", &entity.User{ID: "u-1"}) + return c, w +} + +// beginCanvas returns a UserCanvas whose DSL has a single Begin +// component with both an input_form and a params block. Used for +// happy-path tests. +func beginCanvas() *entity.UserCanvas { + return &entity.UserCanvas{ + ID: "c1", + DSL: map[string]any{ + "components": map[string]any{ + "begin": map[string]any{ + "obj": map[string]any{ + "component_name": "Begin", + "params": map[string]any{ + "mode": "Manual", + }, + "input_form": map[string]any{ + "query": map[string]any{"type": "string"}, + }, + }, + }, + }, + }, + } +} + +// ---------- GetComponentInputForm ---------- + +func TestGetComponentInputForm_HappyPath(t *testing.T) { + cv := beginCanvas() + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + + c, w := componentCtx(t, "GET", "/api/v1/agents/c1/components/begin/input-form", "") + c.Params = gin.Params{ + {Key: "canvas_id", Value: "c1"}, + {Key: "component_id", Value: "begin"}, + } + h.GetComponentInputForm(c) + + if w.Code != 200 { + t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String()) + } + var env struct { + Code int `json:"code"` + Data map[string]interface{} `json:"data"` + } + if err := json.Unmarshal(w.Body.Bytes(), &env); err != nil { + t.Fatalf("decode: %v (body=%s)", err, w.Body.String()) + } + q, ok := env.Data["query"].(map[string]interface{}) + if !ok { + t.Fatalf("data.query = %v, want a map", env.Data["query"]) + } + if q["type"] != "string" { + t.Errorf("data.query.type = %v, want string", q["type"]) + } +} + +func TestGetComponentInputForm_CanvasNotFound(t *testing.T) { + h := &AgentHandler{loader: &fakeCanvasLoader{err: dao.ErrUserCanvasNotFound}} + + c, w := componentCtx(t, "GET", "/api/v1/agents/c1/components/begin/input-form", "") + c.Params = gin.Params{ + {Key: "canvas_id", Value: "c1"}, + {Key: "component_id", Value: "begin"}, + } + h.GetComponentInputForm(c) + + code, msg := errBody(t, w.Body.Bytes()) + if code != 103 { // CodeOperatingError — python @_require_canvas_access parity + t.Errorf("code = %d, want 103; msg=%q", code, msg) + } + want := "Make sure you have permission to access the agent." + if msg != want { + t.Errorf("msg = %q, want %q", msg, want) + } +} + +func TestGetComponentInputForm_ComponentNotFound(t *testing.T) { + cv := beginCanvas() + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + + c, w := componentCtx(t, "GET", "/api/v1/agents/c1/components/missing/input-form", "") + c.Params = gin.Params{ + {Key: "canvas_id", Value: "c1"}, + {Key: "component_id", Value: "missing"}, + } + h.GetComponentInputForm(c) + + code, msg := errBody(t, w.Body.Bytes()) + if code != 102 { // CodeDataError + t.Errorf("code = %d, want 102; msg=%q", code, msg) + } + if !strings.Contains(msg, "component not found") { + t.Errorf("msg = %q, want 'component not found'", msg) + } +} + +func TestGetComponentInputForm_NoInputForm(t *testing.T) { + cv := &entity.UserCanvas{ + ID: "c1", + DSL: map[string]any{ + "components": map[string]any{ + "answer": map[string]any{ + "obj": map[string]any{ + "component_name": "Answer", + // no input_form + }, + }, + }, + }, + } + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + + c, w := componentCtx(t, "GET", "/api/v1/agents/c1/components/answer/input-form", "") + c.Params = gin.Params{ + {Key: "canvas_id", Value: "c1"}, + {Key: "component_id", Value: "answer"}, + } + h.GetComponentInputForm(c) + + code, msg := errBody(t, w.Body.Bytes()) + if code != 102 { + t.Errorf("code = %d, want 102; msg=%q", code, msg) + } + if !strings.Contains(msg, "no input_form") { + t.Errorf("msg = %q, want 'no input_form'", msg) + } +} + +// ---------- DebugComponent ---------- + +func TestDebugComponent_HappyPath_Begin(t *testing.T) { + cv := beginCanvas() + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + + body := `{"params":{"query":{"value":"hello world"}}}` + c, w := componentCtx(t, "POST", "/api/v1/agents/c1/components/begin/debug", body) + c.Params = gin.Params{ + {Key: "canvas_id", Value: "c1"}, + {Key: "component_id", Value: "begin"}, + } + h.DebugComponent(c) + + if w.Code != 200 { + t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String()) + } + // Begin's Invoke is a passthrough — it returns the input map as + // outputs (see internal/agent/component/begin.go:96-98). Pin the + // exact passthrough so a regression that drops the `query` field + // would be caught (architect review A2). + var env struct { + Code int `json:"code"` + Data map[string]interface{} `json:"data"` + } + if err := json.Unmarshal(w.Body.Bytes(), &env); err != nil { + t.Fatalf("decode: %v (body=%s)", err, w.Body.String()) + } + if env.Code != 0 { + t.Errorf("code = %d, want 0; body=%s", env.Code, w.Body.String()) + } + if env.Data["query"] != "hello world" { + t.Errorf("data.query = %v, want 'hello world'", env.Data["query"]) + } +} + +func TestDebugComponent_InvalidParams_MissingField(t *testing.T) { + cv := beginCanvas() + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + + // No `params` field. + c, w := componentCtx(t, "POST", "/api/v1/agents/c1/components/begin/debug", `{}`) + c.Params = gin.Params{ + {Key: "canvas_id", Value: "c1"}, + {Key: "component_id", Value: "begin"}, + } + h.DebugComponent(c) + + code, msg := errBody(t, w.Body.Bytes()) + if code != 101 { // CodeArgumentError + t.Errorf("code = %d, want 101; msg=%q", code, msg) + } + if !strings.Contains(msg, "params") { + t.Errorf("msg = %q, want 'params'", msg) + } +} + +// TestDebugComponent_InvalidParams_MissingValue covers CodeRabbit +// PR review #2: `{"params":{"q":{}}}` (no `value` key) should +// fail-fast with 101 rather than silently invoking the component +// with inputs["q"]=nil. +func TestDebugComponent_InvalidParams_MissingValue(t *testing.T) { + cv := beginCanvas() + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + + c, w := componentCtx(t, "POST", "/api/v1/agents/c1/components/begin/debug", `{"params":{"q":{}}}`) + c.Params = gin.Params{ + {Key: "canvas_id", Value: "c1"}, + {Key: "component_id", Value: "begin"}, + } + h.DebugComponent(c) + + code, msg := errBody(t, w.Body.Bytes()) + if code != 101 { + t.Errorf("code = %d, want 101; msg=%q", code, msg) + } + if !strings.Contains(msg, "value") { + t.Errorf("msg = %q, want mention of 'value'", msg) + } +} + +func TestDebugComponent_UnknownComponent(t *testing.T) { + cv := beginCanvas() + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + + body := `{"params":{"x":{"value":1}}}` + c, w := componentCtx(t, "POST", "/api/v1/agents/c1/components/missing/debug", body) + c.Params = gin.Params{ + {Key: "canvas_id", Value: "c1"}, + {Key: "component_id", Value: "missing"}, + } + h.DebugComponent(c) + + code, msg := errBody(t, w.Body.Bytes()) + if code != 102 { + t.Errorf("code = %d, want 102; msg=%q", code, msg) + } + if !strings.Contains(msg, "component not found") { + t.Errorf("msg = %q, want 'component not found'", msg) + } +} + +func TestDebugComponent_CannotAccessCanvas(t *testing.T) { + h := &AgentHandler{loader: &fakeCanvasLoader{err: dao.ErrUserCanvasNotFound}} + + body := `{"params":{"x":{"value":1}}}` + c, w := componentCtx(t, "POST", "/api/v1/agents/c1/components/begin/debug", body) + c.Params = gin.Params{ + {Key: "canvas_id", Value: "c1"}, + {Key: "component_id", Value: "begin"}, + } + h.DebugComponent(c) + + code, msg := errBody(t, w.Body.Bytes()) + if code != 103 { // CodeOperatingError + t.Errorf("code = %d, want 103; msg=%q", code, msg) + } + want := "Make sure you have permission to access the agent." + if msg != want { + t.Errorf("msg = %q, want %q", msg, want) + } +} + +func TestDebugComponent_LoaderError_NonNotFound(t *testing.T) { + h := &AgentHandler{loader: &fakeCanvasLoader{err: errors.New("db conn lost")}} + + body := `{"params":{"x":{"value":1}}}` + c, w := componentCtx(t, "POST", "/api/v1/agents/c1/components/begin/debug", body) + c.Params = gin.Params{ + {Key: "canvas_id", Value: "c1"}, + {Key: "component_id", Value: "begin"}, + } + h.DebugComponent(c) + + code, _ := errBody(t, w.Body.Bytes()) + if code != 500 { // CodeServerError + t.Errorf("code = %d, want 500", code) + } +} diff --git a/internal/handler/agent_download.go b/internal/handler/agent_download.go new file mode 100644 index 0000000000..2b18e6f0b0 --- /dev/null +++ b/internal/handler/agent_download.go @@ -0,0 +1,86 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Gap A — `GET /api/v1/agents/download` (Python +// api/apps/restful_apis/agent_api.py:523-530). +// +// Mirrors the python download_agent_file handler: +// - auth via @login_required → GetUser +// - tenant injection via @add_tenant_id_to_kwargs → user.TenantID (we use +// user.ID here because the user/tenant distinction is collapsed in +// the Go session model; FileService buckets by userID for download +// retrieval, same as the python tenantID). See service/file.go:1033. +// - reads `id` from query string and streams raw bytes back as +// application/octet-stream. + +package handler + +import ( + "fmt" + "net/http" + "net/url" + "path/filepath" + "strings" + + "github.com/gin-gonic/gin" + + "ragflow/internal/common" +) + +// DownloadAgentFile GET /api/v1/agents/download?id= +func (h *AgentHandler) DownloadAgentFile(c *gin.Context) { + user, code, msg := GetUser(c) + if code != common.CodeSuccess { + jsonError(c, code, msg) + return + } + fileID := c.Query("id") + if fileID == "" { + jsonError(c, common.CodeArgumentError, "`id` is required.") + return + } + + // IDOR note (security review H1): the Go User struct has no + // TenantID field — the project collapses user and tenant in a + // single-tenant session model. Python's @add_tenant_id_to_kwargs + // resolves tenant_id from the session, and the python download + // endpoint also reads `id` directly from the query string with no + // per-object ownership check, so this port preserves the python + // shape. A future per-object ownership check should be added in + // both the python and Go code paths. + blob, err := h.fileService.DownloadAgentFile(user.ID, fileID) + if err != nil { + jsonError(c, common.CodeServerError, err.Error()) + return + } + + // Sanitize the Content-Disposition value to prevent header + // injection (security review H2). The Go net/http layer rejects + // CR/LF in header values, but we sanitize at the source so we + // don't rely on the implicit defense. `filepath.Base` strips any + // path elements; url.PathEscape produces an RFC 5987 filename*= + // value. + safe := filepath.Base(fileID) + if safe == "" || safe == "." || safe == "/" || strings.ContainsAny(safe, "\r\n\"") { + jsonError(c, common.CodeArgumentError, "invalid file id.") + return + } + c.Header("Content-Disposition", fmt.Sprintf( + `attachment; filename="%s"; filename*=UTF-8''%s`, + safe, url.PathEscape(safe), + )) + c.Data(http.StatusOK, "application/octet-stream", blob) +} diff --git a/internal/handler/agent_download_test.go b/internal/handler/agent_download_test.go new file mode 100644 index 0000000000..3f43e8df55 --- /dev/null +++ b/internal/handler/agent_download_test.go @@ -0,0 +1,139 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "errors" + "mime/multipart" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + + "ragflow/internal/entity" +) + +// fakeAgentFileService satisfies the agentFileService interface declared +// in agent.go. It returns canned bytes / errors for DownloadAgentFile so +// download-handler tests can run without a real FileService or MinIO. +// UploadInfos returns whatever is in `uploadList` so the upload tests +// can assert the response payload shape. UploadFromURL returns whatever +// is in `urlUpload` (or `urlUploadErr`) for the ?url= import path. +type fakeAgentFileService struct { + blob []byte + err error + uploadList []map[string]interface{} + uploadErr error + urlUpload map[string]interface{} + urlUploadErr error +} + +func (f *fakeAgentFileService) UploadInfos(_ string, _ []*multipart.FileHeader) ([]map[string]interface{}, error) { + if f.uploadErr != nil { + return nil, f.uploadErr + } + return f.uploadList, nil +} +func (f *fakeAgentFileService) UploadFromURL(_ string, _ string) (map[string]interface{}, error) { + if f.urlUploadErr != nil { + return nil, f.urlUploadErr + } + return f.urlUpload, nil +} +func (f *fakeAgentFileService) DownloadAgentFile(_ string, _ string) ([]byte, error) { + if f.err != nil { + return nil, f.err + } + return f.blob, nil +} + +// downloadCtx builds a Gin context for the /agents/download endpoint with +// optional id query parameter and a stubbed authenticated user. +func downloadCtx(t *testing.T, id string) (*gin.Context, *httptest.ResponseRecorder) { + t.Helper() + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + url := "/api/v1/agents/download" + if id != "" { + url += "?id=" + id + } + c.Request = httptest.NewRequest("GET", url, nil) + c.Set("user", &entity.User{ID: "user-1"}) + return c, w +} + +// TestDownloadAgentFile_HappyPath verifies that a known id returns the +// raw bytes with application/octet-stream content type and a sane +// Content-Disposition header. +func TestDownloadAgentFile_HappyPath(t *testing.T) { + h := &AgentHandler{fileService: &fakeAgentFileService{blob: []byte("hello world")}} + + c, w := downloadCtx(t, "file-123") + h.DownloadAgentFile(c) + + if w.Code != 200 { + t.Errorf("status = %d, want 200; body=%s", w.Code, w.Body.String()) + } + if w.Body.String() != "hello world" { + t.Errorf("body = %q, want %q", w.Body.String(), "hello world") + } + if got := w.Header().Get("Content-Disposition"); !strings.Contains(got, "file-123") { + t.Errorf("Content-Disposition = %q, want filename=file-123", got) + } + if got := w.Header().Get("Content-Type"); got != "application/octet-stream" { + t.Errorf("Content-Type = %q, want application/octet-stream", got) + } +} + +// TestDownloadAgentFile_MissingID verifies that an empty id is rejected +// with a 102-shaped envelope. +func TestDownloadAgentFile_MissingID(t *testing.T) { + h := &AgentHandler{fileService: &fakeAgentFileService{}} + + c, w := downloadCtx(t, "") // no id + h.DownloadAgentFile(c) + + if w.Code != 200 { + t.Errorf("status = %d, want 200 (envelope returned in body)", w.Code) + } + code, _ := errBody(t, w.Body.Bytes()) + if code != 101 { // CodeArgumentError + t.Errorf("code = %d, want 101", code) + } +} + +// TestDownloadAgentFile_LoaderError verifies that a storage error +// bubbles up as a 500-class envelope. +func TestDownloadAgentFile_LoaderError(t *testing.T) { + h := &AgentHandler{fileService: &fakeAgentFileService{err: errors.New("minio down")}} + + c, w := downloadCtx(t, "file-1") + h.DownloadAgentFile(c) + + if w.Code != 200 { + t.Errorf("status = %d, want 200 (envelope in body)", w.Code) + } + code, msg := errBody(t, w.Body.Bytes()) + if code != 500 { // CodeServerError + t.Errorf("code = %d, want 500", code) + } + if !strings.Contains(msg, "minio down") { + t.Errorf("msg = %q, want contains 'minio down'", msg) + } +} diff --git a/internal/handler/agent_upload.go b/internal/handler/agent_upload.go new file mode 100644 index 0000000000..728c01832e --- /dev/null +++ b/internal/handler/agent_upload.go @@ -0,0 +1,152 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Gap B — `POST /api/v1/agents//upload` (Python +// api/apps/restful_apis/agent_api.py:761-790). +// +// Mirrors the python upload_agent_file handler: +// - @_require_canvas_access_async → loader.LoadCanvasByID (returns +// ErrUserCanvasNotFound for both missing and forbidden; we map that +// to CodeOperatingError (103) with the python permission message +// instead of the chat-path 103 "canvas not found.") +// - single file + ?url= → FileService.upload_info(tenant, file, url) +// via UploadFromURL (URL-import mode) +// - single file + no url → FileService.upload_info(tenant, file) +// via UploadInfos with a one-element slice +// - multi file + any url → FileService.upload_info * N +// via UploadInfos (Python ignores ?url= on the multi-file path) +// +// 64 MB upload cap (D3 default). + +package handler + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + "ragflow/internal/common" + "ragflow/internal/dao" +) + +// uploadMaxBytes caps the multipart form body at 64 MB. The python +// reference relies on Quart/werkzeug defaults which are well above +// this; we set it explicitly so the cap is auditable and stable across +// test environments. +const uploadMaxBytes int64 = 64 << 20 // 64 MiB + +// canvasNoAccessMessage mirrors the python permission error +// (api/apps/restful_apis/agent_api.py:78,89). Kept identical to +// python so existing clients can pattern-match the message text. +const canvasNoAccessMessage = "Make sure you have permission to access the agent." + +// UploadAgentFile POST /api/v1/agents/:canvas_id/upload +func (h *AgentHandler) UploadAgentFile(c *gin.Context) { + user, code, msg := GetUser(c) + if code != common.CodeSuccess { + jsonError(c, code, msg) + return + } + canvasID := c.Param("canvas_id") + if canvasID == "" { + jsonError(c, common.CodeArgumentError, "`canvas_id` is required.") + return + } + + // Canvas access check: matches python @_require_canvas_access_async. + // We deliberately do NOT differentiate "missing" from "forbidden" + // (LoadCanvasByID collapses both into ErrUserCanvasNotFound) for + // IDOR mitigation; the user-visible envelope uses OPERATING_ERROR + // (103) with the python permission message so existing clients can + // still pattern-match the text. + if _, err := h.loader.LoadCanvasByID(c.Request.Context(), user.ID, canvasID); err != nil { + if err == dao.ErrUserCanvasNotFound { + jsonError(c, common.CodeOperatingError, canvasNoAccessMessage) + return + } + jsonError(c, common.CodeServerError, err.Error()) + return + } + + // Hard cap the body before any parsing (security review H3). + // Without MaxBytesReader, a 1 GB request body is fully drained + // into memory by ParseMultipartForm before any size check fires. + c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, uploadMaxBytes) + if cl := c.Request.ContentLength; cl > uploadMaxBytes { + jsonError(c, common.CodeArgumentError, "request body too large.") + return + } + if err := c.Request.ParseMultipartForm(uploadMaxBytes); err != nil { + jsonError(c, common.CodeArgumentError, "invalid multipart form: "+err.Error()) + return + } + defer func() { + if c.Request.MultipartForm != nil { + _ = c.Request.MultipartForm.RemoveAll() + } + }() + form := c.Request.MultipartForm + if form == nil { + jsonError(c, common.CodeArgumentError, "missing multipart form.") + return + } + files := form.File["file"] + + // URL-import mode: matches python's behaviour exactly + // (api/apps/restful_apis/agent_api.py:775-783). The url query + // param is consulted ONLY on the single-file branch; for 0 or + // >1 files, the url is silently ignored and the request flows + // into the normal UploadInfos path. We replicate that with a + // guard that dispatches to UploadFromURL only when both + // conditions are met. + if url := c.Query("url"); url != "" && len(files) == 1 { + uploaded, err := h.fileService.UploadFromURL(user.ID, url) + if err != nil { + jsonError(c, common.CodeServerError, err.Error()) + return + } + c.JSON(200, gin.H{ + "code": common.CodeSuccess, + "data": uploaded, + "message": "success", + }) + return + } + + if len(files) == 0 { + jsonError(c, common.CodeArgumentError, "`file` field is required.") + return + } + + results, err := h.fileService.UploadInfos(user.ID, files) + if err != nil { + jsonError(c, common.CodeServerError, err.Error()) + return + } + + // Python parity: 1 file → single dict; >1 → list. + var payload any + if len(results) == 1 { + payload = results[0] + } else { + payload = results + } + c.JSON(200, gin.H{ + "code": common.CodeSuccess, + "data": payload, + "message": "success", + }) +} diff --git a/internal/handler/agent_upload_test.go b/internal/handler/agent_upload_test.go index 8302288f33..30847bcdda 100644 --- a/internal/handler/agent_upload_test.go +++ b/internal/handler/agent_upload_test.go @@ -13,230 +13,305 @@ // See the License for the specific language governing permissions and // limitations under the License. // -//go:build ignore package handler import ( + "bytes" + "context" "encoding/json" + "errors" "mime/multipart" "net/http/httptest" + "strconv" "strings" "testing" "github.com/gin-gonic/gin" - "github.com/glebarez/sqlite" - "gorm.io/gorm" - "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/entity" - "ragflow/internal/service" ) -// setupUploadTestDB sets up SQLite in-memory DB for upload handler tests. -func setupUploadTestDB(t *testing.T) *gorm.DB { +// uploadFakes is a paired pair of fakes used by the upload-handler +// tests. The canvasLoader fake grants/denies access per case; the +// fileService fake records the call and returns a canned descriptor. +type uploadFakes struct { + loader *fakeCanvasLoader + fileSvc *fakeAgentFileService +} + +func (u *uploadFakes) setUploadResult(list []map[string]interface{}) { + u.fileSvc.uploadList = list +} + +// makeUploadCtx builds a Gin context with a multipart body containing +// `n` files under the "file" field. Each file's body is the +// decimal-string representation of its 1-based index. +func makeUploadCtx(t *testing.T, n int) (*gin.Context, *httptest.ResponseRecorder) { t.Helper() + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) - db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ - TranslateError: true, + body := &bytes.Buffer{} + mw := multipart.NewWriter(body) + for i := 1; i <= n; i++ { + name := "f" + strconv.Itoa(i) + ".txt" + fw, err := mw.CreateFormFile("file", name) + if err != nil { + t.Fatalf("CreateFormFile: %v", err) + } + if _, err := fw.Write([]byte(name)); err != nil { + t.Fatalf("write file: %v", err) + } + } + mw.Close() + + req := httptest.NewRequest("POST", "/api/v1/agents/c1/upload", body) + req.Header.Set("Content-Type", mw.FormDataContentType()) + c.Request = req + c.Params = gin.Params{{Key: "canvas_id", Value: "c1"}} + c.Set("user", &entity.User{ID: "u-1"}) + return c, w +} + +// makeUploadCtxNoFile builds a Gin context with a multipart body that +// has no "file" field (used for the missing-file test). +func makeUploadCtxNoFile(t *testing.T) (*gin.Context, *httptest.ResponseRecorder) { + t.Helper() + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + body := &bytes.Buffer{} + mw := multipart.NewWriter(body) + // "other" field, not "file" + fw, _ := mw.CreateFormFile("other", "o.txt") + fw.Write([]byte("o")) + mw.Close() + req := httptest.NewRequest("POST", "/api/v1/agents/c1/upload", body) + req.Header.Set("Content-Type", mw.FormDataContentType()) + c.Request = req + c.Params = gin.Params{{Key: "canvas_id", Value: "c1"}} + c.Set("user", &entity.User{ID: "u-1"}) + return c, w +} + +// TestUploadAgentFile_SingleFile pins the 1-file → single-dict +// response shape (python agent_api.py:775-779). +func TestUploadAgentFile_SingleFile(t *testing.T) { + loader := &fakeCanvasLoader{canvas: &entity.UserCanvas{ID: "c1"}} + fu := &uploadFakes{ + loader: loader, + fileSvc: &fakeAgentFileService{}, + } + fu.setUploadResult([]map[string]interface{}{{"id": "upload-1", "name": "f1.txt"}}) + h := &AgentHandler{loader: fu.loader, fileService: fu.fileSvc} + + c, w := makeUploadCtx(t, 1) + h.UploadAgentFile(c) + + if w.Code != 200 { + t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String()) + } + // Single-file path returns data as a single dict, NOT a list. + var env struct { + Code int `json:"code"` + Data map[string]interface{} `json:"data"` + } + if err := json.Unmarshal(w.Body.Bytes(), &env); err != nil { + t.Fatalf("decode: %v (body=%s)", err, w.Body.String()) + } + if env.Code != 0 { + t.Errorf("code = %d, want 0", env.Code) + } + if env.Data["id"] != "upload-1" { + t.Errorf("data.id = %v, want upload-1", env.Data["id"]) + } +} + +// TestUploadAgentFile_MultiFile pins the >1-file → list-of-dicts +// response shape (python agent_api.py:780-783). +func TestUploadAgentFile_MultiFile(t *testing.T) { + loader := &fakeCanvasLoader{canvas: &entity.UserCanvas{ID: "c1"}} + fu := &uploadFakes{loader: loader, fileSvc: &fakeAgentFileService{}} + fu.setUploadResult([]map[string]interface{}{ + {"id": "u1", "name": "f1.txt"}, + {"id": "u2", "name": "f2.txt"}, + {"id": "u3", "name": "f3.txt"}, }) - if err != nil { - t.Fatalf("failed to open sqlite: %v", err) - } + h := &AgentHandler{loader: fu.loader, fileService: fu.fileSvc} - if err := db.AutoMigrate( - &entity.User{}, - &entity.UserCanvas{}, - &entity.UserTenant{}, - ); err != nil { - t.Fatalf("failed to migrate: %v", err) - } - - return db -} - -// fakeUploadFileService implements fileUploader for tests. -type fakeUploadFileService struct { - uploaded []map[string]interface{} - err error - lastTenantID string - lastParentID string -} - -func (f *fakeUploadFileService) UploadFile(tenantID, parentID string, files []*multipart.FileHeader) ([]map[string]interface{}, error) { - f.lastTenantID = tenantID - f.lastParentID = parentID - return f.uploaded, f.err -} - -// TestUploadAgentFileHandler_Success verifies the happy path. -func TestUploadAgentFileHandler_Success(t *testing.T) { - gin.SetMode(gin.TestMode) - - db := setupUploadTestDB(t) - orig := dao.DB - dao.DB = db - t.Cleanup(func() { dao.DB = orig }) - - db.Create(&entity.User{ID: "user-1", Nickname: "test", Email: "test@test.com"}) - db.Create(&entity.UserCanvas{ID: "canvas-1", UserID: "user-1", Title: sp("Test Agent")}) - - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body := strings.NewReader("--boundary\r\nContent-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\nContent-Type: text/plain\r\n\r\nhello world\r\n--boundary--") - req := httptest.NewRequest("POST", "/api/v1/agents/canvas-1/upload", body) - req.Header.Set("Content-Type", "multipart/form-data; boundary=boundary") - c.Request = req - c.Set("user", &entity.User{ID: "user-1"}) - c.Set("user_id", "user-1") - c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}} - - svc := &fakeUploadFileService{ - uploaded: []map[string]interface{}{ - {"id": "file-1", "name": "test.txt"}, - }, - } - h := &AgentHandler{ - agentService: service.NewAgentService(), - fileService: svc, - } + c, w := makeUploadCtx(t, 3) h.UploadAgentFile(c) - var resp map[string]interface{} - json.Unmarshal(w.Body.Bytes(), &resp) - code, _ := resp["code"].(float64) - if code != float64(common.CodeSuccess) { - t.Fatalf("expected code 0, got %v: %v", code, resp["message"]) + if w.Code != 200 { + t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String()) + } + var env struct { + Code int `json:"code"` + Data []map[string]interface{} `json:"data"` + } + if err := json.Unmarshal(w.Body.Bytes(), &env); err != nil { + t.Fatalf("decode: %v (body=%s)", err, w.Body.String()) + } + if env.Code != 0 { + t.Errorf("code = %d, want 0", env.Code) + } + if len(env.Data) != 3 { + t.Errorf("data has %d items, want 3", len(env.Data)) } } -// TestUploadAgentFileHandler_NoPermission verifies cross-user access is denied. -func TestUploadAgentFileHandler_NoPermission(t *testing.T) { - gin.SetMode(gin.TestMode) +// TestUploadAgentFile_MissingFileField verifies that a multipart body +// without a "file" field is rejected with a 101 envelope. +func TestUploadAgentFile_MissingFileField(t *testing.T) { + loader := &fakeCanvasLoader{canvas: &entity.UserCanvas{ID: "c1"}} + h := &AgentHandler{loader: loader, fileService: &fakeAgentFileService{}} - db := setupUploadTestDB(t) - orig := dao.DB - dao.DB = db - t.Cleanup(func() { dao.DB = orig }) - - db.Create(&entity.User{ID: "user-a", Nickname: "a", Email: "a@test.com"}) - db.Create(&entity.UserCanvas{ID: "canvas-b", UserID: "user-b", Title: sp("Not Yours")}) - - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = httptest.NewRequest("POST", "/api/v1/agents/canvas-b/upload", nil) - c.Set("user", &entity.User{ID: "user-a"}) - c.Set("user_id", "user-a") - c.Params = gin.Params{{Key: "agent_id", Value: "canvas-b"}} - - h := &AgentHandler{ - agentService: service.NewAgentService(), - fileService: &fakeUploadFileService{}, - } + c, w := makeUploadCtxNoFile(t) h.UploadAgentFile(c) - var resp map[string]interface{} - json.Unmarshal(w.Body.Bytes(), &resp) - code, _ := resp["code"].(float64) - if code != float64(common.CodeOperatingError) { - t.Errorf("expected operating error %d, got %v", common.CodeOperatingError, code) + code, msg := errBody(t, w.Body.Bytes()) + if code != 101 { // CodeArgumentError + t.Errorf("code = %d, want 101; msg=%q", code, msg) + } + if !strings.Contains(msg, "file") { + t.Errorf("msg = %q, want mention of 'file'", msg) } } -// TestUploadAgentFileHandler_NoFiles verifies empty file list is rejected. -func TestUploadAgentFileHandler_NoFiles(t *testing.T) { - gin.SetMode(gin.TestMode) +// TestUploadAgentFile_CannotAccessCanvas verifies that an +// ErrUserCanvasNotFound from the loader short-circuits with a 103 +// envelope carrying the python permission-failure message +// (matches @_require_canvas_access_async at agent_api.py:78,89). +func TestUploadAgentFile_CannotAccessCanvas(t *testing.T) { + loader := &fakeCanvasLoader{err: dao.ErrUserCanvasNotFound} + h := &AgentHandler{loader: loader, fileService: &fakeAgentFileService{}} - db := setupUploadTestDB(t) - orig := dao.DB - dao.DB = db - t.Cleanup(func() { dao.DB = orig }) - - db.Create(&entity.User{ID: "user-1", Nickname: "test", Email: "test@test.com"}) - db.Create(&entity.UserCanvas{ID: "canvas-1", UserID: "user-1", Title: sp("Test Agent")}) - - body := strings.NewReader("--boundary\r\nContent-Disposition: form-data; name=\"dummy\"\r\n\r\nvalue\r\n--boundary--") - req := httptest.NewRequest("POST", "/api/v1/agents/canvas-1/upload", body) - req.Header.Set("Content-Type", "multipart/form-data; boundary=boundary") - - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = req - c.Set("user", &entity.User{ID: "user-1"}) - c.Set("user_id", "user-1") - c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}} - - h := &AgentHandler{ - agentService: service.NewAgentService(), - fileService: &fakeUploadFileService{}, - } + c, w := makeUploadCtx(t, 1) h.UploadAgentFile(c) - var resp map[string]interface{} - json.Unmarshal(w.Body.Bytes(), &resp) - code, _ := resp["code"].(float64) - if code != float64(common.CodeArgumentError) { - t.Errorf("expected argument error, got code %v", code) + code, msg := errBody(t, w.Body.Bytes()) + if code != 103 { // CodeOperatingError + t.Errorf("code = %d, want 103; msg=%q", code, msg) + } + want := "Make sure you have permission to access the agent." + if msg != want { + t.Errorf("msg = %q, want %q", msg, want) } } -// TestUploadAgentFileHandler_TeamMemberTenant verifies that when a team -// member uploads to a shared canvas, the file is written into the canvas -// owner's file tree, not the caller's. -func TestUploadAgentFileHandler_TeamMemberTenant(t *testing.T) { - gin.SetMode(gin.TestMode) +// TestUploadAgentFile_URLImport pins the ?url= import path (python +// agent_api.py:775-779). When ?url= is set AND the body has exactly +// one `file` field, the handler delegates to FileService.UploadFromURL +// and returns its single-dict result. +func TestUploadAgentFile_URLImport(t *testing.T) { + loader := &fakeCanvasLoader{canvas: &entity.UserCanvas{ID: "c1"}} + fu := &uploadFakes{loader: loader, fileSvc: &fakeAgentFileService{}} + fu.fileSvc.urlUpload = map[string]interface{}{"id": "url-1", "name": "remote.bin"} + h := &AgentHandler{loader: fu.loader, fileService: fu.fileSvc} - db := setupUploadTestDB(t) - orig := dao.DB - dao.DB = db - t.Cleanup(func() { dao.DB = orig }) + c, w := makeUploadCtx(t, 1) + c.Request.URL.RawQuery = "url=https%3A%2F%2Fexample.com%2Ffile.bin" + h.UploadAgentFile(c) - // user-b is a team member of user-a's tenant - db.Create(&entity.User{ID: "user-a", Nickname: "owner", Email: "a@test.com"}) - db.Create(&entity.User{ID: "user-b", Nickname: "member", Email: "b@test.com"}) - db.Create(&entity.UserTenant{ID: "ut-1", UserID: "user-b", TenantID: "user-a", Role: "member", Status: sp("1")}) - db.Create(&entity.UserCanvas{ - ID: "canvas-1", - UserID: "user-a", - Permission: "team", - Title: sp("Shared Agent"), + if w.Code != 200 { + t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String()) + } + var env struct { + Code int `json:"code"` + Data map[string]interface{} `json:"data"` + } + if err := json.Unmarshal(w.Body.Bytes(), &env); err != nil { + t.Fatalf("decode: %v (body=%s)", err, w.Body.String()) + } + if env.Code != 0 { + t.Errorf("code = %d, want 0", env.Code) + } + if env.Data["id"] != "url-1" { + t.Errorf("data.id = %v, want url-1", env.Data["id"]) + } +} + +// TestUploadAgentFile_URLImport_IgnoredForMultiFile pins that +// ?url= is silently ignored when the body has >1 files, matching +// python's behaviour at agent_api.py:780-783 (the multi-file branch +// never reads url). The request flows into the normal UploadInfos +// path and returns a list of dicts. +func TestUploadAgentFile_URLImport_IgnoredForMultiFile(t *testing.T) { + loader := &fakeCanvasLoader{canvas: &entity.UserCanvas{ID: "c1"}} + fu := &uploadFakes{loader: loader, fileSvc: &fakeAgentFileService{}} + fu.fileSvc.urlUpload = map[string]interface{}{"id": "should-not-appear"} + fu.setUploadResult([]map[string]interface{}{ + {"id": "u1", "name": "f1.txt"}, + {"id": "u2", "name": "f2.txt"}, + {"id": "u3", "name": "f3.txt"}, }) + h := &AgentHandler{loader: fu.loader, fileService: fu.fileSvc} - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - body := strings.NewReader("--boundary\r\nContent-Disposition: form-data; name=\"file\"; filename=\"shared.txt\"\r\nContent-Type: text/plain\r\n\r\nhello\r\n--boundary--") - req := httptest.NewRequest("POST", "/api/v1/agents/canvas-1/upload", body) - req.Header.Set("Content-Type", "multipart/form-data; boundary=boundary") - c.Request = req - c.Set("user", &entity.User{ID: "user-b"}) - c.Set("user_id", "user-b") - c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}} - - svc := &fakeUploadFileService{ - uploaded: []map[string]interface{}{ - {"id": "file-1", "name": "shared.txt"}, - }, - } - h := &AgentHandler{ - agentService: service.NewAgentService(), - fileService: svc, - } + c, w := makeUploadCtx(t, 3) + c.Request.URL.RawQuery = "url=https%3A%2F%2Fexample.com%2Ffile.bin" h.UploadAgentFile(c) - var resp map[string]interface{} - json.Unmarshal(w.Body.Bytes(), &resp) - if resp["code"] != float64(common.CodeSuccess) { - t.Fatalf("expected code 0, got %v: %v", resp["code"], resp["message"]) + if w.Code != 200 { + t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String()) } - if svc.lastTenantID != "user-b" { - t.Errorf("expected UploadFile called with authenticated user 'user-a', got '%s'", svc.lastTenantID) + var env struct { + Code int `json:"code"` + Data []map[string]interface{} `json:"data"` + } + if err := json.Unmarshal(w.Body.Bytes(), &env); err != nil { + t.Fatalf("decode: %v (body=%s)", err, w.Body.String()) + } + if env.Code != 0 { + t.Errorf("code = %d, want 0", env.Code) + } + if len(env.Data) != 3 { + t.Errorf("data has %d items, want 3 (url was ignored)", len(env.Data)) + } + // Sanity check: the urlUpload map should NOT have been consumed. + // We don't read it back, but the test setup proves the + // handler didn't dispatch to UploadFromURL (otherwise it would + // have returned the single urlUpload dict, not the 3-element + // list from uploadList). +} + +// TestUploadAgentFile_URLImport_LoaderError pins that a failed URL +// fetch surfaces as a 500 envelope (matches python's +// `server_error_response(exc)` on line 784). +func TestUploadAgentFile_URLImport_LoaderError(t *testing.T) { + loader := &fakeCanvasLoader{canvas: &entity.UserCanvas{ID: "c1"}} + fu := &uploadFakes{loader: loader, fileSvc: &fakeAgentFileService{urlUploadErr: errors.New("ssrf guard tripped")}} + h := &AgentHandler{loader: fu.loader, fileService: fu.fileSvc} + + c, w := makeUploadCtx(t, 1) + c.Request.URL.RawQuery = "url=https%3A%2F%2Finternal.example.com%2Fsecret" + h.UploadAgentFile(c) + + code, msg := errBody(t, w.Body.Bytes()) + if code != 500 { // CodeServerError + t.Errorf("code = %d, want 500; msg=%q", code, msg) + } + if !strings.Contains(msg, "ssrf guard tripped") { + t.Errorf("msg = %q, want contains 'ssrf guard tripped'", msg) } } -// sp returns a pointer to the given string. -func sp(s string) *string { return &s } -func (f *fakeUploadFileService) DownloadAgentFile(tenantID, location string) ([]byte, error) { - panic("DownloadAgentFile should not be called during upload tests") +// TestUploadAgentFile_LoaderError verifies that a non-ErrUserCanvasNotFound +// service error surfaces as a 500 envelope with the error string. +func TestUploadAgentFile_LoaderError(t *testing.T) { + loader := &fakeCanvasLoader{err: context.DeadlineExceeded} + h := &AgentHandler{loader: loader, fileService: &fakeAgentFileService{}} + + c, w := makeUploadCtx(t, 1) + h.UploadAgentFile(c) + + code, msg := errBody(t, w.Body.Bytes()) + if code != 500 { // CodeServerError + t.Errorf("code = %d, want 500; msg=%q", code, msg) + } + if !strings.Contains(msg, "deadline") { + t.Errorf("msg = %q, want contains 'deadline'", msg) + } } diff --git a/internal/handler/agent_webhook.go b/internal/handler/agent_webhook.go new file mode 100644 index 0000000000..1e1a1a550c --- /dev/null +++ b/internal/handler/agent_webhook.go @@ -0,0 +1,669 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +// Webhook trigger handler — Go port of +// api/apps/restful_apis/agent_api.py:1563-2248 +// (`/api/v1/agents//webhook` and `/.../webhook/test`). +// +// Mirrors the python reference's 9-step flow: +// 1. Load canvas (IDOR-safe via loadCanvasForUser → LoadCanvasByID). +// 2. Reject DataFlow canvas category. +// 3. Parse DSL map. +// 4. Find Begin with mode=="Webhook"; capture webhook_cfg. +// 5. Validate request method against webhook_cfg.methods. +// 6. validateWebhookSecurity (max body size / IP whitelist / rate limit +// / token / basic / jwt — strict fail-closed rate limit). +// 7. Parse request body (parseWebhookRequest). +// 8. Schema extract query/headers/body via extractBySchema. +// 9. Dispatch: +// - execution_mode == "Immediately" → return configured status + +// body_template synchronously; canvas runs detached in the +// background. Trace events appended when isTest. +// - else (streaming/aggregate) → block on canvas.run, aggregate +// message content, return JSON. +// +// Notes on Python parity divergences: +// - `cvs.dsl = json.loads(str(canvas)); UserCanvasService.update_by_id(...)` +// post-run DSL writeback is NOT ported. The Go runner mutates an +// in-memory copy, not the persisted row. UpdateAgent already persists +// the editable DSL on the user-driven path. +// - multipart/form-data uploads are rejected with HTTP 501 +// (ErrWebhookMultipartNotSupported) at parse time. The Python +// upload path through canvas.get_files_async depends on +// FileService.upload_info which is itself not in the Go port yet. +// When FileService.upload_info lands, the 501 branch in Webhook +// (line 162-185) becomes the dispatch entry point. +// - content_types whitelist ENFORCES a hard reject (102 envelope) +// when the request Content-Type does not match the configured +// value. Mirrors python agent_api.py:1839-1842 (`raise +// ValueError("Invalid Content-Type...")`). +// - Webhook trace shape (webhook-trace--logs in Redis) matches +// the python append_webhook_trace at agent_api.py:2073-2091 so the +// eventual /webhook/logs poll PR can consume this writer +// unchanged. + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "ragflow/internal/agent/canvas" + "ragflow/internal/common" + rediscli "ragflow/internal/engine/redis" + "strconv" + "strings" + "time" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" + + "ragflow/internal/dao" + "ragflow/internal/entity" +) + +// canvasLoader is the subset of service.AgentService the webhook handler +// needs. Defined as an interface so handler tests can inject a fake +// without standing up the full AgentService (DB DAOs, eino runner, etc). +// +// LoadCanvasByID returns the raw DAO/service error so the WEBHOOK +// handler can fold it into its own 102 envelope. We deliberately do NOT +// route through handler.mapAgentError here because that helper maps +// ErrUserCanvasNotFound to 103 "Make sure you have permission..." which +// is the chat/run contract — wrong for the webhook path. +type canvasLoader interface { + LoadCanvasByID(ctx context.Context, userID, canvasID string) (*entity.UserCanvas, error) + RunAgentWithWebhook(ctx context.Context, userID, canvasID string, payload map[string]any) (<-chan canvas.RunEvent, error) +} + +// Webhook is the handler method mounted at: +// +// /api/v1/agents/:canvas_id/webhook (production trigger) +// /api/v1/agents/:canvas_id/webhook/test (test trigger, with trace) +// +// The Python decorator stack at agent_api.py:1563 binds six methods to +// a single path. Gin has no Match() — the router registers each verb +// individually via registerAnyMethod (see internal/router/agent_routes.go). +func (h *AgentHandler) Webhook(c *gin.Context) { + user, code, msg := GetUser(c) + if code != common.CodeSuccess { + jsonError(c, code, msg) + return + } + + canvasID := c.Param("canvas_id") + if h.loader == nil { + jsonInternalError(c, errors.New("agent webhook: loader not configured")) + return + } + isTest := strings.HasSuffix(c.Request.URL.Path, "/webhook/test") + startTs := time.Now() + + // 1. Load canvas. Webhook collapses missing-vs-foreign into a + // single 102 "Canvas not found." (matches Python + // api/apps/restful_apis/agent_api.py:1572 and the existing + // GetAgentWebhookLogs envelope), so we DO NOT route through + // handler.mapAgentError here. + cv, err := h.loader.LoadCanvasByID(c.Request.Context(), user.ID, canvasID) + if err != nil { + if errors.Is(err, dao.ErrUserCanvasNotFound) || errors.Is(err, dao.ErrUserCanvasVersionNotFound) { + jsonError(c, common.CodeDataError, "Canvas not found.") + return + } + jsonInternalError(c, err) + return + } + + // 2. Reject DataFlow. + if cv.CanvasCategory == "DataFlow" { + jsonError(c, common.CodeDataError, "Dataflow can not be triggered by webhook.") + return + } + + // 3. DSL map. cv.DSL is a typed entity.JSONMap (map[string]any + // under the hood); we copy into a plain map[string]any so the + // downstream helpers can mutate freely without aliasing surprises. + dsl := map[string]any{} + for k, v := range cv.DSL { + dsl[k] = v + } + + // 4. Find Begin component with mode=="Webhook". + webhookCfg := findWebhookBegin(dsl) + if webhookCfg == nil { + jsonError(c, common.CodeDataError, "Webhook not configured for this agent.") + return + } + + // 5. Method gate. + if !methodAllowed(webhookCfg["methods"], c.Request.Method) { + jsonError(c, common.CodeDataError, + fmt.Sprintf("HTTP method '%s' not allowed for this webhook.", c.Request.Method)) + return + } + + // 6. Security gate (strict; surfaces all errors as 102). + securityCfg := stringMap(webhookCfg["security"]) + if err := validateWebhookSecurity(securityCfg, c, canvasID); err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + + // 6a. Body size DoS hardening (security review MEDIUM-1). The + // security validator above checks the Content-Length header, but + // that header is attacker-controllable. Wrap the actual stream + // reader with http.MaxBytesReader so io.ReadAll inside + // parseWebhookRequest is bounded by the same parsed limit. + // Errors here surface as ErrWebhookContentTypeMismatch-shaped 102 + // so the operator sees the failure mode consistently with the + // header check. + if limit, perr := parseMaxBodySize(securityCfg); perr == nil && limit > 0 && c.Request.Body != nil { + c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, limit) + } + + // 7. Parse request body. Two error classes flow through here: + // - ErrWebhookMultipartNotSupported → HTTP 501 (file uploads + // depend on FileService.upload_info which is not yet ported). + // - ErrWebhookContentTypeMismatch → HTTP 102 (content-type + // whitelist, mirrors python agent_api.py:1839). + // Both are surfaced as typed errors so the dispatch envelope is + // unambiguous — neither path falls through into schema validation. + contentType, _ := webhookCfg["content_types"].(string) + parsed, parseErr := parseWebhookRequest(contentType, c) + if parseErr != nil { + switch { + case errors.Is(parseErr, ErrWebhookMultipartNotSupported): + // 501 — multipart/form-data uploads are not yet + // supported. Body is short so operators see exactly what + // is missing. + c.JSON(http.StatusNotImplemented, gin.H{ + "code": http.StatusNotImplemented, + "data": false, + "message": parseErr.Error(), + }) + return + case errors.Is(parseErr, ErrWebhookContentTypeMismatch): + jsonError(c, common.CodeDataError, parseErr.Error()) + return + default: + jsonError(c, common.CodeDataError, parseErr.Error()) + return + } + } + + // 8. Schema extract query/headers/body. + schema, _ := webhookCfg["schema"].(map[string]any) + clean, schemaErr := applyWebhookSchema(parsed, schema) + if schemaErr != nil { + jsonError(c, common.CodeDataError, schemaErr.Error()) + return + } + + // 9. Dispatch. + mode, _ := webhookCfg["execution_mode"].(string) + if mode == "" { + mode = "Immediately" + } + if mode == "Immediately" { + status, contentType, payload, perr := renderImmediatelyResponse(stringMap(webhookCfg["response"])) + if perr != nil { + jsonError(c, common.CodeDataError, perr.Error()) + return + } + // Detached background run — does NOT inherit c.Request.Context() + // so a client disconnect does not cancel the canvas run. + go h.runWebhookDetached(cv, clean, isTest, startTs) + c.Data(status, contentType, payload) + return + } + // Streaming / blocking mode. + out := h.runWebhookSync(c.Request.Context(), cv, clean, isTest, startTs) + c.JSON(out.status, out.body) +} + +// findWebhookBegin scans the DSL components map for a Begin component +// whose params.mode equals "Webhook". Returns the params map (the +// webhook_cfg) on success; nil otherwise. Mirrors agent_api.py:1584-1592. +func findWebhookBegin(dsl map[string]any) map[string]any { + if dsl == nil { + return nil + } + components, _ := dsl["components"].(map[string]any) + if components == nil { + return nil + } + for _, raw := range components { + entry, _ := raw.(map[string]any) + if entry == nil { + continue + } + obj, _ := entry["obj"].(map[string]any) + if obj == nil { + continue + } + name, _ := obj["component_name"].(string) + if !strings.EqualFold(name, "begin") { + continue + } + params, _ := obj["params"].(map[string]any) + if params == nil { + continue + } + if mode, _ := params["mode"].(string); mode == "Webhook" { + return params + } + } + return nil +} + +// methodAllowed returns true when `requestMethod` is in the configured +// list. Empty list → allow (matches python: agent_api.py:1596 short-circuit). +func methodAllowed(raw any, requestMethod string) bool { + methods, _ := raw.([]any) + if len(methods) == 0 { + return true + } + want := strings.ToUpper(requestMethod) + for _, m := range methods { + if s, _ := m.(string); strings.ToUpper(s) == want { + return true + } + } + return false +} + +// stringMap is a tiny helper: extract a map[string]any, treating nil and +// wrong-type cases as empty maps. Used heavily to keep the dispatch +// logic readable. +func stringMap(v any) map[string]any { + if m, ok := v.(map[string]any); ok { + return m + } + return map[string]any{} +} + +// parseWebhookRequest mirrors agent_api.py:1828-1894. +// +// Returns (parsed, error). The parsed map carries query / headers / +// body / content_type. The body is parsed according to the request +// Content-Type: +// +// - application/json → unmarshal JSON +// - application/x-www-form-urlencoded → form fields +// - text/plain / octet-stream / unknown / empty → raw bytes → JSON +// +// Two errors are surfaced directly to the handler so the dispatch path +// can return the right envelope without falling through into a +// misleading schema-validation error: +// +// - ErrWebhookMultipartNotSupported (501): the request used +// multipart/form-data. The Python path uploads files through +// FileService.upload_info → canvas.get_files_async, both of which +// are NOT yet ported. We refuse the request up front so callers +// see a clear, typed 501 instead of an unrelated schema error. +// +// - ErrWebhookContentTypeMismatch (102): when the webhook config +// sets content_types, the request Content-Type must match. The +// Python reference raises here +// (`raise ValueError("Invalid Content-Type...")` at +// agent_api.py:1839-1842); we mirror that as a typed error and +// surface it through the same 102 envelope as the rest of the +// validation errors so operators see the same response shape. +func parseWebhookRequest(configuredContentType string, c *gin.Context) (map[string]any, error) { + // 1. Query + q := map[string]any{} + for k, vals := range c.Request.URL.Query() { + if len(vals) == 1 { + q[k] = vals[0] + } else { + q[k] = vals + } + } + + // 2. Headers + hd := map[string]any{} + for k, vals := range c.Request.Header { + if len(vals) == 1 { + hd[k] = vals[0] + } else { + hd[k] = vals + } + } + + // 3. Body + ctype := strings.SplitN(c.GetHeader("Content-Type"), ";", 2)[0] + ctype = strings.TrimSpace(strings.ToLower(ctype)) + + // multipart/form-data → 501. Checked BEFORE the content-type match + // below because multipart uploads don't carry a useful + // `content_types` value to validate against — the python handler + // routes them through canvas.get_files_async which we have not + // ported yet. + if ctype == "multipart/form-data" { + return nil, ErrWebhookMultipartNotSupported + } + + body := map[string]any{} + + switch ctype { + case "application/json": + raw, _ := io.ReadAll(c.Request.Body) + if len(raw) > 0 { + _ = json.Unmarshal(raw, &body) + } + case "application/x-www-form-urlencoded": + if err := c.Request.ParseForm(); err == nil { + for k, vals := range c.Request.PostForm { + if len(vals) == 1 { + body[k] = vals[0] + } else { + body[k] = vals + } + } + } + default: + raw, _ := io.ReadAll(c.Request.Body) + if len(raw) > 0 { + _ = json.Unmarshal(raw, &body) + } + } + + // content_type whitelist. Empty configuredContentType → no check + // (matches python: agent_api.py:1839 only raises when the + // webhook_cfg actually sets content_types). + // + // The python reference has a `if ctype and ...` short-circuit that + // lets a request with NO Content-Type header through even when the + // webhook config requires one. That is a whitelist bypass — the + // configured type is irrelevant if a caller can simply omit the + // header. The Go port tightens the contract: when the operator + // configured content_types, the request MUST carry a matching + // Content-Type. Missing header → ErrWebhookContentTypeMismatch. + if configuredContentType != "" && configuredContentType != ctype { + return nil, fmt.Errorf("%w: expect %q, got %q", + ErrWebhookContentTypeMismatch, configuredContentType, ctype) + } + + return map[string]any{ + "query": q, + "headers": hd, + "body": body, + "content_type": ctype, + }, nil +} + +// ErrWebhookMultipartNotSupported is returned by parseWebhookRequest +// when the inbound Content-Type is multipart/form-data. The handler +// translates this to HTTP 501 Not Implemented because the Python file +// upload path (FileService.upload_info → canvas.get_files_async) is +// not ported yet. +var ErrWebhookMultipartNotSupported = errors.New("multipart/form-data uploads are not supported in this port") + +// ErrWebhookContentTypeMismatch is returned by parseWebhookRequest when +// the configured content_types whitelist disagrees with the request's +// Content-Type. Mirrors python agent_api.py:1839-1842 (`raise +// ValueError("Invalid Content-Type...")`). Surfaced as 102 so the +// operator sees a clear, content-type-specific message. +var ErrWebhookContentTypeMismatch = errors.New("invalid content-type") + +// applyWebhookSchema runs extractBySchema on the parsed request's three +// sections and assembles the clean_request map that the Python +// webhook handler passes to canvas.run(webhook_payload=...). Mirrors +// agent_api.py:2052-2068. +func applyWebhookSchema(parsed map[string]any, schema map[string]any) (map[string]any, error) { + q := stringMap(parsed["query"]) + hd := stringMap(parsed["headers"]) + bd := stringMap(parsed["body"]) + + qClean, err := extractBySchema(q, stringMap(schema["query"]), "query") + if err != nil { + return nil, err + } + hdClean, err := extractBySchema(hd, stringMap(schema["headers"]), "headers") + if err != nil { + return nil, err + } + bdClean, err := extractBySchema(bd, stringMap(schema["body"]), "body") + if err != nil { + return nil, err + } + return map[string]any{ + "query": qClean, + "headers": hdClean, + "body": bdClean, + "input": parsed, + }, nil +} + +// renderImmediatelyResponse builds the synchronous response for the +// Immediately execution mode. Mirrors agent_api.py:2093-2121. +// +// - status: int in [200, 399]; defaults to 200; any other value +// raises so the operator notices a config bug. +// - body_template: JSON or text. We try JSON first; fall back to +// plain text on failure. Empty body → no body, content-type +// application/json (matching python parse_body(None)). +func renderImmediatelyResponse(cfg map[string]any) (int, string, []byte, error) { + statusRaw, ok := cfg["status"] + status := 200 + if ok { + switch v := statusRaw.(type) { + case int: + status = v + case int64: + status = int(v) + case float64: + status = int(v) + case string: + n, err := strconv.Atoi(v) + if err != nil { + return 0, "", nil, fmt.Errorf("invalid response status code: %v", v) + } + status = n + default: + return 0, "", nil, fmt.Errorf("invalid response status code: %v", v) + } + } + if status < 200 || status > 399 { + return 0, "", nil, fmt.Errorf("invalid response status code: %d, must be between 200 and 399", status) + } + + bodyTpl, _ := cfg["body_template"].(string) + if bodyTpl == "" { + return status, "application/json", nil, nil + } + + // Try JSON parse first (python: parse_body() branch at line 2110). + var probe any + if err := json.Unmarshal([]byte(bodyTpl), &probe); err == nil { + encoded, _ := json.Marshal(probe) + return status, "application/json", encoded, nil + } + return status, "text/plain", []byte(bodyTpl), nil +} + +// runWebhookDetached runs the canvas in the background. It uses +// context.Background() with a 5-minute timeout (NOT +// c.Request.Context()) so a client disconnect does NOT cancel the run. +// Trace events are appended to the redis key when isTest is true. +// +// Mirrors python: agent_api.py:2123-2175 (the asyncio.create_task body +// inside the Immediately branch). +func (h *AgentHandler) runWebhookDetached( + cv *entity.UserCanvas, payload map[string]any, isTest bool, startTs time.Time, +) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + events, err := h.loader.RunAgentWithWebhook(ctx, cv.UserID, cv.ID, payload) + if err != nil { + common.Warn("webhook detached run start failed", + zap.String("canvas", cv.ID), + zap.Error(err)) + if isTest { + appendWebhookTrace(cv.ID, startTs, canvas.RunEvent{Type: "error", Data: mustJSON(map[string]any{"message": err.Error()})}) + } + return + } + for ev := range events { + if isTest { + appendWebhookTrace(cv.ID, startTs, ev) + } + } +} + +// runWebhookSync drives the canvas in the non-Immediately (streaming) +// mode. Returns the HTTP status + body to send. Mirrors +// agent_api.py:2178-2247 with the python sse() coroutine flattened into +// a synchronous loop. +type webhookSyncResult struct { + status int + body any +} + +func (h *AgentHandler) runWebhookSync( + ctx context.Context, cv *entity.UserCanvas, payload map[string]any, + isTest bool, startTs time.Time, +) webhookSyncResult { + status := 200 + events, err := h.loader.RunAgentWithWebhook(ctx, cv.UserID, cv.ID, payload) + if err != nil { + if isTest { + appendWebhookTrace(cv.ID, startTs, canvas.RunEvent{Type: "error", Data: mustJSON(map[string]any{"message": err.Error()})}) + appendWebhookTrace(cv.ID, startTs, canvas.RunEvent{Type: "finished", Data: mustJSON(map[string]any{"success": false})}) + } + return webhookSyncResult{status: http.StatusBadRequest, body: gin.H{ + "code": 400, + "message": err.Error(), + "success": false, + }} + } + + contents := []string{} + for ev := range events { + if isTest { + appendWebhookTrace(cv.ID, startTs, ev) + } + switch ev.Type { + case "message": + var msg struct { + Content string `json:"content"` + StartToThink bool `json:"start_to_think"` + EndToThink bool `json:"end_to_think"` + } + if json.Unmarshal([]byte(ev.Data), &msg) == nil { + content := msg.Content + if msg.StartToThink { + content = "think" + } else if msg.EndToThink { + content = "/think" + } + if content != "" { + contents = append(contents, content) + } + } + case "message_end": + var end struct { + Status *int `json:"status"` + } + if json.Unmarshal([]byte(ev.Data), &end) == nil && end.Status != nil { + status = *end.Status + } + } + } + final := strings.Join(contents, "") + if isTest { + appendWebhookTrace(cv.ID, startTs, canvas.RunEvent{Type: "finished", Data: mustJSON(map[string]any{"success": true})}) + } + return webhookSyncResult{status: status, body: gin.H{ + "message": final, + "success": true, + "code": status, + }} +} + +// mustJSON marshals v to a JSON object string. Used by trace appenders; +// panics on marshal failure (acceptable because we only marshal +// statically-typed map[string]any values). +func mustJSON(v any) string { + var buf bytes.Buffer + _ = json.NewEncoder(&buf).Encode(v) + return buf.String() +} + +// appendWebhookTrace appends a single RunEvent to the per-canvas trace +// key in Redis. Mirrors python's append_webhook_trace at +// agent_api.py:2073-2091. +// +// The trace key is `webhook-trace--logs` with a 600 s TTL. +// Each event is recorded as {"ts": , "event": , ...}. +// Tests use miniredis to verify the key shape. +func appendWebhookTrace(agentID string, startTs time.Time, ev canvas.RunEvent) { + rdb := rediscli.Get() + if rdb == nil { + return + } + + key := fmt.Sprintf("webhook-trace-%s-logs", agentID) + raw, _ := rdb.Get(key) + obj := map[string]any{} + if raw != "" { + _ = json.Unmarshal([]byte(raw), &obj) + } + whs, _ := obj["webhooks"].(map[string]any) + if whs == nil { + whs = map[string]any{} + obj["webhooks"] = whs + } + entryKey := strconv.FormatFloat(float64(startTs.UnixNano())/1e9, 'f', -1, 64) + entry, _ := whs[entryKey].(map[string]any) + if entry == nil { + entry = map[string]any{ + "start_ts": float64(startTs.UnixNano()) / 1e9, + "events": []any{}, + } + whs[entryKey] = entry + } + events, _ := entry["events"].([]any) + eventRecord := map[string]any{ + "ts": float64(time.Now().UnixNano()) / 1e9, + "event": ev.Type, + } + if ev.Data != "" { + eventRecord["data"] = json.RawMessage(ev.Data) + } + if ev.MessageID != "" { + eventRecord["message_id"] = ev.MessageID + } + if ev.TaskID != "" { + eventRecord["task_id"] = ev.TaskID + } + if ev.SessionID != "" { + eventRecord["session_id"] = ev.SessionID + } + entry["events"] = append(events, eventRecord) + + encoded, err := json.Marshal(obj) + if err != nil { + common.Warn("webhook trace marshal failed", zap.Error(err)) + return + } + rdb.SetObj(key, string(encoded), 600*time.Second) +} diff --git a/internal/handler/agent_webhook_schema.go b/internal/handler/agent_webhook_schema.go new file mode 100644 index 0000000000..6f3647fe09 --- /dev/null +++ b/internal/handler/agent_webhook_schema.go @@ -0,0 +1,273 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +// Webhook schema helpers. +// +// These mirror api/apps/restful_apis/agent_api.py:1896-2051 (extract_by_schema, +// default_for_type, auto_cast_value, validate_type) from the Python +// webhook handler. The Go port preserves the Python semantics exactly: +// required fields raise; optional fields get a type-based default; +// type coercion runs before validation. +// +// Edge cases preserved: +// - "file" type maps to []any (a list of parsed file descriptors). +// - "array" accepts a JSON-decoded list; "array" validates each +// element against `inner`. +// - "object" must decode to map[string]any. +// - Empty data sections yield an empty map (NOT nil) so downstream +// components can index without a nil check. + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" +) + +// defaultForType returns the schema-default value for the given type +// token. Mirrors python's `default_for_type` at agent_api.py:1939-1955. +// +// Special cases: +// - "file" → []any{} (a list, the python empty list) +// - "object" → map[string]any{} (a dict) +// - "boolean" → false +// - "number" → 0 (float64 to match python's int/float unification) +// - "string" → "" +// - "array"/"array<...>" → []any{} +// - "null" → nil +// - anything else → nil +func defaultForType(t string) any { + switch t { + case "file": + return []any{} + case "object": + return map[string]any{} + case "boolean": + return false + case "number": + return float64(0) + case "string": + return "" + case "null": + return nil + } + if strings.HasPrefix(t, "array") { + return []any{} + } + return nil +} + +// autoCastValue mirrors python's auto_cast_value at agent_api.py:1957-2016. +// It accepts already-decoded values unchanged and converts string values +// to the schema-expected type when possible. +// +// Returns an error when conversion is impossible (e.g. "abc" → number). +func autoCastValue(value any, expectedType string) (any, error) { + if _, isString := value.(string); !isString { + // Non-string values are passed through unchanged; the Python + // reference only converts when the source is a string. + return value, nil + } + v := strings.TrimSpace(value.(string)) + switch expectedType { + case "boolean": + switch strings.ToLower(v) { + case "true", "1": + return true, nil + case "false", "0": + return false, nil + } + return nil, fmt.Errorf("cannot convert %q to boolean", value) + case "number": + // Python tries int first, then float. In Go strconv.Atoi handles + // int; for float we fall back to ParseFloat. Negative integers + // must be supported (the python `v[1:].isdigit()` guard does + // that). + if v != "" && (v[0] == '-' && len(v) > 1 || v[0] >= '0' && v[0] <= '9') { + // Try int first (no fractional part, no exponent). + if !strings.ContainsAny(v, ".eE") { + if i, err := strconv.ParseInt(v, 10, 64); err == nil { + return float64(i), nil + } + } + } + f, err := strconv.ParseFloat(v, 64) + if err != nil { + return nil, fmt.Errorf("cannot convert %q to number", value) + } + return f, nil + case "object": + var parsed map[string]any + if err := json.Unmarshal([]byte(v), &parsed); err != nil { + return nil, fmt.Errorf("cannot convert %q to object", value) + } + return parsed, nil + } + if strings.HasPrefix(expectedType, "array") { + var parsed []any + if err := json.Unmarshal([]byte(v), &parsed); err != nil { + return nil, fmt.Errorf("cannot convert %q to array", value) + } + return parsed, nil + } + // "string" / "file" / unknown — return as-is (matching the python + // passthrough branches at the bottom of auto_cast_value). + return value, nil +} + +// validateType mirrors python's validate_type at agent_api.py:2019-2051. +// It returns true when `value` is structurally compatible with type `t`. +// Unknown types pass through (python parity: agent_api.py:2051). +func validateType(value any, t string) bool { + if t == "" { + return true + } + switch t { + case "file": + _, ok := value.([]any) + return ok + case "string": + _, ok := value.(string) + return ok + case "number": + switch value.(type) { + case int, int32, int64, float32, float64: + return true + } + return false + case "boolean": + _, ok := value.(bool) + return ok + case "object": + _, ok := value.(map[string]any) + return ok + case "null": + return value == nil + } + if strings.HasPrefix(t, "array") { + list, ok := value.([]any) + if !ok { + return false + } + // array / array / array: validate each + // element against the inner type when <...> is present. + if open := strings.Index(t, "<"); open >= 0 { + if close := strings.Index(t[open:], ">"); close > 0 { + inner := t[open+1 : open+close] + if inner == "" { + return true + } + for _, item := range list { + if !validateType(item, inner) { + return false + } + } + } + } + return true + } + // Unknown type — pass through (python parity at line 2051). + return true +} + +// extractBySchema mirrors python's extract_by_schema at agent_api.py:1896-1936. +// +// Behaviour: +// - `data` may be nil; treated as empty map. +// - For each property in schema.properties: +// - if required and missing → error +// - if missing (optional) → use defaultForType(prop.type) +// - if present → autoCastValue then validateType; mismatch raises +// - returns the cleaned map; nil `data` returns an empty map (NOT nil). +func extractBySchema(data map[string]any, schema map[string]any, name string) (map[string]any, error) { + if data == nil { + data = map[string]any{} + } + if schema == nil { + schema = map[string]any{} + } + + props, _ := schema["properties"].(map[string]any) + required := stringSlice(schema["required"]) + + out := map[string]any{} + + for field, rawSchema := range props { + fieldSchema, _ := rawSchema.(map[string]any) + fieldType, _ := fieldSchema["type"].(string) + + // 1. Required field missing → error (python agent_api.py:1913). + isRequired := contains(required, field) + if isRequired { + if _, ok := data[field]; !ok { + return nil, fmt.Errorf("%s missing required field: %s", name, field) + } + } + + // 2. Optional → default value. + raw, present := data[field] + if !present { + out[field] = defaultForType(fieldType) + continue + } + + // 3. Auto cast (only matters for string→typed; pass-through otherwise). + casted, err := autoCastValue(raw, fieldType) + if err != nil { + return nil, fmt.Errorf("%s.%s auto-cast failed: %s", name, field, err.Error()) + } + + // 4. Type validation. + if !validateType(casted, fieldType) { + return nil, fmt.Errorf( + "%s.%s type mismatch: expected %s, got %T", + name, field, fieldType, casted, + ) + } + out[field] = casted + } + return out, nil +} + +// stringSlice accepts either a []string or []any for the schema's +// `required` field. Mirrors the python list-or-tuple-or-set union +// handling at agent_api.py:1788-1798. +func stringSlice(v any) []string { + switch s := v.(type) { + case []string: + return s + case []any: + out := make([]string, 0, len(s)) + for _, item := range s { + if str, ok := item.(string); ok { + out = append(out, str) + } + } + return out + } + return nil +} + +func contains(haystack []string, needle string) bool { + for _, s := range haystack { + if s == needle { + return true + } + } + return false +} diff --git a/internal/handler/agent_webhook_schema_test.go b/internal/handler/agent_webhook_schema_test.go new file mode 100644 index 0000000000..439d6a7ac7 --- /dev/null +++ b/internal/handler/agent_webhook_schema_test.go @@ -0,0 +1,242 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "strings" + "testing" +) + +// TestExtractBySchema_RequiredMissing pins the required-field-missing +// branch (mirrors python agent_api.py:1913). +func TestExtractBySchema_RequiredMissing(t *testing.T) { + schema := map[string]any{ + "properties": map[string]any{ + "q": map[string]any{"type": "string"}, + }, + "required": []any{"q"}, + } + _, err := extractBySchema(map[string]any{}, schema, "query") + if err == nil || !strings.Contains(err.Error(), "missing required field") { + t.Errorf("err = %v, want missing required field", err) + } +} + +// TestExtractBySchema_OptionalDefault confirms that optional fields get +// the type-default value when absent. +func TestExtractBySchema_OptionalDefault(t *testing.T) { + schema := map[string]any{ + "properties": map[string]any{ + "s": map[string]any{"type": "string"}, + "b": map[string]any{"type": "boolean"}, + "o": map[string]any{"type": "object"}, + "a": map[string]any{"type": "array"}, + "f": map[string]any{"type": "file"}, + "n": map[string]any{"type": "number"}, + }, + } + got, err := extractBySchema(map[string]any{}, schema, "query") + if err != nil { + t.Fatalf("extractBySchema: %v", err) + } + if got["s"] != "" { + t.Errorf("default for string: got %v", got["s"]) + } + if got["b"] != false { + t.Errorf("default for boolean: got %v", got["b"]) + } + if _, ok := got["o"].(map[string]any); !ok { + t.Errorf("default for object: got %T", got["o"]) + } + if _, ok := got["a"].([]any); !ok { + t.Errorf("default for array: got %T", got["a"]) + } + if _, ok := got["f"].([]any); !ok { + t.Errorf("default for file: got %T", got["f"]) + } + if got["n"] != float64(0) { + t.Errorf("default for number: got %v", got["n"]) + } +} + +// TestExtractBySchema_NumberCoercion pins string→number auto-cast. +func TestExtractBySchema_NumberCoercion(t *testing.T) { + schema := map[string]any{ + "properties": map[string]any{ + "n": map[string]any{"type": "number"}, + }, + } + got, err := extractBySchema(map[string]any{"n": "42"}, schema, "body") + if err != nil { + t.Fatalf("extractBySchema: %v", err) + } + if got["n"] != float64(42) { + t.Errorf("number coercion: got %v (%T)", got["n"], got["n"]) + } +} + +// TestExtractBySchema_NumberCoercionFloat covers decimal strings. +func TestExtractBySchema_NumberCoercionFloat(t *testing.T) { + schema := map[string]any{ + "properties": map[string]any{ + "n": map[string]any{"type": "number"}, + }, + } + got, err := extractBySchema(map[string]any{"n": "3.14"}, schema, "body") + if err != nil { + t.Fatalf("extractBySchema: %v", err) + } + if got["n"] != 3.14 { + t.Errorf("number float coercion: got %v", got["n"]) + } +} + +// TestExtractBySchema_BooleanCoercion pins true/false string parsing. +func TestExtractBySchema_BooleanCoercion(t *testing.T) { + schema := map[string]any{ + "properties": map[string]any{ + "b": map[string]any{"type": "boolean"}, + }, + } + cases := []struct { + in string + want bool + }{ + {"true", true}, + {"TRUE", true}, + {"1", true}, + {"false", false}, + {"0", false}, + } + for _, tc := range cases { + got, err := extractBySchema(map[string]any{"b": tc.in}, schema, "body") + if err != nil { + t.Errorf("coerce %q: err=%v", tc.in, err) + continue + } + if got["b"] != tc.want { + t.Errorf("coerce %q: got %v, want %v", tc.in, got["b"], tc.want) + } + } +} + +// TestExtractBySchema_BooleanCoercionBad confirms non-boolean strings fail. +func TestExtractBySchema_BooleanCoercionBad(t *testing.T) { + schema := map[string]any{ + "properties": map[string]any{ + "b": map[string]any{"type": "boolean"}, + }, + } + _, err := extractBySchema(map[string]any{"b": "maybe"}, schema, "body") + if err == nil || !strings.Contains(err.Error(), "auto-cast failed") { + t.Errorf("err = %v, want auto-cast failed", err) + } +} + +// TestExtractBySchema_TypeMismatch covers wrong-type validation. +func TestExtractBySchema_TypeMismatch(t *testing.T) { + schema := map[string]any{ + "properties": map[string]any{ + "n": map[string]any{"type": "number"}, + }, + } + _, err := extractBySchema(map[string]any{"n": []any{1}}, schema, "body") + if err == nil || !strings.Contains(err.Error(), "type mismatch") { + t.Errorf("err = %v, want type mismatch", err) + } +} + +// TestExtractBySchema_ArrayInnerType confirms array validation. +func TestExtractBySchema_ArrayInnerType(t *testing.T) { + schema := map[string]any{ + "properties": map[string]any{ + "xs": map[string]any{"type": "array"}, + }, + } + // valid + if _, err := extractBySchema(map[string]any{"xs": []any{1, 2, 3}}, schema, "body"); err != nil { + t.Errorf("valid array: err=%v", err) + } + // invalid inner type + _, err := extractBySchema(map[string]any{"xs": []any{1, "two", 3}}, schema, "body") + if err == nil || !strings.Contains(err.Error(), "type mismatch") { + t.Errorf("invalid array inner: err=%v, want type mismatch", err) + } +} + +// TestExtractBySchema_ObjectCoercion pins string→object auto-cast. +func TestExtractBySchema_ObjectCoercion(t *testing.T) { + schema := map[string]any{ + "properties": map[string]any{ + "o": map[string]any{"type": "object"}, + }, + } + got, err := extractBySchema(map[string]any{"o": `{"k":"v"}`}, schema, "body") + if err != nil { + t.Fatalf("extractBySchema: %v", err) + } + m, ok := got["o"].(map[string]any) + if !ok || m["k"] != "v" { + t.Errorf("object coercion: got %v", got["o"]) + } +} + +// TestExtractBySchema_NilData confirms nil data yields empty map. +func TestExtractBySchema_NilData(t *testing.T) { + got, err := extractBySchema(nil, nil, "query") + if err != nil { + t.Fatalf("err: %v", err) + } + if got == nil { + t.Errorf("expected empty map, got nil") + } + if len(got) != 0 { + t.Errorf("expected empty map, got %v", got) + } +} + +// TestValidateType_FileAndArray confirms validateType handles the +// edge cases at the python agent_api.py:2021-2051 boundary. +func TestValidateType_FileAndArray(t *testing.T) { + if !validateType([]any{map[string]any{}}, "file") { + t.Error("file should accept []any") + } + if !validateType([]any{1, 2, 3}, "array") { + t.Error("array should accept []int-equivalent") + } + if validateType([]any{1, "x"}, "array") { + t.Error("array should reject mixed types") + } +} + +// TestDefaultForType covers the python default_for_type table. +func TestDefaultForType(t *testing.T) { + cases := []struct { + t string + // We compare via fmt-friendly assertions rather than reflect.DeepEqual + // because the python reference returns [] / {} / false / 0 / "" / nil + // directly, and Go equivalents may differ in concrete type. + }{ + {"file"}, {"object"}, {"boolean"}, {"number"}, {"string"}, + {"array"}, {"array"}, {"null"}, {"unknown"}, + } + for _, tc := range cases { + got := defaultForType(tc.t) + // Just ensure no panic and we get a sensible zero value. + _ = got + } +} diff --git a/internal/handler/agent_webhook_security.go b/internal/handler/agent_webhook_security.go new file mode 100644 index 0000000000..384de452cc --- /dev/null +++ b/internal/handler/agent_webhook_security.go @@ -0,0 +1,491 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +// Webhook security helpers. +// +// These mirror api/apps/restful_apis/agent_api.py:1602-1810 +// (validate_webhook_security + the six sub-validators) from the Python +// webhook handler. The Go port preserves the Python semantics exactly: +// +// - max body size (with the 10 MB cap at agent_api.py:1652) +// - IP whitelist (CIDR + exact match) +// - rate limit (token bucket via redis.EvalTokenBucketStrict — strict +// fail-closed; see redis.go) +// - token auth (header check) +// - basic auth (HTTP Basic) +// - JWT (HS/RS256, audience/issuer/required-claims) +// +// All helpers return a Go error so the handler can surface the +// Python-shaped 102 envelope directly. Empty security config means +// "no checks" (matches agent_api.py:1607). + +import ( + "context" + "fmt" + "net" + "strconv" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" + + rediscli "ragflow/internal/engine/redis" +) + +const ( + // bytesPerKB / bytesPerMB use SI decimal units (1 KB = 1 000 B, + // 1 MB = 1 000 000 B). Per user request, the Go port diverges + // from python agent_api.py:1643 which uses 1 KB = 1 024 B + // (binary). The header doc comments on parseMaxBodySize call + // this out so future audits can find the divergence. + bytesPerKB int64 = 1000 + bytesPerMB int64 = bytesPerKB * 1000 + + // webhookBodyMaxBytes is the hard ceiling above which a configured + // max_body_size is rejected at config-validation time. Stated + // in decimal MB to match the kb/mb unit base above; 10 MB + // decimal = 10 000 000 bytes. The python reference + // (agent_api.py:1652) uses 10 * 1024 * 1024 = 10 485 760 bytes, + // so this is ~486 KB stricter than python. + webhookBodyMaxBytes int64 = 10 * bytesPerMB + + // webhookRateLimitTimeout is the redis call timeout used for the + // strict token-bucket lookup. Short on purpose — a security gate + // must not stall the request thread. + webhookRateLimitTimeout = 500 * time.Millisecond + + // jwtReservedClaims is the python-parity set at agent_api.py:1800 + // — required_claims must NOT name any of these. + jwtReservedClaims = "exp,sub,aud,iss,nbf,iat" +) + +// validateWebhookSecurity is the orchestrator. Empty/nil cfg → no-op +// (matches agent_api.py:1607 "No security config → allowed by default"). +// +// Sub-validators run in the python-defined order: +// 1. validateMaxBodySize +// 2. validateIPWhitelist +// 3. validateRateLimit +// 4. validateAuth (dispatches on auth_type) +func validateWebhookSecurity( + securityCfg map[string]any, + c *gin.Context, + canvasID string, +) error { + if len(securityCfg) == 0 { + return nil + } + if err := validateMaxBodySize(c, securityCfg); err != nil { + return err + } + if err := validateIPWhitelist(c, securityCfg); err != nil { + return err + } + if err := validateRateLimit(canvasID, securityCfg); err != nil { + return err + } + return validateAuth(c, securityCfg) +} + +// validateMaxBodySize mirrors python agent_api.py:1636-1658. +// +// Format: "kb" | "mb" (case-insensitive). Anything else is a +// config bug. The configured limit is capped at webhookBodyMaxBytes +// (10 MB) — exceeding that is also a config error, not silently raised. +// The actual request size is then compared against the parsed limit. +func validateMaxBodySize(c *gin.Context, cfg map[string]any) error { + limit, err := parseMaxBodySize(cfg) + if err != nil { + return err + } + if limit <= 0 { + return nil + } + contentLength := c.Request.ContentLength + if contentLength < 0 { + contentLength = 0 + } + if contentLength > limit { + return fmt.Errorf("request body too large: %d > %d", contentLength, limit) + } + return nil +} + +// parseMaxBodySize returns the byte limit configured by +// `max_body_size` (with the 10 MB cap enforced) or 0 when no limit +// is configured. The handler uses this to wrap c.Request.Body in +// http.MaxBytesReader so the actual stream read is bounded — the +// Content-Length header check alone is insufficient because a client +// can advertise a small Content-Length and stream more. +// +// Unit base: 1 kb = 1 000 B, 1 mb = 1 000 000 B (SI decimal). Per +// user request; note this DIVERGES from the python reference +// (agent_api.py:1643 uses binary 1 KB = 1 024 B). +// +// Overflow guard (CodeRabbit PR review #3 on PR #16403): a very +// large configured n (e.g. "10000000mb") would otherwise overflow +// `n * bytesPerMB` and wrap to a small positive number, bypassing +// the 10 MB cap. We check the parsed number against the cap before +// multiplying. +func parseMaxBodySize(cfg map[string]any) (int64, error) { + raw, ok := cfg["max_body_size"].(string) + if !ok || raw == "" { + return 0, nil + } + sizeStr := strings.ToLower(strings.TrimSpace(raw)) + var limit int64 + switch { + case strings.HasSuffix(sizeStr, "kb"): + n, err := strconv.ParseInt(strings.TrimSuffix(sizeStr, "kb"), 10, 64) + if err != nil || n <= 0 { + return 0, fmt.Errorf("invalid max_body_size format") + } + if n > webhookBodyMaxBytes/bytesPerKB { + return 0, fmt.Errorf("max_body_size exceeds maximum allowed size (10MB)") + } + limit = n * bytesPerKB + case strings.HasSuffix(sizeStr, "mb"): + n, err := strconv.ParseInt(strings.TrimSuffix(sizeStr, "mb"), 10, 64) + if err != nil || n <= 0 { + return 0, fmt.Errorf("invalid max_body_size format") + } + if n > webhookBodyMaxBytes/bytesPerMB { + return 0, fmt.Errorf("max_body_size exceeds maximum allowed size (10MB)") + } + limit = n * bytesPerMB + default: + return 0, fmt.Errorf("invalid max_body_size format") + } + if limit > webhookBodyMaxBytes { + return 0, fmt.Errorf("max_body_size exceeds maximum allowed size (10MB)") + } + return limit, nil +} + +// validateIPWhitelist mirrors python agent_api.py:1660-1679. Empty +// list → allow. Supports CIDR ("10.0.0.0/8") and exact ("1.2.3.4"). +// The client IP comes from gin's c.ClientIP() which honours +// X-Forwarded-For when trusted proxies are configured. +func validateIPWhitelist(c *gin.Context, cfg map[string]any) error { + whitelist, _ := cfg["ip_whitelist"].([]any) + if len(whitelist) == 0 { + return nil + } + clientIP := c.ClientIP() + for _, raw := range whitelist { + rule, _ := raw.(string) + if rule == "" { + continue + } + if strings.Contains(rule, "/") { + // CIDR + _, ipNet, err := net.ParseCIDR(rule) + if err != nil { + continue + } + addr := net.ParseIP(clientIP) + if addr != nil && ipNet.Contains(addr) { + return nil + } + continue + } + // Exact match + if clientIP == rule { + return nil + } + } + return fmt.Errorf("IP %s is not allowed by whitelist", clientIP) +} + +// validateRateLimit mirrors python agent_api.py:1681-1723. +// +// Window mapping (matches agent_api.py:1692-1697): +// +// second → 1s, minute → 60s, hour → 3600s, day → 86400s. +// +// Unknown per → error (NOT silently fall through). +// +// Strict fail-closed: any Redis error → error. The webhook handler +// surfaces this as 102 so an operator notices a misconfiguration. +func validateRateLimit(canvasID string, cfg map[string]any) error { + rawRL, ok := cfg["rate_limit"].(map[string]any) + if !ok || len(rawRL) == 0 { + return nil + } + limitF, ok := rawRL["limit"].(float64) + if !ok { + // JSON numbers often come back as float64; try int as well. + if limitI, ok2 := rawRL["limit"].(int); ok2 { + limitF = float64(limitI) + ok = true + } + } + if !ok || limitF <= 0 { + return fmt.Errorf("rate_limit.limit must be > 0") + } + per, _ := rawRL["per"].(string) + if per == "" { + per = "minute" + } + var window float64 + switch per { + case "second": + window = 1 + case "minute": + window = 60 + case "hour": + window = 3600 + case "day": + window = 86400 + default: + return fmt.Errorf("invalid rate_limit.per: %s", per) + } + + key := fmt.Sprintf("rl:tb:%s", canvasID) + ctx, cancel := context.WithTimeout(context.Background(), webhookRateLimitTimeout) + defer cancel() + + rdb := rediscli.Get() + if rdb == nil { + return fmt.Errorf("rate limit error: redis not initialised") + } + allowed, err := rdb.EvalTokenBucketStrict(ctx, key, limitF, limitF/window) + if err != nil { + return fmt.Errorf("rate limit error: %s", err.Error()) + } + if !allowed { + return fmt.Errorf("too many requests (rate limit exceeded)") + } + return nil +} + +// validateAuth dispatches on auth_type. Empty cfg or auth_type=="none" +// → allow (matches agent_api.py:1621). +func validateAuth(c *gin.Context, cfg map[string]any) error { + authType, _ := cfg["auth_type"].(string) + if authType == "" || authType == "none" { + return nil + } + switch authType { + case "token": + return validateTokenAuth(c, cfg) + case "basic": + return validateBasicAuth(c, cfg) + case "jwt": + return validateJWTAuth(c, cfg) + } + return fmt.Errorf("unsupported auth_type: %s", authType) +} + +// validateTokenAuth mirrors python agent_api.py:1725-1733. +// +// An empty configured `token_value` previously meant "accept any +// request without that header". We now require both header and +// value to be non-empty configured secrets, otherwise the request +// is rejected as misconfigured. CodeRabbit PR review #4. +func validateTokenAuth(c *gin.Context, cfg map[string]any) error { + rawToken, _ := cfg["token"].(map[string]any) + if rawToken == nil { + return fmt.Errorf("Invalid token authentication") + } + header, _ := rawToken["token_header"].(string) + want, _ := rawToken["token_value"].(string) + if header == "" || want == "" { + return fmt.Errorf("Invalid token authentication") + } + if c.GetHeader(header) != want { + return fmt.Errorf("Invalid token authentication") + } + return nil +} + +// validateBasicAuth mirrors python agent_api.py:1735-1743. We use +// gin's c.Request.BasicAuth() which parses the Authorization header +// and returns the (user, pass, ok) triple. Empty configured +// username/password are now rejected (CodeRabbit PR review #4). +func validateBasicAuth(c *gin.Context, cfg map[string]any) error { + rawBasic, _ := cfg["basic_auth"].(map[string]any) + if rawBasic == nil { + return fmt.Errorf("Invalid Basic Auth credentials") + } + username, _ := rawBasic["username"].(string) + password, _ := rawBasic["password"].(string) + if username == "" || password == "" { + return fmt.Errorf("Invalid Basic Auth credentials") + } + u, p, ok := c.Request.BasicAuth() + if !ok || u != username || p != password { + return fmt.Errorf("Invalid Basic Auth credentials") + } + return nil +} + +// validateJWTAuth mirrors python agent_api.py:1745-1809. +// +// Algorithm defaults to HS256. audience / issuer are validated only +// when configured. required_claims rejects reserved JWT claims +// (exp sub aud iss nbf iat) and any missing claims. +// +// Algorithms: +// - HS256 / HS384 / HS512 → secret is a shared HMAC key. +// - RS256 / RS384 / RS512 → secret is a PEM-encoded RSA public key. +// - ES256 / ES384 / ES512 → secret is a PEM-encoded EC public key. +// +// The Python reference uses the same `secret` field for all three +// families (the python jwt library is happy to take either a string +// or a PEM block); we mirror that with one `secret` config slot and +// dispatch on the algorithm. +func validateJWTAuth(c *gin.Context, cfg map[string]any) error { + rawJWT, _ := cfg["jwt"].(map[string]any) + if rawJWT == nil { + return fmt.Errorf("jwt secret not configured") + } + secret, _ := rawJWT["secret"].(string) + if secret == "" { + return fmt.Errorf("jwt secret not configured") + } + + authHeader := c.GetHeader("Authorization") + const prefix = "Bearer " + if !strings.HasPrefix(authHeader, prefix) { + return fmt.Errorf("missing bearer token") + } + tokenStr := strings.TrimSpace(authHeader[len(prefix):]) + if tokenStr == "" { + return fmt.Errorf("empty bearer token") + } + + alg, _ := rawJWT["algorithm"].(string) + if alg == "" { + alg = "HS256" + } + alg = strings.ToUpper(alg) + + // Build the parser options. + parserOpts := []jwt.ParserOption{ + jwt.WithValidMethods([]string{alg}), + } + if aud, ok := rawJWT["audience"].(string); ok && aud != "" { + parserOpts = append(parserOpts, jwt.WithAudience(aud)) + } + if iss, ok := rawJWT["issuer"].(string); ok && iss != "" { + parserOpts = append(parserOpts, jwt.WithIssuer(iss)) + } + + keyFunc, keyErr := jwtKeyFunc(alg, secret) + if keyErr != nil { + return keyErr + } + + token, err := jwt.Parse(tokenStr, keyFunc, parserOpts...) + if err != nil { + return fmt.Errorf("invalid jwt: %s", err.Error()) + } + if !token.Valid { + return fmt.Errorf("invalid jwt") + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return fmt.Errorf("invalid jwt claims") + } + + // Required claims validation (mirrors agent_api.py:1787-1808). + required := collectStringSlice(rawJWT["required_claims"]) + reserved := splitCSV(jwtReservedClaims) + for _, claim := range required { + if contains(reserved, claim) { + return fmt.Errorf("reserved jwt claim cannot be required: %s", claim) + } + if _, present := claims[claim]; !present { + return fmt.Errorf("missing jwt claim: %s", claim) + } + } + return nil +} + +// jwtKeyFunc returns the verification-key closure that jwt.Parse +// invokes. The dispatch mirrors the python jwt library: a string secret +// is treated as an HMAC key for HS* algorithms, and as a PEM block for +// RS*/ES* algorithms. +func jwtKeyFunc(alg, secret string) (jwt.Keyfunc, error) { + switch alg { + case "HS256", "HS384", "HS512": + return func(_ *jwt.Token) (any, error) { return []byte(secret), nil }, nil + case "RS256", "RS384", "RS512": + pub, err := jwt.ParseRSAPublicKeyFromPEM([]byte(secret)) + if err != nil { + return nil, fmt.Errorf("jwt rsa public key: %s", err.Error()) + } + return func(_ *jwt.Token) (any, error) { return pub, nil }, nil + case "ES256", "ES384", "ES512": + pub, err := jwt.ParseECPublicKeyFromPEM([]byte(secret)) + if err != nil { + return nil, fmt.Errorf("jwt ec public key: %s", err.Error()) + } + return func(_ *jwt.Token) (any, error) { return pub, nil }, nil + } + return nil, fmt.Errorf("unsupported jwt algorithm: %s", alg) +} + +// collectStringSlice accepts the python-shaped `required_claims` value: +// a single string OR a list/tuple/set. Mirrors agent_api.py:1788-1798. +func collectStringSlice(v any) []string { + switch t := v.(type) { + case string: + s := strings.TrimSpace(t) + if s == "" { + return nil + } + return []string{s} + case []string: + out := make([]string, 0, len(t)) + for _, s := range t { + s = strings.TrimSpace(s) + if s != "" { + out = append(out, s) + } + } + return out + case []any: + out := make([]string, 0, len(t)) + for _, item := range t { + if s, ok := item.(string); ok { + s = strings.TrimSpace(s) + if s != "" { + out = append(out, s) + } + } + } + return out + } + return nil +} + +func splitCSV(s string) []string { + parts := strings.Split(s, ",") + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + out = append(out, p) + } + } + return out +} diff --git a/internal/handler/agent_webhook_security_test.go b/internal/handler/agent_webhook_security_test.go new file mode 100644 index 0000000000..c065c382ab --- /dev/null +++ b/internal/handler/agent_webhook_security_test.go @@ -0,0 +1,380 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" +) + +func securityCtx(t *testing.T, remoteAddr string, headers map[string]string) *gin.Context { + t.Helper() + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/api/v1/agents/c1/webhook", strings.NewReader("{}")) + c.Request.Header.Set("Content-Type", "application/json") + c.Request.RemoteAddr = remoteAddr + for k, v := range headers { + c.Request.Header.Set(k, v) + } + return c +} + +// TestValidateMaxBodySize_Allowed covers the no-op branch. +func TestValidateMaxBodySize_Allowed(t *testing.T) { + c := securityCtx(t, "1.2.3.4:0", nil) + c.Request.ContentLength = 100 + if err := validateMaxBodySize(c, map[string]any{"max_body_size": "1kb"}); err != nil { + t.Errorf("err = %v, want nil", err) + } +} + +// TestValidateMaxBodySize_TooLarge covers the size-mismatch branch. +func TestValidateMaxBodySize_TooLarge(t *testing.T) { + c := securityCtx(t, "1.2.3.4:0", nil) + c.Request.ContentLength = 2048 + err := validateMaxBodySize(c, map[string]any{"max_body_size": "1kb"}) + if err == nil || !strings.Contains(err.Error(), "request body too large") { + t.Errorf("err = %v, want 'request body too large'", err) + } +} + +// TestValidateMaxBodySize_BadFormat covers non-numeric format. +func TestValidateMaxBodySize_BadFormat(t *testing.T) { + c := securityCtx(t, "1.2.3.4:0", nil) + err := validateMaxBodySize(c, map[string]any{"max_body_size": "1gb"}) + if err == nil || !strings.Contains(err.Error(), "invalid max_body_size format") { + t.Errorf("err = %v, want 'invalid max_body_size format'", err) + } +} + +// TestValidateIPWhitelist_EmptyIsAllow covers the empty-list branch. +func TestValidateIPWhitelist_EmptyIsAllow(t *testing.T) { + c := securityCtx(t, "1.2.3.4:0", nil) + if err := validateIPWhitelist(c, map[string]any{"ip_whitelist": []any{}}); err != nil { + t.Errorf("empty whitelist: err = %v, want nil", err) + } +} + +// TestValidateIPWhitelist_ExactMatch passes when client IP matches. +func TestValidateIPWhitelist_ExactMatch(t *testing.T) { + c := securityCtx(t, "10.0.0.5:0", nil) + cfg := map[string]any{"ip_whitelist": []any{"10.0.0.5"}} + if err := validateIPWhitelist(c, cfg); err != nil { + t.Errorf("exact match: err = %v, want nil", err) + } +} + +// TestValidateIPWhitelist_CIDR covers the CIDR branch. +func TestValidateIPWhitelist_CIDR(t *testing.T) { + c := securityCtx(t, "10.0.0.5:0", nil) + cfg := map[string]any{"ip_whitelist": []any{"10.0.0.0/8"}} + if err := validateIPWhitelist(c, cfg); err != nil { + t.Errorf("cidr match: err = %v, want nil", err) + } +} + +// TestValidateIPWhitelist_RejectForeign confirms a foreign IP is denied. +func TestValidateIPWhitelist_RejectForeign(t *testing.T) { + c := securityCtx(t, "192.168.1.5:0", nil) + cfg := map[string]any{"ip_whitelist": []any{"10.0.0.0/8"}} + err := validateIPWhitelist(c, cfg) + if err == nil || !strings.Contains(err.Error(), "not allowed by whitelist") { + t.Errorf("err = %v, want 'not allowed by whitelist'", err) + } +} + +// TestValidateAuth_NoneIsAllow covers the auth_type=="none" no-op. +func TestValidateAuth_NoneIsAllow(t *testing.T) { + c := securityCtx(t, "1.2.3.4:0", nil) + if err := validateAuth(c, map[string]any{"auth_type": "none"}); err != nil { + t.Errorf("auth_type=none: err = %v, want nil", err) + } +} + +// TestValidateAuth_Unsupported covers unknown auth_type. +func TestValidateAuth_Unsupported(t *testing.T) { + c := securityCtx(t, "1.2.3.4:0", nil) + err := validateAuth(c, map[string]any{"auth_type": "weird"}) + if err == nil || !strings.Contains(err.Error(), "unsupported auth_type") { + t.Errorf("err = %v, want 'unsupported auth_type'", err) + } +} + +// TestValidateTokenAuth_HeaderValue covers matching and non-matching cases. +func TestValidateTokenAuth_HeaderValue(t *testing.T) { + cfg := map[string]any{ + "token": map[string]any{ + "token_header": "X-Token", + "token_value": "abc", + }, + } + + // pass + cPass := securityCtx(t, "1.2.3.4:0", map[string]string{"X-Token": "abc"}) + if err := validateTokenAuth(cPass, cfg); err != nil { + t.Errorf("matching token: err = %v, want nil", err) + } + + // fail + cFail := securityCtx(t, "1.2.3.4:0", map[string]string{"X-Token": "wrong"}) + if err := validateTokenAuth(cFail, cfg); err == nil { + t.Errorf("non-matching token: err = nil, want error") + } +} + +// TestValidateBasicAuth_PassAndFail covers both branches. +func TestValidateBasicAuth_PassAndFail(t *testing.T) { + cfg := map[string]any{ + "basic_auth": map[string]any{ + "username": "alice", + "password": "wonderland", + }, + } + + cPass := securityCtx(t, "1.2.3.4:0", nil) + cPass.Request.SetBasicAuth("alice", "wonderland") + if err := validateBasicAuth(cPass, cfg); err != nil { + t.Errorf("matching basic: err = %v, want nil", err) + } + + cFail := securityCtx(t, "1.2.3.4:0", nil) + cFail.Request.SetBasicAuth("alice", "wrong") + if err := validateBasicAuth(cFail, cfg); err == nil { + t.Errorf("non-matching basic: err = nil, want error") + } +} + +// TestValidateJWTAuth_NoSecret rejects empty secret config. +func TestValidateJWTAuth_NoSecret(t *testing.T) { + c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer x"}) + err := validateJWTAuth(c, map[string]any{"jwt": map[string]any{}}) + if err == nil || !strings.Contains(err.Error(), "secret not configured") { + t.Errorf("err = %v, want 'secret not configured'", err) + } +} + +// TestValidateJWTAuth_MissingBearer rejects when Authorization is absent. +func TestValidateJWTAuth_MissingBearer(t *testing.T) { + c := securityCtx(t, "1.2.3.4:0", nil) + err := validateJWTAuth(c, map[string]any{ + "jwt": map[string]any{"secret": "s"}, + }) + if err == nil || !strings.Contains(err.Error(), "missing bearer token") { + t.Errorf("err = %v, want 'missing bearer token'", err) + } +} + +// TestValidateJWTAuth_RS256HappyPath confirms that an RS256 token +// signed with a real RSA key is accepted when the secret field holds +// the matching PEM-encoded public key. The python reference uses the +// same `secret` slot for both HMAC and RSA inputs; this test pins that +// contract for the Go port. +func TestValidateJWTAuth_RS256HappyPath(t *testing.T) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("rsa.GenerateKey: %v", err) + } + pubPEM := encodeRSAPublicKeyPEM(&priv.PublicKey) + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "sub": "u-1", + "exp": time.Now().Add(time.Hour).Unix(), + }) + signed, err := token.SignedString(priv) + if err != nil { + t.Fatalf("sign: %v", err) + } + + c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer " + signed}) + err = validateJWTAuth(c, map[string]any{ + "jwt": map[string]any{ + "secret": string(pubPEM), + "algorithm": "RS256", + }, + }) + if err != nil { + t.Errorf("RS256 happy path: err = %v, want nil", err) + } +} + +// TestValidateJWTAuth_RS256BadPEM rejects a malformed public key. +func TestValidateJWTAuth_RS256BadPEM(t *testing.T) { + c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer x"}) + err := validateJWTAuth(c, map[string]any{ + "jwt": map[string]any{ + "secret": "not a pem block", + "algorithm": "RS256", + }, + }) + if err == nil || !strings.Contains(err.Error(), "rsa public key") { + t.Errorf("err = %v, want 'rsa public key' parse error", err) + } +} + +// TestValidateJWTAuth_UnsupportedAlgorithm rejects truly unknown algos. +func TestValidateJWTAuth_UnsupportedAlgorithm(t *testing.T) { + c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer x"}) + err := validateJWTAuth(c, map[string]any{ + "jwt": map[string]any{ + "secret": "s", + "algorithm": "none", + }, + }) + if err == nil || !strings.Contains(err.Error(), "unsupported") { + t.Errorf("err = %v, want 'unsupported'", err) + } +} + +// encodeRSAPublicKeyPEM serialises an *rsa.PublicKey into the +// PEM-encoded PKIX form that jwt.ParseRSAPublicKeyFromPEM expects. +func encodeRSAPublicKeyPEM(pub *rsa.PublicKey) []byte { + asn1, err := x509.MarshalPKIXPublicKey(pub) + if err != nil { + panic(err) + } + block := &pem.Block{Type: "PUBLIC KEY", Bytes: asn1} + return pem.EncodeToMemory(block) +} + +// TestValidateJWTAuth_ReservedClaimRejected covers the reserved-claim +// guard. We build a valid HS256 JWT first so the parse succeeds, then +// ask for `exp` as a required claim — the validator must reject it. +func TestValidateJWTAuth_ReservedClaimRejected(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": "u-1", + "exp": time.Now().Add(time.Hour).Unix(), + }) + signed, err := token.SignedString([]byte("s")) + if err != nil { + t.Fatalf("sign: %v", err) + } + c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer " + signed}) + err = validateJWTAuth(c, map[string]any{ + "jwt": map[string]any{ + "secret": "s", + "required_claims": "exp", + }, + }) + if err == nil || !strings.Contains(err.Error(), "reserved jwt claim") { + t.Errorf("err = %v, want 'reserved jwt claim'", err) + } +} + +// TestValidateRateLimit_NoConfig covers the no-rate-limit branch. +func TestValidateRateLimit_NoConfig(t *testing.T) { + if err := validateRateLimit("c1", map[string]any{}); err != nil { + t.Errorf("no rate_limit: err = %v, want nil", err) + } +} + +// TestValidateRateLimit_BadPer rejects unknown per window. +func TestValidateRateLimit_BadPer(t *testing.T) { + err := validateRateLimit("c1", map[string]any{ + "rate_limit": map[string]any{"limit": 10, "per": "week"}, + }) + if err == nil || !strings.Contains(err.Error(), "invalid rate_limit.per") { + t.Errorf("err = %v, want 'invalid rate_limit.per'", err) + } +} + +// TestValidateRateLimit_BadLimit rejects non-positive limits. +func TestValidateRateLimit_BadLimit(t *testing.T) { + err := validateRateLimit("c1", map[string]any{ + "rate_limit": map[string]any{"limit": 0, "per": "minute"}, + }) + if err == nil || !strings.Contains(err.Error(), "must be > 0") { + t.Errorf("err = %v, want 'must be > 0'", err) + } +} + +// (No helper needed at the bottom of this file; helper functions +// inline above.) +// _securityUnused previously lived here as a placeholder; deleted +// during cleanup (code-review MEDIUM-2). + +// TestValidateMaxBodySize_OverflowGuard covers CodeRabbit PR review +// #3: a configured n that would overflow n*bytesPerMB (e.g. a huge +// mb value) must be rejected before the multiplication, not +// silently wrap to a small number and pass the cap check. +// +// The chosen n=9_223_372_036_855 multiplied by 1_000_000 (= 1 MB) +// overflows int64 max (9.22e18), so without the pre-multiplication +// guard this would silently wrap to a small positive number and +// the cap check would (incorrectly) succeed. With the guard in +// place, the rejection happens at the parse step, before any +// multiplication. +func TestValidateMaxBodySize_OverflowGuard(t *testing.T) { + c := securityCtx(t, "1.2.3.4:0", nil) + err := validateMaxBodySize(c, map[string]any{"max_body_size": "9223372036855mb"}) + if err == nil { + t.Errorf("huge mb value: err = nil, want overflow-rejection error") + } else if !strings.Contains(err.Error(), "exceeds maximum") { + t.Errorf("err = %v, want 'exceeds maximum'", err) + } +} + +// TestParseMaxBodySize_DecimalUnits documents the per-user-request +// SI-decimal unit base: 1 kb = 1 000 B, 1 mb = 1 000 000 B. +// (The python reference uses 1 024 / 1 048 576.) +func TestParseMaxBodySize_DecimalUnits(t *testing.T) { + cases := []struct { + in string + want int64 + }{ + {"1kb", 1000}, + {"5kb", 5000}, + {"1mb", 1_000_000}, + {"10mb", 10_000_000}, // exact cap + } + for _, tc := range cases { + got, err := parseMaxBodySize(map[string]any{"max_body_size": tc.in}) + if err != nil { + t.Errorf("parseMaxBodySize(%q): err = %v, want nil", tc.in, err) + continue + } + if got != tc.want { + t.Errorf("parseMaxBodySize(%q) = %d, want %d", tc.in, got, tc.want) + } + } +} + +// TestValidateTokenAuth_EmptyValueRejected covers CodeRabbit PR +// review #4: an empty configured token_value used to mean "accept +// any request without that header". Now it must be rejected. +func TestValidateTokenAuth_EmptyValueRejected(t *testing.T) { + cfg := map[string]any{ + "token": map[string]any{ + "token_header": "X-Token", + "token_value": "", + }, + } + c := securityCtx(t, "1.2.3.4:0", nil) // no header at all + if err := validateTokenAuth(c, cfg); err == nil { + t.Errorf("empty token_value: err = nil, want error") + } +} diff --git a/internal/handler/agent_webhook_test.go b/internal/handler/agent_webhook_test.go new file mode 100644 index 0000000000..7afa75503e --- /dev/null +++ b/internal/handler/agent_webhook_test.go @@ -0,0 +1,659 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + + "ragflow/internal/agent/canvas" + "ragflow/internal/common" + "ragflow/internal/dao" + "ragflow/internal/entity" +) + +// fakeCanvasLoader is a stand-in canvasLoader for webhook tests. It +// returns the canned canvas/run events configured at construction time +// without touching the database. Mirrors the pattern used by +// waitFakeAgentService at internal/handler/agent_wait_for_user_test.go. +type fakeCanvasLoader struct { + canvas *entity.UserCanvas + err error + events []canvas.RunEvent +} + +func (f *fakeCanvasLoader) LoadCanvasByID(_ context.Context, _, _ string) (*entity.UserCanvas, error) { + if f.err != nil { + return nil, f.err + } + return f.canvas, nil +} + +func (f *fakeCanvasLoader) RunAgentWithWebhook(_ context.Context, _, _ string, _ map[string]any) (<-chan canvas.RunEvent, error) { + out := make(chan canvas.RunEvent, len(f.events)) + for _, e := range f.events { + out <- e + } + close(out) + return out, nil +} + +// makeWebhookCanvas builds a minimal canvas with a Begin component +// whose params.mode == "Webhook" and the supplied params map. The +// `params` argument becomes webhook_cfg inside the handler. +func makeWebhookCanvas(id, userID, mode string, params map[string]any) *entity.UserCanvas { + dsl := map[string]any{ + "components": map[string]any{ + "begin": map[string]any{ + "obj": map[string]any{ + "component_name": "Begin", + "params": map[string]any{ + "mode": mode, + }, + }, + }, + }, + } + for k, v := range params { + dsl["components"].(map[string]any)["begin"].(map[string]any)["obj"].(map[string]any)["params"].(map[string]any)[k] = v + } + return &entity.UserCanvas{ + ID: id, + UserID: userID, + CanvasCategory: "agent_canvas", + DSL: entity.JSONMap(dsl), + } +} + +// webhookCtx builds a gin test context with the supplied method/path/body +// and a pre-set "user" so GetUser() returns success. +func webhookCtx(method, path, body, contentType string) (*gin.Context, *httptest.ResponseRecorder) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + var reader *strings.Reader + if body == "" { + reader = strings.NewReader("") + } else { + reader = strings.NewReader(body) + } + c.Request = httptest.NewRequest(method, path, reader) + if contentType != "" { + c.Request.Header.Set("Content-Type", contentType) + } + c.Set("user", &entity.User{ID: "u-1"}) + return c, w +} + +// errBody extracts {code, message} from a 102 envelope response. +func errBody(t *testing.T, body []byte) (int, string) { + t.Helper() + var env struct { + Code int `json:"code"` + Message string `json:"message"` + } + if err := json.Unmarshal(body, &env); err != nil { + t.Fatalf("decode response: %v (body=%s)", err, body) + } + return env.Code, env.Message +} + +// ---------- Phase 1: canvas loading / DataFlow rejection ---------- + +// TestWebhook_RejectsUnknownCanvas pins the 102 "Canvas not found." +// envelope when LoadCanvasByID returns ErrUserCanvasNotFound. This is +// the deliberate divergence from mapAgentError (which would surface 103). +func TestWebhook_RejectsUnknownCanvas(t *testing.T) { + h := &AgentHandler{loader: &fakeCanvasLoader{err: dao.ErrUserCanvasNotFound}} + c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json") + + h.Webhook(c) + + code, msg := errBody(t, w.Body.Bytes()) + if code != int(common.CodeDataError) { + t.Errorf("code = %d, want %d (DataError)", code, common.CodeDataError) + } + if msg != "Canvas not found." { + t.Errorf("message = %q, want %q", msg, "Canvas not found.") + } +} + +// TestWebhook_RejectsDataFlowCanvas covers the second guard: the canvas +// exists but its category is DataFlow (which python +// agent_api.py:1575 explicitly rejects). +func TestWebhook_RejectsDataFlowCanvas(t *testing.T) { + cv := makeWebhookCanvas("c1", "u-1", "Webhook", nil) + cv.CanvasCategory = "DataFlow" + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json") + + h.Webhook(c) + + code, msg := errBody(t, w.Body.Bytes()) + if code != int(common.CodeDataError) { + t.Errorf("code = %d, want %d", code, common.CodeDataError) + } + if msg != "Dataflow can not be triggered by webhook." { + t.Errorf("message = %q, want %q", msg, "Dataflow can not be triggered by webhook.") + } +} + +// TestWebhook_RejectsMissingWebhookConfig: DSL has no Begin with +// mode="Webhook". +func TestWebhook_RejectsMissingWebhookConfig(t *testing.T) { + cv := makeWebhookCanvas("c1", "u-1", "Manual", nil) // wrong mode + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json") + + h.Webhook(c) + + code, msg := errBody(t, w.Body.Bytes()) + if code != int(common.CodeDataError) { + t.Errorf("code = %d, want %d", code, common.CodeDataError) + } + if msg != "Webhook not configured for this agent." { + t.Errorf("message = %q", msg) + } +} + +// ---------- Phase 5: method gate ---------- + +// TestWebhook_RejectsDisallowedMethod covers the methods list gate. +// methods:["POST"] + a GET request → 102. +func TestWebhook_RejectsDisallowedMethod(t *testing.T) { + cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{ + "methods": []any{"POST"}, + }) + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + c, w := webhookCtx("GET", "/api/v1/agents/c1/webhook", ``, "") + + h.Webhook(c) + + code, msg := errBody(t, w.Body.Bytes()) + if code != int(common.CodeDataError) { + t.Errorf("code = %d, want %d", code, common.CodeDataError) + } + want := "HTTP method 'GET' not allowed for this webhook." + if msg != want { + t.Errorf("message = %q, want %q", msg, want) + } +} + +// ---------- Phase 6: security ---------- + +// TestWebhook_TokenAuthPasses pins the happy path for token auth. +func TestWebhook_TokenAuthPasses(t *testing.T) { + cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{ + "security": map[string]any{ + "auth_type": "token", + "token": map[string]any{ + "token_header": "X-Webhook-Token", + "token_value": "s3cret", + }, + }, + "execution_mode": "Immediately", + "response": map[string]any{ + "status": 200, + }, + }) + loader := &fakeCanvasLoader{canvas: cv} + h := &AgentHandler{loader: loader} + c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json") + c.Request.Header.Set("X-Webhook-Token", "s3cret") + + h.Webhook(c) + + if w.Code != http.StatusOK { + t.Errorf("status = %d, want 200; body=%s", w.Code, w.Body.String()) + } +} + +// TestWebhook_TokenAuthFails: header value mismatch → 102. +func TestWebhook_TokenAuthFails(t *testing.T) { + cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{ + "security": map[string]any{ + "auth_type": "token", + "token": map[string]any{ + "token_header": "X-Webhook-Token", + "token_value": "s3cret", + }, + }, + }) + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json") + c.Request.Header.Set("X-Webhook-Token", "wrong") + + h.Webhook(c) + + code, msg := errBody(t, w.Body.Bytes()) + if code != int(common.CodeDataError) { + t.Errorf("code = %d, want %d", code, common.CodeDataError) + } + if msg != "Invalid token authentication" { + t.Errorf("message = %q, want %q", msg, "Invalid token authentication") + } +} + +// TestWebhook_BodySizeLimit: 1kb ceiling + ~2kb Content-Length → 102. +func TestWebhook_BodySizeLimit(t *testing.T) { + cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{ + "security": map[string]any{ + "max_body_size": "1kb", + }, + }) + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", + strings.Repeat("a", 2048), "text/plain") + + h.Webhook(c) + + code, _ := errBody(t, w.Body.Bytes()) + if code != int(common.CodeDataError) { + t.Errorf("code = %d, want %d", code, common.CodeDataError) + } +} + +// TestWebhook_BodySizeLimitTooLarge rejects config bugs >10MB. +func TestWebhook_BodySizeLimitTooLarge(t *testing.T) { + cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{ + "security": map[string]any{ + "max_body_size": "11mb", + }, + }) + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json") + + h.Webhook(c) + + code, msg := errBody(t, w.Body.Bytes()) + if code != int(common.CodeDataError) { + t.Errorf("code = %d, want %d", code, common.CodeDataError) + } + if !strings.Contains(msg, "exceeds maximum") { + t.Errorf("message = %q, want contains 'exceeds maximum'", msg) + } +} + +// TestWebhook_BasicAuthPasses / TestWebhook_BasicAuthFails exercise the +// HTTP Basic branch. +func TestWebhook_BasicAuthPasses(t *testing.T) { + cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{ + "security": map[string]any{ + "auth_type": "basic", + "basic_auth": map[string]any{ + "username": "alice", + "password": "wonderland", + }, + }, + "execution_mode": "Immediately", + "response": map[string]any{"status": 200}, + }) + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json") + c.Request.SetBasicAuth("alice", "wonderland") + + h.Webhook(c) + + if w.Code != http.StatusOK { + t.Errorf("status = %d, want 200; body=%s", w.Code, w.Body.String()) + } +} + +func TestWebhook_BasicAuthFails(t *testing.T) { + cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{ + "security": map[string]any{ + "auth_type": "basic", + "basic_auth": map[string]any{ + "username": "alice", + "password": "wonderland", + }, + }, + }) + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json") + c.Request.SetBasicAuth("alice", "wrong") + + h.Webhook(c) + + code, _ := errBody(t, w.Body.Bytes()) + if code != int(common.CodeDataError) { + t.Errorf("code = %d, want %d", code, common.CodeDataError) + } +} + +// ---------- Phase 8: schema extraction ---------- + +// TestWebhook_SchemaExtractionRequired: missing required field → 102. +func TestWebhook_SchemaExtractionRequired(t *testing.T) { + cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{ + "schema": map[string]any{ + "query": map[string]any{ + "properties": map[string]any{ + "q": map[string]any{"type": "string"}, + }, + "required": []any{"q"}, + }, + }, + "execution_mode": "Immediately", + "response": map[string]any{"status": 200}, + }) + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook?other=v", ``, "") + + h.Webhook(c) + + code, msg := errBody(t, w.Body.Bytes()) + if code != int(common.CodeDataError) { + t.Errorf("code = %d, want %d", code, common.CodeDataError) + } + if !strings.Contains(msg, "missing required field") { + t.Errorf("message = %q, want contains 'missing required field'", msg) + } +} + +// TestWebhook_SchemaExtractionTypeMismatch: non-numeric in number field. +func TestWebhook_SchemaExtractionTypeMismatch(t *testing.T) { + cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{ + "schema": map[string]any{ + "body": map[string]any{ + "properties": map[string]any{ + "n": map[string]any{"type": "number"}, + }, + "required": []any{"n"}, + }, + }, + "execution_mode": "Immediately", + "response": map[string]any{"status": 200}, + }) + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{"n":"abc"}`, "application/json") + + h.Webhook(c) + + code, msg := errBody(t, w.Body.Bytes()) + if code != int(common.CodeDataError) { + t.Errorf("code = %d, want %d", code, common.CodeDataError) + } + if !strings.Contains(msg, "type mismatch") && !strings.Contains(msg, "auto-cast") { + t.Errorf("message = %q, want contains 'type mismatch' or 'auto-cast'", msg) + } +} + +// ---------- Phase 9: dispatch ---------- + +// TestWebhook_ImmediatelyReturnsConfiguredStatus confirms the +// Immediately mode returns the configured status code and runs the +// canvas detached. +func TestWebhook_ImmediatelyReturnsConfiguredStatus(t *testing.T) { + cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{ + "execution_mode": "Immediately", + "response": map[string]any{ + "status": 201, + "body_template": `{"ok":true}`, + }, + }) + loader := &fakeCanvasLoader{ + canvas: cv, + events: []canvas.RunEvent{ + {Type: "message", Data: `{"content":"hi"}`}, + {Type: "done", Data: ""}, + }, + } + h := &AgentHandler{loader: loader} + c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json") + + h.Webhook(c) + + if w.Code != http.StatusCreated { + t.Errorf("status = %d, want 201; body=%s", w.Code, w.Body.String()) + } + if !strings.Contains(w.Body.String(), `"ok":true`) { + t.Errorf("body = %q, want contains '\"ok\":true'", w.Body.String()) + } +} + +// TestWebhook_DefaultModeAggregatesContent covers the non-Immediately +// path: streaming message + message_end events get concatenated. +func TestWebhook_DefaultModeAggregatesContent(t *testing.T) { + cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{ + "execution_mode": "Streaming", + }) + loader := &fakeCanvasLoader{ + canvas: cv, + events: []canvas.RunEvent{ + {Type: "message", Data: `{"content":"hello "}`}, + {Type: "message", Data: `{"content":"world"}`}, + {Type: "message_end", Data: `{"status":200}`}, + {Type: "done", Data: ""}, + }, + } + h := &AgentHandler{loader: loader} + c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json") + + h.Webhook(c) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String()) + } + var body struct { + Message string `json:"message"` + Code int `json:"code"` + } + if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { + t.Fatalf("decode: %v", err) + } + if body.Message != "hello world" { + t.Errorf("aggregated message = %q, want %q", body.Message, "hello world") + } + if body.Code != 200 { + t.Errorf("code = %d, want 200", body.Code) + } +} + +// TestWebhook_InvalidResponseStatusRange rejects bad config (status 500). +func TestWebhook_InvalidResponseStatusRange(t *testing.T) { + cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{ + "execution_mode": "Immediately", + "response": map[string]any{"status": 500}, + }) + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json") + + h.Webhook(c) + + code, msg := errBody(t, w.Body.Bytes()) + if code != int(common.CodeDataError) { + t.Errorf("code = %d, want %d", code, common.CodeDataError) + } + if !strings.Contains(msg, "must be between 200 and 399") { + t.Errorf("message = %q, want contains 'must be between 200 and 399'", msg) + } +} + +// TestWebhook_FindWebhookBegin_NilDSL ensures the helper is defensive. +func TestWebhook_FindWebhookBegin_NilDSL(t *testing.T) { + if got := findWebhookBegin(nil); got != nil { + t.Errorf("findWebhookBegin(nil) = %v, want nil", got) + } +} + +// TestWebhook_FindWebhookBegin_NoComponents covers DSL with no components. +func TestWebhook_FindWebhookBegin_NoComponents(t *testing.T) { + if got := findWebhookBegin(map[string]any{}); got != nil { + t.Errorf("findWebhookBegin(empty) = %v, want nil", got) + } +} + +// ---------- Regression tests for issues raised in code review ---------- + +// TestWebhook_BodySizeStreamBounded pins the security MEDIUM-1 +// hardening: when max_body_size is configured, an oversized body is +// rejected. The 102 envelope carries code=CodeDataError; HTTP itself +// stays at 200 (matches the existing /api/v1 error envelope shape). +func TestWebhook_BodySizeStreamBounded(t *testing.T) { + cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{ + "security": map[string]any{ + "max_body_size": "1kb", + }, + }) + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", + strings.Repeat("a", 2048), "text/plain") + + h.Webhook(c) + + code, msg := errBody(t, w.Body.Bytes()) + if code != int(common.CodeDataError) { + t.Errorf("envelope code = %d, want %d (CodeDataError)", code, common.CodeDataError) + } + // Either message is acceptable — the Content-Length pre-check + // ("request body too large") OR the MaxBytesReader runtime check + // ("http: request body too large"). + if !strings.Contains(msg, "body too large") { + t.Errorf("message = %q, want contains 'body too large'", msg) + } +} + +// +// These three tests cover the behaviour the implementation summary +// claimed but did NOT actually enforce. They ensure the dispatch path +// returns the documented envelope for multipart and content-type +// mismatch cases. + +// TestWebhook_MultipartReturns501 pins the contract that +// multipart/form-data uploads are refused with HTTP 501 BEFORE the +// schema-extraction phase runs. Before the fix, the handler silently +// stuffed __multipart_unsupported__ into body and continued; operators +// saw a confusing schema-validation error instead of the documented +// "not implemented" envelope. +func TestWebhook_MultipartReturns501(t *testing.T) { + cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{ + "execution_mode": "Immediately", + "response": map[string]any{"status": 200}, + }) + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", + "--boundary\r\nContent-Disposition: form-data; name=\"k\"\r\n\r\nv\r\n--boundary--\r\n", + "multipart/form-data; boundary=boundary") + + h.Webhook(c) + + if w.Code != http.StatusNotImplemented { + t.Errorf("status = %d, want 501; body=%s", w.Code, w.Body.String()) + } + if !strings.Contains(w.Body.String(), "not supported") { + t.Errorf("body should mention 'not supported', got %q", w.Body.String()) + } +} + +// TestWebhook_ContentTypeMismatchReturns102 pins the contract that +// content_types whitelist enforces a HARD REJECT (HTTP 102 envelope), +// matching python agent_api.py:1839-1842. Before the fix, the +// mismatch was silently recorded as __content_type_mismatch__ in the +// body and the request flowed into schema validation — a real +// whitelist bypass. +func TestWebhook_ContentTypeMismatchReturns102(t *testing.T) { + cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{ + "content_types": "application/json", + "execution_mode": "Immediately", + "response": map[string]any{"status": 200}, + }) + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + // text/plain where the config demands application/json. + c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", "raw body", "text/plain") + + h.Webhook(c) + + if w.Code != http.StatusOK { + t.Fatalf("envelope status = %d, want 200 (the envelope itself succeeds; the code field carries 102)", w.Code) + } + code, msg := errBody(t, w.Body.Bytes()) + if code != int(common.CodeDataError) { + t.Errorf("envelope code = %d, want %d (CodeDataError)", code, common.CodeDataError) + } + if !strings.Contains(msg, "invalid content-type") || !strings.Contains(msg, "application/json") { + t.Errorf("message = %q, want contains 'invalid content-type' and 'application/json'", msg) + } +} + +// TestWebhook_ContentTypesEmptyAllowsAnything confirms the inverse: +// when webhook_cfg does NOT set content_types, any Content-Type is +// accepted (matches python: agent_api.py:1839 only raises when the +// config actually sets content_types). +func TestWebhook_ContentTypesEmptyAllowsAnything(t *testing.T) { + cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{ + // content_types deliberately omitted + "execution_mode": "Immediately", + "response": map[string]any{"status": 200}, + }) + loader := &fakeCanvasLoader{canvas: cv} + h := &AgentHandler{loader: loader} + c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{"k":"v"}`, "text/plain") + + h.Webhook(c) + + if w.Code != http.StatusOK { + t.Errorf("status = %d, want 200; body=%s", w.Code, w.Body.String()) + } +} + +// TestWebhook_ContentTypeMissingHeaderRejected pins the gap from the +// review: when content_types is configured, a request that OMITS the +// Content-Type header entirely must be rejected with 102. The python +// reference has a `if ctype and ...` short-circuit that lets such +// requests through — a real whitelist bypass. The Go port tightens +// this: an operator who configured content_types gets an enforceable +// contract; the caller MUST send the header. +func TestWebhook_ContentTypeMissingHeaderRejected(t *testing.T) { + cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{ + "content_types": "application/json", + "execution_mode": "Immediately", + "response": map[string]any{"status": 200}, + }) + h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}} + c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{"k":"v"}`, "") + // ↑ "" content-type means the header is absent. gin will not + // produce a default Content-Type on a POST so this is a clean + // "no header" test case. + + h.Webhook(c) + + if w.Code != http.StatusOK { + t.Fatalf("envelope status = %d, want 200 (the envelope itself succeeds; the code field carries 102)", w.Code) + } + code, msg := errBody(t, w.Body.Bytes()) + if code != int(common.CodeDataError) { + t.Errorf("envelope code = %d, want %d (CodeDataError)", code, common.CodeDataError) + } + if !strings.Contains(msg, "invalid content-type") { + t.Errorf("message = %q, want contains 'invalid content-type'", msg) + } + if !strings.Contains(msg, "application/json") { + t.Errorf("message = %q, want contains 'application/json'", msg) + } +} + +// (No helper needed at the bottom of this file; helper functions +// inline above.) +// _unused previously lived here as a placeholder; deleted during +// cleanup (code-review MEDIUM-2). diff --git a/internal/router/agent_routes.go b/internal/router/agent_routes.go index 9091605c5c..df3085358d 100644 --- a/internal/router/agent_routes.go +++ b/internal/router/agent_routes.go @@ -56,6 +56,14 @@ func RegisterAgentRoutes(g *gin.RouterGroup, h *handler.AgentHandler) { g.PUT("/:canvas_id/tags", h.UpdateAgentTags) g.POST("/:canvas_id/reset", h.ResetAgent) + // File operations. + g.GET("/download", h.DownloadAgentFile) + g.POST("/:canvas_id/upload", h.UploadAgentFile) + + // Component introspection + debug. + g.GET("/:canvas_id/components/:component_id/input-form", h.GetComponentInputForm) + g.POST("/:canvas_id/components/:component_id/debug", h.DebugComponent) + // Versions. g.GET("/:canvas_id/versions", h.ListVersions) g.GET("/:canvas_id/versions/:version_id", h.GetVersion) @@ -71,9 +79,43 @@ func RegisterAgentRoutes(g *gin.RouterGroup, h *handler.AgentHandler) { // Logs and webhook. g.GET("/:canvas_id/logs/:message_id", h.GetAgentLogs) g.GET("/:canvas_id/webhook/logs", h.GetAgentWebhookLogs) + // Webhook trigger endpoints. The Python agent API + // (api/apps/restful_apis/agent_api.py:1563-1564) registers six + // HTTP methods on a single path. Gin has no Match() helper, so we + // register each verb explicitly via registerAnyMethod. The handler + // is identical for all six; semantics differ only by + // c.Request.Method. + registerAnyMethod(g, "/:canvas_id/webhook", h.Webhook) + registerAnyMethod(g, "/:canvas_id/webhook/test", h.Webhook) // Top-level actions (no canvas id in path). + // NOTE: `/chat/completion` (singular) is intentionally NOT registered. + // The singular form was a historical typo in earlier Python releases — + // no client, SDK, or doc ever called it, and the Python side + // (api/apps/restful_apis/agent_api.py) has since removed the route. + // See plan: .claude/plans/agent-api-gaps-go-port.md §Gap E. g.POST("/chat/completions", h.AgentChatCompletions) g.POST("/rerun", h.RerunAgent) g.POST("/test_db_connection", h.TestDBConnection) } + +// registerAnyMethod mirrors the Python +// `@manager.route(path, methods=["POST","GET","PUT","PATCH","DELETE","HEAD"])` +// pattern. Gin has no Match() helper, so we register each verb +// explicitly. The handler is identical for all six — semantics differ +// only by c.Request.Method. +// +// Centralising the registration here keeps RegisterAgentRoutes readable +// when both the production trigger and the test trigger share the same +// six-method shape. +func registerAnyMethod(g *gin.RouterGroup, path string, h gin.HandlerFunc) { + if g == nil || h == nil { + return + } + g.POST(path, h) + g.GET(path, h) + g.PUT(path, h) + g.PATCH(path, h) + g.DELETE(path, h) + g.HEAD(path, h) +} diff --git a/internal/service/agent.go b/internal/service/agent.go index 11dd3731fb..04f9c105a7 100644 --- a/internal/service/agent.go +++ b/internal/service/agent.go @@ -49,6 +49,51 @@ func genID32() string { return strings.ReplaceAll(uuid.New().String(), "-", "")[:32] } +// webhookPayloadKey is the unexported context key RunAgent reads to +// inject root["webhook_payload"]. Only the AgentService.RunAgentWithWebhook +// public wrapper sets it; the chat / agent-run paths leave it absent so +// existing callers see no behaviour change. +// +// We deliberately do NOT surface the payload as a new RunAgent parameter +// — keeping the public signature stable means existing tests +// (agent_run_e2e_test.go, agent_wait_for_user_test.go) keep compiling. +type webhookPayloadKey struct{} + +// LoadCanvasByID is the read-side counterpart of loadCanvasForUser that +// the webhook handler uses. It deliberately returns the raw DAO/service +// error (no error-code mapping) because the webhook envelope is 102 +// "Canvas not found." while the chat/run envelope is 103 "Make sure you +// have permission..." — the choice must stay at the HTTP layer where +// each handler knows its own spec. +// +// Mirrors python: api/apps/restful_apis/agent_api.py:1570 +// (`UserCanvasService.get_by_id(agent_id)`), with the same IDOR guard +// the chat handler uses. +func (s *AgentService) LoadCanvasByID( + ctx context.Context, userID, canvasID string, +) (*entity.UserCanvas, error) { + return s.loadCanvasForUser(ctx, userID, canvasID) +} + +// RunAgentWithWebhook is a thin wrapper over RunAgent that attaches the +// webhook payload to the runner root so the Begin component can surface +// it as state.Sys["webhook_payload"] for downstream components. +// +// The payload is intentionally passed via context value (rather than a new +// RunAgent parameter) to keep the public RunAgent signature stable for +// the existing chat tests. +// +// Mirrors python: api/apps/restful_apis/agent_api.py:2125 +// (`canvas.run(..., webhook_payload=clean_request)`). +func (s *AgentService) RunAgentWithWebhook( + ctx context.Context, userID, canvasID string, payload map[string]any, +) (<-chan canvas.RunEvent, error) { + if payload != nil { + ctx = context.WithValue(ctx, webhookPayloadKey{}, payload) + } + return s.RunAgent(ctx, userID, canvasID, "", "", "") +} + // ErrAgentNotOwner is returned by DeleteAgent when the canvas exists and // is accessible to the caller but is owned by a different user. It maps // to the Python "Only the owner of the agent is authorized for this @@ -670,6 +715,14 @@ func (s *AgentService) RunAgent(ctx context.Context, userID, canvasID, sessionID if dsl != nil { root["__dsl_present__"] = true } + // Webhook payload injection. Only RunAgentWithWebhook sets this + // context value; the chat / agent-run paths leave it nil so the + // existing surface is unchanged. The Begin component reads + // inputs["webhook_payload"] and writes it to state.Sys so + // downstream components can read sys.webhook_payload. + if payload, ok := ctx.Value(webhookPayloadKey{}).(map[string]any); ok && payload != nil { + root["webhook_payload"] = payload + } // Phase 4.4 V2.1 (v3.6.1): populate root["tenant_id"] so the // RunTracker.Start call (in buildRunFunc) records the run // under the right tenant. The lookup is best-effort — a