mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
feat[Go]: port agent webhook trigger, agent file upload/download, component input-form + debug endpoints from Python (#16403)
port agent webhook trigger, agent file upload/download, component input-form + debug endpoints from Python - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -21,7 +21,7 @@ RESTful API migration. Each deprecated route forwards to the corresponding
|
||||
new API implementation.
|
||||
|
||||
Deprecated APIs and their replacements:
|
||||
- POST /api/v1/agents/{agent_id}/completions -> POST /api/v1/agents/chat/completion
|
||||
- POST /api/v1/agents/{agent_id}/completions -> POST /api/v1/agents/chat/completions
|
||||
- POST /api/v1/agents_openai/{agent_id}/chat/completions -> POST /api/v1/agents/chat/completions
|
||||
- POST /api/v1/chats/{chat_id}/completions -> POST /api/v1/chat/completions
|
||||
- POST /api/v1/chats_openai/{chat_id}/chat/completions -> POST /api/v1/openai/{chat_id}/chat/completions
|
||||
|
||||
@@ -1207,7 +1207,10 @@ async def test_db_connection():
|
||||
return server_error_response(exc)
|
||||
|
||||
|
||||
@manager.route("/agents/chat/completion", methods=["POST"]) # noqa: F821
|
||||
# NOTE: The singular form `/agents/chat/completion` was a historical typo
|
||||
# in earlier releases — no client, SDK, or doc ever used it, and the
|
||||
# plural form below is the canonical route. The singular is intentionally
|
||||
# NOT registered; clients sending it receive 404.
|
||||
@manager.route("/agents/chat/completions", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@add_tenant_id_to_kwargs
|
||||
|
||||
1
go.mod
1
go.mod
@@ -117,6 +117,7 @@ require (
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -211,6 +211,8 @@ github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY=
|
||||
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
|
||||
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
|
||||
|
||||
@@ -82,6 +82,16 @@ func (b *BeginComponent) Invoke(ctx context.Context, inputs map[string]any) (map
|
||||
state.Sys["user_id"] = uid
|
||||
}
|
||||
|
||||
// Webhook payload injection. The webhook HTTP handler sets
|
||||
// root["webhook_payload"] (see service/agent.go RunAgentWithWebhook)
|
||||
// which BuildWorkflow forwards into inputs. Surfacing it on
|
||||
// state.Sys lets downstream components read sys.webhook_payload the
|
||||
// same way they read sys.query / sys.user_id. The chat path never
|
||||
// sets this key, so existing tests stay green.
|
||||
if payload, ok := inputs["webhook_payload"].(map[string]any); ok && len(payload) > 0 {
|
||||
state.Sys["webhook_payload"] = payload
|
||||
}
|
||||
|
||||
// Passthrough: a shallow copy keeps the caller's map un-aliased.
|
||||
out := make(map[string]any, len(inputs))
|
||||
mapsCopy(out, inputs)
|
||||
@@ -106,18 +116,20 @@ func (b *BeginComponent) Stream(ctx context.Context, inputs map[string]any) (<-c
|
||||
// strings live on the struct / method above.
|
||||
func (b *BeginComponent) Inputs() map[string]string {
|
||||
return map[string]string{
|
||||
"query": "User query string (the chat input).",
|
||||
"user_id": "Optional user/tenant identifier.",
|
||||
"inputs": "Optional free-form inputs map; passthrough only.",
|
||||
"query": "User query string (the chat input).",
|
||||
"user_id": "Optional user/tenant identifier.",
|
||||
"webhook_payload": "Optional structured webhook request (set by the webhook HTTP handler; absent on chat flows).",
|
||||
"inputs": "Optional free-form inputs map; passthrough only.",
|
||||
}
|
||||
}
|
||||
|
||||
// Outputs returns the same keys as Inputs (Begin is a passthrough).
|
||||
func (b *BeginComponent) Outputs() map[string]string {
|
||||
return map[string]string{
|
||||
"query": "Query string (passthrough).",
|
||||
"user_id": "User id, if provided (passthrough).",
|
||||
"inputs": "Raw inputs map (passthrough).",
|
||||
"query": "Query string (passthrough).",
|
||||
"user_id": "User id, if provided (passthrough).",
|
||||
"webhook_payload": "Webhook request payload, if provided (passthrough; also written to state.Sys[webhook_payload]).",
|
||||
"inputs": "Raw inputs map (passthrough).",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -86,3 +86,77 @@ func TestBegin_PassesThroughInputs(t *testing.T) {
|
||||
func withStateForTest(ctx context.Context, s *canvas.CanvasState) context.Context {
|
||||
return canvas.WithState(ctx, s)
|
||||
}
|
||||
|
||||
// TestBegin_InjectsWebhookPayload pins the contract added for the
|
||||
// webhook HTTP handler: when inputs["webhook_payload"] is present, Begin
|
||||
// must surface it on state.Sys["webhook_payload"] so downstream
|
||||
// components (Retrieval, Agent, etc.) can read sys.webhook_payload the
|
||||
// same way they read sys.query / sys.user_id.
|
||||
//
|
||||
// Mirrors python: agent/canvas.py (Begin component) reading
|
||||
// `webhook_payload` from inputs and writing to state.Sys in the webhook
|
||||
// branch.
|
||||
func TestBegin_InjectsWebhookPayload(t *testing.T) {
|
||||
c, _ := NewBeginComponent(nil)
|
||||
state := canvas.NewCanvasState("run-3", "task-3")
|
||||
ctx := canvas.WithState(context.Background(), state)
|
||||
|
||||
payload := map[string]any{
|
||||
"query": map[string]any{"q": "hello"},
|
||||
"headers": map[string]any{"x-token": "abc"},
|
||||
"body": map[string]any{"k": "v"},
|
||||
}
|
||||
inputs := map[string]any{
|
||||
"query": "",
|
||||
"webhook_payload": payload,
|
||||
}
|
||||
out, err := c.Invoke(ctx, inputs)
|
||||
if err != nil {
|
||||
t.Fatalf("Invoke: %v", err)
|
||||
}
|
||||
got, ok := state.Sys["webhook_payload"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("state.Sys[webhook_payload] missing or wrong type: %T", state.Sys["webhook_payload"])
|
||||
}
|
||||
if !reflect.DeepEqual(got, payload) {
|
||||
t.Errorf("state.Sys[webhook_payload] mismatch:\n got %v\n want %v", got, payload)
|
||||
}
|
||||
// Passthrough preserved.
|
||||
if outPayload, _ := out["webhook_payload"].(map[string]any); !reflect.DeepEqual(outPayload, payload) {
|
||||
t.Errorf("outputs[webhook_payload] mismatch:\n got %v\n want %v", outPayload, payload)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBegin_AbsentWebhookPayload confirms that the chat path (no
|
||||
// webhook_payload key in inputs) leaves state.Sys["webhook_payload"]
|
||||
// unset — adding the new branch must NOT pollute existing callers.
|
||||
func TestBegin_AbsentWebhookPayload(t *testing.T) {
|
||||
c, _ := NewBeginComponent(nil)
|
||||
state := canvas.NewCanvasState("run-4", "task-4")
|
||||
ctx := canvas.WithState(context.Background(), state)
|
||||
|
||||
if _, err := c.Invoke(ctx, map[string]any{"query": "plain chat"}); err != nil {
|
||||
t.Fatalf("Invoke: %v", err)
|
||||
}
|
||||
if _, ok := state.Sys["webhook_payload"]; ok {
|
||||
t.Errorf("state.Sys[webhook_payload] should not be set when inputs lack it; got %v", state.Sys["webhook_payload"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestBegin_EmptyWebhookPayload confirms that an explicitly empty map
|
||||
// is treated as "not present" — matching the python `if payload:` guard.
|
||||
func TestBegin_EmptyWebhookPayload(t *testing.T) {
|
||||
c, _ := NewBeginComponent(nil)
|
||||
state := canvas.NewCanvasState("run-5", "task-5")
|
||||
ctx := canvas.WithState(context.Background(), state)
|
||||
|
||||
if _, err := c.Invoke(ctx, map[string]any{
|
||||
"query": "",
|
||||
"webhook_payload": map[string]any{},
|
||||
}); err != nil {
|
||||
t.Fatalf("Invoke: %v", err)
|
||||
}
|
||||
if _, ok := state.Sys["webhook_payload"]; ok {
|
||||
t.Errorf("state.Sys[webhook_payload] should not be set for empty payload; got %v", state.Sys["webhook_payload"])
|
||||
}
|
||||
}
|
||||
|
||||
40
internal/agent/dsl/errors.go
Normal file
40
internal/agent/dsl/errors.go
Normal file
@@ -0,0 +1,40 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package dsl
|
||||
|
||||
import "errors"
|
||||
|
||||
// Sentinel errors returned by the extractors. Handlers use
|
||||
// errors.Is to map them to 102 (DataError) envelopes without
|
||||
// embedding the raw error text in the response.
|
||||
var (
|
||||
// ErrComponentNotFound is returned when the supplied
|
||||
// componentID does not exist in dsl["components"].
|
||||
ErrComponentNotFound = errors.New("dsl: component not found")
|
||||
|
||||
// ErrMissingInputForm is returned when the component exists
|
||||
// but has no `obj.input_form` dict. The python Canvas returns
|
||||
// None in this case; we surface 102 "component has no
|
||||
// input_form" instead.
|
||||
ErrMissingInputForm = errors.New("dsl: component has no input_form")
|
||||
|
||||
// ErrMalformedDSL is returned for structural problems — nil
|
||||
// dsl, missing components map, wrong types. Distinct from
|
||||
// ErrComponentNotFound so the handler can phrase the error
|
||||
// more clearly when the dsl is broken.
|
||||
ErrMalformedDSL = errors.New("dsl: malformed")
|
||||
)
|
||||
133
internal/agent/dsl/extract.go
Normal file
133
internal/agent/dsl/extract.go
Normal file
@@ -0,0 +1,133 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
// Package dsl contains pure-function helpers for working with the agent
|
||||
// canvas DSL map structure (`map[string]any`). It is intentionally
|
||||
// runtime-free: no Canvas instantiation, no component factories, no
|
||||
// database access. The agent_component handlers use it to introspect
|
||||
// the DSL before deciding whether to wire up a runtime component.
|
||||
package dsl
|
||||
|
||||
import "fmt"
|
||||
|
||||
// ExtractComponentInputForm returns the input-form schema dict stored at
|
||||
// `dsl["components"][componentID]["obj"]["input_form"]`.
|
||||
//
|
||||
// This is the Go equivalent of the python
|
||||
// `Canvas.get_component_input_form(component_id)` method
|
||||
// (api/agent/canvas.py:163) which reads the same path. The python
|
||||
// version walks the live Canvas object; we walk the raw DSL map
|
||||
// directly because the Go Canvas type does not expose an
|
||||
// introspection API (see plan §Gap C — there is no `GetComponent` on
|
||||
// the runtime Canvas type).
|
||||
//
|
||||
// Returns:
|
||||
// - the form-schema dict if present and well-typed
|
||||
// - ErrComponentNotFound if the componentID is missing from dsl
|
||||
// - ErrMissingInputForm if the component exists but has no input_form
|
||||
// - ErrMalformedDSL if the field is present but the wrong type
|
||||
//
|
||||
// Type errors (input_form is e.g. a list or a string) are NOT
|
||||
// collapsed into ErrMissingInputForm — they would mask a contract
|
||||
// violation in the DSL and let DebugComponent run against corrupt
|
||||
// data. CodeRabbit PR review #1 on PR #16403.
|
||||
func ExtractComponentInputForm(dsl map[string]any, componentID string) (map[string]any, error) {
|
||||
comp, err := navigateToComponent(dsl, componentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
obj, ok := comp["obj"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w: component %q has no obj", ErrMalformedDSL, componentID)
|
||||
}
|
||||
rawForm, exists := obj["input_form"]
|
||||
if !exists || rawForm == nil {
|
||||
return nil, fmt.Errorf("%w: component %q has no input_form", ErrMissingInputForm, componentID)
|
||||
}
|
||||
form, ok := rawForm.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w: component %q input_form is not a dict", ErrMalformedDSL, componentID)
|
||||
}
|
||||
return form, nil
|
||||
}
|
||||
|
||||
// ExtractComponentParams returns the params map stored at
|
||||
// `dsl["components"][componentID]["obj"]["params"]`. The debug handler
|
||||
// uses this to build the inputs map for the runtime Component.Invoke
|
||||
// call. Type errors collapse to ErrMalformedDSL (CodeRabbit PR
|
||||
// review #1).
|
||||
func ExtractComponentParams(dsl map[string]any, componentID string) (map[string]any, error) {
|
||||
comp, err := navigateToComponent(dsl, componentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
obj, ok := comp["obj"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w: component %q has no obj", ErrMalformedDSL, componentID)
|
||||
}
|
||||
rawParams, exists := obj["params"]
|
||||
if !exists || rawParams == nil {
|
||||
return nil, nil
|
||||
}
|
||||
params, ok := rawParams.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w: component %q params is not a dict", ErrMalformedDSL, componentID)
|
||||
}
|
||||
return params, nil
|
||||
}
|
||||
|
||||
// ExtractComponentName returns the component's class name (e.g.
|
||||
// "Begin", "LLM", "Retrieval") from `dsl["components"][componentID].
|
||||
// ["obj"]["component_name"]`. The runtime factory is keyed on this
|
||||
// name.
|
||||
func ExtractComponentName(dsl map[string]any, componentID string) (string, error) {
|
||||
comp, err := navigateToComponent(dsl, componentID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
obj, ok := comp["obj"].(map[string]any)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("%w: component %q has no obj", ErrMalformedDSL, componentID)
|
||||
}
|
||||
name, _ := obj["component_name"].(string)
|
||||
if name == "" {
|
||||
return "", fmt.Errorf("%w: component %q has no component_name", ErrMalformedDSL, componentID)
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
// navigateToComponent walks dsl["components"][componentID] and
|
||||
// returns the inner dict. Centralised so the three extractors above
|
||||
// share a single traversal path. (Renamed from extractComponent to
|
||||
// avoid colliding with the same-named helper in normalize.go.)
|
||||
func navigateToComponent(dsl map[string]any, componentID string) (map[string]any, error) {
|
||||
if dsl == nil {
|
||||
return nil, fmt.Errorf("%w: nil dsl", ErrMalformedDSL)
|
||||
}
|
||||
comps, ok := dsl["components"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w: missing components map", ErrMalformedDSL)
|
||||
}
|
||||
comp, ok := comps[componentID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w: %q", ErrComponentNotFound, componentID)
|
||||
}
|
||||
cm, ok := comp.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w: %q is not a dict", ErrMalformedDSL, componentID)
|
||||
}
|
||||
return cm, nil
|
||||
}
|
||||
149
internal/agent/dsl/extract_test.go
Normal file
149
internal/agent/dsl/extract_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package dsl
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// happyDSL returns a 2-component dsl suitable for the happy-path
|
||||
// extractors.
|
||||
func happyDSL() map[string]any {
|
||||
return map[string]any{
|
||||
"components": map[string]any{
|
||||
"begin": map[string]any{
|
||||
"obj": map[string]any{
|
||||
"component_name": "Begin",
|
||||
"params": map[string]any{
|
||||
"mode": "Manual",
|
||||
},
|
||||
"input_form": map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"answer": map[string]any{
|
||||
"obj": map[string]any{
|
||||
"component_name": "Answer",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractComponentInputForm_HappyPath(t *testing.T) {
|
||||
got, err := ExtractComponentInputForm(happyDSL(), "begin")
|
||||
if err != nil {
|
||||
t.Fatalf("err = %v, want nil", err)
|
||||
}
|
||||
if q, ok := got["query"].(map[string]any); !ok || q["type"] != "string" {
|
||||
t.Errorf("query type = %v, want string", got["query"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractComponentInputForm_NotFound(t *testing.T) {
|
||||
_, err := ExtractComponentInputForm(happyDSL(), "missing")
|
||||
if !errors.Is(err, ErrComponentNotFound) {
|
||||
t.Errorf("err = %v, want ErrComponentNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractComponentInputForm_MissingObj(t *testing.T) {
|
||||
dsl := map[string]any{
|
||||
"components": map[string]any{
|
||||
"bare": map[string]any{}, // no obj
|
||||
},
|
||||
}
|
||||
_, err := ExtractComponentInputForm(dsl, "bare")
|
||||
if !errors.Is(err, ErrMalformedDSL) {
|
||||
t.Errorf("err = %v, want ErrMalformedDSL", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractComponentInputForm_MissingInputForm(t *testing.T) {
|
||||
// "answer" has obj but no input_form.
|
||||
_, err := ExtractComponentInputForm(happyDSL(), "answer")
|
||||
if !errors.Is(err, ErrMissingInputForm) {
|
||||
t.Errorf("err = %v, want ErrMissingInputForm", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractComponentInputForm_NilDSL(t *testing.T) {
|
||||
_, err := ExtractComponentInputForm(nil, "anything")
|
||||
if !errors.Is(err, ErrMalformedDSL) {
|
||||
t.Errorf("err = %v, want ErrMalformedDSL", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractComponentParams_HappyPath(t *testing.T) {
|
||||
got, err := ExtractComponentParams(happyDSL(), "begin")
|
||||
if err != nil {
|
||||
t.Fatalf("err = %v, want nil", err)
|
||||
}
|
||||
if got["mode"] != "Manual" {
|
||||
t.Errorf("mode = %v, want Manual", got["mode"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractComponentParams_NoParams(t *testing.T) {
|
||||
got, err := ExtractComponentParams(happyDSL(), "answer")
|
||||
if err != nil {
|
||||
t.Fatalf("err = %v, want nil (params is optional)", err)
|
||||
}
|
||||
if got != nil && len(got) != 0 {
|
||||
t.Errorf("params = %v, want empty/nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractComponentParams_WrongType pins that a present-but-
|
||||
// wrongly-typed params field is ErrMalformedDSL. CodeRabbit PR #1.
|
||||
func TestExtractComponentParams_WrongType(t *testing.T) {
|
||||
dsl := map[string]any{
|
||||
"components": map[string]any{
|
||||
"bad": map[string]any{
|
||||
"obj": map[string]any{
|
||||
"component_name": "Begin",
|
||||
"params": "this is a string, not a dict",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err := ExtractComponentParams(dsl, "bad")
|
||||
if !errors.Is(err, ErrMalformedDSL) {
|
||||
t.Errorf("err = %v, want ErrMalformedDSL", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractComponentName_HappyPath(t *testing.T) {
|
||||
got, err := ExtractComponentName(happyDSL(), "begin")
|
||||
if err != nil {
|
||||
t.Fatalf("err = %v, want nil", err)
|
||||
}
|
||||
if got != "Begin" {
|
||||
t.Errorf("name = %q, want Begin", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractComponentName_NotFound(t *testing.T) {
|
||||
_, err := ExtractComponentName(happyDSL(), "missing")
|
||||
if !errors.Is(err, ErrComponentNotFound) {
|
||||
t.Errorf("err = %v, want ErrComponentNotFound", err)
|
||||
}
|
||||
}
|
||||
@@ -989,6 +989,38 @@ func (r *RedisClient) GetClient() *redis.Client {
|
||||
return r.client
|
||||
}
|
||||
|
||||
// EvalTokenBucketStrict is the fail-closed counterpart to TokenBucket.Allow.
|
||||
// It surfaces Lua errors and the uninitialised-Redis case to the caller so
|
||||
// security gates (e.g. webhook rate limiter) can deny on transport failure
|
||||
// rather than silently passing traffic. The existing TokenBucket.Allow
|
||||
// silently fails-open and is reserved for the chat driver path where
|
||||
// transient Redis outages should not block traffic.
|
||||
//
|
||||
// Cost is fixed at 1.0; callers wanting variable cost should compose their
|
||||
// own Lua. ctx is used for both the EVALSHA round-trip and the deadline.
|
||||
func (r *RedisClient) EvalTokenBucketStrict(
|
||||
ctx context.Context, key string, capacity, rate float64,
|
||||
) (allowed bool, err error) {
|
||||
if r == nil || r.client == nil {
|
||||
return false, fmt.Errorf("redis: not initialised")
|
||||
}
|
||||
now := float64(time.Now().Unix())
|
||||
res, err := r.luaTokenBucket.Run(ctx, r.client, []string{key},
|
||||
capacity, rate, now, 1.0).Result()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("token bucket: %w", err)
|
||||
}
|
||||
values, ok := res.([]interface{})
|
||||
if !ok || len(values) < 1 {
|
||||
return false, fmt.Errorf("token bucket: malformed reply")
|
||||
}
|
||||
allowedI, ok := values[0].(int64)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("token bucket: malformed reply")
|
||||
}
|
||||
return allowedI == 1, nil
|
||||
}
|
||||
|
||||
// RandomSleep sleeps for random duration between min and max milliseconds
|
||||
func RandomSleep(minMs, maxMs int) {
|
||||
duration := time.Duration(rand.Intn(maxMs-minMs)+minMs) * time.Millisecond
|
||||
|
||||
107
internal/engine/redis/redis_test.go
Normal file
107
internal/engine/redis/redis_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// newStrictTestClient wires a miniredis-backed RedisClient for
|
||||
// EvalTokenBucketStrict tests. Each call gets its own miniredis instance so
|
||||
// tests do not share state via the package-level globalClient.
|
||||
func newStrictTestClient(t *testing.T) (*RedisClient, *miniredis.Miniredis) {
|
||||
t.Helper()
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Fatalf("miniredis: %v", err)
|
||||
}
|
||||
t.Cleanup(mr.Close)
|
||||
|
||||
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||
t.Cleanup(func() { _ = rdb.Close() })
|
||||
|
||||
return &RedisClient{
|
||||
client: rdb,
|
||||
luaDeleteIfEqual: redis.NewScript(luaDeleteIfEqualScript),
|
||||
luaTokenBucket: redis.NewScript(luaTokenBucketScript),
|
||||
// luaAutoIncrement intentionally not loaded; not used here.
|
||||
}, mr
|
||||
}
|
||||
|
||||
// TestEvalTokenBucketStrict_AllowedThenDenied walks the bucket through
|
||||
// capacity=2, rate=0.1 (slow refill). Two calls should be allowed; the
|
||||
// third should be denied. This is the happy-path security gate.
|
||||
func TestEvalTokenBucketStrict_AllowedThenDenied(t *testing.T) {
|
||||
r, _ := newStrictTestClient(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 1; i <= 2; i++ {
|
||||
ok, err := r.EvalTokenBucketStrict(ctx, "tb:webhook", 2, 0.1)
|
||||
if err != nil {
|
||||
t.Fatalf("call %d unexpected error: %v", i, err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("call %d: expected allowed=true", i)
|
||||
}
|
||||
}
|
||||
ok, err := r.EvalTokenBucketStrict(ctx, "tb:webhook", 2, 0.1)
|
||||
if err != nil {
|
||||
t.Fatalf("call 3 unexpected error: %v", err)
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("call 3: expected allowed=false (bucket exhausted)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEvalTokenBucketStrict_RedisDownFailsClosed confirms the strict
|
||||
// contract: when the transport fails, the caller sees an error AND
|
||||
// allowed=false. This is the explicit divergence from TokenBucket.Allow,
|
||||
// which would silently return allowed=true in the same situation.
|
||||
func TestEvalTokenBucketStrict_RedisDownFailsClosed(t *testing.T) {
|
||||
r, mr := newStrictTestClient(t)
|
||||
mr.Close() // break the connection
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
ok, err := r.EvalTokenBucketStrict(ctx, "tb:webhook", 5, 1)
|
||||
if err == nil {
|
||||
t.Fatalf("expected transport error, got nil")
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("expected allowed=false on transport failure (fail-closed)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEvalTokenBucketStrict_NilClient confirms the nil-receiver guard.
|
||||
// The uninitialised-Redis case must NOT silently pass; it must return
|
||||
// (false, error) so the webhook handler can surface 102.
|
||||
func TestEvalTokenBucketStrict_NilClient(t *testing.T) {
|
||||
var r *RedisClient
|
||||
ok, err := r.EvalTokenBucketStrict(context.Background(), "tb:webhook", 1, 1)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error on nil client, got nil")
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("expected allowed=false on nil client (fail-closed)")
|
||||
}
|
||||
}
|
||||
@@ -42,9 +42,20 @@ import (
|
||||
|
||||
// AgentHandler agent handler
|
||||
// fileUploader is the subset of FileService used by agent handlers.
|
||||
//
|
||||
// The full FileService also has UploadFile, but it is consumed by
|
||||
// the FileHandler (handler/file.go), not by any agent handler, so
|
||||
// the interface deliberately does NOT list it. (Code review CR1.)
|
||||
type agentFileService interface {
|
||||
UploadFile(tenantID, parentID string, files []*multipart.FileHeader) ([]map[string]interface{}, error)
|
||||
DownloadAgentFile(tenantID, location string) ([]byte, error)
|
||||
// UploadInfos stores raw bytes in the per-user downloads bucket and
|
||||
// returns lightweight descriptors. Mirrors python FileService.upload_info
|
||||
// (multi-file path) used by the agent upload endpoint.
|
||||
UploadInfos(userID string, files []*multipart.FileHeader) ([]map[string]interface{}, error)
|
||||
// UploadFromURL downloads a remote file (with SSRF protection) and
|
||||
// stores it as an info blob. Mirrors python FileService.upload_info
|
||||
// (single-file path with ?url=) used by the agent upload endpoint.
|
||||
UploadFromURL(tenantID, rawURL string) (map[string]interface{}, error)
|
||||
}
|
||||
|
||||
// chatAgentService is the subset of AgentService used by the chat-completion
|
||||
@@ -62,6 +73,7 @@ type AgentHandler struct {
|
||||
agentService *service.AgentService
|
||||
chatRunner chatAgentService
|
||||
fileService agentFileService
|
||||
loader canvasLoader
|
||||
}
|
||||
|
||||
// NewAgentHandler create agent handler
|
||||
@@ -71,6 +83,7 @@ func NewAgentHandler(agentService *service.AgentService, fileService *service.Fi
|
||||
agentService: agentService,
|
||||
chatRunner: agentService,
|
||||
fileService: fileService,
|
||||
loader: agentService,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
200
internal/handler/agent_component.go
Normal file
200
internal/handler/agent_component.go
Normal file
@@ -0,0 +1,200 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
// Gap C — `GET /api/v1/agents/<canvas_id>/components/<component_id>/input-form`.
|
||||
// Gap D — `POST /api/v1/agents/<canvas_id>/components/<component_id>/debug`.
|
||||
//
|
||||
// Both endpoints introspect the user_canvas DSL via
|
||||
// internal/agent/dsl helpers and, for debug only, construct a runtime
|
||||
// component via the production factory. The python reference uses
|
||||
// @_require_canvas_access_sync/_async decorators which return
|
||||
// OPERATING_ERROR (103) with the python permission message; we
|
||||
// inline the check here via loader.LoadCanvasByID and surface the
|
||||
// same envelope (CodeOperatingError + canvasNoAccessMessage) so
|
||||
// existing clients can pattern-match the response.
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/agent/canvas"
|
||||
"ragflow/internal/agent/dsl"
|
||||
"ragflow/internal/agent/runtime"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
)
|
||||
|
||||
// GetComponentInputForm GET /api/v1/agents/:canvas_id/components/:component_id/input-form
|
||||
func (h *AgentHandler) GetComponentInputForm(c *gin.Context) {
|
||||
user, code, msg := GetUser(c)
|
||||
if code != common.CodeSuccess {
|
||||
jsonError(c, code, msg)
|
||||
return
|
||||
}
|
||||
canvasID := c.Param("canvas_id")
|
||||
componentID := c.Param("component_id")
|
||||
if canvasID == "" || componentID == "" {
|
||||
jsonError(c, common.CodeArgumentError, "`canvas_id` and `component_id` are required.")
|
||||
return
|
||||
}
|
||||
|
||||
cv, err := h.loader.LoadCanvasByID(c.Request.Context(), user.ID, canvasID)
|
||||
if err != nil {
|
||||
if err == dao.ErrUserCanvasNotFound {
|
||||
jsonError(c, common.CodeOperatingError, canvasNoAccessMessage)
|
||||
return
|
||||
}
|
||||
jsonError(c, common.CodeServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
form, err := dsl.ExtractComponentInputForm(cv.DSL, componentID)
|
||||
if err != nil {
|
||||
mapDSLError(c, componentID, err)
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"data": form,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// DebugComponent POST /api/v1/agents/:canvas_id/components/:component_id/debug
|
||||
//
|
||||
// Body shape (python parity): {"params": {"input_name": {"value": ...}, ...}}
|
||||
//
|
||||
// The python reference calls `component.invoke(**{k: o["value"] for k, o in
|
||||
// req["params"].items()})`. We replicate this by flattening each param's
|
||||
// "value" field into the Invoke inputs map.
|
||||
func (h *AgentHandler) DebugComponent(c *gin.Context) {
|
||||
user, code, msg := GetUser(c)
|
||||
if code != common.CodeSuccess {
|
||||
jsonError(c, code, msg)
|
||||
return
|
||||
}
|
||||
canvasID := c.Param("canvas_id")
|
||||
componentID := c.Param("component_id")
|
||||
if canvasID == "" || componentID == "" {
|
||||
jsonError(c, common.CodeArgumentError, "`canvas_id` and `component_id` are required.")
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Params map[string]map[string]any `json:"params"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
jsonError(c, common.CodeArgumentError, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if body.Params == nil {
|
||||
jsonError(c, common.CodeArgumentError, "`params` is required.")
|
||||
return
|
||||
}
|
||||
|
||||
cv, err := h.loader.LoadCanvasByID(c.Request.Context(), user.ID, canvasID)
|
||||
if err != nil {
|
||||
if err == dao.ErrUserCanvasNotFound {
|
||||
jsonError(c, common.CodeOperatingError, canvasNoAccessMessage)
|
||||
return
|
||||
}
|
||||
jsonError(c, common.CodeServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
name, err := dsl.ExtractComponentName(cv.DSL, componentID)
|
||||
if err != nil {
|
||||
mapDSLError(c, componentID, err)
|
||||
return
|
||||
}
|
||||
dslParams, _ := dsl.ExtractComponentParams(cv.DSL, componentID)
|
||||
|
||||
// Build the Invoke inputs map by flattening the request body's
|
||||
// {param: {value: ...}} shape into {param: value}. Mirrors python
|
||||
// agent_api.py:830 (`component.invoke(**{k: o["value"] for k, o in
|
||||
// req["params"].items()})`).
|
||||
//
|
||||
// The body contract requires `params.*.value` — a missing value
|
||||
// field used to slip through as nil and still invoke the
|
||||
// component, which silently corrupted debug input. Now we fail
|
||||
// fast. CodeRabbit PR review #2.
|
||||
inputs := make(map[string]any, len(body.Params))
|
||||
for k, v := range body.Params {
|
||||
if v == nil {
|
||||
jsonError(c, common.CodeArgumentError, "`params."+k+".value` is required.")
|
||||
return
|
||||
}
|
||||
value, ok := v["value"]
|
||||
if !ok {
|
||||
jsonError(c, common.CodeArgumentError, "`params."+k+".value` is required.")
|
||||
return
|
||||
}
|
||||
inputs[k] = value
|
||||
}
|
||||
|
||||
factory := runtime.DefaultFactory()
|
||||
if factory == nil {
|
||||
jsonError(c, common.CodeServerError, "component factory not initialised")
|
||||
return
|
||||
}
|
||||
comp, err := factory(name, dslParams)
|
||||
if err != nil {
|
||||
jsonError(c, common.CodeDataError, "component factory: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// D4: skip set_debug_inputs (python-only LLM debug hook). The
|
||||
// raw Invoke already supports the same inputs.
|
||||
//
|
||||
// Begin (and other stateful components) reads a *CanvasState
|
||||
// from the request context. We attach a fresh one here so
|
||||
// debug works on a single component without standing up the
|
||||
// full canvas compile.
|
||||
invokeCtx := runtime.WithState(c.Request.Context(), canvas.NewCanvasState("debug-"+componentID, "debug-task"))
|
||||
|
||||
outputs, err := comp.Invoke(invokeCtx, inputs)
|
||||
if err != nil {
|
||||
jsonError(c, common.CodeServerError, "invoke: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"data": outputs,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// mapDSLError translates a dsl extractor error into a 102 envelope.
|
||||
// Centralised so both handlers return consistent error shapes. The
|
||||
// default arm surfaces unknown errors as 500 (server error) so a
|
||||
// future unmapped dsl sentinel doesn't silently masquerade as a
|
||||
// user-data problem (code-review MEDIUM).
|
||||
func mapDSLError(c *gin.Context, componentID string, err error) {
|
||||
switch {
|
||||
case errors.Is(err, dsl.ErrComponentNotFound):
|
||||
jsonError(c, common.CodeDataError, "component not found: "+componentID)
|
||||
case errors.Is(err, dsl.ErrMissingInputForm):
|
||||
jsonError(c, common.CodeDataError, "component has no input_form: "+componentID)
|
||||
case errors.Is(err, dsl.ErrMalformedDSL):
|
||||
jsonError(c, common.CodeDataError, "malformed dsl: "+err.Error())
|
||||
default:
|
||||
jsonError(c, common.CodeServerError, "internal: "+err.Error())
|
||||
}
|
||||
}
|
||||
315
internal/handler/agent_component_test.go
Normal file
315
internal/handler/agent_component_test.go
Normal file
@@ -0,0 +1,315 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
_ "ragflow/internal/agent/component" // registers the production factory
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// componentCtx builds a Gin context for one of the
|
||||
// /agents/:canvas_id/components/:component_id/* endpoints. The
|
||||
// supplied `body` is bound as the request body (empty for GET).
|
||||
func componentCtx(t *testing.T, method, path, body string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(method, path, strings.NewReader(body))
|
||||
if body != "" {
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
c.Set("user", &entity.User{ID: "u-1"})
|
||||
return c, w
|
||||
}
|
||||
|
||||
// beginCanvas returns a UserCanvas whose DSL has a single Begin
|
||||
// component with both an input_form and a params block. Used for
|
||||
// happy-path tests.
|
||||
func beginCanvas() *entity.UserCanvas {
|
||||
return &entity.UserCanvas{
|
||||
ID: "c1",
|
||||
DSL: map[string]any{
|
||||
"components": map[string]any{
|
||||
"begin": map[string]any{
|
||||
"obj": map[string]any{
|
||||
"component_name": "Begin",
|
||||
"params": map[string]any{
|
||||
"mode": "Manual",
|
||||
},
|
||||
"input_form": map[string]any{
|
||||
"query": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- GetComponentInputForm ----------
|
||||
|
||||
func TestGetComponentInputForm_HappyPath(t *testing.T) {
|
||||
cv := beginCanvas()
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
|
||||
c, w := componentCtx(t, "GET", "/api/v1/agents/c1/components/begin/input-form", "")
|
||||
c.Params = gin.Params{
|
||||
{Key: "canvas_id", Value: "c1"},
|
||||
{Key: "component_id", Value: "begin"},
|
||||
}
|
||||
h.GetComponentInputForm(c)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
var env struct {
|
||||
Code int `json:"code"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &env); err != nil {
|
||||
t.Fatalf("decode: %v (body=%s)", err, w.Body.String())
|
||||
}
|
||||
q, ok := env.Data["query"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("data.query = %v, want a map", env.Data["query"])
|
||||
}
|
||||
if q["type"] != "string" {
|
||||
t.Errorf("data.query.type = %v, want string", q["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetComponentInputForm_CanvasNotFound(t *testing.T) {
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{err: dao.ErrUserCanvasNotFound}}
|
||||
|
||||
c, w := componentCtx(t, "GET", "/api/v1/agents/c1/components/begin/input-form", "")
|
||||
c.Params = gin.Params{
|
||||
{Key: "canvas_id", Value: "c1"},
|
||||
{Key: "component_id", Value: "begin"},
|
||||
}
|
||||
h.GetComponentInputForm(c)
|
||||
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != 103 { // CodeOperatingError — python @_require_canvas_access parity
|
||||
t.Errorf("code = %d, want 103; msg=%q", code, msg)
|
||||
}
|
||||
want := "Make sure you have permission to access the agent."
|
||||
if msg != want {
|
||||
t.Errorf("msg = %q, want %q", msg, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetComponentInputForm_ComponentNotFound(t *testing.T) {
|
||||
cv := beginCanvas()
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
|
||||
c, w := componentCtx(t, "GET", "/api/v1/agents/c1/components/missing/input-form", "")
|
||||
c.Params = gin.Params{
|
||||
{Key: "canvas_id", Value: "c1"},
|
||||
{Key: "component_id", Value: "missing"},
|
||||
}
|
||||
h.GetComponentInputForm(c)
|
||||
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != 102 { // CodeDataError
|
||||
t.Errorf("code = %d, want 102; msg=%q", code, msg)
|
||||
}
|
||||
if !strings.Contains(msg, "component not found") {
|
||||
t.Errorf("msg = %q, want 'component not found'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetComponentInputForm_NoInputForm(t *testing.T) {
|
||||
cv := &entity.UserCanvas{
|
||||
ID: "c1",
|
||||
DSL: map[string]any{
|
||||
"components": map[string]any{
|
||||
"answer": map[string]any{
|
||||
"obj": map[string]any{
|
||||
"component_name": "Answer",
|
||||
// no input_form
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
|
||||
c, w := componentCtx(t, "GET", "/api/v1/agents/c1/components/answer/input-form", "")
|
||||
c.Params = gin.Params{
|
||||
{Key: "canvas_id", Value: "c1"},
|
||||
{Key: "component_id", Value: "answer"},
|
||||
}
|
||||
h.GetComponentInputForm(c)
|
||||
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != 102 {
|
||||
t.Errorf("code = %d, want 102; msg=%q", code, msg)
|
||||
}
|
||||
if !strings.Contains(msg, "no input_form") {
|
||||
t.Errorf("msg = %q, want 'no input_form'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- DebugComponent ----------
|
||||
|
||||
func TestDebugComponent_HappyPath_Begin(t *testing.T) {
|
||||
cv := beginCanvas()
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
|
||||
body := `{"params":{"query":{"value":"hello world"}}}`
|
||||
c, w := componentCtx(t, "POST", "/api/v1/agents/c1/components/begin/debug", body)
|
||||
c.Params = gin.Params{
|
||||
{Key: "canvas_id", Value: "c1"},
|
||||
{Key: "component_id", Value: "begin"},
|
||||
}
|
||||
h.DebugComponent(c)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
// Begin's Invoke is a passthrough — it returns the input map as
|
||||
// outputs (see internal/agent/component/begin.go:96-98). Pin the
|
||||
// exact passthrough so a regression that drops the `query` field
|
||||
// would be caught (architect review A2).
|
||||
var env struct {
|
||||
Code int `json:"code"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &env); err != nil {
|
||||
t.Fatalf("decode: %v (body=%s)", err, w.Body.String())
|
||||
}
|
||||
if env.Code != 0 {
|
||||
t.Errorf("code = %d, want 0; body=%s", env.Code, w.Body.String())
|
||||
}
|
||||
if env.Data["query"] != "hello world" {
|
||||
t.Errorf("data.query = %v, want 'hello world'", env.Data["query"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDebugComponent_InvalidParams_MissingField(t *testing.T) {
|
||||
cv := beginCanvas()
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
|
||||
// No `params` field.
|
||||
c, w := componentCtx(t, "POST", "/api/v1/agents/c1/components/begin/debug", `{}`)
|
||||
c.Params = gin.Params{
|
||||
{Key: "canvas_id", Value: "c1"},
|
||||
{Key: "component_id", Value: "begin"},
|
||||
}
|
||||
h.DebugComponent(c)
|
||||
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != 101 { // CodeArgumentError
|
||||
t.Errorf("code = %d, want 101; msg=%q", code, msg)
|
||||
}
|
||||
if !strings.Contains(msg, "params") {
|
||||
t.Errorf("msg = %q, want 'params'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDebugComponent_InvalidParams_MissingValue covers CodeRabbit
|
||||
// PR review #2: `{"params":{"q":{}}}` (no `value` key) should
|
||||
// fail-fast with 101 rather than silently invoking the component
|
||||
// with inputs["q"]=nil.
|
||||
func TestDebugComponent_InvalidParams_MissingValue(t *testing.T) {
|
||||
cv := beginCanvas()
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
|
||||
c, w := componentCtx(t, "POST", "/api/v1/agents/c1/components/begin/debug", `{"params":{"q":{}}}`)
|
||||
c.Params = gin.Params{
|
||||
{Key: "canvas_id", Value: "c1"},
|
||||
{Key: "component_id", Value: "begin"},
|
||||
}
|
||||
h.DebugComponent(c)
|
||||
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != 101 {
|
||||
t.Errorf("code = %d, want 101; msg=%q", code, msg)
|
||||
}
|
||||
if !strings.Contains(msg, "value") {
|
||||
t.Errorf("msg = %q, want mention of 'value'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDebugComponent_UnknownComponent(t *testing.T) {
|
||||
cv := beginCanvas()
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
|
||||
body := `{"params":{"x":{"value":1}}}`
|
||||
c, w := componentCtx(t, "POST", "/api/v1/agents/c1/components/missing/debug", body)
|
||||
c.Params = gin.Params{
|
||||
{Key: "canvas_id", Value: "c1"},
|
||||
{Key: "component_id", Value: "missing"},
|
||||
}
|
||||
h.DebugComponent(c)
|
||||
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != 102 {
|
||||
t.Errorf("code = %d, want 102; msg=%q", code, msg)
|
||||
}
|
||||
if !strings.Contains(msg, "component not found") {
|
||||
t.Errorf("msg = %q, want 'component not found'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDebugComponent_CannotAccessCanvas(t *testing.T) {
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{err: dao.ErrUserCanvasNotFound}}
|
||||
|
||||
body := `{"params":{"x":{"value":1}}}`
|
||||
c, w := componentCtx(t, "POST", "/api/v1/agents/c1/components/begin/debug", body)
|
||||
c.Params = gin.Params{
|
||||
{Key: "canvas_id", Value: "c1"},
|
||||
{Key: "component_id", Value: "begin"},
|
||||
}
|
||||
h.DebugComponent(c)
|
||||
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != 103 { // CodeOperatingError
|
||||
t.Errorf("code = %d, want 103; msg=%q", code, msg)
|
||||
}
|
||||
want := "Make sure you have permission to access the agent."
|
||||
if msg != want {
|
||||
t.Errorf("msg = %q, want %q", msg, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDebugComponent_LoaderError_NonNotFound(t *testing.T) {
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{err: errors.New("db conn lost")}}
|
||||
|
||||
body := `{"params":{"x":{"value":1}}}`
|
||||
c, w := componentCtx(t, "POST", "/api/v1/agents/c1/components/begin/debug", body)
|
||||
c.Params = gin.Params{
|
||||
{Key: "canvas_id", Value: "c1"},
|
||||
{Key: "component_id", Value: "begin"},
|
||||
}
|
||||
h.DebugComponent(c)
|
||||
|
||||
code, _ := errBody(t, w.Body.Bytes())
|
||||
if code != 500 { // CodeServerError
|
||||
t.Errorf("code = %d, want 500", code)
|
||||
}
|
||||
}
|
||||
86
internal/handler/agent_download.go
Normal file
86
internal/handler/agent_download.go
Normal file
@@ -0,0 +1,86 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
// Gap A — `GET /api/v1/agents/download` (Python
|
||||
// api/apps/restful_apis/agent_api.py:523-530).
|
||||
//
|
||||
// Mirrors the python download_agent_file handler:
|
||||
// - auth via @login_required → GetUser
|
||||
// - tenant injection via @add_tenant_id_to_kwargs → user.TenantID (we use
|
||||
// user.ID here because the user/tenant distinction is collapsed in
|
||||
// the Go session model; FileService buckets by userID for download
|
||||
// retrieval, same as the python tenantID). See service/file.go:1033.
|
||||
// - reads `id` from query string and streams raw bytes back as
|
||||
// application/octet-stream.
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/common"
|
||||
)
|
||||
|
||||
// DownloadAgentFile GET /api/v1/agents/download?id=<file_id>
|
||||
func (h *AgentHandler) DownloadAgentFile(c *gin.Context) {
|
||||
user, code, msg := GetUser(c)
|
||||
if code != common.CodeSuccess {
|
||||
jsonError(c, code, msg)
|
||||
return
|
||||
}
|
||||
fileID := c.Query("id")
|
||||
if fileID == "" {
|
||||
jsonError(c, common.CodeArgumentError, "`id` is required.")
|
||||
return
|
||||
}
|
||||
|
||||
// IDOR note (security review H1): the Go User struct has no
|
||||
// TenantID field — the project collapses user and tenant in a
|
||||
// single-tenant session model. Python's @add_tenant_id_to_kwargs
|
||||
// resolves tenant_id from the session, and the python download
|
||||
// endpoint also reads `id` directly from the query string with no
|
||||
// per-object ownership check, so this port preserves the python
|
||||
// shape. A future per-object ownership check should be added in
|
||||
// both the python and Go code paths.
|
||||
blob, err := h.fileService.DownloadAgentFile(user.ID, fileID)
|
||||
if err != nil {
|
||||
jsonError(c, common.CodeServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Sanitize the Content-Disposition value to prevent header
|
||||
// injection (security review H2). The Go net/http layer rejects
|
||||
// CR/LF in header values, but we sanitize at the source so we
|
||||
// don't rely on the implicit defense. `filepath.Base` strips any
|
||||
// path elements; url.PathEscape produces an RFC 5987 filename*=
|
||||
// value.
|
||||
safe := filepath.Base(fileID)
|
||||
if safe == "" || safe == "." || safe == "/" || strings.ContainsAny(safe, "\r\n\"") {
|
||||
jsonError(c, common.CodeArgumentError, "invalid file id.")
|
||||
return
|
||||
}
|
||||
c.Header("Content-Disposition", fmt.Sprintf(
|
||||
`attachment; filename="%s"; filename*=UTF-8''%s`,
|
||||
safe, url.PathEscape(safe),
|
||||
))
|
||||
c.Data(http.StatusOK, "application/octet-stream", blob)
|
||||
}
|
||||
139
internal/handler/agent_download_test.go
Normal file
139
internal/handler/agent_download_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"mime/multipart"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// fakeAgentFileService satisfies the agentFileService interface declared
|
||||
// in agent.go. It returns canned bytes / errors for DownloadAgentFile so
|
||||
// download-handler tests can run without a real FileService or MinIO.
|
||||
// UploadInfos returns whatever is in `uploadList` so the upload tests
|
||||
// can assert the response payload shape. UploadFromURL returns whatever
|
||||
// is in `urlUpload` (or `urlUploadErr`) for the ?url= import path.
|
||||
type fakeAgentFileService struct {
|
||||
blob []byte
|
||||
err error
|
||||
uploadList []map[string]interface{}
|
||||
uploadErr error
|
||||
urlUpload map[string]interface{}
|
||||
urlUploadErr error
|
||||
}
|
||||
|
||||
func (f *fakeAgentFileService) UploadInfos(_ string, _ []*multipart.FileHeader) ([]map[string]interface{}, error) {
|
||||
if f.uploadErr != nil {
|
||||
return nil, f.uploadErr
|
||||
}
|
||||
return f.uploadList, nil
|
||||
}
|
||||
func (f *fakeAgentFileService) UploadFromURL(_ string, _ string) (map[string]interface{}, error) {
|
||||
if f.urlUploadErr != nil {
|
||||
return nil, f.urlUploadErr
|
||||
}
|
||||
return f.urlUpload, nil
|
||||
}
|
||||
func (f *fakeAgentFileService) DownloadAgentFile(_ string, _ string) ([]byte, error) {
|
||||
if f.err != nil {
|
||||
return nil, f.err
|
||||
}
|
||||
return f.blob, nil
|
||||
}
|
||||
|
||||
// downloadCtx builds a Gin context for the /agents/download endpoint with
|
||||
// optional id query parameter and a stubbed authenticated user.
|
||||
func downloadCtx(t *testing.T, id string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
url := "/api/v1/agents/download"
|
||||
if id != "" {
|
||||
url += "?id=" + id
|
||||
}
|
||||
c.Request = httptest.NewRequest("GET", url, nil)
|
||||
c.Set("user", &entity.User{ID: "user-1"})
|
||||
return c, w
|
||||
}
|
||||
|
||||
// TestDownloadAgentFile_HappyPath verifies that a known id returns the
|
||||
// raw bytes with application/octet-stream content type and a sane
|
||||
// Content-Disposition header.
|
||||
func TestDownloadAgentFile_HappyPath(t *testing.T) {
|
||||
h := &AgentHandler{fileService: &fakeAgentFileService{blob: []byte("hello world")}}
|
||||
|
||||
c, w := downloadCtx(t, "file-123")
|
||||
h.DownloadAgentFile(c)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Errorf("status = %d, want 200; body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
if w.Body.String() != "hello world" {
|
||||
t.Errorf("body = %q, want %q", w.Body.String(), "hello world")
|
||||
}
|
||||
if got := w.Header().Get("Content-Disposition"); !strings.Contains(got, "file-123") {
|
||||
t.Errorf("Content-Disposition = %q, want filename=file-123", got)
|
||||
}
|
||||
if got := w.Header().Get("Content-Type"); got != "application/octet-stream" {
|
||||
t.Errorf("Content-Type = %q, want application/octet-stream", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDownloadAgentFile_MissingID verifies that an empty id is rejected
|
||||
// with a 102-shaped envelope.
|
||||
func TestDownloadAgentFile_MissingID(t *testing.T) {
|
||||
h := &AgentHandler{fileService: &fakeAgentFileService{}}
|
||||
|
||||
c, w := downloadCtx(t, "") // no id
|
||||
h.DownloadAgentFile(c)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Errorf("status = %d, want 200 (envelope returned in body)", w.Code)
|
||||
}
|
||||
code, _ := errBody(t, w.Body.Bytes())
|
||||
if code != 101 { // CodeArgumentError
|
||||
t.Errorf("code = %d, want 101", code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDownloadAgentFile_LoaderError verifies that a storage error
|
||||
// bubbles up as a 500-class envelope.
|
||||
func TestDownloadAgentFile_LoaderError(t *testing.T) {
|
||||
h := &AgentHandler{fileService: &fakeAgentFileService{err: errors.New("minio down")}}
|
||||
|
||||
c, w := downloadCtx(t, "file-1")
|
||||
h.DownloadAgentFile(c)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Errorf("status = %d, want 200 (envelope in body)", w.Code)
|
||||
}
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != 500 { // CodeServerError
|
||||
t.Errorf("code = %d, want 500", code)
|
||||
}
|
||||
if !strings.Contains(msg, "minio down") {
|
||||
t.Errorf("msg = %q, want contains 'minio down'", msg)
|
||||
}
|
||||
}
|
||||
152
internal/handler/agent_upload.go
Normal file
152
internal/handler/agent_upload.go
Normal file
@@ -0,0 +1,152 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
// Gap B — `POST /api/v1/agents/<agent_id>/upload` (Python
|
||||
// api/apps/restful_apis/agent_api.py:761-790).
|
||||
//
|
||||
// Mirrors the python upload_agent_file handler:
|
||||
// - @_require_canvas_access_async → loader.LoadCanvasByID (returns
|
||||
// ErrUserCanvasNotFound for both missing and forbidden; we map that
|
||||
// to CodeOperatingError (103) with the python permission message
|
||||
// instead of the chat-path 103 "canvas not found.")
|
||||
// - single file + ?url= → FileService.upload_info(tenant, file, url)
|
||||
// via UploadFromURL (URL-import mode)
|
||||
// - single file + no url → FileService.upload_info(tenant, file)
|
||||
// via UploadInfos with a one-element slice
|
||||
// - multi file + any url → FileService.upload_info * N
|
||||
// via UploadInfos (Python ignores ?url= on the multi-file path)
|
||||
//
|
||||
// 64 MB upload cap (D3 default).
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
)
|
||||
|
||||
// uploadMaxBytes caps the multipart form body at 64 MB. The python
|
||||
// reference relies on Quart/werkzeug defaults which are well above
|
||||
// this; we set it explicitly so the cap is auditable and stable across
|
||||
// test environments.
|
||||
const uploadMaxBytes int64 = 64 << 20 // 64 MiB
|
||||
|
||||
// canvasNoAccessMessage mirrors the python permission error
|
||||
// (api/apps/restful_apis/agent_api.py:78,89). Kept identical to
|
||||
// python so existing clients can pattern-match the message text.
|
||||
const canvasNoAccessMessage = "Make sure you have permission to access the agent."
|
||||
|
||||
// UploadAgentFile POST /api/v1/agents/:canvas_id/upload
|
||||
func (h *AgentHandler) UploadAgentFile(c *gin.Context) {
|
||||
user, code, msg := GetUser(c)
|
||||
if code != common.CodeSuccess {
|
||||
jsonError(c, code, msg)
|
||||
return
|
||||
}
|
||||
canvasID := c.Param("canvas_id")
|
||||
if canvasID == "" {
|
||||
jsonError(c, common.CodeArgumentError, "`canvas_id` is required.")
|
||||
return
|
||||
}
|
||||
|
||||
// Canvas access check: matches python @_require_canvas_access_async.
|
||||
// We deliberately do NOT differentiate "missing" from "forbidden"
|
||||
// (LoadCanvasByID collapses both into ErrUserCanvasNotFound) for
|
||||
// IDOR mitigation; the user-visible envelope uses OPERATING_ERROR
|
||||
// (103) with the python permission message so existing clients can
|
||||
// still pattern-match the text.
|
||||
if _, err := h.loader.LoadCanvasByID(c.Request.Context(), user.ID, canvasID); err != nil {
|
||||
if err == dao.ErrUserCanvasNotFound {
|
||||
jsonError(c, common.CodeOperatingError, canvasNoAccessMessage)
|
||||
return
|
||||
}
|
||||
jsonError(c, common.CodeServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Hard cap the body before any parsing (security review H3).
|
||||
// Without MaxBytesReader, a 1 GB request body is fully drained
|
||||
// into memory by ParseMultipartForm before any size check fires.
|
||||
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, uploadMaxBytes)
|
||||
if cl := c.Request.ContentLength; cl > uploadMaxBytes {
|
||||
jsonError(c, common.CodeArgumentError, "request body too large.")
|
||||
return
|
||||
}
|
||||
if err := c.Request.ParseMultipartForm(uploadMaxBytes); err != nil {
|
||||
jsonError(c, common.CodeArgumentError, "invalid multipart form: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if c.Request.MultipartForm != nil {
|
||||
_ = c.Request.MultipartForm.RemoveAll()
|
||||
}
|
||||
}()
|
||||
form := c.Request.MultipartForm
|
||||
if form == nil {
|
||||
jsonError(c, common.CodeArgumentError, "missing multipart form.")
|
||||
return
|
||||
}
|
||||
files := form.File["file"]
|
||||
|
||||
// URL-import mode: matches python's behaviour exactly
|
||||
// (api/apps/restful_apis/agent_api.py:775-783). The url query
|
||||
// param is consulted ONLY on the single-file branch; for 0 or
|
||||
// >1 files, the url is silently ignored and the request flows
|
||||
// into the normal UploadInfos path. We replicate that with a
|
||||
// guard that dispatches to UploadFromURL only when both
|
||||
// conditions are met.
|
||||
if url := c.Query("url"); url != "" && len(files) == 1 {
|
||||
uploaded, err := h.fileService.UploadFromURL(user.ID, url)
|
||||
if err != nil {
|
||||
jsonError(c, common.CodeServerError, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"data": uploaded,
|
||||
"message": "success",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
jsonError(c, common.CodeArgumentError, "`file` field is required.")
|
||||
return
|
||||
}
|
||||
|
||||
results, err := h.fileService.UploadInfos(user.ID, files)
|
||||
if err != nil {
|
||||
jsonError(c, common.CodeServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Python parity: 1 file → single dict; >1 → list.
|
||||
var payload any
|
||||
if len(results) == 1 {
|
||||
payload = results[0]
|
||||
} else {
|
||||
payload = results
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"data": payload,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
@@ -13,230 +13,305 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
//go:build ignore
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"mime/multipart"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/service"
|
||||
)
|
||||
|
||||
// setupUploadTestDB sets up SQLite in-memory DB for upload handler tests.
|
||||
func setupUploadTestDB(t *testing.T) *gorm.DB {
|
||||
// uploadFakes is a paired pair of fakes used by the upload-handler
|
||||
// tests. The canvasLoader fake grants/denies access per case; the
|
||||
// fileService fake records the call and returns a canned descriptor.
|
||||
type uploadFakes struct {
|
||||
loader *fakeCanvasLoader
|
||||
fileSvc *fakeAgentFileService
|
||||
}
|
||||
|
||||
func (u *uploadFakes) setUploadResult(list []map[string]interface{}) {
|
||||
u.fileSvc.uploadList = list
|
||||
}
|
||||
|
||||
// makeUploadCtx builds a Gin context with a multipart body containing
|
||||
// `n` files under the "file" field. Each file's body is the
|
||||
// decimal-string representation of its 1-based index.
|
||||
func makeUploadCtx(t *testing.T, n int) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
TranslateError: true,
|
||||
body := &bytes.Buffer{}
|
||||
mw := multipart.NewWriter(body)
|
||||
for i := 1; i <= n; i++ {
|
||||
name := "f" + strconv.Itoa(i) + ".txt"
|
||||
fw, err := mw.CreateFormFile("file", name)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateFormFile: %v", err)
|
||||
}
|
||||
if _, err := fw.Write([]byte(name)); err != nil {
|
||||
t.Fatalf("write file: %v", err)
|
||||
}
|
||||
}
|
||||
mw.Close()
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/v1/agents/c1/upload", body)
|
||||
req.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
c.Request = req
|
||||
c.Params = gin.Params{{Key: "canvas_id", Value: "c1"}}
|
||||
c.Set("user", &entity.User{ID: "u-1"})
|
||||
return c, w
|
||||
}
|
||||
|
||||
// makeUploadCtxNoFile builds a Gin context with a multipart body that
|
||||
// has no "file" field (used for the missing-file test).
|
||||
func makeUploadCtxNoFile(t *testing.T) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
body := &bytes.Buffer{}
|
||||
mw := multipart.NewWriter(body)
|
||||
// "other" field, not "file"
|
||||
fw, _ := mw.CreateFormFile("other", "o.txt")
|
||||
fw.Write([]byte("o"))
|
||||
mw.Close()
|
||||
req := httptest.NewRequest("POST", "/api/v1/agents/c1/upload", body)
|
||||
req.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
c.Request = req
|
||||
c.Params = gin.Params{{Key: "canvas_id", Value: "c1"}}
|
||||
c.Set("user", &entity.User{ID: "u-1"})
|
||||
return c, w
|
||||
}
|
||||
|
||||
// TestUploadAgentFile_SingleFile pins the 1-file → single-dict
|
||||
// response shape (python agent_api.py:775-779).
|
||||
func TestUploadAgentFile_SingleFile(t *testing.T) {
|
||||
loader := &fakeCanvasLoader{canvas: &entity.UserCanvas{ID: "c1"}}
|
||||
fu := &uploadFakes{
|
||||
loader: loader,
|
||||
fileSvc: &fakeAgentFileService{},
|
||||
}
|
||||
fu.setUploadResult([]map[string]interface{}{{"id": "upload-1", "name": "f1.txt"}})
|
||||
h := &AgentHandler{loader: fu.loader, fileService: fu.fileSvc}
|
||||
|
||||
c, w := makeUploadCtx(t, 1)
|
||||
h.UploadAgentFile(c)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
// Single-file path returns data as a single dict, NOT a list.
|
||||
var env struct {
|
||||
Code int `json:"code"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &env); err != nil {
|
||||
t.Fatalf("decode: %v (body=%s)", err, w.Body.String())
|
||||
}
|
||||
if env.Code != 0 {
|
||||
t.Errorf("code = %d, want 0", env.Code)
|
||||
}
|
||||
if env.Data["id"] != "upload-1" {
|
||||
t.Errorf("data.id = %v, want upload-1", env.Data["id"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestUploadAgentFile_MultiFile pins the >1-file → list-of-dicts
|
||||
// response shape (python agent_api.py:780-783).
|
||||
func TestUploadAgentFile_MultiFile(t *testing.T) {
|
||||
loader := &fakeCanvasLoader{canvas: &entity.UserCanvas{ID: "c1"}}
|
||||
fu := &uploadFakes{loader: loader, fileSvc: &fakeAgentFileService{}}
|
||||
fu.setUploadResult([]map[string]interface{}{
|
||||
{"id": "u1", "name": "f1.txt"},
|
||||
{"id": "u2", "name": "f2.txt"},
|
||||
{"id": "u3", "name": "f3.txt"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open sqlite: %v", err)
|
||||
}
|
||||
h := &AgentHandler{loader: fu.loader, fileService: fu.fileSvc}
|
||||
|
||||
if err := db.AutoMigrate(
|
||||
&entity.User{},
|
||||
&entity.UserCanvas{},
|
||||
&entity.UserTenant{},
|
||||
); err != nil {
|
||||
t.Fatalf("failed to migrate: %v", err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// fakeUploadFileService implements fileUploader for tests.
|
||||
type fakeUploadFileService struct {
|
||||
uploaded []map[string]interface{}
|
||||
err error
|
||||
lastTenantID string
|
||||
lastParentID string
|
||||
}
|
||||
|
||||
func (f *fakeUploadFileService) UploadFile(tenantID, parentID string, files []*multipart.FileHeader) ([]map[string]interface{}, error) {
|
||||
f.lastTenantID = tenantID
|
||||
f.lastParentID = parentID
|
||||
return f.uploaded, f.err
|
||||
}
|
||||
|
||||
// TestUploadAgentFileHandler_Success verifies the happy path.
|
||||
func TestUploadAgentFileHandler_Success(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
db := setupUploadTestDB(t)
|
||||
orig := dao.DB
|
||||
dao.DB = db
|
||||
t.Cleanup(func() { dao.DB = orig })
|
||||
|
||||
db.Create(&entity.User{ID: "user-1", Nickname: "test", Email: "test@test.com"})
|
||||
db.Create(&entity.UserCanvas{ID: "canvas-1", UserID: "user-1", Title: sp("Test Agent")})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
body := strings.NewReader("--boundary\r\nContent-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\nContent-Type: text/plain\r\n\r\nhello world\r\n--boundary--")
|
||||
req := httptest.NewRequest("POST", "/api/v1/agents/canvas-1/upload", body)
|
||||
req.Header.Set("Content-Type", "multipart/form-data; boundary=boundary")
|
||||
c.Request = req
|
||||
c.Set("user", &entity.User{ID: "user-1"})
|
||||
c.Set("user_id", "user-1")
|
||||
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}}
|
||||
|
||||
svc := &fakeUploadFileService{
|
||||
uploaded: []map[string]interface{}{
|
||||
{"id": "file-1", "name": "test.txt"},
|
||||
},
|
||||
}
|
||||
h := &AgentHandler{
|
||||
agentService: service.NewAgentService(),
|
||||
fileService: svc,
|
||||
}
|
||||
c, w := makeUploadCtx(t, 3)
|
||||
h.UploadAgentFile(c)
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
code, _ := resp["code"].(float64)
|
||||
if code != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected code 0, got %v: %v", code, resp["message"])
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
var env struct {
|
||||
Code int `json:"code"`
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &env); err != nil {
|
||||
t.Fatalf("decode: %v (body=%s)", err, w.Body.String())
|
||||
}
|
||||
if env.Code != 0 {
|
||||
t.Errorf("code = %d, want 0", env.Code)
|
||||
}
|
||||
if len(env.Data) != 3 {
|
||||
t.Errorf("data has %d items, want 3", len(env.Data))
|
||||
}
|
||||
}
|
||||
|
||||
// TestUploadAgentFileHandler_NoPermission verifies cross-user access is denied.
|
||||
func TestUploadAgentFileHandler_NoPermission(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
// TestUploadAgentFile_MissingFileField verifies that a multipart body
|
||||
// without a "file" field is rejected with a 101 envelope.
|
||||
func TestUploadAgentFile_MissingFileField(t *testing.T) {
|
||||
loader := &fakeCanvasLoader{canvas: &entity.UserCanvas{ID: "c1"}}
|
||||
h := &AgentHandler{loader: loader, fileService: &fakeAgentFileService{}}
|
||||
|
||||
db := setupUploadTestDB(t)
|
||||
orig := dao.DB
|
||||
dao.DB = db
|
||||
t.Cleanup(func() { dao.DB = orig })
|
||||
|
||||
db.Create(&entity.User{ID: "user-a", Nickname: "a", Email: "a@test.com"})
|
||||
db.Create(&entity.UserCanvas{ID: "canvas-b", UserID: "user-b", Title: sp("Not Yours")})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/v1/agents/canvas-b/upload", nil)
|
||||
c.Set("user", &entity.User{ID: "user-a"})
|
||||
c.Set("user_id", "user-a")
|
||||
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-b"}}
|
||||
|
||||
h := &AgentHandler{
|
||||
agentService: service.NewAgentService(),
|
||||
fileService: &fakeUploadFileService{},
|
||||
}
|
||||
c, w := makeUploadCtxNoFile(t)
|
||||
h.UploadAgentFile(c)
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
code, _ := resp["code"].(float64)
|
||||
if code != float64(common.CodeOperatingError) {
|
||||
t.Errorf("expected operating error %d, got %v", common.CodeOperatingError, code)
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != 101 { // CodeArgumentError
|
||||
t.Errorf("code = %d, want 101; msg=%q", code, msg)
|
||||
}
|
||||
if !strings.Contains(msg, "file") {
|
||||
t.Errorf("msg = %q, want mention of 'file'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUploadAgentFileHandler_NoFiles verifies empty file list is rejected.
|
||||
func TestUploadAgentFileHandler_NoFiles(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
// TestUploadAgentFile_CannotAccessCanvas verifies that an
|
||||
// ErrUserCanvasNotFound from the loader short-circuits with a 103
|
||||
// envelope carrying the python permission-failure message
|
||||
// (matches @_require_canvas_access_async at agent_api.py:78,89).
|
||||
func TestUploadAgentFile_CannotAccessCanvas(t *testing.T) {
|
||||
loader := &fakeCanvasLoader{err: dao.ErrUserCanvasNotFound}
|
||||
h := &AgentHandler{loader: loader, fileService: &fakeAgentFileService{}}
|
||||
|
||||
db := setupUploadTestDB(t)
|
||||
orig := dao.DB
|
||||
dao.DB = db
|
||||
t.Cleanup(func() { dao.DB = orig })
|
||||
|
||||
db.Create(&entity.User{ID: "user-1", Nickname: "test", Email: "test@test.com"})
|
||||
db.Create(&entity.UserCanvas{ID: "canvas-1", UserID: "user-1", Title: sp("Test Agent")})
|
||||
|
||||
body := strings.NewReader("--boundary\r\nContent-Disposition: form-data; name=\"dummy\"\r\n\r\nvalue\r\n--boundary--")
|
||||
req := httptest.NewRequest("POST", "/api/v1/agents/canvas-1/upload", body)
|
||||
req.Header.Set("Content-Type", "multipart/form-data; boundary=boundary")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
c.Set("user", &entity.User{ID: "user-1"})
|
||||
c.Set("user_id", "user-1")
|
||||
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}}
|
||||
|
||||
h := &AgentHandler{
|
||||
agentService: service.NewAgentService(),
|
||||
fileService: &fakeUploadFileService{},
|
||||
}
|
||||
c, w := makeUploadCtx(t, 1)
|
||||
h.UploadAgentFile(c)
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
code, _ := resp["code"].(float64)
|
||||
if code != float64(common.CodeArgumentError) {
|
||||
t.Errorf("expected argument error, got code %v", code)
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != 103 { // CodeOperatingError
|
||||
t.Errorf("code = %d, want 103; msg=%q", code, msg)
|
||||
}
|
||||
want := "Make sure you have permission to access the agent."
|
||||
if msg != want {
|
||||
t.Errorf("msg = %q, want %q", msg, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUploadAgentFileHandler_TeamMemberTenant verifies that when a team
|
||||
// member uploads to a shared canvas, the file is written into the canvas
|
||||
// owner's file tree, not the caller's.
|
||||
func TestUploadAgentFileHandler_TeamMemberTenant(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
// TestUploadAgentFile_URLImport pins the ?url= import path (python
|
||||
// agent_api.py:775-779). When ?url= is set AND the body has exactly
|
||||
// one `file` field, the handler delegates to FileService.UploadFromURL
|
||||
// and returns its single-dict result.
|
||||
func TestUploadAgentFile_URLImport(t *testing.T) {
|
||||
loader := &fakeCanvasLoader{canvas: &entity.UserCanvas{ID: "c1"}}
|
||||
fu := &uploadFakes{loader: loader, fileSvc: &fakeAgentFileService{}}
|
||||
fu.fileSvc.urlUpload = map[string]interface{}{"id": "url-1", "name": "remote.bin"}
|
||||
h := &AgentHandler{loader: fu.loader, fileService: fu.fileSvc}
|
||||
|
||||
db := setupUploadTestDB(t)
|
||||
orig := dao.DB
|
||||
dao.DB = db
|
||||
t.Cleanup(func() { dao.DB = orig })
|
||||
c, w := makeUploadCtx(t, 1)
|
||||
c.Request.URL.RawQuery = "url=https%3A%2F%2Fexample.com%2Ffile.bin"
|
||||
h.UploadAgentFile(c)
|
||||
|
||||
// user-b is a team member of user-a's tenant
|
||||
db.Create(&entity.User{ID: "user-a", Nickname: "owner", Email: "a@test.com"})
|
||||
db.Create(&entity.User{ID: "user-b", Nickname: "member", Email: "b@test.com"})
|
||||
db.Create(&entity.UserTenant{ID: "ut-1", UserID: "user-b", TenantID: "user-a", Role: "member", Status: sp("1")})
|
||||
db.Create(&entity.UserCanvas{
|
||||
ID: "canvas-1",
|
||||
UserID: "user-a",
|
||||
Permission: "team",
|
||||
Title: sp("Shared Agent"),
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
var env struct {
|
||||
Code int `json:"code"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &env); err != nil {
|
||||
t.Fatalf("decode: %v (body=%s)", err, w.Body.String())
|
||||
}
|
||||
if env.Code != 0 {
|
||||
t.Errorf("code = %d, want 0", env.Code)
|
||||
}
|
||||
if env.Data["id"] != "url-1" {
|
||||
t.Errorf("data.id = %v, want url-1", env.Data["id"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestUploadAgentFile_URLImport_IgnoredForMultiFile pins that
|
||||
// ?url= is silently ignored when the body has >1 files, matching
|
||||
// python's behaviour at agent_api.py:780-783 (the multi-file branch
|
||||
// never reads url). The request flows into the normal UploadInfos
|
||||
// path and returns a list of dicts.
|
||||
func TestUploadAgentFile_URLImport_IgnoredForMultiFile(t *testing.T) {
|
||||
loader := &fakeCanvasLoader{canvas: &entity.UserCanvas{ID: "c1"}}
|
||||
fu := &uploadFakes{loader: loader, fileSvc: &fakeAgentFileService{}}
|
||||
fu.fileSvc.urlUpload = map[string]interface{}{"id": "should-not-appear"}
|
||||
fu.setUploadResult([]map[string]interface{}{
|
||||
{"id": "u1", "name": "f1.txt"},
|
||||
{"id": "u2", "name": "f2.txt"},
|
||||
{"id": "u3", "name": "f3.txt"},
|
||||
})
|
||||
h := &AgentHandler{loader: fu.loader, fileService: fu.fileSvc}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
body := strings.NewReader("--boundary\r\nContent-Disposition: form-data; name=\"file\"; filename=\"shared.txt\"\r\nContent-Type: text/plain\r\n\r\nhello\r\n--boundary--")
|
||||
req := httptest.NewRequest("POST", "/api/v1/agents/canvas-1/upload", body)
|
||||
req.Header.Set("Content-Type", "multipart/form-data; boundary=boundary")
|
||||
c.Request = req
|
||||
c.Set("user", &entity.User{ID: "user-b"})
|
||||
c.Set("user_id", "user-b")
|
||||
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}}
|
||||
|
||||
svc := &fakeUploadFileService{
|
||||
uploaded: []map[string]interface{}{
|
||||
{"id": "file-1", "name": "shared.txt"},
|
||||
},
|
||||
}
|
||||
h := &AgentHandler{
|
||||
agentService: service.NewAgentService(),
|
||||
fileService: svc,
|
||||
}
|
||||
c, w := makeUploadCtx(t, 3)
|
||||
c.Request.URL.RawQuery = "url=https%3A%2F%2Fexample.com%2Ffile.bin"
|
||||
h.UploadAgentFile(c)
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected code 0, got %v: %v", resp["code"], resp["message"])
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
if svc.lastTenantID != "user-b" {
|
||||
t.Errorf("expected UploadFile called with authenticated user 'user-a', got '%s'", svc.lastTenantID)
|
||||
var env struct {
|
||||
Code int `json:"code"`
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &env); err != nil {
|
||||
t.Fatalf("decode: %v (body=%s)", err, w.Body.String())
|
||||
}
|
||||
if env.Code != 0 {
|
||||
t.Errorf("code = %d, want 0", env.Code)
|
||||
}
|
||||
if len(env.Data) != 3 {
|
||||
t.Errorf("data has %d items, want 3 (url was ignored)", len(env.Data))
|
||||
}
|
||||
// Sanity check: the urlUpload map should NOT have been consumed.
|
||||
// We don't read it back, but the test setup proves the
|
||||
// handler didn't dispatch to UploadFromURL (otherwise it would
|
||||
// have returned the single urlUpload dict, not the 3-element
|
||||
// list from uploadList).
|
||||
}
|
||||
|
||||
// TestUploadAgentFile_URLImport_LoaderError pins that a failed URL
|
||||
// fetch surfaces as a 500 envelope (matches python's
|
||||
// `server_error_response(exc)` on line 784).
|
||||
func TestUploadAgentFile_URLImport_LoaderError(t *testing.T) {
|
||||
loader := &fakeCanvasLoader{canvas: &entity.UserCanvas{ID: "c1"}}
|
||||
fu := &uploadFakes{loader: loader, fileSvc: &fakeAgentFileService{urlUploadErr: errors.New("ssrf guard tripped")}}
|
||||
h := &AgentHandler{loader: fu.loader, fileService: fu.fileSvc}
|
||||
|
||||
c, w := makeUploadCtx(t, 1)
|
||||
c.Request.URL.RawQuery = "url=https%3A%2F%2Finternal.example.com%2Fsecret"
|
||||
h.UploadAgentFile(c)
|
||||
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != 500 { // CodeServerError
|
||||
t.Errorf("code = %d, want 500; msg=%q", code, msg)
|
||||
}
|
||||
if !strings.Contains(msg, "ssrf guard tripped") {
|
||||
t.Errorf("msg = %q, want contains 'ssrf guard tripped'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// sp returns a pointer to the given string.
|
||||
func sp(s string) *string { return &s }
|
||||
func (f *fakeUploadFileService) DownloadAgentFile(tenantID, location string) ([]byte, error) {
|
||||
panic("DownloadAgentFile should not be called during upload tests")
|
||||
// TestUploadAgentFile_LoaderError verifies that a non-ErrUserCanvasNotFound
|
||||
// service error surfaces as a 500 envelope with the error string.
|
||||
func TestUploadAgentFile_LoaderError(t *testing.T) {
|
||||
loader := &fakeCanvasLoader{err: context.DeadlineExceeded}
|
||||
h := &AgentHandler{loader: loader, fileService: &fakeAgentFileService{}}
|
||||
|
||||
c, w := makeUploadCtx(t, 1)
|
||||
h.UploadAgentFile(c)
|
||||
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != 500 { // CodeServerError
|
||||
t.Errorf("code = %d, want 500; msg=%q", code, msg)
|
||||
}
|
||||
if !strings.Contains(msg, "deadline") {
|
||||
t.Errorf("msg = %q, want contains 'deadline'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
669
internal/handler/agent_webhook.go
Normal file
669
internal/handler/agent_webhook.go
Normal file
@@ -0,0 +1,669 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
// Webhook trigger handler — Go port of
|
||||
// api/apps/restful_apis/agent_api.py:1563-2248
|
||||
// (`/api/v1/agents/<agent_id>/webhook` and `/.../webhook/test`).
|
||||
//
|
||||
// Mirrors the python reference's 9-step flow:
|
||||
// 1. Load canvas (IDOR-safe via loadCanvasForUser → LoadCanvasByID).
|
||||
// 2. Reject DataFlow canvas category.
|
||||
// 3. Parse DSL map.
|
||||
// 4. Find Begin with mode=="Webhook"; capture webhook_cfg.
|
||||
// 5. Validate request method against webhook_cfg.methods.
|
||||
// 6. validateWebhookSecurity (max body size / IP whitelist / rate limit
|
||||
// / token / basic / jwt — strict fail-closed rate limit).
|
||||
// 7. Parse request body (parseWebhookRequest).
|
||||
// 8. Schema extract query/headers/body via extractBySchema.
|
||||
// 9. Dispatch:
|
||||
// - execution_mode == "Immediately" → return configured status +
|
||||
// body_template synchronously; canvas runs detached in the
|
||||
// background. Trace events appended when isTest.
|
||||
// - else (streaming/aggregate) → block on canvas.run, aggregate
|
||||
// message content, return JSON.
|
||||
//
|
||||
// Notes on Python parity divergences:
|
||||
// - `cvs.dsl = json.loads(str(canvas)); UserCanvasService.update_by_id(...)`
|
||||
// post-run DSL writeback is NOT ported. The Go runner mutates an
|
||||
// in-memory copy, not the persisted row. UpdateAgent already persists
|
||||
// the editable DSL on the user-driven path.
|
||||
// - multipart/form-data uploads are rejected with HTTP 501
|
||||
// (ErrWebhookMultipartNotSupported) at parse time. The Python
|
||||
// upload path through canvas.get_files_async depends on
|
||||
// FileService.upload_info which is itself not in the Go port yet.
|
||||
// When FileService.upload_info lands, the 501 branch in Webhook
|
||||
// (line 162-185) becomes the dispatch entry point.
|
||||
// - content_types whitelist ENFORCES a hard reject (102 envelope)
|
||||
// when the request Content-Type does not match the configured
|
||||
// value. Mirrors python agent_api.py:1839-1842 (`raise
|
||||
// ValueError("Invalid Content-Type...")`).
|
||||
// - Webhook trace shape (webhook-trace-<id>-logs in Redis) matches
|
||||
// the python append_webhook_trace at agent_api.py:2073-2091 so the
|
||||
// eventual /webhook/logs poll PR can consume this writer
|
||||
// unchanged.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"ragflow/internal/agent/canvas"
|
||||
"ragflow/internal/common"
|
||||
rediscli "ragflow/internal/engine/redis"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// canvasLoader is the subset of service.AgentService the webhook handler
|
||||
// needs. Defined as an interface so handler tests can inject a fake
|
||||
// without standing up the full AgentService (DB DAOs, eino runner, etc).
|
||||
//
|
||||
// LoadCanvasByID returns the raw DAO/service error so the WEBHOOK
|
||||
// handler can fold it into its own 102 envelope. We deliberately do NOT
|
||||
// route through handler.mapAgentError here because that helper maps
|
||||
// ErrUserCanvasNotFound to 103 "Make sure you have permission..." which
|
||||
// is the chat/run contract — wrong for the webhook path.
|
||||
type canvasLoader interface {
|
||||
LoadCanvasByID(ctx context.Context, userID, canvasID string) (*entity.UserCanvas, error)
|
||||
RunAgentWithWebhook(ctx context.Context, userID, canvasID string, payload map[string]any) (<-chan canvas.RunEvent, error)
|
||||
}
|
||||
|
||||
// Webhook is the handler method mounted at:
|
||||
//
|
||||
// /api/v1/agents/:canvas_id/webhook (production trigger)
|
||||
// /api/v1/agents/:canvas_id/webhook/test (test trigger, with trace)
|
||||
//
|
||||
// The Python decorator stack at agent_api.py:1563 binds six methods to
|
||||
// a single path. Gin has no Match() — the router registers each verb
|
||||
// individually via registerAnyMethod (see internal/router/agent_routes.go).
|
||||
func (h *AgentHandler) Webhook(c *gin.Context) {
|
||||
user, code, msg := GetUser(c)
|
||||
if code != common.CodeSuccess {
|
||||
jsonError(c, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
canvasID := c.Param("canvas_id")
|
||||
if h.loader == nil {
|
||||
jsonInternalError(c, errors.New("agent webhook: loader not configured"))
|
||||
return
|
||||
}
|
||||
isTest := strings.HasSuffix(c.Request.URL.Path, "/webhook/test")
|
||||
startTs := time.Now()
|
||||
|
||||
// 1. Load canvas. Webhook collapses missing-vs-foreign into a
|
||||
// single 102 "Canvas not found." (matches Python
|
||||
// api/apps/restful_apis/agent_api.py:1572 and the existing
|
||||
// GetAgentWebhookLogs envelope), so we DO NOT route through
|
||||
// handler.mapAgentError here.
|
||||
cv, err := h.loader.LoadCanvasByID(c.Request.Context(), user.ID, canvasID)
|
||||
if err != nil {
|
||||
if errors.Is(err, dao.ErrUserCanvasNotFound) || errors.Is(err, dao.ErrUserCanvasVersionNotFound) {
|
||||
jsonError(c, common.CodeDataError, "Canvas not found.")
|
||||
return
|
||||
}
|
||||
jsonInternalError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 2. Reject DataFlow.
|
||||
if cv.CanvasCategory == "DataFlow" {
|
||||
jsonError(c, common.CodeDataError, "Dataflow can not be triggered by webhook.")
|
||||
return
|
||||
}
|
||||
|
||||
// 3. DSL map. cv.DSL is a typed entity.JSONMap (map[string]any
|
||||
// under the hood); we copy into a plain map[string]any so the
|
||||
// downstream helpers can mutate freely without aliasing surprises.
|
||||
dsl := map[string]any{}
|
||||
for k, v := range cv.DSL {
|
||||
dsl[k] = v
|
||||
}
|
||||
|
||||
// 4. Find Begin component with mode=="Webhook".
|
||||
webhookCfg := findWebhookBegin(dsl)
|
||||
if webhookCfg == nil {
|
||||
jsonError(c, common.CodeDataError, "Webhook not configured for this agent.")
|
||||
return
|
||||
}
|
||||
|
||||
// 5. Method gate.
|
||||
if !methodAllowed(webhookCfg["methods"], c.Request.Method) {
|
||||
jsonError(c, common.CodeDataError,
|
||||
fmt.Sprintf("HTTP method '%s' not allowed for this webhook.", c.Request.Method))
|
||||
return
|
||||
}
|
||||
|
||||
// 6. Security gate (strict; surfaces all errors as 102).
|
||||
securityCfg := stringMap(webhookCfg["security"])
|
||||
if err := validateWebhookSecurity(securityCfg, c, canvasID); err != nil {
|
||||
jsonError(c, common.CodeDataError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 6a. Body size DoS hardening (security review MEDIUM-1). The
|
||||
// security validator above checks the Content-Length header, but
|
||||
// that header is attacker-controllable. Wrap the actual stream
|
||||
// reader with http.MaxBytesReader so io.ReadAll inside
|
||||
// parseWebhookRequest is bounded by the same parsed limit.
|
||||
// Errors here surface as ErrWebhookContentTypeMismatch-shaped 102
|
||||
// so the operator sees the failure mode consistently with the
|
||||
// header check.
|
||||
if limit, perr := parseMaxBodySize(securityCfg); perr == nil && limit > 0 && c.Request.Body != nil {
|
||||
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, limit)
|
||||
}
|
||||
|
||||
// 7. Parse request body. Two error classes flow through here:
|
||||
// - ErrWebhookMultipartNotSupported → HTTP 501 (file uploads
|
||||
// depend on FileService.upload_info which is not yet ported).
|
||||
// - ErrWebhookContentTypeMismatch → HTTP 102 (content-type
|
||||
// whitelist, mirrors python agent_api.py:1839).
|
||||
// Both are surfaced as typed errors so the dispatch envelope is
|
||||
// unambiguous — neither path falls through into schema validation.
|
||||
contentType, _ := webhookCfg["content_types"].(string)
|
||||
parsed, parseErr := parseWebhookRequest(contentType, c)
|
||||
if parseErr != nil {
|
||||
switch {
|
||||
case errors.Is(parseErr, ErrWebhookMultipartNotSupported):
|
||||
// 501 — multipart/form-data uploads are not yet
|
||||
// supported. Body is short so operators see exactly what
|
||||
// is missing.
|
||||
c.JSON(http.StatusNotImplemented, gin.H{
|
||||
"code": http.StatusNotImplemented,
|
||||
"data": false,
|
||||
"message": parseErr.Error(),
|
||||
})
|
||||
return
|
||||
case errors.Is(parseErr, ErrWebhookContentTypeMismatch):
|
||||
jsonError(c, common.CodeDataError, parseErr.Error())
|
||||
return
|
||||
default:
|
||||
jsonError(c, common.CodeDataError, parseErr.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 8. Schema extract query/headers/body.
|
||||
schema, _ := webhookCfg["schema"].(map[string]any)
|
||||
clean, schemaErr := applyWebhookSchema(parsed, schema)
|
||||
if schemaErr != nil {
|
||||
jsonError(c, common.CodeDataError, schemaErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 9. Dispatch.
|
||||
mode, _ := webhookCfg["execution_mode"].(string)
|
||||
if mode == "" {
|
||||
mode = "Immediately"
|
||||
}
|
||||
if mode == "Immediately" {
|
||||
status, contentType, payload, perr := renderImmediatelyResponse(stringMap(webhookCfg["response"]))
|
||||
if perr != nil {
|
||||
jsonError(c, common.CodeDataError, perr.Error())
|
||||
return
|
||||
}
|
||||
// Detached background run — does NOT inherit c.Request.Context()
|
||||
// so a client disconnect does not cancel the canvas run.
|
||||
go h.runWebhookDetached(cv, clean, isTest, startTs)
|
||||
c.Data(status, contentType, payload)
|
||||
return
|
||||
}
|
||||
// Streaming / blocking mode.
|
||||
out := h.runWebhookSync(c.Request.Context(), cv, clean, isTest, startTs)
|
||||
c.JSON(out.status, out.body)
|
||||
}
|
||||
|
||||
// findWebhookBegin scans the DSL components map for a Begin component
|
||||
// whose params.mode equals "Webhook". Returns the params map (the
|
||||
// webhook_cfg) on success; nil otherwise. Mirrors agent_api.py:1584-1592.
|
||||
func findWebhookBegin(dsl map[string]any) map[string]any {
|
||||
if dsl == nil {
|
||||
return nil
|
||||
}
|
||||
components, _ := dsl["components"].(map[string]any)
|
||||
if components == nil {
|
||||
return nil
|
||||
}
|
||||
for _, raw := range components {
|
||||
entry, _ := raw.(map[string]any)
|
||||
if entry == nil {
|
||||
continue
|
||||
}
|
||||
obj, _ := entry["obj"].(map[string]any)
|
||||
if obj == nil {
|
||||
continue
|
||||
}
|
||||
name, _ := obj["component_name"].(string)
|
||||
if !strings.EqualFold(name, "begin") {
|
||||
continue
|
||||
}
|
||||
params, _ := obj["params"].(map[string]any)
|
||||
if params == nil {
|
||||
continue
|
||||
}
|
||||
if mode, _ := params["mode"].(string); mode == "Webhook" {
|
||||
return params
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// methodAllowed returns true when `requestMethod` is in the configured
|
||||
// list. Empty list → allow (matches python: agent_api.py:1596 short-circuit).
|
||||
func methodAllowed(raw any, requestMethod string) bool {
|
||||
methods, _ := raw.([]any)
|
||||
if len(methods) == 0 {
|
||||
return true
|
||||
}
|
||||
want := strings.ToUpper(requestMethod)
|
||||
for _, m := range methods {
|
||||
if s, _ := m.(string); strings.ToUpper(s) == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// stringMap is a tiny helper: extract a map[string]any, treating nil and
|
||||
// wrong-type cases as empty maps. Used heavily to keep the dispatch
|
||||
// logic readable.
|
||||
func stringMap(v any) map[string]any {
|
||||
if m, ok := v.(map[string]any); ok {
|
||||
return m
|
||||
}
|
||||
return map[string]any{}
|
||||
}
|
||||
|
||||
// parseWebhookRequest mirrors agent_api.py:1828-1894.
|
||||
//
|
||||
// Returns (parsed, error). The parsed map carries query / headers /
|
||||
// body / content_type. The body is parsed according to the request
|
||||
// Content-Type:
|
||||
//
|
||||
// - application/json → unmarshal JSON
|
||||
// - application/x-www-form-urlencoded → form fields
|
||||
// - text/plain / octet-stream / unknown / empty → raw bytes → JSON
|
||||
//
|
||||
// Two errors are surfaced directly to the handler so the dispatch path
|
||||
// can return the right envelope without falling through into a
|
||||
// misleading schema-validation error:
|
||||
//
|
||||
// - ErrWebhookMultipartNotSupported (501): the request used
|
||||
// multipart/form-data. The Python path uploads files through
|
||||
// FileService.upload_info → canvas.get_files_async, both of which
|
||||
// are NOT yet ported. We refuse the request up front so callers
|
||||
// see a clear, typed 501 instead of an unrelated schema error.
|
||||
//
|
||||
// - ErrWebhookContentTypeMismatch (102): when the webhook config
|
||||
// sets content_types, the request Content-Type must match. The
|
||||
// Python reference raises here
|
||||
// (`raise ValueError("Invalid Content-Type...")` at
|
||||
// agent_api.py:1839-1842); we mirror that as a typed error and
|
||||
// surface it through the same 102 envelope as the rest of the
|
||||
// validation errors so operators see the same response shape.
|
||||
func parseWebhookRequest(configuredContentType string, c *gin.Context) (map[string]any, error) {
|
||||
// 1. Query
|
||||
q := map[string]any{}
|
||||
for k, vals := range c.Request.URL.Query() {
|
||||
if len(vals) == 1 {
|
||||
q[k] = vals[0]
|
||||
} else {
|
||||
q[k] = vals
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Headers
|
||||
hd := map[string]any{}
|
||||
for k, vals := range c.Request.Header {
|
||||
if len(vals) == 1 {
|
||||
hd[k] = vals[0]
|
||||
} else {
|
||||
hd[k] = vals
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Body
|
||||
ctype := strings.SplitN(c.GetHeader("Content-Type"), ";", 2)[0]
|
||||
ctype = strings.TrimSpace(strings.ToLower(ctype))
|
||||
|
||||
// multipart/form-data → 501. Checked BEFORE the content-type match
|
||||
// below because multipart uploads don't carry a useful
|
||||
// `content_types` value to validate against — the python handler
|
||||
// routes them through canvas.get_files_async which we have not
|
||||
// ported yet.
|
||||
if ctype == "multipart/form-data" {
|
||||
return nil, ErrWebhookMultipartNotSupported
|
||||
}
|
||||
|
||||
body := map[string]any{}
|
||||
|
||||
switch ctype {
|
||||
case "application/json":
|
||||
raw, _ := io.ReadAll(c.Request.Body)
|
||||
if len(raw) > 0 {
|
||||
_ = json.Unmarshal(raw, &body)
|
||||
}
|
||||
case "application/x-www-form-urlencoded":
|
||||
if err := c.Request.ParseForm(); err == nil {
|
||||
for k, vals := range c.Request.PostForm {
|
||||
if len(vals) == 1 {
|
||||
body[k] = vals[0]
|
||||
} else {
|
||||
body[k] = vals
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
raw, _ := io.ReadAll(c.Request.Body)
|
||||
if len(raw) > 0 {
|
||||
_ = json.Unmarshal(raw, &body)
|
||||
}
|
||||
}
|
||||
|
||||
// content_type whitelist. Empty configuredContentType → no check
|
||||
// (matches python: agent_api.py:1839 only raises when the
|
||||
// webhook_cfg actually sets content_types).
|
||||
//
|
||||
// The python reference has a `if ctype and ...` short-circuit that
|
||||
// lets a request with NO Content-Type header through even when the
|
||||
// webhook config requires one. That is a whitelist bypass — the
|
||||
// configured type is irrelevant if a caller can simply omit the
|
||||
// header. The Go port tightens the contract: when the operator
|
||||
// configured content_types, the request MUST carry a matching
|
||||
// Content-Type. Missing header → ErrWebhookContentTypeMismatch.
|
||||
if configuredContentType != "" && configuredContentType != ctype {
|
||||
return nil, fmt.Errorf("%w: expect %q, got %q",
|
||||
ErrWebhookContentTypeMismatch, configuredContentType, ctype)
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"query": q,
|
||||
"headers": hd,
|
||||
"body": body,
|
||||
"content_type": ctype,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ErrWebhookMultipartNotSupported is returned by parseWebhookRequest
|
||||
// when the inbound Content-Type is multipart/form-data. The handler
|
||||
// translates this to HTTP 501 Not Implemented because the Python file
|
||||
// upload path (FileService.upload_info → canvas.get_files_async) is
|
||||
// not ported yet.
|
||||
var ErrWebhookMultipartNotSupported = errors.New("multipart/form-data uploads are not supported in this port")
|
||||
|
||||
// ErrWebhookContentTypeMismatch is returned by parseWebhookRequest when
|
||||
// the configured content_types whitelist disagrees with the request's
|
||||
// Content-Type. Mirrors python agent_api.py:1839-1842 (`raise
|
||||
// ValueError("Invalid Content-Type...")`). Surfaced as 102 so the
|
||||
// operator sees a clear, content-type-specific message.
|
||||
var ErrWebhookContentTypeMismatch = errors.New("invalid content-type")
|
||||
|
||||
// applyWebhookSchema runs extractBySchema on the parsed request's three
|
||||
// sections and assembles the clean_request map that the Python
|
||||
// webhook handler passes to canvas.run(webhook_payload=...). Mirrors
|
||||
// agent_api.py:2052-2068.
|
||||
func applyWebhookSchema(parsed map[string]any, schema map[string]any) (map[string]any, error) {
|
||||
q := stringMap(parsed["query"])
|
||||
hd := stringMap(parsed["headers"])
|
||||
bd := stringMap(parsed["body"])
|
||||
|
||||
qClean, err := extractBySchema(q, stringMap(schema["query"]), "query")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hdClean, err := extractBySchema(hd, stringMap(schema["headers"]), "headers")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bdClean, err := extractBySchema(bd, stringMap(schema["body"]), "body")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]any{
|
||||
"query": qClean,
|
||||
"headers": hdClean,
|
||||
"body": bdClean,
|
||||
"input": parsed,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// renderImmediatelyResponse builds the synchronous response for the
|
||||
// Immediately execution mode. Mirrors agent_api.py:2093-2121.
|
||||
//
|
||||
// - status: int in [200, 399]; defaults to 200; any other value
|
||||
// raises so the operator notices a config bug.
|
||||
// - body_template: JSON or text. We try JSON first; fall back to
|
||||
// plain text on failure. Empty body → no body, content-type
|
||||
// application/json (matching python parse_body(None)).
|
||||
func renderImmediatelyResponse(cfg map[string]any) (int, string, []byte, error) {
|
||||
statusRaw, ok := cfg["status"]
|
||||
status := 200
|
||||
if ok {
|
||||
switch v := statusRaw.(type) {
|
||||
case int:
|
||||
status = v
|
||||
case int64:
|
||||
status = int(v)
|
||||
case float64:
|
||||
status = int(v)
|
||||
case string:
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return 0, "", nil, fmt.Errorf("invalid response status code: %v", v)
|
||||
}
|
||||
status = n
|
||||
default:
|
||||
return 0, "", nil, fmt.Errorf("invalid response status code: %v", v)
|
||||
}
|
||||
}
|
||||
if status < 200 || status > 399 {
|
||||
return 0, "", nil, fmt.Errorf("invalid response status code: %d, must be between 200 and 399", status)
|
||||
}
|
||||
|
||||
bodyTpl, _ := cfg["body_template"].(string)
|
||||
if bodyTpl == "" {
|
||||
return status, "application/json", nil, nil
|
||||
}
|
||||
|
||||
// Try JSON parse first (python: parse_body() branch at line 2110).
|
||||
var probe any
|
||||
if err := json.Unmarshal([]byte(bodyTpl), &probe); err == nil {
|
||||
encoded, _ := json.Marshal(probe)
|
||||
return status, "application/json", encoded, nil
|
||||
}
|
||||
return status, "text/plain", []byte(bodyTpl), nil
|
||||
}
|
||||
|
||||
// runWebhookDetached runs the canvas in the background. It uses
|
||||
// context.Background() with a 5-minute timeout (NOT
|
||||
// c.Request.Context()) so a client disconnect does NOT cancel the run.
|
||||
// Trace events are appended to the redis key when isTest is true.
|
||||
//
|
||||
// Mirrors python: agent_api.py:2123-2175 (the asyncio.create_task body
|
||||
// inside the Immediately branch).
|
||||
func (h *AgentHandler) runWebhookDetached(
|
||||
cv *entity.UserCanvas, payload map[string]any, isTest bool, startTs time.Time,
|
||||
) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
events, err := h.loader.RunAgentWithWebhook(ctx, cv.UserID, cv.ID, payload)
|
||||
if err != nil {
|
||||
common.Warn("webhook detached run start failed",
|
||||
zap.String("canvas", cv.ID),
|
||||
zap.Error(err))
|
||||
if isTest {
|
||||
appendWebhookTrace(cv.ID, startTs, canvas.RunEvent{Type: "error", Data: mustJSON(map[string]any{"message": err.Error()})})
|
||||
}
|
||||
return
|
||||
}
|
||||
for ev := range events {
|
||||
if isTest {
|
||||
appendWebhookTrace(cv.ID, startTs, ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runWebhookSync drives the canvas in the non-Immediately (streaming)
|
||||
// mode. Returns the HTTP status + body to send. Mirrors
|
||||
// agent_api.py:2178-2247 with the python sse() coroutine flattened into
|
||||
// a synchronous loop.
|
||||
type webhookSyncResult struct {
|
||||
status int
|
||||
body any
|
||||
}
|
||||
|
||||
func (h *AgentHandler) runWebhookSync(
|
||||
ctx context.Context, cv *entity.UserCanvas, payload map[string]any,
|
||||
isTest bool, startTs time.Time,
|
||||
) webhookSyncResult {
|
||||
status := 200
|
||||
events, err := h.loader.RunAgentWithWebhook(ctx, cv.UserID, cv.ID, payload)
|
||||
if err != nil {
|
||||
if isTest {
|
||||
appendWebhookTrace(cv.ID, startTs, canvas.RunEvent{Type: "error", Data: mustJSON(map[string]any{"message": err.Error()})})
|
||||
appendWebhookTrace(cv.ID, startTs, canvas.RunEvent{Type: "finished", Data: mustJSON(map[string]any{"success": false})})
|
||||
}
|
||||
return webhookSyncResult{status: http.StatusBadRequest, body: gin.H{
|
||||
"code": 400,
|
||||
"message": err.Error(),
|
||||
"success": false,
|
||||
}}
|
||||
}
|
||||
|
||||
contents := []string{}
|
||||
for ev := range events {
|
||||
if isTest {
|
||||
appendWebhookTrace(cv.ID, startTs, ev)
|
||||
}
|
||||
switch ev.Type {
|
||||
case "message":
|
||||
var msg struct {
|
||||
Content string `json:"content"`
|
||||
StartToThink bool `json:"start_to_think"`
|
||||
EndToThink bool `json:"end_to_think"`
|
||||
}
|
||||
if json.Unmarshal([]byte(ev.Data), &msg) == nil {
|
||||
content := msg.Content
|
||||
if msg.StartToThink {
|
||||
content = "think"
|
||||
} else if msg.EndToThink {
|
||||
content = "/think"
|
||||
}
|
||||
if content != "" {
|
||||
contents = append(contents, content)
|
||||
}
|
||||
}
|
||||
case "message_end":
|
||||
var end struct {
|
||||
Status *int `json:"status"`
|
||||
}
|
||||
if json.Unmarshal([]byte(ev.Data), &end) == nil && end.Status != nil {
|
||||
status = *end.Status
|
||||
}
|
||||
}
|
||||
}
|
||||
final := strings.Join(contents, "")
|
||||
if isTest {
|
||||
appendWebhookTrace(cv.ID, startTs, canvas.RunEvent{Type: "finished", Data: mustJSON(map[string]any{"success": true})})
|
||||
}
|
||||
return webhookSyncResult{status: status, body: gin.H{
|
||||
"message": final,
|
||||
"success": true,
|
||||
"code": status,
|
||||
}}
|
||||
}
|
||||
|
||||
// mustJSON marshals v to a JSON object string. Used by trace appenders;
|
||||
// panics on marshal failure (acceptable because we only marshal
|
||||
// statically-typed map[string]any values).
|
||||
func mustJSON(v any) string {
|
||||
var buf bytes.Buffer
|
||||
_ = json.NewEncoder(&buf).Encode(v)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// appendWebhookTrace appends a single RunEvent to the per-canvas trace
|
||||
// key in Redis. Mirrors python's append_webhook_trace at
|
||||
// agent_api.py:2073-2091.
|
||||
//
|
||||
// The trace key is `webhook-trace-<agent_id>-logs` with a 600 s TTL.
|
||||
// Each event is recorded as {"ts": <float>, "event": <type>, ...}.
|
||||
// Tests use miniredis to verify the key shape.
|
||||
func appendWebhookTrace(agentID string, startTs time.Time, ev canvas.RunEvent) {
|
||||
rdb := rediscli.Get()
|
||||
if rdb == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("webhook-trace-%s-logs", agentID)
|
||||
raw, _ := rdb.Get(key)
|
||||
obj := map[string]any{}
|
||||
if raw != "" {
|
||||
_ = json.Unmarshal([]byte(raw), &obj)
|
||||
}
|
||||
whs, _ := obj["webhooks"].(map[string]any)
|
||||
if whs == nil {
|
||||
whs = map[string]any{}
|
||||
obj["webhooks"] = whs
|
||||
}
|
||||
entryKey := strconv.FormatFloat(float64(startTs.UnixNano())/1e9, 'f', -1, 64)
|
||||
entry, _ := whs[entryKey].(map[string]any)
|
||||
if entry == nil {
|
||||
entry = map[string]any{
|
||||
"start_ts": float64(startTs.UnixNano()) / 1e9,
|
||||
"events": []any{},
|
||||
}
|
||||
whs[entryKey] = entry
|
||||
}
|
||||
events, _ := entry["events"].([]any)
|
||||
eventRecord := map[string]any{
|
||||
"ts": float64(time.Now().UnixNano()) / 1e9,
|
||||
"event": ev.Type,
|
||||
}
|
||||
if ev.Data != "" {
|
||||
eventRecord["data"] = json.RawMessage(ev.Data)
|
||||
}
|
||||
if ev.MessageID != "" {
|
||||
eventRecord["message_id"] = ev.MessageID
|
||||
}
|
||||
if ev.TaskID != "" {
|
||||
eventRecord["task_id"] = ev.TaskID
|
||||
}
|
||||
if ev.SessionID != "" {
|
||||
eventRecord["session_id"] = ev.SessionID
|
||||
}
|
||||
entry["events"] = append(events, eventRecord)
|
||||
|
||||
encoded, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
common.Warn("webhook trace marshal failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
rdb.SetObj(key, string(encoded), 600*time.Second)
|
||||
}
|
||||
273
internal/handler/agent_webhook_schema.go
Normal file
273
internal/handler/agent_webhook_schema.go
Normal file
@@ -0,0 +1,273 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
// Webhook schema helpers.
|
||||
//
|
||||
// These mirror api/apps/restful_apis/agent_api.py:1896-2051 (extract_by_schema,
|
||||
// default_for_type, auto_cast_value, validate_type) from the Python
|
||||
// webhook handler. The Go port preserves the Python semantics exactly:
|
||||
// required fields raise; optional fields get a type-based default;
|
||||
// type coercion runs before validation.
|
||||
//
|
||||
// Edge cases preserved:
|
||||
// - "file" type maps to []any (a list of parsed file descriptors).
|
||||
// - "array" accepts a JSON-decoded list; "array<inner>" validates each
|
||||
// element against `inner`.
|
||||
// - "object" must decode to map[string]any.
|
||||
// - Empty data sections yield an empty map (NOT nil) so downstream
|
||||
// components can index without a nil check.
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// defaultForType returns the schema-default value for the given type
|
||||
// token. Mirrors python's `default_for_type` at agent_api.py:1939-1955.
|
||||
//
|
||||
// Special cases:
|
||||
// - "file" → []any{} (a list, the python empty list)
|
||||
// - "object" → map[string]any{} (a dict)
|
||||
// - "boolean" → false
|
||||
// - "number" → 0 (float64 to match python's int/float unification)
|
||||
// - "string" → ""
|
||||
// - "array"/"array<...>" → []any{}
|
||||
// - "null" → nil
|
||||
// - anything else → nil
|
||||
func defaultForType(t string) any {
|
||||
switch t {
|
||||
case "file":
|
||||
return []any{}
|
||||
case "object":
|
||||
return map[string]any{}
|
||||
case "boolean":
|
||||
return false
|
||||
case "number":
|
||||
return float64(0)
|
||||
case "string":
|
||||
return ""
|
||||
case "null":
|
||||
return nil
|
||||
}
|
||||
if strings.HasPrefix(t, "array") {
|
||||
return []any{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// autoCastValue mirrors python's auto_cast_value at agent_api.py:1957-2016.
|
||||
// It accepts already-decoded values unchanged and converts string values
|
||||
// to the schema-expected type when possible.
|
||||
//
|
||||
// Returns an error when conversion is impossible (e.g. "abc" → number).
|
||||
func autoCastValue(value any, expectedType string) (any, error) {
|
||||
if _, isString := value.(string); !isString {
|
||||
// Non-string values are passed through unchanged; the Python
|
||||
// reference only converts when the source is a string.
|
||||
return value, nil
|
||||
}
|
||||
v := strings.TrimSpace(value.(string))
|
||||
switch expectedType {
|
||||
case "boolean":
|
||||
switch strings.ToLower(v) {
|
||||
case "true", "1":
|
||||
return true, nil
|
||||
case "false", "0":
|
||||
return false, nil
|
||||
}
|
||||
return nil, fmt.Errorf("cannot convert %q to boolean", value)
|
||||
case "number":
|
||||
// Python tries int first, then float. In Go strconv.Atoi handles
|
||||
// int; for float we fall back to ParseFloat. Negative integers
|
||||
// must be supported (the python `v[1:].isdigit()` guard does
|
||||
// that).
|
||||
if v != "" && (v[0] == '-' && len(v) > 1 || v[0] >= '0' && v[0] <= '9') {
|
||||
// Try int first (no fractional part, no exponent).
|
||||
if !strings.ContainsAny(v, ".eE") {
|
||||
if i, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
return float64(i), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
f, err := strconv.ParseFloat(v, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot convert %q to number", value)
|
||||
}
|
||||
return f, nil
|
||||
case "object":
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(v), &parsed); err != nil {
|
||||
return nil, fmt.Errorf("cannot convert %q to object", value)
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
if strings.HasPrefix(expectedType, "array") {
|
||||
var parsed []any
|
||||
if err := json.Unmarshal([]byte(v), &parsed); err != nil {
|
||||
return nil, fmt.Errorf("cannot convert %q to array", value)
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
// "string" / "file" / unknown — return as-is (matching the python
|
||||
// passthrough branches at the bottom of auto_cast_value).
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// validateType mirrors python's validate_type at agent_api.py:2019-2051.
|
||||
// It returns true when `value` is structurally compatible with type `t`.
|
||||
// Unknown types pass through (python parity: agent_api.py:2051).
|
||||
func validateType(value any, t string) bool {
|
||||
if t == "" {
|
||||
return true
|
||||
}
|
||||
switch t {
|
||||
case "file":
|
||||
_, ok := value.([]any)
|
||||
return ok
|
||||
case "string":
|
||||
_, ok := value.(string)
|
||||
return ok
|
||||
case "number":
|
||||
switch value.(type) {
|
||||
case int, int32, int64, float32, float64:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
case "boolean":
|
||||
_, ok := value.(bool)
|
||||
return ok
|
||||
case "object":
|
||||
_, ok := value.(map[string]any)
|
||||
return ok
|
||||
case "null":
|
||||
return value == nil
|
||||
}
|
||||
if strings.HasPrefix(t, "array") {
|
||||
list, ok := value.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
// array<string> / array<number> / array<object>: validate each
|
||||
// element against the inner type when <...> is present.
|
||||
if open := strings.Index(t, "<"); open >= 0 {
|
||||
if close := strings.Index(t[open:], ">"); close > 0 {
|
||||
inner := t[open+1 : open+close]
|
||||
if inner == "" {
|
||||
return true
|
||||
}
|
||||
for _, item := range list {
|
||||
if !validateType(item, inner) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Unknown type — pass through (python parity at line 2051).
|
||||
return true
|
||||
}
|
||||
|
||||
// extractBySchema mirrors python's extract_by_schema at agent_api.py:1896-1936.
|
||||
//
|
||||
// Behaviour:
|
||||
// - `data` may be nil; treated as empty map.
|
||||
// - For each property in schema.properties:
|
||||
// - if required and missing → error
|
||||
// - if missing (optional) → use defaultForType(prop.type)
|
||||
// - if present → autoCastValue then validateType; mismatch raises
|
||||
// - returns the cleaned map; nil `data` returns an empty map (NOT nil).
|
||||
func extractBySchema(data map[string]any, schema map[string]any, name string) (map[string]any, error) {
|
||||
if data == nil {
|
||||
data = map[string]any{}
|
||||
}
|
||||
if schema == nil {
|
||||
schema = map[string]any{}
|
||||
}
|
||||
|
||||
props, _ := schema["properties"].(map[string]any)
|
||||
required := stringSlice(schema["required"])
|
||||
|
||||
out := map[string]any{}
|
||||
|
||||
for field, rawSchema := range props {
|
||||
fieldSchema, _ := rawSchema.(map[string]any)
|
||||
fieldType, _ := fieldSchema["type"].(string)
|
||||
|
||||
// 1. Required field missing → error (python agent_api.py:1913).
|
||||
isRequired := contains(required, field)
|
||||
if isRequired {
|
||||
if _, ok := data[field]; !ok {
|
||||
return nil, fmt.Errorf("%s missing required field: %s", name, field)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Optional → default value.
|
||||
raw, present := data[field]
|
||||
if !present {
|
||||
out[field] = defaultForType(fieldType)
|
||||
continue
|
||||
}
|
||||
|
||||
// 3. Auto cast (only matters for string→typed; pass-through otherwise).
|
||||
casted, err := autoCastValue(raw, fieldType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s.%s auto-cast failed: %s", name, field, err.Error())
|
||||
}
|
||||
|
||||
// 4. Type validation.
|
||||
if !validateType(casted, fieldType) {
|
||||
return nil, fmt.Errorf(
|
||||
"%s.%s type mismatch: expected %s, got %T",
|
||||
name, field, fieldType, casted,
|
||||
)
|
||||
}
|
||||
out[field] = casted
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// stringSlice accepts either a []string or []any for the schema's
|
||||
// `required` field. Mirrors the python list-or-tuple-or-set union
|
||||
// handling at agent_api.py:1788-1798.
|
||||
func stringSlice(v any) []string {
|
||||
switch s := v.(type) {
|
||||
case []string:
|
||||
return s
|
||||
case []any:
|
||||
out := make([]string, 0, len(s))
|
||||
for _, item := range s {
|
||||
if str, ok := item.(string); ok {
|
||||
out = append(out, str)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func contains(haystack []string, needle string) bool {
|
||||
for _, s := range haystack {
|
||||
if s == needle {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
242
internal/handler/agent_webhook_schema_test.go
Normal file
242
internal/handler/agent_webhook_schema_test.go
Normal file
@@ -0,0 +1,242 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestExtractBySchema_RequiredMissing pins the required-field-missing
|
||||
// branch (mirrors python agent_api.py:1913).
|
||||
func TestExtractBySchema_RequiredMissing(t *testing.T) {
|
||||
schema := map[string]any{
|
||||
"properties": map[string]any{
|
||||
"q": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"q"},
|
||||
}
|
||||
_, err := extractBySchema(map[string]any{}, schema, "query")
|
||||
if err == nil || !strings.Contains(err.Error(), "missing required field") {
|
||||
t.Errorf("err = %v, want missing required field", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractBySchema_OptionalDefault confirms that optional fields get
|
||||
// the type-default value when absent.
|
||||
func TestExtractBySchema_OptionalDefault(t *testing.T) {
|
||||
schema := map[string]any{
|
||||
"properties": map[string]any{
|
||||
"s": map[string]any{"type": "string"},
|
||||
"b": map[string]any{"type": "boolean"},
|
||||
"o": map[string]any{"type": "object"},
|
||||
"a": map[string]any{"type": "array"},
|
||||
"f": map[string]any{"type": "file"},
|
||||
"n": map[string]any{"type": "number"},
|
||||
},
|
||||
}
|
||||
got, err := extractBySchema(map[string]any{}, schema, "query")
|
||||
if err != nil {
|
||||
t.Fatalf("extractBySchema: %v", err)
|
||||
}
|
||||
if got["s"] != "" {
|
||||
t.Errorf("default for string: got %v", got["s"])
|
||||
}
|
||||
if got["b"] != false {
|
||||
t.Errorf("default for boolean: got %v", got["b"])
|
||||
}
|
||||
if _, ok := got["o"].(map[string]any); !ok {
|
||||
t.Errorf("default for object: got %T", got["o"])
|
||||
}
|
||||
if _, ok := got["a"].([]any); !ok {
|
||||
t.Errorf("default for array: got %T", got["a"])
|
||||
}
|
||||
if _, ok := got["f"].([]any); !ok {
|
||||
t.Errorf("default for file: got %T", got["f"])
|
||||
}
|
||||
if got["n"] != float64(0) {
|
||||
t.Errorf("default for number: got %v", got["n"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractBySchema_NumberCoercion pins string→number auto-cast.
|
||||
func TestExtractBySchema_NumberCoercion(t *testing.T) {
|
||||
schema := map[string]any{
|
||||
"properties": map[string]any{
|
||||
"n": map[string]any{"type": "number"},
|
||||
},
|
||||
}
|
||||
got, err := extractBySchema(map[string]any{"n": "42"}, schema, "body")
|
||||
if err != nil {
|
||||
t.Fatalf("extractBySchema: %v", err)
|
||||
}
|
||||
if got["n"] != float64(42) {
|
||||
t.Errorf("number coercion: got %v (%T)", got["n"], got["n"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractBySchema_NumberCoercionFloat covers decimal strings.
|
||||
func TestExtractBySchema_NumberCoercionFloat(t *testing.T) {
|
||||
schema := map[string]any{
|
||||
"properties": map[string]any{
|
||||
"n": map[string]any{"type": "number"},
|
||||
},
|
||||
}
|
||||
got, err := extractBySchema(map[string]any{"n": "3.14"}, schema, "body")
|
||||
if err != nil {
|
||||
t.Fatalf("extractBySchema: %v", err)
|
||||
}
|
||||
if got["n"] != 3.14 {
|
||||
t.Errorf("number float coercion: got %v", got["n"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractBySchema_BooleanCoercion pins true/false string parsing.
|
||||
func TestExtractBySchema_BooleanCoercion(t *testing.T) {
|
||||
schema := map[string]any{
|
||||
"properties": map[string]any{
|
||||
"b": map[string]any{"type": "boolean"},
|
||||
},
|
||||
}
|
||||
cases := []struct {
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{"true", true},
|
||||
{"TRUE", true},
|
||||
{"1", true},
|
||||
{"false", false},
|
||||
{"0", false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
got, err := extractBySchema(map[string]any{"b": tc.in}, schema, "body")
|
||||
if err != nil {
|
||||
t.Errorf("coerce %q: err=%v", tc.in, err)
|
||||
continue
|
||||
}
|
||||
if got["b"] != tc.want {
|
||||
t.Errorf("coerce %q: got %v, want %v", tc.in, got["b"], tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractBySchema_BooleanCoercionBad confirms non-boolean strings fail.
|
||||
func TestExtractBySchema_BooleanCoercionBad(t *testing.T) {
|
||||
schema := map[string]any{
|
||||
"properties": map[string]any{
|
||||
"b": map[string]any{"type": "boolean"},
|
||||
},
|
||||
}
|
||||
_, err := extractBySchema(map[string]any{"b": "maybe"}, schema, "body")
|
||||
if err == nil || !strings.Contains(err.Error(), "auto-cast failed") {
|
||||
t.Errorf("err = %v, want auto-cast failed", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractBySchema_TypeMismatch covers wrong-type validation.
|
||||
func TestExtractBySchema_TypeMismatch(t *testing.T) {
|
||||
schema := map[string]any{
|
||||
"properties": map[string]any{
|
||||
"n": map[string]any{"type": "number"},
|
||||
},
|
||||
}
|
||||
_, err := extractBySchema(map[string]any{"n": []any{1}}, schema, "body")
|
||||
if err == nil || !strings.Contains(err.Error(), "type mismatch") {
|
||||
t.Errorf("err = %v, want type mismatch", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractBySchema_ArrayInnerType confirms array<inner> validation.
|
||||
func TestExtractBySchema_ArrayInnerType(t *testing.T) {
|
||||
schema := map[string]any{
|
||||
"properties": map[string]any{
|
||||
"xs": map[string]any{"type": "array<number>"},
|
||||
},
|
||||
}
|
||||
// valid
|
||||
if _, err := extractBySchema(map[string]any{"xs": []any{1, 2, 3}}, schema, "body"); err != nil {
|
||||
t.Errorf("valid array: err=%v", err)
|
||||
}
|
||||
// invalid inner type
|
||||
_, err := extractBySchema(map[string]any{"xs": []any{1, "two", 3}}, schema, "body")
|
||||
if err == nil || !strings.Contains(err.Error(), "type mismatch") {
|
||||
t.Errorf("invalid array inner: err=%v, want type mismatch", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractBySchema_ObjectCoercion pins string→object auto-cast.
|
||||
func TestExtractBySchema_ObjectCoercion(t *testing.T) {
|
||||
schema := map[string]any{
|
||||
"properties": map[string]any{
|
||||
"o": map[string]any{"type": "object"},
|
||||
},
|
||||
}
|
||||
got, err := extractBySchema(map[string]any{"o": `{"k":"v"}`}, schema, "body")
|
||||
if err != nil {
|
||||
t.Fatalf("extractBySchema: %v", err)
|
||||
}
|
||||
m, ok := got["o"].(map[string]any)
|
||||
if !ok || m["k"] != "v" {
|
||||
t.Errorf("object coercion: got %v", got["o"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractBySchema_NilData confirms nil data yields empty map.
|
||||
func TestExtractBySchema_NilData(t *testing.T) {
|
||||
got, err := extractBySchema(nil, nil, "query")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Errorf("expected empty map, got nil")
|
||||
}
|
||||
if len(got) != 0 {
|
||||
t.Errorf("expected empty map, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateType_FileAndArray confirms validateType handles the
|
||||
// edge cases at the python agent_api.py:2021-2051 boundary.
|
||||
func TestValidateType_FileAndArray(t *testing.T) {
|
||||
if !validateType([]any{map[string]any{}}, "file") {
|
||||
t.Error("file should accept []any")
|
||||
}
|
||||
if !validateType([]any{1, 2, 3}, "array<number>") {
|
||||
t.Error("array<number> should accept []int-equivalent")
|
||||
}
|
||||
if validateType([]any{1, "x"}, "array<number>") {
|
||||
t.Error("array<number> should reject mixed types")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultForType covers the python default_for_type table.
|
||||
func TestDefaultForType(t *testing.T) {
|
||||
cases := []struct {
|
||||
t string
|
||||
// We compare via fmt-friendly assertions rather than reflect.DeepEqual
|
||||
// because the python reference returns [] / {} / false / 0 / "" / nil
|
||||
// directly, and Go equivalents may differ in concrete type.
|
||||
}{
|
||||
{"file"}, {"object"}, {"boolean"}, {"number"}, {"string"},
|
||||
{"array"}, {"array<object>"}, {"null"}, {"unknown"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
got := defaultForType(tc.t)
|
||||
// Just ensure no panic and we get a sensible zero value.
|
||||
_ = got
|
||||
}
|
||||
}
|
||||
491
internal/handler/agent_webhook_security.go
Normal file
491
internal/handler/agent_webhook_security.go
Normal file
@@ -0,0 +1,491 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
// Webhook security helpers.
|
||||
//
|
||||
// These mirror api/apps/restful_apis/agent_api.py:1602-1810
|
||||
// (validate_webhook_security + the six sub-validators) from the Python
|
||||
// webhook handler. The Go port preserves the Python semantics exactly:
|
||||
//
|
||||
// - max body size (with the 10 MB cap at agent_api.py:1652)
|
||||
// - IP whitelist (CIDR + exact match)
|
||||
// - rate limit (token bucket via redis.EvalTokenBucketStrict — strict
|
||||
// fail-closed; see redis.go)
|
||||
// - token auth (header check)
|
||||
// - basic auth (HTTP Basic)
|
||||
// - JWT (HS/RS256, audience/issuer/required-claims)
|
||||
//
|
||||
// All helpers return a Go error so the handler can surface the
|
||||
// Python-shaped 102 envelope directly. Empty security config means
|
||||
// "no checks" (matches agent_api.py:1607).
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
rediscli "ragflow/internal/engine/redis"
|
||||
)
|
||||
|
||||
const (
|
||||
// bytesPerKB / bytesPerMB use SI decimal units (1 KB = 1 000 B,
|
||||
// 1 MB = 1 000 000 B). Per user request, the Go port diverges
|
||||
// from python agent_api.py:1643 which uses 1 KB = 1 024 B
|
||||
// (binary). The header doc comments on parseMaxBodySize call
|
||||
// this out so future audits can find the divergence.
|
||||
bytesPerKB int64 = 1000
|
||||
bytesPerMB int64 = bytesPerKB * 1000
|
||||
|
||||
// webhookBodyMaxBytes is the hard ceiling above which a configured
|
||||
// max_body_size is rejected at config-validation time. Stated
|
||||
// in decimal MB to match the kb/mb unit base above; 10 MB
|
||||
// decimal = 10 000 000 bytes. The python reference
|
||||
// (agent_api.py:1652) uses 10 * 1024 * 1024 = 10 485 760 bytes,
|
||||
// so this is ~486 KB stricter than python.
|
||||
webhookBodyMaxBytes int64 = 10 * bytesPerMB
|
||||
|
||||
// webhookRateLimitTimeout is the redis call timeout used for the
|
||||
// strict token-bucket lookup. Short on purpose — a security gate
|
||||
// must not stall the request thread.
|
||||
webhookRateLimitTimeout = 500 * time.Millisecond
|
||||
|
||||
// jwtReservedClaims is the python-parity set at agent_api.py:1800
|
||||
// — required_claims must NOT name any of these.
|
||||
jwtReservedClaims = "exp,sub,aud,iss,nbf,iat"
|
||||
)
|
||||
|
||||
// validateWebhookSecurity is the orchestrator. Empty/nil cfg → no-op
|
||||
// (matches agent_api.py:1607 "No security config → allowed by default").
|
||||
//
|
||||
// Sub-validators run in the python-defined order:
|
||||
// 1. validateMaxBodySize
|
||||
// 2. validateIPWhitelist
|
||||
// 3. validateRateLimit
|
||||
// 4. validateAuth (dispatches on auth_type)
|
||||
func validateWebhookSecurity(
|
||||
securityCfg map[string]any,
|
||||
c *gin.Context,
|
||||
canvasID string,
|
||||
) error {
|
||||
if len(securityCfg) == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := validateMaxBodySize(c, securityCfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateIPWhitelist(c, securityCfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateRateLimit(canvasID, securityCfg); err != nil {
|
||||
return err
|
||||
}
|
||||
return validateAuth(c, securityCfg)
|
||||
}
|
||||
|
||||
// validateMaxBodySize mirrors python agent_api.py:1636-1658.
|
||||
//
|
||||
// Format: "<n>kb" | "<n>mb" (case-insensitive). Anything else is a
|
||||
// config bug. The configured limit is capped at webhookBodyMaxBytes
|
||||
// (10 MB) — exceeding that is also a config error, not silently raised.
|
||||
// The actual request size is then compared against the parsed limit.
|
||||
func validateMaxBodySize(c *gin.Context, cfg map[string]any) error {
|
||||
limit, err := parseMaxBodySize(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if limit <= 0 {
|
||||
return nil
|
||||
}
|
||||
contentLength := c.Request.ContentLength
|
||||
if contentLength < 0 {
|
||||
contentLength = 0
|
||||
}
|
||||
if contentLength > limit {
|
||||
return fmt.Errorf("request body too large: %d > %d", contentLength, limit)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseMaxBodySize returns the byte limit configured by
|
||||
// `max_body_size` (with the 10 MB cap enforced) or 0 when no limit
|
||||
// is configured. The handler uses this to wrap c.Request.Body in
|
||||
// http.MaxBytesReader so the actual stream read is bounded — the
|
||||
// Content-Length header check alone is insufficient because a client
|
||||
// can advertise a small Content-Length and stream more.
|
||||
//
|
||||
// Unit base: 1 kb = 1 000 B, 1 mb = 1 000 000 B (SI decimal). Per
|
||||
// user request; note this DIVERGES from the python reference
|
||||
// (agent_api.py:1643 uses binary 1 KB = 1 024 B).
|
||||
//
|
||||
// Overflow guard (CodeRabbit PR review #3 on PR #16403): a very
|
||||
// large configured n (e.g. "10000000mb") would otherwise overflow
|
||||
// `n * bytesPerMB` and wrap to a small positive number, bypassing
|
||||
// the 10 MB cap. We check the parsed number against the cap before
|
||||
// multiplying.
|
||||
func parseMaxBodySize(cfg map[string]any) (int64, error) {
|
||||
raw, ok := cfg["max_body_size"].(string)
|
||||
if !ok || raw == "" {
|
||||
return 0, nil
|
||||
}
|
||||
sizeStr := strings.ToLower(strings.TrimSpace(raw))
|
||||
var limit int64
|
||||
switch {
|
||||
case strings.HasSuffix(sizeStr, "kb"):
|
||||
n, err := strconv.ParseInt(strings.TrimSuffix(sizeStr, "kb"), 10, 64)
|
||||
if err != nil || n <= 0 {
|
||||
return 0, fmt.Errorf("invalid max_body_size format")
|
||||
}
|
||||
if n > webhookBodyMaxBytes/bytesPerKB {
|
||||
return 0, fmt.Errorf("max_body_size exceeds maximum allowed size (10MB)")
|
||||
}
|
||||
limit = n * bytesPerKB
|
||||
case strings.HasSuffix(sizeStr, "mb"):
|
||||
n, err := strconv.ParseInt(strings.TrimSuffix(sizeStr, "mb"), 10, 64)
|
||||
if err != nil || n <= 0 {
|
||||
return 0, fmt.Errorf("invalid max_body_size format")
|
||||
}
|
||||
if n > webhookBodyMaxBytes/bytesPerMB {
|
||||
return 0, fmt.Errorf("max_body_size exceeds maximum allowed size (10MB)")
|
||||
}
|
||||
limit = n * bytesPerMB
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid max_body_size format")
|
||||
}
|
||||
if limit > webhookBodyMaxBytes {
|
||||
return 0, fmt.Errorf("max_body_size exceeds maximum allowed size (10MB)")
|
||||
}
|
||||
return limit, nil
|
||||
}
|
||||
|
||||
// validateIPWhitelist mirrors python agent_api.py:1660-1679. Empty
|
||||
// list → allow. Supports CIDR ("10.0.0.0/8") and exact ("1.2.3.4").
|
||||
// The client IP comes from gin's c.ClientIP() which honours
|
||||
// X-Forwarded-For when trusted proxies are configured.
|
||||
func validateIPWhitelist(c *gin.Context, cfg map[string]any) error {
|
||||
whitelist, _ := cfg["ip_whitelist"].([]any)
|
||||
if len(whitelist) == 0 {
|
||||
return nil
|
||||
}
|
||||
clientIP := c.ClientIP()
|
||||
for _, raw := range whitelist {
|
||||
rule, _ := raw.(string)
|
||||
if rule == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(rule, "/") {
|
||||
// CIDR
|
||||
_, ipNet, err := net.ParseCIDR(rule)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
addr := net.ParseIP(clientIP)
|
||||
if addr != nil && ipNet.Contains(addr) {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Exact match
|
||||
if clientIP == rule {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("IP %s is not allowed by whitelist", clientIP)
|
||||
}
|
||||
|
||||
// validateRateLimit mirrors python agent_api.py:1681-1723.
|
||||
//
|
||||
// Window mapping (matches agent_api.py:1692-1697):
|
||||
//
|
||||
// second → 1s, minute → 60s, hour → 3600s, day → 86400s.
|
||||
//
|
||||
// Unknown per → error (NOT silently fall through).
|
||||
//
|
||||
// Strict fail-closed: any Redis error → error. The webhook handler
|
||||
// surfaces this as 102 so an operator notices a misconfiguration.
|
||||
func validateRateLimit(canvasID string, cfg map[string]any) error {
|
||||
rawRL, ok := cfg["rate_limit"].(map[string]any)
|
||||
if !ok || len(rawRL) == 0 {
|
||||
return nil
|
||||
}
|
||||
limitF, ok := rawRL["limit"].(float64)
|
||||
if !ok {
|
||||
// JSON numbers often come back as float64; try int as well.
|
||||
if limitI, ok2 := rawRL["limit"].(int); ok2 {
|
||||
limitF = float64(limitI)
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
if !ok || limitF <= 0 {
|
||||
return fmt.Errorf("rate_limit.limit must be > 0")
|
||||
}
|
||||
per, _ := rawRL["per"].(string)
|
||||
if per == "" {
|
||||
per = "minute"
|
||||
}
|
||||
var window float64
|
||||
switch per {
|
||||
case "second":
|
||||
window = 1
|
||||
case "minute":
|
||||
window = 60
|
||||
case "hour":
|
||||
window = 3600
|
||||
case "day":
|
||||
window = 86400
|
||||
default:
|
||||
return fmt.Errorf("invalid rate_limit.per: %s", per)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("rl:tb:%s", canvasID)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), webhookRateLimitTimeout)
|
||||
defer cancel()
|
||||
|
||||
rdb := rediscli.Get()
|
||||
if rdb == nil {
|
||||
return fmt.Errorf("rate limit error: redis not initialised")
|
||||
}
|
||||
allowed, err := rdb.EvalTokenBucketStrict(ctx, key, limitF, limitF/window)
|
||||
if err != nil {
|
||||
return fmt.Errorf("rate limit error: %s", err.Error())
|
||||
}
|
||||
if !allowed {
|
||||
return fmt.Errorf("too many requests (rate limit exceeded)")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateAuth dispatches on auth_type. Empty cfg or auth_type=="none"
|
||||
// → allow (matches agent_api.py:1621).
|
||||
func validateAuth(c *gin.Context, cfg map[string]any) error {
|
||||
authType, _ := cfg["auth_type"].(string)
|
||||
if authType == "" || authType == "none" {
|
||||
return nil
|
||||
}
|
||||
switch authType {
|
||||
case "token":
|
||||
return validateTokenAuth(c, cfg)
|
||||
case "basic":
|
||||
return validateBasicAuth(c, cfg)
|
||||
case "jwt":
|
||||
return validateJWTAuth(c, cfg)
|
||||
}
|
||||
return fmt.Errorf("unsupported auth_type: %s", authType)
|
||||
}
|
||||
|
||||
// validateTokenAuth mirrors python agent_api.py:1725-1733.
|
||||
//
|
||||
// An empty configured `token_value` previously meant "accept any
|
||||
// request without that header". We now require both header and
|
||||
// value to be non-empty configured secrets, otherwise the request
|
||||
// is rejected as misconfigured. CodeRabbit PR review #4.
|
||||
func validateTokenAuth(c *gin.Context, cfg map[string]any) error {
|
||||
rawToken, _ := cfg["token"].(map[string]any)
|
||||
if rawToken == nil {
|
||||
return fmt.Errorf("Invalid token authentication")
|
||||
}
|
||||
header, _ := rawToken["token_header"].(string)
|
||||
want, _ := rawToken["token_value"].(string)
|
||||
if header == "" || want == "" {
|
||||
return fmt.Errorf("Invalid token authentication")
|
||||
}
|
||||
if c.GetHeader(header) != want {
|
||||
return fmt.Errorf("Invalid token authentication")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateBasicAuth mirrors python agent_api.py:1735-1743. We use
|
||||
// gin's c.Request.BasicAuth() which parses the Authorization header
|
||||
// and returns the (user, pass, ok) triple. Empty configured
|
||||
// username/password are now rejected (CodeRabbit PR review #4).
|
||||
func validateBasicAuth(c *gin.Context, cfg map[string]any) error {
|
||||
rawBasic, _ := cfg["basic_auth"].(map[string]any)
|
||||
if rawBasic == nil {
|
||||
return fmt.Errorf("Invalid Basic Auth credentials")
|
||||
}
|
||||
username, _ := rawBasic["username"].(string)
|
||||
password, _ := rawBasic["password"].(string)
|
||||
if username == "" || password == "" {
|
||||
return fmt.Errorf("Invalid Basic Auth credentials")
|
||||
}
|
||||
u, p, ok := c.Request.BasicAuth()
|
||||
if !ok || u != username || p != password {
|
||||
return fmt.Errorf("Invalid Basic Auth credentials")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateJWTAuth mirrors python agent_api.py:1745-1809.
|
||||
//
|
||||
// Algorithm defaults to HS256. audience / issuer are validated only
|
||||
// when configured. required_claims rejects reserved JWT claims
|
||||
// (exp sub aud iss nbf iat) and any missing claims.
|
||||
//
|
||||
// Algorithms:
|
||||
// - HS256 / HS384 / HS512 → secret is a shared HMAC key.
|
||||
// - RS256 / RS384 / RS512 → secret is a PEM-encoded RSA public key.
|
||||
// - ES256 / ES384 / ES512 → secret is a PEM-encoded EC public key.
|
||||
//
|
||||
// The Python reference uses the same `secret` field for all three
|
||||
// families (the python jwt library is happy to take either a string
|
||||
// or a PEM block); we mirror that with one `secret` config slot and
|
||||
// dispatch on the algorithm.
|
||||
func validateJWTAuth(c *gin.Context, cfg map[string]any) error {
|
||||
rawJWT, _ := cfg["jwt"].(map[string]any)
|
||||
if rawJWT == nil {
|
||||
return fmt.Errorf("jwt secret not configured")
|
||||
}
|
||||
secret, _ := rawJWT["secret"].(string)
|
||||
if secret == "" {
|
||||
return fmt.Errorf("jwt secret not configured")
|
||||
}
|
||||
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
const prefix = "Bearer "
|
||||
if !strings.HasPrefix(authHeader, prefix) {
|
||||
return fmt.Errorf("missing bearer token")
|
||||
}
|
||||
tokenStr := strings.TrimSpace(authHeader[len(prefix):])
|
||||
if tokenStr == "" {
|
||||
return fmt.Errorf("empty bearer token")
|
||||
}
|
||||
|
||||
alg, _ := rawJWT["algorithm"].(string)
|
||||
if alg == "" {
|
||||
alg = "HS256"
|
||||
}
|
||||
alg = strings.ToUpper(alg)
|
||||
|
||||
// Build the parser options.
|
||||
parserOpts := []jwt.ParserOption{
|
||||
jwt.WithValidMethods([]string{alg}),
|
||||
}
|
||||
if aud, ok := rawJWT["audience"].(string); ok && aud != "" {
|
||||
parserOpts = append(parserOpts, jwt.WithAudience(aud))
|
||||
}
|
||||
if iss, ok := rawJWT["issuer"].(string); ok && iss != "" {
|
||||
parserOpts = append(parserOpts, jwt.WithIssuer(iss))
|
||||
}
|
||||
|
||||
keyFunc, keyErr := jwtKeyFunc(alg, secret)
|
||||
if keyErr != nil {
|
||||
return keyErr
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(tokenStr, keyFunc, parserOpts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid jwt: %s", err.Error())
|
||||
}
|
||||
if !token.Valid {
|
||||
return fmt.Errorf("invalid jwt")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid jwt claims")
|
||||
}
|
||||
|
||||
// Required claims validation (mirrors agent_api.py:1787-1808).
|
||||
required := collectStringSlice(rawJWT["required_claims"])
|
||||
reserved := splitCSV(jwtReservedClaims)
|
||||
for _, claim := range required {
|
||||
if contains(reserved, claim) {
|
||||
return fmt.Errorf("reserved jwt claim cannot be required: %s", claim)
|
||||
}
|
||||
if _, present := claims[claim]; !present {
|
||||
return fmt.Errorf("missing jwt claim: %s", claim)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// jwtKeyFunc returns the verification-key closure that jwt.Parse
|
||||
// invokes. The dispatch mirrors the python jwt library: a string secret
|
||||
// is treated as an HMAC key for HS* algorithms, and as a PEM block for
|
||||
// RS*/ES* algorithms.
|
||||
func jwtKeyFunc(alg, secret string) (jwt.Keyfunc, error) {
|
||||
switch alg {
|
||||
case "HS256", "HS384", "HS512":
|
||||
return func(_ *jwt.Token) (any, error) { return []byte(secret), nil }, nil
|
||||
case "RS256", "RS384", "RS512":
|
||||
pub, err := jwt.ParseRSAPublicKeyFromPEM([]byte(secret))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("jwt rsa public key: %s", err.Error())
|
||||
}
|
||||
return func(_ *jwt.Token) (any, error) { return pub, nil }, nil
|
||||
case "ES256", "ES384", "ES512":
|
||||
pub, err := jwt.ParseECPublicKeyFromPEM([]byte(secret))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("jwt ec public key: %s", err.Error())
|
||||
}
|
||||
return func(_ *jwt.Token) (any, error) { return pub, nil }, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported jwt algorithm: %s", alg)
|
||||
}
|
||||
|
||||
// collectStringSlice accepts the python-shaped `required_claims` value:
|
||||
// a single string OR a list/tuple/set. Mirrors agent_api.py:1788-1798.
|
||||
func collectStringSlice(v any) []string {
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
s := strings.TrimSpace(t)
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{s}
|
||||
case []string:
|
||||
out := make([]string, 0, len(t))
|
||||
for _, s := range t {
|
||||
s = strings.TrimSpace(s)
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
case []any:
|
||||
out := make([]string, 0, len(t))
|
||||
for _, item := range t {
|
||||
if s, ok := item.(string); ok {
|
||||
s = strings.TrimSpace(s)
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func splitCSV(s string) []string {
|
||||
parts := strings.Split(s, ",")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
380
internal/handler/agent_webhook_security_test.go
Normal file
380
internal/handler/agent_webhook_security_test.go
Normal file
@@ -0,0 +1,380 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
func securityCtx(t *testing.T, remoteAddr string, headers map[string]string) *gin.Context {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/v1/agents/c1/webhook", strings.NewReader("{}"))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Request.RemoteAddr = remoteAddr
|
||||
for k, v := range headers {
|
||||
c.Request.Header.Set(k, v)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// TestValidateMaxBodySize_Allowed covers the no-op branch.
|
||||
func TestValidateMaxBodySize_Allowed(t *testing.T) {
|
||||
c := securityCtx(t, "1.2.3.4:0", nil)
|
||||
c.Request.ContentLength = 100
|
||||
if err := validateMaxBodySize(c, map[string]any{"max_body_size": "1kb"}); err != nil {
|
||||
t.Errorf("err = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateMaxBodySize_TooLarge covers the size-mismatch branch.
|
||||
func TestValidateMaxBodySize_TooLarge(t *testing.T) {
|
||||
c := securityCtx(t, "1.2.3.4:0", nil)
|
||||
c.Request.ContentLength = 2048
|
||||
err := validateMaxBodySize(c, map[string]any{"max_body_size": "1kb"})
|
||||
if err == nil || !strings.Contains(err.Error(), "request body too large") {
|
||||
t.Errorf("err = %v, want 'request body too large'", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateMaxBodySize_BadFormat covers non-numeric format.
|
||||
func TestValidateMaxBodySize_BadFormat(t *testing.T) {
|
||||
c := securityCtx(t, "1.2.3.4:0", nil)
|
||||
err := validateMaxBodySize(c, map[string]any{"max_body_size": "1gb"})
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid max_body_size format") {
|
||||
t.Errorf("err = %v, want 'invalid max_body_size format'", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateIPWhitelist_EmptyIsAllow covers the empty-list branch.
|
||||
func TestValidateIPWhitelist_EmptyIsAllow(t *testing.T) {
|
||||
c := securityCtx(t, "1.2.3.4:0", nil)
|
||||
if err := validateIPWhitelist(c, map[string]any{"ip_whitelist": []any{}}); err != nil {
|
||||
t.Errorf("empty whitelist: err = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateIPWhitelist_ExactMatch passes when client IP matches.
|
||||
func TestValidateIPWhitelist_ExactMatch(t *testing.T) {
|
||||
c := securityCtx(t, "10.0.0.5:0", nil)
|
||||
cfg := map[string]any{"ip_whitelist": []any{"10.0.0.5"}}
|
||||
if err := validateIPWhitelist(c, cfg); err != nil {
|
||||
t.Errorf("exact match: err = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateIPWhitelist_CIDR covers the CIDR branch.
|
||||
func TestValidateIPWhitelist_CIDR(t *testing.T) {
|
||||
c := securityCtx(t, "10.0.0.5:0", nil)
|
||||
cfg := map[string]any{"ip_whitelist": []any{"10.0.0.0/8"}}
|
||||
if err := validateIPWhitelist(c, cfg); err != nil {
|
||||
t.Errorf("cidr match: err = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateIPWhitelist_RejectForeign confirms a foreign IP is denied.
|
||||
func TestValidateIPWhitelist_RejectForeign(t *testing.T) {
|
||||
c := securityCtx(t, "192.168.1.5:0", nil)
|
||||
cfg := map[string]any{"ip_whitelist": []any{"10.0.0.0/8"}}
|
||||
err := validateIPWhitelist(c, cfg)
|
||||
if err == nil || !strings.Contains(err.Error(), "not allowed by whitelist") {
|
||||
t.Errorf("err = %v, want 'not allowed by whitelist'", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateAuth_NoneIsAllow covers the auth_type=="none" no-op.
|
||||
func TestValidateAuth_NoneIsAllow(t *testing.T) {
|
||||
c := securityCtx(t, "1.2.3.4:0", nil)
|
||||
if err := validateAuth(c, map[string]any{"auth_type": "none"}); err != nil {
|
||||
t.Errorf("auth_type=none: err = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateAuth_Unsupported covers unknown auth_type.
|
||||
func TestValidateAuth_Unsupported(t *testing.T) {
|
||||
c := securityCtx(t, "1.2.3.4:0", nil)
|
||||
err := validateAuth(c, map[string]any{"auth_type": "weird"})
|
||||
if err == nil || !strings.Contains(err.Error(), "unsupported auth_type") {
|
||||
t.Errorf("err = %v, want 'unsupported auth_type'", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateTokenAuth_HeaderValue covers matching and non-matching cases.
|
||||
func TestValidateTokenAuth_HeaderValue(t *testing.T) {
|
||||
cfg := map[string]any{
|
||||
"token": map[string]any{
|
||||
"token_header": "X-Token",
|
||||
"token_value": "abc",
|
||||
},
|
||||
}
|
||||
|
||||
// pass
|
||||
cPass := securityCtx(t, "1.2.3.4:0", map[string]string{"X-Token": "abc"})
|
||||
if err := validateTokenAuth(cPass, cfg); err != nil {
|
||||
t.Errorf("matching token: err = %v, want nil", err)
|
||||
}
|
||||
|
||||
// fail
|
||||
cFail := securityCtx(t, "1.2.3.4:0", map[string]string{"X-Token": "wrong"})
|
||||
if err := validateTokenAuth(cFail, cfg); err == nil {
|
||||
t.Errorf("non-matching token: err = nil, want error")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateBasicAuth_PassAndFail covers both branches.
|
||||
func TestValidateBasicAuth_PassAndFail(t *testing.T) {
|
||||
cfg := map[string]any{
|
||||
"basic_auth": map[string]any{
|
||||
"username": "alice",
|
||||
"password": "wonderland",
|
||||
},
|
||||
}
|
||||
|
||||
cPass := securityCtx(t, "1.2.3.4:0", nil)
|
||||
cPass.Request.SetBasicAuth("alice", "wonderland")
|
||||
if err := validateBasicAuth(cPass, cfg); err != nil {
|
||||
t.Errorf("matching basic: err = %v, want nil", err)
|
||||
}
|
||||
|
||||
cFail := securityCtx(t, "1.2.3.4:0", nil)
|
||||
cFail.Request.SetBasicAuth("alice", "wrong")
|
||||
if err := validateBasicAuth(cFail, cfg); err == nil {
|
||||
t.Errorf("non-matching basic: err = nil, want error")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateJWTAuth_NoSecret rejects empty secret config.
|
||||
func TestValidateJWTAuth_NoSecret(t *testing.T) {
|
||||
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer x"})
|
||||
err := validateJWTAuth(c, map[string]any{"jwt": map[string]any{}})
|
||||
if err == nil || !strings.Contains(err.Error(), "secret not configured") {
|
||||
t.Errorf("err = %v, want 'secret not configured'", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateJWTAuth_MissingBearer rejects when Authorization is absent.
|
||||
func TestValidateJWTAuth_MissingBearer(t *testing.T) {
|
||||
c := securityCtx(t, "1.2.3.4:0", nil)
|
||||
err := validateJWTAuth(c, map[string]any{
|
||||
"jwt": map[string]any{"secret": "s"},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "missing bearer token") {
|
||||
t.Errorf("err = %v, want 'missing bearer token'", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateJWTAuth_RS256HappyPath confirms that an RS256 token
|
||||
// signed with a real RSA key is accepted when the secret field holds
|
||||
// the matching PEM-encoded public key. The python reference uses the
|
||||
// same `secret` slot for both HMAC and RSA inputs; this test pins that
|
||||
// contract for the Go port.
|
||||
func TestValidateJWTAuth_RS256HappyPath(t *testing.T) {
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("rsa.GenerateKey: %v", err)
|
||||
}
|
||||
pubPEM := encodeRSAPublicKeyPEM(&priv.PublicKey)
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
|
||||
"sub": "u-1",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
})
|
||||
signed, err := token.SignedString(priv)
|
||||
if err != nil {
|
||||
t.Fatalf("sign: %v", err)
|
||||
}
|
||||
|
||||
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer " + signed})
|
||||
err = validateJWTAuth(c, map[string]any{
|
||||
"jwt": map[string]any{
|
||||
"secret": string(pubPEM),
|
||||
"algorithm": "RS256",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("RS256 happy path: err = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateJWTAuth_RS256BadPEM rejects a malformed public key.
|
||||
func TestValidateJWTAuth_RS256BadPEM(t *testing.T) {
|
||||
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer x"})
|
||||
err := validateJWTAuth(c, map[string]any{
|
||||
"jwt": map[string]any{
|
||||
"secret": "not a pem block",
|
||||
"algorithm": "RS256",
|
||||
},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "rsa public key") {
|
||||
t.Errorf("err = %v, want 'rsa public key' parse error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateJWTAuth_UnsupportedAlgorithm rejects truly unknown algos.
|
||||
func TestValidateJWTAuth_UnsupportedAlgorithm(t *testing.T) {
|
||||
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer x"})
|
||||
err := validateJWTAuth(c, map[string]any{
|
||||
"jwt": map[string]any{
|
||||
"secret": "s",
|
||||
"algorithm": "none",
|
||||
},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "unsupported") {
|
||||
t.Errorf("err = %v, want 'unsupported'", err)
|
||||
}
|
||||
}
|
||||
|
||||
// encodeRSAPublicKeyPEM serialises an *rsa.PublicKey into the
|
||||
// PEM-encoded PKIX form that jwt.ParseRSAPublicKeyFromPEM expects.
|
||||
func encodeRSAPublicKeyPEM(pub *rsa.PublicKey) []byte {
|
||||
asn1, err := x509.MarshalPKIXPublicKey(pub)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
block := &pem.Block{Type: "PUBLIC KEY", Bytes: asn1}
|
||||
return pem.EncodeToMemory(block)
|
||||
}
|
||||
|
||||
// TestValidateJWTAuth_ReservedClaimRejected covers the reserved-claim
|
||||
// guard. We build a valid HS256 JWT first so the parse succeeds, then
|
||||
// ask for `exp` as a required claim — the validator must reject it.
|
||||
func TestValidateJWTAuth_ReservedClaimRejected(t *testing.T) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": "u-1",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
})
|
||||
signed, err := token.SignedString([]byte("s"))
|
||||
if err != nil {
|
||||
t.Fatalf("sign: %v", err)
|
||||
}
|
||||
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer " + signed})
|
||||
err = validateJWTAuth(c, map[string]any{
|
||||
"jwt": map[string]any{
|
||||
"secret": "s",
|
||||
"required_claims": "exp",
|
||||
},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "reserved jwt claim") {
|
||||
t.Errorf("err = %v, want 'reserved jwt claim'", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateRateLimit_NoConfig covers the no-rate-limit branch.
|
||||
func TestValidateRateLimit_NoConfig(t *testing.T) {
|
||||
if err := validateRateLimit("c1", map[string]any{}); err != nil {
|
||||
t.Errorf("no rate_limit: err = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateRateLimit_BadPer rejects unknown per window.
|
||||
func TestValidateRateLimit_BadPer(t *testing.T) {
|
||||
err := validateRateLimit("c1", map[string]any{
|
||||
"rate_limit": map[string]any{"limit": 10, "per": "week"},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid rate_limit.per") {
|
||||
t.Errorf("err = %v, want 'invalid rate_limit.per'", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateRateLimit_BadLimit rejects non-positive limits.
|
||||
func TestValidateRateLimit_BadLimit(t *testing.T) {
|
||||
err := validateRateLimit("c1", map[string]any{
|
||||
"rate_limit": map[string]any{"limit": 0, "per": "minute"},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "must be > 0") {
|
||||
t.Errorf("err = %v, want 'must be > 0'", err)
|
||||
}
|
||||
}
|
||||
|
||||
// (No helper needed at the bottom of this file; helper functions
|
||||
// inline above.)
|
||||
// _securityUnused previously lived here as a placeholder; deleted
|
||||
// during cleanup (code-review MEDIUM-2).
|
||||
|
||||
// TestValidateMaxBodySize_OverflowGuard covers CodeRabbit PR review
|
||||
// #3: a configured n that would overflow n*bytesPerMB (e.g. a huge
|
||||
// mb value) must be rejected before the multiplication, not
|
||||
// silently wrap to a small number and pass the cap check.
|
||||
//
|
||||
// The chosen n=9_223_372_036_855 multiplied by 1_000_000 (= 1 MB)
|
||||
// overflows int64 max (9.22e18), so without the pre-multiplication
|
||||
// guard this would silently wrap to a small positive number and
|
||||
// the cap check would (incorrectly) succeed. With the guard in
|
||||
// place, the rejection happens at the parse step, before any
|
||||
// multiplication.
|
||||
func TestValidateMaxBodySize_OverflowGuard(t *testing.T) {
|
||||
c := securityCtx(t, "1.2.3.4:0", nil)
|
||||
err := validateMaxBodySize(c, map[string]any{"max_body_size": "9223372036855mb"})
|
||||
if err == nil {
|
||||
t.Errorf("huge mb value: err = nil, want overflow-rejection error")
|
||||
} else if !strings.Contains(err.Error(), "exceeds maximum") {
|
||||
t.Errorf("err = %v, want 'exceeds maximum'", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseMaxBodySize_DecimalUnits documents the per-user-request
|
||||
// SI-decimal unit base: 1 kb = 1 000 B, 1 mb = 1 000 000 B.
|
||||
// (The python reference uses 1 024 / 1 048 576.)
|
||||
func TestParseMaxBodySize_DecimalUnits(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want int64
|
||||
}{
|
||||
{"1kb", 1000},
|
||||
{"5kb", 5000},
|
||||
{"1mb", 1_000_000},
|
||||
{"10mb", 10_000_000}, // exact cap
|
||||
}
|
||||
for _, tc := range cases {
|
||||
got, err := parseMaxBodySize(map[string]any{"max_body_size": tc.in})
|
||||
if err != nil {
|
||||
t.Errorf("parseMaxBodySize(%q): err = %v, want nil", tc.in, err)
|
||||
continue
|
||||
}
|
||||
if got != tc.want {
|
||||
t.Errorf("parseMaxBodySize(%q) = %d, want %d", tc.in, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateTokenAuth_EmptyValueRejected covers CodeRabbit PR
|
||||
// review #4: an empty configured token_value used to mean "accept
|
||||
// any request without that header". Now it must be rejected.
|
||||
func TestValidateTokenAuth_EmptyValueRejected(t *testing.T) {
|
||||
cfg := map[string]any{
|
||||
"token": map[string]any{
|
||||
"token_header": "X-Token",
|
||||
"token_value": "",
|
||||
},
|
||||
}
|
||||
c := securityCtx(t, "1.2.3.4:0", nil) // no header at all
|
||||
if err := validateTokenAuth(c, cfg); err == nil {
|
||||
t.Errorf("empty token_value: err = nil, want error")
|
||||
}
|
||||
}
|
||||
659
internal/handler/agent_webhook_test.go
Normal file
659
internal/handler/agent_webhook_test.go
Normal file
@@ -0,0 +1,659 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/agent/canvas"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// fakeCanvasLoader is a stand-in canvasLoader for webhook tests. It
|
||||
// returns the canned canvas/run events configured at construction time
|
||||
// without touching the database. Mirrors the pattern used by
|
||||
// waitFakeAgentService at internal/handler/agent_wait_for_user_test.go.
|
||||
type fakeCanvasLoader struct {
|
||||
canvas *entity.UserCanvas
|
||||
err error
|
||||
events []canvas.RunEvent
|
||||
}
|
||||
|
||||
func (f *fakeCanvasLoader) LoadCanvasByID(_ context.Context, _, _ string) (*entity.UserCanvas, error) {
|
||||
if f.err != nil {
|
||||
return nil, f.err
|
||||
}
|
||||
return f.canvas, nil
|
||||
}
|
||||
|
||||
func (f *fakeCanvasLoader) RunAgentWithWebhook(_ context.Context, _, _ string, _ map[string]any) (<-chan canvas.RunEvent, error) {
|
||||
out := make(chan canvas.RunEvent, len(f.events))
|
||||
for _, e := range f.events {
|
||||
out <- e
|
||||
}
|
||||
close(out)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// makeWebhookCanvas builds a minimal canvas with a Begin component
|
||||
// whose params.mode == "Webhook" and the supplied params map. The
|
||||
// `params` argument becomes webhook_cfg inside the handler.
|
||||
func makeWebhookCanvas(id, userID, mode string, params map[string]any) *entity.UserCanvas {
|
||||
dsl := map[string]any{
|
||||
"components": map[string]any{
|
||||
"begin": map[string]any{
|
||||
"obj": map[string]any{
|
||||
"component_name": "Begin",
|
||||
"params": map[string]any{
|
||||
"mode": mode,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for k, v := range params {
|
||||
dsl["components"].(map[string]any)["begin"].(map[string]any)["obj"].(map[string]any)["params"].(map[string]any)[k] = v
|
||||
}
|
||||
return &entity.UserCanvas{
|
||||
ID: id,
|
||||
UserID: userID,
|
||||
CanvasCategory: "agent_canvas",
|
||||
DSL: entity.JSONMap(dsl),
|
||||
}
|
||||
}
|
||||
|
||||
// webhookCtx builds a gin test context with the supplied method/path/body
|
||||
// and a pre-set "user" so GetUser() returns success.
|
||||
func webhookCtx(method, path, body, contentType string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
var reader *strings.Reader
|
||||
if body == "" {
|
||||
reader = strings.NewReader("")
|
||||
} else {
|
||||
reader = strings.NewReader(body)
|
||||
}
|
||||
c.Request = httptest.NewRequest(method, path, reader)
|
||||
if contentType != "" {
|
||||
c.Request.Header.Set("Content-Type", contentType)
|
||||
}
|
||||
c.Set("user", &entity.User{ID: "u-1"})
|
||||
return c, w
|
||||
}
|
||||
|
||||
// errBody extracts {code, message} from a 102 envelope response.
|
||||
func errBody(t *testing.T, body []byte) (int, string) {
|
||||
t.Helper()
|
||||
var env struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &env); err != nil {
|
||||
t.Fatalf("decode response: %v (body=%s)", err, body)
|
||||
}
|
||||
return env.Code, env.Message
|
||||
}
|
||||
|
||||
// ---------- Phase 1: canvas loading / DataFlow rejection ----------
|
||||
|
||||
// TestWebhook_RejectsUnknownCanvas pins the 102 "Canvas not found."
|
||||
// envelope when LoadCanvasByID returns ErrUserCanvasNotFound. This is
|
||||
// the deliberate divergence from mapAgentError (which would surface 103).
|
||||
func TestWebhook_RejectsUnknownCanvas(t *testing.T) {
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{err: dao.ErrUserCanvasNotFound}}
|
||||
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != int(common.CodeDataError) {
|
||||
t.Errorf("code = %d, want %d (DataError)", code, common.CodeDataError)
|
||||
}
|
||||
if msg != "Canvas not found." {
|
||||
t.Errorf("message = %q, want %q", msg, "Canvas not found.")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhook_RejectsDataFlowCanvas covers the second guard: the canvas
|
||||
// exists but its category is DataFlow (which python
|
||||
// agent_api.py:1575 explicitly rejects).
|
||||
func TestWebhook_RejectsDataFlowCanvas(t *testing.T) {
|
||||
cv := makeWebhookCanvas("c1", "u-1", "Webhook", nil)
|
||||
cv.CanvasCategory = "DataFlow"
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != int(common.CodeDataError) {
|
||||
t.Errorf("code = %d, want %d", code, common.CodeDataError)
|
||||
}
|
||||
if msg != "Dataflow can not be triggered by webhook." {
|
||||
t.Errorf("message = %q, want %q", msg, "Dataflow can not be triggered by webhook.")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhook_RejectsMissingWebhookConfig: DSL has no Begin with
|
||||
// mode="Webhook".
|
||||
func TestWebhook_RejectsMissingWebhookConfig(t *testing.T) {
|
||||
cv := makeWebhookCanvas("c1", "u-1", "Manual", nil) // wrong mode
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != int(common.CodeDataError) {
|
||||
t.Errorf("code = %d, want %d", code, common.CodeDataError)
|
||||
}
|
||||
if msg != "Webhook not configured for this agent." {
|
||||
t.Errorf("message = %q", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Phase 5: method gate ----------
|
||||
|
||||
// TestWebhook_RejectsDisallowedMethod covers the methods list gate.
|
||||
// methods:["POST"] + a GET request → 102.
|
||||
func TestWebhook_RejectsDisallowedMethod(t *testing.T) {
|
||||
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
|
||||
"methods": []any{"POST"},
|
||||
})
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
c, w := webhookCtx("GET", "/api/v1/agents/c1/webhook", ``, "")
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != int(common.CodeDataError) {
|
||||
t.Errorf("code = %d, want %d", code, common.CodeDataError)
|
||||
}
|
||||
want := "HTTP method 'GET' not allowed for this webhook."
|
||||
if msg != want {
|
||||
t.Errorf("message = %q, want %q", msg, want)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Phase 6: security ----------
|
||||
|
||||
// TestWebhook_TokenAuthPasses pins the happy path for token auth.
|
||||
func TestWebhook_TokenAuthPasses(t *testing.T) {
|
||||
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
|
||||
"security": map[string]any{
|
||||
"auth_type": "token",
|
||||
"token": map[string]any{
|
||||
"token_header": "X-Webhook-Token",
|
||||
"token_value": "s3cret",
|
||||
},
|
||||
},
|
||||
"execution_mode": "Immediately",
|
||||
"response": map[string]any{
|
||||
"status": 200,
|
||||
},
|
||||
})
|
||||
loader := &fakeCanvasLoader{canvas: cv}
|
||||
h := &AgentHandler{loader: loader}
|
||||
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
|
||||
c.Request.Header.Set("X-Webhook-Token", "s3cret")
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200; body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhook_TokenAuthFails: header value mismatch → 102.
|
||||
func TestWebhook_TokenAuthFails(t *testing.T) {
|
||||
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
|
||||
"security": map[string]any{
|
||||
"auth_type": "token",
|
||||
"token": map[string]any{
|
||||
"token_header": "X-Webhook-Token",
|
||||
"token_value": "s3cret",
|
||||
},
|
||||
},
|
||||
})
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
|
||||
c.Request.Header.Set("X-Webhook-Token", "wrong")
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != int(common.CodeDataError) {
|
||||
t.Errorf("code = %d, want %d", code, common.CodeDataError)
|
||||
}
|
||||
if msg != "Invalid token authentication" {
|
||||
t.Errorf("message = %q, want %q", msg, "Invalid token authentication")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhook_BodySizeLimit: 1kb ceiling + ~2kb Content-Length → 102.
|
||||
func TestWebhook_BodySizeLimit(t *testing.T) {
|
||||
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
|
||||
"security": map[string]any{
|
||||
"max_body_size": "1kb",
|
||||
},
|
||||
})
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook",
|
||||
strings.Repeat("a", 2048), "text/plain")
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
code, _ := errBody(t, w.Body.Bytes())
|
||||
if code != int(common.CodeDataError) {
|
||||
t.Errorf("code = %d, want %d", code, common.CodeDataError)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhook_BodySizeLimitTooLarge rejects config bugs >10MB.
|
||||
func TestWebhook_BodySizeLimitTooLarge(t *testing.T) {
|
||||
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
|
||||
"security": map[string]any{
|
||||
"max_body_size": "11mb",
|
||||
},
|
||||
})
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != int(common.CodeDataError) {
|
||||
t.Errorf("code = %d, want %d", code, common.CodeDataError)
|
||||
}
|
||||
if !strings.Contains(msg, "exceeds maximum") {
|
||||
t.Errorf("message = %q, want contains 'exceeds maximum'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhook_BasicAuthPasses / TestWebhook_BasicAuthFails exercise the
|
||||
// HTTP Basic branch.
|
||||
func TestWebhook_BasicAuthPasses(t *testing.T) {
|
||||
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
|
||||
"security": map[string]any{
|
||||
"auth_type": "basic",
|
||||
"basic_auth": map[string]any{
|
||||
"username": "alice",
|
||||
"password": "wonderland",
|
||||
},
|
||||
},
|
||||
"execution_mode": "Immediately",
|
||||
"response": map[string]any{"status": 200},
|
||||
})
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
|
||||
c.Request.SetBasicAuth("alice", "wonderland")
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200; body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhook_BasicAuthFails(t *testing.T) {
|
||||
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
|
||||
"security": map[string]any{
|
||||
"auth_type": "basic",
|
||||
"basic_auth": map[string]any{
|
||||
"username": "alice",
|
||||
"password": "wonderland",
|
||||
},
|
||||
},
|
||||
})
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
|
||||
c.Request.SetBasicAuth("alice", "wrong")
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
code, _ := errBody(t, w.Body.Bytes())
|
||||
if code != int(common.CodeDataError) {
|
||||
t.Errorf("code = %d, want %d", code, common.CodeDataError)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Phase 8: schema extraction ----------
|
||||
|
||||
// TestWebhook_SchemaExtractionRequired: missing required field → 102.
|
||||
func TestWebhook_SchemaExtractionRequired(t *testing.T) {
|
||||
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
|
||||
"schema": map[string]any{
|
||||
"query": map[string]any{
|
||||
"properties": map[string]any{
|
||||
"q": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"q"},
|
||||
},
|
||||
},
|
||||
"execution_mode": "Immediately",
|
||||
"response": map[string]any{"status": 200},
|
||||
})
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook?other=v", ``, "")
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != int(common.CodeDataError) {
|
||||
t.Errorf("code = %d, want %d", code, common.CodeDataError)
|
||||
}
|
||||
if !strings.Contains(msg, "missing required field") {
|
||||
t.Errorf("message = %q, want contains 'missing required field'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhook_SchemaExtractionTypeMismatch: non-numeric in number field.
|
||||
func TestWebhook_SchemaExtractionTypeMismatch(t *testing.T) {
|
||||
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
|
||||
"schema": map[string]any{
|
||||
"body": map[string]any{
|
||||
"properties": map[string]any{
|
||||
"n": map[string]any{"type": "number"},
|
||||
},
|
||||
"required": []any{"n"},
|
||||
},
|
||||
},
|
||||
"execution_mode": "Immediately",
|
||||
"response": map[string]any{"status": 200},
|
||||
})
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{"n":"abc"}`, "application/json")
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != int(common.CodeDataError) {
|
||||
t.Errorf("code = %d, want %d", code, common.CodeDataError)
|
||||
}
|
||||
if !strings.Contains(msg, "type mismatch") && !strings.Contains(msg, "auto-cast") {
|
||||
t.Errorf("message = %q, want contains 'type mismatch' or 'auto-cast'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Phase 9: dispatch ----------
|
||||
|
||||
// TestWebhook_ImmediatelyReturnsConfiguredStatus confirms the
|
||||
// Immediately mode returns the configured status code and runs the
|
||||
// canvas detached.
|
||||
func TestWebhook_ImmediatelyReturnsConfiguredStatus(t *testing.T) {
|
||||
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
|
||||
"execution_mode": "Immediately",
|
||||
"response": map[string]any{
|
||||
"status": 201,
|
||||
"body_template": `{"ok":true}`,
|
||||
},
|
||||
})
|
||||
loader := &fakeCanvasLoader{
|
||||
canvas: cv,
|
||||
events: []canvas.RunEvent{
|
||||
{Type: "message", Data: `{"content":"hi"}`},
|
||||
{Type: "done", Data: ""},
|
||||
},
|
||||
}
|
||||
h := &AgentHandler{loader: loader}
|
||||
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Errorf("status = %d, want 201; body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), `"ok":true`) {
|
||||
t.Errorf("body = %q, want contains '\"ok\":true'", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhook_DefaultModeAggregatesContent covers the non-Immediately
|
||||
// path: streaming message + message_end events get concatenated.
|
||||
func TestWebhook_DefaultModeAggregatesContent(t *testing.T) {
|
||||
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
|
||||
"execution_mode": "Streaming",
|
||||
})
|
||||
loader := &fakeCanvasLoader{
|
||||
canvas: cv,
|
||||
events: []canvas.RunEvent{
|
||||
{Type: "message", Data: `{"content":"hello "}`},
|
||||
{Type: "message", Data: `{"content":"world"}`},
|
||||
{Type: "message_end", Data: `{"status":200}`},
|
||||
{Type: "done", Data: ""},
|
||||
},
|
||||
}
|
||||
h := &AgentHandler{loader: loader}
|
||||
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want 200; body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
var body struct {
|
||||
Message string `json:"message"`
|
||||
Code int `json:"code"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if body.Message != "hello world" {
|
||||
t.Errorf("aggregated message = %q, want %q", body.Message, "hello world")
|
||||
}
|
||||
if body.Code != 200 {
|
||||
t.Errorf("code = %d, want 200", body.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhook_InvalidResponseStatusRange rejects bad config (status 500).
|
||||
func TestWebhook_InvalidResponseStatusRange(t *testing.T) {
|
||||
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
|
||||
"execution_mode": "Immediately",
|
||||
"response": map[string]any{"status": 500},
|
||||
})
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{}`, "application/json")
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != int(common.CodeDataError) {
|
||||
t.Errorf("code = %d, want %d", code, common.CodeDataError)
|
||||
}
|
||||
if !strings.Contains(msg, "must be between 200 and 399") {
|
||||
t.Errorf("message = %q, want contains 'must be between 200 and 399'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhook_FindWebhookBegin_NilDSL ensures the helper is defensive.
|
||||
func TestWebhook_FindWebhookBegin_NilDSL(t *testing.T) {
|
||||
if got := findWebhookBegin(nil); got != nil {
|
||||
t.Errorf("findWebhookBegin(nil) = %v, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhook_FindWebhookBegin_NoComponents covers DSL with no components.
|
||||
func TestWebhook_FindWebhookBegin_NoComponents(t *testing.T) {
|
||||
if got := findWebhookBegin(map[string]any{}); got != nil {
|
||||
t.Errorf("findWebhookBegin(empty) = %v, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Regression tests for issues raised in code review ----------
|
||||
|
||||
// TestWebhook_BodySizeStreamBounded pins the security MEDIUM-1
|
||||
// hardening: when max_body_size is configured, an oversized body is
|
||||
// rejected. The 102 envelope carries code=CodeDataError; HTTP itself
|
||||
// stays at 200 (matches the existing /api/v1 error envelope shape).
|
||||
func TestWebhook_BodySizeStreamBounded(t *testing.T) {
|
||||
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
|
||||
"security": map[string]any{
|
||||
"max_body_size": "1kb",
|
||||
},
|
||||
})
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook",
|
||||
strings.Repeat("a", 2048), "text/plain")
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != int(common.CodeDataError) {
|
||||
t.Errorf("envelope code = %d, want %d (CodeDataError)", code, common.CodeDataError)
|
||||
}
|
||||
// Either message is acceptable — the Content-Length pre-check
|
||||
// ("request body too large") OR the MaxBytesReader runtime check
|
||||
// ("http: request body too large").
|
||||
if !strings.Contains(msg, "body too large") {
|
||||
t.Errorf("message = %q, want contains 'body too large'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// These three tests cover the behaviour the implementation summary
|
||||
// claimed but did NOT actually enforce. They ensure the dispatch path
|
||||
// returns the documented envelope for multipart and content-type
|
||||
// mismatch cases.
|
||||
|
||||
// TestWebhook_MultipartReturns501 pins the contract that
|
||||
// multipart/form-data uploads are refused with HTTP 501 BEFORE the
|
||||
// schema-extraction phase runs. Before the fix, the handler silently
|
||||
// stuffed __multipart_unsupported__ into body and continued; operators
|
||||
// saw a confusing schema-validation error instead of the documented
|
||||
// "not implemented" envelope.
|
||||
func TestWebhook_MultipartReturns501(t *testing.T) {
|
||||
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
|
||||
"execution_mode": "Immediately",
|
||||
"response": map[string]any{"status": 200},
|
||||
})
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook",
|
||||
"--boundary\r\nContent-Disposition: form-data; name=\"k\"\r\n\r\nv\r\n--boundary--\r\n",
|
||||
"multipart/form-data; boundary=boundary")
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
if w.Code != http.StatusNotImplemented {
|
||||
t.Errorf("status = %d, want 501; body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "not supported") {
|
||||
t.Errorf("body should mention 'not supported', got %q", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhook_ContentTypeMismatchReturns102 pins the contract that
|
||||
// content_types whitelist enforces a HARD REJECT (HTTP 102 envelope),
|
||||
// matching python agent_api.py:1839-1842. Before the fix, the
|
||||
// mismatch was silently recorded as __content_type_mismatch__ in the
|
||||
// body and the request flowed into schema validation — a real
|
||||
// whitelist bypass.
|
||||
func TestWebhook_ContentTypeMismatchReturns102(t *testing.T) {
|
||||
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
|
||||
"content_types": "application/json",
|
||||
"execution_mode": "Immediately",
|
||||
"response": map[string]any{"status": 200},
|
||||
})
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
// text/plain where the config demands application/json.
|
||||
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", "raw body", "text/plain")
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("envelope status = %d, want 200 (the envelope itself succeeds; the code field carries 102)", w.Code)
|
||||
}
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != int(common.CodeDataError) {
|
||||
t.Errorf("envelope code = %d, want %d (CodeDataError)", code, common.CodeDataError)
|
||||
}
|
||||
if !strings.Contains(msg, "invalid content-type") || !strings.Contains(msg, "application/json") {
|
||||
t.Errorf("message = %q, want contains 'invalid content-type' and 'application/json'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhook_ContentTypesEmptyAllowsAnything confirms the inverse:
|
||||
// when webhook_cfg does NOT set content_types, any Content-Type is
|
||||
// accepted (matches python: agent_api.py:1839 only raises when the
|
||||
// config actually sets content_types).
|
||||
func TestWebhook_ContentTypesEmptyAllowsAnything(t *testing.T) {
|
||||
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
|
||||
// content_types deliberately omitted
|
||||
"execution_mode": "Immediately",
|
||||
"response": map[string]any{"status": 200},
|
||||
})
|
||||
loader := &fakeCanvasLoader{canvas: cv}
|
||||
h := &AgentHandler{loader: loader}
|
||||
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{"k":"v"}`, "text/plain")
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200; body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhook_ContentTypeMissingHeaderRejected pins the gap from the
|
||||
// review: when content_types is configured, a request that OMITS the
|
||||
// Content-Type header entirely must be rejected with 102. The python
|
||||
// reference has a `if ctype and ...` short-circuit that lets such
|
||||
// requests through — a real whitelist bypass. The Go port tightens
|
||||
// this: an operator who configured content_types gets an enforceable
|
||||
// contract; the caller MUST send the header.
|
||||
func TestWebhook_ContentTypeMissingHeaderRejected(t *testing.T) {
|
||||
cv := makeWebhookCanvas("c1", "u-1", "Webhook", map[string]any{
|
||||
"content_types": "application/json",
|
||||
"execution_mode": "Immediately",
|
||||
"response": map[string]any{"status": 200},
|
||||
})
|
||||
h := &AgentHandler{loader: &fakeCanvasLoader{canvas: cv}}
|
||||
c, w := webhookCtx("POST", "/api/v1/agents/c1/webhook", `{"k":"v"}`, "")
|
||||
// ↑ "" content-type means the header is absent. gin will not
|
||||
// produce a default Content-Type on a POST so this is a clean
|
||||
// "no header" test case.
|
||||
|
||||
h.Webhook(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("envelope status = %d, want 200 (the envelope itself succeeds; the code field carries 102)", w.Code)
|
||||
}
|
||||
code, msg := errBody(t, w.Body.Bytes())
|
||||
if code != int(common.CodeDataError) {
|
||||
t.Errorf("envelope code = %d, want %d (CodeDataError)", code, common.CodeDataError)
|
||||
}
|
||||
if !strings.Contains(msg, "invalid content-type") {
|
||||
t.Errorf("message = %q, want contains 'invalid content-type'", msg)
|
||||
}
|
||||
if !strings.Contains(msg, "application/json") {
|
||||
t.Errorf("message = %q, want contains 'application/json'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// (No helper needed at the bottom of this file; helper functions
|
||||
// inline above.)
|
||||
// _unused previously lived here as a placeholder; deleted during
|
||||
// cleanup (code-review MEDIUM-2).
|
||||
@@ -56,6 +56,14 @@ func RegisterAgentRoutes(g *gin.RouterGroup, h *handler.AgentHandler) {
|
||||
g.PUT("/:canvas_id/tags", h.UpdateAgentTags)
|
||||
g.POST("/:canvas_id/reset", h.ResetAgent)
|
||||
|
||||
// File operations.
|
||||
g.GET("/download", h.DownloadAgentFile)
|
||||
g.POST("/:canvas_id/upload", h.UploadAgentFile)
|
||||
|
||||
// Component introspection + debug.
|
||||
g.GET("/:canvas_id/components/:component_id/input-form", h.GetComponentInputForm)
|
||||
g.POST("/:canvas_id/components/:component_id/debug", h.DebugComponent)
|
||||
|
||||
// Versions.
|
||||
g.GET("/:canvas_id/versions", h.ListVersions)
|
||||
g.GET("/:canvas_id/versions/:version_id", h.GetVersion)
|
||||
@@ -71,9 +79,43 @@ func RegisterAgentRoutes(g *gin.RouterGroup, h *handler.AgentHandler) {
|
||||
// Logs and webhook.
|
||||
g.GET("/:canvas_id/logs/:message_id", h.GetAgentLogs)
|
||||
g.GET("/:canvas_id/webhook/logs", h.GetAgentWebhookLogs)
|
||||
// Webhook trigger endpoints. The Python agent API
|
||||
// (api/apps/restful_apis/agent_api.py:1563-1564) registers six
|
||||
// HTTP methods on a single path. Gin has no Match() helper, so we
|
||||
// register each verb explicitly via registerAnyMethod. The handler
|
||||
// is identical for all six; semantics differ only by
|
||||
// c.Request.Method.
|
||||
registerAnyMethod(g, "/:canvas_id/webhook", h.Webhook)
|
||||
registerAnyMethod(g, "/:canvas_id/webhook/test", h.Webhook)
|
||||
|
||||
// Top-level actions (no canvas id in path).
|
||||
// NOTE: `/chat/completion` (singular) is intentionally NOT registered.
|
||||
// The singular form was a historical typo in earlier Python releases —
|
||||
// no client, SDK, or doc ever called it, and the Python side
|
||||
// (api/apps/restful_apis/agent_api.py) has since removed the route.
|
||||
// See plan: .claude/plans/agent-api-gaps-go-port.md §Gap E.
|
||||
g.POST("/chat/completions", h.AgentChatCompletions)
|
||||
g.POST("/rerun", h.RerunAgent)
|
||||
g.POST("/test_db_connection", h.TestDBConnection)
|
||||
}
|
||||
|
||||
// registerAnyMethod mirrors the Python
|
||||
// `@manager.route(path, methods=["POST","GET","PUT","PATCH","DELETE","HEAD"])`
|
||||
// pattern. Gin has no Match() helper, so we register each verb
|
||||
// explicitly. The handler is identical for all six — semantics differ
|
||||
// only by c.Request.Method.
|
||||
//
|
||||
// Centralising the registration here keeps RegisterAgentRoutes readable
|
||||
// when both the production trigger and the test trigger share the same
|
||||
// six-method shape.
|
||||
func registerAnyMethod(g *gin.RouterGroup, path string, h gin.HandlerFunc) {
|
||||
if g == nil || h == nil {
|
||||
return
|
||||
}
|
||||
g.POST(path, h)
|
||||
g.GET(path, h)
|
||||
g.PUT(path, h)
|
||||
g.PATCH(path, h)
|
||||
g.DELETE(path, h)
|
||||
g.HEAD(path, h)
|
||||
}
|
||||
|
||||
@@ -49,6 +49,51 @@ func genID32() string {
|
||||
return strings.ReplaceAll(uuid.New().String(), "-", "")[:32]
|
||||
}
|
||||
|
||||
// webhookPayloadKey is the unexported context key RunAgent reads to
|
||||
// inject root["webhook_payload"]. Only the AgentService.RunAgentWithWebhook
|
||||
// public wrapper sets it; the chat / agent-run paths leave it absent so
|
||||
// existing callers see no behaviour change.
|
||||
//
|
||||
// We deliberately do NOT surface the payload as a new RunAgent parameter
|
||||
// — keeping the public signature stable means existing tests
|
||||
// (agent_run_e2e_test.go, agent_wait_for_user_test.go) keep compiling.
|
||||
type webhookPayloadKey struct{}
|
||||
|
||||
// LoadCanvasByID is the read-side counterpart of loadCanvasForUser that
|
||||
// the webhook handler uses. It deliberately returns the raw DAO/service
|
||||
// error (no error-code mapping) because the webhook envelope is 102
|
||||
// "Canvas not found." while the chat/run envelope is 103 "Make sure you
|
||||
// have permission..." — the choice must stay at the HTTP layer where
|
||||
// each handler knows its own spec.
|
||||
//
|
||||
// Mirrors python: api/apps/restful_apis/agent_api.py:1570
|
||||
// (`UserCanvasService.get_by_id(agent_id)`), with the same IDOR guard
|
||||
// the chat handler uses.
|
||||
func (s *AgentService) LoadCanvasByID(
|
||||
ctx context.Context, userID, canvasID string,
|
||||
) (*entity.UserCanvas, error) {
|
||||
return s.loadCanvasForUser(ctx, userID, canvasID)
|
||||
}
|
||||
|
||||
// RunAgentWithWebhook is a thin wrapper over RunAgent that attaches the
|
||||
// webhook payload to the runner root so the Begin component can surface
|
||||
// it as state.Sys["webhook_payload"] for downstream components.
|
||||
//
|
||||
// The payload is intentionally passed via context value (rather than a new
|
||||
// RunAgent parameter) to keep the public RunAgent signature stable for
|
||||
// the existing chat tests.
|
||||
//
|
||||
// Mirrors python: api/apps/restful_apis/agent_api.py:2125
|
||||
// (`canvas.run(..., webhook_payload=clean_request)`).
|
||||
func (s *AgentService) RunAgentWithWebhook(
|
||||
ctx context.Context, userID, canvasID string, payload map[string]any,
|
||||
) (<-chan canvas.RunEvent, error) {
|
||||
if payload != nil {
|
||||
ctx = context.WithValue(ctx, webhookPayloadKey{}, payload)
|
||||
}
|
||||
return s.RunAgent(ctx, userID, canvasID, "", "", "")
|
||||
}
|
||||
|
||||
// ErrAgentNotOwner is returned by DeleteAgent when the canvas exists and
|
||||
// is accessible to the caller but is owned by a different user. It maps
|
||||
// to the Python "Only the owner of the agent is authorized for this
|
||||
@@ -670,6 +715,14 @@ func (s *AgentService) RunAgent(ctx context.Context, userID, canvasID, sessionID
|
||||
if dsl != nil {
|
||||
root["__dsl_present__"] = true
|
||||
}
|
||||
// Webhook payload injection. Only RunAgentWithWebhook sets this
|
||||
// context value; the chat / agent-run paths leave it nil so the
|
||||
// existing surface is unchanged. The Begin component reads
|
||||
// inputs["webhook_payload"] and writes it to state.Sys so
|
||||
// downstream components can read sys.webhook_payload.
|
||||
if payload, ok := ctx.Value(webhookPayloadKey{}).(map[string]any); ok && payload != nil {
|
||||
root["webhook_payload"] = payload
|
||||
}
|
||||
// Phase 4.4 V2.1 (v3.6.1): populate root["tenant_id"] so the
|
||||
// RunTracker.Start call (in buildRunFunc) records the run
|
||||
// under the right tenant. The lookup is best-effort — a
|
||||
|
||||
Reference in New Issue
Block a user