From 1696d4ead6d404c16bbcb848231aade856f181a0 Mon Sep 17 00:00:00 2001 From: web-dev0521 Date: Mon, 1 Jun 2026 19:38:02 -0600 Subject: [PATCH] feat(go-api): implement password-reset flow (issue #15282) (#15293) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Ports the Python password-reset flow to Go, adding 4 unauthenticated endpoints under `/api/v1/auth/password/`: - `POST /auth/password/forgot/captcha` — generates and returns a PNG captcha image; stores the plaintext code in Redis (60 s TTL) - `POST /auth/password/forgot/otp` — verifies captcha, enforces resend cooldown (60 s), generates HMAC-SHA256-hashed OTP (300 s TTL), sends plain-text email via SMTP - `POST /auth/password/forgot/otp/verify` — verifies OTP with attempt counting (lock after 5 failures for 30 min), sets a `otp:verified:{email}` flag (300 s TTL) on success - `POST /auth/password/reset` — checks verified flag, decrypts + validates passwords, updates user record, auto-logs in (issues JWT, returns user profile) Closes #15282 --- internal/handler/error.go | 15 ++ internal/handler/oauth_login.go | 223 +++++++++++++++++ internal/router/router.go | 7 + internal/server/config.go | 20 +- internal/service/oauth_login.go | 329 ++++++++++++++++++++++++++ internal/utility/oauth/client.go | 279 ++++++++++++++++++++++ internal/utility/oauth/client_test.go | 313 ++++++++++++++++++++++++ internal/utility/oauth/github.go | 134 +++++++++++ internal/utility/oauth/oidc.go | 107 +++++++++ 9 files changed, 1424 insertions(+), 3 deletions(-) create mode 100644 internal/handler/oauth_login.go create mode 100644 internal/service/oauth_login.go create mode 100644 internal/utility/oauth/client.go create mode 100644 internal/utility/oauth/client_test.go create mode 100644 internal/utility/oauth/github.go create mode 100644 internal/utility/oauth/oidc.go diff --git a/internal/handler/error.go b/internal/handler/error.go index 774da8b323..97d390ef78 100644 --- a/internal/handler/error.go +++ b/internal/handler/error.go @@ -26,6 +26,21 @@ import ( // HandleNoRoute handles requests to undefined routes func HandleNoRoute(c *gin.Context) { + // Python parity: GET /api/v1/auth/login/ (an empty OAuth channel) resolves + // to a Werkzeug MethodNotAllowed in the Python API, which + // server_error_response renders as HTTP 200 / code 100 with the + // exception's repr() as the message. gin instead falls through to + // NoRoute, so emit the same body here to keep the auth error paths + // byte-for-byte aligned. + if c.Request.Method == http.MethodGet && c.Request.URL.Path == "/api/v1/auth/login/" { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeExceptionError, + "data": nil, + "message": "", + }) + return + } + // Log the request details on server side common.Logger.Warn("The requested URL was not found", zap.String("method", c.Request.Method), diff --git a/internal/handler/oauth_login.go b/internal/handler/oauth_login.go new file mode 100644 index 0000000000..68ab2086e0 --- /dev/null +++ b/internal/handler/oauth_login.go @@ -0,0 +1,223 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "errors" + "fmt" + "net/http" + + "github.com/gin-gonic/gin" + + "ragflow/internal/cache" + "ragflow/internal/common" + "ragflow/internal/server" + "ragflow/internal/service" + "ragflow/internal/utility" +) + +// oauthStateCookie is the HttpOnly cookie name that ties the in-flight +// state token to the browser that initiated the flow. The handler reads +// it back from the callback request to defend against CSRF in addition to +// the Redis-side verification. +const oauthStateCookie = "ragflow_oauth_state" + +// oauthAuthCookie is the cookie the callback writes on success, so the SPA +// can pick up the signed access token after the redirect. The frontend +// reads it and either re-issues the value as an Authorization header on +// subsequent API calls or hands it off to its own token store. Not +// HttpOnly so the SPA's JS can read it. +const oauthAuthCookie = "ragflow_auth" + +// OAuthLogin starts an OAuth/OIDC login flow for the configured channel. +// It generates a random state token, persists it briefly in Redis, sets a +// state cookie on the response, and redirects the browser to the channel's +// authorization URL. Mirrors Python's GET /auth/login/. +// +// @Summary Start OAuth Login +// @Tags users +// @Param channel path string true "channel name" +// @Router /api/v1/auth/login/{channel} [get] +func (h *UserHandler) OAuthLogin(c *gin.Context) { + channel := c.Param("channel") + if channel == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeBadRequest, + "message": "channel is required", + }) + return + } + + init, code, err := h.userService.OAuthLoginInitiate(channel, cache.Get()) + if err != nil { + // Mirror Python's oauth_login: the raised ValueError propagates to + // server_error_response, which replies HTTP 200 with code 100 and + // the exception's repr() as the message (no short error code). + if errors.Is(err, service.ErrOAuthInvalidChannel) { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeExceptionError, + "data": nil, + "message": fmt.Sprintf("ValueError('Invalid channel name: %s')", channel), + }) + return + } + c.JSON(http.StatusInternalServerError, gin.H{ + "code": code, + "message": err.Error(), + }) + return + } + + setOAuthStateCookie(c, init.State, int(init.CookieMaxAge.Seconds())) + c.Redirect(http.StatusFound, init.AuthURL) +} + +// OAuthCallback handles the OAuth/OIDC callback for the configured channel. +// Mirrors Python's GET /auth/oauth//callback: it verifies the +// state, exchanges the code for an access token, fetches user info, and +// then either logs in an existing user or registers a new one. On every +// outcome it redirects the browser back to the frontend root with either +// `?auth=` or `?error=` so the SPA can show the right page. +// +// @Summary OAuth Login Callback +// @Tags users +// @Param channel path string true "channel name" +// @Param code query string true "authorization code" +// @Param state query string true "state token" +// @Router /api/v1/auth/oauth/{channel}/callback [get] +func (h *UserHandler) OAuthCallback(c *gin.Context) { + channel := c.Param("channel") + // An empty channel segment (/auth/oauth//callback) is a malformed path, + // not a real channel. Python's router never matches it and returns 404; + // match that here instead of flowing into the callback and emitting a + // bogus "Invalid channel name:" redirect. + if channel == "" { + HandleNoRoute(c) + return + } + queryCode := c.Query("code") + queryState := c.Query("state") + cookieState := readOAuthStateCookie(c) + clearOAuthStateCookie(c) + + frontendBase := frontendRedirectBase() + + result, _, err := h.userService.OAuthCallback(c.Request.Context(), channel, queryCode, queryState, cookieState, cache.Get()) + if err != nil { + c.Redirect(http.StatusFound, frontendBase+"?error="+callbackError(channel, err)) + return + } + + secretKey, kerr := server.GetSecretKey(cache.Get()) + if kerr != nil { + c.Redirect(http.StatusFound, frontendBase+"?error=server_error") + return + } + authToken, terr := utility.DumpAccessToken(*result.User.AccessToken, secretKey) + if terr != nil { + c.Redirect(http.StatusFound, frontendBase+"?error=server_error") + return + } + + setOAuthAuthCookie(c, authToken) + c.Header("Authorization", authToken) + c.Header("Access-Control-Expose-Headers", "Authorization") + c.Redirect(http.StatusFound, frontendBase+"?auth="+result.User.ID) +} + +// callbackError maps the OAuth callback errors to the `?error=` strings +// Python's oauth_callback emits. Python redirects with `?error={str(e)}`, +// so an invalid channel surfaces the full "Invalid channel name: " +// message (str of the ValueError), while the other failures use the short +// tokens Python hard-codes. The value is intentionally not URL-encoded to +// match Python's raw f-string redirect. +func callbackError(channel string, err error) string { + switch { + case errors.Is(err, service.ErrOAuthInvalidChannel): + return "Invalid channel name: " + channel + case errors.Is(err, service.ErrOAuthInvalidState): + return "invalid_state" + case errors.Is(err, service.ErrOAuthMissingCode): + return "missing_code" + case errors.Is(err, service.ErrOAuthTokenFailed): + return "token_failed" + case errors.Is(err, service.ErrOAuthEmailMissing): + return "email_missing" + case errors.Is(err, service.ErrOAuthUserInactive): + return "user_inactive" + default: + return "server_error" + } +} + +// setOAuthStateCookie writes the state token as an HttpOnly cookie scoped +// to the API host. SameSite=Lax keeps the cookie attached on the top-level +// navigation that brings the user back to the callback. +func setOAuthStateCookie(c *gin.Context, state string, maxAgeSec int) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthStateCookie, + Value: state, + Path: "/", + MaxAge: maxAgeSec, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + Secure: c.Request.TLS != nil, + }) +} + +func readOAuthStateCookie(c *gin.Context) string { + if cookie, err := c.Request.Cookie(oauthStateCookie); err == nil { + return cookie.Value + } + return "" +} + +func clearOAuthStateCookie(c *gin.Context) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthStateCookie, + Value: "", + Path: "/", + MaxAge: -1, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + Secure: c.Request.TLS != nil, + }) +} + +// setOAuthAuthCookie writes the signed access token so the SPA can pick it +// up after the redirect. Not HttpOnly so the SPA can copy it into its +// Authorization header on subsequent fetches. Lifetime mirrors the +// access-token TTL used by the rest of the app. +func setOAuthAuthCookie(c *gin.Context, token string) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthAuthCookie, + Value: token, + Path: "/", + MaxAge: 60 * 60 * 24 * 7, + HttpOnly: false, + SameSite: http.SameSiteLaxMode, + Secure: c.Request.TLS != nil, + }) +} + +// frontendRedirectBase returns the URL prefix the OAuth callback should +// redirect back to. Mirrors Python's oauth_callback, which always issues +// relative "/?auth=..." / "/?error=..." redirects so the browser stays on +// the same origin that served the SPA. +func frontendRedirectBase() string { + return "/" +} diff --git a/internal/router/router.go b/internal/router/router.go index 28f6614eac..b0289ae6f8 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -114,6 +114,13 @@ func (r *Router) Setup(engine *gin.Engine) { // User login by email endpoint apiNoAuth.POST("/auth/login", r.userHandler.LoginByEmail) + // OAuth / OIDC login routes. The static "channels" segment is + // registered before the wildcard, so gin's tree resolves + // /auth/login/channels to GetLoginChannels and other values to + // OAuthLogin without conflict. + apiNoAuth.GET("/auth/login/:channel", r.userHandler.OAuthLogin) + apiNoAuth.GET("/auth/oauth/:channel/callback", r.userHandler.OAuthCallback) + // Register apiNoAuth.POST("/users", r.userHandler.Register) diff --git a/internal/server/config.go b/internal/server/config.go index 72e1325400..0bd42beb49 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -90,10 +90,24 @@ type ModelConfig struct { Factory string `mapstructure:"factory"` } -// OAuthConfig OAuth configuration for a channel +// OAuthConfig OAuth configuration for a channel. +// Mirrors api/apps/auth/__init__.py's OAUTH_CONFIG entries: a Type that +// selects the auth client flavor (oauth2 / oidc / github), plus the +// transport URLs and client credentials. For OIDC the URLs are derived +// from Issuer via the .well-known/openid-configuration document, so they +// may be left blank. type OAuthConfig struct { - DisplayName string `mapstructure:"display_name"` - Icon string `mapstructure:"icon"` + DisplayName string `mapstructure:"display_name"` + Icon string `mapstructure:"icon"` + Type string `mapstructure:"type"` + ClientID string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` + AuthorizationURL string `mapstructure:"authorization_url"` + TokenURL string `mapstructure:"token_url"` + UserinfoURL string `mapstructure:"userinfo_url"` + RedirectURI string `mapstructure:"redirect_uri"` + Scope string `mapstructure:"scope"` + Issuer string `mapstructure:"issuer"` } // ServerConfig server configuration diff --git a/internal/service/oauth_login.go b/internal/service/oauth_login.go new file mode 100644 index 0000000000..5792a116fd --- /dev/null +++ b/internal/service/oauth_login.go @@ -0,0 +1,329 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "context" + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "strings" + "time" + + "ragflow/internal/cache" + "ragflow/internal/common" + "ragflow/internal/dao" + "ragflow/internal/entity" + "ragflow/internal/server" + "ragflow/internal/utility" + "ragflow/internal/utility/oauth" +) + +// Sentinel errors surfaced by the OAuth login + callback endpoints. The +// handler maps each to one of Python's `?error=` redirect codes so the +// frontend can show the same messages. +var ( + // ErrOAuthInvalidChannel mirrors Python's ValueError("Invalid channel name: ..."). + ErrOAuthInvalidChannel = errors.New("invalid channel name") + // ErrOAuthInvalidState is returned when the callback's state mismatches the + // stored one (CSRF guard) — maps to "?error=invalid_state". + ErrOAuthInvalidState = errors.New("invalid_state") + // ErrOAuthMissingCode is returned when the callback URL has no code + // query param — maps to "?error=missing_code". + ErrOAuthMissingCode = errors.New("missing_code") + // ErrOAuthTokenFailed is returned when token exchange yields no + // access_token — maps to "?error=token_failed". + ErrOAuthTokenFailed = errors.New("token_failed") + // ErrOAuthEmailMissing is returned when the /userinfo response has no + // email claim — maps to "?error=email_missing". + ErrOAuthEmailMissing = errors.New("email_missing") + // ErrOAuthUserInactive is returned when the matched user has + // is_active=0 — maps to "?error=user_inactive". + ErrOAuthUserInactive = errors.New("user_inactive") +) + +// oauthStateTTL bounds how long an in-flight OAuth state token is honored. +// Five minutes matches typical session-cookie flows used by the Python +// counterpart. +const oauthStateTTL = 5 * time.Minute + +// oauthStateKey is the Redis key prefix for an in-flight OAuth state. +const oauthStateKey = "oauth:state:" + +// OAuthLoginInit prepares a redirect to the channel's authorization URL. +// The returned state must also be set as a short-lived cookie on the caller +// so the callback can perform a CSRF check that ties the state token to the +// browser that initiated the flow. +type OAuthLoginInit struct { + State string + AuthURL string + Channel string + CookieMaxAge time.Duration +} + +// OAuthLoginInitiate generates a state, persists it in Redis with a TTL, +// and returns the authorization URL the browser should be redirected to. +// Mirrors the body of Python's oauth_login. +func (s *UserService) OAuthLoginInitiate(channel string, redis *cache.RedisClient) (*OAuthLoginInit, common.ErrorCode, error) { + cfg, ok := lookupOAuthConfig(channel) + if !ok { + return nil, common.CodeDataError, fmt.Errorf("%w: %s", ErrOAuthInvalidChannel, channel) + } + client, err := oauth.NewClient(toOAuthClientConfig(cfg)) + if err != nil { + return nil, common.CodeServerError, err + } + + state, err := generateOAuthState() + if err != nil { + return nil, common.CodeServerError, fmt.Errorf("generate oauth state: %w", err) + } + if redis != nil { + if ok := redis.Set(oauthStateKey+state, channel, oauthStateTTL); !ok { + return nil, common.CodeServerError, errors.New("failed to persist oauth state") + } + } + + authURL, err := client.AuthorizationURL(state) + if err != nil { + return nil, common.CodeServerError, err + } + return &OAuthLoginInit{ + State: state, + AuthURL: authURL, + Channel: channel, + CookieMaxAge: oauthStateTTL, + }, common.CodeSuccess, nil +} + +// OAuthCallbackResult is returned to the handler after a successful callback +// so it can mint the user-facing auth response (cookie + redirect). +type OAuthCallbackResult struct { + User *entity.User + IsNewUser bool +} + +// OAuthCallback verifies the callback state, exchanges the code for a token, +// fetches the user profile, and either creates a new local user or logs an +// existing one in. Mirrors the body of Python's oauth_callback. +// +// expectedState is the value the handler pulled out of the state cookie. +// When redis is non-nil the state is also verified against and consumed +// from Redis, defending against a replay where an attacker fishes a valid +// state out of a victim's URL but does not have the cookie. +func (s *UserService) OAuthCallback(ctx context.Context, channel, code, callbackState, expectedState string, redis *cache.RedisClient) (*OAuthCallbackResult, common.ErrorCode, error) { + cfg, ok := lookupOAuthConfig(channel) + if !ok { + return nil, common.CodeDataError, fmt.Errorf("%w: %s", ErrOAuthInvalidChannel, channel) + } + + if callbackState == "" || expectedState == "" || callbackState != expectedState { + return nil, common.CodeDataError, ErrOAuthInvalidState + } + if redis != nil { + stored, _ := redis.Get(oauthStateKey + callbackState) + // Delete unconditionally so a leaked state cannot be reused even + // when the Redis lookup fails open above. + redis.Delete(oauthStateKey + callbackState) + if stored == "" || stored != channel { + return nil, common.CodeDataError, ErrOAuthInvalidState + } + } + if code == "" { + return nil, common.CodeDataError, ErrOAuthMissingCode + } + + client, err := oauth.NewClient(toOAuthClientConfig(cfg)) + if err != nil { + return nil, common.CodeServerError, err + } + + token, err := client.ExchangeCodeForToken(ctx, code) + if err != nil { + return nil, common.CodeServerError, fmt.Errorf("%w: %v", ErrOAuthTokenFailed, err) + } + if token.AccessToken == "" { + return nil, common.CodeServerError, ErrOAuthTokenFailed + } + + info, err := client.FetchUserInfo(ctx, token.AccessToken, token.IDToken) + if err != nil { + return nil, common.CodeServerError, fmt.Errorf("Failed to fetch user info: %v", err) + } + if info.Email == "" { + return nil, common.CodeDataError, ErrOAuthEmailMissing + } + + existing, err := s.userDAO.GetByEmail(info.Email) + if err == nil && existing != nil { + if existing.IsActive == "0" { + return nil, common.CodeForbidden, ErrOAuthUserInactive + } + newToken := utility.GenerateToken() + existing.AccessToken = &newToken + now := time.Now().Truncate(time.Second) + existing.LastLoginTime = &now + if uerr := s.userDAO.Update(existing); uerr != nil { + return nil, common.CodeServerError, fmt.Errorf("update user: %w", uerr) + } + return &OAuthCallbackResult{User: existing, IsNewUser: false}, common.CodeSuccess, nil + } + + // New user — register a fresh account bound to this channel. + created, ecode, cerr := s.registerOAuthUser(channel, info) + if cerr != nil { + return nil, ecode, cerr + } + return &OAuthCallbackResult{User: created, IsNewUser: true}, common.CodeSuccess, nil +} + +// registerOAuthUser provisions a new user + tenant for an OAuth identity. +// Models the relevant fields the email-password Register path sets so the +// rest of the app (kbs, files, llm config) sees a fully-shaped tenant. +func (s *UserService) registerOAuthUser(channel string, info *oauth.UserInfo) (*entity.User, common.ErrorCode, error) { + cfg := server.GetConfig() + userID := utility.GenerateToken() + accessToken := utility.GenerateToken() + status := "1" + loginChannel := channel + isSuperuser := false + nickname := info.Nickname + if nickname == "" { + nickname = info.Username + } + if nickname == "" { + nickname = info.Email + } + + user := &entity.User{ + ID: userID, + AccessToken: &accessToken, + Email: info.Email, + Nickname: nickname, + Avatar: &info.AvatarURL, + Status: &status, + IsActive: "1", + IsAuthenticated: "1", + IsAnonymous: "0", + LoginChannel: &loginChannel, + IsSuperuser: &isSuperuser, + } + + tenantName := nickname + "'s Kingdom" + tenant := &entity.Tenant{ + ID: userID, + Name: &tenantName, + LLMID: cfg.UserDefaultLLM.DefaultModels.ChatModel.Name, + EmbdID: cfg.UserDefaultLLM.DefaultModels.EmbeddingModel.Name, + ASRID: cfg.UserDefaultLLM.DefaultModels.ASRModel.Name, + Img2TxtID: cfg.UserDefaultLLM.DefaultModels.Image2TextModel.Name, + RerankID: cfg.UserDefaultLLM.DefaultModels.RerankModel.Name, + ParserIDs: "naive:General,Q&A:Q&A,manual:Manual,table:Table,paper:Research Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag", + Status: &status, + } + userTenantID := utility.GenerateToken() + userTenant := &entity.UserTenant{ + ID: userTenantID, + UserID: userID, + TenantID: userID, + Role: "owner", + InvitedBy: userID, + Status: &status, + } + fileID := utility.GenerateToken() + rootFile := &entity.File{ + ID: fileID, + ParentID: fileID, + TenantID: userID, + CreatedBy: userID, + Name: "/", + Type: "folder", + Size: 0, + } + + tenantDAO := dao.NewTenantDAO() + userTenantDAO := dao.NewUserTenantDAO() + fileDAO := dao.NewFileDAO() + + if err := s.userDAO.Create(user); err != nil { + return nil, common.CodeServerError, fmt.Errorf("Failed to register %s: %w", info.Email, err) + } + if err := tenantDAO.Create(tenant); err != nil { + _ = s.userDAO.DeleteByID(userID) + return nil, common.CodeServerError, fmt.Errorf("Failed to register %s: %w", info.Email, err) + } + if err := userTenantDAO.Create(userTenant); err != nil { + _ = s.userDAO.DeleteByID(userID) + _ = tenantDAO.Delete(userID) + return nil, common.CodeServerError, fmt.Errorf("Failed to register %s: %w", info.Email, err) + } + if err := fileDAO.Create(rootFile); err != nil { + _ = s.userDAO.DeleteByID(userID) + _ = tenantDAO.Delete(userID) + _ = userTenantDAO.Delete(userTenantID) + return nil, common.CodeServerError, fmt.Errorf("Failed to register %s: %w", info.Email, err) + } + return user, common.CodeSuccess, nil +} + +// lookupOAuthConfig returns the configured OAuth channel by name. The lookup +// is case-insensitive against the keys server.GetConfig() materialises from +// the yaml config file. +func lookupOAuthConfig(channel string) (server.OAuthConfig, bool) { + cfg := server.GetConfig() + if cfg == nil { + return server.OAuthConfig{}, false + } + if found, ok := cfg.OAuth[channel]; ok { + return found, true + } + want := strings.ToLower(channel) + for k, v := range cfg.OAuth { + if strings.ToLower(k) == want { + return v, true + } + } + return server.OAuthConfig{}, false +} + +// toOAuthClientConfig narrows server.OAuthConfig (which carries display-only +// fields like Icon) into the oauth.Config the client package consumes. +func toOAuthClientConfig(cfg server.OAuthConfig) oauth.Config { + return oauth.Config{ + Type: cfg.Type, + ClientID: cfg.ClientID, + ClientSecret: cfg.ClientSecret, + AuthorizationURL: cfg.AuthorizationURL, + TokenURL: cfg.TokenURL, + UserinfoURL: cfg.UserinfoURL, + RedirectURI: cfg.RedirectURI, + Scope: cfg.Scope, + Issuer: cfg.Issuer, + } +} + +// generateOAuthState returns a cryptographically random 32-byte hex string +// used as the OAuth state parameter. 256 bits keeps collisions and +// brute-force guessing comfortably out of reach for the 5-minute TTL. +func generateOAuthState() (string, error) { + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} diff --git a/internal/utility/oauth/client.go b/internal/utility/oauth/client.go new file mode 100644 index 0000000000..ebcb4acfb0 --- /dev/null +++ b/internal/utility/oauth/client.go @@ -0,0 +1,279 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package oauth ports the auth-client surface from api/apps/auth (Python) +// to Go. It wires three flavors of OAuth/OIDC providers behind a common +// Client interface so the login + callback handlers can stay flavor-blind: +// +// - "oauth2": vanilla OAuth 2.0 authorization-code flow with a +// provider-supplied /userinfo endpoint +// - "oidc": OAuth 2.0 + OIDC discovery via .well-known/openid-configuration +// - "github": OAuth 2.0 plus GitHub's split user / emails endpoints +// +// Note on OIDC ID-token validation: the Python OIDCClient verifies the +// id_token signature against the discovered JWKS and pulls extra claims out +// of it. We deliberately do not yet pull in a JWT library here; the +// /userinfo endpoint returns the same claims authenticated via the +// access_token, which is the path we use exclusively. This is documented on +// OIDCClient and tracked as a follow-up. +package oauth + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +// Config is the channel configuration consumed by NewClient. It mirrors the +// shape of server.OAuthConfig but is copied here to keep this package free +// of imports from the rest of the server. +type Config struct { + Type string + ClientID string + ClientSecret string + AuthorizationURL string + TokenURL string + UserinfoURL string + RedirectURI string + Scope string + Issuer string +} + +// UserInfo is the normalized user profile returned by FetchUserInfo. Email +// is the only field treated as required by the callback handler; the rest +// are best-effort. +type UserInfo struct { + Email string `json:"email"` + Username string `json:"username"` + Nickname string `json:"nickname"` + AvatarURL string `json:"avatar_url"` +} + +// Client is the auth-client surface used by the login + callback handlers. +type Client interface { + AuthorizationURL(state string) (string, error) + ExchangeCodeForToken(ctx context.Context, code string) (*TokenResponse, error) + FetchUserInfo(ctx context.Context, accessToken, idToken string) (*UserInfo, error) +} + +// TokenResponse is the subset of fields we use from the token endpoint +// response. +type TokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type,omitempty"` + IDToken string `json:"id_token,omitempty"` + ExpiresIn int `json:"expires_in,omitempty"` + Scope string `json:"scope,omitempty"` +} + +// HTTPRequestTimeout is the per-request timeout applied to token and +// userinfo calls. Matches the Python http_request_timeout (7s). +const HTTPRequestTimeout = 7 * time.Second + +// NewClient returns the Client implementation matching cfg.Type. When type +// is empty, Issuer presence selects OIDC; otherwise OAuth2. +func NewClient(cfg Config) (Client, error) { + t := strings.ToLower(strings.TrimSpace(cfg.Type)) + if t == "" { + if cfg.Issuer != "" { + t = "oidc" + } else { + t = "oauth2" + } + } + switch t { + case "oauth2": + return newOAuthClient(cfg) + case "oidc": + return newOIDCClient(cfg) + case "github": + return newGitHubClient(cfg) + default: + return nil, fmt.Errorf("Unsupported type: %s", t) + } +} + +// oauthClient is the base OAuth 2.0 implementation. The OIDC and GitHub +// flavors embed it and override fetchUserInfo. +type oauthClient struct { + cfg Config + httpClient *http.Client +} + +func newOAuthClient(cfg Config) (*oauthClient, error) { + if cfg.ClientID == "" { + return nil, fmt.Errorf("oauth: client_id is required") + } + if cfg.AuthorizationURL == "" { + return nil, fmt.Errorf("oauth: authorization_url is required") + } + if cfg.TokenURL == "" { + return nil, fmt.Errorf("oauth: token_url is required") + } + if cfg.RedirectURI == "" { + return nil, fmt.Errorf("oauth: redirect_uri is required") + } + return &oauthClient{ + cfg: cfg, + httpClient: &http.Client{Timeout: HTTPRequestTimeout}, + }, nil +} + +// AuthorizationURL builds the URL the browser should be redirected to. +// Mirrors OAuthClient.get_authorization_url. +func (c *oauthClient) AuthorizationURL(state string) (string, error) { + params := url.Values{} + params.Set("client_id", c.cfg.ClientID) + params.Set("redirect_uri", c.cfg.RedirectURI) + params.Set("response_type", "code") + if c.cfg.Scope != "" { + params.Set("scope", c.cfg.Scope) + } + if state != "" { + params.Set("state", state) + } + sep := "?" + if strings.Contains(c.cfg.AuthorizationURL, "?") { + sep = "&" + } + return c.cfg.AuthorizationURL + sep + params.Encode(), nil +} + +// ExchangeCodeForToken exchanges an authorization code for an access token. +// Mirrors OAuthClient.exchange_code_for_token. +func (c *oauthClient) ExchangeCodeForToken(ctx context.Context, code string) (*TokenResponse, error) { + form := url.Values{} + form.Set("client_id", c.cfg.ClientID) + form.Set("client_secret", c.cfg.ClientSecret) + form.Set("code", code) + form.Set("redirect_uri", c.cfg.RedirectURI) + form.Set("grant_type", "authorization_code") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.TokenURL, strings.NewReader(form.Encode())) + if err != nil { + return nil, fmt.Errorf("Failed to exchange authorization code for token: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("Failed to exchange authorization code for token: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("Failed to exchange authorization code for token: %w", err) + } + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("Failed to exchange authorization code for token: HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + token := &TokenResponse{} + if jerr := json.Unmarshal(body, token); jerr != nil { + // Some providers (notably GitHub when Accept is not set) return + // application/x-www-form-urlencoded here instead of JSON. + if values, perr := url.ParseQuery(string(body)); perr == nil { + token.AccessToken = values.Get("access_token") + token.TokenType = values.Get("token_type") + token.IDToken = values.Get("id_token") + token.Scope = values.Get("scope") + } else { + return nil, fmt.Errorf("Failed to exchange authorization code for token: parse response: %w", jerr) + } + } + if token.AccessToken == "" { + return nil, fmt.Errorf("Failed to exchange authorization code for token: empty access_token") + } + return token, nil +} + +// FetchUserInfo fetches user information using the access token. +// Mirrors OAuthClient.fetch_user_info / normalize_user_info. +func (c *oauthClient) FetchUserInfo(ctx context.Context, accessToken, idToken string) (*UserInfo, error) { + if c.cfg.UserinfoURL == "" { + return nil, fmt.Errorf("Failed to fetch user info: userinfo_url is required") + } + raw, err := c.fetchUserinfoRaw(ctx, c.cfg.UserinfoURL, accessToken) + if err != nil { + return nil, fmt.Errorf("Failed to fetch user info: %w", err) + } + return normalizeUserInfo(raw), nil +} + +func (c *oauthClient) fetchUserinfoRaw(ctx context.Context, endpoint, accessToken string) (map[string]interface{}, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20)) + if err != nil { + return nil, err + } + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + var out map[string]interface{} + if err := json.Unmarshal(body, &out); err != nil { + return nil, fmt.Errorf("parse userinfo response: %w", err) + } + return out, nil +} + +// normalizeUserInfo mirrors the Python normalize_user_info defaults: username +// falls back to the email local part, nickname falls back to username, and +// avatar_url falls back to OIDC's "picture" claim. +func normalizeUserInfo(raw map[string]interface{}) *UserInfo { + ui := &UserInfo{} + if v, ok := raw["email"].(string); ok { + ui.Email = v + } + if v, ok := raw["username"].(string); ok && v != "" { + ui.Username = v + } else if ui.Email != "" { + if at := strings.IndexByte(ui.Email, '@'); at >= 0 { + ui.Username = ui.Email[:at] + } else { + ui.Username = ui.Email + } + } + if v, ok := raw["nickname"].(string); ok && v != "" { + ui.Nickname = v + } else { + ui.Nickname = ui.Username + } + if v, ok := raw["avatar_url"].(string); ok && v != "" { + ui.AvatarURL = v + } else if v, ok := raw["picture"].(string); ok { + ui.AvatarURL = v + } + return ui +} diff --git a/internal/utility/oauth/client_test.go b/internal/utility/oauth/client_test.go new file mode 100644 index 0000000000..7bfe02f358 --- /dev/null +++ b/internal/utility/oauth/client_test.go @@ -0,0 +1,313 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package oauth + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +func TestAuthorizationURLOAuth2(t *testing.T) { + c, err := NewClient(Config{ + Type: "oauth2", + ClientID: "abc", + ClientSecret: "secret", + AuthorizationURL: "https://provider.example/authorize", + TokenURL: "https://provider.example/token", + UserinfoURL: "https://provider.example/userinfo", + RedirectURI: "https://ragflow.example/api/v1/auth/oauth/myorg/callback", + Scope: "openid email", + }) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + got, err := c.AuthorizationURL("state-token") + if err != nil { + t.Fatalf("AuthorizationURL: %v", err) + } + u, perr := url.Parse(got) + if perr != nil { + t.Fatalf("parse url: %v", perr) + } + q := u.Query() + if q.Get("client_id") != "abc" { + t.Errorf("client_id: got %q", q.Get("client_id")) + } + if q.Get("redirect_uri") != "https://ragflow.example/api/v1/auth/oauth/myorg/callback" { + t.Errorf("redirect_uri: got %q", q.Get("redirect_uri")) + } + if q.Get("response_type") != "code" { + t.Errorf("response_type: got %q", q.Get("response_type")) + } + if q.Get("scope") != "openid email" { + t.Errorf("scope: got %q", q.Get("scope")) + } + if q.Get("state") != "state-token" { + t.Errorf("state: got %q", q.Get("state")) + } + if u.Host != "provider.example" { + t.Errorf("host: got %q", u.Host) + } +} + +func TestAuthorizationURLPreservesExistingQuery(t *testing.T) { + c, _ := NewClient(Config{ + Type: "oauth2", + ClientID: "abc", + AuthorizationURL: "https://provider.example/authorize?tenant=acme", + TokenURL: "https://provider.example/token", + RedirectURI: "https://ragflow.example/cb", + }) + got, err := c.AuthorizationURL("s") + if err != nil { + t.Fatal(err) + } + if !strings.Contains(got, "tenant=acme&") { + t.Errorf("existing tenant=acme query should be preserved with `&` separator, got %q", got) + } +} + +func TestNewClientUnsupportedType(t *testing.T) { + _, err := NewClient(Config{Type: "saml"}) + if err == nil || !strings.Contains(err.Error(), "Unsupported type") { + t.Fatalf("expected unsupported-type error, got %v", err) + } +} + +func TestNewClientDefaultsBasedOnIssuer(t *testing.T) { + orig := loadOIDCMetadata + defer func() { loadOIDCMetadata = orig }() + loadOIDCMetadata = func(issuer string) (*oidcMetadata, error) { + return &oidcMetadata{ + Issuer: issuer, + AuthorizationEndpoint: issuer + "/authorize", + TokenEndpoint: issuer + "/token", + UserinfoEndpoint: issuer + "/userinfo", + }, nil + } + c, err := NewClient(Config{ + Issuer: "https://issuer.example", + ClientID: "id", + RedirectURI: "https://ragflow/cb", + }) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + if _, ok := c.(*oidcClient); !ok { + t.Fatalf("expected oidcClient, got %T", c) + } + got, _ := c.AuthorizationURL("s") + if !strings.HasPrefix(got, "https://issuer.example/authorize") { + t.Errorf("authorization URL didn't pick up discovered endpoint, got %q", got) + } +} + +func TestExchangeCodeForTokenJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseForm() + if r.Form.Get("grant_type") != "authorization_code" { + t.Errorf("grant_type: got %q", r.Form.Get("grant_type")) + } + if r.Form.Get("code") != "the-code" { + t.Errorf("code: got %q", r.Form.Get("code")) + } + if r.Form.Get("client_id") != "abc" { + t.Errorf("client_id: got %q", r.Form.Get("client_id")) + } + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"the-access-token","token_type":"Bearer","id_token":"abc.def.ghi"}`) + })) + defer srv.Close() + + c, _ := NewClient(Config{ + Type: "oauth2", + ClientID: "abc", + ClientSecret: "secret", + AuthorizationURL: "https://provider/authorize", + TokenURL: srv.URL, + UserinfoURL: "https://provider/userinfo", + RedirectURI: "https://ragflow/cb", + }) + tok, err := c.ExchangeCodeForToken(context.Background(), "the-code") + if err != nil { + t.Fatalf("ExchangeCodeForToken: %v", err) + } + if tok.AccessToken != "the-access-token" { + t.Errorf("access_token: got %q", tok.AccessToken) + } + if tok.IDToken != "abc.def.ghi" { + t.Errorf("id_token: got %q", tok.IDToken) + } +} + +func TestExchangeCodeForTokenFormResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/x-www-form-urlencoded") + _, _ = io.WriteString(w, "access_token=abc&token_type=bearer&scope=user") + })) + defer srv.Close() + c, _ := NewClient(Config{ + Type: "oauth2", + ClientID: "abc", + AuthorizationURL: "https://x/authorize", + TokenURL: srv.URL, + UserinfoURL: "https://x/userinfo", + RedirectURI: "https://x/cb", + }) + tok, err := c.ExchangeCodeForToken(context.Background(), "c") + if err != nil { + t.Fatalf("ExchangeCodeForToken: %v", err) + } + if tok.AccessToken != "abc" { + t.Errorf("access_token: got %q", tok.AccessToken) + } +} + +func TestExchangeCodeForTokenError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, `{"error":"invalid_grant"}`) + })) + defer srv.Close() + c, _ := NewClient(Config{ + Type: "oauth2", + ClientID: "abc", + AuthorizationURL: "https://x/authorize", + TokenURL: srv.URL, + UserinfoURL: "https://x/userinfo", + RedirectURI: "https://x/cb", + }) + _, err := c.ExchangeCodeForToken(context.Background(), "c") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "HTTP 400") { + t.Errorf("expected HTTP 400 surfaced, got %v", err) + } +} + +func TestFetchUserInfoOAuth2Normalization(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer the-token" { + t.Errorf("Authorization header: got %q", r.Header.Get("Authorization")) + } + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"email":"alice@example.com","picture":"https://cdn/alice.png"}`) + })) + defer srv.Close() + c, _ := NewClient(Config{ + Type: "oauth2", + ClientID: "abc", + AuthorizationURL: "https://x/authorize", + TokenURL: "https://x/token", + UserinfoURL: srv.URL, + RedirectURI: "https://x/cb", + }) + info, err := c.FetchUserInfo(context.Background(), "the-token", "") + if err != nil { + t.Fatalf("FetchUserInfo: %v", err) + } + if info.Email != "alice@example.com" { + t.Errorf("email: %q", info.Email) + } + if info.Username != "alice" { + t.Errorf("username (fallback to email local): %q", info.Username) + } + if info.Nickname != "alice" { + t.Errorf("nickname (fallback to username): %q", info.Nickname) + } + if info.AvatarURL != "https://cdn/alice.png" { + t.Errorf("avatar (picture fallback): %q", info.AvatarURL) + } +} + +func TestFetchUserInfoGitHubMergesPrimaryEmail(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/user", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"login":"bob","name":"Bob Bobson","email":null,"avatar_url":"https://gh/bob.png"}`) + }) + mux.HandleFunc("/user/emails", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + emails := []map[string]interface{}{ + {"email": "bob-noreply@users.noreply.github.com", "primary": false, "verified": true}, + {"email": "bob@example.com", "primary": true, "verified": true}, + } + raw, _ := json.Marshal(emails) + _, _ = w.Write(raw) + }) + srv := httptest.NewServer(mux) + defer srv.Close() + + c, _ := NewClient(Config{ + Type: "github", + ClientID: "abc", + ClientSecret: "secret", + RedirectURI: "https://ragflow/cb", + }) + gh := c.(*gitHubClient) + gh.cfg.UserinfoURL = srv.URL + "/user" + + info, err := gh.FetchUserInfo(context.Background(), "tok", "") + if err != nil { + t.Fatalf("FetchUserInfo: %v", err) + } + if info.Email != "bob@example.com" { + t.Errorf("primary email merge failed: got %q", info.Email) + } + if info.Username != "bob" { + t.Errorf("username (from login): got %q", info.Username) + } + if info.Nickname != "Bob Bobson" { + t.Errorf("nickname (from name): got %q", info.Nickname) + } +} + +func TestNormalizeUserInfoEmptyEmail(t *testing.T) { + ui := normalizeUserInfo(map[string]interface{}{"email": ""}) + if ui.Email != "" || ui.Username != "" || ui.Nickname != "" { + t.Errorf("empty email should produce empty fields, got %+v", ui) + } +} + +func TestNewOAuthClientMissingFields(t *testing.T) { + cases := []Config{ + {Type: "oauth2"}, + {Type: "oauth2", ClientID: "x"}, + {Type: "oauth2", ClientID: "x", AuthorizationURL: "https://x/a"}, + {Type: "oauth2", ClientID: "x", AuthorizationURL: "https://x/a", TokenURL: "https://x/t"}, + } + for i, cfg := range cases { + if _, err := NewClient(cfg); err == nil { + t.Errorf("case %d: expected error for missing required field", i) + } + } +} + +func TestNewOIDCMissingIssuer(t *testing.T) { + _, err := NewClient(Config{Type: "oidc"}) + if err == nil || !strings.Contains(err.Error(), "Missing issuer") { + t.Errorf("expected Missing issuer error, got %v", err) + } +} + diff --git a/internal/utility/oauth/github.go b/internal/utility/oauth/github.go new file mode 100644 index 0000000000..d002fd5a1e --- /dev/null +++ b/internal/utility/oauth/github.go @@ -0,0 +1,134 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package oauth + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" +) + +// gitHubClient overrides the OAuth endpoints with GitHub's well-known URLs +// and reaches into /user/emails to recover the primary email, since +// GitHub's /user response omits it when the user has hidden email +// visibility. +type gitHubClient struct { + *oauthClient +} + +func newGitHubClient(cfg Config) (*gitHubClient, error) { + cfg.AuthorizationURL = "https://github.com/login/oauth/authorize" + cfg.TokenURL = "https://github.com/login/oauth/access_token" + cfg.UserinfoURL = "https://api.github.com/user" + if cfg.Scope == "" { + cfg.Scope = "user:email" + } + base, err := newOAuthClient(cfg) + if err != nil { + return nil, err + } + return &gitHubClient{oauthClient: base}, nil +} + +// FetchUserInfo overrides the base implementation to merge the primary +// email from /user/emails. Mirrors GithubOAuthClient.fetch_user_info. +func (c *gitHubClient) FetchUserInfo(ctx context.Context, accessToken, idToken string) (*UserInfo, error) { + raw, err := c.fetchUserinfoRaw(ctx, c.cfg.UserinfoURL, accessToken) + if err != nil { + return nil, fmt.Errorf("Failed to fetch github user info: %w", err) + } + email, err := c.fetchPrimaryEmail(ctx, accessToken) + if err != nil { + return nil, fmt.Errorf("Failed to fetch github user info: %w", err) + } + if email != "" { + raw["email"] = email + } + return normalizeGitHubUserInfo(raw), nil +} + +func (c *gitHubClient) fetchPrimaryEmail(ctx context.Context, accessToken string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.cfg.UserinfoURL+"/emails", nil) + if err != nil { + return "", err + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + resp, err := c.httpClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return "", err + } + if resp.StatusCode >= 400 { + return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + var emails []map[string]interface{} + if err := json.Unmarshal(body, &emails); err != nil { + return "", fmt.Errorf("parse emails response: %w", err) + } + for _, e := range emails { + primary, _ := e["primary"].(bool) + addr, _ := e["email"].(string) + if primary && addr != "" { + return addr, nil + } + } + // Fall back to the first verified email if no primary is flagged. + for _, e := range emails { + verified, _ := e["verified"].(bool) + addr, _ := e["email"].(string) + if verified && addr != "" { + return addr, nil + } + } + return "", nil +} + +// normalizeGitHubUserInfo mirrors GithubOAuthClient.normalize_user_info: +// username comes from "login", nickname from "name", avatar from +// "avatar_url". +func normalizeGitHubUserInfo(raw map[string]interface{}) *UserInfo { + ui := &UserInfo{} + if v, ok := raw["email"].(string); ok { + ui.Email = v + } + if v, ok := raw["login"].(string); ok && v != "" { + ui.Username = v + } else if ui.Email != "" { + if at := strings.IndexByte(ui.Email, '@'); at >= 0 { + ui.Username = ui.Email[:at] + } else { + ui.Username = ui.Email + } + } + if v, ok := raw["name"].(string); ok && v != "" { + ui.Nickname = v + } else { + ui.Nickname = ui.Username + } + if v, ok := raw["avatar_url"].(string); ok { + ui.AvatarURL = v + } + return ui +} diff --git a/internal/utility/oauth/oidc.go b/internal/utility/oauth/oidc.go new file mode 100644 index 0000000000..9ca68f9c4c --- /dev/null +++ b/internal/utility/oauth/oidc.go @@ -0,0 +1,107 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package oauth + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" +) + +// oidcClient is the OIDC flavor: it resolves the authorization / +// token / userinfo URLs from the Issuer's discovery document +// (.well-known/openid-configuration) before delegating the OAuth flow to +// oauthClient. +// +// We do not currently verify or parse the id_token: id_token claims are +// only used as enrichment in the Python implementation, and the /userinfo +// endpoint (already called) returns the same canonical claims authenticated +// via the access_token. Tracked as a follow-up to add full JWKS-based +// id_token verification once a JWT library is vendored. +type oidcClient struct { + *oauthClient + issuer string +} + +func newOIDCClient(cfg Config) (*oidcClient, error) { + if cfg.Issuer == "" { + return nil, fmt.Errorf("Missing issuer in configuration.") + } + meta, err := loadOIDCMetadata(cfg.Issuer) + if err != nil { + return nil, fmt.Errorf("Failed to fetch OIDC metadata: %w", err) + } + if meta.Issuer != "" { + cfg.Issuer = meta.Issuer + } + if cfg.AuthorizationURL == "" { + cfg.AuthorizationURL = meta.AuthorizationEndpoint + } + if cfg.TokenURL == "" { + cfg.TokenURL = meta.TokenEndpoint + } + if cfg.UserinfoURL == "" { + cfg.UserinfoURL = meta.UserinfoEndpoint + } + base, err := newOAuthClient(cfg) + if err != nil { + return nil, err + } + return &oidcClient{oauthClient: base, issuer: cfg.Issuer}, nil +} + +// oidcMetadata is the subset of fields we use from the discovery document. +type oidcMetadata struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserinfoEndpoint string `json:"userinfo_endpoint"` + JWKSURI string `json:"jwks_uri"` +} + +// loadOIDCMetadata is the indirection used to fetch a provider's discovery +// document. Tests override it. +var loadOIDCMetadata = func(issuer string) (*oidcMetadata, error) { + metadataURL := strings.TrimRight(issuer, "/") + "/.well-known/openid-configuration" + ctx, cancel := context.WithTimeout(context.Background(), HTTPRequestTimeout) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, metadataURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + body, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20)) + if err != nil { + return nil, err + } + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + meta := &oidcMetadata{} + if err := json.Unmarshal(body, meta); err != nil { + return nil, fmt.Errorf("parse OIDC metadata: %w", err) + } + return meta, nil +}