diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 19f215bc85..30f6f047de 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -540,15 +540,18 @@ class Dealer: res = [] bs = 128 for p in range(offset, max_count, bs): - es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id), + limit = min(bs, max_count - p) + if limit <= 0: + break + es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, limit, index_name(tenant_id), kb_ids) dict_chunks = self.dataStore.get_fields(es_res, fields) for id, doc in dict_chunks.items(): doc["id"] = id if dict_chunks: res.extend(dict_chunks.values()) - # FIX: Solo terminar si no hay chunks, no si hay menos de bs - if len(dict_chunks.values()) == 0: + chunk_count = len(dict_chunks) + if chunk_count == 0 or chunk_count < limit: break return res diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index 8c1e506b4d..6f88c9a44e 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -28,6 +28,8 @@ from common.float_utils import get_float from common.constants import PAGERANK_FLD, TAG_FLD ATTEMPT_TIME = 2 +MAX_RESULT_WINDOW = 10000 +SEARCH_AFTER_BATCH_SIZE = 1000 @singleton @@ -36,6 +38,81 @@ class ESConnection(ESConnectionBase): CRUD operations """ + def _es_search_once(self, index_names: list[str], query: dict, track_total_hits: bool): + return self.es.search( + index=index_names, + body=query, + timeout="600s", + track_total_hits=track_total_hits, + _source=True, + ) + + def _search_with_search_after(self, index_names: list[str], query: dict, offset: int, limit: int): + q_base = copy.deepcopy(query) + q_base.pop("from", None) + q_base.pop("size", None) + + search_after = None + template_res = None + collected_hits = [] + remaining_skip = max(0, offset) + remaining_take = max(0, limit) + with_aggs = True + + while remaining_skip > 0: + batch = min(SEARCH_AFTER_BATCH_SIZE, remaining_skip) + q_iter = copy.deepcopy(q_base) + q_iter["size"] = batch + if search_after is not None: + q_iter["search_after"] = search_after + if not with_aggs: + q_iter.pop("aggs", None) + res = self._es_search_once(index_names, q_iter, track_total_hits=template_res is None) + if template_res is None: + template_res = res + hits = res.get("hits", {}).get("hits", []) + if not hits: + break + next_search_after = hits[-1].get("sort") + if not next_search_after or next_search_after == search_after: + break + search_after = next_search_after + remaining_skip -= len(hits) + with_aggs = False + if len(hits) < batch: + break + + while remaining_skip <= 0 and remaining_take > 0: + batch = min(SEARCH_AFTER_BATCH_SIZE, remaining_take) + q_iter = copy.deepcopy(q_base) + q_iter["size"] = batch + if search_after is not None: + q_iter["search_after"] = search_after + if not with_aggs: + q_iter.pop("aggs", None) + res = self._es_search_once(index_names, q_iter, track_total_hits=template_res is None) + if template_res is None: + template_res = res + hits = res.get("hits", {}).get("hits", []) + if not hits: + break + collected_hits.extend(hits) + remaining_take -= len(hits) + next_search_after = hits[-1].get("sort") + if not next_search_after or next_search_after == search_after: + break + search_after = next_search_after + with_aggs = False + if len(hits) < batch: + break + + if template_res is None: + q_count = copy.deepcopy(q_base) + q_count["size"] = 0 + template_res = self._es_search_once(index_names, q_count, track_total_hits=True) + template_res["hits"]["hits"] = collected_hits + return template_res + def search( self, select_fields: list[str], highlight_fields: list[str], @@ -139,20 +216,27 @@ class ESConnection(ESConnectionBase): for fld in agg_fields: s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000) - if limit > 0: + has_dense = any(isinstance(m, MatchDenseExpr) for m in match_expressions) + has_explicit_sort = bool(order_by and order_by.fields) + use_search_after = ( + limit > 0 + and (offset + limit > MAX_RESULT_WINDOW) + and has_explicit_sort + and not has_dense + ) + + if limit > 0 and not use_search_after: s = s[offset:offset + limit] q = s.to_dict() self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q)) for i in range(ATTEMPT_TIME): try: - # print(json.dumps(q, ensure_ascii=False)) - res = self.es.search(index=index_names, - body=q, - timeout="600s", - # search_type="dfs_query_then_fetch", - track_total_hits=True, - _source=True) + if use_search_after: + res = self._search_with_search_after(index_names, q, offset, limit) + else: + # print(json.dumps(q, ensure_ascii=False)) + res = self._es_search_once(index_names, q, track_total_hits=True) if str(res.get("timed_out", "")).lower() == "true": raise Exception("Es Timeout.") self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res))