mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Fix user registration initialization in Go API (#15349)
### What problem does this PR solve? This PR fixes several behavior gaps in the Go implementation of the user registration API. ### Type of change - Make `nickname` required for user registration. - Align registration error messages and response data with expected API behavior. - Handle password decryption errors for registration more consistently. - Generate UUID v1-style IDs for new users, access tokens, tenants, user-tenant records, and root files. - Initialize default user fields during registration, including: - language - color schema - timezone - last login time - Create user, tenant, user-tenant relation, tenant LLM records, and root folder in a single DB transaction. - Initialize default tenant LLM records from configured default models. - Avoid partial registration data when one creation step fails. - Use locale-based default language fallback for user profile responses.
This commit is contained in:
@@ -66,10 +66,14 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
|
||||
user, code, err := h.userService.Register(&req)
|
||||
if err != nil {
|
||||
var data interface{} = false
|
||||
if code == common.CodeExceptionError {
|
||||
data = nil
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": false,
|
||||
"data": data,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@ import (
|
||||
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
"golang.org/x/crypto/scrypt"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"ragflow/internal/dao"
|
||||
|
||||
@@ -62,7 +63,7 @@ func NewUserService() *UserService {
|
||||
type RegisterRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=1"`
|
||||
Nickname string `json:"nickname"`
|
||||
Nickname string `json:"nickname" binding:"required"`
|
||||
}
|
||||
|
||||
// LoginRequest login request
|
||||
@@ -107,22 +108,25 @@ type UserResponse struct {
|
||||
func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.ErrorCode, error) {
|
||||
cfg := server.GetConfig()
|
||||
if !cfg.Authentication.RegisterEnabled {
|
||||
return nil, common.CodeOperatingError, fmt.Errorf("user registration is disabled")
|
||||
return nil, common.CodeOperatingError, fmt.Errorf("User registration is disabled!")
|
||||
}
|
||||
|
||||
emailRegex := regexp.MustCompile(`^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$`)
|
||||
if !emailRegex.MatchString(req.Email) {
|
||||
return nil, common.CodeOperatingError, fmt.Errorf("invalid email address: %s", req.Email)
|
||||
return nil, common.CodeOperatingError, fmt.Errorf("Invalid email address: %s!", req.Email)
|
||||
}
|
||||
|
||||
existUser, _ := s.userDAO.GetByEmail(req.Email)
|
||||
existUser, err := s.userDAO.GetByEmail(req.Email)
|
||||
if existUser != nil {
|
||||
return nil, common.CodeOperatingError, fmt.Errorf("email: %s has already registered", req.Email)
|
||||
return nil, common.CodeOperatingError, fmt.Errorf("Email: %s has already registered!", req.Email)
|
||||
}
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to check existing user: %w", err)
|
||||
}
|
||||
|
||||
decryptedPassword, err := s.decryptPassword(req.Password)
|
||||
decryptedPassword, err := s.decryptPasswordForRegister(req.Password)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("fail to decrypt password")
|
||||
return nil, common.CodeExceptionError, err
|
||||
}
|
||||
|
||||
var hashedPassword string
|
||||
@@ -131,12 +135,22 @@ func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.Error
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
userID := utility.GenerateToken()
|
||||
accessToken := utility.GenerateToken()
|
||||
userID, err := utility.GenerateUUID1()
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to generate user id: %w", err)
|
||||
}
|
||||
accessToken, err := utility.GenerateUUID1()
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to generate access token: %w", err)
|
||||
}
|
||||
status := "1"
|
||||
loginChannel := "password"
|
||||
isSuperuser := false
|
||||
language := defaultUserLanguage()
|
||||
colorSchema := "Bright"
|
||||
timezone := "UTC+8\tAsia/Shanghai"
|
||||
|
||||
now := time.Now().Truncate(time.Second)
|
||||
user := &entity.User{
|
||||
ID: userID,
|
||||
AccessToken: &accessToken,
|
||||
@@ -144,9 +158,13 @@ func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.Error
|
||||
Nickname: req.Nickname,
|
||||
Password: &hashedPassword,
|
||||
Status: &status,
|
||||
Language: &language,
|
||||
ColorSchema: &colorSchema,
|
||||
Timezone: &timezone,
|
||||
IsActive: "1",
|
||||
IsAuthenticated: "1",
|
||||
IsAnonymous: "0",
|
||||
LastLoginTime: &now,
|
||||
LoginChannel: &loginChannel,
|
||||
IsSuperuser: &isSuperuser,
|
||||
}
|
||||
@@ -185,7 +203,10 @@ func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.Error
|
||||
ParserIDs: "naive:General,Q&A:Q&A,manual:Manual,table:Table,paper:Research Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag",
|
||||
Status: &status,
|
||||
}
|
||||
userTenantID := utility.GenerateToken()
|
||||
userTenantID, err := utility.GenerateUUID1()
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to generate user tenant id: %w", err)
|
||||
}
|
||||
userTenant := &entity.UserTenant{
|
||||
ID: userTenantID,
|
||||
UserID: userID,
|
||||
@@ -194,7 +215,11 @@ func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.Error
|
||||
InvitedBy: userID,
|
||||
Status: &status,
|
||||
}
|
||||
fileID := utility.GenerateToken()
|
||||
fileID, err := utility.GenerateUUID1()
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to generate file id: %w", err)
|
||||
}
|
||||
file__ := ""
|
||||
rootFile := &entity.File{
|
||||
ID: fileID,
|
||||
ParentID: fileID,
|
||||
@@ -202,55 +227,133 @@ func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.Error
|
||||
CreatedBy: userID,
|
||||
Name: "/",
|
||||
Type: "folder",
|
||||
Location: &file__,
|
||||
Size: 0,
|
||||
}
|
||||
tenantDAO := dao.NewTenantDAO()
|
||||
userTenantDAO := dao.NewUserTenantDAO()
|
||||
fileDAO := dao.NewFileDAO()
|
||||
|
||||
if err = s.userDAO.Create(user); err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to create user: %w", err)
|
||||
tenantLLMs, err := s.getInitTenantLLM(userID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to initialize tenant llm: %w", err)
|
||||
}
|
||||
|
||||
if err = tenantDAO.Create(tenant); err != nil {
|
||||
err = s.userDAO.DeleteByID(userID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
db := dao.GetDB()
|
||||
if err := db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Create(user).Error; err != nil {
|
||||
return fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to create tenant: %w", err)
|
||||
}
|
||||
|
||||
if err = userTenantDAO.Create(userTenant); err != nil {
|
||||
err = s.userDAO.DeleteByID(userID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
if err := tx.Create(tenant).Error; err != nil {
|
||||
return fmt.Errorf("failed to create tenant: %w", err)
|
||||
}
|
||||
err = tenantDAO.Delete(userID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to create user tenant relation: %w", err)
|
||||
}
|
||||
|
||||
if err = fileDAO.Create(rootFile); err != nil {
|
||||
err = s.userDAO.DeleteByID(userID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
if err := tx.Create(userTenant).Error; err != nil {
|
||||
return fmt.Errorf("failed to create user tenant relation: %w", err)
|
||||
}
|
||||
err = tenantDAO.Delete(userID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
err = userTenantDAO.Delete(userTenantID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to create root folder: %w", err)
|
||||
}
|
||||
|
||||
if len(tenantLLMs) > 0 {
|
||||
if err := tx.Create(&tenantLLMs).Error; err != nil {
|
||||
return fmt.Errorf("failed to create tenant llm: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Create(rootFile).Error; err != nil {
|
||||
return fmt.Errorf("failed to create root folder: %w", err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("fail to create transaction: %w", err)
|
||||
}
|
||||
return user, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// getInitTenantLLM builds the tenant_llm rows created for a new user's default tenant.
|
||||
func (s *UserService) getInitTenantLLM(userID string) ([]*entity.TenantLLM, error) {
|
||||
cfg := server.GetConfig()
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config not initialized")
|
||||
}
|
||||
|
||||
modelConfigs := map[string]server.ModelConfig{
|
||||
string(entity.ModelTypeChat): cfg.UserDefaultLLM.DefaultModels.ChatModel,
|
||||
string(entity.ModelTypeEmbedding): cfg.UserDefaultLLM.DefaultModels.EmbeddingModel,
|
||||
string(entity.ModelTypeSpeech2Text): cfg.UserDefaultLLM.DefaultModels.ASRModel,
|
||||
string(entity.ModelTypeImage2Text): cfg.UserDefaultLLM.DefaultModels.Image2TextModel,
|
||||
string(entity.ModelTypeRerank): cfg.UserDefaultLLM.DefaultModels.RerankModel,
|
||||
}
|
||||
|
||||
seenFactories := make(map[string]bool)
|
||||
factoryConfigs := make([]server.ModelConfig, 0, len(modelConfigs))
|
||||
for _, modelConfig := range []server.ModelConfig{
|
||||
cfg.UserDefaultLLM.DefaultModels.ChatModel,
|
||||
cfg.UserDefaultLLM.DefaultModels.EmbeddingModel,
|
||||
cfg.UserDefaultLLM.DefaultModels.ASRModel,
|
||||
cfg.UserDefaultLLM.DefaultModels.Image2TextModel,
|
||||
cfg.UserDefaultLLM.DefaultModels.RerankModel,
|
||||
} {
|
||||
if modelConfig.Factory == "" || seenFactories[modelConfig.Factory] {
|
||||
continue
|
||||
}
|
||||
seenFactories[modelConfig.Factory] = true
|
||||
factoryConfigs = append(factoryConfigs, modelConfig)
|
||||
}
|
||||
|
||||
llmDAO := dao.NewLLMDAO()
|
||||
tenantLLMs := make([]*entity.TenantLLM, 0)
|
||||
for _, factoryConfig := range factoryConfigs {
|
||||
llms, err := llmDAO.GetByFactory(factoryConfig.Factory)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get LLMs for factory %s: %w", factoryConfig.Factory, err)
|
||||
}
|
||||
|
||||
for _, llm := range llms {
|
||||
apiKey := factoryConfig.APIKey
|
||||
apiBase := factoryConfig.BaseURL
|
||||
if modelConfig, ok := modelConfigs[llm.ModelType]; ok {
|
||||
if modelConfig.APIKey != "" {
|
||||
apiKey = modelConfig.APIKey
|
||||
}
|
||||
if modelConfig.BaseURL != "" {
|
||||
apiBase = modelConfig.BaseURL
|
||||
}
|
||||
}
|
||||
|
||||
maxTokens := int64(8192)
|
||||
if llm.MaxTokens > 0 {
|
||||
maxTokens = llm.MaxTokens
|
||||
}
|
||||
|
||||
llmName := llm.LLMName
|
||||
modelType := llm.ModelType
|
||||
tenantLLMs = append(tenantLLMs, &entity.TenantLLM{
|
||||
TenantID: userID,
|
||||
LLMFactory: factoryConfig.Factory,
|
||||
LLMName: &llmName,
|
||||
ModelType: &modelType,
|
||||
APIKey: &apiKey,
|
||||
APIBase: &apiBase,
|
||||
MaxTokens: maxTokens,
|
||||
Status: "1",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
seen := make(map[string]bool)
|
||||
uniqueTenantLLMs := make([]*entity.TenantLLM, 0, len(tenantLLMs))
|
||||
for _, tenantLLM := range tenantLLMs {
|
||||
llmName := ""
|
||||
if tenantLLM.LLMName != nil {
|
||||
llmName = *tenantLLM.LLMName
|
||||
}
|
||||
key := strings.Join([]string{tenantLLM.TenantID, tenantLLM.LLMFactory, llmName}, "|")
|
||||
if seen[key] {
|
||||
continue
|
||||
}
|
||||
seen[key] = true
|
||||
uniqueTenantLLMs = append(uniqueTenantLLMs, tenantLLM)
|
||||
}
|
||||
return uniqueTenantLLMs, nil
|
||||
}
|
||||
|
||||
// Login user login
|
||||
func (s *UserService) Login(req *LoginRequest) (*entity.User, common.ErrorCode, error) {
|
||||
// Get user by email (using username field as email)
|
||||
@@ -610,6 +713,31 @@ func (s *UserService) decryptPassword(encryptedPassword string) (string, error)
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
func (s *UserService) decryptPasswordForRegister(encryptedPassword string) (string, error) {
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Error('Incorrect padding')")
|
||||
}
|
||||
|
||||
privateKey, err := s.loadPrivateKey()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
plaintext, err := rsa.DecryptPKCS1v15(nil, privateKey, ciphertext)
|
||||
if err != nil {
|
||||
return "Fail to decrypt password!", nil
|
||||
}
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
func defaultUserLanguage() string {
|
||||
if strings.Contains(os.Getenv("LANG"), "zh_CN") {
|
||||
return "Chinese"
|
||||
}
|
||||
return "English"
|
||||
}
|
||||
|
||||
// GetUserByToken gets user by authorization header
|
||||
// The token parameter is the authorization header value, which needs to be decrypted
|
||||
// using itsdangerous URLSafeTimedSerializer to get the actual access_token
|
||||
@@ -704,7 +832,7 @@ func (s *UserService) GetUserProfile(user *entity.User) map[string]interface{} {
|
||||
}
|
||||
|
||||
// Get language
|
||||
language := "English"
|
||||
language := defaultUserLanguage()
|
||||
if user.Language != nil && *user.Language != "" {
|
||||
language = *user.Language
|
||||
}
|
||||
|
||||
@@ -142,6 +142,14 @@ func GenerateToken() string {
|
||||
return strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
}
|
||||
|
||||
func GenerateUUID1() (string, error) {
|
||||
id, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.ReplaceAll(id.String(), "-", ""), nil
|
||||
}
|
||||
|
||||
// GenerateAPIToken generates a secure random access key
|
||||
// Equivalent to Python's generate_confirmation_token():
|
||||
// return "ragflow-" + secrets.token_urlsafe(32)
|
||||
|
||||
Reference in New Issue
Block a user