Push metadata filters down to Infinity (#14974)

### What problem does this PR solve?

Push metadata filters down to Infinity

### Type of change

- [x] Refactoring
This commit is contained in:
qinling0210
2026-05-18 14:22:04 +08:00
committed by GitHub
parent 7cdc74bbe5
commit f1d2383572
7 changed files with 1148 additions and 572 deletions

View File

@@ -0,0 +1,296 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Translate RAGflow document-metadata filter lists into Infinity SQL filter expressions.
"""
from __future__ import annotations
import ast
import re
from typing import Any, Dict, List, Sequence
_KEY_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
def _validate_key(key: str, flt: Dict[str, Any]) -> None:
if not _KEY_PATTERN.match(key):
raise ValueError(f"invalid key format (must be identifier-like): {flt}")
SUPPORTED_OPERATORS: frozenset[str] = frozenset(
{
"=",
"",
">",
"<",
"",
"",
"in",
"not in",
"contains",
"not contains",
"start with",
"end with",
"empty",
"not empty",
}
)
_RANGE_OPS: Dict[str, str] = {
">": ">",
"<": "<",
"": ">=",
"": "<=",
}
class MetaFilterTranslator:
"""Translate one user filter clause at a time into Infinity SQL filter strings."""
def translate(self, flt: Dict[str, Any]) -> str:
op = flt.get("op")
key = flt.get("key")
value = flt.get("value")
if not key or not isinstance(key, str):
raise ValueError(f"filter is missing a string key: {flt}")
_validate_key(key, flt)
if op not in SUPPORTED_OPERATORS:
raise ValueError(f"unknown operator: {op!r}, filter: {flt}")
if op == "empty":
return self._translate_empty(key)
if op == "not empty":
return self._translate_not_empty(key)
if op == "=":
return self._translate_equal(key, value, flt)
if op == "":
return self._translate_not_equal(key, value, flt)
if op in _RANGE_OPS:
return self._translate_range(key, op, value, flt)
if op == "in":
return self._translate_in(key, value, flt)
if op == "not in":
return self._translate_not_in(key, value, flt)
if op == "contains":
return self._translate_contains(key, value, flt)
if op == "not contains":
return self._translate_not_contains(key, value, flt)
if op == "start with":
return self._translate_start_with(key, value, flt)
if op == "end with":
return self._translate_end_with(key, value, flt)
raise ValueError(f"no handler for operator: {op!r}, filter: {flt}")
def _translate_empty(self, key: str) -> str:
return f"JSON_EXTRACT_STRING(meta_fields, '$.{key}') = '\"\"'"
def _translate_not_empty(self, key: str) -> str:
return f"JSON_EXTRACT_STRING(meta_fields, '$.{key}') != '\"\"'"
def _translate_equal(self, key: str, value: Any, flt: Dict[str, Any]) -> str:
coerced = _coerce_scalar(value, flt)
if isinstance(coerced, str):
escaped = _escape_sql_string(coerced)
return f"JSON_CONTAINS(meta_fields, '$.{key}', '\"{escaped}\"')"
return f"JSON_CONTAINS(meta_fields, '$.{key}', {coerced})"
def _translate_not_equal(self, key: str, value: Any, flt: Dict[str, Any]) -> str:
coerced = _coerce_scalar(value, flt)
if isinstance(coerced, str):
escaped = _escape_sql_string(coerced)
return f"NOT JSON_CONTAINS(meta_fields, '$.{key}', '\"{escaped}\"')"
return f"NOT JSON_CONTAINS(meta_fields, '$.{key}', {coerced})"
def _translate_range(self, key: str, op: str, value: Any, flt: Dict[str, Any]) -> str:
coerced = _coerce_range_value(value, flt)
sql_op = _RANGE_OPS.get(op, op)
if isinstance(coerced, str):
escaped = _escape_sql_string(coerced)
return f"JSON_EXTRACT_STRING(meta_fields, '$.{key}') {sql_op} '{escaped}'"
return f"JSON_EXTRACT_DOUBLE(meta_fields, '$.{key}') {sql_op} {coerced}"
def _translate_in(self, key: str, value: Any, flt: Dict[str, Any]) -> str:
members = _csv_or_list(value, flt)
string_parts = []
num_parts = []
for m in members:
# Use same coercion as range operators to detect numeric values
coerced = _coerce_range_value(m, flt)
if isinstance(coerced, (int, float)):
num_parts.append(f"JSON_CONTAINS(meta_fields, '$.{key}', {coerced})")
else:
escaped = _escape_sql_string(coerced)
string_parts.append(f"JSON_CONTAINS(meta_fields, '$.{key}', '\"{escaped}\"')")
conditions = []
if string_parts:
conditions.append("(" + " OR ".join(string_parts) + ")")
if num_parts:
conditions.append("(" + " OR ".join(num_parts) + ")")
return "(" + " OR ".join(conditions) + ")"
def _translate_not_in(self, key: str, value: Any, flt: Dict[str, Any]) -> str:
members = _csv_or_list(value, flt)
string_parts = []
num_parts = []
for m in members:
# Use same coercion as range operators to detect numeric values
coerced = _coerce_range_value(m, flt)
if isinstance(coerced, (int, float)):
num_parts.append(f"NOT JSON_CONTAINS(meta_fields, '$.{key}', {coerced})")
else:
escaped = _escape_sql_string(coerced)
string_parts.append(f"NOT JSON_CONTAINS(meta_fields, '$.{key}', '\"{escaped}\"')")
conditions = []
if string_parts:
conditions.append("(" + " AND ".join(string_parts) + ")")
if num_parts:
conditions.append("(" + " AND ".join(num_parts) + ")")
return " AND ".join(conditions)
def _translate_contains(self, key: str, value: Any, flt: Dict[str, Any]) -> str:
if not value and value != 0:
raise ValueError(f"contains value is empty: {flt}")
# Use same coercion as range operators to detect numeric values
coerced = _coerce_range_value(value, flt)
if isinstance(coerced, (int, float)):
return f"JSON_CONTAINS(meta_fields, '$.{key}', {coerced})"
escaped = _escape_sql_string(str(value))
return f"JSON_CONTAINS(meta_fields, '$.{key}', '\"{escaped}\"')"
def _translate_not_contains(self, key: str, value: Any, flt: Dict[str, Any]) -> str:
text = _coerce_string(value, flt)
escaped = _escape_sql_string(text)
# Use Infinity's JSON_CONTAINS to check if value does NOT exist in JSON array
return f"NOT JSON_CONTAINS(meta_fields, '$.{key}', '\"{escaped}\"')"
def _translate_start_with(self, key: str, value: Any, flt: Dict[str, Any]) -> str:
text = _coerce_string(value, flt)
escaped = _escape_sql_string(_escape_likeWildcards(text))
return f"JSON_EXTRACT_STRING(meta_fields, '$.{key}') LIKE '{escaped}%'"
def _translate_end_with(self, key: str, value: Any, flt: Dict[str, Any]) -> str:
text = _coerce_string(value, flt)
escaped = _escape_sql_string(_escape_likeWildcards(text))
return f"JSON_EXTRACT_STRING(meta_fields, '$.{key}') LIKE '%{escaped}'"
def plan_pushdown(filters: Sequence[Dict[str, Any]], logic: str) -> List[str]:
if logic not in {"and", "or"}:
raise ValueError(f"unknown logic {logic!r}")
translator = MetaFilterTranslator()
return [translator.translate(flt) for flt in filters]
def build_infinity_filter(filters: Sequence[Dict[str, Any]], logic: str) -> str:
if not filters:
return "1=1"
fragments = plan_pushdown(filters, logic)
joiner = " AND " if logic == "and" else " OR "
result = "(" + joiner.join(fragments) + ")"
return result
def is_pushdown_supported(filters: Sequence[Dict[str, Any]]) -> bool:
for flt in filters:
op = flt.get("op")
if op not in SUPPORTED_OPERATORS:
return False
if not isinstance(flt.get("key"), str) or not flt.get("key"):
return False
return True
def extract_doc_ids(df) -> List[str]:
if df is None or not hasattr(df, "iterrows"):
return []
return [str(row["id"]) for _, row in df.iterrows() if "id" in row]
# ---------------------------------------------------------------------------
# Value coercion helpers
# ---------------------------------------------------------------------------
def _coerce_scalar(value: Any, flt: Dict[str, Any]) -> Any:
if value is None:
raise ValueError(f"scalar comparison value is None: {flt}")
if isinstance(value, (list, dict)):
raise ValueError(f"scalar comparison value is non-scalar: {flt}")
try:
parsed = ast.literal_eval(str(value).strip())
if isinstance(parsed, (int, float, bool)):
return parsed
except Exception:
pass
return str(value)
def _coerce_range_value(value: Any, flt: Dict[str, Any]) -> Any:
if value is None:
raise ValueError(f"range comparison value is None: {flt}")
try:
parsed = ast.literal_eval(str(value).strip())
if isinstance(parsed, (int, float)):
return parsed
except Exception:
pass
return str(value)
def _coerce_string(value: Any, flt: Dict[str, Any]) -> str:
if value is None:
raise ValueError(f"string-operator value is None: {flt}")
if isinstance(value, (list, dict)):
raise ValueError(f"string-operator value must be a scalar: {flt}")
s = str(value)
if not s:
raise ValueError(f"string-operator value is empty: {flt}")
return s
def _csv_or_list(value: Any, flt: Dict[str, Any]) -> List[Any]:
if value is None:
raise ValueError(f"membership value is None: {flt}")
if isinstance(value, (list, tuple)):
members = list(value)
elif isinstance(value, str):
try:
parsed = ast.literal_eval(value)
except Exception:
parsed = value
if isinstance(parsed, (list, tuple)):
members = list(parsed)
else:
members = [m.strip() for m in value.split(",") if m.strip()]
else:
members = [value]
if not members:
raise ValueError(f"membership value resolved to empty list: {flt}")
normalised: List[Any] = []
for m in members:
if isinstance(m, str):
normalised.append(m.lower().strip())
else:
normalised.append(m)
return normalised
def _escape_sql_string(s: str) -> str:
return s.replace("'", "''")
def _escape_likeWildcards(text: str) -> str:
return text.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")

View File

@@ -19,6 +19,7 @@ from typing import Any, Callable, Dict
import json_repair
def convert_conditions(metadata_condition):
if metadata_condition is None:
metadata_condition = {}
@@ -60,21 +61,21 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
# Strict date format detection: YYYY-MM-DD (must be 10 chars with correct format)
is_input_date = (
len(input_str) == 10 and
input_str[4] == '-' and
input_str[7] == '-' and
input_str[:4].isdigit() and
input_str[5:7].isdigit() and
input_str[8:10].isdigit()
len(input_str) == 10 and
input_str[4] == '-' and
input_str[7] == '-' and
input_str[:4].isdigit() and
input_str[5:7].isdigit() and
input_str[8:10].isdigit()
)
is_value_date = (
len(value_str) == 10 and
value_str[4] == '-' and
value_str[7] == '-' and
value_str[:4].isdigit() and
value_str[5:7].isdigit() and
value_str[8:10].isdigit()
len(value_str) == 10 and
value_str[4] == '-' and
value_str[7] == '-' and
value_str[:4].isdigit() and
value_str[5:7].isdigit() and
value_str[8:10].isdigit()
)
if is_value_date:
@@ -109,17 +110,23 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
matched = False
try:
if operator == "contains":
matched = str(input).find(value) >= 0 if not isinstance(input, list) else any(str(i).find(value) >= 0 for i in input)
matched = str(input).find(value) >= 0 if not isinstance(input, list) else any(
str(i).find(value) >= 0 for i in input)
elif operator == "not contains":
matched = str(input).find(value) == -1 if not isinstance(input, list) else all(str(i).find(value) == -1 for i in input)
matched = str(input).find(value) == -1 if not isinstance(input, list) else all(
str(i).find(value) == -1 for i in input)
elif operator == "in":
matched = input in value if not isinstance(input, list) else all(i in value for i in input)
elif operator == "not in":
matched = input not in value if not isinstance(input, list) else all(i not in value for i in input)
elif operator == "start with":
matched = str(input).lower().startswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).startswith(str(value).lower())
matched = str(input).lower().startswith(str(value).lower()) if not isinstance(input,
list) else "".join(
[str(i).lower() for i in input]).startswith(str(value).lower())
elif operator == "end with":
matched = str(input).lower().endswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).endswith(str(value).lower())
matched = str(input).lower().endswith(str(value).lower()) if not isinstance(input,
list) else "".join(
[str(i).lower() for i in input]).endswith(str(value).lower())
elif operator == "empty":
matched = not input
elif operator == "not empty":
@@ -158,21 +165,23 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
if logic == "and":
doc_ids = doc_ids & set(ids)
if not doc_ids:
logging.debug(f"meta_filter filters={filters}, logic={logic}, early return []")
return []
else:
doc_ids = doc_ids | set(ids)
logging.debug(f"meta_filter filters={filters}, logic={logic}, returning doc_ids={list(doc_ids)}")
return list(doc_ids)
async def apply_meta_data_filter(
meta_data_filter: dict | None,
metas: dict | None = None,
question: str = "",
chat_mdl: Any = None,
base_doc_ids: list[str] | None = None,
manual_value_resolver: Callable[[dict], dict] | None = None,
kb_ids: list[str] | None = None,
metas_loader: Callable[[], dict] | None = None,
meta_data_filter: dict | None,
metas: dict | None = None,
question: str = "",
chat_mdl: Any = None,
base_doc_ids: list[str] | None = None,
manual_value_resolver: Callable[[dict], dict] | None = None,
kb_ids: list[str] | None = None,
metas_loader: Callable[[], dict] | None = None,
) -> list[str] | None:
"""
Apply metadata filtering rules and return the filtered doc_ids.
@@ -182,12 +191,11 @@ async def apply_meta_data_filter(
- semi_auto: generate conditions using selected metadata keys only
- manual: directly filter based on provided conditions
When ``kb_ids`` is supplied and the active doc store is Elasticsearch the
generated filter conditions are pushed down to ES via
``DocMetadataService.filter_doc_ids_by_meta_pushdown`` instead of being
evaluated in Python over ``metas``. The in-memory ``meta_filter`` path
remains the fallback so callers without a KB scope, or backends without
push-down support, behave exactly as before.
When ``kb_ids`` is supplied, metadata filters are pushed down to the doc metadata
index (ES/Infinity) via ``DocMetadataService.filter_doc_ids_by_metadata`` instead
of being evaluated in Python over ``metas``. The in-memory ``meta_filter`` path
remains the fallback so callers without a KB scope, or backends without push-down
support, behave exactly as before.
``metas`` may be supplied eagerly or via ``metas_loader``. The loader is
only invoked when the metadata dict is actually needed — i.e. for the LLM
@@ -200,7 +208,7 @@ async def apply_meta_data_filter(
list of doc_ids, ["-999"] when manual filters yield no result, or None
when auto/semi_auto filters return empty.
"""
from rag.prompts.generator import gen_meta_filter # move from the top of the file to avoid circular import
from rag.prompts.generator import gen_meta_filter # move from the top of the file to avoid circular import
doc_ids = list(base_doc_ids) if base_doc_ids else []
@@ -220,17 +228,26 @@ async def apply_meta_data_filter(
cached_metas = metas_loader() if metas_loader else {}
return cached_metas
def _evaluate(conditions: list[dict], logic: str) -> list[str]:
"""Run conditions through ES push-down when possible, in-memory otherwise."""
def _run_metadata_filter(conditions: list[dict], logic: str) -> list[str]:
"""Run conditions through ES/Infinity push-down when possible, in-memory otherwise."""
if conditions and kb_ids:
pushed = _try_meta_pushdown(kb_ids, conditions, logic)
if pushed is not None:
return pushed
try:
from api.db.services.doc_metadata_service import DocMetadataService
doc_ids = DocMetadataService.filter_doc_ids_by_meta_pushdown(kb_ids, conditions, logic)
logging.debug(f"Doc ids filtered by metadata: {doc_ids}")
if doc_ids is not None:
return doc_ids
except Exception as e:
logging.error(f"Metadata filter push down errored: {e}")
# In-memory fallback
logging.debug("Metadata filter falls back to in-memory filter")
return meta_filter(_get_metas(), conditions, logic)
if method == "auto":
filters: dict = await gen_meta_filter(chat_mdl, _get_metas(), question)
doc_ids.extend(_evaluate(filters["conditions"], filters.get("logic", "and")))
logging.debug(f"Metadata filter(auto) generated: {filters}")
doc_ids.extend(_run_metadata_filter(filters["conditions"], filters.get("logic", "and")))
if not doc_ids:
return None
elif method == "semi_auto":
@@ -251,24 +268,27 @@ async def apply_meta_data_filter(
filtered_metas = {key: current_metas[key] for key in selected_keys if key in current_metas}
if filtered_metas:
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question, constraints=constraints)
doc_ids.extend(_evaluate(filters["conditions"], filters.get("logic", "and")))
logging.debug(f"Metadata filter(semi_auto) generated: {filters}")
doc_ids.extend(_run_metadata_filter(filters["conditions"], filters.get("logic", "and")))
if not doc_ids:
return None
elif method == "manual":
filters = meta_data_filter.get("manual", [])
if manual_value_resolver:
filters = [manual_value_resolver(flt) for flt in filters]
doc_ids.extend(_evaluate(filters, meta_data_filter.get("logic", "and")))
logging.debug(f"Metadata filter(manual): {filters}")
doc_ids.extend(_run_metadata_filter(filters, meta_data_filter.get("logic", "and")))
if filters and not doc_ids:
doc_ids = ["-999"]
logging.debug(f"apply_meta_data_filter meta_filter={meta_data_filter}, returning doc_ids={doc_ids}")
return doc_ids
def _try_meta_pushdown(
kb_ids: list[str],
conditions: list[dict],
logic: str,
kb_ids: list[str],
conditions: list[dict],
logic: str,
) -> list[str] | None:
"""Attempt the ES push-down path; return ``None`` to fall back in-memory.
@@ -335,7 +355,7 @@ def update_metadata_to(metadata, meta):
return metadata
def metadata_schema(metadata: dict|list|None) -> Dict[str, Any]:
def metadata_schema(metadata: dict | list | None) -> Dict[str, Any]:
if not metadata:
return {}
properties = {}