From 6fdca2d2125edcf9bc91eeb26975314483e9e892 Mon Sep 17 00:00:00 2001 From: Tong Liu <57178900+Lyutoon@users.noreply.github.com> Date: Mon, 13 Apr 2026 19:24:13 +0800 Subject: [PATCH] [Security] Fix jinja2 SSTI vulnerability using SandboxedEnvironment (#14068) --- rag/prompts/generator.py | 6 +- .../rag/prompts/test_generator_sandbox.py | 66 +++++++++++++++++++ 2 files changed, 70 insertions(+), 2 deletions(-) create mode 100644 test/unit_test/rag/prompts/test_generator_sandbox.py diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index e363fe180c..47c0b9f2ba 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -20,7 +20,7 @@ import logging import re from copy import deepcopy from typing import Tuple -import jinja2 +from jinja2.sandbox import SandboxedEnvironment import json_repair from common.misc_utils import hash_str2int from rag.nlp import rag_tokenizer @@ -183,7 +183,9 @@ RANK_MEMORY = load_prompt("rank_memory") META_FILTER = load_prompt("meta_filter") ASK_SUMMARY = load_prompt("ask_summary") -PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True) +PROMPT_JINJA_ENV = SandboxedEnvironment( + autoescape=False, trim_blocks=True, lstrip_blocks=True +) def citation_prompt(user_defined_prompts: dict = {}) -> str: diff --git a/test/unit_test/rag/prompts/test_generator_sandbox.py b/test/unit_test/rag/prompts/test_generator_sandbox.py new file mode 100644 index 0000000000..55095788b0 --- /dev/null +++ b/test/unit_test/rag/prompts/test_generator_sandbox.py @@ -0,0 +1,66 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from jinja2.exceptions import SecurityError, UndefinedError +from jinja2.sandbox import SandboxedEnvironment + +from rag.prompts.generator import PROMPT_JINJA_ENV + + +@pytest.mark.p1 +class TestJinjaSandbox: + """Test that PROMPT_JINJA_ENV uses SandboxedEnvironment to prevent SSTI attacks.""" + + @pytest.mark.p1 + @pytest.mark.parametrize( + "payload", + [ + # Classic SSTI payloads targeting __globals__, __mro__, __subclasses__ + "{{ self.__class__.__mro__[1].__subclasses__() }}", + "{{ ''.__class__.__mro__[1].__subclasses__() }}", + "{{ request.__class__.__mro__[1].__subclasses__() }}", + # Attribute traversal (no hardcoded subclass index) + "{{ config.__class__.__init__.__globals__['os'] }}", + ], + ) + def test_ssti_payload_blocked(self, payload): + """Verify that SSTI payloads are blocked by SandboxedEnvironment.""" + assert isinstance(PROMPT_JINJA_ENV, SandboxedEnvironment), ( + "PROMPT_JINJA_ENV must use SandboxedEnvironment to prevent SSTI" + ) + template = PROMPT_JINJA_ENV.from_string(payload) + # SandboxedEnvironment raises SecurityError, AttributeError, or UndefinedError to block SSTI attacks + with pytest.raises((SecurityError, AttributeError, UndefinedError)) as exc_info: + template.render() + # Verify exception contains sandbox indicators + exc_msg = str(exc_info.value) + assert any(x in exc_msg.lower() for x in ["unsafe", "security", "__mro__"]) + + @pytest.mark.p1 + def test_safe_template_rendering(self): + """Verify that benign templates still render correctly.""" + template = PROMPT_JINJA_ENV.from_string("Hello, {{ name }}!") + result = template.render(name="World") + assert result == "Hello, World!" + + @pytest.mark.p1 + def test_loop_and_conditional_rendering(self): + """Verify control flow templates work properly.""" + template = PROMPT_JINJA_ENV.from_string( + "{% for item in items %}{{ item }}{% endfor %}" + ) + result = template.render(items=["a", "b", "c"]) + assert result == "abc"