Fix: JWT algorithm-confusion in OIDC ID token verification (#15181)

### What problem does this PR solve?

Closes #15180.

`OIDCClient.parse_id_token` in `api/apps/auth/oidc.py` read the JWT
signing
algorithm from the **unverified** JWT header and passed it through to
`jwt.decode(..., algorithms=[alg], ...)` as the trust anchor. This is
the
textbook JWT algorithm-confusion vulnerability (CWE-345 / CWE-347). Any
unauthenticated client capable of reaching the OIDC callback could take
over
an arbitrary account on any RAGFlow deployment with OIDC login enabled:

1. **`alg: "none"`** — present a JWT with `{"alg": "none"}` and no
   signature segment → `jwt.decode(..., algorithms=["none"])` → PyJWT's
   `NoneAlgorithm` accepts the token without verification → login as any
   user.
2. **RSA / HMAC confusion** — fetch the public RSA key from the
provider's
   JWKS (it's public), forge a JWT with `{"alg": "HS256"}` HMAC-signed
   using the public-key bytes as the secret → `jwt.decode(...,
   algorithms=["HS256"], key=public_key)` → verifier accepts → login as
   any user. (Modern PyJWT independently refuses to use a PEM-formatted
   key as an HMAC secret, which mitigates this leg for PEM key formats;
the fix here is the only mitigation for raw / DER / JWK octet keys and
   for older PyJWT versions.)

### What changed

**`api/apps/auth/oidc.py`:**

- New module constants `_ALLOWED_OIDC_SIGNING_ALGS` (asymmetric-only:
  `RS*`, `ES*`, `PS*`, `EdDSA` — explicitly excludes `none` and `HS*`)
  and `_DEFAULT_OIDC_SIGNING_ALGS = ("RS256",)` (the OIDC Core 1.0 §2
  spec default).
- New helper `_resolve_id_token_signing_algs(metadata)` — intersects the
  provider's advertised `id_token_signing_alg_values_supported` from
`/.well-known/openid-configuration` with the safe allowlist; falls back
  to RS256 when the field is missing or contains only unsafe values.
- `OIDCClient.__init__` now stores the resolved allowlist on
  `self.id_token_signing_algs` — pinned once, from a trusted source, at
  construction time.
- `parse_id_token` no longer calls `jwt.get_unverified_header` and no
  longer reads `alg` from the JWT header. It passes
  `self.id_token_signing_algs` to `jwt.decode(..., algorithms=...)`.
  `PyJWKClient.get_signing_key_from_jwt` still reads the `kid` from the
  header internally for JWKS lookup — that's fine, `kid` is not a
  security decision; the signature still proves which key was actually
  used.


**`test/testcases/test_web_api/test_auth_app/test_oidc_client_unit.py`:**

- Existing `test_parse_id_token_success_and_error` drops its
`jwt.get_unverified_header` mock (no longer called by `parse_id_token`).
- `_metadata` and `_make_client` helpers grew an optional `signing_algs`
parameter so tests can configure what the discovery document advertises.
- New `TestSSRFValidation` / algorithm-confusion regression block (7
  tests):
  - `test_id_token_signing_algs_default_to_rs256_when_metadata_missing`
  - `test_id_token_signing_algs_intersect_metadata_with_safe_allowlist`
  - `test_id_token_signing_algs_fall_back_when_only_unsafe_advertised`
  - `test_id_token_signing_algs_ignores_non_string_entries`
  - `test_id_token_signing_algs_handles_non_list_metadata_field`
  - `test_parse_id_token_passes_pinned_algorithms_to_jwt_decode` —
    sabotages `jwt.get_unverified_header` to raise on call, proving the
    verification path never consults the unverified header.
- `test_parse_id_token_rejects_alg_none` — uses real PyJWT to encode an
    `alg: "none"` token; `parse_id_token` raises `ValueError("Error
    parsing ID Token: …")` instead of accepting it.
  - `test_parse_id_token_rejects_hs256_when_allowlist_is_asymmetric` —
    uses real PyJWT to forge an `alg: "HS256"` token with a non-PEM
    shared secret (so PyJWT's incidental PEM-as-HMAC refusal isn't what
    blocks it); `parse_id_token` raises because `HS256` is not in the
    pinned allowlist.

Sanity-checked end-to-end with real PyJWT outside the project test
runner:

- `alg=none` forged token + `algorithms=["RS256"]` →
`InvalidAlgorithmError` ✓
- `alg=HS256` forged token + `algorithms=["RS256"]` →
`InvalidAlgorithmError` ✓
- Same `alg=HS256` token + `algorithms=["HS256"]` → **accepted**
({'sub': 'admin'})
  — confirming the attack path was real before the fix.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Co-authored-by: galuis116 <contact@duerrimports.com>
This commit is contained in:
galuis116
2026-05-29 04:37:01 -07:00
committed by GitHub
parent cb1ea5a47f
commit d1f6594618
2 changed files with 260 additions and 16 deletions

View File

@@ -19,6 +19,45 @@ from common.http_client import sync_request
from .oauth import OAuthClient
# Asymmetric signing algorithms safe to accept for OIDC ID tokens.
# Symmetric HMAC algorithms (HS*) are intentionally excluded — when the
# verification key is the asymmetric public key fetched from the provider's
# JWKS (as it is for every OIDC ID token), accepting HS256 lets an attacker
# forge tokens by HMAC-signing them with the public key bytes
# (RSA/HMAC algorithm-confusion attack, CWE-347). "none" is excluded for the
# obvious reason that it disables signature verification entirely.
_ALLOWED_OIDC_SIGNING_ALGS = frozenset({
"RS256", "RS384", "RS512",
"ES256", "ES384", "ES512",
"PS256", "PS384", "PS512",
"EdDSA",
})
# OIDC Core 1.0 § 2 makes RS256 the spec-default ``id_token_signing_alg``,
# so this is the safe fallback when a provider's discovery document does not
# advertise ``id_token_signing_alg_values_supported`` (or advertises only
# algorithms outside the safe allowlist).
_DEFAULT_OIDC_SIGNING_ALGS = ("RS256",)
def _resolve_id_token_signing_algs(metadata):
"""Return the algorithms to pass to ``jwt.decode(..., algorithms=...)``.
Intersects the provider-advertised
``id_token_signing_alg_values_supported`` with
:data:`_ALLOWED_OIDC_SIGNING_ALGS`. Falls back to
:data:`_DEFAULT_OIDC_SIGNING_ALGS` when the provider does not advertise
the field or advertises only algorithms outside the safe allowlist —
crucially, the fallback is to RS256, **never** to whatever the JWT
header claims at verification time.
"""
advertised = metadata.get("id_token_signing_alg_values_supported") or []
if not isinstance(advertised, (list, tuple)):
advertised = []
safe = [a for a in advertised if isinstance(a, str) and a in _ALLOWED_OIDC_SIGNING_ALGS]
return safe or list(_DEFAULT_OIDC_SIGNING_ALGS)
class OIDCClient(OAuthClient):
def __init__(self, config):
"""
@@ -32,7 +71,7 @@ class OIDCClient(OAuthClient):
oidc_metadata = self._load_oidc_metadata(self.issuer)
config.update({
'issuer': oidc_metadata['issuer'],
'jwks_uri': oidc_metadata['jwks_uri'],
'jwks_uri': oidc_metadata['jwks_uri'],
'authorization_url': oidc_metadata['authorization_endpoint'],
'token_url': oidc_metadata['token_endpoint'],
'userinfo_url': oidc_metadata['userinfo_endpoint']
@@ -41,6 +80,11 @@ class OIDCClient(OAuthClient):
super().__init__(config)
self.issuer = config['issuer']
self.jwks_uri = config['jwks_uri']
# Pin the accepted ID-token signing algorithms at construction time
# from a trusted source (provider metadata + safe allowlist) so the
# JWT verification step in :meth:`parse_id_token` cannot be tricked
# by attacker-controlled JWT headers (CWE-345 / CWE-347).
self.id_token_signing_algs = _resolve_id_token_signing_algs(oidc_metadata)
@staticmethod
@@ -60,23 +104,29 @@ class OIDCClient(OAuthClient):
def parse_id_token(self, id_token):
"""
Parse and validate OIDC ID Token (JWT format) with signature verification.
The accepted signing algorithms come from ``self.id_token_signing_algs``
(pinned at construction time from the provider's discovery metadata,
intersected with :data:`_ALLOWED_OIDC_SIGNING_ALGS`). We deliberately
do **not** read the algorithm from the unverified JWT header — doing
so would let an attacker bypass signature verification by setting
``"alg": "none"`` or pull off the classic RSA / HMAC algorithm
confusion by setting ``"alg": "HS256"`` and signing with the public
key fetched from the provider's JWKS (CWE-345 / CWE-347).
"""
try:
# Decode JWT header without verifying signature
headers = jwt.get_unverified_header(id_token)
# OIDC usually uses `RS256` for signing
alg = headers.get("alg", "RS256")
# Use PyJWT's PyJWKClient to fetch JWKS and find signing key
# Use PyJWT's PyJWKClient to fetch JWKS and find signing key.
# The client reads the ``kid`` from the JWT header internally to
# look up the key — that's fine: ``kid`` is not a security
# decision, the signature still proves which key was used.
jwks_cli = jwt.PyJWKClient(self.jwks_uri)
signing_key = jwks_cli.get_signing_key_from_jwt(id_token).key
# Decode and verify signature
# Decode and verify signature against the pinned allowlist.
decoded_token = jwt.decode(
id_token,
key=signing_key,
algorithms=[alg],
algorithms=list(self.id_token_signing_algs),
audience=str(self.client_id),
issuer=self.issuer,
)

View File

@@ -127,18 +127,25 @@ def _base_config():
}
def _metadata(issuer):
return {
def _metadata(issuer, signing_algs=None):
md = {
"issuer": issuer,
"jwks_uri": f"{issuer}/jwks",
"authorization_endpoint": f"{issuer}/authorize",
"token_endpoint": f"{issuer}/token",
"userinfo_endpoint": f"{issuer}/userinfo",
}
if signing_algs is not None:
md["id_token_signing_alg_values_supported"] = signing_algs
return md
def _make_client(monkeypatch, oidc_module):
monkeypatch.setattr(oidc_module.OIDCClient, "_load_oidc_metadata", staticmethod(lambda issuer: _metadata(issuer)))
def _make_client(monkeypatch, oidc_module, signing_algs=None):
monkeypatch.setattr(
oidc_module.OIDCClient,
"_load_oidc_metadata",
staticmethod(lambda issuer: _metadata(issuer, signing_algs=signing_algs)),
)
return oidc_module.OIDCClient(_base_config())
@@ -199,8 +206,6 @@ def test_parse_id_token_success_and_error(monkeypatch):
_, oidc_module = _load_auth_modules(monkeypatch)
client = _make_client(monkeypatch, oidc_module)
monkeypatch.setattr(oidc_module.jwt, "get_unverified_header", lambda _token: {})
seen = {}
class _JwkClient(_DummyJwkClient):
@@ -245,6 +250,195 @@ def test_parse_id_token_success_and_error(monkeypatch):
assert str(exc_info.value) == "Error parsing ID Token: decode boom"
# ===================================================================== #
# JWT algorithm-confusion regression tests #
# #
# Before the fix, ``parse_id_token`` read the signing algorithm from #
# the unverified JWT header. An attacker who presents a JWT with #
# ``"alg": "none"`` would have signature verification disabled, and an #
# attacker who presents ``"alg": "HS256"`` and signs the JWT with the #
# public key bytes (RSA / HMAC confusion) would get the forged token #
# accepted. The tests below pin the contract: #
# #
# - the algorithm allowlist is pinned at construction time from the #
# provider's discovery metadata intersected with the safe allowlist #
# - the JWT header's ``alg`` claim is never read at decode time #
# ===================================================================== #
@pytest.mark.p2
def test_id_token_signing_algs_default_to_rs256_when_metadata_missing(monkeypatch):
"""No ``id_token_signing_alg_values_supported`` in metadata → RS256 only.
Crucially the fallback is RS256, never whatever the JWT header claims.
"""
_, oidc_module = _load_auth_modules(monkeypatch)
client = _make_client(monkeypatch, oidc_module)
assert client.id_token_signing_algs == ["RS256"]
@pytest.mark.p2
def test_id_token_signing_algs_intersect_metadata_with_safe_allowlist(monkeypatch):
"""Metadata advertises a mix of safe and unsafe algs — only safe ones kept."""
_, oidc_module = _load_auth_modules(monkeypatch)
client = _make_client(
monkeypatch,
oidc_module,
signing_algs=["RS256", "ES256", "HS256", "none", "PS512"],
)
assert set(client.id_token_signing_algs) == {"RS256", "ES256", "PS512"}
# The dangerous algorithms must not appear in the verification allowlist.
assert "HS256" not in client.id_token_signing_algs
assert "none" not in client.id_token_signing_algs
@pytest.mark.p2
def test_id_token_signing_algs_fall_back_when_only_unsafe_advertised(monkeypatch):
"""Provider advertises only HS256 / none → fall back to RS256, do not trust."""
_, oidc_module = _load_auth_modules(monkeypatch)
client = _make_client(
monkeypatch,
oidc_module,
signing_algs=["HS256", "none", "bogus"],
)
assert client.id_token_signing_algs == ["RS256"]
@pytest.mark.p2
def test_id_token_signing_algs_ignores_non_string_entries(monkeypatch):
"""Malformed entries (None / dict / int) are filtered out, not crashed on."""
_, oidc_module = _load_auth_modules(monkeypatch)
client = _make_client(
monkeypatch,
oidc_module,
signing_algs=["RS256", None, 42, {"x": 1}, "ES384"],
)
assert set(client.id_token_signing_algs) == {"RS256", "ES384"}
@pytest.mark.p2
def test_id_token_signing_algs_handles_non_list_metadata_field(monkeypatch):
"""If metadata gives a non-list type for the field, fall back to default."""
_, oidc_module = _load_auth_modules(monkeypatch)
client = _make_client(monkeypatch, oidc_module, signing_algs="RS256")
assert client.id_token_signing_algs == ["RS256"]
@pytest.mark.p2
def test_parse_id_token_passes_pinned_algorithms_to_jwt_decode(monkeypatch):
"""``jwt.decode`` receives the pinned allowlist, regardless of JWT header."""
_, oidc_module = _load_auth_modules(monkeypatch)
client = _make_client(monkeypatch, oidc_module, signing_algs=["RS256", "ES256"])
# Even if the unverified header claims something dangerous, the
# verification path must not consult it. We sabotage
# ``jwt.get_unverified_header`` to prove the code never calls it.
def _explode(_token): # pragma: no cover - must not be called
raise AssertionError(
"parse_id_token must not read the algorithm from the unverified JWT header"
)
monkeypatch.setattr(oidc_module.jwt, "get_unverified_header", _explode)
monkeypatch.setattr(oidc_module.jwt, "PyJWKClient", _DummyJwkClient)
seen = {}
def _decode(id_token, key, algorithms, audience, issuer):
seen["algorithms"] = list(algorithms)
return {"sub": "user-2"}
monkeypatch.setattr(oidc_module.jwt, "decode", _decode)
client.parse_id_token("malicious-header-token")
assert set(seen["algorithms"]) == {"RS256", "ES256"}
# Hard-stop: dangerous algorithms must never reach ``jwt.decode``.
assert "none" not in seen["algorithms"]
assert "HS256" not in seen["algorithms"]
@pytest.mark.p2
def test_parse_id_token_rejects_alg_none(monkeypatch):
"""End-to-end: an ``alg: "none"`` JWT must not authenticate.
Uses the real PyJWT decoder so the test exercises the actual contract
between ``parse_id_token`` and the upstream library.
"""
import jwt as real_jwt
_, oidc_module = _load_auth_modules(monkeypatch)
client = _make_client(monkeypatch, oidc_module) # defaults to RS256
# PyJWT requires explicit opt-in to encode ``alg=none``; even then it
# produces a token with no signature segment.
forged = real_jwt.encode(
{
"sub": "victim-subject",
"email": "admin@target.example",
"aud": "client-1",
"iss": "https://issuer.example",
},
key="",
algorithm="none",
)
# Force the JWKS step into a no-op so we exercise *just* the alg gate.
monkeypatch.setattr(oidc_module.jwt, "PyJWKClient", _DummyJwkClient)
with pytest.raises(ValueError) as exc_info:
client.parse_id_token(forged)
assert "Error parsing ID Token" in str(exc_info.value)
@pytest.mark.p2
def test_parse_id_token_rejects_hs256_when_allowlist_is_asymmetric(monkeypatch):
"""End-to-end: a JWT whose header claims ``alg: HS256`` must not be
accepted when the pinned allowlist is asymmetric-only.
This is the algorithm half of the RSA / HMAC confusion attack
(CWE-347). The attacker forges a JWT with ``"alg": "HS256"`` so the
server picks the HMAC verifier; pre-fix the server would read that alg
straight from the header and call
``jwt.decode(..., algorithms=["HS256"], key=public_key)`` which lets
the attacker forge tokens with the public key bytes. After the fix the
allowlist pinned at construction time wins — HS* is never in it — so
PyJWT raises ``InvalidAlgorithmError`` before the HMAC verifier is
ever invoked.
Note: modern PyJWT (>=2.0) also independently refuses to use a
PEM-encoded key as an HMAC secret, so the public-key-bytes leg of the
full attack is partially mitigated at the library level. The fix here
is defense in depth and the only mitigation for non-PEM key formats
(raw bytes, DER, JWK octet keys, older PyJWT versions).
"""
import jwt as real_jwt
_, oidc_module = _load_auth_modules(monkeypatch)
client = _make_client(monkeypatch, oidc_module) # defaults to RS256
# Use a non-PEM byte string so we exercise the alg gate (not PyJWT's
# incidental PEM-as-HMAC-secret refusal).
shared_secret = b"shared-secret-bytes-not-a-pem-key"
forged = real_jwt.encode(
{
"sub": "victim-subject",
"email": "admin@target.example",
"aud": "client-1",
"iss": "https://issuer.example",
},
key=shared_secret,
algorithm="HS256",
)
class _SecretJwkClient(_DummyJwkClient):
def get_signing_key_from_jwt(self, _id_token):
return SimpleNamespace(key=shared_secret)
monkeypatch.setattr(oidc_module.jwt, "PyJWKClient", _SecretJwkClient)
with pytest.raises(ValueError) as exc_info:
client.parse_id_token(forged)
assert "Error parsing ID Token" in str(exc_info.value)
@pytest.mark.p2
def test_fetch_user_info_merges_id_token_and_oauth_userinfo(monkeypatch):
oauth_module, oidc_module = _load_auth_modules(monkeypatch)