mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
feat(admin): Implemented default administrator initialization and login functionality. (#13504)
### What problem does this PR solve? feat(admin): Implemented default administrator initialization and login functionality. Added support for default administrator configuration, including super user nickname, email, and password. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@@ -33,21 +33,18 @@ import (
|
||||
|
||||
"ragflow/internal/admin"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/handler"
|
||||
"ragflow/internal/logger"
|
||||
"ragflow/internal/server"
|
||||
"ragflow/internal/service"
|
||||
"ragflow/internal/utility"
|
||||
)
|
||||
|
||||
// AdminServer admin server
|
||||
type AdminServer struct {
|
||||
router *admin.Router
|
||||
handler *admin.Handler
|
||||
service *admin.Service
|
||||
userHandler *handler.UserHandler
|
||||
engine *gin.Engine
|
||||
port string
|
||||
router *admin.Router
|
||||
handler *admin.Handler
|
||||
service *admin.Service
|
||||
engine *gin.Engine
|
||||
port string
|
||||
}
|
||||
|
||||
func main() {
|
||||
@@ -112,8 +109,12 @@ func main() {
|
||||
}
|
||||
|
||||
adminService := admin.NewService()
|
||||
userService := service.NewUserService()
|
||||
adminHandler := admin.NewHandler(adminService, userService)
|
||||
adminHandler := admin.NewHandler(adminService)
|
||||
|
||||
// Initialize default admin user
|
||||
if err := adminService.InitDefaultAdmin(); err != nil {
|
||||
logger.Error("Failed to initialize default admin user", err)
|
||||
}
|
||||
|
||||
// Initialize router
|
||||
r := admin.NewRouter(adminHandler)
|
||||
|
||||
@@ -18,6 +18,7 @@ package admin
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/server"
|
||||
@@ -31,9 +32,7 @@ import (
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
ErrInvalidCredentials = errors.New("invalid credentials")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
)
|
||||
|
||||
// Handler admin handler
|
||||
@@ -43,8 +42,11 @@ type Handler struct {
|
||||
}
|
||||
|
||||
// NewHandler create admin handler
|
||||
func NewHandler(service *Service, userService *service.UserService) *Handler {
|
||||
return &Handler{service: service, userService: userService}
|
||||
func NewHandler(svc *Service) *Handler {
|
||||
return &Handler{
|
||||
service: svc,
|
||||
userService: service.NewUserService(),
|
||||
}
|
||||
}
|
||||
|
||||
// SuccessResponse success response
|
||||
@@ -96,40 +98,58 @@ func (h *Handler) Ping(c *gin.Context) {
|
||||
successNoData(c, "PONG")
|
||||
}
|
||||
|
||||
// LoginHTTPRequest login request body
|
||||
type LoginHTTPRequest struct {
|
||||
Email string `json:"email" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
// Login handle admin login
|
||||
// @Summary Admin Login
|
||||
// @Description Admin login verification using email, only superuser can login
|
||||
// @Tags admin
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body service.EmailLoginRequest true "login info with email"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /admin/login [post]
|
||||
func (h *Handler) Login(c *gin.Context) {
|
||||
var req service.EmailLoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"code": common.CodeBadRequest,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Use userService.LoginByEmail with adminLogin=true
|
||||
// This allows default admin account to login admin system
|
||||
user, code, err := h.userService.LoginByEmail(&req, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user is superuser (admin)
|
||||
if user.IsSuperuser == nil || !*user.IsSuperuser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeForbidden,
|
||||
"message": "Only superuser can login admin system",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
variables := server.GetVariables()
|
||||
secretKey := variables.SecretKey
|
||||
authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": fmt.Sprintf("Failed to generate auth token: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Set Authorization header with access_token
|
||||
if user.AccessToken != nil {
|
||||
c.Header("Authorization", authToken)
|
||||
}
|
||||
c.Header("Authorization", authToken)
|
||||
// Set CORS headers
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Methods", "*")
|
||||
@@ -919,6 +939,7 @@ func (h *Handler) TestSandboxConnection(c *gin.Context) {
|
||||
}
|
||||
|
||||
// AuthMiddleware JWT auth middleware
|
||||
// Validates that the user is authenticated and is a superuser (admin)
|
||||
func (h *Handler) AuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
token := c.GetHeader("Authorization")
|
||||
@@ -935,6 +956,7 @@ func (h *Handler) AuthMiddleware() gin.HandlerFunc {
|
||||
"code": code,
|
||||
"message": "Invalid access token",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
111
internal/admin/password.go
Normal file
111
internal/admin/password.go
Normal file
@@ -0,0 +1,111 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"hash"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
// CheckWerkzeugPassword verifies a password against a werkzeug password hash
|
||||
// Format: pbkdf2:sha256:iterations$salt$hash
|
||||
func CheckWerkzeugPassword(password, hashStr string) bool {
|
||||
parts := strings.Split(hashStr, "$")
|
||||
if len(parts) != 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse method (e.g., "pbkdf2:sha256:150000")
|
||||
methodParts := strings.Split(parts[0], ":")
|
||||
if len(methodParts) != 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
if methodParts[0] != "pbkdf2" {
|
||||
return false
|
||||
}
|
||||
|
||||
var hashFunc func() hash.Hash
|
||||
switch methodParts[1] {
|
||||
case "sha256":
|
||||
hashFunc = sha256.New
|
||||
case "sha512":
|
||||
// sha512 not supported in this implementation
|
||||
return false
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
iterations, err := strconv.Atoi(methodParts[2])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
salt := parts[1]
|
||||
expectedHash := parts[2]
|
||||
|
||||
// Decode salt from base64
|
||||
saltBytes, err := base64.StdEncoding.DecodeString(salt)
|
||||
if err != nil {
|
||||
// Try hex encoding
|
||||
saltBytes, err = hex.DecodeString(salt)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Generate hash using PBKDF2
|
||||
key := pbkdf2.Key([]byte(password), saltBytes, iterations, 32, hashFunc)
|
||||
computedHash := base64.StdEncoding.EncodeToString(key)
|
||||
|
||||
return computedHash == expectedHash
|
||||
}
|
||||
|
||||
// IsWerkzeugHash checks if a hash is in werkzeug format
|
||||
func IsWerkzeugHash(hashStr string) bool {
|
||||
return strings.HasPrefix(hashStr, "pbkdf2:")
|
||||
}
|
||||
|
||||
// GenerateWerkzeugPasswordHash generates a werkzeug-compatible password hash
|
||||
func GenerateWerkzeugPasswordHash(password string, iterations int) (string, error) {
|
||||
if iterations == 0 {
|
||||
iterations = 150000
|
||||
}
|
||||
|
||||
// Generate random salt
|
||||
salt := make([]byte, 16)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Generate hash using PBKDF2-SHA256
|
||||
key := pbkdf2.Key([]byte(password), salt, iterations, 32, sha256.New)
|
||||
|
||||
// Format: pbkdf2:sha256:iterations$base64(salt)$base64(hash)
|
||||
saltB64 := base64.StdEncoding.EncodeToString(salt)
|
||||
hashB64 := base64.StdEncoding.EncodeToString(key)
|
||||
|
||||
return fmt.Sprintf("pbkdf2:sha256:%d$%s$%s", iterations, saltB64, hashB64), nil
|
||||
}
|
||||
@@ -18,14 +18,11 @@ package admin
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/handler"
|
||||
)
|
||||
|
||||
// Router admin router
|
||||
type Router struct {
|
||||
handler *Handler
|
||||
userHandler *handler.UserHandler
|
||||
handler *Handler
|
||||
}
|
||||
|
||||
// NewRouter create admin router
|
||||
|
||||
@@ -19,6 +19,7 @@ package admin
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -35,6 +36,13 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Service errors
|
||||
var (
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
ErrNotAdmin = errors.New("user is not admin")
|
||||
ErrUserInactive = errors.New("user is inactive")
|
||||
)
|
||||
|
||||
// Service admin service layer
|
||||
type Service struct {
|
||||
userDAO *dao.UserDAO
|
||||
@@ -47,47 +55,6 @@ func NewService() *Service {
|
||||
}
|
||||
}
|
||||
|
||||
// LoginRequest login request
|
||||
type LoginRequest struct {
|
||||
Email string
|
||||
Password string
|
||||
}
|
||||
|
||||
// LoginResponse login response
|
||||
type LoginResponse struct {
|
||||
Token string
|
||||
UserID string
|
||||
Email string
|
||||
Nickname string
|
||||
}
|
||||
|
||||
// Login admin login
|
||||
func (s *Service) Login(req *LoginRequest) (*LoginResponse, error) {
|
||||
// Get user by email
|
||||
user, err := s.userDAO.GetByEmail(req.Email)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
// Check if user is active
|
||||
if user.IsActive != "1" {
|
||||
return nil, errors.New("user is not active")
|
||||
}
|
||||
|
||||
// Generate access token
|
||||
token := utility.GenerateToken()
|
||||
if err := s.userDAO.UpdateAccessToken(user, token); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &LoginResponse{
|
||||
Token: token,
|
||||
UserID: user.ID,
|
||||
Email: user.Email,
|
||||
Nickname: user.Nickname,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Logout user logout
|
||||
func (s *Service) Logout(user interface{}) error {
|
||||
// Invalidate token by setting it to INVALID_ prefix
|
||||
@@ -98,6 +65,24 @@ func (s *Service) Logout(user interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserByToken get user by access token
|
||||
func (s *Service) GetUserByToken(token string) (*model.User, error) {
|
||||
user, err := s.userDAO.GetByAccessToken(token)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
if user.IsSuperuser == nil || !*user.IsSuperuser {
|
||||
return nil, ErrNotAdmin
|
||||
}
|
||||
|
||||
if user.IsActive != "1" {
|
||||
return nil, fmt.Errorf("user inactive")
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// generateRandomHex generate random hex string
|
||||
func generateRandomHex(n int) string {
|
||||
bytes := make([]byte, n)
|
||||
@@ -780,3 +765,133 @@ func (s *Service) HandleHeartbeat(msg *common.BaseMessage) error {
|
||||
GlobalServerStatusStore.UpdateStatus(msg.ServerName, status)
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitDefaultAdmin initialize default admin user
|
||||
// This matches Python's init_default_admin behavior
|
||||
func (s *Service) InitDefaultAdmin() error {
|
||||
// Default superuser settings (matching Python's DEFAULT_SUPERUSER_* defaults)
|
||||
defaultNickname := "admin"
|
||||
defaultEmail := "admin@ragflow.io"
|
||||
defaultPassword := "admin"
|
||||
|
||||
// Query superusers
|
||||
var users []*model.User
|
||||
err := dao.DB.Where("is_superuser = ? AND status = ?", true, "1").Find(&users).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query superusers: %w", err)
|
||||
}
|
||||
|
||||
if len(users) == 0 {
|
||||
now := time.Now().Unix()
|
||||
nowDate := time.Now()
|
||||
userID := utility.GenerateToken()
|
||||
accessToken := utility.GenerateToken()
|
||||
status := "1"
|
||||
loginChannel := "password"
|
||||
isSuperuser := true
|
||||
|
||||
// Python: password = encode_to_base64(password) = base64.b64encode(password)
|
||||
// Then: generate_password_hash(base64_password) creates werkzeug hash
|
||||
password := base64.StdEncoding.EncodeToString([]byte(defaultPassword))
|
||||
hashedPassword, err := GenerateWerkzeugPasswordHash(password, 150000)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
user := &model.User{
|
||||
ID: userID,
|
||||
Email: defaultEmail,
|
||||
Nickname: defaultNickname,
|
||||
Password: &hashedPassword,
|
||||
AccessToken: &accessToken,
|
||||
Status: &status,
|
||||
IsActive: "1",
|
||||
IsAuthenticated: "1",
|
||||
IsAnonymous: "0",
|
||||
LoginChannel: &loginChannel,
|
||||
IsSuperuser: &isSuperuser,
|
||||
BaseModel: model.BaseModel{
|
||||
CreateTime: &now,
|
||||
CreateDate: &nowDate,
|
||||
UpdateTime: &now,
|
||||
UpdateDate: &nowDate,
|
||||
},
|
||||
}
|
||||
|
||||
if err := dao.DB.Create(user).Error; err != nil {
|
||||
return fmt.Errorf("can't init admin: %w", err)
|
||||
}
|
||||
|
||||
if err := s.addTenantForAdmin(userID, defaultNickname); err != nil {
|
||||
return fmt.Errorf("failed to add tenant for admin: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
if user.IsActive != "1" {
|
||||
return fmt.Errorf("no active admin. Please update 'is_active' in db manually")
|
||||
}
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
if user.Email == defaultEmail {
|
||||
// Check if tenant exists
|
||||
var count int64
|
||||
dao.DB.Model(&model.UserTenant{}).Where("user_id = ? AND status = ?", user.ID, "1").Count(&count)
|
||||
if count == 0 {
|
||||
nickname := defaultNickname
|
||||
if user.Nickname != "" {
|
||||
nickname = user.Nickname
|
||||
}
|
||||
if err := s.addTenantForAdmin(user.ID, nickname); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addTenantForAdmin add tenant for admin user
|
||||
func (s *Service) addTenantForAdmin(userID, nickname string) error {
|
||||
now := time.Now().Unix()
|
||||
nowDate := time.Now()
|
||||
status := "1"
|
||||
role := "owner"
|
||||
tenantName := nickname + "'s Kingdom"
|
||||
|
||||
tenant := &model.Tenant{
|
||||
ID: userID,
|
||||
Name: &tenantName,
|
||||
BaseModel: model.BaseModel{
|
||||
CreateTime: &now,
|
||||
CreateDate: &nowDate,
|
||||
UpdateTime: &now,
|
||||
UpdateDate: &nowDate,
|
||||
},
|
||||
}
|
||||
|
||||
if err := dao.DB.Create(tenant).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userTenant := &model.UserTenant{
|
||||
TenantID: userID,
|
||||
UserID: userID,
|
||||
InvitedBy: userID,
|
||||
Role: role,
|
||||
Status: &status,
|
||||
BaseModel: model.BaseModel{
|
||||
CreateTime: &now,
|
||||
CreateDate: &nowDate,
|
||||
UpdateTime: &now,
|
||||
UpdateDate: &nowDate,
|
||||
},
|
||||
}
|
||||
|
||||
return dao.DB.Create(userTenant).Error
|
||||
}
|
||||
|
||||
@@ -18,12 +18,15 @@ package service
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"os"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/server"
|
||||
@@ -32,6 +35,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
"golang.org/x/crypto/scrypt"
|
||||
|
||||
"ragflow/internal/dao"
|
||||
@@ -408,7 +412,77 @@ func (s *UserService) HashPassword(password string) (string, error) {
|
||||
}
|
||||
|
||||
// VerifyPassword verify password
|
||||
// Supports both werkzeug pbkdf2 format (pbkdf2:sha256:iterations$salt$hash) and scrypt format
|
||||
func (s *UserService) VerifyPassword(hashedPassword, password string) bool {
|
||||
// Check if it's pbkdf2 format (werkzeug)
|
||||
if strings.HasPrefix(hashedPassword, "pbkdf2:") {
|
||||
return s.verifyPBKDF2Password(hashedPassword, password)
|
||||
}
|
||||
|
||||
// Check if it's scrypt format
|
||||
if strings.HasPrefix(hashedPassword, "scrypt:") {
|
||||
return s.verifyScryptPassword(hashedPassword, password)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// verifyPBKDF2Password verifies password using PBKDF2 (werkzeug format)
|
||||
// Format: pbkdf2:sha256:iterations$salt$hash
|
||||
func (s *UserService) verifyPBKDF2Password(hashedPassword, password string) bool {
|
||||
parts := strings.Split(hashedPassword, "$")
|
||||
if len(parts) != 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse method (e.g., "pbkdf2:sha256:150000")
|
||||
methodParts := strings.Split(parts[0], ":")
|
||||
if len(methodParts) != 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
if methodParts[0] != "pbkdf2" {
|
||||
return false
|
||||
}
|
||||
|
||||
var hashFunc func() hash.Hash
|
||||
switch methodParts[1] {
|
||||
case "sha256":
|
||||
hashFunc = sha256.New
|
||||
case "sha512":
|
||||
hashFunc = sha512.New
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
iterations, err := strconv.Atoi(methodParts[2])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
salt := parts[1]
|
||||
expectedHash := parts[2]
|
||||
|
||||
// Decode salt from base64
|
||||
saltBytes, err := base64.StdEncoding.DecodeString(salt)
|
||||
if err != nil {
|
||||
// Try hex encoding
|
||||
saltBytes, err = hex.DecodeString(salt)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Generate hash using PBKDF2
|
||||
key := pbkdf2.Key([]byte(password), saltBytes, iterations, 32, hashFunc)
|
||||
computedHash := base64.StdEncoding.EncodeToString(key)
|
||||
|
||||
return computedHash == expectedHash
|
||||
}
|
||||
|
||||
// verifyScryptPassword verifies password using scrypt format
|
||||
// Format: scrypt:n:r:p$salt$hash
|
||||
func (s *UserService) verifyScryptPassword(hashedPassword, password string) bool {
|
||||
// Parse hash format: scrypt:n:r:p$salt$hash
|
||||
parts := strings.Split(hashedPassword, "$")
|
||||
if len(parts) != 3 {
|
||||
|
||||
Reference in New Issue
Block a user