Fix user registration initialization in Go API (#15349)

### What problem does this PR solve?

This PR fixes several behavior gaps in the Go implementation of the user
registration API.

### Type of change

- Make `nickname` required for user registration.
- Align registration error messages and response data with expected API
behavior.
- Handle password decryption errors for registration more consistently.
- Generate UUID v1-style IDs for new users, access tokens, tenants,
user-tenant records, and root files.
- Initialize default user fields during registration, including:
  - language
  - color schema
  - timezone
  - last login time
- Create user, tenant, user-tenant relation, tenant LLM records, and
root folder in a single DB transaction.
- Initialize default tenant LLM records from configured default models.
- Avoid partial registration data when one creation step fails.
- Use locale-based default language fallback for user profile responses.
This commit is contained in:
Hz_
2026-05-29 19:29:23 +08:00
committed by GitHub
parent 658ff06ca4
commit 09e91a8e61
3 changed files with 188 additions and 48 deletions

View File

@@ -66,10 +66,14 @@ func (h *UserHandler) Register(c *gin.Context) {
user, code, err := h.userService.Register(&req)
if err != nil {
var data interface{} = false
if code == common.CodeExceptionError {
data = nil
}
c.JSON(http.StatusOK, gin.H{
"code": code,
"message": err.Error(),
"data": false,
"data": data,
})
return
}

View File

@@ -40,6 +40,7 @@ import (
"golang.org/x/crypto/pbkdf2"
"golang.org/x/crypto/scrypt"
"gorm.io/gorm"
"ragflow/internal/dao"
@@ -62,7 +63,7 @@ func NewUserService() *UserService {
type RegisterRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=1"`
Nickname string `json:"nickname"`
Nickname string `json:"nickname" binding:"required"`
}
// LoginRequest login request
@@ -107,22 +108,25 @@ type UserResponse struct {
func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.ErrorCode, error) {
cfg := server.GetConfig()
if !cfg.Authentication.RegisterEnabled {
return nil, common.CodeOperatingError, fmt.Errorf("user registration is disabled")
return nil, common.CodeOperatingError, fmt.Errorf("User registration is disabled!")
}
emailRegex := regexp.MustCompile(`^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$`)
if !emailRegex.MatchString(req.Email) {
return nil, common.CodeOperatingError, fmt.Errorf("invalid email address: %s", req.Email)
return nil, common.CodeOperatingError, fmt.Errorf("Invalid email address: %s!", req.Email)
}
existUser, _ := s.userDAO.GetByEmail(req.Email)
existUser, err := s.userDAO.GetByEmail(req.Email)
if existUser != nil {
return nil, common.CodeOperatingError, fmt.Errorf("email: %s has already registered", req.Email)
return nil, common.CodeOperatingError, fmt.Errorf("Email: %s has already registered!", req.Email)
}
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, common.CodeServerError, fmt.Errorf("failed to check existing user: %w", err)
}
decryptedPassword, err := s.decryptPassword(req.Password)
decryptedPassword, err := s.decryptPasswordForRegister(req.Password)
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("fail to decrypt password")
return nil, common.CodeExceptionError, err
}
var hashedPassword string
@@ -131,12 +135,22 @@ func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.Error
return nil, common.CodeServerError, fmt.Errorf("failed to hash password: %w", err)
}
userID := utility.GenerateToken()
accessToken := utility.GenerateToken()
userID, err := utility.GenerateUUID1()
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to generate user id: %w", err)
}
accessToken, err := utility.GenerateUUID1()
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to generate access token: %w", err)
}
status := "1"
loginChannel := "password"
isSuperuser := false
language := defaultUserLanguage()
colorSchema := "Bright"
timezone := "UTC+8\tAsia/Shanghai"
now := time.Now().Truncate(time.Second)
user := &entity.User{
ID: userID,
AccessToken: &accessToken,
@@ -144,9 +158,13 @@ func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.Error
Nickname: req.Nickname,
Password: &hashedPassword,
Status: &status,
Language: &language,
ColorSchema: &colorSchema,
Timezone: &timezone,
IsActive: "1",
IsAuthenticated: "1",
IsAnonymous: "0",
LastLoginTime: &now,
LoginChannel: &loginChannel,
IsSuperuser: &isSuperuser,
}
@@ -185,7 +203,10 @@ func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.Error
ParserIDs: "naive:General,Q&A:Q&A,manual:Manual,table:Table,paper:Research Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag",
Status: &status,
}
userTenantID := utility.GenerateToken()
userTenantID, err := utility.GenerateUUID1()
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to generate user tenant id: %w", err)
}
userTenant := &entity.UserTenant{
ID: userTenantID,
UserID: userID,
@@ -194,7 +215,11 @@ func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.Error
InvitedBy: userID,
Status: &status,
}
fileID := utility.GenerateToken()
fileID, err := utility.GenerateUUID1()
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to generate file id: %w", err)
}
file__ := ""
rootFile := &entity.File{
ID: fileID,
ParentID: fileID,
@@ -202,55 +227,133 @@ func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.Error
CreatedBy: userID,
Name: "/",
Type: "folder",
Location: &file__,
Size: 0,
}
tenantDAO := dao.NewTenantDAO()
userTenantDAO := dao.NewUserTenantDAO()
fileDAO := dao.NewFileDAO()
if err = s.userDAO.Create(user); err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to create user: %w", err)
tenantLLMs, err := s.getInitTenantLLM(userID)
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to initialize tenant llm: %w", err)
}
if err = tenantDAO.Create(tenant); err != nil {
err = s.userDAO.DeleteByID(userID)
if err != nil {
return nil, 0, err
db := dao.GetDB()
if err := db.Transaction(func(tx *gorm.DB) error {
if err := tx.Create(user).Error; err != nil {
return fmt.Errorf("failed to create user: %w", err)
}
return nil, common.CodeServerError, fmt.Errorf("failed to create tenant: %w", err)
}
if err = userTenantDAO.Create(userTenant); err != nil {
err = s.userDAO.DeleteByID(userID)
if err != nil {
return nil, 0, err
if err := tx.Create(tenant).Error; err != nil {
return fmt.Errorf("failed to create tenant: %w", err)
}
err = tenantDAO.Delete(userID)
if err != nil {
return nil, 0, err
}
return nil, common.CodeServerError, fmt.Errorf("failed to create user tenant relation: %w", err)
}
if err = fileDAO.Create(rootFile); err != nil {
err = s.userDAO.DeleteByID(userID)
if err != nil {
return nil, 0, err
if err := tx.Create(userTenant).Error; err != nil {
return fmt.Errorf("failed to create user tenant relation: %w", err)
}
err = tenantDAO.Delete(userID)
if err != nil {
return nil, 0, err
}
err = userTenantDAO.Delete(userTenantID)
if err != nil {
return nil, 0, err
}
return nil, common.CodeServerError, fmt.Errorf("failed to create root folder: %w", err)
}
if len(tenantLLMs) > 0 {
if err := tx.Create(&tenantLLMs).Error; err != nil {
return fmt.Errorf("failed to create tenant llm: %w", err)
}
}
if err := tx.Create(rootFile).Error; err != nil {
return fmt.Errorf("failed to create root folder: %w", err)
}
return nil
}); err != nil {
return nil, common.CodeServerError, fmt.Errorf("fail to create transaction: %w", err)
}
return user, common.CodeSuccess, nil
}
// getInitTenantLLM builds the tenant_llm rows created for a new user's default tenant.
func (s *UserService) getInitTenantLLM(userID string) ([]*entity.TenantLLM, error) {
cfg := server.GetConfig()
if cfg == nil {
return nil, fmt.Errorf("config not initialized")
}
modelConfigs := map[string]server.ModelConfig{
string(entity.ModelTypeChat): cfg.UserDefaultLLM.DefaultModels.ChatModel,
string(entity.ModelTypeEmbedding): cfg.UserDefaultLLM.DefaultModels.EmbeddingModel,
string(entity.ModelTypeSpeech2Text): cfg.UserDefaultLLM.DefaultModels.ASRModel,
string(entity.ModelTypeImage2Text): cfg.UserDefaultLLM.DefaultModels.Image2TextModel,
string(entity.ModelTypeRerank): cfg.UserDefaultLLM.DefaultModels.RerankModel,
}
seenFactories := make(map[string]bool)
factoryConfigs := make([]server.ModelConfig, 0, len(modelConfigs))
for _, modelConfig := range []server.ModelConfig{
cfg.UserDefaultLLM.DefaultModels.ChatModel,
cfg.UserDefaultLLM.DefaultModels.EmbeddingModel,
cfg.UserDefaultLLM.DefaultModels.ASRModel,
cfg.UserDefaultLLM.DefaultModels.Image2TextModel,
cfg.UserDefaultLLM.DefaultModels.RerankModel,
} {
if modelConfig.Factory == "" || seenFactories[modelConfig.Factory] {
continue
}
seenFactories[modelConfig.Factory] = true
factoryConfigs = append(factoryConfigs, modelConfig)
}
llmDAO := dao.NewLLMDAO()
tenantLLMs := make([]*entity.TenantLLM, 0)
for _, factoryConfig := range factoryConfigs {
llms, err := llmDAO.GetByFactory(factoryConfig.Factory)
if err != nil {
return nil, fmt.Errorf("failed to get LLMs for factory %s: %w", factoryConfig.Factory, err)
}
for _, llm := range llms {
apiKey := factoryConfig.APIKey
apiBase := factoryConfig.BaseURL
if modelConfig, ok := modelConfigs[llm.ModelType]; ok {
if modelConfig.APIKey != "" {
apiKey = modelConfig.APIKey
}
if modelConfig.BaseURL != "" {
apiBase = modelConfig.BaseURL
}
}
maxTokens := int64(8192)
if llm.MaxTokens > 0 {
maxTokens = llm.MaxTokens
}
llmName := llm.LLMName
modelType := llm.ModelType
tenantLLMs = append(tenantLLMs, &entity.TenantLLM{
TenantID: userID,
LLMFactory: factoryConfig.Factory,
LLMName: &llmName,
ModelType: &modelType,
APIKey: &apiKey,
APIBase: &apiBase,
MaxTokens: maxTokens,
Status: "1",
})
}
}
seen := make(map[string]bool)
uniqueTenantLLMs := make([]*entity.TenantLLM, 0, len(tenantLLMs))
for _, tenantLLM := range tenantLLMs {
llmName := ""
if tenantLLM.LLMName != nil {
llmName = *tenantLLM.LLMName
}
key := strings.Join([]string{tenantLLM.TenantID, tenantLLM.LLMFactory, llmName}, "|")
if seen[key] {
continue
}
seen[key] = true
uniqueTenantLLMs = append(uniqueTenantLLMs, tenantLLM)
}
return uniqueTenantLLMs, nil
}
// Login user login
func (s *UserService) Login(req *LoginRequest) (*entity.User, common.ErrorCode, error) {
// Get user by email (using username field as email)
@@ -610,6 +713,31 @@ func (s *UserService) decryptPassword(encryptedPassword string) (string, error)
return string(plaintext), nil
}
func (s *UserService) decryptPasswordForRegister(encryptedPassword string) (string, error) {
ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword)
if err != nil {
return "", fmt.Errorf("Error('Incorrect padding')")
}
privateKey, err := s.loadPrivateKey()
if err != nil {
return "", err
}
plaintext, err := rsa.DecryptPKCS1v15(nil, privateKey, ciphertext)
if err != nil {
return "Fail to decrypt password!", nil
}
return string(plaintext), nil
}
func defaultUserLanguage() string {
if strings.Contains(os.Getenv("LANG"), "zh_CN") {
return "Chinese"
}
return "English"
}
// GetUserByToken gets user by authorization header
// The token parameter is the authorization header value, which needs to be decrypted
// using itsdangerous URLSafeTimedSerializer to get the actual access_token
@@ -704,7 +832,7 @@ func (s *UserService) GetUserProfile(user *entity.User) map[string]interface{} {
}
// Get language
language := "English"
language := defaultUserLanguage()
if user.Language != nil && *user.Language != "" {
language = *user.Language
}

View File

@@ -142,6 +142,14 @@ func GenerateToken() string {
return strings.ReplaceAll(uuid.New().String(), "-", "")
}
func GenerateUUID1() (string, error) {
id, err := uuid.NewUUID()
if err != nil {
return "", err
}
return strings.ReplaceAll(id.String(), "-", ""), nil
}
// GenerateAPIToken generates a secure random access key
// Equivalent to Python's generate_confirmation_token():
// return "ragflow-" + secrets.token_urlsafe(32)