mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
port agent webhook trigger, agent file upload/download, component input-form + debug endpoints from Python - [x] New Feature (non-breaking change which adds functionality)
381 lines
13 KiB
Go
381 lines
13 KiB
Go
//
|
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
|
|
package handler
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
func securityCtx(t *testing.T, remoteAddr string, headers map[string]string) *gin.Context {
|
|
t.Helper()
|
|
gin.SetMode(gin.TestMode)
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest("POST", "/api/v1/agents/c1/webhook", strings.NewReader("{}"))
|
|
c.Request.Header.Set("Content-Type", "application/json")
|
|
c.Request.RemoteAddr = remoteAddr
|
|
for k, v := range headers {
|
|
c.Request.Header.Set(k, v)
|
|
}
|
|
return c
|
|
}
|
|
|
|
// TestValidateMaxBodySize_Allowed covers the no-op branch.
|
|
func TestValidateMaxBodySize_Allowed(t *testing.T) {
|
|
c := securityCtx(t, "1.2.3.4:0", nil)
|
|
c.Request.ContentLength = 100
|
|
if err := validateMaxBodySize(c, map[string]any{"max_body_size": "1kb"}); err != nil {
|
|
t.Errorf("err = %v, want nil", err)
|
|
}
|
|
}
|
|
|
|
// TestValidateMaxBodySize_TooLarge covers the size-mismatch branch.
|
|
func TestValidateMaxBodySize_TooLarge(t *testing.T) {
|
|
c := securityCtx(t, "1.2.3.4:0", nil)
|
|
c.Request.ContentLength = 2048
|
|
err := validateMaxBodySize(c, map[string]any{"max_body_size": "1kb"})
|
|
if err == nil || !strings.Contains(err.Error(), "request body too large") {
|
|
t.Errorf("err = %v, want 'request body too large'", err)
|
|
}
|
|
}
|
|
|
|
// TestValidateMaxBodySize_BadFormat covers non-numeric format.
|
|
func TestValidateMaxBodySize_BadFormat(t *testing.T) {
|
|
c := securityCtx(t, "1.2.3.4:0", nil)
|
|
err := validateMaxBodySize(c, map[string]any{"max_body_size": "1gb"})
|
|
if err == nil || !strings.Contains(err.Error(), "invalid max_body_size format") {
|
|
t.Errorf("err = %v, want 'invalid max_body_size format'", err)
|
|
}
|
|
}
|
|
|
|
// TestValidateIPWhitelist_EmptyIsAllow covers the empty-list branch.
|
|
func TestValidateIPWhitelist_EmptyIsAllow(t *testing.T) {
|
|
c := securityCtx(t, "1.2.3.4:0", nil)
|
|
if err := validateIPWhitelist(c, map[string]any{"ip_whitelist": []any{}}); err != nil {
|
|
t.Errorf("empty whitelist: err = %v, want nil", err)
|
|
}
|
|
}
|
|
|
|
// TestValidateIPWhitelist_ExactMatch passes when client IP matches.
|
|
func TestValidateIPWhitelist_ExactMatch(t *testing.T) {
|
|
c := securityCtx(t, "10.0.0.5:0", nil)
|
|
cfg := map[string]any{"ip_whitelist": []any{"10.0.0.5"}}
|
|
if err := validateIPWhitelist(c, cfg); err != nil {
|
|
t.Errorf("exact match: err = %v, want nil", err)
|
|
}
|
|
}
|
|
|
|
// TestValidateIPWhitelist_CIDR covers the CIDR branch.
|
|
func TestValidateIPWhitelist_CIDR(t *testing.T) {
|
|
c := securityCtx(t, "10.0.0.5:0", nil)
|
|
cfg := map[string]any{"ip_whitelist": []any{"10.0.0.0/8"}}
|
|
if err := validateIPWhitelist(c, cfg); err != nil {
|
|
t.Errorf("cidr match: err = %v, want nil", err)
|
|
}
|
|
}
|
|
|
|
// TestValidateIPWhitelist_RejectForeign confirms a foreign IP is denied.
|
|
func TestValidateIPWhitelist_RejectForeign(t *testing.T) {
|
|
c := securityCtx(t, "192.168.1.5:0", nil)
|
|
cfg := map[string]any{"ip_whitelist": []any{"10.0.0.0/8"}}
|
|
err := validateIPWhitelist(c, cfg)
|
|
if err == nil || !strings.Contains(err.Error(), "not allowed by whitelist") {
|
|
t.Errorf("err = %v, want 'not allowed by whitelist'", err)
|
|
}
|
|
}
|
|
|
|
// TestValidateAuth_NoneIsAllow covers the auth_type=="none" no-op.
|
|
func TestValidateAuth_NoneIsAllow(t *testing.T) {
|
|
c := securityCtx(t, "1.2.3.4:0", nil)
|
|
if err := validateAuth(c, map[string]any{"auth_type": "none"}); err != nil {
|
|
t.Errorf("auth_type=none: err = %v, want nil", err)
|
|
}
|
|
}
|
|
|
|
// TestValidateAuth_Unsupported covers unknown auth_type.
|
|
func TestValidateAuth_Unsupported(t *testing.T) {
|
|
c := securityCtx(t, "1.2.3.4:0", nil)
|
|
err := validateAuth(c, map[string]any{"auth_type": "weird"})
|
|
if err == nil || !strings.Contains(err.Error(), "unsupported auth_type") {
|
|
t.Errorf("err = %v, want 'unsupported auth_type'", err)
|
|
}
|
|
}
|
|
|
|
// TestValidateTokenAuth_HeaderValue covers matching and non-matching cases.
|
|
func TestValidateTokenAuth_HeaderValue(t *testing.T) {
|
|
cfg := map[string]any{
|
|
"token": map[string]any{
|
|
"token_header": "X-Token",
|
|
"token_value": "abc",
|
|
},
|
|
}
|
|
|
|
// pass
|
|
cPass := securityCtx(t, "1.2.3.4:0", map[string]string{"X-Token": "abc"})
|
|
if err := validateTokenAuth(cPass, cfg); err != nil {
|
|
t.Errorf("matching token: err = %v, want nil", err)
|
|
}
|
|
|
|
// fail
|
|
cFail := securityCtx(t, "1.2.3.4:0", map[string]string{"X-Token": "wrong"})
|
|
if err := validateTokenAuth(cFail, cfg); err == nil {
|
|
t.Errorf("non-matching token: err = nil, want error")
|
|
}
|
|
}
|
|
|
|
// TestValidateBasicAuth_PassAndFail covers both branches.
|
|
func TestValidateBasicAuth_PassAndFail(t *testing.T) {
|
|
cfg := map[string]any{
|
|
"basic_auth": map[string]any{
|
|
"username": "alice",
|
|
"password": "wonderland",
|
|
},
|
|
}
|
|
|
|
cPass := securityCtx(t, "1.2.3.4:0", nil)
|
|
cPass.Request.SetBasicAuth("alice", "wonderland")
|
|
if err := validateBasicAuth(cPass, cfg); err != nil {
|
|
t.Errorf("matching basic: err = %v, want nil", err)
|
|
}
|
|
|
|
cFail := securityCtx(t, "1.2.3.4:0", nil)
|
|
cFail.Request.SetBasicAuth("alice", "wrong")
|
|
if err := validateBasicAuth(cFail, cfg); err == nil {
|
|
t.Errorf("non-matching basic: err = nil, want error")
|
|
}
|
|
}
|
|
|
|
// TestValidateJWTAuth_NoSecret rejects empty secret config.
|
|
func TestValidateJWTAuth_NoSecret(t *testing.T) {
|
|
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer x"})
|
|
err := validateJWTAuth(c, map[string]any{"jwt": map[string]any{}})
|
|
if err == nil || !strings.Contains(err.Error(), "secret not configured") {
|
|
t.Errorf("err = %v, want 'secret not configured'", err)
|
|
}
|
|
}
|
|
|
|
// TestValidateJWTAuth_MissingBearer rejects when Authorization is absent.
|
|
func TestValidateJWTAuth_MissingBearer(t *testing.T) {
|
|
c := securityCtx(t, "1.2.3.4:0", nil)
|
|
err := validateJWTAuth(c, map[string]any{
|
|
"jwt": map[string]any{"secret": "s"},
|
|
})
|
|
if err == nil || !strings.Contains(err.Error(), "missing bearer token") {
|
|
t.Errorf("err = %v, want 'missing bearer token'", err)
|
|
}
|
|
}
|
|
|
|
// TestValidateJWTAuth_RS256HappyPath confirms that an RS256 token
|
|
// signed with a real RSA key is accepted when the secret field holds
|
|
// the matching PEM-encoded public key. The python reference uses the
|
|
// same `secret` slot for both HMAC and RSA inputs; this test pins that
|
|
// contract for the Go port.
|
|
func TestValidateJWTAuth_RS256HappyPath(t *testing.T) {
|
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("rsa.GenerateKey: %v", err)
|
|
}
|
|
pubPEM := encodeRSAPublicKeyPEM(&priv.PublicKey)
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
|
|
"sub": "u-1",
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
})
|
|
signed, err := token.SignedString(priv)
|
|
if err != nil {
|
|
t.Fatalf("sign: %v", err)
|
|
}
|
|
|
|
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer " + signed})
|
|
err = validateJWTAuth(c, map[string]any{
|
|
"jwt": map[string]any{
|
|
"secret": string(pubPEM),
|
|
"algorithm": "RS256",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Errorf("RS256 happy path: err = %v, want nil", err)
|
|
}
|
|
}
|
|
|
|
// TestValidateJWTAuth_RS256BadPEM rejects a malformed public key.
|
|
func TestValidateJWTAuth_RS256BadPEM(t *testing.T) {
|
|
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer x"})
|
|
err := validateJWTAuth(c, map[string]any{
|
|
"jwt": map[string]any{
|
|
"secret": "not a pem block",
|
|
"algorithm": "RS256",
|
|
},
|
|
})
|
|
if err == nil || !strings.Contains(err.Error(), "rsa public key") {
|
|
t.Errorf("err = %v, want 'rsa public key' parse error", err)
|
|
}
|
|
}
|
|
|
|
// TestValidateJWTAuth_UnsupportedAlgorithm rejects truly unknown algos.
|
|
func TestValidateJWTAuth_UnsupportedAlgorithm(t *testing.T) {
|
|
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer x"})
|
|
err := validateJWTAuth(c, map[string]any{
|
|
"jwt": map[string]any{
|
|
"secret": "s",
|
|
"algorithm": "none",
|
|
},
|
|
})
|
|
if err == nil || !strings.Contains(err.Error(), "unsupported") {
|
|
t.Errorf("err = %v, want 'unsupported'", err)
|
|
}
|
|
}
|
|
|
|
// encodeRSAPublicKeyPEM serialises an *rsa.PublicKey into the
|
|
// PEM-encoded PKIX form that jwt.ParseRSAPublicKeyFromPEM expects.
|
|
func encodeRSAPublicKeyPEM(pub *rsa.PublicKey) []byte {
|
|
asn1, err := x509.MarshalPKIXPublicKey(pub)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
block := &pem.Block{Type: "PUBLIC KEY", Bytes: asn1}
|
|
return pem.EncodeToMemory(block)
|
|
}
|
|
|
|
// TestValidateJWTAuth_ReservedClaimRejected covers the reserved-claim
|
|
// guard. We build a valid HS256 JWT first so the parse succeeds, then
|
|
// ask for `exp` as a required claim — the validator must reject it.
|
|
func TestValidateJWTAuth_ReservedClaimRejected(t *testing.T) {
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
|
"sub": "u-1",
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
})
|
|
signed, err := token.SignedString([]byte("s"))
|
|
if err != nil {
|
|
t.Fatalf("sign: %v", err)
|
|
}
|
|
c := securityCtx(t, "1.2.3.4:0", map[string]string{"Authorization": "Bearer " + signed})
|
|
err = validateJWTAuth(c, map[string]any{
|
|
"jwt": map[string]any{
|
|
"secret": "s",
|
|
"required_claims": "exp",
|
|
},
|
|
})
|
|
if err == nil || !strings.Contains(err.Error(), "reserved jwt claim") {
|
|
t.Errorf("err = %v, want 'reserved jwt claim'", err)
|
|
}
|
|
}
|
|
|
|
// TestValidateRateLimit_NoConfig covers the no-rate-limit branch.
|
|
func TestValidateRateLimit_NoConfig(t *testing.T) {
|
|
if err := validateRateLimit("c1", map[string]any{}); err != nil {
|
|
t.Errorf("no rate_limit: err = %v, want nil", err)
|
|
}
|
|
}
|
|
|
|
// TestValidateRateLimit_BadPer rejects unknown per window.
|
|
func TestValidateRateLimit_BadPer(t *testing.T) {
|
|
err := validateRateLimit("c1", map[string]any{
|
|
"rate_limit": map[string]any{"limit": 10, "per": "week"},
|
|
})
|
|
if err == nil || !strings.Contains(err.Error(), "invalid rate_limit.per") {
|
|
t.Errorf("err = %v, want 'invalid rate_limit.per'", err)
|
|
}
|
|
}
|
|
|
|
// TestValidateRateLimit_BadLimit rejects non-positive limits.
|
|
func TestValidateRateLimit_BadLimit(t *testing.T) {
|
|
err := validateRateLimit("c1", map[string]any{
|
|
"rate_limit": map[string]any{"limit": 0, "per": "minute"},
|
|
})
|
|
if err == nil || !strings.Contains(err.Error(), "must be > 0") {
|
|
t.Errorf("err = %v, want 'must be > 0'", err)
|
|
}
|
|
}
|
|
|
|
// (No helper needed at the bottom of this file; helper functions
|
|
// inline above.)
|
|
// _securityUnused previously lived here as a placeholder; deleted
|
|
// during cleanup (code-review MEDIUM-2).
|
|
|
|
// TestValidateMaxBodySize_OverflowGuard covers CodeRabbit PR review
|
|
// #3: a configured n that would overflow n*bytesPerMB (e.g. a huge
|
|
// mb value) must be rejected before the multiplication, not
|
|
// silently wrap to a small number and pass the cap check.
|
|
//
|
|
// The chosen n=9_223_372_036_855 multiplied by 1_000_000 (= 1 MB)
|
|
// overflows int64 max (9.22e18), so without the pre-multiplication
|
|
// guard this would silently wrap to a small positive number and
|
|
// the cap check would (incorrectly) succeed. With the guard in
|
|
// place, the rejection happens at the parse step, before any
|
|
// multiplication.
|
|
func TestValidateMaxBodySize_OverflowGuard(t *testing.T) {
|
|
c := securityCtx(t, "1.2.3.4:0", nil)
|
|
err := validateMaxBodySize(c, map[string]any{"max_body_size": "9223372036855mb"})
|
|
if err == nil {
|
|
t.Errorf("huge mb value: err = nil, want overflow-rejection error")
|
|
} else if !strings.Contains(err.Error(), "exceeds maximum") {
|
|
t.Errorf("err = %v, want 'exceeds maximum'", err)
|
|
}
|
|
}
|
|
|
|
// TestParseMaxBodySize_DecimalUnits documents the per-user-request
|
|
// SI-decimal unit base: 1 kb = 1 000 B, 1 mb = 1 000 000 B.
|
|
// (The python reference uses 1 024 / 1 048 576.)
|
|
func TestParseMaxBodySize_DecimalUnits(t *testing.T) {
|
|
cases := []struct {
|
|
in string
|
|
want int64
|
|
}{
|
|
{"1kb", 1000},
|
|
{"5kb", 5000},
|
|
{"1mb", 1_000_000},
|
|
{"10mb", 10_000_000}, // exact cap
|
|
}
|
|
for _, tc := range cases {
|
|
got, err := parseMaxBodySize(map[string]any{"max_body_size": tc.in})
|
|
if err != nil {
|
|
t.Errorf("parseMaxBodySize(%q): err = %v, want nil", tc.in, err)
|
|
continue
|
|
}
|
|
if got != tc.want {
|
|
t.Errorf("parseMaxBodySize(%q) = %d, want %d", tc.in, got, tc.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestValidateTokenAuth_EmptyValueRejected covers CodeRabbit PR
|
|
// review #4: an empty configured token_value used to mean "accept
|
|
// any request without that header". Now it must be rejected.
|
|
func TestValidateTokenAuth_EmptyValueRejected(t *testing.T) {
|
|
cfg := map[string]any{
|
|
"token": map[string]any{
|
|
"token_header": "X-Token",
|
|
"token_value": "",
|
|
},
|
|
}
|
|
c := securityCtx(t, "1.2.3.4:0", nil) // no header at all
|
|
if err := validateTokenAuth(c, cfg); err == nil {
|
|
t.Errorf("empty token_value: err = nil, want error")
|
|
}
|
|
}
|