mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-01 08:15:44 +08:00
Fix search vector_similarity_weight (#15108)
### What problem does this PR solve? Fix search vector_similarity_weight ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from numbers import Real
|
||||
|
||||
from quart import Response, request
|
||||
from api.db.services.dialog_service import async_ask
|
||||
@@ -31,6 +32,12 @@ from common.constants import RetCode, StatusEnum
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
|
||||
|
||||
def _full_text_weight(vector_similarity_weight):
|
||||
if isinstance(vector_similarity_weight, Real):
|
||||
return 1 - vector_similarity_weight
|
||||
return None
|
||||
|
||||
|
||||
@manager.route("/searches", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
@@ -158,6 +165,16 @@ async def update(search_id):
|
||||
if not isinstance(new_config, dict):
|
||||
return get_data_error_result(message="search_config must be a JSON object")
|
||||
req["search_config"] = {**current_config, **new_config}
|
||||
logging.debug(
|
||||
"Search update weight: search_id=%s user_id=%s "
|
||||
"incoming_vector_similarity_weight=%s stored_vector_similarity_weight=%s "
|
||||
"stored_full_text_weight=%s",
|
||||
search_id,
|
||||
current_user.id,
|
||||
new_config.get("vector_similarity_weight"),
|
||||
req["search_config"].get("vector_similarity_weight"),
|
||||
_full_text_weight(req["search_config"].get("vector_similarity_weight", 0.3)),
|
||||
)
|
||||
|
||||
for field in ("search_id", "tenant_id", "created_by", "update_time", "id"):
|
||||
req.pop(field, None)
|
||||
@@ -209,6 +226,14 @@ async def completion(search_id):
|
||||
return get_data_error_result(message=f"Cannot find search {search_id}")
|
||||
|
||||
search_config = search_app.get("search_config", {})
|
||||
logging.debug(
|
||||
"Search completion loaded weight: search_id=%s user_id=%s "
|
||||
"stored_vector_similarity_weight=%s stored_full_text_weight=%s",
|
||||
search_id,
|
||||
uid,
|
||||
search_config.get("vector_similarity_weight", 0.3),
|
||||
_full_text_weight(search_config.get("vector_similarity_weight", 0.3)),
|
||||
)
|
||||
kb_ids = search_config.get("kb_ids") or req.get("kb_ids") or []
|
||||
if not kb_ids:
|
||||
return get_data_error_result(message="`kb_ids` is required.")
|
||||
@@ -216,7 +241,7 @@ async def completion(search_id):
|
||||
async def stream():
|
||||
nonlocal req, uid, kb_ids, search_config
|
||||
try:
|
||||
async for ans in async_ask(req["question"], kb_ids, uid, search_config=search_config):
|
||||
async for ans in async_ask(req["question"], kb_ids, uid, search_config=search_config, search_id=search_id):
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
except Exception as ex:
|
||||
yield "data:" + json.dumps(
|
||||
|
||||
@@ -938,6 +938,8 @@ async def search(dataset_id: str, tenant_id: str, req: dict):
|
||||
question = req.get("question", "")
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
use_kg = req.get("use_kg", False)
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.0))
|
||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||
top = max(1, min(int(req.get("top_k", 1024)), 2048))
|
||||
langs = req.get("cross_languages", [])
|
||||
|
||||
@@ -955,14 +957,31 @@ async def search(dataset_id: str, tenant_id: str, req: dict):
|
||||
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||
|
||||
meta_data_filter = {}
|
||||
search_id = req.get("search_id", "")
|
||||
search_config = {}
|
||||
chat_mdl = None
|
||||
if req.get("search_id", ""):
|
||||
search_detail = SearchService.get_detail(req.get("search_id", ""))
|
||||
if search_id:
|
||||
search_detail = SearchService.get_detail(search_id)
|
||||
if not search_detail:
|
||||
logging.warning("search config not found: search_id=%s", req.get("search_id", ""))
|
||||
logging.warning("search config not found: search_id=%s", search_id)
|
||||
return False, "Invalid search_id"
|
||||
search_config = search_detail.get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
similarity_threshold = float(search_config.get("similarity_threshold", similarity_threshold))
|
||||
vector_similarity_weight = float(search_config.get("vector_similarity_weight", vector_similarity_weight))
|
||||
top = max(1, min(int(search_config.get("top_k", top)), 2048))
|
||||
use_kg = search_config.get("use_kg", use_kg)
|
||||
langs = search_config.get("cross_languages", langs)
|
||||
logging.debug(
|
||||
"Dataset search loaded Search config: search_id=%s dataset_id=%s "
|
||||
"vector_similarity_weight=%s full_text_weight=%s similarity_threshold=%s top_k=%s",
|
||||
search_id,
|
||||
dataset_id,
|
||||
vector_similarity_weight,
|
||||
1 - vector_similarity_weight,
|
||||
similarity_threshold,
|
||||
top,
|
||||
)
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_id = search_config.get("chat_id", "")
|
||||
if chat_id:
|
||||
@@ -1016,11 +1035,13 @@ async def search(dataset_id: str, tenant_id: str, req: dict):
|
||||
requester_tenant_id=tenant_id,
|
||||
)
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
||||
elif req.get("rerank_id"):
|
||||
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"])
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
||||
else:
|
||||
rerank_id = search_config.get("rerank_id") or req.get("rerank_id")
|
||||
if rerank_id:
|
||||
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, rerank_id)
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
||||
|
||||
if req.get("keyword", False):
|
||||
if search_config.get("keyword", req.get("keyword", False)):
|
||||
default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config)
|
||||
_question += await keyword_extraction(chat_mdl, _question)
|
||||
@@ -1033,12 +1054,13 @@ async def search(dataset_id: str, tenant_id: str, req: dict):
|
||||
[dataset_id],
|
||||
page,
|
||||
size,
|
||||
float(req.get("similarity_threshold", 0.0)),
|
||||
float(req.get("vector_similarity_weight", 0.3)),
|
||||
similarity_threshold,
|
||||
vector_similarity_weight,
|
||||
doc_ids=local_doc_ids,
|
||||
top=top,
|
||||
rerank_mdl=rerank_mdl,
|
||||
rank_feature=labels,
|
||||
trace_id=search_id,
|
||||
)
|
||||
|
||||
if use_kg:
|
||||
@@ -1291,6 +1313,8 @@ async def search_datasets(tenant_id: str, req: dict):
|
||||
question = req.get("question", "")
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
use_kg = req.get("use_kg", False)
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.0))
|
||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||
top = max(1, min(int(req.get("top_k", 1024)), 2048))
|
||||
langs = req.get("cross_languages", [])
|
||||
|
||||
@@ -1321,14 +1345,31 @@ async def search_datasets(tenant_id: str, req: dict):
|
||||
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||
|
||||
meta_data_filter = {}
|
||||
search_id = req.get("search_id", "")
|
||||
search_config = {}
|
||||
chat_mdl = None
|
||||
if req.get("search_id", ""):
|
||||
search_detail = SearchService.get_detail(req.get("search_id", ""))
|
||||
if search_id:
|
||||
search_detail = SearchService.get_detail(search_id)
|
||||
if not search_detail:
|
||||
logging.warning("search config not found: search_id=%s", req.get("search_id", ""))
|
||||
logging.warning("search config not found: search_id=%s", search_id)
|
||||
return False, "Invalid search_id"
|
||||
search_config = search_detail.get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
similarity_threshold = float(search_config.get("similarity_threshold", similarity_threshold))
|
||||
vector_similarity_weight = float(search_config.get("vector_similarity_weight", vector_similarity_weight))
|
||||
top = max(1, min(int(search_config.get("top_k", top)), 2048))
|
||||
use_kg = search_config.get("use_kg", use_kg)
|
||||
langs = search_config.get("cross_languages", langs)
|
||||
logging.debug(
|
||||
"Dataset search loaded Search config: search_id=%s dataset_ids=%s "
|
||||
"vector_similarity_weight=%s full_text_weight=%s similarity_threshold=%s top_k=%s",
|
||||
search_id,
|
||||
kb_ids,
|
||||
vector_similarity_weight,
|
||||
1 - vector_similarity_weight,
|
||||
similarity_threshold,
|
||||
top,
|
||||
)
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_id = search_config.get("chat_id", "")
|
||||
if chat_id:
|
||||
@@ -1384,11 +1425,13 @@ async def search_datasets(tenant_id: str, req: dict):
|
||||
requester_tenant_id=tenant_id,
|
||||
)
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
||||
elif req.get("rerank_id"):
|
||||
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"])
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
||||
else:
|
||||
rerank_id = search_config.get("rerank_id") or req.get("rerank_id")
|
||||
if rerank_id:
|
||||
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, rerank_id)
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
||||
|
||||
if req.get("keyword", False):
|
||||
if search_config.get("keyword", req.get("keyword", False)):
|
||||
default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config)
|
||||
_question += await keyword_extraction(chat_mdl, _question)
|
||||
@@ -1401,12 +1444,13 @@ async def search_datasets(tenant_id: str, req: dict):
|
||||
kb_ids,
|
||||
page,
|
||||
size,
|
||||
float(req.get("similarity_threshold", 0.0)),
|
||||
float(req.get("vector_similarity_weight", 0.3)),
|
||||
similarity_threshold,
|
||||
vector_similarity_weight,
|
||||
doc_ids=local_doc_ids,
|
||||
top=top,
|
||||
rerank_mdl=rerank_mdl,
|
||||
rank_feature=labels,
|
||||
trace_id=search_id,
|
||||
)
|
||||
|
||||
if use_kg:
|
||||
|
||||
@@ -1525,7 +1525,7 @@ async def _stream_with_think_delta(stream_iter, min_tokens: int = 16):
|
||||
yield ("marker", "</think>", state)
|
||||
|
||||
|
||||
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}, search_id=None):
|
||||
doc_ids = search_config.get("doc_ids", [])
|
||||
rerank_mdl = None
|
||||
kb_ids = search_config.get("kb_ids", kb_ids)
|
||||
@@ -1561,6 +1561,21 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
|
||||
metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids),
|
||||
)
|
||||
|
||||
vector_similarity_weight = search_config.get("vector_similarity_weight", 0.3)
|
||||
try:
|
||||
full_text_weight = 1 - vector_similarity_weight
|
||||
except TypeError:
|
||||
full_text_weight = None
|
||||
logger.debug(
|
||||
"Search async_ask retrieval weight: search_id=%s tenant_id=%s kb_count=%s "
|
||||
"vector_similarity_weight=%s full_text_weight=%s",
|
||||
search_id,
|
||||
tenant_id,
|
||||
len(kb_ids),
|
||||
vector_similarity_weight,
|
||||
full_text_weight,
|
||||
)
|
||||
|
||||
kbinfos = await retriever.retrieval(
|
||||
question=question,
|
||||
embd_mdl=embd_mdl,
|
||||
@@ -1569,12 +1584,13 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
|
||||
page=1,
|
||||
page_size=12,
|
||||
similarity_threshold=search_config.get("similarity_threshold", 0.1),
|
||||
vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3),
|
||||
vector_similarity_weight=vector_similarity_weight,
|
||||
top=search_config.get("top_k", 1024),
|
||||
doc_ids=doc_ids,
|
||||
aggs=True,
|
||||
rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(question, kbs),
|
||||
trace_id=search_id,
|
||||
)
|
||||
if include_reference_metadata:
|
||||
logging.debug(
|
||||
|
||||
@@ -575,6 +575,7 @@ class Dealer:
|
||||
rerank_mdl=None,
|
||||
highlight=False,
|
||||
rank_feature: dict | None = {PAGERANK_FLD: 10},
|
||||
trace_id=None,
|
||||
):
|
||||
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
||||
if not question:
|
||||
@@ -614,12 +615,24 @@ class Dealer:
|
||||
ranks["doc_aggs"] = []
|
||||
return ranks
|
||||
|
||||
term_similarity_weight = 1 - vector_similarity_weight
|
||||
logging.debug(
|
||||
"[Search] retrieval weights: trace_id=%s kb_count=%s similarity_threshold=%s "
|
||||
"vector_similarity_weight=%s full_text_weight=%s rerank_enabled=%s",
|
||||
trace_id,
|
||||
len(kb_ids),
|
||||
similarity_threshold,
|
||||
vector_similarity_weight,
|
||||
term_similarity_weight,
|
||||
bool(rerank_mdl),
|
||||
)
|
||||
|
||||
if rerank_mdl and sres.total > 0:
|
||||
sim, tsim, vsim = self.rerank_by_model(
|
||||
rerank_mdl,
|
||||
sres,
|
||||
question,
|
||||
1 - vector_similarity_weight,
|
||||
term_similarity_weight,
|
||||
vector_similarity_weight,
|
||||
rank_feature=rank_feature,
|
||||
)
|
||||
@@ -636,7 +649,7 @@ class Dealer:
|
||||
sim, tsim, vsim = self.rerank(
|
||||
sres,
|
||||
question,
|
||||
1 - vector_similarity_weight,
|
||||
term_similarity_weight,
|
||||
vector_similarity_weight,
|
||||
rank_feature=rank_feature,
|
||||
)
|
||||
@@ -650,7 +663,7 @@ class Dealer:
|
||||
sres,
|
||||
question,
|
||||
knn_scores,
|
||||
1 - vector_similarity_weight,
|
||||
term_similarity_weight,
|
||||
vector_similarity_weight,
|
||||
rank_feature=rank_feature,
|
||||
)
|
||||
|
||||
@@ -132,9 +132,7 @@ const SearchSetting: React.FC<SearchSettingProps> = ({
|
||||
search_config: {
|
||||
kb_ids: search_config?.kb_ids || [],
|
||||
vector_similarity_weight:
|
||||
(search_config?.vector_similarity_weight
|
||||
? 1 - search_config?.vector_similarity_weight
|
||||
: 0.3) || 0.3,
|
||||
search_config?.vector_similarity_weight ?? 0.3,
|
||||
web_search: search_config?.web_search || false,
|
||||
doc_ids: [],
|
||||
similarity_threshold: search_config?.similarity_threshold || 0.2,
|
||||
@@ -309,7 +307,7 @@ const SearchSetting: React.FC<SearchSettingProps> = ({
|
||||
...other_config,
|
||||
reference_metadata: normalizedReferenceMetadata,
|
||||
chat_id: llm_setting.llm_id,
|
||||
vector_similarity_weight: 1 - vector_similarity_weight,
|
||||
vector_similarity_weight,
|
||||
rerank_id: use_rerank ? rerank_id : '',
|
||||
llm_setting: { ...llmSetting },
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user