From f1d238357258b792dc8937845e0ab6cfd0473982 Mon Sep 17 00:00:00 2001 From: qinling0210 <88864212+qinling0210@users.noreply.github.com> Date: Mon, 18 May 2026 14:22:04 +0800 Subject: [PATCH] Push metadata filters down to Infinity (#14974) ### What problem does this PR solve? Push metadata filters down to Infinity ### Type of change - [x] Refactoring --- api/apps/services/dataset_api_service.py | 1 + api/db/services/doc_metadata_service.py | 163 +++-- common/metadata_infinity_filter.py | 296 ++++++++ common/metadata_utils.py | 106 +-- rag/prompts/generator.py | 22 + .../common/test_metadata_es_filter.py | 473 ------------- test/unit_test/common/test_metadata_filter.py | 659 ++++++++++++++++++ 7 files changed, 1148 insertions(+), 572 deletions(-) create mode 100644 common/metadata_infinity_filter.py delete mode 100644 test/unit_test/common/test_metadata_es_filter.py create mode 100644 test/unit_test/common/test_metadata_filter.py diff --git a/api/apps/services/dataset_api_service.py b/api/apps/services/dataset_api_service.py index 74b081add3..5927b780ec 100644 --- a/api/apps/services/dataset_api_service.py +++ b/api/apps/services/dataset_api_service.py @@ -1344,6 +1344,7 @@ async def search_datasets(tenant_id: str, req: dict): chat_mdl = LLMBundle(tenant_id, chat_model_config) if meta_data_filter: + logging.debug(f"Metadata filter: {meta_data_filter}, question: {question}, chat_mdl={'None' if chat_mdl is None else chat_mdl.llm_name}") local_doc_ids = await apply_meta_data_filter( meta_data_filter, None, diff --git a/api/db/services/doc_metadata_service.py b/api/db/services/doc_metadata_service.py index 34258c69f5..fbe32f9e5b 100644 --- a/api/db/services/doc_metadata_service.py +++ b/api/db/services/doc_metadata_service.py @@ -404,7 +404,7 @@ class DocMetadataService: ) else: logging.debug(f"Backend {type(settings.docStoreConn).__name__} has no refresh_idx; skipping") - + logging.debug(f"Successfully inserted metadata for document {doc_id}") return True @@ -448,7 +448,8 @@ class DocMetadataService: # Post-process to split combined values processed_meta = cls._split_combined_values(meta_fields) - logging.debug(f"[update_document_metadata] Updating doc_id: {doc_id}, kb_id: {kb_id}, meta_fields: {processed_meta}") + logging.debug( + f"[update_document_metadata] Updating doc_id: {doc_id}, kb_id: {kb_id}, meta_fields: {processed_meta}") # For Elasticsearch, use efficient partial update if not settings.DOC_ENGINE_INFINITY and not settings.DOC_ENGINE_OCEANBASE: @@ -456,7 +457,8 @@ class DocMetadataService: index_exists = settings.docStoreConn.index_exist(index_name, "") if not index_exists: # Index doesn't exist - create it and insert directly - logging.debug(f"[update_document_metadata] Index {index_name} does not exist, creating and inserting") + logging.debug( + f"[update_document_metadata] Index {index_name} does not exist, creating and inserting") result = settings.docStoreConn.create_doc_meta_idx(index_name) if result is False: logging.error(f"Failed to create metadata index {index_name}") @@ -477,7 +479,8 @@ class DocMetadataService: # to a backend-provided scripted assignment that fully overwrites it. replace_meta_fields = getattr(settings.docStoreConn, "replace_meta_fields", None) if callable(replace_meta_fields) and replace_meta_fields(index_name, doc_id, processed_meta): - logging.debug(f"Successfully updated metadata for document {doc_id} via {type(settings.docStoreConn).__name__}.replace_meta_fields") + logging.debug( + f"Successfully updated metadata for document {doc_id} via {type(settings.docStoreConn).__name__}.replace_meta_fields") return True logging.warning( f"replace_meta_fields unavailable or failed on backend " @@ -537,7 +540,8 @@ class DocMetadataService: # Check if metadata table exists before attempting deletion # This is the key optimization - no table = no metadata = nothing to delete if not settings.docStoreConn.index_exist(index_name, ""): - logging.debug(f"Metadata table {index_name} does not exist, skipping metadata deletion for document {doc_id}") + logging.debug( + f"Metadata table {index_name} does not exist, skipping metadata deletion for document {doc_id}") return True # No metadata to delete is considered success # Try to get the metadata to confirm it exists before deleting @@ -627,7 +631,8 @@ class DocMetadataService: if isinstance(results, tuple) and len(results) == 2: # Infinity returns (DataFrame, int) df, total = results - logging.debug(f"[DROP EMPTY TABLE] Infinity format - total: {total}, df length: {len(df) if hasattr(df, '__len__') else 'N/A'}") + logging.debug( + f"[DROP EMPTY TABLE] Infinity format - total: {total}, df length: {len(df) if hasattr(df, '__len__') else 'N/A'}") is_empty = (total == 0 or (hasattr(df, '__len__') and len(df) == 0)) elif hasattr(results, 'get') and 'hits' in results: # ES format - MUST check this before hasattr(results, '__len__') @@ -791,26 +796,62 @@ class DocMetadataService: @classmethod def filter_doc_ids_by_meta_pushdown( - cls, - kb_ids: List[str], - filters: List[Dict], - logic: str = "and", - limit: int = 10000, + cls, + kb_ids: List[str], + filters: List[Dict], + logic: str = "and", + limit: int = 10000, ) -> Optional[List[str]]: - """Run a metadata filter directly against ES, returning matching doc IDs. + """Run a metadata filter directly against ES or Infinity, returning matching doc IDs. Returns ``None`` to signal "push-down not viable, use the in-memory ``meta_filter`` fallback". Reasons for ``None``: - - Active doc store is not Elasticsearch (Infinity / OceanBase have - different filter semantics for the JSON ``meta_fields`` column). - - One of the user filters cannot be expressed in ES DSL. - - The ES request itself failed (network, mapping, missing index). + - kb_ids or filters is empty + - One of the user filters cannot be expressed in ES DSL or Infinity SQL + - The request itself failed (network, mapping, missing index) On success returns the deduplicated, ordered list of document IDs the - ES query matched. Callers can union or intersect this with their own + query matched. Callers can union or intersect this with their own base ``doc_ids`` rather than fetching the entire metadata table. """ + if not kb_ids or not filters: + logging.debug("Metadata filter skipped: empty kb_ids or filters") + return None + + try: + kb = Knowledgebase.get_by_id(kb_ids[0]) + except Exception as e: + logging.warning(f"Metadata filter cannot resolve tenant for kb {kb_ids[0]}: {e}") + return None + if not kb: + return None + + tenant_id = kb.tenant_id + index_name = cls._get_doc_meta_index_name(tenant_id) + + if not settings.docStoreConn.index_exist(index_name, ""): + return [] + + if settings.DOC_ENGINE_INFINITY: + return cls._filter_doc_ids_by_metadata_infinity( + index_name, kb_ids, filters, logic + ) + else: + return cls._filter_doc_ids_by_metadata_es( + index_name, kb_ids, filters, logic, limit + ) + + @classmethod + def _filter_doc_ids_by_metadata_es( + cls, + index_name: str, + kb_ids: List[str], + filters: List[Dict], + logic: str, + limit: int, + ) -> Optional[List[str]]: + """ES push-down path for metadata filtering.""" from common.metadata_es_filter import ( UnsupportedMetaFilter, build_meta_filter_query, @@ -818,14 +859,6 @@ class DocMetadataService: is_pushdown_supported, ) - if not kb_ids: - return [] - - if settings.DOC_ENGINE_INFINITY: - # Infinity stores ``meta_fields`` as a JSON column without dotted - # field access; the in-memory path is still the reliable answer. - return None - es_client = getattr(settings.docStoreConn, "es", None) if es_client is None: return None @@ -833,35 +866,12 @@ class DocMetadataService: if not is_pushdown_supported(filters): return None - try: - kb = Knowledgebase.get_by_id(kb_ids[0]) - except Exception as e: - logging.warning(f"[meta_pushdown] cannot resolve tenant for kb {kb_ids[0]}: {e}") - return None - if not kb: - return None - - tenant_id = kb.tenant_id - index_name = cls._get_doc_meta_index_name(tenant_id) - - try: - if not settings.docStoreConn.index_exist(index_name, ""): - # No metadata index → no metadata-filtered docs. Returning an - # empty list (rather than ``None``) so callers don't bounce - # back to the in-memory path and re-query MySQL for nothing. - return [] - except Exception as e: - logging.warning(f"[meta_pushdown] index_exist check failed for {index_name}: {e}") - return None - try: query_body = build_meta_filter_query(filters, logic, kb_ids) except UnsupportedMetaFilter as e: - logging.debug(f"[meta_pushdown] falling back to in-memory: {e.reason}") + logging.error(f"ES build query failed: {e.reason}, filters={filters}") return None - # Only the doc id is needed downstream; trimming ``_source`` keeps the - # response small when the metadata blob is large. request_body = { **query_body, "size": limit, @@ -871,12 +881,10 @@ class DocMetadataService: try: response = es_client.search(index=index_name, body=request_body) except Exception as e: - logging.warning(f"[meta_pushdown] ES query failed for {index_name}: {e}") + logging.error(f"ES metadata filter failed for {index_name}: {e}") return None doc_ids = extract_doc_ids(response if isinstance(response, dict) else dict(response)) - # Preserve order while removing duplicates so caller-side de-dupe stays - # cheap. seen: set[str] = set() unique: List[str] = [] for did in doc_ids: @@ -887,12 +895,52 @@ class DocMetadataService: if len(unique) >= limit: logging.warning( - f"[meta_pushdown] hit limit {limit} for KBs {kb_ids}; some matches may be missing" + f"ES metadata filter hit limit {limit} for KBs {kb_ids}" ) - logging.debug(f"[meta_pushdown] {len(unique)} matches for KBs {kb_ids}") + logging.debug(f"ES metadata filter returned {len(unique)} matches for KBs {kb_ids}") return unique + @classmethod + def _filter_doc_ids_by_metadata_infinity( + cls, + index_name: str, + kb_ids: List[str], + filters: List[Dict], + logic: str, + ) -> Optional[List[str]]: + """Infinity push-down path for metadata filtering.""" + from common.metadata_infinity_filter import ( + build_infinity_filter, + extract_doc_ids, + is_pushdown_supported, + ) + + if not is_pushdown_supported(filters): + return None + + try: + sql_filter = build_infinity_filter(filters, logic) + escaped_kb_ids = [k.replace("'", "''") for k in kb_ids] + kb_filter = "kb_id IN (" + ", ".join([f"'{k}'" for k in escaped_kb_ids]) + ")" + where_clause = f"{kb_filter} AND {sql_filter}" + logging.debug(f"Infinity metadata filter: {where_clause}") + + inf_conn = settings.docStoreConn.connPool.get_conn() + try: + db_instance = inf_conn.get_database(settings.docStoreConn.dbName) + table_instance = db_instance.get_table(index_name) + df, _ = table_instance.output(["id"]).filter(where_clause).to_df() + doc_ids = extract_doc_ids(df) + logging.debug( + f"Infinity metadata filter returned {len(doc_ids)} doc IDs for kb_ids={kb_ids}, logic={logic}") + return doc_ids + finally: + settings.docStoreConn.connPool.release_conn(inf_conn) + except Exception: + logging.warning("Metadata filter push-down failed; falling back to in-memory filter", exc_info=True) + return None + @classmethod def get_metadata_keys_by_kbs(cls, kb_ids: List[str]) -> List[str]: """ @@ -955,7 +1003,8 @@ class DocMetadataService: if doc_meta: meta_mapping[doc_id] = doc_meta - logging.debug(f"[get_metadata_for_documents] Found metadata for {len(meta_mapping)}/{len(doc_ids) if doc_ids else 'all'} documents") + logging.debug( + f"[get_metadata_for_documents] Found metadata for {len(meta_mapping)}/{len(doc_ids) if doc_ids else 'all'} documents") return meta_mapping except Exception as e: @@ -981,6 +1030,7 @@ class DocMetadataService: } } """ + def _is_time_string(value: str) -> bool: """Check if a string value is an ISO 8601 datetime (e.g., '2026-02-03T00:00:00').""" if not isinstance(value, str): @@ -1220,7 +1270,8 @@ class DocMetadataService: doc_ids_set = set(doc_ids) missing_doc_ids = doc_ids_set - found_doc_ids if missing_doc_ids and updates: - logging.debug(f"[batch_update_metadata] Inserting new metadata for documents without metadata rows: {missing_doc_ids}") + logging.debug( + f"[batch_update_metadata] Inserting new metadata for documents without metadata rows: {missing_doc_ids}") for doc_id in missing_doc_ids: # Apply updates to create new metadata meta = {} diff --git a/common/metadata_infinity_filter.py b/common/metadata_infinity_filter.py new file mode 100644 index 0000000000..076cc2e23e --- /dev/null +++ b/common/metadata_infinity_filter.py @@ -0,0 +1,296 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Translate RAGflow document-metadata filter lists into Infinity SQL filter expressions. +""" + +from __future__ import annotations + +import ast +import re +from typing import Any, Dict, List, Sequence + +_KEY_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") + + +def _validate_key(key: str, flt: Dict[str, Any]) -> None: + if not _KEY_PATTERN.match(key): + raise ValueError(f"invalid key format (must be identifier-like): {flt}") + +SUPPORTED_OPERATORS: frozenset[str] = frozenset( + { + "=", + "≠", + ">", + "<", + "≥", + "≤", + "in", + "not in", + "contains", + "not contains", + "start with", + "end with", + "empty", + "not empty", + } +) + +_RANGE_OPS: Dict[str, str] = { + ">": ">", + "<": "<", + "≥": ">=", + "≤": "<=", +} + +class MetaFilterTranslator: + """Translate one user filter clause at a time into Infinity SQL filter strings.""" + + def translate(self, flt: Dict[str, Any]) -> str: + op = flt.get("op") + key = flt.get("key") + value = flt.get("value") + + if not key or not isinstance(key, str): + raise ValueError(f"filter is missing a string key: {flt}") + _validate_key(key, flt) + if op not in SUPPORTED_OPERATORS: + raise ValueError(f"unknown operator: {op!r}, filter: {flt}") + + if op == "empty": + return self._translate_empty(key) + if op == "not empty": + return self._translate_not_empty(key) + if op == "=": + return self._translate_equal(key, value, flt) + if op == "≠": + return self._translate_not_equal(key, value, flt) + if op in _RANGE_OPS: + return self._translate_range(key, op, value, flt) + if op == "in": + return self._translate_in(key, value, flt) + if op == "not in": + return self._translate_not_in(key, value, flt) + if op == "contains": + return self._translate_contains(key, value, flt) + if op == "not contains": + return self._translate_not_contains(key, value, flt) + if op == "start with": + return self._translate_start_with(key, value, flt) + if op == "end with": + return self._translate_end_with(key, value, flt) + + raise ValueError(f"no handler for operator: {op!r}, filter: {flt}") + + def _translate_empty(self, key: str) -> str: + return f"JSON_EXTRACT_STRING(meta_fields, '$.{key}') = '\"\"'" + + def _translate_not_empty(self, key: str) -> str: + return f"JSON_EXTRACT_STRING(meta_fields, '$.{key}') != '\"\"'" + + def _translate_equal(self, key: str, value: Any, flt: Dict[str, Any]) -> str: + coerced = _coerce_scalar(value, flt) + if isinstance(coerced, str): + escaped = _escape_sql_string(coerced) + return f"JSON_CONTAINS(meta_fields, '$.{key}', '\"{escaped}\"')" + return f"JSON_CONTAINS(meta_fields, '$.{key}', {coerced})" + + def _translate_not_equal(self, key: str, value: Any, flt: Dict[str, Any]) -> str: + coerced = _coerce_scalar(value, flt) + if isinstance(coerced, str): + escaped = _escape_sql_string(coerced) + return f"NOT JSON_CONTAINS(meta_fields, '$.{key}', '\"{escaped}\"')" + return f"NOT JSON_CONTAINS(meta_fields, '$.{key}', {coerced})" + + def _translate_range(self, key: str, op: str, value: Any, flt: Dict[str, Any]) -> str: + coerced = _coerce_range_value(value, flt) + sql_op = _RANGE_OPS.get(op, op) + if isinstance(coerced, str): + escaped = _escape_sql_string(coerced) + return f"JSON_EXTRACT_STRING(meta_fields, '$.{key}') {sql_op} '{escaped}'" + return f"JSON_EXTRACT_DOUBLE(meta_fields, '$.{key}') {sql_op} {coerced}" + + def _translate_in(self, key: str, value: Any, flt: Dict[str, Any]) -> str: + members = _csv_or_list(value, flt) + string_parts = [] + num_parts = [] + for m in members: + # Use same coercion as range operators to detect numeric values + coerced = _coerce_range_value(m, flt) + if isinstance(coerced, (int, float)): + num_parts.append(f"JSON_CONTAINS(meta_fields, '$.{key}', {coerced})") + else: + escaped = _escape_sql_string(coerced) + string_parts.append(f"JSON_CONTAINS(meta_fields, '$.{key}', '\"{escaped}\"')") + conditions = [] + if string_parts: + conditions.append("(" + " OR ".join(string_parts) + ")") + if num_parts: + conditions.append("(" + " OR ".join(num_parts) + ")") + return "(" + " OR ".join(conditions) + ")" + + def _translate_not_in(self, key: str, value: Any, flt: Dict[str, Any]) -> str: + members = _csv_or_list(value, flt) + string_parts = [] + num_parts = [] + for m in members: + # Use same coercion as range operators to detect numeric values + coerced = _coerce_range_value(m, flt) + if isinstance(coerced, (int, float)): + num_parts.append(f"NOT JSON_CONTAINS(meta_fields, '$.{key}', {coerced})") + else: + escaped = _escape_sql_string(coerced) + string_parts.append(f"NOT JSON_CONTAINS(meta_fields, '$.{key}', '\"{escaped}\"')") + conditions = [] + if string_parts: + conditions.append("(" + " AND ".join(string_parts) + ")") + if num_parts: + conditions.append("(" + " AND ".join(num_parts) + ")") + return " AND ".join(conditions) + + def _translate_contains(self, key: str, value: Any, flt: Dict[str, Any]) -> str: + if not value and value != 0: + raise ValueError(f"contains value is empty: {flt}") + # Use same coercion as range operators to detect numeric values + coerced = _coerce_range_value(value, flt) + if isinstance(coerced, (int, float)): + return f"JSON_CONTAINS(meta_fields, '$.{key}', {coerced})" + escaped = _escape_sql_string(str(value)) + return f"JSON_CONTAINS(meta_fields, '$.{key}', '\"{escaped}\"')" + + def _translate_not_contains(self, key: str, value: Any, flt: Dict[str, Any]) -> str: + text = _coerce_string(value, flt) + escaped = _escape_sql_string(text) + # Use Infinity's JSON_CONTAINS to check if value does NOT exist in JSON array + return f"NOT JSON_CONTAINS(meta_fields, '$.{key}', '\"{escaped}\"')" + + def _translate_start_with(self, key: str, value: Any, flt: Dict[str, Any]) -> str: + text = _coerce_string(value, flt) + escaped = _escape_sql_string(_escape_likeWildcards(text)) + return f"JSON_EXTRACT_STRING(meta_fields, '$.{key}') LIKE '{escaped}%'" + + def _translate_end_with(self, key: str, value: Any, flt: Dict[str, Any]) -> str: + text = _coerce_string(value, flt) + escaped = _escape_sql_string(_escape_likeWildcards(text)) + return f"JSON_EXTRACT_STRING(meta_fields, '$.{key}') LIKE '%{escaped}'" + + +def plan_pushdown(filters: Sequence[Dict[str, Any]], logic: str) -> List[str]: + if logic not in {"and", "or"}: + raise ValueError(f"unknown logic {logic!r}") + translator = MetaFilterTranslator() + return [translator.translate(flt) for flt in filters] + + +def build_infinity_filter(filters: Sequence[Dict[str, Any]], logic: str) -> str: + if not filters: + return "1=1" + fragments = plan_pushdown(filters, logic) + joiner = " AND " if logic == "and" else " OR " + result = "(" + joiner.join(fragments) + ")" + return result + + +def is_pushdown_supported(filters: Sequence[Dict[str, Any]]) -> bool: + for flt in filters: + op = flt.get("op") + if op not in SUPPORTED_OPERATORS: + return False + if not isinstance(flt.get("key"), str) or not flt.get("key"): + return False + return True + + +def extract_doc_ids(df) -> List[str]: + if df is None or not hasattr(df, "iterrows"): + return [] + return [str(row["id"]) for _, row in df.iterrows() if "id" in row] + + +# --------------------------------------------------------------------------- +# Value coercion helpers +# --------------------------------------------------------------------------- + + +def _coerce_scalar(value: Any, flt: Dict[str, Any]) -> Any: + if value is None: + raise ValueError(f"scalar comparison value is None: {flt}") + if isinstance(value, (list, dict)): + raise ValueError(f"scalar comparison value is non-scalar: {flt}") + try: + parsed = ast.literal_eval(str(value).strip()) + if isinstance(parsed, (int, float, bool)): + return parsed + except Exception: + pass + return str(value) + + +def _coerce_range_value(value: Any, flt: Dict[str, Any]) -> Any: + if value is None: + raise ValueError(f"range comparison value is None: {flt}") + try: + parsed = ast.literal_eval(str(value).strip()) + if isinstance(parsed, (int, float)): + return parsed + except Exception: + pass + return str(value) + + +def _coerce_string(value: Any, flt: Dict[str, Any]) -> str: + if value is None: + raise ValueError(f"string-operator value is None: {flt}") + if isinstance(value, (list, dict)): + raise ValueError(f"string-operator value must be a scalar: {flt}") + s = str(value) + if not s: + raise ValueError(f"string-operator value is empty: {flt}") + return s + + +def _csv_or_list(value: Any, flt: Dict[str, Any]) -> List[Any]: + if value is None: + raise ValueError(f"membership value is None: {flt}") + if isinstance(value, (list, tuple)): + members = list(value) + elif isinstance(value, str): + try: + parsed = ast.literal_eval(value) + except Exception: + parsed = value + if isinstance(parsed, (list, tuple)): + members = list(parsed) + else: + members = [m.strip() for m in value.split(",") if m.strip()] + else: + members = [value] + if not members: + raise ValueError(f"membership value resolved to empty list: {flt}") + normalised: List[Any] = [] + for m in members: + if isinstance(m, str): + normalised.append(m.lower().strip()) + else: + normalised.append(m) + return normalised + + +def _escape_sql_string(s: str) -> str: + return s.replace("'", "''") + + +def _escape_likeWildcards(text: str) -> str: + return text.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") \ No newline at end of file diff --git a/common/metadata_utils.py b/common/metadata_utils.py index c2fc90b541..53af2b4eaf 100644 --- a/common/metadata_utils.py +++ b/common/metadata_utils.py @@ -19,6 +19,7 @@ from typing import Any, Callable, Dict import json_repair + def convert_conditions(metadata_condition): if metadata_condition is None: metadata_condition = {} @@ -60,21 +61,21 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"): # Strict date format detection: YYYY-MM-DD (must be 10 chars with correct format) is_input_date = ( - len(input_str) == 10 and - input_str[4] == '-' and - input_str[7] == '-' and - input_str[:4].isdigit() and - input_str[5:7].isdigit() and - input_str[8:10].isdigit() + len(input_str) == 10 and + input_str[4] == '-' and + input_str[7] == '-' and + input_str[:4].isdigit() and + input_str[5:7].isdigit() and + input_str[8:10].isdigit() ) is_value_date = ( - len(value_str) == 10 and - value_str[4] == '-' and - value_str[7] == '-' and - value_str[:4].isdigit() and - value_str[5:7].isdigit() and - value_str[8:10].isdigit() + len(value_str) == 10 and + value_str[4] == '-' and + value_str[7] == '-' and + value_str[:4].isdigit() and + value_str[5:7].isdigit() and + value_str[8:10].isdigit() ) if is_value_date: @@ -109,17 +110,23 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"): matched = False try: if operator == "contains": - matched = str(input).find(value) >= 0 if not isinstance(input, list) else any(str(i).find(value) >= 0 for i in input) + matched = str(input).find(value) >= 0 if not isinstance(input, list) else any( + str(i).find(value) >= 0 for i in input) elif operator == "not contains": - matched = str(input).find(value) == -1 if not isinstance(input, list) else all(str(i).find(value) == -1 for i in input) + matched = str(input).find(value) == -1 if not isinstance(input, list) else all( + str(i).find(value) == -1 for i in input) elif operator == "in": matched = input in value if not isinstance(input, list) else all(i in value for i in input) elif operator == "not in": matched = input not in value if not isinstance(input, list) else all(i not in value for i in input) elif operator == "start with": - matched = str(input).lower().startswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).startswith(str(value).lower()) + matched = str(input).lower().startswith(str(value).lower()) if not isinstance(input, + list) else "".join( + [str(i).lower() for i in input]).startswith(str(value).lower()) elif operator == "end with": - matched = str(input).lower().endswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).endswith(str(value).lower()) + matched = str(input).lower().endswith(str(value).lower()) if not isinstance(input, + list) else "".join( + [str(i).lower() for i in input]).endswith(str(value).lower()) elif operator == "empty": matched = not input elif operator == "not empty": @@ -158,21 +165,23 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"): if logic == "and": doc_ids = doc_ids & set(ids) if not doc_ids: + logging.debug(f"meta_filter filters={filters}, logic={logic}, early return []") return [] else: doc_ids = doc_ids | set(ids) + logging.debug(f"meta_filter filters={filters}, logic={logic}, returning doc_ids={list(doc_ids)}") return list(doc_ids) async def apply_meta_data_filter( - meta_data_filter: dict | None, - metas: dict | None = None, - question: str = "", - chat_mdl: Any = None, - base_doc_ids: list[str] | None = None, - manual_value_resolver: Callable[[dict], dict] | None = None, - kb_ids: list[str] | None = None, - metas_loader: Callable[[], dict] | None = None, + meta_data_filter: dict | None, + metas: dict | None = None, + question: str = "", + chat_mdl: Any = None, + base_doc_ids: list[str] | None = None, + manual_value_resolver: Callable[[dict], dict] | None = None, + kb_ids: list[str] | None = None, + metas_loader: Callable[[], dict] | None = None, ) -> list[str] | None: """ Apply metadata filtering rules and return the filtered doc_ids. @@ -182,12 +191,11 @@ async def apply_meta_data_filter( - semi_auto: generate conditions using selected metadata keys only - manual: directly filter based on provided conditions - When ``kb_ids`` is supplied and the active doc store is Elasticsearch the - generated filter conditions are pushed down to ES via - ``DocMetadataService.filter_doc_ids_by_meta_pushdown`` instead of being - evaluated in Python over ``metas``. The in-memory ``meta_filter`` path - remains the fallback so callers without a KB scope, or backends without - push-down support, behave exactly as before. + When ``kb_ids`` is supplied, metadata filters are pushed down to the doc metadata + index (ES/Infinity) via ``DocMetadataService.filter_doc_ids_by_metadata`` instead + of being evaluated in Python over ``metas``. The in-memory ``meta_filter`` path + remains the fallback so callers without a KB scope, or backends without push-down + support, behave exactly as before. ``metas`` may be supplied eagerly or via ``metas_loader``. The loader is only invoked when the metadata dict is actually needed — i.e. for the LLM @@ -200,7 +208,7 @@ async def apply_meta_data_filter( list of doc_ids, ["-999"] when manual filters yield no result, or None when auto/semi_auto filters return empty. """ - from rag.prompts.generator import gen_meta_filter # move from the top of the file to avoid circular import + from rag.prompts.generator import gen_meta_filter # move from the top of the file to avoid circular import doc_ids = list(base_doc_ids) if base_doc_ids else [] @@ -220,17 +228,26 @@ async def apply_meta_data_filter( cached_metas = metas_loader() if metas_loader else {} return cached_metas - def _evaluate(conditions: list[dict], logic: str) -> list[str]: - """Run conditions through ES push-down when possible, in-memory otherwise.""" + def _run_metadata_filter(conditions: list[dict], logic: str) -> list[str]: + """Run conditions through ES/Infinity push-down when possible, in-memory otherwise.""" if conditions and kb_ids: - pushed = _try_meta_pushdown(kb_ids, conditions, logic) - if pushed is not None: - return pushed + try: + from api.db.services.doc_metadata_service import DocMetadataService + doc_ids = DocMetadataService.filter_doc_ids_by_meta_pushdown(kb_ids, conditions, logic) + logging.debug(f"Doc ids filtered by metadata: {doc_ids}") + if doc_ids is not None: + return doc_ids + except Exception as e: + logging.error(f"Metadata filter push down errored: {e}") + + # In-memory fallback + logging.debug("Metadata filter falls back to in-memory filter") return meta_filter(_get_metas(), conditions, logic) if method == "auto": filters: dict = await gen_meta_filter(chat_mdl, _get_metas(), question) - doc_ids.extend(_evaluate(filters["conditions"], filters.get("logic", "and"))) + logging.debug(f"Metadata filter(auto) generated: {filters}") + doc_ids.extend(_run_metadata_filter(filters["conditions"], filters.get("logic", "and"))) if not doc_ids: return None elif method == "semi_auto": @@ -251,24 +268,27 @@ async def apply_meta_data_filter( filtered_metas = {key: current_metas[key] for key in selected_keys if key in current_metas} if filtered_metas: filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question, constraints=constraints) - doc_ids.extend(_evaluate(filters["conditions"], filters.get("logic", "and"))) + logging.debug(f"Metadata filter(semi_auto) generated: {filters}") + doc_ids.extend(_run_metadata_filter(filters["conditions"], filters.get("logic", "and"))) if not doc_ids: return None elif method == "manual": filters = meta_data_filter.get("manual", []) if manual_value_resolver: filters = [manual_value_resolver(flt) for flt in filters] - doc_ids.extend(_evaluate(filters, meta_data_filter.get("logic", "and"))) + logging.debug(f"Metadata filter(manual): {filters}") + doc_ids.extend(_run_metadata_filter(filters, meta_data_filter.get("logic", "and"))) if filters and not doc_ids: doc_ids = ["-999"] + logging.debug(f"apply_meta_data_filter meta_filter={meta_data_filter}, returning doc_ids={doc_ids}") return doc_ids def _try_meta_pushdown( - kb_ids: list[str], - conditions: list[dict], - logic: str, + kb_ids: list[str], + conditions: list[dict], + logic: str, ) -> list[str] | None: """Attempt the ES push-down path; return ``None`` to fall back in-memory. @@ -335,7 +355,7 @@ def update_metadata_to(metadata, meta): return metadata -def metadata_schema(metadata: dict|list|None) -> Dict[str, Any]: +def metadata_schema(metadata: dict | list | None) -> Dict[str, Any]: if not metadata: return {} properties = {} diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index b55e7a4c91..fc4999dbe4 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -494,6 +494,28 @@ async def rank_memories_async(chat_mdl, goal: str, sub_goal: str, tool_call_summ async def gen_meta_filter(chat_mdl, meta_data: dict, query: str, constraints: dict = None) -> dict: + """Generate metadata filter conditions from a user query using an LLM. + + Args: + chat_mdl: LLM bundle for generating filters + meta_data: Dict of {key: set of values} - e.g. {"character": {"Caocao", "Liubei"}, "year": {2026}} + query: User question (e.g. "Caocao in 2026") + constraints: Optional dict of {key: operator} to constrain which op to use for a key + + Returns: + Dict with "logic" ("and"/"or") and "conditions" list. + Example return value: + { + "logic": "and", + "conditions": [ + {"key": "year", "value": "2026", "op": "="}, + {"key": "character", "value": "Caocao", "op": "="} + ] + } + + The LLM is prompted with the available metadata keys and values, and is asked to + generate filter conditions that match the user's query semantics. + """ meta_data_structure = {} for key, values in meta_data.items(): meta_data_structure[key] = list(values.keys()) if isinstance(values, dict) else values diff --git a/test/unit_test/common/test_metadata_es_filter.py b/test/unit_test/common/test_metadata_es_filter.py deleted file mode 100644 index eb8217909e..0000000000 --- a/test/unit_test/common/test_metadata_es_filter.py +++ /dev/null @@ -1,473 +0,0 @@ -"""Unit tests for the Elasticsearch push-down translator. - -These tests cover the public surface of ``common.metadata_es_filter`` without -touching the live ES cluster. They verify the shape of the produced query DSL -operator-by-operator and confirm that the parity rules with the in-memory -``meta_filter`` (lower-casing, list-membership coercion, date detection) hold. -""" - -import pytest - -from common.metadata_es_filter import ( - META_FIELDS_PREFIX, - MetaFilterPushdownPlan, - MetaFilterTranslator, - SUPPORTED_OPERATORS, - UnsupportedMetaFilter, - build_meta_filter_query, - extract_doc_ids, - is_pushdown_supported, - plan_pushdown, -) - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def translator() -> MetaFilterTranslator: - return MetaFilterTranslator() - - -def _field(key: str) -> str: - return f"{META_FIELDS_PREFIX}.{key}" - - -# --------------------------------------------------------------------------- -# Translator: per-operator shape -# --------------------------------------------------------------------------- - - -def test_equal_translates_to_term_with_lowercased_value(translator): - """String equality runs against ``.keyword`` so multi-word phrases match. - - Querying the analyzed parent field with ``term`` only matches docs whose - inverted index contains the literal phrase token, which never happens for - multi-word values. The ``.keyword`` sub-field stores the unmodified string, - and ``case_insensitive: true`` keeps the lower-cased compare semantics from - the in-memory ``meta_filter``. - """ - clauses = translator.translate({"key": "tag", "op": "=", "value": "Alpha"}).to_clauses() - assert clauses == [ - {"term": {_field("tag") + ".keyword": {"value": "alpha", "case_insensitive": True}}} - ] - - -def test_equal_parses_numeric_literal(translator): - """Numeric values stay on the parent path — no ``.keyword`` sub-field exists for ``long``.""" - clauses = translator.translate({"key": "score", "op": "=", "value": "5"}).to_clauses() - assert clauses == [{"term": {_field("score"): 5}}] - - -def test_equal_multiword_uses_keyword_subfield(translator): - """Regression for qinling0210's report: multi-word string values must match. - - Before the keyword-routing fix this emitted - ``term: meta_fields.author = "alice wonderland"`` against an analyzed text - field, which never matched (inverted index only contained per-token - entries). Routing through ``.keyword`` preserves the full phrase. - """ - clauses = translator.translate( - {"key": "author", "op": "=", "value": "Alice Wonderland"} - ).to_clauses() - assert clauses == [ - { - "term": { - _field("author") + ".keyword": { - "value": "alice wonderland", - "case_insensitive": True, - } - } - } - ] - - -def test_not_equal_requires_field_to_exist(translator): - clauses = translator.translate({"key": "tag", "op": "≠", "value": "alpha"}).to_clauses() - assert clauses == [ - { - "bool": { - "must": [{"exists": {"field": _field("tag")}}], - "must_not": [ - {"term": {_field("tag") + ".keyword": {"value": "alpha", "case_insensitive": True}}} - ], - } - } - ] - - -@pytest.mark.parametrize( - "op,es_key", - [(">", "gt"), ("<", "lt"), ("≥", "gte"), ("≤", "lte")], -) -def test_range_operator_translation(translator, op, es_key): - # Multi-clause positive filters wrap into a single bool so OR-logic - # parents can't match on just the ``exists`` half of the range. - clauses = translator.translate({"key": "score", "op": op, "value": "10"}).to_clauses() - assert clauses == [ - { - "bool": { - "must": [ - {"exists": {"field": _field("score")}}, - {"range": {_field("score"): {es_key: 10}}}, - ] - } - } - ] - - -def test_range_passes_iso_date_through_unparsed(translator): - clauses = translator.translate({"key": "published", "op": "≥", "value": "2025-01-15"}).to_clauses() - range_clause = clauses[0]["bool"]["must"][1] - assert range_clause == {"range": {_field("published"): {"gte": "2025-01-15"}}} - - -def _string_terms_should(field_path: str, members): - """``in``/``not in`` over string members expands per-element so each ``term`` - can carry ``case_insensitive`` (``terms`` does not accept that flag).""" - return { - "bool": { - "should": [ - {"term": {field_path + ".keyword": {"value": m, "case_insensitive": True}}} - for m in members - ], - "minimum_should_match": 1, - } - } - - -def test_in_operator_csv_value_lowercased(translator): - clauses = translator.translate({"key": "status", "op": "in", "value": "Active,Pending"}).to_clauses() - assert clauses == [_string_terms_should(_field("status"), ["active", "pending"])] - - -def test_in_operator_python_list_literal(translator): - clauses = translator.translate({"key": "status", "op": "in", "value": "['Open', 'Closed']"}).to_clauses() - assert clauses == [_string_terms_should(_field("status"), ["open", "closed"])] - - -def test_in_operator_numeric_members_keep_terms(translator): - """All-numeric member lists keep the cheaper ``terms`` form on the parent path.""" - clauses = translator.translate({"key": "year", "op": "in", "value": "[2024, 2025]"}).to_clauses() - assert clauses == [{"terms": {_field("year"): [2024, 2025]}}] - - -def test_not_in_negates_with_existence_guard(translator): - clauses = translator.translate({"key": "status", "op": "not in", "value": "active,pending"}).to_clauses() - assert clauses == [ - { - "bool": { - "must": [{"exists": {"field": _field("status")}}], - "must_not": [_string_terms_should(_field("status"), ["active", "pending"])], - } - } - ] - - -def test_contains_uses_case_insensitive_wildcard(translator): - clauses = translator.translate({"key": "version", "op": "contains", "value": "earth"}).to_clauses() - assert clauses == [ - { - "wildcard": { - _field("version") + ".keyword": { - "value": "*earth*", - "case_insensitive": True, - } - } - } - ] - - -def test_contains_escapes_user_wildcards(translator): - clauses = translator.translate({"key": "title", "op": "contains", "value": "a*b?c"}).to_clauses() - pattern = clauses[0]["wildcard"][_field("title") + ".keyword"]["value"] - assert pattern == "*a\\*b\\?c*" - - -def test_not_contains_negates_with_exists(translator): - clauses = translator.translate({"key": "version", "op": "not contains", "value": "earth"}).to_clauses() - assert clauses == [ - { - "bool": { - "must": [{"exists": {"field": _field("version")}}], - "must_not": [ - { - "wildcard": { - _field("version") + ".keyword": { - "value": "*earth*", - "case_insensitive": True, - } - } - } - ], - } - } - ] - - -def test_start_with_uses_prefix(translator): - clauses = translator.translate({"key": "name", "op": "start with", "value": "pre"}).to_clauses() - assert clauses == [ - {"prefix": {_field("name") + ".keyword": {"value": "pre", "case_insensitive": True}}} - ] - - -def test_end_with_uses_trailing_wildcard(translator): - clauses = translator.translate({"key": "file", "op": "end with", "value": ".pdf"}).to_clauses() - pattern = clauses[0]["wildcard"][_field("file") + ".keyword"]["value"] - assert pattern == "*.pdf" - - -def test_empty_matches_missing_or_blank(translator): - clauses = translator.translate({"key": "notes", "op": "empty", "value": ""}).to_clauses() - assert clauses == [ - { - "bool": { - "should": [ - {"bool": {"must_not": [{"exists": {"field": _field("notes")}}]}}, - {"term": {_field("notes") + ".keyword": ""}}, - ], - "minimum_should_match": 1, - } - } - ] - - -def test_not_empty_requires_exists_and_excludes_blank(translator): - clauses = translator.translate({"key": "notes", "op": "not empty", "value": ""}).to_clauses() - assert clauses == [ - { - "bool": { - "must": [{"exists": {"field": _field("notes")}}], - "must_not": [{"term": {_field("notes") + ".keyword": ""}}], - } - } - ] - - -# --------------------------------------------------------------------------- -# Translator: validation paths -# --------------------------------------------------------------------------- - - -def test_unknown_operator_raises(translator): - with pytest.raises(UnsupportedMetaFilter) as exc: - translator.translate({"key": "tag", "op": "regex", "value": "^foo"}) - assert "regex" in exc.value.reason - - -def test_missing_key_raises(translator): - with pytest.raises(UnsupportedMetaFilter): - translator.translate({"op": "=", "value": "x"}) - - -def test_scalar_op_with_list_value_raises(translator): - with pytest.raises(UnsupportedMetaFilter): - translator.translate({"key": "tag", "op": "=", "value": ["a", "b"]}) - - -def test_string_op_with_empty_value_raises(translator): - with pytest.raises(UnsupportedMetaFilter): - translator.translate({"key": "tag", "op": "contains", "value": ""}) - - -def test_membership_with_empty_csv_raises(translator): - with pytest.raises(UnsupportedMetaFilter): - translator.translate({"key": "tag", "op": "in", "value": ""}) - - -def test_supported_operator_set_matches_documentation(): - expected = { - "=", - "≠", - ">", - "<", - "≥", - "≤", - "in", - "not in", - "contains", - "not contains", - "start with", - "end with", - "empty", - "not empty", - } - assert SUPPORTED_OPERATORS == expected - - -# --------------------------------------------------------------------------- -# Plan composition -# --------------------------------------------------------------------------- - - -def test_plan_emits_must_clauses_for_and_logic(): - plan = plan_pushdown( - [ - {"key": "tag", "op": "=", "value": "alpha"}, - {"key": "score", "op": ">", "value": "5"}, - ], - logic="and", - ) - assert isinstance(plan, MetaFilterPushdownPlan) - body = plan.to_query(["kb1"]) - bool_root = body["query"]["bool"] - assert bool_root["filter"][0] == {"terms": {"kb_id": ["kb1"]}} - inner = bool_root["filter"][1]["bool"] - assert "must" in inner - # Each translated filter contributes exactly one clause to the parent bool: - # ``=`` is a single ``term``; ``>`` is wrapped into one atomic ``bool``. - assert len(inner["must"]) == 2 - expected_tag_term = { - "term": {_field("tag") + ".keyword": {"value": "alpha", "case_insensitive": True}} - } - assert expected_tag_term in inner["must"] - range_wrap = { - "bool": { - "must": [ - {"exists": {"field": _field("score")}}, - {"range": {_field("score"): {"gt": 5}}}, - ] - } - } - assert range_wrap in inner["must"] - - -def test_range_filter_under_or_stays_atomic(): - """An OR'd range must not split into independent ``exists`` + ``range`` should branches.""" - body = build_meta_filter_query( - [ - {"key": "tag", "op": "=", "value": "alpha"}, - {"key": "score", "op": ">", "value": "5"}, - ], - logic="or", - kb_ids=["kb1"], - ) - should = body["query"]["bool"]["filter"][1]["bool"]["should"] - # Two filters → two should branches, not three or four. - assert len(should) == 2 - assert { - "term": {_field("tag") + ".keyword": {"value": "alpha", "case_insensitive": True}} - } in should - - -def test_plan_emits_should_clauses_for_or_logic(): - plan = plan_pushdown( - [ - {"key": "tag", "op": "=", "value": "alpha"}, - {"key": "tag", "op": "=", "value": "beta"}, - ], - logic="or", - ) - inner = plan.to_query(["kb1"])["query"]["bool"]["filter"][1]["bool"] - assert inner["minimum_should_match"] == 1 - assert len(inner["should"]) == 2 - - -def test_unknown_logic_rejected(): - with pytest.raises(UnsupportedMetaFilter): - plan_pushdown([{"key": "k", "op": "=", "value": "v"}], logic="xor") - - -def test_empty_filter_list_returns_kb_only_query(): - body = build_meta_filter_query([], "and", ["kb1", "kb2"]) - assert body == {"query": {"bool": {"filter": [{"terms": {"kb_id": ["kb1", "kb2"]}}]}}} - - -def test_negative_filter_in_or_logic_keeps_negation_scope(): - """Wrapping ``≠`` in an OR should not let the ``must_not`` swallow other branches. - - ``≠`` is rejected by :func:`is_pushdown_supported` for multi-value safety, so - this test exercises the translator directly to confirm the per-filter - wrapping invariant. The same shape protects ``not contains`` (which IS - pushed down) from leaking its ``must_not`` into a parent should. - """ - body = build_meta_filter_query( - [ - {"key": "tag", "op": "=", "value": "alpha"}, - {"key": "tag", "op": "≠", "value": "beta"}, - ], - logic="or", - kb_ids=["kb1"], - ) - inner = body["query"]["bool"]["filter"][1]["bool"] - should = inner["should"] - assert should[0] == { - "term": {_field("tag") + ".keyword": {"value": "alpha", "case_insensitive": True}} - } - # The ≠ branch is wrapped so its must_not does not bleed into the OR set. - assert "bool" in should[1] - assert "must_not" in should[1]["bool"] - - -# --------------------------------------------------------------------------- -# is_pushdown_supported pre-check -# --------------------------------------------------------------------------- - - -def test_pushdown_check_accepts_known_ops(): - assert is_pushdown_supported( - [ - {"key": "tag", "op": "=", "value": "v"}, - {"key": "tag", "op": "contains", "value": "x"}, - ] - ) - - -def test_pushdown_check_rejects_unknown_op(): - assert not is_pushdown_supported([{"key": "tag", "op": "regex", "value": "^v"}]) - - -def test_pushdown_check_rejects_missing_key(): - assert not is_pushdown_supported([{"op": "=", "value": "v"}]) - - -@pytest.mark.parametrize("op", ["≠", "not in"]) -def test_pushdown_check_rejects_multivalue_unsafe_negatives(op): - """Negatives that diverge on multi-valued fields force the in-memory fallback.""" - assert not is_pushdown_supported([{"key": "tag", "op": op, "value": "x"}]) - - -def test_pushdown_check_one_unsafe_op_rejects_whole_request(): - """Mixing one unsafe op with safe ones still falls back, preserving correctness.""" - assert not is_pushdown_supported( - [ - {"key": "tag", "op": "=", "value": "v"}, - {"key": "tag", "op": "≠", "value": "w"}, - ] - ) - - -def test_pushdown_check_accepts_not_contains(): - """``not contains`` stays in push-down; ``all(not contains)`` ≡ ``not any(contains)``.""" - assert is_pushdown_supported([{"key": "tag", "op": "not contains", "value": "x"}]) - - -# --------------------------------------------------------------------------- -# extract_doc_ids -# --------------------------------------------------------------------------- - - -def test_extract_doc_ids_from_dict_response(): - response = { - "hits": { - "hits": [ - {"_id": "doc1", "_source": {"id": "doc1"}}, - {"_id": "doc2", "_source": {"id": "doc2"}}, - ] - } - } - assert extract_doc_ids(response) == ["doc1", "doc2"] - - -def test_extract_doc_ids_falls_back_to_source_id(): - response = {"hits": {"hits": [{"_source": {"id": "src-id"}}]}} - assert extract_doc_ids(response) == ["src-id"] - - -def test_extract_doc_ids_empty_response(): - assert extract_doc_ids({}) == [] - assert extract_doc_ids({"hits": {}}) == [] - assert extract_doc_ids({"hits": {"hits": []}}) == [] diff --git a/test/unit_test/common/test_metadata_filter.py b/test/unit_test/common/test_metadata_filter.py new file mode 100644 index 0000000000..d48b30fb6c --- /dev/null +++ b/test/unit_test/common/test_metadata_filter.py @@ -0,0 +1,659 @@ +"""Unit tests for the metadata filter push-down translators (ES and Infinity). + +Verifies the shape of the produced filter expressions for both ES DSL and +Infinity SQL, and confirms that coercion rules (lower-casing, list-membership, +date detection) are consistent between the two backends. +""" + +import pytest + +pytestmark = pytest.mark.p2 + +from common.metadata_es_filter import MetaFilterTranslator as ESMetaFilterTranslator +from common.metadata_infinity_filter import ( + MetaFilterTranslator as InfinityMetaFilterTranslator, + SUPPORTED_OPERATORS, + build_infinity_filter, + is_pushdown_supported, + plan_pushdown, + extract_doc_ids, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def es_translator() -> ESMetaFilterTranslator: + return ESMetaFilterTranslator() + + +@pytest.fixture +def infinity_translator() -> InfinityMetaFilterTranslator: + return InfinityMetaFilterTranslator() + + +# --------------------------------------------------------------------------- +# Shared: is_pushdown_supported pre-check (same logic for both backends) +# --------------------------------------------------------------------------- + + +def test_pushdown_check_accepts_known_ops(): + assert is_pushdown_supported( + [ + {"key": "tag", "op": "=", "value": "v"}, + {"key": "tag", "op": "contains", "value": "x"}, + ] + ) + + +def test_pushdown_check_rejects_unknown_op(): + assert not is_pushdown_supported([{"key": "tag", "op": "regex", "value": "^v"}]) + + +def test_pushdown_check_rejects_missing_key(): + assert not is_pushdown_supported([{"op": "=", "value": "v"}]) + + +def test_pushdown_check_accepts_not_contains(): + assert is_pushdown_supported([{"key": "tag", "op": "not contains", "value": "x"}]) + + +# --------------------------------------------------------------------------- +# Shared: plan_pushdown (same logic for both backends) +# --------------------------------------------------------------------------- + + +def test_plan_pushdown_and_logic(): + fragments = plan_pushdown( + [ + {"key": "tag", "op": "=", "value": "alpha"}, + {"key": "score", "op": ">", "value": "5"}, + ], + logic="and", + ) + assert len(fragments) == 2 + + +def test_plan_pushdown_or_logic(): + fragments = plan_pushdown( + [ + {"key": "tag", "op": "=", "value": "alpha"}, + {"key": "tag", "op": "=", "value": "beta"}, + ], + logic="or", + ) + assert len(fragments) == 2 + + +def test_unknown_logic_rejected(): + with pytest.raises(ValueError): + plan_pushdown([{"key": "k", "op": "=", "value": "v"}], logic="xor") + + +# --------------------------------------------------------------------------- +# Shared: extract_doc_ids (same implementation) +# --------------------------------------------------------------------------- + + +def test_extract_doc_ids_from_dataframe(): + import pandas as pd + + df = pd.DataFrame({"id": ["doc1", "doc2", "doc3"]}) + assert extract_doc_ids(df) == ["doc1", "doc2", "doc3"] + + +def test_extract_doc_ids_empty_dataframe(): + import pandas as pd + + df = pd.DataFrame({"id": []}) + assert extract_doc_ids(df) == [] + + +def test_extract_doc_ids_none_input(): + assert extract_doc_ids(None) == [] + + +def test_extract_doc_ids_non_dataframe(): + assert extract_doc_ids("not a dataframe") == [] + + +# --------------------------------------------------------------------------- +# Shared: SUPPORTED_OPERATORS +# --------------------------------------------------------------------------- + + +def test_supported_operator_set_matches_documentation(): + expected = { + "=", + "≠", + ">", + "<", + "≥", + "≤", + "in", + "not in", + "contains", + "not contains", + "start with", + "end with", + "empty", + "not empty", + } + assert SUPPORTED_OPERATORS == expected + + +# =========================================================================== +# ES-only tests +# =========================================================================== + + +def test_equal_translates_to_term_with_lowercased_value(es_translator): + """String equality runs against ``.keyword`` so multi-word phrases match.""" + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "tag", "op": "=", "value": "Alpha"}).to_clauses() + assert clauses == [ + {"term": {_field("tag") + ".keyword": {"value": "alpha", "case_insensitive": True}}} + ] + + +def test_equal_parses_numeric_literal(es_translator): + """Numeric values stay on the parent path — no ``.keyword`` sub-field exists for ``long``.""" + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "score", "op": "=", "value": "5"}).to_clauses() + assert clauses == [{"term": {_field("score"): 5}}] + + +def test_equal_multiword_uses_keyword_subfield(es_translator): + """Regression: multi-word string values must match via .keyword sub-field.""" + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate( + {"key": "author", "op": "=", "value": "Alice Wonderland"} + ).to_clauses() + assert clauses == [ + { + "term": { + _field("author") + ".keyword": { + "value": "alice wonderland", + "case_insensitive": True, + } + } + } + ] + + +def test_not_equal_requires_field_to_exist(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "tag", "op": "≠", "value": "alpha"}).to_clauses() + assert clauses == [ + { + "bool": { + "must": [{"exists": {"field": _field("tag")}}], + "must_not": [ + {"term": {_field("tag") + ".keyword": {"value": "alpha", "case_insensitive": True}}} + ], + } + } + ] + + +@pytest.mark.parametrize( + "op,es_key", + [(">", "gt"), ("<", "lt"), ("≥", "gte"), ("≤", "lte")], +) +def test_range_operator_translation(es_translator, op, es_key): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "score", "op": op, "value": "10"}).to_clauses() + assert clauses == [ + { + "bool": { + "must": [ + {"exists": {"field": _field("score")}}, + {"range": {_field("score"): {es_key: 10}}}, + ] + } + } + ] + + +def test_range_passes_iso_date_through_unparsed(es_translator): + clauses = es_translator.translate({"key": "published", "op": "≥", "value": "2025-01-15"}).to_clauses() + range_clause = clauses[0]["bool"]["must"][1] + assert range_clause == {"range": {"meta_fields.published": {"gte": "2025-01-15"}}} + + +def test_in_operator_csv_value_lowercased(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + def _string_terms_should(field_path: str, members): + return { + "bool": { + "should": [ + {"term": {field_path + ".keyword": {"value": m, "case_insensitive": True}}} + for m in members + ], + "minimum_should_match": 1, + } + } + + clauses = es_translator.translate({"key": "status", "op": "in", "value": "Active,Pending"}).to_clauses() + assert clauses == [_string_terms_should(_field("status"), ["active", "pending"])] + + +def test_in_operator_python_list_literal(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + def _string_terms_should(field_path: str, members): + return { + "bool": { + "should": [ + {"term": {field_path + ".keyword": {"value": m, "case_insensitive": True}}} + for m in members + ], + "minimum_should_match": 1, + } + } + + clauses = es_translator.translate({"key": "status", "op": "in", "value": "['Open', 'Closed']"}).to_clauses() + assert clauses == [_string_terms_should(_field("status"), ["open", "closed"])] + + +def test_in_operator_numeric_members_keep_terms(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "year", "op": "in", "value": "[2024, 2025]"}).to_clauses() + assert clauses == [{"terms": {_field("year"): [2024, 2025]}}] + + +def test_not_in_negates_with_existence_guard(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + def _string_terms_should(field_path: str, members): + return { + "bool": { + "should": [ + {"term": {field_path + ".keyword": {"value": m, "case_insensitive": True}}} + for m in members + ], + "minimum_should_match": 1, + } + } + + clauses = es_translator.translate({"key": "status", "op": "not in", "value": "active,pending"}).to_clauses() + assert clauses == [ + { + "bool": { + "must": [{"exists": {"field": _field("status")}}], + "must_not": [_string_terms_should(_field("status"), ["active", "pending"])], + } + } + ] + + +def test_contains_uses_case_insensitive_wildcard(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "version", "op": "contains", "value": "earth"}).to_clauses() + assert clauses == [ + { + "wildcard": { + _field("version") + ".keyword": { + "value": "*earth*", + "case_insensitive": True, + } + } + } + ] + + +def test_contains_escapes_user_wildcards(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "title", "op": "contains", "value": "a*b?c"}).to_clauses() + pattern = clauses[0]["wildcard"][_field("title") + ".keyword"]["value"] + assert pattern == "*a\\*b\\?c*" + + +def test_not_contains_negates_with_exists(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "version", "op": "not contains", "value": "earth"}).to_clauses() + assert clauses == [ + { + "bool": { + "must": [{"exists": {"field": _field("version")}}], + "must_not": [ + { + "wildcard": { + _field("version") + ".keyword": { + "value": "*earth*", + "case_insensitive": True, + } + } + } + ], + } + } + ] + + +def test_start_with_uses_prefix(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "name", "op": "start with", "value": "pre"}).to_clauses() + assert clauses == [ + {"prefix": {_field("name") + ".keyword": {"value": "pre", "case_insensitive": True}}} + ] + + +def test_end_with_uses_trailing_wildcard(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "file", "op": "end with", "value": ".pdf"}).to_clauses() + pattern = clauses[0]["wildcard"][_field("file") + ".keyword"]["value"] + assert pattern == "*.pdf" + + +def test_empty_matches_missing_or_blank(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "notes", "op": "empty", "value": ""}).to_clauses() + assert clauses == [ + { + "bool": { + "should": [ + {"bool": {"must_not": [{"exists": {"field": _field("notes")}}]}}, + {"term": {_field("notes") + ".keyword": ""}}, + ], + "minimum_should_match": 1, + } + } + ] + + +def test_not_empty_requires_exists_and_excludes_blank(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "notes", "op": "not empty", "value": ""}).to_clauses() + assert clauses == [ + { + "bool": { + "must": [{"exists": {"field": _field("notes")}}], + "must_not": [{"term": {_field("notes") + ".keyword": ""}}], + } + } + ] + + +def test_unknown_operator_raises(es_translator): + from common.metadata_es_filter import UnsupportedMetaFilter + + with pytest.raises(UnsupportedMetaFilter) as exc: + es_translator.translate({"key": "tag", "op": "regex", "value": "^foo"}) + assert "regex" in exc.value.reason + + +def test_missing_key_raises(es_translator): + from common.metadata_es_filter import UnsupportedMetaFilter + + with pytest.raises(UnsupportedMetaFilter): + es_translator.translate({"op": "=", "value": "x"}) + + +def test_scalar_op_with_list_value_raises(es_translator): + from common.metadata_es_filter import UnsupportedMetaFilter + + with pytest.raises(UnsupportedMetaFilter): + es_translator.translate({"key": "tag", "op": "=", "value": ["a", "b"]}) + + +def test_string_op_with_empty_value_raises(es_translator): + from common.metadata_es_filter import UnsupportedMetaFilter + + with pytest.raises(UnsupportedMetaFilter): + es_translator.translate({"key": "tag", "op": "contains", "value": ""}) + + +def test_membership_with_empty_csv_raises(es_translator): + from common.metadata_es_filter import UnsupportedMetaFilter + + with pytest.raises(UnsupportedMetaFilter): + es_translator.translate({"key": "tag", "op": "in", "value": ""}) + + +# =========================================================================== +# Infinity-only tests +# =========================================================================== + + +def test_build_infinity_filter_and_logic(): + body = build_infinity_filter( + [ + {"key": "tag", "op": "=", "value": "alpha"}, + {"key": "score", "op": ">", "value": "5"}, + ], + logic="and", + ) + assert " AND " in body + assert "alpha" in body + + +def test_build_infinity_filter_or_logic(): + body = build_infinity_filter( + [ + {"key": "tag", "op": "=", "value": "alpha"}, + {"key": "tag", "op": "=", "value": "beta"}, + ], + logic="or", + ) + assert " OR " in body + assert "alpha" in body + assert "beta" in body + + +def test_empty_filter_list_returns_1eq1(): + body = build_infinity_filter([], "and") + assert body == "1=1" + + +def test_infinity_equal_string_uses_lowercase(infinity_translator): + cond = infinity_translator.translate({"key": "tag", "op": "=", "value": "Alpha"}) + assert cond == "JSON_CONTAINS(meta_fields, '$.tag', '\"Alpha\"')" + + +def test_infinity_equal_numeric_keeps_number(infinity_translator): + cond = infinity_translator.translate({"key": "score", "op": "=", "value": "5"}) + assert cond == "JSON_CONTAINS(meta_fields, '$.score', 5)" + + +def test_infinity_equal_date_passes_unparsed(infinity_translator): + cond = infinity_translator.translate({"key": "published", "op": "=", "value": "2025-01-15"}) + assert cond == "JSON_CONTAINS(meta_fields, '$.published', '\"2025-01-15\"')" + + +def test_infinity_not_equal_string(infinity_translator): + cond = infinity_translator.translate({"key": "tag", "op": "≠", "value": "alpha"}) + assert "JSON_CONTAINS" in cond + assert "alpha" in cond + assert "NOT" in cond + + +def test_infinity_not_equal_numeric(infinity_translator): + cond = infinity_translator.translate({"key": "score", "op": "≠", "value": "5"}) + assert "JSON_CONTAINS" in cond and "NOT" in cond and "5" in cond + + +@pytest.mark.parametrize("op,sql_op", [(">", ">"), ("<", "<"), ("≥", ">="), ("≤", "<=")]) +def test_infinity_range_operators(infinity_translator, op, sql_op): + cond = infinity_translator.translate({"key": "score", "op": op, "value": "10"}) + assert sql_op in cond + assert "JSON_EXTRACT_DOUBLE(meta_fields, '$.score')" in cond + + +def test_infinity_range_string_value(infinity_translator): + cond = infinity_translator.translate({"key": "published", "op": "≥", "value": "2025-01-15"}) + assert ">=" in cond + assert "2025-01-15" in cond + + +def test_infinity_in_csv_lowercased(infinity_translator): + cond = infinity_translator.translate({"key": "status", "op": "in", "value": "Active,Pending"}) + assert "JSON_CONTAINS" in cond + assert "active" in cond + assert "pending" in cond + + +def test_infinity_in_python_list(infinity_translator): + cond = infinity_translator.translate({"key": "status", "op": "in", "value": "['Open', 'Closed']"}) + assert "JSON_CONTAINS" in cond + assert "open" in cond + assert "closed" in cond + + +def test_infinity_in_numeric_members(infinity_translator): + cond = infinity_translator.translate({"key": "year", "op": "in", "value": "[2024, 2025]"}) + assert "JSON_CONTAINS" in cond + assert "2024" in cond + assert "2025" in cond + + +def test_infinity_not_in_csv(infinity_translator): + cond = infinity_translator.translate({"key": "status", "op": "not in", "value": "active,pending"}) + assert "NOT JSON_CONTAINS" in cond + + +def test_infinity_contains_uses_JSON_CONTAINS(infinity_translator): + """Infinity 'contains' uses JSON_CONTAINS for JSON array membership.""" + cond = infinity_translator.translate({"key": "version", "op": "contains", "value": "earth"}) + assert "JSON_CONTAINS" in cond + assert "earth" in cond + + +def test_infinity_contains_escapes_quotes(infinity_translator): + """Special characters in contains value are escaped for JSON_CONTAINS.""" + cond = infinity_translator.translate({"key": "title", "op": "contains", "value": "a%b_c"}) + assert "JSON_CONTAINS" in cond + assert "a%b_c" in cond + + +def test_infinity_not_contains_uses_JSON_CONTAINS(infinity_translator): + """Infinity 'not contains' uses JSON_CONTAINS with NOT.""" + cond = infinity_translator.translate({"key": "version", "op": "not contains", "value": "earth"}) + assert "JSON_CONTAINS" in cond + assert "NOT" in cond or "not" in cond.lower() + + +def test_infinity_start_with(infinity_translator): + cond = infinity_translator.translate({"key": "name", "op": "start with", "value": "pre"}) + assert "LIKE" in cond + assert "'pre%" in cond + + +def test_infinity_end_with(infinity_translator): + """Infinity 'end with' uses LIKE with trailing wildcard.""" + cond = infinity_translator.translate({"key": "file", "op": "end with", "value": ".pdf"}) + assert "LIKE" in cond + assert "%.pdf" in cond + + +def test_infinity_empty(infinity_translator): + cond = infinity_translator.translate({"key": "notes", "op": "empty", "value": ""}) + assert "JSON_EXTRACT_STRING" in cond + assert '""' in cond + + +def test_infinity_not_empty(infinity_translator): + cond = infinity_translator.translate({"key": "notes", "op": "not empty", "value": ""}) + assert "JSON_EXTRACT_STRING" in cond + assert "!=" in cond + + +def test_infinity_unknown_operator_raises(infinity_translator): + with pytest.raises(ValueError) as exc: + infinity_translator.translate({"key": "tag", "op": "regex", "value": "^foo"}) + assert "regex" in str(exc.value) + + +def test_infinity_missing_key_raises(infinity_translator): + with pytest.raises(ValueError): + infinity_translator.translate({"op": "=", "value": "x"}) + + +def test_infinity_invalid_key_format_raises(infinity_translator): + with pytest.raises(ValueError, match="invalid key format"): + infinity_translator.translate({"key": "a;b", "op": "=", "value": "x"}) + + +def test_infinity_key_with_brace_raises(infinity_translator): + with pytest.raises(ValueError, match="invalid key format"): + infinity_translator.translate({"key": "field$}", "op": "=", "value": "x"}) + + +def test_infinity_scalar_op_with_list_value_raises(infinity_translator): + with pytest.raises(ValueError): + infinity_translator.translate({"key": "tag", "op": "=", "value": ["a", "b"]}) + + +def test_infinity_string_op_with_empty_value_raises(infinity_translator): + with pytest.raises(ValueError): + infinity_translator.translate({"key": "tag", "op": "contains", "value": ""}) + + +def test_infinity_membership_with_empty_csv_raises(infinity_translator): + with pytest.raises(ValueError): + infinity_translator.translate({"key": "tag", "op": "in", "value": ""}) \ No newline at end of file