diff --git a/api/apps/restful_apis/memory_api.py b/api/apps/restful_apis/memory_api.py index 8f92661e70..c361d816b6 100644 --- a/api/apps/restful_apis/memory_api.py +++ b/api/apps/restful_apis/memory_api.py @@ -130,7 +130,7 @@ async def delete_memory(memory_id): @login_required async def list_memory(): filter_params = { - k: request.args.get(k) for k in ["memory_type", "tenant_id", "storage_type"] if k in request.args + k: request.args.get(k) for k in ["memory_type", "tenant_id", "owner_ids", "storage_type"] if k in request.args } keywords = request.args.get("keywords") page = int(request.args.get("page", 1)) diff --git a/api/apps/services/memory_api_service.py b/api/apps/services/memory_api_service.py index 1b640cff66..9040f0ce44 100644 --- a/api/apps/services/memory_api_service.py +++ b/api/apps/services/memory_api_service.py @@ -29,6 +29,49 @@ from common.exceptions import ArgumentException, NotFoundException from common.time_utils import current_timestamp, timestamp_to_date +def _split_filter_values(values): + if not values: + return [] + if isinstance(values, str): + values = [values] + res = [] + for value in values: + if not value: + continue + if isinstance(value, str): + res.extend([v.strip() for v in value.split(",") if v.strip()]) + else: + res.append(value) + return res + + +def _joined_tenant_ids(user_id: str) -> set[str]: + user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(user_id) + return {user_id, *[tenant["tenant_id"] for tenant in user_tenants]} + + +def _memory_accessible(memory) -> bool: + if memory.tenant_id == current_user.id: + return True + if memory.permissions != TenantPermission.TEAM.value: + return False + return memory.tenant_id in _joined_tenant_ids(current_user.id) + + +def _require_memory_access(memory_id: str): + memory = MemoryService.get_by_memory_id(memory_id) + if not memory or not _memory_accessible(memory): + raise NotFoundException(f"Memory '{memory_id}' not found.") + return memory + + +def _filter_accessible_memories(memory_ids: list[str]): + memory_ids = _split_filter_values(memory_ids) + if not memory_ids: + return [] + return [memory for memory in MemoryService.get_by_ids(memory_ids) if _memory_accessible(memory)] + + async def create_memory(memory_info: dict): """ :param memory_info: { @@ -137,9 +180,7 @@ async def update_memory(memory_id: str, new_memory_setting: dict): for field in ["avatar", "description", "system_prompt", "user_prompt"]: if field in new_memory_setting: update_dict[field] = new_memory_setting[field] - current_memory = MemoryService.get_by_memory_id(memory_id) - if not current_memory: - raise NotFoundException(f"Memory '{memory_id}' not found.") + current_memory = _require_memory_access(memory_id) memory_dict = current_memory.to_dict() memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)}) @@ -168,9 +209,7 @@ async def update_memory(memory_id: str, new_memory_setting: dict): async def delete_memory(memory_id): - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - raise NotFoundException(f"Memory '{memory_id}' not found.") + memory = _require_memory_access(memory_id) MemoryService.delete_memory(memory_id) if MessageService.has_index(memory.tenant_id, memory_id): MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id) @@ -188,19 +227,16 @@ async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size :param page: int :param page_size: int """ - filter_dict: dict = {"storage_type": filter_params.get("storage_type")} - tenant_ids = filter_params.get("tenant_id") - if not filter_params.get("tenant_id"): - # restrict to current user's tenants - user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id) - filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants] + filter_dict: dict = {"storage_type": filter_params.get("storage_type"), "accessible_user_id": current_user.id} + allowed_tenant_ids = _joined_tenant_ids(current_user.id) + tenant_ids = _split_filter_values(filter_params.get("tenant_id") or filter_params.get("owner_ids")) + if tenant_ids: + filter_dict["tenant_id"] = [tenant_id for tenant_id in tenant_ids if tenant_id in allowed_tenant_ids] + if not filter_dict["tenant_id"]: + return {"memory_list": [], "total_count": 0} else: - if len(tenant_ids) == 1 and ',' in tenant_ids[0]: - tenant_ids = tenant_ids[0].split(',') - filter_dict["tenant_id"] = tenant_ids - memory_types = filter_params.get("memory_type") - if memory_types and len(memory_types) == 1 and ',' in memory_types[0]: - memory_types = memory_types[0].split(',') + filter_dict["tenant_id"] = list(allowed_tenant_ids) + memory_types = _split_filter_values(filter_params.get("memory_type")) filter_dict["memory_type"] = memory_types memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size) @@ -212,15 +248,13 @@ async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size async def get_memory_config(memory_id): memory = MemoryService.get_with_owner_name_by_id(memory_id) - if not memory: + if not memory or not _memory_accessible(memory): raise NotFoundException(f"Memory '{memory_id}' not found.") return format_ret_data_from_memory(memory) async def get_memory_messages(memory_id, agent_ids: list[str], keywords: str, page: int=1, page_size: int = 50): - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - raise NotFoundException(f"Memory '{memory_id}' not found.") + memory = _require_memory_access(memory_id) messages = MessageService.list_message( memory.tenant_id, memory_id, agent_ids, keywords, page, page_size) agent_name_mapping = {} @@ -253,13 +287,14 @@ async def add_message(memory_ids: list[str], message_dict: dict): "message_type": str } """ - return await queue_save_to_memory_task(memory_ids, message_dict) + accessible_memory_ids = [memory.id for memory in _filter_accessible_memories(memory_ids)] + if not accessible_memory_ids: + return False, "Memory not found." + return await queue_save_to_memory_task(accessible_memory_ids, message_dict) async def forget_message(memory_id: str, message_id: int): - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - raise NotFoundException(f"Memory '{memory_id}' not found.") + memory = _require_memory_access(memory_id) forget_time = timestamp_to_date(current_timestamp()) update_succeed = MessageService.update_message( @@ -272,9 +307,7 @@ async def forget_message(memory_id: str, message_id: int): async def update_message_status(memory_id: str, message_id: int, status: bool): - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - raise NotFoundException(f"Memory '{memory_id}' not found.") + memory = _require_memory_access(memory_id) update_succeed = MessageService.update_message( {"memory_id": memory_id, "message_id": int(message_id)}, @@ -300,6 +333,11 @@ async def search_message(filter_dict: dict, params: dict): "top_n": int } """ + memory_ids = _split_filter_values(filter_dict.get("memory_id")) + accessible_memory_ids = [memory.id for memory in _filter_accessible_memories(memory_ids)] + if not accessible_memory_ids: + return [] + filter_dict = {**filter_dict, "memory_id": accessible_memory_ids} return query_message(filter_dict, params) @@ -313,11 +351,14 @@ async def get_messages(memory_ids: list[str], agent_id: str = "", session_id: st :param limit: maximum number of messages to return :return: list of recent messages """ - memory_list = MemoryService.get_by_ids(memory_ids) + memory_list = _filter_accessible_memories(memory_ids) + if not memory_list: + return [] uids = [memory.tenant_id for memory in memory_list] + accessible_memory_ids = [memory.id for memory in memory_list] res = MessageService.get_recent_messages( uids, - memory_ids, + accessible_memory_ids, agent_id, session_id, limit @@ -334,11 +375,9 @@ async def get_message_content(memory_id: str, message_id: int): :return: message content :raises NotFoundException: if memory or message not found """ - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - raise NotFoundException(f"Memory '{memory_id}' not found.") + memory = _require_memory_access(memory_id) res = MessageService.get_by_message_id(memory_id, message_id, memory.tenant_id) if res: return res - raise NotFoundException(f"Message '{message_id}' in memory '{memory_id}' not found.") \ No newline at end of file + raise NotFoundException(f"Message '{message_id}' in memory '{memory_id}' not found.") diff --git a/api/db/services/memory_service.py b/api/db/services/memory_service.py index d2433d01d0..530fc5ad9e 100644 --- a/api/db/services/memory_service.py +++ b/api/db/services/memory_service.py @@ -92,6 +92,11 @@ class MemoryService(CommonService): memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)) if filter_dict.get("tenant_id"): memories = memories.where(cls.model.tenant_id.in_(filter_dict["tenant_id"])) + if filter_dict.get("accessible_user_id"): + memories = memories.where( + (cls.model.tenant_id == filter_dict["accessible_user_id"]) | + (cls.model.permissions == "team") + ) if filter_dict.get("memory_type"): memory_type_int = calculate_memory_type(filter_dict["memory_type"]) memories = memories.where(cls.model.memory_type.bin_and(memory_type_int) > 0)