mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Refactor : Allow search multiple datasets (#14685)
### What problem does this PR solve? Refactor : Allow search multiple datasets 1. support /datasets/search 2. get rid of /graph/search, use /graph ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoring
This commit is contained in:
@@ -25,6 +25,7 @@ from api.utils.validation_utils import (
|
||||
DeleteDatasetReq,
|
||||
ListDatasetReq,
|
||||
SearchDatasetReq,
|
||||
SearchDatasetsReq,
|
||||
UpdateDatasetReq,
|
||||
validate_and_parse_json_request,
|
||||
validate_and_parse_request_args,
|
||||
@@ -477,6 +478,35 @@ async def rename_tag(tenant_id, dataset_id):
|
||||
return get_error_data_result(message="Internal server error")
|
||||
|
||||
|
||||
@manager.route("/datasets/search", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@add_tenant_id_to_kwargs
|
||||
async def search_datasets(tenant_id):
|
||||
"""Search (retrieval test) across multiple datasets.
|
||||
|
||||
POST /api/v1/datasets/search
|
||||
JSON body: {"dataset_ids": list[str] (required), "question": str (required), "doc_ids": list[str], "top_k": int, "page": int, "size": int,
|
||||
"similarity_threshold": float, "vector_similarity_weight": float, "use_kg": bool,
|
||||
"cross_languages": list[str], "keyword": bool, "meta_data_filter": dict}
|
||||
Success: {"code": 0, "data": {"chunks": [...], "total": int, "labels": [...]}}
|
||||
Errors: ARGUMENT_ERROR (101) for invalid payload; DATA_ERROR (102) for access denied or internal errors.
|
||||
"""
|
||||
req, err = await validate_and_parse_json_request(request, SearchDatasetsReq)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
try:
|
||||
success, result = await dataset_api_service.search_datasets(tenant_id, req)
|
||||
if success:
|
||||
return get_result(data=result)
|
||||
else:
|
||||
return get_error_data_result(message=result)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
if "not_found" in str(e):
|
||||
return get_error_data_result(message="No chunk found! Check the chunk status please!")
|
||||
return get_error_data_result(message="Internal server error")
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/search", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@add_tenant_id_to_kwargs
|
||||
@@ -493,8 +523,9 @@ async def search(tenant_id, dataset_id):
|
||||
req, err = await validate_and_parse_json_request(request, SearchDatasetReq)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
req['dataset_ids'] = [dataset_id]
|
||||
try:
|
||||
success, result = await dataset_api_service.search(dataset_id, tenant_id, req)
|
||||
success, result = await dataset_api_service.search_datasets(tenant_id, req)
|
||||
if success:
|
||||
return get_result(data=result)
|
||||
else:
|
||||
@@ -506,21 +537,6 @@ async def search(tenant_id, dataset_id):
|
||||
return get_error_data_result(message="Internal server error")
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/graph/search", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
@add_tenant_id_to_kwargs
|
||||
async def knowledge_graph(tenant_id, dataset_id):
|
||||
try:
|
||||
success, result = await dataset_api_service.get_knowledge_graph(dataset_id, tenant_id)
|
||||
if success:
|
||||
return get_result(data=result)
|
||||
else:
|
||||
return get_result(data=False, message=result, code=RetCode.AUTHENTICATION_ERROR)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Internal server error")
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/graph", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
@add_tenant_id_to_kwargs
|
||||
|
||||
@@ -26,6 +26,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.connector_service import Connector2KbService
|
||||
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService
|
||||
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from common.constants import FileSource, StatusEnum
|
||||
from api.utils.api_utils import deep_merge, get_parser_config, remap_dictionary_keys, verify_embedding_availability
|
||||
|
||||
@@ -1050,3 +1051,162 @@ async def search(dataset_id: str, tenant_id: str, req: dict):
|
||||
ranks["labels"] = labels
|
||||
|
||||
return True, ranks
|
||||
|
||||
|
||||
async def search_datasets(tenant_id: str, req: dict):
|
||||
"""
|
||||
Search (retrieval test) across multiple datasets.
|
||||
|
||||
:param tenant_id: tenant ID
|
||||
:param req: search request containing dataset_ids and other params
|
||||
:return: (success, result) or (success, error_message)
|
||||
"""
|
||||
from api.db.joint_services.tenant_model_service import (
|
||||
get_model_config_by_id,
|
||||
get_model_config_by_type_and_name,
|
||||
get_tenant_default_model_by_type,
|
||||
)
|
||||
from api.db.services.doc_metadata_service import DocMetadataService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from common.constants import LLMType
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.generator import cross_languages, keyword_extraction
|
||||
|
||||
kb_ids = req.get("dataset_ids", [])
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
question = req.get("question", "")
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
use_kg = req.get("use_kg", False)
|
||||
top = max(1, min(int(req.get("top_k", 1024)), 2048))
|
||||
langs = req.get("cross_languages", [])
|
||||
|
||||
logging.debug(
|
||||
"search_datasets(datasets=%s, tenant=%s, question_len=%s)",
|
||||
kb_ids,
|
||||
tenant_id,
|
||||
len(question),
|
||||
)
|
||||
|
||||
# Access check for all datasets
|
||||
for kb_id in kb_ids:
|
||||
if not KnowledgebaseService.accessible(kb_id, tenant_id):
|
||||
logging.warning("search_datasets access denied: dataset=%s tenant=%s", kb_id, tenant_id)
|
||||
return False, f"Only owner of dataset {kb_id} authorized for this operation."
|
||||
|
||||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||||
if not kbs:
|
||||
return False, "Datasets not found!"
|
||||
|
||||
# All datasets must use the same embedding model
|
||||
embd_nms = list(set([TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs]))
|
||||
if len(embd_nms) != 1:
|
||||
return False, "Datasets use different embedding models."
|
||||
|
||||
if doc_ids is not None and not isinstance(doc_ids, list):
|
||||
return False, "`doc_ids` should be a list"
|
||||
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||
|
||||
meta_data_filter = {}
|
||||
chat_mdl = None
|
||||
if req.get("search_id", ""):
|
||||
search_detail = SearchService.get_detail(req.get("search_id", ""))
|
||||
if not search_detail:
|
||||
logging.warning("search config not found: search_id=%s", req.get("search_id", ""))
|
||||
return False, "Invalid search_id"
|
||||
search_config = search_detail.get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_id = search_config.get("chat_id", "")
|
||||
if chat_id:
|
||||
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, search_config["chat_id"])
|
||||
else:
|
||||
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(tenant_id, chat_model_config)
|
||||
else:
|
||||
meta_data_filter = req.get("meta_data_filter") or {}
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(tenant_id, chat_model_config)
|
||||
|
||||
if meta_data_filter:
|
||||
local_doc_ids = await apply_meta_data_filter(
|
||||
meta_data_filter,
|
||||
None,
|
||||
question,
|
||||
chat_mdl,
|
||||
local_doc_ids,
|
||||
kb_ids=kb_ids,
|
||||
metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids),
|
||||
)
|
||||
|
||||
tenant_ids = []
|
||||
tenants = UserTenantService.query(user_id=tenant_id)
|
||||
for tenant in tenants:
|
||||
if any(KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id) for kb_id in kb_ids):
|
||||
tenant_ids.append(tenant.tenant_id)
|
||||
break
|
||||
else:
|
||||
return False, "Only owner of datasets authorized for this operation."
|
||||
|
||||
kb = kbs[0]
|
||||
_question = question
|
||||
if langs:
|
||||
_question = await cross_languages(kb.tenant_id, None, _question, langs)
|
||||
if kb.tenant_embd_id:
|
||||
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
|
||||
elif kb.embd_id:
|
||||
embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
||||
else:
|
||||
embd_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.EMBEDDING)
|
||||
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config)
|
||||
|
||||
rerank_mdl = None
|
||||
if req.get("tenant_rerank_id"):
|
||||
rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"])
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
||||
elif req.get("rerank_id"):
|
||||
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"])
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
||||
|
||||
if req.get("keyword", False):
|
||||
default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config)
|
||||
_question += await keyword_extraction(chat_mdl, _question)
|
||||
|
||||
labels = label_question(_question, kbs)
|
||||
ranks = await settings.retriever.retrieval(
|
||||
_question,
|
||||
embd_mdl,
|
||||
tenant_ids,
|
||||
kb_ids,
|
||||
page,
|
||||
size,
|
||||
float(req.get("similarity_threshold", 0.0)),
|
||||
float(req.get("vector_similarity_weight", 0.3)),
|
||||
doc_ids=local_doc_ids,
|
||||
top=top,
|
||||
rerank_mdl=rerank_mdl,
|
||||
rank_feature=labels,
|
||||
)
|
||||
|
||||
if use_kg:
|
||||
try:
|
||||
default_chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
|
||||
ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, default_chat_model_config))
|
||||
if ck["content_with_weight"]:
|
||||
ranks["chunks"].insert(0, ck)
|
||||
except Exception:
|
||||
logging.warning("search_datasets KG retrieval failed: datasets=%s tenant=%s", kb_ids, tenant_id, exc_info=True)
|
||||
total = ranks.get("total", 0)
|
||||
ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids)
|
||||
ranks["total"] = total
|
||||
|
||||
for c in ranks["chunks"]:
|
||||
c.pop("vector", None)
|
||||
ranks["labels"] = labels
|
||||
|
||||
return True, ranks
|
||||
|
||||
@@ -858,6 +858,26 @@ class SearchDatasetReq(BaseModel):
|
||||
meta_data_filter: Annotated[dict | None, Field(default=None)]
|
||||
|
||||
|
||||
class SearchDatasetsReq(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
dataset_ids: Annotated[list[str], Field(..., min_length=1)]
|
||||
question: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1), Field(...)]
|
||||
doc_ids: Annotated[list[str], Field(default=[])]
|
||||
page: Annotated[int, Field(default=1, ge=1)]
|
||||
size: Annotated[int, Field(default=30, ge=1)]
|
||||
top_k: Annotated[int, Field(default=1024, ge=1)]
|
||||
similarity_threshold: Annotated[float, Field(default=0.0, ge=0.0, le=1.0)]
|
||||
vector_similarity_weight: Annotated[float, Field(default=0.3, ge=0.0, le=1.0)]
|
||||
use_kg: Annotated[bool, Field(default=False)]
|
||||
cross_languages: Annotated[list[str], Field(default=[])]
|
||||
keyword: Annotated[bool, Field(default=False)]
|
||||
search_id: Annotated[str | None, Field(default=None)]
|
||||
rerank_id: Annotated[str | None, Field(default=None)]
|
||||
tenant_rerank_id: Annotated[str | None, Field(default=None)]
|
||||
meta_data_filter: Annotated[dict | None, Field(default=None)]
|
||||
|
||||
|
||||
class BaseListReq(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
@@ -298,7 +298,7 @@ def batch_add_sessions_with_chat_assistant(auth, chat_assistant_id, num):
|
||||
|
||||
# DATASET GRAPH AND TASKS
|
||||
def knowledge_graph(auth, dataset_id, params=None):
|
||||
url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/graph/search"
|
||||
url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/graph"
|
||||
res = requests.get(url=url, headers=HEADERS, auth=auth, params=params)
|
||||
return res.json()
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ export default {
|
||||
rmKb: `${restAPIv1}/datasets`,
|
||||
getKbDetail: (datasetId: string) => `${restAPIv1}/datasets/${datasetId}`,
|
||||
getKnowledgeGraph: (knowledgeId: string) =>
|
||||
`${restAPIv1}/datasets/${knowledgeId}/graph/search`,
|
||||
`${restAPIv1}/datasets/${knowledgeId}/graph`,
|
||||
knowledgeGraph: (datasetId: string) =>
|
||||
`${restAPIv1}/datasets/${datasetId}/graph`,
|
||||
deleteKnowledgeGraph: (knowledgeId: string) =>
|
||||
|
||||
Reference in New Issue
Block a user