mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-01 00:05:43 +08:00
### What problem does this PR solve? Fixes the OpenSearch side of #10747: hybrid search drops the keyword (BM25) leg and ends up doing plain vector search. When a search has both a text and a vector leg, `OSConnection.search()` throws the text query away: del q["query"] q["query"] = {"knn": knn_query} The text clause only stays on as a filter inside the knn query, so it narrows the candidate set but doesn't count towards scoring. So hybrid search on OpenSearch behaves like plain vector search, unlike the Elasticsearch backend. What I changed: - when both legs are present, send a real hybrid query `{"hybrid": {"queries": [bm25, {"knn": ...}]}}` and let a normalization-processor search pipeline score and combine the two legs - only the actual filters (kb_id, available_int, ...) go in the knn filter, not the text must clause - create the pipeline on startup if it's missing, so there's no separate provisioning step. name and weights can be set under `os:` in service_conf.yaml, or via `OS_HYBRID_PIPELINE`; defaults are `ragflow_hybrid_pipeline` and `[0.5, 0.5]` - normalization-processor needs OpenSearch 2.10+. on older clusters, or when the pipeline can't be created, log a warning and fall back to vector-only instead of pointing at a pipeline that doesn't exist This is only the hybrid-search fix; `create_doc_meta_idx` is already on main. Testing (there's no OpenSearch path in CI): added a unit test (`test/unit_test/rag/utils/test_opensearch_hybrid_search.py`, no services needed) that checks the query built in each case — hybrid + pipeline param for text+vector, plain knn for vector-only, plain bool for text-only, the knn filter never carrying the text query_string, and the vector-only fallback when the pipeline isn't available. Also ran it against a real OpenSearch 2.19.1 container with a doc that matches the keyword but sits outside the knn top-k: pure knn returns `['D1','D2','D5']` (keyword doc missing), the hybrid query returns `['A','D1','D2','D5']` (keyword doc present). ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) Signed-off-by: Danut Matei <matei.danut.dm@gmail.com>
873 lines
36 KiB
Python
873 lines
36 KiB
Python
#
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import logging
|
|
import re
|
|
import json
|
|
import time
|
|
import os
|
|
|
|
import copy
|
|
from opensearchpy import OpenSearch, NotFoundError
|
|
from opensearchpy import UpdateByQuery, Q, Search, Index
|
|
from opensearchpy import ConnectionTimeout
|
|
from common.decorator import singleton
|
|
from common.file_utils import get_project_base_directory
|
|
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
|
|
FusionExpr
|
|
from rag.nlp import is_english, rag_tokenizer
|
|
from common.constants import PAGERANK_FLD, TAG_FLD
|
|
from common import settings
|
|
|
|
ATTEMPT_TIME = 2
|
|
|
|
_PAGERANK_FEA_ADJUST_SCRIPT = """
|
|
double cur = 0.0;
|
|
if (ctx._source.containsKey(params.pf)) {
|
|
Object v = ctx._source[params.pf];
|
|
if (v != null) {
|
|
if (v instanceof Number) {
|
|
cur = ((Number)v).doubleValue();
|
|
} else {
|
|
try { cur = Double.parseDouble(v.toString()); } catch (Exception e) { cur = 0.0; }
|
|
}
|
|
}
|
|
}
|
|
double nw = cur + params.delta;
|
|
if (nw < params.min_w) { nw = params.min_w; }
|
|
if (nw > params.max_w) { nw = params.max_w; }
|
|
if (nw <= 0.0) {
|
|
if (ctx._source.containsKey(params.pf)) {
|
|
ctx._source.remove(params.pf);
|
|
}
|
|
} else {
|
|
ctx._source[params.pf] = nw;
|
|
}
|
|
"""
|
|
|
|
logger = logging.getLogger('ragflow.opensearch_conn')
|
|
|
|
|
|
@singleton
|
|
class OSConnection(DocStoreConnection):
|
|
def __init__(self):
|
|
self.info = {}
|
|
logger.info(f"Use OpenSearch {settings.OS['hosts']} as the doc engine.")
|
|
for _ in range(ATTEMPT_TIME):
|
|
try:
|
|
self.os = OpenSearch(
|
|
settings.OS["hosts"].split(","),
|
|
http_auth=(settings.OS["username"], settings.OS[
|
|
"password"]) if "username" in settings.OS and "password" in settings.OS else None,
|
|
verify_certs=False,
|
|
timeout=600
|
|
)
|
|
if self.os:
|
|
self.info = self.os.info()
|
|
break
|
|
except Exception as e:
|
|
logger.warning(f"{str(e)}. Waiting OpenSearch {settings.OS['hosts']} to be healthy.")
|
|
time.sleep(5)
|
|
if not self.os.ping():
|
|
msg = f"OpenSearch {settings.OS['hosts']} is unhealthy in 120s."
|
|
logger.error(msg)
|
|
raise Exception(msg)
|
|
v = self.info.get("version", {"number": "2.18.0"})
|
|
v = v["number"].split(".")[0]
|
|
if int(v) < 2:
|
|
msg = f"OpenSearch version must be greater than or equal to 2, current version: {v}"
|
|
logger.error(msg)
|
|
raise Exception(msg)
|
|
fp_mapping = os.path.join(get_project_base_directory(), "conf", "os_mapping.json")
|
|
if not os.path.exists(fp_mapping):
|
|
msg = f"OpenSearch mapping file not found at {fp_mapping}"
|
|
logger.error(msg)
|
|
raise Exception(msg)
|
|
with open(fp_mapping, "r") as f:
|
|
self.mapping = json.load(f)
|
|
logger.info(f"OpenSearch {settings.OS['hosts']} is healthy.")
|
|
self._init_hybrid_search()
|
|
|
|
# normalization-processor (needed to merge the BM25 and KNN scores) only
|
|
# exists on OpenSearch 2.10+.
|
|
HYBRID_MIN_VERSION = (2, 10)
|
|
|
|
def _init_hybrid_search(self):
|
|
"""Create the hybrid-search pipeline if it isn't there yet.
|
|
|
|
A {"hybrid": {...}} query is scored by a normalization-processor that has
|
|
to live on a search pipeline, otherwise OpenSearch rejects the query. We
|
|
create it once at startup (PUT _search/pipeline is idempotent) so there's
|
|
no extra setup step to run.
|
|
|
|
Sets self.hybrid_search_enabled. If the pipeline can't be created
|
|
(OpenSearch < 2.10, or no permission to manage pipelines) we log a
|
|
warning, leave it off, and search() keeps doing vector-only.
|
|
"""
|
|
self.hybrid_search_enabled = False
|
|
self._hybrid_pipeline = os.environ.get("OS_HYBRID_PIPELINE") \
|
|
or settings.OS.get("hybrid_search_pipeline") or "ragflow_hybrid_pipeline"
|
|
|
|
version_number = self.info.get("version", {}).get("number", "")
|
|
try:
|
|
version = tuple(int(p) for p in version_number.split(".")[:2])
|
|
except (ValueError, AttributeError):
|
|
version = (0, 0)
|
|
if version < self.HYBRID_MIN_VERSION:
|
|
logger.warning(f"OpenSearch {version_number or 'unknown'} does not support the "
|
|
f"normalization-processor (requires >= {self.HYBRID_MIN_VERSION[0]}."
|
|
f"{self.HYBRID_MIN_VERSION[1]}); hybrid search is disabled and "
|
|
f"queries fall back to vector-only.")
|
|
return
|
|
|
|
weights = settings.OS.get("hybrid_search_weights", [0.5, 0.5])
|
|
pipeline_body = {
|
|
"description": "RAGFlow hybrid search normalization pipeline (BM25 + KNN).",
|
|
"phase_results_processors": [
|
|
{"normalization-processor": {
|
|
"normalization": {"technique": "min_max"},
|
|
"combination": {"technique": "arithmetic_mean",
|
|
"parameters": {"weights": weights}}}}
|
|
],
|
|
}
|
|
try:
|
|
self.os.transport.perform_request(
|
|
"PUT", f"/_search/pipeline/{self._hybrid_pipeline}", body=pipeline_body)
|
|
self.hybrid_search_enabled = True
|
|
logger.info(f"OpenSearch hybrid search enabled via pipeline "
|
|
f"'{self._hybrid_pipeline}' (weights {weights}).")
|
|
except Exception:
|
|
logger.warning(f"Could not create OpenSearch search pipeline '{self._hybrid_pipeline}'; "
|
|
f"hybrid search is disabled and queries fall back to vector-only. "
|
|
f"Creating a search pipeline needs the "
|
|
f"'cluster:admin/search/pipeline/put' privilege (relevant on "
|
|
f"locked-down or managed OpenSearch).", exc_info=True)
|
|
|
|
"""
|
|
Database operations
|
|
"""
|
|
|
|
def db_type(self) -> str:
|
|
return "opensearch"
|
|
|
|
def health(self) -> dict:
|
|
health_dict = dict(self.os.cluster.health())
|
|
health_dict["type"] = "opensearch"
|
|
return health_dict
|
|
|
|
"""
|
|
Table operations
|
|
"""
|
|
|
|
def create_idx(self, indexName: str, knowledgebaseId: str, vectorSize: int, parser_id: str = None):
|
|
if self.index_exist(indexName, knowledgebaseId):
|
|
return True
|
|
try:
|
|
from opensearchpy.client import IndicesClient
|
|
return IndicesClient(self.os).create(index=indexName,
|
|
body=self.mapping)
|
|
except Exception:
|
|
logger.exception("OSConnection.createIndex error %s" % (indexName))
|
|
|
|
def create_doc_meta_idx(self, index_name: str):
|
|
"""
|
|
Create a per-tenant document metadata index on OpenSearch.
|
|
|
|
Mirrors ESConnectionBase.create_doc_meta_idx so that the
|
|
DocMetadataService dispatches uniformly across ES and OS backends.
|
|
Index name pattern: ragflow_doc_meta_{tenant_id}
|
|
"""
|
|
if self.index_exist(index_name, ""):
|
|
return True
|
|
try:
|
|
fp_mapping = os.path.join(get_project_base_directory(), "conf", "doc_meta_es_mapping.json")
|
|
if not os.path.exists(fp_mapping):
|
|
logger.error(f"Document metadata mapping file not found at {fp_mapping}")
|
|
return False
|
|
|
|
with open(fp_mapping, "r") as f:
|
|
doc_meta_mapping = json.load(f)
|
|
|
|
mappings = doc_meta_mapping["mappings"]
|
|
# `conf/doc_meta_es_mapping.json` declares a top-level
|
|
# `"dynamic": "runtime"`. Runtime fields are an Elasticsearch-only
|
|
# feature; OpenSearch cannot parse the value and rejects index
|
|
# creation with `mapper_parsing_exception: Could not convert
|
|
# [dynamic.dynamic] to boolean`. Fall back to standard dynamic
|
|
# mapping (`true`) on OpenSearch so dynamic field discovery is kept
|
|
# without the ES-specific runtime semantics. The shared mapping file
|
|
# is left untouched so the Elasticsearch backend still gets runtime
|
|
# fields.
|
|
if mappings.get("dynamic") == "runtime":
|
|
mappings = {**mappings, "dynamic": True}
|
|
|
|
from opensearchpy.client import IndicesClient
|
|
body = {
|
|
"settings": doc_meta_mapping["settings"],
|
|
"mappings": mappings,
|
|
}
|
|
return IndicesClient(self.os).create(index=index_name, body=body)
|
|
except Exception as e:
|
|
logger.exception(f"OSConnection.create_doc_meta_idx error creating {index_name}: {e}")
|
|
return False
|
|
|
|
def refresh_idx(self, index_name: str) -> bool:
|
|
"""
|
|
Refresh an index so that recently inserted documents become searchable.
|
|
|
|
DocMetadataService used to call ``settings.docStoreConn.es.indices.refresh``
|
|
directly, which raised AttributeError on the OpenSearch backend because
|
|
OSConnection exposes ``self.os`` rather than ``self.es``. This wrapper
|
|
gives both backends a uniform abstract entry point.
|
|
"""
|
|
try:
|
|
self.os.indices.refresh(index=index_name)
|
|
return True
|
|
except NotFoundError:
|
|
return False
|
|
except Exception as e:
|
|
logger.warning(f"OSConnection.refresh_idx({index_name}) failed: {e}")
|
|
return False
|
|
|
|
def count_idx(self, index_name: str) -> int:
|
|
"""
|
|
Return the document count for an index, or -1 if the call fails.
|
|
|
|
Used by DocMetadataService._drop_empty_metadata_table to decide whether
|
|
a per-tenant metadata index is empty without paying a full search.
|
|
"""
|
|
try:
|
|
response = self.os.count(index=index_name)
|
|
return int(response.get("count", 0))
|
|
except NotFoundError:
|
|
return 0
|
|
except Exception as e:
|
|
logger.warning(f"OSConnection.count_idx({index_name}) failed: {e}")
|
|
return -1
|
|
|
|
def replace_meta_fields(self, index_name: str, doc_id: str, meta_fields: dict) -> bool:
|
|
"""
|
|
Replace the ``meta_fields`` object on a single document.
|
|
|
|
ES.update with a ``doc`` body deep-merges object fields, which retains
|
|
old keys that should be removed. The fix in ESConnection is a script
|
|
that fully assigns the new meta_fields. We provide the same primitive
|
|
on OpenSearch so the service layer never reaches into ``self.es`` or
|
|
``self.os`` directly.
|
|
"""
|
|
body = {
|
|
"script": {
|
|
"source": "ctx._source.meta_fields = params.meta_fields",
|
|
"params": {"meta_fields": meta_fields},
|
|
}
|
|
}
|
|
for _ in range(ATTEMPT_TIME):
|
|
try:
|
|
self.os.update(index=index_name, id=doc_id, body=body, refresh=True)
|
|
return True
|
|
except NotFoundError:
|
|
return False
|
|
except Exception as e:
|
|
logger.warning(f"OSConnection.replace_meta_fields({index_name}, {doc_id}) failed: {e}")
|
|
if re.search(r"(timeout|connection)", str(e).lower()):
|
|
time.sleep(1)
|
|
continue
|
|
return False
|
|
return False
|
|
|
|
def delete_idx(self, indexName: str, knowledgebaseId: str):
|
|
if len(knowledgebaseId) > 0:
|
|
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
|
|
return
|
|
try:
|
|
self.os.indices.delete(index=indexName, allow_no_indices=True)
|
|
except NotFoundError:
|
|
pass
|
|
except Exception:
|
|
logger.exception("OSConnection.deleteIdx error %s" % (indexName))
|
|
|
|
def index_exist(self, indexName: str, knowledgebaseId: str = None) -> bool:
|
|
s = Index(indexName, self.os)
|
|
for i in range(ATTEMPT_TIME):
|
|
try:
|
|
return s.exists()
|
|
except Exception as e:
|
|
logger.exception("OSConnection.indexExist got exception")
|
|
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
|
continue
|
|
break
|
|
return False
|
|
|
|
"""
|
|
CRUD operations
|
|
"""
|
|
|
|
def search(
|
|
self, select_fields: list[str],
|
|
highlight_fields: list[str],
|
|
condition: dict,
|
|
match_expressions: list[MatchExpr],
|
|
order_by: OrderByExpr,
|
|
offset: int,
|
|
limit: int,
|
|
index_names: str | list[str],
|
|
knowledgebase_ids: list[str],
|
|
agg_fields: list[str] = [],
|
|
rank_feature: dict | None = None
|
|
):
|
|
"""
|
|
Refers to https://github.com/opensearch-project/opensearch-py/blob/main/guides/dsl.md
|
|
"""
|
|
use_knn = False
|
|
use_text = False
|
|
if isinstance(index_names, str):
|
|
index_names = index_names.split(",")
|
|
assert isinstance(index_names, list) and len(index_names) > 0
|
|
assert "_id" not in condition
|
|
|
|
bqry = Q("bool", must=[])
|
|
condition["kb_id"] = knowledgebase_ids
|
|
for k, v in condition.items():
|
|
if k == "available_int":
|
|
if v == 0:
|
|
bqry.filter.append(Q("range", available_int={"lt": 1}))
|
|
else:
|
|
bqry.filter.append(
|
|
Q("bool", must_not=Q("range", available_int={"lt": 1})))
|
|
continue
|
|
if not v:
|
|
continue
|
|
if isinstance(v, list):
|
|
bqry.filter.append(Q("terms", **{k: v}))
|
|
elif isinstance(v, str) or isinstance(v, int):
|
|
bqry.filter.append(Q("term", **{k: v}))
|
|
else:
|
|
raise Exception(
|
|
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
|
|
|
s = Search()
|
|
vector_similarity_weight = 0.5
|
|
for m in match_expressions:
|
|
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
|
|
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1],
|
|
MatchDenseExpr) and isinstance(
|
|
match_expressions[2], FusionExpr)
|
|
weights = m.fusion_params["weights"]
|
|
vector_similarity_weight = float(weights.split(",")[1])
|
|
knn_query = {}
|
|
for m in match_expressions:
|
|
if isinstance(m, MatchTextExpr):
|
|
use_text = True
|
|
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
|
|
if isinstance(minimum_should_match, float):
|
|
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
|
|
bqry.must.append(Q("query_string", fields=m.fields,
|
|
type="best_fields", query=m.matching_text,
|
|
minimum_should_match=minimum_should_match,
|
|
boost=1))
|
|
bqry.boost = 1.0 - vector_similarity_weight
|
|
|
|
# Elasticsearch has the encapsulation of KNN_search in python sdk
|
|
# while the Python SDK for OpenSearch does not provide encapsulation for KNN_search,
|
|
# the following codes implement KNN_search in OpenSearch using DSL
|
|
# Besides, Opensearch's DSL for KNN_search query syntax differs from that in Elasticsearch, I also made some adaptions for it
|
|
elif isinstance(m, MatchDenseExpr):
|
|
assert (bqry is not None)
|
|
similarity = 0.0
|
|
if "similarity" in m.extra_options:
|
|
similarity = m.extra_options["similarity"]
|
|
use_knn = True
|
|
vector_column_name = m.vector_column_name
|
|
knn_query[vector_column_name] = {}
|
|
knn_query[vector_column_name]["vector"] = list(m.embedding_data)
|
|
knn_query[vector_column_name]["k"] = m.topn
|
|
# The knn filter holds only the structural filters (kb_id,
|
|
# available_int, ...). The text query is deliberately kept out of it:
|
|
# it's scored as its own leg in the hybrid query below, not used to
|
|
# pre-filter knn candidates.
|
|
bool_inner = bqry.to_dict().get("bool", {})
|
|
if bool_inner.get("filter"):
|
|
knn_query[vector_column_name]["filter"] = {"bool": {"filter": bool_inner["filter"]}}
|
|
knn_query[vector_column_name]["boost"] = similarity
|
|
|
|
if bqry and rank_feature:
|
|
for fld, sc in rank_feature.items():
|
|
if fld != PAGERANK_FLD:
|
|
fld = f"{TAG_FLD}.{fld}"
|
|
bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
|
|
|
|
if bqry:
|
|
s = s.query(bqry)
|
|
for field in highlight_fields:
|
|
s = s.highlight(field, force_source=True, no_match_size=30, require_field_match=False)
|
|
|
|
if order_by:
|
|
orders = list()
|
|
for field, order in order_by.fields:
|
|
order = "asc" if order == 0 else "desc"
|
|
if field in ["page_num_int", "top_int"]:
|
|
order_info = {"order": order, "unmapped_type": "float",
|
|
"mode": "avg", "numeric_type": "double"}
|
|
elif field.endswith("_int") or field.endswith("_flt"):
|
|
order_info = {"order": order, "unmapped_type": "float"}
|
|
else:
|
|
order_info = {"order": order, "unmapped_type": "text"}
|
|
orders.append({field: order_info})
|
|
s = s.sort(*orders)
|
|
|
|
for fld in agg_fields:
|
|
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
|
|
|
|
if limit > 0:
|
|
s = s[offset:offset + limit]
|
|
q = s.to_dict()
|
|
logger.debug(f"OSConnection.search {str(index_names)} query: " + json.dumps(q))
|
|
|
|
hybrid_search = use_knn and use_text and getattr(self, "hybrid_search_enabled", False)
|
|
if use_knn:
|
|
if hybrid_search:
|
|
# both legs + a pipeline available: send a real hybrid query so the
|
|
# keyword (BM25) and vector (knn) legs are scored separately and
|
|
# merged by the pipeline.
|
|
keyword_query = q.get("query")
|
|
q["query"] = {"hybrid": {"queries": [keyword_query, {"knn": knn_query}]}}
|
|
else:
|
|
# vector-only, or no pipeline available: fall back to a plain knn query.
|
|
del q["query"]
|
|
q["query"] = {"knn": knn_query}
|
|
|
|
search_kwargs = {}
|
|
if hybrid_search:
|
|
search_kwargs["params"] = {"search_pipeline": self._hybrid_pipeline}
|
|
|
|
for i in range(ATTEMPT_TIME):
|
|
try:
|
|
res = self.os.search(index=index_names,
|
|
body=q,
|
|
timeout=600,
|
|
# search_type="dfs_query_then_fetch",
|
|
track_total_hits=True,
|
|
_source=True,
|
|
**search_kwargs)
|
|
if str(res.get("timed_out", "")).lower() == "true":
|
|
raise Exception("OpenSearch Timeout.")
|
|
logger.debug(f"OSConnection.search {str(index_names)} res: " + str(res))
|
|
return res
|
|
except Exception as e:
|
|
logger.exception(f"OSConnection.search {str(index_names)} query: " + str(q))
|
|
if str(e).find("Timeout") > 0:
|
|
continue
|
|
raise e
|
|
logger.error(f"OSConnection.search timeout for {ATTEMPT_TIME} times!")
|
|
raise Exception("OSConnection.search timeout.")
|
|
|
|
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
|
for i in range(ATTEMPT_TIME):
|
|
try:
|
|
res = self.os.get(index=(indexName),
|
|
id=chunkId, _source=True, )
|
|
if str(res.get("timed_out", "")).lower() == "true":
|
|
raise Exception("Es Timeout.")
|
|
chunk = res["_source"]
|
|
chunk["id"] = chunkId
|
|
return chunk
|
|
except NotFoundError:
|
|
return None
|
|
except Exception as e:
|
|
logger.exception(f"OSConnection.get({chunkId}) got exception")
|
|
if str(e).find("Timeout") > 0:
|
|
continue
|
|
raise e
|
|
logger.error(f"OSConnection.get timeout for {ATTEMPT_TIME} times!")
|
|
raise Exception("OSConnection.get timeout.")
|
|
|
|
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
|
|
# Refers to https://opensearch.org/docs/latest/api-reference/document-apis/bulk/
|
|
operations = []
|
|
for d in documents:
|
|
assert "_id" not in d
|
|
assert "id" in d
|
|
d_copy = copy.deepcopy(d)
|
|
meta_id = d_copy.pop("id", "")
|
|
operations.append(
|
|
{"index": {"_index": indexName, "_id": meta_id}})
|
|
operations.append(d_copy)
|
|
|
|
res = []
|
|
for _ in range(ATTEMPT_TIME):
|
|
try:
|
|
res = []
|
|
r = self.os.bulk(index=(indexName), body=operations,
|
|
refresh="wait_for", timeout=60)
|
|
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
|
return res
|
|
|
|
for item in r["items"]:
|
|
for action in ["create", "delete", "index", "update"]:
|
|
if action in item and "error" in item[action]:
|
|
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
|
|
return res
|
|
except Exception as e:
|
|
res.append(str(e))
|
|
logger.warning("OSConnection.insert got exception: " + str(e))
|
|
res = []
|
|
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
|
res.append(str(e))
|
|
time.sleep(3)
|
|
continue
|
|
return res
|
|
|
|
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
|
doc = copy.deepcopy(newValue)
|
|
doc.pop("id", None)
|
|
if "id" in condition and isinstance(condition["id"], str):
|
|
# update specific single document
|
|
chunkId = condition["id"]
|
|
for i in range(ATTEMPT_TIME):
|
|
doc_part = copy.deepcopy(doc)
|
|
remove_value = doc_part.pop("remove", None)
|
|
remove_field = remove_value if isinstance(remove_value, str) else None
|
|
remove_dict = remove_value if isinstance(remove_value, dict) else None
|
|
try:
|
|
if remove_field is not None:
|
|
self.os.update(
|
|
index=indexName,
|
|
id=chunkId,
|
|
body={"script": {"source": f"ctx._source.remove('{remove_field}');"}},
|
|
)
|
|
if remove_dict is not None:
|
|
scripts = []
|
|
params = {}
|
|
for kk, vv in remove_dict.items():
|
|
scripts.append(
|
|
f"if (ctx._source.containsKey('{kk}') && ctx._source.{kk} != null) "
|
|
f"{{ int i = ctx._source.{kk}.indexOf(params.p_{kk}); "
|
|
f"if (i >= 0) {{ ctx._source.{kk}.remove(i); }} }}"
|
|
)
|
|
params[f"p_{kk}"] = vv
|
|
if scripts:
|
|
self.os.update(
|
|
index=indexName,
|
|
id=chunkId,
|
|
body={"script": {"source": "".join(scripts), "params": params}},
|
|
)
|
|
if doc_part:
|
|
self.os.update(index=indexName, id=chunkId, body={"doc": doc_part})
|
|
if remove_field is not None or remove_dict is not None or doc_part:
|
|
return True
|
|
except Exception as e:
|
|
logger.exception(
|
|
f"OSConnection.update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
|
|
if re.search(r"(timeout|connection)", str(e).lower()):
|
|
continue
|
|
break
|
|
return False
|
|
|
|
# update unspecific maybe-multiple documents
|
|
bqry = Q("bool")
|
|
for k, v in condition.items():
|
|
if not isinstance(k, str) or not v:
|
|
continue
|
|
if k == "exists":
|
|
bqry.filter.append(Q("exists", field=v))
|
|
continue
|
|
if isinstance(v, list):
|
|
bqry.filter.append(Q("terms", **{k: v}))
|
|
elif isinstance(v, str) or isinstance(v, int):
|
|
bqry.filter.append(Q("term", **{k: v}))
|
|
else:
|
|
raise Exception(
|
|
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
|
scripts = []
|
|
params = {}
|
|
for k, v in newValue.items():
|
|
if k == "remove":
|
|
if isinstance(v, str):
|
|
scripts.append(f"ctx._source.remove('{v}');")
|
|
if isinstance(v, dict):
|
|
for kk, vv in v.items():
|
|
scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);")
|
|
params[f"p_{kk}"] = vv
|
|
continue
|
|
if k == "add":
|
|
if isinstance(v, dict):
|
|
for kk, vv in v.items():
|
|
scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});")
|
|
params[f"pp_{kk}"] = vv.strip()
|
|
continue
|
|
if (not isinstance(k, str) or not v) and k != "available_int":
|
|
continue
|
|
if isinstance(v, str):
|
|
v = re.sub(r"(['\n\r]|\\.)", " ", v)
|
|
params[f"pp_{k}"] = v
|
|
scripts.append(f"ctx._source.{k}=params.pp_{k};")
|
|
elif isinstance(v, int) or isinstance(v, float):
|
|
scripts.append(f"ctx._source.{k}={v};")
|
|
elif isinstance(v, list):
|
|
scripts.append(f"ctx._source.{k}=params.pp_{k};")
|
|
params[f"pp_{k}"] = json.dumps(v, ensure_ascii=False)
|
|
else:
|
|
raise Exception(
|
|
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
|
|
ubq = UpdateByQuery(
|
|
index=indexName).using(
|
|
self.os).query(bqry)
|
|
ubq = ubq.script(source="".join(scripts), params=params)
|
|
ubq = ubq.params(refresh=True)
|
|
ubq = ubq.params(slices=5)
|
|
ubq = ubq.params(conflicts="proceed")
|
|
|
|
for _ in range(ATTEMPT_TIME):
|
|
try:
|
|
_ = ubq.execute()
|
|
return True
|
|
except Exception as e:
|
|
logger.error("OSConnection.update got exception: " + str(e) + "\n".join(scripts))
|
|
if re.search(r"(timeout|connection|conflict)", str(e).lower()):
|
|
continue
|
|
break
|
|
return False
|
|
|
|
def adjust_chunk_pagerank_fea(
|
|
self,
|
|
chunk_id: str,
|
|
indexName: str,
|
|
knowledgebaseId: str,
|
|
delta: float,
|
|
min_w: float = 0.0,
|
|
max_w: float = 100.0,
|
|
row_id: int | None = None,
|
|
) -> bool:
|
|
"""Atomically adjust pagerank_fea on one chunk (painless script)."""
|
|
_ = row_id
|
|
try:
|
|
self.os.update(
|
|
index=indexName,
|
|
id=chunk_id,
|
|
retry_on_conflict=3,
|
|
body={
|
|
"script": {
|
|
"source": _PAGERANK_FEA_ADJUST_SCRIPT.strip(),
|
|
"lang": "painless",
|
|
"params": {
|
|
"pf": PAGERANK_FLD,
|
|
"delta": float(delta),
|
|
"min_w": float(min_w),
|
|
"max_w": float(max_w),
|
|
},
|
|
}
|
|
},
|
|
)
|
|
logger.debug(
|
|
"OSConnection.adjust_chunk_pagerank_fea(index=%s, id=%s, delta=%s) succeeded",
|
|
indexName,
|
|
chunk_id,
|
|
delta,
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.exception(
|
|
"OSConnection.adjust_chunk_pagerank_fea(index=%s, id=%s): %s",
|
|
indexName,
|
|
chunk_id,
|
|
e,
|
|
)
|
|
return False
|
|
|
|
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
|
assert "_id" not in condition
|
|
condition["kb_id"] = knowledgebaseId
|
|
|
|
# Build a bool query that combines id filter with other conditions
|
|
bool_query = Q("bool")
|
|
|
|
# Handle chunk IDs if present
|
|
if "id" in condition:
|
|
chunk_ids = condition["id"]
|
|
if not isinstance(chunk_ids, list):
|
|
chunk_ids = [chunk_ids]
|
|
if chunk_ids:
|
|
# Filter by specific chunk IDs
|
|
bool_query.filter.append(Q("ids", values=chunk_ids))
|
|
# If chunk_ids is empty, we don't add an ids filter - rely on other conditions
|
|
|
|
# Add all other conditions as filters
|
|
for k, v in condition.items():
|
|
if k == "id":
|
|
continue # Already handled above
|
|
if k == "exists":
|
|
bool_query.filter.append(Q("exists", field=v))
|
|
elif k == "must_not":
|
|
if isinstance(v, dict):
|
|
for kk, vv in v.items():
|
|
if kk == "exists":
|
|
bool_query.must_not.append(Q("exists", field=vv))
|
|
elif isinstance(v, list):
|
|
bool_query.must.append(Q("terms", **{k: v}))
|
|
elif isinstance(v, str) or isinstance(v, int):
|
|
bool_query.must.append(Q("term", **{k: v}))
|
|
elif v is not None:
|
|
raise Exception("Condition value must be int, str or list.")
|
|
|
|
# If no filters were added, use match_all (for tenant-wide operations)
|
|
if not bool_query.filter and not bool_query.must and not bool_query.must_not:
|
|
qry = Q("match_all")
|
|
else:
|
|
qry = bool_query
|
|
logger.debug("OSConnection.delete query: " + json.dumps(qry.to_dict()))
|
|
for _ in range(ATTEMPT_TIME):
|
|
try:
|
|
# print(Search().query(qry).to_dict(), flush=True)
|
|
res = self.os.delete_by_query(
|
|
index=indexName,
|
|
body=Search().query(qry).to_dict(),
|
|
refresh=True)
|
|
return res["deleted"]
|
|
except Exception as e:
|
|
logger.warning("OSConnection.delete got exception: " + str(e))
|
|
if re.search(r"(timeout|connection)", str(e).lower()):
|
|
time.sleep(3)
|
|
continue
|
|
if re.search(r"(not_found)", str(e), re.IGNORECASE):
|
|
return 0
|
|
return 0
|
|
|
|
"""
|
|
Helper functions for search result
|
|
"""
|
|
|
|
def get_total(self, res):
|
|
if isinstance(res["hits"]["total"], type({})):
|
|
return res["hits"]["total"]["value"]
|
|
return res["hits"]["total"]
|
|
|
|
def get_doc_ids(self, res):
|
|
return [d["_id"] for d in res["hits"]["hits"]]
|
|
|
|
def get_scores(self, res) -> dict[str, float]:
|
|
"""
|
|
Map hit `_id` to its raw `_score`. Used by rag/nlp/search.py:_knn_scores()
|
|
to recover the cosine similarity returned by a KNN-only second-pass search
|
|
without pulling the chunk vectors out of the index. OpenSearch hit headers
|
|
carry `_score` exactly like Elasticsearch, so this mirrors
|
|
ESConnectionBase.get_scores.
|
|
"""
|
|
out = {}
|
|
for d in res.get("hits", {}).get("hits", []):
|
|
doc_id = d.get("_id")
|
|
if doc_id is None:
|
|
continue
|
|
score = d.get("_score")
|
|
out[doc_id] = float(score) if score is not None else 0.0
|
|
return out
|
|
|
|
def __getSource(self, res):
|
|
rr = []
|
|
for d in res["hits"]["hits"]:
|
|
d["_source"]["id"] = d["_id"]
|
|
d["_source"]["_score"] = d["_score"]
|
|
rr.append(d["_source"])
|
|
return rr
|
|
|
|
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
|
res_fields = {}
|
|
if not fields:
|
|
return {}
|
|
for d in self.__getSource(res):
|
|
m = {n: d.get(n) for n in fields if d.get(n) is not None}
|
|
for n, v in m.items():
|
|
if isinstance(v, list):
|
|
m[n] = v
|
|
continue
|
|
if not isinstance(v, str):
|
|
m[n] = str(m[n])
|
|
# if n.find("tks") > 0:
|
|
# m[n] = remove_redundant_spaces(m[n])
|
|
|
|
if m:
|
|
res_fields[d["id"]] = m
|
|
return res_fields
|
|
|
|
def get_highlight(self, res, keywords: list[str], fieldnm: str):
|
|
ans = {}
|
|
for d in res["hits"]["hits"]:
|
|
hlts = d.get("highlight")
|
|
if not hlts:
|
|
continue
|
|
txt = "...".join([a for a in list(hlts.items())[0][1]])
|
|
if not is_english(txt.split()):
|
|
ans[d["_id"]] = txt
|
|
continue
|
|
|
|
txt = d["_source"][fieldnm]
|
|
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
|
|
txts = []
|
|
for t in re.split(r"[.?!;\n]", txt):
|
|
for w in keywords:
|
|
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
|
|
flags=re.IGNORECASE | re.MULTILINE)
|
|
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
|
|
continue
|
|
txts.append(t)
|
|
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
|
|
|
|
return ans
|
|
|
|
def get_aggregation(self, res, fieldnm: str):
|
|
agg_field = "aggs_" + fieldnm
|
|
if "aggregations" not in res or agg_field not in res["aggregations"]:
|
|
return list()
|
|
bkts = res["aggregations"][agg_field]["buckets"]
|
|
return [(b["key"], b["doc_count"]) for b in bkts]
|
|
|
|
"""
|
|
SQL
|
|
"""
|
|
|
|
def sql(self, sql: str, fetch_size: int, format: str):
|
|
logger.debug(f"OSConnection.sql get sql: {sql}")
|
|
sql = re.sub(r"[ `]+", " ", sql)
|
|
sql = sql.replace("%", "")
|
|
replaces = []
|
|
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
|
|
fld, v = r.group(1), r.group(3)
|
|
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
|
|
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
|
|
replaces.append(
|
|
("{}{}'{}'".format(
|
|
r.group(1),
|
|
r.group(2),
|
|
r.group(3)),
|
|
match))
|
|
|
|
for p, r in replaces:
|
|
sql = sql.replace(p, r, 1)
|
|
logger.debug(f"OSConnection.sql to os: {sql}")
|
|
|
|
for i in range(ATTEMPT_TIME):
|
|
try:
|
|
res = self.os.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
|
|
request_timeout="2s")
|
|
return res
|
|
except ConnectionTimeout:
|
|
logger.exception("OSConnection.sql timeout")
|
|
continue
|
|
except Exception:
|
|
logger.exception("OSConnection.sql got exception")
|
|
return None
|
|
logger.error(f"OSConnection.sql timeout for {ATTEMPT_TIME} times!")
|
|
return None
|