mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
## Summary - add public Go route for `/api/v1/searchbots/detail` - implement beta-token auth flow for shared search access - add tenant-based access check for shared search apps - add joined search detail query for the share response - align Go response shape with the current Python runtime behavior - add DAO / service / handler tests for the new endpoint
1367 lines
42 KiB
Go
1367 lines
42 KiB
Go
//
|
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rsa"
|
|
"crypto/sha256"
|
|
"crypto/sha512"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"hash"
|
|
"os"
|
|
"ragflow/internal/common"
|
|
"ragflow/internal/engine/redis"
|
|
"ragflow/internal/entity"
|
|
"ragflow/internal/server"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/pbkdf2"
|
|
"golang.org/x/crypto/scrypt"
|
|
"gorm.io/gorm"
|
|
|
|
"ragflow/internal/dao"
|
|
|
|
"ragflow/internal/utility"
|
|
)
|
|
|
|
// UserService user service
|
|
type UserService struct {
|
|
userDAO *dao.UserDAO
|
|
}
|
|
|
|
// NewUserService create user service
|
|
func NewUserService() *UserService {
|
|
return &UserService{
|
|
userDAO: dao.NewUserDAO(),
|
|
}
|
|
}
|
|
|
|
// RegisterRequest registration request
|
|
type RegisterRequest struct {
|
|
Email string `json:"email" binding:"required,email"`
|
|
Password string `json:"password" binding:"required,min=1"`
|
|
Nickname string `json:"nickname" binding:"required"`
|
|
}
|
|
|
|
// LoginRequest login request
|
|
type LoginRequest struct {
|
|
Username string `json:"username" binding:"required"`
|
|
Password string `json:"password" binding:"required"`
|
|
}
|
|
|
|
// EmailLoginRequest email login request
|
|
type EmailLoginRequest struct {
|
|
Email string `json:"email" binding:"required,email"`
|
|
Password string `json:"password" binding:"required"`
|
|
}
|
|
|
|
// UpdateSettingsRequest update user settings request
|
|
type UpdateSettingsRequest struct {
|
|
Nickname *string `json:"nickname,omitempty"`
|
|
Avatar *string `json:"avatar,omitempty"`
|
|
Language *string `json:"language,omitempty"`
|
|
ColorSchema *string `json:"color_schema,omitempty"`
|
|
Timezone *string `json:"timezone,omitempty"`
|
|
Password *string `json:"password,omitempty"`
|
|
NewPassword *string `json:"new_password,omitempty"`
|
|
}
|
|
|
|
// ChangePasswordRequest change password request
|
|
type ChangePasswordRequest struct {
|
|
Password *string `json:"password,omitempty"`
|
|
NewPassword *string `json:"new_password,omitempty"`
|
|
}
|
|
|
|
// UserResponse user response
|
|
type UserResponse struct {
|
|
ID string `json:"id"`
|
|
Email string `json:"email"`
|
|
Nickname string `json:"nickname"`
|
|
Status *string `json:"status"`
|
|
CreatedAt string `json:"created_at"`
|
|
}
|
|
|
|
// Register user registration
|
|
func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.ErrorCode, error) {
|
|
cfg := server.GetConfig()
|
|
if !cfg.Authentication.RegisterEnabled {
|
|
return nil, common.CodeOperatingError, fmt.Errorf("User registration is disabled!")
|
|
}
|
|
|
|
emailRegex := regexp.MustCompile(`^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$`)
|
|
if !emailRegex.MatchString(req.Email) {
|
|
return nil, common.CodeOperatingError, fmt.Errorf("Invalid email address: %s!", req.Email)
|
|
}
|
|
|
|
existUser, err := s.userDAO.GetByEmail(req.Email)
|
|
if existUser != nil {
|
|
return nil, common.CodeOperatingError, fmt.Errorf("Email: %s has already registered!", req.Email)
|
|
}
|
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to check existing user: %w", err)
|
|
}
|
|
|
|
decryptedPassword, err := common.DecryptPassword(req.Password)
|
|
if err != nil {
|
|
return nil, common.CodeExceptionError, err
|
|
}
|
|
|
|
var hashedPassword string
|
|
hashedPassword, err = common.GenerateWerkzeugPasswordHash(decryptedPassword)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to hash password: %w", err)
|
|
}
|
|
|
|
userID := utility.GenerateToken()
|
|
accessToken := utility.GenerateToken()
|
|
status := "1"
|
|
loginChannel := "password"
|
|
isSuperuser := false
|
|
language := defaultUserLanguage()
|
|
colorSchema := "Bright"
|
|
timezone := "UTC+8\tAsia/Shanghai"
|
|
|
|
now := time.Now().Truncate(time.Second)
|
|
user := &entity.User{
|
|
ID: userID,
|
|
AccessToken: &accessToken,
|
|
Email: req.Email,
|
|
Nickname: req.Nickname,
|
|
Password: &hashedPassword,
|
|
Status: &status,
|
|
Language: &language,
|
|
ColorSchema: &colorSchema,
|
|
Timezone: &timezone,
|
|
IsActive: "1",
|
|
IsAuthenticated: "1",
|
|
IsAnonymous: "0",
|
|
LastLoginTime: &now,
|
|
LoginChannel: &loginChannel,
|
|
IsSuperuser: &isSuperuser,
|
|
}
|
|
|
|
tenantName := req.Nickname + "'s Kingdom"
|
|
|
|
llmID := cfg.UserDefaultLLM.DefaultModels.ChatModel.Name
|
|
if llmID == "" {
|
|
llmID = ""
|
|
}
|
|
embdID := cfg.UserDefaultLLM.DefaultModels.EmbeddingModel.Name
|
|
if embdID == "" {
|
|
embdID = ""
|
|
}
|
|
asrID := cfg.UserDefaultLLM.DefaultModels.ASRModel.Name
|
|
if asrID == "" {
|
|
asrID = ""
|
|
}
|
|
img2txtID := cfg.UserDefaultLLM.DefaultModels.Image2TextModel.Name
|
|
if img2txtID == "" {
|
|
img2txtID = ""
|
|
}
|
|
rerankID := cfg.UserDefaultLLM.DefaultModels.RerankModel.Name
|
|
if rerankID == "" {
|
|
rerankID = ""
|
|
}
|
|
ttsID := cfg.UserDefaultLLM.DefaultModels.TTSModel.Name
|
|
if ttsID == "" {
|
|
ttsID = ""
|
|
}
|
|
ocrID := cfg.UserDefaultLLM.DefaultModels.OCRModel.Name
|
|
if ocrID == "" {
|
|
ocrID = ""
|
|
}
|
|
|
|
tenant := &entity.Tenant{
|
|
ID: userID,
|
|
Name: &tenantName,
|
|
LLMID: llmID,
|
|
EmbdID: embdID,
|
|
ASRID: asrID,
|
|
Img2TxtID: img2txtID,
|
|
RerankID: rerankID,
|
|
TTSID: ttsID,
|
|
OCRID: ocrID,
|
|
ParserIDs: "naive:General,Q&A:Q&A,manual:Manual,table:Table,paper:Research Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag",
|
|
Status: &status,
|
|
}
|
|
userTenantID := utility.GenerateToken()
|
|
userTenant := &entity.UserTenant{
|
|
ID: userTenantID,
|
|
UserID: userID,
|
|
TenantID: userID,
|
|
Role: "owner",
|
|
InvitedBy: userID,
|
|
Status: &status,
|
|
}
|
|
fileID := utility.GenerateToken()
|
|
file__ := ""
|
|
rootFile := &entity.File{
|
|
ID: fileID,
|
|
ParentID: fileID,
|
|
TenantID: userID,
|
|
CreatedBy: userID,
|
|
Name: "/",
|
|
Type: "folder",
|
|
Location: &file__,
|
|
Size: 0,
|
|
}
|
|
|
|
tenantLLMs, err := s.getInitTenantLLM(userID)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to initialize tenant llm: %w", err)
|
|
}
|
|
|
|
db := dao.GetDB()
|
|
if err := db.Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Create(user).Error; err != nil {
|
|
return fmt.Errorf("failed to create user: %w", err)
|
|
}
|
|
|
|
if err := tx.Create(tenant).Error; err != nil {
|
|
return fmt.Errorf("failed to create tenant: %w", err)
|
|
}
|
|
|
|
if err := tx.Create(userTenant).Error; err != nil {
|
|
return fmt.Errorf("failed to create user tenant relation: %w", err)
|
|
}
|
|
|
|
if len(tenantLLMs) > 0 {
|
|
if err := tx.Create(&tenantLLMs).Error; err != nil {
|
|
return fmt.Errorf("failed to create tenant llm: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := tx.Create(rootFile).Error; err != nil {
|
|
return fmt.Errorf("failed to create root folder: %w", err)
|
|
}
|
|
return nil
|
|
}); err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("fail to create transaction: %w", err)
|
|
}
|
|
return user, common.CodeSuccess, nil
|
|
}
|
|
|
|
// getInitTenantLLM builds the tenant_llm rows created for a new user's default tenant.
|
|
func (s *UserService) getInitTenantLLM(userID string) ([]*entity.TenantLLM, error) {
|
|
cfg := server.GetConfig()
|
|
if cfg == nil {
|
|
return nil, fmt.Errorf("config not initialized")
|
|
}
|
|
|
|
modelConfigs := map[string]server.ModelConfig{
|
|
string(entity.ModelTypeChat): cfg.UserDefaultLLM.DefaultModels.ChatModel,
|
|
string(entity.ModelTypeEmbedding): cfg.UserDefaultLLM.DefaultModels.EmbeddingModel,
|
|
string(entity.ModelTypeSpeech2Text): cfg.UserDefaultLLM.DefaultModels.ASRModel,
|
|
string(entity.ModelTypeImage2Text): cfg.UserDefaultLLM.DefaultModels.Image2TextModel,
|
|
string(entity.ModelTypeRerank): cfg.UserDefaultLLM.DefaultModels.RerankModel,
|
|
}
|
|
|
|
seenFactories := make(map[string]bool)
|
|
factoryConfigs := make([]server.ModelConfig, 0, len(modelConfigs))
|
|
for _, modelConfig := range []server.ModelConfig{
|
|
cfg.UserDefaultLLM.DefaultModels.ChatModel,
|
|
cfg.UserDefaultLLM.DefaultModels.EmbeddingModel,
|
|
cfg.UserDefaultLLM.DefaultModels.ASRModel,
|
|
cfg.UserDefaultLLM.DefaultModels.Image2TextModel,
|
|
cfg.UserDefaultLLM.DefaultModels.RerankModel,
|
|
} {
|
|
if modelConfig.Factory == "" || seenFactories[modelConfig.Factory] {
|
|
continue
|
|
}
|
|
seenFactories[modelConfig.Factory] = true
|
|
factoryConfigs = append(factoryConfigs, modelConfig)
|
|
}
|
|
|
|
llmDAO := dao.NewLLMDAO()
|
|
tenantLLMs := make([]*entity.TenantLLM, 0)
|
|
for _, factoryConfig := range factoryConfigs {
|
|
llms, err := llmDAO.GetByFactory(factoryConfig.Factory)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get LLMs for factory %s: %w", factoryConfig.Factory, err)
|
|
}
|
|
|
|
for _, llm := range llms {
|
|
apiKey := factoryConfig.APIKey
|
|
apiBase := factoryConfig.BaseURL
|
|
if modelConfig, ok := modelConfigs[llm.ModelType]; ok {
|
|
if modelConfig.APIKey != "" {
|
|
apiKey = modelConfig.APIKey
|
|
}
|
|
if modelConfig.BaseURL != "" {
|
|
apiBase = modelConfig.BaseURL
|
|
}
|
|
}
|
|
|
|
maxTokens := int64(8192)
|
|
if llm.MaxTokens > 0 {
|
|
maxTokens = llm.MaxTokens
|
|
}
|
|
|
|
llmName := llm.LLMName
|
|
modelType := llm.ModelType
|
|
tenantLLMs = append(tenantLLMs, &entity.TenantLLM{
|
|
TenantID: userID,
|
|
LLMFactory: factoryConfig.Factory,
|
|
LLMName: &llmName,
|
|
ModelType: &modelType,
|
|
APIKey: &apiKey,
|
|
APIBase: &apiBase,
|
|
MaxTokens: maxTokens,
|
|
Status: "1",
|
|
})
|
|
}
|
|
}
|
|
|
|
seen := make(map[string]bool)
|
|
uniqueTenantLLMs := make([]*entity.TenantLLM, 0, len(tenantLLMs))
|
|
for _, tenantLLM := range tenantLLMs {
|
|
llmName := ""
|
|
if tenantLLM.LLMName != nil {
|
|
llmName = *tenantLLM.LLMName
|
|
}
|
|
key := strings.Join([]string{tenantLLM.TenantID, tenantLLM.LLMFactory, llmName}, "|")
|
|
if seen[key] {
|
|
continue
|
|
}
|
|
seen[key] = true
|
|
uniqueTenantLLMs = append(uniqueTenantLLMs, tenantLLM)
|
|
}
|
|
return uniqueTenantLLMs, nil
|
|
}
|
|
|
|
// Login user login
|
|
func (s *UserService) Login(req *LoginRequest) (*entity.User, common.ErrorCode, error) {
|
|
// Get user by email (using username field as email)
|
|
user, err := s.userDAO.GetByEmail(req.Username)
|
|
if err != nil {
|
|
return nil, common.CodeAuthenticationError, fmt.Errorf("invalid email or password")
|
|
}
|
|
|
|
// Decrypt password using RSA
|
|
decryptedPassword, err := common.DecryptPassword(req.Password)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to decrypt password: %w", err)
|
|
}
|
|
|
|
// Verify password
|
|
if user.Password == nil || !s.VerifyPassword(*user.Password, decryptedPassword) {
|
|
return nil, common.CodeAuthenticationError, fmt.Errorf("invalid username or password")
|
|
}
|
|
|
|
if user.Status == nil || *user.Status != "1" {
|
|
return nil, common.CodeForbidden, fmt.Errorf("user is disabled")
|
|
}
|
|
|
|
// Generate new access token
|
|
token := utility.GenerateToken()
|
|
user.AccessToken = &token
|
|
now := time.Now().Truncate(time.Second)
|
|
user.LastLoginTime = &now
|
|
if err := s.userDAO.Update(user); err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to update user: %w", err)
|
|
}
|
|
|
|
return user, common.CodeSuccess, nil
|
|
}
|
|
|
|
// LoginByEmail user login by email
|
|
// Returns user on success, or error with specific code:
|
|
// - CodeAuthenticationError (109): Email not registered or password mismatch
|
|
// - CodeServerError (500): Password decryption failure
|
|
// - CodeForbidden (403): Account disabled
|
|
func (s *UserService) LoginByEmail(req *EmailLoginRequest) (*entity.User, common.ErrorCode, error) {
|
|
user, err := s.userDAO.GetByEmail(req.Email)
|
|
if err != nil {
|
|
return nil, common.CodeAuthenticationError, fmt.Errorf("email: %s is not registered!", req.Email)
|
|
}
|
|
|
|
decryptedPassword, err := common.DecryptPassword(req.Password)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("fail to crypt password")
|
|
}
|
|
|
|
if user.Password == nil || !s.VerifyPassword(*user.Password, decryptedPassword) {
|
|
return nil, common.CodeAuthenticationError, fmt.Errorf("email and password do not match!")
|
|
}
|
|
|
|
if user.IsActive == "0" {
|
|
return nil, common.CodeForbidden, fmt.Errorf("This account has been disabled, please contact the administrator!")
|
|
}
|
|
|
|
// Generate new access token
|
|
token := utility.GenerateToken()
|
|
user.AccessToken = &token
|
|
now := time.Now().Truncate(time.Second)
|
|
user.LastLoginTime = &now
|
|
|
|
if err := s.userDAO.Update(user); err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to update user: %w", err)
|
|
}
|
|
|
|
return user, common.CodeSuccess, nil
|
|
}
|
|
|
|
// GetUserByID get user by ID
|
|
func (s *UserService) GetUserByID(id uint) (*UserResponse, common.ErrorCode, error) {
|
|
user, err := s.userDAO.GetByID(id)
|
|
if err != nil {
|
|
return nil, common.CodeNotFound, err
|
|
}
|
|
|
|
return &UserResponse{
|
|
ID: user.ID,
|
|
Email: user.Email,
|
|
Nickname: user.Nickname,
|
|
Status: user.Status,
|
|
CreatedAt: func() string {
|
|
if user.CreateTime != nil {
|
|
return time.Unix(*user.CreateTime, 0).Format("2006-01-02 15:04:05")
|
|
}
|
|
return ""
|
|
}(),
|
|
}, common.CodeSuccess, nil
|
|
}
|
|
|
|
// ListUsers list users
|
|
func (s *UserService) ListUsers(page, pageSize int) ([]*UserResponse, int64, common.ErrorCode, error) {
|
|
offset := (page - 1) * pageSize
|
|
users, total, err := s.userDAO.List(offset, pageSize)
|
|
if err != nil {
|
|
return nil, 0, common.CodeServerError, err
|
|
}
|
|
|
|
responses := make([]*UserResponse, len(users))
|
|
for i, user := range users {
|
|
responses[i] = &UserResponse{
|
|
ID: user.ID,
|
|
Email: user.Email,
|
|
Nickname: user.Nickname,
|
|
Status: user.Status,
|
|
CreatedAt: func() string {
|
|
if user.CreateTime != nil {
|
|
return time.Unix(*user.CreateTime, 0).Format("2006-01-02 15:04:05")
|
|
}
|
|
return ""
|
|
}(),
|
|
}
|
|
}
|
|
|
|
return responses, total, common.CodeSuccess, nil
|
|
}
|
|
|
|
// VerifyPassword verify password
|
|
// Supports both werkzeug pbkdf2 format (pbkdf2:sha256:iterations$salt$hash) and scrypt format
|
|
func (s *UserService) VerifyPassword(hashedPassword, password string) bool {
|
|
// Check if it's pbkdf2 format (werkzeug)
|
|
if strings.HasPrefix(hashedPassword, "pbkdf2:") {
|
|
return s.verifyPBKDF2Password(hashedPassword, password)
|
|
}
|
|
|
|
// Check if it's scrypt format
|
|
if strings.HasPrefix(hashedPassword, "scrypt:") {
|
|
return s.verifyScryptPassword(hashedPassword, password)
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// verifyPBKDF2Password verifies password using PBKDF2 (werkzeug format)
|
|
// Format: pbkdf2:sha256:iterations$salt$hash
|
|
func (s *UserService) verifyPBKDF2Password(hashedPassword, password string) bool {
|
|
parts := strings.Split(hashedPassword, "$")
|
|
if len(parts) != 3 {
|
|
return false
|
|
}
|
|
|
|
// Parse method (e.g., "pbkdf2:sha256:150000")
|
|
methodParts := strings.Split(parts[0], ":")
|
|
if len(methodParts) != 3 {
|
|
return false
|
|
}
|
|
|
|
if methodParts[0] != "pbkdf2" {
|
|
return false
|
|
}
|
|
|
|
var hashFunc func() hash.Hash
|
|
switch methodParts[1] {
|
|
case "sha256":
|
|
hashFunc = sha256.New
|
|
case "sha512":
|
|
hashFunc = sha512.New
|
|
default:
|
|
return false
|
|
}
|
|
|
|
iterations, err := strconv.Atoi(methodParts[2])
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
salt := parts[1]
|
|
expectedHash := parts[2]
|
|
|
|
// Decode salt from base64
|
|
saltBytes, err := base64.StdEncoding.DecodeString(salt)
|
|
if err != nil {
|
|
// Try hex encoding
|
|
saltBytes, err = hex.DecodeString(salt)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
}
|
|
|
|
// Generate hash using PBKDF2
|
|
key := pbkdf2.Key([]byte(password), saltBytes, iterations, 32, hashFunc)
|
|
computedHash := base64.StdEncoding.EncodeToString(key)
|
|
|
|
return computedHash == expectedHash
|
|
}
|
|
|
|
// verifyScryptPassword verifies password using scrypt format
|
|
// Format: scrypt:n:r:p$base64(salt)$hex(hash)
|
|
// IMPORTANT: werkzeug uses the base64-encoded salt string as UTF-8 bytes, NOT the decoded bytes
|
|
func (s *UserService) verifyScryptPassword(hashedPassword, password string) bool {
|
|
// Parse hash format: scrypt:n:r:p$base64(salt)$hex(hash)
|
|
parts := strings.Split(hashedPassword, "$")
|
|
if len(parts) != 3 {
|
|
return false
|
|
}
|
|
|
|
params := strings.Split(parts[0], ":")
|
|
if len(params) != 4 || params[0] != "scrypt" {
|
|
return false
|
|
}
|
|
|
|
n, err := strconv.ParseUint(params[1], 10, 0)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
r, err := strconv.ParseUint(params[2], 10, 0)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
p, err := strconv.ParseUint(params[3], 10, 0)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
saltB64 := parts[1]
|
|
hashHex := parts[2]
|
|
|
|
// IMPORTANT: werkzeug uses the base64 string as UTF-8 bytes, NOT decoded bytes
|
|
// This is the key difference from standard implementations
|
|
salt := []byte(saltB64)
|
|
|
|
// Decode expected hash from hex
|
|
expectedHash, err := hex.DecodeString(hashHex)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
// Compute password hash
|
|
computed, err := scrypt.Key([]byte(password), salt, int(n), int(r), int(p), len(expectedHash))
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
// Constant time comparison
|
|
return s.constantTimeCompare(expectedHash, computed)
|
|
}
|
|
|
|
// constantTimeCompare constant time comparison
|
|
func (s *UserService) constantTimeCompare(a, b []byte) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
|
|
var result byte
|
|
for i := 0; i < len(a); i++ {
|
|
result |= a[i] ^ b[i]
|
|
}
|
|
|
|
return result == 0
|
|
}
|
|
|
|
func defaultUserLanguage() string {
|
|
if strings.Contains(os.Getenv("LANG"), "zh_CN") {
|
|
return "Chinese"
|
|
}
|
|
return "English"
|
|
}
|
|
|
|
// GetUserByToken gets user by authorization header
|
|
// The token parameter is the authorization header value, which needs to be decrypted
|
|
// using itsdangerous URLSafeTimedSerializer to get the actual access_token
|
|
func (s *UserService) GetUserByToken(authorization string) (*entity.User, common.ErrorCode, error) {
|
|
// Get secret key from config
|
|
secretKey, err := server.GetSecretKey(redis.Get())
|
|
if err != nil {
|
|
return nil, common.CodeUnauthorized, err
|
|
}
|
|
|
|
// Extract access token from authorization header
|
|
// Equivalent to: access_token = str(jwt.loads(authorization)) in Python
|
|
accessToken, err := utility.ExtractAccessToken(authorization, secretKey)
|
|
if err != nil {
|
|
return nil, common.CodeUnauthorized, fmt.Errorf("invalid authorization token: %w", err)
|
|
}
|
|
|
|
// Validate token format (should be at least 32 chars, UUID format)
|
|
if len(accessToken) < 32 {
|
|
return nil, common.CodeUnauthorized, fmt.Errorf("invalid access token format")
|
|
}
|
|
|
|
// Get user by access token
|
|
user, err := s.userDAO.GetByAccessToken(accessToken)
|
|
if err != nil {
|
|
return nil, common.CodeUnauthorized, err
|
|
}
|
|
|
|
return user, common.CodeSuccess, nil
|
|
}
|
|
|
|
// UpdateUserAccessToken updates user's access token
|
|
func (s *UserService) UpdateUserAccessToken(user *entity.User, token string) error {
|
|
return s.userDAO.UpdateAccessToken(user, token)
|
|
}
|
|
|
|
// Logout invalidates user's access token
|
|
func (s *UserService) Logout(user *entity.User) (common.ErrorCode, error) {
|
|
// Invalidate token by setting it to an invalid value
|
|
// Similar to Python implementation: "INVALID_" + secrets.token_hex(16)
|
|
invalidToken := "INVALID_" + utility.GenerateToken()
|
|
err := s.UpdateUserAccessToken(user, invalidToken)
|
|
if err != nil {
|
|
return common.CodeServerError, err
|
|
}
|
|
return common.CodeSuccess, nil
|
|
}
|
|
|
|
// GetUserProfile returns user profile information
|
|
func (s *UserService) GetUserProfile(user *entity.User) map[string]interface{} {
|
|
// Format create time and date (from database fields)
|
|
createTime := user.CreateTime
|
|
createDate := ""
|
|
if user.CreateDate != nil {
|
|
createDate = user.CreateDate.Format("2006-01-02T15:04:05")
|
|
}
|
|
|
|
// Format update time and date (from database fields)
|
|
var updateTime int64
|
|
updateDate := ""
|
|
if user.UpdateTime != nil {
|
|
updateTime = *user.UpdateTime
|
|
}
|
|
if user.UpdateDate != nil {
|
|
updateDate = user.UpdateDate.Format("2006-01-02T15:04:05")
|
|
}
|
|
|
|
// Format last login time
|
|
var lastLoginTime string
|
|
if user.LastLoginTime != nil {
|
|
lastLoginTime = user.LastLoginTime.Format("2006-01-02T15:04:05")
|
|
}
|
|
|
|
// Get access token
|
|
var accessToken string
|
|
if user.AccessToken != nil {
|
|
accessToken = *user.AccessToken
|
|
}
|
|
|
|
// Get avatar
|
|
var avatar interface{}
|
|
if user.Avatar != nil {
|
|
avatar = *user.Avatar
|
|
} else {
|
|
avatar = nil
|
|
}
|
|
|
|
// Get color schema
|
|
colorSchema := "Bright"
|
|
if user.ColorSchema != nil && *user.ColorSchema != "" {
|
|
colorSchema = *user.ColorSchema
|
|
}
|
|
|
|
// Get language
|
|
language := defaultUserLanguage()
|
|
if user.Language != nil && *user.Language != "" {
|
|
language = *user.Language
|
|
}
|
|
|
|
// Get timezone
|
|
timezone := "UTC+8\tAsia/Shanghai"
|
|
if user.Timezone != nil && *user.Timezone != "" {
|
|
timezone = *user.Timezone
|
|
}
|
|
|
|
// Get login channel
|
|
loginChannel := "password"
|
|
if user.LoginChannel != nil && *user.LoginChannel != "" {
|
|
loginChannel = *user.LoginChannel
|
|
}
|
|
|
|
// Get password
|
|
var password string
|
|
if user.Password != nil {
|
|
password = *user.Password
|
|
}
|
|
|
|
// Get status
|
|
status := "1"
|
|
if user.Status != nil {
|
|
status = *user.Status
|
|
}
|
|
|
|
// Get is_superuser
|
|
isSuperuser := false
|
|
if user.IsSuperuser != nil {
|
|
isSuperuser = *user.IsSuperuser
|
|
}
|
|
|
|
return map[string]interface{}{
|
|
"access_token": accessToken,
|
|
"avatar": avatar,
|
|
"color_schema": colorSchema,
|
|
"create_date": createDate,
|
|
"create_time": createTime,
|
|
"email": user.Email,
|
|
"id": user.ID,
|
|
"is_active": user.IsActive,
|
|
"is_anonymous": user.IsAnonymous,
|
|
"is_authenticated": user.IsAuthenticated,
|
|
"is_superuser": isSuperuser,
|
|
"language": language,
|
|
"last_login_time": lastLoginTime,
|
|
"login_channel": loginChannel,
|
|
"nickname": user.Nickname,
|
|
"password": password,
|
|
"status": status,
|
|
"timezone": timezone,
|
|
"update_date": updateDate,
|
|
"update_time": updateTime,
|
|
}
|
|
}
|
|
|
|
// UpdateUserSettings updates user settings
|
|
func (s *UserService) UpdateUserSettings(user *entity.User, req *UpdateSettingsRequest) (common.ErrorCode, error) {
|
|
// Update fields if provided
|
|
if req.Password != nil {
|
|
ciphertext, err := base64.StdEncoding.DecodeString(*req.Password)
|
|
if err != nil {
|
|
return common.CodeExceptionError, fmt.Errorf("Error('Incorrect padding')")
|
|
}
|
|
privateKey, err := common.LoadPrivateKey()
|
|
if err != nil {
|
|
return common.CodeExceptionError, err
|
|
}
|
|
oldPasswordBytes, err := rsa.DecryptPKCS1v15(nil, privateKey, ciphertext)
|
|
oldPassword := "Fail to decrypt password!"
|
|
if err == nil {
|
|
oldPassword = string(oldPasswordBytes)
|
|
}
|
|
if user.Password == nil || !s.VerifyPassword(*user.Password, oldPassword) {
|
|
return common.CodeAuthenticationError, fmt.Errorf("Password error!")
|
|
}
|
|
|
|
if req.NewPassword != nil {
|
|
ciphertext, err := base64.StdEncoding.DecodeString(*req.NewPassword)
|
|
if err != nil {
|
|
return common.CodeExceptionError, fmt.Errorf("Error('Incorrect padding')")
|
|
}
|
|
newPasswordBytes, err := rsa.DecryptPKCS1v15(nil, privateKey, ciphertext)
|
|
if err != nil {
|
|
return common.CodeExceptionError, err
|
|
}
|
|
|
|
hashedPassword, err := common.GenerateWerkzeugPasswordHash(string(newPasswordBytes))
|
|
if err != nil {
|
|
return common.CodeExceptionError, err
|
|
}
|
|
user.Password = &hashedPassword
|
|
}
|
|
}
|
|
if req.Nickname != nil {
|
|
user.Nickname = *req.Nickname
|
|
}
|
|
if req.Avatar != nil {
|
|
// In Go version, avatar might be stored differently
|
|
// For now, just update if field exists
|
|
user.Avatar = req.Avatar
|
|
}
|
|
if req.Language != nil {
|
|
// Store language preference
|
|
user.Language = req.Language
|
|
}
|
|
if req.ColorSchema != nil {
|
|
// Store color schema preference
|
|
user.ColorSchema = req.ColorSchema
|
|
}
|
|
if req.Timezone != nil {
|
|
// Store timezone preference
|
|
user.Timezone = req.Timezone
|
|
}
|
|
|
|
// Save updated user
|
|
if err := s.userDAO.Update(user); err != nil {
|
|
return common.CodeServerError, err
|
|
}
|
|
return common.CodeSuccess, nil
|
|
}
|
|
|
|
// ChangePassword changes user password
|
|
func (s *UserService) ChangePassword(user *entity.User, req *ChangePasswordRequest) (common.ErrorCode, error) {
|
|
// If password is provided, verify current password
|
|
if req.Password != nil {
|
|
if user.Password == nil || !s.VerifyPassword(*user.Password, *req.Password) {
|
|
return common.CodeBadRequest, fmt.Errorf("current password is incorrect")
|
|
}
|
|
}
|
|
|
|
// If new password is provided, update password
|
|
if req.NewPassword != nil {
|
|
hashedPassword, err := common.GenerateWerkzeugPasswordHash(*req.NewPassword)
|
|
if err != nil {
|
|
return common.CodeServerError, fmt.Errorf("failed to hash new password: %w", err)
|
|
}
|
|
user.Password = &hashedPassword
|
|
}
|
|
|
|
// Save updated user
|
|
if err := s.userDAO.Update(user); err != nil {
|
|
return common.CodeServerError, err
|
|
}
|
|
return common.CodeSuccess, nil
|
|
}
|
|
|
|
// LoginChannel represents a login channel response
|
|
type LoginChannel struct {
|
|
Channel string `json:"channel"`
|
|
DisplayName string `json:"display_name"`
|
|
Icon string `json:"icon"`
|
|
}
|
|
|
|
// GetLoginChannels gets all supported authentication channels
|
|
func (s *UserService) GetLoginChannels() ([]*LoginChannel, common.ErrorCode, error) {
|
|
cfg := server.GetConfig()
|
|
channels := make([]*LoginChannel, 0)
|
|
|
|
for channel, oauthCfg := range cfg.OAuth {
|
|
displayName := oauthCfg.DisplayName
|
|
if displayName == "" {
|
|
displayName = strings.Title(channel)
|
|
}
|
|
|
|
icon := oauthCfg.Icon
|
|
if icon == "" {
|
|
icon = "sso"
|
|
}
|
|
|
|
channels = append(channels, &LoginChannel{
|
|
Channel: channel,
|
|
DisplayName: displayName,
|
|
Icon: icon,
|
|
})
|
|
}
|
|
|
|
return channels, common.CodeSuccess, nil
|
|
}
|
|
|
|
// SetTenantInfoRequest represents the request for setting tenant info
|
|
type SetTenantInfoRequest struct {
|
|
TenantID *string `json:"tenant_id"`
|
|
ASRID *string `json:"asr_id"`
|
|
EmbdID *string `json:"embd_id"`
|
|
Img2TxtID *string `json:"img2txt_id"`
|
|
LLMID *string `json:"llm_id"`
|
|
RerankID *string `json:"rerank_id"`
|
|
TTSID *string `json:"tts_id"`
|
|
Raw map[string]interface{} `json:"-"`
|
|
}
|
|
|
|
// SetTenantInfo updates tenant model configuration
|
|
func (s *UserService) SetTenantInfo(userID string, req *SetTenantInfoRequest) (common.ErrorCode, error) {
|
|
_ = userID
|
|
tenantDAO := dao.NewTenantDAO()
|
|
updates := make(map[string]interface{})
|
|
|
|
for key, value := range req.Raw {
|
|
if key == "tenant_id" {
|
|
continue
|
|
}
|
|
updates[key] = value
|
|
}
|
|
|
|
tenantID := ""
|
|
if req.TenantID != nil {
|
|
tenantID = *req.TenantID
|
|
}
|
|
|
|
tenantLLMService := NewTenantLLMService()
|
|
updates = tenantLLMService.EnsureTenantModelIDForParams(tenantID, updates)
|
|
|
|
if len(updates) > 0 {
|
|
if err := tenantDAO.Update(tenantID, updates); err != nil {
|
|
return common.CodeExceptionError, err
|
|
}
|
|
}
|
|
|
|
return common.CodeSuccess, nil
|
|
}
|
|
|
|
// UserTenantService user tenant service
|
|
// Provides business logic for user-tenant relationship management
|
|
type UserTenantService struct {
|
|
userTenantDAO *dao.UserTenantDAO
|
|
}
|
|
|
|
// NewUserTenantService creates a new UserTenantService instance
|
|
/**
|
|
* Returns:
|
|
* - *UserTenantService: a new UserTenantService instance
|
|
*
|
|
* Example:
|
|
*
|
|
* service := NewUserTenantService()
|
|
* relations, err := service.GetUserTenantRelationByUserID("user123")
|
|
*/
|
|
func NewUserTenantService() *UserTenantService {
|
|
return &UserTenantService{
|
|
userTenantDAO: dao.NewUserTenantDAO(),
|
|
}
|
|
}
|
|
|
|
// UserTenantRelation represents a user-tenant relationship response
|
|
// This structure matches the Python implementation's return format
|
|
type UserTenantRelation struct {
|
|
ID string `json:"id"`
|
|
UserID string `json:"user_id"`
|
|
TenantID string `json:"tenant_id"`
|
|
Role string `json:"role"`
|
|
}
|
|
|
|
// GetUserTenantRelationByUserID retrieves all user-tenant relationships for a given user ID
|
|
/**
|
|
* This method returns a list of user-tenant relationships with selected fields:
|
|
* - id: the relationship ID
|
|
* - user_id: the user ID
|
|
* - tenant_id: the tenant ID
|
|
* - role: the user's role in the tenant
|
|
*
|
|
* Parameters:
|
|
* - userID: the unique identifier of the user
|
|
*
|
|
* Returns:
|
|
* - []*UserTenantRelation: list of user-tenant relationships
|
|
* - error: error if the operation fails, nil otherwise
|
|
*
|
|
* Example:
|
|
*
|
|
* service := NewUserTenantService()
|
|
* relations, err := service.GetUserTenantRelationByUserID("user123")
|
|
* if err != nil {
|
|
* log.Printf("Failed to get user tenant relations: %v", err)
|
|
* return
|
|
* }
|
|
* for _, rel := range relations {
|
|
* fmt.Printf("User %s has role %s in tenant %s\n", rel.UserID, rel.Role, rel.TenantID)
|
|
* }
|
|
*/
|
|
func (s *UserTenantService) GetUserTenantRelationByUserID(userID string) ([]*UserTenantRelation, error) {
|
|
return s.GetUserTenantRelationByUserIDWithContext(context.Background(), userID)
|
|
}
|
|
|
|
// GetUserTenantRelationByUserIDWithContext retrieves all user-tenant relationships for a given user ID with context.
|
|
func (s *UserTenantService) GetUserTenantRelationByUserIDWithContext(ctx context.Context, userID string) ([]*UserTenantRelation, error) {
|
|
relations, err := s.userTenantDAO.GetByUserIDWithContext(ctx, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := make([]*UserTenantRelation, len(relations))
|
|
for i, rel := range relations {
|
|
result[i] = convertToUserTenantRelation(rel)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// convertToUserTenantRelation converts model.UserTenant to UserTenantRelation
|
|
/**
|
|
* Parameters:
|
|
* - userTenant: the model.UserTenant to convert
|
|
*
|
|
* Returns:
|
|
* - *UserTenantRelation: the converted UserTenantRelation
|
|
*/
|
|
func convertToUserTenantRelation(userTenant *entity.UserTenant) *UserTenantRelation {
|
|
return &UserTenantRelation{
|
|
ID: userTenant.ID,
|
|
UserID: userTenant.UserID,
|
|
TenantID: userTenant.TenantID,
|
|
Role: userTenant.Role,
|
|
}
|
|
}
|
|
|
|
// GetUserByAPIToken gets user by access key from Authorization header
|
|
// This is used for API token authentication
|
|
// The authorization parameter should be in format: "Bearer <token>" or just "<token>"
|
|
func (s *UserService) GetUserByAPIToken(authorization string) (*entity.User, common.ErrorCode, error) {
|
|
if authorization == "" {
|
|
return nil, common.CodeUnauthorized, fmt.Errorf("authorization header is empty")
|
|
}
|
|
|
|
// Split authorization header to get the token
|
|
// Expected format: "Bearer <token>" or "<token>"
|
|
parts := strings.Split(authorization, " ")
|
|
var token string
|
|
if len(parts) == 2 {
|
|
token = parts[1]
|
|
} else if len(parts) == 1 {
|
|
token = parts[0]
|
|
} else {
|
|
return nil, common.CodeUnauthorized, fmt.Errorf("invalid authorization format")
|
|
}
|
|
|
|
// Query API token from database
|
|
apiTokenDAO := dao.NewAPITokenDAO()
|
|
userToken, err := apiTokenDAO.GetUserByAPIToken(token)
|
|
if err != nil {
|
|
return nil, common.CodeUnauthorized, fmt.Errorf("invalid access token")
|
|
}
|
|
|
|
// Get user by tenant_id from API token
|
|
user, err := s.userDAO.GetByTenantID(userToken.TenantID)
|
|
if err != nil {
|
|
return nil, common.CodeUnauthorized, fmt.Errorf("user not found for this access token")
|
|
}
|
|
|
|
// Check if user's access_token is empty
|
|
if user.AccessToken == nil || *user.AccessToken == "" {
|
|
return nil, common.CodeUnauthorized, fmt.Errorf("user has empty access_token in database")
|
|
}
|
|
|
|
return user, common.CodeSuccess, nil
|
|
|
|
}
|
|
|
|
// GetUserByBetaAPIToken gets user by beta access key from Authorization
|
|
// header. This mirrors Python's AUTH_BETA flow used by public bot endpoints.
|
|
func (s *UserService) GetUserByBetaAPIToken(authorization string) (*entity.User, common.ErrorCode, error) {
|
|
authorization = strings.TrimSpace(authorization)
|
|
if authorization == "" {
|
|
return nil, common.CodeUnauthorized, fmt.Errorf("authorization header is empty")
|
|
}
|
|
|
|
parts := strings.Fields(authorization)
|
|
var token string
|
|
if len(parts) == 2 {
|
|
token = parts[1]
|
|
} else if len(parts) == 1 {
|
|
if strings.EqualFold(parts[0], "Bearer") {
|
|
return nil, common.CodeUnauthorized, fmt.Errorf("invalid authorization format")
|
|
}
|
|
token = parts[0]
|
|
} else {
|
|
return nil, common.CodeUnauthorized, fmt.Errorf("invalid authorization format")
|
|
}
|
|
if token == "" {
|
|
return nil, common.CodeUnauthorized, fmt.Errorf("invalid authorization format")
|
|
}
|
|
|
|
apiTokenDAO := dao.NewAPITokenDAO()
|
|
userToken, err := apiTokenDAO.GetByBeta(token)
|
|
if err != nil {
|
|
return nil, common.CodeUnauthorized, fmt.Errorf("invalid beta access token")
|
|
}
|
|
|
|
user, err := s.userDAO.GetByTenantID(userToken.TenantID)
|
|
if err != nil {
|
|
return nil, common.CodeUnauthorized, fmt.Errorf("user not found for this beta access token")
|
|
}
|
|
|
|
if user.AccessToken == nil || *user.AccessToken == "" {
|
|
return nil, common.CodeUnauthorized, fmt.Errorf("user has empty access_token in database")
|
|
}
|
|
|
|
return user, common.CodeSuccess, nil
|
|
}
|
|
|
|
// ---- Forgot-password flow (mirrors api/apps/restful_apis/user_api.py
|
|
// `/auth/password/...` endpoints, fixes #15282) -------------------------
|
|
|
|
// ForgotIssueCaptcha mints a captcha for the given email and stores the
|
|
// expected text in Redis under utility.CaptchaIDRedisKey, keyed by a
|
|
// fresh server-side captcha_id, with a 60s TTL. Returns the captcha_id
|
|
// and a renderable SVG image (data URL) the FE drops into <img src> so
|
|
// the human can read the challenge and type the answer. The plaintext
|
|
// code itself is never sent to the client outside the rendered image.
|
|
//
|
|
// Refuses unknown emails to avoid leaking the user list — matches Python.
|
|
func (s *UserService) ForgotIssueCaptcha(email string) (captchaID, imageDataURL string, code common.ErrorCode, err error) {
|
|
if email == "" {
|
|
return "", "", common.CodeArgumentError, fmt.Errorf("email is required")
|
|
}
|
|
if _, err := s.userDAO.GetByEmail(email); err != nil {
|
|
return "", "", common.CodeDataError, fmt.Errorf("invalid email")
|
|
}
|
|
|
|
text, err := utility.GenerateCaptchaCode()
|
|
if err != nil {
|
|
return "", "", common.CodeServerError, err
|
|
}
|
|
captchaID = utility.GenerateToken()
|
|
if ok := redis.Get().Set(utility.CaptchaIDRedisKey(captchaID), text, 60*time.Second); !ok {
|
|
return "", "", common.CodeServerError, fmt.Errorf("failed to store captcha")
|
|
}
|
|
imageDataURL = utility.RenderCaptchaPNGDataURL(text)
|
|
return captchaID, imageDataURL, common.CodeSuccess, nil
|
|
}
|
|
|
|
// ForgotSendOTP verifies the captcha (looked up by the server-issued
|
|
// captcha_id), then issues an OTP and emails it. Hash-and-salt is
|
|
// stored in Redis under the keys returned by utility.OTPRedisKeys.
|
|
// Resend cooldown and per-email lockout behaviour otherwise match the
|
|
// Python implementation byte-for-byte.
|
|
func (s *UserService) ForgotSendOTP(email, captchaID, captcha string) (common.ErrorCode, error) {
|
|
if email == "" || captchaID == "" || captcha == "" {
|
|
return common.CodeArgumentError, fmt.Errorf("email, captcha_id and captcha required")
|
|
}
|
|
if _, err := s.userDAO.GetByEmail(email); err != nil {
|
|
return common.CodeDataError, fmt.Errorf("invalid email")
|
|
}
|
|
|
|
rc := redis.Get()
|
|
captchaKey := utility.CaptchaIDRedisKey(captchaID)
|
|
stored, _ := rc.Get(captchaKey)
|
|
if stored == "" {
|
|
return common.CodeNotEffective, fmt.Errorf("invalid or expired captcha")
|
|
}
|
|
if !strings.EqualFold(strings.TrimSpace(stored), strings.TrimSpace(captcha)) {
|
|
return common.CodeAuthenticationError, fmt.Errorf("invalid or expired captcha")
|
|
}
|
|
// One-shot: consume the captcha so a leaked captcha_id cannot be
|
|
// reused for a stream of OTP requests.
|
|
rc.Delete(captchaKey)
|
|
|
|
codeKey, attemptsKey, lastSentKey, lockKey := utility.OTPRedisKeys(email)
|
|
|
|
// Lockout — a previous verify burst already locked this email; do not
|
|
// let a request for a new OTP wipe the lock (deliberate divergence
|
|
// from the Python implementation, which deletes the lock here and so
|
|
// allows a locked attacker to clear their own lockout by re-requesting).
|
|
if locked, _ := rc.Get(lockKey); locked != "" {
|
|
return common.CodeNotEffective, fmt.Errorf("too many attempts, try later")
|
|
}
|
|
|
|
// Resend cooldown — refuse if we already sent within the window.
|
|
if lastSent, _ := rc.Get(lastSentKey); lastSent != "" {
|
|
ts, parseErr := strconv.ParseInt(lastSent, 10, 64)
|
|
if parseErr == nil {
|
|
elapsed := time.Since(time.Unix(ts, 0))
|
|
remaining := utility.OTPResendCooldown - elapsed
|
|
if remaining > 0 {
|
|
return common.CodeNotEffective, fmt.Errorf("you still have to wait %d seconds", int(remaining.Seconds()))
|
|
}
|
|
}
|
|
}
|
|
|
|
otp, err := utility.GenerateOTPCode()
|
|
if err != nil {
|
|
return common.CodeServerError, err
|
|
}
|
|
salt, err := utility.GenerateOTPSalt()
|
|
if err != nil {
|
|
return common.CodeServerError, err
|
|
}
|
|
codeHash := utility.HashOTPCode(otp, salt)
|
|
now := strconv.FormatInt(time.Now().Unix(), 10)
|
|
|
|
// Snapshot the previous OTP-flow state so we can restore it if email
|
|
// delivery fails — otherwise the user is throttled by lastSentKey
|
|
// even though they never received the code.
|
|
prevCode, _ := rc.Get(codeKey)
|
|
prevAttempts, _ := rc.Get(attemptsKey)
|
|
prevLastSent, _ := rc.Get(lastSentKey)
|
|
|
|
if !rc.Set(codeKey, utility.EncodeOTPStorageValue(codeHash, salt), utility.OTPTTL) {
|
|
return common.CodeServerError, fmt.Errorf("failed to store otp")
|
|
}
|
|
rc.Set(attemptsKey, "0", utility.OTPTTL)
|
|
rc.Set(lastSentKey, now, utility.OTPTTL)
|
|
// Note: lockKey is intentionally not cleared here. If the user has
|
|
// been locked out by a previous verify burst, requesting a new OTP
|
|
// does not lift the lock — we already refused above.
|
|
|
|
ttlMin := int(utility.OTPTTL.Minutes())
|
|
cfg := server.GetConfig()
|
|
if err := utility.SendResetCodeEmail(cfg.SMTP, email, otp, ttlMin); err != nil {
|
|
// Roll back: restore prior code/attempts/last-sent or remove the
|
|
// keys we just wrote so the next attempt isn't blocked by the
|
|
// resend cooldown a failed send just installed.
|
|
if prevCode != "" {
|
|
rc.Set(codeKey, prevCode, utility.OTPTTL)
|
|
} else {
|
|
rc.Delete(codeKey)
|
|
}
|
|
if prevAttempts != "" {
|
|
rc.Set(attemptsKey, prevAttempts, utility.OTPTTL)
|
|
} else {
|
|
rc.Delete(attemptsKey)
|
|
}
|
|
if prevLastSent != "" {
|
|
rc.Set(lastSentKey, prevLastSent, utility.OTPTTL)
|
|
} else {
|
|
rc.Delete(lastSentKey)
|
|
}
|
|
return common.CodeServerError, fmt.Errorf("failed to send email")
|
|
}
|
|
return common.CodeSuccess, nil
|
|
}
|
|
|
|
// ForgotVerifyOTP checks an OTP submitted by the user. On success it
|
|
// consumes the OTP/attempt counters and writes a short-lived "verified"
|
|
// flag the reset endpoint will gate on.
|
|
func (s *UserService) ForgotVerifyOTP(email, otp string) (common.ErrorCode, error) {
|
|
if email == "" || otp == "" {
|
|
return common.CodeArgumentError, fmt.Errorf("email and otp are required")
|
|
}
|
|
if _, err := s.userDAO.GetByEmail(email); err != nil {
|
|
return common.CodeDataError, fmt.Errorf("invalid email")
|
|
}
|
|
|
|
rc := redis.Get()
|
|
codeKey, attemptsKey, lastSentKey, lockKey := utility.OTPRedisKeys(email)
|
|
|
|
if locked, _ := rc.Get(lockKey); locked != "" {
|
|
return common.CodeNotEffective, fmt.Errorf("too many attempts, try later")
|
|
}
|
|
|
|
stored, _ := rc.Get(codeKey)
|
|
if stored == "" {
|
|
return common.CodeNotEffective, fmt.Errorf("expired otp")
|
|
}
|
|
storedHash, salt, err := utility.DecodeOTPStorageValue(stored)
|
|
if err != nil {
|
|
return common.CodeServerError, fmt.Errorf("otp storage corrupted")
|
|
}
|
|
|
|
if utility.HashOTPCode(strings.ToUpper(strings.TrimSpace(otp)), salt) != storedHash {
|
|
// bump attempts; lock on >= limit
|
|
attempts := 0
|
|
if cur, _ := rc.Get(attemptsKey); cur != "" {
|
|
if n, perr := strconv.Atoi(cur); perr == nil {
|
|
attempts = n
|
|
}
|
|
}
|
|
attempts++
|
|
rc.Set(attemptsKey, strconv.Itoa(attempts), utility.OTPTTL)
|
|
if attempts >= utility.OTPAttemptLimit {
|
|
rc.Set(lockKey, strconv.FormatInt(time.Now().Unix(), 10), utility.OTPAttemptLockDuration)
|
|
}
|
|
return common.CodeAuthenticationError, fmt.Errorf("expired otp")
|
|
}
|
|
|
|
// Success: clear OTP state, mark email verified.
|
|
rc.Delete(codeKey)
|
|
rc.Delete(attemptsKey)
|
|
rc.Delete(lastSentKey)
|
|
rc.Delete(lockKey)
|
|
if !rc.Set(utility.OTPVerifiedRedisKey(email), "1", utility.OTPTTL) {
|
|
return common.CodeServerError, fmt.Errorf("failed to set verification state")
|
|
}
|
|
return common.CodeSuccess, nil
|
|
}
|
|
|
|
// ForgotResetPasswordRequest carries the JSON body of /auth/password/reset.
|
|
//
|
|
// No `binding` tags on purpose: gin's validator fires inside
|
|
// c.ShouldBindJSON and produces a verbose
|
|
// `Key: 'ForgotResetPasswordRequest.Email' Error:Field validation ...`
|
|
// message that diverges from the Python contract for this endpoint,
|
|
// which returns the friendlier `"email and passwords are required"`
|
|
// (api/apps/restful_apis/user_api.py:forget_reset_password). Letting
|
|
// the binding succeed with zero values means the existing service
|
|
// check below produces the matching message, and an entirely missing
|
|
// JSON body now gets exactly Python's response.
|
|
type ForgotResetPasswordRequest struct {
|
|
Email string `json:"email"`
|
|
NewPassword string `json:"new_password"`
|
|
ConfirmNewPassword string `json:"confirm_new_password"`
|
|
}
|
|
|
|
// ForgotResetPassword finalises the reset: only proceeds if the verified
|
|
// flag is set, validates the two ciphertexts match after RSA decryption,
|
|
// updates the password hash, and clears the verified flag. Returns the
|
|
// user so the handler can auto-login (matching Python's
|
|
// `construct_response(auth=user.get_id())`).
|
|
func (s *UserService) ForgotResetPassword(req *ForgotResetPasswordRequest) (*entity.User, common.ErrorCode, error) {
|
|
if req.Email == "" || req.NewPassword == "" || req.ConfirmNewPassword == "" {
|
|
return nil, common.CodeArgumentError, fmt.Errorf("email and passwords are required")
|
|
}
|
|
|
|
rc := redis.Get()
|
|
verifiedKey := utility.OTPVerifiedRedisKey(req.Email)
|
|
if v, _ := rc.Get(verifiedKey); v != "1" {
|
|
return nil, common.CodeAuthenticationError, fmt.Errorf("email not verified")
|
|
}
|
|
|
|
plain, err := common.DecryptPassword(req.NewPassword)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("fail to decrypt password")
|
|
}
|
|
confirm, err := common.DecryptPassword(req.ConfirmNewPassword)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("fail to decrypt password")
|
|
}
|
|
if plain != confirm {
|
|
return nil, common.CodeArgumentError, fmt.Errorf("passwords do not match")
|
|
}
|
|
|
|
user, err := s.userDAO.GetByEmail(req.Email)
|
|
if err != nil {
|
|
return nil, common.CodeDataError, fmt.Errorf("invalid email")
|
|
}
|
|
|
|
hashed, err := common.GenerateWerkzeugPasswordHash(plain)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to hash new password: %w", err)
|
|
}
|
|
user.Password = &hashed
|
|
|
|
// Auto-login: rotate the access token like LoginByEmail does so the
|
|
// handler can immediately mint an Authorization header.
|
|
token := utility.GenerateToken()
|
|
user.AccessToken = &token
|
|
now := time.Now().Truncate(time.Second)
|
|
user.LastLoginTime = &now
|
|
|
|
if err := s.userDAO.Update(user); err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to reset password: %w", err)
|
|
}
|
|
|
|
rc.Delete(verifiedKey)
|
|
return user, common.CodeSuccess, nil
|
|
}
|