diff --git a/cmd/admin_server.go b/cmd/admin_server.go index 1165ce219a..f9b095c908 100644 --- a/cmd/admin_server.go +++ b/cmd/admin_server.go @@ -33,21 +33,18 @@ import ( "ragflow/internal/admin" "ragflow/internal/dao" - "ragflow/internal/handler" "ragflow/internal/logger" "ragflow/internal/server" - "ragflow/internal/service" "ragflow/internal/utility" ) // AdminServer admin server type AdminServer struct { - router *admin.Router - handler *admin.Handler - service *admin.Service - userHandler *handler.UserHandler - engine *gin.Engine - port string + router *admin.Router + handler *admin.Handler + service *admin.Service + engine *gin.Engine + port string } func main() { @@ -112,8 +109,12 @@ func main() { } adminService := admin.NewService() - userService := service.NewUserService() - adminHandler := admin.NewHandler(adminService, userService) + adminHandler := admin.NewHandler(adminService) + + // Initialize default admin user + if err := adminService.InitDefaultAdmin(); err != nil { + logger.Error("Failed to initialize default admin user", err) + } // Initialize router r := admin.NewRouter(adminHandler) diff --git a/internal/admin/handler.go b/internal/admin/handler.go index 4a4e36bfb6..a47628f646 100644 --- a/internal/admin/handler.go +++ b/internal/admin/handler.go @@ -18,6 +18,7 @@ package admin import ( "errors" + "fmt" "net/http" "ragflow/internal/common" "ragflow/internal/server" @@ -31,9 +32,7 @@ import ( // Common errors var ( - ErrInvalidCredentials = errors.New("invalid credentials") - ErrUserNotFound = errors.New("user not found") - ErrInvalidToken = errors.New("invalid token") + ErrUserNotFound = errors.New("user not found") ) // Handler admin handler @@ -43,8 +42,11 @@ type Handler struct { } // NewHandler create admin handler -func NewHandler(service *Service, userService *service.UserService) *Handler { - return &Handler{service: service, userService: userService} +func NewHandler(svc *Service) *Handler { + return &Handler{ + service: svc, + userService: service.NewUserService(), + } } // SuccessResponse success response @@ -96,40 +98,58 @@ func (h *Handler) Ping(c *gin.Context) { successNoData(c, "PONG") } -// LoginHTTPRequest login request body -type LoginHTTPRequest struct { - Email string `json:"email" binding:"required"` - Password string `json:"password" binding:"required"` -} - // Login handle admin login +// @Summary Admin Login +// @Description Admin login verification using email, only superuser can login +// @Tags admin +// @Accept json +// @Produce json +// @Param request body service.EmailLoginRequest true "login info with email" +// @Success 200 {object} map[string]interface{} +// @Router /admin/login [post] func (h *Handler) Login(c *gin.Context) { var req service.EmailLoginRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, + "code": common.CodeBadRequest, "message": err.Error(), }) return } + // Use userService.LoginByEmail with adminLogin=true + // This allows default admin account to login admin system user, code, err := h.userService.LoginByEmail(&req, true) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ + c.JSON(http.StatusOK, gin.H{ "code": code, "message": err.Error(), }) return } + // Check if user is superuser (admin) + if user.IsSuperuser == nil || !*user.IsSuperuser { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeForbidden, + "message": "Only superuser can login admin system", + }) + return + } + variables := server.GetVariables() secretKey := variables.SecretKey authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": fmt.Sprintf("Failed to generate auth token: %s", err.Error()), + }) + return + } // Set Authorization header with access_token - if user.AccessToken != nil { - c.Header("Authorization", authToken) - } + c.Header("Authorization", authToken) // Set CORS headers c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Methods", "*") @@ -919,6 +939,7 @@ func (h *Handler) TestSandboxConnection(c *gin.Context) { } // AuthMiddleware JWT auth middleware +// Validates that the user is authenticated and is a superuser (admin) func (h *Handler) AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { token := c.GetHeader("Authorization") @@ -935,6 +956,7 @@ func (h *Handler) AuthMiddleware() gin.HandlerFunc { "code": code, "message": "Invalid access token", }) + c.Abort() return } diff --git a/internal/admin/password.go b/internal/admin/password.go new file mode 100644 index 0000000000..ab81b169ba --- /dev/null +++ b/internal/admin/password.go @@ -0,0 +1,111 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package admin + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" + "hash" + "strconv" + "strings" + + "golang.org/x/crypto/pbkdf2" +) + +// CheckWerkzeugPassword verifies a password against a werkzeug password hash +// Format: pbkdf2:sha256:iterations$salt$hash +func CheckWerkzeugPassword(password, hashStr string) bool { + parts := strings.Split(hashStr, "$") + if len(parts) != 3 { + return false + } + + // Parse method (e.g., "pbkdf2:sha256:150000") + methodParts := strings.Split(parts[0], ":") + if len(methodParts) != 3 { + return false + } + + if methodParts[0] != "pbkdf2" { + return false + } + + var hashFunc func() hash.Hash + switch methodParts[1] { + case "sha256": + hashFunc = sha256.New + case "sha512": + // sha512 not supported in this implementation + return false + default: + return false + } + + iterations, err := strconv.Atoi(methodParts[2]) + if err != nil { + return false + } + + salt := parts[1] + expectedHash := parts[2] + + // Decode salt from base64 + saltBytes, err := base64.StdEncoding.DecodeString(salt) + if err != nil { + // Try hex encoding + saltBytes, err = hex.DecodeString(salt) + if err != nil { + return false + } + } + + // Generate hash using PBKDF2 + key := pbkdf2.Key([]byte(password), saltBytes, iterations, 32, hashFunc) + computedHash := base64.StdEncoding.EncodeToString(key) + + return computedHash == expectedHash +} + +// IsWerkzeugHash checks if a hash is in werkzeug format +func IsWerkzeugHash(hashStr string) bool { + return strings.HasPrefix(hashStr, "pbkdf2:") +} + +// GenerateWerkzeugPasswordHash generates a werkzeug-compatible password hash +func GenerateWerkzeugPasswordHash(password string, iterations int) (string, error) { + if iterations == 0 { + iterations = 150000 + } + + // Generate random salt + salt := make([]byte, 16) + if _, err := rand.Read(salt); err != nil { + return "", err + } + + // Generate hash using PBKDF2-SHA256 + key := pbkdf2.Key([]byte(password), salt, iterations, 32, sha256.New) + + // Format: pbkdf2:sha256:iterations$base64(salt)$base64(hash) + saltB64 := base64.StdEncoding.EncodeToString(salt) + hashB64 := base64.StdEncoding.EncodeToString(key) + + return fmt.Sprintf("pbkdf2:sha256:%d$%s$%s", iterations, saltB64, hashB64), nil +} diff --git a/internal/admin/router.go b/internal/admin/router.go index 4e9dd21346..6e2239d4de 100644 --- a/internal/admin/router.go +++ b/internal/admin/router.go @@ -18,14 +18,11 @@ package admin import ( "github.com/gin-gonic/gin" - - "ragflow/internal/handler" ) // Router admin router type Router struct { - handler *Handler - userHandler *handler.UserHandler + handler *Handler } // NewRouter create admin router diff --git a/internal/admin/service.go b/internal/admin/service.go index 6e2e1b346c..79fba69955 100644 --- a/internal/admin/service.go +++ b/internal/admin/service.go @@ -19,6 +19,7 @@ package admin import ( "crypto/rand" "crypto/tls" + "encoding/base64" "encoding/hex" "errors" "fmt" @@ -35,6 +36,13 @@ import ( "time" ) +// Service errors +var ( + ErrInvalidToken = errors.New("invalid token") + ErrNotAdmin = errors.New("user is not admin") + ErrUserInactive = errors.New("user is inactive") +) + // Service admin service layer type Service struct { userDAO *dao.UserDAO @@ -47,47 +55,6 @@ func NewService() *Service { } } -// LoginRequest login request -type LoginRequest struct { - Email string - Password string -} - -// LoginResponse login response -type LoginResponse struct { - Token string - UserID string - Email string - Nickname string -} - -// Login admin login -func (s *Service) Login(req *LoginRequest) (*LoginResponse, error) { - // Get user by email - user, err := s.userDAO.GetByEmail(req.Email) - if err != nil { - return nil, ErrInvalidCredentials - } - - // Check if user is active - if user.IsActive != "1" { - return nil, errors.New("user is not active") - } - - // Generate access token - token := utility.GenerateToken() - if err := s.userDAO.UpdateAccessToken(user, token); err != nil { - return nil, err - } - - return &LoginResponse{ - Token: token, - UserID: user.ID, - Email: user.Email, - Nickname: user.Nickname, - }, nil -} - // Logout user logout func (s *Service) Logout(user interface{}) error { // Invalidate token by setting it to INVALID_ prefix @@ -98,6 +65,24 @@ func (s *Service) Logout(user interface{}) error { return nil } +// GetUserByToken get user by access token +func (s *Service) GetUserByToken(token string) (*model.User, error) { + user, err := s.userDAO.GetByAccessToken(token) + if err != nil { + return nil, ErrInvalidToken + } + + if user.IsSuperuser == nil || !*user.IsSuperuser { + return nil, ErrNotAdmin + } + + if user.IsActive != "1" { + return nil, fmt.Errorf("user inactive") + } + + return user, nil +} + // generateRandomHex generate random hex string func generateRandomHex(n int) string { bytes := make([]byte, n) @@ -780,3 +765,133 @@ func (s *Service) HandleHeartbeat(msg *common.BaseMessage) error { GlobalServerStatusStore.UpdateStatus(msg.ServerName, status) return nil } + +// InitDefaultAdmin initialize default admin user +// This matches Python's init_default_admin behavior +func (s *Service) InitDefaultAdmin() error { + // Default superuser settings (matching Python's DEFAULT_SUPERUSER_* defaults) + defaultNickname := "admin" + defaultEmail := "admin@ragflow.io" + defaultPassword := "admin" + + // Query superusers + var users []*model.User + err := dao.DB.Where("is_superuser = ? AND status = ?", true, "1").Find(&users).Error + if err != nil { + return fmt.Errorf("failed to query superusers: %w", err) + } + + if len(users) == 0 { + now := time.Now().Unix() + nowDate := time.Now() + userID := utility.GenerateToken() + accessToken := utility.GenerateToken() + status := "1" + loginChannel := "password" + isSuperuser := true + + // Python: password = encode_to_base64(password) = base64.b64encode(password) + // Then: generate_password_hash(base64_password) creates werkzeug hash + password := base64.StdEncoding.EncodeToString([]byte(defaultPassword)) + hashedPassword, err := GenerateWerkzeugPasswordHash(password, 150000) + if err != nil { + return fmt.Errorf("failed to hash password: %w", err) + } + + user := &model.User{ + ID: userID, + Email: defaultEmail, + Nickname: defaultNickname, + Password: &hashedPassword, + AccessToken: &accessToken, + Status: &status, + IsActive: "1", + IsAuthenticated: "1", + IsAnonymous: "0", + LoginChannel: &loginChannel, + IsSuperuser: &isSuperuser, + BaseModel: model.BaseModel{ + CreateTime: &now, + CreateDate: &nowDate, + UpdateTime: &now, + UpdateDate: &nowDate, + }, + } + + if err := dao.DB.Create(user).Error; err != nil { + return fmt.Errorf("can't init admin: %w", err) + } + + if err := s.addTenantForAdmin(userID, defaultNickname); err != nil { + return fmt.Errorf("failed to add tenant for admin: %w", err) + } + + return nil + } + + for _, user := range users { + if user.IsActive != "1" { + return fmt.Errorf("no active admin. Please update 'is_active' in db manually") + } + } + + for _, user := range users { + if user.Email == defaultEmail { + // Check if tenant exists + var count int64 + dao.DB.Model(&model.UserTenant{}).Where("user_id = ? AND status = ?", user.ID, "1").Count(&count) + if count == 0 { + nickname := defaultNickname + if user.Nickname != "" { + nickname = user.Nickname + } + if err := s.addTenantForAdmin(user.ID, nickname); err != nil { + return err + } + } + break + } + } + + return nil +} + +// addTenantForAdmin add tenant for admin user +func (s *Service) addTenantForAdmin(userID, nickname string) error { + now := time.Now().Unix() + nowDate := time.Now() + status := "1" + role := "owner" + tenantName := nickname + "'s Kingdom" + + tenant := &model.Tenant{ + ID: userID, + Name: &tenantName, + BaseModel: model.BaseModel{ + CreateTime: &now, + CreateDate: &nowDate, + UpdateTime: &now, + UpdateDate: &nowDate, + }, + } + + if err := dao.DB.Create(tenant).Error; err != nil { + return err + } + + userTenant := &model.UserTenant{ + TenantID: userID, + UserID: userID, + InvitedBy: userID, + Role: role, + Status: &status, + BaseModel: model.BaseModel{ + CreateTime: &now, + CreateDate: &nowDate, + UpdateTime: &now, + UpdateDate: &nowDate, + }, + } + + return dao.DB.Create(userTenant).Error +} diff --git a/internal/service/user.go b/internal/service/user.go index a87260b680..bf3aff7952 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -18,12 +18,15 @@ package service import ( "crypto/rsa" + "crypto/sha256" + "crypto/sha512" "crypto/x509" "encoding/base64" "encoding/hex" "encoding/pem" "errors" "fmt" + "hash" "os" "ragflow/internal/common" "ragflow/internal/server" @@ -32,6 +35,7 @@ import ( "strings" "time" + "golang.org/x/crypto/pbkdf2" "golang.org/x/crypto/scrypt" "ragflow/internal/dao" @@ -408,7 +412,77 @@ func (s *UserService) HashPassword(password string) (string, error) { } // VerifyPassword verify password +// Supports both werkzeug pbkdf2 format (pbkdf2:sha256:iterations$salt$hash) and scrypt format func (s *UserService) VerifyPassword(hashedPassword, password string) bool { + // Check if it's pbkdf2 format (werkzeug) + if strings.HasPrefix(hashedPassword, "pbkdf2:") { + return s.verifyPBKDF2Password(hashedPassword, password) + } + + // Check if it's scrypt format + if strings.HasPrefix(hashedPassword, "scrypt:") { + return s.verifyScryptPassword(hashedPassword, password) + } + + return false +} + +// verifyPBKDF2Password verifies password using PBKDF2 (werkzeug format) +// Format: pbkdf2:sha256:iterations$salt$hash +func (s *UserService) verifyPBKDF2Password(hashedPassword, password string) bool { + parts := strings.Split(hashedPassword, "$") + if len(parts) != 3 { + return false + } + + // Parse method (e.g., "pbkdf2:sha256:150000") + methodParts := strings.Split(parts[0], ":") + if len(methodParts) != 3 { + return false + } + + if methodParts[0] != "pbkdf2" { + return false + } + + var hashFunc func() hash.Hash + switch methodParts[1] { + case "sha256": + hashFunc = sha256.New + case "sha512": + hashFunc = sha512.New + default: + return false + } + + iterations, err := strconv.Atoi(methodParts[2]) + if err != nil { + return false + } + + salt := parts[1] + expectedHash := parts[2] + + // Decode salt from base64 + saltBytes, err := base64.StdEncoding.DecodeString(salt) + if err != nil { + // Try hex encoding + saltBytes, err = hex.DecodeString(salt) + if err != nil { + return false + } + } + + // Generate hash using PBKDF2 + key := pbkdf2.Key([]byte(password), saltBytes, iterations, 32, hashFunc) + computedHash := base64.StdEncoding.EncodeToString(key) + + return computedHash == expectedHash +} + +// verifyScryptPassword verifies password using scrypt format +// Format: scrypt:n:r:p$salt$hash +func (s *UserService) verifyScryptPassword(hashedPassword, password string) bool { // Parse hash format: scrypt:n:r:p$salt$hash parts := strings.Split(hashedPassword, "$") if len(parts) != 3 {