Refactor : Allow search multiple datasets (#14685)

### What problem does this PR solve?

Refactor : Allow search multiple datasets
1. support /datasets/search
2. get rid of /graph/search, use /graph

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
This commit is contained in:
Wang Qi
2026-05-08 19:01:35 +08:00
committed by GitHub
parent 26d70189b6
commit 7d35e40c7b
5 changed files with 214 additions and 18 deletions

View File

@@ -25,6 +25,7 @@ from api.utils.validation_utils import (
DeleteDatasetReq,
ListDatasetReq,
SearchDatasetReq,
SearchDatasetsReq,
UpdateDatasetReq,
validate_and_parse_json_request,
validate_and_parse_request_args,
@@ -477,6 +478,35 @@ async def rename_tag(tenant_id, dataset_id):
return get_error_data_result(message="Internal server error")
@manager.route("/datasets/search", methods=["POST"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
async def search_datasets(tenant_id):
"""Search (retrieval test) across multiple datasets.
POST /api/v1/datasets/search
JSON body: {"dataset_ids": list[str] (required), "question": str (required), "doc_ids": list[str], "top_k": int, "page": int, "size": int,
"similarity_threshold": float, "vector_similarity_weight": float, "use_kg": bool,
"cross_languages": list[str], "keyword": bool, "meta_data_filter": dict}
Success: {"code": 0, "data": {"chunks": [...], "total": int, "labels": [...]}}
Errors: ARGUMENT_ERROR (101) for invalid payload; DATA_ERROR (102) for access denied or internal errors.
"""
req, err = await validate_and_parse_json_request(request, SearchDatasetsReq)
if err is not None:
return get_error_argument_result(err)
try:
success, result = await dataset_api_service.search_datasets(tenant_id, req)
if success:
return get_result(data=result)
else:
return get_error_data_result(message=result)
except Exception as e:
logging.exception(e)
if "not_found" in str(e):
return get_error_data_result(message="No chunk found! Check the chunk status please!")
return get_error_data_result(message="Internal server error")
@manager.route("/datasets/<dataset_id>/search", methods=["POST"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
@@ -493,8 +523,9 @@ async def search(tenant_id, dataset_id):
req, err = await validate_and_parse_json_request(request, SearchDatasetReq)
if err is not None:
return get_error_argument_result(err)
req['dataset_ids'] = [dataset_id]
try:
success, result = await dataset_api_service.search(dataset_id, tenant_id, req)
success, result = await dataset_api_service.search_datasets(tenant_id, req)
if success:
return get_result(data=result)
else:
@@ -506,21 +537,6 @@ async def search(tenant_id, dataset_id):
return get_error_data_result(message="Internal server error")
@manager.route("/datasets/<dataset_id>/graph/search", methods=["GET"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
async def knowledge_graph(tenant_id, dataset_id):
try:
success, result = await dataset_api_service.get_knowledge_graph(dataset_id, tenant_id)
if success:
return get_result(data=result)
else:
return get_result(data=False, message=result, code=RetCode.AUTHENTICATION_ERROR)
except Exception as e:
logging.exception(e)
return get_error_data_result(message="Internal server error")
@manager.route("/datasets/<dataset_id>/graph", methods=["GET"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs

View File

@@ -26,6 +26,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.connector_service import Connector2KbService
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService
from api.db.services.user_service import TenantService, UserService, UserTenantService
from api.db.services.tenant_llm_service import TenantLLMService
from common.constants import FileSource, StatusEnum
from api.utils.api_utils import deep_merge, get_parser_config, remap_dictionary_keys, verify_embedding_availability
@@ -1050,3 +1051,162 @@ async def search(dataset_id: str, tenant_id: str, req: dict):
ranks["labels"] = labels
return True, ranks
async def search_datasets(tenant_id: str, req: dict):
"""
Search (retrieval test) across multiple datasets.
:param tenant_id: tenant ID
:param req: search request containing dataset_ids and other params
:return: (success, result) or (success, error_message)
"""
from api.db.joint_services.tenant_model_service import (
get_model_config_by_id,
get_model_config_by_type_and_name,
get_tenant_default_model_by_type,
)
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from common.constants import LLMType
from common.metadata_utils import apply_meta_data_filter
from rag.app.tag import label_question
from rag.prompts.generator import cross_languages, keyword_extraction
kb_ids = req.get("dataset_ids", [])
page = int(req.get("page", 1))
size = int(req.get("size", 30))
question = req.get("question", "")
doc_ids = req.get("doc_ids", [])
use_kg = req.get("use_kg", False)
top = max(1, min(int(req.get("top_k", 1024)), 2048))
langs = req.get("cross_languages", [])
logging.debug(
"search_datasets(datasets=%s, tenant=%s, question_len=%s)",
kb_ids,
tenant_id,
len(question),
)
# Access check for all datasets
for kb_id in kb_ids:
if not KnowledgebaseService.accessible(kb_id, tenant_id):
logging.warning("search_datasets access denied: dataset=%s tenant=%s", kb_id, tenant_id)
return False, f"Only owner of dataset {kb_id} authorized for this operation."
kbs = KnowledgebaseService.get_by_ids(kb_ids)
if not kbs:
return False, "Datasets not found!"
# All datasets must use the same embedding model
embd_nms = list(set([TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs]))
if len(embd_nms) != 1:
return False, "Datasets use different embedding models."
if doc_ids is not None and not isinstance(doc_ids, list):
return False, "`doc_ids` should be a list"
local_doc_ids = list(doc_ids) if doc_ids else []
meta_data_filter = {}
chat_mdl = None
if req.get("search_id", ""):
search_detail = SearchService.get_detail(req.get("search_id", ""))
if not search_detail:
logging.warning("search config not found: search_id=%s", req.get("search_id", ""))
return False, "Invalid search_id"
search_config = search_detail.get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
chat_id = search_config.get("chat_id", "")
if chat_id:
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, search_config["chat_id"])
else:
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(tenant_id, chat_model_config)
else:
meta_data_filter = req.get("meta_data_filter") or {}
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(tenant_id, chat_model_config)
if meta_data_filter:
local_doc_ids = await apply_meta_data_filter(
meta_data_filter,
None,
question,
chat_mdl,
local_doc_ids,
kb_ids=kb_ids,
metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids),
)
tenant_ids = []
tenants = UserTenantService.query(user_id=tenant_id)
for tenant in tenants:
if any(KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id) for kb_id in kb_ids):
tenant_ids.append(tenant.tenant_id)
break
else:
return False, "Only owner of datasets authorized for this operation."
kb = kbs[0]
_question = question
if langs:
_question = await cross_languages(kb.tenant_id, None, _question, langs)
if kb.tenant_embd_id:
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
elif kb.embd_id:
embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
else:
embd_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.EMBEDDING)
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config)
rerank_mdl = None
if req.get("tenant_rerank_id"):
rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"])
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
elif req.get("rerank_id"):
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"])
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
if req.get("keyword", False):
default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config)
_question += await keyword_extraction(chat_mdl, _question)
labels = label_question(_question, kbs)
ranks = await settings.retriever.retrieval(
_question,
embd_mdl,
tenant_ids,
kb_ids,
page,
size,
float(req.get("similarity_threshold", 0.0)),
float(req.get("vector_similarity_weight", 0.3)),
doc_ids=local_doc_ids,
top=top,
rerank_mdl=rerank_mdl,
rank_feature=labels,
)
if use_kg:
try:
default_chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, default_chat_model_config))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
except Exception:
logging.warning("search_datasets KG retrieval failed: datasets=%s tenant=%s", kb_ids, tenant_id, exc_info=True)
total = ranks.get("total", 0)
ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids)
ranks["total"] = total
for c in ranks["chunks"]:
c.pop("vector", None)
ranks["labels"] = labels
return True, ranks

View File

@@ -858,6 +858,26 @@ class SearchDatasetReq(BaseModel):
meta_data_filter: Annotated[dict | None, Field(default=None)]
class SearchDatasetsReq(BaseModel):
model_config = ConfigDict(extra="ignore")
dataset_ids: Annotated[list[str], Field(..., min_length=1)]
question: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1), Field(...)]
doc_ids: Annotated[list[str], Field(default=[])]
page: Annotated[int, Field(default=1, ge=1)]
size: Annotated[int, Field(default=30, ge=1)]
top_k: Annotated[int, Field(default=1024, ge=1)]
similarity_threshold: Annotated[float, Field(default=0.0, ge=0.0, le=1.0)]
vector_similarity_weight: Annotated[float, Field(default=0.3, ge=0.0, le=1.0)]
use_kg: Annotated[bool, Field(default=False)]
cross_languages: Annotated[list[str], Field(default=[])]
keyword: Annotated[bool, Field(default=False)]
search_id: Annotated[str | None, Field(default=None)]
rerank_id: Annotated[str | None, Field(default=None)]
tenant_rerank_id: Annotated[str | None, Field(default=None)]
meta_data_filter: Annotated[dict | None, Field(default=None)]
class BaseListReq(BaseModel):
model_config = ConfigDict(extra="forbid")

View File

@@ -298,7 +298,7 @@ def batch_add_sessions_with_chat_assistant(auth, chat_assistant_id, num):
# DATASET GRAPH AND TASKS
def knowledge_graph(auth, dataset_id, params=None):
url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/graph/search"
url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/graph"
res = requests.get(url=url, headers=HEADERS, auth=auth, params=params)
return res.json()

View File

@@ -65,7 +65,7 @@ export default {
rmKb: `${restAPIv1}/datasets`,
getKbDetail: (datasetId: string) => `${restAPIv1}/datasets/${datasetId}`,
getKnowledgeGraph: (knowledgeId: string) =>
`${restAPIv1}/datasets/${knowledgeId}/graph/search`,
`${restAPIv1}/datasets/${knowledgeId}/graph`,
knowledgeGraph: (datasetId: string) =>
`${restAPIv1}/datasets/${datasetId}/graph`,
deleteKnowledgeGraph: (knowledgeId: string) =>