diff --git a/rag/graphrag/general/extractor.py b/rag/graphrag/general/extractor.py index 3328604b67..2c0dae54d0 100644 --- a/rag/graphrag/general/extractor.py +++ b/rag/graphrag/general/extractor.py @@ -62,11 +62,26 @@ class Extractor: self._language = language self._entity_types = entity_types or DEFAULT_ENTITY_TYPES + @staticmethod + def _normalize_response_text(response): + if isinstance(response, (list, tuple)): + response = response[0] if response else "" + if response is None: + return "" + return response if isinstance(response, str) else str(response) + + @staticmethod + def _is_truncated_cache(response): + return len((response or "").strip()) <= 1 + @timeout(60 * 20) def _chat(self, system, history, gen_conf={}, task_id=""): hist = deepcopy(history) conf = deepcopy(gen_conf) response = get_llm_cache(self._llm.llm_name, system, hist, conf) + response = self._normalize_response_text(response) + if self._is_truncated_cache(response): + response = "" if response: return response _, system_msg = message_fit_in([{"role": "system", "content": system}], int(self._llm.max_length * 0.92)) @@ -78,10 +93,12 @@ class Extractor: raise TaskCanceledException(f"Task {task_id} was cancelled") try: response = asyncio.run(self._llm.async_chat(system_msg[0]["content"], hist, conf)) - response = re.sub(r"^.*", "", response[0], flags=re.DOTALL) + response = self._normalize_response_text(response) + response = re.sub(r"^.*", "", response, flags=re.DOTALL) if response.find("**ERROR**") >= 0: raise Exception(response) - set_llm_cache(self._llm.llm_name, system, response, history, gen_conf) + if not self._is_truncated_cache(response): + set_llm_cache(self._llm.llm_name, system, response, history, gen_conf) break except Exception as e: logging.exception(e)