From 7e6844118b699fbda40b33ed9fdaccace01e3baa Mon Sep 17 00:00:00 2001 From: Wang Qi Date: Fri, 22 May 2026 16:05:13 +0800 Subject: [PATCH] Fix search vector_similarity_weight (#15108) ### What problem does this PR solve? Fix search vector_similarity_weight ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/apps/restful_apis/search_api.py | 27 ++++++- api/apps/services/dataset_api_service.py | 80 +++++++++++++++----- api/db/services/dialog_service.py | 20 ++++- rag/nlp/search.py | 19 ++++- web/src/pages/next-search/search-setting.tsx | 6 +- 5 files changed, 124 insertions(+), 28 deletions(-) diff --git a/api/apps/restful_apis/search_api.py b/api/apps/restful_apis/search_api.py index 7755704e4d..2693dad95b 100644 --- a/api/apps/restful_apis/search_api.py +++ b/api/apps/restful_apis/search_api.py @@ -16,6 +16,7 @@ import json import logging +from numbers import Real from quart import Response, request from api.db.services.dialog_service import async_ask @@ -31,6 +32,12 @@ from common.constants import RetCode, StatusEnum from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request +def _full_text_weight(vector_similarity_weight): + if isinstance(vector_similarity_weight, Real): + return 1 - vector_similarity_weight + return None + + @manager.route("/searches", methods=["POST"]) # noqa: F821 @login_required @validate_request("name") @@ -158,6 +165,16 @@ async def update(search_id): if not isinstance(new_config, dict): return get_data_error_result(message="search_config must be a JSON object") req["search_config"] = {**current_config, **new_config} + logging.debug( + "Search update weight: search_id=%s user_id=%s " + "incoming_vector_similarity_weight=%s stored_vector_similarity_weight=%s " + "stored_full_text_weight=%s", + search_id, + current_user.id, + new_config.get("vector_similarity_weight"), + req["search_config"].get("vector_similarity_weight"), + _full_text_weight(req["search_config"].get("vector_similarity_weight", 0.3)), + ) for field in ("search_id", "tenant_id", "created_by", "update_time", "id"): req.pop(field, None) @@ -209,6 +226,14 @@ async def completion(search_id): return get_data_error_result(message=f"Cannot find search {search_id}") search_config = search_app.get("search_config", {}) + logging.debug( + "Search completion loaded weight: search_id=%s user_id=%s " + "stored_vector_similarity_weight=%s stored_full_text_weight=%s", + search_id, + uid, + search_config.get("vector_similarity_weight", 0.3), + _full_text_weight(search_config.get("vector_similarity_weight", 0.3)), + ) kb_ids = search_config.get("kb_ids") or req.get("kb_ids") or [] if not kb_ids: return get_data_error_result(message="`kb_ids` is required.") @@ -216,7 +241,7 @@ async def completion(search_id): async def stream(): nonlocal req, uid, kb_ids, search_config try: - async for ans in async_ask(req["question"], kb_ids, uid, search_config=search_config): + async for ans in async_ask(req["question"], kb_ids, uid, search_config=search_config, search_id=search_id): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as ex: yield "data:" + json.dumps( diff --git a/api/apps/services/dataset_api_service.py b/api/apps/services/dataset_api_service.py index 8d5f512a35..42c39364c5 100644 --- a/api/apps/services/dataset_api_service.py +++ b/api/apps/services/dataset_api_service.py @@ -938,6 +938,8 @@ async def search(dataset_id: str, tenant_id: str, req: dict): question = req.get("question", "") doc_ids = req.get("doc_ids", []) use_kg = req.get("use_kg", False) + similarity_threshold = float(req.get("similarity_threshold", 0.0)) + vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) top = max(1, min(int(req.get("top_k", 1024)), 2048)) langs = req.get("cross_languages", []) @@ -955,14 +957,31 @@ async def search(dataset_id: str, tenant_id: str, req: dict): local_doc_ids = list(doc_ids) if doc_ids else [] meta_data_filter = {} + search_id = req.get("search_id", "") + search_config = {} chat_mdl = None - if req.get("search_id", ""): - search_detail = SearchService.get_detail(req.get("search_id", "")) + if search_id: + search_detail = SearchService.get_detail(search_id) if not search_detail: - logging.warning("search config not found: search_id=%s", req.get("search_id", "")) + logging.warning("search config not found: search_id=%s", search_id) return False, "Invalid search_id" search_config = search_detail.get("search_config", {}) meta_data_filter = search_config.get("meta_data_filter", {}) + similarity_threshold = float(search_config.get("similarity_threshold", similarity_threshold)) + vector_similarity_weight = float(search_config.get("vector_similarity_weight", vector_similarity_weight)) + top = max(1, min(int(search_config.get("top_k", top)), 2048)) + use_kg = search_config.get("use_kg", use_kg) + langs = search_config.get("cross_languages", langs) + logging.debug( + "Dataset search loaded Search config: search_id=%s dataset_id=%s " + "vector_similarity_weight=%s full_text_weight=%s similarity_threshold=%s top_k=%s", + search_id, + dataset_id, + vector_similarity_weight, + 1 - vector_similarity_weight, + similarity_threshold, + top, + ) if meta_data_filter.get("method") in ["auto", "semi_auto"]: chat_id = search_config.get("chat_id", "") if chat_id: @@ -1016,11 +1035,13 @@ async def search(dataset_id: str, tenant_id: str, req: dict): requester_tenant_id=tenant_id, ) rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) - elif req.get("rerank_id"): - rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"]) - rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) + else: + rerank_id = search_config.get("rerank_id") or req.get("rerank_id") + if rerank_id: + rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, rerank_id) + rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) - if req.get("keyword", False): + if search_config.get("keyword", req.get("keyword", False)): default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config) _question += await keyword_extraction(chat_mdl, _question) @@ -1033,12 +1054,13 @@ async def search(dataset_id: str, tenant_id: str, req: dict): [dataset_id], page, size, - float(req.get("similarity_threshold", 0.0)), - float(req.get("vector_similarity_weight", 0.3)), + similarity_threshold, + vector_similarity_weight, doc_ids=local_doc_ids, top=top, rerank_mdl=rerank_mdl, rank_feature=labels, + trace_id=search_id, ) if use_kg: @@ -1291,6 +1313,8 @@ async def search_datasets(tenant_id: str, req: dict): question = req.get("question", "") doc_ids = req.get("doc_ids", []) use_kg = req.get("use_kg", False) + similarity_threshold = float(req.get("similarity_threshold", 0.0)) + vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) top = max(1, min(int(req.get("top_k", 1024)), 2048)) langs = req.get("cross_languages", []) @@ -1321,14 +1345,31 @@ async def search_datasets(tenant_id: str, req: dict): local_doc_ids = list(doc_ids) if doc_ids else [] meta_data_filter = {} + search_id = req.get("search_id", "") + search_config = {} chat_mdl = None - if req.get("search_id", ""): - search_detail = SearchService.get_detail(req.get("search_id", "")) + if search_id: + search_detail = SearchService.get_detail(search_id) if not search_detail: - logging.warning("search config not found: search_id=%s", req.get("search_id", "")) + logging.warning("search config not found: search_id=%s", search_id) return False, "Invalid search_id" search_config = search_detail.get("search_config", {}) meta_data_filter = search_config.get("meta_data_filter", {}) + similarity_threshold = float(search_config.get("similarity_threshold", similarity_threshold)) + vector_similarity_weight = float(search_config.get("vector_similarity_weight", vector_similarity_weight)) + top = max(1, min(int(search_config.get("top_k", top)), 2048)) + use_kg = search_config.get("use_kg", use_kg) + langs = search_config.get("cross_languages", langs) + logging.debug( + "Dataset search loaded Search config: search_id=%s dataset_ids=%s " + "vector_similarity_weight=%s full_text_weight=%s similarity_threshold=%s top_k=%s", + search_id, + kb_ids, + vector_similarity_weight, + 1 - vector_similarity_weight, + similarity_threshold, + top, + ) if meta_data_filter.get("method") in ["auto", "semi_auto"]: chat_id = search_config.get("chat_id", "") if chat_id: @@ -1384,11 +1425,13 @@ async def search_datasets(tenant_id: str, req: dict): requester_tenant_id=tenant_id, ) rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) - elif req.get("rerank_id"): - rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"]) - rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) + else: + rerank_id = search_config.get("rerank_id") or req.get("rerank_id") + if rerank_id: + rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, rerank_id) + rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) - if req.get("keyword", False): + if search_config.get("keyword", req.get("keyword", False)): default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config) _question += await keyword_extraction(chat_mdl, _question) @@ -1401,12 +1444,13 @@ async def search_datasets(tenant_id: str, req: dict): kb_ids, page, size, - float(req.get("similarity_threshold", 0.0)), - float(req.get("vector_similarity_weight", 0.3)), + similarity_threshold, + vector_similarity_weight, doc_ids=local_doc_ids, top=top, rerank_mdl=rerank_mdl, rank_feature=labels, + trace_id=search_id, ) if use_kg: diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 77c11ce57b..4084fb80fe 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -1525,7 +1525,7 @@ async def _stream_with_think_delta(stream_iter, min_tokens: int = 16): yield ("marker", "", state) -async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): +async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}, search_id=None): doc_ids = search_config.get("doc_ids", []) rerank_mdl = None kb_ids = search_config.get("kb_ids", kb_ids) @@ -1561,6 +1561,21 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids), ) + vector_similarity_weight = search_config.get("vector_similarity_weight", 0.3) + try: + full_text_weight = 1 - vector_similarity_weight + except TypeError: + full_text_weight = None + logger.debug( + "Search async_ask retrieval weight: search_id=%s tenant_id=%s kb_count=%s " + "vector_similarity_weight=%s full_text_weight=%s", + search_id, + tenant_id, + len(kb_ids), + vector_similarity_weight, + full_text_weight, + ) + kbinfos = await retriever.retrieval( question=question, embd_mdl=embd_mdl, @@ -1569,12 +1584,13 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf page=1, page_size=12, similarity_threshold=search_config.get("similarity_threshold", 0.1), - vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3), + vector_similarity_weight=vector_similarity_weight, top=search_config.get("top_k", 1024), doc_ids=doc_ids, aggs=True, rerank_mdl=rerank_mdl, rank_feature=label_question(question, kbs), + trace_id=search_id, ) if include_reference_metadata: logging.debug( diff --git a/rag/nlp/search.py b/rag/nlp/search.py index e79671f04e..a32d2cacba 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -575,6 +575,7 @@ class Dealer: rerank_mdl=None, highlight=False, rank_feature: dict | None = {PAGERANK_FLD: 10}, + trace_id=None, ): ranks = {"total": 0, "chunks": [], "doc_aggs": {}} if not question: @@ -614,12 +615,24 @@ class Dealer: ranks["doc_aggs"] = [] return ranks + term_similarity_weight = 1 - vector_similarity_weight + logging.debug( + "[Search] retrieval weights: trace_id=%s kb_count=%s similarity_threshold=%s " + "vector_similarity_weight=%s full_text_weight=%s rerank_enabled=%s", + trace_id, + len(kb_ids), + similarity_threshold, + vector_similarity_weight, + term_similarity_weight, + bool(rerank_mdl), + ) + if rerank_mdl and sres.total > 0: sim, tsim, vsim = self.rerank_by_model( rerank_mdl, sres, question, - 1 - vector_similarity_weight, + term_similarity_weight, vector_similarity_weight, rank_feature=rank_feature, ) @@ -636,7 +649,7 @@ class Dealer: sim, tsim, vsim = self.rerank( sres, question, - 1 - vector_similarity_weight, + term_similarity_weight, vector_similarity_weight, rank_feature=rank_feature, ) @@ -650,7 +663,7 @@ class Dealer: sres, question, knn_scores, - 1 - vector_similarity_weight, + term_similarity_weight, vector_similarity_weight, rank_feature=rank_feature, ) diff --git a/web/src/pages/next-search/search-setting.tsx b/web/src/pages/next-search/search-setting.tsx index cb658eea5e..3c3dc6c600 100644 --- a/web/src/pages/next-search/search-setting.tsx +++ b/web/src/pages/next-search/search-setting.tsx @@ -132,9 +132,7 @@ const SearchSetting: React.FC = ({ search_config: { kb_ids: search_config?.kb_ids || [], vector_similarity_weight: - (search_config?.vector_similarity_weight - ? 1 - search_config?.vector_similarity_weight - : 0.3) || 0.3, + search_config?.vector_similarity_weight ?? 0.3, web_search: search_config?.web_search || false, doc_ids: [], similarity_threshold: search_config?.similarity_threshold || 0.2, @@ -309,7 +307,7 @@ const SearchSetting: React.FC = ({ ...other_config, reference_metadata: normalizedReferenceMetadata, chat_id: llm_setting.llm_id, - vector_similarity_weight: 1 - vector_similarity_weight, + vector_similarity_weight, rerank_id: use_rerank ? rerank_id : '', llm_setting: { ...llmSetting }, },