diff --git a/api/apps/restful_apis/dataset_api.py b/api/apps/restful_apis/dataset_api.py index bdfa98699d..701c7340b7 100644 --- a/api/apps/restful_apis/dataset_api.py +++ b/api/apps/restful_apis/dataset_api.py @@ -25,6 +25,7 @@ from api.utils.validation_utils import ( DeleteDatasetReq, ListDatasetReq, SearchDatasetReq, + SearchDatasetsReq, UpdateDatasetReq, validate_and_parse_json_request, validate_and_parse_request_args, @@ -477,6 +478,35 @@ async def rename_tag(tenant_id, dataset_id): return get_error_data_result(message="Internal server error") +@manager.route("/datasets/search", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def search_datasets(tenant_id): + """Search (retrieval test) across multiple datasets. + + POST /api/v1/datasets/search + JSON body: {"dataset_ids": list[str] (required), "question": str (required), "doc_ids": list[str], "top_k": int, "page": int, "size": int, + "similarity_threshold": float, "vector_similarity_weight": float, "use_kg": bool, + "cross_languages": list[str], "keyword": bool, "meta_data_filter": dict} + Success: {"code": 0, "data": {"chunks": [...], "total": int, "labels": [...]}} + Errors: ARGUMENT_ERROR (101) for invalid payload; DATA_ERROR (102) for access denied or internal errors. + """ + req, err = await validate_and_parse_json_request(request, SearchDatasetsReq) + if err is not None: + return get_error_argument_result(err) + try: + success, result = await dataset_api_service.search_datasets(tenant_id, req) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + if "not_found" in str(e): + return get_error_data_result(message="No chunk found! Check the chunk status please!") + return get_error_data_result(message="Internal server error") + + @manager.route("/datasets//search", methods=["POST"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs @@ -493,8 +523,9 @@ async def search(tenant_id, dataset_id): req, err = await validate_and_parse_json_request(request, SearchDatasetReq) if err is not None: return get_error_argument_result(err) + req['dataset_ids'] = [dataset_id] try: - success, result = await dataset_api_service.search(dataset_id, tenant_id, req) + success, result = await dataset_api_service.search_datasets(tenant_id, req) if success: return get_result(data=result) else: @@ -506,21 +537,6 @@ async def search(tenant_id, dataset_id): return get_error_data_result(message="Internal server error") -@manager.route("/datasets//graph/search", methods=["GET"]) # noqa: F821 -@login_required -@add_tenant_id_to_kwargs -async def knowledge_graph(tenant_id, dataset_id): - try: - success, result = await dataset_api_service.get_knowledge_graph(dataset_id, tenant_id) - if success: - return get_result(data=result) - else: - return get_result(data=False, message=result, code=RetCode.AUTHENTICATION_ERROR) - except Exception as e: - logging.exception(e) - return get_error_data_result(message="Internal server error") - - @manager.route("/datasets//graph", methods=["GET"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs diff --git a/api/apps/services/dataset_api_service.py b/api/apps/services/dataset_api_service.py index 16418d83d8..795e42b7b8 100644 --- a/api/apps/services/dataset_api_service.py +++ b/api/apps/services/dataset_api_service.py @@ -26,6 +26,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.connector_service import Connector2KbService from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService from api.db.services.user_service import TenantService, UserService, UserTenantService +from api.db.services.tenant_llm_service import TenantLLMService from common.constants import FileSource, StatusEnum from api.utils.api_utils import deep_merge, get_parser_config, remap_dictionary_keys, verify_embedding_availability @@ -1050,3 +1051,162 @@ async def search(dataset_id: str, tenant_id: str, req: dict): ranks["labels"] = labels return True, ranks + + +async def search_datasets(tenant_id: str, req: dict): + """ + Search (retrieval test) across multiple datasets. + + :param tenant_id: tenant ID + :param req: search request containing dataset_ids and other params + :return: (success, result) or (success, error_message) + """ + from api.db.joint_services.tenant_model_service import ( + get_model_config_by_id, + get_model_config_by_type_and_name, + get_tenant_default_model_by_type, + ) + from api.db.services.doc_metadata_service import DocMetadataService + from api.db.services.llm_service import LLMBundle + from api.db.services.search_service import SearchService + from api.db.services.user_service import UserTenantService + from common.constants import LLMType + from common.metadata_utils import apply_meta_data_filter + from rag.app.tag import label_question + from rag.prompts.generator import cross_languages, keyword_extraction + + kb_ids = req.get("dataset_ids", []) + page = int(req.get("page", 1)) + size = int(req.get("size", 30)) + question = req.get("question", "") + doc_ids = req.get("doc_ids", []) + use_kg = req.get("use_kg", False) + top = max(1, min(int(req.get("top_k", 1024)), 2048)) + langs = req.get("cross_languages", []) + + logging.debug( + "search_datasets(datasets=%s, tenant=%s, question_len=%s)", + kb_ids, + tenant_id, + len(question), + ) + + # Access check for all datasets + for kb_id in kb_ids: + if not KnowledgebaseService.accessible(kb_id, tenant_id): + logging.warning("search_datasets access denied: dataset=%s tenant=%s", kb_id, tenant_id) + return False, f"Only owner of dataset {kb_id} authorized for this operation." + + kbs = KnowledgebaseService.get_by_ids(kb_ids) + if not kbs: + return False, "Datasets not found!" + + # All datasets must use the same embedding model + embd_nms = list(set([TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs])) + if len(embd_nms) != 1: + return False, "Datasets use different embedding models." + + if doc_ids is not None and not isinstance(doc_ids, list): + return False, "`doc_ids` should be a list" + local_doc_ids = list(doc_ids) if doc_ids else [] + + meta_data_filter = {} + chat_mdl = None + if req.get("search_id", ""): + search_detail = SearchService.get_detail(req.get("search_id", "")) + if not search_detail: + logging.warning("search config not found: search_id=%s", req.get("search_id", "")) + return False, "Invalid search_id" + search_config = search_detail.get("search_config", {}) + meta_data_filter = search_config.get("meta_data_filter", {}) + if meta_data_filter.get("method") in ["auto", "semi_auto"]: + chat_id = search_config.get("chat_id", "") + if chat_id: + chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, search_config["chat_id"]) + else: + chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) + chat_mdl = LLMBundle(tenant_id, chat_model_config) + else: + meta_data_filter = req.get("meta_data_filter") or {} + if meta_data_filter.get("method") in ["auto", "semi_auto"]: + chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) + chat_mdl = LLMBundle(tenant_id, chat_model_config) + + if meta_data_filter: + local_doc_ids = await apply_meta_data_filter( + meta_data_filter, + None, + question, + chat_mdl, + local_doc_ids, + kb_ids=kb_ids, + metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids), + ) + + tenant_ids = [] + tenants = UserTenantService.query(user_id=tenant_id) + for tenant in tenants: + if any(KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id) for kb_id in kb_ids): + tenant_ids.append(tenant.tenant_id) + break + else: + return False, "Only owner of datasets authorized for this operation." + + kb = kbs[0] + _question = question + if langs: + _question = await cross_languages(kb.tenant_id, None, _question, langs) + if kb.tenant_embd_id: + embd_model_config = get_model_config_by_id(kb.tenant_embd_id) + elif kb.embd_id: + embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) + else: + embd_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.EMBEDDING) + embd_mdl = LLMBundle(kb.tenant_id, embd_model_config) + + rerank_mdl = None + if req.get("tenant_rerank_id"): + rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"]) + rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) + elif req.get("rerank_id"): + rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"]) + rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) + + if req.get("keyword", False): + default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) + chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config) + _question += await keyword_extraction(chat_mdl, _question) + + labels = label_question(_question, kbs) + ranks = await settings.retriever.retrieval( + _question, + embd_mdl, + tenant_ids, + kb_ids, + page, + size, + float(req.get("similarity_threshold", 0.0)), + float(req.get("vector_similarity_weight", 0.3)), + doc_ids=local_doc_ids, + top=top, + rerank_mdl=rerank_mdl, + rank_feature=labels, + ) + + if use_kg: + try: + default_chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) + ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, default_chat_model_config)) + if ck["content_with_weight"]: + ranks["chunks"].insert(0, ck) + except Exception: + logging.warning("search_datasets KG retrieval failed: datasets=%s tenant=%s", kb_ids, tenant_id, exc_info=True) + total = ranks.get("total", 0) + ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids) + ranks["total"] = total + + for c in ranks["chunks"]: + c.pop("vector", None) + ranks["labels"] = labels + + return True, ranks diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index c51cf5acc4..94e0fa2ab8 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -858,6 +858,26 @@ class SearchDatasetReq(BaseModel): meta_data_filter: Annotated[dict | None, Field(default=None)] +class SearchDatasetsReq(BaseModel): + model_config = ConfigDict(extra="ignore") + + dataset_ids: Annotated[list[str], Field(..., min_length=1)] + question: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1), Field(...)] + doc_ids: Annotated[list[str], Field(default=[])] + page: Annotated[int, Field(default=1, ge=1)] + size: Annotated[int, Field(default=30, ge=1)] + top_k: Annotated[int, Field(default=1024, ge=1)] + similarity_threshold: Annotated[float, Field(default=0.0, ge=0.0, le=1.0)] + vector_similarity_weight: Annotated[float, Field(default=0.3, ge=0.0, le=1.0)] + use_kg: Annotated[bool, Field(default=False)] + cross_languages: Annotated[list[str], Field(default=[])] + keyword: Annotated[bool, Field(default=False)] + search_id: Annotated[str | None, Field(default=None)] + rerank_id: Annotated[str | None, Field(default=None)] + tenant_rerank_id: Annotated[str | None, Field(default=None)] + meta_data_filter: Annotated[dict | None, Field(default=None)] + + class BaseListReq(BaseModel): model_config = ConfigDict(extra="forbid") diff --git a/test/testcases/test_http_api/common.py b/test/testcases/test_http_api/common.py index 1e90415579..f62cf6338d 100644 --- a/test/testcases/test_http_api/common.py +++ b/test/testcases/test_http_api/common.py @@ -298,7 +298,7 @@ def batch_add_sessions_with_chat_assistant(auth, chat_assistant_id, num): # DATASET GRAPH AND TASKS def knowledge_graph(auth, dataset_id, params=None): - url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/graph/search" + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/graph" res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) return res.json() diff --git a/web/src/utils/api.ts b/web/src/utils/api.ts index 2e23727b76..b74d169611 100644 --- a/web/src/utils/api.ts +++ b/web/src/utils/api.ts @@ -65,7 +65,7 @@ export default { rmKb: `${restAPIv1}/datasets`, getKbDetail: (datasetId: string) => `${restAPIv1}/datasets/${datasetId}`, getKnowledgeGraph: (knowledgeId: string) => - `${restAPIv1}/datasets/${knowledgeId}/graph/search`, + `${restAPIv1}/datasets/${knowledgeId}/graph`, knowledgeGraph: (datasetId: string) => `${restAPIv1}/datasets/${datasetId}/graph`, deleteKnowledgeGraph: (knowledgeId: string) =>