Fix: validate memory tenant model IDs on update and enforce tenant scope in memory pipeline (#14923)

### Related issues

Closes #14922

### What problem does this PR solve?

`POST /memories` already resolves `tenant_llm_id` and `tenant_embd_id`
through `ensure_tenant_model_id_for_params`, but `PUT
/memories/<memory_id>` accepted client-supplied `tenant_llm_id` /
`tenant_embd_id` without checking that those `tenant_llm` rows belong to
the memory owner’s tenant. A caller could persist another tenant’s row
IDs and later trigger extraction or embedding that loaded foreign model
credentials via `get_model_config_by_id(tenant_model_id)` with no tenant
allow-list.

This change aligns the update path with create: updates that change
models must go through `llm_id` / `embd_id` and
`ensure_tenant_model_id_for_params` scoped to the **memory’s**
`tenant_id` (not only the current user, so team-access cases stay
correct). Direct `tenant_*` fields in the body without `llm_id` /
`embd_id` are rejected. As defense in depth, `memory_message_service`
passes `allowed_tenant_ids` / `requester_tenant_id` into
`get_model_config_by_id` for LLM and embedding resolution so mismatched
IDs cannot be used even if bad data existed. A regression test rejects
payloads that set only `tenant_llm_id` / `tenant_embd_id`.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: jony376 <jony376@gmail.com>
This commit is contained in:
jony376
2026-05-19 05:11:46 +03:00
committed by GitHub
parent b69a6a5d80
commit 198f3c4b9a
3 changed files with 54 additions and 13 deletions

View File

@@ -21,10 +21,11 @@ from api.db.services.canvas_service import UserCanvasService
from api.db.services.task_service import TaskService from api.db.services.task_service import TaskService
from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default, queue_save_to_memory_task, query_message from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default, queue_save_to_memory_task, query_message
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
from api.utils.tenant_utils import ensure_tenant_model_id_for_params
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
from memory.services.messages import MessageService from memory.services.messages import MessageService
from memory.utils.prompt_util import PromptAssembler from memory.utils.prompt_util import PromptAssembler
from common.constants import MemoryType, ForgettingPolicy from common.constants import MemoryType, ForgettingPolicy, LLMType
from common.exceptions import ArgumentException, NotFoundException from common.exceptions import ArgumentException, NotFoundException
from common.time_utils import current_timestamp, timestamp_to_date from common.time_utils import current_timestamp, timestamp_to_date
@@ -131,6 +132,9 @@ async def update_memory(memory_id: str, new_memory_setting: dict):
"user_prompt": str "user_prompt": str
} }
""" """
current_memory = _require_memory_access(memory_id)
owner_tenant_id = current_memory.tenant_id
update_dict = {} update_dict = {}
# check name length # check name length
if "name" in new_memory_setting: if "name" in new_memory_setting:
@@ -146,14 +150,32 @@ async def update_memory(memory_id: str, new_memory_setting: dict):
if new_memory_setting["permissions"] not in [e.value for e in TenantPermission]: if new_memory_setting["permissions"] not in [e.value for e in TenantPermission]:
raise ArgumentException(f"Unknown permission '{new_memory_setting['permissions']}'.") raise ArgumentException(f"Unknown permission '{new_memory_setting['permissions']}'.")
update_dict["permissions"] = new_memory_setting["permissions"] update_dict["permissions"] = new_memory_setting["permissions"]
if new_memory_setting.get("llm_id"): if ("tenant_llm_id" in new_memory_setting or "tenant_embd_id" in new_memory_setting) and not (
update_dict["llm_id"] = new_memory_setting["llm_id"] new_memory_setting.get("llm_id") or new_memory_setting.get("embd_id")
if new_memory_setting.get("embd_id"): ):
update_dict["embd_id"] = new_memory_setting["embd_id"] raise ArgumentException(
if new_memory_setting.get("tenant_llm_id"): "Do not set tenant_llm_id or tenant_embd_id directly; update llm_id and/or embd_id instead."
update_dict["tenant_llm_id"] = new_memory_setting["tenant_llm_id"] )
if new_memory_setting.get("tenant_embd_id"): if new_memory_setting.get("llm_id") or new_memory_setting.get("embd_id"):
update_dict["tenant_embd_id"] = new_memory_setting["tenant_embd_id"] merged = {
"llm_id": new_memory_setting.get("llm_id") or current_memory.llm_id,
"embd_id": new_memory_setting.get("embd_id") or current_memory.embd_id,
}
merged = ensure_tenant_model_id_for_params(owner_tenant_id, merged)
if not merged.get("tenant_llm_id"):
raise ArgumentException(
f"Tenant Model with name {merged['llm_id']} and type {LLMType.CHAT.value} not found"
)
if new_memory_setting.get("embd_id") and not merged.get("tenant_embd_id"):
raise ArgumentException(
f"Tenant Model with name {merged['embd_id']} and type {LLMType.EMBEDDING.value} not found"
)
if new_memory_setting.get("llm_id"):
update_dict["llm_id"] = merged["llm_id"]
if new_memory_setting.get("embd_id"):
update_dict["embd_id"] = merged["embd_id"]
update_dict["tenant_llm_id"] = merged["tenant_llm_id"]
update_dict["tenant_embd_id"] = merged.get("tenant_embd_id")
if new_memory_setting.get("memory_type"): if new_memory_setting.get("memory_type"):
memory_type = set(new_memory_setting["memory_type"]) memory_type = set(new_memory_setting["memory_type"])
invalid_type = memory_type - {e.name.lower() for e in MemoryType} invalid_type = memory_type - {e.name.lower() for e in MemoryType}
@@ -180,7 +202,6 @@ async def update_memory(memory_id: str, new_memory_setting: dict):
for field in ["avatar", "description", "system_prompt", "user_prompt"]: for field in ["avatar", "description", "system_prompt", "user_prompt"]:
if field in new_memory_setting: if field in new_memory_setting:
update_dict[field] = new_memory_setting[field] update_dict[field] = new_memory_setting[field]
current_memory = _require_memory_access(memory_id)
memory_dict = current_memory.to_dict() memory_dict = current_memory.to_dict()
memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)}) memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)})

View File

@@ -154,7 +154,11 @@ async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict,
else: else:
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)}) user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
if tenant_llm_id: if tenant_llm_id:
llm_config = get_model_config_by_id(tenant_llm_id) llm_config = get_model_config_by_id(
tenant_llm_id,
allowed_tenant_ids=tenant_id,
requester_tenant_id=tenant_id,
)
else: else:
llm_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, llm_id) llm_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, llm_id)
llm = LLMBundle(tenant_id, llm_config) llm = LLMBundle(tenant_id, llm_config)
@@ -174,7 +178,11 @@ async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict,
async def embed_and_save(memory, message_list: list[dict], task_id: str=None): async def embed_and_save(memory, message_list: list[dict], task_id: str=None):
if memory.tenant_embd_id: if memory.tenant_embd_id:
embd_model_config = get_model_config_by_id(memory.tenant_embd_id) embd_model_config = get_model_config_by_id(
memory.tenant_embd_id,
allowed_tenant_ids=memory.tenant_id,
requester_tenant_id=memory.tenant_id,
)
else: else:
embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id) embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id)
embedding_model = LLMBundle(memory.tenant_id, embd_model_config) embedding_model = LLMBundle(memory.tenant_id, embd_model_config)
@@ -248,7 +256,11 @@ def query_message(filter_dict: dict, params: dict):
question = question.strip() question = question.strip()
memory = memory_list[0] memory = memory_list[0]
if memory.tenant_embd_id: if memory.tenant_embd_id:
embd_model_config = get_model_config_by_id(memory.tenant_embd_id) embd_model_config = get_model_config_by_id(
memory.tenant_embd_id,
allowed_tenant_ids=memory.tenant_id,
requester_tenant_id=memory.tenant_id,
)
else: else:
embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id) embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id)
embd_model = LLMBundle(memory.tenant_id, embd_model_config) embd_model = LLMBundle(memory.tenant_id, embd_model_config)

View File

@@ -106,6 +106,14 @@ class TestMemoryUpdate:
assert res["code"] == 0, res assert res["code"] == 0, res
assert res["data"]["llm_id"] == llm_id, res assert res["data"]["llm_id"] == llm_id, res
@pytest.mark.p2
def test_reject_direct_tenant_model_ids(self, WebApiAuth, add_memory_func):
memory_ids = add_memory_func
payload = {"tenant_llm_id": 999999, "tenant_embd_id": 999998}
res = update_memory(WebApiAuth, memory_ids[0], payload)
assert res["code"] == 101, res
assert "Do not set tenant_llm_id or tenant_embd_id directly" in res["message"], res
@pytest.mark.p2 @pytest.mark.p2
@pytest.mark.parametrize( @pytest.mark.parametrize(
"permission", "permission",