Memory API: enforce tenant permissions on memory and message endpoints (#14535)

### What problem does this PR solve?

This PR fixes missing authorization checks in the Memory API.
Previously, several authenticated endpoints accepted caller-supplied
`tenant_id`, `owner_ids`, or `memory_id` values and used them directly
to list, read, update, delete, or search Memory data.

That could allow an authenticated user to access or mutate another
tenant's Memory records if they knew a tenant ID or memory ID. The fix
centralizes Memory access checks and applies them consistently across
Memory and Memory-message operations.

The change:

- Adds helper logic to parse list filters and compute tenant IDs
accessible to `current_user`.
- Requires direct `memory_id` operations to pass Memory access checks
before reading, updating, deleting, or changing message state.
- Filters list/search/recent-message requests to accessible memories
only.
- Applies Memory visibility filtering before count and pagination in
`MemoryService.get_by_filter`.
- Accepts `owner_ids` in the Memory list route, matching the frontend
owner filter while still intersecting values with the caller's
accessible tenants.
- 

### Related issues
Closes #14534 

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Co-authored-by: jony376 <jony376@gmail.com>
This commit is contained in:
jony376
2026-05-05 23:10:47 -07:00
committed by GitHub
parent 5672be0652
commit 94f8779a00
3 changed files with 80 additions and 36 deletions

View File

@@ -130,7 +130,7 @@ async def delete_memory(memory_id):
@login_required
async def list_memory():
filter_params = {
k: request.args.get(k) for k in ["memory_type", "tenant_id", "storage_type"] if k in request.args
k: request.args.get(k) for k in ["memory_type", "tenant_id", "owner_ids", "storage_type"] if k in request.args
}
keywords = request.args.get("keywords")
page = int(request.args.get("page", 1))

View File

@@ -29,6 +29,49 @@ from common.exceptions import ArgumentException, NotFoundException
from common.time_utils import current_timestamp, timestamp_to_date
def _split_filter_values(values):
if not values:
return []
if isinstance(values, str):
values = [values]
res = []
for value in values:
if not value:
continue
if isinstance(value, str):
res.extend([v.strip() for v in value.split(",") if v.strip()])
else:
res.append(value)
return res
def _joined_tenant_ids(user_id: str) -> set[str]:
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(user_id)
return {user_id, *[tenant["tenant_id"] for tenant in user_tenants]}
def _memory_accessible(memory) -> bool:
if memory.tenant_id == current_user.id:
return True
if memory.permissions != TenantPermission.TEAM.value:
return False
return memory.tenant_id in _joined_tenant_ids(current_user.id)
def _require_memory_access(memory_id: str):
memory = MemoryService.get_by_memory_id(memory_id)
if not memory or not _memory_accessible(memory):
raise NotFoundException(f"Memory '{memory_id}' not found.")
return memory
def _filter_accessible_memories(memory_ids: list[str]):
memory_ids = _split_filter_values(memory_ids)
if not memory_ids:
return []
return [memory for memory in MemoryService.get_by_ids(memory_ids) if _memory_accessible(memory)]
async def create_memory(memory_info: dict):
"""
:param memory_info: {
@@ -137,9 +180,7 @@ async def update_memory(memory_id: str, new_memory_setting: dict):
for field in ["avatar", "description", "system_prompt", "user_prompt"]:
if field in new_memory_setting:
update_dict[field] = new_memory_setting[field]
current_memory = MemoryService.get_by_memory_id(memory_id)
if not current_memory:
raise NotFoundException(f"Memory '{memory_id}' not found.")
current_memory = _require_memory_access(memory_id)
memory_dict = current_memory.to_dict()
memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)})
@@ -168,9 +209,7 @@ async def update_memory(memory_id: str, new_memory_setting: dict):
async def delete_memory(memory_id):
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
raise NotFoundException(f"Memory '{memory_id}' not found.")
memory = _require_memory_access(memory_id)
MemoryService.delete_memory(memory_id)
if MessageService.has_index(memory.tenant_id, memory_id):
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id)
@@ -188,19 +227,16 @@ async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size
:param page: int
:param page_size: int
"""
filter_dict: dict = {"storage_type": filter_params.get("storage_type")}
tenant_ids = filter_params.get("tenant_id")
if not filter_params.get("tenant_id"):
# restrict to current user's tenants
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id)
filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants]
filter_dict: dict = {"storage_type": filter_params.get("storage_type"), "accessible_user_id": current_user.id}
allowed_tenant_ids = _joined_tenant_ids(current_user.id)
tenant_ids = _split_filter_values(filter_params.get("tenant_id") or filter_params.get("owner_ids"))
if tenant_ids:
filter_dict["tenant_id"] = [tenant_id for tenant_id in tenant_ids if tenant_id in allowed_tenant_ids]
if not filter_dict["tenant_id"]:
return {"memory_list": [], "total_count": 0}
else:
if len(tenant_ids) == 1 and ',' in tenant_ids[0]:
tenant_ids = tenant_ids[0].split(',')
filter_dict["tenant_id"] = tenant_ids
memory_types = filter_params.get("memory_type")
if memory_types and len(memory_types) == 1 and ',' in memory_types[0]:
memory_types = memory_types[0].split(',')
filter_dict["tenant_id"] = list(allowed_tenant_ids)
memory_types = _split_filter_values(filter_params.get("memory_type"))
filter_dict["memory_type"] = memory_types
memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size)
@@ -212,15 +248,13 @@ async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size
async def get_memory_config(memory_id):
memory = MemoryService.get_with_owner_name_by_id(memory_id)
if not memory:
if not memory or not _memory_accessible(memory):
raise NotFoundException(f"Memory '{memory_id}' not found.")
return format_ret_data_from_memory(memory)
async def get_memory_messages(memory_id, agent_ids: list[str], keywords: str, page: int=1, page_size: int = 50):
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
raise NotFoundException(f"Memory '{memory_id}' not found.")
memory = _require_memory_access(memory_id)
messages = MessageService.list_message(
memory.tenant_id, memory_id, agent_ids, keywords, page, page_size)
agent_name_mapping = {}
@@ -253,13 +287,14 @@ async def add_message(memory_ids: list[str], message_dict: dict):
"message_type": str
}
"""
return await queue_save_to_memory_task(memory_ids, message_dict)
accessible_memory_ids = [memory.id for memory in _filter_accessible_memories(memory_ids)]
if not accessible_memory_ids:
return False, "Memory not found."
return await queue_save_to_memory_task(accessible_memory_ids, message_dict)
async def forget_message(memory_id: str, message_id: int):
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
raise NotFoundException(f"Memory '{memory_id}' not found.")
memory = _require_memory_access(memory_id)
forget_time = timestamp_to_date(current_timestamp())
update_succeed = MessageService.update_message(
@@ -272,9 +307,7 @@ async def forget_message(memory_id: str, message_id: int):
async def update_message_status(memory_id: str, message_id: int, status: bool):
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
raise NotFoundException(f"Memory '{memory_id}' not found.")
memory = _require_memory_access(memory_id)
update_succeed = MessageService.update_message(
{"memory_id": memory_id, "message_id": int(message_id)},
@@ -300,6 +333,11 @@ async def search_message(filter_dict: dict, params: dict):
"top_n": int
}
"""
memory_ids = _split_filter_values(filter_dict.get("memory_id"))
accessible_memory_ids = [memory.id for memory in _filter_accessible_memories(memory_ids)]
if not accessible_memory_ids:
return []
filter_dict = {**filter_dict, "memory_id": accessible_memory_ids}
return query_message(filter_dict, params)
@@ -313,11 +351,14 @@ async def get_messages(memory_ids: list[str], agent_id: str = "", session_id: st
:param limit: maximum number of messages to return
:return: list of recent messages
"""
memory_list = MemoryService.get_by_ids(memory_ids)
memory_list = _filter_accessible_memories(memory_ids)
if not memory_list:
return []
uids = [memory.tenant_id for memory in memory_list]
accessible_memory_ids = [memory.id for memory in memory_list]
res = MessageService.get_recent_messages(
uids,
memory_ids,
accessible_memory_ids,
agent_id,
session_id,
limit
@@ -334,11 +375,9 @@ async def get_message_content(memory_id: str, message_id: int):
:return: message content
:raises NotFoundException: if memory or message not found
"""
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
raise NotFoundException(f"Memory '{memory_id}' not found.")
memory = _require_memory_access(memory_id)
res = MessageService.get_by_message_id(memory_id, message_id, memory.tenant_id)
if res:
return res
raise NotFoundException(f"Message '{message_id}' in memory '{memory_id}' not found.")
raise NotFoundException(f"Message '{message_id}' in memory '{memory_id}' not found.")

View File

@@ -92,6 +92,11 @@ class MemoryService(CommonService):
memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id))
if filter_dict.get("tenant_id"):
memories = memories.where(cls.model.tenant_id.in_(filter_dict["tenant_id"]))
if filter_dict.get("accessible_user_id"):
memories = memories.where(
(cls.model.tenant_id == filter_dict["accessible_user_id"]) |
(cls.model.permissions == "team")
)
if filter_dict.get("memory_type"):
memory_type_int = calculate_memory_type(filter_dict["memory_type"])
memories = memories.where(cls.model.memory_type.bin_and(memory_type_int) > 0)