diff --git a/rag/nlp/search.py b/rag/nlp/search.py index f37ce24572..23e86cb9db 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -60,6 +60,58 @@ class Dealer: vector_column_name = f"q_{len(embedding_data)}_vec" return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity}) + async def _existing_doc_ids(self, doc_ids: list[str]) -> set[str]: + if not doc_ids: + return set() + + unique_doc_ids = list(dict.fromkeys(doc_ids)) + + def _load(): + from api.db.services.document_service import DocumentService + + return {row["id"] for row in DocumentService.get_by_ids(unique_doc_ids).dicts()} + + return await thread_pool_exec(_load) + + async def _prune_deleted_chunks(self, sres: SearchResult) -> SearchResult: + chunk_doc_ids = [chunk.get("doc_id") for chunk in sres.field.values() if chunk and chunk.get("doc_id")] + if not chunk_doc_ids: + return sres + + existing_doc_ids = await self._existing_doc_ids(chunk_doc_ids) + if len(existing_doc_ids) == len(set(chunk_doc_ids)): + return sres + + filtered_ids = [] + filtered_field = {} + filtered_highlight = {} if sres.highlight else sres.highlight + removed = 0 + + for chunk_id in sres.ids: + chunk = sres.field.get(chunk_id) + if not chunk or chunk.get("doc_id") not in existing_doc_ids: + removed += 1 + continue + + filtered_ids.append(chunk_id) + filtered_field[chunk_id] = chunk + if sres.highlight and chunk_id in sres.highlight: + filtered_highlight[chunk_id] = sres.highlight[chunk_id] + + if removed: + logging.warning("Pruned %s stale chunks whose documents no longer exist.", removed) + + return self.SearchResult( + total=len(filtered_ids), + ids=filtered_ids, + query_vector=sres.query_vector, + field=filtered_field, + highlight=filtered_highlight, + aggregation=sres.aggregation, + keywords=sres.keywords, + group_docs=sres.group_docs, + ) + def get_filters(self, req): condition = dict() for key, field in {"kb_ids": "kb_id", "doc_ids": "doc_id"}.items(): @@ -436,6 +488,10 @@ class Dealer: sres = await self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, rank_feature=rank_feature) + sres = await self._prune_deleted_chunks(sres) + if sres.total == 0: + ranks["doc_aggs"] = [] + return ranks if rerank_mdl and sres.total > 0: sim, tsim, vsim = self.rerank_by_model(