Feat: Support get aggregated parsing status to dataset via the API (#13481)

### What problem does this PR solve?

Support getting aggregated parsing status to dataset via the API

Issue: #12810

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: heyang.why <heyang.why@alibaba-inc.com>
This commit is contained in:
Heyang Wang
2026-03-10 18:05:45 +08:00
committed by GitHub
parent 68a623154a
commit 08f83ff331
7 changed files with 654 additions and 309 deletions

View File

@@ -138,12 +138,7 @@ async def create(tenant_id):
parser_cfg["metadata"] = fields
parser_cfg["enable_metadata"] = auto_meta.get("enabled", True)
req["parser_config"] = parser_cfg
e, req = KnowledgebaseService.create_with_name(
name = req.pop("name", None),
tenant_id = tenant_id,
parser_id = req.pop("parser_id", None),
**req
)
e, req = KnowledgebaseService.create_with_name(name=req.pop("name", None), tenant_id=tenant_id, parser_id=req.pop("parser_id", None), **req)
if not e:
return req
@@ -159,19 +154,19 @@ async def create(tenant_id):
if not ok:
return err
try:
if not KnowledgebaseService.save(**req):
return get_error_data_result()
ok, k = KnowledgebaseService.get_by_id(req["id"])
if not ok:
return get_error_data_result(message="Dataset created failed")
response_data = remap_dictionary_keys(k.to_dict())
return get_result(data=response_data)
if not KnowledgebaseService.save(**req):
return get_error_data_result()
ok, k = KnowledgebaseService.get_by_id(req["id"])
if not ok:
return get_error_data_result(message="Dataset created failed")
response_data = remap_dictionary_keys(k.to_dict())
return get_result(data=response_data)
except Exception as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
@manager.route("/datasets", methods=["DELETE"]) # noqa: F821
@token_required
async def delete(tenant_id):
@@ -227,8 +222,7 @@ async def delete(tenant_id):
continue
kb_id_instance_pairs.append((kb_id, kb))
if len(error_kb_ids) > 0:
return get_error_permission_result(
message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""")
return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""")
errors = []
success_count = 0
@@ -245,12 +239,12 @@ async def delete(tenant_id):
]
)
File2DocumentService.delete_by_document_id(doc.id)
FileService.filter_delete(
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
# Drop index for this dataset
try:
from rag.nlp import search
idxnm = search.index_name(kb.tenant_id)
settings.docStoreConn.delete_idx(idxnm, kb_id)
except Exception as e:
@@ -352,8 +346,7 @@ async def update(tenant_id, dataset_id):
try:
kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
if kb is None:
return get_error_permission_result(
message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
# Map auto_metadata_config into parser_config if present
auto_meta = req.pop("auto_metadata_config", None)
@@ -384,8 +377,7 @@ async def update(tenant_id, dataset_id):
del req["parser_config"]
if "name" in req and req["name"].lower() != kb.name.lower():
exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id,
status=StatusEnum.VALID.value)
exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)
if exists:
return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
@@ -393,8 +385,7 @@ async def update(tenant_id, dataset_id):
if not req["embd_id"]:
req["embd_id"] = kb.embd_id
if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id:
return get_error_data_result(
message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
if not ok:
return err
@@ -404,12 +395,10 @@ async def update(tenant_id, dataset_id):
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")
if req["pagerank"] > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id)
else:
# Elasticsearch requires PAGERANK_FLD be non-zero!
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id)
if not KnowledgebaseService.update_by_id(kb.id, req):
return get_error_data_result(message="Update dataset error.(Database error)")
@@ -470,6 +459,15 @@ def list_datasets(tenant_id):
required: false
default: true
description: Order in descending.
- in: query
name: include_parsing_status
type: boolean
required: false
default: false
description: |
Whether to include document parsing status counts in the response.
When true, each dataset object will include: unstart_count, running_count,
cancel_count, done_count, and fail_count.
- in: header
name: Authorization
type: string
@@ -487,17 +485,18 @@ def list_datasets(tenant_id):
if err is not None:
return get_error_argument_result(err)
include_parsing_status = args.get("include_parsing_status", False)
try:
kb_id = request.args.get("id")
name = args.get("name")
# check whether user has permission for the dataset with specified id
if kb_id:
kbs = KnowledgebaseService.get_kb_by_id(kb_id, tenant_id)
if not kbs:
if not KnowledgebaseService.get_kb_by_id(kb_id, tenant_id):
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{kb_id}'")
# check whether user has permission for the dataset with specified name
if name:
kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id)
if not kbs:
if not KnowledgebaseService.get_kb_by_name(name, tenant_id):
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'")
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
@@ -512,9 +511,17 @@ def list_datasets(tenant_id):
name,
)
parsing_status_map = {}
if include_parsing_status and kbs:
kb_ids = [kb["id"] for kb in kbs]
parsing_status_map = DocumentService.get_parsing_status_by_kb_ids(kb_ids)
response_data_list = []
for kb in kbs:
response_data_list.append(remap_dictionary_keys(kb))
data = remap_dictionary_keys(kb)
if include_parsing_status:
data.update(parsing_status_map.get(kb["id"], {}))
response_data_list.append(data)
return get_result(data=response_data_list, total=total)
except OperationalError as e:
logging.exception(e)
@@ -530,9 +537,7 @@ def get_auto_metadata(tenant_id, dataset_id):
try:
kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
if kb is None:
return get_error_permission_result(
message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'"
)
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
parser_cfg = kb.parser_config or {}
metadata = parser_cfg.get("metadata") or []
@@ -570,9 +575,7 @@ async def update_auto_metadata(tenant_id, dataset_id):
try:
kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
if kb is None:
return get_error_permission_result(
message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'"
)
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
parser_cfg = kb.parser_config or {}
fields = []
@@ -598,20 +601,13 @@ async def update_auto_metadata(tenant_id, dataset_id):
return get_error_data_result(message="Database operation failed")
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['GET']) # noqa: F821
@manager.route("/datasets/<dataset_id>/knowledge_graph", methods=["GET"]) # noqa: F821
@token_required
async def knowledge_graph(tenant_id, dataset_id):
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
_, kb = KnowledgebaseService.get_by_id(dataset_id)
req = {
"kb_id": [dataset_id],
"knowledge_graph_kwd": ["graph"]
}
req = {"kb_id": [dataset_id], "knowledge_graph_kwd": ["graph"]}
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id):
@@ -633,39 +629,29 @@ async def knowledge_graph(tenant_id, dataset_id):
obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256]
if "edges" in obj["graph"]:
node_id_set = {o["id"] for o in obj["graph"]["nodes"]}
filtered_edges = [o for o in obj["graph"]["edges"] if
o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set]
filtered_edges = [o for o in obj["graph"]["edges"] if o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set]
obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128]
return get_result(data=obj)
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['DELETE']) # noqa: F821
@manager.route("/datasets/<dataset_id>/knowledge_graph", methods=["DELETE"]) # noqa: F821
@token_required
def delete_knowledge_graph(tenant_id, dataset_id):
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
_, kb = KnowledgebaseService.get_by_id(dataset_id)
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]},
search.index_name(kb.tenant_id), dataset_id)
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), dataset_id)
return get_result(data=True)
@manager.route("/datasets/<dataset_id>/run_graphrag", methods=["POST"]) # noqa: F821
@token_required
def run_graphrag(tenant_id,dataset_id):
def run_graphrag(tenant_id, dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
@@ -707,15 +693,11 @@ def run_graphrag(tenant_id,dataset_id):
@manager.route("/datasets/<dataset_id>/trace_graphrag", methods=["GET"]) # noqa: F821
@token_required
def trace_graphrag(tenant_id,dataset_id):
def trace_graphrag(tenant_id, dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
@@ -734,15 +716,11 @@ def trace_graphrag(tenant_id,dataset_id):
@manager.route("/datasets/<dataset_id>/run_raptor", methods=["POST"]) # noqa: F821
@token_required
def run_raptor(tenant_id,dataset_id):
def run_raptor(tenant_id, dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
@@ -784,7 +762,7 @@ def run_raptor(tenant_id,dataset_id):
@manager.route("/datasets/<dataset_id>/trace_raptor", methods=["GET"]) # noqa: F821
@token_required
def trace_raptor(tenant_id,dataset_id):
def trace_raptor(tenant_id, dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')

View File

@@ -28,8 +28,7 @@ from peewee import fn, Case, JOIN
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES, FileType, UserTenantRole, CanvasCategory
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \
User
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, User
from api.db.db_utils import bulk_insert_into_db
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
@@ -78,24 +77,21 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_list(cls, kb_id, page_number, items_per_page,
orderby, desc, keywords, id, name, suffix=None, run = None, doc_ids=None):
def get_list(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, id, name, suffix=None, run=None, doc_ids=None):
fields = cls.get_cls_model_fields()
docs = cls.model.select(*[*fields, UserCanvas.title]).join(File2Document, on = (File2Document.document_id == cls.model.id))\
.join(File, on = (File.id == File2Document.file_id))\
.join(UserCanvas, on = ((cls.model.pipeline_id == UserCanvas.id) & (UserCanvas.canvas_category == CanvasCategory.DataFlow.value)), join_type=JOIN.LEFT_OUTER)\
docs = (
cls.model.select(*[*fields, UserCanvas.title])
.join(File2Document, on=(File2Document.document_id == cls.model.id))
.join(File, on=(File.id == File2Document.file_id))
.join(UserCanvas, on=((cls.model.pipeline_id == UserCanvas.id) & (UserCanvas.canvas_category == CanvasCategory.DataFlow.value)), join_type=JOIN.LEFT_OUTER)
.where(cls.model.kb_id == kb_id)
)
if id:
docs = docs.where(
cls.model.id == id)
docs = docs.where(cls.model.id == id)
if name:
docs = docs.where(
cls.model.name == name
)
docs = docs.where(cls.model.name == name)
if keywords:
docs = docs.where(
fn.LOWER(cls.model.name).contains(keywords.lower())
)
docs = docs.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
if doc_ids:
docs = docs.where(cls.model.id.in_(doc_ids))
if suffix:
@@ -120,6 +116,7 @@ class DocumentService(CommonService):
@DB.connection_context()
def check_doc_health(cls, tenant_id: str, filename):
import os
MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0))
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(tenant_id):
raise RuntimeError("Exceed the maximum file number of a free user!")
@@ -211,13 +208,14 @@ class DocumentService(CommonService):
"""
fields = cls.get_cls_model_fields()
if keywords:
query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
query = (
cls.model.select(*fields)
.join(File2Document, on=(File2Document.document_id == cls.model.id))
.join(File, on=(File.id == File2Document.file_id))
.where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower())))
)
else:
query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
if run_status:
query = query.where(cls.model.run.in_(run_status))
@@ -272,14 +270,60 @@ class DocumentService(CommonService):
"metadata": metadata_counter,
}, total
@classmethod
@DB.connection_context()
def get_parsing_status_by_kb_ids(cls, kb_ids: list[str]) -> dict[str, dict[str, int]]:
"""Return aggregated document parsing status counts grouped by dataset (kb_id).
For each kb_id, counts documents in each run-status bucket:
- unstart_count (run == "0")
- running_count (run == "1")
- cancel_count (run == "2")
- done_count (run == "3")
- fail_count (run == "4")
Returns a dict keyed by kb_id, e.g.
{"kb-abc": {"unstart_count": 10, "running_count": 2, ...}, ...}
"""
if not kb_ids:
return {}
status_field_map = {
TaskStatus.UNSTART.value: "unstart_count",
TaskStatus.RUNNING.value: "running_count",
TaskStatus.CANCEL.value: "cancel_count",
TaskStatus.DONE.value: "done_count",
TaskStatus.FAIL.value: "fail_count",
}
empty_status = {v: 0 for v in status_field_map.values()}
result: dict[str, dict[str, int]] = {kb_id: dict(empty_status) for kb_id in kb_ids}
rows = (
cls.model.select(
cls.model.kb_id,
cls.model.run,
fn.COUNT(cls.model.id).alias("cnt"),
)
.where(cls.model.kb_id.in_(kb_ids))
.group_by(cls.model.kb_id, cls.model.run)
.dicts()
)
for row in rows:
kb_id = row["kb_id"]
run_val = str(row["run"])
field_name = status_field_map.get(run_val)
if field_name and kb_id in result:
result[kb_id][field_name] = int(row["cnt"])
return result
@classmethod
@DB.connection_context()
def count_by_kb_id(cls, kb_id, keywords, run_status, types):
if keywords:
docs = cls.model.select().where(
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
)
docs = cls.model.select().where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower())))
else:
docs = cls.model.select().where(cls.model.kb_id == kb_id)
@@ -295,9 +339,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_total_size_by_kb_id(cls, kb_id, keywords="", run_status=[], types=[]):
query = cls.model.select(fn.COALESCE(fn.SUM(cls.model.size), 0)).where(
cls.model.kb_id == kb_id
)
query = cls.model.select(fn.COALESCE(fn.SUM(cls.model.size), 0)).where(cls.model.kb_id == kb_id)
if keywords:
query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
@@ -329,12 +371,8 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_all_docs_by_creator_id(cls, creator_id):
fields = [
cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id
]
docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(
cls.model.created_by == creator_id
)
fields = [cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id]
docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.created_by == creator_id)
docs.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later
offset, limit = 0, 100
@@ -361,6 +399,7 @@ class DocumentService(CommonService):
@DB.connection_context()
def remove_document(cls, doc, tenant_id):
from api.db.services.task_service import TaskService, cancel_all_task_of
if not cls.delete_document_and_update_kb_counts(doc.id):
return True
@@ -406,17 +445,22 @@ class DocumentService(CommonService):
# Cleanup knowledge graph references (non-critical, log and continue)
try:
graph_source = settings.docStoreConn.get_fields(
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]),
["source_id"],
)
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
{"remove": {"source_id": doc.id}},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
{"removed_kwd": "Y"},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update(
{"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
{"remove": {"source_id": doc.id}},
search.index_name(tenant_id),
doc.kb_id,
)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, {"removed_kwd": "Y"}, search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete(
{"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
search.index_name(tenant_id),
doc.kb_id,
)
except Exception as e:
logging.warning(f"Failed to cleanup knowledge graph for document {doc.id}: {e}")
@@ -428,9 +472,7 @@ class DocumentService(CommonService):
page = 0
page_size = 1000
while True:
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id])
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), page * page_size, page_size, search.index_name(tenant_id), [doc.kb_id])
chunk_ids = settings.docStoreConn.get_doc_ids(chunks)
if not chunk_ids:
break
@@ -455,71 +497,61 @@ class DocumentService(CommonService):
Tenant.embd_id,
Tenant.img2txt_id,
Tenant.asr_id,
cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
cls.model.update_time,
]
docs = (
cls.model.select(*fields)
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= current_timestamp() - 1000 * 600,
cls.model.run == TaskStatus.RUNNING.value) \
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= current_timestamp() - 1000 * 600,
cls.model.run == TaskStatus.RUNNING.value,
)
.order_by(cls.model.update_time.asc())
)
return list(docs.dicts())
@classmethod
@DB.connection_context()
def get_unfinished_docs(cls):
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
cls.model.run, cls.model.parser_id]
unfinished_task_query = Task.select(Task.doc_id).where(
(Task.progress >= 0) & (Task.progress < 1)
)
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run, cls.model.parser_id]
unfinished_task_query = Task.select(Task.doc_id).where((Task.progress >= 0) & (Task.progress < 1))
docs = cls.model.select(*fields) \
.where(
docs = cls.model.select(*fields).where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)),
(((cls.model.progress < 1) & (cls.model.progress > 0)) |
(cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap
(((cls.model.progress < 1) & (cls.model.progress > 0)) | (cls.model.id.in_(unfinished_task_query))),
) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap
return list(docs.dicts())
@classmethod
@DB.connection_context()
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration):
num = cls.model.update(token_num=cls.model.token_num + token_num,
chunk_num=cls.model.chunk_num + chunk_num,
process_duration=cls.model.process_duration + duration).where(
cls.model.id == doc_id).execute()
num = (
cls.model.update(token_num=cls.model.token_num + token_num, chunk_num=cls.model.chunk_num + chunk_num, process_duration=cls.model.process_duration + duration)
.where(cls.model.id == doc_id)
.execute()
)
if num == 0:
logging.warning("Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num +
token_num,
chunk_num=Knowledgebase.chunk_num +
chunk_num).where(
Knowledgebase.id == kb_id).execute()
num = Knowledgebase.update(token_num=Knowledgebase.token_num + token_num, chunk_num=Knowledgebase.chunk_num + chunk_num).where(Knowledgebase.id == kb_id).execute()
return num
@classmethod
@DB.connection_context()
def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration):
num = cls.model.update(token_num=cls.model.token_num - token_num,
chunk_num=cls.model.chunk_num - chunk_num,
process_duration=cls.model.process_duration + duration).where(
cls.model.id == doc_id).execute()
num = (
cls.model.update(token_num=cls.model.token_num - token_num, chunk_num=cls.model.chunk_num - chunk_num, process_duration=cls.model.process_duration + duration)
.where(cls.model.id == doc_id)
.execute()
)
if num == 0:
raise LookupError(
"Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
token_num,
chunk_num=Knowledgebase.chunk_num -
chunk_num
).where(
Knowledgebase.id == kb_id).execute()
raise LookupError("Document not found which is supposed to be there")
num = Knowledgebase.update(token_num=Knowledgebase.token_num - token_num, chunk_num=Knowledgebase.chunk_num - chunk_num).where(Knowledgebase.id == kb_id).execute()
return num
@classmethod
@@ -551,17 +583,13 @@ class DocumentService(CommonService):
doc = cls.model.get_by_id(doc_id)
assert doc, "Can't fine document in database."
num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
doc.token_num,
chunk_num=Knowledgebase.chunk_num -
doc.chunk_num,
doc_num=Knowledgebase.doc_num - 1
).where(
Knowledgebase.id == doc.kb_id).execute()
num = (
Knowledgebase.update(token_num=Knowledgebase.token_num - doc.token_num, chunk_num=Knowledgebase.chunk_num - doc.chunk_num, doc_num=Knowledgebase.doc_num - 1)
.where(Knowledgebase.id == doc.kb_id)
.execute()
)
return num
@classmethod
@DB.connection_context()
def clear_chunk_num_when_rerun(cls, doc_id):
@@ -578,15 +606,10 @@ class DocumentService(CommonService):
)
return num
@classmethod
@DB.connection_context()
def get_tenant_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return None
@@ -604,11 +627,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_tenant_id_by_name(cls, name):
docs = cls.model.select(
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return None
@@ -617,12 +636,13 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def accessible(cls, doc_id, user_id):
docs = cls.model.select(
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = (
cls.model.select(cls.model.id)
.join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id))
.join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id))
.where(cls.model.id == doc_id, UserTenant.user_id == user_id)
.paginate(0, 1)
)
docs = docs.dicts()
if not docs:
return False
@@ -631,18 +651,13 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def accessible4deletion(cls, doc_id, user_id):
docs = cls.model.select(cls.model.id
).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).join(
UserTenant, on=(
(UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id))
).where(
cls.model.id == doc_id,
UserTenant.status == StatusEnum.VALID.value,
((UserTenant.role == UserTenantRole.NORMAL) | (UserTenant.role == UserTenantRole.OWNER))
).paginate(0, 1)
docs = (
cls.model.select(cls.model.id)
.join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id))
.join(UserTenant, on=((UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id)))
.where(cls.model.id == doc_id, UserTenant.status == StatusEnum.VALID.value, ((UserTenant.role == UserTenantRole.NORMAL) | (UserTenant.role == UserTenantRole.OWNER)))
.paginate(0, 1)
)
docs = docs.dicts()
if not docs:
return False
@@ -651,11 +666,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_embd_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.embd_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = cls.model.select(Knowledgebase.embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return None
@@ -664,11 +675,9 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_tenant_embd_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.tenant_embd_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = (
cls.model.select(Knowledgebase.tenant_embd_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
)
docs = docs.dicts()
if not docs:
return None
@@ -705,8 +714,7 @@ class DocumentService(CommonService):
@DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name):
fields = [cls.model.id]
doc_id = cls.model.select(*fields) \
.where(cls.model.name == doc_name)
doc_id = cls.model.select(*fields).where(cls.model.name == doc_name)
doc_id = doc_id.dicts()
if not doc_id:
return None
@@ -725,8 +733,7 @@ class DocumentService(CommonService):
@DB.connection_context()
def get_thumbnails(cls, docids):
fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
return list(cls.model.select(
*fields).where(cls.model.id.in_(docids)).dicts())
return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts())
@classmethod
@DB.connection_context()
@@ -755,9 +762,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_doc_count(cls, tenant_id):
docs = cls.model.select(cls.model.id).join(Knowledgebase,
on=(Knowledgebase.id == cls.model.kb_id)).where(
Knowledgebase.tenant_id == tenant_id)
docs = cls.model.select(cls.model.id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(Knowledgebase.tenant_id == tenant_id)
return len(docs)
@classmethod
@@ -768,7 +773,7 @@ class DocumentService(CommonService):
"process_begin_at": get_format_time(),
}
if not keep_progress:
info["progress"] = random.random() * 1 / 100.
info["progress"] = random.random() * 1 / 100.0
info["run"] = TaskStatus.RUNNING.value
# keep the doc in DONE state when keep_progress=True for GraphRAG, RAPTOR and Mindmap tasks
@@ -781,19 +786,17 @@ class DocumentService(CommonService):
cls._sync_progress(docs)
@classmethod
@DB.connection_context()
def update_progress_immediately(cls, docs:list[dict]):
def update_progress_immediately(cls, docs: list[dict]):
if not docs:
return
cls._sync_progress(docs)
@classmethod
@DB.connection_context()
def _sync_progress(cls, docs:list[dict]):
def _sync_progress(cls, docs: list[dict]):
from api.db.services.task_service import TaskService
for d in docs:
@@ -841,27 +844,18 @@ class DocumentService(CommonService):
# fallback
cls.update_by_id(d["id"], {"process_begin_at": begin_at})
info = {
"process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0),
"run": status}
info = {"process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0), "run": status}
if prg != 0 and not freeze_progress:
info["progress"] = prg
if msg:
info["progress_msg"] = msg
if msg.endswith("created task graphrag") or msg.endswith("created task raptor") or msg.endswith("created task mindmap"):
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
info["progress_msg"] += "\n%d tasks are ahead in the queue..." % get_queue_length(priority)
else:
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
info["progress_msg"] = "%d tasks are ahead in the queue..." % get_queue_length(priority)
info["update_time"] = current_timestamp()
info["update_date"] = get_format_time()
(
cls.model.update(info)
.where(
(cls.model.id == d["id"])
& ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value))
)
.execute()
)
(cls.model.update(info).where((cls.model.id == d["id"]) & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value))).execute())
except Exception as e:
if str(e).find("'0'") < 0:
logging.exception("fetch task exception")
@@ -875,7 +869,7 @@ class DocumentService(CommonService):
@DB.connection_context()
def get_all_kb_doc_count(cls):
result = {}
rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias('count')).group_by(cls.model.kb_id)
rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias("count")).group_by(cls.model.kb_id)
for row in rows:
result[row.kb_id] = row.count
return result
@@ -890,33 +884,19 @@ class DocumentService(CommonService):
pass
return False
@classmethod
@DB.connection_context()
def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]:
# cancelled: run == "2"
cancelled = (
cls.model.select(fn.COUNT(1))
.where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL))
.scalar()
)
downloaded = (
cls.model.select(fn.COUNT(1))
.where(
cls.model.kb_id == kb_id,
cls.model.source_type != "local"
)
.scalar()
)
cancelled = cls.model.select(fn.COUNT(1)).where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL)).scalar()
downloaded = cls.model.select(fn.COUNT(1)).where(cls.model.kb_id == kb_id, cls.model.source_type != "local").scalar()
row = (
cls.model.select(
# finished: progress == 1
fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == 1, 1)], 0)), 0).alias("finished"),
# failed: progress == -1
fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == -1, 1)], 0)), 0).alias("failed"),
# processing: 0 <= progress < 1
fn.COALESCE(
fn.SUM(
@@ -931,24 +911,15 @@ class DocumentService(CommonService):
0,
).alias("processing"),
)
.where(
(cls.model.kb_id == kb_id)
& ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL))
)
.where((cls.model.kb_id == kb_id) & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL)))
.dicts()
.get()
)
return {
"processing": int(row["processing"]),
"finished": int(row["finished"]),
"failed": int(row["failed"]),
"cancelled": int(cancelled),
"downloaded": int(downloaded)
}
return {"processing": int(row["processing"]), "finished": int(row["finished"]), "failed": int(row["failed"]), "cancelled": int(cancelled), "downloaded": int(downloaded)}
@classmethod
def run(cls, tenant_id:str, doc:dict, kb_table_num_map:dict):
def run(cls, tenant_id: str, doc: dict, kb_table_num_map: dict):
from api.db.services.task_service import queue_dataflow, queue_tasks
from api.db.services.file2document_service import File2DocumentService
@@ -990,7 +961,7 @@ def queue_raptor_o_graphrag_tasks(sample_doc_id, ty, priority, fake_doc_id="", d
"from_page": 100000000,
"to_page": 100000000,
"task_type": ty,
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty,
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty,
"begin_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
@@ -1032,8 +1003,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
e, dia = DialogService.get_by_id(conv.dialog_id)
if not dia.kb_ids:
raise LookupError("No dataset associated with this conversation. "
"Please add a dataset before uploading documents")
raise LookupError("No dataset associated with this conversation. Please add a dataset before uploading documents")
kb_id = dia.kb_ids[0]
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
@@ -1050,12 +1020,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
def dummy(prog=None, msg=""):
pass
FACTORY = {
ParserType.PRESENTATION.value: presentation,
ParserType.PICTURE.value: picture,
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email
}
FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email}
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text", "table_context_size": 0, "image_context_size": 0}
exe = ThreadPoolExecutor(max_workers=12)
threads = []
@@ -1063,22 +1028,12 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
for d, blob in files:
doc_nm[d["id"]] = d["name"]
for d, blob in files:
kwargs = {
"callback": dummy,
"parser_config": parser_config,
"from_page": 0,
"to_page": 100000,
"tenant_id": kb.tenant_id,
"lang": kb.language
}
kwargs = {"callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": kb.tenant_id, "lang": kb.language}
threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
for (docinfo, _), th in zip(files, threads):
docs = []
doc = {
"doc_id": docinfo["id"],
"kb_id": [kb.id]
}
doc = {"doc_id": docinfo["id"], "kb_id": [kb.id]}
for ck in th.result():
d = deepcopy(doc)
d.update(ck)
@@ -1093,7 +1048,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if isinstance(d["image"], bytes):
output_buffer = BytesIO(d["image"])
else:
d["image"].save(output_buffer, format='JPEG')
d["image"].save(output_buffer, format="JPEG")
settings.STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(kb.id, d["id"])
@@ -1110,9 +1065,9 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
nonlocal embd_mdl, chunk_counts, token_counts
vectors = []
for i in range(0, len(cnts), batch_size):
vts, c = embd_mdl.encode(cnts[i: i + batch_size])
vts, c = embd_mdl.encode(cnts[i : i + batch_size])
vectors.extend(vts.tolist())
chunk_counts[doc_id] += len(cnts[i:i + batch_size])
chunk_counts[doc_id] += len(cnts[i : i + batch_size])
token_counts[doc_id] += c
return vectors
@@ -1127,22 +1082,25 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if parser_ids[doc_id] != ParserType.PICTURE.value:
from rag.graphrag.general.mind_map_extractor import MindMapExtractor
mindmap = MindMapExtractor(llm_bdl)
try:
mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]))
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
if len(mind_map) < 32:
raise Exception("Few content: " + mind_map)
cks.append({
"id": get_uuid(),
"doc_id": doc_id,
"kb_id": [kb.id],
"docnm_kwd": doc_nm[doc_id],
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
"content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
"content_with_weight": mind_map,
"knowledge_graph_kwd": "mind_map"
})
cks.append(
{
"id": get_uuid(),
"doc_id": doc_id,
"kb_id": [kb.id],
"docnm_kwd": doc_nm[doc_id],
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
"content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
"content_with_weight": mind_map,
"knowledge_graph_kwd": "mind_map",
}
)
except Exception:
logging.exception("Mind map generation error")
@@ -1156,9 +1114,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if not settings.docStoreConn.index_exist(idxnm, kb_id):
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]), kb.parser_id)
try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
settings.docStoreConn.insert(cks[b : b + es_bulk_size], idxnm, kb_id)
DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
DocumentService.increment_chunk_num(doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
return [d["id"] for d, _ in files]

View File

@@ -34,7 +34,9 @@ from werkzeug.exceptions import BadRequest, UnsupportedMediaType
from api.constants import DATASET_NAME_LIMIT
async def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]:
async def validate_and_parse_json_request(
request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False
) -> tuple[dict[str, Any] | None, str | None]:
"""
Validates and parses JSON requests through a multi-stage validation pipeline.
@@ -742,4 +744,5 @@ class BaseListReq(BaseModel):
return validate_uuid1_hex(v)
class ListDatasetReq(BaseListReq): ...
class ListDatasetReq(BaseListReq):
include_parsing_status: Annotated[bool, Field(default=False)]