mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-06 03:18:36 +08:00
fix: prevent sensitive fields from leaking in user API responses (#14792)
Closes #14789 ### What problem does this PR solve? User API endpoints (`login`, `user_profile`, `user_add`, `forget_reset_password`) were returning full user objects via `to_json()` / `to_dict()`, which included sensitive fields like `password` and `access_token` in the response body. This leaks credentials to the client. This PR adds a `to_safe_dict()` method on the `User` model that strips sensitive fields (`password`, `access_token`) and replaces all affected call sites to use it. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@@ -94,12 +94,14 @@ async def login():
|
||||
"""
|
||||
json_body = await get_request_json()
|
||||
if not json_body:
|
||||
logging.warning("Login failed: invalid or empty JSON body")
|
||||
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
|
||||
|
||||
email = json_body.get("email", "")
|
||||
|
||||
users = UserService.query(email=email)
|
||||
if not users:
|
||||
logging.warning("Login failed: email not registered")
|
||||
return get_json_result(
|
||||
data=False,
|
||||
code=RetCode.AUTHENTICATION_ERROR,
|
||||
@@ -110,27 +112,30 @@ async def login():
|
||||
try:
|
||||
password = decrypt(password)
|
||||
except BaseException:
|
||||
logging.warning("Login failed: password decryption error")
|
||||
return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="Fail to crypt password")
|
||||
|
||||
user = UserService.query_user(email, password)
|
||||
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
logging.warning("Login failed: disabled account for user_id=%s", user.id)
|
||||
return get_json_result(
|
||||
data=False,
|
||||
code=RetCode.FORBIDDEN,
|
||||
message="This account has been disabled, please contact the administrator!",
|
||||
)
|
||||
elif user:
|
||||
response_data = user.to_json()
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
user.update_time = current_timestamp()
|
||||
user.update_date = datetime_format(datetime.now())
|
||||
user.save()
|
||||
logging.info("Login successful: user_id=%s", user.id)
|
||||
msg = "Welcome back!"
|
||||
|
||||
return await construct_response(data=response_data, auth=user.get_id(), message=msg)
|
||||
return await construct_response(data=user.to_safe_dict(for_self=True), auth=user.get_id(), message=msg)
|
||||
else:
|
||||
logging.warning("Login failed: wrong credentials")
|
||||
return get_json_result(
|
||||
data=False,
|
||||
code=RetCode.AUTHENTICATION_ERROR,
|
||||
@@ -169,6 +174,7 @@ async def oauth_login(channel):
|
||||
state = get_uuid()
|
||||
session["oauth_state"] = state
|
||||
auth_url = auth_cli.get_authorization_url(state)
|
||||
logging.info("OAuth login initiated: channel='%s', state='%s'", channel, state)
|
||||
return redirect(auth_url)
|
||||
|
||||
|
||||
@@ -283,9 +289,11 @@ async def log_out():
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
user_id = current_user.id
|
||||
current_user.access_token = f"INVALID_{secrets.token_hex(16)}"
|
||||
current_user.save()
|
||||
logout_user()
|
||||
logging.info("Logout: user_id=%s, access_token invalidated", user_id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@@ -383,7 +391,7 @@ async def user_profile():
|
||||
type: string
|
||||
description: User email.
|
||||
"""
|
||||
return get_json_result(data=current_user.to_dict())
|
||||
return get_json_result(data=current_user.to_safe_dict(for_self=True))
|
||||
|
||||
|
||||
def rollback_user_registration(user_id):
|
||||
@@ -528,7 +536,7 @@ async def user_add():
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
return await construct_response(
|
||||
data=user.to_json(),
|
||||
data=user.to_safe_dict(for_self=True),
|
||||
auth=user.get_id(),
|
||||
message=f"{nickname}, welcome aboard!",
|
||||
)
|
||||
@@ -837,6 +845,6 @@ async def forget_reset_password():
|
||||
pass
|
||||
|
||||
msg = "Password reset successful. Logged in."
|
||||
return await construct_response(data=user.to_json(), auth=user.get_id(), message=msg)
|
||||
return await construct_response(data=user.to_safe_dict(for_self=True), auth=user.get_id(), message=msg)
|
||||
|
||||
|
||||
|
||||
@@ -705,6 +705,8 @@ def fill_db_model_object(model_object, human_model_dict):
|
||||
|
||||
|
||||
class User(DataBaseModel, AuthUser):
|
||||
SENSITIVE_FIELDS = {"password", "access_token", "email"}
|
||||
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
access_token = CharField(max_length=255, null=True, index=True)
|
||||
nickname = CharField(max_length=100, null=False, help_text="nicky name", index=True)
|
||||
@@ -729,6 +731,18 @@ class User(DataBaseModel, AuthUser):
|
||||
jwt = Serializer(secret_key=settings.get_secret_key())
|
||||
return jwt.dumps(str(self.access_token))
|
||||
|
||||
def to_safe_dict(self, *, for_self: bool = False):
|
||||
"""Return a dict with sensitive fields stripped for API responses.
|
||||
|
||||
Email is treated as sensitive in generic serialization. Pass for_self=True
|
||||
when returning the authenticated user's own record (login, profile, etc.).
|
||||
"""
|
||||
result = {k: v for k, v in self.to_dict().items() if k not in self.SENSITIVE_FIELDS}
|
||||
if for_self:
|
||||
result["email"] = self.email
|
||||
logging.debug("User %s serialized safely, filtered fields: %s", self.id, self.SENSITIVE_FIELDS)
|
||||
return result
|
||||
|
||||
class Meta:
|
||||
db_table = "user"
|
||||
|
||||
|
||||
@@ -392,6 +392,13 @@ class _DummyUser:
|
||||
def to_dict(self):
|
||||
return {"id": self.id, "email": self.email}
|
||||
|
||||
def to_safe_dict(self, *, for_self: bool = False):
|
||||
_sensitive = {"password", "access_token", "email"}
|
||||
result = {k: v for k, v in self.to_dict().items() if k not in _sensitive}
|
||||
if for_self:
|
||||
result["email"] = self.email
|
||||
return result
|
||||
|
||||
|
||||
def _set_request_args(monkeypatch, module, args=None):
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args(args or {})))
|
||||
|
||||
Reference in New Issue
Block a user