diff --git a/rag/utils/ob_conn.py b/rag/utils/ob_conn.py index 10e0334005..22fbc9c7b1 100644 --- a/rag/utils/ob_conn.py +++ b/rag/utils/ob_conn.py @@ -260,6 +260,9 @@ def get_metadata_filter_expression(metadata_filtering_conditions: dict) -> str: return f"({f' {logical_operator} '.join(metadata_filters)})" +_VALID_FILTER_COLUMNS: set[str] = set(column_names) | set(doc_meta_column_names) + + def get_filters(condition: dict) -> list[str]: filters: list[str] = [] for k, v in condition.items(): @@ -267,9 +270,12 @@ def get_filters(condition: dict) -> list[str]: continue if k == "exists": - filters.append(f"{v} IS NOT NULL") + if isinstance(v, str) and v in _VALID_FILTER_COLUMNS: + filters.append(f"{v} IS NOT NULL") elif k == "must_not" and isinstance(v, dict) and "exists" in v: - filters.append(f"{v.get('exists')} IS NULL") + col = v.get("exists") + if isinstance(col, str) and col in _VALID_FILTER_COLUMNS: + filters.append(f"{col} IS NULL") elif k == "metadata_filtering_conditions": # Handle metadata filtering conditions metadata_filter = get_metadata_filter_expression(v) @@ -284,14 +290,15 @@ def get_filters(condition: dict) -> list[str]: filters.append(f"({array_filter})") else: filters.append(f"array_contains({k}, {get_value_str(v)})") - elif isinstance(v, list): - values: list[str] = [] - for item in v: - values.append(get_value_str(item)) - value = ", ".join(values) - filters.append(f"{k} IN ({value})") - else: - filters.append(f"{k} = {get_value_str(v)}") + elif k in _VALID_FILTER_COLUMNS: + if isinstance(v, list): + values: list[str] = [] + for item in v: + values.append(get_value_str(item)) + value = ", ".join(values) + filters.append(f"{k} IN ({value})") + else: + filters.append(f"{k} = {get_value_str(v)}") return filters @@ -530,7 +537,8 @@ class OBConnection(OBConnectionBase): ): if isinstance(index_names, str): index_names = index_names.split(",") - assert isinstance(index_names, list) and len(index_names) > 0 + if not (isinstance(index_names, list) and len(index_names) > 0): + raise ValueError("index_names must be a non-empty list") index_names = list(set(index_names)) if len(match_expressions) == 3: @@ -579,10 +587,10 @@ class OBConnection(OBConnectionBase): vector_similarity_weight = 0.5 for m in match_expressions: if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params: - assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance( - match_expressions[1], - MatchDenseExpr) and isinstance( - match_expressions[2], FusionExpr) + if not (len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance( + match_expressions[1], MatchDenseExpr) and isinstance( + match_expressions[2], FusionExpr)): + raise ValueError("match_expressions must contain MatchTextExpr, MatchDenseExpr, and FusionExpr") weights = m.fusion_params["weights"] vector_similarity_weight = get_float(weights.split(",")[1]) for m in match_expressions: @@ -597,7 +605,8 @@ class OBConnection(OBConnectionBase): bqry.boost = 1.0 - vector_similarity_weight elif isinstance(m, MatchDenseExpr): - assert (bqry is not None) + if bqry is None: + raise ValueError("bqry must not be None") similarity = 0.0 if "similarity" in m.extra_options: similarity = m.extra_options["similarity"] @@ -701,7 +710,8 @@ class OBConnection(OBConnectionBase): for m in match_expressions: if isinstance(m, MatchTextExpr): - assert "original_query" in m.extra_options, "'original_query' is missing in extra_options." + if "original_query" not in m.extra_options: + raise ValueError("'original_query' is missing in extra_options.") fulltext_query = m.extra_options["original_query"] fulltext_query = escape_string(fulltext_query.strip()) fulltext_topn = m.topn @@ -713,11 +723,12 @@ class OBConnection(OBConnectionBase): fulltext_search_idx_list.append(fulltext_index_name_template % column_name) elif isinstance(m, MatchDenseExpr): - assert m.embedding_data_type == "float", f"embedding data type '{m.embedding_data_type}' is not float." + if m.embedding_data_type != "float": + raise ValueError(f"embedding data type '{m.embedding_data_type}' is not float.") vector_column_name = m.vector_column_name vector_data = m.embedding_data vector_topn = m.topn - vector_similarity_threshold = m.extra_options.get("similarity", 0.0) + vector_similarity_threshold = float(m.extra_options.get("similarity", 0.0)) elif isinstance(m, FusionExpr): weights = m.fusion_params["weights"] vector_similarity_weight = get_float(weights.split(",")[1]) @@ -945,7 +956,8 @@ class OBConnection(OBConnectionBase): result.chunks.append(self._row_to_entity(row, output_fields)) elif search_type == "aggregation": # aggregation search - assert len(agg_fields) == 1, "Only one aggregation field is supported in OceanBase." + if len(agg_fields) != 1: + raise ValueError("Only one aggregation field is supported in OceanBase.") agg_field = agg_fields[0] if agg_field in array_columns: res = self.client.perform_raw_text_sql( @@ -1174,17 +1186,22 @@ class OBConnection(OBConnectionBase): if isinstance(v, str): set_values.append(f"{v} = NULL") else: - assert isinstance(v, dict), f"Expected str or dict for 'remove', got {type(new_value[k])}." + if not isinstance(v, dict): + raise ValueError(f"Expected str or dict for 'remove', got {type(new_value[k])}.") for kk, vv in v.items(): - assert kk in array_columns, f"Column '{kk}' is not an array column." + if kk not in array_columns: + raise ValueError(f"Column '{kk}' is not an array column.") set_values.append(f"{kk} = array_remove({kk}, {get_value_str(vv)})") elif k == "add": - assert isinstance(v, dict), f"Expected str or dict for 'add', got {type(new_value[k])}." + if not isinstance(v, dict): + raise ValueError(f"Expected str or dict for 'add', got {type(new_value[k])}.") for kk, vv in v.items(): - assert kk in array_columns, f"Column '{kk}' is not an array column." + if kk not in array_columns: + raise ValueError(f"Column '{kk}' is not an array column.") set_values.append(f"{kk} = array_append({kk}, {get_value_str(vv)})") elif k == "metadata": - assert isinstance(v, dict), f"Expected dict for 'metadata', got {type(new_value[k])}" + if not isinstance(v, dict): + raise ValueError(f"Expected dict for 'metadata', got {type(new_value[k])}") set_values.append(f"{k} = {get_value_str(v)}") if v and "doc_id" in condition: group_id = v.get("_group_id")