mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-01 00:05:43 +08:00
Refactor: Consolidation WEB API & HTTP API for document get_filter (#14248)
### What problem does this PR solve? Before consolidation Web API: POST /v1/document/filter Http API - GET /api/v1/datasets/<dataset_id>/documents After consolidation, Restful API -- GET /api/v1/datasets/<dataset_id>/documents?type=filter ### Type of change - [x] Refactoring
This commit is contained in:
@@ -22,7 +22,7 @@ from quart import make_response, request
|
||||
from api.apps import current_user, login_required
|
||||
from api.common.check_team_permission import check_kb_team_permission
|
||||
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
|
||||
from api.db import VALID_FILE_TYPES, FileType
|
||||
from api.db import FileType
|
||||
from api.db.db_models import Task
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.doc_metadata_service import DocMetadataService
|
||||
@@ -31,7 +31,6 @@ from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import TaskService, cancel_all_task_of
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
@@ -42,7 +41,7 @@ from api.utils.api_utils import (
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
from api.utils.web_utils import CONTENT_TYPE_MAP, apply_safe_file_response_headers, html2pdf, is_valid_url
|
||||
from common import settings
|
||||
from common.constants import SANDBOX_ARTIFACT_BUCKET, VALID_TASK_STATUS, ParserType, RetCode, TaskStatus
|
||||
from common.constants import SANDBOX_ARTIFACT_BUCKET, ParserType, RetCode, TaskStatus
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from deepdoc.parser.html_parser import RAGFlowHtmlParser
|
||||
@@ -184,44 +183,6 @@ async def create():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/filter", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def get_filter():
|
||||
req = await get_request_json()
|
||||
|
||||
kb_id = req.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
|
||||
keywords = req.get("keywords", "")
|
||||
|
||||
suffix = req.get("suffix", [])
|
||||
|
||||
run_status = req.get("run_status", [])
|
||||
if run_status:
|
||||
invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(message=f"Invalid filter run status conditions: {', '.join(invalid_status)}")
|
||||
|
||||
types = req.get("types", [])
|
||||
if types:
|
||||
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
|
||||
if invalid_types:
|
||||
return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
|
||||
|
||||
try:
|
||||
filter, total = DocumentService.get_filter_by_kb_id(kb_id, keywords, run_status, types, suffix)
|
||||
return get_json_result(data={"total": total, "filter": filter})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/infos", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def doc_infos():
|
||||
|
||||
@@ -436,16 +436,19 @@ def list_docs(dataset_id, tenant_id):
|
||||
if err_code != RetCode.SUCCESS:
|
||||
return get_data_error_result(code=err_code, message=err_msg)
|
||||
|
||||
renamed_doc_list = [map_doc_keys(doc) for doc in docs]
|
||||
for doc_item in renamed_doc_list:
|
||||
if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX):
|
||||
doc_item["thumbnail"] = f"/v1/document/image/{dataset_id}-{doc_item['thumbnail']}"
|
||||
if doc_item.get("source_type"):
|
||||
doc_item["source_type"] = doc_item["source_type"].split("/")[0]
|
||||
if doc_item["parser_config"].get("metadata"):
|
||||
doc_item["parser_config"]["metadata"] = turn2jsonschema(doc_item["parser_config"]["metadata"])
|
||||
|
||||
return get_json_result(data={"total": total, "docs": renamed_doc_list})
|
||||
if request.args.get("type") == "filter":
|
||||
docs_filter = _aggregate_filters(docs)
|
||||
return get_json_result(data={"total": total, "filter": docs_filter})
|
||||
else:
|
||||
renamed_doc_list = [map_doc_keys(doc) for doc in docs]
|
||||
for doc_item in renamed_doc_list:
|
||||
if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX):
|
||||
doc_item["thumbnail"] = f"/v1/document/image/{dataset_id}-{doc_item['thumbnail']}"
|
||||
if doc_item.get("source_type"):
|
||||
doc_item["source_type"] = doc_item["source_type"].split("/")[0]
|
||||
if doc_item["parser_config"].get("metadata"):
|
||||
doc_item["parser_config"]["metadata"] = turn2jsonschema(doc_item["parser_config"]["metadata"])
|
||||
return get_json_result(data={"total": total, "docs": renamed_doc_list})
|
||||
|
||||
|
||||
def _get_docs_with_request(req, dataset_id:str):
|
||||
@@ -517,13 +520,15 @@ def _get_docs_with_request(req, dataset_id:str):
|
||||
|
||||
doc_name = q.get("name")
|
||||
doc_id = q.get("id")
|
||||
if doc_id and not DocumentService.query(id=doc_id, kb_id=dataset_id):
|
||||
return RetCode.DATA_ERROR, f"You don't own the document {doc_id}.", [], 0
|
||||
if doc_id:
|
||||
if not DocumentService.query(id=doc_id, kb_id=dataset_id):
|
||||
return RetCode.DATA_ERROR, f"You don't own the document {doc_id}.", [], 0
|
||||
doc_ids_filter = [doc_id] # id provided, ignore other filters
|
||||
if doc_name and not DocumentService.query(name=doc_name, kb_id=dataset_id):
|
||||
return RetCode.DATA_ERROR, f"You don't own the document {doc_name}.", [], 0
|
||||
|
||||
docs, total = DocumentService.get_by_kb_id(dataset_id, page, page_size, orderby, desc, keywords, run_status_converted, types, suffix,
|
||||
doc_id=doc_id, name=doc_name, doc_ids_filter=doc_ids_filter, return_empty_metadata=return_empty_metadata)
|
||||
name=doc_name, doc_ids=doc_ids_filter, return_empty_metadata=return_empty_metadata)
|
||||
|
||||
# time range filter (0 means no bound)
|
||||
create_time_from = int(q.get("create_time_from", 0))
|
||||
@@ -622,11 +627,11 @@ def _parse_doc_id_filter_with_metadata(req, kb_id):
|
||||
if metadata and not isinstance(metadata, dict):
|
||||
return RetCode.DATA_ERROR, "metadata must be an object.", [], return_empty_metadata
|
||||
|
||||
doc_ids_filter = None
|
||||
metas = None
|
||||
metas = dict()
|
||||
if metadata_condition or metadata:
|
||||
metas = DocMetadataService.get_flatted_meta_by_kbs([kb_id])
|
||||
|
||||
doc_ids_filter = None
|
||||
if metadata_condition:
|
||||
doc_ids_filter = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
|
||||
if metadata_condition.get("conditions") and not doc_ids_filter:
|
||||
@@ -651,6 +656,7 @@ def _parse_doc_id_filter_with_metadata(req, kb_id):
|
||||
metadata_doc_ids &= key_doc_ids
|
||||
if not metadata_doc_ids:
|
||||
return RetCode.SUCCESS, "", [], return_empty_metadata
|
||||
|
||||
if metadata_doc_ids is not None:
|
||||
if doc_ids_filter is None:
|
||||
doc_ids_filter = metadata_doc_ids
|
||||
@@ -660,3 +666,62 @@ def _parse_doc_id_filter_with_metadata(req, kb_id):
|
||||
return RetCode.SUCCESS, "", [], return_empty_metadata
|
||||
|
||||
return RetCode.SUCCESS, "", list(doc_ids_filter) if doc_ids_filter is not None else [], return_empty_metadata
|
||||
|
||||
|
||||
def _aggregate_filters(docs):
|
||||
"""Aggregate filter options from a list of documents.
|
||||
|
||||
This function processes a list of document dictionaries and aggregates
|
||||
available filter values for building filter UI on the client side.
|
||||
|
||||
Args:
|
||||
docs (list): List of document dictionaries, each containing:
|
||||
- id (str): Document ID
|
||||
- suffix (str): File extension (e.g., "pdf", "docx")
|
||||
- run (int): Parsing status code (0=UNSTART, 1=RUNNING, 2=CANCEL, 3=DONE, 4=FAIL)
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing:
|
||||
- dict: Aggregated filter options with keys:
|
||||
- suffix: Dict mapping file extensions to document counts
|
||||
- run_status: Dict mapping status codes to document counts
|
||||
- metadata: Dict mapping metadata field names to value counts
|
||||
- int: Total number of documents processed
|
||||
"""
|
||||
suffix_counter = {}
|
||||
run_status_counter = {}
|
||||
metadata_counter = {}
|
||||
empty_metadata_count = 0
|
||||
|
||||
for doc in docs:
|
||||
suffix_counter[doc.get("suffix")] = suffix_counter.get(doc.get("suffix"), 0) + 1
|
||||
key_of_run = str(doc.get("run"))
|
||||
run_status_counter[key_of_run] = run_status_counter.get(key_of_run, 0) + 1
|
||||
meta_fields = doc.get("meta_fields", {})
|
||||
|
||||
if not meta_fields:
|
||||
empty_metadata_count += 1
|
||||
continue
|
||||
has_valid_meta = False
|
||||
|
||||
for key, value in meta_fields.items():
|
||||
values = value if isinstance(value, list) else [value]
|
||||
for vv in values:
|
||||
if vv is None:
|
||||
continue
|
||||
if isinstance(vv, str) and not vv.strip():
|
||||
continue
|
||||
sv = str(vv)
|
||||
if key not in metadata_counter:
|
||||
metadata_counter[key] = {}
|
||||
metadata_counter[key][sv] = metadata_counter[key].get(sv, 0) + 1
|
||||
has_valid_meta = True
|
||||
if not has_valid_meta:
|
||||
empty_metadata_count += 1
|
||||
|
||||
metadata_counter["empty_metadata"] = {"true": empty_metadata_count}
|
||||
return {
|
||||
"suffix": suffix_counter,
|
||||
"run_status": run_status_counter,
|
||||
"metadata": metadata_counter,
|
||||
}
|
||||
|
||||
@@ -127,7 +127,7 @@ class DocumentService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix, doc_id=None, name=None, doc_ids_filter=None, return_empty_metadata=False):
|
||||
def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix, name=None, doc_ids=None, return_empty_metadata=False):
|
||||
fields = cls.get_cls_model_fields()
|
||||
if keywords:
|
||||
docs = (
|
||||
@@ -147,10 +147,8 @@ class DocumentService(CommonService):
|
||||
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)
|
||||
.where(cls.model.kb_id == kb_id)
|
||||
)
|
||||
if doc_id:
|
||||
docs = docs.where(cls.model.id == doc_id)
|
||||
if doc_ids_filter:
|
||||
docs = docs.where(cls.model.id.in_(doc_ids_filter))
|
||||
if doc_ids:
|
||||
docs = docs.where(cls.model.id.in_(doc_ids))
|
||||
if run_status:
|
||||
docs = docs.where(cls.model.run.in_(run_status))
|
||||
if types:
|
||||
|
||||
@@ -375,7 +375,7 @@ def create_document(auth, payload=None, *, headers=HEADERS, data=None):
|
||||
|
||||
def list_documents(auth, params=None, payload=None, *, headers=HEADERS, data=None):
|
||||
kb_id = params.get("kb_id") if params else None
|
||||
url = f"{HOST_ADDRESS}/api/{VERSION}/datasets/{kb_id}/documents"
|
||||
url = f"{HOST_ADDRESS}{DATASETS_URL}/{kb_id}/documents"
|
||||
if payload is None:
|
||||
payload = {}
|
||||
res = requests.get(url=url, headers=headers, auth=auth, params=params, json=payload, data=data)
|
||||
@@ -392,8 +392,8 @@ def parse_documents(auth, payload=None, *, headers=HEADERS, data=None):
|
||||
return res.json()
|
||||
|
||||
|
||||
def document_filter(auth, payload=None, *, headers=HEADERS, data=None):
|
||||
res = requests.post(url=f"{HOST_ADDRESS}{DOCUMENT_APP_URL}/filter", headers=headers, auth=auth, json=payload, data=data)
|
||||
def document_filter(auth, dataset_id, payload=None, *, headers=HEADERS, data=None):
|
||||
res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/documents?type=filter", params=payload, headers=headers, auth=auth, data=data)
|
||||
return res.json()
|
||||
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class TestAuthorization:
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES)
|
||||
def test_filter_auth_invalid(self, invalid_auth, expected_code, expected_fragment):
|
||||
res = document_filter(invalid_auth, {"kb_id": "kb_id"})
|
||||
res = document_filter(invalid_auth, "kb_id", {})
|
||||
assert res["code"] == expected_code, res
|
||||
assert expected_fragment in res["message"], res
|
||||
|
||||
@@ -84,7 +84,7 @@ class TestDocumentMetadata:
|
||||
@pytest.mark.p2
|
||||
def test_filter(self, WebApiAuth, add_dataset_func):
|
||||
kb_id = add_dataset_func
|
||||
res = document_filter(WebApiAuth, {"kb_id": kb_id})
|
||||
res = document_filter(WebApiAuth, kb_id, {})
|
||||
assert res["code"] == 0, res
|
||||
assert "filter" in res["data"], res
|
||||
assert "total" in res["data"], res
|
||||
@@ -148,12 +148,12 @@ class TestDocumentMetadata:
|
||||
|
||||
|
||||
class TestDocumentMetadataNegative:
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.p2
|
||||
def test_filter_missing_kb_id(self, WebApiAuth, add_document_func):
|
||||
_, doc_id = add_document_func
|
||||
res = document_filter(WebApiAuth, {"doc_ids": [doc_id]})
|
||||
assert res["code"] == 101, res
|
||||
assert "KB ID" in res["message"], res
|
||||
kb_id, doc_id = add_document_func
|
||||
res = document_filter(WebApiAuth, "", {"doc_ids": [doc_id]})
|
||||
assert res["code"] == 100, res
|
||||
assert "<MethodNotAllowed '405: Method Not Allowed'>" == res["message"], res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_metadata_summary_missing_kb_id(self, WebApiAuth, add_document_func):
|
||||
@@ -228,77 +228,6 @@ class TestDocumentMetadataUnit:
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id=tenant_id)])
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: True if _kwargs.get("id") == kb_id else False)
|
||||
|
||||
def test_filter_missing_kb_id(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
|
||||
async def fake_request_json():
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.get_filter())
|
||||
assert res["code"] == 101
|
||||
assert "KB ID" in res["message"]
|
||||
|
||||
def test_filter_unauthorized(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant1")])
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: False)
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "kb1"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.get_filter())
|
||||
assert res["code"] == 103
|
||||
|
||||
def test_filter_invalid_filters(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
self._allow_kb(module, monkeypatch)
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "kb1", "run_status": ["INVALID"]}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.get_filter())
|
||||
assert res["code"] == 102
|
||||
assert "Invalid filter run status" in res["message"]
|
||||
|
||||
async def fake_request_json_types():
|
||||
return {"kb_id": "kb1", "types": ["INVALID"]}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json_types)
|
||||
res = _run(module.get_filter())
|
||||
assert res["code"] == 102
|
||||
assert "Invalid filter conditions" in res["message"]
|
||||
|
||||
def test_filter_keywords_suffix(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
self._allow_kb(module, monkeypatch)
|
||||
monkeypatch.setattr(module.DocumentService, "get_filter_by_kb_id", lambda *_args, **_kwargs: ({"run": {}}, 1))
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "kb1", "keywords": "ragflow", "suffix": ["txt"]}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.get_filter())
|
||||
assert res["code"] == 0
|
||||
assert "filter" in res["data"]
|
||||
|
||||
def test_filter_exception(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
self._allow_kb(module, monkeypatch)
|
||||
|
||||
def raise_error(*_args, **_kwargs):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_filter_by_kb_id", raise_error)
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "kb1"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.get_filter())
|
||||
assert res["code"] == 100
|
||||
|
||||
def test_infos_meta_fields(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
import i18n from '@/locales/config';
|
||||
import { EMPTY_METADATA_FIELD } from '@/pages/dataset/dataset/use-select-filters';
|
||||
import kbService, {
|
||||
documentFilter,
|
||||
listDocument,
|
||||
renameDocument,
|
||||
uploadDocument,
|
||||
@@ -214,10 +215,7 @@ export const useGetDocumentFilter = (): {
|
||||
knowledgeId,
|
||||
],
|
||||
queryFn: async () => {
|
||||
const { data } = await kbService.documentFilter({
|
||||
kb_id: knowledgeId || id,
|
||||
keywords: debouncedSearchString,
|
||||
});
|
||||
const { data } = await documentFilter(knowledgeId || id);
|
||||
if (data.code === 0) {
|
||||
return data.data;
|
||||
}
|
||||
|
||||
@@ -152,7 +152,7 @@ const methods = {
|
||||
},
|
||||
documentFilter: {
|
||||
url: api.getDatasetFilter,
|
||||
method: 'post',
|
||||
method: 'get',
|
||||
},
|
||||
getMeta: {
|
||||
url: getMeta,
|
||||
@@ -262,7 +262,7 @@ export const listDocument = (
|
||||
};
|
||||
|
||||
export const documentFilter = (kb_id: string) =>
|
||||
request.post(api.getDatasetFilter, { kb_id });
|
||||
request.get(api.getDatasetFilter(kb_id), { params: {} });
|
||||
|
||||
// Custom upload function that handles dynamic URL using axios directly
|
||||
export const uploadDocument = async (datasetId: string, formData: FormData) => {
|
||||
|
||||
@@ -126,7 +126,8 @@ export default {
|
||||
documentInfos: `${webAPI}/document/infos`,
|
||||
uploadAndParse: `${webAPI}/document/upload_info`,
|
||||
setMeta: `${webAPI}/document/set_meta`,
|
||||
getDatasetFilter: `${webAPI}/document/filter`,
|
||||
getDatasetFilter: (datasetId: string) =>
|
||||
`${restAPIv1}/datasets/${datasetId}/documents?type=filter`,
|
||||
|
||||
// chat
|
||||
createChat: `${restAPIv1}/chats`,
|
||||
|
||||
Reference in New Issue
Block a user