mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Fix: validate memory tenant model IDs on update and enforce tenant scope in memory pipeline (#14923)
### Related issues Closes #14922 ### What problem does this PR solve? `POST /memories` already resolves `tenant_llm_id` and `tenant_embd_id` through `ensure_tenant_model_id_for_params`, but `PUT /memories/<memory_id>` accepted client-supplied `tenant_llm_id` / `tenant_embd_id` without checking that those `tenant_llm` rows belong to the memory owner’s tenant. A caller could persist another tenant’s row IDs and later trigger extraction or embedding that loaded foreign model credentials via `get_model_config_by_id(tenant_model_id)` with no tenant allow-list. This change aligns the update path with create: updates that change models must go through `llm_id` / `embd_id` and `ensure_tenant_model_id_for_params` scoped to the **memory’s** `tenant_id` (not only the current user, so team-access cases stay correct). Direct `tenant_*` fields in the body without `llm_id` / `embd_id` are rejected. As defense in depth, `memory_message_service` passes `allowed_tenant_ids` / `requester_tenant_id` into `get_model_config_by_id` for LLM and embedding resolution so mismatched IDs cannot be used even if bad data existed. A regression test rejects payloads that set only `tenant_llm_id` / `tenant_embd_id`. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: jony376 <jony376@gmail.com>
This commit is contained in:
@@ -21,10 +21,11 @@ from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default, queue_save_to_memory_task, query_message
|
||||
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
|
||||
from api.utils.tenant_utils import ensure_tenant_model_id_for_params
|
||||
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
|
||||
from memory.services.messages import MessageService
|
||||
from memory.utils.prompt_util import PromptAssembler
|
||||
from common.constants import MemoryType, ForgettingPolicy
|
||||
from common.constants import MemoryType, ForgettingPolicy, LLMType
|
||||
from common.exceptions import ArgumentException, NotFoundException
|
||||
from common.time_utils import current_timestamp, timestamp_to_date
|
||||
|
||||
@@ -131,6 +132,9 @@ async def update_memory(memory_id: str, new_memory_setting: dict):
|
||||
"user_prompt": str
|
||||
}
|
||||
"""
|
||||
current_memory = _require_memory_access(memory_id)
|
||||
owner_tenant_id = current_memory.tenant_id
|
||||
|
||||
update_dict = {}
|
||||
# check name length
|
||||
if "name" in new_memory_setting:
|
||||
@@ -146,14 +150,32 @@ async def update_memory(memory_id: str, new_memory_setting: dict):
|
||||
if new_memory_setting["permissions"] not in [e.value for e in TenantPermission]:
|
||||
raise ArgumentException(f"Unknown permission '{new_memory_setting['permissions']}'.")
|
||||
update_dict["permissions"] = new_memory_setting["permissions"]
|
||||
if new_memory_setting.get("llm_id"):
|
||||
update_dict["llm_id"] = new_memory_setting["llm_id"]
|
||||
if new_memory_setting.get("embd_id"):
|
||||
update_dict["embd_id"] = new_memory_setting["embd_id"]
|
||||
if new_memory_setting.get("tenant_llm_id"):
|
||||
update_dict["tenant_llm_id"] = new_memory_setting["tenant_llm_id"]
|
||||
if new_memory_setting.get("tenant_embd_id"):
|
||||
update_dict["tenant_embd_id"] = new_memory_setting["tenant_embd_id"]
|
||||
if ("tenant_llm_id" in new_memory_setting or "tenant_embd_id" in new_memory_setting) and not (
|
||||
new_memory_setting.get("llm_id") or new_memory_setting.get("embd_id")
|
||||
):
|
||||
raise ArgumentException(
|
||||
"Do not set tenant_llm_id or tenant_embd_id directly; update llm_id and/or embd_id instead."
|
||||
)
|
||||
if new_memory_setting.get("llm_id") or new_memory_setting.get("embd_id"):
|
||||
merged = {
|
||||
"llm_id": new_memory_setting.get("llm_id") or current_memory.llm_id,
|
||||
"embd_id": new_memory_setting.get("embd_id") or current_memory.embd_id,
|
||||
}
|
||||
merged = ensure_tenant_model_id_for_params(owner_tenant_id, merged)
|
||||
if not merged.get("tenant_llm_id"):
|
||||
raise ArgumentException(
|
||||
f"Tenant Model with name {merged['llm_id']} and type {LLMType.CHAT.value} not found"
|
||||
)
|
||||
if new_memory_setting.get("embd_id") and not merged.get("tenant_embd_id"):
|
||||
raise ArgumentException(
|
||||
f"Tenant Model with name {merged['embd_id']} and type {LLMType.EMBEDDING.value} not found"
|
||||
)
|
||||
if new_memory_setting.get("llm_id"):
|
||||
update_dict["llm_id"] = merged["llm_id"]
|
||||
if new_memory_setting.get("embd_id"):
|
||||
update_dict["embd_id"] = merged["embd_id"]
|
||||
update_dict["tenant_llm_id"] = merged["tenant_llm_id"]
|
||||
update_dict["tenant_embd_id"] = merged.get("tenant_embd_id")
|
||||
if new_memory_setting.get("memory_type"):
|
||||
memory_type = set(new_memory_setting["memory_type"])
|
||||
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
|
||||
@@ -180,7 +202,6 @@ async def update_memory(memory_id: str, new_memory_setting: dict):
|
||||
for field in ["avatar", "description", "system_prompt", "user_prompt"]:
|
||||
if field in new_memory_setting:
|
||||
update_dict[field] = new_memory_setting[field]
|
||||
current_memory = _require_memory_access(memory_id)
|
||||
|
||||
memory_dict = current_memory.to_dict()
|
||||
memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)})
|
||||
|
||||
@@ -154,7 +154,11 @@ async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict,
|
||||
else:
|
||||
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
|
||||
if tenant_llm_id:
|
||||
llm_config = get_model_config_by_id(tenant_llm_id)
|
||||
llm_config = get_model_config_by_id(
|
||||
tenant_llm_id,
|
||||
allowed_tenant_ids=tenant_id,
|
||||
requester_tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
llm_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, llm_id)
|
||||
llm = LLMBundle(tenant_id, llm_config)
|
||||
@@ -174,7 +178,11 @@ async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict,
|
||||
|
||||
async def embed_and_save(memory, message_list: list[dict], task_id: str=None):
|
||||
if memory.tenant_embd_id:
|
||||
embd_model_config = get_model_config_by_id(memory.tenant_embd_id)
|
||||
embd_model_config = get_model_config_by_id(
|
||||
memory.tenant_embd_id,
|
||||
allowed_tenant_ids=memory.tenant_id,
|
||||
requester_tenant_id=memory.tenant_id,
|
||||
)
|
||||
else:
|
||||
embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id)
|
||||
embedding_model = LLMBundle(memory.tenant_id, embd_model_config)
|
||||
@@ -248,7 +256,11 @@ def query_message(filter_dict: dict, params: dict):
|
||||
question = question.strip()
|
||||
memory = memory_list[0]
|
||||
if memory.tenant_embd_id:
|
||||
embd_model_config = get_model_config_by_id(memory.tenant_embd_id)
|
||||
embd_model_config = get_model_config_by_id(
|
||||
memory.tenant_embd_id,
|
||||
allowed_tenant_ids=memory.tenant_id,
|
||||
requester_tenant_id=memory.tenant_id,
|
||||
)
|
||||
else:
|
||||
embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id)
|
||||
embd_model = LLMBundle(memory.tenant_id, embd_model_config)
|
||||
|
||||
@@ -106,6 +106,14 @@ class TestMemoryUpdate:
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["llm_id"] == llm_id, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_reject_direct_tenant_model_ids(self, WebApiAuth, add_memory_func):
|
||||
memory_ids = add_memory_func
|
||||
payload = {"tenant_llm_id": 999999, "tenant_embd_id": 999998}
|
||||
res = update_memory(WebApiAuth, memory_ids[0], payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Do not set tenant_llm_id or tenant_embd_id directly" in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"permission",
|
||||
|
||||
Reference in New Issue
Block a user