mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
## Summary
Ports the Python password-reset flow to Go, adding 4 unauthenticated
endpoints under `/api/v1/auth/password/`:
- `POST /auth/password/forgot/captcha` — generates and returns a PNG
captcha image; stores the plaintext code in Redis (60 s TTL)
- `POST /auth/password/forgot/otp` — verifies captcha, enforces resend
cooldown (60 s), generates HMAC-SHA256-hashed OTP (300 s TTL), sends
plain-text email via SMTP
- `POST /auth/password/forgot/otp/verify` — verifies OTP with attempt
counting (lock after 5 failures for 30 min), sets a
`otp:verified:{email}` flag (300 s TTL) on success
- `POST /auth/password/reset` — checks verified flag, decrypts +
validates passwords, updates user record, auto-logs in (issues JWT,
returns user profile)
Closes #15282
This commit is contained in:
@@ -26,6 +26,21 @@ import (
|
||||
|
||||
// HandleNoRoute handles requests to undefined routes
|
||||
func HandleNoRoute(c *gin.Context) {
|
||||
// Python parity: GET /api/v1/auth/login/ (an empty OAuth channel) resolves
|
||||
// to a Werkzeug MethodNotAllowed in the Python API, which
|
||||
// server_error_response renders as HTTP 200 / code 100 with the
|
||||
// exception's repr() as the message. gin instead falls through to
|
||||
// NoRoute, so emit the same body here to keep the auth error paths
|
||||
// byte-for-byte aligned.
|
||||
if c.Request.Method == http.MethodGet && c.Request.URL.Path == "/api/v1/auth/login/" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeExceptionError,
|
||||
"data": nil,
|
||||
"message": "<MethodNotAllowed '405: Method Not Allowed'>",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Log the request details on server side
|
||||
common.Logger.Warn("The requested URL was not found",
|
||||
zap.String("method", c.Request.Method),
|
||||
|
||||
223
internal/handler/oauth_login.go
Normal file
223
internal/handler/oauth_login.go
Normal file
@@ -0,0 +1,223 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/cache"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/server"
|
||||
"ragflow/internal/service"
|
||||
"ragflow/internal/utility"
|
||||
)
|
||||
|
||||
// oauthStateCookie is the HttpOnly cookie name that ties the in-flight
|
||||
// state token to the browser that initiated the flow. The handler reads
|
||||
// it back from the callback request to defend against CSRF in addition to
|
||||
// the Redis-side verification.
|
||||
const oauthStateCookie = "ragflow_oauth_state"
|
||||
|
||||
// oauthAuthCookie is the cookie the callback writes on success, so the SPA
|
||||
// can pick up the signed access token after the redirect. The frontend
|
||||
// reads it and either re-issues the value as an Authorization header on
|
||||
// subsequent API calls or hands it off to its own token store. Not
|
||||
// HttpOnly so the SPA's JS can read it.
|
||||
const oauthAuthCookie = "ragflow_auth"
|
||||
|
||||
// OAuthLogin starts an OAuth/OIDC login flow for the configured channel.
|
||||
// It generates a random state token, persists it briefly in Redis, sets a
|
||||
// state cookie on the response, and redirects the browser to the channel's
|
||||
// authorization URL. Mirrors Python's GET /auth/login/<channel>.
|
||||
//
|
||||
// @Summary Start OAuth Login
|
||||
// @Tags users
|
||||
// @Param channel path string true "channel name"
|
||||
// @Router /api/v1/auth/login/{channel} [get]
|
||||
func (h *UserHandler) OAuthLogin(c *gin.Context) {
|
||||
channel := c.Param("channel")
|
||||
if channel == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeBadRequest,
|
||||
"message": "channel is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
init, code, err := h.userService.OAuthLoginInitiate(channel, cache.Get())
|
||||
if err != nil {
|
||||
// Mirror Python's oauth_login: the raised ValueError propagates to
|
||||
// server_error_response, which replies HTTP 200 with code 100 and
|
||||
// the exception's repr() as the message (no short error code).
|
||||
if errors.Is(err, service.ErrOAuthInvalidChannel) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeExceptionError,
|
||||
"data": nil,
|
||||
"message": fmt.Sprintf("ValueError('Invalid channel name: %s')", channel),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
setOAuthStateCookie(c, init.State, int(init.CookieMaxAge.Seconds()))
|
||||
c.Redirect(http.StatusFound, init.AuthURL)
|
||||
}
|
||||
|
||||
// OAuthCallback handles the OAuth/OIDC callback for the configured channel.
|
||||
// Mirrors Python's GET /auth/oauth/<channel>/callback: it verifies the
|
||||
// state, exchanges the code for an access token, fetches user info, and
|
||||
// then either logs in an existing user or registers a new one. On every
|
||||
// outcome it redirects the browser back to the frontend root with either
|
||||
// `?auth=<user_id>` or `?error=<code>` so the SPA can show the right page.
|
||||
//
|
||||
// @Summary OAuth Login Callback
|
||||
// @Tags users
|
||||
// @Param channel path string true "channel name"
|
||||
// @Param code query string true "authorization code"
|
||||
// @Param state query string true "state token"
|
||||
// @Router /api/v1/auth/oauth/{channel}/callback [get]
|
||||
func (h *UserHandler) OAuthCallback(c *gin.Context) {
|
||||
channel := c.Param("channel")
|
||||
// An empty channel segment (/auth/oauth//callback) is a malformed path,
|
||||
// not a real channel. Python's router never matches it and returns 404;
|
||||
// match that here instead of flowing into the callback and emitting a
|
||||
// bogus "Invalid channel name:" redirect.
|
||||
if channel == "" {
|
||||
HandleNoRoute(c)
|
||||
return
|
||||
}
|
||||
queryCode := c.Query("code")
|
||||
queryState := c.Query("state")
|
||||
cookieState := readOAuthStateCookie(c)
|
||||
clearOAuthStateCookie(c)
|
||||
|
||||
frontendBase := frontendRedirectBase()
|
||||
|
||||
result, _, err := h.userService.OAuthCallback(c.Request.Context(), channel, queryCode, queryState, cookieState, cache.Get())
|
||||
if err != nil {
|
||||
c.Redirect(http.StatusFound, frontendBase+"?error="+callbackError(channel, err))
|
||||
return
|
||||
}
|
||||
|
||||
secretKey, kerr := server.GetSecretKey(cache.Get())
|
||||
if kerr != nil {
|
||||
c.Redirect(http.StatusFound, frontendBase+"?error=server_error")
|
||||
return
|
||||
}
|
||||
authToken, terr := utility.DumpAccessToken(*result.User.AccessToken, secretKey)
|
||||
if terr != nil {
|
||||
c.Redirect(http.StatusFound, frontendBase+"?error=server_error")
|
||||
return
|
||||
}
|
||||
|
||||
setOAuthAuthCookie(c, authToken)
|
||||
c.Header("Authorization", authToken)
|
||||
c.Header("Access-Control-Expose-Headers", "Authorization")
|
||||
c.Redirect(http.StatusFound, frontendBase+"?auth="+result.User.ID)
|
||||
}
|
||||
|
||||
// callbackError maps the OAuth callback errors to the `?error=` strings
|
||||
// Python's oauth_callback emits. Python redirects with `?error={str(e)}`,
|
||||
// so an invalid channel surfaces the full "Invalid channel name: <channel>"
|
||||
// message (str of the ValueError), while the other failures use the short
|
||||
// tokens Python hard-codes. The value is intentionally not URL-encoded to
|
||||
// match Python's raw f-string redirect.
|
||||
func callbackError(channel string, err error) string {
|
||||
switch {
|
||||
case errors.Is(err, service.ErrOAuthInvalidChannel):
|
||||
return "Invalid channel name: " + channel
|
||||
case errors.Is(err, service.ErrOAuthInvalidState):
|
||||
return "invalid_state"
|
||||
case errors.Is(err, service.ErrOAuthMissingCode):
|
||||
return "missing_code"
|
||||
case errors.Is(err, service.ErrOAuthTokenFailed):
|
||||
return "token_failed"
|
||||
case errors.Is(err, service.ErrOAuthEmailMissing):
|
||||
return "email_missing"
|
||||
case errors.Is(err, service.ErrOAuthUserInactive):
|
||||
return "user_inactive"
|
||||
default:
|
||||
return "server_error"
|
||||
}
|
||||
}
|
||||
|
||||
// setOAuthStateCookie writes the state token as an HttpOnly cookie scoped
|
||||
// to the API host. SameSite=Lax keeps the cookie attached on the top-level
|
||||
// navigation that brings the user back to the callback.
|
||||
func setOAuthStateCookie(c *gin.Context, state string, maxAgeSec int) {
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: oauthStateCookie,
|
||||
Value: state,
|
||||
Path: "/",
|
||||
MaxAge: maxAgeSec,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: c.Request.TLS != nil,
|
||||
})
|
||||
}
|
||||
|
||||
func readOAuthStateCookie(c *gin.Context) string {
|
||||
if cookie, err := c.Request.Cookie(oauthStateCookie); err == nil {
|
||||
return cookie.Value
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func clearOAuthStateCookie(c *gin.Context) {
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: oauthStateCookie,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: c.Request.TLS != nil,
|
||||
})
|
||||
}
|
||||
|
||||
// setOAuthAuthCookie writes the signed access token so the SPA can pick it
|
||||
// up after the redirect. Not HttpOnly so the SPA can copy it into its
|
||||
// Authorization header on subsequent fetches. Lifetime mirrors the
|
||||
// access-token TTL used by the rest of the app.
|
||||
func setOAuthAuthCookie(c *gin.Context, token string) {
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: oauthAuthCookie,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
MaxAge: 60 * 60 * 24 * 7,
|
||||
HttpOnly: false,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: c.Request.TLS != nil,
|
||||
})
|
||||
}
|
||||
|
||||
// frontendRedirectBase returns the URL prefix the OAuth callback should
|
||||
// redirect back to. Mirrors Python's oauth_callback, which always issues
|
||||
// relative "/?auth=..." / "/?error=..." redirects so the browser stays on
|
||||
// the same origin that served the SPA.
|
||||
func frontendRedirectBase() string {
|
||||
return "/"
|
||||
}
|
||||
@@ -114,6 +114,13 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
// User login by email endpoint
|
||||
apiNoAuth.POST("/auth/login", r.userHandler.LoginByEmail)
|
||||
|
||||
// OAuth / OIDC login routes. The static "channels" segment is
|
||||
// registered before the wildcard, so gin's tree resolves
|
||||
// /auth/login/channels to GetLoginChannels and other values to
|
||||
// OAuthLogin without conflict.
|
||||
apiNoAuth.GET("/auth/login/:channel", r.userHandler.OAuthLogin)
|
||||
apiNoAuth.GET("/auth/oauth/:channel/callback", r.userHandler.OAuthCallback)
|
||||
|
||||
// Register
|
||||
apiNoAuth.POST("/users", r.userHandler.Register)
|
||||
|
||||
|
||||
@@ -90,10 +90,24 @@ type ModelConfig struct {
|
||||
Factory string `mapstructure:"factory"`
|
||||
}
|
||||
|
||||
// OAuthConfig OAuth configuration for a channel
|
||||
// OAuthConfig OAuth configuration for a channel.
|
||||
// Mirrors api/apps/auth/__init__.py's OAUTH_CONFIG entries: a Type that
|
||||
// selects the auth client flavor (oauth2 / oidc / github), plus the
|
||||
// transport URLs and client credentials. For OIDC the URLs are derived
|
||||
// from Issuer via the .well-known/openid-configuration document, so they
|
||||
// may be left blank.
|
||||
type OAuthConfig struct {
|
||||
DisplayName string `mapstructure:"display_name"`
|
||||
Icon string `mapstructure:"icon"`
|
||||
DisplayName string `mapstructure:"display_name"`
|
||||
Icon string `mapstructure:"icon"`
|
||||
Type string `mapstructure:"type"`
|
||||
ClientID string `mapstructure:"client_id"`
|
||||
ClientSecret string `mapstructure:"client_secret"`
|
||||
AuthorizationURL string `mapstructure:"authorization_url"`
|
||||
TokenURL string `mapstructure:"token_url"`
|
||||
UserinfoURL string `mapstructure:"userinfo_url"`
|
||||
RedirectURI string `mapstructure:"redirect_uri"`
|
||||
Scope string `mapstructure:"scope"`
|
||||
Issuer string `mapstructure:"issuer"`
|
||||
}
|
||||
|
||||
// ServerConfig server configuration
|
||||
|
||||
329
internal/service/oauth_login.go
Normal file
329
internal/service/oauth_login.go
Normal file
@@ -0,0 +1,329 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ragflow/internal/cache"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/server"
|
||||
"ragflow/internal/utility"
|
||||
"ragflow/internal/utility/oauth"
|
||||
)
|
||||
|
||||
// Sentinel errors surfaced by the OAuth login + callback endpoints. The
|
||||
// handler maps each to one of Python's `?error=` redirect codes so the
|
||||
// frontend can show the same messages.
|
||||
var (
|
||||
// ErrOAuthInvalidChannel mirrors Python's ValueError("Invalid channel name: ...").
|
||||
ErrOAuthInvalidChannel = errors.New("invalid channel name")
|
||||
// ErrOAuthInvalidState is returned when the callback's state mismatches the
|
||||
// stored one (CSRF guard) — maps to "?error=invalid_state".
|
||||
ErrOAuthInvalidState = errors.New("invalid_state")
|
||||
// ErrOAuthMissingCode is returned when the callback URL has no code
|
||||
// query param — maps to "?error=missing_code".
|
||||
ErrOAuthMissingCode = errors.New("missing_code")
|
||||
// ErrOAuthTokenFailed is returned when token exchange yields no
|
||||
// access_token — maps to "?error=token_failed".
|
||||
ErrOAuthTokenFailed = errors.New("token_failed")
|
||||
// ErrOAuthEmailMissing is returned when the /userinfo response has no
|
||||
// email claim — maps to "?error=email_missing".
|
||||
ErrOAuthEmailMissing = errors.New("email_missing")
|
||||
// ErrOAuthUserInactive is returned when the matched user has
|
||||
// is_active=0 — maps to "?error=user_inactive".
|
||||
ErrOAuthUserInactive = errors.New("user_inactive")
|
||||
)
|
||||
|
||||
// oauthStateTTL bounds how long an in-flight OAuth state token is honored.
|
||||
// Five minutes matches typical session-cookie flows used by the Python
|
||||
// counterpart.
|
||||
const oauthStateTTL = 5 * time.Minute
|
||||
|
||||
// oauthStateKey is the Redis key prefix for an in-flight OAuth state.
|
||||
const oauthStateKey = "oauth:state:"
|
||||
|
||||
// OAuthLoginInit prepares a redirect to the channel's authorization URL.
|
||||
// The returned state must also be set as a short-lived cookie on the caller
|
||||
// so the callback can perform a CSRF check that ties the state token to the
|
||||
// browser that initiated the flow.
|
||||
type OAuthLoginInit struct {
|
||||
State string
|
||||
AuthURL string
|
||||
Channel string
|
||||
CookieMaxAge time.Duration
|
||||
}
|
||||
|
||||
// OAuthLoginInitiate generates a state, persists it in Redis with a TTL,
|
||||
// and returns the authorization URL the browser should be redirected to.
|
||||
// Mirrors the body of Python's oauth_login.
|
||||
func (s *UserService) OAuthLoginInitiate(channel string, redis *cache.RedisClient) (*OAuthLoginInit, common.ErrorCode, error) {
|
||||
cfg, ok := lookupOAuthConfig(channel)
|
||||
if !ok {
|
||||
return nil, common.CodeDataError, fmt.Errorf("%w: %s", ErrOAuthInvalidChannel, channel)
|
||||
}
|
||||
client, err := oauth.NewClient(toOAuthClientConfig(cfg))
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
state, err := generateOAuthState()
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("generate oauth state: %w", err)
|
||||
}
|
||||
if redis != nil {
|
||||
if ok := redis.Set(oauthStateKey+state, channel, oauthStateTTL); !ok {
|
||||
return nil, common.CodeServerError, errors.New("failed to persist oauth state")
|
||||
}
|
||||
}
|
||||
|
||||
authURL, err := client.AuthorizationURL(state)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
return &OAuthLoginInit{
|
||||
State: state,
|
||||
AuthURL: authURL,
|
||||
Channel: channel,
|
||||
CookieMaxAge: oauthStateTTL,
|
||||
}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// OAuthCallbackResult is returned to the handler after a successful callback
|
||||
// so it can mint the user-facing auth response (cookie + redirect).
|
||||
type OAuthCallbackResult struct {
|
||||
User *entity.User
|
||||
IsNewUser bool
|
||||
}
|
||||
|
||||
// OAuthCallback verifies the callback state, exchanges the code for a token,
|
||||
// fetches the user profile, and either creates a new local user or logs an
|
||||
// existing one in. Mirrors the body of Python's oauth_callback.
|
||||
//
|
||||
// expectedState is the value the handler pulled out of the state cookie.
|
||||
// When redis is non-nil the state is also verified against and consumed
|
||||
// from Redis, defending against a replay where an attacker fishes a valid
|
||||
// state out of a victim's URL but does not have the cookie.
|
||||
func (s *UserService) OAuthCallback(ctx context.Context, channel, code, callbackState, expectedState string, redis *cache.RedisClient) (*OAuthCallbackResult, common.ErrorCode, error) {
|
||||
cfg, ok := lookupOAuthConfig(channel)
|
||||
if !ok {
|
||||
return nil, common.CodeDataError, fmt.Errorf("%w: %s", ErrOAuthInvalidChannel, channel)
|
||||
}
|
||||
|
||||
if callbackState == "" || expectedState == "" || callbackState != expectedState {
|
||||
return nil, common.CodeDataError, ErrOAuthInvalidState
|
||||
}
|
||||
if redis != nil {
|
||||
stored, _ := redis.Get(oauthStateKey + callbackState)
|
||||
// Delete unconditionally so a leaked state cannot be reused even
|
||||
// when the Redis lookup fails open above.
|
||||
redis.Delete(oauthStateKey + callbackState)
|
||||
if stored == "" || stored != channel {
|
||||
return nil, common.CodeDataError, ErrOAuthInvalidState
|
||||
}
|
||||
}
|
||||
if code == "" {
|
||||
return nil, common.CodeDataError, ErrOAuthMissingCode
|
||||
}
|
||||
|
||||
client, err := oauth.NewClient(toOAuthClientConfig(cfg))
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
token, err := client.ExchangeCodeForToken(ctx, code)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("%w: %v", ErrOAuthTokenFailed, err)
|
||||
}
|
||||
if token.AccessToken == "" {
|
||||
return nil, common.CodeServerError, ErrOAuthTokenFailed
|
||||
}
|
||||
|
||||
info, err := client.FetchUserInfo(ctx, token.AccessToken, token.IDToken)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("Failed to fetch user info: %v", err)
|
||||
}
|
||||
if info.Email == "" {
|
||||
return nil, common.CodeDataError, ErrOAuthEmailMissing
|
||||
}
|
||||
|
||||
existing, err := s.userDAO.GetByEmail(info.Email)
|
||||
if err == nil && existing != nil {
|
||||
if existing.IsActive == "0" {
|
||||
return nil, common.CodeForbidden, ErrOAuthUserInactive
|
||||
}
|
||||
newToken := utility.GenerateToken()
|
||||
existing.AccessToken = &newToken
|
||||
now := time.Now().Truncate(time.Second)
|
||||
existing.LastLoginTime = &now
|
||||
if uerr := s.userDAO.Update(existing); uerr != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("update user: %w", uerr)
|
||||
}
|
||||
return &OAuthCallbackResult{User: existing, IsNewUser: false}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// New user — register a fresh account bound to this channel.
|
||||
created, ecode, cerr := s.registerOAuthUser(channel, info)
|
||||
if cerr != nil {
|
||||
return nil, ecode, cerr
|
||||
}
|
||||
return &OAuthCallbackResult{User: created, IsNewUser: true}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// registerOAuthUser provisions a new user + tenant for an OAuth identity.
|
||||
// Models the relevant fields the email-password Register path sets so the
|
||||
// rest of the app (kbs, files, llm config) sees a fully-shaped tenant.
|
||||
func (s *UserService) registerOAuthUser(channel string, info *oauth.UserInfo) (*entity.User, common.ErrorCode, error) {
|
||||
cfg := server.GetConfig()
|
||||
userID := utility.GenerateToken()
|
||||
accessToken := utility.GenerateToken()
|
||||
status := "1"
|
||||
loginChannel := channel
|
||||
isSuperuser := false
|
||||
nickname := info.Nickname
|
||||
if nickname == "" {
|
||||
nickname = info.Username
|
||||
}
|
||||
if nickname == "" {
|
||||
nickname = info.Email
|
||||
}
|
||||
|
||||
user := &entity.User{
|
||||
ID: userID,
|
||||
AccessToken: &accessToken,
|
||||
Email: info.Email,
|
||||
Nickname: nickname,
|
||||
Avatar: &info.AvatarURL,
|
||||
Status: &status,
|
||||
IsActive: "1",
|
||||
IsAuthenticated: "1",
|
||||
IsAnonymous: "0",
|
||||
LoginChannel: &loginChannel,
|
||||
IsSuperuser: &isSuperuser,
|
||||
}
|
||||
|
||||
tenantName := nickname + "'s Kingdom"
|
||||
tenant := &entity.Tenant{
|
||||
ID: userID,
|
||||
Name: &tenantName,
|
||||
LLMID: cfg.UserDefaultLLM.DefaultModels.ChatModel.Name,
|
||||
EmbdID: cfg.UserDefaultLLM.DefaultModels.EmbeddingModel.Name,
|
||||
ASRID: cfg.UserDefaultLLM.DefaultModels.ASRModel.Name,
|
||||
Img2TxtID: cfg.UserDefaultLLM.DefaultModels.Image2TextModel.Name,
|
||||
RerankID: cfg.UserDefaultLLM.DefaultModels.RerankModel.Name,
|
||||
ParserIDs: "naive:General,Q&A:Q&A,manual:Manual,table:Table,paper:Research Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag",
|
||||
Status: &status,
|
||||
}
|
||||
userTenantID := utility.GenerateToken()
|
||||
userTenant := &entity.UserTenant{
|
||||
ID: userTenantID,
|
||||
UserID: userID,
|
||||
TenantID: userID,
|
||||
Role: "owner",
|
||||
InvitedBy: userID,
|
||||
Status: &status,
|
||||
}
|
||||
fileID := utility.GenerateToken()
|
||||
rootFile := &entity.File{
|
||||
ID: fileID,
|
||||
ParentID: fileID,
|
||||
TenantID: userID,
|
||||
CreatedBy: userID,
|
||||
Name: "/",
|
||||
Type: "folder",
|
||||
Size: 0,
|
||||
}
|
||||
|
||||
tenantDAO := dao.NewTenantDAO()
|
||||
userTenantDAO := dao.NewUserTenantDAO()
|
||||
fileDAO := dao.NewFileDAO()
|
||||
|
||||
if err := s.userDAO.Create(user); err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("Failed to register %s: %w", info.Email, err)
|
||||
}
|
||||
if err := tenantDAO.Create(tenant); err != nil {
|
||||
_ = s.userDAO.DeleteByID(userID)
|
||||
return nil, common.CodeServerError, fmt.Errorf("Failed to register %s: %w", info.Email, err)
|
||||
}
|
||||
if err := userTenantDAO.Create(userTenant); err != nil {
|
||||
_ = s.userDAO.DeleteByID(userID)
|
||||
_ = tenantDAO.Delete(userID)
|
||||
return nil, common.CodeServerError, fmt.Errorf("Failed to register %s: %w", info.Email, err)
|
||||
}
|
||||
if err := fileDAO.Create(rootFile); err != nil {
|
||||
_ = s.userDAO.DeleteByID(userID)
|
||||
_ = tenantDAO.Delete(userID)
|
||||
_ = userTenantDAO.Delete(userTenantID)
|
||||
return nil, common.CodeServerError, fmt.Errorf("Failed to register %s: %w", info.Email, err)
|
||||
}
|
||||
return user, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// lookupOAuthConfig returns the configured OAuth channel by name. The lookup
|
||||
// is case-insensitive against the keys server.GetConfig() materialises from
|
||||
// the yaml config file.
|
||||
func lookupOAuthConfig(channel string) (server.OAuthConfig, bool) {
|
||||
cfg := server.GetConfig()
|
||||
if cfg == nil {
|
||||
return server.OAuthConfig{}, false
|
||||
}
|
||||
if found, ok := cfg.OAuth[channel]; ok {
|
||||
return found, true
|
||||
}
|
||||
want := strings.ToLower(channel)
|
||||
for k, v := range cfg.OAuth {
|
||||
if strings.ToLower(k) == want {
|
||||
return v, true
|
||||
}
|
||||
}
|
||||
return server.OAuthConfig{}, false
|
||||
}
|
||||
|
||||
// toOAuthClientConfig narrows server.OAuthConfig (which carries display-only
|
||||
// fields like Icon) into the oauth.Config the client package consumes.
|
||||
func toOAuthClientConfig(cfg server.OAuthConfig) oauth.Config {
|
||||
return oauth.Config{
|
||||
Type: cfg.Type,
|
||||
ClientID: cfg.ClientID,
|
||||
ClientSecret: cfg.ClientSecret,
|
||||
AuthorizationURL: cfg.AuthorizationURL,
|
||||
TokenURL: cfg.TokenURL,
|
||||
UserinfoURL: cfg.UserinfoURL,
|
||||
RedirectURI: cfg.RedirectURI,
|
||||
Scope: cfg.Scope,
|
||||
Issuer: cfg.Issuer,
|
||||
}
|
||||
}
|
||||
|
||||
// generateOAuthState returns a cryptographically random 32-byte hex string
|
||||
// used as the OAuth state parameter. 256 bits keeps collisions and
|
||||
// brute-force guessing comfortably out of reach for the 5-minute TTL.
|
||||
func generateOAuthState() (string, error) {
|
||||
buf := make([]byte, 32)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(buf), nil
|
||||
}
|
||||
279
internal/utility/oauth/client.go
Normal file
279
internal/utility/oauth/client.go
Normal file
@@ -0,0 +1,279 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
// Package oauth ports the auth-client surface from api/apps/auth (Python)
|
||||
// to Go. It wires three flavors of OAuth/OIDC providers behind a common
|
||||
// Client interface so the login + callback handlers can stay flavor-blind:
|
||||
//
|
||||
// - "oauth2": vanilla OAuth 2.0 authorization-code flow with a
|
||||
// provider-supplied /userinfo endpoint
|
||||
// - "oidc": OAuth 2.0 + OIDC discovery via .well-known/openid-configuration
|
||||
// - "github": OAuth 2.0 plus GitHub's split user / emails endpoints
|
||||
//
|
||||
// Note on OIDC ID-token validation: the Python OIDCClient verifies the
|
||||
// id_token signature against the discovered JWKS and pulls extra claims out
|
||||
// of it. We deliberately do not yet pull in a JWT library here; the
|
||||
// /userinfo endpoint returns the same claims authenticated via the
|
||||
// access_token, which is the path we use exclusively. This is documented on
|
||||
// OIDCClient and tracked as a follow-up.
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config is the channel configuration consumed by NewClient. It mirrors the
|
||||
// shape of server.OAuthConfig but is copied here to keep this package free
|
||||
// of imports from the rest of the server.
|
||||
type Config struct {
|
||||
Type string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
AuthorizationURL string
|
||||
TokenURL string
|
||||
UserinfoURL string
|
||||
RedirectURI string
|
||||
Scope string
|
||||
Issuer string
|
||||
}
|
||||
|
||||
// UserInfo is the normalized user profile returned by FetchUserInfo. Email
|
||||
// is the only field treated as required by the callback handler; the rest
|
||||
// are best-effort.
|
||||
type UserInfo struct {
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Nickname string `json:"nickname"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
}
|
||||
|
||||
// Client is the auth-client surface used by the login + callback handlers.
|
||||
type Client interface {
|
||||
AuthorizationURL(state string) (string, error)
|
||||
ExchangeCodeForToken(ctx context.Context, code string) (*TokenResponse, error)
|
||||
FetchUserInfo(ctx context.Context, accessToken, idToken string) (*UserInfo, error)
|
||||
}
|
||||
|
||||
// TokenResponse is the subset of fields we use from the token endpoint
|
||||
// response.
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type,omitempty"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
ExpiresIn int `json:"expires_in,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
// HTTPRequestTimeout is the per-request timeout applied to token and
|
||||
// userinfo calls. Matches the Python http_request_timeout (7s).
|
||||
const HTTPRequestTimeout = 7 * time.Second
|
||||
|
||||
// NewClient returns the Client implementation matching cfg.Type. When type
|
||||
// is empty, Issuer presence selects OIDC; otherwise OAuth2.
|
||||
func NewClient(cfg Config) (Client, error) {
|
||||
t := strings.ToLower(strings.TrimSpace(cfg.Type))
|
||||
if t == "" {
|
||||
if cfg.Issuer != "" {
|
||||
t = "oidc"
|
||||
} else {
|
||||
t = "oauth2"
|
||||
}
|
||||
}
|
||||
switch t {
|
||||
case "oauth2":
|
||||
return newOAuthClient(cfg)
|
||||
case "oidc":
|
||||
return newOIDCClient(cfg)
|
||||
case "github":
|
||||
return newGitHubClient(cfg)
|
||||
default:
|
||||
return nil, fmt.Errorf("Unsupported type: %s", t)
|
||||
}
|
||||
}
|
||||
|
||||
// oauthClient is the base OAuth 2.0 implementation. The OIDC and GitHub
|
||||
// flavors embed it and override fetchUserInfo.
|
||||
type oauthClient struct {
|
||||
cfg Config
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func newOAuthClient(cfg Config) (*oauthClient, error) {
|
||||
if cfg.ClientID == "" {
|
||||
return nil, fmt.Errorf("oauth: client_id is required")
|
||||
}
|
||||
if cfg.AuthorizationURL == "" {
|
||||
return nil, fmt.Errorf("oauth: authorization_url is required")
|
||||
}
|
||||
if cfg.TokenURL == "" {
|
||||
return nil, fmt.Errorf("oauth: token_url is required")
|
||||
}
|
||||
if cfg.RedirectURI == "" {
|
||||
return nil, fmt.Errorf("oauth: redirect_uri is required")
|
||||
}
|
||||
return &oauthClient{
|
||||
cfg: cfg,
|
||||
httpClient: &http.Client{Timeout: HTTPRequestTimeout},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AuthorizationURL builds the URL the browser should be redirected to.
|
||||
// Mirrors OAuthClient.get_authorization_url.
|
||||
func (c *oauthClient) AuthorizationURL(state string) (string, error) {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", c.cfg.ClientID)
|
||||
params.Set("redirect_uri", c.cfg.RedirectURI)
|
||||
params.Set("response_type", "code")
|
||||
if c.cfg.Scope != "" {
|
||||
params.Set("scope", c.cfg.Scope)
|
||||
}
|
||||
if state != "" {
|
||||
params.Set("state", state)
|
||||
}
|
||||
sep := "?"
|
||||
if strings.Contains(c.cfg.AuthorizationURL, "?") {
|
||||
sep = "&"
|
||||
}
|
||||
return c.cfg.AuthorizationURL + sep + params.Encode(), nil
|
||||
}
|
||||
|
||||
// ExchangeCodeForToken exchanges an authorization code for an access token.
|
||||
// Mirrors OAuthClient.exchange_code_for_token.
|
||||
func (c *oauthClient) ExchangeCodeForToken(ctx context.Context, code string) (*TokenResponse, error) {
|
||||
form := url.Values{}
|
||||
form.Set("client_id", c.cfg.ClientID)
|
||||
form.Set("client_secret", c.cfg.ClientSecret)
|
||||
form.Set("code", code)
|
||||
form.Set("redirect_uri", c.cfg.RedirectURI)
|
||||
form.Set("grant_type", "authorization_code")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.TokenURL, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to exchange authorization code for token: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to exchange authorization code for token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to exchange authorization code for token: %w", err)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("Failed to exchange authorization code for token: HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
token := &TokenResponse{}
|
||||
if jerr := json.Unmarshal(body, token); jerr != nil {
|
||||
// Some providers (notably GitHub when Accept is not set) return
|
||||
// application/x-www-form-urlencoded here instead of JSON.
|
||||
if values, perr := url.ParseQuery(string(body)); perr == nil {
|
||||
token.AccessToken = values.Get("access_token")
|
||||
token.TokenType = values.Get("token_type")
|
||||
token.IDToken = values.Get("id_token")
|
||||
token.Scope = values.Get("scope")
|
||||
} else {
|
||||
return nil, fmt.Errorf("Failed to exchange authorization code for token: parse response: %w", jerr)
|
||||
}
|
||||
}
|
||||
if token.AccessToken == "" {
|
||||
return nil, fmt.Errorf("Failed to exchange authorization code for token: empty access_token")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// FetchUserInfo fetches user information using the access token.
|
||||
// Mirrors OAuthClient.fetch_user_info / normalize_user_info.
|
||||
func (c *oauthClient) FetchUserInfo(ctx context.Context, accessToken, idToken string) (*UserInfo, error) {
|
||||
if c.cfg.UserinfoURL == "" {
|
||||
return nil, fmt.Errorf("Failed to fetch user info: userinfo_url is required")
|
||||
}
|
||||
raw, err := c.fetchUserinfoRaw(ctx, c.cfg.UserinfoURL, accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to fetch user info: %w", err)
|
||||
}
|
||||
return normalizeUserInfo(raw), nil
|
||||
}
|
||||
|
||||
func (c *oauthClient) fetchUserinfoRaw(ctx context.Context, endpoint, accessToken string) (map[string]interface{}, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
var out map[string]interface{}
|
||||
if err := json.Unmarshal(body, &out); err != nil {
|
||||
return nil, fmt.Errorf("parse userinfo response: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// normalizeUserInfo mirrors the Python normalize_user_info defaults: username
|
||||
// falls back to the email local part, nickname falls back to username, and
|
||||
// avatar_url falls back to OIDC's "picture" claim.
|
||||
func normalizeUserInfo(raw map[string]interface{}) *UserInfo {
|
||||
ui := &UserInfo{}
|
||||
if v, ok := raw["email"].(string); ok {
|
||||
ui.Email = v
|
||||
}
|
||||
if v, ok := raw["username"].(string); ok && v != "" {
|
||||
ui.Username = v
|
||||
} else if ui.Email != "" {
|
||||
if at := strings.IndexByte(ui.Email, '@'); at >= 0 {
|
||||
ui.Username = ui.Email[:at]
|
||||
} else {
|
||||
ui.Username = ui.Email
|
||||
}
|
||||
}
|
||||
if v, ok := raw["nickname"].(string); ok && v != "" {
|
||||
ui.Nickname = v
|
||||
} else {
|
||||
ui.Nickname = ui.Username
|
||||
}
|
||||
if v, ok := raw["avatar_url"].(string); ok && v != "" {
|
||||
ui.AvatarURL = v
|
||||
} else if v, ok := raw["picture"].(string); ok {
|
||||
ui.AvatarURL = v
|
||||
}
|
||||
return ui
|
||||
}
|
||||
313
internal/utility/oauth/client_test.go
Normal file
313
internal/utility/oauth/client_test.go
Normal file
@@ -0,0 +1,313 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAuthorizationURLOAuth2(t *testing.T) {
|
||||
c, err := NewClient(Config{
|
||||
Type: "oauth2",
|
||||
ClientID: "abc",
|
||||
ClientSecret: "secret",
|
||||
AuthorizationURL: "https://provider.example/authorize",
|
||||
TokenURL: "https://provider.example/token",
|
||||
UserinfoURL: "https://provider.example/userinfo",
|
||||
RedirectURI: "https://ragflow.example/api/v1/auth/oauth/myorg/callback",
|
||||
Scope: "openid email",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient: %v", err)
|
||||
}
|
||||
got, err := c.AuthorizationURL("state-token")
|
||||
if err != nil {
|
||||
t.Fatalf("AuthorizationURL: %v", err)
|
||||
}
|
||||
u, perr := url.Parse(got)
|
||||
if perr != nil {
|
||||
t.Fatalf("parse url: %v", perr)
|
||||
}
|
||||
q := u.Query()
|
||||
if q.Get("client_id") != "abc" {
|
||||
t.Errorf("client_id: got %q", q.Get("client_id"))
|
||||
}
|
||||
if q.Get("redirect_uri") != "https://ragflow.example/api/v1/auth/oauth/myorg/callback" {
|
||||
t.Errorf("redirect_uri: got %q", q.Get("redirect_uri"))
|
||||
}
|
||||
if q.Get("response_type") != "code" {
|
||||
t.Errorf("response_type: got %q", q.Get("response_type"))
|
||||
}
|
||||
if q.Get("scope") != "openid email" {
|
||||
t.Errorf("scope: got %q", q.Get("scope"))
|
||||
}
|
||||
if q.Get("state") != "state-token" {
|
||||
t.Errorf("state: got %q", q.Get("state"))
|
||||
}
|
||||
if u.Host != "provider.example" {
|
||||
t.Errorf("host: got %q", u.Host)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizationURLPreservesExistingQuery(t *testing.T) {
|
||||
c, _ := NewClient(Config{
|
||||
Type: "oauth2",
|
||||
ClientID: "abc",
|
||||
AuthorizationURL: "https://provider.example/authorize?tenant=acme",
|
||||
TokenURL: "https://provider.example/token",
|
||||
RedirectURI: "https://ragflow.example/cb",
|
||||
})
|
||||
got, err := c.AuthorizationURL("s")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.Contains(got, "tenant=acme&") {
|
||||
t.Errorf("existing tenant=acme query should be preserved with `&` separator, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClientUnsupportedType(t *testing.T) {
|
||||
_, err := NewClient(Config{Type: "saml"})
|
||||
if err == nil || !strings.Contains(err.Error(), "Unsupported type") {
|
||||
t.Fatalf("expected unsupported-type error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClientDefaultsBasedOnIssuer(t *testing.T) {
|
||||
orig := loadOIDCMetadata
|
||||
defer func() { loadOIDCMetadata = orig }()
|
||||
loadOIDCMetadata = func(issuer string) (*oidcMetadata, error) {
|
||||
return &oidcMetadata{
|
||||
Issuer: issuer,
|
||||
AuthorizationEndpoint: issuer + "/authorize",
|
||||
TokenEndpoint: issuer + "/token",
|
||||
UserinfoEndpoint: issuer + "/userinfo",
|
||||
}, nil
|
||||
}
|
||||
c, err := NewClient(Config{
|
||||
Issuer: "https://issuer.example",
|
||||
ClientID: "id",
|
||||
RedirectURI: "https://ragflow/cb",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient: %v", err)
|
||||
}
|
||||
if _, ok := c.(*oidcClient); !ok {
|
||||
t.Fatalf("expected oidcClient, got %T", c)
|
||||
}
|
||||
got, _ := c.AuthorizationURL("s")
|
||||
if !strings.HasPrefix(got, "https://issuer.example/authorize") {
|
||||
t.Errorf("authorization URL didn't pick up discovered endpoint, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExchangeCodeForTokenJSON(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseForm()
|
||||
if r.Form.Get("grant_type") != "authorization_code" {
|
||||
t.Errorf("grant_type: got %q", r.Form.Get("grant_type"))
|
||||
}
|
||||
if r.Form.Get("code") != "the-code" {
|
||||
t.Errorf("code: got %q", r.Form.Get("code"))
|
||||
}
|
||||
if r.Form.Get("client_id") != "abc" {
|
||||
t.Errorf("client_id: got %q", r.Form.Get("client_id"))
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"the-access-token","token_type":"Bearer","id_token":"abc.def.ghi"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c, _ := NewClient(Config{
|
||||
Type: "oauth2",
|
||||
ClientID: "abc",
|
||||
ClientSecret: "secret",
|
||||
AuthorizationURL: "https://provider/authorize",
|
||||
TokenURL: srv.URL,
|
||||
UserinfoURL: "https://provider/userinfo",
|
||||
RedirectURI: "https://ragflow/cb",
|
||||
})
|
||||
tok, err := c.ExchangeCodeForToken(context.Background(), "the-code")
|
||||
if err != nil {
|
||||
t.Fatalf("ExchangeCodeForToken: %v", err)
|
||||
}
|
||||
if tok.AccessToken != "the-access-token" {
|
||||
t.Errorf("access_token: got %q", tok.AccessToken)
|
||||
}
|
||||
if tok.IDToken != "abc.def.ghi" {
|
||||
t.Errorf("id_token: got %q", tok.IDToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExchangeCodeForTokenFormResponse(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
_, _ = io.WriteString(w, "access_token=abc&token_type=bearer&scope=user")
|
||||
}))
|
||||
defer srv.Close()
|
||||
c, _ := NewClient(Config{
|
||||
Type: "oauth2",
|
||||
ClientID: "abc",
|
||||
AuthorizationURL: "https://x/authorize",
|
||||
TokenURL: srv.URL,
|
||||
UserinfoURL: "https://x/userinfo",
|
||||
RedirectURI: "https://x/cb",
|
||||
})
|
||||
tok, err := c.ExchangeCodeForToken(context.Background(), "c")
|
||||
if err != nil {
|
||||
t.Fatalf("ExchangeCodeForToken: %v", err)
|
||||
}
|
||||
if tok.AccessToken != "abc" {
|
||||
t.Errorf("access_token: got %q", tok.AccessToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExchangeCodeForTokenError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = io.WriteString(w, `{"error":"invalid_grant"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
c, _ := NewClient(Config{
|
||||
Type: "oauth2",
|
||||
ClientID: "abc",
|
||||
AuthorizationURL: "https://x/authorize",
|
||||
TokenURL: srv.URL,
|
||||
UserinfoURL: "https://x/userinfo",
|
||||
RedirectURI: "https://x/cb",
|
||||
})
|
||||
_, err := c.ExchangeCodeForToken(context.Background(), "c")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "HTTP 400") {
|
||||
t.Errorf("expected HTTP 400 surfaced, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchUserInfoOAuth2Normalization(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Authorization") != "Bearer the-token" {
|
||||
t.Errorf("Authorization header: got %q", r.Header.Get("Authorization"))
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"email":"alice@example.com","picture":"https://cdn/alice.png"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
c, _ := NewClient(Config{
|
||||
Type: "oauth2",
|
||||
ClientID: "abc",
|
||||
AuthorizationURL: "https://x/authorize",
|
||||
TokenURL: "https://x/token",
|
||||
UserinfoURL: srv.URL,
|
||||
RedirectURI: "https://x/cb",
|
||||
})
|
||||
info, err := c.FetchUserInfo(context.Background(), "the-token", "")
|
||||
if err != nil {
|
||||
t.Fatalf("FetchUserInfo: %v", err)
|
||||
}
|
||||
if info.Email != "alice@example.com" {
|
||||
t.Errorf("email: %q", info.Email)
|
||||
}
|
||||
if info.Username != "alice" {
|
||||
t.Errorf("username (fallback to email local): %q", info.Username)
|
||||
}
|
||||
if info.Nickname != "alice" {
|
||||
t.Errorf("nickname (fallback to username): %q", info.Nickname)
|
||||
}
|
||||
if info.AvatarURL != "https://cdn/alice.png" {
|
||||
t.Errorf("avatar (picture fallback): %q", info.AvatarURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchUserInfoGitHubMergesPrimaryEmail(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/user", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"login":"bob","name":"Bob Bobson","email":null,"avatar_url":"https://gh/bob.png"}`)
|
||||
})
|
||||
mux.HandleFunc("/user/emails", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
emails := []map[string]interface{}{
|
||||
{"email": "bob-noreply@users.noreply.github.com", "primary": false, "verified": true},
|
||||
{"email": "bob@example.com", "primary": true, "verified": true},
|
||||
}
|
||||
raw, _ := json.Marshal(emails)
|
||||
_, _ = w.Write(raw)
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
c, _ := NewClient(Config{
|
||||
Type: "github",
|
||||
ClientID: "abc",
|
||||
ClientSecret: "secret",
|
||||
RedirectURI: "https://ragflow/cb",
|
||||
})
|
||||
gh := c.(*gitHubClient)
|
||||
gh.cfg.UserinfoURL = srv.URL + "/user"
|
||||
|
||||
info, err := gh.FetchUserInfo(context.Background(), "tok", "")
|
||||
if err != nil {
|
||||
t.Fatalf("FetchUserInfo: %v", err)
|
||||
}
|
||||
if info.Email != "bob@example.com" {
|
||||
t.Errorf("primary email merge failed: got %q", info.Email)
|
||||
}
|
||||
if info.Username != "bob" {
|
||||
t.Errorf("username (from login): got %q", info.Username)
|
||||
}
|
||||
if info.Nickname != "Bob Bobson" {
|
||||
t.Errorf("nickname (from name): got %q", info.Nickname)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeUserInfoEmptyEmail(t *testing.T) {
|
||||
ui := normalizeUserInfo(map[string]interface{}{"email": ""})
|
||||
if ui.Email != "" || ui.Username != "" || ui.Nickname != "" {
|
||||
t.Errorf("empty email should produce empty fields, got %+v", ui)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewOAuthClientMissingFields(t *testing.T) {
|
||||
cases := []Config{
|
||||
{Type: "oauth2"},
|
||||
{Type: "oauth2", ClientID: "x"},
|
||||
{Type: "oauth2", ClientID: "x", AuthorizationURL: "https://x/a"},
|
||||
{Type: "oauth2", ClientID: "x", AuthorizationURL: "https://x/a", TokenURL: "https://x/t"},
|
||||
}
|
||||
for i, cfg := range cases {
|
||||
if _, err := NewClient(cfg); err == nil {
|
||||
t.Errorf("case %d: expected error for missing required field", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewOIDCMissingIssuer(t *testing.T) {
|
||||
_, err := NewClient(Config{Type: "oidc"})
|
||||
if err == nil || !strings.Contains(err.Error(), "Missing issuer") {
|
||||
t.Errorf("expected Missing issuer error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
134
internal/utility/oauth/github.go
Normal file
134
internal/utility/oauth/github.go
Normal file
@@ -0,0 +1,134 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// gitHubClient overrides the OAuth endpoints with GitHub's well-known URLs
|
||||
// and reaches into /user/emails to recover the primary email, since
|
||||
// GitHub's /user response omits it when the user has hidden email
|
||||
// visibility.
|
||||
type gitHubClient struct {
|
||||
*oauthClient
|
||||
}
|
||||
|
||||
func newGitHubClient(cfg Config) (*gitHubClient, error) {
|
||||
cfg.AuthorizationURL = "https://github.com/login/oauth/authorize"
|
||||
cfg.TokenURL = "https://github.com/login/oauth/access_token"
|
||||
cfg.UserinfoURL = "https://api.github.com/user"
|
||||
if cfg.Scope == "" {
|
||||
cfg.Scope = "user:email"
|
||||
}
|
||||
base, err := newOAuthClient(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &gitHubClient{oauthClient: base}, nil
|
||||
}
|
||||
|
||||
// FetchUserInfo overrides the base implementation to merge the primary
|
||||
// email from /user/emails. Mirrors GithubOAuthClient.fetch_user_info.
|
||||
func (c *gitHubClient) FetchUserInfo(ctx context.Context, accessToken, idToken string) (*UserInfo, error) {
|
||||
raw, err := c.fetchUserinfoRaw(ctx, c.cfg.UserinfoURL, accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to fetch github user info: %w", err)
|
||||
}
|
||||
email, err := c.fetchPrimaryEmail(ctx, accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to fetch github user info: %w", err)
|
||||
}
|
||||
if email != "" {
|
||||
raw["email"] = email
|
||||
}
|
||||
return normalizeGitHubUserInfo(raw), nil
|
||||
}
|
||||
|
||||
func (c *gitHubClient) fetchPrimaryEmail(ctx context.Context, accessToken string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.cfg.UserinfoURL+"/emails", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
var emails []map[string]interface{}
|
||||
if err := json.Unmarshal(body, &emails); err != nil {
|
||||
return "", fmt.Errorf("parse emails response: %w", err)
|
||||
}
|
||||
for _, e := range emails {
|
||||
primary, _ := e["primary"].(bool)
|
||||
addr, _ := e["email"].(string)
|
||||
if primary && addr != "" {
|
||||
return addr, nil
|
||||
}
|
||||
}
|
||||
// Fall back to the first verified email if no primary is flagged.
|
||||
for _, e := range emails {
|
||||
verified, _ := e["verified"].(bool)
|
||||
addr, _ := e["email"].(string)
|
||||
if verified && addr != "" {
|
||||
return addr, nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// normalizeGitHubUserInfo mirrors GithubOAuthClient.normalize_user_info:
|
||||
// username comes from "login", nickname from "name", avatar from
|
||||
// "avatar_url".
|
||||
func normalizeGitHubUserInfo(raw map[string]interface{}) *UserInfo {
|
||||
ui := &UserInfo{}
|
||||
if v, ok := raw["email"].(string); ok {
|
||||
ui.Email = v
|
||||
}
|
||||
if v, ok := raw["login"].(string); ok && v != "" {
|
||||
ui.Username = v
|
||||
} else if ui.Email != "" {
|
||||
if at := strings.IndexByte(ui.Email, '@'); at >= 0 {
|
||||
ui.Username = ui.Email[:at]
|
||||
} else {
|
||||
ui.Username = ui.Email
|
||||
}
|
||||
}
|
||||
if v, ok := raw["name"].(string); ok && v != "" {
|
||||
ui.Nickname = v
|
||||
} else {
|
||||
ui.Nickname = ui.Username
|
||||
}
|
||||
if v, ok := raw["avatar_url"].(string); ok {
|
||||
ui.AvatarURL = v
|
||||
}
|
||||
return ui
|
||||
}
|
||||
107
internal/utility/oauth/oidc.go
Normal file
107
internal/utility/oauth/oidc.go
Normal file
@@ -0,0 +1,107 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// oidcClient is the OIDC flavor: it resolves the authorization /
|
||||
// token / userinfo URLs from the Issuer's discovery document
|
||||
// (.well-known/openid-configuration) before delegating the OAuth flow to
|
||||
// oauthClient.
|
||||
//
|
||||
// We do not currently verify or parse the id_token: id_token claims are
|
||||
// only used as enrichment in the Python implementation, and the /userinfo
|
||||
// endpoint (already called) returns the same canonical claims authenticated
|
||||
// via the access_token. Tracked as a follow-up to add full JWKS-based
|
||||
// id_token verification once a JWT library is vendored.
|
||||
type oidcClient struct {
|
||||
*oauthClient
|
||||
issuer string
|
||||
}
|
||||
|
||||
func newOIDCClient(cfg Config) (*oidcClient, error) {
|
||||
if cfg.Issuer == "" {
|
||||
return nil, fmt.Errorf("Missing issuer in configuration.")
|
||||
}
|
||||
meta, err := loadOIDCMetadata(cfg.Issuer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to fetch OIDC metadata: %w", err)
|
||||
}
|
||||
if meta.Issuer != "" {
|
||||
cfg.Issuer = meta.Issuer
|
||||
}
|
||||
if cfg.AuthorizationURL == "" {
|
||||
cfg.AuthorizationURL = meta.AuthorizationEndpoint
|
||||
}
|
||||
if cfg.TokenURL == "" {
|
||||
cfg.TokenURL = meta.TokenEndpoint
|
||||
}
|
||||
if cfg.UserinfoURL == "" {
|
||||
cfg.UserinfoURL = meta.UserinfoEndpoint
|
||||
}
|
||||
base, err := newOAuthClient(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &oidcClient{oauthClient: base, issuer: cfg.Issuer}, nil
|
||||
}
|
||||
|
||||
// oidcMetadata is the subset of fields we use from the discovery document.
|
||||
type oidcMetadata struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
UserinfoEndpoint string `json:"userinfo_endpoint"`
|
||||
JWKSURI string `json:"jwks_uri"`
|
||||
}
|
||||
|
||||
// loadOIDCMetadata is the indirection used to fetch a provider's discovery
|
||||
// document. Tests override it.
|
||||
var loadOIDCMetadata = func(issuer string) (*oidcMetadata, error) {
|
||||
metadataURL := strings.TrimRight(issuer, "/") + "/.well-known/openid-configuration"
|
||||
ctx, cancel := context.WithTimeout(context.Background(), HTTPRequestTimeout)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, metadataURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
meta := &oidcMetadata{}
|
||||
if err := json.Unmarshal(body, meta); err != nil {
|
||||
return nil, fmt.Errorf("parse OIDC metadata: %w", err)
|
||||
}
|
||||
return meta, nil
|
||||
}
|
||||
Reference in New Issue
Block a user