Files
ragflow/api/apps/auth/oidc.py

164 lines
6.6 KiB
Python

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import jwt
from common.http_client import sync_request
from .oauth import OAuthClient
# Asymmetric signing algorithms safe to accept for OIDC ID tokens.
# Symmetric HMAC algorithms (HS*) are intentionally excluded — when the
# verification key is the asymmetric public key fetched from the provider's
# JWKS (as it is for every OIDC ID token), accepting HS256 lets an attacker
# forge tokens by HMAC-signing them with the public key bytes
# (RSA/HMAC algorithm-confusion attack, CWE-347). "none" is excluded for the
# obvious reason that it disables signature verification entirely.
_ALLOWED_OIDC_SIGNING_ALGS = frozenset(
{
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
"EdDSA",
}
)
# OIDC Core 1.0 § 2 makes RS256 the spec-default ``id_token_signing_alg``,
# so this is the safe fallback when a provider's discovery document does not
# advertise ``id_token_signing_alg_values_supported`` (or advertises only
# algorithms outside the safe allowlist).
_DEFAULT_OIDC_SIGNING_ALGS = ("RS256",)
def _resolve_id_token_signing_algs(metadata):
"""Return the algorithms to pass to ``jwt.decode(..., algorithms=...)``.
Intersects the provider-advertised
``id_token_signing_alg_values_supported`` with
:data:`_ALLOWED_OIDC_SIGNING_ALGS`. Falls back to
:data:`_DEFAULT_OIDC_SIGNING_ALGS` when the provider does not advertise
the field or advertises only algorithms outside the safe allowlist —
crucially, the fallback is to RS256, **never** to whatever the JWT
header claims at verification time.
"""
advertised = metadata.get("id_token_signing_alg_values_supported") or []
if not isinstance(advertised, (list, tuple)):
advertised = []
safe = [a for a in advertised if isinstance(a, str) and a in _ALLOWED_OIDC_SIGNING_ALGS]
return safe or list(_DEFAULT_OIDC_SIGNING_ALGS)
class OIDCClient(OAuthClient):
def __init__(self, config):
"""
Initialize the OIDCClient with the provider's configuration.
Use `issuer` as the single source of truth for configuration discovery.
"""
self.issuer = config.get("issuer")
if not self.issuer:
raise ValueError("Missing issuer in configuration.")
oidc_metadata = self._load_oidc_metadata(self.issuer)
config.update(
{
"issuer": oidc_metadata["issuer"],
"jwks_uri": oidc_metadata["jwks_uri"],
"authorization_url": oidc_metadata["authorization_endpoint"],
"token_url": oidc_metadata["token_endpoint"],
"userinfo_url": oidc_metadata["userinfo_endpoint"],
}
)
super().__init__(config)
self.issuer = config["issuer"]
self.jwks_uri = config["jwks_uri"]
# Pin the accepted ID-token signing algorithms at construction time
# from a trusted source (provider metadata + safe allowlist) so the
# JWT verification step in :meth:`parse_id_token` cannot be tricked
# by attacker-controlled JWT headers (CWE-345 / CWE-347).
self.id_token_signing_algs = _resolve_id_token_signing_algs(oidc_metadata)
@staticmethod
def _load_oidc_metadata(issuer):
"""
Load OIDC metadata from `/.well-known/openid-configuration`.
"""
try:
metadata_url = f"{issuer}/.well-known/openid-configuration"
response = sync_request("GET", metadata_url, timeout=7)
response.raise_for_status()
return response.json()
except Exception as e:
raise ValueError(f"Failed to fetch OIDC metadata: {e}")
def parse_id_token(self, id_token):
"""
Parse and validate OIDC ID Token (JWT format) with signature verification.
The accepted signing algorithms come from ``self.id_token_signing_algs``
(pinned at construction time from the provider's discovery metadata,
intersected with :data:`_ALLOWED_OIDC_SIGNING_ALGS`). We deliberately
do **not** read the algorithm from the unverified JWT header — doing
so would let an attacker bypass signature verification by setting
``"alg": "none"`` or pull off the classic RSA / HMAC algorithm
confusion by setting ``"alg": "HS256"`` and signing with the public
key fetched from the provider's JWKS (CWE-345 / CWE-347).
"""
try:
# Use PyJWT's PyJWKClient to fetch JWKS and find signing key.
# The client reads the ``kid`` from the JWT header internally to
# look up the key — that's fine: ``kid`` is not a security
# decision, the signature still proves which key was used.
jwks_cli = jwt.PyJWKClient(self.jwks_uri)
signing_key = jwks_cli.get_signing_key_from_jwt(id_token).key
# Decode and verify signature against the pinned allowlist.
decoded_token = jwt.decode(
id_token,
key=signing_key,
algorithms=list(self.id_token_signing_algs),
audience=str(self.client_id),
issuer=self.issuer,
)
return decoded_token
except Exception as e:
raise ValueError(f"Error parsing ID Token: {e}")
def fetch_user_info(self, access_token, id_token=None, **kwargs):
"""
Fetch user info.
"""
user_info = {}
if id_token:
user_info = self.parse_id_token(id_token)
user_info.update(super().fetch_user_info(access_token).to_dict())
return self.normalize_user_info(user_info)
async def async_fetch_user_info(self, access_token, id_token=None, **kwargs):
user_info = {}
if id_token:
user_info = self.parse_id_token(id_token)
user_info.update((await super().async_fetch_user_info(access_token)).to_dict())
return self.normalize_user_info(user_info)
def normalize_user_info(self, user_info):
return super().normalize_user_info(user_info)