From 198f3c4b9a4750a80a7ecf6abcdd3c421e8beff4 Mon Sep 17 00:00:00 2001 From: jony376 Date: Tue, 19 May 2026 05:11:46 +0300 Subject: [PATCH] Fix: validate memory tenant model IDs on update and enforce tenant scope in memory pipeline (#14923) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Related issues Closes #14922 ### What problem does this PR solve? `POST /memories` already resolves `tenant_llm_id` and `tenant_embd_id` through `ensure_tenant_model_id_for_params`, but `PUT /memories/` accepted client-supplied `tenant_llm_id` / `tenant_embd_id` without checking that those `tenant_llm` rows belong to the memory owner’s tenant. A caller could persist another tenant’s row IDs and later trigger extraction or embedding that loaded foreign model credentials via `get_model_config_by_id(tenant_model_id)` with no tenant allow-list. This change aligns the update path with create: updates that change models must go through `llm_id` / `embd_id` and `ensure_tenant_model_id_for_params` scoped to the **memory’s** `tenant_id` (not only the current user, so team-access cases stay correct). Direct `tenant_*` fields in the body without `llm_id` / `embd_id` are rejected. As defense in depth, `memory_message_service` passes `allowed_tenant_ids` / `requester_tenant_id` into `get_model_config_by_id` for LLM and embedding resolution so mismatched IDs cannot be used even if bad data existed. A regression test rejects payloads that set only `tenant_llm_id` / `tenant_embd_id`. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: jony376 --- api/apps/services/memory_api_service.py | 41 ++++++++++++++----- .../joint_services/memory_message_service.py | 18 ++++++-- .../test_memory_app/test_update_memory.py | 8 ++++ 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/api/apps/services/memory_api_service.py b/api/apps/services/memory_api_service.py index 9040f0ce44..53cff623ce 100644 --- a/api/apps/services/memory_api_service.py +++ b/api/apps/services/memory_api_service.py @@ -21,10 +21,11 @@ from api.db.services.canvas_service import UserCanvasService from api.db.services.task_service import TaskService from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default, queue_save_to_memory_task, query_message from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human +from api.utils.tenant_utils import ensure_tenant_model_id_for_params from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT from memory.services.messages import MessageService from memory.utils.prompt_util import PromptAssembler -from common.constants import MemoryType, ForgettingPolicy +from common.constants import MemoryType, ForgettingPolicy, LLMType from common.exceptions import ArgumentException, NotFoundException from common.time_utils import current_timestamp, timestamp_to_date @@ -131,6 +132,9 @@ async def update_memory(memory_id: str, new_memory_setting: dict): "user_prompt": str } """ + current_memory = _require_memory_access(memory_id) + owner_tenant_id = current_memory.tenant_id + update_dict = {} # check name length if "name" in new_memory_setting: @@ -146,14 +150,32 @@ async def update_memory(memory_id: str, new_memory_setting: dict): if new_memory_setting["permissions"] not in [e.value for e in TenantPermission]: raise ArgumentException(f"Unknown permission '{new_memory_setting['permissions']}'.") update_dict["permissions"] = new_memory_setting["permissions"] - if new_memory_setting.get("llm_id"): - update_dict["llm_id"] = new_memory_setting["llm_id"] - if new_memory_setting.get("embd_id"): - update_dict["embd_id"] = new_memory_setting["embd_id"] - if new_memory_setting.get("tenant_llm_id"): - update_dict["tenant_llm_id"] = new_memory_setting["tenant_llm_id"] - if new_memory_setting.get("tenant_embd_id"): - update_dict["tenant_embd_id"] = new_memory_setting["tenant_embd_id"] + if ("tenant_llm_id" in new_memory_setting or "tenant_embd_id" in new_memory_setting) and not ( + new_memory_setting.get("llm_id") or new_memory_setting.get("embd_id") + ): + raise ArgumentException( + "Do not set tenant_llm_id or tenant_embd_id directly; update llm_id and/or embd_id instead." + ) + if new_memory_setting.get("llm_id") or new_memory_setting.get("embd_id"): + merged = { + "llm_id": new_memory_setting.get("llm_id") or current_memory.llm_id, + "embd_id": new_memory_setting.get("embd_id") or current_memory.embd_id, + } + merged = ensure_tenant_model_id_for_params(owner_tenant_id, merged) + if not merged.get("tenant_llm_id"): + raise ArgumentException( + f"Tenant Model with name {merged['llm_id']} and type {LLMType.CHAT.value} not found" + ) + if new_memory_setting.get("embd_id") and not merged.get("tenant_embd_id"): + raise ArgumentException( + f"Tenant Model with name {merged['embd_id']} and type {LLMType.EMBEDDING.value} not found" + ) + if new_memory_setting.get("llm_id"): + update_dict["llm_id"] = merged["llm_id"] + if new_memory_setting.get("embd_id"): + update_dict["embd_id"] = merged["embd_id"] + update_dict["tenant_llm_id"] = merged["tenant_llm_id"] + update_dict["tenant_embd_id"] = merged.get("tenant_embd_id") if new_memory_setting.get("memory_type"): memory_type = set(new_memory_setting["memory_type"]) invalid_type = memory_type - {e.name.lower() for e in MemoryType} @@ -180,7 +202,6 @@ async def update_memory(memory_id: str, new_memory_setting: dict): for field in ["avatar", "description", "system_prompt", "user_prompt"]: if field in new_memory_setting: update_dict[field] = new_memory_setting[field] - current_memory = _require_memory_access(memory_id) memory_dict = current_memory.to_dict() memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)}) diff --git a/api/db/joint_services/memory_message_service.py b/api/db/joint_services/memory_message_service.py index 4765b2bdbb..1a6da3a8d6 100644 --- a/api/db/joint_services/memory_message_service.py +++ b/api/db/joint_services/memory_message_service.py @@ -154,7 +154,11 @@ async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict, else: user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)}) if tenant_llm_id: - llm_config = get_model_config_by_id(tenant_llm_id) + llm_config = get_model_config_by_id( + tenant_llm_id, + allowed_tenant_ids=tenant_id, + requester_tenant_id=tenant_id, + ) else: llm_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, llm_id) llm = LLMBundle(tenant_id, llm_config) @@ -174,7 +178,11 @@ async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict, async def embed_and_save(memory, message_list: list[dict], task_id: str=None): if memory.tenant_embd_id: - embd_model_config = get_model_config_by_id(memory.tenant_embd_id) + embd_model_config = get_model_config_by_id( + memory.tenant_embd_id, + allowed_tenant_ids=memory.tenant_id, + requester_tenant_id=memory.tenant_id, + ) else: embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id) embedding_model = LLMBundle(memory.tenant_id, embd_model_config) @@ -248,7 +256,11 @@ def query_message(filter_dict: dict, params: dict): question = question.strip() memory = memory_list[0] if memory.tenant_embd_id: - embd_model_config = get_model_config_by_id(memory.tenant_embd_id) + embd_model_config = get_model_config_by_id( + memory.tenant_embd_id, + allowed_tenant_ids=memory.tenant_id, + requester_tenant_id=memory.tenant_id, + ) else: embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id) embd_model = LLMBundle(memory.tenant_id, embd_model_config) diff --git a/test/testcases/test_web_api/test_memory_app/test_update_memory.py b/test/testcases/test_web_api/test_memory_app/test_update_memory.py index 1fa92b8e44..72ecfaa8ec 100644 --- a/test/testcases/test_web_api/test_memory_app/test_update_memory.py +++ b/test/testcases/test_web_api/test_memory_app/test_update_memory.py @@ -106,6 +106,14 @@ class TestMemoryUpdate: assert res["code"] == 0, res assert res["data"]["llm_id"] == llm_id, res + @pytest.mark.p2 + def test_reject_direct_tenant_model_ids(self, WebApiAuth, add_memory_func): + memory_ids = add_memory_func + payload = {"tenant_llm_id": 999999, "tenant_embd_id": 999998} + res = update_memory(WebApiAuth, memory_ids[0], payload) + assert res["code"] == 101, res + assert "Do not set tenant_llm_id or tenant_embd_id directly" in res["message"], res + @pytest.mark.p2 @pytest.mark.parametrize( "permission",