From 81fd1811b8906fbe392ba2c40bdc4d467c2d8bf0 Mon Sep 17 00:00:00 2001 From: chanx <1243304602@qq.com> Date: Fri, 6 Mar 2026 16:42:49 +0800 Subject: [PATCH] =?UTF-8?q?Feat=EF=BC=9AUsing=20Go=20to=20implement=20user?= =?UTF-8?q?=20registration=20logic=20(#13431)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? Feat:Using Go to implement user registration logic ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- .agents/rules/named.md | 192 ++++++++++++++++++++++++ .agents/skills/go-naming/SKILL.md | 6 + go.mod | 3 +- go.sum | 4 + internal/common/error_code.go | 40 +++++ internal/dao/database.go | 13 +- internal/dao/file.go | 5 + internal/dao/tenant.go | 10 ++ internal/dao/user.go | 5 + internal/handler/chat.go | 24 +-- internal/handler/chat_session.go | 24 +-- internal/handler/chunk.go | 6 +- internal/handler/connector.go | 6 +- internal/handler/file.go | 24 +-- internal/handler/kb.go | 6 +- internal/handler/llm.go | 16 +- internal/handler/search.go | 6 +- internal/handler/tenant.go | 11 +- internal/handler/user.go | 230 +++++++++++++++++++---------- internal/router/router.go | 1 + internal/service/user.go | 235 ++++++++++++++++++++++-------- 21 files changed, 660 insertions(+), 207 deletions(-) create mode 100644 .agents/rules/named.md create mode 100644 .agents/skills/go-naming/SKILL.md create mode 100644 internal/common/error_code.go diff --git a/.agents/rules/named.md b/.agents/rules/named.md new file mode 100644 index 0000000000..32ba41e1f8 --- /dev/null +++ b/.agents/rules/named.md @@ -0,0 +1,192 @@ +# Go Naming Best Practices + +## 1. Package Naming + +- **All lowercase, no underscores**: `package user`, not `package userService` or `package user_service` +- **Short and meaningful**: `package http`, `package json`, `package dao` +- **Avoid plurals**: `package user` not `package users` +- **Avoid generic names**: Avoid `package util`, `package common`, `package base` + +```go +// Recommended +package user +package handler +package service + +// Not recommended +package UserService +package user_service +package utils +``` + +## 2. File Naming + +- **All lowercase, underscore separated**: `user_handler.go`, `user_service.go` +- **Test files**: `user_handler_test.go` +- **Platform-specific**: `user_linux.go`, `user_windows.go` + +``` +user/ +├── user_handler.go +├── user_service.go +├── user_dao.go +└── user_test.go +``` + +## 3. Directory Naming + +- **All lowercase, no underscores or hyphens**: `internal/`, `pkg/`, `cmd/` +- **Short and descriptive**: `handler/`, `service/`, `dao/` + +``` +project/ +├── cmd/ # Main entry point +│ └── server_main.go +├── internal/ # Private code +│ ├── handler/ +│ ├── service/ +│ ├── dao/ +│ ├── model/ +│ └── middleware/ +├── pkg/ # Public code +└── api/ # API definitions +``` + +## 4. Interface Naming + +- **Single-method interfaces end with "-er"**: `Reader`, `Writer`, `Handler` +- **Verb form**: `Reader`, `Executor`, `Validator` + +```go +// Recommended +type Reader interface { + Read(p []byte) (n int, err error) +} + +type UserService interface { + Register(req *RegisterRequest) (*User, error) + Login(req *LoginRequest) (*User, error) +} + +// Not recommended +type UserInterface interface {} +type IUserService interface {} +``` + +## 5. Struct Naming + +- **CamelCase**: `UserService`, `UserHandler` +- **Avoid redundant prefixes**: `User` not `UserModel` + +```go +// Recommended +type UserService struct {} +type UserHandler struct {} +type RegisterRequest struct {} + +// Not recommended +type user_service struct {} +type SUserService struct {} +type UserModel struct {} +``` + +## 6. Method/Function Naming + +- **CamelCase** +- **Start with verb**: `GetUser`, `CreateUser`, `DeleteUser` +- **Boolean returns use Is/Has/Can prefix**: `IsValid`, `HasPermission` + +```go +// Recommended +func (s *UserService) Register(req *RegisterRequest) (*User, error) +func (s *UserService) GetUserByID(id uint) (*User, error) +func (s *UserService) IsEmailExists(email string) bool + +// Not recommended +func (s *UserService) register_user() +func (s *UserService) get_user_by_id() +func (s *UserService) CheckEmailExists() // Should use Is/Has +``` + +## 7. Constant Naming + +- **CamelCase**: `const MaxRetryCount = 3` +- **Enum constants**: `const StatusActive = "active"` + +```go +// Recommended +const ( + StatusActive = "1" + StatusInactive = "0" + MaxRetryCount = 3 +) + +// Not recommended +const ( + STATUS_ACTIVE = "1" // Not all uppercase + status_active = "1" // Not all lowercase +) +``` + +## 8. Error Variable Naming + +- **Start with "Err"**: `ErrNotFound`, `ErrInvalidInput` + +```go +// Recommended +var ( + ErrNotFound = errors.New("not found") + ErrInvalidInput = errors.New("invalid input") + ErrUnauthorized = errors.New("unauthorized") +) +``` + +## 9. Acronyms Keep Consistent Case + +```go +// Recommended +type HTTPHandler struct {} +var URL string +func GetHTTPClient() {} +func ParseJSON() {} + +// Not recommended +type HttpHandler struct {} +var Url string +func GetHttpClient() {} +``` + +## 10. Project Structure Naming + +``` +project-name/ +├── cmd/ # Main programs +│ └── app_name/ +│ └── main.go +├── internal/ # Private code +│ ├── handler/ # HTTP handlers +│ ├── service/ # Business logic +│ ├── repository/ # Data access +│ ├── model/ # Data models +│ └── config/ # Configuration +├── pkg/ # Public code +├── api/ # API definitions +├── configs/ # Config files +├── scripts/ # Scripts +├── docs/ # Documentation +├── go.mod +└── go.sum +``` + +## Summary Table + +| Type | Rule | Example | +| -------------- | ----------------------------------- | ------------------- | +| Package | All lowercase, no underscores | `package user` | +| File | All lowercase, underscore separated | `user_service.go` | +| Directory | All lowercase, no separators | `internal/handler/` | +| Struct | CamelCase, capitalized first letter | `UserService` | +| Interface | CamelCase, -er suffix | `Reader`, `Writer` | +| Method | CamelCase, verb prefix | `GetUserByID` | +| Constant | CamelCase | `MaxRetryCount` | +| Error Variable | Err prefix | `ErrNotFound` | diff --git a/.agents/skills/go-naming/SKILL.md b/.agents/skills/go-naming/SKILL.md new file mode 100644 index 0000000000..fb7f2b96a5 --- /dev/null +++ b/.agents/skills/go-naming/SKILL.md @@ -0,0 +1,6 @@ +--- +name: go-naming +description: Go naming conventions and best practices. Use this skill when working with Go code and need to name packages, files, directories, structs, interfaces, functions, variables, or constants. Provides comprehensive naming guidelines following Go community standards. +--- + +Strictly follow the naming conventions in [rules/named.md](rules/named.md) diff --git a/go.mod b/go.mod index 256f066ac6..139d21f861 100644 --- a/go.mod +++ b/go.mod @@ -61,7 +61,8 @@ require ( golang.org/x/arch v0.6.0 // indirect golang.org/x/exp v0.0.0-20231226003508-02704c960a9b // indirect golang.org/x/net v0.48.0 // indirect - golang.org/x/sys v0.40.0 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/term v0.40.0 // indirect golang.org/x/text v0.33.0 // indirect google.golang.org/protobuf v1.32.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/go.sum b/go.sum index 6b405659db..eb856996c4 100644 --- a/go.sum +++ b/go.sum @@ -156,6 +156,10 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= +golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= diff --git a/internal/common/error_code.go b/internal/common/error_code.go new file mode 100644 index 0000000000..0817c17431 --- /dev/null +++ b/internal/common/error_code.go @@ -0,0 +1,40 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package common + +type ErrorCode int + +const ( + CodeSuccess ErrorCode = 0 + CodeNotEffective ErrorCode = 10 + CodeExceptionError ErrorCode = 100 + CodeArgumentError ErrorCode = 101 + CodeDataError ErrorCode = 102 + CodeOperatingError ErrorCode = 103 + CodeTimeoutError ErrorCode = 104 + CodeConnectionError ErrorCode = 105 + CodeRunning ErrorCode = 106 + CodeResourceExhausted ErrorCode = 107 + CodePermissionError ErrorCode = 108 + CodeAuthenticationError ErrorCode = 109 + CodeBadRequest ErrorCode = 400 + CodeUnauthorized ErrorCode = 401 + CodeForbidden ErrorCode = 403 + CodeNotFound ErrorCode = 404 + CodeConflict ErrorCode = 409 + CodeServerError ErrorCode = 500 +) diff --git a/internal/dao/database.go b/internal/dao/database.go index b0d0e1ee5e..163f172df9 100644 --- a/internal/dao/database.go +++ b/internal/dao/database.go @@ -18,6 +18,7 @@ package dao import ( "fmt" + "ragflow/internal/model" "ragflow/internal/server" "time" @@ -77,9 +78,15 @@ func InitDB() error { sqlDB.SetConnMaxLifetime(time.Hour) // Auto migrate - //if err := DB.AutoMigrate(&model.User{}, &model.Document{}); err != nil { - // return fmt.Errorf("failed to migrate database: %w", err) - //} + if err := DB.AutoMigrate( + &model.User{}, + &model.Tenant{}, + &model.UserTenant{}, + &model.File{}, + &model.File2Document{}, + ); err != nil { + return fmt.Errorf("failed to migrate database: %w", err) + } logger.Info("Database connected and migrated successfully") return nil diff --git a/internal/dao/file.go b/internal/dao/file.go index bbf9a66098..c665d1077b 100644 --- a/internal/dao/file.go +++ b/internal/dao/file.go @@ -195,6 +195,11 @@ func (dao *FileDAO) GetAllParentFolders(startID string) ([]*model.File, error) { return parentFolders, nil } +// Create creates a new file +func (dao *FileDAO) Create(file *model.File) error { + return DB.Create(file).Error +} + // generateUUID generates a UUID func generateUUID() string { id := uuid.New().String() diff --git a/internal/dao/tenant.go b/internal/dao/tenant.go index c992b1a742..781c6c2058 100644 --- a/internal/dao/tenant.go +++ b/internal/dao/tenant.go @@ -88,3 +88,13 @@ func (dao *TenantDAO) GetByID(id string) (*model.Tenant, error) { } return &tenant, nil } + +// Create creates a new tenant +func (dao *TenantDAO) Create(tenant *model.Tenant) error { + return DB.Create(tenant).Error +} + +// Delete deletes a tenant by ID (soft delete) +func (dao *TenantDAO) Delete(id string) error { + return DB.Model(&model.Tenant{}).Where("id = ?", id).Update("status", "0").Error +} diff --git a/internal/dao/user.go b/internal/dao/user.go index 014be06197..ff134683bc 100644 --- a/internal/dao/user.go +++ b/internal/dao/user.go @@ -101,3 +101,8 @@ func (dao *UserDAO) List(offset, limit int) ([]*model.User, int64, error) { func (dao *UserDAO) Delete(id uint) error { return DB.Delete(&model.User{}, id).Error } + +// DeleteByID delete user by string ID +func (dao *UserDAO) DeleteByID(id string) error { + return DB.Model(&model.User{}).Where("id = ?", id).Update("status", "0").Error +} diff --git a/internal/handler/chat.go b/internal/handler/chat.go index aa09c3353b..c7b2dde984 100644 --- a/internal/handler/chat.go +++ b/internal/handler/chat.go @@ -59,11 +59,11 @@ func (h *ChatHandler) ListChats(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -112,11 +112,11 @@ func (h *ChatHandler) ListChatsNext(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -196,11 +196,11 @@ func (h *ChatHandler) SetDialog(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -268,11 +268,11 @@ func (h *ChatHandler) RemoveChats(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } diff --git a/internal/handler/chat_session.go b/internal/handler/chat_session.go index fd5d449231..54995371a5 100644 --- a/internal/handler/chat_session.go +++ b/internal/handler/chat_session.go @@ -61,11 +61,11 @@ func (h *ChatSessionHandler) SetChatSession(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -124,11 +124,11 @@ func (h *ChatSessionHandler) RemoveChatSessions(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -190,11 +190,11 @@ func (h *ChatSessionHandler) ListChatSessions(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -270,11 +270,11 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } diff --git a/internal/handler/chunk.go b/internal/handler/chunk.go index 10b19830da..d13f4ac279 100644 --- a/internal/handler/chunk.go +++ b/internal/handler/chunk.go @@ -59,11 +59,11 @@ func (h *ChunkHandler) RetrievalTest(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } diff --git a/internal/handler/connector.go b/internal/handler/connector.go index 9f54b80419..6c0ebedb05 100644 --- a/internal/handler/connector.go +++ b/internal/handler/connector.go @@ -58,11 +58,11 @@ func (h *ConnectorHandler) ListConnectors(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } diff --git a/internal/handler/file.go b/internal/handler/file.go index 974d3bbd68..3474ce0cb5 100644 --- a/internal/handler/file.go +++ b/internal/handler/file.go @@ -65,11 +65,11 @@ func (h *FileHandler) ListFiles(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -141,11 +141,11 @@ func (h *FileHandler) GetRootFolder(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -189,11 +189,11 @@ func (h *FileHandler) GetParentFolder(c *gin.Context) { } // Get user by access token (for validation) - _, err := h.userService.GetUserByToken(token) + _, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -246,11 +246,11 @@ func (h *FileHandler) GetAllParentFolders(c *gin.Context) { } // Get user by access token (for validation) - _, err := h.userService.GetUserByToken(token) + _, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } diff --git a/internal/handler/kb.go b/internal/handler/kb.go index 1c482fa89f..e4e2a025b4 100644 --- a/internal/handler/kb.go +++ b/internal/handler/kb.go @@ -130,11 +130,11 @@ func (h *KnowledgebaseHandler) ListKbs(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } diff --git a/internal/handler/llm.go b/internal/handler/llm.go index bcad7f2be1..6926dfc97d 100644 --- a/internal/handler/llm.go +++ b/internal/handler/llm.go @@ -71,10 +71,11 @@ func (h *LLMHandler) GetMyLLMs(c *gin.Context) { } // Get user by token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "error": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -127,10 +128,11 @@ func (h *LLMHandler) Factories(c *gin.Context) { } // Get user by token - _, err := h.userService.GetUserByToken(token) + _, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "error": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -207,11 +209,11 @@ func (h *LLMHandler) ListApp(c *gin.Context) { } // Get user by token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } diff --git a/internal/handler/search.go b/internal/handler/search.go index 5a6317b183..b291a78027 100644 --- a/internal/handler/search.go +++ b/internal/handler/search.go @@ -65,11 +65,11 @@ func (h *SearchHandler) ListSearchApps(c *gin.Context) { } // Get user by access token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } diff --git a/internal/handler/tenant.go b/internal/handler/tenant.go index ab96f958ce..02b87a4164 100644 --- a/internal/handler/tenant.go +++ b/internal/handler/tenant.go @@ -58,10 +58,11 @@ func (h *TenantHandler) TenantInfo(c *gin.Context) { return } // Get user by token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "error": "Invalid access token", + "code": code, + "message": err.Error(), }) return } @@ -109,11 +110,11 @@ func (h *TenantHandler) TenantList(c *gin.Context) { } // Get user by token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + "code": code, + "message": err.Error(), }) return } diff --git a/internal/handler/user.go b/internal/handler/user.go index 2a4091857f..7fb39a5df9 100644 --- a/internal/handler/user.go +++ b/internal/handler/user.go @@ -17,7 +17,9 @@ package handler import ( + "fmt" "net/http" + "ragflow/internal/common" "ragflow/internal/server" "ragflow/internal/utility" "strconv" @@ -47,31 +49,51 @@ func NewUserHandler(userService *service.UserService) *UserHandler { // @Produce json // @Param request body service.RegisterRequest true "registration info" // @Success 200 {object} map[string]interface{} -// @Router /api/v1/users/register [post] +// @Router /v1/user/register [post] func (h *UserHandler) Register(c *gin.Context) { var req service.RegisterRequest if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + "data": false, }) return } - user, err := h.userService.Register(&req) + user, code, err := h.userService.Register(&req) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, }) return } + variables := server.GetVariables() + secretKey := variables.SecretKey + authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": "Failed to generate auth token", + "data": false, + }) + return + } + + c.Header("Authorization", authToken) + c.Header("Access-Control-Allow-Origin", "*") + c.Header("Access-Control-Allow-Methods", "*") + c.Header("Access-Control-Allow-Headers", "*") + c.Header("Access-Control-Expose-Headers", "Authorization") + + profile := h.userService.GetUserProfile(user) c.JSON(http.StatusOK, gin.H{ - "message": "registration successful", - "data": gin.H{ - "id": user.ID, - "nickname": user.Nickname, - "email": user.Email, - }, + "code": common.CodeSuccess, + "message": fmt.Sprintf("%s, welcome aboard!", req.Nickname), + "data": profile, }) } @@ -87,18 +109,20 @@ func (h *UserHandler) Register(c *gin.Context) { func (h *UserHandler) Login(c *gin.Context) { var req service.LoginRequest if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, "message": err.Error(), + "data": false, }) return } - user, err := h.userService.Login(&req) + user, code, err := h.userService.Login(&req) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, + c.JSON(http.StatusOK, gin.H{ + "code": code, "message": err.Error(), + "data": false, }) return } @@ -114,7 +138,7 @@ func (h *UserHandler) Login(c *gin.Context) { c.Header("Access-Control-Expose-Headers", "Authorization") c.JSON(http.StatusOK, gin.H{ - "code": 0, + "code": common.CodeSuccess, "message": "Welcome back!", "data": user, }) @@ -132,18 +156,20 @@ func (h *UserHandler) Login(c *gin.Context) { func (h *UserHandler) LoginByEmail(c *gin.Context) { var req service.EmailLoginRequest if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, "message": err.Error(), + "data": false, }) return } - user, err := h.userService.LoginByEmail(&req) + user, code, err := h.userService.LoginByEmail(&req) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, + c.JSON(http.StatusOK, gin.H{ + "code": code, "message": err.Error(), + "data": false, }) return } @@ -151,21 +177,26 @@ func (h *UserHandler) LoginByEmail(c *gin.Context) { variables := server.GetVariables() secretKey := variables.SecretKey authToken, err := utility.DumpAccessToken(*user.AccessToken, secretKey) - - // Set Authorization header with access_token - if user.AccessToken != nil { - c.Header("Authorization", authToken) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": "Failed to generate auth token", + "data": false, + }) + return } - // Set CORS headers + + c.Header("Authorization", authToken) c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Methods", "*") c.Header("Access-Control-Allow-Headers", "*") c.Header("Access-Control-Expose-Headers", "Authorization") + profile := h.userService.GetUserProfile(user) c.JSON(http.StatusOK, gin.H{ - "code": 0, + "code": common.CodeSuccess, "message": "Welcome back!", - "data": user, + "data": profile, }) } @@ -182,22 +213,28 @@ func (h *UserHandler) GetUserByID(c *gin.Context) { idStr := c.Param("id") id, err := strconv.ParseUint(idStr, 10, 32) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "invalid user id", + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": "invalid user id", + "data": false, }) return } - user, err := h.userService.GetUserByID(uint(id)) + user, code, err := h.userService.GetUserByID(uint(id)) if err != nil { - c.JSON(http.StatusNotFound, gin.H{ - "error": "user not found", + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, }) return } c.JSON(http.StatusOK, gin.H{ - "data": user, + "code": common.CodeSuccess, + "message": "success", + "data": user, }) } @@ -222,15 +259,19 @@ func (h *UserHandler) ListUsers(c *gin.Context) { pageSize = 10 } - users, total, err := h.userService.ListUsers(page, pageSize) + users, total, code, err := h.userService.ListUsers(page, pageSize) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "failed to get users", + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, }) return } c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": "success", "data": gin.H{ "items": users, "total": total, @@ -253,34 +294,38 @@ func (h *UserHandler) Logout(c *gin.Context) { // Extract token from request token := c.GetHeader("Authorization") if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, "message": "Missing Authorization header", + "data": false, }) return } // Get user by token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "Invalid access token", + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, }) return } // Logout user - if err := h.userService.Logout(user); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "code": 500, + code, err = h.userService.Logout(user) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": code, "message": err.Error(), + "data": false, }) return } c.JSON(http.StatusOK, gin.H{ - "code": 0, + "code": common.CodeSuccess, "data": true, "message": "success", }) @@ -299,19 +344,21 @@ func (h *UserHandler) Info(c *gin.Context) { // Extract token from request token := c.GetHeader("Authorization") if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, "message": "Missing Authorization header", + "data": false, }) return } // Get user by token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "error": "Invalid access token", + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, }) return } @@ -320,8 +367,9 @@ func (h *UserHandler) Info(c *gin.Context) { profile := h.userService.GetUserProfile(user) c.JSON(http.StatusOK, gin.H{ - "code": 0, - "data": profile, + "code": common.CodeSuccess, + "message": "success", + "data": profile, }) } @@ -339,18 +387,21 @@ func (h *UserHandler) Setting(c *gin.Context) { // Extract token from request token := c.GetHeader("Authorization") if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, "message": "Missing Authorization header", + "data": false, }) return } // Get user by token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "error": "Invalid access token", + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, }) return } @@ -358,22 +409,29 @@ func (h *UserHandler) Setting(c *gin.Context) { // Parse request var req service.UpdateSettingsRequest if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + "data": false, }) return } // Update user settings - if err := h.userService.UpdateUserSettings(user, &req); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": err.Error(), + code, err = h.userService.UpdateUserSettings(user, &req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, }) return } c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, "message": "settings updated successfully", + "data": true, }) } @@ -391,18 +449,21 @@ func (h *UserHandler) ChangePassword(c *gin.Context) { // Extract token from request token := c.GetHeader("Authorization") if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeUnauthorized, "message": "Missing Authorization header", + "data": false, }) return } // Get user by token - user, err := h.userService.GetUserByToken(token) + user, code, err := h.userService.GetUserByToken(token) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "error": "Invalid access token", + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, }) return } @@ -410,22 +471,29 @@ func (h *UserHandler) ChangePassword(c *gin.Context) { // Parse request var req service.ChangePasswordRequest if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + "data": false, }) return } // Change password - if err := h.userService.ChangePassword(user, &req); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": err.Error(), + code, err = h.userService.ChangePassword(user, &req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + "data": false, }) return } c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, "message": "password changed successfully", + "data": true, }) } @@ -438,10 +506,10 @@ func (h *UserHandler) ChangePassword(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/user/login/channels [get] func (h *UserHandler) GetLoginChannels(c *gin.Context) { - channels, err := h.userService.GetLoginChannels() + channels, code, err := h.userService.GetLoginChannels() if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "code": 500, + c.JSON(http.StatusOK, gin.H{ + "code": code, "message": "Load channels failure, error: " + err.Error(), "data": []interface{}{}, }) @@ -449,7 +517,7 @@ func (h *UserHandler) GetLoginChannels(c *gin.Context) { } c.JSON(http.StatusOK, gin.H{ - "code": 0, + "code": common.CodeSuccess, "message": "success", "data": channels, }) diff --git a/internal/router/router.go b/internal/router/router.go index bcc1d68384..5f41765d60 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -86,6 +86,7 @@ func (r *Router) Setup(engine *gin.Engine) { // User login by email endpoint engine.POST("/v1/user/login", r.userHandler.LoginByEmail) + engine.POST("/v1/user/register", r.userHandler.Register) // User login channels endpoint engine.GET("/v1/user/login/channels", r.userHandler.GetLoginChannels) // User logout endpoint diff --git a/internal/service/user.go b/internal/service/user.go index e92541502d..99dd5351b5 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -25,7 +25,9 @@ import ( "errors" "fmt" "os" + "ragflow/internal/common" "ragflow/internal/server" + "regexp" "strconv" "strings" "time" @@ -52,9 +54,8 @@ func NewUserService() *UserService { // RegisterRequest registration request type RegisterRequest struct { - Username string `json:"username" binding:"required,min=3,max=50"` - Password string `json:"password" binding:"required,min=6"` Email string `json:"email" binding:"required,email"` + Password string `json:"password" binding:"required,min=6"` Nickname string `json:"nickname"` } @@ -96,125 +97,220 @@ type UserResponse struct { } // Register user registration -func (s *UserService) Register(req *RegisterRequest) (*model.User, error) { - // Check if email exists +func (s *UserService) Register(req *RegisterRequest) (*model.User, common.ErrorCode, error) { + cfg := server.GetConfig() + if cfg.RegisterEnabled == 0 { + return nil, common.CodeOperatingError, fmt.Errorf("User registration is disabled!") + } + + emailRegex := regexp.MustCompile(`^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$`) + if !emailRegex.MatchString(req.Email) { + return nil, common.CodeOperatingError, fmt.Errorf("Invalid email address: %s!", req.Email) + } + existUser, _ := s.userDAO.GetByEmail(req.Email) if existUser != nil { - return nil, errors.New("email already exists") + return nil, common.CodeOperatingError, fmt.Errorf("Email: %s has already registered!", req.Email) } - // Generate password hash - hashedPassword, err := s.HashPassword(req.Password) + decryptedPassword, err := s.decryptPassword(req.Password) if err != nil { - return nil, fmt.Errorf("failed to hash password: %w", err) + return nil, common.CodeServerError, fmt.Errorf("Fail to decrypt password") } - // Create user - status := "1" - user := &model.User{ - Password: &hashedPassword, - Email: req.Email, - Nickname: req.Nickname, - Status: &status, + hashedPassword, err := s.HashPassword(decryptedPassword) + if err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to hash password: %w", err) } + userID := s.GenerateToken() + accessToken := s.GenerateToken() + status := "1" + loginChannel := "password" + isSuperuser := false + + user := &model.User{ + ID: userID, + AccessToken: &accessToken, + Email: req.Email, + Nickname: req.Nickname, + Password: &hashedPassword, + Status: &status, + IsActive: "1", + IsAuthenticated: "1", + IsAnonymous: "0", + LoginChannel: &loginChannel, + IsSuperuser: &isSuperuser, + } + + now := time.Now().Unix() + user.CreateTime = now + user.UpdateTime = &now + now_date := time.Now() + user.CreateDate = &now_date + user.UpdateDate = &now_date + user.LastLoginTime = &now_date + + tenantName := req.Nickname + "'s Kingdom" + tenant := &model.Tenant{ + ID: userID, + Name: &tenantName, + LLMID: cfg.Server.Mode, + EmbDID: cfg.Server.Mode, + ASRID: cfg.Server.Mode, + Img2TxtID: cfg.Server.Mode, + RerankID: cfg.Server.Mode, + ParserIDs: "naive:General,Q&A:Q&A,manual:Manual,table:Table,paper:Research Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag", + } + tenant.CreateTime = now + tenant.UpdateTime = &now + tenant.CreateDate = &now_date + tenant.UpdateDate = &now_date + + userTenantID := s.GenerateToken() + userTenant := &model.UserTenant{ + ID: userTenantID, + UserID: userID, + TenantID: userID, + Role: "owner", + InvitedBy: userID, + Status: &status, + } + userTenant.CreateTime = now + userTenant.UpdateTime = &now + userTenant.CreateDate = &now_date + userTenant.UpdateDate = &now_date + + fileID := s.GenerateToken() + rootFile := &model.File{ + ID: fileID, + ParentID: fileID, + TenantID: userID, + CreatedBy: userID, + Name: "/", + Type: "folder", + Size: 0, + } + rootFile.CreateTime = now + rootFile.UpdateTime = &now + rootFile.CreateDate = &now_date + rootFile.UpdateDate = &now_date + + tenantDAO := dao.NewTenantDAO() + userTenantDAO := dao.NewUserTenantDAO() + fileDAO := dao.NewFileDAO() + if err := s.userDAO.Create(user); err != nil { - return nil, fmt.Errorf("failed to create user: %w", err) + return nil, common.CodeServerError, fmt.Errorf("failed to create user: %w", err) } - return user, nil + if err := tenantDAO.Create(tenant); err != nil { + s.userDAO.DeleteByID(userID) + return nil, common.CodeServerError, fmt.Errorf("failed to create tenant: %w", err) + } + + if err := userTenantDAO.Create(userTenant); err != nil { + s.userDAO.DeleteByID(userID) + tenantDAO.Delete(userID) + return nil, common.CodeServerError, fmt.Errorf("failed to create user tenant relation: %w", err) + } + + if err := fileDAO.Create(rootFile); err != nil { + s.userDAO.DeleteByID(userID) + tenantDAO.Delete(userID) + userTenantDAO.Delete(userTenantID) + return nil, common.CodeServerError, fmt.Errorf("failed to create root folder: %w", err) + } + + return user, common.CodeSuccess, nil } // Login user login -func (s *UserService) Login(req *LoginRequest) (*model.User, error) { +func (s *UserService) Login(req *LoginRequest) (*model.User, common.ErrorCode, error) { // Get user by email (using username field as email) user, err := s.userDAO.GetByEmail(req.Username) if err != nil { - return nil, errors.New("invalid email or password") + return nil, common.CodeAuthenticationError, fmt.Errorf("invalid email or password") } // Decrypt password using RSA decryptedPassword, err := s.decryptPassword(req.Password) if err != nil { - return nil, fmt.Errorf("failed to decrypt password: %w", err) + return nil, common.CodeServerError, fmt.Errorf("failed to decrypt password: %w", err) } // Verify password if user.Password == nil || !s.VerifyPassword(*user.Password, decryptedPassword) { - return nil, errors.New("invalid username or password") + return nil, common.CodeAuthenticationError, fmt.Errorf("invalid username or password") } - // Check user status if user.Status == nil || *user.Status != "1" { - return nil, errors.New("user is disabled") + return nil, common.CodeForbidden, fmt.Errorf("user is disabled") } // Generate new access token token := s.GenerateToken() if err := s.UpdateUserAccessToken(user, token); err != nil { - return nil, fmt.Errorf("failed to update access token: %w", err) + return nil, common.CodeServerError, fmt.Errorf("failed to update access token: %w", err) } // Update timestamp now := time.Now().Unix() user.UpdateTime = &now if err := s.userDAO.Update(user); err != nil { - return nil, fmt.Errorf("failed to update user: %w", err) + return nil, common.CodeServerError, fmt.Errorf("failed to update user: %w", err) } - return user, nil + return user, common.CodeSuccess, nil } // LoginByEmail user login by email -func (s *UserService) LoginByEmail(req *EmailLoginRequest) (*model.User, error) { - // Check for default admin account +// Returns user on success, or error with specific code: +// - CodeAuthenticationError (109): Email not registered or password mismatch +// - CodeServerError (500): Password decryption failure +// - CodeForbidden (403): Account disabled +func (s *UserService) LoginByEmail(req *EmailLoginRequest) (*model.User, common.ErrorCode, error) { if req.Email == "admin@ragflow.io" { - return nil, errors.New("default admin account cannot be used to login normal services") + return nil, common.CodeAuthenticationError, fmt.Errorf("default admin account cannot be used to login normal services") } - // Get user by email user, err := s.userDAO.GetByEmail(req.Email) if err != nil { - return nil, errors.New("invalid email or password") + return nil, common.CodeAuthenticationError, fmt.Errorf("Email: %s is not registered!", req.Email) } - // Decrypt password using RSA decryptedPassword, err := s.decryptPassword(req.Password) if err != nil { - return nil, fmt.Errorf("failed to decrypt password: %w", err) + return nil, common.CodeServerError, fmt.Errorf("Fail to crypt password") } - // Verify password if user.Password == nil || !s.VerifyPassword(*user.Password, decryptedPassword) { - return nil, errors.New("invalid email or password") + return nil, common.CodeAuthenticationError, fmt.Errorf("Email and password do not match!") } - // Check user status - if user.Status == nil || *user.Status != "1" { - return nil, errors.New("user is disabled") + if user.IsActive == "0" { + return nil, common.CodeForbidden, fmt.Errorf("This account has been disabled, please contact the administrator!") } - // Generate new access token token := s.GenerateToken() user.AccessToken = &token - // Update timestamp now := time.Now().Unix() user.UpdateTime = &now now_date := time.Now() user.UpdateDate = &now_date if err := s.userDAO.Update(user); err != nil { - return nil, fmt.Errorf("failed to update user: %w", err) + return nil, common.CodeServerError, fmt.Errorf("failed to update user: %w", err) } - return user, nil + return user, common.CodeSuccess, nil } // GetUserByID get user by ID -func (s *UserService) GetUserByID(id uint) (*UserResponse, error) { +func (s *UserService) GetUserByID(id uint) (*UserResponse, common.ErrorCode, error) { user, err := s.userDAO.GetByID(id) if err != nil { - return nil, err + return nil, common.CodeNotFound, err } return &UserResponse{ @@ -223,15 +319,15 @@ func (s *UserService) GetUserByID(id uint) (*UserResponse, error) { Nickname: user.Nickname, Status: user.Status, CreatedAt: time.Unix(user.CreateTime, 0).Format("2006-01-02 15:04:05"), - }, nil + }, common.CodeSuccess, nil } // ListUsers list users -func (s *UserService) ListUsers(page, pageSize int) ([]*UserResponse, int64, error) { +func (s *UserService) ListUsers(page, pageSize int) ([]*UserResponse, int64, common.ErrorCode, error) { offset := (page - 1) * pageSize users, total, err := s.userDAO.List(offset, pageSize) if err != nil { - return nil, 0, err + return nil, 0, common.CodeServerError, err } responses := make([]*UserResponse, len(users)) @@ -245,7 +341,7 @@ func (s *UserService) ListUsers(page, pageSize int) ([]*UserResponse, int64, err } } - return responses, total, nil + return responses, total, common.CodeSuccess, nil } // HashPassword generate password hash @@ -399,7 +495,7 @@ func (s *UserService) GenerateToken() string { // GetUserByToken gets user by authorization header // The token parameter is the authorization header value, which needs to be decrypted // using itsdangerous URLSafeTimedSerializer to get the actual access_token -func (s *UserService) GetUserByToken(authorization string) (*model.User, error) { +func (s *UserService) GetUserByToken(authorization string) (*model.User, common.ErrorCode, error) { // Get secret key from config variables := server.GetVariables() secretKey := variables.SecretKey @@ -408,16 +504,21 @@ func (s *UserService) GetUserByToken(authorization string) (*model.User, error) // Equivalent to: access_token = str(jwt.loads(authorization)) in Python accessToken, err := utility.ExtractAccessToken(authorization, secretKey) if err != nil { - return nil, fmt.Errorf("invalid authorization token: %w", err) + return nil, common.CodeUnauthorized, fmt.Errorf("invalid authorization token: %w", err) } // Validate token format (should be at least 32 chars, UUID format) if len(accessToken) < 32 { - return nil, errors.New("invalid access token format") + return nil, common.CodeUnauthorized, fmt.Errorf("invalid access token format") } // Get user by access token - return s.userDAO.GetByAccessToken(accessToken) + user, err := s.userDAO.GetByAccessToken(accessToken) + if err != nil { + return nil, common.CodeUnauthorized, err + } + + return user, common.CodeSuccess, nil } // UpdateUserAccessToken updates user's access token @@ -426,11 +527,15 @@ func (s *UserService) UpdateUserAccessToken(user *model.User, token string) erro } // Logout invalidates user's access token -func (s *UserService) Logout(user *model.User) error { +func (s *UserService) Logout(user *model.User) (common.ErrorCode, error) { // Invalidate token by setting it to an invalid value // Similar to Python implementation: "INVALID_" + secrets.token_hex(16) invalidToken := "INVALID_" + s.GenerateToken() - return s.UpdateUserAccessToken(user, invalidToken) + err := s.UpdateUserAccessToken(user, invalidToken) + if err != nil { + return common.CodeServerError, err + } + return common.CodeSuccess, nil } // GetUserProfile returns user profile information @@ -539,7 +644,7 @@ func (s *UserService) GetUserProfile(user *model.User) map[string]interface{} { } // UpdateUserSettings updates user settings -func (s *UserService) UpdateUserSettings(user *model.User, req *UpdateSettingsRequest) error { +func (s *UserService) UpdateUserSettings(user *model.User, req *UpdateSettingsRequest) (common.ErrorCode, error) { // Update fields if provided if req.Nickname != nil { user.Nickname = *req.Nickname @@ -562,15 +667,18 @@ func (s *UserService) UpdateUserSettings(user *model.User, req *UpdateSettingsRe } // Save updated user - return s.userDAO.Update(user) + if err := s.userDAO.Update(user); err != nil { + return common.CodeServerError, err + } + return common.CodeSuccess, nil } // ChangePassword changes user password -func (s *UserService) ChangePassword(user *model.User, req *ChangePasswordRequest) error { +func (s *UserService) ChangePassword(user *model.User, req *ChangePasswordRequest) (common.ErrorCode, error) { // If password is provided, verify current password if req.Password != nil { if user.Password == nil || !s.VerifyPassword(*user.Password, *req.Password) { - return errors.New("current password is incorrect") + return common.CodeBadRequest, fmt.Errorf("current password is incorrect") } } @@ -578,13 +686,16 @@ func (s *UserService) ChangePassword(user *model.User, req *ChangePasswordReques if req.NewPassword != nil { hashedPassword, err := s.HashPassword(*req.NewPassword) if err != nil { - return fmt.Errorf("failed to hash new password: %w", err) + return common.CodeServerError, fmt.Errorf("failed to hash new password: %w", err) } user.Password = &hashedPassword } // Save updated user - return s.userDAO.Update(user) + if err := s.userDAO.Update(user); err != nil { + return common.CodeServerError, err + } + return common.CodeSuccess, nil } // LoginChannel represents a login channel response @@ -595,7 +706,7 @@ type LoginChannel struct { } // GetLoginChannels gets all supported authentication channels -func (s *UserService) GetLoginChannels() ([]*LoginChannel, error) { +func (s *UserService) GetLoginChannels() ([]*LoginChannel, common.ErrorCode, error) { cfg := server.GetConfig() channels := make([]*LoginChannel, 0) @@ -617,5 +728,5 @@ func (s *UserService) GetLoginChannels() ([]*LoginChannel, error) { }) } - return channels, nil + return channels, common.CodeSuccess, nil }