diff --git a/common/doc_store/infinity_conn_pool.py b/common/doc_store/infinity_conn_pool.py index 1aa3f81254..83ea4d51ff 100644 --- a/common/doc_store/infinity_conn_pool.py +++ b/common/doc_store/infinity_conn_pool.py @@ -14,6 +14,7 @@ # limitations under the License. # import logging +import os import time import infinity @@ -37,30 +38,47 @@ class InfinityConnectionPool: "db_name": "default_db" }) + raw_pool_max_size = os.environ.get("INFINITY_POOL_MAX_SIZE", "4") + try: + self.pool_max_size = int(raw_pool_max_size) + except ValueError as e: + raise ValueError("INFINITY_POOL_MAX_SIZE must be a positive integer") from e + if self.pool_max_size < 1: + raise ValueError("INFINITY_POOL_MAX_SIZE must be >= 1") + infinity_uri = self.INFINITY_CONFIG["uri"] if ":" in infinity_uri: host, port = infinity_uri.split(":") self.infinity_uri = infinity.common.NetworkAddress(host, int(port)) + self.conn_pool = None for _ in range(24): + conn_pool = None + inf_conn = None try: - conn_pool = ConnectionPool(self.infinity_uri, max_size=4) + conn_pool = ConnectionPool(self.infinity_uri, max_size=self.pool_max_size) inf_conn = conn_pool.get_conn() res = inf_conn.show_current_node() if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]: self.conn_pool = conn_pool - conn_pool.release_conn(inf_conn) break + logging.warning(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.") + time.sleep(5) except Exception as e: logging.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.") time.sleep(5) + finally: + if inf_conn is not None and conn_pool is not None: + conn_pool.release_conn(inf_conn) + if conn_pool is not None and conn_pool is not self.conn_pool: + conn_pool.destroy() if self.conn_pool is None: msg = f"Infinity {infinity_uri} is unhealthy in 120s." logging.error(msg) raise Exception(msg) - logging.info(f"Infinity {infinity_uri} is healthy.") + logging.info(f"Infinity {infinity_uri} is healthy. Connection pool max_size={self.pool_max_size}") def get_conn_pool(self): return self.conn_pool @@ -91,7 +109,7 @@ class InfinityConnectionPool: logging.error(str(e)) if hasattr(self, "conn_pool") and self.conn_pool: self.conn_pool.destroy() - self.conn_pool = ConnectionPool(self.infinity_uri, max_size=32) + self.conn_pool = ConnectionPool(self.infinity_uri, max_size=self.pool_max_size) return self.conn_pool def __del__(self): diff --git a/rag/graphrag/utils.py b/rag/graphrag/utils.py index 1c2b3cbea3..1d8d2a1dd2 100644 --- a/rag/graphrag/utils.py +++ b/rag/graphrag/utils.py @@ -457,13 +457,24 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang if change.removed_edges: async def del_edges(from_node, to_node): - async with chat_limiter: - await thread_pool_exec( - settings.docStoreConn.delete, - {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, - search.index_name(tenant_id), - kb_id - ) + max_retries = 3 + for attempt in range(max_retries): + try: + async with chat_limiter: + await thread_pool_exec( + settings.docStoreConn.delete, + {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, + search.index_name(tenant_id), + kb_id + ) + return + except Exception as e: + if attempt < max_retries - 1: + wait = 2 ** attempt + logging.warning(f"del_edges({from_node}, {to_node}) attempt {attempt + 1} failed: {e}, retrying in {wait}s") + await asyncio.sleep(wait) + else: + raise tasks = [] for from_node, to_node in change.removed_edges: @@ -558,15 +569,40 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang es_bulk_size = 4 for b in range(0, len(chunks), es_bulk_size): timeout = 3 if enable_timeout_assertion else 30000000 - doc_store_result = await asyncio.wait_for( - thread_pool_exec( - settings.docStoreConn.insert, - chunks[b : b + es_bulk_size], - search.index_name(tenant_id), - kb_id - ), - timeout=timeout - ) + max_retries = 3 + for attempt in range(max_retries): + task = asyncio.create_task( + thread_pool_exec( + settings.docStoreConn.insert, + chunks[b : b + es_bulk_size], + search.index_name(tenant_id), + kb_id + ) + ) + try: + doc_store_result = await asyncio.wait_for(task, timeout=timeout) + break + except asyncio.TimeoutError: + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + if attempt < max_retries - 1: + wait = 2 ** attempt + logging.warning(f"Insert batch {b}/{len(chunks)} attempt {attempt + 1} timed out, retrying in {wait}s") + await asyncio.sleep(wait) + else: + raise + except asyncio.CancelledError: + raise + except Exception as e: + if attempt < max_retries - 1: + wait = 2 ** attempt + logging.warning(f"Insert batch {b}/{len(chunks)} attempt {attempt + 1} failed: {e}, retrying in {wait}s") + await asyncio.sleep(wait) + else: + raise if b % 100 == es_bulk_size and callback: callback(msg=f"Insert chunks: {b}/{len(chunks)}") if doc_store_result: