mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Fix tag_feas code injection in retrieval ranking (#13923)
## Summary - remove eval-based parsing from retrieval rank feature scoring - validate `tag_feas` at write time in chunk APIs and SDK routes - add regression tests for safe parsing and malicious payload rejection ## Details `tag_feas` is intended to be structured rank-feature data, but the retrieval ranking path was evaluating stored values as Python expressions. This change treats `tag_feas` strictly as data. ### What changed - replace `eval()` in `rag/nlp/search.py` with safe parsing via `json.loads()` and optional `ast.literal_eval()` compatibility for legacy Python-dict strings - strictly filter parsed values down to `dict[str, finite number]` - reject invalid `tag_feas` payloads at write time in web chunk routes and SDK document chunk routes - add focused regression tests to prove executable strings are ignored and invalid payloads are rejected ## Validation - `python -m pytest test/unit_test/common/test_tag_feature_utils.py test/unit_test/rag/test_rank_feature_scores.py -q` --------- Co-authored-by: unknown <zhenglinkai@CCN.Local> Co-authored-by: Yingfeng Zhang <yingfeng.zhang@gmail.com>
This commit is contained in:
@@ -38,6 +38,7 @@ from api.utils.api_utils import (
|
|||||||
get_request_json,
|
get_request_json,
|
||||||
)
|
)
|
||||||
from common.misc_utils import thread_pool_exec
|
from common.misc_utils import thread_pool_exec
|
||||||
|
from common.tag_feature_utils import validate_tag_features
|
||||||
from rag.app.qa import beAdoc, rmPrefix
|
from rag.app.qa import beAdoc, rmPrefix
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp import rag_tokenizer, search
|
from rag.nlp import rag_tokenizer, search
|
||||||
@@ -161,7 +162,10 @@ async def set():
|
|||||||
return get_data_error_result(message="`tag_kwd` must be a list of strings")
|
return get_data_error_result(message="`tag_kwd` must be a list of strings")
|
||||||
d["tag_kwd"] = req["tag_kwd"]
|
d["tag_kwd"] = req["tag_kwd"]
|
||||||
if "tag_feas" in req:
|
if "tag_feas" in req:
|
||||||
d["tag_feas"] = req["tag_feas"]
|
try:
|
||||||
|
d["tag_feas"] = validate_tag_features(req["tag_feas"])
|
||||||
|
except ValueError as exc:
|
||||||
|
return get_data_error_result(message=f"`tag_feas` {exc}")
|
||||||
if "available_int" in req:
|
if "available_int" in req:
|
||||||
d["available_int"] = req["available_int"]
|
d["available_int"] = req["available_int"]
|
||||||
|
|
||||||
@@ -328,7 +332,10 @@ async def create():
|
|||||||
return get_data_error_result(message="`tag_kwd` must be a list of strings")
|
return get_data_error_result(message="`tag_kwd` must be a list of strings")
|
||||||
d["tag_kwd"] = req["tag_kwd"]
|
d["tag_kwd"] = req["tag_kwd"]
|
||||||
if "tag_feas" in req:
|
if "tag_feas" in req:
|
||||||
d["tag_feas"] = req["tag_feas"]
|
try:
|
||||||
|
d["tag_feas"] = validate_tag_features(req["tag_feas"])
|
||||||
|
except ValueError as exc:
|
||||||
|
return get_data_error_result(message=f"`tag_feas` {exc}")
|
||||||
image_base64 = req.get("image_base64", None)
|
image_base64 = req.get("image_base64", None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from common.constants import FileSource, LLMType, ParserType, RetCode, TaskStatu
|
|||||||
from common.metadata_utils import convert_conditions, meta_filter
|
from common.metadata_utils import convert_conditions, meta_filter
|
||||||
from common.misc_utils import thread_pool_exec
|
from common.misc_utils import thread_pool_exec
|
||||||
from common.string_utils import remove_redundant_spaces
|
from common.string_utils import remove_redundant_spaces
|
||||||
|
from common.tag_feature_utils import validate_tag_features
|
||||||
from rag.app.qa import beAdoc, rmPrefix
|
from rag.app.qa import beAdoc, rmPrefix
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp import rag_tokenizer, search
|
from rag.nlp import rag_tokenizer, search
|
||||||
@@ -963,7 +964,10 @@ async def add_chunk(tenant_id, dataset_id, document_id):
|
|||||||
return get_error_data_result("`tag_kwd` must be a list of strings")
|
return get_error_data_result("`tag_kwd` must be a list of strings")
|
||||||
d["tag_kwd"] = req["tag_kwd"]
|
d["tag_kwd"] = req["tag_kwd"]
|
||||||
if "tag_feas" in req:
|
if "tag_feas" in req:
|
||||||
d["tag_feas"] = req["tag_feas"]
|
try:
|
||||||
|
d["tag_feas"] = validate_tag_features(req["tag_feas"])
|
||||||
|
except ValueError as exc:
|
||||||
|
return get_error_data_result(f"`tag_feas` {exc}")
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
image_base64 = req.get("image_base64", None)
|
image_base64 = req.get("image_base64", None)
|
||||||
@@ -1202,7 +1206,10 @@ async def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
|||||||
return get_error_data_result("`tag_kwd` must be a list of strings")
|
return get_error_data_result("`tag_kwd` must be a list of strings")
|
||||||
d["tag_kwd"] = req["tag_kwd"]
|
d["tag_kwd"] = req["tag_kwd"]
|
||||||
if "tag_feas" in req:
|
if "tag_feas" in req:
|
||||||
d["tag_feas"] = req["tag_feas"]
|
try:
|
||||||
|
d["tag_feas"] = validate_tag_features(req["tag_feas"])
|
||||||
|
except ValueError as exc:
|
||||||
|
return get_error_data_result(f"`tag_feas` {exc}")
|
||||||
tenant_embd_id = DocumentService.get_tenant_embd_id(document_id)
|
tenant_embd_id = DocumentService.get_tenant_embd_id(document_id)
|
||||||
if tenant_embd_id:
|
if tenant_embd_id:
|
||||||
model_config = get_model_config_by_id(tenant_embd_id)
|
model_config = get_model_config_by_id(tenant_embd_id)
|
||||||
|
|||||||
85
common/tag_feature_utils.py
Normal file
85
common/tag_feature_utils.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def parse_tag_features(raw, *, allow_json_string=True, allow_python_literal=False):
|
||||||
|
if raw is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
parsed = raw
|
||||||
|
if isinstance(raw, str):
|
||||||
|
raw = raw.strip()
|
||||||
|
if not raw:
|
||||||
|
return {}
|
||||||
|
parsed = None
|
||||||
|
if allow_json_string:
|
||||||
|
try:
|
||||||
|
parsed = json.loads(raw)
|
||||||
|
except Exception:
|
||||||
|
parsed = None
|
||||||
|
if parsed is None and allow_python_literal:
|
||||||
|
try:
|
||||||
|
parsed = ast.literal_eval(raw)
|
||||||
|
except Exception:
|
||||||
|
parsed = None
|
||||||
|
if parsed is None:
|
||||||
|
return {}
|
||||||
|
elif not isinstance(raw, dict):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
cleaned = {}
|
||||||
|
for key, value in parsed.items():
|
||||||
|
if not isinstance(key, str):
|
||||||
|
continue
|
||||||
|
key = key.strip()
|
||||||
|
if not key:
|
||||||
|
continue
|
||||||
|
if isinstance(value, bool):
|
||||||
|
continue
|
||||||
|
if isinstance(value, (int, float)) and math.isfinite(float(value)):
|
||||||
|
cleaned[key] = float(value)
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
def validate_tag_features(raw):
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(raw, dict):
|
||||||
|
raise ValueError("must be an object mapping string tags to finite numeric scores")
|
||||||
|
|
||||||
|
cleaned = {}
|
||||||
|
for key, value in raw.items():
|
||||||
|
if not isinstance(key, str):
|
||||||
|
raise ValueError("keys must be strings")
|
||||||
|
key = key.strip()
|
||||||
|
if not key:
|
||||||
|
raise ValueError("keys must be non-empty strings")
|
||||||
|
if isinstance(value, bool) or not isinstance(value, (int, float)):
|
||||||
|
raise ValueError("values must be finite numbers")
|
||||||
|
numeric = float(value)
|
||||||
|
if not math.isfinite(numeric):
|
||||||
|
raise ValueError("values must be finite numbers")
|
||||||
|
cleaned[key] = numeric
|
||||||
|
|
||||||
|
return cleaned
|
||||||
@@ -26,6 +26,7 @@ from common.doc_store.doc_store_base import MatchDenseExpr, FusionExpr, OrderByE
|
|||||||
from common.string_utils import remove_redundant_spaces
|
from common.string_utils import remove_redundant_spaces
|
||||||
from common.float_utils import get_float
|
from common.float_utils import get_float
|
||||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||||
|
from common.tag_feature_utils import parse_tag_features
|
||||||
from common import settings
|
from common import settings
|
||||||
|
|
||||||
from common.misc_utils import thread_pool_exec
|
from common.misc_utils import thread_pool_exec
|
||||||
@@ -279,12 +280,18 @@ class Dealer:
|
|||||||
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
|
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
|
||||||
|
|
||||||
q_denor = np.sqrt(np.sum([s * s for t, s in query_rfea.items() if t != PAGERANK_FLD]))
|
q_denor = np.sqrt(np.sum([s * s for t, s in query_rfea.items() if t != PAGERANK_FLD]))
|
||||||
|
if q_denor == 0:
|
||||||
|
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
|
||||||
for i in search_res.ids:
|
for i in search_res.ids:
|
||||||
nor, denor = 0, 0
|
nor, denor = 0, 0
|
||||||
if not search_res.field[i].get(TAG_FLD):
|
if not search_res.field[i].get(TAG_FLD):
|
||||||
rank_fea.append(0)
|
rank_fea.append(0)
|
||||||
continue
|
continue
|
||||||
for t, sc in eval(search_res.field[i].get(TAG_FLD, "{}")).items():
|
tag_feas = parse_tag_features(search_res.field[i].get(TAG_FLD), allow_json_string=True, allow_python_literal=True)
|
||||||
|
if not tag_feas:
|
||||||
|
rank_fea.append(0)
|
||||||
|
continue
|
||||||
|
for t, sc in tag_feas.items():
|
||||||
if t in query_rfea:
|
if t in query_rfea:
|
||||||
nor += query_rfea[t] * sc
|
nor += query_rfea[t] * sc
|
||||||
denor += sc * sc
|
denor += sc * sc
|
||||||
|
|||||||
@@ -584,6 +584,14 @@ def test_set_chunk_bytes_qa_image_and_guard_matrix_unit(monkeypatch):
|
|||||||
"get_by_id",
|
"get_by_id",
|
||||||
lambda _doc_id: (True, _DummyDoc(doc_id="doc-1", parser_id=module.ParserType.NAIVE)),
|
lambda _doc_id: (True, _DummyDoc(doc_id="doc-1", parser_id=module.ParserType.NAIVE)),
|
||||||
)
|
)
|
||||||
|
_set_request_json(
|
||||||
|
monkeypatch,
|
||||||
|
module,
|
||||||
|
{"doc_id": "doc-1", "chunk_id": "chunk-1", "content_with_weight": "abc", "tag_feas": [0.1]},
|
||||||
|
)
|
||||||
|
res = _run(module.set())
|
||||||
|
assert "`tag_feas` must be an object mapping string tags to finite numeric scores" in res["message"], res
|
||||||
|
|
||||||
_set_request_json(
|
_set_request_json(
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
module,
|
module,
|
||||||
@@ -594,7 +602,7 @@ def test_set_chunk_bytes_qa_image_and_guard_matrix_unit(monkeypatch):
|
|||||||
"important_kwd": ["important"],
|
"important_kwd": ["important"],
|
||||||
"question_kwd": ["question"],
|
"question_kwd": ["question"],
|
||||||
"tag_kwd": ["tag"],
|
"tag_kwd": ["tag"],
|
||||||
"tag_feas": [0.1],
|
"tag_feas": {"tag": 0.1},
|
||||||
"available_int": 0,
|
"available_int": 0,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -762,6 +770,14 @@ def test_create_chunk_guards_pagerank_and_success_unit(monkeypatch):
|
|||||||
assert res["message"] == "Knowledgebase not found!", res
|
assert res["message"] == "Knowledgebase not found!", res
|
||||||
|
|
||||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(pagerank=0.8)))
|
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(pagerank=0.8)))
|
||||||
|
_set_request_json(
|
||||||
|
monkeypatch,
|
||||||
|
module,
|
||||||
|
{"doc_id": "doc-1", "content_with_weight": "chunk", "tag_feas": [0.2]},
|
||||||
|
)
|
||||||
|
res = _run(module.create())
|
||||||
|
assert "`tag_feas` must be an object mapping string tags to finite numeric scores" in res["message"], res
|
||||||
|
|
||||||
_set_request_json(
|
_set_request_json(
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
module,
|
module,
|
||||||
@@ -770,7 +786,7 @@ def test_create_chunk_guards_pagerank_and_success_unit(monkeypatch):
|
|||||||
"content_with_weight": "chunk",
|
"content_with_weight": "chunk",
|
||||||
"important_kwd": ["i1"],
|
"important_kwd": ["i1"],
|
||||||
"question_kwd": ["q1"],
|
"question_kwd": ["q1"],
|
||||||
"tag_feas": [0.2],
|
"tag_feas": {"tag": 0.2},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
res = _run(module.create())
|
res = _run(module.create())
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ class TestAddChunk:
|
|||||||
payload = {
|
payload = {
|
||||||
"doc_id": doc_id,
|
"doc_id": doc_id,
|
||||||
"content_with_weight": "chunk with tags",
|
"content_with_weight": "chunk with tags",
|
||||||
"tag_feas": [0.1, 0.2],
|
"tag_feas": {"tag1": 0.1, "tag2": 0.2},
|
||||||
"important_kwd": ["tag"],
|
"important_kwd": ["tag"],
|
||||||
"question_kwd": ["question"],
|
"question_kwd": ["question"],
|
||||||
}
|
}
|
||||||
|
|||||||
32
test/unit_test/common/test_tag_feature_utils.py
Normal file
32
test/unit_test/common/test_tag_feature_utils.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from common.tag_feature_utils import parse_tag_features, validate_tag_features
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_tag_features_accepts_numeric_dict():
|
||||||
|
assert validate_tag_features({"apple": 1, "banana": 2.5}) == {
|
||||||
|
"apple": 1.0,
|
||||||
|
"banana": 2.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_tag_features_rejects_string_payload():
|
||||||
|
with pytest.raises(ValueError, match="object mapping string tags"):
|
||||||
|
validate_tag_features('{"apple": 1.0}')
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_tag_features_rejects_non_finite_or_non_numeric_values():
|
||||||
|
with pytest.raises(ValueError, match="finite numbers"):
|
||||||
|
validate_tag_features({"apple": float("inf")})
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="finite numbers"):
|
||||||
|
validate_tag_features({"apple": "1.0"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_tag_features_supports_legacy_python_literal_strings():
|
||||||
|
assert parse_tag_features("{'apple': 2.0}", allow_python_literal=True) == {"apple": 2.0}
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_tag_features_ignores_executable_strings():
|
||||||
|
payload = '{"apple": (__import__("time").sleep(1) or 1.0)}'
|
||||||
|
assert parse_tag_features(payload, allow_python_literal=True) == {}
|
||||||
97
test/unit_test/rag/test_rank_feature_scores.py
Normal file
97
test/unit_test/rag/test_rank_feature_scores.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyTokenizer:
|
||||||
|
def tag(self, *args, **kwargs):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def freq(self, *args, **kwargs):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _tradi2simp(self, text):
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _strQ2B(self, text):
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
fake_infinity = types.ModuleType("infinity")
|
||||||
|
fake_infinity_tokenizer = types.ModuleType("infinity.rag_tokenizer")
|
||||||
|
fake_infinity_tokenizer.RagTokenizer = _DummyTokenizer
|
||||||
|
fake_infinity_tokenizer.is_chinese = lambda text: False
|
||||||
|
fake_infinity_tokenizer.is_number = lambda text: False
|
||||||
|
fake_infinity_tokenizer.is_alphabet = lambda text: True
|
||||||
|
fake_infinity_tokenizer.naive_qie = lambda text: text.split()
|
||||||
|
fake_infinity.rag_tokenizer = fake_infinity_tokenizer
|
||||||
|
sys.modules.setdefault("infinity", fake_infinity)
|
||||||
|
sys.modules.setdefault("infinity.rag_tokenizer", fake_infinity_tokenizer)
|
||||||
|
|
||||||
|
fake_query = types.ModuleType("rag.nlp.query")
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyFulltextQueryer:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
fake_query.FulltextQueryer = _DummyFulltextQueryer
|
||||||
|
sys.modules.setdefault("rag.nlp.query", fake_query)
|
||||||
|
|
||||||
|
fake_settings = types.ModuleType("common.settings")
|
||||||
|
sys.modules.setdefault("common.settings", fake_settings)
|
||||||
|
|
||||||
|
from rag.nlp.search import Dealer
|
||||||
|
|
||||||
|
|
||||||
|
def _make_search_res(tag_feas):
|
||||||
|
return Dealer.SearchResult(
|
||||||
|
total=1,
|
||||||
|
ids=["c1"],
|
||||||
|
field={"c1": {TAG_FLD: tag_feas, PAGERANK_FLD: 0}},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rank_feature_scores_parses_python_dict_string():
|
||||||
|
dealer = Dealer.__new__(Dealer)
|
||||||
|
sres = _make_search_res("{'apple': 2.0}")
|
||||||
|
scores = dealer._rank_feature_scores({"apple": 1.0}, sres)
|
||||||
|
assert np.isclose(scores[0], 10.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rank_feature_scores_parses_json_string():
|
||||||
|
dealer = Dealer.__new__(Dealer)
|
||||||
|
sres = _make_search_res('{"apple": 2.0}')
|
||||||
|
scores = dealer._rank_feature_scores({"apple": 1.0}, sres)
|
||||||
|
assert np.isclose(scores[0], 10.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rank_feature_scores_handles_dict_value():
|
||||||
|
dealer = Dealer.__new__(Dealer)
|
||||||
|
sres = _make_search_res({"apple": 2.0})
|
||||||
|
scores = dealer._rank_feature_scores({"apple": 1.0}, sres)
|
||||||
|
assert np.isclose(scores[0], 10.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rank_feature_scores_ignores_invalid_tag_feas_string():
|
||||||
|
dealer = Dealer.__new__(Dealer)
|
||||||
|
sres = _make_search_res("not a dict")
|
||||||
|
scores = dealer._rank_feature_scores({"apple": 1.0}, sres)
|
||||||
|
assert np.isclose(scores[0], 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rank_feature_scores_ignores_executable_tag_feas_string():
|
||||||
|
dealer = Dealer.__new__(Dealer)
|
||||||
|
sres = _make_search_res('{"apple": (__import__("time").sleep(1) or 1.0)}')
|
||||||
|
scores = dealer._rank_feature_scores({"apple": 1.0}, sres)
|
||||||
|
assert np.isclose(scores[0], 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rank_feature_scores_returns_pagerank_when_no_tag_feature():
|
||||||
|
dealer = Dealer.__new__(Dealer)
|
||||||
|
sres = _make_search_res("{'apple': 2.0}")
|
||||||
|
scores = dealer._rank_feature_scores({PAGERANK_FLD: 10}, sres)
|
||||||
|
assert np.isclose(scores[0], 0.0)
|
||||||
Reference in New Issue
Block a user