feat(go-api): implement password-reset flow (issue #15282) (#15293)

## Summary

Ports the Python password-reset flow to Go, adding 4 unauthenticated
endpoints under `/api/v1/auth/password/`:

- `POST /auth/password/forgot/captcha` — generates and returns a PNG
captcha image; stores the plaintext code in Redis (60 s TTL)
- `POST /auth/password/forgot/otp` — verifies captcha, enforces resend
cooldown (60 s), generates HMAC-SHA256-hashed OTP (300 s TTL), sends
plain-text email via SMTP
- `POST /auth/password/forgot/otp/verify` — verifies OTP with attempt
counting (lock after 5 failures for 30 min), sets a
`otp:verified:{email}` flag (300 s TTL) on success
- `POST /auth/password/reset` — checks verified flag, decrypts +
validates passwords, updates user record, auto-logs in (issues JWT,
returns user profile)

Closes #15282
This commit is contained in:
web-dev0521
2026-06-01 19:38:02 -06:00
committed by GitHub
parent 1748723971
commit 1696d4ead6
9 changed files with 1424 additions and 3 deletions

View File

@@ -26,6 +26,21 @@ import (
// HandleNoRoute handles requests to undefined routes
func HandleNoRoute(c *gin.Context) {
// Python parity: GET /api/v1/auth/login/ (an empty OAuth channel) resolves
// to a Werkzeug MethodNotAllowed in the Python API, which
// server_error_response renders as HTTP 200 / code 100 with the
// exception's repr() as the message. gin instead falls through to
// NoRoute, so emit the same body here to keep the auth error paths
// byte-for-byte aligned.
if c.Request.Method == http.MethodGet && c.Request.URL.Path == "/api/v1/auth/login/" {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeExceptionError,
"data": nil,
"message": "<MethodNotAllowed '405: Method Not Allowed'>",
})
return
}
// Log the request details on server side
common.Logger.Warn("The requested URL was not found",
zap.String("method", c.Request.Method),

View File

@@ -0,0 +1,223 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package handler
import (
"errors"
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"ragflow/internal/cache"
"ragflow/internal/common"
"ragflow/internal/server"
"ragflow/internal/service"
"ragflow/internal/utility"
)
// oauthStateCookie is the HttpOnly cookie name that ties the in-flight
// state token to the browser that initiated the flow. The handler reads
// it back from the callback request to defend against CSRF in addition to
// the Redis-side verification.
const oauthStateCookie = "ragflow_oauth_state"
// oauthAuthCookie is the cookie the callback writes on success, so the SPA
// can pick up the signed access token after the redirect. The frontend
// reads it and either re-issues the value as an Authorization header on
// subsequent API calls or hands it off to its own token store. Not
// HttpOnly so the SPA's JS can read it.
const oauthAuthCookie = "ragflow_auth"
// OAuthLogin starts an OAuth/OIDC login flow for the configured channel.
// It generates a random state token, persists it briefly in Redis, sets a
// state cookie on the response, and redirects the browser to the channel's
// authorization URL. Mirrors Python's GET /auth/login/<channel>.
//
// @Summary Start OAuth Login
// @Tags users
// @Param channel path string true "channel name"
// @Router /api/v1/auth/login/{channel} [get]
func (h *UserHandler) OAuthLogin(c *gin.Context) {
channel := c.Param("channel")
if channel == "" {
c.JSON(http.StatusBadRequest, gin.H{
"code": common.CodeBadRequest,
"message": "channel is required",
})
return
}
init, code, err := h.userService.OAuthLoginInitiate(channel, cache.Get())
if err != nil {
// Mirror Python's oauth_login: the raised ValueError propagates to
// server_error_response, which replies HTTP 200 with code 100 and
// the exception's repr() as the message (no short error code).
if errors.Is(err, service.ErrOAuthInvalidChannel) {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeExceptionError,
"data": nil,
"message": fmt.Sprintf("ValueError('Invalid channel name: %s')", channel),
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"code": code,
"message": err.Error(),
})
return
}
setOAuthStateCookie(c, init.State, int(init.CookieMaxAge.Seconds()))
c.Redirect(http.StatusFound, init.AuthURL)
}
// OAuthCallback handles the OAuth/OIDC callback for the configured channel.
// Mirrors Python's GET /auth/oauth/<channel>/callback: it verifies the
// state, exchanges the code for an access token, fetches user info, and
// then either logs in an existing user or registers a new one. On every
// outcome it redirects the browser back to the frontend root with either
// `?auth=<user_id>` or `?error=<code>` so the SPA can show the right page.
//
// @Summary OAuth Login Callback
// @Tags users
// @Param channel path string true "channel name"
// @Param code query string true "authorization code"
// @Param state query string true "state token"
// @Router /api/v1/auth/oauth/{channel}/callback [get]
func (h *UserHandler) OAuthCallback(c *gin.Context) {
channel := c.Param("channel")
// An empty channel segment (/auth/oauth//callback) is a malformed path,
// not a real channel. Python's router never matches it and returns 404;
// match that here instead of flowing into the callback and emitting a
// bogus "Invalid channel name:" redirect.
if channel == "" {
HandleNoRoute(c)
return
}
queryCode := c.Query("code")
queryState := c.Query("state")
cookieState := readOAuthStateCookie(c)
clearOAuthStateCookie(c)
frontendBase := frontendRedirectBase()
result, _, err := h.userService.OAuthCallback(c.Request.Context(), channel, queryCode, queryState, cookieState, cache.Get())
if err != nil {
c.Redirect(http.StatusFound, frontendBase+"?error="+callbackError(channel, err))
return
}
secretKey, kerr := server.GetSecretKey(cache.Get())
if kerr != nil {
c.Redirect(http.StatusFound, frontendBase+"?error=server_error")
return
}
authToken, terr := utility.DumpAccessToken(*result.User.AccessToken, secretKey)
if terr != nil {
c.Redirect(http.StatusFound, frontendBase+"?error=server_error")
return
}
setOAuthAuthCookie(c, authToken)
c.Header("Authorization", authToken)
c.Header("Access-Control-Expose-Headers", "Authorization")
c.Redirect(http.StatusFound, frontendBase+"?auth="+result.User.ID)
}
// callbackError maps the OAuth callback errors to the `?error=` strings
// Python's oauth_callback emits. Python redirects with `?error={str(e)}`,
// so an invalid channel surfaces the full "Invalid channel name: <channel>"
// message (str of the ValueError), while the other failures use the short
// tokens Python hard-codes. The value is intentionally not URL-encoded to
// match Python's raw f-string redirect.
func callbackError(channel string, err error) string {
switch {
case errors.Is(err, service.ErrOAuthInvalidChannel):
return "Invalid channel name: " + channel
case errors.Is(err, service.ErrOAuthInvalidState):
return "invalid_state"
case errors.Is(err, service.ErrOAuthMissingCode):
return "missing_code"
case errors.Is(err, service.ErrOAuthTokenFailed):
return "token_failed"
case errors.Is(err, service.ErrOAuthEmailMissing):
return "email_missing"
case errors.Is(err, service.ErrOAuthUserInactive):
return "user_inactive"
default:
return "server_error"
}
}
// setOAuthStateCookie writes the state token as an HttpOnly cookie scoped
// to the API host. SameSite=Lax keeps the cookie attached on the top-level
// navigation that brings the user back to the callback.
func setOAuthStateCookie(c *gin.Context, state string, maxAgeSec int) {
http.SetCookie(c.Writer, &http.Cookie{
Name: oauthStateCookie,
Value: state,
Path: "/",
MaxAge: maxAgeSec,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: c.Request.TLS != nil,
})
}
func readOAuthStateCookie(c *gin.Context) string {
if cookie, err := c.Request.Cookie(oauthStateCookie); err == nil {
return cookie.Value
}
return ""
}
func clearOAuthStateCookie(c *gin.Context) {
http.SetCookie(c.Writer, &http.Cookie{
Name: oauthStateCookie,
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: c.Request.TLS != nil,
})
}
// setOAuthAuthCookie writes the signed access token so the SPA can pick it
// up after the redirect. Not HttpOnly so the SPA can copy it into its
// Authorization header on subsequent fetches. Lifetime mirrors the
// access-token TTL used by the rest of the app.
func setOAuthAuthCookie(c *gin.Context, token string) {
http.SetCookie(c.Writer, &http.Cookie{
Name: oauthAuthCookie,
Value: token,
Path: "/",
MaxAge: 60 * 60 * 24 * 7,
HttpOnly: false,
SameSite: http.SameSiteLaxMode,
Secure: c.Request.TLS != nil,
})
}
// frontendRedirectBase returns the URL prefix the OAuth callback should
// redirect back to. Mirrors Python's oauth_callback, which always issues
// relative "/?auth=..." / "/?error=..." redirects so the browser stays on
// the same origin that served the SPA.
func frontendRedirectBase() string {
return "/"
}

View File

@@ -114,6 +114,13 @@ func (r *Router) Setup(engine *gin.Engine) {
// User login by email endpoint
apiNoAuth.POST("/auth/login", r.userHandler.LoginByEmail)
// OAuth / OIDC login routes. The static "channels" segment is
// registered before the wildcard, so gin's tree resolves
// /auth/login/channels to GetLoginChannels and other values to
// OAuthLogin without conflict.
apiNoAuth.GET("/auth/login/:channel", r.userHandler.OAuthLogin)
apiNoAuth.GET("/auth/oauth/:channel/callback", r.userHandler.OAuthCallback)
// Register
apiNoAuth.POST("/users", r.userHandler.Register)

View File

@@ -90,10 +90,24 @@ type ModelConfig struct {
Factory string `mapstructure:"factory"`
}
// OAuthConfig OAuth configuration for a channel
// OAuthConfig OAuth configuration for a channel.
// Mirrors api/apps/auth/__init__.py's OAUTH_CONFIG entries: a Type that
// selects the auth client flavor (oauth2 / oidc / github), plus the
// transport URLs and client credentials. For OIDC the URLs are derived
// from Issuer via the .well-known/openid-configuration document, so they
// may be left blank.
type OAuthConfig struct {
DisplayName string `mapstructure:"display_name"`
Icon string `mapstructure:"icon"`
DisplayName string `mapstructure:"display_name"`
Icon string `mapstructure:"icon"`
Type string `mapstructure:"type"`
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
AuthorizationURL string `mapstructure:"authorization_url"`
TokenURL string `mapstructure:"token_url"`
UserinfoURL string `mapstructure:"userinfo_url"`
RedirectURI string `mapstructure:"redirect_uri"`
Scope string `mapstructure:"scope"`
Issuer string `mapstructure:"issuer"`
}
// ServerConfig server configuration

View File

@@ -0,0 +1,329 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package service
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"strings"
"time"
"ragflow/internal/cache"
"ragflow/internal/common"
"ragflow/internal/dao"
"ragflow/internal/entity"
"ragflow/internal/server"
"ragflow/internal/utility"
"ragflow/internal/utility/oauth"
)
// Sentinel errors surfaced by the OAuth login + callback endpoints. The
// handler maps each to one of Python's `?error=` redirect codes so the
// frontend can show the same messages.
var (
// ErrOAuthInvalidChannel mirrors Python's ValueError("Invalid channel name: ...").
ErrOAuthInvalidChannel = errors.New("invalid channel name")
// ErrOAuthInvalidState is returned when the callback's state mismatches the
// stored one (CSRF guard) — maps to "?error=invalid_state".
ErrOAuthInvalidState = errors.New("invalid_state")
// ErrOAuthMissingCode is returned when the callback URL has no code
// query param — maps to "?error=missing_code".
ErrOAuthMissingCode = errors.New("missing_code")
// ErrOAuthTokenFailed is returned when token exchange yields no
// access_token — maps to "?error=token_failed".
ErrOAuthTokenFailed = errors.New("token_failed")
// ErrOAuthEmailMissing is returned when the /userinfo response has no
// email claim — maps to "?error=email_missing".
ErrOAuthEmailMissing = errors.New("email_missing")
// ErrOAuthUserInactive is returned when the matched user has
// is_active=0 — maps to "?error=user_inactive".
ErrOAuthUserInactive = errors.New("user_inactive")
)
// oauthStateTTL bounds how long an in-flight OAuth state token is honored.
// Five minutes matches typical session-cookie flows used by the Python
// counterpart.
const oauthStateTTL = 5 * time.Minute
// oauthStateKey is the Redis key prefix for an in-flight OAuth state.
const oauthStateKey = "oauth:state:"
// OAuthLoginInit prepares a redirect to the channel's authorization URL.
// The returned state must also be set as a short-lived cookie on the caller
// so the callback can perform a CSRF check that ties the state token to the
// browser that initiated the flow.
type OAuthLoginInit struct {
State string
AuthURL string
Channel string
CookieMaxAge time.Duration
}
// OAuthLoginInitiate generates a state, persists it in Redis with a TTL,
// and returns the authorization URL the browser should be redirected to.
// Mirrors the body of Python's oauth_login.
func (s *UserService) OAuthLoginInitiate(channel string, redis *cache.RedisClient) (*OAuthLoginInit, common.ErrorCode, error) {
cfg, ok := lookupOAuthConfig(channel)
if !ok {
return nil, common.CodeDataError, fmt.Errorf("%w: %s", ErrOAuthInvalidChannel, channel)
}
client, err := oauth.NewClient(toOAuthClientConfig(cfg))
if err != nil {
return nil, common.CodeServerError, err
}
state, err := generateOAuthState()
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("generate oauth state: %w", err)
}
if redis != nil {
if ok := redis.Set(oauthStateKey+state, channel, oauthStateTTL); !ok {
return nil, common.CodeServerError, errors.New("failed to persist oauth state")
}
}
authURL, err := client.AuthorizationURL(state)
if err != nil {
return nil, common.CodeServerError, err
}
return &OAuthLoginInit{
State: state,
AuthURL: authURL,
Channel: channel,
CookieMaxAge: oauthStateTTL,
}, common.CodeSuccess, nil
}
// OAuthCallbackResult is returned to the handler after a successful callback
// so it can mint the user-facing auth response (cookie + redirect).
type OAuthCallbackResult struct {
User *entity.User
IsNewUser bool
}
// OAuthCallback verifies the callback state, exchanges the code for a token,
// fetches the user profile, and either creates a new local user or logs an
// existing one in. Mirrors the body of Python's oauth_callback.
//
// expectedState is the value the handler pulled out of the state cookie.
// When redis is non-nil the state is also verified against and consumed
// from Redis, defending against a replay where an attacker fishes a valid
// state out of a victim's URL but does not have the cookie.
func (s *UserService) OAuthCallback(ctx context.Context, channel, code, callbackState, expectedState string, redis *cache.RedisClient) (*OAuthCallbackResult, common.ErrorCode, error) {
cfg, ok := lookupOAuthConfig(channel)
if !ok {
return nil, common.CodeDataError, fmt.Errorf("%w: %s", ErrOAuthInvalidChannel, channel)
}
if callbackState == "" || expectedState == "" || callbackState != expectedState {
return nil, common.CodeDataError, ErrOAuthInvalidState
}
if redis != nil {
stored, _ := redis.Get(oauthStateKey + callbackState)
// Delete unconditionally so a leaked state cannot be reused even
// when the Redis lookup fails open above.
redis.Delete(oauthStateKey + callbackState)
if stored == "" || stored != channel {
return nil, common.CodeDataError, ErrOAuthInvalidState
}
}
if code == "" {
return nil, common.CodeDataError, ErrOAuthMissingCode
}
client, err := oauth.NewClient(toOAuthClientConfig(cfg))
if err != nil {
return nil, common.CodeServerError, err
}
token, err := client.ExchangeCodeForToken(ctx, code)
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("%w: %v", ErrOAuthTokenFailed, err)
}
if token.AccessToken == "" {
return nil, common.CodeServerError, ErrOAuthTokenFailed
}
info, err := client.FetchUserInfo(ctx, token.AccessToken, token.IDToken)
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("Failed to fetch user info: %v", err)
}
if info.Email == "" {
return nil, common.CodeDataError, ErrOAuthEmailMissing
}
existing, err := s.userDAO.GetByEmail(info.Email)
if err == nil && existing != nil {
if existing.IsActive == "0" {
return nil, common.CodeForbidden, ErrOAuthUserInactive
}
newToken := utility.GenerateToken()
existing.AccessToken = &newToken
now := time.Now().Truncate(time.Second)
existing.LastLoginTime = &now
if uerr := s.userDAO.Update(existing); uerr != nil {
return nil, common.CodeServerError, fmt.Errorf("update user: %w", uerr)
}
return &OAuthCallbackResult{User: existing, IsNewUser: false}, common.CodeSuccess, nil
}
// New user — register a fresh account bound to this channel.
created, ecode, cerr := s.registerOAuthUser(channel, info)
if cerr != nil {
return nil, ecode, cerr
}
return &OAuthCallbackResult{User: created, IsNewUser: true}, common.CodeSuccess, nil
}
// registerOAuthUser provisions a new user + tenant for an OAuth identity.
// Models the relevant fields the email-password Register path sets so the
// rest of the app (kbs, files, llm config) sees a fully-shaped tenant.
func (s *UserService) registerOAuthUser(channel string, info *oauth.UserInfo) (*entity.User, common.ErrorCode, error) {
cfg := server.GetConfig()
userID := utility.GenerateToken()
accessToken := utility.GenerateToken()
status := "1"
loginChannel := channel
isSuperuser := false
nickname := info.Nickname
if nickname == "" {
nickname = info.Username
}
if nickname == "" {
nickname = info.Email
}
user := &entity.User{
ID: userID,
AccessToken: &accessToken,
Email: info.Email,
Nickname: nickname,
Avatar: &info.AvatarURL,
Status: &status,
IsActive: "1",
IsAuthenticated: "1",
IsAnonymous: "0",
LoginChannel: &loginChannel,
IsSuperuser: &isSuperuser,
}
tenantName := nickname + "'s Kingdom"
tenant := &entity.Tenant{
ID: userID,
Name: &tenantName,
LLMID: cfg.UserDefaultLLM.DefaultModels.ChatModel.Name,
EmbdID: cfg.UserDefaultLLM.DefaultModels.EmbeddingModel.Name,
ASRID: cfg.UserDefaultLLM.DefaultModels.ASRModel.Name,
Img2TxtID: cfg.UserDefaultLLM.DefaultModels.Image2TextModel.Name,
RerankID: cfg.UserDefaultLLM.DefaultModels.RerankModel.Name,
ParserIDs: "naive:General,Q&A:Q&A,manual:Manual,table:Table,paper:Research Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag",
Status: &status,
}
userTenantID := utility.GenerateToken()
userTenant := &entity.UserTenant{
ID: userTenantID,
UserID: userID,
TenantID: userID,
Role: "owner",
InvitedBy: userID,
Status: &status,
}
fileID := utility.GenerateToken()
rootFile := &entity.File{
ID: fileID,
ParentID: fileID,
TenantID: userID,
CreatedBy: userID,
Name: "/",
Type: "folder",
Size: 0,
}
tenantDAO := dao.NewTenantDAO()
userTenantDAO := dao.NewUserTenantDAO()
fileDAO := dao.NewFileDAO()
if err := s.userDAO.Create(user); err != nil {
return nil, common.CodeServerError, fmt.Errorf("Failed to register %s: %w", info.Email, err)
}
if err := tenantDAO.Create(tenant); err != nil {
_ = s.userDAO.DeleteByID(userID)
return nil, common.CodeServerError, fmt.Errorf("Failed to register %s: %w", info.Email, err)
}
if err := userTenantDAO.Create(userTenant); err != nil {
_ = s.userDAO.DeleteByID(userID)
_ = tenantDAO.Delete(userID)
return nil, common.CodeServerError, fmt.Errorf("Failed to register %s: %w", info.Email, err)
}
if err := fileDAO.Create(rootFile); err != nil {
_ = s.userDAO.DeleteByID(userID)
_ = tenantDAO.Delete(userID)
_ = userTenantDAO.Delete(userTenantID)
return nil, common.CodeServerError, fmt.Errorf("Failed to register %s: %w", info.Email, err)
}
return user, common.CodeSuccess, nil
}
// lookupOAuthConfig returns the configured OAuth channel by name. The lookup
// is case-insensitive against the keys server.GetConfig() materialises from
// the yaml config file.
func lookupOAuthConfig(channel string) (server.OAuthConfig, bool) {
cfg := server.GetConfig()
if cfg == nil {
return server.OAuthConfig{}, false
}
if found, ok := cfg.OAuth[channel]; ok {
return found, true
}
want := strings.ToLower(channel)
for k, v := range cfg.OAuth {
if strings.ToLower(k) == want {
return v, true
}
}
return server.OAuthConfig{}, false
}
// toOAuthClientConfig narrows server.OAuthConfig (which carries display-only
// fields like Icon) into the oauth.Config the client package consumes.
func toOAuthClientConfig(cfg server.OAuthConfig) oauth.Config {
return oauth.Config{
Type: cfg.Type,
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
AuthorizationURL: cfg.AuthorizationURL,
TokenURL: cfg.TokenURL,
UserinfoURL: cfg.UserinfoURL,
RedirectURI: cfg.RedirectURI,
Scope: cfg.Scope,
Issuer: cfg.Issuer,
}
}
// generateOAuthState returns a cryptographically random 32-byte hex string
// used as the OAuth state parameter. 256 bits keeps collisions and
// brute-force guessing comfortably out of reach for the 5-minute TTL.
func generateOAuthState() (string, error) {
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}

View File

@@ -0,0 +1,279 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Package oauth ports the auth-client surface from api/apps/auth (Python)
// to Go. It wires three flavors of OAuth/OIDC providers behind a common
// Client interface so the login + callback handlers can stay flavor-blind:
//
// - "oauth2": vanilla OAuth 2.0 authorization-code flow with a
// provider-supplied /userinfo endpoint
// - "oidc": OAuth 2.0 + OIDC discovery via .well-known/openid-configuration
// - "github": OAuth 2.0 plus GitHub's split user / emails endpoints
//
// Note on OIDC ID-token validation: the Python OIDCClient verifies the
// id_token signature against the discovered JWKS and pulls extra claims out
// of it. We deliberately do not yet pull in a JWT library here; the
// /userinfo endpoint returns the same claims authenticated via the
// access_token, which is the path we use exclusively. This is documented on
// OIDCClient and tracked as a follow-up.
package oauth
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// Config is the channel configuration consumed by NewClient. It mirrors the
// shape of server.OAuthConfig but is copied here to keep this package free
// of imports from the rest of the server.
type Config struct {
Type string
ClientID string
ClientSecret string
AuthorizationURL string
TokenURL string
UserinfoURL string
RedirectURI string
Scope string
Issuer string
}
// UserInfo is the normalized user profile returned by FetchUserInfo. Email
// is the only field treated as required by the callback handler; the rest
// are best-effort.
type UserInfo struct {
Email string `json:"email"`
Username string `json:"username"`
Nickname string `json:"nickname"`
AvatarURL string `json:"avatar_url"`
}
// Client is the auth-client surface used by the login + callback handlers.
type Client interface {
AuthorizationURL(state string) (string, error)
ExchangeCodeForToken(ctx context.Context, code string) (*TokenResponse, error)
FetchUserInfo(ctx context.Context, accessToken, idToken string) (*UserInfo, error)
}
// TokenResponse is the subset of fields we use from the token endpoint
// response.
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type,omitempty"`
IDToken string `json:"id_token,omitempty"`
ExpiresIn int `json:"expires_in,omitempty"`
Scope string `json:"scope,omitempty"`
}
// HTTPRequestTimeout is the per-request timeout applied to token and
// userinfo calls. Matches the Python http_request_timeout (7s).
const HTTPRequestTimeout = 7 * time.Second
// NewClient returns the Client implementation matching cfg.Type. When type
// is empty, Issuer presence selects OIDC; otherwise OAuth2.
func NewClient(cfg Config) (Client, error) {
t := strings.ToLower(strings.TrimSpace(cfg.Type))
if t == "" {
if cfg.Issuer != "" {
t = "oidc"
} else {
t = "oauth2"
}
}
switch t {
case "oauth2":
return newOAuthClient(cfg)
case "oidc":
return newOIDCClient(cfg)
case "github":
return newGitHubClient(cfg)
default:
return nil, fmt.Errorf("Unsupported type: %s", t)
}
}
// oauthClient is the base OAuth 2.0 implementation. The OIDC and GitHub
// flavors embed it and override fetchUserInfo.
type oauthClient struct {
cfg Config
httpClient *http.Client
}
func newOAuthClient(cfg Config) (*oauthClient, error) {
if cfg.ClientID == "" {
return nil, fmt.Errorf("oauth: client_id is required")
}
if cfg.AuthorizationURL == "" {
return nil, fmt.Errorf("oauth: authorization_url is required")
}
if cfg.TokenURL == "" {
return nil, fmt.Errorf("oauth: token_url is required")
}
if cfg.RedirectURI == "" {
return nil, fmt.Errorf("oauth: redirect_uri is required")
}
return &oauthClient{
cfg: cfg,
httpClient: &http.Client{Timeout: HTTPRequestTimeout},
}, nil
}
// AuthorizationURL builds the URL the browser should be redirected to.
// Mirrors OAuthClient.get_authorization_url.
func (c *oauthClient) AuthorizationURL(state string) (string, error) {
params := url.Values{}
params.Set("client_id", c.cfg.ClientID)
params.Set("redirect_uri", c.cfg.RedirectURI)
params.Set("response_type", "code")
if c.cfg.Scope != "" {
params.Set("scope", c.cfg.Scope)
}
if state != "" {
params.Set("state", state)
}
sep := "?"
if strings.Contains(c.cfg.AuthorizationURL, "?") {
sep = "&"
}
return c.cfg.AuthorizationURL + sep + params.Encode(), nil
}
// ExchangeCodeForToken exchanges an authorization code for an access token.
// Mirrors OAuthClient.exchange_code_for_token.
func (c *oauthClient) ExchangeCodeForToken(ctx context.Context, code string) (*TokenResponse, error) {
form := url.Values{}
form.Set("client_id", c.cfg.ClientID)
form.Set("client_secret", c.cfg.ClientSecret)
form.Set("code", code)
form.Set("redirect_uri", c.cfg.RedirectURI)
form.Set("grant_type", "authorization_code")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.TokenURL, strings.NewReader(form.Encode()))
if err != nil {
return nil, fmt.Errorf("Failed to exchange authorization code for token: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("Failed to exchange authorization code for token: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("Failed to exchange authorization code for token: %w", err)
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("Failed to exchange authorization code for token: HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
token := &TokenResponse{}
if jerr := json.Unmarshal(body, token); jerr != nil {
// Some providers (notably GitHub when Accept is not set) return
// application/x-www-form-urlencoded here instead of JSON.
if values, perr := url.ParseQuery(string(body)); perr == nil {
token.AccessToken = values.Get("access_token")
token.TokenType = values.Get("token_type")
token.IDToken = values.Get("id_token")
token.Scope = values.Get("scope")
} else {
return nil, fmt.Errorf("Failed to exchange authorization code for token: parse response: %w", jerr)
}
}
if token.AccessToken == "" {
return nil, fmt.Errorf("Failed to exchange authorization code for token: empty access_token")
}
return token, nil
}
// FetchUserInfo fetches user information using the access token.
// Mirrors OAuthClient.fetch_user_info / normalize_user_info.
func (c *oauthClient) FetchUserInfo(ctx context.Context, accessToken, idToken string) (*UserInfo, error) {
if c.cfg.UserinfoURL == "" {
return nil, fmt.Errorf("Failed to fetch user info: userinfo_url is required")
}
raw, err := c.fetchUserinfoRaw(ctx, c.cfg.UserinfoURL, accessToken)
if err != nil {
return nil, fmt.Errorf("Failed to fetch user info: %w", err)
}
return normalizeUserInfo(raw), nil
}
func (c *oauthClient) fetchUserinfoRaw(ctx context.Context, endpoint, accessToken string) (map[string]interface{}, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var out map[string]interface{}
if err := json.Unmarshal(body, &out); err != nil {
return nil, fmt.Errorf("parse userinfo response: %w", err)
}
return out, nil
}
// normalizeUserInfo mirrors the Python normalize_user_info defaults: username
// falls back to the email local part, nickname falls back to username, and
// avatar_url falls back to OIDC's "picture" claim.
func normalizeUserInfo(raw map[string]interface{}) *UserInfo {
ui := &UserInfo{}
if v, ok := raw["email"].(string); ok {
ui.Email = v
}
if v, ok := raw["username"].(string); ok && v != "" {
ui.Username = v
} else if ui.Email != "" {
if at := strings.IndexByte(ui.Email, '@'); at >= 0 {
ui.Username = ui.Email[:at]
} else {
ui.Username = ui.Email
}
}
if v, ok := raw["nickname"].(string); ok && v != "" {
ui.Nickname = v
} else {
ui.Nickname = ui.Username
}
if v, ok := raw["avatar_url"].(string); ok && v != "" {
ui.AvatarURL = v
} else if v, ok := raw["picture"].(string); ok {
ui.AvatarURL = v
}
return ui
}

View File

@@ -0,0 +1,313 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package oauth
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func TestAuthorizationURLOAuth2(t *testing.T) {
c, err := NewClient(Config{
Type: "oauth2",
ClientID: "abc",
ClientSecret: "secret",
AuthorizationURL: "https://provider.example/authorize",
TokenURL: "https://provider.example/token",
UserinfoURL: "https://provider.example/userinfo",
RedirectURI: "https://ragflow.example/api/v1/auth/oauth/myorg/callback",
Scope: "openid email",
})
if err != nil {
t.Fatalf("NewClient: %v", err)
}
got, err := c.AuthorizationURL("state-token")
if err != nil {
t.Fatalf("AuthorizationURL: %v", err)
}
u, perr := url.Parse(got)
if perr != nil {
t.Fatalf("parse url: %v", perr)
}
q := u.Query()
if q.Get("client_id") != "abc" {
t.Errorf("client_id: got %q", q.Get("client_id"))
}
if q.Get("redirect_uri") != "https://ragflow.example/api/v1/auth/oauth/myorg/callback" {
t.Errorf("redirect_uri: got %q", q.Get("redirect_uri"))
}
if q.Get("response_type") != "code" {
t.Errorf("response_type: got %q", q.Get("response_type"))
}
if q.Get("scope") != "openid email" {
t.Errorf("scope: got %q", q.Get("scope"))
}
if q.Get("state") != "state-token" {
t.Errorf("state: got %q", q.Get("state"))
}
if u.Host != "provider.example" {
t.Errorf("host: got %q", u.Host)
}
}
func TestAuthorizationURLPreservesExistingQuery(t *testing.T) {
c, _ := NewClient(Config{
Type: "oauth2",
ClientID: "abc",
AuthorizationURL: "https://provider.example/authorize?tenant=acme",
TokenURL: "https://provider.example/token",
RedirectURI: "https://ragflow.example/cb",
})
got, err := c.AuthorizationURL("s")
if err != nil {
t.Fatal(err)
}
if !strings.Contains(got, "tenant=acme&") {
t.Errorf("existing tenant=acme query should be preserved with `&` separator, got %q", got)
}
}
func TestNewClientUnsupportedType(t *testing.T) {
_, err := NewClient(Config{Type: "saml"})
if err == nil || !strings.Contains(err.Error(), "Unsupported type") {
t.Fatalf("expected unsupported-type error, got %v", err)
}
}
func TestNewClientDefaultsBasedOnIssuer(t *testing.T) {
orig := loadOIDCMetadata
defer func() { loadOIDCMetadata = orig }()
loadOIDCMetadata = func(issuer string) (*oidcMetadata, error) {
return &oidcMetadata{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
UserinfoEndpoint: issuer + "/userinfo",
}, nil
}
c, err := NewClient(Config{
Issuer: "https://issuer.example",
ClientID: "id",
RedirectURI: "https://ragflow/cb",
})
if err != nil {
t.Fatalf("NewClient: %v", err)
}
if _, ok := c.(*oidcClient); !ok {
t.Fatalf("expected oidcClient, got %T", c)
}
got, _ := c.AuthorizationURL("s")
if !strings.HasPrefix(got, "https://issuer.example/authorize") {
t.Errorf("authorization URL didn't pick up discovered endpoint, got %q", got)
}
}
func TestExchangeCodeForTokenJSON(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
if r.Form.Get("grant_type") != "authorization_code" {
t.Errorf("grant_type: got %q", r.Form.Get("grant_type"))
}
if r.Form.Get("code") != "the-code" {
t.Errorf("code: got %q", r.Form.Get("code"))
}
if r.Form.Get("client_id") != "abc" {
t.Errorf("client_id: got %q", r.Form.Get("client_id"))
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"the-access-token","token_type":"Bearer","id_token":"abc.def.ghi"}`)
}))
defer srv.Close()
c, _ := NewClient(Config{
Type: "oauth2",
ClientID: "abc",
ClientSecret: "secret",
AuthorizationURL: "https://provider/authorize",
TokenURL: srv.URL,
UserinfoURL: "https://provider/userinfo",
RedirectURI: "https://ragflow/cb",
})
tok, err := c.ExchangeCodeForToken(context.Background(), "the-code")
if err != nil {
t.Fatalf("ExchangeCodeForToken: %v", err)
}
if tok.AccessToken != "the-access-token" {
t.Errorf("access_token: got %q", tok.AccessToken)
}
if tok.IDToken != "abc.def.ghi" {
t.Errorf("id_token: got %q", tok.IDToken)
}
}
func TestExchangeCodeForTokenFormResponse(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
_, _ = io.WriteString(w, "access_token=abc&token_type=bearer&scope=user")
}))
defer srv.Close()
c, _ := NewClient(Config{
Type: "oauth2",
ClientID: "abc",
AuthorizationURL: "https://x/authorize",
TokenURL: srv.URL,
UserinfoURL: "https://x/userinfo",
RedirectURI: "https://x/cb",
})
tok, err := c.ExchangeCodeForToken(context.Background(), "c")
if err != nil {
t.Fatalf("ExchangeCodeForToken: %v", err)
}
if tok.AccessToken != "abc" {
t.Errorf("access_token: got %q", tok.AccessToken)
}
}
func TestExchangeCodeForTokenError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = io.WriteString(w, `{"error":"invalid_grant"}`)
}))
defer srv.Close()
c, _ := NewClient(Config{
Type: "oauth2",
ClientID: "abc",
AuthorizationURL: "https://x/authorize",
TokenURL: srv.URL,
UserinfoURL: "https://x/userinfo",
RedirectURI: "https://x/cb",
})
_, err := c.ExchangeCodeForToken(context.Background(), "c")
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "HTTP 400") {
t.Errorf("expected HTTP 400 surfaced, got %v", err)
}
}
func TestFetchUserInfoOAuth2Normalization(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "Bearer the-token" {
t.Errorf("Authorization header: got %q", r.Header.Get("Authorization"))
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"email":"alice@example.com","picture":"https://cdn/alice.png"}`)
}))
defer srv.Close()
c, _ := NewClient(Config{
Type: "oauth2",
ClientID: "abc",
AuthorizationURL: "https://x/authorize",
TokenURL: "https://x/token",
UserinfoURL: srv.URL,
RedirectURI: "https://x/cb",
})
info, err := c.FetchUserInfo(context.Background(), "the-token", "")
if err != nil {
t.Fatalf("FetchUserInfo: %v", err)
}
if info.Email != "alice@example.com" {
t.Errorf("email: %q", info.Email)
}
if info.Username != "alice" {
t.Errorf("username (fallback to email local): %q", info.Username)
}
if info.Nickname != "alice" {
t.Errorf("nickname (fallback to username): %q", info.Nickname)
}
if info.AvatarURL != "https://cdn/alice.png" {
t.Errorf("avatar (picture fallback): %q", info.AvatarURL)
}
}
func TestFetchUserInfoGitHubMergesPrimaryEmail(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/user", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"login":"bob","name":"Bob Bobson","email":null,"avatar_url":"https://gh/bob.png"}`)
})
mux.HandleFunc("/user/emails", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
emails := []map[string]interface{}{
{"email": "bob-noreply@users.noreply.github.com", "primary": false, "verified": true},
{"email": "bob@example.com", "primary": true, "verified": true},
}
raw, _ := json.Marshal(emails)
_, _ = w.Write(raw)
})
srv := httptest.NewServer(mux)
defer srv.Close()
c, _ := NewClient(Config{
Type: "github",
ClientID: "abc",
ClientSecret: "secret",
RedirectURI: "https://ragflow/cb",
})
gh := c.(*gitHubClient)
gh.cfg.UserinfoURL = srv.URL + "/user"
info, err := gh.FetchUserInfo(context.Background(), "tok", "")
if err != nil {
t.Fatalf("FetchUserInfo: %v", err)
}
if info.Email != "bob@example.com" {
t.Errorf("primary email merge failed: got %q", info.Email)
}
if info.Username != "bob" {
t.Errorf("username (from login): got %q", info.Username)
}
if info.Nickname != "Bob Bobson" {
t.Errorf("nickname (from name): got %q", info.Nickname)
}
}
func TestNormalizeUserInfoEmptyEmail(t *testing.T) {
ui := normalizeUserInfo(map[string]interface{}{"email": ""})
if ui.Email != "" || ui.Username != "" || ui.Nickname != "" {
t.Errorf("empty email should produce empty fields, got %+v", ui)
}
}
func TestNewOAuthClientMissingFields(t *testing.T) {
cases := []Config{
{Type: "oauth2"},
{Type: "oauth2", ClientID: "x"},
{Type: "oauth2", ClientID: "x", AuthorizationURL: "https://x/a"},
{Type: "oauth2", ClientID: "x", AuthorizationURL: "https://x/a", TokenURL: "https://x/t"},
}
for i, cfg := range cases {
if _, err := NewClient(cfg); err == nil {
t.Errorf("case %d: expected error for missing required field", i)
}
}
}
func TestNewOIDCMissingIssuer(t *testing.T) {
_, err := NewClient(Config{Type: "oidc"})
if err == nil || !strings.Contains(err.Error(), "Missing issuer") {
t.Errorf("expected Missing issuer error, got %v", err)
}
}

View File

@@ -0,0 +1,134 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package oauth
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
)
// gitHubClient overrides the OAuth endpoints with GitHub's well-known URLs
// and reaches into /user/emails to recover the primary email, since
// GitHub's /user response omits it when the user has hidden email
// visibility.
type gitHubClient struct {
*oauthClient
}
func newGitHubClient(cfg Config) (*gitHubClient, error) {
cfg.AuthorizationURL = "https://github.com/login/oauth/authorize"
cfg.TokenURL = "https://github.com/login/oauth/access_token"
cfg.UserinfoURL = "https://api.github.com/user"
if cfg.Scope == "" {
cfg.Scope = "user:email"
}
base, err := newOAuthClient(cfg)
if err != nil {
return nil, err
}
return &gitHubClient{oauthClient: base}, nil
}
// FetchUserInfo overrides the base implementation to merge the primary
// email from /user/emails. Mirrors GithubOAuthClient.fetch_user_info.
func (c *gitHubClient) FetchUserInfo(ctx context.Context, accessToken, idToken string) (*UserInfo, error) {
raw, err := c.fetchUserinfoRaw(ctx, c.cfg.UserinfoURL, accessToken)
if err != nil {
return nil, fmt.Errorf("Failed to fetch github user info: %w", err)
}
email, err := c.fetchPrimaryEmail(ctx, accessToken)
if err != nil {
return nil, fmt.Errorf("Failed to fetch github user info: %w", err)
}
if email != "" {
raw["email"] = email
}
return normalizeGitHubUserInfo(raw), nil
}
func (c *gitHubClient) fetchPrimaryEmail(ctx context.Context, accessToken string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.cfg.UserinfoURL+"/emails", nil)
if err != nil {
return "", err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return "", err
}
if resp.StatusCode >= 400 {
return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var emails []map[string]interface{}
if err := json.Unmarshal(body, &emails); err != nil {
return "", fmt.Errorf("parse emails response: %w", err)
}
for _, e := range emails {
primary, _ := e["primary"].(bool)
addr, _ := e["email"].(string)
if primary && addr != "" {
return addr, nil
}
}
// Fall back to the first verified email if no primary is flagged.
for _, e := range emails {
verified, _ := e["verified"].(bool)
addr, _ := e["email"].(string)
if verified && addr != "" {
return addr, nil
}
}
return "", nil
}
// normalizeGitHubUserInfo mirrors GithubOAuthClient.normalize_user_info:
// username comes from "login", nickname from "name", avatar from
// "avatar_url".
func normalizeGitHubUserInfo(raw map[string]interface{}) *UserInfo {
ui := &UserInfo{}
if v, ok := raw["email"].(string); ok {
ui.Email = v
}
if v, ok := raw["login"].(string); ok && v != "" {
ui.Username = v
} else if ui.Email != "" {
if at := strings.IndexByte(ui.Email, '@'); at >= 0 {
ui.Username = ui.Email[:at]
} else {
ui.Username = ui.Email
}
}
if v, ok := raw["name"].(string); ok && v != "" {
ui.Nickname = v
} else {
ui.Nickname = ui.Username
}
if v, ok := raw["avatar_url"].(string); ok {
ui.AvatarURL = v
}
return ui
}

View File

@@ -0,0 +1,107 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package oauth
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
)
// oidcClient is the OIDC flavor: it resolves the authorization /
// token / userinfo URLs from the Issuer's discovery document
// (.well-known/openid-configuration) before delegating the OAuth flow to
// oauthClient.
//
// We do not currently verify or parse the id_token: id_token claims are
// only used as enrichment in the Python implementation, and the /userinfo
// endpoint (already called) returns the same canonical claims authenticated
// via the access_token. Tracked as a follow-up to add full JWKS-based
// id_token verification once a JWT library is vendored.
type oidcClient struct {
*oauthClient
issuer string
}
func newOIDCClient(cfg Config) (*oidcClient, error) {
if cfg.Issuer == "" {
return nil, fmt.Errorf("Missing issuer in configuration.")
}
meta, err := loadOIDCMetadata(cfg.Issuer)
if err != nil {
return nil, fmt.Errorf("Failed to fetch OIDC metadata: %w", err)
}
if meta.Issuer != "" {
cfg.Issuer = meta.Issuer
}
if cfg.AuthorizationURL == "" {
cfg.AuthorizationURL = meta.AuthorizationEndpoint
}
if cfg.TokenURL == "" {
cfg.TokenURL = meta.TokenEndpoint
}
if cfg.UserinfoURL == "" {
cfg.UserinfoURL = meta.UserinfoEndpoint
}
base, err := newOAuthClient(cfg)
if err != nil {
return nil, err
}
return &oidcClient{oauthClient: base, issuer: cfg.Issuer}, nil
}
// oidcMetadata is the subset of fields we use from the discovery document.
type oidcMetadata struct {
Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserinfoEndpoint string `json:"userinfo_endpoint"`
JWKSURI string `json:"jwks_uri"`
}
// loadOIDCMetadata is the indirection used to fetch a provider's discovery
// document. Tests override it.
var loadOIDCMetadata = func(issuer string) (*oidcMetadata, error) {
metadataURL := strings.TrimRight(issuer, "/") + "/.well-known/openid-configuration"
ctx, cancel := context.WithTimeout(context.Background(), HTTPRequestTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, metadataURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
meta := &oidcMetadata{}
if err := json.Unmarshal(body, meta); err != nil {
return nil, fmt.Errorf("parse OIDC metadata: %w", err)
}
return meta, nil
}