2026-02-03 16:46:17 +08:00
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from typing import Optional
import numpy as np
from pydantic import BaseModel
from pymysql . converters import escape_string
from sqlalchemy import Column , String , Integer
from sqlalchemy . dialects . mysql import LONGTEXT
from common . decorator import singleton
2026-03-05 23:51:22 -05:00
from memory . utils . aggregation_utils import aggregate_by_field
2026-03-06 07:17:11 -05:00
from memory . utils . highlight_utils import get_highlight_from_messages
2026-02-03 16:46:17 +08:00
from common . doc_store . doc_store_base import MatchExpr , OrderByExpr , FusionExpr , MatchTextExpr , MatchDenseExpr
from common . doc_store . ob_conn_base import OBConnectionBase , get_value_str , vector_search_template
from common . float_utils import get_float
2026-03-06 07:17:11 -05:00
from rag . nlp import is_english
2026-02-03 16:46:17 +08:00
from rag . nlp . rag_tokenizer import tokenize , fine_grained_tokenize
# Column definitions for memory message table
COLUMN_DEFINITIONS : list [ Column ] = [
Column ( " id " , String ( 256 ) , primary_key = True , comment = " unique record id " ) ,
Column ( " message_id " , String ( 256 ) , nullable = False , index = True , comment = " message id " ) ,
Column ( " message_type_kwd " , String ( 64 ) , nullable = True , comment = " message type " ) ,
Column ( " source_id " , String ( 256 ) , nullable = True , comment = " source message id " ) ,
Column ( " memory_id " , String ( 256 ) , nullable = False , index = True , comment = " memory id " ) ,
Column ( " user_id " , String ( 256 ) , nullable = True , comment = " user id " ) ,
Column ( " agent_id " , String ( 256 ) , nullable = True , comment = " agent id " ) ,
Column ( " session_id " , String ( 256 ) , nullable = True , comment = " session id " ) ,
Column ( " zone_id " , Integer , nullable = True , server_default = " 0 " , comment = " zone id " ) ,
Column ( " valid_at " , String ( 64 ) , nullable = True , comment = " valid at timestamp string " ) ,
Column ( " invalid_at " , String ( 64 ) , nullable = True , comment = " invalid at timestamp string " ) ,
Column ( " forget_at " , String ( 64 ) , nullable = True , comment = " forget at timestamp string " ) ,
Column ( " status_int " , Integer , nullable = False , server_default = " 1 " , comment = " status: 1 for active, 0 for inactive " ) ,
Column ( " content_ltks " , LONGTEXT , nullable = True , comment = " content with tokenization " ) ,
Column ( " tokenized_content_ltks " , LONGTEXT , nullable = True , comment = " fine-grained tokenized content " ) ,
]
COLUMN_NAMES : list [ str ] = [ col . name for col in COLUMN_DEFINITIONS ]
# Index columns for creating indexes
INDEX_COLUMNS : list [ str ] = [
" message_id " ,
" memory_id " ,
" status_int " ,
]
# Full-text search columns
FTS_COLUMNS : list [ str ] = [
" content_ltks " ,
" tokenized_content_ltks " ,
]
class SearchResult ( BaseModel ) :
total : int
messages : list [ dict ]
@singleton
class OBConnection ( OBConnectionBase ) :
def __init__ ( self ) :
2026-07-03 12:53:39 +08:00
super ( ) . __init__ ( logger_name = " ragflow.memory_ob_conn " )
2026-02-03 16:46:17 +08:00
self . _fulltext_search_columns = FTS_COLUMNS
"""
Template method implementations
"""
def get_index_columns ( self ) - > list [ str ] :
return INDEX_COLUMNS
def get_fulltext_columns ( self ) - > list [ str ] :
""" Return list of column names that need fulltext indexes (without weight suffix). """
return [ col . split ( " ^ " ) [ 0 ] for col in self . _fulltext_search_columns ]
def get_column_definitions ( self ) - > list [ Column ] :
return COLUMN_DEFINITIONS
def get_lock_prefix ( self ) - > str :
return " ob_memory_ "
def _get_dataset_id_field ( self ) - > str :
return " memory_id "
def _get_vector_column_name_from_table ( self , table_name : str ) - > Optional [ str ] :
""" Get the vector column name from the table (q_ {size} _vec pattern). """
sql = f """
2026-07-03 12:53:39 +08:00
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA . COLUMNS
WHERE TABLE_SCHEMA = ' {self.db_name} '
AND TABLE_NAME = ' {table_name} '
2026-02-03 16:46:17 +08:00
AND COLUMN_NAME REGEXP ' ^q_[0-9]+_vec$ '
LIMIT 1
"""
try :
res = self . client . perform_raw_text_sql ( sql )
row = res . fetchone ( )
return row [ 0 ] if row else None
except Exception :
return None
"""
Field conversion methods
"""
@staticmethod
def convert_field_name ( field_name : str , use_tokenized_content = False ) - > str :
""" Convert message field name to database column name. """
match field_name :
case " message_type " :
return " message_type_kwd "
case " status " :
return " status_int "
case " content " :
if use_tokenized_content :
return " tokenized_content_ltks "
return " content_ltks "
case _ :
return field_name
@staticmethod
def map_message_to_ob_fields ( message : dict ) - > dict :
""" Map message dictionary fields to OceanBase document fields. """
storage_doc = {
" id " : message . get ( " id " ) ,
" message_id " : message [ " message_id " ] ,
" message_type_kwd " : message [ " message_type " ] ,
" source_id " : message . get ( " source_id " ) ,
" memory_id " : message [ " memory_id " ] ,
" user_id " : message . get ( " user_id " , " " ) ,
" agent_id " : message [ " agent_id " ] ,
" session_id " : message [ " session_id " ] ,
" valid_at " : message [ " valid_at " ] ,
" invalid_at " : message . get ( " invalid_at " ) ,
" forget_at " : message . get ( " forget_at " ) ,
" status_int " : 1 if message [ " status " ] else 0 ,
" zone_id " : message . get ( " zone_id " , 0 ) ,
" content_ltks " : message [ " content " ] ,
" tokenized_content_ltks " : fine_grained_tokenize ( tokenize ( message [ " content " ] ) ) ,
}
# Handle vector embedding
content_embed = message . get ( " content_embed " , [ ] )
if len ( content_embed ) > 0 :
storage_doc [ f " q_ { len ( content_embed ) } _vec " ] = content_embed
return storage_doc
@staticmethod
def get_message_from_ob_doc ( doc : dict ) - > dict :
""" Convert an OceanBase document back to a message dictionary. """
embd_field_name = next ( ( key for key in doc . keys ( ) if re . match ( r " q_ \ d+_vec " , key ) ) , None )
content_embed = doc . get ( embd_field_name , [ ] ) if embd_field_name else [ ]
if isinstance ( content_embed , np . ndarray ) :
content_embed = content_embed . tolist ( )
message = {
" message_id " : doc . get ( " message_id " ) ,
" message_type " : doc . get ( " message_type_kwd " ) ,
" source_id " : doc . get ( " source_id " ) if doc . get ( " source_id " ) else None ,
" memory_id " : doc . get ( " memory_id " ) ,
" user_id " : doc . get ( " user_id " , " " ) ,
" agent_id " : doc . get ( " agent_id " ) ,
" session_id " : doc . get ( " session_id " ) ,
" zone_id " : doc . get ( " zone_id " , 0 ) ,
" valid_at " : doc . get ( " valid_at " ) ,
" invalid_at " : doc . get ( " invalid_at " , " - " ) ,
" forget_at " : doc . get ( " forget_at " , " - " ) ,
" status " : bool ( int ( doc . get ( " status_int " , 0 ) ) ) ,
" content " : doc . get ( " content_ltks " , " " ) ,
" content_embed " : content_embed ,
}
if doc . get ( " id " ) :
message [ " id " ] = doc [ " id " ]
return message
"""
CRUD operations
"""
def search (
self ,
select_fields : list [ str ] ,
highlight_fields : list [ str ] ,
condition : dict ,
match_expressions : list [ MatchExpr ] ,
order_by : OrderByExpr ,
offset : int ,
limit : int ,
index_names : str | list [ str ] ,
memory_ids : list [ str ] ,
agg_fields : list [ str ] | None = None ,
rank_feature : dict | None = None ,
2026-07-03 12:53:39 +08:00
hide_forgotten : bool = True ,
2026-02-03 16:46:17 +08:00
) :
""" Search messages in memory storage. """
if isinstance ( index_names , str ) :
index_names = index_names . split ( " , " )
assert isinstance ( index_names , list ) and len ( index_names ) > 0
result : SearchResult = SearchResult ( total = 0 , messages = [ ] )
output_fields = select_fields . copy ( )
if " id " not in output_fields :
output_fields = [ " id " ] + output_fields
if " _score " in output_fields :
output_fields . remove ( " _score " )
# Handle content_embed field - resolve to actual vector column name
has_content_embed = " content_embed " in output_fields
actual_vector_column : Optional [ str ] = None
if has_content_embed :
output_fields = [ f for f in output_fields if f != " content_embed " ]
# Try to get vector column name from first available table
for idx_name in index_names :
if self . _check_table_exists_cached ( idx_name ) :
actual_vector_column = self . _get_vector_column_name_from_table ( idx_name )
if actual_vector_column :
output_fields . append ( actual_vector_column )
break
if highlight_fields :
for field in highlight_fields :
field_name = self . convert_field_name ( field )
if field_name not in output_fields :
output_fields . append ( field_name )
db_output_fields = [ self . convert_field_name ( f ) for f in output_fields ]
fields_expr = " , " . join ( db_output_fields )
condition [ " memory_id " ] = memory_ids
if hide_forgotten :
condition [ " must_not " ] = { " exists " : " forget_at " }
condition_dict = { self . convert_field_name ( k ) : v for k , v in condition . items ( ) }
filters : list [ str ] = self . _get_filters ( condition_dict )
filters_expr = " AND " . join ( filters ) if filters else " 1=1 "
# Parse match expressions
fulltext_query : Optional [ str ] = None
fulltext_topn : Optional [ int ] = None
fulltext_search_expr : dict [ str , str ] = { }
fulltext_search_weight : dict [ str , float ] = { }
fulltext_search_filter : Optional [ str ] = None
fulltext_search_score_expr : Optional [ str ] = None
vector_column_name : Optional [ str ] = None
vector_data : Optional [ list [ float ] ] = None
vector_topn : Optional [ int ] = None
vector_similarity_threshold : Optional [ float ] = None
vector_similarity_weight : Optional [ float ] = None
vector_search_expr : Optional [ str ] = None
vector_search_score_expr : Optional [ str ] = None
vector_search_filter : Optional [ str ] = None
for m in match_expressions :
if isinstance ( m , MatchTextExpr ) :
assert " original_query " in m . extra_options , " ' original_query ' is missing in extra_options. "
fulltext_query = m . extra_options [ " original_query " ]
fulltext_query = escape_string ( fulltext_query . strip ( ) )
fulltext_topn = m . topn
2026-07-03 12:53:39 +08:00
fulltext_search_expr , fulltext_search_weight = self . _parse_fulltext_columns ( fulltext_query , self . _fulltext_search_columns )
2026-02-03 16:46:17 +08:00
elif isinstance ( m , MatchDenseExpr ) :
vector_column_name = m . vector_column_name
vector_data = m . embedding_data
vector_topn = m . topn
vector_similarity_threshold = m . extra_options . get ( " similarity " , 0.0 ) if m . extra_options else 0.0
elif isinstance ( m , FusionExpr ) :
weights = m . fusion_params . get ( " weights " , " 0.5,0.5 " ) if m . fusion_params else " 0.5,0.5 "
vector_similarity_weight = get_float ( weights . split ( " , " ) [ 1 ] )
if fulltext_query :
fulltext_search_filter = f " ( { ' OR ' . join ( [ expr for expr in fulltext_search_expr . values ( ) ] ) } ) "
fulltext_search_score_expr = f " ( { ' + ' . join ( f ' { expr } * { fulltext_search_weight . get ( col , 0 ) } ' for col , expr in fulltext_search_expr . items ( ) ) } ) "
if vector_data :
vector_data_str = " [ " + " , " . join ( [ str ( np . float32 ( v ) ) for v in vector_data ] ) + " ] "
vector_search_expr = vector_search_template % ( vector_column_name , vector_data_str )
vector_search_score_expr = f " (1 - { vector_search_expr } ) "
vector_search_filter = f " { vector_search_score_expr } >= { vector_similarity_threshold } "
# Determine search type
if fulltext_query and vector_data :
search_type = " fusion "
elif fulltext_query :
search_type = " fulltext "
elif vector_data :
search_type = " vector "
else :
search_type = " filter "
if search_type in [ " fusion " , " fulltext " , " vector " ] and " _score " not in output_fields :
output_fields . append ( " _score " )
if limit :
if vector_topn is not None :
limit = min ( vector_topn , limit )
if fulltext_topn is not None :
limit = min ( fulltext_topn , limit )
for index_name in index_names :
table_name = index_name
if not self . _check_table_exists_cached ( table_name ) :
continue
if search_type == " fusion " :
num_candidates = ( vector_topn or limit ) + ( fulltext_topn or limit )
score_expr = f " (relevance * { 1 - vector_similarity_weight } + { vector_search_score_expr } * { vector_similarity_weight } ) "
fusion_sql = (
f " WITH fulltext_results AS ( "
f " SELECT *, { fulltext_search_score_expr } AS relevance "
f " FROM { table_name } "
f " WHERE { filters_expr } AND { fulltext_search_filter } "
f " ORDER BY relevance DESC "
f " LIMIT { num_candidates } "
f " ) "
f " SELECT { fields_expr } , { score_expr } AS _score "
f " FROM fulltext_results "
f " WHERE { vector_search_filter } "
f " ORDER BY _score DESC "
f " LIMIT { offset } , { limit } "
)
self . logger . debug ( " OBConnection.search with fusion sql: %s " , fusion_sql )
rows , elapsed_time = self . _execute_search_sql ( fusion_sql )
2026-07-03 12:53:39 +08:00
self . logger . info ( f " OBConnection.search table { table_name } , search type: fusion, elapsed time: { elapsed_time : .3f } s, rows: { len ( rows ) } " )
2026-02-03 16:46:17 +08:00
for row in rows :
result . messages . append ( self . _row_to_entity ( row , db_output_fields + [ " _score " ] ) )
result . total + = 1
elif search_type == " vector " :
2026-07-03 12:53:39 +08:00
vector_sql = self . _build_vector_search_sql ( table_name , fields_expr , vector_search_score_expr , filters_expr , vector_search_filter , vector_search_expr , limit , vector_topn , offset )
2026-02-03 16:46:17 +08:00
self . logger . debug ( " OBConnection.search with vector sql: %s " , vector_sql )
rows , elapsed_time = self . _execute_search_sql ( vector_sql )
2026-07-03 12:53:39 +08:00
self . logger . info ( f " OBConnection.search table { table_name } , search type: vector, elapsed time: { elapsed_time : .3f } s, rows: { len ( rows ) } " )
2026-02-03 16:46:17 +08:00
for row in rows :
result . messages . append ( self . _row_to_entity ( row , db_output_fields + [ " _score " ] ) )
result . total + = 1
elif search_type == " fulltext " :
2026-07-03 12:53:39 +08:00
fulltext_sql = self . _build_fulltext_search_sql ( table_name , fields_expr , fulltext_search_score_expr , filters_expr , fulltext_search_filter , offset , limit , fulltext_topn )
2026-02-03 16:46:17 +08:00
self . logger . debug ( " OBConnection.search with fulltext sql: %s " , fulltext_sql )
rows , elapsed_time = self . _execute_search_sql ( fulltext_sql )
2026-07-03 12:53:39 +08:00
self . logger . info ( f " OBConnection.search table { table_name } , search type: fulltext, elapsed time: { elapsed_time : .3f } s, rows: { len ( rows ) } " )
2026-02-03 16:46:17 +08:00
for row in rows :
result . messages . append ( self . _row_to_entity ( row , db_output_fields + [ " _score " ] ) )
result . total + = 1
else :
orders : list [ str ] = [ ]
if order_by and order_by . fields :
for field , order_dir in order_by . fields :
field_name = self . convert_field_name ( field )
order_str = " ASC " if order_dir == 0 else " DESC "
orders . append ( f " { field_name } { order_str } " )
order_by_expr = ( " ORDER BY " + " , " . join ( orders ) ) if orders else " "
limit_expr = f " LIMIT { offset } , { limit } " if limit != 0 else " "
2026-07-03 12:53:39 +08:00
filter_sql = self . _build_filter_search_sql ( table_name , fields_expr , filters_expr , order_by_expr , limit_expr )
2026-02-03 16:46:17 +08:00
self . logger . debug ( " OBConnection.search with filter sql: %s " , filter_sql )
rows , elapsed_time = self . _execute_search_sql ( filter_sql )
2026-07-03 12:53:39 +08:00
self . logger . info ( f " OBConnection.search table { table_name } , search type: filter, elapsed time: { elapsed_time : .3f } s, rows: { len ( rows ) } " )
2026-02-03 16:46:17 +08:00
for row in rows :
result . messages . append ( self . _row_to_entity ( row , db_output_fields ) )
result . total + = 1
if result . total == 0 :
result . total = len ( result . messages )
return result , result . total
def get_forgotten_messages ( self , select_fields : list [ str ] , index_name : str , memory_id : str , limit : int = 512 ) :
""" Get forgotten messages (messages with forget_at set). """
if not self . _check_table_exists_cached ( index_name ) :
return None
db_output_fields = [ self . convert_field_name ( f ) for f in select_fields ]
fields_expr = " , " . join ( db_output_fields )
2026-07-03 12:53:39 +08:00
sql = f " SELECT { fields_expr } FROM { index_name } WHERE memory_id = { get_value_str ( memory_id ) } AND forget_at IS NOT NULL ORDER BY forget_at ASC LIMIT { limit } "
2026-02-03 16:46:17 +08:00
self . logger . debug ( " OBConnection.get_forgotten_messages sql: %s " , sql )
res = self . client . perform_raw_text_sql ( sql )
rows = res . fetchall ( )
result = SearchResult ( total = len ( rows ) , messages = [ ] )
for row in rows :
result . messages . append ( self . _row_to_entity ( row , db_output_fields ) )
return result
2026-07-03 12:53:39 +08:00
def get_missing_field_message ( self , select_fields : list [ str ] , index_name : str , memory_id : str , field_name : str , limit : int = 512 ) :
2026-02-03 16:46:17 +08:00
""" Get messages missing a specific field. """
if not self . _check_table_exists_cached ( index_name ) :
return None
db_field_name = self . convert_field_name ( field_name )
db_output_fields = [ self . convert_field_name ( f ) for f in select_fields ]
fields_expr = " , " . join ( db_output_fields )
2026-07-03 12:53:39 +08:00
sql = f " SELECT { fields_expr } FROM { index_name } WHERE memory_id = { get_value_str ( memory_id ) } AND { db_field_name } IS NULL ORDER BY valid_at ASC LIMIT { limit } "
2026-02-03 16:46:17 +08:00
self . logger . debug ( " OBConnection.get_missing_field_message sql: %s " , sql )
res = self . client . perform_raw_text_sql ( sql )
rows = res . fetchall ( )
result = SearchResult ( total = len ( rows ) , messages = [ ] )
for row in rows :
result . messages . append ( self . _row_to_entity ( row , db_output_fields ) )
return result
def get ( self , doc_id : str , index_name : str , memory_ids : list [ str ] ) - > dict | None :
""" Get single message by id. """
doc = super ( ) . get ( doc_id , index_name , memory_ids )
if doc is None :
return None
return self . get_message_from_ob_doc ( doc )
def insert ( self , documents : list [ dict ] , index_name : str , memory_id : str = None ) - > list [ str ] :
""" Insert messages into memory storage. """
if not documents :
return [ ]
vector_size = len ( documents [ 0 ] . get ( " content_embed " , [ ] ) ) if " content_embed " in documents [ 0 ] else 0
if not self . _check_table_exists_cached ( index_name ) :
if vector_size == 0 :
raise ValueError ( " Cannot infer vector size from documents " )
self . create_idx ( index_name , memory_id , vector_size )
elif vector_size > 0 :
# Table exists but may not have the required vector column
self . _ensure_vector_column_exists ( index_name , vector_size )
docs : list [ dict ] = [ ]
ids : list [ str ] = [ ]
for document in documents :
d = self . map_message_to_ob_fields ( document )
ids . append ( d [ " id " ] )
for column_name in COLUMN_NAMES :
if column_name not in d :
d [ column_name ] = None
docs . append ( d )
self . logger . debug ( " OBConnection.insert messages: %s " , ids )
res = [ ]
try :
self . client . upsert ( index_name , docs )
except Exception as e :
self . logger . error ( f " OBConnection.insert error: { str ( e ) } " )
res . append ( str ( e ) )
return res
def update ( self , condition : dict , new_value : dict , index_name : str , memory_id : str ) - > bool :
""" Update messages with given condition. """
if not self . _check_table_exists_cached ( index_name ) :
return True
condition [ " memory_id " ] = memory_id
condition_dict = { self . convert_field_name ( k ) : v for k , v in condition . items ( ) }
filters = self . _get_filters ( condition_dict )
update_dict = { self . convert_field_name ( k ) : v for k , v in new_value . items ( ) }
if " content_ltks " in update_dict :
update_dict [ " tokenized_content_ltks " ] = fine_grained_tokenize ( tokenize ( update_dict [ " content_ltks " ] ) )
update_dict . pop ( " id " , None )
set_values : list [ str ] = [ ]
for k , v in update_dict . items ( ) :
if k == " remove " :
if isinstance ( v , str ) :
set_values . append ( f " { v } = NULL " )
elif k == " status " :
set_values . append ( f " status_int = { 1 if v else 0 } " )
else :
set_values . append ( f " { k } = { get_value_str ( v ) } " )
if not set_values :
return True
2026-07-03 12:53:39 +08:00
update_sql = f " UPDATE { index_name } SET { ' , ' . join ( set_values ) } WHERE { ' AND ' . join ( filters ) } "
2026-02-03 16:46:17 +08:00
self . logger . debug ( " OBConnection.update sql: %s " , update_sql )
try :
self . client . perform_raw_text_sql ( update_sql )
return True
except Exception as e :
self . logger . error ( f " OBConnection.update error: { str ( e ) } " )
return False
def delete ( self , condition : dict , index_name : str , memory_id : str ) - > int :
""" Delete messages with given condition. """
condition_dict = { self . convert_field_name ( k ) : v for k , v in condition . items ( ) }
return super ( ) . delete ( condition_dict , index_name , memory_id )
"""
Helper functions for search result
"""
def get_total ( self , res ) - > int :
if isinstance ( res , tuple ) :
return res [ 1 ]
2026-07-03 12:53:39 +08:00
if hasattr ( res , " total " ) :
2026-02-03 16:46:17 +08:00
return res . total
return 0
def get_doc_ids ( self , res ) - > list [ str ] :
if isinstance ( res , tuple ) :
res = res [ 0 ]
2026-07-03 12:53:39 +08:00
if hasattr ( res , " messages " ) :
2026-02-03 16:46:17 +08:00
return [ row . get ( " id " ) for row in res . messages if row . get ( " id " ) ]
return [ ]
def get_fields ( self , res , fields : list [ str ] ) - > dict [ str , dict ] :
""" Get fields from search result. """
if isinstance ( res , tuple ) :
res = res [ 0 ]
res_fields = { }
if not fields :
return { }
2026-07-03 12:53:39 +08:00
messages = res . messages if hasattr ( res , " messages " ) else [ ]
2026-02-03 16:46:17 +08:00
for doc in messages :
message = self . get_message_from_ob_doc ( doc )
m = { }
for n , v in message . items ( ) :
if n not in fields :
continue
if isinstance ( v , list ) :
m [ n ] = v
continue
2026-07-03 12:53:39 +08:00
if n in [ " message_id " , " source_id " , " valid_at " , " invalid_at " , " forget_at " , " status " ] and isinstance ( v , ( int , float , bool ) ) :
2026-02-03 16:46:17 +08:00
m [ n ] = v
continue
if not isinstance ( v , str ) :
m [ n ] = str ( v ) if v is not None else " "
else :
m [ n ] = v
doc_id = doc . get ( " id " ) or message . get ( " id " )
if m and doc_id :
res_fields [ doc_id ] = m
return res_fields
def get_highlight ( self , res , keywords : list [ str ] , field_name : str ) :
""" Get highlighted text for search results. """
2026-03-06 07:17:11 -05:00
if isinstance ( res , tuple ) :
res = res [ 0 ]
messages = getattr ( res , " messages " , None )
2026-07-03 12:53:39 +08:00
return get_highlight_from_messages ( messages , keywords , field_name , is_english_fn = lambda s : is_english ( [ s ] ) )
2026-02-03 16:46:17 +08:00
def get_aggregation ( self , res , field_name : str ) :
""" Get aggregation for search results. """
2026-03-05 23:51:22 -05:00
if isinstance ( res , tuple ) :
res_obj = res [ 0 ]
else :
res_obj = res
messages = getattr ( res_obj , " messages " , None )
return aggregate_by_field ( messages , field_name )