diff --git a/rag/advanced_rag/tree_structured_query_decomposition_retrieval.py b/rag/advanced_rag/tree_structured_query_decomposition_retrieval.py index 11af6aa46b..38d9f9808b 100644 --- a/rag/advanced_rag/tree_structured_query_decomposition_retrieval.py +++ b/rag/advanced_rag/tree_structured_query_decomposition_retrieval.py @@ -41,9 +41,10 @@ class TreeStructuredQueryDecompositionRetrieval: async def _retrieve_information(self, search_query): """Retrieve information from different sources""" # 1. Knowledge base retrieval - kbinfos = [] + kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} try: - kbinfos = await self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []} + kbinfos = await self._kb_retrieve(question=search_query) if self._kb_retrieve else {"total": 0, "chunks": [], "doc_aggs": []} + kbinfos.setdefault("total", 0) except Exception as e: logging.error(f"Knowledge base retrieval error: {e}") @@ -87,12 +88,18 @@ class TreeStructuredQueryDecompositionRetrieval: if d["doc_id"] not in dids: chunk_info["doc_aggs"].append(d) + chunk_info["total"] = chunk_info.get("total", 0) + kbinfos.get("total", 0) + async def research(self, chunk_info, question, query, depth=3, callback=None): if callback: await callback("") - await self._research(chunk_info, question, query, depth, callback) - if callback: - await callback("") + try: + await self._research(chunk_info, question, query, depth, callback) + except Exception: + logging.exception("Unhandled exception in deep research for query: %s", query) + finally: + if callback: + await callback("") async def _research(self, chunk_info, question, query, depth=3, callback=None): if depth == 0: @@ -111,14 +118,14 @@ class TreeStructuredQueryDecompositionRetrieval: if callback: await callback("Checking the sufficiency for retrieved information.") suff = await sufficiency_check(self.chat_mdl, question, ret) - if suff["is_sufficient"]: + if suff.get("is_sufficient"): if callback: await callback(f"Yes, the retrieved information is sufficient for '{question}'.") return ret #if callback: # await callback("The retrieved information is not sufficient. Planing next steps...") - succ_question_info = await multi_queries_gen(self.chat_mdl, question, query, suff["missing_information"], ret) + succ_question_info = await multi_queries_gen(self.chat_mdl, question, query, suff.get("missing_information", []), ret) if callback: await callback("Next step is to search for the following questions:
- " + "
- ".join(step["question"] for step in succ_question_info["questions"])) steps = []