From 09e91a8e61670ed9bf97e8e01431a125292ba731 Mon Sep 17 00:00:00 2001 From: Hz_ Date: Fri, 29 May 2026 19:29:23 +0800 Subject: [PATCH] Fix user registration initialization in Go API (#15349) ### What problem does this PR solve? This PR fixes several behavior gaps in the Go implementation of the user registration API. ### Type of change - Make `nickname` required for user registration. - Align registration error messages and response data with expected API behavior. - Handle password decryption errors for registration more consistently. - Generate UUID v1-style IDs for new users, access tokens, tenants, user-tenant records, and root files. - Initialize default user fields during registration, including: - language - color schema - timezone - last login time - Create user, tenant, user-tenant relation, tenant LLM records, and root folder in a single DB transaction. - Initialize default tenant LLM records from configured default models. - Avoid partial registration data when one creation step fails. - Use locale-based default language fallback for user profile responses. --- internal/handler/user.go | 6 +- internal/service/user.go | 222 ++++++++++++++++++++++++++++++-------- internal/utility/token.go | 8 ++ 3 files changed, 188 insertions(+), 48 deletions(-) diff --git a/internal/handler/user.go b/internal/handler/user.go index e167f59078..6076b54421 100644 --- a/internal/handler/user.go +++ b/internal/handler/user.go @@ -66,10 +66,14 @@ func (h *UserHandler) Register(c *gin.Context) { user, code, err := h.userService.Register(&req) if err != nil { + var data interface{} = false + if code == common.CodeExceptionError { + data = nil + } c.JSON(http.StatusOK, gin.H{ "code": code, "message": err.Error(), - "data": false, + "data": data, }) return } diff --git a/internal/service/user.go b/internal/service/user.go index 87bb260095..5db52dc10e 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -40,6 +40,7 @@ import ( "golang.org/x/crypto/pbkdf2" "golang.org/x/crypto/scrypt" + "gorm.io/gorm" "ragflow/internal/dao" @@ -62,7 +63,7 @@ func NewUserService() *UserService { type RegisterRequest struct { Email string `json:"email" binding:"required,email"` Password string `json:"password" binding:"required,min=1"` - Nickname string `json:"nickname"` + Nickname string `json:"nickname" binding:"required"` } // LoginRequest login request @@ -107,22 +108,25 @@ type UserResponse struct { func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.ErrorCode, error) { cfg := server.GetConfig() if !cfg.Authentication.RegisterEnabled { - return nil, common.CodeOperatingError, fmt.Errorf("user registration is disabled") + return nil, common.CodeOperatingError, fmt.Errorf("User registration is disabled!") } emailRegex := regexp.MustCompile(`^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$`) if !emailRegex.MatchString(req.Email) { - return nil, common.CodeOperatingError, fmt.Errorf("invalid email address: %s", req.Email) + return nil, common.CodeOperatingError, fmt.Errorf("Invalid email address: %s!", req.Email) } - existUser, _ := s.userDAO.GetByEmail(req.Email) + existUser, err := s.userDAO.GetByEmail(req.Email) if existUser != nil { - return nil, common.CodeOperatingError, fmt.Errorf("email: %s has already registered", req.Email) + return nil, common.CodeOperatingError, fmt.Errorf("Email: %s has already registered!", req.Email) + } + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return nil, common.CodeServerError, fmt.Errorf("failed to check existing user: %w", err) } - decryptedPassword, err := s.decryptPassword(req.Password) + decryptedPassword, err := s.decryptPasswordForRegister(req.Password) if err != nil { - return nil, common.CodeServerError, fmt.Errorf("fail to decrypt password") + return nil, common.CodeExceptionError, err } var hashedPassword string @@ -131,12 +135,22 @@ func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.Error return nil, common.CodeServerError, fmt.Errorf("failed to hash password: %w", err) } - userID := utility.GenerateToken() - accessToken := utility.GenerateToken() + userID, err := utility.GenerateUUID1() + if err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to generate user id: %w", err) + } + accessToken, err := utility.GenerateUUID1() + if err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to generate access token: %w", err) + } status := "1" loginChannel := "password" isSuperuser := false + language := defaultUserLanguage() + colorSchema := "Bright" + timezone := "UTC+8\tAsia/Shanghai" + now := time.Now().Truncate(time.Second) user := &entity.User{ ID: userID, AccessToken: &accessToken, @@ -144,9 +158,13 @@ func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.Error Nickname: req.Nickname, Password: &hashedPassword, Status: &status, + Language: &language, + ColorSchema: &colorSchema, + Timezone: &timezone, IsActive: "1", IsAuthenticated: "1", IsAnonymous: "0", + LastLoginTime: &now, LoginChannel: &loginChannel, IsSuperuser: &isSuperuser, } @@ -185,7 +203,10 @@ func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.Error ParserIDs: "naive:General,Q&A:Q&A,manual:Manual,table:Table,paper:Research Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag", Status: &status, } - userTenantID := utility.GenerateToken() + userTenantID, err := utility.GenerateUUID1() + if err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to generate user tenant id: %w", err) + } userTenant := &entity.UserTenant{ ID: userTenantID, UserID: userID, @@ -194,7 +215,11 @@ func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.Error InvitedBy: userID, Status: &status, } - fileID := utility.GenerateToken() + fileID, err := utility.GenerateUUID1() + if err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to generate file id: %w", err) + } + file__ := "" rootFile := &entity.File{ ID: fileID, ParentID: fileID, @@ -202,55 +227,133 @@ func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.Error CreatedBy: userID, Name: "/", Type: "folder", + Location: &file__, Size: 0, } - tenantDAO := dao.NewTenantDAO() - userTenantDAO := dao.NewUserTenantDAO() - fileDAO := dao.NewFileDAO() - if err = s.userDAO.Create(user); err != nil { - return nil, common.CodeServerError, fmt.Errorf("failed to create user: %w", err) + tenantLLMs, err := s.getInitTenantLLM(userID) + if err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to initialize tenant llm: %w", err) } - if err = tenantDAO.Create(tenant); err != nil { - err = s.userDAO.DeleteByID(userID) - if err != nil { - return nil, 0, err + db := dao.GetDB() + if err := db.Transaction(func(tx *gorm.DB) error { + if err := tx.Create(user).Error; err != nil { + return fmt.Errorf("failed to create user: %w", err) } - return nil, common.CodeServerError, fmt.Errorf("failed to create tenant: %w", err) - } - if err = userTenantDAO.Create(userTenant); err != nil { - err = s.userDAO.DeleteByID(userID) - if err != nil { - return nil, 0, err + if err := tx.Create(tenant).Error; err != nil { + return fmt.Errorf("failed to create tenant: %w", err) } - err = tenantDAO.Delete(userID) - if err != nil { - return nil, 0, err - } - return nil, common.CodeServerError, fmt.Errorf("failed to create user tenant relation: %w", err) - } - if err = fileDAO.Create(rootFile); err != nil { - err = s.userDAO.DeleteByID(userID) - if err != nil { - return nil, 0, err + if err := tx.Create(userTenant).Error; err != nil { + return fmt.Errorf("failed to create user tenant relation: %w", err) } - err = tenantDAO.Delete(userID) - if err != nil { - return nil, 0, err - } - err = userTenantDAO.Delete(userTenantID) - if err != nil { - return nil, 0, err - } - return nil, common.CodeServerError, fmt.Errorf("failed to create root folder: %w", err) - } + if len(tenantLLMs) > 0 { + if err := tx.Create(&tenantLLMs).Error; err != nil { + return fmt.Errorf("failed to create tenant llm: %w", err) + } + } + + if err := tx.Create(rootFile).Error; err != nil { + return fmt.Errorf("failed to create root folder: %w", err) + } + return nil + }); err != nil { + return nil, common.CodeServerError, fmt.Errorf("fail to create transaction: %w", err) + } return user, common.CodeSuccess, nil } +// getInitTenantLLM builds the tenant_llm rows created for a new user's default tenant. +func (s *UserService) getInitTenantLLM(userID string) ([]*entity.TenantLLM, error) { + cfg := server.GetConfig() + if cfg == nil { + return nil, fmt.Errorf("config not initialized") + } + + modelConfigs := map[string]server.ModelConfig{ + string(entity.ModelTypeChat): cfg.UserDefaultLLM.DefaultModels.ChatModel, + string(entity.ModelTypeEmbedding): cfg.UserDefaultLLM.DefaultModels.EmbeddingModel, + string(entity.ModelTypeSpeech2Text): cfg.UserDefaultLLM.DefaultModels.ASRModel, + string(entity.ModelTypeImage2Text): cfg.UserDefaultLLM.DefaultModels.Image2TextModel, + string(entity.ModelTypeRerank): cfg.UserDefaultLLM.DefaultModels.RerankModel, + } + + seenFactories := make(map[string]bool) + factoryConfigs := make([]server.ModelConfig, 0, len(modelConfigs)) + for _, modelConfig := range []server.ModelConfig{ + cfg.UserDefaultLLM.DefaultModels.ChatModel, + cfg.UserDefaultLLM.DefaultModels.EmbeddingModel, + cfg.UserDefaultLLM.DefaultModels.ASRModel, + cfg.UserDefaultLLM.DefaultModels.Image2TextModel, + cfg.UserDefaultLLM.DefaultModels.RerankModel, + } { + if modelConfig.Factory == "" || seenFactories[modelConfig.Factory] { + continue + } + seenFactories[modelConfig.Factory] = true + factoryConfigs = append(factoryConfigs, modelConfig) + } + + llmDAO := dao.NewLLMDAO() + tenantLLMs := make([]*entity.TenantLLM, 0) + for _, factoryConfig := range factoryConfigs { + llms, err := llmDAO.GetByFactory(factoryConfig.Factory) + if err != nil { + return nil, fmt.Errorf("failed to get LLMs for factory %s: %w", factoryConfig.Factory, err) + } + + for _, llm := range llms { + apiKey := factoryConfig.APIKey + apiBase := factoryConfig.BaseURL + if modelConfig, ok := modelConfigs[llm.ModelType]; ok { + if modelConfig.APIKey != "" { + apiKey = modelConfig.APIKey + } + if modelConfig.BaseURL != "" { + apiBase = modelConfig.BaseURL + } + } + + maxTokens := int64(8192) + if llm.MaxTokens > 0 { + maxTokens = llm.MaxTokens + } + + llmName := llm.LLMName + modelType := llm.ModelType + tenantLLMs = append(tenantLLMs, &entity.TenantLLM{ + TenantID: userID, + LLMFactory: factoryConfig.Factory, + LLMName: &llmName, + ModelType: &modelType, + APIKey: &apiKey, + APIBase: &apiBase, + MaxTokens: maxTokens, + Status: "1", + }) + } + } + + seen := make(map[string]bool) + uniqueTenantLLMs := make([]*entity.TenantLLM, 0, len(tenantLLMs)) + for _, tenantLLM := range tenantLLMs { + llmName := "" + if tenantLLM.LLMName != nil { + llmName = *tenantLLM.LLMName + } + key := strings.Join([]string{tenantLLM.TenantID, tenantLLM.LLMFactory, llmName}, "|") + if seen[key] { + continue + } + seen[key] = true + uniqueTenantLLMs = append(uniqueTenantLLMs, tenantLLM) + } + return uniqueTenantLLMs, nil +} + // Login user login func (s *UserService) Login(req *LoginRequest) (*entity.User, common.ErrorCode, error) { // Get user by email (using username field as email) @@ -610,6 +713,31 @@ func (s *UserService) decryptPassword(encryptedPassword string) (string, error) return string(plaintext), nil } +func (s *UserService) decryptPasswordForRegister(encryptedPassword string) (string, error) { + ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword) + if err != nil { + return "", fmt.Errorf("Error('Incorrect padding')") + } + + privateKey, err := s.loadPrivateKey() + if err != nil { + return "", err + } + + plaintext, err := rsa.DecryptPKCS1v15(nil, privateKey, ciphertext) + if err != nil { + return "Fail to decrypt password!", nil + } + return string(plaintext), nil +} + +func defaultUserLanguage() string { + if strings.Contains(os.Getenv("LANG"), "zh_CN") { + return "Chinese" + } + return "English" +} + // GetUserByToken gets user by authorization header // The token parameter is the authorization header value, which needs to be decrypted // using itsdangerous URLSafeTimedSerializer to get the actual access_token @@ -704,7 +832,7 @@ func (s *UserService) GetUserProfile(user *entity.User) map[string]interface{} { } // Get language - language := "English" + language := defaultUserLanguage() if user.Language != nil && *user.Language != "" { language = *user.Language } diff --git a/internal/utility/token.go b/internal/utility/token.go index d3e67f9e81..92258eb850 100644 --- a/internal/utility/token.go +++ b/internal/utility/token.go @@ -142,6 +142,14 @@ func GenerateToken() string { return strings.ReplaceAll(uuid.New().String(), "-", "") } +func GenerateUUID1() (string, error) { + id, err := uuid.NewUUID() + if err != nil { + return "", err + } + return strings.ReplaceAll(id.String(), "-", ""), nil +} + // GenerateAPIToken generates a secure random access key // Equivalent to Python's generate_confirmation_token(): // return "ragflow-" + secrets.token_urlsafe(32)