mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-03 09:11:59 +08:00
### What problem does this PR solve? The Profile **Name** field currently lacks application-level validation and allows users to save excessively long names and unsupported special characters. While the database enforces a maximum length of 100 characters, neither the frontend nor backend validates nickname format before persistence. This can result in inconsistent user data, poor user experience, and UI layout issues when long names wrap across multiple lines. This PR introduces consistent frontend and backend validation for profile names, enforces length and character constraints, provides clear validation feedback, and prevents invalid values from being saved. Fixes #15693 ### Type of change * [x] Bug Fix (non-breaking change which fixes an issue)
858 lines
27 KiB
Python
858 lines
27 KiB
Python
#
|
||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
#
|
||
import logging
|
||
import string
|
||
import os
|
||
import re
|
||
import secrets
|
||
import time
|
||
from datetime import datetime
|
||
import base64
|
||
|
||
from quart import make_response, redirect, request, session
|
||
from werkzeug.security import check_password_hash, generate_password_hash
|
||
|
||
from api.apps.auth import get_auth_client
|
||
from api.db import FileType, UserTenantRole
|
||
from api.db.services.file_service import FileService
|
||
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||
from common.time_utils import current_timestamp, datetime_format, get_format_time
|
||
from common.misc_utils import download_img, get_uuid
|
||
from common.constants import RetCode
|
||
from common.connection_utils import construct_response
|
||
from api.utils.api_utils import (
|
||
get_data_error_result,
|
||
get_json_result,
|
||
get_request_json,
|
||
server_error_response,
|
||
validate_request,
|
||
)
|
||
from api.utils.nickname_validation import validate_nickname
|
||
from api.utils.crypt import decrypt
|
||
from rag.utils.redis_conn import REDIS_CONN
|
||
from api.apps import login_required, current_user, login_user, logout_user
|
||
from api.utils.web_utils import (
|
||
send_email_html,
|
||
OTP_LENGTH,
|
||
OTP_TTL_SECONDS,
|
||
ATTEMPT_LIMIT,
|
||
ATTEMPT_LOCK_SECONDS,
|
||
RESEND_COOLDOWN_SECONDS,
|
||
otp_keys,
|
||
hash_code,
|
||
captcha_key,
|
||
)
|
||
from common import settings
|
||
|
||
|
||
@manager.route("/auth/login", methods=["POST"]) # noqa: F821
|
||
async def login():
|
||
"""
|
||
User login endpoint.
|
||
---
|
||
tags:
|
||
- User
|
||
parameters:
|
||
- in: body
|
||
name: body
|
||
description: Login credentials.
|
||
required: true
|
||
schema:
|
||
type: object
|
||
properties:
|
||
email:
|
||
type: string
|
||
description: User email.
|
||
password:
|
||
type: string
|
||
description: User password.
|
||
responses:
|
||
200:
|
||
description: Login successful.
|
||
schema:
|
||
type: object
|
||
401:
|
||
description: Authentication failed.
|
||
schema:
|
||
type: object
|
||
"""
|
||
json_body = await get_request_json()
|
||
if not json_body:
|
||
logging.warning("Login failed: invalid or empty JSON body")
|
||
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
|
||
|
||
email = json_body.get("email", "")
|
||
|
||
users = UserService.query(email=email)
|
||
if not users:
|
||
logging.warning("Login failed: email not registered")
|
||
return get_json_result(
|
||
data=False,
|
||
code=RetCode.AUTHENTICATION_ERROR,
|
||
message=f"Email: {email} is not registered!",
|
||
)
|
||
|
||
password = json_body.get("password")
|
||
try:
|
||
password = decrypt(password)
|
||
except BaseException:
|
||
logging.warning("Login failed: password decryption error")
|
||
return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="Fail to crypt password")
|
||
|
||
user = UserService.query_user(email, password)
|
||
|
||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||
logging.warning("Login failed: disabled account for user_id=%s", user.id)
|
||
return get_json_result(
|
||
data=False,
|
||
code=RetCode.FORBIDDEN,
|
||
message="This account has been disabled, please contact the administrator!",
|
||
)
|
||
elif user:
|
||
user.access_token = get_uuid()
|
||
login_user(user)
|
||
user.update_time = current_timestamp()
|
||
user.update_date = datetime_format(datetime.now())
|
||
user.save()
|
||
logging.info("Login successful: user_id=%s", user.id)
|
||
msg = "Welcome back!"
|
||
|
||
return await construct_response(data=user.to_safe_dict(for_self=True), auth=user.get_id(), message=msg)
|
||
else:
|
||
logging.warning("Login failed: wrong credentials")
|
||
return get_json_result(
|
||
data=False,
|
||
code=RetCode.AUTHENTICATION_ERROR,
|
||
message="Email and password do not match!",
|
||
)
|
||
|
||
|
||
@manager.route("/auth/login/channels", methods=["GET"]) # noqa: F821
|
||
async def get_login_channels():
|
||
"""
|
||
Get all supported authentication channels.
|
||
"""
|
||
try:
|
||
channels = []
|
||
for channel, config in settings.OAUTH_CONFIG.items():
|
||
channels.append(
|
||
{
|
||
"channel": channel,
|
||
"display_name": config.get("display_name", channel.title()),
|
||
"icon": config.get("icon", "sso"),
|
||
}
|
||
)
|
||
return get_json_result(data=channels)
|
||
except Exception as e:
|
||
logging.exception(e)
|
||
return get_json_result(data=[], message=f"Load channels failure, error: {str(e)}", code=RetCode.EXCEPTION_ERROR)
|
||
|
||
|
||
@manager.route("/auth/login/<channel>", methods=["GET"]) # noqa: F821
|
||
async def oauth_login(channel):
|
||
channel_config = settings.OAUTH_CONFIG.get(channel)
|
||
if not channel_config:
|
||
raise ValueError(f"Invalid channel name: {channel}")
|
||
auth_cli = get_auth_client(channel_config)
|
||
|
||
state = get_uuid()
|
||
session["oauth_state"] = state
|
||
auth_url = auth_cli.get_authorization_url(state)
|
||
logging.info("OAuth login initiated: channel='%s', state='%s'", channel, state)
|
||
return redirect(auth_url)
|
||
|
||
|
||
@manager.route("/auth/oauth/<channel>/callback", methods=["GET"]) # noqa: F821
|
||
async def oauth_callback(channel):
|
||
"""
|
||
Handle the OAuth/OIDC callback for various channels dynamically.
|
||
"""
|
||
try:
|
||
channel_config = settings.OAUTH_CONFIG.get(channel)
|
||
if not channel_config:
|
||
raise ValueError(f"Invalid channel name: {channel}")
|
||
auth_cli = get_auth_client(channel_config)
|
||
|
||
# Check the state
|
||
state = request.args.get("state")
|
||
if not state or state != session.get("oauth_state"):
|
||
return redirect("/?error=invalid_state")
|
||
session.pop("oauth_state", None)
|
||
|
||
# Obtain the authorization code
|
||
code = request.args.get("code")
|
||
if not code:
|
||
return redirect("/?error=missing_code")
|
||
|
||
# Exchange authorization code for access token
|
||
if hasattr(auth_cli, "async_exchange_code_for_token"):
|
||
token_info = await auth_cli.async_exchange_code_for_token(code)
|
||
else:
|
||
token_info = auth_cli.exchange_code_for_token(code)
|
||
access_token = token_info.get("access_token")
|
||
if not access_token:
|
||
return redirect("/?error=token_failed")
|
||
|
||
id_token = token_info.get("id_token")
|
||
|
||
# Fetch user info
|
||
if hasattr(auth_cli, "async_fetch_user_info"):
|
||
user_info = await auth_cli.async_fetch_user_info(access_token, id_token=id_token)
|
||
else:
|
||
user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
|
||
if not user_info.email:
|
||
return redirect("/?error=email_missing")
|
||
|
||
# Login or register
|
||
users = UserService.query(email=user_info.email)
|
||
user_id = get_uuid()
|
||
|
||
if not users:
|
||
try:
|
||
try:
|
||
avatar = await download_img(user_info.avatar_url)
|
||
except Exception as e:
|
||
logging.exception(e)
|
||
avatar = ""
|
||
|
||
users = user_register(
|
||
user_id,
|
||
{
|
||
"access_token": get_uuid(),
|
||
"email": user_info.email,
|
||
"avatar": avatar,
|
||
"nickname": user_info.nickname,
|
||
"login_channel": channel,
|
||
"last_login_time": get_format_time(),
|
||
"is_superuser": False,
|
||
},
|
||
)
|
||
|
||
if not users:
|
||
raise Exception(f"Failed to register {user_info.email}")
|
||
if len(users) > 1:
|
||
raise Exception(f"Same email: {user_info.email} exists!")
|
||
|
||
# Try to log in
|
||
user = users[0]
|
||
login_user(user)
|
||
return redirect(f"/?auth={user.get_id()}")
|
||
|
||
except Exception as e:
|
||
rollback_user_registration(user_id)
|
||
logging.exception(e)
|
||
return redirect(f"/?error={str(e)}")
|
||
|
||
# User exists, try to log in
|
||
user = users[0]
|
||
user.access_token = get_uuid()
|
||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||
return redirect("/?error=user_inactive")
|
||
|
||
login_user(user)
|
||
user.save()
|
||
return redirect(f"/?auth={user.get_id()}")
|
||
except Exception as e:
|
||
logging.exception(e)
|
||
return redirect(f"/?error={str(e)}")
|
||
|
||
|
||
@manager.route("/auth/logout", methods=["POST"]) # noqa: F821
|
||
@login_required
|
||
async def log_out():
|
||
"""
|
||
User logout endpoint.
|
||
---
|
||
tags:
|
||
- User
|
||
security:
|
||
- ApiKeyAuth: []
|
||
responses:
|
||
200:
|
||
description: Logout successful.
|
||
schema:
|
||
type: object
|
||
"""
|
||
user = current_user._get_current_object() if hasattr(current_user, "_get_current_object") else current_user
|
||
user_id = user.id
|
||
user.access_token = f"INVALID_{secrets.token_hex(16)}"
|
||
saved = user.save()
|
||
if saved == 0:
|
||
logging.error("Logout failed to persist access token update: user_id=%s", user_id)
|
||
return get_json_result(code=RetCode.SERVER_ERROR, data=False, message="Failed to update access token")
|
||
logout_user()
|
||
logging.info("Logout: user_id=%s, access_token invalidated", user_id)
|
||
return get_json_result(data=True)
|
||
|
||
|
||
@manager.route("/users/me", methods=["PATCH"]) # noqa: F821
|
||
@login_required
|
||
async def setting_user():
|
||
"""
|
||
Update user settings.
|
||
---
|
||
tags:
|
||
- User
|
||
security:
|
||
- ApiKeyAuth: []
|
||
parameters:
|
||
- in: body
|
||
name: body
|
||
description: User settings to update.
|
||
required: true
|
||
schema:
|
||
type: object
|
||
properties:
|
||
nickname:
|
||
type: string
|
||
description: New nickname.
|
||
email:
|
||
type: string
|
||
description: New email.
|
||
responses:
|
||
200:
|
||
description: Settings updated successfully.
|
||
schema:
|
||
type: object
|
||
"""
|
||
update_dict = {}
|
||
request_data = await get_request_json()
|
||
if request_data.get("password"):
|
||
new_password = request_data.get("new_password")
|
||
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
|
||
return get_json_result(
|
||
data=False,
|
||
code=RetCode.AUTHENTICATION_ERROR,
|
||
message="Password error!",
|
||
)
|
||
|
||
if new_password:
|
||
update_dict["password"] = generate_password_hash(decrypt(new_password))
|
||
|
||
for k in request_data.keys():
|
||
if k in [
|
||
"password",
|
||
"new_password",
|
||
"email",
|
||
"status",
|
||
"is_superuser",
|
||
"login_channel",
|
||
"is_anonymous",
|
||
"is_active",
|
||
"is_authenticated",
|
||
"last_login_time",
|
||
]:
|
||
continue
|
||
update_dict[k] = request_data[k]
|
||
|
||
if "nickname" in update_dict:
|
||
error_message, error_code = validate_nickname(update_dict["nickname"])
|
||
if error_message:
|
||
return get_json_result(data=False, message=error_message, code=error_code)
|
||
update_dict["nickname"] = update_dict["nickname"].strip()
|
||
|
||
try:
|
||
UserService.update_by_id(current_user.id, update_dict)
|
||
return get_json_result(data=True)
|
||
except Exception as e:
|
||
logging.exception(e)
|
||
return get_json_result(data=False, message="Update failure!", code=RetCode.EXCEPTION_ERROR)
|
||
|
||
|
||
@manager.route("/users/me", methods=["GET"]) # noqa: F821
|
||
@login_required
|
||
async def user_profile():
|
||
"""
|
||
Get user profile information.
|
||
---
|
||
tags:
|
||
- User
|
||
security:
|
||
- ApiKeyAuth: []
|
||
responses:
|
||
200:
|
||
description: User profile retrieved successfully.
|
||
schema:
|
||
type: object
|
||
properties:
|
||
id:
|
||
type: string
|
||
description: User ID.
|
||
nickname:
|
||
type: string
|
||
description: User nickname.
|
||
email:
|
||
type: string
|
||
description: User email.
|
||
"""
|
||
return get_json_result(data=current_user.to_safe_dict(for_self=True))
|
||
|
||
|
||
def rollback_user_registration(user_id):
|
||
try:
|
||
UserService.delete_by_id(user_id)
|
||
except Exception:
|
||
pass
|
||
try:
|
||
TenantService.delete_by_id(user_id)
|
||
except Exception:
|
||
pass
|
||
try:
|
||
u = UserTenantService.query(tenant_id=user_id)
|
||
if u:
|
||
UserTenantService.delete_by_id(u[0].id)
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
def user_register(user_id, user):
|
||
user["id"] = user_id
|
||
tenant = {
|
||
"id": user_id,
|
||
"name": user["nickname"] + "‘s Kingdom",
|
||
"llm_id": settings.CHAT_MDL,
|
||
"embd_id": settings.EMBEDDING_MDL,
|
||
"asr_id": settings.ASR_MDL,
|
||
"parser_ids": settings.PARSERS,
|
||
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
||
"rerank_id": settings.RERANK_MDL,
|
||
}
|
||
usr_tenant = {
|
||
"tenant_id": user_id,
|
||
"user_id": user_id,
|
||
"invited_by": user_id,
|
||
"role": UserTenantRole.OWNER,
|
||
}
|
||
file_id = get_uuid()
|
||
file = {
|
||
"id": file_id,
|
||
"parent_id": file_id,
|
||
"tenant_id": user_id,
|
||
"created_by": user_id,
|
||
"name": "/",
|
||
"type": FileType.FOLDER.value,
|
||
"size": 0,
|
||
"location": "",
|
||
}
|
||
|
||
# tenant_llm = get_init_tenant_llm(user_id)
|
||
|
||
if not UserService.save(**user):
|
||
return None
|
||
TenantService.insert(**tenant)
|
||
UserTenantService.insert(**usr_tenant)
|
||
# TenantLLMService.insert_many(tenant_llm)
|
||
FileService.insert(file)
|
||
return UserService.query(email=user["email"])
|
||
|
||
|
||
@manager.route("/users", methods=["POST"]) # noqa: F821
|
||
@validate_request("nickname", "email", "password")
|
||
async def user_add():
|
||
"""
|
||
Register a new user.
|
||
---
|
||
tags:
|
||
- User
|
||
parameters:
|
||
- in: body
|
||
name: body
|
||
description: Registration details.
|
||
required: true
|
||
schema:
|
||
type: object
|
||
properties:
|
||
nickname:
|
||
type: string
|
||
description: User nickname.
|
||
email:
|
||
type: string
|
||
description: User email.
|
||
password:
|
||
type: string
|
||
description: User password.
|
||
responses:
|
||
200:
|
||
description: Registration successful.
|
||
schema:
|
||
type: object
|
||
"""
|
||
|
||
if not settings.REGISTER_ENABLED:
|
||
return get_json_result(
|
||
data=False,
|
||
message="User registration is disabled!",
|
||
code=RetCode.OPERATING_ERROR,
|
||
)
|
||
|
||
req = await get_request_json()
|
||
email_address = req["email"]
|
||
|
||
# Validate the email address
|
||
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", email_address):
|
||
return get_json_result(
|
||
data=False,
|
||
message=f"Invalid email address: {email_address}!",
|
||
code=RetCode.OPERATING_ERROR,
|
||
)
|
||
|
||
# Check if the email address is already used
|
||
if UserService.query(email=email_address):
|
||
return get_json_result(
|
||
data=False,
|
||
message=f"Email: {email_address} has already registered!",
|
||
code=RetCode.OPERATING_ERROR,
|
||
)
|
||
|
||
# Construct user info data
|
||
nickname = req["nickname"]
|
||
error_message, error_code = validate_nickname(nickname)
|
||
if error_message:
|
||
return get_json_result(data=False, message=error_message, code=error_code)
|
||
nickname = nickname.strip()
|
||
|
||
user_dict = {
|
||
"access_token": get_uuid(),
|
||
"email": email_address,
|
||
"nickname": nickname,
|
||
"password": decrypt(req["password"]),
|
||
"login_channel": "password",
|
||
"last_login_time": get_format_time(),
|
||
"is_superuser": False,
|
||
}
|
||
|
||
user_id = get_uuid()
|
||
try:
|
||
users = user_register(user_id, user_dict)
|
||
if not users:
|
||
raise Exception(f"Fail to register {email_address}.")
|
||
if len(users) > 1:
|
||
raise Exception(f"Same email: {email_address} exists!")
|
||
user = users[0]
|
||
login_user(user)
|
||
return await construct_response(
|
||
data=user.to_safe_dict(for_self=True),
|
||
auth=user.get_id(),
|
||
message=f"{nickname}, welcome aboard!",
|
||
)
|
||
except Exception as e:
|
||
rollback_user_registration(user_id)
|
||
logging.exception(e)
|
||
return get_json_result(
|
||
data=False,
|
||
message=f"User registration failure, error: {str(e)}",
|
||
code=RetCode.EXCEPTION_ERROR,
|
||
)
|
||
|
||
|
||
@manager.route("/users/me/models", methods=["GET"]) # noqa: F821
|
||
@login_required
|
||
async def tenant_info():
|
||
"""
|
||
Get tenant information.
|
||
---
|
||
tags:
|
||
- Tenant
|
||
security:
|
||
- ApiKeyAuth: []
|
||
responses:
|
||
200:
|
||
description: Tenant information retrieved successfully.
|
||
schema:
|
||
type: object
|
||
properties:
|
||
tenant_id:
|
||
type: string
|
||
description: Tenant ID.
|
||
name:
|
||
type: string
|
||
description: Tenant name.
|
||
llm_id:
|
||
type: string
|
||
description: LLM ID.
|
||
embd_id:
|
||
type: string
|
||
description: Embedding model ID.
|
||
"""
|
||
try:
|
||
tenants = TenantService.get_info_by(current_user.id)
|
||
if not tenants:
|
||
return get_data_error_result(message="Tenant not found!")
|
||
return get_json_result(data=tenants[0])
|
||
except Exception as e:
|
||
return server_error_response(e)
|
||
|
||
|
||
@manager.route("/users/me/models", methods=["PATCH"]) # noqa: F821
|
||
@login_required
|
||
@validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id")
|
||
async def set_tenant_info():
|
||
"""
|
||
Update tenant information.
|
||
---
|
||
tags:
|
||
- Tenant
|
||
security:
|
||
- ApiKeyAuth: []
|
||
parameters:
|
||
- in: body
|
||
name: body
|
||
description: Tenant information to update.
|
||
required: true
|
||
schema:
|
||
type: object
|
||
properties:
|
||
tenant_id:
|
||
type: string
|
||
description: Tenant ID.
|
||
llm_id:
|
||
type: string
|
||
description: LLM ID.
|
||
embd_id:
|
||
type: string
|
||
description: Embedding model ID.
|
||
asr_id:
|
||
type: string
|
||
description: ASR model ID.
|
||
img2txt_id:
|
||
type: string
|
||
description: Image to Text model ID.
|
||
responses:
|
||
200:
|
||
description: Tenant information updated successfully.
|
||
schema:
|
||
type: object
|
||
"""
|
||
req = await get_request_json()
|
||
try:
|
||
tid = req.pop("tenant_id")
|
||
TenantService.update_by_id(tid, req)
|
||
return get_json_result(data=True)
|
||
except Exception as e:
|
||
return server_error_response(e)
|
||
|
||
|
||
@manager.route("/auth/password/forgot/captcha", methods=["POST"]) # noqa: F821
|
||
async def forget_get_captcha():
|
||
"""
|
||
GET /forget/captcha?email=<email>
|
||
- Generate an image captcha and cache it in Redis under key captcha:{email} with TTL = OTP_TTL_SECONDS.
|
||
- Returns the captcha as a PNG image.
|
||
"""
|
||
email = (request.args.get("email") or "")
|
||
if not email:
|
||
return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="email is required")
|
||
|
||
users = UserService.query(email=email)
|
||
if not users:
|
||
return get_json_result(data=False, code=RetCode.DATA_ERROR, message="invalid email")
|
||
|
||
# Generate captcha text
|
||
allowed = string.ascii_uppercase + string.digits
|
||
captcha_text = "".join(secrets.choice(allowed) for _ in range(OTP_LENGTH))
|
||
REDIS_CONN.set(captcha_key(email), captcha_text, 60) # Valid for 60 seconds
|
||
|
||
from captcha.image import ImageCaptcha
|
||
image = ImageCaptcha(width=300, height=120, font_sizes=[50, 60, 70])
|
||
img_bytes = image.generate(captcha_text).read()
|
||
response = await make_response(img_bytes)
|
||
response.headers.set("Content-Type", "image/JPEG")
|
||
return response
|
||
|
||
|
||
@manager.route("/auth/password/forgot/otp", methods=["POST"]) # noqa: F821
|
||
async def forget_send_otp():
|
||
"""
|
||
POST /forget/otp
|
||
- Verify the image captcha stored at captcha:{email} (case-insensitive).
|
||
- On success, generate an email OTP (A–Z with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
|
||
"""
|
||
req = await get_request_json()
|
||
email = req.get("email") or ""
|
||
captcha = (req.get("captcha") or "").strip()
|
||
|
||
if not email or not captcha:
|
||
return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="email and captcha required")
|
||
|
||
users = UserService.query(email=email)
|
||
if not users:
|
||
return get_json_result(data=False, code=RetCode.DATA_ERROR, message="invalid email")
|
||
|
||
stored_captcha = REDIS_CONN.get(captcha_key(email))
|
||
if not stored_captcha:
|
||
return get_json_result(data=False, code=RetCode.NOT_EFFECTIVE, message="invalid or expired captcha")
|
||
if (stored_captcha or "").strip().lower() != captcha.lower():
|
||
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="invalid or expired captcha")
|
||
|
||
# Delete captcha to prevent reuse
|
||
REDIS_CONN.delete(captcha_key(email))
|
||
|
||
k_code, k_attempts, k_last, k_lock = otp_keys(email)
|
||
now = int(time.time())
|
||
last_ts = REDIS_CONN.get(k_last)
|
||
if last_ts:
|
||
try:
|
||
elapsed = now - int(last_ts)
|
||
except Exception:
|
||
elapsed = RESEND_COOLDOWN_SECONDS
|
||
remaining = RESEND_COOLDOWN_SECONDS - elapsed
|
||
if remaining > 0:
|
||
return get_json_result(data=False, code=RetCode.NOT_EFFECTIVE, message=f"you still have to wait {remaining} seconds")
|
||
|
||
# Generate OTP (uppercase letters only) and store hashed
|
||
otp = "".join(secrets.choice(string.ascii_uppercase) for _ in range(OTP_LENGTH))
|
||
salt = os.urandom(16)
|
||
code_hash = hash_code(otp, salt)
|
||
REDIS_CONN.set(k_code, f"{code_hash}:{salt.hex()}", OTP_TTL_SECONDS)
|
||
REDIS_CONN.set(k_attempts, 0, OTP_TTL_SECONDS)
|
||
REDIS_CONN.set(k_last, now, OTP_TTL_SECONDS)
|
||
REDIS_CONN.delete(k_lock)
|
||
|
||
ttl_min = OTP_TTL_SECONDS // 60
|
||
|
||
try:
|
||
await send_email_html(
|
||
subject="Your Password Reset Code",
|
||
to_email=email,
|
||
template_key="reset_code",
|
||
code=otp,
|
||
ttl_min=ttl_min,
|
||
)
|
||
|
||
except Exception as e:
|
||
logging.exception(e)
|
||
return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="failed to send email")
|
||
|
||
return get_json_result(data=True, code=RetCode.SUCCESS, message="verification passed, email sent")
|
||
|
||
|
||
def _verified_key(email: str) -> str:
|
||
return f"otp:verified:{email}"
|
||
|
||
|
||
@manager.route("/auth/password/forgot/otp/verify", methods=["POST"]) # noqa: F821
|
||
async def forget_verify_otp():
|
||
"""
|
||
Verify email + OTP only. On success:
|
||
- consume the OTP and attempt counters
|
||
- set a short-lived verified flag in Redis for the email
|
||
Request JSON: { email, otp }
|
||
"""
|
||
req = await get_request_json()
|
||
email = req.get("email") or ""
|
||
otp = (req.get("otp") or "").strip()
|
||
|
||
if not all([email, otp]):
|
||
return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="email and otp are required")
|
||
|
||
users = UserService.query(email=email)
|
||
if not users:
|
||
return get_json_result(data=False, code=RetCode.DATA_ERROR, message="invalid email")
|
||
|
||
# Verify OTP from Redis
|
||
k_code, k_attempts, k_last, k_lock = otp_keys(email)
|
||
if REDIS_CONN.get(k_lock):
|
||
return get_json_result(data=False, code=RetCode.NOT_EFFECTIVE, message="too many attempts, try later")
|
||
|
||
stored = REDIS_CONN.get(k_code)
|
||
if not stored:
|
||
return get_json_result(data=False, code=RetCode.NOT_EFFECTIVE, message="expired otp")
|
||
|
||
try:
|
||
stored_hash, salt_hex = str(stored).split(":", 1)
|
||
salt = bytes.fromhex(salt_hex)
|
||
except Exception:
|
||
return get_json_result(data=False, code=RetCode.EXCEPTION_ERROR, message="otp storage corrupted")
|
||
|
||
calc = hash_code(otp.upper(), salt)
|
||
if calc != stored_hash:
|
||
# bump attempts
|
||
try:
|
||
attempts = int(REDIS_CONN.get(k_attempts) or 0) + 1
|
||
except Exception:
|
||
attempts = 1
|
||
REDIS_CONN.set(k_attempts, attempts, OTP_TTL_SECONDS)
|
||
if attempts >= ATTEMPT_LIMIT:
|
||
REDIS_CONN.set(k_lock, int(time.time()), ATTEMPT_LOCK_SECONDS)
|
||
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="expired otp")
|
||
|
||
# Success: consume OTP and attempts; mark verified
|
||
REDIS_CONN.delete(k_code)
|
||
REDIS_CONN.delete(k_attempts)
|
||
REDIS_CONN.delete(k_last)
|
||
REDIS_CONN.delete(k_lock)
|
||
|
||
# set verified flag with limited TTL, reuse OTP_TTL_SECONDS or smaller window
|
||
try:
|
||
REDIS_CONN.set(_verified_key(email), "1", OTP_TTL_SECONDS)
|
||
except Exception:
|
||
return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="failed to set verification state")
|
||
|
||
return get_json_result(data=True, code=RetCode.SUCCESS, message="otp verified")
|
||
|
||
|
||
@manager.route("/auth/password/reset", methods=["POST"]) # noqa: F821
|
||
async def forget_reset_password():
|
||
"""
|
||
Reset password after successful OTP verification.
|
||
Requires: { email, new_password, confirm_new_password }
|
||
Steps:
|
||
- check verified flag in Redis
|
||
- update user password
|
||
- auto login
|
||
- clear verified flag
|
||
"""
|
||
|
||
req = await get_request_json()
|
||
email = req.get("email") or ""
|
||
new_pwd = req.get("new_password")
|
||
new_pwd2 = req.get("confirm_new_password")
|
||
|
||
if not all([email, new_pwd, new_pwd2]):
|
||
return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="email and passwords are required")
|
||
|
||
if not REDIS_CONN.get(_verified_key(email)):
|
||
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="email not verified")
|
||
|
||
new_pwd_base64 = decrypt(new_pwd)
|
||
new_pwd_string = base64.b64decode(new_pwd_base64).decode('utf-8')
|
||
new_pwd2_string = base64.b64decode(decrypt(new_pwd2)).decode('utf-8')
|
||
|
||
if new_pwd_string != new_pwd2_string:
|
||
return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="passwords do not match")
|
||
|
||
users = UserService.query_user_by_email(email=email)
|
||
if not users:
|
||
return get_json_result(data=False, code=RetCode.DATA_ERROR, message="invalid email")
|
||
|
||
user = users[0]
|
||
try:
|
||
UserService.update_user_password(user.id, new_pwd_base64)
|
||
except Exception as e:
|
||
logging.exception(e)
|
||
return get_json_result(data=False, code=RetCode.EXCEPTION_ERROR, message="failed to reset password")
|
||
|
||
# clear verified flag
|
||
try:
|
||
REDIS_CONN.delete(_verified_key(email))
|
||
except Exception:
|
||
pass
|
||
|
||
msg = "Password reset successful. Logged in."
|
||
return await construct_response(data=user.to_safe_dict(for_self=True), auth=user.get_id(), message=msg)
|
||
|
||
|