Fix search vector_similarity_weight (#15108)

### What problem does this PR solve?

Fix search vector_similarity_weight

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Wang Qi
2026-05-22 16:05:13 +08:00
committed by GitHub
parent f9ce07ced1
commit 7e6844118b
5 changed files with 124 additions and 28 deletions

View File

@@ -16,6 +16,7 @@
import json
import logging
from numbers import Real
from quart import Response, request
from api.db.services.dialog_service import async_ask
@@ -31,6 +32,12 @@ from common.constants import RetCode, StatusEnum
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
def _full_text_weight(vector_similarity_weight):
if isinstance(vector_similarity_weight, Real):
return 1 - vector_similarity_weight
return None
@manager.route("/searches", methods=["POST"]) # noqa: F821
@login_required
@validate_request("name")
@@ -158,6 +165,16 @@ async def update(search_id):
if not isinstance(new_config, dict):
return get_data_error_result(message="search_config must be a JSON object")
req["search_config"] = {**current_config, **new_config}
logging.debug(
"Search update weight: search_id=%s user_id=%s "
"incoming_vector_similarity_weight=%s stored_vector_similarity_weight=%s "
"stored_full_text_weight=%s",
search_id,
current_user.id,
new_config.get("vector_similarity_weight"),
req["search_config"].get("vector_similarity_weight"),
_full_text_weight(req["search_config"].get("vector_similarity_weight", 0.3)),
)
for field in ("search_id", "tenant_id", "created_by", "update_time", "id"):
req.pop(field, None)
@@ -209,6 +226,14 @@ async def completion(search_id):
return get_data_error_result(message=f"Cannot find search {search_id}")
search_config = search_app.get("search_config", {})
logging.debug(
"Search completion loaded weight: search_id=%s user_id=%s "
"stored_vector_similarity_weight=%s stored_full_text_weight=%s",
search_id,
uid,
search_config.get("vector_similarity_weight", 0.3),
_full_text_weight(search_config.get("vector_similarity_weight", 0.3)),
)
kb_ids = search_config.get("kb_ids") or req.get("kb_ids") or []
if not kb_ids:
return get_data_error_result(message="`kb_ids` is required.")
@@ -216,7 +241,7 @@ async def completion(search_id):
async def stream():
nonlocal req, uid, kb_ids, search_config
try:
async for ans in async_ask(req["question"], kb_ids, uid, search_config=search_config):
async for ans in async_ask(req["question"], kb_ids, uid, search_config=search_config, search_id=search_id):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as ex:
yield "data:" + json.dumps(

View File

@@ -938,6 +938,8 @@ async def search(dataset_id: str, tenant_id: str, req: dict):
question = req.get("question", "")
doc_ids = req.get("doc_ids", [])
use_kg = req.get("use_kg", False)
similarity_threshold = float(req.get("similarity_threshold", 0.0))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = max(1, min(int(req.get("top_k", 1024)), 2048))
langs = req.get("cross_languages", [])
@@ -955,14 +957,31 @@ async def search(dataset_id: str, tenant_id: str, req: dict):
local_doc_ids = list(doc_ids) if doc_ids else []
meta_data_filter = {}
search_id = req.get("search_id", "")
search_config = {}
chat_mdl = None
if req.get("search_id", ""):
search_detail = SearchService.get_detail(req.get("search_id", ""))
if search_id:
search_detail = SearchService.get_detail(search_id)
if not search_detail:
logging.warning("search config not found: search_id=%s", req.get("search_id", ""))
logging.warning("search config not found: search_id=%s", search_id)
return False, "Invalid search_id"
search_config = search_detail.get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
similarity_threshold = float(search_config.get("similarity_threshold", similarity_threshold))
vector_similarity_weight = float(search_config.get("vector_similarity_weight", vector_similarity_weight))
top = max(1, min(int(search_config.get("top_k", top)), 2048))
use_kg = search_config.get("use_kg", use_kg)
langs = search_config.get("cross_languages", langs)
logging.debug(
"Dataset search loaded Search config: search_id=%s dataset_id=%s "
"vector_similarity_weight=%s full_text_weight=%s similarity_threshold=%s top_k=%s",
search_id,
dataset_id,
vector_similarity_weight,
1 - vector_similarity_weight,
similarity_threshold,
top,
)
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
chat_id = search_config.get("chat_id", "")
if chat_id:
@@ -1016,11 +1035,13 @@ async def search(dataset_id: str, tenant_id: str, req: dict):
requester_tenant_id=tenant_id,
)
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
elif req.get("rerank_id"):
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"])
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
else:
rerank_id = search_config.get("rerank_id") or req.get("rerank_id")
if rerank_id:
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, rerank_id)
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
if req.get("keyword", False):
if search_config.get("keyword", req.get("keyword", False)):
default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config)
_question += await keyword_extraction(chat_mdl, _question)
@@ -1033,12 +1054,13 @@ async def search(dataset_id: str, tenant_id: str, req: dict):
[dataset_id],
page,
size,
float(req.get("similarity_threshold", 0.0)),
float(req.get("vector_similarity_weight", 0.3)),
similarity_threshold,
vector_similarity_weight,
doc_ids=local_doc_ids,
top=top,
rerank_mdl=rerank_mdl,
rank_feature=labels,
trace_id=search_id,
)
if use_kg:
@@ -1291,6 +1313,8 @@ async def search_datasets(tenant_id: str, req: dict):
question = req.get("question", "")
doc_ids = req.get("doc_ids", [])
use_kg = req.get("use_kg", False)
similarity_threshold = float(req.get("similarity_threshold", 0.0))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = max(1, min(int(req.get("top_k", 1024)), 2048))
langs = req.get("cross_languages", [])
@@ -1321,14 +1345,31 @@ async def search_datasets(tenant_id: str, req: dict):
local_doc_ids = list(doc_ids) if doc_ids else []
meta_data_filter = {}
search_id = req.get("search_id", "")
search_config = {}
chat_mdl = None
if req.get("search_id", ""):
search_detail = SearchService.get_detail(req.get("search_id", ""))
if search_id:
search_detail = SearchService.get_detail(search_id)
if not search_detail:
logging.warning("search config not found: search_id=%s", req.get("search_id", ""))
logging.warning("search config not found: search_id=%s", search_id)
return False, "Invalid search_id"
search_config = search_detail.get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
similarity_threshold = float(search_config.get("similarity_threshold", similarity_threshold))
vector_similarity_weight = float(search_config.get("vector_similarity_weight", vector_similarity_weight))
top = max(1, min(int(search_config.get("top_k", top)), 2048))
use_kg = search_config.get("use_kg", use_kg)
langs = search_config.get("cross_languages", langs)
logging.debug(
"Dataset search loaded Search config: search_id=%s dataset_ids=%s "
"vector_similarity_weight=%s full_text_weight=%s similarity_threshold=%s top_k=%s",
search_id,
kb_ids,
vector_similarity_weight,
1 - vector_similarity_weight,
similarity_threshold,
top,
)
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
chat_id = search_config.get("chat_id", "")
if chat_id:
@@ -1384,11 +1425,13 @@ async def search_datasets(tenant_id: str, req: dict):
requester_tenant_id=tenant_id,
)
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
elif req.get("rerank_id"):
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"])
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
else:
rerank_id = search_config.get("rerank_id") or req.get("rerank_id")
if rerank_id:
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, rerank_id)
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
if req.get("keyword", False):
if search_config.get("keyword", req.get("keyword", False)):
default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config)
_question += await keyword_extraction(chat_mdl, _question)
@@ -1401,12 +1444,13 @@ async def search_datasets(tenant_id: str, req: dict):
kb_ids,
page,
size,
float(req.get("similarity_threshold", 0.0)),
float(req.get("vector_similarity_weight", 0.3)),
similarity_threshold,
vector_similarity_weight,
doc_ids=local_doc_ids,
top=top,
rerank_mdl=rerank_mdl,
rank_feature=labels,
trace_id=search_id,
)
if use_kg:

View File

@@ -1525,7 +1525,7 @@ async def _stream_with_think_delta(stream_iter, min_tokens: int = 16):
yield ("marker", "</think>", state)
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}, search_id=None):
doc_ids = search_config.get("doc_ids", [])
rerank_mdl = None
kb_ids = search_config.get("kb_ids", kb_ids)
@@ -1561,6 +1561,21 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids),
)
vector_similarity_weight = search_config.get("vector_similarity_weight", 0.3)
try:
full_text_weight = 1 - vector_similarity_weight
except TypeError:
full_text_weight = None
logger.debug(
"Search async_ask retrieval weight: search_id=%s tenant_id=%s kb_count=%s "
"vector_similarity_weight=%s full_text_weight=%s",
search_id,
tenant_id,
len(kb_ids),
vector_similarity_weight,
full_text_weight,
)
kbinfos = await retriever.retrieval(
question=question,
embd_mdl=embd_mdl,
@@ -1569,12 +1584,13 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
page=1,
page_size=12,
similarity_threshold=search_config.get("similarity_threshold", 0.1),
vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3),
vector_similarity_weight=vector_similarity_weight,
top=search_config.get("top_k", 1024),
doc_ids=doc_ids,
aggs=True,
rerank_mdl=rerank_mdl,
rank_feature=label_question(question, kbs),
trace_id=search_id,
)
if include_reference_metadata:
logging.debug(

View File

@@ -575,6 +575,7 @@ class Dealer:
rerank_mdl=None,
highlight=False,
rank_feature: dict | None = {PAGERANK_FLD: 10},
trace_id=None,
):
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question:
@@ -614,12 +615,24 @@ class Dealer:
ranks["doc_aggs"] = []
return ranks
term_similarity_weight = 1 - vector_similarity_weight
logging.debug(
"[Search] retrieval weights: trace_id=%s kb_count=%s similarity_threshold=%s "
"vector_similarity_weight=%s full_text_weight=%s rerank_enabled=%s",
trace_id,
len(kb_ids),
similarity_threshold,
vector_similarity_weight,
term_similarity_weight,
bool(rerank_mdl),
)
if rerank_mdl and sres.total > 0:
sim, tsim, vsim = self.rerank_by_model(
rerank_mdl,
sres,
question,
1 - vector_similarity_weight,
term_similarity_weight,
vector_similarity_weight,
rank_feature=rank_feature,
)
@@ -636,7 +649,7 @@ class Dealer:
sim, tsim, vsim = self.rerank(
sres,
question,
1 - vector_similarity_weight,
term_similarity_weight,
vector_similarity_weight,
rank_feature=rank_feature,
)
@@ -650,7 +663,7 @@ class Dealer:
sres,
question,
knn_scores,
1 - vector_similarity_weight,
term_similarity_weight,
vector_similarity_weight,
rank_feature=rank_feature,
)

View File

@@ -132,9 +132,7 @@ const SearchSetting: React.FC<SearchSettingProps> = ({
search_config: {
kb_ids: search_config?.kb_ids || [],
vector_similarity_weight:
(search_config?.vector_similarity_weight
? 1 - search_config?.vector_similarity_weight
: 0.3) || 0.3,
search_config?.vector_similarity_weight ?? 0.3,
web_search: search_config?.web_search || false,
doc_ids: [],
similarity_threshold: search_config?.similarity_threshold || 0.2,
@@ -309,7 +307,7 @@ const SearchSetting: React.FC<SearchSettingProps> = ({
...other_config,
reference_metadata: normalizedReferenceMetadata,
chat_id: llm_setting.llm_id,
vector_similarity_weight: 1 - vector_similarity_weight,
vector_similarity_weight,
rerank_id: use_rerank ? rerank_id : '',
llm_setting: { ...llmSetting },
},