diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index cf297c4b25..a71b901617 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -492,7 +492,12 @@ async def retrieval_test(tenant_id): rerank_mdl = None if req.get("tenant_rerank_id"): - rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"]) + allowed_rerank_tenant_ids = {tenant_id, *[dataset.tenant_id for dataset in kbs]} + rerank_model_config = get_model_config_by_id( + req["tenant_rerank_id"], + allowed_tenant_ids=allowed_rerank_tenant_ids, + requester_tenant_id=tenant_id, + ) rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) elif req.get("rerank_id"): rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK, req["rerank_id"]) diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 815fe79e35..7ba6fbd81d 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -467,7 +467,13 @@ async def retrieval_test_embedded(): rerank_mdl = None if tenant_rerank_id: - rerank_model_config = await thread_pool_exec(get_model_config_by_id, tenant_rerank_id) + allowed_rerank_tenant_ids = {tenant_id, *tenant_ids} + rerank_model_config = await thread_pool_exec( + get_model_config_by_id, + tenant_rerank_id, + allowed_rerank_tenant_ids, + tenant_id, + ) rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) elif rerank_id: rerank_model_config = await thread_pool_exec(get_model_config_by_type_and_name, tenant_id, LLMType.RERANK, rerank_id) diff --git a/api/apps/services/dataset_api_service.py b/api/apps/services/dataset_api_service.py index d2b4497da8..74b081add3 100644 --- a/api/apps/services/dataset_api_service.py +++ b/api/apps/services/dataset_api_service.py @@ -1010,7 +1010,12 @@ async def search(dataset_id: str, tenant_id: str, req: dict): rerank_mdl = None if req.get("tenant_rerank_id"): - rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"]) + allowed_rerank_tenant_ids = {tenant_id, kb.tenant_id} + rerank_model_config = get_model_config_by_id( + req["tenant_rerank_id"], + allowed_tenant_ids=allowed_rerank_tenant_ids, + requester_tenant_id=tenant_id, + ) rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) elif req.get("rerank_id"): rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"]) @@ -1372,7 +1377,12 @@ async def search_datasets(tenant_id: str, req: dict): rerank_mdl = None if req.get("tenant_rerank_id"): - rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"]) + allowed_rerank_tenant_ids = {tenant_id, *[dataset.tenant_id for dataset in kbs]} + rerank_model_config = get_model_config_by_id( + req["tenant_rerank_id"], + allowed_tenant_ids=allowed_rerank_tenant_ids, + requester_tenant_id=tenant_id, + ) rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) elif req.get("rerank_id"): rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"]) diff --git a/api/db/joint_services/tenant_model_service.py b/api/db/joint_services/tenant_model_service.py index 645d756381..677bfcaaaf 100644 --- a/api/db/joint_services/tenant_model_service.py +++ b/api/db/joint_services/tenant_model_service.py @@ -24,10 +24,29 @@ from api.db.services.tenant_llm_service import TenantLLMService, TenantService logger = logging.getLogger(__name__) -def get_model_config_by_id(tenant_model_id: int) -> dict: +def get_model_config_by_id( + tenant_model_id: int, + allowed_tenant_ids: str | list[str] | set[str] | tuple[str, ...] | None = None, + requester_tenant_id: str | None = None, +) -> dict: found, model_config = TenantLLMService.get_by_id(tenant_model_id) if not found: raise LookupError(f"Tenant Model with id {tenant_model_id} not found") + if allowed_tenant_ids is not None: + if isinstance(allowed_tenant_ids, str): + allowed_tenant_ids = {allowed_tenant_ids} + else: + allowed_tenant_ids = {str(tenant_id) for tenant_id in allowed_tenant_ids if tenant_id} + if str(model_config.tenant_id) not in allowed_tenant_ids: + logger.warning( + "Denied tenant model access: tenant_model_id=%s model_tenant_id=%s " + "allowed_tenant_ids=%s requester_tenant_id=%s", + tenant_model_id, + model_config.tenant_id, + sorted(allowed_tenant_ids), + requester_tenant_id, + ) + raise LookupError(f"Tenant Model with id {tenant_model_id} not authorized") config_dict = model_config.to_dict() api_key, is_tools, api_key_payload = TenantLLMService._decode_api_key_config(config_dict.get("api_key", "")) config_dict["api_key"] = api_key diff --git a/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py b/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py index 8234866e82..6f4927b8d0 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py +++ b/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py @@ -223,8 +223,20 @@ def _load_dify_retrieval_module(monkeypatch): "id": self.id } - def _get_model_config_by_id(tenant_model_id: int) -> dict: - return _MockModelConfig2("tenant-1", "model-1").to_dict() + def _get_model_config_by_id( + tenant_model_id: int, + allowed_tenant_ids=None, + requester_tenant_id=None, + ) -> dict: + mock_tenant_id = "tenant-1" + if allowed_tenant_ids is not None: + if isinstance(allowed_tenant_ids, str): + allowed_tenant_ids = {allowed_tenant_ids} + else: + allowed_tenant_ids = {str(tenant_id) for tenant_id in allowed_tenant_ids if tenant_id} + if mock_tenant_id not in allowed_tenant_ids and str(requester_tenant_id) != mock_tenant_id: + raise LookupError(f"Tenant Model with id {tenant_model_id} not authorized") + return _MockModelConfig2(mock_tenant_id, "model-1").to_dict() def _get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str): if not model_name: diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py index b4ee851745..08055a57e6 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py @@ -417,8 +417,20 @@ def _load_doc_module(monkeypatch): "id": self.id } - def _get_model_config_by_id(tenant_model_id: int) -> dict: - return _MockModelConfig2("tenant-1", "model-1").to_dict() + def _get_model_config_by_id( + tenant_model_id: int, + allowed_tenant_ids=None, + requester_tenant_id=None, + ) -> dict: + mock_tenant_id = "tenant-1" + if allowed_tenant_ids is not None: + if isinstance(allowed_tenant_ids, str): + allowed_tenant_ids = {allowed_tenant_ids} + else: + allowed_tenant_ids = {str(tenant_id) for tenant_id in allowed_tenant_ids if tenant_id} + if mock_tenant_id not in allowed_tenant_ids and str(requester_tenant_id) != mock_tenant_id: + raise LookupError(f"Tenant Model with id {tenant_model_id} not authorized") + return _MockModelConfig2(mock_tenant_id, "model-1").to_dict() def _get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str): if not model_name: diff --git a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py index 77ac86232b..773660fdd4 100644 --- a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py @@ -466,8 +466,20 @@ def _load_session_module(monkeypatch): "id": self.id } - def _get_model_config_by_id(tenant_model_id: int) -> dict: - return _MockModelConfig2("tenant-1", "model-1").to_dict() + def _get_model_config_by_id( + tenant_model_id: int, + allowed_tenant_ids=None, + requester_tenant_id=None, + ) -> dict: + mock_tenant_id = "tenant-1" + if allowed_tenant_ids is not None: + if isinstance(allowed_tenant_ids, str): + allowed_tenant_ids = {allowed_tenant_ids} + else: + allowed_tenant_ids = {str(tenant_id) for tenant_id in allowed_tenant_ids if tenant_id} + if mock_tenant_id not in allowed_tenant_ids and str(requester_tenant_id) != mock_tenant_id: + raise LookupError(f"Tenant Model with id {tenant_model_id} not authorized") + return _MockModelConfig2(mock_tenant_id, "model-1").to_dict() def _get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str): if not model_name: