Files
ragflow/internal/service/user.go
Jin Hai fad82fd1c0 Go: fix register user (#16058)
### What problem does this PR solve?

Fix register user

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2026-06-16 14:03:53 +08:00

1451 lines
44 KiB
Go

//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package service
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/sha512"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"encoding/pem"
"errors"
"fmt"
"hash"
"os"
"ragflow/internal/common"
"ragflow/internal/engine/redis"
"ragflow/internal/entity"
"ragflow/internal/server"
"regexp"
"strconv"
"strings"
"time"
"golang.org/x/crypto/pbkdf2"
"golang.org/x/crypto/scrypt"
"gorm.io/gorm"
"ragflow/internal/dao"
"ragflow/internal/utility"
)
// UserService user service
type UserService struct {
userDAO *dao.UserDAO
}
// NewUserService create user service
func NewUserService() *UserService {
return &UserService{
userDAO: dao.NewUserDAO(),
}
}
// RegisterRequest registration request
type RegisterRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=1"`
Nickname string `json:"nickname" binding:"required"`
}
// LoginRequest login request
type LoginRequest struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}
// EmailLoginRequest email login request
type EmailLoginRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required"`
}
// UpdateSettingsRequest update user settings request
type UpdateSettingsRequest struct {
Nickname *string `json:"nickname,omitempty"`
Avatar *string `json:"avatar,omitempty"`
Language *string `json:"language,omitempty"`
ColorSchema *string `json:"color_schema,omitempty"`
Timezone *string `json:"timezone,omitempty"`
Password *string `json:"password,omitempty"`
NewPassword *string `json:"new_password,omitempty"`
}
// ChangePasswordRequest change password request
type ChangePasswordRequest struct {
Password *string `json:"password,omitempty"`
NewPassword *string `json:"new_password,omitempty"`
}
// UserResponse user response
type UserResponse struct {
ID string `json:"id"`
Email string `json:"email"`
Nickname string `json:"nickname"`
Status *string `json:"status"`
CreatedAt string `json:"created_at"`
}
// Register user registration
func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.ErrorCode, error) {
cfg := server.GetConfig()
if !cfg.Authentication.RegisterEnabled {
return nil, common.CodeOperatingError, fmt.Errorf("User registration is disabled!")
}
emailRegex := regexp.MustCompile(`^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$`)
if !emailRegex.MatchString(req.Email) {
return nil, common.CodeOperatingError, fmt.Errorf("Invalid email address: %s!", req.Email)
}
existUser, err := s.userDAO.GetByEmail(req.Email)
if existUser != nil {
return nil, common.CodeOperatingError, fmt.Errorf("Email: %s has already registered!", req.Email)
}
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, common.CodeServerError, fmt.Errorf("failed to check existing user: %w", err)
}
decryptedPassword, err := s.decryptPasswordForRegister(req.Password)
if err != nil {
return nil, common.CodeExceptionError, err
}
var hashedPassword string
hashedPassword, err = s.HashPassword(decryptedPassword)
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to hash password: %w", err)
}
userID := utility.GenerateToken()
accessToken := utility.GenerateToken()
status := "1"
loginChannel := "password"
isSuperuser := false
language := defaultUserLanguage()
colorSchema := "Bright"
timezone := "UTC+8\tAsia/Shanghai"
now := time.Now().Truncate(time.Second)
user := &entity.User{
ID: userID,
AccessToken: &accessToken,
Email: req.Email,
Nickname: req.Nickname,
Password: &hashedPassword,
Status: &status,
Language: &language,
ColorSchema: &colorSchema,
Timezone: &timezone,
IsActive: "1",
IsAuthenticated: "1",
IsAnonymous: "0",
LastLoginTime: &now,
LoginChannel: &loginChannel,
IsSuperuser: &isSuperuser,
}
tenantName := req.Nickname + "'s Kingdom"
llmID := cfg.UserDefaultLLM.DefaultModels.ChatModel.Name
if llmID == "" {
llmID = ""
}
embdID := cfg.UserDefaultLLM.DefaultModels.EmbeddingModel.Name
if embdID == "" {
embdID = ""
}
asrID := cfg.UserDefaultLLM.DefaultModels.ASRModel.Name
if asrID == "" {
asrID = ""
}
img2txtID := cfg.UserDefaultLLM.DefaultModels.Image2TextModel.Name
if img2txtID == "" {
img2txtID = ""
}
rerankID := cfg.UserDefaultLLM.DefaultModels.RerankModel.Name
if rerankID == "" {
rerankID = ""
}
ttsID := cfg.UserDefaultLLM.DefaultModels.TTSModel.Name
if ttsID == "" {
ttsID = ""
}
ocrID := cfg.UserDefaultLLM.DefaultModels.OCRModel.Name
if ocrID == "" {
ocrID = ""
}
tenant := &entity.Tenant{
ID: userID,
Name: &tenantName,
LLMID: llmID,
EmbdID: embdID,
ASRID: asrID,
Img2TxtID: img2txtID,
RerankID: rerankID,
TTSID: ttsID,
OCRID: ocrID,
ParserIDs: "naive:General,Q&A:Q&A,manual:Manual,table:Table,paper:Research Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag",
Status: &status,
}
userTenantID := utility.GenerateToken()
userTenant := &entity.UserTenant{
ID: userTenantID,
UserID: userID,
TenantID: userID,
Role: "owner",
InvitedBy: userID,
Status: &status,
}
fileID := utility.GenerateToken()
file__ := ""
rootFile := &entity.File{
ID: fileID,
ParentID: fileID,
TenantID: userID,
CreatedBy: userID,
Name: "/",
Type: "folder",
Location: &file__,
Size: 0,
}
tenantLLMs, err := s.getInitTenantLLM(userID)
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to initialize tenant llm: %w", err)
}
db := dao.GetDB()
if err := db.Transaction(func(tx *gorm.DB) error {
if err := tx.Create(user).Error; err != nil {
return fmt.Errorf("failed to create user: %w", err)
}
if err := tx.Create(tenant).Error; err != nil {
return fmt.Errorf("failed to create tenant: %w", err)
}
if err := tx.Create(userTenant).Error; err != nil {
return fmt.Errorf("failed to create user tenant relation: %w", err)
}
if len(tenantLLMs) > 0 {
if err := tx.Create(&tenantLLMs).Error; err != nil {
return fmt.Errorf("failed to create tenant llm: %w", err)
}
}
if err := tx.Create(rootFile).Error; err != nil {
return fmt.Errorf("failed to create root folder: %w", err)
}
return nil
}); err != nil {
return nil, common.CodeServerError, fmt.Errorf("fail to create transaction: %w", err)
}
return user, common.CodeSuccess, nil
}
// getInitTenantLLM builds the tenant_llm rows created for a new user's default tenant.
func (s *UserService) getInitTenantLLM(userID string) ([]*entity.TenantLLM, error) {
cfg := server.GetConfig()
if cfg == nil {
return nil, fmt.Errorf("config not initialized")
}
modelConfigs := map[string]server.ModelConfig{
string(entity.ModelTypeChat): cfg.UserDefaultLLM.DefaultModels.ChatModel,
string(entity.ModelTypeEmbedding): cfg.UserDefaultLLM.DefaultModels.EmbeddingModel,
string(entity.ModelTypeSpeech2Text): cfg.UserDefaultLLM.DefaultModels.ASRModel,
string(entity.ModelTypeImage2Text): cfg.UserDefaultLLM.DefaultModels.Image2TextModel,
string(entity.ModelTypeRerank): cfg.UserDefaultLLM.DefaultModels.RerankModel,
}
seenFactories := make(map[string]bool)
factoryConfigs := make([]server.ModelConfig, 0, len(modelConfigs))
for _, modelConfig := range []server.ModelConfig{
cfg.UserDefaultLLM.DefaultModels.ChatModel,
cfg.UserDefaultLLM.DefaultModels.EmbeddingModel,
cfg.UserDefaultLLM.DefaultModels.ASRModel,
cfg.UserDefaultLLM.DefaultModels.Image2TextModel,
cfg.UserDefaultLLM.DefaultModels.RerankModel,
} {
if modelConfig.Factory == "" || seenFactories[modelConfig.Factory] {
continue
}
seenFactories[modelConfig.Factory] = true
factoryConfigs = append(factoryConfigs, modelConfig)
}
llmDAO := dao.NewLLMDAO()
tenantLLMs := make([]*entity.TenantLLM, 0)
for _, factoryConfig := range factoryConfigs {
llms, err := llmDAO.GetByFactory(factoryConfig.Factory)
if err != nil {
return nil, fmt.Errorf("failed to get LLMs for factory %s: %w", factoryConfig.Factory, err)
}
for _, llm := range llms {
apiKey := factoryConfig.APIKey
apiBase := factoryConfig.BaseURL
if modelConfig, ok := modelConfigs[llm.ModelType]; ok {
if modelConfig.APIKey != "" {
apiKey = modelConfig.APIKey
}
if modelConfig.BaseURL != "" {
apiBase = modelConfig.BaseURL
}
}
maxTokens := int64(8192)
if llm.MaxTokens > 0 {
maxTokens = llm.MaxTokens
}
llmName := llm.LLMName
modelType := llm.ModelType
tenantLLMs = append(tenantLLMs, &entity.TenantLLM{
TenantID: userID,
LLMFactory: factoryConfig.Factory,
LLMName: &llmName,
ModelType: &modelType,
APIKey: &apiKey,
APIBase: &apiBase,
MaxTokens: maxTokens,
Status: "1",
})
}
}
seen := make(map[string]bool)
uniqueTenantLLMs := make([]*entity.TenantLLM, 0, len(tenantLLMs))
for _, tenantLLM := range tenantLLMs {
llmName := ""
if tenantLLM.LLMName != nil {
llmName = *tenantLLM.LLMName
}
key := strings.Join([]string{tenantLLM.TenantID, tenantLLM.LLMFactory, llmName}, "|")
if seen[key] {
continue
}
seen[key] = true
uniqueTenantLLMs = append(uniqueTenantLLMs, tenantLLM)
}
return uniqueTenantLLMs, nil
}
// Login user login
func (s *UserService) Login(req *LoginRequest) (*entity.User, common.ErrorCode, error) {
// Get user by email (using username field as email)
user, err := s.userDAO.GetByEmail(req.Username)
if err != nil {
return nil, common.CodeAuthenticationError, fmt.Errorf("invalid email or password")
}
// Decrypt password using RSA
decryptedPassword, err := s.decryptPassword(req.Password)
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to decrypt password: %w", err)
}
// Verify password
if user.Password == nil || !s.VerifyPassword(*user.Password, decryptedPassword) {
return nil, common.CodeAuthenticationError, fmt.Errorf("invalid username or password")
}
if user.Status == nil || *user.Status != "1" {
return nil, common.CodeForbidden, fmt.Errorf("user is disabled")
}
// Generate new access token
token := utility.GenerateToken()
user.AccessToken = &token
now := time.Now().Truncate(time.Second)
user.LastLoginTime = &now
if err := s.userDAO.Update(user); err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to update user: %w", err)
}
return user, common.CodeSuccess, nil
}
// LoginByEmail user login by email
// Returns user on success, or error with specific code:
// - CodeAuthenticationError (109): Email not registered or password mismatch
// - CodeServerError (500): Password decryption failure
// - CodeForbidden (403): Account disabled
func (s *UserService) LoginByEmail(req *EmailLoginRequest) (*entity.User, common.ErrorCode, error) {
user, err := s.userDAO.GetByEmail(req.Email)
if err != nil {
return nil, common.CodeAuthenticationError, fmt.Errorf("email: %s is not registered!", req.Email)
}
decryptedPassword, err := s.decryptPassword(req.Password)
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("fail to crypt password")
}
if user.Password == nil || !s.VerifyPassword(*user.Password, decryptedPassword) {
return nil, common.CodeAuthenticationError, fmt.Errorf("email and password do not match!")
}
if user.IsActive == "0" {
return nil, common.CodeForbidden, fmt.Errorf("This account has been disabled, please contact the administrator!")
}
// Generate new access token
token := utility.GenerateToken()
user.AccessToken = &token
now := time.Now().Truncate(time.Second)
user.LastLoginTime = &now
if err := s.userDAO.Update(user); err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to update user: %w", err)
}
return user, common.CodeSuccess, nil
}
// GetUserByID get user by ID
func (s *UserService) GetUserByID(id uint) (*UserResponse, common.ErrorCode, error) {
user, err := s.userDAO.GetByID(id)
if err != nil {
return nil, common.CodeNotFound, err
}
return &UserResponse{
ID: user.ID,
Email: user.Email,
Nickname: user.Nickname,
Status: user.Status,
CreatedAt: func() string {
if user.CreateTime != nil {
return time.Unix(*user.CreateTime, 0).Format("2006-01-02 15:04:05")
}
return ""
}(),
}, common.CodeSuccess, nil
}
// ListUsers list users
func (s *UserService) ListUsers(page, pageSize int) ([]*UserResponse, int64, common.ErrorCode, error) {
offset := (page - 1) * pageSize
users, total, err := s.userDAO.List(offset, pageSize)
if err != nil {
return nil, 0, common.CodeServerError, err
}
responses := make([]*UserResponse, len(users))
for i, user := range users {
responses[i] = &UserResponse{
ID: user.ID,
Email: user.Email,
Nickname: user.Nickname,
Status: user.Status,
CreatedAt: func() string {
if user.CreateTime != nil {
return time.Unix(*user.CreateTime, 0).Format("2006-01-02 15:04:05")
}
return ""
}(),
}
}
return responses, total, common.CodeSuccess, nil
}
// HashPassword generate password hash using scrypt (werkzeug compatible)
// The password should already be base64 encoded (from decrypt process)
// Werkzeug default format: scrypt:32768:8:1$base64(salt)$hex(hash)
// IMPORTANT: werkzeug uses the base64-encoded salt string as UTF-8 bytes, NOT the decoded bytes
func (s *UserService) HashPassword(password string) (string, error) {
// Generate random bytes (12 bytes will produce 16-char base64 string)
randomBytes, err := s.generateSalt()
if err != nil {
return "", fmt.Errorf("failed to generate salt: %w", err)
}
// Encode to base64 string (this will be 16 characters)
saltB64 := base64.StdEncoding.EncodeToString(randomBytes)
// Use scrypt with werkzeug default parameters: N=32768, r=8, p=1, keyLen=64
// IMPORTANT: werkzeug uses the base64 string as UTF-8 bytes, NOT the decoded bytes
hash, err := scrypt.Key([]byte(password), []byte(saltB64), 32768, 8, 1, 64)
if err != nil {
return "", fmt.Errorf("failed to compute scrypt hash: %w", err)
}
// Format: scrypt:n:r:p$base64(salt)$hex(hash)
return fmt.Sprintf("scrypt:32768:8:1$%s$%x", saltB64, hash), nil
}
// VerifyPassword verify password
// Supports both werkzeug pbkdf2 format (pbkdf2:sha256:iterations$salt$hash) and scrypt format
func (s *UserService) VerifyPassword(hashedPassword, password string) bool {
// Check if it's pbkdf2 format (werkzeug)
if strings.HasPrefix(hashedPassword, "pbkdf2:") {
return s.verifyPBKDF2Password(hashedPassword, password)
}
// Check if it's scrypt format
if strings.HasPrefix(hashedPassword, "scrypt:") {
return s.verifyScryptPassword(hashedPassword, password)
}
return false
}
// verifyPBKDF2Password verifies password using PBKDF2 (werkzeug format)
// Format: pbkdf2:sha256:iterations$salt$hash
func (s *UserService) verifyPBKDF2Password(hashedPassword, password string) bool {
parts := strings.Split(hashedPassword, "$")
if len(parts) != 3 {
return false
}
// Parse method (e.g., "pbkdf2:sha256:150000")
methodParts := strings.Split(parts[0], ":")
if len(methodParts) != 3 {
return false
}
if methodParts[0] != "pbkdf2" {
return false
}
var hashFunc func() hash.Hash
switch methodParts[1] {
case "sha256":
hashFunc = sha256.New
case "sha512":
hashFunc = sha512.New
default:
return false
}
iterations, err := strconv.Atoi(methodParts[2])
if err != nil {
return false
}
salt := parts[1]
expectedHash := parts[2]
// Decode salt from base64
saltBytes, err := base64.StdEncoding.DecodeString(salt)
if err != nil {
// Try hex encoding
saltBytes, err = hex.DecodeString(salt)
if err != nil {
return false
}
}
// Generate hash using PBKDF2
key := pbkdf2.Key([]byte(password), saltBytes, iterations, 32, hashFunc)
computedHash := base64.StdEncoding.EncodeToString(key)
return computedHash == expectedHash
}
// verifyScryptPassword verifies password using scrypt format
// Format: scrypt:n:r:p$base64(salt)$hex(hash)
// IMPORTANT: werkzeug uses the base64-encoded salt string as UTF-8 bytes, NOT the decoded bytes
func (s *UserService) verifyScryptPassword(hashedPassword, password string) bool {
// Parse hash format: scrypt:n:r:p$base64(salt)$hex(hash)
parts := strings.Split(hashedPassword, "$")
if len(parts) != 3 {
return false
}
params := strings.Split(parts[0], ":")
if len(params) != 4 || params[0] != "scrypt" {
return false
}
n, err := strconv.ParseUint(params[1], 10, 0)
if err != nil {
return false
}
r, err := strconv.ParseUint(params[2], 10, 0)
if err != nil {
return false
}
p, err := strconv.ParseUint(params[3], 10, 0)
if err != nil {
return false
}
saltB64 := parts[1]
hashHex := parts[2]
// IMPORTANT: werkzeug uses the base64 string as UTF-8 bytes, NOT decoded bytes
// This is the key difference from standard implementations
salt := []byte(saltB64)
// Decode expected hash from hex
expectedHash, err := hex.DecodeString(hashHex)
if err != nil {
return false
}
// Compute password hash
computed, err := scrypt.Key([]byte(password), salt, int(n), int(r), int(p), len(expectedHash))
if err != nil {
return false
}
// Constant time comparison
return s.constantTimeCompare(expectedHash, computed)
}
// generateSalt generates a random 12-byte salt (werkzeug default)
func (s *UserService) generateSalt() ([]byte, error) {
salt := make([]byte, 12)
if _, err := rand.Read(salt); err != nil {
return nil, fmt.Errorf("failed to generate random salt: %w", err)
}
return salt, nil
}
// constantTimeCompare constant time comparison
func (s *UserService) constantTimeCompare(a, b []byte) bool {
if len(a) != len(b) {
return false
}
var result byte
for i := 0; i < len(a); i++ {
result |= a[i] ^ b[i]
}
return result == 0
}
// loadPrivateKey loads and decrypts the RSA private key from conf/private.pem
// nolint:static check // DecryptPEMBlock is deprecated but still works for traditional PEM encryption
func (s *UserService) loadPrivateKey() (*rsa.PrivateKey, error) {
// Read private key file
keyData, err := os.ReadFile("conf/private.pem")
if err != nil {
return nil, fmt.Errorf("failed to read private key file: %w", err)
}
// Parse PEM block
block, _ := pem.Decode(keyData)
if block == nil {
return nil, errors.New("failed to decode PEM block")
}
// Decrypt the PEM block if it's encrypted
var privateKey interface{}
if block.Headers["Proc-Type"] == "4,ENCRYPTED" {
// Decrypt using password "Welcome"
// Note: DecryptPEMBlock is deprecated but still functional for traditional PEM encryption
decryptedData, err := x509.DecryptPEMBlock(block, []byte("Welcome"))
if err != nil {
return nil, fmt.Errorf("failed to decrypt private key: %w", err)
}
// Parse the decrypted key
privateKey, err = x509.ParsePKCS1PrivateKey(decryptedData)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
} else {
// Not encrypted, parse directly
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
}
rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
if !ok {
return nil, errors.New("not an RSA private key")
}
return rsaPrivateKey, nil
}
// decryptPassword decrypts the password using RSA private key
func (s *UserService) decryptPassword(encryptedPassword string) (string, error) {
// Try to decode base64
ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword)
if err != nil {
// If base64 decoding fails, assume it's already a plain password
return encryptedPassword, nil
}
// Load private key
privateKey, err := s.loadPrivateKey()
if err != nil {
return "", err
}
// Decrypt using PKCS#1 v1.5
plaintext, err := rsa.DecryptPKCS1v15(nil, privateKey, ciphertext)
if err != nil {
// If decryption fails, assume it's already a plain password
return encryptedPassword, nil
}
return string(plaintext), nil
}
func (s *UserService) decryptPasswordForRegister(encryptedPassword string) (string, error) {
ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword)
if err != nil {
return "", fmt.Errorf("Error('Incorrect padding')")
}
privateKey, err := s.loadPrivateKey()
if err != nil {
return "", err
}
plaintext, err := rsa.DecryptPKCS1v15(nil, privateKey, ciphertext)
if err != nil {
return "Fail to decrypt password!", nil
}
return string(plaintext), nil
}
func defaultUserLanguage() string {
if strings.Contains(os.Getenv("LANG"), "zh_CN") {
return "Chinese"
}
return "English"
}
// GetUserByToken gets user by authorization header
// The token parameter is the authorization header value, which needs to be decrypted
// using itsdangerous URLSafeTimedSerializer to get the actual access_token
func (s *UserService) GetUserByToken(authorization string) (*entity.User, common.ErrorCode, error) {
// Get secret key from config
secretKey, err := server.GetSecretKey(redis.Get())
if err != nil {
return nil, common.CodeUnauthorized, err
}
// Extract access token from authorization header
// Equivalent to: access_token = str(jwt.loads(authorization)) in Python
accessToken, err := utility.ExtractAccessToken(authorization, secretKey)
if err != nil {
return nil, common.CodeUnauthorized, fmt.Errorf("invalid authorization token: %w", err)
}
// Validate token format (should be at least 32 chars, UUID format)
if len(accessToken) < 32 {
return nil, common.CodeUnauthorized, fmt.Errorf("invalid access token format")
}
// Get user by access token
user, err := s.userDAO.GetByAccessToken(accessToken)
if err != nil {
return nil, common.CodeUnauthorized, err
}
return user, common.CodeSuccess, nil
}
// UpdateUserAccessToken updates user's access token
func (s *UserService) UpdateUserAccessToken(user *entity.User, token string) error {
return s.userDAO.UpdateAccessToken(user, token)
}
// Logout invalidates user's access token
func (s *UserService) Logout(user *entity.User) (common.ErrorCode, error) {
// Invalidate token by setting it to an invalid value
// Similar to Python implementation: "INVALID_" + secrets.token_hex(16)
invalidToken := "INVALID_" + utility.GenerateToken()
err := s.UpdateUserAccessToken(user, invalidToken)
if err != nil {
return common.CodeServerError, err
}
return common.CodeSuccess, nil
}
// GetUserProfile returns user profile information
func (s *UserService) GetUserProfile(user *entity.User) map[string]interface{} {
// Format create time and date (from database fields)
createTime := user.CreateTime
createDate := ""
if user.CreateDate != nil {
createDate = user.CreateDate.Format("2006-01-02T15:04:05")
}
// Format update time and date (from database fields)
var updateTime int64
updateDate := ""
if user.UpdateTime != nil {
updateTime = *user.UpdateTime
}
if user.UpdateDate != nil {
updateDate = user.UpdateDate.Format("2006-01-02T15:04:05")
}
// Format last login time
var lastLoginTime string
if user.LastLoginTime != nil {
lastLoginTime = user.LastLoginTime.Format("2006-01-02T15:04:05")
}
// Get access token
var accessToken string
if user.AccessToken != nil {
accessToken = *user.AccessToken
}
// Get avatar
var avatar interface{}
if user.Avatar != nil {
avatar = *user.Avatar
} else {
avatar = nil
}
// Get color schema
colorSchema := "Bright"
if user.ColorSchema != nil && *user.ColorSchema != "" {
colorSchema = *user.ColorSchema
}
// Get language
language := defaultUserLanguage()
if user.Language != nil && *user.Language != "" {
language = *user.Language
}
// Get timezone
timezone := "UTC+8\tAsia/Shanghai"
if user.Timezone != nil && *user.Timezone != "" {
timezone = *user.Timezone
}
// Get login channel
loginChannel := "password"
if user.LoginChannel != nil && *user.LoginChannel != "" {
loginChannel = *user.LoginChannel
}
// Get password
var password string
if user.Password != nil {
password = *user.Password
}
// Get status
status := "1"
if user.Status != nil {
status = *user.Status
}
// Get is_superuser
isSuperuser := false
if user.IsSuperuser != nil {
isSuperuser = *user.IsSuperuser
}
return map[string]interface{}{
"access_token": accessToken,
"avatar": avatar,
"color_schema": colorSchema,
"create_date": createDate,
"create_time": createTime,
"email": user.Email,
"id": user.ID,
"is_active": user.IsActive,
"is_anonymous": user.IsAnonymous,
"is_authenticated": user.IsAuthenticated,
"is_superuser": isSuperuser,
"language": language,
"last_login_time": lastLoginTime,
"login_channel": loginChannel,
"nickname": user.Nickname,
"password": password,
"status": status,
"timezone": timezone,
"update_date": updateDate,
"update_time": updateTime,
}
}
// UpdateUserSettings updates user settings
func (s *UserService) UpdateUserSettings(user *entity.User, req *UpdateSettingsRequest) (common.ErrorCode, error) {
// Update fields if provided
if req.Password != nil {
ciphertext, err := base64.StdEncoding.DecodeString(*req.Password)
if err != nil {
return common.CodeExceptionError, fmt.Errorf("Error('Incorrect padding')")
}
privateKey, err := s.loadPrivateKey()
if err != nil {
return common.CodeExceptionError, err
}
oldPasswordBytes, err := rsa.DecryptPKCS1v15(nil, privateKey, ciphertext)
oldPassword := "Fail to decrypt password!"
if err == nil {
oldPassword = string(oldPasswordBytes)
}
if user.Password == nil || !s.VerifyPassword(*user.Password, oldPassword) {
return common.CodeAuthenticationError, fmt.Errorf("Password error!")
}
if req.NewPassword != nil {
ciphertext, err := base64.StdEncoding.DecodeString(*req.NewPassword)
if err != nil {
return common.CodeExceptionError, fmt.Errorf("Error('Incorrect padding')")
}
newPasswordBytes, err := rsa.DecryptPKCS1v15(nil, privateKey, ciphertext)
if err != nil {
return common.CodeExceptionError, err
}
hashedPassword, err := s.HashPassword(string(newPasswordBytes))
if err != nil {
return common.CodeExceptionError, err
}
user.Password = &hashedPassword
}
}
if req.Nickname != nil {
user.Nickname = *req.Nickname
}
if req.Avatar != nil {
// In Go version, avatar might be stored differently
// For now, just update if field exists
user.Avatar = req.Avatar
}
if req.Language != nil {
// Store language preference
user.Language = req.Language
}
if req.ColorSchema != nil {
// Store color schema preference
user.ColorSchema = req.ColorSchema
}
if req.Timezone != nil {
// Store timezone preference
user.Timezone = req.Timezone
}
// Save updated user
if err := s.userDAO.Update(user); err != nil {
return common.CodeServerError, err
}
return common.CodeSuccess, nil
}
// ChangePassword changes user password
func (s *UserService) ChangePassword(user *entity.User, req *ChangePasswordRequest) (common.ErrorCode, error) {
// If password is provided, verify current password
if req.Password != nil {
if user.Password == nil || !s.VerifyPassword(*user.Password, *req.Password) {
return common.CodeBadRequest, fmt.Errorf("current password is incorrect")
}
}
// If new password is provided, update password
if req.NewPassword != nil {
hashedPassword, err := s.HashPassword(*req.NewPassword)
if err != nil {
return common.CodeServerError, fmt.Errorf("failed to hash new password: %w", err)
}
user.Password = &hashedPassword
}
// Save updated user
if err := s.userDAO.Update(user); err != nil {
return common.CodeServerError, err
}
return common.CodeSuccess, nil
}
// LoginChannel represents a login channel response
type LoginChannel struct {
Channel string `json:"channel"`
DisplayName string `json:"display_name"`
Icon string `json:"icon"`
}
// GetLoginChannels gets all supported authentication channels
func (s *UserService) GetLoginChannels() ([]*LoginChannel, common.ErrorCode, error) {
cfg := server.GetConfig()
channels := make([]*LoginChannel, 0)
for channel, oauthCfg := range cfg.OAuth {
displayName := oauthCfg.DisplayName
if displayName == "" {
displayName = strings.Title(channel)
}
icon := oauthCfg.Icon
if icon == "" {
icon = "sso"
}
channels = append(channels, &LoginChannel{
Channel: channel,
DisplayName: displayName,
Icon: icon,
})
}
return channels, common.CodeSuccess, nil
}
// SetTenantInfoRequest represents the request for setting tenant info
type SetTenantInfoRequest struct {
TenantID *string `json:"tenant_id"`
ASRID *string `json:"asr_id"`
EmbdID *string `json:"embd_id"`
Img2TxtID *string `json:"img2txt_id"`
LLMID *string `json:"llm_id"`
RerankID *string `json:"rerank_id"`
TTSID *string `json:"tts_id"`
Raw map[string]interface{} `json:"-"`
}
// SetTenantInfo updates tenant model configuration
func (s *UserService) SetTenantInfo(userID string, req *SetTenantInfoRequest) (common.ErrorCode, error) {
_ = userID
tenantDAO := dao.NewTenantDAO()
updates := make(map[string]interface{})
for key, value := range req.Raw {
if key == "tenant_id" {
continue
}
updates[key] = value
}
tenantID := ""
if req.TenantID != nil {
tenantID = *req.TenantID
}
tenantLLMService := NewTenantLLMService()
updates = tenantLLMService.EnsureTenantModelIDForParams(tenantID, updates)
if len(updates) > 0 {
if err := tenantDAO.Update(tenantID, updates); err != nil {
return common.CodeExceptionError, err
}
}
return common.CodeSuccess, nil
}
// UserTenantService user tenant service
// Provides business logic for user-tenant relationship management
type UserTenantService struct {
userTenantDAO *dao.UserTenantDAO
}
// NewUserTenantService creates a new UserTenantService instance
/**
* Returns:
* - *UserTenantService: a new UserTenantService instance
*
* Example:
*
* service := NewUserTenantService()
* relations, err := service.GetUserTenantRelationByUserID("user123")
*/
func NewUserTenantService() *UserTenantService {
return &UserTenantService{
userTenantDAO: dao.NewUserTenantDAO(),
}
}
// UserTenantRelation represents a user-tenant relationship response
// This structure matches the Python implementation's return format
type UserTenantRelation struct {
ID string `json:"id"`
UserID string `json:"user_id"`
TenantID string `json:"tenant_id"`
Role string `json:"role"`
}
// GetUserTenantRelationByUserID retrieves all user-tenant relationships for a given user ID
/**
* This method returns a list of user-tenant relationships with selected fields:
* - id: the relationship ID
* - user_id: the user ID
* - tenant_id: the tenant ID
* - role: the user's role in the tenant
*
* Parameters:
* - userID: the unique identifier of the user
*
* Returns:
* - []*UserTenantRelation: list of user-tenant relationships
* - error: error if the operation fails, nil otherwise
*
* Example:
*
* service := NewUserTenantService()
* relations, err := service.GetUserTenantRelationByUserID("user123")
* if err != nil {
* log.Printf("Failed to get user tenant relations: %v", err)
* return
* }
* for _, rel := range relations {
* fmt.Printf("User %s has role %s in tenant %s\n", rel.UserID, rel.Role, rel.TenantID)
* }
*/
func (s *UserTenantService) GetUserTenantRelationByUserID(userID string) ([]*UserTenantRelation, error) {
return s.GetUserTenantRelationByUserIDWithContext(context.Background(), userID)
}
// GetUserTenantRelationByUserIDWithContext retrieves all user-tenant relationships for a given user ID with context.
func (s *UserTenantService) GetUserTenantRelationByUserIDWithContext(ctx context.Context, userID string) ([]*UserTenantRelation, error) {
relations, err := s.userTenantDAO.GetByUserIDWithContext(ctx, userID)
if err != nil {
return nil, err
}
result := make([]*UserTenantRelation, len(relations))
for i, rel := range relations {
result[i] = convertToUserTenantRelation(rel)
}
return result, nil
}
// convertToUserTenantRelation converts model.UserTenant to UserTenantRelation
/**
* Parameters:
* - userTenant: the model.UserTenant to convert
*
* Returns:
* - *UserTenantRelation: the converted UserTenantRelation
*/
func convertToUserTenantRelation(userTenant *entity.UserTenant) *UserTenantRelation {
return &UserTenantRelation{
ID: userTenant.ID,
UserID: userTenant.UserID,
TenantID: userTenant.TenantID,
Role: userTenant.Role,
}
}
// GetUserByAPIToken gets user by access key from Authorization header
// This is used for API token authentication
// The authorization parameter should be in format: "Bearer <token>" or just "<token>"
func (s *UserService) GetUserByAPIToken(authorization string) (*entity.User, common.ErrorCode, error) {
if authorization == "" {
return nil, common.CodeUnauthorized, fmt.Errorf("authorization header is empty")
}
// Split authorization header to get the token
// Expected format: "Bearer <token>" or "<token>"
parts := strings.Split(authorization, " ")
var token string
if len(parts) == 2 {
token = parts[1]
} else if len(parts) == 1 {
token = parts[0]
} else {
return nil, common.CodeUnauthorized, fmt.Errorf("invalid authorization format")
}
// Query API token from database
apiTokenDAO := dao.NewAPITokenDAO()
userToken, err := apiTokenDAO.GetUserByAPIToken(token)
if err != nil {
return nil, common.CodeUnauthorized, fmt.Errorf("invalid access token")
}
// Get user by tenant_id from API token
user, err := s.userDAO.GetByTenantID(userToken.TenantID)
if err != nil {
return nil, common.CodeUnauthorized, fmt.Errorf("user not found for this access token")
}
// Check if user's access_token is empty
if user.AccessToken == nil || *user.AccessToken == "" {
return nil, common.CodeUnauthorized, fmt.Errorf("user has empty access_token in database")
}
return user, common.CodeSuccess, nil
}
// ---- Forgot-password flow (mirrors api/apps/restful_apis/user_api.py
// `/auth/password/...` endpoints, fixes #15282) -------------------------
// ForgotIssueCaptcha mints a captcha for the given email and stores the
// expected text in Redis under utility.CaptchaIDRedisKey, keyed by a
// fresh server-side captcha_id, with a 60s TTL. Returns the captcha_id
// and a renderable SVG image (data URL) the FE drops into <img src> so
// the human can read the challenge and type the answer. The plaintext
// code itself is never sent to the client outside the rendered image.
//
// Refuses unknown emails to avoid leaking the user list — matches Python.
func (s *UserService) ForgotIssueCaptcha(email string) (captchaID, imageDataURL string, code common.ErrorCode, err error) {
if email == "" {
return "", "", common.CodeArgumentError, fmt.Errorf("email is required")
}
if _, err := s.userDAO.GetByEmail(email); err != nil {
return "", "", common.CodeDataError, fmt.Errorf("invalid email")
}
text, err := utility.GenerateCaptchaCode()
if err != nil {
return "", "", common.CodeServerError, err
}
captchaID = utility.GenerateToken()
if ok := redis.Get().Set(utility.CaptchaIDRedisKey(captchaID), text, 60*time.Second); !ok {
return "", "", common.CodeServerError, fmt.Errorf("failed to store captcha")
}
imageDataURL = utility.RenderCaptchaPNGDataURL(text)
return captchaID, imageDataURL, common.CodeSuccess, nil
}
// ForgotSendOTP verifies the captcha (looked up by the server-issued
// captcha_id), then issues an OTP and emails it. Hash-and-salt is
// stored in Redis under the keys returned by utility.OTPRedisKeys.
// Resend cooldown and per-email lockout behaviour otherwise match the
// Python implementation byte-for-byte.
func (s *UserService) ForgotSendOTP(email, captchaID, captcha string) (common.ErrorCode, error) {
if email == "" || captchaID == "" || captcha == "" {
return common.CodeArgumentError, fmt.Errorf("email, captcha_id and captcha required")
}
if _, err := s.userDAO.GetByEmail(email); err != nil {
return common.CodeDataError, fmt.Errorf("invalid email")
}
rc := redis.Get()
captchaKey := utility.CaptchaIDRedisKey(captchaID)
stored, _ := rc.Get(captchaKey)
if stored == "" {
return common.CodeNotEffective, fmt.Errorf("invalid or expired captcha")
}
if !strings.EqualFold(strings.TrimSpace(stored), strings.TrimSpace(captcha)) {
return common.CodeAuthenticationError, fmt.Errorf("invalid or expired captcha")
}
// One-shot: consume the captcha so a leaked captcha_id cannot be
// reused for a stream of OTP requests.
rc.Delete(captchaKey)
codeKey, attemptsKey, lastSentKey, lockKey := utility.OTPRedisKeys(email)
// Lockout — a previous verify burst already locked this email; do not
// let a request for a new OTP wipe the lock (deliberate divergence
// from the Python implementation, which deletes the lock here and so
// allows a locked attacker to clear their own lockout by re-requesting).
if locked, _ := rc.Get(lockKey); locked != "" {
return common.CodeNotEffective, fmt.Errorf("too many attempts, try later")
}
// Resend cooldown — refuse if we already sent within the window.
if lastSent, _ := rc.Get(lastSentKey); lastSent != "" {
ts, parseErr := strconv.ParseInt(lastSent, 10, 64)
if parseErr == nil {
elapsed := time.Since(time.Unix(ts, 0))
remaining := utility.OTPResendCooldown - elapsed
if remaining > 0 {
return common.CodeNotEffective, fmt.Errorf("you still have to wait %d seconds", int(remaining.Seconds()))
}
}
}
otp, err := utility.GenerateOTPCode()
if err != nil {
return common.CodeServerError, err
}
salt, err := utility.GenerateOTPSalt()
if err != nil {
return common.CodeServerError, err
}
codeHash := utility.HashOTPCode(otp, salt)
now := strconv.FormatInt(time.Now().Unix(), 10)
// Snapshot the previous OTP-flow state so we can restore it if email
// delivery fails — otherwise the user is throttled by lastSentKey
// even though they never received the code.
prevCode, _ := rc.Get(codeKey)
prevAttempts, _ := rc.Get(attemptsKey)
prevLastSent, _ := rc.Get(lastSentKey)
if !rc.Set(codeKey, utility.EncodeOTPStorageValue(codeHash, salt), utility.OTPTTL) {
return common.CodeServerError, fmt.Errorf("failed to store otp")
}
rc.Set(attemptsKey, "0", utility.OTPTTL)
rc.Set(lastSentKey, now, utility.OTPTTL)
// Note: lockKey is intentionally not cleared here. If the user has
// been locked out by a previous verify burst, requesting a new OTP
// does not lift the lock — we already refused above.
ttlMin := int(utility.OTPTTL.Minutes())
cfg := server.GetConfig()
if err := utility.SendResetCodeEmail(cfg.SMTP, email, otp, ttlMin); err != nil {
// Roll back: restore prior code/attempts/last-sent or remove the
// keys we just wrote so the next attempt isn't blocked by the
// resend cooldown a failed send just installed.
if prevCode != "" {
rc.Set(codeKey, prevCode, utility.OTPTTL)
} else {
rc.Delete(codeKey)
}
if prevAttempts != "" {
rc.Set(attemptsKey, prevAttempts, utility.OTPTTL)
} else {
rc.Delete(attemptsKey)
}
if prevLastSent != "" {
rc.Set(lastSentKey, prevLastSent, utility.OTPTTL)
} else {
rc.Delete(lastSentKey)
}
return common.CodeServerError, fmt.Errorf("failed to send email")
}
return common.CodeSuccess, nil
}
// ForgotVerifyOTP checks an OTP submitted by the user. On success it
// consumes the OTP/attempt counters and writes a short-lived "verified"
// flag the reset endpoint will gate on.
func (s *UserService) ForgotVerifyOTP(email, otp string) (common.ErrorCode, error) {
if email == "" || otp == "" {
return common.CodeArgumentError, fmt.Errorf("email and otp are required")
}
if _, err := s.userDAO.GetByEmail(email); err != nil {
return common.CodeDataError, fmt.Errorf("invalid email")
}
rc := redis.Get()
codeKey, attemptsKey, lastSentKey, lockKey := utility.OTPRedisKeys(email)
if locked, _ := rc.Get(lockKey); locked != "" {
return common.CodeNotEffective, fmt.Errorf("too many attempts, try later")
}
stored, _ := rc.Get(codeKey)
if stored == "" {
return common.CodeNotEffective, fmt.Errorf("expired otp")
}
storedHash, salt, err := utility.DecodeOTPStorageValue(stored)
if err != nil {
return common.CodeServerError, fmt.Errorf("otp storage corrupted")
}
if utility.HashOTPCode(strings.ToUpper(strings.TrimSpace(otp)), salt) != storedHash {
// bump attempts; lock on >= limit
attempts := 0
if cur, _ := rc.Get(attemptsKey); cur != "" {
if n, perr := strconv.Atoi(cur); perr == nil {
attempts = n
}
}
attempts++
rc.Set(attemptsKey, strconv.Itoa(attempts), utility.OTPTTL)
if attempts >= utility.OTPAttemptLimit {
rc.Set(lockKey, strconv.FormatInt(time.Now().Unix(), 10), utility.OTPAttemptLockDuration)
}
return common.CodeAuthenticationError, fmt.Errorf("expired otp")
}
// Success: clear OTP state, mark email verified.
rc.Delete(codeKey)
rc.Delete(attemptsKey)
rc.Delete(lastSentKey)
rc.Delete(lockKey)
if !rc.Set(utility.OTPVerifiedRedisKey(email), "1", utility.OTPTTL) {
return common.CodeServerError, fmt.Errorf("failed to set verification state")
}
return common.CodeSuccess, nil
}
// ForgotResetPasswordRequest carries the JSON body of /auth/password/reset.
//
// No `binding` tags on purpose: gin's validator fires inside
// c.ShouldBindJSON and produces a verbose
// `Key: 'ForgotResetPasswordRequest.Email' Error:Field validation ...`
// message that diverges from the Python contract for this endpoint,
// which returns the friendlier `"email and passwords are required"`
// (api/apps/restful_apis/user_api.py:forget_reset_password). Letting
// the binding succeed with zero values means the existing service
// check below produces the matching message, and an entirely missing
// JSON body now gets exactly Python's response.
type ForgotResetPasswordRequest struct {
Email string `json:"email"`
NewPassword string `json:"new_password"`
ConfirmNewPassword string `json:"confirm_new_password"`
}
// ForgotResetPassword finalises the reset: only proceeds if the verified
// flag is set, validates the two ciphertexts match after RSA decryption,
// updates the password hash, and clears the verified flag. Returns the
// user so the handler can auto-login (matching Python's
// `construct_response(auth=user.get_id())`).
func (s *UserService) ForgotResetPassword(req *ForgotResetPasswordRequest) (*entity.User, common.ErrorCode, error) {
if req.Email == "" || req.NewPassword == "" || req.ConfirmNewPassword == "" {
return nil, common.CodeArgumentError, fmt.Errorf("email and passwords are required")
}
rc := redis.Get()
verifiedKey := utility.OTPVerifiedRedisKey(req.Email)
if v, _ := rc.Get(verifiedKey); v != "1" {
return nil, common.CodeAuthenticationError, fmt.Errorf("email not verified")
}
plain, err := s.decryptPassword(req.NewPassword)
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("fail to decrypt password")
}
confirm, err := s.decryptPassword(req.ConfirmNewPassword)
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("fail to decrypt password")
}
if plain != confirm {
return nil, common.CodeArgumentError, fmt.Errorf("passwords do not match")
}
user, err := s.userDAO.GetByEmail(req.Email)
if err != nil {
return nil, common.CodeDataError, fmt.Errorf("invalid email")
}
hashed, err := s.HashPassword(plain)
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to hash new password: %w", err)
}
user.Password = &hashed
// Auto-login: rotate the access token like LoginByEmail does so the
// handler can immediately mint an Authorization header.
token := utility.GenerateToken()
user.AccessToken = &token
now := time.Now().Truncate(time.Second)
user.LastLoginTime = &now
if err := s.userDAO.Update(user); err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to reset password: %w", err)
}
rc.Delete(verifiedKey)
return user, common.CodeSuccess, nil
}