mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Fix: validate memory tenant model IDs on update and enforce tenant scope in memory pipeline (#14923)
### Related issues Closes #14922 ### What problem does this PR solve? `POST /memories` already resolves `tenant_llm_id` and `tenant_embd_id` through `ensure_tenant_model_id_for_params`, but `PUT /memories/<memory_id>` accepted client-supplied `tenant_llm_id` / `tenant_embd_id` without checking that those `tenant_llm` rows belong to the memory owner’s tenant. A caller could persist another tenant’s row IDs and later trigger extraction or embedding that loaded foreign model credentials via `get_model_config_by_id(tenant_model_id)` with no tenant allow-list. This change aligns the update path with create: updates that change models must go through `llm_id` / `embd_id` and `ensure_tenant_model_id_for_params` scoped to the **memory’s** `tenant_id` (not only the current user, so team-access cases stay correct). Direct `tenant_*` fields in the body without `llm_id` / `embd_id` are rejected. As defense in depth, `memory_message_service` passes `allowed_tenant_ids` / `requester_tenant_id` into `get_model_config_by_id` for LLM and embedding resolution so mismatched IDs cannot be used even if bad data existed. A regression test rejects payloads that set only `tenant_llm_id` / `tenant_embd_id`. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: jony376 <jony376@gmail.com>
This commit is contained in:
@@ -21,10 +21,11 @@ from api.db.services.canvas_service import UserCanvasService
|
|||||||
from api.db.services.task_service import TaskService
|
from api.db.services.task_service import TaskService
|
||||||
from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default, queue_save_to_memory_task, query_message
|
from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default, queue_save_to_memory_task, query_message
|
||||||
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
|
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
|
||||||
|
from api.utils.tenant_utils import ensure_tenant_model_id_for_params
|
||||||
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
|
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
|
||||||
from memory.services.messages import MessageService
|
from memory.services.messages import MessageService
|
||||||
from memory.utils.prompt_util import PromptAssembler
|
from memory.utils.prompt_util import PromptAssembler
|
||||||
from common.constants import MemoryType, ForgettingPolicy
|
from common.constants import MemoryType, ForgettingPolicy, LLMType
|
||||||
from common.exceptions import ArgumentException, NotFoundException
|
from common.exceptions import ArgumentException, NotFoundException
|
||||||
from common.time_utils import current_timestamp, timestamp_to_date
|
from common.time_utils import current_timestamp, timestamp_to_date
|
||||||
|
|
||||||
@@ -131,6 +132,9 @@ async def update_memory(memory_id: str, new_memory_setting: dict):
|
|||||||
"user_prompt": str
|
"user_prompt": str
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
current_memory = _require_memory_access(memory_id)
|
||||||
|
owner_tenant_id = current_memory.tenant_id
|
||||||
|
|
||||||
update_dict = {}
|
update_dict = {}
|
||||||
# check name length
|
# check name length
|
||||||
if "name" in new_memory_setting:
|
if "name" in new_memory_setting:
|
||||||
@@ -146,14 +150,32 @@ async def update_memory(memory_id: str, new_memory_setting: dict):
|
|||||||
if new_memory_setting["permissions"] not in [e.value for e in TenantPermission]:
|
if new_memory_setting["permissions"] not in [e.value for e in TenantPermission]:
|
||||||
raise ArgumentException(f"Unknown permission '{new_memory_setting['permissions']}'.")
|
raise ArgumentException(f"Unknown permission '{new_memory_setting['permissions']}'.")
|
||||||
update_dict["permissions"] = new_memory_setting["permissions"]
|
update_dict["permissions"] = new_memory_setting["permissions"]
|
||||||
if new_memory_setting.get("llm_id"):
|
if ("tenant_llm_id" in new_memory_setting or "tenant_embd_id" in new_memory_setting) and not (
|
||||||
update_dict["llm_id"] = new_memory_setting["llm_id"]
|
new_memory_setting.get("llm_id") or new_memory_setting.get("embd_id")
|
||||||
if new_memory_setting.get("embd_id"):
|
):
|
||||||
update_dict["embd_id"] = new_memory_setting["embd_id"]
|
raise ArgumentException(
|
||||||
if new_memory_setting.get("tenant_llm_id"):
|
"Do not set tenant_llm_id or tenant_embd_id directly; update llm_id and/or embd_id instead."
|
||||||
update_dict["tenant_llm_id"] = new_memory_setting["tenant_llm_id"]
|
)
|
||||||
if new_memory_setting.get("tenant_embd_id"):
|
if new_memory_setting.get("llm_id") or new_memory_setting.get("embd_id"):
|
||||||
update_dict["tenant_embd_id"] = new_memory_setting["tenant_embd_id"]
|
merged = {
|
||||||
|
"llm_id": new_memory_setting.get("llm_id") or current_memory.llm_id,
|
||||||
|
"embd_id": new_memory_setting.get("embd_id") or current_memory.embd_id,
|
||||||
|
}
|
||||||
|
merged = ensure_tenant_model_id_for_params(owner_tenant_id, merged)
|
||||||
|
if not merged.get("tenant_llm_id"):
|
||||||
|
raise ArgumentException(
|
||||||
|
f"Tenant Model with name {merged['llm_id']} and type {LLMType.CHAT.value} not found"
|
||||||
|
)
|
||||||
|
if new_memory_setting.get("embd_id") and not merged.get("tenant_embd_id"):
|
||||||
|
raise ArgumentException(
|
||||||
|
f"Tenant Model with name {merged['embd_id']} and type {LLMType.EMBEDDING.value} not found"
|
||||||
|
)
|
||||||
|
if new_memory_setting.get("llm_id"):
|
||||||
|
update_dict["llm_id"] = merged["llm_id"]
|
||||||
|
if new_memory_setting.get("embd_id"):
|
||||||
|
update_dict["embd_id"] = merged["embd_id"]
|
||||||
|
update_dict["tenant_llm_id"] = merged["tenant_llm_id"]
|
||||||
|
update_dict["tenant_embd_id"] = merged.get("tenant_embd_id")
|
||||||
if new_memory_setting.get("memory_type"):
|
if new_memory_setting.get("memory_type"):
|
||||||
memory_type = set(new_memory_setting["memory_type"])
|
memory_type = set(new_memory_setting["memory_type"])
|
||||||
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
|
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
|
||||||
@@ -180,7 +202,6 @@ async def update_memory(memory_id: str, new_memory_setting: dict):
|
|||||||
for field in ["avatar", "description", "system_prompt", "user_prompt"]:
|
for field in ["avatar", "description", "system_prompt", "user_prompt"]:
|
||||||
if field in new_memory_setting:
|
if field in new_memory_setting:
|
||||||
update_dict[field] = new_memory_setting[field]
|
update_dict[field] = new_memory_setting[field]
|
||||||
current_memory = _require_memory_access(memory_id)
|
|
||||||
|
|
||||||
memory_dict = current_memory.to_dict()
|
memory_dict = current_memory.to_dict()
|
||||||
memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)})
|
memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)})
|
||||||
|
|||||||
@@ -154,7 +154,11 @@ async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict,
|
|||||||
else:
|
else:
|
||||||
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
|
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
|
||||||
if tenant_llm_id:
|
if tenant_llm_id:
|
||||||
llm_config = get_model_config_by_id(tenant_llm_id)
|
llm_config = get_model_config_by_id(
|
||||||
|
tenant_llm_id,
|
||||||
|
allowed_tenant_ids=tenant_id,
|
||||||
|
requester_tenant_id=tenant_id,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
llm_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, llm_id)
|
llm_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, llm_id)
|
||||||
llm = LLMBundle(tenant_id, llm_config)
|
llm = LLMBundle(tenant_id, llm_config)
|
||||||
@@ -174,7 +178,11 @@ async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict,
|
|||||||
|
|
||||||
async def embed_and_save(memory, message_list: list[dict], task_id: str=None):
|
async def embed_and_save(memory, message_list: list[dict], task_id: str=None):
|
||||||
if memory.tenant_embd_id:
|
if memory.tenant_embd_id:
|
||||||
embd_model_config = get_model_config_by_id(memory.tenant_embd_id)
|
embd_model_config = get_model_config_by_id(
|
||||||
|
memory.tenant_embd_id,
|
||||||
|
allowed_tenant_ids=memory.tenant_id,
|
||||||
|
requester_tenant_id=memory.tenant_id,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id)
|
embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id)
|
||||||
embedding_model = LLMBundle(memory.tenant_id, embd_model_config)
|
embedding_model = LLMBundle(memory.tenant_id, embd_model_config)
|
||||||
@@ -248,7 +256,11 @@ def query_message(filter_dict: dict, params: dict):
|
|||||||
question = question.strip()
|
question = question.strip()
|
||||||
memory = memory_list[0]
|
memory = memory_list[0]
|
||||||
if memory.tenant_embd_id:
|
if memory.tenant_embd_id:
|
||||||
embd_model_config = get_model_config_by_id(memory.tenant_embd_id)
|
embd_model_config = get_model_config_by_id(
|
||||||
|
memory.tenant_embd_id,
|
||||||
|
allowed_tenant_ids=memory.tenant_id,
|
||||||
|
requester_tenant_id=memory.tenant_id,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id)
|
embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id)
|
||||||
embd_model = LLMBundle(memory.tenant_id, embd_model_config)
|
embd_model = LLMBundle(memory.tenant_id, embd_model_config)
|
||||||
|
|||||||
@@ -106,6 +106,14 @@ class TestMemoryUpdate:
|
|||||||
assert res["code"] == 0, res
|
assert res["code"] == 0, res
|
||||||
assert res["data"]["llm_id"] == llm_id, res
|
assert res["data"]["llm_id"] == llm_id, res
|
||||||
|
|
||||||
|
@pytest.mark.p2
|
||||||
|
def test_reject_direct_tenant_model_ids(self, WebApiAuth, add_memory_func):
|
||||||
|
memory_ids = add_memory_func
|
||||||
|
payload = {"tenant_llm_id": 999999, "tenant_embd_id": 999998}
|
||||||
|
res = update_memory(WebApiAuth, memory_ids[0], payload)
|
||||||
|
assert res["code"] == 101, res
|
||||||
|
assert "Do not set tenant_llm_id or tenant_embd_id directly" in res["message"], res
|
||||||
|
|
||||||
@pytest.mark.p2
|
@pytest.mark.p2
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"permission",
|
"permission",
|
||||||
|
|||||||
Reference in New Issue
Block a user