From d1f6594618bf121f0463aa727495061b5c9bba1a Mon Sep 17 00:00:00 2001 From: galuis116 <116897328+galuis116@users.noreply.github.com> Date: Fri, 29 May 2026 04:37:01 -0700 Subject: [PATCH] Fix: JWT algorithm-confusion in OIDC ID token verification (#15181) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? Closes #15180. `OIDCClient.parse_id_token` in `api/apps/auth/oidc.py` read the JWT signing algorithm from the **unverified** JWT header and passed it through to `jwt.decode(..., algorithms=[alg], ...)` as the trust anchor. This is the textbook JWT algorithm-confusion vulnerability (CWE-345 / CWE-347). Any unauthenticated client capable of reaching the OIDC callback could take over an arbitrary account on any RAGFlow deployment with OIDC login enabled: 1. **`alg: "none"`** — present a JWT with `{"alg": "none"}` and no signature segment → `jwt.decode(..., algorithms=["none"])` → PyJWT's `NoneAlgorithm` accepts the token without verification → login as any user. 2. **RSA / HMAC confusion** — fetch the public RSA key from the provider's JWKS (it's public), forge a JWT with `{"alg": "HS256"}` HMAC-signed using the public-key bytes as the secret → `jwt.decode(..., algorithms=["HS256"], key=public_key)` → verifier accepts → login as any user. (Modern PyJWT independently refuses to use a PEM-formatted key as an HMAC secret, which mitigates this leg for PEM key formats; the fix here is the only mitigation for raw / DER / JWK octet keys and for older PyJWT versions.) ### What changed **`api/apps/auth/oidc.py`:** - New module constants `_ALLOWED_OIDC_SIGNING_ALGS` (asymmetric-only: `RS*`, `ES*`, `PS*`, `EdDSA` — explicitly excludes `none` and `HS*`) and `_DEFAULT_OIDC_SIGNING_ALGS = ("RS256",)` (the OIDC Core 1.0 §2 spec default). - New helper `_resolve_id_token_signing_algs(metadata)` — intersects the provider's advertised `id_token_signing_alg_values_supported` from `/.well-known/openid-configuration` with the safe allowlist; falls back to RS256 when the field is missing or contains only unsafe values. - `OIDCClient.__init__` now stores the resolved allowlist on `self.id_token_signing_algs` — pinned once, from a trusted source, at construction time. - `parse_id_token` no longer calls `jwt.get_unverified_header` and no longer reads `alg` from the JWT header. It passes `self.id_token_signing_algs` to `jwt.decode(..., algorithms=...)`. `PyJWKClient.get_signing_key_from_jwt` still reads the `kid` from the header internally for JWKS lookup — that's fine, `kid` is not a security decision; the signature still proves which key was actually used. **`test/testcases/test_web_api/test_auth_app/test_oidc_client_unit.py`:** - Existing `test_parse_id_token_success_and_error` drops its `jwt.get_unverified_header` mock (no longer called by `parse_id_token`). - `_metadata` and `_make_client` helpers grew an optional `signing_algs` parameter so tests can configure what the discovery document advertises. - New `TestSSRFValidation` / algorithm-confusion regression block (7 tests): - `test_id_token_signing_algs_default_to_rs256_when_metadata_missing` - `test_id_token_signing_algs_intersect_metadata_with_safe_allowlist` - `test_id_token_signing_algs_fall_back_when_only_unsafe_advertised` - `test_id_token_signing_algs_ignores_non_string_entries` - `test_id_token_signing_algs_handles_non_list_metadata_field` - `test_parse_id_token_passes_pinned_algorithms_to_jwt_decode` — sabotages `jwt.get_unverified_header` to raise on call, proving the verification path never consults the unverified header. - `test_parse_id_token_rejects_alg_none` — uses real PyJWT to encode an `alg: "none"` token; `parse_id_token` raises `ValueError("Error parsing ID Token: …")` instead of accepting it. - `test_parse_id_token_rejects_hs256_when_allowlist_is_asymmetric` — uses real PyJWT to forge an `alg: "HS256"` token with a non-PEM shared secret (so PyJWT's incidental PEM-as-HMAC refusal isn't what blocks it); `parse_id_token` raises because `HS256` is not in the pinned allowlist. Sanity-checked end-to-end with real PyJWT outside the project test runner: - `alg=none` forged token + `algorithms=["RS256"]` → `InvalidAlgorithmError` ✓ - `alg=HS256` forged token + `algorithms=["RS256"]` → `InvalidAlgorithmError` ✓ - Same `alg=HS256` token + `algorithms=["HS256"]` → **accepted** ({'sub': 'admin'}) — confirming the attack path was real before the fix. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) Co-authored-by: galuis116 --- api/apps/auth/oidc.py | 70 +++++- .../test_auth_app/test_oidc_client_unit.py | 206 +++++++++++++++++- 2 files changed, 260 insertions(+), 16 deletions(-) diff --git a/api/apps/auth/oidc.py b/api/apps/auth/oidc.py index 80ac79399f..e28e982805 100644 --- a/api/apps/auth/oidc.py +++ b/api/apps/auth/oidc.py @@ -19,6 +19,45 @@ from common.http_client import sync_request from .oauth import OAuthClient +# Asymmetric signing algorithms safe to accept for OIDC ID tokens. +# Symmetric HMAC algorithms (HS*) are intentionally excluded — when the +# verification key is the asymmetric public key fetched from the provider's +# JWKS (as it is for every OIDC ID token), accepting HS256 lets an attacker +# forge tokens by HMAC-signing them with the public key bytes +# (RSA/HMAC algorithm-confusion attack, CWE-347). "none" is excluded for the +# obvious reason that it disables signature verification entirely. +_ALLOWED_OIDC_SIGNING_ALGS = frozenset({ + "RS256", "RS384", "RS512", + "ES256", "ES384", "ES512", + "PS256", "PS384", "PS512", + "EdDSA", +}) + +# OIDC Core 1.0 § 2 makes RS256 the spec-default ``id_token_signing_alg``, +# so this is the safe fallback when a provider's discovery document does not +# advertise ``id_token_signing_alg_values_supported`` (or advertises only +# algorithms outside the safe allowlist). +_DEFAULT_OIDC_SIGNING_ALGS = ("RS256",) + + +def _resolve_id_token_signing_algs(metadata): + """Return the algorithms to pass to ``jwt.decode(..., algorithms=...)``. + + Intersects the provider-advertised + ``id_token_signing_alg_values_supported`` with + :data:`_ALLOWED_OIDC_SIGNING_ALGS`. Falls back to + :data:`_DEFAULT_OIDC_SIGNING_ALGS` when the provider does not advertise + the field or advertises only algorithms outside the safe allowlist — + crucially, the fallback is to RS256, **never** to whatever the JWT + header claims at verification time. + """ + advertised = metadata.get("id_token_signing_alg_values_supported") or [] + if not isinstance(advertised, (list, tuple)): + advertised = [] + safe = [a for a in advertised if isinstance(a, str) and a in _ALLOWED_OIDC_SIGNING_ALGS] + return safe or list(_DEFAULT_OIDC_SIGNING_ALGS) + + class OIDCClient(OAuthClient): def __init__(self, config): """ @@ -32,7 +71,7 @@ class OIDCClient(OAuthClient): oidc_metadata = self._load_oidc_metadata(self.issuer) config.update({ 'issuer': oidc_metadata['issuer'], - 'jwks_uri': oidc_metadata['jwks_uri'], + 'jwks_uri': oidc_metadata['jwks_uri'], 'authorization_url': oidc_metadata['authorization_endpoint'], 'token_url': oidc_metadata['token_endpoint'], 'userinfo_url': oidc_metadata['userinfo_endpoint'] @@ -41,6 +80,11 @@ class OIDCClient(OAuthClient): super().__init__(config) self.issuer = config['issuer'] self.jwks_uri = config['jwks_uri'] + # Pin the accepted ID-token signing algorithms at construction time + # from a trusted source (provider metadata + safe allowlist) so the + # JWT verification step in :meth:`parse_id_token` cannot be tricked + # by attacker-controlled JWT headers (CWE-345 / CWE-347). + self.id_token_signing_algs = _resolve_id_token_signing_algs(oidc_metadata) @staticmethod @@ -60,23 +104,29 @@ class OIDCClient(OAuthClient): def parse_id_token(self, id_token): """ Parse and validate OIDC ID Token (JWT format) with signature verification. + + The accepted signing algorithms come from ``self.id_token_signing_algs`` + (pinned at construction time from the provider's discovery metadata, + intersected with :data:`_ALLOWED_OIDC_SIGNING_ALGS`). We deliberately + do **not** read the algorithm from the unverified JWT header — doing + so would let an attacker bypass signature verification by setting + ``"alg": "none"`` or pull off the classic RSA / HMAC algorithm + confusion by setting ``"alg": "HS256"`` and signing with the public + key fetched from the provider's JWKS (CWE-345 / CWE-347). """ try: - # Decode JWT header without verifying signature - headers = jwt.get_unverified_header(id_token) - - # OIDC usually uses `RS256` for signing - alg = headers.get("alg", "RS256") - - # Use PyJWT's PyJWKClient to fetch JWKS and find signing key + # Use PyJWT's PyJWKClient to fetch JWKS and find signing key. + # The client reads the ``kid`` from the JWT header internally to + # look up the key — that's fine: ``kid`` is not a security + # decision, the signature still proves which key was used. jwks_cli = jwt.PyJWKClient(self.jwks_uri) signing_key = jwks_cli.get_signing_key_from_jwt(id_token).key - # Decode and verify signature + # Decode and verify signature against the pinned allowlist. decoded_token = jwt.decode( id_token, key=signing_key, - algorithms=[alg], + algorithms=list(self.id_token_signing_algs), audience=str(self.client_id), issuer=self.issuer, ) diff --git a/test/testcases/test_web_api/test_auth_app/test_oidc_client_unit.py b/test/testcases/test_web_api/test_auth_app/test_oidc_client_unit.py index f1e620d65d..7f48a3b95e 100644 --- a/test/testcases/test_web_api/test_auth_app/test_oidc_client_unit.py +++ b/test/testcases/test_web_api/test_auth_app/test_oidc_client_unit.py @@ -127,18 +127,25 @@ def _base_config(): } -def _metadata(issuer): - return { +def _metadata(issuer, signing_algs=None): + md = { "issuer": issuer, "jwks_uri": f"{issuer}/jwks", "authorization_endpoint": f"{issuer}/authorize", "token_endpoint": f"{issuer}/token", "userinfo_endpoint": f"{issuer}/userinfo", } + if signing_algs is not None: + md["id_token_signing_alg_values_supported"] = signing_algs + return md -def _make_client(monkeypatch, oidc_module): - monkeypatch.setattr(oidc_module.OIDCClient, "_load_oidc_metadata", staticmethod(lambda issuer: _metadata(issuer))) +def _make_client(monkeypatch, oidc_module, signing_algs=None): + monkeypatch.setattr( + oidc_module.OIDCClient, + "_load_oidc_metadata", + staticmethod(lambda issuer: _metadata(issuer, signing_algs=signing_algs)), + ) return oidc_module.OIDCClient(_base_config()) @@ -199,8 +206,6 @@ def test_parse_id_token_success_and_error(monkeypatch): _, oidc_module = _load_auth_modules(monkeypatch) client = _make_client(monkeypatch, oidc_module) - monkeypatch.setattr(oidc_module.jwt, "get_unverified_header", lambda _token: {}) - seen = {} class _JwkClient(_DummyJwkClient): @@ -245,6 +250,195 @@ def test_parse_id_token_success_and_error(monkeypatch): assert str(exc_info.value) == "Error parsing ID Token: decode boom" +# ===================================================================== # +# JWT algorithm-confusion regression tests # +# # +# Before the fix, ``parse_id_token`` read the signing algorithm from # +# the unverified JWT header. An attacker who presents a JWT with # +# ``"alg": "none"`` would have signature verification disabled, and an # +# attacker who presents ``"alg": "HS256"`` and signs the JWT with the # +# public key bytes (RSA / HMAC confusion) would get the forged token # +# accepted. The tests below pin the contract: # +# # +# - the algorithm allowlist is pinned at construction time from the # +# provider's discovery metadata intersected with the safe allowlist # +# - the JWT header's ``alg`` claim is never read at decode time # +# ===================================================================== # + + +@pytest.mark.p2 +def test_id_token_signing_algs_default_to_rs256_when_metadata_missing(monkeypatch): + """No ``id_token_signing_alg_values_supported`` in metadata → RS256 only. + + Crucially the fallback is RS256, never whatever the JWT header claims. + """ + _, oidc_module = _load_auth_modules(monkeypatch) + client = _make_client(monkeypatch, oidc_module) + assert client.id_token_signing_algs == ["RS256"] + + +@pytest.mark.p2 +def test_id_token_signing_algs_intersect_metadata_with_safe_allowlist(monkeypatch): + """Metadata advertises a mix of safe and unsafe algs — only safe ones kept.""" + _, oidc_module = _load_auth_modules(monkeypatch) + client = _make_client( + monkeypatch, + oidc_module, + signing_algs=["RS256", "ES256", "HS256", "none", "PS512"], + ) + assert set(client.id_token_signing_algs) == {"RS256", "ES256", "PS512"} + # The dangerous algorithms must not appear in the verification allowlist. + assert "HS256" not in client.id_token_signing_algs + assert "none" not in client.id_token_signing_algs + + +@pytest.mark.p2 +def test_id_token_signing_algs_fall_back_when_only_unsafe_advertised(monkeypatch): + """Provider advertises only HS256 / none → fall back to RS256, do not trust.""" + _, oidc_module = _load_auth_modules(monkeypatch) + client = _make_client( + monkeypatch, + oidc_module, + signing_algs=["HS256", "none", "bogus"], + ) + assert client.id_token_signing_algs == ["RS256"] + + +@pytest.mark.p2 +def test_id_token_signing_algs_ignores_non_string_entries(monkeypatch): + """Malformed entries (None / dict / int) are filtered out, not crashed on.""" + _, oidc_module = _load_auth_modules(monkeypatch) + client = _make_client( + monkeypatch, + oidc_module, + signing_algs=["RS256", None, 42, {"x": 1}, "ES384"], + ) + assert set(client.id_token_signing_algs) == {"RS256", "ES384"} + + +@pytest.mark.p2 +def test_id_token_signing_algs_handles_non_list_metadata_field(monkeypatch): + """If metadata gives a non-list type for the field, fall back to default.""" + _, oidc_module = _load_auth_modules(monkeypatch) + client = _make_client(monkeypatch, oidc_module, signing_algs="RS256") + assert client.id_token_signing_algs == ["RS256"] + + +@pytest.mark.p2 +def test_parse_id_token_passes_pinned_algorithms_to_jwt_decode(monkeypatch): + """``jwt.decode`` receives the pinned allowlist, regardless of JWT header.""" + _, oidc_module = _load_auth_modules(monkeypatch) + client = _make_client(monkeypatch, oidc_module, signing_algs=["RS256", "ES256"]) + + # Even if the unverified header claims something dangerous, the + # verification path must not consult it. We sabotage + # ``jwt.get_unverified_header`` to prove the code never calls it. + def _explode(_token): # pragma: no cover - must not be called + raise AssertionError( + "parse_id_token must not read the algorithm from the unverified JWT header" + ) + + monkeypatch.setattr(oidc_module.jwt, "get_unverified_header", _explode) + monkeypatch.setattr(oidc_module.jwt, "PyJWKClient", _DummyJwkClient) + + seen = {} + + def _decode(id_token, key, algorithms, audience, issuer): + seen["algorithms"] = list(algorithms) + return {"sub": "user-2"} + + monkeypatch.setattr(oidc_module.jwt, "decode", _decode) + client.parse_id_token("malicious-header-token") + + assert set(seen["algorithms"]) == {"RS256", "ES256"} + # Hard-stop: dangerous algorithms must never reach ``jwt.decode``. + assert "none" not in seen["algorithms"] + assert "HS256" not in seen["algorithms"] + + +@pytest.mark.p2 +def test_parse_id_token_rejects_alg_none(monkeypatch): + """End-to-end: an ``alg: "none"`` JWT must not authenticate. + + Uses the real PyJWT decoder so the test exercises the actual contract + between ``parse_id_token`` and the upstream library. + """ + import jwt as real_jwt + + _, oidc_module = _load_auth_modules(monkeypatch) + client = _make_client(monkeypatch, oidc_module) # defaults to RS256 + + # PyJWT requires explicit opt-in to encode ``alg=none``; even then it + # produces a token with no signature segment. + forged = real_jwt.encode( + { + "sub": "victim-subject", + "email": "admin@target.example", + "aud": "client-1", + "iss": "https://issuer.example", + }, + key="", + algorithm="none", + ) + # Force the JWKS step into a no-op so we exercise *just* the alg gate. + monkeypatch.setattr(oidc_module.jwt, "PyJWKClient", _DummyJwkClient) + + with pytest.raises(ValueError) as exc_info: + client.parse_id_token(forged) + assert "Error parsing ID Token" in str(exc_info.value) + + +@pytest.mark.p2 +def test_parse_id_token_rejects_hs256_when_allowlist_is_asymmetric(monkeypatch): + """End-to-end: a JWT whose header claims ``alg: HS256`` must not be + accepted when the pinned allowlist is asymmetric-only. + + This is the algorithm half of the RSA / HMAC confusion attack + (CWE-347). The attacker forges a JWT with ``"alg": "HS256"`` so the + server picks the HMAC verifier; pre-fix the server would read that alg + straight from the header and call + ``jwt.decode(..., algorithms=["HS256"], key=public_key)`` which lets + the attacker forge tokens with the public key bytes. After the fix the + allowlist pinned at construction time wins — HS* is never in it — so + PyJWT raises ``InvalidAlgorithmError`` before the HMAC verifier is + ever invoked. + + Note: modern PyJWT (>=2.0) also independently refuses to use a + PEM-encoded key as an HMAC secret, so the public-key-bytes leg of the + full attack is partially mitigated at the library level. The fix here + is defense in depth and the only mitigation for non-PEM key formats + (raw bytes, DER, JWK octet keys, older PyJWT versions). + """ + import jwt as real_jwt + + _, oidc_module = _load_auth_modules(monkeypatch) + client = _make_client(monkeypatch, oidc_module) # defaults to RS256 + + # Use a non-PEM byte string so we exercise the alg gate (not PyJWT's + # incidental PEM-as-HMAC-secret refusal). + shared_secret = b"shared-secret-bytes-not-a-pem-key" + forged = real_jwt.encode( + { + "sub": "victim-subject", + "email": "admin@target.example", + "aud": "client-1", + "iss": "https://issuer.example", + }, + key=shared_secret, + algorithm="HS256", + ) + + class _SecretJwkClient(_DummyJwkClient): + def get_signing_key_from_jwt(self, _id_token): + return SimpleNamespace(key=shared_secret) + + monkeypatch.setattr(oidc_module.jwt, "PyJWKClient", _SecretJwkClient) + + with pytest.raises(ValueError) as exc_info: + client.parse_id_token(forged) + assert "Error parsing ID Token" in str(exc_info.value) + + @pytest.mark.p2 def test_fetch_user_info_merges_id_token_and_oauth_userinfo(monkeypatch): oauth_module, oidc_module = _load_auth_modules(monkeypatch)