mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-03 17:21:59 +08:00
Refactor: migrate chunk retrieval_test and knowledge_graph to REST API endpoints (#14402)
### What problem does this PR solve? ## Summary Migrate two web API endpoints to REST-style HTTP API endpoints, following the pattern established in #14222: | Old Endpoint | New Endpoint | |---|---| | `POST /v1/chunk/retrieval_test` | `POST /api/v1/datasets/<dataset_id>/search` | | `GET /v1/chunk/knowledge_graph` | `GET /api/v1/datasets/<dataset_id>/graph` |
This commit is contained in:
@@ -1,215 +0,0 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
|
||||
from quart import request
|
||||
|
||||
from api.apps import current_user, login_required
|
||||
from api.db.joint_services.tenant_model_service import (
|
||||
get_model_config_by_id,
|
||||
get_model_config_by_type_and_name,
|
||||
get_tenant_default_model_by_type,
|
||||
)
|
||||
from api.db.services.doc_metadata_service import DocMetadataService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
get_request_json,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
)
|
||||
from common import settings
|
||||
from common.constants import LLMType, RetCode
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import search
|
||||
from rag.prompts.generator import cross_languages, keyword_extraction
|
||||
|
||||
|
||||
@manager.route('/retrieval_test', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id", "question")
|
||||
async def retrieval_test():
|
||||
req = await get_request_json()
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
question = req["question"]
|
||||
kb_ids = req["kb_id"]
|
||||
if isinstance(kb_ids, str):
|
||||
kb_ids = [kb_ids]
|
||||
if not kb_ids:
|
||||
return get_json_result(data=False, message='Please specify dataset firstly.',
|
||||
code=RetCode.DATA_ERROR)
|
||||
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
use_kg = req.get("use_kg", False)
|
||||
top = int(req.get("top_k", 1024))
|
||||
langs = req.get("cross_languages", [])
|
||||
user_id = current_user.id
|
||||
|
||||
async def _retrieval():
|
||||
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||
tenant_ids = []
|
||||
|
||||
meta_data_filter = {}
|
||||
chat_mdl = None
|
||||
if req.get("search_id", ""):
|
||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_id = search_config.get("chat_id", "")
|
||||
if chat_id:
|
||||
chat_model_config = get_model_config_by_type_and_name(user_id, LLMType.CHAT, search_config["chat_id"])
|
||||
else:
|
||||
chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(user_id, chat_model_config)
|
||||
else:
|
||||
meta_data_filter = req.get("meta_data_filter") or {}
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(user_id, chat_model_config)
|
||||
|
||||
if meta_data_filter:
|
||||
metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
|
||||
local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, local_doc_ids)
|
||||
|
||||
tenants = UserTenantService.query(user_id=user_id)
|
||||
for kb_id in kb_ids:
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(
|
||||
tenant_id=tenant.tenant_id, id=kb_id):
|
||||
tenant_ids.append(tenant.tenant_id)
|
||||
break
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of dataset authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||
if not e:
|
||||
return get_data_error_result(message="Knowledgebase not found!")
|
||||
|
||||
_question = question
|
||||
if langs:
|
||||
_question = await cross_languages(kb.tenant_id, None, _question, langs)
|
||||
if kb.tenant_embd_id:
|
||||
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
|
||||
elif kb.embd_id:
|
||||
embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
||||
else:
|
||||
embd_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.EMBEDDING)
|
||||
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config)
|
||||
|
||||
rerank_mdl = None
|
||||
if req.get("tenant_rerank_id"):
|
||||
rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"])
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
||||
elif req.get("rerank_id"):
|
||||
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"])
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
||||
|
||||
if req.get("keyword", False):
|
||||
default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config)
|
||||
_question += await keyword_extraction(chat_mdl, _question)
|
||||
|
||||
labels = label_question(_question, [kb])
|
||||
ranks = await settings.retriever.retrieval(
|
||||
_question,
|
||||
embd_mdl,
|
||||
tenant_ids,
|
||||
kb_ids,
|
||||
page,
|
||||
size,
|
||||
float(req.get("similarity_threshold", 0.0)),
|
||||
float(req.get("vector_similarity_weight", 0.3)),
|
||||
doc_ids=local_doc_ids,
|
||||
top=top,
|
||||
rerank_mdl=rerank_mdl,
|
||||
rank_feature=labels
|
||||
)
|
||||
|
||||
if use_kg:
|
||||
default_chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT)
|
||||
ck = await settings.kg_retriever.retrieval(_question,
|
||||
tenant_ids,
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
LLMBundle(kb.tenant_id, default_chat_model_config))
|
||||
if ck["content_with_weight"]:
|
||||
ranks["chunks"].insert(0, ck)
|
||||
ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids)
|
||||
ranks["total"] = len(ranks["chunks"])
|
||||
|
||||
for c in ranks["chunks"]:
|
||||
c.pop("vector", None)
|
||||
ranks["labels"] = labels
|
||||
|
||||
return get_json_result(data=ranks)
|
||||
|
||||
try:
|
||||
return await _retrieval()
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||
code=RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def knowledge_graph():
|
||||
doc_id = request.args["doc_id"]
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
req = {
|
||||
"doc_ids": [doc_id],
|
||||
"knowledge_graph_kwd": ["graph", "mind_map"]
|
||||
}
|
||||
sres = await settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
for id in sres.ids[:2]:
|
||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||
try:
|
||||
content_json = json.loads(sres.field[id]["content_with_weight"])
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if ty == 'mind_map':
|
||||
node_dict = {}
|
||||
|
||||
def repeat_deal(content_json, node_dict):
|
||||
if 'id' in content_json:
|
||||
if content_json['id'] in node_dict:
|
||||
node_name = content_json['id']
|
||||
content_json['id'] += f"({node_dict[content_json['id']]})"
|
||||
node_dict[node_name] += 1
|
||||
else:
|
||||
node_dict[content_json['id']] = 1
|
||||
if 'children' in content_json and content_json['children']:
|
||||
for item in content_json['children']:
|
||||
repeat_deal(item, node_dict)
|
||||
|
||||
repeat_deal(content_json, node_dict)
|
||||
|
||||
obj[ty] = content_json
|
||||
|
||||
return get_json_result(data=obj)
|
||||
@@ -24,6 +24,7 @@ from api.utils.validation_utils import (
|
||||
CreateDatasetReq,
|
||||
DeleteDatasetReq,
|
||||
ListDatasetReq,
|
||||
SearchDatasetReq,
|
||||
UpdateDatasetReq,
|
||||
validate_and_parse_json_request,
|
||||
validate_and_parse_request_args,
|
||||
@@ -476,6 +477,35 @@ async def rename_tag(tenant_id, dataset_id):
|
||||
return get_error_data_result(message="Internal server error")
|
||||
|
||||
|
||||
@manager.route('/datasets/<dataset_id>/search', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@add_tenant_id_to_kwargs
|
||||
async def search(tenant_id, dataset_id):
|
||||
"""Search (retrieval test) within a dataset.
|
||||
|
||||
POST /api/v1/datasets/<dataset_id>/search
|
||||
JSON body: {"question": str (required), "doc_ids": list[str], "top_k": int, "page": int, "size": int,
|
||||
"similarity_threshold": float, "vector_similarity_weight": float, "use_kg": bool,
|
||||
"cross_languages": list[str], "keyword": bool, "meta_data_filter": dict}
|
||||
Success: {"code": 0, "data": {"chunks": [...], "total": int, "labels": [...]}}
|
||||
Errors: ARGUMENT_ERROR (101) for invalid payload; DATA_ERROR (102) for access denied or internal errors.
|
||||
"""
|
||||
req, err = await validate_and_parse_json_request(request, SearchDatasetReq)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
try:
|
||||
success, result = await dataset_api_service.search(dataset_id, tenant_id, req)
|
||||
if success:
|
||||
return get_result(data=result)
|
||||
else:
|
||||
return get_error_data_result(message=result)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
if "not_found" in str(e):
|
||||
return get_error_data_result(message="No chunk found! Check the chunk status please!")
|
||||
return get_error_data_result(message="Internal server error")
|
||||
|
||||
|
||||
@manager.route('/datasets/<dataset_id>/graph/search', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
@add_tenant_id_to_kwargs
|
||||
@@ -495,6 +525,32 @@ async def knowledge_graph(tenant_id, dataset_id):
|
||||
return get_error_data_result(message="Internal server error")
|
||||
|
||||
|
||||
@manager.route('/datasets/<dataset_id>/graph', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
@add_tenant_id_to_kwargs
|
||||
async def get_knowledge_graph(tenant_id, dataset_id):
|
||||
"""Get the knowledge graph of a dataset.
|
||||
|
||||
GET /api/v1/datasets/<dataset_id>/graph
|
||||
Query params: optional filter params.
|
||||
Success: {"code": 0, "data": {...}}
|
||||
Errors: AUTHENTICATION_ERROR for access denied; DATA_ERROR for internal errors.
|
||||
"""
|
||||
try:
|
||||
success, result = await dataset_api_service.get_knowledge_graph(dataset_id, tenant_id)
|
||||
if success:
|
||||
return get_result(data=result)
|
||||
else:
|
||||
return get_result(
|
||||
data=False,
|
||||
message=result,
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Internal server error")
|
||||
|
||||
|
||||
@manager.route('/datasets/<dataset_id>/graph', methods=['DELETE']) # noqa: F821
|
||||
@login_required
|
||||
@add_tenant_id_to_kwargs
|
||||
|
||||
@@ -900,3 +900,153 @@ def rename_tag(dataset_id: str, tenant_id: str, from_tag: str, to_tag: str):
|
||||
|
||||
return True, {"from": from_tag, "to": to_tag}
|
||||
|
||||
|
||||
async def search(dataset_id: str, tenant_id: str, req: dict):
|
||||
"""
|
||||
Search (retrieval test) within a dataset.
|
||||
|
||||
:param dataset_id: dataset ID
|
||||
:param tenant_id: tenant ID
|
||||
:param req: search request
|
||||
:return: (success, result) or (success, error_message)
|
||||
"""
|
||||
from api.db.joint_services.tenant_model_service import (
|
||||
get_model_config_by_id,
|
||||
get_model_config_by_type_and_name,
|
||||
get_tenant_default_model_by_type,
|
||||
)
|
||||
from api.db.services.doc_metadata_service import DocMetadataService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from common.constants import LLMType
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.generator import cross_languages, keyword_extraction
|
||||
|
||||
logging.debug(
|
||||
"search(dataset=%s, tenant=%s, question_len=%s)",
|
||||
dataset_id,
|
||||
tenant_id,
|
||||
len(req.get("question", "")),
|
||||
)
|
||||
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
question = req.get("question", "")
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
use_kg = req.get("use_kg", False)
|
||||
top = max(1, min(int(req.get("top_k", 1024)), 2048))
|
||||
langs = req.get("cross_languages", [])
|
||||
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
logging.warning("search access denied: dataset=%s tenant=%s", dataset_id, tenant_id)
|
||||
return False, "Only owner of dataset authorized for this operation."
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not e:
|
||||
logging.warning("search dataset not found: dataset=%s", dataset_id)
|
||||
return False, "Dataset not found!"
|
||||
|
||||
if doc_ids is not None and not isinstance(doc_ids, list):
|
||||
return False, "`doc_ids` should be a list"
|
||||
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||
|
||||
meta_data_filter = {}
|
||||
chat_mdl = None
|
||||
if req.get("search_id", ""):
|
||||
search_detail = SearchService.get_detail(req.get("search_id", ""))
|
||||
if not search_detail:
|
||||
logging.warning("search config not found: search_id=%s", req.get("search_id", ""))
|
||||
return False, "Invalid search_id"
|
||||
search_config = search_detail.get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_id = search_config.get("chat_id", "")
|
||||
if chat_id:
|
||||
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, search_config["chat_id"])
|
||||
else:
|
||||
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(tenant_id, chat_model_config)
|
||||
else:
|
||||
meta_data_filter = req.get("meta_data_filter") or {}
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(tenant_id, chat_model_config)
|
||||
|
||||
if meta_data_filter:
|
||||
metas = DocMetadataService.get_flatted_meta_by_kbs([dataset_id])
|
||||
local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, local_doc_ids)
|
||||
|
||||
tenant_ids = []
|
||||
tenants = UserTenantService.query(user_id=tenant_id)
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=dataset_id):
|
||||
tenant_ids.append(tenant.tenant_id)
|
||||
break
|
||||
else:
|
||||
return False, "Only owner of dataset authorized for this operation."
|
||||
|
||||
_question = question
|
||||
if langs:
|
||||
_question = await cross_languages(kb.tenant_id, None, _question, langs)
|
||||
if kb.tenant_embd_id:
|
||||
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
|
||||
elif kb.embd_id:
|
||||
embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
||||
else:
|
||||
embd_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.EMBEDDING)
|
||||
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config)
|
||||
|
||||
rerank_mdl = None
|
||||
if req.get("tenant_rerank_id"):
|
||||
rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"])
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
||||
elif req.get("rerank_id"):
|
||||
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"])
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
||||
|
||||
if req.get("keyword", False):
|
||||
default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config)
|
||||
_question += await keyword_extraction(chat_mdl, _question)
|
||||
|
||||
labels = label_question(_question, [kb])
|
||||
ranks = await settings.retriever.retrieval(
|
||||
_question,
|
||||
embd_mdl,
|
||||
tenant_ids,
|
||||
[dataset_id],
|
||||
page,
|
||||
size,
|
||||
float(req.get("similarity_threshold", 0.0)),
|
||||
float(req.get("vector_similarity_weight", 0.3)),
|
||||
doc_ids=local_doc_ids,
|
||||
top=top,
|
||||
rerank_mdl=rerank_mdl,
|
||||
rank_feature=labels
|
||||
)
|
||||
|
||||
if use_kg:
|
||||
try:
|
||||
default_chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
|
||||
ck = await settings.kg_retriever.retrieval(_question,
|
||||
tenant_ids,
|
||||
[dataset_id],
|
||||
embd_mdl,
|
||||
LLMBundle(kb.tenant_id, default_chat_model_config))
|
||||
if ck["content_with_weight"]:
|
||||
ranks["chunks"].insert(0, ck)
|
||||
except Exception:
|
||||
logging.warning("search KG retrieval failed: dataset=%s tenant=%s", dataset_id, tenant_id, exc_info=True)
|
||||
total = ranks.get("total", 0)
|
||||
ranks["chunks"] = settings.retriever.retrieval_by_children(
|
||||
ranks["chunks"], tenant_ids
|
||||
)
|
||||
ranks["total"] = total
|
||||
|
||||
for c in ranks["chunks"]:
|
||||
c.pop("vector", None)
|
||||
ranks["labels"] = labels
|
||||
|
||||
return True, ranks
|
||||
|
||||
@@ -819,6 +819,25 @@ class DeleteReq(Base):
|
||||
class DeleteDatasetReq(DeleteReq): ...
|
||||
|
||||
|
||||
class SearchDatasetReq(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
question: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1), Field(...)]
|
||||
doc_ids: Annotated[list[str], Field(default=[])]
|
||||
page: Annotated[int, Field(default=1, ge=1)]
|
||||
size: Annotated[int, Field(default=30, ge=1)]
|
||||
top_k: Annotated[int, Field(default=1024, ge=1)]
|
||||
similarity_threshold: Annotated[float, Field(default=0.0, ge=0.0, le=1.0)]
|
||||
vector_similarity_weight: Annotated[float, Field(default=0.3, ge=0.0, le=1.0)]
|
||||
use_kg: Annotated[bool, Field(default=False)]
|
||||
cross_languages: Annotated[list[str], Field(default=[])]
|
||||
keyword: Annotated[bool, Field(default=False)]
|
||||
search_id: Annotated[str | None, Field(default=None)]
|
||||
rerank_id: Annotated[str | None, Field(default=None)]
|
||||
tenant_rerank_id: Annotated[str | None, Field(default=None)]
|
||||
meta_data_filter: Annotated[dict | None, Field(default=None)]
|
||||
|
||||
|
||||
class DeleteDocumentReq(DeleteReq): ...
|
||||
|
||||
|
||||
|
||||
@@ -517,3 +517,12 @@ def get_flattened_metadata(auth, dataset_ids, *, headers=HEADERS):
|
||||
url = f"{HOST_ADDRESS}{DATASETS_API_URL}/metadata/flattened"
|
||||
res = requests.get(url=url, headers=headers, auth=auth, params={"dataset_ids": ",".join(dataset_ids)})
|
||||
return res.json()
|
||||
|
||||
|
||||
def search_dataset(auth, dataset_id, payload=None, *, headers=HEADERS):
|
||||
url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/search"
|
||||
res = requests.post(url=url, headers=headers, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import pytest
|
||||
from common import search_dataset, knowledge_graph
|
||||
from configs import INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
class TestAuthorization:
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_code, expected_message",
|
||||
[
|
||||
(None, 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
(RAGFlowHttpApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
|
||||
res = search_dataset(invalid_auth, "dataset_id", {"question": "test"})
|
||||
assert res["code"] == expected_code
|
||||
assert expected_message in res.get("message", "")
|
||||
|
||||
|
||||
class TestDatasetSearch:
|
||||
@pytest.mark.p2
|
||||
def test_search_without_question(self, HttpApiAuth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
res = search_dataset(HttpApiAuth, dataset_id, {})
|
||||
assert res["code"] == 101, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_search_basic(self, HttpApiAuth, add_chunks):
|
||||
dataset_id, document_id, _ = add_chunks
|
||||
res = search_dataset(HttpApiAuth, dataset_id, {"question": "chunk"})
|
||||
assert res["code"] == 0, res
|
||||
assert "chunks" in res["data"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_search_with_doc_ids(self, HttpApiAuth, add_chunks):
|
||||
dataset_id, document_id, _ = add_chunks
|
||||
res = search_dataset(HttpApiAuth, dataset_id, {"question": "chunk", "doc_ids": [document_id]})
|
||||
assert res["code"] == 0, res
|
||||
assert "chunks" in res["data"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code",
|
||||
[
|
||||
({"question": "chunk", "page": 1, "size": 2}, 0),
|
||||
({"question": "chunk", "similarity_threshold": 0.5}, 0),
|
||||
({"question": "chunk", "vector_similarity_weight": 0.7}, 0),
|
||||
({"question": "chunk", "top_k": 10}, 0),
|
||||
],
|
||||
)
|
||||
def test_search_params(self, HttpApiAuth, add_chunks, payload, expected_code):
|
||||
dataset_id, _, _ = add_chunks
|
||||
res = search_dataset(HttpApiAuth, dataset_id, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
class TestDatasetGraph:
|
||||
def test_graph_requires_auth(self):
|
||||
res = knowledge_graph(None, "dataset_id")
|
||||
assert res["code"] == 401
|
||||
|
||||
def test_graph_basic(self, HttpApiAuth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
res = knowledge_graph(HttpApiAuth, dataset_id)
|
||||
assert res["code"] == 0, res
|
||||
@@ -17,7 +17,6 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import importlib.util
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
@@ -491,13 +490,15 @@ def _load_chunk_module(monkeypatch):
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
|
||||
services_pkg.user_service = user_service_mod
|
||||
|
||||
module_name = "test_chunk_routes_unit_module"
|
||||
module_path = repo_root / "api" / "apps" / "chunk_app.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, module_name, module)
|
||||
spec.loader.exec_module(module)
|
||||
module = None
|
||||
if module_path.exists():
|
||||
module_name = "test_chunk_routes_unit_module"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, module_name, module)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@@ -653,167 +654,3 @@ def test_restful_chunk_guard_branches_unit(monkeypatch):
|
||||
assert res["message"] == "`available_int` or `available` is required.", res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_retrieval_test_branch_matrix_unit(monkeypatch):
|
||||
module = _load_chunk_module(monkeypatch)
|
||||
module.request = SimpleNamespace(headers={"X-Request-ID": "req-r"}, args={})
|
||||
|
||||
applied_filters = []
|
||||
llm_calls = []
|
||||
cross_calls = []
|
||||
keyword_calls = []
|
||||
|
||||
async def _apply_filter(meta_data_filter, metas, question, chat_mdl, local_doc_ids):
|
||||
applied_filters.append(
|
||||
{
|
||||
"meta_data_filter": meta_data_filter,
|
||||
"metas": metas,
|
||||
"question": question,
|
||||
"chat_mdl": chat_mdl,
|
||||
"local_doc_ids": list(local_doc_ids),
|
||||
}
|
||||
)
|
||||
return ["doc-filtered"]
|
||||
|
||||
async def _cross_languages(_tenant_id, _dialog, question, langs):
|
||||
cross_calls.append((question, tuple(langs)))
|
||||
return f"{question}-xl"
|
||||
|
||||
async def _keyword_extraction(_chat_mdl, question):
|
||||
keyword_calls.append(question)
|
||||
return "-kw"
|
||||
|
||||
class _Retriever:
|
||||
def __init__(self, mode="ok"):
|
||||
self.mode = mode
|
||||
self.retrieval_questions = []
|
||||
|
||||
async def retrieval(self, question, *_args, **_kwargs):
|
||||
if self.mode == "not_found":
|
||||
raise Exception("boom not_found boom")
|
||||
if self.mode == "explode":
|
||||
raise RuntimeError("retrieval boom")
|
||||
self.retrieval_questions.append(question)
|
||||
return {"chunks": [{"id": "c1", "vector": [0.1], "content_with_weight": "chunk-content"}]}
|
||||
|
||||
def retrieval_by_children(self, chunks, _tenant_ids):
|
||||
return list(chunks)
|
||||
|
||||
class _KgRetriever:
|
||||
async def retrieval(self, *_args, **_kwargs):
|
||||
return {"id": "kg-1", "content_with_weight": "kg-content"}
|
||||
|
||||
class _NoContentKgRetriever:
|
||||
async def retrieval(self, *_args, **_kwargs):
|
||||
return {"id": "kg-2", "content_with_weight": ""}
|
||||
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *args, **kwargs: llm_calls.append((args, kwargs)) or SimpleNamespace())
|
||||
monkeypatch.setattr(module, "get_model_config_by_type_and_name", lambda *_args, **_kwargs: {"llm_name": "stub-model", "model_type": "chat"})
|
||||
monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: {"llm_name": "stub-model", "model_type": "chat"})
|
||||
monkeypatch.setattr(module, "get_model_config_by_id", lambda *_args, **_kwargs: {"llm_name": "stub-model", "model_type": "embedding"})
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kb_ids: [{"meta": "v"}], raising=False)
|
||||
monkeypatch.setattr(module, "apply_meta_data_filter", _apply_filter)
|
||||
monkeypatch.setattr(module.SearchService, "get_detail", lambda _sid: {"search_config": {"meta_data_filter": {"method": "auto"}, "chat_id": "chat-1"}}, raising=False)
|
||||
monkeypatch.setattr(module, "cross_languages", _cross_languages)
|
||||
monkeypatch.setattr(module, "keyword_extraction", _keyword_extraction)
|
||||
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: ["lbl"])
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [_DummyTenant("tenant-1")])
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: False, raising=False)
|
||||
_set_request_json(monkeypatch, module, {"kb_id": "kb-1", "question": "q", "search_id": "search-1"})
|
||||
res = _run(module.retrieval_test())
|
||||
assert res["code"] == module.RetCode.OPERATING_ERROR, res
|
||||
assert "Only owner of dataset authorized for this operation." in res["message"], res
|
||||
assert applied_filters and applied_filters[-1]["meta_data_filter"]["method"] == "auto"
|
||||
assert llm_calls, "search_id metadata auto branch should instantiate chat model"
|
||||
|
||||
_set_request_json(monkeypatch, module, {"kb_id": [], "question": "q"})
|
||||
res = _run(module.retrieval_test())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR, res
|
||||
assert "Please specify dataset firstly." in res["message"], res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: True, raising=False)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None), raising=False)
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{"kb_id": ["kb-1"], "question": "q", "meta_data_filter": {"method": "semi_auto"}},
|
||||
)
|
||||
res = _run(module.retrieval_test())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR, res
|
||||
assert "Knowledgebase not found!" in res["message"], res
|
||||
|
||||
retriever = _Retriever(mode="ok")
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-kb", embd_id="embd-1", tenant_embd_id=2)), raising=False)
|
||||
monkeypatch.setattr(module.settings, "retriever", retriever)
|
||||
monkeypatch.setattr(module.settings, "kg_retriever", _KgRetriever(), raising=False)
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"kb_id": ["kb-1"],
|
||||
"question": "q",
|
||||
"cross_languages": ["fr"],
|
||||
"rerank_id": "rerank-1",
|
||||
"keyword": True,
|
||||
"use_kg": True,
|
||||
},
|
||||
)
|
||||
res = _run(module.retrieval_test())
|
||||
assert res["code"] == 0, res
|
||||
assert cross_calls[-1] == ("q", ("fr",))
|
||||
assert keyword_calls[-1] == "q-xl"
|
||||
assert retriever.retrieval_questions[-1] == "q-xl-kw"
|
||||
assert res["data"]["chunks"][0]["id"] == "kg-1", res
|
||||
assert all("vector" not in chunk for chunk in res["data"]["chunks"])
|
||||
|
||||
monkeypatch.setattr(module.settings, "kg_retriever", _NoContentKgRetriever(), raising=False)
|
||||
_set_request_json(monkeypatch, module, {"kb_id": ["kb-1"], "question": "q", "use_kg": True})
|
||||
res = _run(module.retrieval_test())
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["chunks"][0]["id"] == "c1", res
|
||||
|
||||
monkeypatch.setattr(module.settings, "retriever", _Retriever(mode="not_found"))
|
||||
_set_request_json(monkeypatch, module, {"kb_id": ["kb-1"], "question": "q"})
|
||||
res = _run(module.retrieval_test())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR, res
|
||||
assert "No chunk found! Check the chunk status please!" in res["message"], res
|
||||
|
||||
monkeypatch.setattr(module.settings, "retriever", _Retriever(mode="explode"))
|
||||
_set_request_json(monkeypatch, module, {"kb_id": ["kb-1"], "question": "q"})
|
||||
res = _run(module.retrieval_test())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
|
||||
assert "retrieval boom" in res["message"], res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_knowledge_graph_repeat_deal_matrix_unit(monkeypatch):
|
||||
module = _load_chunk_module(monkeypatch)
|
||||
module.request = SimpleNamespace(args={"doc_id": "doc-1"}, headers={})
|
||||
|
||||
payload = {
|
||||
"id": "root",
|
||||
"children": [
|
||||
{"id": "dup"},
|
||||
{"id": "dup", "children": [{"id": "dup"}]},
|
||||
],
|
||||
}
|
||||
|
||||
class _SRes:
|
||||
ids = ["bad-json", "mind-map"]
|
||||
field = {
|
||||
"bad-json": {"knowledge_graph_kwd": "graph", "content_with_weight": "{bad json"},
|
||||
"mind-map": {"knowledge_graph_kwd": "mind_map", "content_with_weight": json.dumps(payload)},
|
||||
}
|
||||
|
||||
async def _search(*_args, **_kwargs):
|
||||
return _SRes()
|
||||
|
||||
monkeypatch.setattr(module.settings.retriever, "search", _search)
|
||||
res = _run(module.knowledge_graph())
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["graph"] == {}, res
|
||||
mind_map = res["data"]["mind_map"]
|
||||
assert mind_map["children"][0]["id"] == "dup", res
|
||||
assert mind_map["children"][1]["id"] == "dup(1)", res
|
||||
assert mind_map["children"][1]["children"][0]["id"] == "dup(2)", res
|
||||
|
||||
@@ -1,308 +0,0 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
from test_common import retrieval_chunks
|
||||
from configs import INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
class TestAuthorization:
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_code, expected_message",
|
||||
[
|
||||
(None, 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
(RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
|
||||
res = retrieval_chunks(invalid_auth, {"kb_id": "dummy_kb_id", "question": "dummy question"})
|
||||
assert res["code"] == expected_code, res
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
|
||||
class TestChunksRetrieval:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
({"question": "chunk", "kb_id": None}, 0, 4, ""),
|
||||
({"question": "chunk", "doc_ids": None}, 101, 0, "required argument are missing: kb_id; "),
|
||||
({"question": "chunk", "kb_id": None, "doc_ids": None}, 0, 4, ""),
|
||||
({"question": "chunk"}, 101, 0, "required argument are missing: kb_id; "),
|
||||
],
|
||||
)
|
||||
def test_basic_scenarios(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
|
||||
dataset_id, document_id, _ = add_chunks
|
||||
if "kb_id" in payload:
|
||||
payload["kb_id"] = [dataset_id]
|
||||
if "doc_ids" in payload:
|
||||
payload["doc_ids"] = [document_id]
|
||||
res = retrieval_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size, res
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
pytest.param(
|
||||
{"page": None, "size": 2},
|
||||
100,
|
||||
0,
|
||||
"""TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
pytest.param(
|
||||
{"page": 0, "size": 2},
|
||||
100,
|
||||
0,
|
||||
"ValueError('Search does not support negative slicing.')",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
pytest.param({"page": 2, "size": 2}, 0, 2, "", marks=pytest.mark.skip(reason="issues/6646")),
|
||||
({"page": 3, "size": 2}, 0, 0, ""),
|
||||
({"page": "3", "size": 2}, 0, 0, ""),
|
||||
pytest.param(
|
||||
{"page": -1, "size": 2},
|
||||
100,
|
||||
0,
|
||||
"ValueError('Search does not support negative slicing.')",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
pytest.param(
|
||||
{"page": "a", "size": 2},
|
||||
100,
|
||||
0,
|
||||
"""ValueError("invalid literal for int() with base 10: 'a'")""",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_page(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "kb_id": [dataset_id]})
|
||||
res = retrieval_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size, res
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
pytest.param(
|
||||
{"size": None},
|
||||
100,
|
||||
0,
|
||||
"""TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
# ({"size": 0}, 0, 0, ""),
|
||||
({"size": 1}, 0, 1, ""),
|
||||
({"size": 5}, 0, 4, ""),
|
||||
({"size": "1"}, 0, 1, ""),
|
||||
# ({"size": -1}, 0, 0, ""),
|
||||
pytest.param(
|
||||
{"size": "a"},
|
||||
100,
|
||||
0,
|
||||
"""ValueError("invalid literal for int() with base 10: 'a'")""",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_page_size(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "kb_id": [dataset_id]})
|
||||
|
||||
res = retrieval_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size, res
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
({"vector_similarity_weight": 0}, 0, 4, ""),
|
||||
({"vector_similarity_weight": 0.5}, 0, 4, ""),
|
||||
({"vector_similarity_weight": 10}, 0, 4, ""),
|
||||
pytest.param(
|
||||
{"vector_similarity_weight": "a"},
|
||||
100,
|
||||
0,
|
||||
"""ValueError("could not convert string to float: 'a'")""",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_vector_similarity_weight(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "kb_id": [dataset_id]})
|
||||
res = retrieval_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size, res
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
({"top_k": 10}, 0, 4, ""),
|
||||
pytest.param(
|
||||
{"top_k": 1},
|
||||
0,
|
||||
4,
|
||||
"",
|
||||
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"),
|
||||
),
|
||||
pytest.param(
|
||||
{"top_k": 1},
|
||||
0,
|
||||
1,
|
||||
"",
|
||||
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"),
|
||||
),
|
||||
pytest.param(
|
||||
{"top_k": -1},
|
||||
100,
|
||||
4,
|
||||
"must be greater than 0",
|
||||
marks=pytest.mark.skip(reason="Web API does not validate top_k"),
|
||||
),
|
||||
pytest.param(
|
||||
{"top_k": -1},
|
||||
100,
|
||||
4,
|
||||
"3014",
|
||||
marks=pytest.mark.skip(reason="Web API does not validate top_k"),
|
||||
),
|
||||
pytest.param(
|
||||
{"top_k": "a"},
|
||||
100,
|
||||
0,
|
||||
"""ValueError("invalid literal for int() with base 10: 'a'")""",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_top_k(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "kb_id": [dataset_id]})
|
||||
res = retrieval_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size, res
|
||||
else:
|
||||
assert expected_message in res["message"], res
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message",
|
||||
[
|
||||
({"rerank_id": "BAAI/bge-reranker-v2-m3"}, 0, ""),
|
||||
pytest.param({"rerank_id": "unknown"}, 100, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip),
|
||||
],
|
||||
)
|
||||
def test_rerank_id(self, WebApiAuth, add_chunks, payload, expected_code, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "kb_id": [dataset_id]})
|
||||
res = retrieval_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) > 0, res
|
||||
else:
|
||||
assert expected_message in res["message"], res
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
({"keyword": True}, 0, 5, ""),
|
||||
({"keyword": "True"}, 0, 5, ""),
|
||||
({"keyword": False}, 0, 5, ""),
|
||||
({"keyword": "False"}, 0, 5, ""),
|
||||
({"keyword": None}, 0, 5, ""),
|
||||
],
|
||||
)
|
||||
def test_keyword(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk test", "kb_id": [dataset_id]})
|
||||
res = retrieval_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size, res
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_highlight, expected_message",
|
||||
[
|
||||
pytest.param({"highlight": True}, 0, True, "", marks=pytest.mark.skip(reason="highlight not functionnal")),
|
||||
pytest.param({"highlight": "True"}, 0, True, "", marks=pytest.mark.skip(reason="highlight not functionnal")),
|
||||
({"highlight": False}, 0, False, ""),
|
||||
({"highlight": "False"}, 0, False, ""),
|
||||
({"highlight": None}, 0, False, "")
|
||||
],
|
||||
)
|
||||
def test_highlight(self, WebApiAuth, add_chunks, payload, expected_code, expected_highlight, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "kb_id": [dataset_id]})
|
||||
res = retrieval_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_highlight:
|
||||
for chunk in res["data"]["chunks"]:
|
||||
assert "highlight" in chunk, res
|
||||
else:
|
||||
for chunk in res["data"]["chunks"]:
|
||||
assert "highlight" not in chunk, res
|
||||
|
||||
if expected_code != 0:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_invalid_params(self, WebApiAuth, add_chunks):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload = {"question": "chunk", "kb_id": [dataset_id], "a": "b"}
|
||||
res = retrieval_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]["chunks"]) == 4, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_concurrent_retrieval(self, WebApiAuth, add_chunks):
|
||||
dataset_id, _, _ = add_chunks
|
||||
count = 100
|
||||
payload = {"question": "chunk", "kb_id": [dataset_id]}
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(retrieval_chunks, WebApiAuth, payload) for i in range(count)]
|
||||
responses = list(as_completed(futures))
|
||||
assert len(responses) == count, responses
|
||||
assert all(future.result()["code"] == 0 for future in futures)
|
||||
@@ -244,22 +244,6 @@ def kb_pipeline_log_detail(auth, dataset_id, log_id, *, headers=HEADERS):
|
||||
return res.json()
|
||||
|
||||
|
||||
# DATASET GRAPH AND TASKS
|
||||
def knowledge_graph(auth, dataset_id, params=None):
|
||||
url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/knowledge_graph"
|
||||
res = requests.get(url=url, headers=HEADERS, auth=auth, params=params)
|
||||
return res.json()
|
||||
|
||||
|
||||
def delete_knowledge_graph(auth, dataset_id, payload=None):
|
||||
url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/knowledge_graph"
|
||||
if payload is None:
|
||||
res = requests.delete(url=url, headers=HEADERS, auth=auth)
|
||||
else:
|
||||
res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
|
||||
def list_tags_from_kbs(auth, dataset_ids, *, headers=HEADERS):
|
||||
params = {"dataset_ids": dataset_ids}
|
||||
res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_URL}/tags/aggregation", headers=headers, auth=auth, params=params)
|
||||
@@ -518,11 +502,6 @@ def delete_chunks(auth, dataset_id, document_id, payload=None, *, headers=HEADER
|
||||
return res.json()
|
||||
|
||||
|
||||
def retrieval_chunks(auth, payload=None, *, headers=HEADERS):
|
||||
res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_APP_URL}/retrieval_test", headers=headers, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
|
||||
def batch_add_chunks(auth, dataset_id, document_id, num):
|
||||
chunk_ids = []
|
||||
for i in range(num):
|
||||
|
||||
@@ -18,11 +18,9 @@ const {
|
||||
documentChangeStatus,
|
||||
documentChangeParser,
|
||||
documentThumbnails,
|
||||
retrievalTest,
|
||||
documentIngest,
|
||||
documentUpload,
|
||||
webCrawl,
|
||||
knowledgeGraph,
|
||||
listTagByKnowledgeIds,
|
||||
setMeta,
|
||||
getMeta,
|
||||
@@ -71,14 +69,6 @@ const methods = {
|
||||
url: setMeta,
|
||||
method: 'post',
|
||||
},
|
||||
retrievalTest: {
|
||||
url: retrievalTest,
|
||||
method: 'post',
|
||||
},
|
||||
knowledgeGraph: {
|
||||
url: knowledgeGraph,
|
||||
method: 'get',
|
||||
},
|
||||
listTagByKnowledgeIds: {
|
||||
url: listTagByKnowledgeIds,
|
||||
method: 'get',
|
||||
@@ -151,6 +141,17 @@ const getAvailableParam = (available?: number) => {
|
||||
};
|
||||
|
||||
const chunkService = {
|
||||
retrievalTest: async (params: Record<string, any>) => {
|
||||
const datasetId = getDatasetId(params);
|
||||
if (!datasetId) {
|
||||
throw new Error(
|
||||
'dataset_id (or kb_id/knowledge_id) is required for retrievalTest',
|
||||
);
|
||||
}
|
||||
return request.post(api.retrievalTest(datasetId), {
|
||||
data: params,
|
||||
});
|
||||
},
|
||||
chunkList: async (params: Record<string, any>) => {
|
||||
const datasetId = getDatasetId(params);
|
||||
const documentId = getDocumentId(params);
|
||||
|
||||
@@ -66,6 +66,8 @@ export default {
|
||||
getKbDetail: (datasetId: string) => `${restAPIv1}/datasets/${datasetId}`,
|
||||
getKnowledgeGraph: (knowledgeId: string) =>
|
||||
`${restAPIv1}/datasets/${knowledgeId}/graph/search`,
|
||||
knowledgeGraph: (datasetId: string) =>
|
||||
`${restAPIv1}/datasets/${datasetId}/graph`,
|
||||
deleteKnowledgeGraph: (knowledgeId: string) =>
|
||||
`${restAPIv1}/datasets/${knowledgeId}/graph`,
|
||||
getMeta: `${restAPIv1}/datasets/metadata/flattened`,
|
||||
@@ -107,8 +109,8 @@ export default {
|
||||
`${restAPIv1}/datasets/${datasetId}/documents/${documentId}/chunks`,
|
||||
chunkDetail: (datasetId: string, documentId: string, chunkId: string) =>
|
||||
`${restAPIv1}/datasets/${datasetId}/documents/${documentId}/chunks/${chunkId}`,
|
||||
retrievalTest: `${webAPI}/chunk/retrieval_test`,
|
||||
knowledgeGraph: `${webAPI}/chunk/knowledge_graph`,
|
||||
retrievalTest: (datasetId: string) =>
|
||||
`${restAPIv1}/datasets/${datasetId}/search`,
|
||||
|
||||
// document
|
||||
getDocumentList: (datasetId: string) =>
|
||||
|
||||
Reference in New Issue
Block a user