Refactor: migrate chunk retrieval_test and knowledge_graph to REST API endpoints (#14402)

### What problem does this PR solve?

## Summary

Migrate two web API endpoints to REST-style HTTP API endpoints,
following the pattern established in #14222:

| Old Endpoint | New Endpoint |
|---|---|
| `POST /v1/chunk/retrieval_test` | `POST
/api/v1/datasets/<dataset_id>/search` |
| `GET /v1/chunk/knowledge_graph` | `GET
/api/v1/datasets/<dataset_id>/graph` |
This commit is contained in:
euvre
2026-04-28 12:00:26 +00:00
committed by GitHub
parent 85575259ac
commit 35f6d81b73
11 changed files with 340 additions and 727 deletions

View File

@@ -1,215 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from quart import request
from api.apps import current_user, login_required
from api.db.joint_services.tenant_model_service import (
get_model_config_by_id,
get_model_config_by_type_and_name,
get_tenant_default_model_by_type,
)
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import (
get_data_error_result,
get_json_result,
get_request_json,
server_error_response,
validate_request,
)
from common import settings
from common.constants import LLMType, RetCode
from common.metadata_utils import apply_meta_data_filter
from rag.app.tag import label_question
from rag.nlp import search
from rag.prompts.generator import cross_languages, keyword_extraction
@manager.route('/retrieval_test', methods=['POST']) # noqa: F821
@login_required
@validate_request("kb_id", "question")
async def retrieval_test():
req = await get_request_json()
page = int(req.get("page", 1))
size = int(req.get("size", 30))
question = req["question"]
kb_ids = req["kb_id"]
if isinstance(kb_ids, str):
kb_ids = [kb_ids]
if not kb_ids:
return get_json_result(data=False, message='Please specify dataset firstly.',
code=RetCode.DATA_ERROR)
doc_ids = req.get("doc_ids", [])
use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024))
langs = req.get("cross_languages", [])
user_id = current_user.id
async def _retrieval():
local_doc_ids = list(doc_ids) if doc_ids else []
tenant_ids = []
meta_data_filter = {}
chat_mdl = None
if req.get("search_id", ""):
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
chat_id = search_config.get("chat_id", "")
if chat_id:
chat_model_config = get_model_config_by_type_and_name(user_id, LLMType.CHAT, search_config["chat_id"])
else:
chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT)
chat_mdl = LLMBundle(user_id, chat_model_config)
else:
meta_data_filter = req.get("meta_data_filter") or {}
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT)
chat_mdl = LLMBundle(user_id, chat_model_config)
if meta_data_filter:
metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, local_doc_ids)
tenants = UserTenantService.query(user_id=user_id)
for kb_id in kb_ids:
for tenant in tenants:
if KnowledgebaseService.query(
tenant_id=tenant.tenant_id, id=kb_id):
tenant_ids.append(tenant.tenant_id)
break
else:
return get_json_result(
data=False, message='Only owner of dataset authorized for this operation.',
code=RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
if not e:
return get_data_error_result(message="Knowledgebase not found!")
_question = question
if langs:
_question = await cross_languages(kb.tenant_id, None, _question, langs)
if kb.tenant_embd_id:
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
elif kb.embd_id:
embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
else:
embd_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.EMBEDDING)
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config)
rerank_mdl = None
if req.get("tenant_rerank_id"):
rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"])
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
elif req.get("rerank_id"):
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"])
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
if req.get("keyword", False):
default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config)
_question += await keyword_extraction(chat_mdl, _question)
labels = label_question(_question, [kb])
ranks = await settings.retriever.retrieval(
_question,
embd_mdl,
tenant_ids,
kb_ids,
page,
size,
float(req.get("similarity_threshold", 0.0)),
float(req.get("vector_similarity_weight", 0.3)),
doc_ids=local_doc_ids,
top=top,
rerank_mdl=rerank_mdl,
rank_feature=labels
)
if use_kg:
default_chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT)
ck = await settings.kg_retriever.retrieval(_question,
tenant_ids,
kb_ids,
embd_mdl,
LLMBundle(kb.tenant_id, default_chat_model_config))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids)
ranks["total"] = len(ranks["chunks"])
for c in ranks["chunks"]:
c.pop("vector", None)
ranks["labels"] = labels
return get_json_result(data=ranks)
try:
return await _retrieval()
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
code=RetCode.DATA_ERROR)
return server_error_response(e)
@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821
@login_required
async def knowledge_graph():
doc_id = request.args["doc_id"]
tenant_id = DocumentService.get_tenant_id(doc_id)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
req = {
"doc_ids": [doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"]
}
sres = await settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"]
try:
content_json = json.loads(sres.field[id]["content_with_weight"])
except Exception:
continue
if ty == 'mind_map':
node_dict = {}
def repeat_deal(content_json, node_dict):
if 'id' in content_json:
if content_json['id'] in node_dict:
node_name = content_json['id']
content_json['id'] += f"({node_dict[content_json['id']]})"
node_dict[node_name] += 1
else:
node_dict[content_json['id']] = 1
if 'children' in content_json and content_json['children']:
for item in content_json['children']:
repeat_deal(item, node_dict)
repeat_deal(content_json, node_dict)
obj[ty] = content_json
return get_json_result(data=obj)

View File

@@ -24,6 +24,7 @@ from api.utils.validation_utils import (
CreateDatasetReq,
DeleteDatasetReq,
ListDatasetReq,
SearchDatasetReq,
UpdateDatasetReq,
validate_and_parse_json_request,
validate_and_parse_request_args,
@@ -476,6 +477,35 @@ async def rename_tag(tenant_id, dataset_id):
return get_error_data_result(message="Internal server error")
@manager.route('/datasets/<dataset_id>/search', methods=['POST']) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
async def search(tenant_id, dataset_id):
"""Search (retrieval test) within a dataset.
POST /api/v1/datasets/<dataset_id>/search
JSON body: {"question": str (required), "doc_ids": list[str], "top_k": int, "page": int, "size": int,
"similarity_threshold": float, "vector_similarity_weight": float, "use_kg": bool,
"cross_languages": list[str], "keyword": bool, "meta_data_filter": dict}
Success: {"code": 0, "data": {"chunks": [...], "total": int, "labels": [...]}}
Errors: ARGUMENT_ERROR (101) for invalid payload; DATA_ERROR (102) for access denied or internal errors.
"""
req, err = await validate_and_parse_json_request(request, SearchDatasetReq)
if err is not None:
return get_error_argument_result(err)
try:
success, result = await dataset_api_service.search(dataset_id, tenant_id, req)
if success:
return get_result(data=result)
else:
return get_error_data_result(message=result)
except Exception as e:
logging.exception(e)
if "not_found" in str(e):
return get_error_data_result(message="No chunk found! Check the chunk status please!")
return get_error_data_result(message="Internal server error")
@manager.route('/datasets/<dataset_id>/graph/search', methods=['GET']) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
@@ -495,6 +525,32 @@ async def knowledge_graph(tenant_id, dataset_id):
return get_error_data_result(message="Internal server error")
@manager.route('/datasets/<dataset_id>/graph', methods=['GET']) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
async def get_knowledge_graph(tenant_id, dataset_id):
"""Get the knowledge graph of a dataset.
GET /api/v1/datasets/<dataset_id>/graph
Query params: optional filter params.
Success: {"code": 0, "data": {...}}
Errors: AUTHENTICATION_ERROR for access denied; DATA_ERROR for internal errors.
"""
try:
success, result = await dataset_api_service.get_knowledge_graph(dataset_id, tenant_id)
if success:
return get_result(data=result)
else:
return get_result(
data=False,
message=result,
code=RetCode.AUTHENTICATION_ERROR
)
except Exception as e:
logging.exception(e)
return get_error_data_result(message="Internal server error")
@manager.route('/datasets/<dataset_id>/graph', methods=['DELETE']) # noqa: F821
@login_required
@add_tenant_id_to_kwargs

View File

@@ -900,3 +900,153 @@ def rename_tag(dataset_id: str, tenant_id: str, from_tag: str, to_tag: str):
return True, {"from": from_tag, "to": to_tag}
async def search(dataset_id: str, tenant_id: str, req: dict):
"""
Search (retrieval test) within a dataset.
:param dataset_id: dataset ID
:param tenant_id: tenant ID
:param req: search request
:return: (success, result) or (success, error_message)
"""
from api.db.joint_services.tenant_model_service import (
get_model_config_by_id,
get_model_config_by_type_and_name,
get_tenant_default_model_by_type,
)
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from common.constants import LLMType
from common.metadata_utils import apply_meta_data_filter
from rag.app.tag import label_question
from rag.prompts.generator import cross_languages, keyword_extraction
logging.debug(
"search(dataset=%s, tenant=%s, question_len=%s)",
dataset_id,
tenant_id,
len(req.get("question", "")),
)
page = int(req.get("page", 1))
size = int(req.get("size", 30))
question = req.get("question", "")
doc_ids = req.get("doc_ids", [])
use_kg = req.get("use_kg", False)
top = max(1, min(int(req.get("top_k", 1024)), 2048))
langs = req.get("cross_languages", [])
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
logging.warning("search access denied: dataset=%s tenant=%s", dataset_id, tenant_id)
return False, "Only owner of dataset authorized for this operation."
e, kb = KnowledgebaseService.get_by_id(dataset_id)
if not e:
logging.warning("search dataset not found: dataset=%s", dataset_id)
return False, "Dataset not found!"
if doc_ids is not None and not isinstance(doc_ids, list):
return False, "`doc_ids` should be a list"
local_doc_ids = list(doc_ids) if doc_ids else []
meta_data_filter = {}
chat_mdl = None
if req.get("search_id", ""):
search_detail = SearchService.get_detail(req.get("search_id", ""))
if not search_detail:
logging.warning("search config not found: search_id=%s", req.get("search_id", ""))
return False, "Invalid search_id"
search_config = search_detail.get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
chat_id = search_config.get("chat_id", "")
if chat_id:
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, search_config["chat_id"])
else:
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(tenant_id, chat_model_config)
else:
meta_data_filter = req.get("meta_data_filter") or {}
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(tenant_id, chat_model_config)
if meta_data_filter:
metas = DocMetadataService.get_flatted_meta_by_kbs([dataset_id])
local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, local_doc_ids)
tenant_ids = []
tenants = UserTenantService.query(user_id=tenant_id)
for tenant in tenants:
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=dataset_id):
tenant_ids.append(tenant.tenant_id)
break
else:
return False, "Only owner of dataset authorized for this operation."
_question = question
if langs:
_question = await cross_languages(kb.tenant_id, None, _question, langs)
if kb.tenant_embd_id:
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
elif kb.embd_id:
embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
else:
embd_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.EMBEDDING)
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config)
rerank_mdl = None
if req.get("tenant_rerank_id"):
rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"])
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
elif req.get("rerank_id"):
rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"])
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
if req.get("keyword", False):
default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config)
_question += await keyword_extraction(chat_mdl, _question)
labels = label_question(_question, [kb])
ranks = await settings.retriever.retrieval(
_question,
embd_mdl,
tenant_ids,
[dataset_id],
page,
size,
float(req.get("similarity_threshold", 0.0)),
float(req.get("vector_similarity_weight", 0.3)),
doc_ids=local_doc_ids,
top=top,
rerank_mdl=rerank_mdl,
rank_feature=labels
)
if use_kg:
try:
default_chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
ck = await settings.kg_retriever.retrieval(_question,
tenant_ids,
[dataset_id],
embd_mdl,
LLMBundle(kb.tenant_id, default_chat_model_config))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
except Exception:
logging.warning("search KG retrieval failed: dataset=%s tenant=%s", dataset_id, tenant_id, exc_info=True)
total = ranks.get("total", 0)
ranks["chunks"] = settings.retriever.retrieval_by_children(
ranks["chunks"], tenant_ids
)
ranks["total"] = total
for c in ranks["chunks"]:
c.pop("vector", None)
ranks["labels"] = labels
return True, ranks

View File

@@ -819,6 +819,25 @@ class DeleteReq(Base):
class DeleteDatasetReq(DeleteReq): ...
class SearchDatasetReq(BaseModel):
model_config = ConfigDict(extra="ignore")
question: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1), Field(...)]
doc_ids: Annotated[list[str], Field(default=[])]
page: Annotated[int, Field(default=1, ge=1)]
size: Annotated[int, Field(default=30, ge=1)]
top_k: Annotated[int, Field(default=1024, ge=1)]
similarity_threshold: Annotated[float, Field(default=0.0, ge=0.0, le=1.0)]
vector_similarity_weight: Annotated[float, Field(default=0.3, ge=0.0, le=1.0)]
use_kg: Annotated[bool, Field(default=False)]
cross_languages: Annotated[list[str], Field(default=[])]
keyword: Annotated[bool, Field(default=False)]
search_id: Annotated[str | None, Field(default=None)]
rerank_id: Annotated[str | None, Field(default=None)]
tenant_rerank_id: Annotated[str | None, Field(default=None)]
meta_data_filter: Annotated[dict | None, Field(default=None)]
class DeleteDocumentReq(DeleteReq): ...

View File

@@ -517,3 +517,12 @@ def get_flattened_metadata(auth, dataset_ids, *, headers=HEADERS):
url = f"{HOST_ADDRESS}{DATASETS_API_URL}/metadata/flattened"
res = requests.get(url=url, headers=headers, auth=auth, params={"dataset_ids": ",".join(dataset_ids)})
return res.json()
def search_dataset(auth, dataset_id, payload=None, *, headers=HEADERS):
url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/search"
res = requests.post(url=url, headers=headers, auth=auth, json=payload)
return res.json()

View File

@@ -0,0 +1,83 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from common import search_dataset, knowledge_graph
from configs import INVALID_API_TOKEN
from libs.auth import RAGFlowHttpApiAuth
@pytest.mark.p2
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 401, "<Unauthorized '401: Unauthorized'>"),
(RAGFlowHttpApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = search_dataset(invalid_auth, "dataset_id", {"question": "test"})
assert res["code"] == expected_code
assert expected_message in res.get("message", "")
class TestDatasetSearch:
@pytest.mark.p2
def test_search_without_question(self, HttpApiAuth, add_dataset_func):
dataset_id = add_dataset_func
res = search_dataset(HttpApiAuth, dataset_id, {})
assert res["code"] == 101, res
@pytest.mark.p2
def test_search_basic(self, HttpApiAuth, add_chunks):
dataset_id, document_id, _ = add_chunks
res = search_dataset(HttpApiAuth, dataset_id, {"question": "chunk"})
assert res["code"] == 0, res
assert "chunks" in res["data"], res
@pytest.mark.p2
def test_search_with_doc_ids(self, HttpApiAuth, add_chunks):
dataset_id, document_id, _ = add_chunks
res = search_dataset(HttpApiAuth, dataset_id, {"question": "chunk", "doc_ids": [document_id]})
assert res["code"] == 0, res
assert "chunks" in res["data"], res
@pytest.mark.p2
@pytest.mark.parametrize(
"payload, expected_code",
[
({"question": "chunk", "page": 1, "size": 2}, 0),
({"question": "chunk", "similarity_threshold": 0.5}, 0),
({"question": "chunk", "vector_similarity_weight": 0.7}, 0),
({"question": "chunk", "top_k": 10}, 0),
],
)
def test_search_params(self, HttpApiAuth, add_chunks, payload, expected_code):
dataset_id, _, _ = add_chunks
res = search_dataset(HttpApiAuth, dataset_id, payload)
assert res["code"] == expected_code, res
@pytest.mark.p2
class TestDatasetGraph:
def test_graph_requires_auth(self):
res = knowledge_graph(None, "dataset_id")
assert res["code"] == 401
def test_graph_basic(self, HttpApiAuth, add_dataset_func):
dataset_id = add_dataset_func
res = knowledge_graph(HttpApiAuth, dataset_id)
assert res["code"] == 0, res

View File

@@ -17,7 +17,6 @@
import asyncio
import inspect
import importlib.util
import json
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
@@ -491,13 +490,15 @@ def _load_chunk_module(monkeypatch):
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
services_pkg.user_service = user_service_mod
module_name = "test_chunk_routes_unit_module"
module_path = repo_root / "api" / "apps" / "chunk_app.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, module_name, module)
spec.loader.exec_module(module)
module = None
if module_path.exists():
module_name = "test_chunk_routes_unit_module"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, module_name, module)
spec.loader.exec_module(module)
return module
@@ -653,167 +654,3 @@ def test_restful_chunk_guard_branches_unit(monkeypatch):
assert res["message"] == "`available_int` or `available` is required.", res
@pytest.mark.p2
def test_retrieval_test_branch_matrix_unit(monkeypatch):
module = _load_chunk_module(monkeypatch)
module.request = SimpleNamespace(headers={"X-Request-ID": "req-r"}, args={})
applied_filters = []
llm_calls = []
cross_calls = []
keyword_calls = []
async def _apply_filter(meta_data_filter, metas, question, chat_mdl, local_doc_ids):
applied_filters.append(
{
"meta_data_filter": meta_data_filter,
"metas": metas,
"question": question,
"chat_mdl": chat_mdl,
"local_doc_ids": list(local_doc_ids),
}
)
return ["doc-filtered"]
async def _cross_languages(_tenant_id, _dialog, question, langs):
cross_calls.append((question, tuple(langs)))
return f"{question}-xl"
async def _keyword_extraction(_chat_mdl, question):
keyword_calls.append(question)
return "-kw"
class _Retriever:
def __init__(self, mode="ok"):
self.mode = mode
self.retrieval_questions = []
async def retrieval(self, question, *_args, **_kwargs):
if self.mode == "not_found":
raise Exception("boom not_found boom")
if self.mode == "explode":
raise RuntimeError("retrieval boom")
self.retrieval_questions.append(question)
return {"chunks": [{"id": "c1", "vector": [0.1], "content_with_weight": "chunk-content"}]}
def retrieval_by_children(self, chunks, _tenant_ids):
return list(chunks)
class _KgRetriever:
async def retrieval(self, *_args, **_kwargs):
return {"id": "kg-1", "content_with_weight": "kg-content"}
class _NoContentKgRetriever:
async def retrieval(self, *_args, **_kwargs):
return {"id": "kg-2", "content_with_weight": ""}
monkeypatch.setattr(module, "LLMBundle", lambda *args, **kwargs: llm_calls.append((args, kwargs)) or SimpleNamespace())
monkeypatch.setattr(module, "get_model_config_by_type_and_name", lambda *_args, **_kwargs: {"llm_name": "stub-model", "model_type": "chat"})
monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: {"llm_name": "stub-model", "model_type": "chat"})
monkeypatch.setattr(module, "get_model_config_by_id", lambda *_args, **_kwargs: {"llm_name": "stub-model", "model_type": "embedding"})
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kb_ids: [{"meta": "v"}], raising=False)
monkeypatch.setattr(module, "apply_meta_data_filter", _apply_filter)
monkeypatch.setattr(module.SearchService, "get_detail", lambda _sid: {"search_config": {"meta_data_filter": {"method": "auto"}, "chat_id": "chat-1"}}, raising=False)
monkeypatch.setattr(module, "cross_languages", _cross_languages)
monkeypatch.setattr(module, "keyword_extraction", _keyword_extraction)
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: ["lbl"])
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [_DummyTenant("tenant-1")])
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: False, raising=False)
_set_request_json(monkeypatch, module, {"kb_id": "kb-1", "question": "q", "search_id": "search-1"})
res = _run(module.retrieval_test())
assert res["code"] == module.RetCode.OPERATING_ERROR, res
assert "Only owner of dataset authorized for this operation." in res["message"], res
assert applied_filters and applied_filters[-1]["meta_data_filter"]["method"] == "auto"
assert llm_calls, "search_id metadata auto branch should instantiate chat model"
_set_request_json(monkeypatch, module, {"kb_id": [], "question": "q"})
res = _run(module.retrieval_test())
assert res["code"] == module.RetCode.DATA_ERROR, res
assert "Please specify dataset firstly." in res["message"], res
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: True, raising=False)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None), raising=False)
_set_request_json(
monkeypatch,
module,
{"kb_id": ["kb-1"], "question": "q", "meta_data_filter": {"method": "semi_auto"}},
)
res = _run(module.retrieval_test())
assert res["code"] == module.RetCode.DATA_ERROR, res
assert "Knowledgebase not found!" in res["message"], res
retriever = _Retriever(mode="ok")
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-kb", embd_id="embd-1", tenant_embd_id=2)), raising=False)
monkeypatch.setattr(module.settings, "retriever", retriever)
monkeypatch.setattr(module.settings, "kg_retriever", _KgRetriever(), raising=False)
_set_request_json(
monkeypatch,
module,
{
"kb_id": ["kb-1"],
"question": "q",
"cross_languages": ["fr"],
"rerank_id": "rerank-1",
"keyword": True,
"use_kg": True,
},
)
res = _run(module.retrieval_test())
assert res["code"] == 0, res
assert cross_calls[-1] == ("q", ("fr",))
assert keyword_calls[-1] == "q-xl"
assert retriever.retrieval_questions[-1] == "q-xl-kw"
assert res["data"]["chunks"][0]["id"] == "kg-1", res
assert all("vector" not in chunk for chunk in res["data"]["chunks"])
monkeypatch.setattr(module.settings, "kg_retriever", _NoContentKgRetriever(), raising=False)
_set_request_json(monkeypatch, module, {"kb_id": ["kb-1"], "question": "q", "use_kg": True})
res = _run(module.retrieval_test())
assert res["code"] == 0, res
assert res["data"]["chunks"][0]["id"] == "c1", res
monkeypatch.setattr(module.settings, "retriever", _Retriever(mode="not_found"))
_set_request_json(monkeypatch, module, {"kb_id": ["kb-1"], "question": "q"})
res = _run(module.retrieval_test())
assert res["code"] == module.RetCode.DATA_ERROR, res
assert "No chunk found! Check the chunk status please!" in res["message"], res
monkeypatch.setattr(module.settings, "retriever", _Retriever(mode="explode"))
_set_request_json(monkeypatch, module, {"kb_id": ["kb-1"], "question": "q"})
res = _run(module.retrieval_test())
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
assert "retrieval boom" in res["message"], res
@pytest.mark.p2
def test_knowledge_graph_repeat_deal_matrix_unit(monkeypatch):
module = _load_chunk_module(monkeypatch)
module.request = SimpleNamespace(args={"doc_id": "doc-1"}, headers={})
payload = {
"id": "root",
"children": [
{"id": "dup"},
{"id": "dup", "children": [{"id": "dup"}]},
],
}
class _SRes:
ids = ["bad-json", "mind-map"]
field = {
"bad-json": {"knowledge_graph_kwd": "graph", "content_with_weight": "{bad json"},
"mind-map": {"knowledge_graph_kwd": "mind_map", "content_with_weight": json.dumps(payload)},
}
async def _search(*_args, **_kwargs):
return _SRes()
monkeypatch.setattr(module.settings.retriever, "search", _search)
res = _run(module.knowledge_graph())
assert res["code"] == 0, res
assert res["data"]["graph"] == {}, res
mind_map = res["data"]["mind_map"]
assert mind_map["children"][0]["id"] == "dup", res
assert mind_map["children"][1]["id"] == "dup(1)", res
assert mind_map["children"][1]["children"][0]["id"] == "dup(2)", res

View File

@@ -1,308 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
from test_common import retrieval_chunks
from configs import INVALID_API_TOKEN
from libs.auth import RAGFlowWebApiAuth
@pytest.mark.p2
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 401, "<Unauthorized '401: Unauthorized'>"),
(RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = retrieval_chunks(invalid_auth, {"kb_id": "dummy_kb_id", "question": "dummy question"})
assert res["code"] == expected_code, res
assert res["message"] == expected_message, res
class TestChunksRetrieval:
@pytest.mark.p1
@pytest.mark.parametrize(
"payload, expected_code, expected_page_size, expected_message",
[
({"question": "chunk", "kb_id": None}, 0, 4, ""),
({"question": "chunk", "doc_ids": None}, 101, 0, "required argument are missing: kb_id; "),
({"question": "chunk", "kb_id": None, "doc_ids": None}, 0, 4, ""),
({"question": "chunk"}, 101, 0, "required argument are missing: kb_id; "),
],
)
def test_basic_scenarios(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, document_id, _ = add_chunks
if "kb_id" in payload:
payload["kb_id"] = [dataset_id]
if "doc_ids" in payload:
payload["doc_ids"] = [document_id]
res = retrieval_chunks(WebApiAuth, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size, res
else:
assert res["message"] == expected_message, res
@pytest.mark.p2
@pytest.mark.parametrize(
"payload, expected_code, expected_page_size, expected_message",
[
pytest.param(
{"page": None, "size": 2},
100,
0,
"""TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""",
marks=pytest.mark.skip,
),
pytest.param(
{"page": 0, "size": 2},
100,
0,
"ValueError('Search does not support negative slicing.')",
marks=pytest.mark.skip,
),
pytest.param({"page": 2, "size": 2}, 0, 2, "", marks=pytest.mark.skip(reason="issues/6646")),
({"page": 3, "size": 2}, 0, 0, ""),
({"page": "3", "size": 2}, 0, 0, ""),
pytest.param(
{"page": -1, "size": 2},
100,
0,
"ValueError('Search does not support negative slicing.')",
marks=pytest.mark.skip,
),
pytest.param(
{"page": "a", "size": 2},
100,
0,
"""ValueError("invalid literal for int() with base 10: 'a'")""",
marks=pytest.mark.skip,
),
],
)
def test_page(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "kb_id": [dataset_id]})
res = retrieval_chunks(WebApiAuth, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size, res
else:
assert res["message"] == expected_message, res
@pytest.mark.p3
@pytest.mark.parametrize(
"payload, expected_code, expected_page_size, expected_message",
[
pytest.param(
{"size": None},
100,
0,
"""TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""",
marks=pytest.mark.skip,
),
# ({"size": 0}, 0, 0, ""),
({"size": 1}, 0, 1, ""),
({"size": 5}, 0, 4, ""),
({"size": "1"}, 0, 1, ""),
# ({"size": -1}, 0, 0, ""),
pytest.param(
{"size": "a"},
100,
0,
"""ValueError("invalid literal for int() with base 10: 'a'")""",
marks=pytest.mark.skip,
),
],
)
def test_page_size(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "kb_id": [dataset_id]})
res = retrieval_chunks(WebApiAuth, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size, res
else:
assert res["message"] == expected_message, res
@pytest.mark.p3
@pytest.mark.parametrize(
"payload, expected_code, expected_page_size, expected_message",
[
({"vector_similarity_weight": 0}, 0, 4, ""),
({"vector_similarity_weight": 0.5}, 0, 4, ""),
({"vector_similarity_weight": 10}, 0, 4, ""),
pytest.param(
{"vector_similarity_weight": "a"},
100,
0,
"""ValueError("could not convert string to float: 'a'")""",
marks=pytest.mark.skip,
),
],
)
def test_vector_similarity_weight(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "kb_id": [dataset_id]})
res = retrieval_chunks(WebApiAuth, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size, res
else:
assert res["message"] == expected_message, res
@pytest.mark.p2
@pytest.mark.parametrize(
"payload, expected_code, expected_page_size, expected_message",
[
({"top_k": 10}, 0, 4, ""),
pytest.param(
{"top_k": 1},
0,
4,
"",
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"),
),
pytest.param(
{"top_k": 1},
0,
1,
"",
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"),
),
pytest.param(
{"top_k": -1},
100,
4,
"must be greater than 0",
marks=pytest.mark.skip(reason="Web API does not validate top_k"),
),
pytest.param(
{"top_k": -1},
100,
4,
"3014",
marks=pytest.mark.skip(reason="Web API does not validate top_k"),
),
pytest.param(
{"top_k": "a"},
100,
0,
"""ValueError("invalid literal for int() with base 10: 'a'")""",
marks=pytest.mark.skip,
),
],
)
def test_top_k(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "kb_id": [dataset_id]})
res = retrieval_chunks(WebApiAuth, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size, res
else:
assert expected_message in res["message"], res
@pytest.mark.skip
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"rerank_id": "BAAI/bge-reranker-v2-m3"}, 0, ""),
pytest.param({"rerank_id": "unknown"}, 100, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip),
],
)
def test_rerank_id(self, WebApiAuth, add_chunks, payload, expected_code, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "kb_id": [dataset_id]})
res = retrieval_chunks(WebApiAuth, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert len(res["data"]["chunks"]) > 0, res
else:
assert expected_message in res["message"], res
@pytest.mark.skip
@pytest.mark.parametrize(
"payload, expected_code, expected_page_size, expected_message",
[
({"keyword": True}, 0, 5, ""),
({"keyword": "True"}, 0, 5, ""),
({"keyword": False}, 0, 5, ""),
({"keyword": "False"}, 0, 5, ""),
({"keyword": None}, 0, 5, ""),
],
)
def test_keyword(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk test", "kb_id": [dataset_id]})
res = retrieval_chunks(WebApiAuth, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size, res
else:
assert res["message"] == expected_message, res
@pytest.mark.p3
@pytest.mark.parametrize(
"payload, expected_code, expected_highlight, expected_message",
[
pytest.param({"highlight": True}, 0, True, "", marks=pytest.mark.skip(reason="highlight not functionnal")),
pytest.param({"highlight": "True"}, 0, True, "", marks=pytest.mark.skip(reason="highlight not functionnal")),
({"highlight": False}, 0, False, ""),
({"highlight": "False"}, 0, False, ""),
({"highlight": None}, 0, False, "")
],
)
def test_highlight(self, WebApiAuth, add_chunks, payload, expected_code, expected_highlight, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "kb_id": [dataset_id]})
res = retrieval_chunks(WebApiAuth, payload)
assert res["code"] == expected_code, res
if expected_highlight:
for chunk in res["data"]["chunks"]:
assert "highlight" in chunk, res
else:
for chunk in res["data"]["chunks"]:
assert "highlight" not in chunk, res
if expected_code != 0:
assert res["message"] == expected_message, res
@pytest.mark.p3
def test_invalid_params(self, WebApiAuth, add_chunks):
dataset_id, _, _ = add_chunks
payload = {"question": "chunk", "kb_id": [dataset_id], "a": "b"}
res = retrieval_chunks(WebApiAuth, payload)
assert res["code"] == 0, res
assert len(res["data"]["chunks"]) == 4, res
@pytest.mark.p3
def test_concurrent_retrieval(self, WebApiAuth, add_chunks):
dataset_id, _, _ = add_chunks
count = 100
payload = {"question": "chunk", "kb_id": [dataset_id]}
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(retrieval_chunks, WebApiAuth, payload) for i in range(count)]
responses = list(as_completed(futures))
assert len(responses) == count, responses
assert all(future.result()["code"] == 0 for future in futures)

View File

@@ -244,22 +244,6 @@ def kb_pipeline_log_detail(auth, dataset_id, log_id, *, headers=HEADERS):
return res.json()
# DATASET GRAPH AND TASKS
def knowledge_graph(auth, dataset_id, params=None):
url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/knowledge_graph"
res = requests.get(url=url, headers=HEADERS, auth=auth, params=params)
return res.json()
def delete_knowledge_graph(auth, dataset_id, payload=None):
url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/knowledge_graph"
if payload is None:
res = requests.delete(url=url, headers=HEADERS, auth=auth)
else:
res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def list_tags_from_kbs(auth, dataset_ids, *, headers=HEADERS):
params = {"dataset_ids": dataset_ids}
res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_URL}/tags/aggregation", headers=headers, auth=auth, params=params)
@@ -518,11 +502,6 @@ def delete_chunks(auth, dataset_id, document_id, payload=None, *, headers=HEADER
return res.json()
def retrieval_chunks(auth, payload=None, *, headers=HEADERS):
res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_APP_URL}/retrieval_test", headers=headers, auth=auth, json=payload)
return res.json()
def batch_add_chunks(auth, dataset_id, document_id, num):
chunk_ids = []
for i in range(num):

View File

@@ -18,11 +18,9 @@ const {
documentChangeStatus,
documentChangeParser,
documentThumbnails,
retrievalTest,
documentIngest,
documentUpload,
webCrawl,
knowledgeGraph,
listTagByKnowledgeIds,
setMeta,
getMeta,
@@ -71,14 +69,6 @@ const methods = {
url: setMeta,
method: 'post',
},
retrievalTest: {
url: retrievalTest,
method: 'post',
},
knowledgeGraph: {
url: knowledgeGraph,
method: 'get',
},
listTagByKnowledgeIds: {
url: listTagByKnowledgeIds,
method: 'get',
@@ -151,6 +141,17 @@ const getAvailableParam = (available?: number) => {
};
const chunkService = {
retrievalTest: async (params: Record<string, any>) => {
const datasetId = getDatasetId(params);
if (!datasetId) {
throw new Error(
'dataset_id (or kb_id/knowledge_id) is required for retrievalTest',
);
}
return request.post(api.retrievalTest(datasetId), {
data: params,
});
},
chunkList: async (params: Record<string, any>) => {
const datasetId = getDatasetId(params);
const documentId = getDocumentId(params);

View File

@@ -66,6 +66,8 @@ export default {
getKbDetail: (datasetId: string) => `${restAPIv1}/datasets/${datasetId}`,
getKnowledgeGraph: (knowledgeId: string) =>
`${restAPIv1}/datasets/${knowledgeId}/graph/search`,
knowledgeGraph: (datasetId: string) =>
`${restAPIv1}/datasets/${datasetId}/graph`,
deleteKnowledgeGraph: (knowledgeId: string) =>
`${restAPIv1}/datasets/${knowledgeId}/graph`,
getMeta: `${restAPIv1}/datasets/metadata/flattened`,
@@ -107,8 +109,8 @@ export default {
`${restAPIv1}/datasets/${datasetId}/documents/${documentId}/chunks`,
chunkDetail: (datasetId: string, documentId: string, chunkId: string) =>
`${restAPIv1}/datasets/${datasetId}/documents/${documentId}/chunks/${chunkId}`,
retrievalTest: `${webAPI}/chunk/retrieval_test`,
knowledgeGraph: `${webAPI}/chunk/knowledge_graph`,
retrievalTest: (datasetId: string) =>
`${restAPIv1}/datasets/${datasetId}/search`,
// document
getDocumentList: (datasetId: string) =>