mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Memory API: enforce tenant permissions on memory and message endpoints (#14535)
### What problem does this PR solve? This PR fixes missing authorization checks in the Memory API. Previously, several authenticated endpoints accepted caller-supplied `tenant_id`, `owner_ids`, or `memory_id` values and used them directly to list, read, update, delete, or search Memory data. That could allow an authenticated user to access or mutate another tenant's Memory records if they knew a tenant ID or memory ID. The fix centralizes Memory access checks and applies them consistently across Memory and Memory-message operations. The change: - Adds helper logic to parse list filters and compute tenant IDs accessible to `current_user`. - Requires direct `memory_id` operations to pass Memory access checks before reading, updating, deleting, or changing message state. - Filters list/search/recent-message requests to accessible memories only. - Applies Memory visibility filtering before count and pagination in `MemoryService.get_by_filter`. - Accepts `owner_ids` in the Memory list route, matching the frontend owner filter while still intersecting values with the caller's accessible tenants. - ### Related issues Closes #14534 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) Co-authored-by: jony376 <jony376@gmail.com>
This commit is contained in:
@@ -130,7 +130,7 @@ async def delete_memory(memory_id):
|
||||
@login_required
|
||||
async def list_memory():
|
||||
filter_params = {
|
||||
k: request.args.get(k) for k in ["memory_type", "tenant_id", "storage_type"] if k in request.args
|
||||
k: request.args.get(k) for k in ["memory_type", "tenant_id", "owner_ids", "storage_type"] if k in request.args
|
||||
}
|
||||
keywords = request.args.get("keywords")
|
||||
page = int(request.args.get("page", 1))
|
||||
|
||||
@@ -29,6 +29,49 @@ from common.exceptions import ArgumentException, NotFoundException
|
||||
from common.time_utils import current_timestamp, timestamp_to_date
|
||||
|
||||
|
||||
def _split_filter_values(values):
|
||||
if not values:
|
||||
return []
|
||||
if isinstance(values, str):
|
||||
values = [values]
|
||||
res = []
|
||||
for value in values:
|
||||
if not value:
|
||||
continue
|
||||
if isinstance(value, str):
|
||||
res.extend([v.strip() for v in value.split(",") if v.strip()])
|
||||
else:
|
||||
res.append(value)
|
||||
return res
|
||||
|
||||
|
||||
def _joined_tenant_ids(user_id: str) -> set[str]:
|
||||
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(user_id)
|
||||
return {user_id, *[tenant["tenant_id"] for tenant in user_tenants]}
|
||||
|
||||
|
||||
def _memory_accessible(memory) -> bool:
|
||||
if memory.tenant_id == current_user.id:
|
||||
return True
|
||||
if memory.permissions != TenantPermission.TEAM.value:
|
||||
return False
|
||||
return memory.tenant_id in _joined_tenant_ids(current_user.id)
|
||||
|
||||
|
||||
def _require_memory_access(memory_id: str):
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory or not _memory_accessible(memory):
|
||||
raise NotFoundException(f"Memory '{memory_id}' not found.")
|
||||
return memory
|
||||
|
||||
|
||||
def _filter_accessible_memories(memory_ids: list[str]):
|
||||
memory_ids = _split_filter_values(memory_ids)
|
||||
if not memory_ids:
|
||||
return []
|
||||
return [memory for memory in MemoryService.get_by_ids(memory_ids) if _memory_accessible(memory)]
|
||||
|
||||
|
||||
async def create_memory(memory_info: dict):
|
||||
"""
|
||||
:param memory_info: {
|
||||
@@ -137,9 +180,7 @@ async def update_memory(memory_id: str, new_memory_setting: dict):
|
||||
for field in ["avatar", "description", "system_prompt", "user_prompt"]:
|
||||
if field in new_memory_setting:
|
||||
update_dict[field] = new_memory_setting[field]
|
||||
current_memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not current_memory:
|
||||
raise NotFoundException(f"Memory '{memory_id}' not found.")
|
||||
current_memory = _require_memory_access(memory_id)
|
||||
|
||||
memory_dict = current_memory.to_dict()
|
||||
memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)})
|
||||
@@ -168,9 +209,7 @@ async def update_memory(memory_id: str, new_memory_setting: dict):
|
||||
|
||||
|
||||
async def delete_memory(memory_id):
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
raise NotFoundException(f"Memory '{memory_id}' not found.")
|
||||
memory = _require_memory_access(memory_id)
|
||||
MemoryService.delete_memory(memory_id)
|
||||
if MessageService.has_index(memory.tenant_id, memory_id):
|
||||
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id)
|
||||
@@ -188,19 +227,16 @@ async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size
|
||||
:param page: int
|
||||
:param page_size: int
|
||||
"""
|
||||
filter_dict: dict = {"storage_type": filter_params.get("storage_type")}
|
||||
tenant_ids = filter_params.get("tenant_id")
|
||||
if not filter_params.get("tenant_id"):
|
||||
# restrict to current user's tenants
|
||||
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id)
|
||||
filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants]
|
||||
filter_dict: dict = {"storage_type": filter_params.get("storage_type"), "accessible_user_id": current_user.id}
|
||||
allowed_tenant_ids = _joined_tenant_ids(current_user.id)
|
||||
tenant_ids = _split_filter_values(filter_params.get("tenant_id") or filter_params.get("owner_ids"))
|
||||
if tenant_ids:
|
||||
filter_dict["tenant_id"] = [tenant_id for tenant_id in tenant_ids if tenant_id in allowed_tenant_ids]
|
||||
if not filter_dict["tenant_id"]:
|
||||
return {"memory_list": [], "total_count": 0}
|
||||
else:
|
||||
if len(tenant_ids) == 1 and ',' in tenant_ids[0]:
|
||||
tenant_ids = tenant_ids[0].split(',')
|
||||
filter_dict["tenant_id"] = tenant_ids
|
||||
memory_types = filter_params.get("memory_type")
|
||||
if memory_types and len(memory_types) == 1 and ',' in memory_types[0]:
|
||||
memory_types = memory_types[0].split(',')
|
||||
filter_dict["tenant_id"] = list(allowed_tenant_ids)
|
||||
memory_types = _split_filter_values(filter_params.get("memory_type"))
|
||||
filter_dict["memory_type"] = memory_types
|
||||
|
||||
memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size)
|
||||
@@ -212,15 +248,13 @@ async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size
|
||||
|
||||
async def get_memory_config(memory_id):
|
||||
memory = MemoryService.get_with_owner_name_by_id(memory_id)
|
||||
if not memory:
|
||||
if not memory or not _memory_accessible(memory):
|
||||
raise NotFoundException(f"Memory '{memory_id}' not found.")
|
||||
return format_ret_data_from_memory(memory)
|
||||
|
||||
|
||||
async def get_memory_messages(memory_id, agent_ids: list[str], keywords: str, page: int=1, page_size: int = 50):
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
raise NotFoundException(f"Memory '{memory_id}' not found.")
|
||||
memory = _require_memory_access(memory_id)
|
||||
messages = MessageService.list_message(
|
||||
memory.tenant_id, memory_id, agent_ids, keywords, page, page_size)
|
||||
agent_name_mapping = {}
|
||||
@@ -253,13 +287,14 @@ async def add_message(memory_ids: list[str], message_dict: dict):
|
||||
"message_type": str
|
||||
}
|
||||
"""
|
||||
return await queue_save_to_memory_task(memory_ids, message_dict)
|
||||
accessible_memory_ids = [memory.id for memory in _filter_accessible_memories(memory_ids)]
|
||||
if not accessible_memory_ids:
|
||||
return False, "Memory not found."
|
||||
return await queue_save_to_memory_task(accessible_memory_ids, message_dict)
|
||||
|
||||
|
||||
async def forget_message(memory_id: str, message_id: int):
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
raise NotFoundException(f"Memory '{memory_id}' not found.")
|
||||
memory = _require_memory_access(memory_id)
|
||||
|
||||
forget_time = timestamp_to_date(current_timestamp())
|
||||
update_succeed = MessageService.update_message(
|
||||
@@ -272,9 +307,7 @@ async def forget_message(memory_id: str, message_id: int):
|
||||
|
||||
|
||||
async def update_message_status(memory_id: str, message_id: int, status: bool):
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
raise NotFoundException(f"Memory '{memory_id}' not found.")
|
||||
memory = _require_memory_access(memory_id)
|
||||
|
||||
update_succeed = MessageService.update_message(
|
||||
{"memory_id": memory_id, "message_id": int(message_id)},
|
||||
@@ -300,6 +333,11 @@ async def search_message(filter_dict: dict, params: dict):
|
||||
"top_n": int
|
||||
}
|
||||
"""
|
||||
memory_ids = _split_filter_values(filter_dict.get("memory_id"))
|
||||
accessible_memory_ids = [memory.id for memory in _filter_accessible_memories(memory_ids)]
|
||||
if not accessible_memory_ids:
|
||||
return []
|
||||
filter_dict = {**filter_dict, "memory_id": accessible_memory_ids}
|
||||
return query_message(filter_dict, params)
|
||||
|
||||
|
||||
@@ -313,11 +351,14 @@ async def get_messages(memory_ids: list[str], agent_id: str = "", session_id: st
|
||||
:param limit: maximum number of messages to return
|
||||
:return: list of recent messages
|
||||
"""
|
||||
memory_list = MemoryService.get_by_ids(memory_ids)
|
||||
memory_list = _filter_accessible_memories(memory_ids)
|
||||
if not memory_list:
|
||||
return []
|
||||
uids = [memory.tenant_id for memory in memory_list]
|
||||
accessible_memory_ids = [memory.id for memory in memory_list]
|
||||
res = MessageService.get_recent_messages(
|
||||
uids,
|
||||
memory_ids,
|
||||
accessible_memory_ids,
|
||||
agent_id,
|
||||
session_id,
|
||||
limit
|
||||
@@ -334,11 +375,9 @@ async def get_message_content(memory_id: str, message_id: int):
|
||||
:return: message content
|
||||
:raises NotFoundException: if memory or message not found
|
||||
"""
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
raise NotFoundException(f"Memory '{memory_id}' not found.")
|
||||
memory = _require_memory_access(memory_id)
|
||||
|
||||
res = MessageService.get_by_message_id(memory_id, message_id, memory.tenant_id)
|
||||
if res:
|
||||
return res
|
||||
raise NotFoundException(f"Message '{message_id}' in memory '{memory_id}' not found.")
|
||||
raise NotFoundException(f"Message '{message_id}' in memory '{memory_id}' not found.")
|
||||
|
||||
@@ -92,6 +92,11 @@ class MemoryService(CommonService):
|
||||
memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id))
|
||||
if filter_dict.get("tenant_id"):
|
||||
memories = memories.where(cls.model.tenant_id.in_(filter_dict["tenant_id"]))
|
||||
if filter_dict.get("accessible_user_id"):
|
||||
memories = memories.where(
|
||||
(cls.model.tenant_id == filter_dict["accessible_user_id"]) |
|
||||
(cls.model.permissions == "team")
|
||||
)
|
||||
if filter_dict.get("memory_type"):
|
||||
memory_type_int = calculate_memory_type(filter_dict["memory_type"])
|
||||
memories = memories.where(cls.model.memory_type.bin_and(memory_type_int) > 0)
|
||||
|
||||
Reference in New Issue
Block a user