From 30d5fc1a07bbe7770a07f0e23c10da954bc68691 Mon Sep 17 00:00:00 2001 From: Lynn Date: Thu, 12 Feb 2026 10:11:50 +0800 Subject: [PATCH] Refactor: split memory API into gateway and service layers (#13111) ### What problem does this PR solve? Decouple the memory API into a gateway layer (for routing/param parse) and a service layer (for business logic). ### Type of change - [x] Refactoring --- api/apps/__init__.py | 8 +- api/apps/restful_apis/memory_api.py | 173 ++++++++++++++ api/apps/sdk/memories.py | 291 ------------------------ api/apps/services/__init__.py | 0 api/apps/services/memory_api_service.py | 223 ++++++++++++++++++ common/exceptions.py | 10 + 6 files changed, 413 insertions(+), 292 deletions(-) create mode 100644 api/apps/restful_apis/memory_api.py delete mode 100644 api/apps/sdk/memories.py create mode 100644 api/apps/services/__init__.py create mode 100644 api/apps/services/memory_api_service.py diff --git a/api/apps/__init__.py b/api/apps/__init__.py index 7feae696e3..89078d9fb8 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -244,6 +244,10 @@ def search_pages_path(page_path): path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".") ] app_path_list.extend(api_path_list) + restful_api_path_list = [ + path for path in page_path.glob("*restful_apis/*.py") if not path.name.startswith(".") + ] + app_path_list.extend(restful_api_path_list) return app_path_list @@ -263,8 +267,9 @@ def register_page(page_path): spec.loader.exec_module(page) page_name = getattr(page, "page_name", page_name) sdk_path = "\\sdk\\" if sys.platform.startswith("win") else "/sdk/" + restful_api_path = "\\restful_apis\\" if sys.platform.startswith("win") else "/restful_apis/" url_prefix = ( - f"/api/{API_VERSION}" if sdk_path in path else f"/{API_VERSION}/{page_name}" + f"/api/{API_VERSION}" if sdk_path in path or restful_api_path in path else f"/{API_VERSION}/{page_name}" ) app.register_blueprint(page.manager, url_prefix=url_prefix) @@ -274,6 +279,7 @@ def register_page(page_path): pages_dir = [ Path(__file__).parent, Path(__file__).parent.parent / "api" / "apps", + Path(__file__).parent.parent / "api" / "apps" / "restful_apis", Path(__file__).parent.parent / "api" / "apps" / "sdk", ] diff --git a/api/apps/restful_apis/memory_api.py b/api/apps/restful_apis/memory_api.py new file mode 100644 index 0000000000..53c7f866e2 --- /dev/null +++ b/api/apps/restful_apis/memory_api.py @@ -0,0 +1,173 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +import os +import time + +from quart import request +from common.constants import RetCode +from common.exceptions import ArgumentException, NotFoundException +from api.apps import login_required +from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result +from api.apps.services import memory_api_service + + +@manager.route("/memories", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("name", "memory_type", "embd_id", "llm_id") +async def create_memory(): + timing_enabled = os.getenv("RAGFLOW_API_TIMING") + t_start = time.perf_counter() if timing_enabled else None + req = await get_request_json() + t_parsed = time.perf_counter() if timing_enabled else None + try: + memory_info = { + "name": req["name"], + "memory_type": req["memory_type"], + "embd_id": req["embd_id"], + "llm_id": req["llm_id"] + } + success, res = await memory_api_service.create_memory(memory_info) + if timing_enabled: + logging.info( + "api_timing create_memory parse_ms=%.2f validate_and_db_ms=%.2f total_ms=%.2f path=%s", + (t_parsed - t_start) * 1000, + (time.perf_counter() - t_parsed) * 1000, + (time.perf_counter() - t_start) * 1000, + request.path, + ) + if success: + return get_json_result(message=True, data=res) + else: + return get_json_result(message=res, code=RetCode.SERVER_ERROR) + + except ArgumentException as arg_error: + logging.error(arg_error) + if timing_enabled: + logging.info( + "api_timing create_memory error=%s parse_ms=%.2f total_ms=%.2f path=%s", + str(arg_error), + (t_parsed - t_start) * 1000, + (time.perf_counter() - t_start) * 1000, + request.path, + ) + return get_error_argument_result(str(arg_error)) + + except Exception as e: + logging.error(e) + if timing_enabled: + logging.info( + "api_timing create_memory error=%s parse_ms=%.2f total_ms=%.2f path=%s", + str(e), + (t_parsed - t_start) * 1000, + (time.perf_counter() - t_start) * 1000, + request.path, + ) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") + + +@manager.route("/memories/", methods=["PUT"]) # noqa: F821 +@login_required +async def update_memory(memory_id): + req = await get_request_json() + new_settings = {k: req[k] for k in [ + "name", "permissions", "llm_id", "embd_id", "memory_type", "memory_size", "forgetting_policy", "temperature", + "avatar", "description", "system_prompt", "user_prompt" + ] if k in req} + try: + success, res = await memory_api_service.update_memory(memory_id, new_settings) + if success: + return get_json_result(message=True, data=res) + else: + return get_json_result(message=res, code=RetCode.SERVER_ERROR) + except NotFoundException as not_found_exception: + logging.error(not_found_exception) + return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception)) + except ArgumentException as arg_error: + logging.error(arg_error) + return get_error_argument_result(str(arg_error)) + except Exception as e: + logging.error(e) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") + + +@manager.route("/memories/", methods=["DELETE"]) # noqa: F821 +@login_required +async def delete_memory(memory_id): + try: + await memory_api_service.delete_memory(memory_id) + return get_json_result(message=True) + except NotFoundException as not_found_exception: + logging.error(not_found_exception) + return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception)) + except Exception as e: + logging.error(e) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") + + +@manager.route("/memories", methods=["GET"]) # noqa: F821 +@login_required +async def list_memory(): + filter_params = { + k: request.args.get(k) for k in ["memory_type", "tenant_id", "storage_type"] if k in request.args + } + keywords = request.args.get("keywords") + page = int(request.args.get("page", 1)) + page_size = int(request.args.get("page_size", 50)) + try: + res = await memory_api_service.list_memory(filter_params, keywords, page, page_size) + return get_json_result(message=True, data=res) + except Exception as e: + logging.error(e) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") + + +@manager.route("/memories//config", methods=["GET"]) # noqa: F821 +@login_required +async def get_memory_config(memory_id): + try: + res = await memory_api_service.get_memory_config(memory_id) + return get_json_result(message=True, data=res) + except NotFoundException as not_found_exception: + logging.error(not_found_exception) + return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception)) + except Exception as e: + logging.error(e) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") + + +@manager.route("/memories/", methods=["GET"]) # noqa: F821 +@login_required +async def get_memory_messages(memory_id): + args = request.args + agent_ids = args.getlist("agent_id") + if len(agent_ids) == 1 and ',' in agent_ids[0]: + agent_ids = agent_ids[0].split(',') + keywords = args.get("keywords", "") + keywords = keywords.strip() + page = int(args.get("page", 1)) + page_size = int(args.get("page_size", 50)) + try: + res = await memory_api_service.get_memory_messages( + memory_id, agent_ids, keywords, page, page_size + ) + return get_json_result(message=True, data=res) + except NotFoundException as not_found_exception: + logging.error(not_found_exception) + return get_json_result(code=RetCode.NOT_FOUND, message=str(not_found_exception)) + except Exception as e: + logging.error(e) + return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") diff --git a/api/apps/sdk/memories.py b/api/apps/sdk/memories.py deleted file mode 100644 index ada4b34fab..0000000000 --- a/api/apps/sdk/memories.py +++ /dev/null @@ -1,291 +0,0 @@ -# -# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import logging -import os -import time - -from quart import request -from api.apps import login_required, current_user -from api.db import TenantPermission -from api.db.services.memory_service import MemoryService -from api.db.services.user_service import UserTenantService -from api.db.services.canvas_service import UserCanvasService -from api.db.services.task_service import TaskService -from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default -from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result -from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human -from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT -from memory.services.messages import MessageService -from memory.utils.prompt_util import PromptAssembler -from common.constants import MemoryType, RetCode, ForgettingPolicy - - -@manager.route("/memories", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("name", "memory_type", "embd_id", "llm_id") -async def create_memory(): - timing_enabled = os.getenv("RAGFLOW_API_TIMING") - t_start = time.perf_counter() if timing_enabled else None - req = await get_request_json() - t_parsed = time.perf_counter() if timing_enabled else None - # check name length - name = req["name"] - memory_name = name.strip() - if len(memory_name) == 0: - if timing_enabled: - logging.info( - "api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s", - (t_parsed - t_start) * 1000, - (time.perf_counter() - t_start) * 1000, - request.path, - ) - return get_error_argument_result("Memory name cannot be empty or whitespace.") - if len(memory_name) > MEMORY_NAME_LIMIT: - if timing_enabled: - logging.info( - "api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s", - (t_parsed - t_start) * 1000, - (time.perf_counter() - t_start) * 1000, - request.path, - ) - return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.") - # check memory_type valid - if not isinstance(req["memory_type"], list): - if timing_enabled: - logging.info( - "api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s", - (t_parsed - t_start) * 1000, - (time.perf_counter() - t_start) * 1000, - request.path, - ) - return get_error_argument_result("Memory type must be a list.") - memory_type = set(req["memory_type"]) - invalid_type = memory_type - {e.name.lower() for e in MemoryType} - if invalid_type: - if timing_enabled: - logging.info( - "api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s", - (t_parsed - t_start) * 1000, - (time.perf_counter() - t_start) * 1000, - request.path, - ) - return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.") - memory_type = list(memory_type) - - try: - t_before_db = time.perf_counter() if timing_enabled else None - res, memory = MemoryService.create_memory( - tenant_id=current_user.id, - name=memory_name, - memory_type=memory_type, - embd_id=req["embd_id"], - llm_id=req["llm_id"] - ) - if timing_enabled: - logging.info( - "api_timing create_memory parse_ms=%.2f validate_ms=%.2f db_ms=%.2f total_ms=%.2f path=%s", - (t_parsed - t_start) * 1000, - (t_before_db - t_parsed) * 1000, - (time.perf_counter() - t_before_db) * 1000, - (time.perf_counter() - t_start) * 1000, - request.path, - ) - - if res: - return get_json_result(message=True, data=format_ret_data_from_memory(memory)) - else: - return get_json_result(message=memory, code=RetCode.SERVER_ERROR) - - except Exception as e: - return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) - - -@manager.route("/memories/", methods=["PUT"]) # noqa: F821 -@login_required -async def update_memory(memory_id): - req = await get_request_json() - update_dict = {} - # check name length - if "name" in req: - name = req["name"] - memory_name = name.strip() - if len(memory_name) == 0: - return get_error_argument_result("Memory name cannot be empty or whitespace.") - if len(memory_name) > MEMORY_NAME_LIMIT: - return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.") - update_dict["name"] = memory_name - # check permissions valid - if req.get("permissions"): - if req["permissions"] not in [e.value for e in TenantPermission]: - return get_error_argument_result(f"Unknown permission '{req['permissions']}'.") - update_dict["permissions"] = req["permissions"] - if req.get("llm_id"): - update_dict["llm_id"] = req["llm_id"] - if req.get("embd_id"): - update_dict["embd_id"] = req["embd_id"] - if req.get("memory_type"): - memory_type = set(req["memory_type"]) - invalid_type = memory_type - {e.name.lower() for e in MemoryType} - if invalid_type: - return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.") - update_dict["memory_type"] = list(memory_type) - # check memory_size valid - if req.get("memory_size"): - if not 0 < int(req["memory_size"]) <= MEMORY_SIZE_LIMIT: - return get_error_argument_result(f"Memory size should be in range (0, {MEMORY_SIZE_LIMIT}] Bytes.") - update_dict["memory_size"] = req["memory_size"] - # check forgetting_policy valid - if req.get("forgetting_policy"): - if req["forgetting_policy"] not in [e.value for e in ForgettingPolicy]: - return get_error_argument_result(f"Forgetting policy '{req['forgetting_policy']}' is not supported.") - update_dict["forgetting_policy"] = req["forgetting_policy"] - # check temperature valid - if "temperature" in req: - temperature = float(req["temperature"]) - if not 0 <= temperature <= 1: - return get_error_argument_result("Temperature should be in range [0, 1].") - update_dict["temperature"] = temperature - # allow update to empty fields - for field in ["avatar", "description", "system_prompt", "user_prompt"]: - if field in req: - update_dict[field] = req[field] - current_memory = MemoryService.get_by_memory_id(memory_id) - if not current_memory: - return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") - - memory_dict = current_memory.to_dict() - memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)}) - to_update = {} - for k, v in update_dict.items(): - if isinstance(v, list) and set(memory_dict[k]) != set(v): - to_update[k] = v - elif memory_dict[k] != v: - to_update[k] = v - - if not to_update: - return get_json_result(message=True, data=memory_dict) - # check memory empty when update embd_id, memory_type - memory_size = get_memory_size_cache(memory_id, current_memory.tenant_id) - not_allowed_update = [f for f in ["embd_id", "memory_type"] if f in to_update and memory_size > 0] - if not_allowed_update: - return get_error_argument_result(f"Can't update {not_allowed_update} when memory isn't empty.") - if "memory_type" in to_update: - if "system_prompt" not in to_update and judge_system_prompt_is_default(current_memory.system_prompt, current_memory.memory_type): - # update old default prompt, assemble a new one - to_update["system_prompt"] = PromptAssembler.assemble_system_prompt({"memory_type": to_update["memory_type"]}) - - try: - MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update) - updated_memory = MemoryService.get_by_memory_id(memory_id) - return get_json_result(message=True, data=format_ret_data_from_memory(updated_memory)) - - except Exception as e: - logging.error(e) - return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) - - -@manager.route("/memories/", methods=["DELETE"]) # noqa: F821 -@login_required -async def delete_memory(memory_id): - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - return get_json_result(message=True, code=RetCode.NOT_FOUND) - try: - MemoryService.delete_memory(memory_id) - if MessageService.has_index(memory.tenant_id, memory_id): - MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id) - return get_json_result(message=True) - except Exception as e: - logging.error(e) - return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) - - -@manager.route("/memories", methods=["GET"]) # noqa: F821 -@login_required -async def list_memory(): - args = request.args - try: - tenant_ids = args.getlist("tenant_id") - memory_types = args.getlist("memory_type") - storage_type = args.get("storage_type") - keywords = args.get("keywords", "") - page = int(args.get("page", 1)) - page_size = int(args.get("page_size", 50)) - # make filter dict - filter_dict: dict = {"storage_type": storage_type} - if not tenant_ids: - # restrict to current user's tenants - user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id) - filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants] - else: - if len(tenant_ids) == 1 and ',' in tenant_ids[0]: - tenant_ids = tenant_ids[0].split(',') - filter_dict["tenant_id"] = tenant_ids - if memory_types and len(memory_types) == 1 and ',' in memory_types[0]: - memory_types = memory_types[0].split(',') - filter_dict["memory_type"] = memory_types - - memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size) - [memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list] - return get_json_result(message=True, data={"memory_list": memory_list, "total_count": count}) - - except Exception as e: - logging.error(e) - return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) - - -@manager.route("/memories//config", methods=["GET"]) # noqa: F821 -@login_required -async def get_memory_config(memory_id): - memory = MemoryService.get_with_owner_name_by_id(memory_id) - if not memory: - return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") - return get_json_result(message=True, data=format_ret_data_from_memory(memory)) - - -@manager.route("/memories/", methods=["GET"]) # noqa: F821 -@login_required -async def get_memory_detail(memory_id): - args = request.args - agent_ids = args.getlist("agent_id") - if len(agent_ids) == 1 and ',' in agent_ids[0]: - agent_ids = agent_ids[0].split(',') - keywords = args.get("keywords", "") - keywords = keywords.strip() - page = int(args.get("page", 1)) - page_size = int(args.get("page_size", 50)) - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") - messages = MessageService.list_message( - memory.tenant_id, memory_id, agent_ids, keywords, page, page_size) - agent_name_mapping = {} - extract_task_mapping = {} - if messages["message_list"]: - agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]]) - agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list} - task_list = TaskService.get_tasks_progress_by_doc_ids([memory_id]) - if task_list: - task_list.sort(key=lambda t: t["create_time"]) # asc, use newer when exist more than one task - for task in task_list: - # the 'digest' field carries the source_id when a task is created, so use 'digest' as key - extract_task_mapping.update({int(task["digest"]): task}) - for message in messages["message_list"]: - message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown") - message["task"] = extract_task_mapping.get(message["message_id"], {}) - for extract_msg in message["extract"]: - extract_msg["agent_name"] = agent_name_mapping.get(extract_msg["agent_id"], "Unknown") - return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True) diff --git a/api/apps/services/__init__.py b/api/apps/services/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/apps/services/memory_api_service.py b/api/apps/services/memory_api_service.py new file mode 100644 index 0000000000..53bb0f6e9e --- /dev/null +++ b/api/apps/services/memory_api_service.py @@ -0,0 +1,223 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from api.apps import current_user +from api.db import TenantPermission +from api.db.services.memory_service import MemoryService +from api.db.services.user_service import UserTenantService +from api.db.services.canvas_service import UserCanvasService +from api.db.services.task_service import TaskService +from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default +from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human +from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT +from memory.services.messages import MessageService +from memory.utils.prompt_util import PromptAssembler +from common.constants import MemoryType, ForgettingPolicy +from common.exceptions import ArgumentException, NotFoundException + + +async def create_memory(memory_info: dict): + """ + :param memory_info: { + "name": str, + "memory_type": list[str], + "embd_id": str, + "llm_id": str + } + """ + # check name length + name = memory_info["name"] + memory_name = name.strip() + if len(memory_name) == 0: + raise ArgumentException("Memory name cannot be empty or whitespace.") + if len(memory_name) > MEMORY_NAME_LIMIT: + raise ArgumentException(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.") + # check memory_type valid + if not isinstance(memory_info["memory_type"], list): + raise ArgumentException("Memory type must be a list.") + memory_type = set(memory_info["memory_type"]) + invalid_type = memory_type - {e.name.lower() for e in MemoryType} + if invalid_type: + raise ArgumentException(f"Memory type '{invalid_type}' is not supported.") + memory_type = list(memory_type) + success, res = MemoryService.create_memory( + tenant_id=current_user.id, + name=memory_name, + memory_type=memory_type, + embd_id=memory_info["embd_id"], + llm_id=memory_info["llm_id"] + ) + if success: + return True, format_ret_data_from_memory(res) + else: + return False, res + + +async def update_memory(memory_id: str, new_memory_setting: dict): + """ + :param memory_id: str + :param new_memory_setting: { + "name": str, + "permissions": str, + "llm_id": str, + "embd_id": str, + "memory_type": list[str], + "memory_size": int, + "forgetting_policy": str, + "temperature": float, + "avatar": str, + "description": str, + "system_prompt": str, + "user_prompt": str + } + """ + update_dict = {} + # check name length + if "name" in new_memory_setting: + name = new_memory_setting["name"] + memory_name = name.strip() + if len(memory_name) == 0: + raise ArgumentException("Memory name cannot be empty or whitespace.") + if len(memory_name) > MEMORY_NAME_LIMIT: + raise ArgumentException(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.") + update_dict["name"] = memory_name + # check permissions valid + if new_memory_setting.get("permissions"): + if new_memory_setting["permissions"] not in [e.value for e in TenantPermission]: + raise ArgumentException(f"Unknown permission '{new_memory_setting['permissions']}'.") + update_dict["permissions"] = new_memory_setting["permissions"] + if new_memory_setting.get("llm_id"): + update_dict["llm_id"] = new_memory_setting["llm_id"] + if new_memory_setting.get("embd_id"): + update_dict["embd_id"] = new_memory_setting["embd_id"] + if new_memory_setting.get("memory_type"): + memory_type = set(new_memory_setting["memory_type"]) + invalid_type = memory_type - {e.name.lower() for e in MemoryType} + if invalid_type: + raise ArgumentException(f"Memory type '{invalid_type}' is not supported.") + update_dict["memory_type"] = list(memory_type) + # check memory_size valid + if new_memory_setting.get("memory_size"): + if not 0 < int(new_memory_setting["memory_size"]) <= MEMORY_SIZE_LIMIT: + raise ArgumentException(f"Memory size should be in range (0, {MEMORY_SIZE_LIMIT}] Bytes.") + update_dict["memory_size"] = new_memory_setting["memory_size"] + # check forgetting_policy valid + if new_memory_setting.get("forgetting_policy"): + if new_memory_setting["forgetting_policy"] not in [e.value for e in ForgettingPolicy]: + raise ArgumentException(f"Forgetting policy '{new_memory_setting['forgetting_policy']}' is not supported.") + update_dict["forgetting_policy"] = new_memory_setting["forgetting_policy"] + # check temperature valid + if "temperature" in new_memory_setting: + temperature = float(new_memory_setting["temperature"]) + if not 0 <= temperature <= 1: + raise ArgumentException("Temperature should be in range [0, 1].") + update_dict["temperature"] = temperature + # allow update to empty fields + for field in ["avatar", "description", "system_prompt", "user_prompt"]: + if field in new_memory_setting: + update_dict[field] = new_memory_setting[field] + current_memory = MemoryService.get_by_memory_id(memory_id) + if not current_memory: + raise NotFoundException(f"Memory '{memory_id}' not found.") + + memory_dict = current_memory.to_dict() + memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)}) + to_update = {} + for k, v in update_dict.items(): + if isinstance(v, list) and set(memory_dict[k]) != set(v): + to_update[k] = v + elif memory_dict[k] != v: + to_update[k] = v + + if not to_update: + return True, memory_dict + # check memory empty when update embd_id, memory_type + memory_size = get_memory_size_cache(memory_id, current_memory.tenant_id) + not_allowed_update = [f for f in ["embd_id", "memory_type"] if f in to_update and memory_size > 0] + if not_allowed_update: + raise ArgumentException(f"Can't update {not_allowed_update} when memory isn't empty.") + if "memory_type" in to_update: + if "system_prompt" not in to_update and judge_system_prompt_is_default(current_memory.system_prompt, current_memory.memory_type): + # update old default prompt, assemble a new one + to_update["system_prompt"] = PromptAssembler.assemble_system_prompt({"memory_type": to_update["memory_type"]}) + + MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update) + updated_memory = MemoryService.get_by_memory_id(memory_id) + return True, format_ret_data_from_memory(updated_memory) + + +async def delete_memory(memory_id): + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + raise NotFoundException(f"Memory '{memory_id}' not found.") + MemoryService.delete_memory(memory_id) + if MessageService.has_index(memory.tenant_id, memory_id): + MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id) + return True + + +async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size: int = 50): + filter_dict: dict = {"storage_type": filter_params.get("storage_type")} + tenant_ids = filter_params.get("tenant_id") + if not filter_params.get("tenant_id"): + # restrict to current user's tenants + user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id) + filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants] + else: + if len(tenant_ids) == 1 and ',' in tenant_ids[0]: + tenant_ids = tenant_ids[0].split(',') + filter_dict["tenant_id"] = tenant_ids + memory_types = filter_params.get("memory_type") + if memory_types and len(memory_types) == 1 and ',' in memory_types[0]: + memory_types = memory_types[0].split(',') + filter_dict["memory_type"] = memory_types + + memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size) + [memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list] + return { + "memory_list": memory_list, "total_count": count + } + + +async def get_memory_config(memory_id): + memory = MemoryService.get_with_owner_name_by_id(memory_id) + if not memory: + raise NotFoundException(f"Memory '{memory_id}' not found.") + return format_ret_data_from_memory(memory) + + +async def get_memory_messages(memory_id, agent_ids: list[str], keywords: str, page: int=1, page_size: int = 50): + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + raise NotFoundException(f"Memory '{memory_id}' not found.") + messages = MessageService.list_message( + memory.tenant_id, memory_id, agent_ids, keywords, page, page_size) + agent_name_mapping = {} + extract_task_mapping = {} + if messages["message_list"]: + agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]]) + agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list} + task_list = TaskService.get_tasks_progress_by_doc_ids([memory_id]) + if task_list: + task_list.sort(key=lambda t: t["create_time"]) # asc, use newer when exist more than one task + for task in task_list: + # the 'digest' field carries the source_id when a task is created, so use 'digest' as key + extract_task_mapping.update({int(task["digest"]): task}) + for message in messages["message_list"]: + message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown") + message["task"] = extract_task_mapping.get(message["message_id"], {}) + for extract_msg in message["extract"]: + extract_msg["agent_name"] = agent_name_mapping.get(extract_msg["agent_id"], "Unknown") + return {"messages": messages, "storage_type": memory.storage_type} diff --git a/common/exceptions.py b/common/exceptions.py index c0caac4842..9511304720 100644 --- a/common/exceptions.py +++ b/common/exceptions.py @@ -16,3 +16,13 @@ class TaskCanceledException(Exception): def __init__(self, msg): self.msg = msg + + +class ArgumentException(Exception): + def __init__(self, msg): + self.msg = msg + + +class NotFoundException(Exception): + def __init__(self, msg): + self.msg = msg