diff --git a/api/apps/services/memory_api_service.py b/api/apps/services/memory_api_service.py index 9040f0ce44..53cff623ce 100644 --- a/api/apps/services/memory_api_service.py +++ b/api/apps/services/memory_api_service.py @@ -21,10 +21,11 @@ from api.db.services.canvas_service import UserCanvasService from api.db.services.task_service import TaskService from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default, queue_save_to_memory_task, query_message from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human +from api.utils.tenant_utils import ensure_tenant_model_id_for_params from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT from memory.services.messages import MessageService from memory.utils.prompt_util import PromptAssembler -from common.constants import MemoryType, ForgettingPolicy +from common.constants import MemoryType, ForgettingPolicy, LLMType from common.exceptions import ArgumentException, NotFoundException from common.time_utils import current_timestamp, timestamp_to_date @@ -131,6 +132,9 @@ async def update_memory(memory_id: str, new_memory_setting: dict): "user_prompt": str } """ + current_memory = _require_memory_access(memory_id) + owner_tenant_id = current_memory.tenant_id + update_dict = {} # check name length if "name" in new_memory_setting: @@ -146,14 +150,32 @@ async def update_memory(memory_id: str, new_memory_setting: dict): if new_memory_setting["permissions"] not in [e.value for e in TenantPermission]: raise ArgumentException(f"Unknown permission '{new_memory_setting['permissions']}'.") update_dict["permissions"] = new_memory_setting["permissions"] - if new_memory_setting.get("llm_id"): - update_dict["llm_id"] = new_memory_setting["llm_id"] - if new_memory_setting.get("embd_id"): - update_dict["embd_id"] = new_memory_setting["embd_id"] - if new_memory_setting.get("tenant_llm_id"): - update_dict["tenant_llm_id"] = new_memory_setting["tenant_llm_id"] - if new_memory_setting.get("tenant_embd_id"): - update_dict["tenant_embd_id"] = new_memory_setting["tenant_embd_id"] + if ("tenant_llm_id" in new_memory_setting or "tenant_embd_id" in new_memory_setting) and not ( + new_memory_setting.get("llm_id") or new_memory_setting.get("embd_id") + ): + raise ArgumentException( + "Do not set tenant_llm_id or tenant_embd_id directly; update llm_id and/or embd_id instead." + ) + if new_memory_setting.get("llm_id") or new_memory_setting.get("embd_id"): + merged = { + "llm_id": new_memory_setting.get("llm_id") or current_memory.llm_id, + "embd_id": new_memory_setting.get("embd_id") or current_memory.embd_id, + } + merged = ensure_tenant_model_id_for_params(owner_tenant_id, merged) + if not merged.get("tenant_llm_id"): + raise ArgumentException( + f"Tenant Model with name {merged['llm_id']} and type {LLMType.CHAT.value} not found" + ) + if new_memory_setting.get("embd_id") and not merged.get("tenant_embd_id"): + raise ArgumentException( + f"Tenant Model with name {merged['embd_id']} and type {LLMType.EMBEDDING.value} not found" + ) + if new_memory_setting.get("llm_id"): + update_dict["llm_id"] = merged["llm_id"] + if new_memory_setting.get("embd_id"): + update_dict["embd_id"] = merged["embd_id"] + update_dict["tenant_llm_id"] = merged["tenant_llm_id"] + update_dict["tenant_embd_id"] = merged.get("tenant_embd_id") if new_memory_setting.get("memory_type"): memory_type = set(new_memory_setting["memory_type"]) invalid_type = memory_type - {e.name.lower() for e in MemoryType} @@ -180,7 +202,6 @@ async def update_memory(memory_id: str, new_memory_setting: dict): for field in ["avatar", "description", "system_prompt", "user_prompt"]: if field in new_memory_setting: update_dict[field] = new_memory_setting[field] - current_memory = _require_memory_access(memory_id) memory_dict = current_memory.to_dict() memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)}) diff --git a/api/db/joint_services/memory_message_service.py b/api/db/joint_services/memory_message_service.py index 4765b2bdbb..1a6da3a8d6 100644 --- a/api/db/joint_services/memory_message_service.py +++ b/api/db/joint_services/memory_message_service.py @@ -154,7 +154,11 @@ async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict, else: user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)}) if tenant_llm_id: - llm_config = get_model_config_by_id(tenant_llm_id) + llm_config = get_model_config_by_id( + tenant_llm_id, + allowed_tenant_ids=tenant_id, + requester_tenant_id=tenant_id, + ) else: llm_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, llm_id) llm = LLMBundle(tenant_id, llm_config) @@ -174,7 +178,11 @@ async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict, async def embed_and_save(memory, message_list: list[dict], task_id: str=None): if memory.tenant_embd_id: - embd_model_config = get_model_config_by_id(memory.tenant_embd_id) + embd_model_config = get_model_config_by_id( + memory.tenant_embd_id, + allowed_tenant_ids=memory.tenant_id, + requester_tenant_id=memory.tenant_id, + ) else: embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id) embedding_model = LLMBundle(memory.tenant_id, embd_model_config) @@ -248,7 +256,11 @@ def query_message(filter_dict: dict, params: dict): question = question.strip() memory = memory_list[0] if memory.tenant_embd_id: - embd_model_config = get_model_config_by_id(memory.tenant_embd_id) + embd_model_config = get_model_config_by_id( + memory.tenant_embd_id, + allowed_tenant_ids=memory.tenant_id, + requester_tenant_id=memory.tenant_id, + ) else: embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id) embd_model = LLMBundle(memory.tenant_id, embd_model_config) diff --git a/test/testcases/test_web_api/test_memory_app/test_update_memory.py b/test/testcases/test_web_api/test_memory_app/test_update_memory.py index 1fa92b8e44..72ecfaa8ec 100644 --- a/test/testcases/test_web_api/test_memory_app/test_update_memory.py +++ b/test/testcases/test_web_api/test_memory_app/test_update_memory.py @@ -106,6 +106,14 @@ class TestMemoryUpdate: assert res["code"] == 0, res assert res["data"]["llm_id"] == llm_id, res + @pytest.mark.p2 + def test_reject_direct_tenant_model_ids(self, WebApiAuth, add_memory_func): + memory_ids = add_memory_func + payload = {"tenant_llm_id": 999999, "tenant_embd_id": 999998} + res = update_memory(WebApiAuth, memory_ids[0], payload) + assert res["code"] == 101, res + assert "Do not set tenant_llm_id or tenant_embd_id directly" in res["message"], res + @pytest.mark.p2 @pytest.mark.parametrize( "permission",