Files
ragflow/internal/handler/agent_webhook_security_test.go
Zhichang Yu 5b09910d52 feat[Go]: port agent webhook trigger, agent file upload/download, component input-form + debug endpoints from Python (#16403)
port agent webhook trigger, agent file upload/download, component
input-form + debug endpoints from Python
- [x] New Feature (non-breaking change which adds functionality)
2026-06-27 14:07:22 +08:00

381 lines
13 KiB
Go

//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package handler
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
)
func securityCtx(t *testing.T, remoteAddr string, headers map[string]string) *gin.Context {
t.Helper()
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/v1/agents/c1/webhook", strings.NewReader("{}"))
c.Request.Header.Set("Content-Type", "application/json")
c.Request.RemoteAddr = remoteAddr
for k, v := range headers {
c.Request.Header.Set(k, v)
}
return c
}
// TestValidateMaxBodySize_Allowed covers the no-op branch.
func TestValidateMaxBodySize_Allowed(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", nil)
c.Request.ContentLength = 100
if err := validateMaxBodySize(c, map[string]any{"max_body_size": "1kb"}); err != nil {
t.Errorf("err = %v, want nil", err)
}
}
// TestValidateMaxBodySize_TooLarge covers the size-mismatch branch.
func TestValidateMaxBodySize_TooLarge(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", nil)
c.Request.ContentLength = 2048
err := validateMaxBodySize(c, map[string]any{"max_body_size": "1kb"})
if err == nil || !strings.Contains(err.Error(), "request body too large") {
t.Errorf("err = %v, want 'request body too large'", err)
}
}
// TestValidateMaxBodySize_BadFormat covers non-numeric format.
func TestValidateMaxBodySize_BadFormat(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", nil)
err := validateMaxBodySize(c, map[string]any{"max_body_size": "1gb"})
if err == nil || !strings.Contains(err.Error(), "invalid max_body_size format") {
t.Errorf("err = %v, want 'invalid max_body_size format'", err)
}
}
// TestValidateIPWhitelist_EmptyIsAllow covers the empty-list branch.
func TestValidateIPWhitelist_EmptyIsAllow(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", nil)
if err := validateIPWhitelist(c, map[string]any{"ip_whitelist": []any{}}); err != nil {
t.Errorf("empty whitelist: err = %v, want nil", err)
}
}
// TestValidateIPWhitelist_ExactMatch passes when client IP matches.
func TestValidateIPWhitelist_ExactMatch(t *testing.T) {
c := securityCtx(t, "10.0.0.5:0", nil)
cfg := map[string]any{"ip_whitelist": []any{"10.0.0.5"}}
if err := validateIPWhitelist(c, cfg); err != nil {
t.Errorf("exact match: err = %v, want nil", err)
}
}
// TestValidateIPWhitelist_CIDR covers the CIDR branch.
func TestValidateIPWhitelist_CIDR(t *testing.T) {
c := securityCtx(t, "10.0.0.5:0", nil)
cfg := map[string]any{"ip_whitelist": []any{"10.0.0.0/8"}}
if err := validateIPWhitelist(c, cfg); err != nil {
t.Errorf("cidr match: err = %v, want nil", err)
}
}
// TestValidateIPWhitelist_RejectForeign confirms a foreign IP is denied.
func TestValidateIPWhitelist_RejectForeign(t *testing.T) {
c := securityCtx(t, "192.168.1.5:0", nil)
cfg := map[string]any{"ip_whitelist": []any{"10.0.0.0/8"}}
err := validateIPWhitelist(c, cfg)
if err == nil || !strings.Contains(err.Error(), "not allowed by whitelist") {
t.Errorf("err = %v, want 'not allowed by whitelist'", err)
}
}
// TestValidateAuth_NoneIsAllow covers the auth_type=="none" no-op.
func TestValidateAuth_NoneIsAllow(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", nil)
if err := validateAuth(c, map[string]any{"auth_type": "none"}); err != nil {
t.Errorf("auth_type=none: err = %v, want nil", err)
}
}
// TestValidateAuth_Unsupported covers unknown auth_type.
func TestValidateAuth_Unsupported(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", nil)
err := validateAuth(c, map[string]any{"auth_type": "weird"})
if err == nil || !strings.Contains(err.Error(), "unsupported auth_type") {
t.Errorf("err = %v, want 'unsupported auth_type'", err)
}
}
// TestValidateTokenAuth_HeaderValue covers matching and non-matching cases.
func TestValidateTokenAuth_HeaderValue(t *testing.T) {
cfg := map[string]any{
"token": map[string]any{
"token_header": "X-Token",
"token_value": "abc",
},
}
// pass
cPass := securityCtx(t, "1.2.3.4:0", map[string]string{"X-Token": "abc"})
if err := validateTokenAuth(cPass, cfg); err != nil {
t.Errorf("matching token: err = %v, want nil", err)
}
// fail
cFail := securityCtx(t, "1.2.3.4:0", map[string]string{"X-Token": "wrong"})
if err := validateTokenAuth(cFail, cfg); err == nil {
t.Errorf("non-matching token: err = nil, want error")
}
}
// TestValidateBasicAuth_PassAndFail covers both branches.
func TestValidateBasicAuth_PassAndFail(t *testing.T) {
cfg := map[string]any{
"basic_auth": map[string]any{
"username": "alice",
"password": "wonderland",
},
}
cPass := securityCtx(t, "1.2.3.4:0", nil)
cPass.Request.SetBasicAuth("alice", "wonderland")
if err := validateBasicAuth(cPass, cfg); err != nil {
t.Errorf("matching basic: err = %v, want nil", err)
}
cFail := securityCtx(t, "1.2.3.4:0", nil)
cFail.Request.SetBasicAuth("alice", "wrong")
if err := validateBasicAuth(cFail, cfg); err == nil {
t.Errorf("non-matching basic: err = nil, want error")
}
}
// TestValidateJWTAuth_NoSecret rejects empty secret config.
func TestValidateJWTAuth_NoSecret(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer x"})
err := validateJWTAuth(c, map[string]any{"jwt": map[string]any{}})
if err == nil || !strings.Contains(err.Error(), "secret not configured") {
t.Errorf("err = %v, want 'secret not configured'", err)
}
}
// TestValidateJWTAuth_MissingBearer rejects when Authorization is absent.
func TestValidateJWTAuth_MissingBearer(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", nil)
err := validateJWTAuth(c, map[string]any{
"jwt": map[string]any{"secret": "s"},
})
if err == nil || !strings.Contains(err.Error(), "missing bearer token") {
t.Errorf("err = %v, want 'missing bearer token'", err)
}
}
// TestValidateJWTAuth_RS256HappyPath confirms that an RS256 token
// signed with a real RSA key is accepted when the secret field holds
// the matching PEM-encoded public key. The python reference uses the
// same `secret` slot for both HMAC and RSA inputs; this test pins that
// contract for the Go port.
func TestValidateJWTAuth_RS256HappyPath(t *testing.T) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("rsa.GenerateKey: %v", err)
}
pubPEM := encodeRSAPublicKeyPEM(&priv.PublicKey)
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"sub": "u-1",
"exp": time.Now().Add(time.Hour).Unix(),
})
signed, err := token.SignedString(priv)
if err != nil {
t.Fatalf("sign: %v", err)
}
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer " + signed})
err = validateJWTAuth(c, map[string]any{
"jwt": map[string]any{
"secret": string(pubPEM),
"algorithm": "RS256",
},
})
if err != nil {
t.Errorf("RS256 happy path: err = %v, want nil", err)
}
}
// TestValidateJWTAuth_RS256BadPEM rejects a malformed public key.
func TestValidateJWTAuth_RS256BadPEM(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer x"})
err := validateJWTAuth(c, map[string]any{
"jwt": map[string]any{
"secret": "not a pem block",
"algorithm": "RS256",
},
})
if err == nil || !strings.Contains(err.Error(), "rsa public key") {
t.Errorf("err = %v, want 'rsa public key' parse error", err)
}
}
// TestValidateJWTAuth_UnsupportedAlgorithm rejects truly unknown algos.
func TestValidateJWTAuth_UnsupportedAlgorithm(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer x"})
err := validateJWTAuth(c, map[string]any{
"jwt": map[string]any{
"secret": "s",
"algorithm": "none",
},
})
if err == nil || !strings.Contains(err.Error(), "unsupported") {
t.Errorf("err = %v, want 'unsupported'", err)
}
}
// encodeRSAPublicKeyPEM serialises an *rsa.PublicKey into the
// PEM-encoded PKIX form that jwt.ParseRSAPublicKeyFromPEM expects.
func encodeRSAPublicKeyPEM(pub *rsa.PublicKey) []byte {
asn1, err := x509.MarshalPKIXPublicKey(pub)
if err != nil {
panic(err)
}
block := &pem.Block{Type: "PUBLIC KEY", Bytes: asn1}
return pem.EncodeToMemory(block)
}
// TestValidateJWTAuth_ReservedClaimRejected covers the reserved-claim
// guard. We build a valid HS256 JWT first so the parse succeeds, then
// ask for `exp` as a required claim — the validator must reject it.
func TestValidateJWTAuth_ReservedClaimRejected(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": "u-1",
"exp": time.Now().Add(time.Hour).Unix(),
})
signed, err := token.SignedString([]byte("s"))
if err != nil {
t.Fatalf("sign: %v", err)
}
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer " + signed})
err = validateJWTAuth(c, map[string]any{
"jwt": map[string]any{
"secret": "s",
"required_claims": "exp",
},
})
if err == nil || !strings.Contains(err.Error(), "reserved jwt claim") {
t.Errorf("err = %v, want 'reserved jwt claim'", err)
}
}
// TestValidateRateLimit_NoConfig covers the no-rate-limit branch.
func TestValidateRateLimit_NoConfig(t *testing.T) {
if err := validateRateLimit("c1", map[string]any{}); err != nil {
t.Errorf("no rate_limit: err = %v, want nil", err)
}
}
// TestValidateRateLimit_BadPer rejects unknown per window.
func TestValidateRateLimit_BadPer(t *testing.T) {
err := validateRateLimit("c1", map[string]any{
"rate_limit": map[string]any{"limit": 10, "per": "week"},
})
if err == nil || !strings.Contains(err.Error(), "invalid rate_limit.per") {
t.Errorf("err = %v, want 'invalid rate_limit.per'", err)
}
}
// TestValidateRateLimit_BadLimit rejects non-positive limits.
func TestValidateRateLimit_BadLimit(t *testing.T) {
err := validateRateLimit("c1", map[string]any{
"rate_limit": map[string]any{"limit": 0, "per": "minute"},
})
if err == nil || !strings.Contains(err.Error(), "must be > 0") {
t.Errorf("err = %v, want 'must be > 0'", err)
}
}
// (No helper needed at the bottom of this file; helper functions
// inline above.)
// _securityUnused previously lived here as a placeholder; deleted
// during cleanup (code-review MEDIUM-2).
// TestValidateMaxBodySize_OverflowGuard covers CodeRabbit PR review
// #3: a configured n that would overflow n*bytesPerMB (e.g. a huge
// mb value) must be rejected before the multiplication, not
// silently wrap to a small number and pass the cap check.
//
// The chosen n=9_223_372_036_855 multiplied by 1_000_000 (= 1 MB)
// overflows int64 max (9.22e18), so without the pre-multiplication
// guard this would silently wrap to a small positive number and
// the cap check would (incorrectly) succeed. With the guard in
// place, the rejection happens at the parse step, before any
// multiplication.
func TestValidateMaxBodySize_OverflowGuard(t *testing.T) {
c := securityCtx(t, "1.2.3.4:0", nil)
err := validateMaxBodySize(c, map[string]any{"max_body_size": "9223372036855mb"})
if err == nil {
t.Errorf("huge mb value: err = nil, want overflow-rejection error")
} else if !strings.Contains(err.Error(), "exceeds maximum") {
t.Errorf("err = %v, want 'exceeds maximum'", err)
}
}
// TestParseMaxBodySize_DecimalUnits documents the per-user-request
// SI-decimal unit base: 1 kb = 1 000 B, 1 mb = 1 000 000 B.
// (The python reference uses 1 024 / 1 048 576.)
func TestParseMaxBodySize_DecimalUnits(t *testing.T) {
cases := []struct {
in string
want int64
}{
{"1kb", 1000},
{"5kb", 5000},
{"1mb", 1_000_000},
{"10mb", 10_000_000}, // exact cap
}
for _, tc := range cases {
got, err := parseMaxBodySize(map[string]any{"max_body_size": tc.in})
if err != nil {
t.Errorf("parseMaxBodySize(%q): err = %v, want nil", tc.in, err)
continue
}
if got != tc.want {
t.Errorf("parseMaxBodySize(%q) = %d, want %d", tc.in, got, tc.want)
}
}
}
// TestValidateTokenAuth_EmptyValueRejected covers CodeRabbit PR
// review #4: an empty configured token_value used to mean "accept
// any request without that header". Now it must be rejected.
func TestValidateTokenAuth_EmptyValueRejected(t *testing.T) {
cfg := map[string]any{
"token": map[string]any{
"token_header": "X-Token",
"token_value": "",
},
}
c := securityCtx(t, "1.2.3.4:0", nil) // no header at all
if err := validateTokenAuth(c, cfg); err == nil {
t.Errorf("empty token_value: err = nil, want error")
}
}