From 94f8779a00993395bf4f3506bae7a9229f5d6d6e Mon Sep 17 00:00:00 2001 From: jony376 Date: Tue, 5 May 2026 23:10:47 -0700 Subject: [PATCH] Memory API: enforce tenant permissions on memory and message endpoints (#14535) ### What problem does this PR solve? This PR fixes missing authorization checks in the Memory API. Previously, several authenticated endpoints accepted caller-supplied `tenant_id`, `owner_ids`, or `memory_id` values and used them directly to list, read, update, delete, or search Memory data. That could allow an authenticated user to access or mutate another tenant's Memory records if they knew a tenant ID or memory ID. The fix centralizes Memory access checks and applies them consistently across Memory and Memory-message operations. The change: - Adds helper logic to parse list filters and compute tenant IDs accessible to `current_user`. - Requires direct `memory_id` operations to pass Memory access checks before reading, updating, deleting, or changing message state. - Filters list/search/recent-message requests to accessible memories only. - Applies Memory visibility filtering before count and pagination in `MemoryService.get_by_filter`. - Accepts `owner_ids` in the Memory list route, matching the frontend owner filter while still intersecting values with the caller's accessible tenants. - ### Related issues Closes #14534 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) Co-authored-by: jony376 --- api/apps/restful_apis/memory_api.py | 2 +- api/apps/services/memory_api_service.py | 109 ++++++++++++++++-------- api/db/services/memory_service.py | 5 ++ 3 files changed, 80 insertions(+), 36 deletions(-) diff --git a/api/apps/restful_apis/memory_api.py b/api/apps/restful_apis/memory_api.py index 8f92661e70..c361d816b6 100644 --- a/api/apps/restful_apis/memory_api.py +++ b/api/apps/restful_apis/memory_api.py @@ -130,7 +130,7 @@ async def delete_memory(memory_id): @login_required async def list_memory(): filter_params = { - k: request.args.get(k) for k in ["memory_type", "tenant_id", "storage_type"] if k in request.args + k: request.args.get(k) for k in ["memory_type", "tenant_id", "owner_ids", "storage_type"] if k in request.args } keywords = request.args.get("keywords") page = int(request.args.get("page", 1)) diff --git a/api/apps/services/memory_api_service.py b/api/apps/services/memory_api_service.py index 1b640cff66..9040f0ce44 100644 --- a/api/apps/services/memory_api_service.py +++ b/api/apps/services/memory_api_service.py @@ -29,6 +29,49 @@ from common.exceptions import ArgumentException, NotFoundException from common.time_utils import current_timestamp, timestamp_to_date +def _split_filter_values(values): + if not values: + return [] + if isinstance(values, str): + values = [values] + res = [] + for value in values: + if not value: + continue + if isinstance(value, str): + res.extend([v.strip() for v in value.split(",") if v.strip()]) + else: + res.append(value) + return res + + +def _joined_tenant_ids(user_id: str) -> set[str]: + user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(user_id) + return {user_id, *[tenant["tenant_id"] for tenant in user_tenants]} + + +def _memory_accessible(memory) -> bool: + if memory.tenant_id == current_user.id: + return True + if memory.permissions != TenantPermission.TEAM.value: + return False + return memory.tenant_id in _joined_tenant_ids(current_user.id) + + +def _require_memory_access(memory_id: str): + memory = MemoryService.get_by_memory_id(memory_id) + if not memory or not _memory_accessible(memory): + raise NotFoundException(f"Memory '{memory_id}' not found.") + return memory + + +def _filter_accessible_memories(memory_ids: list[str]): + memory_ids = _split_filter_values(memory_ids) + if not memory_ids: + return [] + return [memory for memory in MemoryService.get_by_ids(memory_ids) if _memory_accessible(memory)] + + async def create_memory(memory_info: dict): """ :param memory_info: { @@ -137,9 +180,7 @@ async def update_memory(memory_id: str, new_memory_setting: dict): for field in ["avatar", "description", "system_prompt", "user_prompt"]: if field in new_memory_setting: update_dict[field] = new_memory_setting[field] - current_memory = MemoryService.get_by_memory_id(memory_id) - if not current_memory: - raise NotFoundException(f"Memory '{memory_id}' not found.") + current_memory = _require_memory_access(memory_id) memory_dict = current_memory.to_dict() memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)}) @@ -168,9 +209,7 @@ async def update_memory(memory_id: str, new_memory_setting: dict): async def delete_memory(memory_id): - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - raise NotFoundException(f"Memory '{memory_id}' not found.") + memory = _require_memory_access(memory_id) MemoryService.delete_memory(memory_id) if MessageService.has_index(memory.tenant_id, memory_id): MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id) @@ -188,19 +227,16 @@ async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size :param page: int :param page_size: int """ - filter_dict: dict = {"storage_type": filter_params.get("storage_type")} - tenant_ids = filter_params.get("tenant_id") - if not filter_params.get("tenant_id"): - # restrict to current user's tenants - user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id) - filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants] + filter_dict: dict = {"storage_type": filter_params.get("storage_type"), "accessible_user_id": current_user.id} + allowed_tenant_ids = _joined_tenant_ids(current_user.id) + tenant_ids = _split_filter_values(filter_params.get("tenant_id") or filter_params.get("owner_ids")) + if tenant_ids: + filter_dict["tenant_id"] = [tenant_id for tenant_id in tenant_ids if tenant_id in allowed_tenant_ids] + if not filter_dict["tenant_id"]: + return {"memory_list": [], "total_count": 0} else: - if len(tenant_ids) == 1 and ',' in tenant_ids[0]: - tenant_ids = tenant_ids[0].split(',') - filter_dict["tenant_id"] = tenant_ids - memory_types = filter_params.get("memory_type") - if memory_types and len(memory_types) == 1 and ',' in memory_types[0]: - memory_types = memory_types[0].split(',') + filter_dict["tenant_id"] = list(allowed_tenant_ids) + memory_types = _split_filter_values(filter_params.get("memory_type")) filter_dict["memory_type"] = memory_types memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size) @@ -212,15 +248,13 @@ async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size async def get_memory_config(memory_id): memory = MemoryService.get_with_owner_name_by_id(memory_id) - if not memory: + if not memory or not _memory_accessible(memory): raise NotFoundException(f"Memory '{memory_id}' not found.") return format_ret_data_from_memory(memory) async def get_memory_messages(memory_id, agent_ids: list[str], keywords: str, page: int=1, page_size: int = 50): - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - raise NotFoundException(f"Memory '{memory_id}' not found.") + memory = _require_memory_access(memory_id) messages = MessageService.list_message( memory.tenant_id, memory_id, agent_ids, keywords, page, page_size) agent_name_mapping = {} @@ -253,13 +287,14 @@ async def add_message(memory_ids: list[str], message_dict: dict): "message_type": str } """ - return await queue_save_to_memory_task(memory_ids, message_dict) + accessible_memory_ids = [memory.id for memory in _filter_accessible_memories(memory_ids)] + if not accessible_memory_ids: + return False, "Memory not found." + return await queue_save_to_memory_task(accessible_memory_ids, message_dict) async def forget_message(memory_id: str, message_id: int): - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - raise NotFoundException(f"Memory '{memory_id}' not found.") + memory = _require_memory_access(memory_id) forget_time = timestamp_to_date(current_timestamp()) update_succeed = MessageService.update_message( @@ -272,9 +307,7 @@ async def forget_message(memory_id: str, message_id: int): async def update_message_status(memory_id: str, message_id: int, status: bool): - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - raise NotFoundException(f"Memory '{memory_id}' not found.") + memory = _require_memory_access(memory_id) update_succeed = MessageService.update_message( {"memory_id": memory_id, "message_id": int(message_id)}, @@ -300,6 +333,11 @@ async def search_message(filter_dict: dict, params: dict): "top_n": int } """ + memory_ids = _split_filter_values(filter_dict.get("memory_id")) + accessible_memory_ids = [memory.id for memory in _filter_accessible_memories(memory_ids)] + if not accessible_memory_ids: + return [] + filter_dict = {**filter_dict, "memory_id": accessible_memory_ids} return query_message(filter_dict, params) @@ -313,11 +351,14 @@ async def get_messages(memory_ids: list[str], agent_id: str = "", session_id: st :param limit: maximum number of messages to return :return: list of recent messages """ - memory_list = MemoryService.get_by_ids(memory_ids) + memory_list = _filter_accessible_memories(memory_ids) + if not memory_list: + return [] uids = [memory.tenant_id for memory in memory_list] + accessible_memory_ids = [memory.id for memory in memory_list] res = MessageService.get_recent_messages( uids, - memory_ids, + accessible_memory_ids, agent_id, session_id, limit @@ -334,11 +375,9 @@ async def get_message_content(memory_id: str, message_id: int): :return: message content :raises NotFoundException: if memory or message not found """ - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - raise NotFoundException(f"Memory '{memory_id}' not found.") + memory = _require_memory_access(memory_id) res = MessageService.get_by_message_id(memory_id, message_id, memory.tenant_id) if res: return res - raise NotFoundException(f"Message '{message_id}' in memory '{memory_id}' not found.") \ No newline at end of file + raise NotFoundException(f"Message '{message_id}' in memory '{memory_id}' not found.") diff --git a/api/db/services/memory_service.py b/api/db/services/memory_service.py index d2433d01d0..530fc5ad9e 100644 --- a/api/db/services/memory_service.py +++ b/api/db/services/memory_service.py @@ -92,6 +92,11 @@ class MemoryService(CommonService): memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)) if filter_dict.get("tenant_id"): memories = memories.where(cls.model.tenant_id.in_(filter_dict["tenant_id"])) + if filter_dict.get("accessible_user_id"): + memories = memories.where( + (cls.model.tenant_id == filter_dict["accessible_user_id"]) | + (cls.model.permissions == "team") + ) if filter_dict.get("memory_type"): memory_type_int = calculate_memory_type(filter_dict["memory_type"]) memories = memories.where(cls.model.memory_type.bin_and(memory_type_int) > 0)