diff --git a/memory/utils/aggregation_utils.py b/memory/utils/aggregation_utils.py new file mode 100644 index 0000000000..6de63f1ba1 --- /dev/null +++ b/memory/utils/aggregation_utils.py @@ -0,0 +1,56 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use it except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Pure aggregation helpers for search results (no heavy dependencies).""" + + +def aggregate_by_field(messages: list | None, field_name: str) -> list[tuple[str, int]]: + """Aggregate message documents by a field; returns [(value, count), ...]. + + Handles pre-aggregated rows (dicts with "value" and "count") and + per-doc field values (str or list of str). + """ + if not messages: + return [] + + counts: dict[str, int] = {} + result: list[tuple[str, int]] = [] + + for doc in messages: + if "value" in doc and "count" in doc: + result.append((doc["value"], doc["count"])) + continue + + if field_name not in doc: + continue + + v = doc[field_name] + if isinstance(v, list): + for vv in v: + if isinstance(vv, str): + key = vv.strip() + if key: + counts[key] = counts.get(key, 0) + 1 + elif isinstance(v, str): + key = v.strip() + if key: + counts[key] = counts.get(key, 0) + 1 + + if counts: + for k, v in counts.items(): + result.append((k, v)) + + return result diff --git a/memory/utils/ob_conn.py b/memory/utils/ob_conn.py index bf8ac40050..09c976e2ca 100644 --- a/memory/utils/ob_conn.py +++ b/memory/utils/ob_conn.py @@ -24,6 +24,7 @@ from sqlalchemy import Column, String, Integer from sqlalchemy.dialects.mysql import LONGTEXT from common.decorator import singleton +from memory.utils.aggregation_utils import aggregate_by_field from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr from common.doc_store.ob_conn_base import OBConnectionBase, get_value_str, vector_search_template from common.float_utils import get_float @@ -609,5 +610,10 @@ class OBConnection(OBConnectionBase): def get_aggregation(self, res, field_name: str): """Get aggregation for search results.""" - # TODO: Implement aggregation functionality for OceanBase memory - return [] + if isinstance(res, tuple): + res_obj = res[0] + else: + res_obj = res + + messages = getattr(res_obj, "messages", None) + return aggregate_by_field(messages, field_name) diff --git a/test/unit_test/memory/utils/test_ob_conn_aggregation.py b/test/unit_test/memory/utils/test_ob_conn_aggregation.py new file mode 100644 index 0000000000..cf136eb208 --- /dev/null +++ b/test/unit_test/memory/utils/test_ob_conn_aggregation.py @@ -0,0 +1,55 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use it except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for OceanBase memory aggregation. + +Tests the pure aggregation logic used by OBConnection.get_aggregation, +without requiring a real OceanBase instance or heavy dependencies. +""" + +from memory.utils.aggregation_utils import aggregate_by_field + + +class TestAggregateByField: + """Tests for aggregate_by_field (used by get_aggregation).""" + + def test_empty_messages_returns_empty_list(self): + assert aggregate_by_field([], "message_type_kwd") == [] + assert aggregate_by_field(None, "message_type_kwd") == [] + + def test_aggregates_field_values(self): + messages = [ + {"id": "m1", "message_type_kwd": "user", "content_ltks": "a", "message_id": "msg1", "memory_id": "mem1", "status_int": 1}, + {"id": "m2", "message_type_kwd": "assistant", "content_ltks": "b", "message_id": "msg2", "memory_id": "mem1", "status_int": 1}, + {"id": "m3", "message_type_kwd": "user", "content_ltks": "c", "message_id": "msg3", "memory_id": "mem1", "status_int": 1}, + ] + out = aggregate_by_field(messages, "message_type_kwd") + assert set(out) == {("user", 2), ("assistant", 1)} + + def test_single_doc_result(self): + messages = [ + {"id": "m1", "message_type_kwd": "user", "content_ltks": "x", "message_id": "msg1", "memory_id": "mem1", "status_int": 1} + ] + out = aggregate_by_field(messages, "message_type_kwd") + assert out == [("user", 1)] + + def test_pre_aggregated_value_count_rows(self): + messages = [ + {"value": "user", "count": 2}, + {"value": "assistant", "count": 1}, + ] + out = aggregate_by_field(messages, "message_type_kwd") + assert set(out) == {("user", 2), ("assistant", 1)}