# # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import re import json import time import copy from elasticsearch_dsl import UpdateByQuery, Q, Search from elastic_transport import ConnectionTimeout from common.decorator import singleton from common.doc_store.doc_store_base import MatchTextExpr, OrderByExpr, MatchExpr, MatchDenseExpr, FusionExpr from common.doc_store.es_conn_base import ESConnectionBase from common.float_utils import get_float from common.constants import PAGERANK_FLD, TAG_FLD ATTEMPT_TIME = 2 MAX_RESULT_WINDOW = 10000 SEARCH_AFTER_BATCH_SIZE = 1000 # Single-document atomic pagerank_fea adjust (chunk feedback). Clamps using params.min_w / max_w; # removes field at zero for rank_feature compatibility. _PAGERANK_FEA_ADJUST_SCRIPT = """ double cur = 0.0; if (ctx._source.containsKey(params.pf)) { Object v = ctx._source[params.pf]; if (v != null) { if (v instanceof Number) { cur = ((Number)v).doubleValue(); } else { try { cur = Double.parseDouble(v.toString()); } catch (Exception e) { cur = 0.0; } } } } double nw = cur + params.delta; if (nw < params.min_w) { nw = params.min_w; } if (nw > params.max_w) { nw = params.max_w; } if (nw <= 0.0) { if (ctx._source.containsKey(params.pf)) { ctx._source.remove(params.pf); } } else { ctx._source[params.pf] = nw; } """ @singleton class ESConnection(ESConnectionBase): """ CRUD operations """ def _es_search_once(self, index_names: list[str], query: dict, track_total_hits: bool): return self.es.search( index=index_names, body=query, timeout="600s", track_total_hits=track_total_hits ) def _search_with_search_after(self, index_names: list[str], query: dict, offset: int, limit: int): q_base = copy.deepcopy(query) q_base.pop("from", None) q_base.pop("size", None) search_after = None template_res = None collected_hits = [] remaining_skip = max(0, offset) remaining_take = max(0, limit) with_aggs = True while remaining_skip > 0: batch = min(SEARCH_AFTER_BATCH_SIZE, remaining_skip) q_iter = copy.deepcopy(q_base) q_iter["size"] = batch if search_after is not None: q_iter["search_after"] = search_after if not with_aggs: q_iter.pop("aggs", None) res = self._es_search_once(index_names, q_iter, track_total_hits=template_res is None) if template_res is None: template_res = res hits = res.get("hits", {}).get("hits", []) if not hits: break next_search_after = hits[-1].get("sort") if not next_search_after or next_search_after == search_after: break search_after = next_search_after remaining_skip -= len(hits) with_aggs = False if len(hits) < batch: break while remaining_skip <= 0 and remaining_take > 0: batch = min(SEARCH_AFTER_BATCH_SIZE, remaining_take) q_iter = copy.deepcopy(q_base) q_iter["size"] = batch if search_after is not None: q_iter["search_after"] = search_after if not with_aggs: q_iter.pop("aggs", None) res = self._es_search_once(index_names, q_iter, track_total_hits=template_res is None) if template_res is None: template_res = res hits = res.get("hits", {}).get("hits", []) if not hits: break collected_hits.extend(hits) remaining_take -= len(hits) next_search_after = hits[-1].get("sort") if not next_search_after or next_search_after == search_after: break search_after = next_search_after with_aggs = False if len(hits) < batch: break if template_res is None: q_count = copy.deepcopy(q_base) q_count["size"] = 0 template_res = self._es_search_once(index_names, q_count, track_total_hits=True) template_res["hits"]["hits"] = collected_hits return template_res def search( self, select_fields: list[str], highlight_fields: list[str], condition: dict, match_expressions: list[MatchExpr], order_by: OrderByExpr, offset: int, limit: int, index_names: str | list[str], knowledgebase_ids: list[str], agg_fields: list[str] | None = None, rank_feature: dict | None = None ): """ Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html """ if isinstance(index_names, str): index_names = index_names.split(",") assert isinstance(index_names, list) and len(index_names) > 0 assert "_id" not in condition bool_query = Q("bool", must=[]) condition["kb_id"] = knowledgebase_ids for k, v in condition.items(): if k == "available_int": if v == 0: bool_query.filter.append(Q("range", available_int={"lt": 1})) else: bool_query.filter.append( Q("bool", must_not=Q("range", available_int={"lt": 1}))) continue if k == "id": if not v: continue if isinstance(v, list): bool_query.filter.append( Q("bool", should=[Q("terms", id=v), Q("terms", _id=v)], minimum_should_match=1)) elif isinstance(v, str) or isinstance(v, int): bool_query.filter.append( Q("bool", should=[Q("term", id=v), Q("term", _id=v)], minimum_should_match=1)) continue if not v: continue if isinstance(v, list): bool_query.filter.append(Q("terms", **{k: v})) elif isinstance(v, str) or isinstance(v, int): bool_query.filter.append(Q("term", **{k: v})) else: raise Exception( f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") s = Search() vector_similarity_weight = 0.5 for m in match_expressions: if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params: assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance( match_expressions[1], MatchDenseExpr) and isinstance( match_expressions[2], FusionExpr) weights = m.fusion_params["weights"] vector_similarity_weight = get_float(weights.split(",")[1]) for m in match_expressions: if isinstance(m, MatchTextExpr): minimum_should_match = m.extra_options.get("minimum_should_match", 0.0) if isinstance(minimum_should_match, float): minimum_should_match = str(int(minimum_should_match * 100)) + "%" bool_query.must.append(Q("query_string", fields=m.fields, type="best_fields", query=m.matching_text, minimum_should_match=minimum_should_match, boost=1)) bool_query.boost = 1.0 - vector_similarity_weight elif isinstance(m, MatchDenseExpr): assert (bool_query is not None) similarity = 0.0 if "similarity" in m.extra_options: similarity = m.extra_options["similarity"] s = s.knn(m.vector_column_name, m.topn, m.topn * 2, query_vector=list(m.embedding_data), filter=bool_query.to_dict(), similarity=similarity, ) if bool_query and rank_feature: for fld, sc in rank_feature.items(): if fld != PAGERANK_FLD: fld = f"{TAG_FLD}.{fld}" bool_query.should.append(Q("rank_feature", field=fld, linear={}, boost=sc)) if bool_query: s = s.query(bool_query) for field in highlight_fields: s = s.highlight(field) if order_by: orders = list() for field, order in order_by.fields: order = "asc" if order == 0 else "desc" if field in ["page_num_int", "top_int"]: order_info = {"order": order, "unmapped_type": "float", "mode": "avg", "numeric_type": "double"} elif field.endswith("_int") or field.endswith("_flt"): order_info = {"order": order, "unmapped_type": "float"} elif field == "id": continue # id as "text", not a "keyword", order by it will cause error else: order_info = {"order": order, "unmapped_type": "keyword"} orders.append({field: order_info}) s = s.sort(*orders) if agg_fields: for fld in agg_fields: s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000) has_dense = any(isinstance(m, MatchDenseExpr) for m in match_expressions) has_explicit_sort = bool(order_by and order_by.fields) use_search_after = ( limit > 0 and (offset + limit > MAX_RESULT_WINDOW) and has_explicit_sort and not has_dense ) if limit > 0 and not use_search_after: s = s[offset:offset + limit] # Filter _source to only requested fields for efficiency, and add vector # fields to "fields" param so they appear in hit.fields when ES 9.x # exclude_source_vectors is enabled (dense_vector not in _source). if select_fields: s = s.source(select_fields) q = s.to_dict() # ES 9.x: dense_vector fields excluded from _source; request them via fields. # Note: knn does NOT have a "fields" parameter - adding it inside the knn # object causes BadRequestError on ES 9.x. We add "fields" at top level. vector_fields = [f for f in (select_fields or []) if f.endswith("_vec")] if vector_fields: q["fields"] = vector_fields self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q)) for i in range(ATTEMPT_TIME): try: if use_search_after: res = self._search_with_search_after(index_names, q, offset, limit) else: # print(json.dumps(q, ensure_ascii=False)) res = self._es_search_once(index_names, q, track_total_hits=True) if str(res.get("timed_out", "")).lower() == "true": raise Exception("Es Timeout.") self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res)) return res except ConnectionTimeout: self.logger.exception("ES request timeout") self._connect() continue except Exception as e: # Only log debug for NotFoundError(accepted when metadata index doesn't exist) if 'NotFound' in str(e): self.logger.debug(f"ESConnection.search {str(index_names)} query: " + str(q) + " - " + str(e)) else: self.logger.exception(f"ESConnection.search {str(index_names)} query: " + str(q) + str(e)) raise e self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!") raise Exception("ESConnection.search timeout.") def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]: # Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html operations = [] for d in documents: assert "_id" not in d assert "id" in d d_copy = copy.deepcopy(d) d_copy["kb_id"] = knowledgebase_id # Use id as _id for uniqueness, also keep "id" as a regular field for sorting meta_id = d_copy.get("id", "") operations.append( {"index": {"_index": index_name, "_id": meta_id}}) operations.append(d_copy) res = [] for _ in range(ATTEMPT_TIME): try: res = [] r = self.es.bulk(index=index_name, operations=operations, refresh="wait_for", timeout="60s") if re.search(r"False", str(r["errors"]), re.IGNORECASE): return res for item in r["items"]: for action in ["create", "delete", "index", "update"]: if action in item and "error" in item[action]: res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"])) return res except ConnectionTimeout: self.logger.exception("ES request timeout") time.sleep(3) self._connect() continue except Exception as e: res.append(str(e)) self.logger.warning("ESConnection.insert got exception: " + str(e)) return res def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool: doc = copy.deepcopy(new_value) doc.pop("id", None) condition["kb_id"] = knowledgebase_id if "id" in condition and isinstance(condition["id"], str): # update specific single document chunk_id = condition["id"] for i in range(ATTEMPT_TIME): doc_part = copy.deepcopy(doc) remove_value = doc_part.pop("remove", None) remove_field = remove_value if isinstance(remove_value, str) else None remove_dict = remove_value if isinstance(remove_value, dict) else None for k in doc_part.keys(): if "feas" != k.split("_")[-1]: continue try: self.es.update(index=index_name, id=chunk_id, script=f"ctx._source.remove(\"{k}\");") except Exception: self.logger.exception( f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception") try: if remove_field is not None: self.es.update( index=index_name, id=chunk_id, script=f"ctx._source.remove('{remove_field}');", ) if remove_dict is not None: scripts = [] params = {} for kk, vv in remove_dict.items(): scripts.append( f"if (ctx._source.containsKey('{kk}') && ctx._source.{kk} != null) " f"{{ int i = ctx._source.{kk}.indexOf(params.p_{kk}); " f"if (i >= 0) {{ ctx._source.{kk}.remove(i); }} }}" ) params[f"p_{kk}"] = vv if scripts: self.es.update( index=index_name, id=chunk_id, script={"source": "".join(scripts), "params": params}, ) if doc_part: self.es.update(index=index_name, id=chunk_id, doc=doc_part) if remove_field is not None or remove_dict is not None or doc_part: return True except Exception as e: self.logger.exception( f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str( e)) break return False # update unspecific maybe-multiple documents bool_query = Q("bool") for k, v in condition.items(): if not isinstance(k, str) or not v: continue if k == "exists": bool_query.filter.append(Q("exists", field=v)) continue if isinstance(v, list): bool_query.filter.append(Q("terms", **{k: v})) elif isinstance(v, str) or isinstance(v, int): bool_query.filter.append(Q("term", **{k: v})) else: raise Exception( f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") scripts = [] params = {} for k, v in new_value.items(): if k == "remove": if isinstance(v, str): scripts.append(f"ctx._source.remove('{v}');") if isinstance(v, dict): for kk, vv in v.items(): scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);") params[f"p_{kk}"] = vv continue if k == "add": if isinstance(v, dict): for kk, vv in v.items(): scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});") params[f"pp_{kk}"] = vv.strip() continue if (not isinstance(k, str) or not v) and k != "available_int": continue if isinstance(v, str): v = re.sub(r"(['\n\r]|\\.)", " ", v) params[f"pp_{k}"] = v scripts.append(f"ctx._source.{k}=params.pp_{k};") elif isinstance(v, int) or isinstance(v, float): scripts.append(f"ctx._source.{k}={v};") elif isinstance(v, list): scripts.append(f"ctx._source.{k}=params.pp_{k};") params[f"pp_{k}"] = json.dumps(v, ensure_ascii=False) else: raise Exception( f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.") ubq = UpdateByQuery( index=index_name).using( self.es).query(bool_query) ubq = ubq.script(source="".join(scripts), params=params) ubq = ubq.params(refresh=True) ubq = ubq.params(slices=5) ubq = ubq.params(conflicts="proceed") for _ in range(ATTEMPT_TIME): try: _ = ubq.execute() return True except ConnectionTimeout: self.logger.exception("ES request timeout") time.sleep(3) self._connect() continue except Exception as e: self.logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts)) break return False def adjust_chunk_pagerank_fea( self, chunk_id: str, index_name: str, knowledgebase_id: str, delta: float, min_w: float = 0.0, max_w: float = 100.0, row_id: int | None = None, ) -> bool: """Atomically adjust pagerank_fea on one chunk (painless script).""" _ = row_id for _ in range(ATTEMPT_TIME): try: self.es.update( index=index_name, id=chunk_id, retry_on_conflict=3, script={ "source": _PAGERANK_FEA_ADJUST_SCRIPT.strip(), "lang": "painless", "params": { "pf": PAGERANK_FLD, "delta": float(delta), "min_w": float(min_w), "max_w": float(max_w), }, }, ) self.logger.debug( "ESConnection.adjust_chunk_pagerank_fea(index=%s, id=%s, delta=%s) succeeded", index_name, chunk_id, delta, ) return True except ConnectionTimeout: self.logger.exception("ES request timeout") time.sleep(3) self._connect() continue except Exception as e: self.logger.exception( "ESConnection.adjust_chunk_pagerank_fea(index=%s, id=%s): %s", index_name, chunk_id, e, ) if re.search(r"connection", str(e).lower()): time.sleep(3) self._connect() continue break return False def delete(self, condition: dict, index_name: str, knowledgebase_id: str) -> int: assert "_id" not in condition condition["kb_id"] = knowledgebase_id # Build a bool query that combines id filter with other conditions bool_query = Q("bool") # Handle chunk IDs if present if "id" in condition: chunk_ids = condition["id"] if not isinstance(chunk_ids, list): chunk_ids = [chunk_ids] if chunk_ids: # Filter by specific chunk IDs bool_query.filter.append(Q("ids", values=chunk_ids)) # If chunk_ids is empty, we don't add an ids filter - rely on other conditions # Add all other conditions as filters for k, v in condition.items(): if k == "id": continue # Already handled above if k == "exists": bool_query.filter.append(Q("exists", field=v)) elif k == "must_not": if isinstance(v, dict): for kk, vv in v.items(): if kk == "exists": bool_query.must_not.append(Q("exists", field=vv)) elif isinstance(v, list): bool_query.must.append(Q("terms", **{k: v})) elif isinstance(v, str) or isinstance(v, int): bool_query.must.append(Q("term", **{k: v})) elif v is not None: raise Exception("Condition value must be int, str or list.") # If no filters were added, use match_all (for tenant-wide operations) if not bool_query.filter and not bool_query.must and not bool_query.must_not: qry = Q("match_all") else: qry = bool_query self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict())) for _ in range(ATTEMPT_TIME): try: res = self.es.delete_by_query( index=index_name, body=Search().query(qry).to_dict(), refresh=True) return res["deleted"] except ConnectionTimeout: self.logger.exception("ES request timeout") time.sleep(3) self._connect() continue except Exception as e: self.logger.warning("ESConnection.delete got exception: " + str(e)) if re.search(r"(not_found)", str(e), re.IGNORECASE): return 0 return 0 """ Helper functions for search result """ def get_fields(self, res, fields: list[str]) -> dict[str, dict]: res_fields = {} if not fields: return {} hits = res.get("hits", {}).get("hits", []) for hit in hits: doc_id = hit.get("_id") d = hit.get("_source", {}) # Also extract fields from ES "fields" response (used by dense_vector in ES 9.x) hit_fields = hit.get("fields", {}) m = {} for n in fields: # First check _source if d.get(n) is not None: m[n] = d.get(n) # Then check fields (ES 9.x stores dense_vector here, not in _source) elif n in hit_fields: vals = hit_fields[n] # ES fields response wraps dense_vector in 2 levels: [[v1,v2,...]] -> [v1,v2,...] if isinstance(vals, list) and len(vals) == 1: vals = vals[0] m[n] = vals for n, v in m.items(): if isinstance(v, list): m[n] = v continue if n == "available_int" and isinstance(v, (int, float)): m[n] = v continue if not isinstance(v, str): m[n] = str(m[n]) # if n.find("tks") > 0: # m[n] = remove_redundant_spaces(m[n]) if m: res_fields[doc_id] = m return res_fields