Fix tag_feas code injection in retrieval ranking (#13923)

## Summary
- remove eval-based parsing from retrieval rank feature scoring
- validate `tag_feas` at write time in chunk APIs and SDK routes
- add regression tests for safe parsing and malicious payload rejection

## Details
`tag_feas` is intended to be structured rank-feature data, but the
retrieval ranking path was evaluating stored values as Python
expressions. This change treats `tag_feas` strictly as data.

### What changed
- replace `eval()` in `rag/nlp/search.py` with safe parsing via
`json.loads()` and optional `ast.literal_eval()` compatibility for
legacy Python-dict strings
- strictly filter parsed values down to `dict[str, finite number]`
- reject invalid `tag_feas` payloads at write time in web chunk routes
and SDK document chunk routes
- add focused regression tests to prove executable strings are ignored
and invalid payloads are rejected

## Validation
- `python -m pytest test/unit_test/common/test_tag_feature_utils.py
test/unit_test/rag/test_rank_feature_scores.py -q`

---------

Co-authored-by: unknown <zhenglinkai@CCN.Local>
Co-authored-by: Yingfeng Zhang <yingfeng.zhang@gmail.com>
This commit is contained in:
Ea001
2026-04-15 16:31:11 +08:00
committed by GitHub
parent 1f33ca1099
commit 38cefd88e2
8 changed files with 259 additions and 8 deletions

View File

@@ -38,6 +38,7 @@ from api.utils.api_utils import (
get_request_json, get_request_json,
) )
from common.misc_utils import thread_pool_exec from common.misc_utils import thread_pool_exec
from common.tag_feature_utils import validate_tag_features
from rag.app.qa import beAdoc, rmPrefix from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search from rag.nlp import rag_tokenizer, search
@@ -161,7 +162,10 @@ async def set():
return get_data_error_result(message="`tag_kwd` must be a list of strings") return get_data_error_result(message="`tag_kwd` must be a list of strings")
d["tag_kwd"] = req["tag_kwd"] d["tag_kwd"] = req["tag_kwd"]
if "tag_feas" in req: if "tag_feas" in req:
d["tag_feas"] = req["tag_feas"] try:
d["tag_feas"] = validate_tag_features(req["tag_feas"])
except ValueError as exc:
return get_data_error_result(message=f"`tag_feas` {exc}")
if "available_int" in req: if "available_int" in req:
d["available_int"] = req["available_int"] d["available_int"] = req["available_int"]
@@ -328,7 +332,10 @@ async def create():
return get_data_error_result(message="`tag_kwd` must be a list of strings") return get_data_error_result(message="`tag_kwd` must be a list of strings")
d["tag_kwd"] = req["tag_kwd"] d["tag_kwd"] = req["tag_kwd"]
if "tag_feas" in req: if "tag_feas" in req:
d["tag_feas"] = req["tag_feas"] try:
d["tag_feas"] = validate_tag_features(req["tag_feas"])
except ValueError as exc:
return get_data_error_result(message=f"`tag_feas` {exc}")
image_base64 = req.get("image_base64", None) image_base64 = req.get("image_base64", None)
try: try:

View File

@@ -39,6 +39,7 @@ from common.constants import FileSource, LLMType, ParserType, RetCode, TaskStatu
from common.metadata_utils import convert_conditions, meta_filter from common.metadata_utils import convert_conditions, meta_filter
from common.misc_utils import thread_pool_exec from common.misc_utils import thread_pool_exec
from common.string_utils import remove_redundant_spaces from common.string_utils import remove_redundant_spaces
from common.tag_feature_utils import validate_tag_features
from rag.app.qa import beAdoc, rmPrefix from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search from rag.nlp import rag_tokenizer, search
@@ -963,7 +964,10 @@ async def add_chunk(tenant_id, dataset_id, document_id):
return get_error_data_result("`tag_kwd` must be a list of strings") return get_error_data_result("`tag_kwd` must be a list of strings")
d["tag_kwd"] = req["tag_kwd"] d["tag_kwd"] = req["tag_kwd"]
if "tag_feas" in req: if "tag_feas" in req:
d["tag_feas"] = req["tag_feas"] try:
d["tag_feas"] = validate_tag_features(req["tag_feas"])
except ValueError as exc:
return get_error_data_result(f"`tag_feas` {exc}")
import base64 import base64
image_base64 = req.get("image_base64", None) image_base64 = req.get("image_base64", None)
@@ -1202,7 +1206,10 @@ async def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
return get_error_data_result("`tag_kwd` must be a list of strings") return get_error_data_result("`tag_kwd` must be a list of strings")
d["tag_kwd"] = req["tag_kwd"] d["tag_kwd"] = req["tag_kwd"]
if "tag_feas" in req: if "tag_feas" in req:
d["tag_feas"] = req["tag_feas"] try:
d["tag_feas"] = validate_tag_features(req["tag_feas"])
except ValueError as exc:
return get_error_data_result(f"`tag_feas` {exc}")
tenant_embd_id = DocumentService.get_tenant_embd_id(document_id) tenant_embd_id = DocumentService.get_tenant_embd_id(document_id)
if tenant_embd_id: if tenant_embd_id:
model_config = get_model_config_by_id(tenant_embd_id) model_config = get_model_config_by_id(tenant_embd_id)

View File

@@ -0,0 +1,85 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import ast
import json
import math
def parse_tag_features(raw, *, allow_json_string=True, allow_python_literal=False):
if raw is None:
return {}
parsed = raw
if isinstance(raw, str):
raw = raw.strip()
if not raw:
return {}
parsed = None
if allow_json_string:
try:
parsed = json.loads(raw)
except Exception:
parsed = None
if parsed is None and allow_python_literal:
try:
parsed = ast.literal_eval(raw)
except Exception:
parsed = None
if parsed is None:
return {}
elif not isinstance(raw, dict):
return {}
if not isinstance(parsed, dict):
return {}
cleaned = {}
for key, value in parsed.items():
if not isinstance(key, str):
continue
key = key.strip()
if not key:
continue
if isinstance(value, bool):
continue
if isinstance(value, (int, float)) and math.isfinite(float(value)):
cleaned[key] = float(value)
return cleaned
def validate_tag_features(raw):
if raw is None:
return None
if not isinstance(raw, dict):
raise ValueError("must be an object mapping string tags to finite numeric scores")
cleaned = {}
for key, value in raw.items():
if not isinstance(key, str):
raise ValueError("keys must be strings")
key = key.strip()
if not key:
raise ValueError("keys must be non-empty strings")
if isinstance(value, bool) or not isinstance(value, (int, float)):
raise ValueError("values must be finite numbers")
numeric = float(value)
if not math.isfinite(numeric):
raise ValueError("values must be finite numbers")
cleaned[key] = numeric
return cleaned

View File

@@ -26,6 +26,7 @@ from common.doc_store.doc_store_base import MatchDenseExpr, FusionExpr, OrderByE
from common.string_utils import remove_redundant_spaces from common.string_utils import remove_redundant_spaces
from common.float_utils import get_float from common.float_utils import get_float
from common.constants import PAGERANK_FLD, TAG_FLD from common.constants import PAGERANK_FLD, TAG_FLD
from common.tag_feature_utils import parse_tag_features
from common import settings from common import settings
from common.misc_utils import thread_pool_exec from common.misc_utils import thread_pool_exec
@@ -279,12 +280,18 @@ class Dealer:
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
q_denor = np.sqrt(np.sum([s * s for t, s in query_rfea.items() if t != PAGERANK_FLD])) q_denor = np.sqrt(np.sum([s * s for t, s in query_rfea.items() if t != PAGERANK_FLD]))
if q_denor == 0:
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
for i in search_res.ids: for i in search_res.ids:
nor, denor = 0, 0 nor, denor = 0, 0
if not search_res.field[i].get(TAG_FLD): if not search_res.field[i].get(TAG_FLD):
rank_fea.append(0) rank_fea.append(0)
continue continue
for t, sc in eval(search_res.field[i].get(TAG_FLD, "{}")).items(): tag_feas = parse_tag_features(search_res.field[i].get(TAG_FLD), allow_json_string=True, allow_python_literal=True)
if not tag_feas:
rank_fea.append(0)
continue
for t, sc in tag_feas.items():
if t in query_rfea: if t in query_rfea:
nor += query_rfea[t] * sc nor += query_rfea[t] * sc
denor += sc * sc denor += sc * sc

View File

@@ -584,6 +584,14 @@ def test_set_chunk_bytes_qa_image_and_guard_matrix_unit(monkeypatch):
"get_by_id", "get_by_id",
lambda _doc_id: (True, _DummyDoc(doc_id="doc-1", parser_id=module.ParserType.NAIVE)), lambda _doc_id: (True, _DummyDoc(doc_id="doc-1", parser_id=module.ParserType.NAIVE)),
) )
_set_request_json(
monkeypatch,
module,
{"doc_id": "doc-1", "chunk_id": "chunk-1", "content_with_weight": "abc", "tag_feas": [0.1]},
)
res = _run(module.set())
assert "`tag_feas` must be an object mapping string tags to finite numeric scores" in res["message"], res
_set_request_json( _set_request_json(
monkeypatch, monkeypatch,
module, module,
@@ -594,7 +602,7 @@ def test_set_chunk_bytes_qa_image_and_guard_matrix_unit(monkeypatch):
"important_kwd": ["important"], "important_kwd": ["important"],
"question_kwd": ["question"], "question_kwd": ["question"],
"tag_kwd": ["tag"], "tag_kwd": ["tag"],
"tag_feas": [0.1], "tag_feas": {"tag": 0.1},
"available_int": 0, "available_int": 0,
}, },
) )
@@ -762,6 +770,14 @@ def test_create_chunk_guards_pagerank_and_success_unit(monkeypatch):
assert res["message"] == "Knowledgebase not found!", res assert res["message"] == "Knowledgebase not found!", res
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(pagerank=0.8))) monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(pagerank=0.8)))
_set_request_json(
monkeypatch,
module,
{"doc_id": "doc-1", "content_with_weight": "chunk", "tag_feas": [0.2]},
)
res = _run(module.create())
assert "`tag_feas` must be an object mapping string tags to finite numeric scores" in res["message"], res
_set_request_json( _set_request_json(
monkeypatch, monkeypatch,
module, module,
@@ -770,7 +786,7 @@ def test_create_chunk_guards_pagerank_and_success_unit(monkeypatch):
"content_with_weight": "chunk", "content_with_weight": "chunk",
"important_kwd": ["i1"], "important_kwd": ["i1"],
"question_kwd": ["q1"], "question_kwd": ["q1"],
"tag_feas": [0.2], "tag_feas": {"tag": 0.2},
}, },
) )
res = _run(module.create()) res = _run(module.create())

View File

@@ -166,7 +166,7 @@ class TestAddChunk:
payload = { payload = {
"doc_id": doc_id, "doc_id": doc_id,
"content_with_weight": "chunk with tags", "content_with_weight": "chunk with tags",
"tag_feas": [0.1, 0.2], "tag_feas": {"tag1": 0.1, "tag2": 0.2},
"important_kwd": ["tag"], "important_kwd": ["tag"],
"question_kwd": ["question"], "question_kwd": ["question"],
} }

View File

@@ -0,0 +1,32 @@
import pytest
from common.tag_feature_utils import parse_tag_features, validate_tag_features
def test_validate_tag_features_accepts_numeric_dict():
assert validate_tag_features({"apple": 1, "banana": 2.5}) == {
"apple": 1.0,
"banana": 2.5,
}
def test_validate_tag_features_rejects_string_payload():
with pytest.raises(ValueError, match="object mapping string tags"):
validate_tag_features('{"apple": 1.0}')
def test_validate_tag_features_rejects_non_finite_or_non_numeric_values():
with pytest.raises(ValueError, match="finite numbers"):
validate_tag_features({"apple": float("inf")})
with pytest.raises(ValueError, match="finite numbers"):
validate_tag_features({"apple": "1.0"})
def test_parse_tag_features_supports_legacy_python_literal_strings():
assert parse_tag_features("{'apple': 2.0}", allow_python_literal=True) == {"apple": 2.0}
def test_parse_tag_features_ignores_executable_strings():
payload = '{"apple": (__import__("time").sleep(1) or 1.0)}'
assert parse_tag_features(payload, allow_python_literal=True) == {}

View File

@@ -0,0 +1,97 @@
import sys
import types
import numpy as np
from common.constants import PAGERANK_FLD, TAG_FLD
class _DummyTokenizer:
def tag(self, *args, **kwargs):
return []
def freq(self, *args, **kwargs):
return 0
def _tradi2simp(self, text):
return text
def _strQ2B(self, text):
return text
fake_infinity = types.ModuleType("infinity")
fake_infinity_tokenizer = types.ModuleType("infinity.rag_tokenizer")
fake_infinity_tokenizer.RagTokenizer = _DummyTokenizer
fake_infinity_tokenizer.is_chinese = lambda text: False
fake_infinity_tokenizer.is_number = lambda text: False
fake_infinity_tokenizer.is_alphabet = lambda text: True
fake_infinity_tokenizer.naive_qie = lambda text: text.split()
fake_infinity.rag_tokenizer = fake_infinity_tokenizer
sys.modules.setdefault("infinity", fake_infinity)
sys.modules.setdefault("infinity.rag_tokenizer", fake_infinity_tokenizer)
fake_query = types.ModuleType("rag.nlp.query")
class _DummyFulltextQueryer:
pass
fake_query.FulltextQueryer = _DummyFulltextQueryer
sys.modules.setdefault("rag.nlp.query", fake_query)
fake_settings = types.ModuleType("common.settings")
sys.modules.setdefault("common.settings", fake_settings)
from rag.nlp.search import Dealer
def _make_search_res(tag_feas):
return Dealer.SearchResult(
total=1,
ids=["c1"],
field={"c1": {TAG_FLD: tag_feas, PAGERANK_FLD: 0}},
)
def test_rank_feature_scores_parses_python_dict_string():
dealer = Dealer.__new__(Dealer)
sres = _make_search_res("{'apple': 2.0}")
scores = dealer._rank_feature_scores({"apple": 1.0}, sres)
assert np.isclose(scores[0], 10.0)
def test_rank_feature_scores_parses_json_string():
dealer = Dealer.__new__(Dealer)
sres = _make_search_res('{"apple": 2.0}')
scores = dealer._rank_feature_scores({"apple": 1.0}, sres)
assert np.isclose(scores[0], 10.0)
def test_rank_feature_scores_handles_dict_value():
dealer = Dealer.__new__(Dealer)
sres = _make_search_res({"apple": 2.0})
scores = dealer._rank_feature_scores({"apple": 1.0}, sres)
assert np.isclose(scores[0], 10.0)
def test_rank_feature_scores_ignores_invalid_tag_feas_string():
dealer = Dealer.__new__(Dealer)
sres = _make_search_res("not a dict")
scores = dealer._rank_feature_scores({"apple": 1.0}, sres)
assert np.isclose(scores[0], 0.0)
def test_rank_feature_scores_ignores_executable_tag_feas_string():
dealer = Dealer.__new__(Dealer)
sres = _make_search_res('{"apple": (__import__("time").sleep(1) or 1.0)}')
scores = dealer._rank_feature_scores({"apple": 1.0}, sres)
assert np.isclose(scores[0], 0.0)
def test_rank_feature_scores_returns_pagerank_when_no_tag_feature():
dealer = Dealer.__new__(Dealer)
sres = _make_search_res("{'apple': 2.0}")
scores = dealer._rank_feature_scores({PAGERANK_FLD: 10}, sres)
assert np.isclose(scores[0], 0.0)