diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 730d63c66c..b8551c2a96 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -13,38 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import logging -import random -import re - -from common.metadata_utils import turn2jsonschema -from quart import request -import numpy as np - -from api.db.services.connector_service import Connector2KbService -from api.db.services.llm_service import LLMBundle -from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks -from api.db.services.doc_metadata_service import DocMetadataService -from api.db.services.pipeline_operation_log_service import PipelineOperationLogService -from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID -from api.db.services.user_service import UserTenantService -from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name, get_model_config_by_id -from api.utils.api_utils import ( - get_error_data_result, - server_error_response, - get_data_error_result, - validate_request, - get_request_json, -) -from api.db import VALID_FILE_TYPES -from api.db.services.knowledgebase_service import KnowledgebaseService -from api.utils.api_utils import get_json_result -from rag.nlp import search -from rag.utils.redis_conn import REDIS_CONN -from common.constants import RetCode, PipelineTaskType, VALID_TASK_STATUS, LLMType -from common import settings -from common.doc_store.doc_store_base import OrderByExpr -from api.apps import login_required, current_user """ Deprecated, todo delete @@ -182,52 +150,6 @@ async def update(): return server_error_response(e) """ -@manager.route('/update_metadata_setting', methods=['post']) # noqa: F821 -@login_required -@validate_request("kb_id", "metadata") -async def update_metadata_setting(): - req = await get_request_json() - e, kb = KnowledgebaseService.get_by_id(req["kb_id"]) - if not e: - return get_data_error_result( - message="Database error (Knowledgebase rename)!") - kb = kb.to_dict() - kb["parser_config"]["metadata"] = req["metadata"] - kb["parser_config"]["enable_metadata"] = req.get("enable_metadata", True) - KnowledgebaseService.update_by_id(kb["id"], kb) - return get_json_result(data=kb) - - -@manager.route('/detail', methods=['GET']) # noqa: F821 -@login_required -def detail(): - kb_id = request.args["kb_id"] - try: - tenants = UserTenantService.query(user_id=current_user.id) - for tenant in tenants: - if KnowledgebaseService.query( - tenant_id=tenant.tenant_id, id=kb_id): - break - else: - return get_json_result( - data=False, message='Only owner of dataset authorized for this operation.', - code=RetCode.OPERATING_ERROR) - kb = KnowledgebaseService.get_detail(kb_id) - if not kb: - return get_data_error_result( - message="Can't find this dataset!") - kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[]) - kb["connectors"] = Connector2KbService.list_connectors(kb_id) - if kb["parser_config"].get("metadata"): - kb["parser_config"]["metadata"] = turn2jsonschema(kb["parser_config"]["metadata"]) - - for key in ["graphrag_task_finish_at", "raptor_task_finish_at", "mindmap_task_finish_at"]: - if finish_at := kb.get(key): - kb[key] = finish_at.strftime("%Y-%m-%d %H:%M:%S") - return get_json_result(data=kb) - except Exception as e: - return server_error_response(e) - """ Deprecated, todo delete @manager.route('/list', methods=['POST']) # noqa: F821 @@ -326,80 +248,6 @@ async def rm(): return server_error_response(e) """ -@manager.route('//tags', methods=['GET']) # noqa: F821 -@login_required -def list_tags(kb_id): - if not KnowledgebaseService.accessible(kb_id, current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) - - tenants = UserTenantService.get_tenants_by_user_id(current_user.id) - tags = [] - for tenant in tenants: - tags += settings.retriever.all_tags(tenant["tenant_id"], [kb_id]) - return get_json_result(data=tags) - - -@manager.route('/tags', methods=['GET']) # noqa: F821 -@login_required -def list_tags_from_kbs(): - kb_ids = request.args.get("kb_ids", "").split(",") - for kb_id in kb_ids: - if not KnowledgebaseService.accessible(kb_id, current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) - - tenants = UserTenantService.get_tenants_by_user_id(current_user.id) - tags = [] - for tenant in tenants: - tags += settings.retriever.all_tags(tenant["tenant_id"], kb_ids) - return get_json_result(data=tags) - - -@manager.route('//rm_tags', methods=['POST']) # noqa: F821 -@login_required -async def rm_tags(kb_id): - req = await get_request_json() - if not KnowledgebaseService.accessible(kb_id, current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) - e, kb = KnowledgebaseService.get_by_id(kb_id) - - for t in req["tags"]: - settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]}, - {"remove": {"tag_kwd": t}}, - search.index_name(kb.tenant_id), - kb_id) - return get_json_result(data=True) - - -@manager.route('//rename_tag', methods=['POST']) # noqa: F821 -@login_required -async def rename_tags(kb_id): - req = await get_request_json() - if not KnowledgebaseService.accessible(kb_id, current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) - e, kb = KnowledgebaseService.get_by_id(kb_id) - - settings.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]}, - {"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}}, - search.index_name(kb.tenant_id), - kb_id) - return get_json_result(data=True) - """ Deprecated, todo delete @manager.route('//knowledge_graph', methods=['GET']) # noqa: F821 @@ -457,143 +305,6 @@ def delete_knowledge_graph(kb_id): return get_json_result(data=True) """ -@manager.route("/get_meta", methods=["GET"]) # noqa: F821 -@login_required -def get_meta(): - kb_ids = request.args.get("kb_ids", "").split(",") - for kb_id in kb_ids: - if not KnowledgebaseService.accessible(kb_id, current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) - return get_json_result(data=DocMetadataService.get_flatted_meta_by_kbs(kb_ids)) - - -@manager.route("/basic_info", methods=["GET"]) # noqa: F821 -@login_required -def get_basic_info(): - kb_id = request.args.get("kb_id", "") - if not KnowledgebaseService.accessible(kb_id, current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) - - basic_info = DocumentService.knowledgebase_basic_info(kb_id) - - return get_json_result(data=basic_info) - - -@manager.route("/list_pipeline_logs", methods=["POST"]) # noqa: F821 -@login_required -async def list_pipeline_logs(): - kb_id = request.args.get("kb_id") - if not kb_id: - return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) - - keywords = request.args.get("keywords", "") - - page_number = int(request.args.get("page", 0)) - items_per_page = int(request.args.get("page_size", 0)) - orderby = request.args.get("orderby", "create_time") - if request.args.get("desc", "true").lower() == "false": - desc = False - else: - desc = True - create_date_from = request.args.get("create_date_from", "") - create_date_to = request.args.get("create_date_to", "") - if create_date_to > create_date_from: - return get_data_error_result(message="Create data filter is abnormal.") - - req = await get_request_json() - - operation_status = req.get("operation_status", []) - if operation_status: - invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS} - if invalid_status: - return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}") - - types = req.get("types", []) - if types: - invalid_types = {t for t in types if t not in VALID_FILE_TYPES} - if invalid_types: - return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}") - - suffix = req.get("suffix", []) - - try: - logs, count = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from, create_date_to) - return get_json_result(data={"total": count, "logs": logs}) - except Exception as e: - return server_error_response(e) - - -@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821 -@login_required -async def list_pipeline_dataset_logs(): - kb_id = request.args.get("kb_id") - if not kb_id: - return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) - - page_number = int(request.args.get("page", 0)) - items_per_page = int(request.args.get("page_size", 0)) - orderby = request.args.get("orderby", "create_time") - if request.args.get("desc", "true").lower() == "false": - desc = False - else: - desc = True - create_date_from = request.args.get("create_date_from", "") - create_date_to = request.args.get("create_date_to", "") - if create_date_to > create_date_from: - return get_data_error_result(message="Create data filter is abnormal.") - - req = await get_request_json() - - operation_status = req.get("operation_status", []) - if operation_status: - invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS} - if invalid_status: - return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}") - - try: - logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from, create_date_to) - return get_json_result(data={"total": tol, "logs": logs}) - except Exception as e: - return server_error_response(e) - - -@manager.route("/delete_pipeline_logs", methods=["POST"]) # noqa: F821 -@login_required -async def delete_pipeline_logs(): - kb_id = request.args.get("kb_id") - if not kb_id: - return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) - - req = await get_request_json() - log_ids = req.get("log_ids", []) - - PipelineOperationLogService.delete_by_ids(log_ids) - - return get_json_result(data=True) - - -@manager.route("/pipeline_log_detail", methods=["GET"]) # noqa: F821 -@login_required -def pipeline_log_detail(): - log_id = request.args.get("log_id") - if not log_id: - return get_json_result(data=False, message='Lack of "Pipeline log ID"', code=RetCode.ARGUMENT_ERROR) - - ok, log = PipelineOperationLogService.get_by_id(log_id) - if not ok: - return get_data_error_result(message="Invalid pipeline log ID") - - return get_json_result(data=log.to_dict()) - - """ Deprecated, todo delete @manager.route("/run_graphrag", methods=["POST"]) # noqa: F821 @@ -733,280 +444,3 @@ def trace_raptor(): return get_json_result(data=task.to_dict()) """ - -@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821 -@login_required -async def run_mindmap(): - req = await get_request_json() - - kb_id = req.get("kb_id", "") - if not kb_id: - return get_error_data_result(message='Lack of "KB ID"') - - ok, kb = KnowledgebaseService.get_by_id(kb_id) - if not ok: - return get_error_data_result(message="Invalid Knowledgebase ID") - - task_id = kb.mindmap_task_id - if task_id: - ok, task = TaskService.get_by_id(task_id) - if not ok: - logging.warning(f"A valid Mindmap task id is expected for kb {kb_id}") - - if task and task.progress not in [-1, 1]: - return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Mindmap Task is already running.") - - documents, _ = DocumentService.get_by_kb_id( - kb_id=kb_id, - page_number=0, - items_per_page=0, - orderby="create_time", - desc=False, - keywords="", - run_status=[], - types=[], - suffix=[], - ) - if not documents: - return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}") - - sample_document = documents[0] - document_ids = [document["id"] for document in documents] - - task_id = queue_raptor_o_graphrag_tasks(sample_doc=sample_document, ty="mindmap", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) - - if not KnowledgebaseService.update_by_id(kb.id, {"mindmap_task_id": task_id}): - logging.warning(f"Cannot save mindmap_task_id for kb {kb_id}") - - return get_json_result(data={"mindmap_task_id": task_id}) - - -@manager.route("/trace_mindmap", methods=["GET"]) # noqa: F821 -@login_required -def trace_mindmap(): - kb_id = request.args.get("kb_id", "") - if not kb_id: - return get_error_data_result(message='Lack of "KB ID"') - - ok, kb = KnowledgebaseService.get_by_id(kb_id) - if not ok: - return get_error_data_result(message="Invalid Knowledgebase ID") - - task_id = kb.mindmap_task_id - if not task_id: - return get_json_result(data={}) - - ok, task = TaskService.get_by_id(task_id) - if not ok: - return get_error_data_result(message="Mindmap Task Not Found or Error Occurred") - - return get_json_result(data=task.to_dict()) - - -@manager.route("/unbind_task", methods=["DELETE"]) # noqa: F821 -@login_required -def delete_kb_task(): - kb_id = request.args.get("kb_id", "") - if not kb_id: - return get_error_data_result(message='Lack of "KB ID"') - ok, kb = KnowledgebaseService.get_by_id(kb_id) - if not ok: - return get_json_result(data=True) - - pipeline_task_type = request.args.get("pipeline_task_type", "") - if not pipeline_task_type or pipeline_task_type not in [PipelineTaskType.GRAPH_RAG, PipelineTaskType.RAPTOR, PipelineTaskType.MINDMAP]: - return get_error_data_result(message="Invalid task type") - - def cancel_task(task_id): - REDIS_CONN.set(f"{task_id}-cancel", "x") - - kb_task_id_field: str = "" - kb_task_finish_at: str = "" - match pipeline_task_type: - case PipelineTaskType.GRAPH_RAG: - kb_task_id_field = "graphrag_task_id" - task_id = kb.graphrag_task_id - kb_task_finish_at = "graphrag_task_finish_at" - cancel_task(task_id) - settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id) - case PipelineTaskType.RAPTOR: - kb_task_id_field = "raptor_task_id" - task_id = kb.raptor_task_id - kb_task_finish_at = "raptor_task_finish_at" - cancel_task(task_id) - settings.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id) - case PipelineTaskType.MINDMAP: - kb_task_id_field = "mindmap_task_id" - task_id = kb.mindmap_task_id - kb_task_finish_at = "mindmap_task_finish_at" - cancel_task(task_id) - case _: - return get_error_data_result(message="Internal Error: Invalid task type") - - - ok = KnowledgebaseService.update_by_id(kb_id, {kb_task_id_field: "", kb_task_finish_at: None}) - if not ok: - return server_error_response(f"Internal error: cannot delete task {pipeline_task_type}") - - return get_json_result(data=True) - -@manager.route("/check_embedding", methods=["post"]) # noqa: F821 -@login_required -async def check_embedding(): - - def _guess_vec_field(src: dict) -> str | None: - for k in src or {}: - if k.endswith("_vec"): - return k - return None - - def _as_float_vec(v): - if v is None: - return [] - if isinstance(v, str): - return [float(x) for x in v.split("\t") if x != ""] - if isinstance(v, (list, tuple, np.ndarray)): - return [float(x) for x in v] - return [] - - def _to_1d(x): - a = np.asarray(x, dtype=np.float32) - return a.reshape(-1) - - def _cos_sim(a, b, eps=1e-12): - a = _to_1d(a) - b = _to_1d(b) - na = np.linalg.norm(a) - nb = np.linalg.norm(b) - if na < eps or nb < eps: - return 0.0 - return float(np.dot(a, b) / (na * nb)) - - def sample_random_chunks_with_vectors( - docStoreConn, - tenant_id: str, - kb_id: str, - n: int = 5, - base_fields=("docnm_kwd","doc_id","content_with_weight","page_num_int","position_int","top_int"), - ): - index_nm = search.index_name(tenant_id) - - res0 = docStoreConn.search( - select_fields=[], highlight_fields=[], - condition={"kb_id": kb_id, "available_int": 1}, - match_expressions=[], order_by=OrderByExpr(), - offset=0, limit=1, - index_names=index_nm, knowledgebase_ids=[kb_id] - ) - total = docStoreConn.get_total(res0) - if total <= 0: - return [] - - n = min(n, total) - offsets = sorted(random.sample(range(min(total,1000)), n)) - out = [] - - for off in offsets: - res1 = docStoreConn.search( - select_fields=list(base_fields), - highlight_fields=[], - condition={"kb_id": kb_id, "available_int": 1}, - match_expressions=[], order_by=OrderByExpr(), - offset=off, limit=1, - index_names=index_nm, knowledgebase_ids=[kb_id] - ) - ids = docStoreConn.get_doc_ids(res1) - if not ids: - continue - - cid = ids[0] - full_doc = docStoreConn.get(cid, index_nm, [kb_id]) or {} - vec_field = _guess_vec_field(full_doc) - vec = _as_float_vec(full_doc.get(vec_field)) - - out.append({ - "chunk_id": cid, - "kb_id": kb_id, - "doc_id": full_doc.get("doc_id"), - "doc_name": full_doc.get("docnm_kwd"), - "vector_field": vec_field, - "vector_dim": len(vec), - "vector": vec, - "page_num_int": full_doc.get("page_num_int"), - "position_int": full_doc.get("position_int"), - "top_int": full_doc.get("top_int"), - "content_with_weight": full_doc.get("content_with_weight") or "", - "question_kwd": full_doc.get("question_kwd") or [] - }) - return out - - def _clean(s: str) -> str: - s = re.sub(r"]{0,12})?>", " ", s or "") - return s if s else "None" - req = await get_request_json() - kb_id = req.get("kb_id", "") - tenant_embd_id = req.get("tenant_embd_id") - embd_id = req.get("embd_id", "") - n = int(req.get("check_num", 5)) - _, kb = KnowledgebaseService.get_by_id(kb_id) - tenant_id = kb.tenant_id - if tenant_embd_id: - embd_model_config = get_model_config_by_id(tenant_embd_id) - elif embd_id: - embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_id) - else: - return get_error_data_result("`tenant_embd_id` or `embd_id` is required.") - emb_mdl = LLMBundle(tenant_id, embd_model_config) - samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n) - - results, eff_sims = [], [] - for ck in samples: - title = ck.get("doc_name") or "Title" - txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or "" - txt_in = _clean(txt_in) - if not txt_in: - results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"}) - continue - - if not ck.get("vector"): - results.append({"chunk_id": ck["chunk_id"], "reason": "no_stored_vector"}) - continue - - try: - v, _ = emb_mdl.encode([title, txt_in]) - assert len(v[1]) == len(ck["vector"]), f"The dimension ({len(v[1])}) of given embedding model is different from the original ({len(ck['vector'])})" - sim_content = _cos_sim(v[1], ck["vector"]) - title_w = 0.1 - qv_mix = title_w * v[0] + (1 - title_w) * v[1] - sim_mix = _cos_sim(qv_mix, ck["vector"]) - sim = sim_content - mode = "content_only" - if sim_mix > sim: - sim = sim_mix - mode = "title+content" - except Exception as e: - return get_error_data_result(message=f"Embedding failure. {e}") - - eff_sims.append(sim) - results.append({ - "chunk_id": ck["chunk_id"], - "doc_id": ck["doc_id"], - "doc_name": ck["doc_name"], - "vector_field": ck["vector_field"], - "vector_dim": ck["vector_dim"], - "cos_sim": round(sim, 6), - }) - - summary = { - "kb_id": kb_id, - "model": embd_id, - "sampled": len(samples), - "valid": len(eff_sims), - "avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6), - "min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6), - "max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6), - "match_mode": mode, - } - if summary["avg_cos_sim"] > 0.9: - return get_json_result(data={"summary": summary, "results": results}) - return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results}) diff --git a/api/apps/restful_apis/dataset_api.py b/api/apps/restful_apis/dataset_api.py index 4f3ff2d59a..8a7cd80371 100644 --- a/api/apps/restful_apis/dataset_api.py +++ b/api/apps/restful_apis/dataset_api.py @@ -31,6 +31,50 @@ from api.utils.validation_utils import ( from api.apps.services import dataset_api_service +@manager.route("/datasets/tags/aggregation", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def aggregate_tags(tenant_id): + dataset_ids = request.args.get("dataset_ids", "").split(",") + dataset_ids = [d for d in dataset_ids if d] + if not dataset_ids: + return get_error_data_result(message="Lack of dataset_ids in query parameters") + + try: + success, result = dataset_api_service.aggregate_tags(dataset_ids, tenant_id) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets/metadata/flattened", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def get_flattened_metadata(tenant_id): + dataset_ids = request.args.get("dataset_ids", "").split(",") + dataset_ids = [d for d in dataset_ids if d] + if not dataset_ids: + return get_error_data_result(message="Lack of dataset_ids in query parameters") + + try: + success, result = dataset_api_service.get_flattened_metadata(dataset_ids, tenant_id) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + @manager.route("/datasets", methods=["POST"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs @@ -102,6 +146,8 @@ async def create(tenant_id: str=None): return get_result(data=result) else: return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) except Exception as e: logging.exception(e) return get_error_data_result(message="Internal server error") @@ -330,7 +376,107 @@ def list_datasets(tenant_id): return get_error_data_result(message="Internal server error") -@manager.route('/datasets//knowledge_graph', methods=['GET']) # noqa: F821 +@manager.route("/datasets/", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def get_dataset(tenant_id, dataset_id): + try: + success, result = dataset_api_service.get_dataset(dataset_id, tenant_id) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//ingestions/summary", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def get_ingestion_summary(tenant_id, dataset_id): + try: + success, result = dataset_api_service.get_ingestion_summary(dataset_id, tenant_id) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//tags", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def list_tags(tenant_id, dataset_id): + try: + success, result = dataset_api_service.list_tags(dataset_id, tenant_id) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//tags", methods=["DELETE"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def delete_tags(tenant_id, dataset_id): + req = await request.get_json() + if not req or "tags" not in req: + return get_error_data_result(message="Lack of tags in request body") + if not isinstance(req["tags"], list) or not all(isinstance(t, str) for t in req["tags"]): + return get_error_argument_result("tags must be a list of strings") + + try: + success, result = dataset_api_service.delete_tags(dataset_id, tenant_id, req["tags"]) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//tags", methods=["PUT"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def rename_tag(tenant_id, dataset_id): + req = await request.get_json() + if not req or "from_tag" not in req or "to_tag" not in req: + return get_error_data_result(message="Lack of from_tag or to_tag in request body") + if not isinstance(req["from_tag"], str) or not isinstance(req["to_tag"], str): + return get_error_argument_result("from_tag and to_tag must be strings") + + if not req["from_tag"].strip() or not req["to_tag"].strip(): + return get_error_argument_result("from_tag and to_tag must not be empty") + + try: + success, result = dataset_api_service.rename_tag(dataset_id, tenant_id, req["from_tag"], req["to_tag"]) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route('/datasets//graph/search', methods=['GET']) # noqa: F821 @login_required @add_tenant_id_to_kwargs async def knowledge_graph(tenant_id, dataset_id): @@ -349,7 +495,7 @@ async def knowledge_graph(tenant_id, dataset_id): return get_error_data_result(message="Internal server error") -@manager.route('/datasets//knowledge_graph', methods=['DELETE']) # noqa: F821 +@manager.route('/datasets//graph', methods=['DELETE']) # noqa: F821 @login_required @add_tenant_id_to_kwargs def delete_knowledge_graph(tenant_id, dataset_id): @@ -368,12 +514,67 @@ def delete_knowledge_graph(tenant_id, dataset_id): return get_error_data_result(message="Internal server error") -@manager.route("/datasets//run_graphrag", methods=["POST"]) # noqa: F821 +@manager.route("/datasets//index", methods=["POST"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs -async def run_graphrag(tenant_id, dataset_id): +async def run_index(tenant_id, dataset_id): + index_type = request.args.get("type", "") try: - success, result = dataset_api_service.run_graphrag(dataset_id, tenant_id) + success, result = dataset_api_service.run_index(dataset_id, tenant_id, index_type) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//index", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def trace_index(tenant_id, dataset_id): + index_type = request.args.get("type", "") + try: + success, result = dataset_api_service.trace_index(dataset_id, tenant_id, index_type) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//", methods=["DELETE"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def delete_index(tenant_id, dataset_id, index_type): + if index_type not in dataset_api_service._VALID_INDEX_TYPES: + return get_error_argument_result(f"Invalid index type '{index_type}'") + try: + success, result = dataset_api_service.delete_index(dataset_id, tenant_id, index_type) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//embedding", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def run_embedding(tenant_id, dataset_id): + try: + success, result = dataset_api_service.run_embedding(dataset_id, tenant_id) if success: return get_result(data=result) else: @@ -383,52 +584,50 @@ async def run_graphrag(tenant_id, dataset_id): return get_error_data_result(message="Internal server error") -@manager.route("/datasets//trace_graphrag", methods=["GET"]) # noqa: F821 +@manager.route("/datasets//ingestions", methods=["GET"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs -def trace_graphrag(tenant_id, dataset_id): +def list_ingestion_logs(tenant_id, dataset_id): try: - success, result = dataset_api_service.trace_graphrag(dataset_id, tenant_id) + page = int(request.args.get("page", 0)) + page_size = int(request.args.get("page_size", 0)) + orderby = request.args.get("orderby", "create_time") + desc = request.args.get("desc", "true").lower() != "false" + operation_status = request.args.getlist("operation_status") + create_date_from = request.args.get("create_date_from", None) + create_date_to = request.args.get("create_date_to", None) + success, result = dataset_api_service.list_ingestion_logs( + dataset_id, tenant_id, page, page_size, orderby, desc, operation_status, create_date_from, create_date_to + ) if success: return get_result(data=result) else: return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) except Exception as e: logging.exception(e) return get_error_data_result(message="Internal server error") -@manager.route("/datasets//run_raptor", methods=["POST"]) # noqa: F821 +@manager.route("/datasets//ingestions/", methods=["GET"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs -async def run_raptor(tenant_id, dataset_id): +def get_ingestion_log(tenant_id, dataset_id, log_id): try: - success, result = dataset_api_service.run_raptor(dataset_id, tenant_id) + success, result = dataset_api_service.get_ingestion_log(dataset_id, tenant_id, log_id) if success: return get_result(data=result) else: return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) except Exception as e: logging.exception(e) return get_error_data_result(message="Internal server error") -@manager.route("/datasets//trace_raptor", methods=["GET"]) # noqa: F821 -@login_required -@add_tenant_id_to_kwargs -def trace_raptor(tenant_id, dataset_id): - try: - success, result = dataset_api_service.trace_raptor(dataset_id, tenant_id) - if success: - return get_result(data=result) - else: - return get_error_data_result(message=result) - except Exception as e: - logging.exception(e) - return get_error_data_result(message="Internal server error") - - -@manager.route("/datasets//auto_metadata", methods=["GET"]) # noqa: F821 +@manager.route("/datasets//metadata/config", methods=["GET"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs def get_auto_metadata(tenant_id, dataset_id): @@ -462,12 +661,14 @@ def get_auto_metadata(tenant_id, dataset_id): return get_result(data=result) else: return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) except Exception as e: logging.exception(e) return get_error_data_result(message="Internal server error") -@manager.route("/datasets//auto_metadata", methods=["PUT"]) # noqa: F821 +@manager.route("/datasets//metadata/config", methods=["PUT"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs async def update_auto_metadata(tenant_id, dataset_id): @@ -512,6 +713,8 @@ async def update_auto_metadata(tenant_id, dataset_id): return get_result(data=result) else: return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) except Exception as e: logging.exception(e) return get_error_data_result(message="Internal server error") diff --git a/api/apps/restful_apis/document_api.py b/api/apps/restful_apis/document_api.py index 220ed2c624..8098dbec8c 100644 --- a/api/apps/restful_apis/document_api.py +++ b/api/apps/restful_apis/document_api.py @@ -26,18 +26,22 @@ from api.apps.services.document_api_service import validate_document_update_fiel from api.constants import IMG_BASE64_PREFIX from api.db import VALID_FILE_TYPES from api.db.services.doc_metadata_service import DocMetadataService +from api.db.db_models import Task from api.db.services.document_service import DocumentService from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.task_service import TaskService, cancel_all_task_of from api.common.check_team_permission import check_kb_team_permission from api.utils.api_utils import get_data_error_result, get_error_data_result, get_result, get_json_result, \ server_error_response, add_tenant_id_to_kwargs, get_request_json, get_error_argument_result, check_duplicate_ids from api.utils.validation_utils import ( UpdateDocumentReq, format_validation_error_message, validate_and_parse_json_request, DeleteDocumentReq, ) -from common.constants import RetCode +from common import settings +from common.constants import RetCode, TaskStatus from common.metadata_utils import convert_conditions, meta_filter, turn2jsonschema from common.misc_utils import thread_pool_exec +from rag.nlp import search @manager.route("/datasets//documents/", methods=["PATCH"]) # noqa: F821 @login_required @@ -192,6 +196,88 @@ async def metadata_summary(dataset_id, tenant_id): return server_error_response(e) +@manager.route("/datasets//metadata/update", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def metadata_batch_update(dataset_id, tenant_id): + """ + Batch update metadata for documents in a dataset. + --- + tags: + - Documents + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + selector: + type: object + updates: + type: array + deletes: + type: array + responses: + 200: + description: Metadata updated successfully. + """ + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ") + + req = await get_request_json() + selector = req.get("selector", {}) or {} + updates = req.get("updates", []) or [] + deletes = req.get("deletes", []) or [] + + if not isinstance(selector, dict): + return get_error_data_result(message="selector must be an object.") + if not isinstance(updates, list) or not isinstance(deletes, list): + return get_error_data_result(message="updates and deletes must be lists.") + + metadata_condition = selector.get("metadata_condition", {}) or {} + if metadata_condition and not isinstance(metadata_condition, dict): + return get_error_data_result(message="metadata_condition must be an object.") + + document_ids = selector.get("document_ids", []) or [] + if document_ids and not isinstance(document_ids, list): + return get_error_data_result(message="document_ids must be a list.") + + for upd in updates: + if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd: + return get_error_data_result(message="Each update requires key and value.") + for d in deletes: + if not isinstance(d, dict) or not d.get("key"): + return get_error_data_result(message="Each delete requires key.") + + target_doc_ids = set() + if document_ids: + kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id]) + invalid_ids = set(document_ids) - set(kb_doc_ids) + if invalid_ids: + return get_error_data_result(message=f"These documents do not belong to dataset {dataset_id}: {', '.join(invalid_ids)}") + target_doc_ids = set(document_ids) + + if metadata_condition: + metas = DocMetadataService.get_flatted_meta_by_kbs([dataset_id]) + filtered_ids = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))) + target_doc_ids = target_doc_ids & filtered_ids + if metadata_condition.get("conditions") and not target_doc_ids: + return get_result(data={"updated": 0, "matched_docs": 0}) + + target_doc_ids = list(target_doc_ids) + updated = DocMetadataService.batch_update_metadata(dataset_id, target_doc_ids, updates, deletes) + return get_result(data={"updated": updated, "matched_docs": len(target_doc_ids)}) + + @manager.route("/datasets//documents", methods=["POST"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs @@ -1019,3 +1105,217 @@ async def update_metadata(tenant_id, dataset_id): target_doc_ids = list(target_doc_ids) updated = DocMetadataService.batch_update_metadata(dataset_id, target_doc_ids, updates, deletes) return get_result(data={"updated": updated, "matched_docs": len(target_doc_ids)}) + + +@manager.route("/datasets//documents/parse", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def parse_documents(tenant_id, dataset_id): + """ + Start parsing documents in a dataset. + --- + tags: + - Documents + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Document parse parameters. + required: true + schema: + type: object + properties: + document_ids: + type: array + items: + type: string + description: List of document IDs to parse. + responses: + 200: + description: Successful operation. + """ + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + + req = await get_request_json() + if req is None: + return get_error_data_result(message="Request body is required") + + document_ids = req.get("document_ids") + if document_ids is None or not isinstance(document_ids, list): + return get_error_data_result(message="`document_ids` is required") + if len(document_ids) == 0: + return get_error_data_result(message="`document_ids` is required") + + # Check for duplicate document IDs + unique_doc_ids, duplicate_messages = check_duplicate_ids(document_ids, "document") + errors = duplicate_messages if duplicate_messages else [] + + # Validate all document IDs belong to the dataset + not_found_ids = [] + valid_doc_ids = [] + for doc_id in unique_doc_ids: + docs = DocumentService.query(kb_id=dataset_id, id=doc_id) + if not docs: + not_found_ids.append(doc_id) + else: + valid_doc_ids.append(doc_id) + + if not_found_ids: + errors.append(f"Documents not found: {not_found_ids}") + # Still parse valid documents, but return error code + if not valid_doc_ids: + return get_error_data_result(message=f"Documents not found: {not_found_ids}") + + try: + def _run_sync(): + kb_table_num_map = {} + success_count = 0 + for doc_id in valid_doc_ids: + e, doc = DocumentService.get_by_id(doc_id) + if not e: + errors.append(f"Document not found: {doc_id}") + continue + + info = {"run": str(TaskStatus.RUNNING.value), "progress": 0} + # If re-running a completed document, clear previous chunks + if str(doc.run) == TaskStatus.DONE.value: + DocumentService.clear_chunk_num_when_rerun(doc.id) + info["progress_msg"] = "" + info["chunk_num"] = 0 + info["token_num"] = 0 + + DocumentService.update_by_id(doc_id, info) + TaskService.filter_delete([Task.doc_id == doc_id]) + if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id): + settings.docStoreConn.delete({"doc_id": doc_id}, search.index_name(tenant_id), doc.kb_id) + + doc_dict = doc.to_dict() + DocumentService.run(tenant_id, doc_dict, kb_table_num_map) + success_count += 1 + + result = {"success_count": success_count} + if errors: + result["errors"] = errors + return result + + result = await thread_pool_exec(_run_sync) + if not_found_ids: + return get_error_data_result(message=f"Documents not found: {not_found_ids}") + return get_result(data=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//documents/stop", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def stop_parse_documents(tenant_id, dataset_id): + """ + Stop parsing documents in a dataset. + --- + tags: + - Documents + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Document stop parse parameters. + required: true + schema: + type: object + properties: + document_ids: + type: array + items: + type: string + description: List of document IDs to stop parsing. + responses: + 200: + description: Successful operation. + """ + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + + req = await get_request_json() + if req is None: + return get_error_data_result(message="Request body is required") + + document_ids = req.get("document_ids") + if document_ids is None or not isinstance(document_ids, list): + return get_error_data_result(message="`document_ids` is required") + if len(document_ids) == 0: + return get_error_data_result(message="`document_ids` is required") + + # Check for duplicate document IDs + unique_doc_ids, duplicate_messages = check_duplicate_ids(document_ids, "document") + errors = duplicate_messages if duplicate_messages else [] + + # Validate all document IDs belong to the dataset + not_found_ids = [] + valid_doc_ids = [] + for doc_id in unique_doc_ids: + docs = DocumentService.query(kb_id=dataset_id, id=doc_id) + if not docs: + not_found_ids.append(doc_id) + else: + valid_doc_ids.append(doc_id) + + if not_found_ids: + return get_error_data_result(message=f"Documents not found: {not_found_ids}") + + try: + def _run_sync(): + success_count = 0 + for doc_id in valid_doc_ids: + e, doc = DocumentService.get_by_id(doc_id) + if not e: + errors.append(f"Document not found: {doc_id}") + continue + + # Check if the document is currently running + tasks = list(TaskService.query(doc_id=doc_id)) + has_unfinished_task = any((task.progress or 0) < 1 for task in tasks) + if str(doc.run) not in [TaskStatus.RUNNING.value, TaskStatus.CANCEL.value] and not has_unfinished_task: + errors.append("Can't stop parsing document that has not started or already completed") + continue + + cancel_all_task_of(doc_id) + DocumentService.update_by_id(doc_id, {"run": str(TaskStatus.CANCEL.value)}) + success_count += 1 + + result = {"success_count": success_count} + if errors: + result["errors"] = errors + return result + + result = await thread_pool_exec(_run_sync) + if not_found_ids: + return get_error_data_result(message=f"Documents not found: {not_found_ids}") + return get_result(data=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") diff --git a/api/apps/services/dataset_api_service.py b/api/apps/services/dataset_api_service.py index 8cb718467a..509104e7e9 100644 --- a/api/apps/services/dataset_api_service.py +++ b/api/apps/services/dataset_api_service.py @@ -25,10 +25,30 @@ from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.connector_service import Connector2KbService from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService -from api.db.services.user_service import TenantService, UserService +from api.db.services.user_service import TenantService, UserService, UserTenantService from common.constants import FileSource, StatusEnum from api.utils.api_utils import deep_merge, get_parser_config, remap_dictionary_keys, verify_embedding_availability +_VALID_INDEX_TYPES = {"graph", "raptor", "mindmap"} + +_INDEX_TYPE_TO_TASK_TYPE = { + "graph": "graphrag", + "raptor": "raptor", + "mindmap": "mindmap", +} + +_INDEX_TYPE_TO_TASK_ID_FIELD = { + "graph": "graphrag_task_id", + "raptor": "raptor_task_id", + "mindmap": "mindmap_task_id", +} + +_INDEX_TYPE_TO_DISPLAY_NAME = { + "graph": "Graph", + "raptor": "RAPTOR", + "mindmap": "Mindmap", +} + async def create_dataset(tenant_id: str, req: dict): """ @@ -158,6 +178,55 @@ async def delete_datasets(tenant_id: str, ids: list = None, delete_all: bool = F return True, {"success_count": success_count, "errors": errors[:5]} +def get_dataset(dataset_id: str, tenant_id: str): + """ + Get a single dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :return: (success, result) or (success, error_message) + """ + if not dataset_id: + return False, 'Lack of "Dataset ID"' + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'" + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return False, "Invalid Dataset ID" + + response_data = remap_dictionary_keys(kb.to_dict()) + return True, response_data + + +def get_ingestion_summary(dataset_id: str, tenant_id: str): + """ + Get ingestion summary for a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :return: (success, result) or (success, error_message) + """ + if not dataset_id: + return False, 'Lack of "Dataset ID"' + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'" + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return False, "Invalid Dataset ID" + + status = DocumentService.get_parsing_status_by_kb_ids([dataset_id]).get(dataset_id, {}) + return True, { + "doc_num": kb.doc_num, + "chunk_num": kb.chunk_num, + "token_num": kb.token_num, + "status": status, + } + + async def update_dataset(tenant_id: str, dataset_id: str, req: dict): """ Update a dataset. @@ -404,14 +473,18 @@ def delete_knowledge_graph(dataset_id: str, tenant_id: str): return True, True -def run_graphrag(dataset_id: str, tenant_id: str): +def run_index(dataset_id: str, tenant_id: str, index_type: str): """ - Run GraphRAG for a dataset. + Run an indexing task (graph/raptor/mindmap) for a dataset. :param dataset_id: dataset ID :param tenant_id: tenant ID + :param index_type: one of "graph", "raptor", "mindmap" :return: (success, result) or (success, error_message) """ + if index_type not in _VALID_INDEX_TYPES: + return False, f"Invalid index type '{index_type}'. Must be one of {sorted(_VALID_INDEX_TYPES)}" + if not dataset_id: return False, 'Lack of "Dataset ID"' if not KnowledgebaseService.accessible(dataset_id, tenant_id): @@ -421,14 +494,18 @@ def run_graphrag(dataset_id: str, tenant_id: str): if not ok: return False, "Invalid Dataset ID" - task_id = kb.graphrag_task_id - if task_id: - ok, task = TaskService.get_by_id(task_id) + task_type = _INDEX_TYPE_TO_TASK_TYPE[index_type] + task_id_field = _INDEX_TYPE_TO_TASK_ID_FIELD[index_type] + display_name = _INDEX_TYPE_TO_DISPLAY_NAME[index_type] + + existing_task_id = getattr(kb, task_id_field, None) + if existing_task_id: + ok, task = TaskService.get_by_id(existing_task_id) if not ok: - logging.warning(f"A valid GraphRAG task id is expected for Dataset {dataset_id}") + logging.warning(f"A valid {display_name} task id is expected for Dataset {dataset_id}") if task and task.progress not in [-1, 1]: - return False, f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running." + return False, f"Task {existing_task_id} in progress with status {task.progress}. A {display_name} Task is already running." documents, _ = DocumentService.get_by_kb_id( kb_id=dataset_id, @@ -447,24 +524,29 @@ def run_graphrag(dataset_id: str, tenant_id: str): sample_document = documents[0] document_ids = [document["id"] for document in documents] - task_id = queue_raptor_o_graphrag_tasks(sample_doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) + task_id = queue_raptor_o_graphrag_tasks(sample_doc=sample_document, ty=task_type, priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) - if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}): - logging.warning(f"Cannot save graphrag_task_id for Dataset {dataset_id}") + if not KnowledgebaseService.update_by_id(kb.id, {task_id_field: task_id}): + logging.warning(f"Cannot save {task_id_field} for Dataset {dataset_id}") - return True, {"graphrag_task_id": task_id} + return True, {"task_id": task_id} -def trace_graphrag(dataset_id: str, tenant_id: str): +def trace_index(dataset_id: str, tenant_id: str, index_type: str): """ - Trace GraphRAG task for a dataset. + Trace an indexing task (graph/raptor/mindmap) for a dataset. :param dataset_id: dataset ID :param tenant_id: tenant ID + :param index_type: one of "graph", "raptor", "mindmap" :return: (success, result) or (success, error_message) """ + if index_type not in _VALID_INDEX_TYPES: + return False, f"Invalid index type '{index_type}'. Must be one of {sorted(_VALID_INDEX_TYPES)}" + if not dataset_id: return False, 'Lack of "Dataset ID"' + if not KnowledgebaseService.accessible(dataset_id, tenant_id): return False, "No authorization." @@ -472,7 +554,8 @@ def trace_graphrag(dataset_id: str, tenant_id: str): if not ok: return False, "Invalid Dataset ID" - task_id = kb.graphrag_task_id + task_id_field = _INDEX_TYPE_TO_TASK_ID_FIELD[index_type] + task_id = getattr(kb, task_id_field, None) if not task_id: return True, {} @@ -483,9 +566,9 @@ def trace_graphrag(dataset_id: str, tenant_id: str): return True, task.to_dict() -def run_raptor(dataset_id: str, tenant_id: str): +def list_tags(dataset_id: str, tenant_id: str): """ - Run RAPTOR for a dataset. + List tags for a dataset. :param dataset_id: dataset ID :param tenant_id: tenant ID @@ -493,74 +576,65 @@ def run_raptor(dataset_id: str, tenant_id: str): """ if not dataset_id: return False, 'Lack of "Dataset ID"' + if not KnowledgebaseService.accessible(dataset_id, tenant_id): return False, "No authorization." - ok, kb = KnowledgebaseService.get_by_id(dataset_id) - if not ok: - return False, "Invalid Dataset ID" + tenants = UserTenantService.get_tenants_by_user_id(tenant_id) + tags = [] + for tenant in tenants: + tags += settings.retriever.all_tags(tenant["tenant_id"], [dataset_id]) + return True, tags - task_id = kb.raptor_task_id - if task_id: - ok, task = TaskService.get_by_id(task_id) + +def aggregate_tags(dataset_ids: list[str], tenant_id: str): + """ + Aggregate tags across multiple datasets. + + :param dataset_ids: list of dataset IDs + :param tenant_id: tenant ID + :return: (success, result) or (success, error_message) + """ + if not dataset_ids: + return False, 'Lack of "dataset_ids"' + + for dataset_id in dataset_ids: + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, f"No authorization for dataset '{dataset_id}'" + + dataset_ids_by_tenant = {} + for dataset_id in dataset_ids: + ok, kb = KnowledgebaseService.get_by_id(dataset_id) if not ok: - logging.warning(f"A valid RAPTOR task id is expected for Dataset {dataset_id}") + return False, f"Invalid Dataset ID '{dataset_id}'" + dataset_ids_by_tenant.setdefault(kb.tenant_id, []).append(dataset_id) - if task and task.progress not in [-1, 1]: - return False, f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running." + merged = {} + for kb_tenant_id, kb_ids in dataset_ids_by_tenant.items(): + for bucket in settings.retriever.all_tags(kb_tenant_id, kb_ids): + tag = bucket["value"] + merged[tag] = merged.get(tag, 0) + bucket["count"] - documents, _ = DocumentService.get_by_kb_id( - kb_id=dataset_id, - page_number=0, - items_per_page=0, - orderby="create_time", - desc=False, - keywords="", - run_status=[], - types=[], - suffix=[], - ) - if not documents: - return False, f"No documents in Dataset {dataset_id}" - - sample_document = documents[0] - document_ids = [document["id"] for document in documents] - - task_id = queue_raptor_o_graphrag_tasks(sample_doc=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) - - if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}): - logging.warning(f"Cannot save raptor_task_id for Dataset {dataset_id}") - - return True, {"raptor_task_id": task_id} + return True, [{"value": tag, "count": count} for tag, count in merged.items()] -def trace_raptor(dataset_id: str, tenant_id: str): +def get_flattened_metadata(dataset_ids: list[str], tenant_id: str): """ - Trace RAPTOR task for a dataset. + Get flattened metadata for datasets. - :param dataset_id: dataset ID + :param dataset_ids: list of dataset IDs :param tenant_id: tenant ID :return: (success, result) or (success, error_message) """ - if not dataset_id: - return False, 'Lack of "Dataset ID"' + if not dataset_ids: + return False, 'Lack of "dataset_ids"' - if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return False, "No authorization." + for dataset_id in dataset_ids: + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, f"No authorization for dataset '{dataset_id}'" - ok, kb = KnowledgebaseService.get_by_id(dataset_id) - if not ok: - return False, "Invalid Dataset ID" - - task_id = kb.raptor_task_id - if not task_id: - return True, {} - - ok, task = TaskService.get_by_id(task_id) - if not ok: - return False, "RAPTOR Task Not Found or Error Occurred" - - return True, task.to_dict() + from api.db.services.doc_metadata_service import DocMetadataService + return True, DocMetadataService.get_flatted_meta_by_kbs(dataset_ids) def get_auto_metadata(dataset_id: str, tenant_id: str): @@ -627,3 +701,202 @@ async def update_auto_metadata(dataset_id: str, tenant_id: str, cfg: dict): return False, "Update auto-metadata error.(Database error)" return True, {"enabled": parser_cfg["enable_metadata"], "fields": fields} + + +def delete_tags(dataset_id: str, tenant_id: str, tags: list[str]): + """ + Delete tags from a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :param tags: list of tags to delete + :return: (success, result) or (success, error_message) + """ + if not dataset_id: + return False, 'Lack of "Dataset ID"' + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, "No authorization." + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return False, "Invalid Dataset ID" + + from rag.nlp import search + for t in tags: + settings.docStoreConn.update({"tag_kwd": t, "kb_id": [dataset_id]}, + {"remove": {"tag_kwd": t}}, + search.index_name(kb.tenant_id), + dataset_id) + + return True, {} + +def list_ingestion_logs(dataset_id: str, tenant_id: str, page: int, page_size: int, orderby: str, desc: bool, operation_status: list = None, create_date_from: str = None, create_date_to: str = None): + """ + List ingestion logs for a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :param page: page number + :param page_size: items per page + :param orderby: order by field + :param desc: descending order + :param operation_status: filter by operation status + :param create_date_from: filter start date + :param create_date_to: filter end date + :return: (success, result) or (success, error_message) + """ + if not dataset_id: + return False, 'Lack of "Dataset ID"' + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, "No authorization." + + from api.db.services.pipeline_operation_log_service import PipelineOperationLogService + logs, total = PipelineOperationLogService.get_dataset_logs_by_kb_id( + dataset_id, page, page_size, orderby, desc, operation_status or [], create_date_from, create_date_to + ) + return True, {"total": total, "logs": logs} + + +def get_ingestion_log(dataset_id: str, tenant_id: str, log_id: str): + """ + Get a single ingestion log. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :param log_id: log ID + :return: (success, result) or (success, error_message) + """ + if not dataset_id: + return False, 'Lack of "Dataset ID"' + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, "No authorization." + + from api.db.services.pipeline_operation_log_service import PipelineOperationLogService + fields = PipelineOperationLogService.get_dataset_logs_fields() + log = PipelineOperationLogService.model.select(*fields).where( + (PipelineOperationLogService.model.id == log_id) & (PipelineOperationLogService.model.kb_id == dataset_id) + ).first() + if not log: + return False, "Log not found" + + return True, log.to_dict() + + +def delete_index(dataset_id: str, tenant_id: str, index_type: str): + """ + Delete an indexing task (graph/raptor/mindmap) for a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :param index_type: one of "graph", "raptor", "mindmap" + :return: (success, result) or (success, error_message) + """ + if index_type not in _VALID_INDEX_TYPES: + return False, f"Invalid index type '{index_type}'. Must be one of {sorted(_VALID_INDEX_TYPES)}" + + if not dataset_id: + return False, 'Lack of "Dataset ID"' + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, "No authorization." + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return False, "Invalid Dataset ID" + + task_id_field = _INDEX_TYPE_TO_TASK_ID_FIELD[index_type] + task_finish_at_field = f"{task_id_field.replace('_task_id', '_task_finish_at')}" + task_id = getattr(kb, task_id_field, None) + + if task_id: + from rag.utils.redis_conn import REDIS_CONN + try: + REDIS_CONN.set(f"{task_id}-cancel", "x") + except Exception as e: + logging.exception(e) + TaskService.delete_by_id(task_id) + + if index_type == "graph": + from rag.nlp import search + settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, + search.index_name(kb.tenant_id), dataset_id) + elif index_type == "raptor": + from rag.nlp import search + settings.docStoreConn.delete({"raptor_kwd": ["raptor"]}, + search.index_name(kb.tenant_id), dataset_id) + + KnowledgebaseService.update_by_id(kb.id, {task_id_field: "", task_finish_at_field: None}) + return True, {} + + +def run_embedding(dataset_id: str, tenant_id: str): + """ + Run embedding for all documents in a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :return: (success, result) or (success, error_message) + """ + if not dataset_id: + return False, 'Lack of "Dataset ID"' + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, "No authorization." + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return False, "Invalid Dataset ID" + + documents, _ = DocumentService.get_by_kb_id( + kb_id=dataset_id, + page_number=0, + items_per_page=0, + orderby="create_time", + desc=False, + keywords="", + run_status=[], + types=[], + suffix=[], + ) + if not documents: + return False, f"No documents in Dataset {dataset_id}" + + kb_table_num_map = {} + for doc in documents: + doc["tenant_id"] = tenant_id + DocumentService.run(tenant_id, doc, kb_table_num_map) + + return True, {"scheduled_count": len(documents)} + + +def rename_tag(dataset_id: str, tenant_id: str, from_tag: str, to_tag: str): + """ + Rename a tag in a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :param from_tag: original tag name + :param to_tag: new tag name + :return: (success, result) or (success, error_message) + """ + if not dataset_id: + return False, 'Lack of "Dataset ID"' + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, "No authorization." + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return False, "Invalid Dataset ID" + + from rag.nlp import search + settings.docStoreConn.update({"tag_kwd": from_tag, "kb_id": [dataset_id]}, + {"remove": {"tag_kwd": from_tag.strip()}, "add": {"tag_kwd": to_tag}}, + search.index_name(kb.tenant_id), + dataset_id) + + return True, {"from": from_tag, "to": to_tag} + diff --git a/api/db/services/doc_metadata_service.py b/api/db/services/doc_metadata_service.py index 7a9e435e07..2e4b93056b 100644 --- a/api/db/services/doc_metadata_service.py +++ b/api/db/services/doc_metadata_service.py @@ -454,19 +454,27 @@ class DocMetadataService: # Index exists - check if document exists try: doc_exists = settings.docStoreConn.get( - index_name=index_name, - id=doc_id, - kb_id=kb_id + doc_id, + index_name, + [kb_id] ) if doc_exists: - # Document exists - use partial update + # Document exists - replace meta_fields entirely + # Use upsert to fully replace the meta_fields field + # (ES update with doc parameter does deep merge on object fields, + # which would retain old keys that should be removed) settings.docStoreConn.es.update( index=index_name, id=doc_id, refresh=True, - doc={"meta_fields": processed_meta} + body={ + "script": { + "source": "ctx._source.meta_fields = params.meta_fields", + "params": {"meta_fields": processed_meta} + } + } ) - logging.debug(f"Successfully updated metadata for document {doc_id} using ES partial update") + logging.debug(f"Successfully updated metadata for document {doc_id} using ES script update") return True except Exception as e: logging.debug(f"Document {doc_id} not found in index, will insert: {e}") diff --git a/sdk/python/ragflow_sdk/modules/dataset.py b/sdk/python/ragflow_sdk/modules/dataset.py index b464fe70de..fd65e6116f 100644 --- a/sdk/python/ragflow_sdk/modules/dataset.py +++ b/sdk/python/ragflow_sdk/modules/dataset.py @@ -165,7 +165,7 @@ class DataSet(Base): """ Retrieve auto-metadata configuration for a dataset via SDK. """ - res = self.get(f"/datasets/{self.id}/auto_metadata") + res = self.get(f"/datasets/{self.id}/metadata/config") res = res.json() if res.get("code") == 0: return res["data"] @@ -175,7 +175,7 @@ class DataSet(Base): """ Update auto-metadata configuration for a dataset via SDK. """ - res = self.put(f"/datasets/{self.id}/auto_metadata", config) + res = self.put(f"/datasets/{self.id}/metadata/config", config) res = res.json() if res.get("code") == 0: return res["data"] diff --git a/sdk/python/test/test_frontend_api/common.py b/sdk/python/test/test_frontend_api/common.py index e054bba8f3..7e09041eb5 100644 --- a/sdk/python/test/test_frontend_api/common.py +++ b/sdk/python/test/test_frontend_api/common.py @@ -19,38 +19,33 @@ import os import requests HOST_ADDRESS = os.getenv("HOST_ADDRESS", "http://127.0.0.1:9380") +API_VERSION = "v1" +DATASETS_API_URL = f"/api/{API_VERSION}/datasets" DATASET_NAME_LIMIT = 128 -def create_dataset(auth, dataset_name): - authorization = {"Authorization": auth} - url = f"{HOST_ADDRESS}/v1/kb/create" - json = {"name": dataset_name} - res = requests.post(url=url, headers=authorization, json=json) +def create_dataset(auth, payload=None): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}" + res = requests.post(url=url, headers={"Content-Type": "application/json"}, auth=auth, json=payload) return res.json() -def list_dataset(auth, page_number, page_size=30): - authorization = {"Authorization": auth} - url = f"{HOST_ADDRESS}/v1/kb/list?page={page_number}&page_size={page_size}" - json = {} - res = requests.post(url=url, headers=authorization, json=json) +def list_dataset(auth, params=None): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}" + res = requests.get(url=url, headers={"Content-Type": "application/json"}, auth=auth, params=params) return res.json() -def rm_dataset(auth, dataset_id): - authorization = {"Authorization": auth} - url = f"{HOST_ADDRESS}/v1/kb/rm" - json = {"kb_id": dataset_id} - res = requests.post(url=url, headers=authorization, json=json) +def rm_dataset(auth, dataset_ids): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}" + res = requests.delete(url=url, headers={"Content-Type": "application/json"}, auth=auth, json={"ids": dataset_ids}) return res.json() -def update_dataset(auth, json_req): - authorization = {"Authorization": auth} - url = f"{HOST_ADDRESS}/v1/kb/update" - res = requests.post(url=url, headers=authorization, json=json_req) +def update_dataset(auth, dataset_id, payload=None): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}" + res = requests.put(url=url, headers={"Content-Type": "application/json"}, auth=auth, json=payload) return res.json() diff --git a/sdk/python/test/test_frontend_api/test_chunk.py b/sdk/python/test/test_frontend_api/test_chunk.py index fadeb10ee2..b1f7ff1bd1 100644 --- a/sdk/python/test/test_frontend_api/test_chunk.py +++ b/sdk/python/test/test_frontend_api/test_chunk.py @@ -21,7 +21,7 @@ from timeit import default_timer as timer def test_parse_txt_document(get_auth): # create dataset - res = create_dataset(get_auth, "test_parse_txt_document") + res = create_dataset(get_auth, {"name": "test_parse_txt_document"}) assert res.get("code") == 0, f"{res.get('message')}" # list dataset @@ -29,8 +29,10 @@ def test_parse_txt_document(get_auth): dataset_list = [] dataset_id = None while True: - res = list_dataset(get_auth, page_number) - data = res.get("data").get("kbs") + res = list_dataset(get_auth, {"page": page_number, "page_size": 150}) + data = res.get("data") + if isinstance(data, dict): + data = data.get("kbs", []) for item in data: dataset_id = item.get("id") dataset_list.append(dataset_id) @@ -66,7 +68,7 @@ def test_parse_txt_document(get_auth): print('time cost {:.1f}s'.format(timer() - start_ts)) # delete dataset - for dataset_id in dataset_list: - res = rm_dataset(get_auth, dataset_id) + if dataset_list: + res = rm_dataset(get_auth, dataset_list) assert res.get("code") == 0, f"{res.get('message')}" print(f"{len(dataset_list)} datasets are deleted") diff --git a/sdk/python/test/test_frontend_api/test_dataset.py b/sdk/python/test/test_frontend_api/test_dataset.py index b00f343648..bfbc02da2d 100644 --- a/sdk/python/test/test_frontend_api/test_dataset.py +++ b/sdk/python/test/test_frontend_api/test_dataset.py @@ -22,15 +22,17 @@ import string def test_dataset(get_auth): # create dataset - res = create_dataset(get_auth, "test_create_dataset") + res = create_dataset(get_auth, {"name": "test_create_dataset"}) assert res.get("code") == 0, f"{res.get('message')}" # list dataset page_number = 1 dataset_list = [] while True: - res = list_dataset(get_auth, page_number) - data = res.get("data").get("kbs") + res = list_dataset(get_auth, {"page": page_number, "page_size": 150}) + data = res.get("data") + if isinstance(data, dict): + data = data.get("kbs", []) for item in data: dataset_id = item.get("id") dataset_list.append(dataset_id) @@ -40,8 +42,8 @@ def test_dataset(get_auth): print(f"found {len(dataset_list)} datasets") # delete dataset - for dataset_id in dataset_list: - res = rm_dataset(get_auth, dataset_id) + if dataset_list: + res = rm_dataset(get_auth, dataset_list) assert res.get("code") == 0, f"{res.get('message')}" print(f"{len(dataset_list)} datasets are deleted") @@ -49,15 +51,17 @@ def test_dataset(get_auth): def test_dataset_1k_dataset(get_auth): # create dataset for i in range(1000): - res = create_dataset(get_auth, f"test_create_dataset_{i}") + res = create_dataset(get_auth, {"name": f"test_create_dataset_{i}"}) assert res.get("code") == 0, f"{res.get('message')}" # list dataset page_number = 1 dataset_list = [] while True: - res = list_dataset(get_auth, page_number) - data = res.get("data").get("kbs") + res = list_dataset(get_auth, {"page": page_number, "page_size": 150}) + data = res.get("data") + if isinstance(data, dict): + data = data.get("kbs", []) for item in data: dataset_id = item.get("id") dataset_list.append(dataset_id) @@ -67,8 +71,8 @@ def test_dataset_1k_dataset(get_auth): print(f"found {len(dataset_list)} datasets") # delete dataset - for dataset_id in dataset_list: - res = rm_dataset(get_auth, dataset_id) + if dataset_list: + res = rm_dataset(get_auth, dataset_list) assert res.get("code") == 0, f"{res.get('message')}" print(f"{len(dataset_list)} datasets are deleted") @@ -76,12 +80,14 @@ def test_dataset_1k_dataset(get_auth): def test_duplicated_name_dataset(get_auth): # create dataset for i in range(20): - res = create_dataset(get_auth, "test_create_dataset") + res = create_dataset(get_auth, {"name": "test_create_dataset"}) assert res.get("code") == 0, f"{res.get('message')}" # list dataset - res = list_dataset(get_auth, 1) - data = res.get("data").get("kbs") + res = list_dataset(get_auth, {"page": 1}) + data = res.get("data") + if isinstance(data, dict): + data = data.get("kbs", []) dataset_list = [] pattern = r'^test_create_dataset.*' for item in data: @@ -91,19 +97,18 @@ def test_duplicated_name_dataset(get_auth): match = re.match(pattern, dataset_name) assert match is not None - for dataset_id in dataset_list: - res = rm_dataset(get_auth, dataset_id) + if dataset_list: + res = rm_dataset(get_auth, dataset_list) assert res.get("code") == 0, f"{res.get('message')}" print(f"{len(dataset_list)} datasets are deleted") def test_invalid_name_dataset(get_auth): # create dataset - # with pytest.raises(Exception) as e: - res = create_dataset(get_auth, 0) + res = create_dataset(get_auth, {"name": 0}) assert res['code'] != 0 - res = create_dataset(get_auth, "") + res = create_dataset(get_auth, {"name": ""}) assert res['code'] != 0 long_string = "" @@ -111,22 +116,24 @@ def test_invalid_name_dataset(get_auth): while len(long_string.encode("utf-8")) <= DATASET_NAME_LIMIT: long_string += random.choice(string.ascii_letters + string.digits) - res = create_dataset(get_auth, long_string) + res = create_dataset(get_auth, {"name": long_string}) assert res['code'] != 0 print(res) def test_update_different_params_dataset_success(get_auth): # create dataset - res = create_dataset(get_auth, "test_create_dataset") + res = create_dataset(get_auth, {"name": "test_create_dataset"}) assert res.get("code") == 0, f"{res.get('message')}" # list dataset page_number = 1 dataset_list = [] while True: - res = list_dataset(get_auth, page_number) - data = res.get("data").get("kbs") + res = list_dataset(get_auth, {"page": page_number, "page_size": 150}) + data = res.get("data") + if isinstance(data, dict): + data = data.get("kbs", []) for item in data: dataset_id = item.get("id") dataset_list.append(dataset_id) @@ -137,15 +144,18 @@ def test_update_different_params_dataset_success(get_auth): print(f"found {len(dataset_list)} datasets") dataset_id = dataset_list[0] - json_req = {"kb_id": dataset_id, "name": "test_update_dataset", "description": "test", "permission": "me", - "parser_id": "presentation", - "language": "spanish"} - res = update_dataset(get_auth, json_req) + res = update_dataset(get_auth, dataset_id, { + "name": "test_update_dataset", + "description": "test", + "permission": "me", + "chunk_method": "presentation", + "language": "spanish", + }) assert res.get("code") == 0, f"{res.get('message')}" # delete dataset - for dataset_id in dataset_list: - res = rm_dataset(get_auth, dataset_id) + if dataset_list: + res = rm_dataset(get_auth, dataset_list) assert res.get("code") == 0, f"{res.get('message')}" print(f"{len(dataset_list)} datasets are deleted") @@ -153,15 +163,17 @@ def test_update_different_params_dataset_success(get_auth): # update dataset with different parameters def test_update_different_params_dataset_fail(get_auth): # create dataset - res = create_dataset(get_auth, "test_create_dataset") + res = create_dataset(get_auth, {"name": "test_create_dataset"}) assert res.get("code") == 0, f"{res.get('message')}" # list dataset page_number = 1 dataset_list = [] while True: - res = list_dataset(get_auth, page_number) - data = res.get("data").get("kbs") + res = list_dataset(get_auth, {"page": page_number, "page_size": 150}) + data = res.get("data") + if isinstance(data, dict): + data = data.get("kbs", []) for item in data: dataset_id = item.get("id") dataset_list.append(dataset_id) @@ -172,12 +184,11 @@ def test_update_different_params_dataset_fail(get_auth): print(f"found {len(dataset_list)} datasets") dataset_id = dataset_list[0] - json_req = {"kb_id": dataset_id, "id": "xxx"} - res = update_dataset(get_auth, json_req) + res = update_dataset(get_auth, dataset_id, {"id": "xxx"}) assert res.get("code") == 101 # delete dataset - for dataset_id in dataset_list: - res = rm_dataset(get_auth, dataset_id) + if dataset_list: + res = rm_dataset(get_auth, dataset_list) assert res.get("code") == 0, f"{res.get('message')}" print(f"{len(dataset_list)} datasets are deleted") diff --git a/test/playwright/conftest.py b/test/playwright/conftest.py index e73445129f..6b62636193 100644 --- a/test/playwright/conftest.py +++ b/test/playwright/conftest.py @@ -1189,9 +1189,9 @@ def _ensure_dataset_ready_via_api( base_url: str, auth_header: str, dataset_name: str ) -> dict: headers = {"Authorization": auth_header} - list_url = _build_url(base_url, "/v1/kb/list?page=1&page_size=200") + list_url = _build_url(base_url, "/api/v1/datasets?page=1&page_size=200") - _, list_payload = _api_request_json(list_url, method="POST", payload={}, headers=headers) + _, list_payload = _api_request_json(list_url, method="GET", headers=headers) existing = _find_dataset_by_name(list_payload, dataset_name) if existing: return { @@ -1201,7 +1201,7 @@ def _ensure_dataset_ready_via_api( } _, create_payload = _api_request_json( - _build_url(base_url, "/v1/kb/create"), + _build_url(base_url, "/api/v1/datasets"), method="POST", payload={"name": dataset_name}, headers=headers, @@ -1212,12 +1212,12 @@ def _ensure_dataset_ready_via_api( return {"kb_id": kb_id, "kb_name": dataset_name, "reused": False} _, list_payload_after = _api_request_json( - list_url, method="POST", payload={}, headers=headers + list_url, method="GET", headers=headers ) existing_after = _find_dataset_by_name(list_payload_after, dataset_name) if not existing_after: raise RuntimeError( - f"Dataset {dataset_name!r} not found after kb/create response={create_payload}" + f"Dataset {dataset_name!r} not found after /api/v1/datasets create response={create_payload}" ) return { "kb_id": existing_after.get("id"), diff --git a/test/playwright/e2e/test_dataset_upload_parse.py b/test/playwright/e2e/test_dataset_upload_parse.py index 437e4858f0..9e918714b2 100644 --- a/test/playwright/e2e/test_dataset_upload_parse.py +++ b/test/playwright/e2e/test_dataset_upload_parse.py @@ -203,7 +203,7 @@ def get_request_json_payload(response) -> dict: payload = None if not isinstance(payload, dict): - raise AssertionError(f"Expected JSON object payload for /v1/kb/update, got={payload!r}") + raise AssertionError(f"Expected JSON object payload for /api/v1/datasets update, got={payload!r}") return payload @@ -334,7 +334,7 @@ def step_03_create_dataset( create_response = capture_response( page, trigger, - lambda resp: resp.request.method == "POST" and "/v1/kb/create" in resp.url, + lambda resp: resp.request.method == "POST" and "/api/v1/datasets" in resp.url, timeout_ms=RESULT_TIMEOUT_MS * 2, ) try: @@ -540,23 +540,20 @@ def step_04_set_dataset_settings( response = capture_response( page, trigger, - lambda resp: resp.request.method == "POST" and "/v1/kb/update" in resp.url, + lambda resp: resp.request.method == "PUT" and f"/api/v1/datasets/{dataset_id}" in resp.url, timeout_ms=RESULT_TIMEOUT_MS * 2, ) - assert 200 <= response.status < 400, f"Unexpected /v1/kb/update status={response.status}" + assert 200 <= response.status < 400, f"Unexpected /api/v1/datasets update status={response.status}" response_payload = response.json() if isinstance(response_payload, dict): assert response_payload.get("code") == 0, ( - f"/v1/kb/update response code={response_payload.get('code')} " + f"/api/v1/datasets update response code={response_payload.get('code')} " f"message={response_payload.get('message')}" ) payload = get_request_json_payload(response) - assert payload.get("kb_id") == dataset_id, ( - f"Expected kb_id={dataset_id!r}, got {payload.get('kb_id')!r}" - ) for key in ("name", "language", "parser_config"): - assert key in payload, f"Expected key {key!r} in /v1/kb/update payload" + assert key in payload, f"Expected key {key!r} in /api/v1/datasets update payload" parser_config = payload.get("parser_config") or {} assert ( parser_config.get("image_table_context_window") diff --git a/test/testcases/test_http_api/common.py b/test/testcases/test_http_api/common.py index 0fbdcb7c32..bcfcf5541a 100644 --- a/test/testcases/test_http_api/common.py +++ b/test/testcases/test_http_api/common.py @@ -23,7 +23,8 @@ from utils.file_utils import create_txt_file HEADERS = {"Content-Type": "application/json"} DATASETS_API_URL = f"/api/{VERSION}/datasets" FILE_API_URL = f"/api/{VERSION}/datasets/{{dataset_id}}/documents" -FILE_CHUNK_API_URL = f"/api/{VERSION}/datasets/{{dataset_id}}/chunks" +FILE_PARSE_API_URL = f"/api/{VERSION}/datasets/{{dataset_id}}/documents/parse" +FILE_STOP_PARSE_API_URL = f"/api/{VERSION}/datasets/{{dataset_id}}/documents/stop" CHUNK_API_URL = f"/api/{VERSION}/datasets/{{dataset_id}}/documents/{{document_id}}/chunks" CHAT_ASSISTANT_API_URL = f"/api/{VERSION}/chats" SESSION_WITH_CHAT_ASSISTANT_API_URL = f"/api/{VERSION}/chats/{{chat_id}}/sessions" @@ -136,15 +137,15 @@ def delete_all_documents(auth, dataset_id, *, page_size=1000): return delete_documents(auth, dataset_id, {"ids": None, "delete_all": True}) -def parse_documents(auth, dataset_id, payload=None): - url = f"{HOST_ADDRESS}{FILE_CHUNK_API_URL}".format(dataset_id=dataset_id) - res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) +def parse_documents(auth, dataset_id, payload=None, *, headers=HEADERS): + url = f"{HOST_ADDRESS}{FILE_PARSE_API_URL}".format(dataset_id=dataset_id) + res = requests.post(url=url, headers=headers, auth=auth, json=payload) return res.json() def stop_parse_documents(auth, dataset_id, payload=None): - url = f"{HOST_ADDRESS}{FILE_CHUNK_API_URL}".format(dataset_id=dataset_id) - res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) + url = f"{HOST_ADDRESS}{FILE_STOP_PARSE_API_URL}".format(dataset_id=dataset_id) + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) return res.json() @@ -161,9 +162,9 @@ def bulk_upload_documents(auth, dataset_id, num, tmp_path): # CHUNK MANAGEMENT WITHIN DATASET -def add_chunk(auth, dataset_id, document_id, payload=None): +def add_chunk(auth, dataset_id, document_id, payload=None, *, headers=HEADERS): url = f"{HOST_ADDRESS}{CHUNK_API_URL}".format(dataset_id=dataset_id, document_id=document_id) - res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + res = requests.post(url=url, headers=headers, auth=auth, json=payload) return res.json() @@ -195,9 +196,9 @@ def delete_all_chunks(auth, dataset_id, document_id, *, page_size=1000): return delete_chunks(auth, dataset_id, document_id, {"chunk_ids": None, "delete_all": True}) -def retrieval_chunks(auth, payload=None): +def retrieval_chunks(auth, payload=None, *, headers=HEADERS): url = f"{HOST_ADDRESS}{RETRIEVAL_API_URL}" - res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + res = requests.post(url=url, headers=headers, auth=auth, json=payload) return res.json() @@ -210,9 +211,9 @@ def batch_add_chunks(auth, dataset_id, document_id, num): # CHAT ASSISTANT MANAGEMENT -def create_chat_assistant(auth, payload=None): +def create_chat_assistant(auth, payload=None, *, headers=HEADERS): url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}" - res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + res = requests.post(url=url, headers=headers, auth=auth, json=payload) return res.json() @@ -259,9 +260,9 @@ def batch_create_chat_assistants(auth, num): # SESSION MANAGEMENT -def create_session_with_chat_assistant(auth, chat_assistant_id, payload=None): +def create_session_with_chat_assistant(auth, chat_assistant_id, payload=None, *, headers=HEADERS): url = f"{HOST_ADDRESS}{SESSION_WITH_CHAT_ASSISTANT_API_URL}".format(chat_id=chat_assistant_id) - res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + res = requests.post(url=url, headers=headers, auth=auth, json=payload) return res.json() @@ -297,13 +298,13 @@ def batch_add_sessions_with_chat_assistant(auth, chat_assistant_id, num): # DATASET GRAPH AND TASKS def knowledge_graph(auth, dataset_id, params=None): - url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/knowledge_graph" + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/graph/search" res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) return res.json() def delete_knowledge_graph(auth, dataset_id, payload=None): - url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/knowledge_graph" + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/graph" if payload is None: res = requests.delete(url=url, headers=HEADERS, auth=auth) else: @@ -311,39 +312,15 @@ def delete_knowledge_graph(auth, dataset_id, payload=None): return res.json() -def run_graphrag(auth, dataset_id, payload=None): - url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/run_graphrag" - res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) - return res.json() - - -def trace_graphrag(auth, dataset_id, params=None): - url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/trace_graphrag" - res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) - return res.json() - - -def run_raptor(auth, dataset_id, payload=None): - url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/run_raptor" - res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) - return res.json() - - -def trace_raptor(auth, dataset_id, params=None): - url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/trace_raptor" - res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) - return res.json() - - def metadata_summary(auth, dataset_id, params=None): url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/metadata/summary" res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) return res.json() -def metadata_batch_update(auth, dataset_id, payload=None): +def metadata_batch_update(auth, dataset_id, payload=None, *, headers=HEADERS): url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/metadata/update" - res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + res = requests.post(url=url, headers=headers, auth=auth, json=payload) return res.json() @@ -358,16 +335,16 @@ def update_documents_metadata(auth, dataset_id, payload=None): # CHAT COMPLETIONS AND RELATED QUESTIONS -def related_questions(auth, payload=None): +def related_questions(auth, payload=None, *, headers=HEADERS): url = f"{HOST_ADDRESS}/api/{VERSION}/sessions/related_questions" - res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + res = requests.post(url=url, headers=headers, auth=auth, json=payload) return res.json() # AGENT MANAGEMENT AND SESSIONS -def create_agent(auth, payload=None): +def create_agent(auth, payload=None, *, headers=HEADERS): url = f"{HOST_ADDRESS}{AGENT_API_URL}" - res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + res = requests.post(url=url, headers=headers, auth=auth, json=payload) return res.json() @@ -439,7 +416,7 @@ def chat_completions(auth, chat_id=None, payload=None): return res.json() -def chat_completions_openai(auth, chat_id, payload=None): +def chat_completions_openai(auth, chat_id, payload=None, *, headers=HEADERS): """ Send a request to the OpenAI-compatible chat completions endpoint. @@ -454,5 +431,88 @@ def chat_completions_openai(auth, chat_id, payload=None): Response JSON in OpenAI chat completions format with usage information """ url = f"{HOST_ADDRESS}/api/{VERSION}/chats_openai/{chat_id}/chat/completions" - res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + res = requests.post(url=url, headers=headers, auth=auth, json=payload) + return res.json() + + +# NEW DATASET ENDPOINTS +def get_dataset(auth, dataset_id, *, headers=HEADERS): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}" + res = requests.get(url=url, headers=headers, auth=auth) + return res.json() + + +def get_ingestion_summary(auth, dataset_id, *, headers=HEADERS): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/ingestions/summary" + res = requests.get(url=url, headers=headers, auth=auth) + return res.json() + + +def list_ingestion_logs(auth, dataset_id, params=None, *, headers=HEADERS): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/ingestions" + res = requests.get(url=url, headers=headers, auth=auth, params=params) + return res.json() + + +def get_ingestion_log(auth, dataset_id, log_id, *, headers=HEADERS): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/ingestions/{log_id}" + res = requests.get(url=url, headers=headers, auth=auth) + return res.json() + + +def run_index(auth, dataset_id, index_type, payload=None, *, headers=HEADERS): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/index" + params = {"type": index_type} + res = requests.post(url=url, headers=headers, auth=auth, json=payload, params=params) + return res.json() + + +def trace_index(auth, dataset_id, index_type, params=None, *, headers=HEADERS): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/index" + all_params = {"type": index_type} + if params: + all_params.update(params) + res = requests.get(url=url, headers=headers, auth=auth, params=all_params) + return res.json() + + +def delete_index(auth, dataset_id, index_type, *, headers=HEADERS): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/{index_type}" + res = requests.delete(url=url, headers=headers, auth=auth) + return res.json() + + +def run_embedding(auth, dataset_id, payload=None, *, headers=HEADERS): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/embedding" + res = requests.post(url=url, headers=headers, auth=auth, json=payload) + return res.json() + + +def list_tags(auth, dataset_id, *, headers=HEADERS): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/tags" + res = requests.get(url=url, headers=headers, auth=auth) + return res.json() + + +def aggregate_tags(auth, dataset_ids, *, headers=HEADERS): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/tags/aggregation" + res = requests.get(url=url, headers=headers, auth=auth, params={"dataset_ids": ",".join(dataset_ids)}) + return res.json() + + +def delete_tags(auth, dataset_id, tags, *, headers=HEADERS): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/tags" + res = requests.delete(url=url, headers=headers, auth=auth, json={"tags": tags}) + return res.json() + + +def rename_tag(auth, dataset_id, from_tag, to_tag, *, headers=HEADERS): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/tags" + res = requests.put(url=url, headers=headers, auth=auth, json={"from_tag": from_tag, "to_tag": to_tag}) + return res.json() + + +def get_flattened_metadata(auth, dataset_ids, *, headers=HEADERS): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/metadata/flattened" + res = requests.get(url=url, headers=headers, auth=auth, params={"dataset_ids": ",".join(dataset_ids)}) return res.json() diff --git a/test/testcases/test_http_api/conftest.py b/test/testcases/test_http_api/conftest.py index d3c571a6f0..9fdb2803a1 100644 --- a/test/testcases/test_http_api/conftest.py +++ b/test/testcases/test_http_api/conftest.py @@ -43,7 +43,7 @@ from utils.file_utils import ( ) -@wait_for(30, 1, "Document parsing timeout") +@wait_for(200, 1, "Document parsing timeout") def condition(_auth, _dataset_id): res = list_documents(_auth, _dataset_id) for doc in res["data"]["docs"]: diff --git a/test/testcases/test_http_api/test_dataset_management/test_embedding.py b/test/testcases/test_http_api/test_dataset_management/test_embedding.py new file mode 100644 index 0000000000..6ee5593962 --- /dev/null +++ b/test/testcases/test_http_api/test_dataset_management/test_embedding.py @@ -0,0 +1,32 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import run_embedding + + +@pytest.mark.usefixtures("clear_datasets") +class TestRunEmbedding: + @pytest.mark.p2 + def test_run_embedding_no_documents(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = run_embedding(HttpApiAuth, dataset_id) + assert res["code"] == 102, res + assert "No documents in Dataset" in res.get("message", ""), res + + @pytest.mark.p2 + def test_run_embedding_invalid_id(self, HttpApiAuth): + res = run_embedding(HttpApiAuth, "invalid_id") + assert res["code"] != 0, res diff --git a/test/testcases/test_http_api/test_dataset_management/test_flattened_metadata.py b/test/testcases/test_http_api/test_dataset_management/test_flattened_metadata.py new file mode 100644 index 0000000000..d67e66ce06 --- /dev/null +++ b/test/testcases/test_http_api/test_dataset_management/test_flattened_metadata.py @@ -0,0 +1,42 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import get_flattened_metadata + + +@pytest.mark.usefixtures("clear_datasets") +class TestFlattenedMetadata: + @pytest.mark.p2 + def test_get_flattened_metadata_success(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = get_flattened_metadata(HttpApiAuth, [dataset_id]) + assert res["code"] == 0, res + + @pytest.mark.p2 + def test_get_flattened_metadata_multiple_datasets(self, HttpApiAuth, add_datasets_func): + dataset_ids = add_datasets_func + res = get_flattened_metadata(HttpApiAuth, dataset_ids) + assert res["code"] == 0, res + + @pytest.mark.p2 + def test_get_flattened_metadata_empty_ids(self, HttpApiAuth): + res = get_flattened_metadata(HttpApiAuth, []) + assert res["code"] != 0, res + + @pytest.mark.p2 + def test_get_flattened_metadata_invalid_id(self, HttpApiAuth): + res = get_flattened_metadata(HttpApiAuth, ["invalid_id"]) + assert res["code"] != 0, res diff --git a/test/testcases/test_http_api/test_dataset_management/test_get_dataset.py b/test/testcases/test_http_api/test_dataset_management/test_get_dataset.py new file mode 100644 index 0000000000..92df5ea679 --- /dev/null +++ b/test/testcases/test_http_api/test_dataset_management/test_get_dataset.py @@ -0,0 +1,45 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import get_dataset +from libs.auth import RAGFlowHttpApiAuth +from configs import INVALID_API_TOKEN + + +@pytest.mark.usefixtures("clear_datasets") +class TestGetDataset: + @pytest.mark.p2 + def test_get_dataset_success(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = get_dataset(HttpApiAuth, dataset_id) + assert res["code"] == 0, res + assert res["data"]["id"] == dataset_id, res + + @pytest.mark.p2 + def test_get_dataset_invalid_id(self, HttpApiAuth): + res = get_dataset(HttpApiAuth, "invalid_dataset_id") + assert res["code"] != 0, res + + @pytest.mark.p2 + def test_get_dataset_unauthorized(self, add_dataset_func): + dataset_id = add_dataset_func + res = get_dataset(RAGFlowHttpApiAuth(INVALID_API_TOKEN), dataset_id) + assert res["code"] != 0, res + + @pytest.mark.p2 + def test_get_dataset_nonexistent(self, HttpApiAuth): + res = get_dataset(HttpApiAuth, "0" * 32) + assert res["code"] != 0, res diff --git a/test/testcases/test_http_api/test_dataset_management/test_graphrag_tasks.py b/test/testcases/test_http_api/test_dataset_management/test_graphrag_tasks.py deleted file mode 100644 index a805be9a6d..0000000000 --- a/test/testcases/test_http_api/test_dataset_management/test_graphrag_tasks.py +++ /dev/null @@ -1,89 +0,0 @@ -# -# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import pytest -from common import bulk_upload_documents, list_documents, parse_documents, run_graphrag, trace_graphrag -from utils import wait_for - - -@wait_for(200, 1, "Document parsing timeout") -def _parse_done(auth, dataset_id, document_ids=None): - res = list_documents(auth, dataset_id) - target_docs = res["data"]["docs"] - if document_ids is None: - return all(doc.get("run") == "DONE" for doc in target_docs) - target_ids = set(document_ids) - for doc in target_docs: - if doc.get("id") in target_ids and doc.get("run") != "DONE": - return False - return True - - -class TestGraphRAGTasks: - @pytest.mark.p2 - def test_trace_graphrag_before_run(self, HttpApiAuth, add_dataset_func): - dataset_id = add_dataset_func - res = trace_graphrag(HttpApiAuth, dataset_id) - assert res["code"] == 0, res - assert res["data"] == {}, res - - @pytest.mark.p2 - def test_run_graphrag_no_documents(self, HttpApiAuth, add_dataset_func): - dataset_id = add_dataset_func - res = run_graphrag(HttpApiAuth, dataset_id) - assert res["code"] == 102, res - assert "No documents in Dataset" in res.get("message", ""), res - - @pytest.mark.p3 - def test_run_graphrag_returns_task_id(self, HttpApiAuth, add_dataset_func, tmp_path): - dataset_id = add_dataset_func - bulk_upload_documents(HttpApiAuth, dataset_id, 1, tmp_path) - res = run_graphrag(HttpApiAuth, dataset_id) - assert res["code"] == 0, res - assert res["data"].get("graphrag_task_id"), res - - @pytest.mark.p3 - def test_trace_graphrag_until_complete(self, HttpApiAuth, add_dataset_func, tmp_path): - dataset_id = add_dataset_func - document_ids = bulk_upload_documents(HttpApiAuth, dataset_id, 1, tmp_path) - res = parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) - assert res["code"] == 0, res - _parse_done(HttpApiAuth, dataset_id, document_ids) - - res = run_graphrag(HttpApiAuth, dataset_id) - assert res["code"] == 0, res - - last_res = {} - - @wait_for(200, 1, "GraphRAG task timeout") - def condition(): - res = trace_graphrag(HttpApiAuth, dataset_id) - if res["code"] != 0: - return False - data = res.get("data") or {} - if not data: - return False - if data.get("task_type") != "graphrag": - return False - progress = data.get("progress") - if progress in (-1, 1, -1.0, 1.0): - last_res["res"] = res - return True - return False - - condition() - res = last_res["res"] - assert res["data"]["task_type"] == "graphrag", res - assert res["data"].get("progress") in (-1, 1, -1.0, 1.0), res diff --git a/test/testcases/test_http_api/test_dataset_management/test_index_api.py b/test/testcases/test_http_api/test_dataset_management/test_index_api.py new file mode 100644 index 0000000000..d97691223d --- /dev/null +++ b/test/testcases/test_http_api/test_dataset_management/test_index_api.py @@ -0,0 +1,166 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import ( + bulk_upload_documents, + list_documents, + run_index, + trace_index, + delete_index, +) +from utils import wait_for + + +@wait_for(200, 1, "Document parsing timeout") +def _parse_done(auth, dataset_id, document_ids=None): + res = list_documents(auth, dataset_id) + if res.get("code") != 0: + return False + target_docs = res.get("data", {}).get("docs", []) + if not target_docs: + return False + if document_ids is None: + return all(doc.get("run") == "DONE" for doc in target_docs) + target_ids = set(document_ids) + seen_ids = set() + for doc in target_docs: + doc_id = doc.get("id") + if doc_id in target_ids: + seen_ids.add(doc_id) + if doc.get("run") != "DONE": + return False + return seen_ids == target_ids + + +@wait_for(60, 1, "Index task creation timeout") +def _index_task_created(auth, dataset_id, index_type): + res = trace_index(auth, dataset_id, index_type) + if res.get("code") != 0: + return False + return bool(res.get("data", {}).get("id")) + + +@pytest.mark.usefixtures("clear_datasets") +class TestRunIndex: + @pytest.mark.p2 + def test_run_index_graph(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + bulk_upload_documents(HttpApiAuth, dataset_id, 1, tmp_path) + res = run_index(HttpApiAuth, dataset_id, "graph") + assert res["code"] == 0, res + assert res["data"].get("task_id"), res + + @pytest.mark.p2 + def test_run_index_raptor(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + bulk_upload_documents(HttpApiAuth, dataset_id, 1, tmp_path) + res = run_index(HttpApiAuth, dataset_id, "raptor") + assert res["code"] == 0, res + assert res["data"].get("task_id"), res + + @pytest.mark.p2 + def test_run_index_mindmap(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + bulk_upload_documents(HttpApiAuth, dataset_id, 1, tmp_path) + res = run_index(HttpApiAuth, dataset_id, "mindmap") + assert res["code"] == 0, res + assert res["data"].get("task_id"), res + + @pytest.mark.p2 + def test_run_index_invalid_type(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + bulk_upload_documents(HttpApiAuth, dataset_id, 1, tmp_path) + res = run_index(HttpApiAuth, dataset_id, "invalid_type") + assert res["code"] != 0, res + + @pytest.mark.p2 + def test_run_index_no_documents(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = run_index(HttpApiAuth, dataset_id, "raptor") + assert res["code"] == 102, res + + +@pytest.mark.usefixtures("clear_datasets") +class TestDeleteIndex: + @pytest.mark.p2 + def test_delete_graph(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = delete_index(HttpApiAuth, dataset_id, "graph") + assert res["code"] == 0, res + + @pytest.mark.p2 + def test_delete_raptor(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = delete_index(HttpApiAuth, dataset_id, "raptor") + assert res["code"] == 0, res + + @pytest.mark.p2 + def test_delete_mindmap(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = delete_index(HttpApiAuth, dataset_id, "mindmap") + assert res["code"] == 0, res + + @pytest.mark.p2 + def test_delete_invalid_type(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = delete_index(HttpApiAuth, dataset_id, "invalid_type") + assert res["code"] != 0, res + + +@pytest.mark.usefixtures("clear_datasets") +class TestTraceIndex: + @pytest.mark.p2 + def test_trace_index_graph(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + bulk_upload_documents(HttpApiAuth, dataset_id, 1, tmp_path) + res = run_index(HttpApiAuth, dataset_id, "graph") + assert res["code"] == 0, res + _index_task_created(HttpApiAuth, dataset_id, "graph") + res = trace_index(HttpApiAuth, dataset_id, "graph") + assert res["code"] == 0, res + + @pytest.mark.p2 + def test_trace_index_raptor(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + bulk_upload_documents(HttpApiAuth, dataset_id, 1, tmp_path) + res = run_index(HttpApiAuth, dataset_id, "raptor") + assert res["code"] == 0, res + _index_task_created(HttpApiAuth, dataset_id, "raptor") + res = trace_index(HttpApiAuth, dataset_id, "raptor") + assert res["code"] == 0, res + + @pytest.mark.p2 + def test_trace_index_mindmap(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + bulk_upload_documents(HttpApiAuth, dataset_id, 1, tmp_path) + res = run_index(HttpApiAuth, dataset_id, "mindmap") + assert res["code"] == 0, res + _index_task_created(HttpApiAuth, dataset_id, "mindmap") + res = trace_index(HttpApiAuth, dataset_id, "mindmap") + assert res["code"] == 0, res + + @pytest.mark.p2 + def test_trace_index_invalid_type(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = trace_index(HttpApiAuth, dataset_id, "invalid_type") + assert res["code"] != 0, res + + @pytest.mark.p2 + def test_trace_index_no_task(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = trace_index(HttpApiAuth, dataset_id, "graph") + assert res["code"] == 0, res + assert res["data"] == {} diff --git a/test/testcases/test_http_api/test_dataset_management/test_ingestion_logs.py b/test/testcases/test_http_api/test_dataset_management/test_ingestion_logs.py new file mode 100644 index 0000000000..f74f7855ba --- /dev/null +++ b/test/testcases/test_http_api/test_dataset_management/test_ingestion_logs.py @@ -0,0 +1,53 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import list_ingestion_logs, get_ingestion_log + + +@pytest.mark.usefixtures("clear_datasets") +class TestListIngestionLogs: + @pytest.mark.p2 + def test_list_ingestion_logs_success(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = list_ingestion_logs(HttpApiAuth, dataset_id) + assert res["code"] == 0, res + assert "total" in res["data"], res + assert "logs" in res["data"], res + + @pytest.mark.p2 + def test_list_ingestion_logs_with_pagination(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = list_ingestion_logs(HttpApiAuth, dataset_id, params={"page": 1, "page_size": 10}) + assert res["code"] == 0, res + + @pytest.mark.p2 + def test_list_ingestion_logs_invalid_id(self, HttpApiAuth): + res = list_ingestion_logs(HttpApiAuth, "invalid_id") + assert res["code"] != 0, res + + +@pytest.mark.usefixtures("clear_datasets") +class TestGetIngestionLog: + @pytest.mark.p2 + def test_get_ingestion_log_not_found(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = get_ingestion_log(HttpApiAuth, dataset_id, "nonexistent_log_id") + assert res["code"] != 0, res + + @pytest.mark.p2 + def test_get_ingestion_log_invalid_dataset(self, HttpApiAuth): + res = get_ingestion_log(HttpApiAuth, "invalid_id", "some_log_id") + assert res["code"] != 0, res diff --git a/test/testcases/test_http_api/test_dataset_management/test_ingestion_summary.py b/test/testcases/test_http_api/test_dataset_management/test_ingestion_summary.py new file mode 100644 index 0000000000..3dc8b7aee6 --- /dev/null +++ b/test/testcases/test_http_api/test_dataset_management/test_ingestion_summary.py @@ -0,0 +1,35 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import get_ingestion_summary + + +@pytest.mark.usefixtures("clear_datasets") +class TestIngestionSummary: + @pytest.mark.p2 + def test_ingestion_summary_success(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = get_ingestion_summary(HttpApiAuth, dataset_id) + assert res["code"] == 0, res + assert "doc_num" in res["data"], res + assert "chunk_num" in res["data"], res + assert "token_num" in res["data"], res + assert "status" in res["data"], res + + @pytest.mark.p2 + def test_ingestion_summary_invalid_id(self, HttpApiAuth): + res = get_ingestion_summary(HttpApiAuth, "invalid_id") + assert res["code"] != 0, res diff --git a/test/testcases/test_http_api/test_dataset_management/test_raptor_tasks.py b/test/testcases/test_http_api/test_dataset_management/test_raptor_tasks.py deleted file mode 100644 index 6358fc2660..0000000000 --- a/test/testcases/test_http_api/test_dataset_management/test_raptor_tasks.py +++ /dev/null @@ -1,89 +0,0 @@ -# -# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import pytest -from common import bulk_upload_documents, list_documents, parse_documents, run_raptor, trace_raptor -from utils import wait_for - - -@wait_for(200, 1, "Document parsing timeout") -def _parse_done(auth, dataset_id, document_ids=None): - res = list_documents(auth, dataset_id) - target_docs = res["data"]["docs"] - if document_ids is None: - return all(doc.get("run") == "DONE" for doc in target_docs) - target_ids = set(document_ids) - for doc in target_docs: - if doc.get("id") in target_ids and doc.get("run") != "DONE": - return False - return True - - -class TestRaptorTasks: - @pytest.mark.p2 - def test_trace_raptor_before_run(self, HttpApiAuth, add_dataset_func): - dataset_id = add_dataset_func - res = trace_raptor(HttpApiAuth, dataset_id) - assert res["code"] == 0, res - assert res["data"] == {}, res - - @pytest.mark.p2 - def test_run_raptor_no_documents(self, HttpApiAuth, add_dataset_func): - dataset_id = add_dataset_func - res = run_raptor(HttpApiAuth, dataset_id) - assert res["code"] == 102, res - assert "No documents in Dataset" in res.get("message", ""), res - - @pytest.mark.p3 - def test_run_raptor_returns_task_id(self, HttpApiAuth, add_dataset_func, tmp_path): - dataset_id = add_dataset_func - bulk_upload_documents(HttpApiAuth, dataset_id, 1, tmp_path) - res = run_raptor(HttpApiAuth, dataset_id) - assert res["code"] == 0, res - assert res["data"].get("raptor_task_id"), res - - @pytest.mark.p3 - def test_trace_raptor_until_complete(self, HttpApiAuth, add_dataset_func, tmp_path): - dataset_id = add_dataset_func - document_ids = bulk_upload_documents(HttpApiAuth, dataset_id, 1, tmp_path) - res = parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) - assert res["code"] == 0, res - _parse_done(HttpApiAuth, dataset_id, document_ids) - - res = run_raptor(HttpApiAuth, dataset_id) - assert res["code"] == 0, res - - last_res = {} - - @wait_for(200, 1, "RAPTOR task timeout") - def condition(): - res = trace_raptor(HttpApiAuth, dataset_id) - if res["code"] != 0: - return False - data = res.get("data") or {} - if not data: - return False - if data.get("task_type") != "raptor": - return False - progress = data.get("progress") - if progress in (-1, 1, -1.0, 1.0): - last_res["res"] = res - return True - return False - - condition() - res = last_res["res"] - assert res["data"]["task_type"] == "raptor", res - assert res["data"].get("progress") in (-1, 1, -1.0, 1.0), res diff --git a/test/testcases/test_http_api/test_dataset_management/test_tags.py b/test/testcases/test_http_api/test_dataset_management/test_tags.py new file mode 100644 index 0000000000..9460cbe7c0 --- /dev/null +++ b/test/testcases/test_http_api/test_dataset_management/test_tags.py @@ -0,0 +1,84 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import ( + list_tags, + aggregate_tags, + delete_tags, + rename_tag, +) + + +@pytest.mark.usefixtures("clear_datasets") +class TestListTags: + @pytest.mark.p2 + def test_list_tags_success(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = list_tags(HttpApiAuth, dataset_id) + assert res["code"] == 0, res + + @pytest.mark.p2 + def test_list_tags_invalid_id(self, HttpApiAuth): + res = list_tags(HttpApiAuth, "invalid_id") + assert res["code"] != 0, res + + +@pytest.mark.usefixtures("clear_datasets") +class TestAggregateTags: + @pytest.mark.p2 + def test_aggregate_tags_success(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = aggregate_tags(HttpApiAuth, [dataset_id]) + assert res["code"] == 0, res + + @pytest.mark.p2 + def test_aggregate_tags_multiple_datasets(self, HttpApiAuth, add_datasets_func): + dataset_ids = add_datasets_func + res = aggregate_tags(HttpApiAuth, dataset_ids) + assert res["code"] == 0, res + + @pytest.mark.p2 + def test_aggregate_tags_empty_ids(self, HttpApiAuth): + res = aggregate_tags(HttpApiAuth, []) + assert res["code"] != 0, res + + +@pytest.mark.usefixtures("clear_datasets") +class TestDeleteTags: + @pytest.mark.p2 + def test_delete_tags_missing_body(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = delete_tags(HttpApiAuth, dataset_id, []) + assert res["code"] == 0, res + + @pytest.mark.p2 + def test_delete_tags_invalid_id(self, HttpApiAuth): + res = delete_tags(HttpApiAuth, "invalid_id", ["tag1"]) + assert res["code"] != 0, res + + +@pytest.mark.usefixtures("clear_datasets") +class TestRenameTag: + @pytest.mark.p2 + def test_rename_tag_empty_names(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = rename_tag(HttpApiAuth, dataset_id, "", "") + assert res["code"] != 0, res + + @pytest.mark.p2 + def test_rename_tag_invalid_id(self, HttpApiAuth): + res = rename_tag(HttpApiAuth, "invalid_id", "old_tag", "new_tag") + assert res["code"] != 0, res diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_retrieval.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_retrieval.py index adc6435dd5..9b0dd18cde 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_retrieval.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_retrieval.py @@ -27,11 +27,14 @@ from common import ( delete_datasets, list_documents, update_document, + upload_documents, + parse_documents, + retrieval_chunks, ) from utils import wait_for -@wait_for(30, 1, "Document parsing timeout") +@wait_for(120, 1, "Document parsing timeout") def _condition_parsing_complete(_auth, dataset_id): res = list_documents(_auth, dataset_id) if res["code"] != 0: @@ -39,7 +42,7 @@ def _condition_parsing_complete(_auth, dataset_id): for doc in res["data"]["docs"]: status = doc.get("run", "UNKNOWN") - if status == "FAILED": + if status in ("FAIL", "FAILED"): pytest.fail(f"Document parsing failed: {doc}") return False if status != "DONE": @@ -62,35 +65,17 @@ def add_dataset_with_metadata(HttpApiAuth): import requests from configs import HOST_ADDRESS, VERSION - metadata_config = { - "type": "object", - "properties": { - "character": { - "description": "Historical figure name", - "type": "string" - }, - "era": { - "description": "Historical era", - "type": "string" - }, - "achievements": { - "description": "Major achievements", - "type": "array", - "items": { - "type": "string" - } - } - } - } - - res = requests.post( - url=f"{HOST_ADDRESS}/{VERSION}/kb/update_metadata_setting", + res = requests.put( + url=f"{HOST_ADDRESS}/api/{VERSION}/datasets/{dataset_id}/metadata/config", headers={"Content-Type": "application/json"}, auth=HttpApiAuth, json={ - "kb_id": dataset_id, - "metadata": metadata_config, - "enable_metadata": False + "enabled": False, + "fields": [ + {"name": "character", "type": "string", "description": "Historical figure name"}, + {"name": "era", "type": "string", "description": "Historical era"}, + {"name": "achievements", "type": "list", "description": "Major achievements"}, + ] } ).json() @@ -112,8 +97,6 @@ class TestMetadataWithRetrieval: Verifies that chunks are only retrieved from documents matching the metadata condition. """ - from common import upload_documents, parse_documents, retrieval_chunks - dataset_id = add_dataset_with_metadata # Create two documents with different metadata diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_summary.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_summary.py index 4c231277b1..bd2ca9beda 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_summary.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_summary.py @@ -28,16 +28,12 @@ def _summary_to_counts(summary): class TestMetadataSummary: @pytest.mark.p2 - def test_metadata_summary_missing_kb_id(self, HttpApiAuth, add_document_func): + def test_metadata_summary_nonexistent_kb_id(self, HttpApiAuth, add_document_func): """ Call with non-existent dataset - :param HttpApiAuth: - :param add_document_func: - :return: """ - res = metadata_summary(HttpApiAuth, "") - assert res["code"] == 404, res - assert res["message"] == "Not Found: /api/v1/datasets//metadata/summary", res + res = metadata_summary(HttpApiAuth, "0" * 32) + assert res["code"] == 102, res @pytest.mark.p2 def test_metadata_summary_invalid_kb_id(self, HttpApiAuth, add_document_func): diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_parse_documents.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_parse_documents.py index 755d87cce7..5b9e5ad314 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_parse_documents.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_parse_documents.py @@ -58,11 +58,11 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_code, expected_message", [ - (None, 0, "`Authorization` can't be empty"), + (None, 401, ""), ( RAGFlowHttpApiAuth(INVALID_API_TOKEN), - 109, - "Authentication error: API key is invalid!", + 401, + "", ), ], ) @@ -101,7 +101,7 @@ class TestDocumentsParse: @pytest.mark.parametrize( "dataset_id, expected_code, expected_message", [ - ("", 100, ""), + ("", 102, "You don't own the dataset ."), ( "invalid_dataset_id", 102, diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_stop_parse_documents.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_stop_parse_documents.py index a79e1c6d18..ab2a251560 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_stop_parse_documents.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_stop_parse_documents.py @@ -48,11 +48,11 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_code, expected_message", [ - (None, 0, "`Authorization` can't be empty"), + (None, 401, ""), ( RAGFlowHttpApiAuth(INVALID_API_TOKEN), - 109, - "Authentication error: API key is invalid!", + 401, + "", ), ], ) @@ -105,7 +105,7 @@ class TestDocumentsParseStop: @pytest.mark.parametrize( "invalid_dataset_id, expected_code, expected_message", [ - ("", 100, ""), + ("", 102, "You don't own the dataset ."), ( "invalid_dataset_id", 102, diff --git a/test/testcases/test_sdk_api/conftest.py b/test/testcases/test_sdk_api/conftest.py index f4791306cc..511842fb9d 100644 --- a/test/testcases/test_sdk_api/conftest.py +++ b/test/testcases/test_sdk_api/conftest.py @@ -46,7 +46,7 @@ from utils.file_utils import ( ) -@wait_for(30, 1, "Document parsing timeout") +@wait_for(200, 1, "Document parsing timeout") def condition(_dataset: DataSet): documents = _dataset.list_documents(page_size=1000) for document in documents: diff --git a/test/testcases/test_sdk_api/test_chat_assistant_management/conftest.py b/test/testcases/test_sdk_api/test_chat_assistant_management/conftest.py index c02065061a..4d1a419e68 100644 --- a/test/testcases/test_sdk_api/test_chat_assistant_management/conftest.py +++ b/test/testcases/test_sdk_api/test_chat_assistant_management/conftest.py @@ -20,7 +20,7 @@ from ragflow_sdk import Chat, DataSet, Document, RAGFlow from utils import wait_for -@wait_for(30, 1, "Document parsing timeout") +@wait_for(200, 1, "Document parsing timeout") def condition(_dataset: DataSet): documents = _dataset.list_documents(page_size=1000) for document in documents: @@ -29,6 +29,17 @@ def condition(_dataset: DataSet): return True +def _ensure_parsed(dataset: DataSet, document: Document): + """Trigger parsing only if the document is not already done or in progress.""" + if document.run == "DONE": + return + try: + dataset.async_parse_documents([document.id]) + except Exception: + pass # Already being processed + condition(dataset) + + @pytest.fixture(scope="function") def add_chat_assistants_func(request: FixtureRequest, client: RAGFlow, add_document: tuple[DataSet, Document]) -> tuple[DataSet, Document, list[Chat]]: def cleanup(): @@ -37,6 +48,5 @@ def add_chat_assistants_func(request: FixtureRequest, client: RAGFlow, add_docum request.addfinalizer(cleanup) dataset, document = add_document - dataset.async_parse_documents([document.id]) - condition(dataset) + _ensure_parsed(dataset, document) return dataset, document, batch_create_chat_assistants(client, 5) diff --git a/test/testcases/test_web_api/test_chunk_app/test_retrieval_chunks.py b/test/testcases/test_web_api/test_chunk_app/test_retrieval_chunks.py index 14857210f4..357cd477b4 100644 --- a/test/testcases/test_web_api/test_chunk_app/test_retrieval_chunks.py +++ b/test/testcases/test_web_api/test_chunk_app/test_retrieval_chunks.py @@ -194,14 +194,14 @@ class TestChunksRetrieval: 100, 4, "must be greater than 0", - marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"), + marks=pytest.mark.skip(reason="Web API does not validate top_k"), ), pytest.param( {"top_k": -1}, 100, 4, "3014", - marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"), + marks=pytest.mark.skip(reason="Web API does not validate top_k"), ), pytest.param( {"top_k": "a"}, diff --git a/test/testcases/test_web_api/test_common.py b/test/testcases/test_web_api/test_common.py index aa525c6edb..c0c84038be 100644 --- a/test/testcases/test_web_api/test_common.py +++ b/test/testcases/test_web_api/test_common.py @@ -25,7 +25,6 @@ from utils.file_utils import create_txt_file HEADERS = {"Content-Type": "application/json"} -KB_APP_URL = f"/{VERSION}/kb" DATASETS_URL = f"/api/{VERSION}/datasets" DOCUMENT_APP_URL = f"/{VERSION}/document" CHUNK_APP_URL = f"/{VERSION}/chunk" @@ -207,49 +206,41 @@ def delete_datasets(auth, payload=None, *, headers=HEADERS, data=None): return res.json() -def detail_kb(auth, params=None, *, headers=HEADERS): - res = requests.get(url=f"{HOST_ADDRESS}{KB_APP_URL}/detail", headers=headers, auth=auth, params=params) +def detail_kb(auth, dataset_id, *, headers=HEADERS): + res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}", headers=headers, auth=auth) return res.json() -def kb_get_meta(auth, params=None, *, headers=HEADERS): - res = requests.get(url=f"{HOST_ADDRESS}{KB_APP_URL}/get_meta", headers=headers, auth=auth, params=params) +def kb_get_meta(auth, dataset_ids, *, headers=HEADERS): + params = {"dataset_ids": dataset_ids} + res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_URL}/metadata/flattened", headers=headers, auth=auth, params=params) return res.json() -def kb_basic_info(auth, params=None, *, headers=HEADERS): - res = requests.get(url=f"{HOST_ADDRESS}{KB_APP_URL}/basic_info", headers=headers, auth=auth, params=params) +def kb_basic_info(auth, dataset_id, *, headers=HEADERS): + res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/ingestions/summary", headers=headers, auth=auth) return res.json() -def kb_update_metadata_setting(auth, payload=None, *, headers=HEADERS, data=None): - res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/update_metadata_setting", headers=headers, auth=auth, json=payload, data=data) +def kb_update_metadata_setting(auth, dataset_id, payload=None, *, headers=HEADERS, data=None): + res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/metadata/config", headers=headers, auth=auth, json=payload, data=data) return res.json() -def kb_list_pipeline_logs(auth, params=None, payload=None, *, headers=HEADERS, data=None): - if payload is None: - payload = {} - res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/list_pipeline_logs", headers=headers, auth=auth, params=params, json=payload, data=data) +def kb_list_pipeline_logs(auth, dataset_id, params=None, *, headers=HEADERS): + url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/ingestions" + res = requests.get(url=url, headers=headers, auth=auth, params=params) return res.json() -def kb_list_pipeline_dataset_logs(auth, params=None, payload=None, *, headers=HEADERS, data=None): - if payload is None: - payload = {} - res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/list_pipeline_dataset_logs", headers=headers, auth=auth, params=params, json=payload, data=data) +def kb_list_pipeline_dataset_logs(auth, dataset_id, params=None, *, headers=HEADERS): + url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/ingestions" + res = requests.get(url=url, headers=headers, auth=auth, params=params) return res.json() -def kb_delete_pipeline_logs(auth, params=None, payload=None, *, headers=HEADERS, data=None): - if payload is None: - payload = {} - res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/delete_pipeline_logs", headers=headers, auth=auth, params=params, json=payload, data=data) - return res.json() - - -def kb_pipeline_log_detail(auth, params=None, *, headers=HEADERS): - res = requests.get(url=f"{HOST_ADDRESS}{KB_APP_URL}/pipeline_log_detail", headers=headers, auth=auth, params=params) +def kb_pipeline_log_detail(auth, dataset_id, log_id, *, headers=HEADERS): + res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/ingestions/{log_id}", headers=headers, auth=auth) return res.json() @@ -269,57 +260,24 @@ def delete_knowledge_graph(auth, dataset_id, payload=None): return res.json() -def run_graphrag(auth, dataset_id, payload=None): - url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/run_graphrag" - res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) +def list_tags_from_kbs(auth, dataset_ids, *, headers=HEADERS): + params = {"dataset_ids": dataset_ids} + res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_URL}/tags/aggregation", headers=headers, auth=auth, params=params) return res.json() -def trace_graphrag(auth, dataset_id, params=None): - url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/trace_graphrag" - res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) - return res.json() - - -def run_raptor(auth, dataset_id, payload=None): - url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/run_raptor" - res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) - return res.json() - - -def trace_raptor(auth, dataset_id, params=None): - url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/trace_raptor" - res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) - return res.json() - - -def kb_run_mindmap(auth, payload=None, *, headers=HEADERS, data=None): - res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/run_mindmap", headers=headers, auth=auth, json=payload, data=data) - return res.json() - - -def kb_trace_mindmap(auth, params=None, *, headers=HEADERS): - res = requests.get(url=f"{HOST_ADDRESS}{KB_APP_URL}/trace_mindmap", headers=headers, auth=auth, params=params) - return res.json() - - -def list_tags_from_kbs(auth, params=None, *, headers=HEADERS): - res = requests.get(url=f"{HOST_ADDRESS}{KB_APP_URL}/tags", headers=headers, auth=auth, params=params) - return res.json() - - -def list_tags(auth, dataset_id, params=None, *, headers=HEADERS): - res = requests.get(url=f"{HOST_ADDRESS}{KB_APP_URL}/{dataset_id}/tags", headers=headers, auth=auth, params=params) +def list_tags(auth, dataset_id, *, headers=HEADERS): + res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/tags", headers=headers, auth=auth) return res.json() def rm_tags(auth, dataset_id, payload=None, *, headers=HEADERS, data=None): - res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/{dataset_id}/rm_tags", headers=headers, auth=auth, json=payload, data=data) + res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/tags", headers=headers, auth=auth, json=payload, data=data) return res.json() def rename_tags(auth, dataset_id, payload=None, *, headers=HEADERS, data=None): - res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/{dataset_id}/rename_tag", headers=headers, auth=auth, json=payload, data=data) + res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/tags", headers=headers, auth=auth, json=payload, data=data) return res.json() diff --git a/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py b/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py index 411824de08..1a42af9dfa 100644 --- a/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py +++ b/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py @@ -142,6 +142,12 @@ def _load_dataset_module(monkeypatch): api_pkg.__path__ = [str(repo_root / "api")] monkeypatch.setitem(sys.modules, "api", api_pkg) + api_constants_mod = ModuleType("api.constants") + api_constants_mod.DATASET_NAME_LIMIT = 128 + api_constants_mod.FILE_NAME_LEN_LIMIT = 255 + monkeypatch.setitem(sys.modules, "api.constants", api_constants_mod) + api_pkg.constants = api_constants_mod + utils_pkg = ModuleType("api.utils") utils_pkg.__path__ = [str(repo_root / "api" / "utils")] monkeypatch.setitem(sys.modules, "api.utils", utils_pkg) @@ -161,6 +167,7 @@ def _load_dataset_module(monkeypatch): db_pkg = ModuleType("api.db") db_pkg.__path__ = [] + db_pkg.FileType = SimpleNamespace() monkeypatch.setitem(sys.modules, "api.db", db_pkg) api_pkg.db = db_pkg @@ -313,8 +320,14 @@ def _load_dataset_module(monkeypatch): def get_by_ids(_ids): return [] + class _StubUserTenantService: + @staticmethod + def get_tenants_by_user_id(_user_id): + return [] + user_service_mod.TenantService = _StubTenantService user_service_mod.UserService = _StubUserService + user_service_mod.UserTenantService = _StubUserTenantService monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod) services_pkg.user_service = user_service_mod @@ -662,143 +675,115 @@ def test_list_knowledge_graph_delete_kg_matrix_unit(monkeypatch): @pytest.mark.p3 -def test_run_trace_graphrag_matrix_unit(monkeypatch): +def test_run_index_matrix_unit(monkeypatch): module = _load_dataset_module(monkeypatch) warnings = [] monkeypatch.setattr(module.logging, "warning", lambda msg, *_args, **_kwargs: warnings.append(msg)) - res = _run(inspect.unwrap(module.run_graphrag)("tenant-1", "")) - assert 'Dataset ID' in res["message"], res + # Invalid index type + _set_request_args(monkeypatch, module, {"type": "invalid"}) + res = _run(inspect.unwrap(module.run_index)("tenant-1", "kb-1")) + assert "Invalid index type" in res["message"], res + # Missing dataset ID + _set_request_args(monkeypatch, module, {"type": "graph"}) + res = _run(inspect.unwrap(module.run_index)("tenant-1", "")) + assert "Dataset ID" in res["message"], res + + # No authorization + _set_request_args(monkeypatch, module, {"type": "graph"}) monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False) - res = _run(inspect.unwrap(module.run_graphrag)("tenant-1", "kb-1")) + res = _run(inspect.unwrap(module.run_index)("tenant-1", "kb-1")) assert res["code"] == module.RetCode.DATA_ERROR, res + # Invalid dataset ID monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True) monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None)) - res = _run(inspect.unwrap(module.run_graphrag)("tenant-1", "kb-1")) + res = _run(inspect.unwrap(module.run_index)("tenant-1", "kb-1")) assert "Invalid Dataset ID" in res["message"], res + # Stale graphrag task + successful re-queue stale_kb = _KB(kb_id="kb-1", graphrag_task_id="task-old") monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, stale_kb)) monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (False, None)) monkeypatch.setattr(module.DocumentService, "get_by_kb_id", lambda **_kwargs: ([{"id": "doc-1"}], 1)) monkeypatch.setattr(module.dataset_api_service, "queue_raptor_o_graphrag_tasks", lambda **_kwargs: "task-new") monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: True) - res = _run(inspect.unwrap(module.run_graphrag)("tenant-1", "kb-1")) + _set_request_args(monkeypatch, module, {"type": "graph"}) + res = _run(inspect.unwrap(module.run_index)("tenant-1", "kb-1")) assert res["code"] == module.RetCode.SUCCESS, res - assert any("GraphRAG" in msg for msg in warnings), warnings + assert any("Graph" in msg for msg in warnings), warnings + # Task already running monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (True, SimpleNamespace(progress=0))) - res = _run(inspect.unwrap(module.run_graphrag)("tenant-1", "kb-1")) + res = _run(inspect.unwrap(module.run_index)("tenant-1", "kb-1")) assert "already running" in res["message"], res + # Successful raptor run with save warning warnings.clear() - queue_calls = {} - no_task_kb = _KB(kb_id="kb-1", graphrag_task_id="") + no_task_kb = _KB(kb_id="kb-1", raptor_task_id="") monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, no_task_kb)) monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (False, None)) monkeypatch.setattr(module.DocumentService, "get_by_kb_id", lambda **_kwargs: ([{"id": "doc-1"}, {"id": "doc-2"}], 2)) + queue_calls = {} + def _queue(**kwargs): queue_calls.update(kwargs) - return "queued-id" + return "queued-raptor" monkeypatch.setattr(module.dataset_api_service, "queue_raptor_o_graphrag_tasks", _queue) monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: False) - res = _run(inspect.unwrap(module.run_graphrag)("tenant-1", "kb-1")) + _set_request_args(monkeypatch, module, {"type": "raptor"}) + res = _run(inspect.unwrap(module.run_index)("tenant-1", "kb-1")) assert res["code"] == module.RetCode.SUCCESS, res - assert res["data"]["graphrag_task_id"] == "queued-id", res + assert res["data"]["task_id"] == "queued-raptor", res assert queue_calls["doc_ids"] == ["doc-1", "doc-2"], queue_calls - assert any("Cannot save graphrag_task_id" in msg for msg in warnings), warnings - - res = inspect.unwrap(module.trace_graphrag)("tenant-1", "") - assert 'Dataset ID' in res["message"], res - - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False) - res = inspect.unwrap(module.trace_graphrag)("tenant-1", "kb-1") - assert res["code"] == module.RetCode.DATA_ERROR, res - - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True) - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None)) - res = inspect.unwrap(module.trace_graphrag)("tenant-1", "kb-1") - assert "Invalid Dataset ID" in res["message"], res - - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _KB(kb_id="kb-1", graphrag_task_id="task-1"))) - monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (False, None)) - res = inspect.unwrap(module.trace_graphrag)("tenant-1", "kb-1") - assert res["code"] == module.RetCode.SUCCESS, res - assert res["data"] == {}, res - - monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (True, SimpleNamespace(to_dict=lambda: {"id": _task_id, "progress": 1}))) - res = inspect.unwrap(module.trace_graphrag)("tenant-1", "kb-1") - assert res["code"] == module.RetCode.SUCCESS, res - assert res["data"]["id"] == "task-1", res + assert any("Cannot save" in msg for msg in warnings), warnings @pytest.mark.p3 -def test_run_trace_raptor_matrix_unit(monkeypatch): +def test_trace_index_matrix_unit(monkeypatch): module = _load_dataset_module(monkeypatch) - warnings = [] - monkeypatch.setattr(module.logging, "warning", lambda msg, *_args, **_kwargs: warnings.append(msg)) + # Invalid index type + _set_request_args(monkeypatch, module, {"type": "invalid"}) + res = inspect.unwrap(module.trace_index)("tenant-1", "kb-1") + assert "Invalid index type" in res["message"], res - res = _run(inspect.unwrap(module.run_raptor)("tenant-1", "")) - assert 'Dataset ID' in res["message"], res + # Missing dataset ID + _set_request_args(monkeypatch, module, {"type": "graph"}) + res = inspect.unwrap(module.trace_index)("tenant-1", "") + assert "Dataset ID" in res["message"], res + # No authorization + _set_request_args(monkeypatch, module, {"type": "graph"}) monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False) - res = _run(inspect.unwrap(module.run_raptor)("tenant-1", "kb-1")) + res = inspect.unwrap(module.trace_index)("tenant-1", "kb-1") assert res["code"] == module.RetCode.DATA_ERROR, res + # Invalid dataset ID monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True) monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None)) - res = _run(inspect.unwrap(module.run_raptor)("tenant-1", "kb-1")) + res = inspect.unwrap(module.trace_index)("tenant-1", "kb-1") assert "Invalid Dataset ID" in res["message"], res - stale_kb = _KB(kb_id="kb-1", raptor_task_id="task-old") - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, stale_kb)) - monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (False, None)) - monkeypatch.setattr(module.DocumentService, "get_by_kb_id", lambda **_kwargs: ([{"id": "doc-1"}], 1)) - monkeypatch.setattr(module.dataset_api_service, "queue_raptor_o_graphrag_tasks", lambda **_kwargs: "task-new") - monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: True) - res = _run(inspect.unwrap(module.run_raptor)("tenant-1", "kb-1")) + # No existing task — returns empty + monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _KB(kb_id="kb-1", graphrag_task_id=""))) + res = inspect.unwrap(module.trace_index)("tenant-1", "kb-1") assert res["code"] == module.RetCode.SUCCESS, res - assert any("RAPTOR" in msg for msg in warnings), warnings + assert res["data"] == {}, res - monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (True, SimpleNamespace(progress=0))) - res = _run(inspect.unwrap(module.run_raptor)("tenant-1", "kb-1")) - assert "already running" in res["message"], res - - warnings.clear() - no_task_kb = _KB(kb_id="kb-1", raptor_task_id="") - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, no_task_kb)) - monkeypatch.setattr(module.DocumentService, "get_by_kb_id", lambda **_kwargs: ([{"id": "doc-1"}], 1)) - monkeypatch.setattr(module.dataset_api_service, "queue_raptor_o_graphrag_tasks", lambda **_kwargs: "queued-raptor") - monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: False) - res = _run(inspect.unwrap(module.run_raptor)("tenant-1", "kb-1")) - assert res["code"] == module.RetCode.SUCCESS, res - assert res["data"]["raptor_task_id"] == "queued-raptor", res - assert any("Cannot save raptor_task_id" in msg for msg in warnings), warnings - - res = inspect.unwrap(module.trace_raptor)("tenant-1", "") - assert 'Dataset ID' in res["message"], res - - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False) - res = inspect.unwrap(module.trace_raptor)("tenant-1", "kb-1") - assert res["code"] == module.RetCode.DATA_ERROR, res - - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True) - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None)) - res = inspect.unwrap(module.trace_raptor)("tenant-1", "kb-1") - assert "Invalid Dataset ID" in res["message"], res - - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _KB(kb_id="kb-1", raptor_task_id="task-1"))) + # Task ID set but task not found — returns empty + monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _KB(kb_id="kb-1", graphrag_task_id="task-1"))) monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (False, None)) - res = inspect.unwrap(module.trace_raptor)("tenant-1", "kb-1") - assert "RAPTOR Task Not Found" in res["message"], res + res = inspect.unwrap(module.trace_index)("tenant-1", "kb-1") + assert res["code"] == module.RetCode.SUCCESS, res + assert res["data"] == {}, res - monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (True, SimpleNamespace(to_dict=lambda: {"id": _task_id, "progress": -1}))) - res = inspect.unwrap(module.trace_raptor)("tenant-1", "kb-1") + # Task found — returns task data + monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (True, SimpleNamespace(to_dict=lambda: {"id": _task_id, "progress": 1}))) + res = inspect.unwrap(module.trace_index)("tenant-1", "kb-1") assert res["code"] == module.RetCode.SUCCESS, res assert res["data"]["id"] == "task-1", res diff --git a/test/testcases/test_web_api/test_document_app/test_document_metadata.py b/test/testcases/test_web_api/test_document_app/test_document_metadata.py deleted file mode 100644 index 1fd6486948..0000000000 --- a/test/testcases/test_web_api/test_document_app/test_document_metadata.py +++ /dev/null @@ -1,662 +0,0 @@ -# -# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import asyncio -from types import SimpleNamespace - -import pytest -from test_common import ( - delete_document, - document_change_status, - document_filter, - document_infos, - document_metadata_summary, - document_metadata_update, - document_update_metadata_setting, -) -from configs import INVALID_API_TOKEN -from libs.auth import RAGFlowWebApiAuth - -INVALID_AUTH_CASES = [ - (None, 401, "Unauthorized"), - (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "Unauthorized"), -] - - -class TestAuthorization: - @pytest.mark.p2 - @pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) - def test_filter_auth_invalid(self, invalid_auth, expected_code, expected_fragment): - res = document_filter(invalid_auth, "kb_id", {}) - assert res["code"] == expected_code, res - assert expected_fragment in res["message"], res - - @pytest.mark.p2 - @pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) - def test_infos_auth_invalid(self, invalid_auth, expected_code, expected_fragment): - res = document_infos(invalid_auth, "kb_id", {"doc_ids": ["doc_id"]}) - assert res["code"] == expected_code, res - assert expected_fragment in res["message"], res - - ## The inputs has been changed to add 'doc_ids' - ## TODO: - #@pytest.mark.p2 - #@pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) - #def test_metadata_summary_auth_invalid(self, invalid_auth, expected_code, expected_fragment): - # res = document_metadata_summary(invalid_auth, {"kb_id": "kb_id"}) - # assert res["code"] == expected_code, res - # assert expected_fragment in res["message"], res - - ## The inputs has been changed to deprecate 'selector' - ## TODO: - #@pytest.mark.p2 - #@pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) - #def test_metadata_update_auth_invalid(self, invalid_auth, expected_code, expected_fragment): - # res = document_metadata_update(invalid_auth, {"kb_id": "kb_id", "selector": {"document_ids": ["doc_id"]}, "updates": []}) - # assert res["code"] == expected_code, res - # assert expected_fragment in res["message"], res - - @pytest.mark.p2 - @pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) - def test_update_metadata_setting_auth_invalid(self, invalid_auth, expected_code, expected_fragment): - res = document_update_metadata_setting(invalid_auth, "kb_id", "doc_id", {"metadata": {}}) - assert res["code"] == expected_code, res - assert expected_fragment in res["message"], res - - @pytest.mark.p2 - @pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) - def test_change_status_auth_invalid(self, invalid_auth, expected_code, expected_fragment): - res = document_change_status(invalid_auth, {"doc_ids": ["doc_id"], "status": "1"}) - assert res["code"] == expected_code, res - assert expected_fragment in res["message"], res - -class TestDocumentMetadata: - @pytest.mark.p2 - def test_filter(self, WebApiAuth, add_dataset_func): - kb_id = add_dataset_func - res = document_filter(WebApiAuth, kb_id, {}) - assert res["code"] == 0, res - assert "filter" in res["data"], res - assert "total" in res["data"], res - - @pytest.mark.p2 - def test_infos(self, WebApiAuth, add_document_func): - dataset_id, doc_id = add_document_func - res = document_infos(WebApiAuth, dataset_id, {"ids": [doc_id]}) - assert res["code"] == 0, res - docs = res["data"]["docs"] - assert len(docs) == 1, docs - assert docs[0]["id"] == doc_id, res - - ## The inputs has been changed to add 'doc_ids' - ## TODO: - #@pytest.mark.p2 - #def test_metadata_summary(self, WebApiAuth, add_document_func): - # kb_id, _ = add_document_func - # res = document_metadata_summary(WebApiAuth, {"kb_id": kb_id}) - # assert res["code"] == 0, res - # assert isinstance(res["data"]["summary"], dict), res - - ## The inputs has been changed to deprecate 'selector' - ## TODO: - #@pytest.mark.p2 - #def test_metadata_update(self, WebApiAuth, add_document_func): - # kb_id, doc_id = add_document_func - # payload = { - # "kb_id": kb_id, - # "selector": {"document_ids": [doc_id]}, - # "updates": [{"key": "author", "value": "alice"}], - # "deletes": [], - # } - # res = document_metadata_update(WebApiAuth, payload) - # assert res["code"] == 0, res - # assert res["data"]["matched_docs"] == 1, res - # info_res = document_infos(WebApiAuth, {"doc_ids": [doc_id]}) - # assert info_res["code"] == 0, info_res - # meta_fields = info_res["data"][0].get("meta_fields", {}) - # assert meta_fields.get("author") == "alice", info_res - - ## The inputs has been changed to deprecate 'selector' - ## TODO: - #@pytest.mark.p2 - #def test_update_metadata_setting(self, WebApiAuth, add_document_func): - # _, doc_id = add_document_func - # metadata = {"source": "test"} - # res = document_update_metadata_setting(WebApiAuth, {"doc_id": doc_id, "metadata": metadata}) - # assert res["code"] == 0, res - # assert res["data"]["id"] == doc_id, res - # assert res["data"]["parser_config"]["metadata"] == metadata, res - - @pytest.mark.p2 - def test_change_status(self, WebApiAuth, add_document_func): - dataset_id, doc_id = add_document_func - res = document_change_status(WebApiAuth, {"doc_ids": [doc_id], "status": "1"}) - - assert res["code"] == 0, res - assert res["data"][doc_id]["status"] == "1", res - info_res = document_infos(WebApiAuth, dataset_id, {"ids": [doc_id]}) - - assert info_res["code"] == 0, info_res - assert info_res["data"]["docs"][0]["status"] == "1", info_res - - -class TestDocumentMetadataNegative: - @pytest.mark.p2 - def test_filter_missing_kb_id(self, WebApiAuth, add_document_func): - kb_id, doc_id = add_document_func - res = document_filter(WebApiAuth, "", {"ids": [doc_id]}) - assert res["code"] == 100, res - assert "" == res["message"], res - - @pytest.mark.p3 - def test_metadata_summary_missing_kb_id(self, WebApiAuth, add_document_func): - _, doc_id = add_document_func - res = document_metadata_summary(WebApiAuth, {"doc_ids": [doc_id]}) - assert res["code"] == 101, res - assert "KB ID" in res["message"], res - - ## The inputs has been changed to deprecate 'selector' - ## TODO: - #@pytest.mark.p3 - #def test_metadata_update_missing_kb_id(self, WebApiAuth, add_document_func): - # _, doc_id = add_document_func - # res = document_metadata_update(WebApiAuth, {"selector": {"document_ids": [doc_id]}, "updates": []}) - # assert res["code"] == 101, res - # assert "KB ID" in res["message"], res - - @pytest.mark.p3 - def test_infos_invalid_doc_id(self, WebApiAuth): - res = document_infos(WebApiAuth, {"doc_ids": ["invalid_id"]}) - assert res["code"] == 109, res - assert "No authorization" in res["message"], res - - @pytest.mark.p3 - def test_update_metadata_setting_missing_metadata(self, WebApiAuth, add_document_func): - _, doc_id = add_document_func - res = document_update_metadata_setting(WebApiAuth, {"doc_id": doc_id}) - assert res["code"] == 101, res - assert "required argument are missing" in res["message"], res - assert "metadata" in res["message"], res - - @pytest.mark.p2 - def test_update_metadata_setting_not_found(self, WebApiAuth, add_document_func): - """Test updating metadata setting for a non-existent document returns error.""" - dataset_id, doc_id = add_document_func - # First delete the document - delete_res = delete_document(WebApiAuth, dataset_id, {"ids": [doc_id]}) - assert delete_res["code"] == 0, delete_res - - # Now try to update metadata setting for the deleted document - res = document_update_metadata_setting(WebApiAuth, dataset_id, doc_id, {"metadata": {"author": "test"}}) - assert res["code"] == 102, res - assert f"Document {doc_id} not found in dataset {dataset_id}" in res["message"], res - - @pytest.mark.p3 - def test_change_status_invalid_status(self, WebApiAuth, add_document_func): - _, doc_id = add_document_func - res = document_change_status(WebApiAuth, {"doc_ids": [doc_id], "status": "2"}) - assert res["code"] == 101, res - assert "Status" in res["message"], res - - -def _run(coro): - return asyncio.run(coro) - - -class _DummyArgs: - def __init__(self, args=None): - self._args = args or {} - - def get(self, key, default=None): - return self._args.get(key, default) - - def getlist(self, key): - value = self._args.get(key, []) - if isinstance(value, list): - return value - return [value] - - -class _DummyRequest: - def __init__(self, args=None): - self.args = _DummyArgs(args) - - -class _DummyResponse: - def __init__(self, data=None): - self.data = data - self.headers = {} - - -@pytest.mark.p2 -class TestDocumentMetadataUnit: - def _allow_kb(self, module, monkeypatch, kb_id="kb1", tenant_id="tenant1"): - monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id=tenant_id)]) - monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: True if _kwargs.get("id") == kb_id else False) - - @pytest.mark.p3 - def test_update_metadata_missing_dataset_id(self, WebApiAuth, add_document_func): - """Test the new unified update_metadata API - missing dataset_id.""" - # Call with empty dataset_id (should fail validation) - res = document_metadata_update(WebApiAuth, "", {"dataset_id": "", "selector": {"document_ids": ["doc1"]}, "updates": []}) - assert res["code"] == 404 - assert res["message"] == "Not Found: /api/v1/datasets//documents/metadatas", res - - @pytest.mark.p3 - def test_update_metadata_success(self, WebApiAuth, add_document_func): - """Test the new unified update_metadata API - success case.""" - kb_id, doc_id = add_document_func - res = document_metadata_update( - WebApiAuth, kb_id, - { - "selector": {"document_ids": [doc_id]}, - "updates": [{"key": "author", "value": "test_author"}], - "deletes": [] - } - ) - assert res["code"] == 0, res - - - @pytest.mark.p3 - def test_update_metadata_invalid_delete_item(self, WebApiAuth, add_document_func): - """Test the new unified update_metadata API - invalid delete item.""" - kb_id, doc_id = add_document_func - res = document_metadata_update( - WebApiAuth, kb_id, - { - "selector": {"document_ids": [doc_id]}, - "updates": [], - "deletes": [{}] # Invalid - missing key - } - ) - assert res["code"] == 102 - assert "Each delete requires key" in res["message"], res - - - def test_thumbnails_missing_ids_rewrite_and_exception_unit(self, document_app_module, monkeypatch): - module = document_app_module - monkeypatch.setattr(module, "request", _DummyRequest(args={})) - res = module.thumbnails() - assert res["code"] == module.RetCode.ARGUMENT_ERROR - assert 'Lack of "Document ID"' in res["message"] - - monkeypatch.setattr(module, "request", _DummyRequest(args={"doc_ids": ["doc1", "doc2"]})) - monkeypatch.setattr( - module.DocumentService, - "get_thumbnails", - lambda _doc_ids: [ - {"id": "doc1", "kb_id": "kb1", "thumbnail": "thumb.jpg"}, - {"id": "doc2", "kb_id": "kb1", "thumbnail": f"{module.IMG_BASE64_PREFIX}blob"}, - ], - ) - res = module.thumbnails() - assert res["code"] == 0 - assert res["data"]["doc1"] == "/v1/document/image/kb1-thumb.jpg" - assert res["data"]["doc2"] == f"{module.IMG_BASE64_PREFIX}blob" - - def raise_error(*_args, **_kwargs): - raise RuntimeError("thumb boom") - - monkeypatch.setattr(module.DocumentService, "get_thumbnails", raise_error) - monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)}) - res = module.thumbnails() - assert res["code"] == 500 - assert "thumb boom" in res["message"] - - def test_change_status_partial_failure_matrix_unit(self, document_app_module, monkeypatch): - module = document_app_module - calls = {"docstore_update": []} - doc_ids = ["unauth", "missing_doc", "missing_kb", "update_fail", "docstore_3022", "docstore_generic", "outer_exc"] - - async def fake_request_json(): - return {"doc_ids": doc_ids, "status": "1"} - - def fake_accessible(doc_id, _uid): - return doc_id != "unauth" - - def fake_get_by_id(doc_id): - if doc_id == "missing_doc": - return False, None - if doc_id == "outer_exc": - raise RuntimeError("explode") - kb_id = "kb_missing" if doc_id == "missing_kb" else "kb1" - chunk_num = 1 if doc_id in {"docstore_3022", "docstore_generic"} else 0 - doc = SimpleNamespace(id=doc_id, kb_id=kb_id, status="0", chunk_num=chunk_num) - return True, doc - - def fake_get_kb(kb_id): - if kb_id == "kb_missing": - return False, None - return True, SimpleNamespace(tenant_id="tenant1") - - def fake_update_by_id(doc_id, _payload): - return doc_id != "update_fail" - - class _DocStore: - def update(self, where, _payload, _index_name, _kb_id): - calls["docstore_update"].append(where["doc_id"]) - if where["doc_id"] == "docstore_3022": - raise RuntimeError("3022 table missing") - if where["doc_id"] == "docstore_generic": - raise RuntimeError("doc store down") - return True - - monkeypatch.setattr(module, "get_request_json", fake_request_json) - monkeypatch.setattr(module.DocumentService, "accessible", fake_accessible) - monkeypatch.setattr(module.DocumentService, "get_by_id", fake_get_by_id) - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda kb_id: fake_get_kb(kb_id)) - monkeypatch.setattr(module.DocumentService, "update_by_id", fake_update_by_id) - monkeypatch.setattr(module.settings, "docStoreConn", _DocStore()) - monkeypatch.setattr(module.search, "index_name", lambda tenant_id: f"idx_{tenant_id}") - - res = _run(module.change_status.__wrapped__()) - assert res["code"] == module.RetCode.SERVER_ERROR - assert res["message"] == "Partial failure" - assert res["data"]["unauth"]["error"] == "No authorization." - assert res["data"]["missing_doc"]["error"] == "No authorization." - assert res["data"]["missing_kb"]["error"] == "Can't find this dataset!" - assert res["data"]["update_fail"]["error"] == "Database error (Document update)!" - assert res["data"]["docstore_3022"]["error"] == "Document store table missing." - assert "Document store update failed:" in res["data"]["docstore_generic"]["error"] - assert "Internal server error: explode" == res["data"]["outer_exc"]["error"] - assert calls["docstore_update"] == ["docstore_3022", "docstore_generic"] - - def test_change_status_invalid_status_unit(self, document_app_module, monkeypatch): - module = document_app_module - - async def fake_request_json(): - return {"doc_ids": ["doc1"], "status": "2"} - - monkeypatch.setattr(module, "get_request_json", fake_request_json) - res = _run(module.change_status.__wrapped__()) - assert res["code"] == module.RetCode.ARGUMENT_ERROR - assert '"Status" must be either 0 or 1!' in res["message"] - - def test_change_status_all_success_unit(self, document_app_module, monkeypatch): - module = document_app_module - - async def fake_request_json(): - return {"doc_ids": ["doc1"], "status": "1"} - - monkeypatch.setattr(module, "get_request_json", fake_request_json) - monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: True) - monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, SimpleNamespace(id="doc1", kb_id="kb1", status="0", chunk_num=0))) - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant1"))) - monkeypatch.setattr(module.DocumentService, "update_by_id", lambda *_args, **_kwargs: True) - res = _run(module.change_status.__wrapped__()) - assert res["code"] == 0 - assert res["data"]["doc1"]["status"] == "1" - - def test_get_route_not_found_success_and_exception_unit(self, document_app_module, monkeypatch): - module = document_app_module - monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None)) - res = _run(module.get("doc1")) - assert res["code"] == module.RetCode.DATA_ERROR - assert "Document not found!" in res["message"] - - async def fake_thread_pool_exec(*_args, **_kwargs): - return b"blob-data" - - async def fake_make_response(data): - return _DummyResponse(data) - - monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, SimpleNamespace(name="image.abc", type=module.FileType.VISUAL.value))) - monkeypatch.setattr(module.File2DocumentService, "get_storage_address", lambda **_kwargs: ("bucket", "name")) - monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(get=lambda *_args, **_kwargs: b"blob-data")) - monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec) - monkeypatch.setattr(module, "make_response", fake_make_response) - monkeypatch.setattr( - module, - "apply_safe_file_response_headers", - lambda response, content_type, extension: response.headers.update({"content_type": content_type, "extension": extension}), - ) - res = _run(module.get("doc1")) - assert isinstance(res, _DummyResponse) - assert res.data == b"blob-data" - assert res.headers["content_type"] == "image/abc" - assert res.headers["extension"] == "abc" - - monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (_ for _ in ()).throw(RuntimeError("get boom"))) - monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)}) - res = _run(module.get("doc1")) - assert res["code"] == 500 - assert "get boom" in res["message"] - - def test_download_attachment_success_and_exception_unit(self, document_app_module, monkeypatch): - module = document_app_module - monkeypatch.setattr(module, "request", _DummyRequest(args={"ext": "abc"})) - - async def fake_thread_pool_exec(*_args, **_kwargs): - return b"attachment" - - async def fake_make_response(data): - return _DummyResponse(data) - - monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec) - monkeypatch.setattr(module, "make_response", fake_make_response) - monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(get=lambda *_args, **_kwargs: b"attachment")) - monkeypatch.setattr( - module, - "apply_safe_file_response_headers", - lambda response, content_type, extension: response.headers.update({"content_type": content_type, "extension": extension}), - ) - res = _run(module.download_attachment("att1")) - assert isinstance(res, _DummyResponse) - assert res.data == b"attachment" - assert res.headers["content_type"] == "application/abc" - assert res.headers["extension"] == "abc" - - async def raise_error(*_args, **_kwargs): - raise RuntimeError("download boom") - - monkeypatch.setattr(module, "thread_pool_exec", raise_error) - monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)}) - res = _run(module.download_attachment("att1")) - assert res["code"] == 500 - assert "download boom" in res["message"] - - def test_change_parser_guards_and_reset_update_failure_unit(self, document_app_module, monkeypatch): - module = document_app_module - - monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)}) - - async def req_auth_fail(): - return {"doc_id": "doc1", "parser_id": "naive", "pipeline_id": "pipe2"} - - monkeypatch.setattr(module, "get_request_json", req_auth_fail) - monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: False) - res = _run(module.change_parser.__wrapped__()) - assert res["code"] == module.RetCode.AUTHENTICATION_ERROR - - monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: True) - monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None)) - res = _run(module.change_parser.__wrapped__()) - assert res["code"] == module.RetCode.DATA_ERROR - assert "Document not found!" in res["message"] - - async def req_same_pipeline(): - return {"doc_id": "doc1", "parser_id": "naive", "pipeline_id": "pipe1"} - - doc_same = SimpleNamespace( - id="doc1", - pipeline_id="pipe1", - parser_id="naive", - parser_config={"k": "v"}, - token_num=0, - chunk_num=0, - process_duration=0, - kb_id="kb1", - type="doc", - name="doc.txt", - ) - monkeypatch.setattr(module, "get_request_json", req_same_pipeline) - monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, doc_same)) - res = _run(module.change_parser.__wrapped__()) - assert res["code"] == 0 - - calls = [] - - async def req_pipeline_change(): - return {"doc_id": "doc1", "parser_id": "naive", "pipeline_id": "pipe2"} - - doc = SimpleNamespace( - id="doc1", - pipeline_id="pipe1", - parser_id="naive", - parser_config={}, - token_num=0, - chunk_num=0, - process_duration=0, - kb_id="kb1", - type="doc", - name="doc.txt", - ) - - def fake_update_by_id(doc_id, payload): - calls.append((doc_id, payload)) - return True - - monkeypatch.setattr(module, "get_request_json", req_pipeline_change) - monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, doc)) - monkeypatch.setattr(module.DocumentService, "update_by_id", fake_update_by_id) - res = _run(module.change_parser.__wrapped__()) - assert res["code"] == 0 - assert calls[0][1] == {"pipeline_id": "pipe2"} - assert calls[1][1]["run"] == module.TaskStatus.UNSTART.value - - doc.token_num = 3 - doc.chunk_num = 2 - doc.process_duration = 9 - monkeypatch.setattr(module.DocumentService, "increment_chunk_num", lambda *_args, **_kwargs: False) - res = _run(module.change_parser.__wrapped__()) - assert res["code"] == 0 - - monkeypatch.setattr(module.DocumentService, "increment_chunk_num", lambda *_args, **_kwargs: True) - monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: None) - res = _run(module.change_parser.__wrapped__()) - assert res["code"] == 0 - - side_effects = {"img": [], "delete": []} - - class _DocStore: - def index_exist(self, _idx, _kb_id): - return True - - def delete(self, where, _idx, kb_id): - side_effects["delete"].append((where["doc_id"], kb_id)) - - monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant1") - monkeypatch.setattr(module.DocumentService, "delete_chunk_images", lambda _doc, _tenant: side_effects["img"].append((_doc.id, _tenant))) - monkeypatch.setattr(module.search, "index_name", lambda tenant_id: f"idx_{tenant_id}") - monkeypatch.setattr(module.settings, "docStoreConn", _DocStore()) - res = _run(module.change_parser.__wrapped__()) - assert res["code"] == 0 - assert ("doc1", "tenant1") in side_effects["img"] - assert ("doc1", "kb1") in side_effects["delete"] - - async def req_same_parser_with_cfg(): - return {"doc_id": "doc1", "parser_id": "naive", "parser_config": {"a": 1}} - - doc_same_parser = SimpleNamespace( - id="doc1", - pipeline_id="pipe1", - parser_id="naive", - parser_config={"a": 1}, - token_num=0, - chunk_num=0, - process_duration=0, - kb_id="kb1", - type="doc", - name="doc.txt", - ) - monkeypatch.setattr(module, "get_request_json", req_same_parser_with_cfg) - monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, doc_same_parser)) - res = _run(module.change_parser.__wrapped__()) - assert res["code"] == 0 - - async def req_same_parser_no_cfg(): - return {"doc_id": "doc1", "parser_id": "naive"} - - monkeypatch.setattr(module, "get_request_json", req_same_parser_no_cfg) - res = _run(module.change_parser.__wrapped__()) - assert res["code"] == 0 - - parser_cfg_updates = [] - - async def req_parser_update(): - return {"doc_id": "doc1", "parser_id": "paper", "pipeline_id": "", "parser_config": {"beta": True}} - - doc_parser_update = SimpleNamespace( - id="doc1", - pipeline_id="pipe1", - parser_id="naive", - parser_config={"alpha": 1}, - token_num=0, - chunk_num=0, - process_duration=0, - kb_id="kb1", - type="doc", - name="doc.txt", - ) - monkeypatch.setattr(module, "get_request_json", req_parser_update) - monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, doc_parser_update)) - monkeypatch.setattr(module.DocumentService, "update_parser_config", lambda doc_id, cfg: parser_cfg_updates.append((doc_id, cfg))) - monkeypatch.setattr(module.DocumentService, "update_by_id", lambda *_args, **_kwargs: True) - res = _run(module.change_parser.__wrapped__()) - assert res["code"] == 0 - assert parser_cfg_updates == [("doc1", {"beta": True})] - - def raise_parser_config(*_args, **_kwargs): - raise RuntimeError("parser boom") - - monkeypatch.setattr(module.DocumentService, "update_parser_config", raise_parser_config) - res = _run(module.change_parser.__wrapped__()) - assert res["code"] == 500 - assert "parser boom" in res["message"] - - def test_get_image_success_and_exception_unit(self, document_app_module, monkeypatch): - module = document_app_module - - class _Headers(dict): - def set(self, key, value): - self[key] = value - - class _ImageResponse: - def __init__(self, data): - self.data = data - self.headers = _Headers() - - async def fake_thread_pool_exec(*_args, **_kwargs): - return b"image-bytes" - - async def fake_make_response(data): - return _ImageResponse(data) - - monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec) - monkeypatch.setattr(module, "make_response", fake_make_response) - monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(get=lambda *_args, **_kwargs: b"image-bytes")) - res = _run(module.get_image("bucket-name")) - assert isinstance(res, _ImageResponse) - assert res.data == b"image-bytes" - assert res.headers["Content-Type"] == "image/JPEG" - - async def raise_error(*_args, **_kwargs): - raise RuntimeError("image boom") - - monkeypatch.setattr(module, "thread_pool_exec", raise_error) - monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)}) - res = _run(module.get_image("bucket-name")) - assert res["code"] == 500 - assert "image boom" in res["message"] diff --git a/test/testcases/test_web_api/test_document_app/test_list_documents.py b/test/testcases/test_web_api/test_document_app/test_list_documents.py index 4005c07735..e4a9579a8a 100644 --- a/test/testcases/test_web_api/test_document_app/test_list_documents.py +++ b/test/testcases/test_web_api/test_document_app/test_list_documents.py @@ -172,15 +172,15 @@ class TestDocumentsList: def test_missing_kb_id(self, WebApiAuth): """Test missing KB ID returns error.""" res = list_documents(WebApiAuth, {"kb_id": ""}) - assert res["code"] == 100 - assert res["message"] == "" + assert res["code"] == 102 + assert res["message"] @pytest.mark.p2 def test_unauthorized_dataset(self, WebApiAuth): """Test unauthorized dataset returns error.""" res = list_documents(WebApiAuth, {"kb_id": "non_existent_kb_id"}) assert res["code"] == 102 - assert "You don't own the dataset" in res["message"] + assert res["message"] @pytest.mark.p3 def test_invalid_run_status_filter(self, WebApiAuth, add_documents): diff --git a/test/testcases/test_web_api/test_kb_app/conftest.py b/test/testcases/test_web_api/test_kb_app/conftest.py deleted file mode 100644 index 667e85e47c..0000000000 --- a/test/testcases/test_web_api/test_kb_app/conftest.py +++ /dev/null @@ -1,50 +0,0 @@ -# -# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import pytest -from test_common import batch_create_datasets, list_datasets, delete_datasets -from libs.auth import RAGFlowWebApiAuth -from pytest import FixtureRequest -from ragflow_sdk import RAGFlow - - -@pytest.fixture(scope="class") -def add_datasets(request: FixtureRequest, client: RAGFlow, WebApiAuth: RAGFlowWebApiAuth) -> list[str]: - dataset_ids = batch_create_datasets(WebApiAuth, 5) - - def cleanup(): - # Web KB cleanup cannot call SDK dataset bulk delete with empty ids; deletion must stay explicit. - res = list_datasets(WebApiAuth, params={"page_size": 1000}) - existing_ids = {kb["id"] for kb in res["data"]} - ids_to_delete = list({dataset_id for dataset_id in dataset_ids if dataset_id in existing_ids}) - delete_datasets(WebApiAuth, {"ids": ids_to_delete}) - - request.addfinalizer(cleanup) - return dataset_ids - - -@pytest.fixture(scope="function") -def add_datasets_func(request: FixtureRequest, client: RAGFlow, WebApiAuth: RAGFlowWebApiAuth) -> list[str]: - dataset_ids = batch_create_datasets(WebApiAuth, 3) - - def cleanup(): - # Web KB cleanup cannot call SDK dataset bulk delete with empty ids; deletion must stay explicit. - res = list_datasets(WebApiAuth, params={"page_size": 1000}) - existing_ids = {kb["id"] for kb in res["data"]} - ids_to_delete = list({dataset_id for dataset_id in dataset_ids if dataset_id in existing_ids}) - delete_datasets(WebApiAuth, {"ids": ids_to_delete}) - - request.addfinalizer(cleanup) - return dataset_ids diff --git a/test/testcases/test_web_api/test_kb_app/test_create_kb.py b/test/testcases/test_web_api/test_kb_app/test_create_kb.py deleted file mode 100644 index e6ae9e0339..0000000000 --- a/test/testcases/test_web_api/test_kb_app/test_create_kb.py +++ /dev/null @@ -1,109 +0,0 @@ -# -# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from concurrent.futures import ThreadPoolExecutor, as_completed - -import pytest -from test_common import create_dataset -from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN -from hypothesis import example, given, settings -from libs.auth import RAGFlowWebApiAuth -from utils.hypothesis_utils import valid_names - - -@pytest.mark.usefixtures("clear_datasets") -class TestAuthorization: - @pytest.mark.p2 - @pytest.mark.parametrize( - "invalid_auth, expected_code, expected_message", - [ - (None, 401, ""), - (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), - ], - ids=["empty_auth", "invalid_api_token"], - ) - def test_auth_invalid(self, invalid_auth, expected_code, expected_message): - res = create_dataset(invalid_auth, {"name": "auth_test"}) - assert res["code"] == expected_code, res - assert res["message"] == expected_message, res - - -@pytest.mark.usefixtures("clear_datasets") -class TestCapability: - @pytest.mark.p3 - def test_create_kb_1k(self, WebApiAuth): - for i in range(1_000): - payload = {"name": f"dataset_{i}"} - res = create_dataset(WebApiAuth, payload) - assert res["code"] == 0, f"Failed to create dataset {i}" - - @pytest.mark.p3 - def test_create_kb_concurrent(self, WebApiAuth): - count = 100 - with ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(create_dataset, WebApiAuth, {"name": f"dataset_{i}"}) for i in range(count)] - responses = list(as_completed(futures)) - assert len(responses) == count, responses - assert all(future.result()["code"] == 0 for future in futures) - - -@pytest.mark.usefixtures("clear_datasets") -class TestDatasetCreate: - @pytest.mark.p1 - @given(name=valid_names()) - @example("a" * 128) - @settings(max_examples=20) - def test_name(self, WebApiAuth, name): - res = create_dataset(WebApiAuth, {"name": name}) - assert res["code"] == 0, res - - @pytest.mark.p2 - @pytest.mark.parametrize( - "name, expected_message", - [ - ("", "Field: - Message: "), - (" ", "Field: - Message: "), - ("a" * (DATASET_NAME_LIMIT + 1), "Field: - Message: "), - (0, "Field: - Message: "), - (None, "Field: - Message: "), - ], - ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"], - ) - def test_name_invalid(self, WebApiAuth, name, expected_message): - payload = {"name": name} - res = create_dataset(WebApiAuth, payload) - assert res["code"] == 101, res - assert expected_message in res["message"], res - - @pytest.mark.p3 - def test_name_duplicated(self, WebApiAuth): - name = "duplicated_name" - payload = {"name": name} - res = create_dataset(WebApiAuth, payload) - assert res["code"] == 0, res - - res = create_dataset(WebApiAuth, payload) - assert res["code"] == 0, res - - @pytest.mark.p3 - def test_name_case_insensitive(self, WebApiAuth): - name = "CaseInsensitive" - payload = {"name": name.upper()} - res = create_dataset(WebApiAuth, payload) - assert res["code"] == 0, res - - payload = {"name": name.lower()} - res = create_dataset(WebApiAuth, payload) - assert res["code"] == 0, res diff --git a/test/testcases/test_web_api/test_kb_app/test_detail_kb.py b/test/testcases/test_web_api/test_kb_app/test_detail_kb.py deleted file mode 100644 index ae0e12ac4f..0000000000 --- a/test/testcases/test_web_api/test_kb_app/test_detail_kb.py +++ /dev/null @@ -1,53 +0,0 @@ -# -# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import pytest -from test_common import ( - detail_kb, -) -from configs import INVALID_API_TOKEN -from libs.auth import RAGFlowWebApiAuth - - -class TestAuthorization: - @pytest.mark.p2 - @pytest.mark.parametrize( - "invalid_auth, expected_code, expected_message", - [ - (None, 401, ""), - (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), - ], - ) - def test_auth_invalid(self, invalid_auth, expected_code, expected_message): - res = detail_kb(invalid_auth) - assert res["code"] == expected_code, res - assert res["message"] == expected_message, res - - -class TestDatasetsDetail: - @pytest.mark.p1 - def test_kb_id(self, WebApiAuth, add_dataset): - kb_id = add_dataset - payload = {"kb_id": kb_id} - res = detail_kb(WebApiAuth, payload) - assert res["code"] == 0, res - assert res["data"]["name"] == "kb_0" - - @pytest.mark.p2 - def test_id_wrong_uuid(self, WebApiAuth): - payload = {"kb_id": "d94a8dc02c9711f0930f7fbc369eab6d"} - res = detail_kb(WebApiAuth, payload) - assert res["code"] == 103, res - assert "Only owner of dataset authorized for this operation." in res["message"], res diff --git a/test/testcases/test_web_api/test_kb_app/test_kb_pipeline_tasks.py b/test/testcases/test_web_api/test_kb_app/test_kb_pipeline_tasks.py deleted file mode 100644 index a4dfe50c77..0000000000 --- a/test/testcases/test_web_api/test_kb_app/test_kb_pipeline_tasks.py +++ /dev/null @@ -1,233 +0,0 @@ -# -# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import pytest -from test_common import ( - kb_delete_pipeline_logs, - kb_list_pipeline_dataset_logs, - kb_list_pipeline_logs, - kb_pipeline_log_detail, - run_graphrag, - trace_graphrag, - run_raptor, - trace_raptor, - kb_run_mindmap, - kb_trace_mindmap, - list_documents, - parse_documents, -) -from utils import wait_for - -TASK_STATUS_DONE = "3" - -def _find_task(data, task_id): - if isinstance(data, dict): - if data.get("id") == task_id: - return data - tasks = data.get("tasks") - if isinstance(tasks, list): - for item in tasks: - if isinstance(item, dict) and item.get("id") == task_id: - return item - elif isinstance(data, list): - for item in data: - if isinstance(item, dict) and item.get("id") == task_id: - return item - return None - - -def _assert_progress_in_scale(progress, payload): - assert isinstance(progress, (int, float)), payload - if progress < 0: - assert False, f"Negative progress is not expected: {payload}" - scale = 100 if progress > 1 else 1 - # Infer scale from observed payload (0..1 or 0..100). - assert 0 <= progress <= scale, payload - return scale - - -def _wait_for_task(trace_func, auth, kb_id, task_id, timeout=60, use_params_payload=False): - @wait_for(timeout, 1, "Pipeline task trace timeout") - def _condition(): - if use_params_payload: - res = trace_func(auth, {"kb_id": kb_id}) - else: - res = trace_func(auth, kb_id) - if res["code"] != 0: - return False - return _find_task(res["data"], task_id) is not None - - _condition() - - -def _wait_for_docs_parsed(auth, kb_id, timeout=60): - @wait_for(timeout, 2, "Document parsing timeout") - def _condition(): - res = list_documents(auth, {"kb_id": kb_id}) - if res["code"] != 0: - return False - for doc in res["data"]["docs"]: - progress = doc.get("progress", 0) - _assert_progress_in_scale(progress, doc) - scale = 100 if progress > 1 else 1 - if doc.get("run") != TASK_STATUS_DONE or progress < scale: - return False - return True - - _condition() - - -def _wait_for_pipeline_logs(auth, kb_id, timeout=30): - @wait_for(timeout, 1, "Pipeline log timeout") - def _condition(): - res = kb_list_pipeline_logs(auth, params={"kb_id": kb_id}, payload={}) - if res["code"] != 0: - return False - return bool(res["data"]["logs"]) - - _condition() - - -class TestKbPipelineTasks: - @pytest.mark.p3 - def test_graphrag_run_and_trace(self, WebApiAuth, add_chunks): - kb_id, _, _ = add_chunks - run_res = run_graphrag(WebApiAuth, kb_id) - assert run_res["code"] == 0, run_res - task_id = run_res["data"]["graphrag_task_id"] - assert task_id, run_res - - _wait_for_task(trace_graphrag, WebApiAuth, kb_id, task_id) - trace_res = trace_graphrag(WebApiAuth, kb_id) - assert trace_res["code"] == 0, trace_res - task = _find_task(trace_res["data"], task_id) - assert task, trace_res - assert task["id"] == task_id, trace_res - progress = task.get("progress") - _assert_progress_in_scale(progress, task) - - @pytest.mark.p3 - def test_raptor_run_and_trace(self, WebApiAuth, add_chunks): - kb_id, _, _ = add_chunks - run_res = run_raptor(WebApiAuth, kb_id) - assert run_res["code"] == 0, run_res - task_id = run_res["data"]["raptor_task_id"] - assert task_id, run_res - - _wait_for_task(trace_raptor, WebApiAuth, kb_id, task_id) - trace_res = trace_raptor(WebApiAuth, kb_id) - assert trace_res["code"] == 0, trace_res - task = _find_task(trace_res["data"], task_id) - assert task, trace_res - assert task["id"] == task_id, trace_res - progress = task.get("progress") - _assert_progress_in_scale(progress, task) - - @pytest.mark.p3 - def test_mindmap_run_and_trace(self, WebApiAuth, add_chunks): - kb_id, _, _ = add_chunks - run_res = kb_run_mindmap(WebApiAuth, {"kb_id": kb_id}) - assert run_res["code"] == 0, run_res - task_id = run_res["data"]["mindmap_task_id"] - assert task_id, run_res - - _wait_for_task(kb_trace_mindmap, WebApiAuth, kb_id, task_id, use_params_payload=True) - trace_res = kb_trace_mindmap(WebApiAuth, {"kb_id": kb_id}) - assert trace_res["code"] == 0, trace_res - task = _find_task(trace_res["data"], task_id) - assert task, trace_res - assert task["id"] == task_id, trace_res - progress = task.get("progress") - _assert_progress_in_scale(progress, task) - - -class TestKbPipelineLogs: - @pytest.mark.p3 - def test_pipeline_log_lifecycle(self, WebApiAuth, add_document): - kb_id, document_id = add_document - parse_documents(WebApiAuth, {"doc_ids": [document_id], "run": "1"}) - _wait_for_docs_parsed(WebApiAuth, kb_id) - _wait_for_pipeline_logs(WebApiAuth, kb_id) - - list_res = kb_list_pipeline_logs(WebApiAuth, params={"kb_id": kb_id}, payload={}) - assert list_res["code"] == 0, list_res - assert "total" in list_res["data"], list_res - assert isinstance(list_res["data"]["logs"], list), list_res - assert list_res["data"]["logs"], list_res - - log_id = list_res["data"]["logs"][0]["id"] - detail_res = kb_pipeline_log_detail(WebApiAuth, {"log_id": log_id}) - assert detail_res["code"] == 0, detail_res - detail = detail_res["data"] - assert detail["id"] == log_id, detail_res - assert detail["kb_id"] == kb_id, detail_res - for key in ["document_id", "task_type", "operation_status", "progress"]: - assert key in detail, detail_res - - delete_res = kb_delete_pipeline_logs(WebApiAuth, params={"kb_id": kb_id}, payload={"log_ids": [log_id]}) - assert delete_res["code"] == 0, delete_res - assert delete_res["data"] is True, delete_res - - @wait_for(30, 1, "Pipeline log delete timeout") - def _condition(): - res = kb_list_pipeline_logs(WebApiAuth, params={"kb_id": kb_id}, payload={}) - if res["code"] != 0: - return False - return all(log.get("id") != log_id for log in res["data"]["logs"]) - - _condition() - - @pytest.mark.p3 - def test_list_pipeline_dataset_logs(self, WebApiAuth, add_document): - kb_id, _ = add_document - res = kb_list_pipeline_dataset_logs(WebApiAuth, params={"kb_id": kb_id}, payload={}) - assert res["code"] == 0, res - assert "total" in res["data"], res - assert isinstance(res["data"]["logs"], list), res - - @pytest.mark.p3 - def test_pipeline_log_detail_missing_id(self, WebApiAuth): - res = kb_pipeline_log_detail(WebApiAuth, {}) - assert res["code"] == 101, res - assert "Pipeline log ID" in res["message"], res - - @pytest.mark.p3 - def test_delete_pipeline_logs_empty(self, WebApiAuth, add_document): - kb_id, _ = add_document - res = kb_delete_pipeline_logs(WebApiAuth, params={"kb_id": kb_id}, payload={"log_ids": []}) - assert res["code"] == 0, res - assert res["data"] is True, res - - @pytest.mark.p3 - def test_list_pipeline_logs_missing_kb_id(self, WebApiAuth): - res = kb_list_pipeline_logs(WebApiAuth, params={}, payload={}) - assert res["code"] == 101, res - assert "KB ID" in res["message"], res - - @pytest.mark.p3 - def test_list_pipeline_logs_abnormal_date_filter(self, WebApiAuth, add_document): - kb_id, _ = add_document - res = kb_list_pipeline_logs( - WebApiAuth, - params={ - "kb_id": kb_id, - "desc": "false", - "create_date_from": "2025-01-01", - "create_date_to": "2025-02-01", - }, - payload={}, - ) - assert res["code"] == 102, res - assert "Create data filter is abnormal." in res["message"], res diff --git a/test/testcases/test_web_api/test_kb_app/test_kb_routes_unit.py b/test/testcases/test_web_api/test_kb_app/test_kb_routes_unit.py deleted file mode 100644 index 998a231453..0000000000 --- a/test/testcases/test_web_api/test_kb_app/test_kb_routes_unit.py +++ /dev/null @@ -1,1021 +0,0 @@ -# -# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import asyncio -import importlib -import importlib.util -import inspect -import sys -from copy import deepcopy -from datetime import datetime -from pathlib import Path -from types import ModuleType, SimpleNamespace - -import pytest - -pytestmark = pytest.mark.filterwarnings("ignore:.*joblib will operate in serial mode.*:UserWarning") - - -class _DummyManager: - def route(self, *_args, **_kwargs): - def decorator(func): - return func - - return decorator - - -class _AwaitableValue: - def __init__(self, value): - self._value = value - - def __await__(self): - async def _co(): - return self._value - - return _co().__await__() - - -class _DummyArgs(dict): - def getlist(self, key): - value = self.get(key) - if value is None: - return [] - if isinstance(value, list): - return value - return [value] - - -class _DummyKB: - def __init__(self, *, kb_id="kb-1", name="old_kb", tenant_id="tenant-1", pagerank=0): - self.id = kb_id - self.name = name - self.tenant_id = tenant_id - self.pagerank = pagerank - self.parser_config = {} - - def to_dict(self): - return { - "id": self.id, - "name": self.name, - "tenant_id": self.tenant_id, - "pagerank": self.pagerank, - "parser_config": deepcopy(self.parser_config), - } - - -class _DummyTask: - def __init__(self, task_id, progress): - self.id = task_id - self.progress = progress - - def to_dict(self): - return {"id": self.id, "progress": self.progress} - - -def _run(coro): - return asyncio.run(coro) - - -def _unwrap_route(func): - route_func = inspect.unwrap(func) - visited = set() - while getattr(route_func, "__closure__", None) and route_func not in visited: - visited.add(route_func) - nested = None - for cell in route_func.__closure__: - candidate = cell.cell_contents - if inspect.isfunction(candidate) and candidate is not route_func: - nested = inspect.unwrap(candidate) - break - if nested is None: - break - route_func = nested - return route_func - - -def _load_kb_module(monkeypatch): - repo_root = Path(__file__).resolve().parents[4] - - common_pkg = ModuleType("common") - common_pkg.__path__ = [str(repo_root / "common")] - monkeypatch.setitem(sys.modules, "common", common_pkg) - - deepdoc_pkg = ModuleType("deepdoc") - deepdoc_parser_pkg = ModuleType("deepdoc.parser") - deepdoc_parser_pkg.__path__ = [] - - class _StubPdfParser: - pass - - class _StubExcelParser: - pass - - class _StubDocxParser: - pass - - deepdoc_parser_pkg.PdfParser = _StubPdfParser - deepdoc_parser_pkg.ExcelParser = _StubExcelParser - deepdoc_parser_pkg.DocxParser = _StubDocxParser - deepdoc_pkg.parser = deepdoc_parser_pkg - monkeypatch.setitem(sys.modules, "deepdoc", deepdoc_pkg) - monkeypatch.setitem(sys.modules, "deepdoc.parser", deepdoc_parser_pkg) - - deepdoc_excel_module = ModuleType("deepdoc.parser.excel_parser") - deepdoc_excel_module.RAGFlowExcelParser = _StubExcelParser - monkeypatch.setitem(sys.modules, "deepdoc.parser.excel_parser", deepdoc_excel_module) - - deepdoc_parser_utils = ModuleType("deepdoc.parser.utils") - deepdoc_parser_utils.get_text = lambda *_args, **_kwargs: "" - monkeypatch.setitem(sys.modules, "deepdoc.parser.utils", deepdoc_parser_utils) - monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost")) - - apps_mod = ModuleType("api.apps") - apps_mod.current_user = SimpleNamespace(id="user-1") - apps_mod.login_required = lambda func: func - monkeypatch.setitem(sys.modules, "api.apps", apps_mod) - - module_name = "test_kb_routes_unit_module" - module_path = repo_root / "api" / "apps" / "kb_app.py" - spec = importlib.util.spec_from_file_location(module_name, module_path) - module = importlib.util.module_from_spec(spec) - module.manager = _DummyManager() - monkeypatch.setitem(sys.modules, module_name, module) - spec.loader.exec_module(module) - return module - - -def _dataset_sdk_routes_unit_module(): - return importlib.import_module("test.testcases.test_web_api.test_dataset_management.test_dataset_sdk_routes_unit") - - -def _set_request_json(monkeypatch, module, payload): - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(deepcopy(payload))) - - -def _set_request_args(monkeypatch, module, args): - monkeypatch.setattr(module, "request", SimpleNamespace(args=_DummyArgs(args))) - - -def _base_update_payload(**kwargs): - payload = {"kb_id": "kb-1", "name": "new_kb", "description": "", "parser_id": "naive"} - payload.update(kwargs) - return payload - - -@pytest.fixture(scope="session") -def auth(): - return "unit-auth" - - -@pytest.fixture(scope="session", autouse=True) -def set_tenant_info(): - return None - - -@pytest.mark.p3 -def test_create_branches(monkeypatch): - module = _dataset_sdk_routes_unit_module() - module.test_create_route_error_matrix_unit(monkeypatch) - - -@pytest.mark.p3 -def test_update_branches(monkeypatch): - module = _dataset_sdk_routes_unit_module() - module.test_update_route_branch_matrix_unit(monkeypatch) - - -@pytest.mark.p3 -def test_update_metadata_setting_not_found(monkeypatch): - module = _load_kb_module(monkeypatch) - _set_request_json(monkeypatch, module, {"kb_id": "missing-kb", "metadata": {}}) - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None)) - res = _run(inspect.unwrap(module.update_metadata_setting)()) - assert res["code"] == module.RetCode.DATA_ERROR, res - assert "Database error" in res["message"], res - - -@pytest.mark.p3 -def test_detail_branches(monkeypatch): - module = _load_kb_module(monkeypatch) - - _set_request_args(monkeypatch, module, {"kb_id": "kb-1"}) - monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")]) - monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: []) - res = inspect.unwrap(module.detail)() - assert res["code"] == module.RetCode.OPERATING_ERROR, res - - _set_request_args(monkeypatch, module, {"kb_id": "kb-1"}) - monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [SimpleNamespace(id="kb-1")]) - monkeypatch.setattr(module.KnowledgebaseService, "get_detail", lambda _kb_id: None) - res = inspect.unwrap(module.detail)() - assert res["code"] == module.RetCode.DATA_ERROR, res - assert "Can't find this dataset" in res["message"], res - - finish_at = datetime(2025, 1, 1, 12, 30, 0) - kb_detail = { - "id": "kb-1", - "parser_config": {"metadata": {"x": "y"}}, - "graphrag_task_finish_at": finish_at, - "raptor_task_finish_at": finish_at, - "mindmap_task_finish_at": finish_at, - } - monkeypatch.setattr(module.KnowledgebaseService, "get_detail", lambda _kb_id: deepcopy(kb_detail)) - monkeypatch.setattr(module.DocumentService, "get_total_size_by_kb_id", lambda **_kwargs: 1024) - monkeypatch.setattr(module.Connector2KbService, "list_connectors", lambda _kb_id: ["conn-1"]) - monkeypatch.setattr(module, "turn2jsonschema", lambda metadata: {"type": "object", "properties": metadata}) - res = inspect.unwrap(module.detail)() - assert res["code"] == module.RetCode.SUCCESS, res - assert res["data"]["size"] == 1024, res - assert res["data"]["connectors"] == ["conn-1"], res - assert isinstance(res["data"]["parser_config"]["metadata"], dict), res - assert res["data"]["graphrag_task_finish_at"] == "2025-01-01 12:30:00", res - - def _raise_tenants(**_kwargs): - raise RuntimeError("detail boom") - monkeypatch.setattr(module.UserTenantService, "query", _raise_tenants) - res = inspect.unwrap(module.detail)() - assert res["code"] == module.RetCode.EXCEPTION_ERROR, res - assert "detail boom" in res["message"], res - - -@pytest.mark.p3 -def test_list_kbs_owner_ids_and_desc(monkeypatch): - module = _dataset_sdk_routes_unit_module() - module.test_list_knowledge_graph_delete_kg_matrix_unit(monkeypatch) - - -@pytest.mark.p3 -def test_rm_and_rm_sync_branches(monkeypatch): - module = _dataset_sdk_routes_unit_module() - module.test_delete_route_error_summary_matrix_unit(monkeypatch) - - -@pytest.mark.p3 -def test_tags_and_meta_branches(monkeypatch): - module = _load_kb_module(monkeypatch) - - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False) - res = inspect.unwrap(module.list_tags)("kb-1") - assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res - - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True) - monkeypatch.setattr(module.UserTenantService, "get_tenants_by_user_id", lambda _uid: [{"tenant_id": "tenant-1"}, {"tenant_id": "tenant-2"}]) - monkeypatch.setattr(module.settings, "retriever", SimpleNamespace(all_tags=lambda tenant_id, kb_ids: [f"{tenant_id}:{kb_ids[0]}"])) - res = inspect.unwrap(module.list_tags)("kb-1") - assert res["code"] == module.RetCode.SUCCESS, res - assert len(res["data"]) == 2, res - - _set_request_args(monkeypatch, module, {"kb_ids": "kb-1,kb-2"}) - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda kb_id, _uid: kb_id == "kb-1") - res = inspect.unwrap(module.list_tags_from_kbs)() - assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res - - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True) - res = inspect.unwrap(module.list_tags_from_kbs)() - assert res["code"] == module.RetCode.SUCCESS, res - assert isinstance(res["data"], list), res - - _set_request_json(monkeypatch, module, {"tags": ["a", "b"]}) - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False) - res = _run(inspect.unwrap(module.rm_tags)("kb-1")) - assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res - - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True) - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB(tenant_id="tenant-1"))) - monkeypatch.setattr(module.settings, "docStoreConn", SimpleNamespace(update=lambda *_args, **_kwargs: True)) - monkeypatch.setattr(module.search, "index_name", lambda _tenant_id: "idx") - res = _run(inspect.unwrap(module.rm_tags)("kb-1")) - assert res["code"] == module.RetCode.SUCCESS, res - - _set_request_json(monkeypatch, module, {"from_tag": "a", "to_tag": "b"}) - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False) - res = _run(inspect.unwrap(module.rename_tags)("kb-1")) - assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res - - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True) - res = _run(inspect.unwrap(module.rename_tags)("kb-1")) - assert res["code"] == module.RetCode.SUCCESS, res - - _set_request_args(monkeypatch, module, {"kb_ids": "kb-1,kb-2"}) - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda kb_id, _uid: kb_id == "kb-1") - res = inspect.unwrap(module.get_meta)() - assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res - - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True) - monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kb_ids: {"source": ["a"]}) - res = inspect.unwrap(module.get_meta)() - assert res["code"] == module.RetCode.SUCCESS, res - assert "source" in res["data"], res - - _set_request_args(monkeypatch, module, {"kb_id": "kb-1"}) - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False) - res = inspect.unwrap(module.get_basic_info)() - assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res - - monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True) - monkeypatch.setattr(module.DocumentService, "knowledgebase_basic_info", lambda _kb_id: {"finished": 1}) - res = inspect.unwrap(module.get_basic_info)() - assert res["code"] == module.RetCode.SUCCESS, res - assert res["data"]["finished"] == 1, res - - -@pytest.mark.p3 -def test_knowledge_graph_branches(monkeypatch): - module = _dataset_sdk_routes_unit_module() - module.test_list_knowledge_graph_delete_kg_matrix_unit(monkeypatch) - - -@pytest.mark.p3 -def test_list_pipeline_logs_validation_branches(monkeypatch): - module = _load_kb_module(monkeypatch) - - _set_request_args(monkeypatch, module, {}) - _set_request_json(monkeypatch, module, {}) - res = _run(inspect.unwrap(module.list_pipeline_logs)()) - assert res["code"] == module.RetCode.ARGUMENT_ERROR, res - assert "KB ID" in res["message"], res - - _set_request_args( - monkeypatch, - module, - { - "kb_id": "kb-1", - "keywords": "k", - "page": "1", - "page_size": "10", - "orderby": "create_time", - "desc": "false", - "create_date_from": "2025-02-01", - "create_date_to": "2025-01-01", - }, - ) - _set_request_json(monkeypatch, module, {}) - monkeypatch.setattr(module.PipelineOperationLogService, "get_file_logs_by_kb_id", lambda *_args, **_kwargs: ([], 0)) - res = _run(inspect.unwrap(module.list_pipeline_logs)()) - assert res["code"] == module.RetCode.SUCCESS, res - assert res["data"]["total"] == 0, res - - _set_request_args( - monkeypatch, - module, - { - "kb_id": "kb-1", - "create_date_from": "2025-01-01", - "create_date_to": "2025-02-01", - }, - ) - _set_request_json(monkeypatch, module, {}) - res = _run(inspect.unwrap(module.list_pipeline_logs)()) - assert res["code"] == module.RetCode.DATA_ERROR, res - assert "Create data filter is abnormal." in res["message"], res - - -@pytest.mark.p3 -def test_list_pipeline_logs_filter_and_exception_branches(monkeypatch): - module = _load_kb_module(monkeypatch) - - _set_request_args( - monkeypatch, - module, - { - "kb_id": "kb-1", - "page": "1", - "page_size": "10", - "desc": "false", - "create_date_from": "2025-02-01", - "create_date_to": "2025-01-01", - }, - ) - - _set_request_json(monkeypatch, module, {"operation_status": ["BAD_STATUS"]}) - res = _run(inspect.unwrap(module.list_pipeline_logs)()) - assert res["code"] == module.RetCode.DATA_ERROR, res - assert "operation_status" in res["message"], res - - _set_request_json(monkeypatch, module, {"types": ["bad_type"]}) - res = _run(inspect.unwrap(module.list_pipeline_logs)()) - assert res["code"] == module.RetCode.DATA_ERROR, res - assert "Invalid filter conditions" in res["message"], res - - def _raise_file_logs(*_args, **_kwargs): - raise RuntimeError("logs boom") - - _set_request_json(monkeypatch, module, {"suffix": [".txt"]}) - monkeypatch.setattr(module.PipelineOperationLogService, "get_file_logs_by_kb_id", _raise_file_logs) - res = _run(inspect.unwrap(module.list_pipeline_logs)()) - assert res["code"] == module.RetCode.EXCEPTION_ERROR, res - assert "logs boom" in res["message"], res - - -@pytest.mark.p3 -def test_list_pipeline_dataset_logs_branches(monkeypatch): - module = _load_kb_module(monkeypatch) - - _set_request_args(monkeypatch, module, {}) - _set_request_json(monkeypatch, module, {}) - res = _run(inspect.unwrap(module.list_pipeline_dataset_logs)()) - assert res["code"] == module.RetCode.ARGUMENT_ERROR, res - assert "KB ID" in res["message"], res - - _set_request_args( - monkeypatch, - module, - { - "kb_id": "kb-1", - "desc": "false", - "create_date_from": "2025-01-01", - "create_date_to": "2025-02-01", - }, - ) - _set_request_json(monkeypatch, module, {}) - res = _run(inspect.unwrap(module.list_pipeline_dataset_logs)()) - assert res["code"] == module.RetCode.DATA_ERROR, res - assert "Create data filter is abnormal." in res["message"], res - - _set_request_args( - monkeypatch, - module, - { - "kb_id": "kb-1", - "page": "1", - "page_size": "10", - "desc": "false", - "create_date_from": "2025-02-01", - "create_date_to": "2025-01-01", - }, - ) - _set_request_json(monkeypatch, module, {"operation_status": ["NOT_A_STATUS"]}) - res = _run(inspect.unwrap(module.list_pipeline_dataset_logs)()) - assert res["code"] == module.RetCode.DATA_ERROR, res - assert "operation_status" in res["message"], res - - _set_request_args( - monkeypatch, - module, - { - "kb_id": "kb-1", - "page": "1", - "page_size": "10", - "desc": "true", - "create_date_from": "2025-02-01", - "create_date_to": "2025-01-01", - }, - ) - _set_request_json(monkeypatch, module, {"operation_status": []}) - monkeypatch.setattr( - module.PipelineOperationLogService, - "get_dataset_logs_by_kb_id", - lambda *_args, **_kwargs: ([{"id": "l1"}], 1), - ) - res = _run(inspect.unwrap(module.list_pipeline_dataset_logs)()) - assert res["code"] == module.RetCode.SUCCESS, res - assert res["data"]["total"] == 1, res - assert res["data"]["logs"][0]["id"] == "l1", res - - def _raise_dataset_logs(*_args, **_kwargs): - raise RuntimeError("dataset logs boom") - - monkeypatch.setattr(module.PipelineOperationLogService, "get_dataset_logs_by_kb_id", _raise_dataset_logs) - res = _run(inspect.unwrap(module.list_pipeline_dataset_logs)()) - assert res["code"] == module.RetCode.EXCEPTION_ERROR, res - assert "dataset logs boom" in res["message"], res - - -@pytest.mark.p3 -def test_pipeline_log_detail_and_delete_routes_branches(monkeypatch): - module = _load_kb_module(monkeypatch) - - _set_request_args(monkeypatch, module, {}) - _set_request_json(monkeypatch, module, {}) - res = _run(inspect.unwrap(module.delete_pipeline_logs)()) - assert res["code"] == module.RetCode.ARGUMENT_ERROR, res - assert "KB ID" in res["message"], res - - deleted_ids = [] - - def _delete_by_ids(log_ids): - deleted_ids.extend(log_ids) - - monkeypatch.setattr(module.PipelineOperationLogService, "delete_by_ids", _delete_by_ids) - _set_request_args(monkeypatch, module, {"kb_id": "kb-1"}) - _set_request_json(monkeypatch, module, {}) - res = _run(inspect.unwrap(module.delete_pipeline_logs)()) - assert res["code"] == module.RetCode.SUCCESS, res - assert res["data"] is True, res - assert deleted_ids == [], deleted_ids - - _set_request_json(monkeypatch, module, {"log_ids": ["l1", "l2"]}) - res = _run(inspect.unwrap(module.delete_pipeline_logs)()) - assert res["code"] == module.RetCode.SUCCESS, res - assert deleted_ids == ["l1", "l2"], deleted_ids - - _set_request_args(monkeypatch, module, {}) - res = inspect.unwrap(module.pipeline_log_detail)() - assert res["code"] == module.RetCode.ARGUMENT_ERROR, res - assert "Pipeline log ID" in res["message"], res - - _set_request_args(monkeypatch, module, {"log_id": "missing"}) - monkeypatch.setattr(module.PipelineOperationLogService, "get_by_id", lambda _log_id: (False, None)) - res = inspect.unwrap(module.pipeline_log_detail)() - assert res["code"] == module.RetCode.DATA_ERROR, res - assert "Invalid pipeline log ID" in res["message"], res - - class _Log: - def to_dict(self): - return {"id": "log-1", "status": "ok"} - - monkeypatch.setattr(module.PipelineOperationLogService, "get_by_id", lambda _log_id: (True, _Log())) - res = inspect.unwrap(module.pipeline_log_detail)() - assert res["code"] == module.RetCode.SUCCESS, res - assert res["data"]["id"] == "log-1", res - - -@pytest.mark.p3 -@pytest.mark.parametrize( - "route_name,task_attr,response_key,task_type", - [ - ("run_graphrag", "graphrag_task_id", "graphrag_task_id", "graphrag"), - ("run_raptor", "raptor_task_id", "raptor_task_id", "raptor"), - ("run_mindmap", "mindmap_task_id", "mindmap_task_id", "mindmap"), - ], -) -def test_run_pipeline_task_routes_branch_matrix(monkeypatch, route_name, task_attr, response_key, task_type): - if route_name in {"run_graphrag", "run_raptor"}: - module = _dataset_sdk_routes_unit_module() - if route_name == "run_graphrag": - module.test_run_trace_graphrag_matrix_unit(monkeypatch) - else: - module.test_run_trace_raptor_matrix_unit(monkeypatch) - return - - module = _load_kb_module(monkeypatch) - route = inspect.unwrap(getattr(module, route_name)) - - def _make_kb(task_id): - payload = { - "id": "kb-1", - "tenant_id": "tenant-1", - "graphrag_task_id": "", - "raptor_task_id": "", - "mindmap_task_id": "", - } - payload[task_attr] = task_id - return SimpleNamespace(**payload) - - warnings = [] - monkeypatch.setattr(module.logging, "warning", lambda msg, *_args, **_kwargs: warnings.append(msg)) - - _set_request_json(monkeypatch, module, {"kb_id": ""}) - res = _run(route()) - assert res["code"] == module.RetCode.DATA_ERROR, res - assert "KB ID" in res["message"], res - - _set_request_json(monkeypatch, module, {"kb_id": "kb-1"}) - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None)) - res = _run(route()) - assert res["code"] == module.RetCode.DATA_ERROR, res - assert "Invalid Knowledgebase ID" in res["message"], res - - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _make_kb("task-running"))) - monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (True, SimpleNamespace(progress=0))) - res = _run(route()) - assert res["code"] == module.RetCode.DATA_ERROR, res - assert "already running" in res["message"], res - - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _make_kb("task-stale"))) - monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (False, None)) - monkeypatch.setattr(module.DocumentService, "get_by_kb_id", lambda **_kwargs: ([], 0)) - res = _run(route()) - assert res["code"] == module.RetCode.DATA_ERROR, res - assert "No documents in Knowledgebase kb-1" in res["message"], res - assert warnings, "Expected warning for stale task id" - - queue_calls = {} - - def _queue_stub(**kwargs): - queue_calls.update(kwargs) - return "queued-task-id" - - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _make_kb(""))) - monkeypatch.setattr( - module.DocumentService, - "get_by_kb_id", - lambda **_kwargs: ([{"id": "doc-1"}, {"id": "doc-2"}], 2), - ) - monkeypatch.setattr(module, "queue_raptor_o_graphrag_tasks", _queue_stub) - monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: False) - res = _run(route()) - assert res["code"] == module.RetCode.SUCCESS, res - assert res["data"][response_key] == "queued-task-id", res - assert queue_calls["ty"] == task_type, queue_calls - assert queue_calls["doc_ids"] == ["doc-1", "doc-2"], queue_calls - - -@pytest.mark.p3 -@pytest.mark.parametrize( - "route_name,task_attr,empty_on_missing_task,error_text", - [ - ("trace_graphrag", "graphrag_task_id", True, ""), - ("trace_raptor", "raptor_task_id", False, "RAPTOR Task Not Found or Error Occurred"), - ("trace_mindmap", "mindmap_task_id", False, "Mindmap Task Not Found or Error Occurred"), - ], -) -def test_trace_pipeline_task_routes_branch_matrix(monkeypatch, route_name, task_attr, empty_on_missing_task, error_text): - if route_name in {"trace_graphrag", "trace_raptor"}: - module = _dataset_sdk_routes_unit_module() - if route_name == "trace_graphrag": - module.test_run_trace_graphrag_matrix_unit(monkeypatch) - else: - module.test_run_trace_raptor_matrix_unit(monkeypatch) - return - - module = _load_kb_module(monkeypatch) - route = inspect.unwrap(getattr(module, route_name)) - - def _make_kb(task_id): - payload = { - "id": "kb-1", - "tenant_id": "tenant-1", - "graphrag_task_id": "", - "raptor_task_id": "", - "mindmap_task_id": "", - } - payload[task_attr] = task_id - return SimpleNamespace(**payload) - - _set_request_args(monkeypatch, module, {"kb_id": ""}) - res = route() - assert res["code"] == module.RetCode.DATA_ERROR, res - assert "KB ID" in res["message"], res - - _set_request_args(monkeypatch, module, {"kb_id": "kb-1"}) - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None)) - res = route() - assert res["code"] == module.RetCode.DATA_ERROR, res - assert "Invalid Knowledgebase ID" in res["message"], res - - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _make_kb(""))) - res = route() - assert res["code"] == module.RetCode.SUCCESS, res - assert res["data"] == {}, res - - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _make_kb("task-1"))) - monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (False, None)) - res = route() - if empty_on_missing_task: - assert res["code"] == module.RetCode.SUCCESS, res - assert res["data"] == {}, res - else: - assert res["code"] == module.RetCode.DATA_ERROR, res - assert error_text in res["message"], res - - monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (True, _DummyTask("task-1", 1))) - res = route() - assert res["code"] == module.RetCode.SUCCESS, res - assert res["data"]["id"] == "task-1", res - - -@pytest.mark.p3 -def test_unbind_task_branch_matrix(monkeypatch): - module = _load_kb_module(monkeypatch) - route = inspect.unwrap(module.delete_kb_task) - - _set_request_args(monkeypatch, module, {"kb_id": ""}) - res = route() - assert res["code"] == module.RetCode.DATA_ERROR, res - assert "KB ID" in res["message"], res - - _set_request_args(monkeypatch, module, {"kb_id": "missing", "pipeline_task_type": module.PipelineTaskType.GRAPH_RAG}) - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None)) - res = route() - assert res["code"] == module.RetCode.SUCCESS, res - assert res["data"] is True, res - - kb = SimpleNamespace( - id="kb-1", - tenant_id="tenant-1", - graphrag_task_id="graph-task", - raptor_task_id="raptor-task", - mindmap_task_id="mindmap-task", - ) - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb)) - _set_request_args(monkeypatch, module, {"kb_id": "kb-1", "pipeline_task_type": "unknown"}) - res = route() - assert res["code"] == module.RetCode.DATA_ERROR, res - assert "Invalid task type" in res["message"], res - - cancelled = [] - deleted = [] - update_payloads = [] - monkeypatch.setattr(module.REDIS_CONN, "set", lambda key, value: cancelled.append((key, value))) - monkeypatch.setattr(module.search, "index_name", lambda _tenant_id: "idx") - monkeypatch.setattr(module.settings, "docStoreConn", SimpleNamespace(delete=lambda *args, **_kwargs: deleted.append(args))) - - def _record_update(_kb_id, payload): - update_payloads.append((_kb_id, payload)) - return True - - monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", _record_update) - - _set_request_args(monkeypatch, module, {"kb_id": "kb-1", "pipeline_task_type": module.PipelineTaskType.GRAPH_RAG}) - res = route() - assert res["code"] == module.RetCode.SUCCESS, res - - _set_request_args(monkeypatch, module, {"kb_id": "kb-1", "pipeline_task_type": module.PipelineTaskType.RAPTOR}) - res = route() - assert res["code"] == module.RetCode.SUCCESS, res - - _set_request_args(monkeypatch, module, {"kb_id": "kb-1", "pipeline_task_type": module.PipelineTaskType.MINDMAP}) - res = route() - assert res["code"] == module.RetCode.SUCCESS, res - - assert ("graph-task-cancel", "x") in cancelled, cancelled - assert ("raptor-task-cancel", "x") in cancelled, cancelled - assert ("mindmap-task-cancel", "x") in cancelled, cancelled - assert len(deleted) == 2, deleted - assert any(payload.get("graphrag_task_id") == "" for _, payload in update_payloads), update_payloads - assert any(payload.get("raptor_task_id") == "" for _, payload in update_payloads), update_payloads - assert any(payload.get("mindmap_task_id") == "" for _, payload in update_payloads), update_payloads - - class _FlakyPipelineType: - def __init__(self, target): - self.target = target - self.calls = 0 - - def __eq__(self, other): - self.calls += 1 - if self.calls == 1: - return other == self.target - return False - - _set_request_args( - monkeypatch, - module, - {"kb_id": "kb-1", "pipeline_task_type": _FlakyPipelineType(module.PipelineTaskType.GRAPH_RAG)}, - ) - res = route() - assert res["code"] == module.RetCode.DATA_ERROR, res - assert "Internal Error: Invalid task type" in res["message"], res - - monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: False) - monkeypatch.setattr(module, "server_error_response", lambda e: module.get_json_result(code=module.RetCode.EXCEPTION_ERROR, message=str(e))) - _set_request_args(monkeypatch, module, {"kb_id": "kb-1", "pipeline_task_type": module.PipelineTaskType.GRAPH_RAG}) - res = route() - assert res["code"] == module.RetCode.EXCEPTION_ERROR, res - assert "cannot delete task" in res["message"], res - - -@pytest.mark.p3 -def test_check_embedding_similarity_threshold_matrix_unit(monkeypatch): - module = _load_kb_module(monkeypatch) - route = inspect.unwrap(module.check_embedding) - monkeypatch.setattr( - module, - "get_model_config_by_type_and_name", - lambda *_args, **_kwargs: {"llm_factory": "test", "llm_name": "emb-1", "model_type": module.LLMType.EMBEDDING.value}, - ) - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-1"))) - monkeypatch.setattr(module.search, "index_name", lambda _tenant_id: "idx") - - class _FlipBool: - def __init__(self): - self._calls = 0 - - def __bool__(self): - self._calls += 1 - return self._calls == 1 - - monkeypatch.setattr( - module.re, - "sub", - lambda _pattern, _repl, text: _FlipBool() if "TRIGGER_NO_TEXT" in str(text) else text, - ) - - def _fixed_sample(population, k): - return list(population)[:k] - - monkeypatch.setattr(module.random, "sample", _fixed_sample) - - class _DocStore: - def __init__(self, total, ids_by_offset, docs): - self.total = total - self.ids_by_offset = ids_by_offset - self.docs = docs - - def search(self, select_fields, **kwargs): - if not select_fields: - return {"kind": "total"} - return {"kind": "sample", "offset": kwargs["offset"]} - - def get_total(self, _res): - return self.total - - def get_doc_ids(self, res): - return self.ids_by_offset.get(res.get("offset", -1), []) - - def get(self, cid, _index_name, _kb_ids): - return self.docs.get(cid, {}) - - class _EmbModel: - def __init__(self): - self.calls = [] - - def encode(self, pair): - title, _txt = pair - self.calls.append(title) - if title == "Doc Mix": - # title+content mix wins over content only path. - return [module.np.array([1.0, 0.0]), module.np.array([0.0, 1.0])], None - if title == "Doc High": - return [module.np.array([1.0, 0.0]), module.np.array([1.0, 0.0])], None - return [module.np.array([0.0, 1.0]), module.np.array([0.0, 1.0])], None - - emb_model = _EmbModel() - monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: emb_model) - - low_docs = { - "chunk-no-vec": { - "doc_id": "doc-no-vec", - "docnm_kwd": "Doc No Vec", - "content_with_weight": "body-no-vec", - "page_num_int": 1, - "position_int": 1, - "top_int": 1, - }, - "chunk-bad-type": { - "doc_id": "doc-bad-type", - "docnm_kwd": "Doc Bad Type", - "content_with_weight": "body-bad-type", - "question_kwd": [], - "q_vec": {"bad": "type"}, - "page_num_int": 1, - "position_int": 2, - "top_int": 2, - }, - "chunk-low-zero": { - "doc_id": "doc-low-zero", - "docnm_kwd": "Doc Low Zero", - "content_with_weight": "body-low", - "question_kwd": [], - "q_vec": "0\t0", - "page_num_int": 1, - "position_int": 3, - "top_int": 3, - }, - "chunk-no-text": { - "doc_id": "doc-no-text", - "docnm_kwd": "Doc No Text", - "content_with_weight": "TRIGGER_NO_TEXT", - "q_vec": [1.0, 0.0], - "page_num_int": 1, - "position_int": 4, - "top_int": 4, - }, - "chunk-mix": { - "doc_id": "doc-mix", - "docnm_kwd": "Doc Mix", - "content_with_weight": "body-mix", - "q_vec": [1.0, 0.0], - "page_num_int": 1, - "position_int": 5, - "top_int": 5, - }, - } - - monkeypatch.setattr( - module.settings, - "docStoreConn", - _DocStore( - total=6, - ids_by_offset={ - 0: [], - 1: ["chunk-no-vec"], - 2: ["chunk-bad-type"], - 3: ["chunk-low-zero"], - 4: ["chunk-no-text"], - 5: ["chunk-mix"], - }, - docs=low_docs, - ), - ) - - _set_request_json(monkeypatch, module, {"kb_id": "kb-1", "embd_id": "emb-1", "check_num": 6}) - res = _run(route()) - assert res["code"] == module.RetCode.NOT_EFFECTIVE, res - assert "average similarity" in res["message"], res - summary = res["data"]["summary"] - assert summary["sampled"] == 5, summary - assert summary["valid"] == 2, summary - reasons = {item.get("reason") for item in res["data"]["results"] if "reason" in item} - assert "no_stored_vector" in reasons, res - assert "no_text" in reasons, res - assert any(item.get("chunk_id") == "chunk-low-zero" and "cos_sim" in item for item in res["data"]["results"]), res - assert summary["match_mode"] in {"content_only", "title+content"}, summary - - high_docs = { - "chunk-high": { - "doc_id": "doc-high", - "docnm_kwd": "Doc High", - "content_with_weight": "body-high", - "q_vec": [1.0, 0.0], - "page_num_int": 1, - "position_int": 1, - "top_int": 1, - } - } - monkeypatch.setattr( - module.settings, - "docStoreConn", - _DocStore(total=1, ids_by_offset={0: ["chunk-high"]}, docs=high_docs), - ) - _set_request_json(monkeypatch, module, {"kb_id": "kb-1", "embd_id": "emb-1", "check_num": 1}) - res = _run(route()) - assert res["code"] == module.RetCode.SUCCESS, res - assert res["data"]["summary"]["avg_cos_sim"] > 0.9, res - - -@pytest.mark.p3 -def test_check_embedding_error_and_empty_sample_paths_unit(monkeypatch): - module = _load_kb_module(monkeypatch) - route = inspect.unwrap(module.check_embedding) - monkeypatch.setattr( - module, - "get_model_config_by_type_and_name", - lambda *_args, **_kwargs: {"llm_factory": "test", "llm_name": "emb-1", "model_type": module.LLMType.EMBEDDING.value}, - ) - monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-1"))) - monkeypatch.setattr(module.search, "index_name", lambda _tenant_id: "idx") - monkeypatch.setattr(module.random, "sample", lambda population, k: list(population)[:k]) - - class _DocStore: - def __init__(self, total, ids_by_offset, docs): - self.total = total - self.ids_by_offset = ids_by_offset - self.docs = docs - - def search(self, select_fields, **kwargs): - if not select_fields: - return {"kind": "total"} - return {"kind": "sample", "offset": kwargs["offset"]} - - def get_total(self, _res): - return self.total - - def get_doc_ids(self, res): - return self.ids_by_offset.get(res.get("offset", -1), []) - - def get(self, cid, _index_name, _kb_ids): - return self.docs.get(cid, {}) - - class _BoomEmbModel: - def encode(self, _pair): - raise RuntimeError("encode boom") - - monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _BoomEmbModel()) - monkeypatch.setattr( - module.settings, - "docStoreConn", - _DocStore( - total=1, - ids_by_offset={0: ["chunk-err"]}, - docs={ - "chunk-err": { - "doc_id": "doc-err", - "docnm_kwd": "Doc Err", - "content_with_weight": "body-err", - "q_vec": [1.0, 0.0], - "page_num_int": 1, - "position_int": 1, - "top_int": 1, - } - }, - ), - ) - _set_request_json(monkeypatch, module, {"kb_id": "kb-1", "embd_id": "emb-1", "check_num": 1}) - res = _run(route()) - assert res["code"] == module.RetCode.DATA_ERROR, res - assert "Embedding failure." in res["message"], res - assert "encode boom" in res["message"], res - - class _OkEmbModel: - def encode(self, _pair): - return [module.np.array([1.0, 0.0]), module.np.array([1.0, 0.0])], None - - monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _OkEmbModel()) - monkeypatch.setattr(module.settings, "docStoreConn", _DocStore(total=0, ids_by_offset={}, docs={})) - _set_request_json(monkeypatch, module, {"kb_id": "kb-1", "embd_id": "emb-1", "check_num": 1}) - with pytest.raises(UnboundLocalError): - _run(route()) diff --git a/test/testcases/test_web_api/test_kb_app/test_kb_tags_meta.py b/test/testcases/test_web_api/test_kb_app/test_kb_tags_meta.py deleted file mode 100644 index aed597e24b..0000000000 --- a/test/testcases/test_web_api/test_kb_app/test_kb_tags_meta.py +++ /dev/null @@ -1,296 +0,0 @@ -# -# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import uuid - -import pytest -from test_common import ( - delete_knowledge_graph, - kb_basic_info, - kb_get_meta, - kb_update_metadata_setting, - knowledge_graph, - list_tags, - list_tags_from_kbs, - rename_tags, - rm_tags, - update_chunk, -) -from configs import INVALID_API_TOKEN -from libs.auth import RAGFlowWebApiAuth -from utils import wait_for - -INVALID_AUTH_CASES = [ - (None, 401, "Unauthorized"), - (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "Unauthorized"), -] - -TAG_SEED_TIMEOUT = 20 - - -def _wait_for_tag(auth, kb_id, tag, timeout=TAG_SEED_TIMEOUT): - @wait_for(timeout, 1, "Tag seed timeout") - def _condition(): - res = list_tags(auth, kb_id) - if res["code"] != 0: - return False - return tag in res["data"] - - try: - _condition() - except AssertionError: - return False - return True - - -def _seed_tag(auth, kb_id, document_id, chunk_id): - # KB tags are derived from chunk tag_kwd, not document metadata. - tag = f"tag_{uuid.uuid4().hex[:8]}" - res = update_chunk( - auth, - kb_id, - document_id, - chunk_id, - { - "content": f"tag seed {tag}", - "tag_kwd": [tag], - }, - ) - assert res["code"] == 0, res - if not _wait_for_tag(auth, kb_id, tag): - return None - return tag - - -class TestAuthorization: - @pytest.mark.p2 - @pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) - def test_list_tags_auth_invalid(self, invalid_auth, expected_code, expected_fragment): - res = list_tags(invalid_auth, "kb_id") - assert res["code"] == expected_code, res - assert expected_fragment in res["message"], res - - @pytest.mark.p2 - @pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) - def test_list_tags_from_kbs_auth_invalid(self, invalid_auth, expected_code, expected_fragment): - res = list_tags_from_kbs(invalid_auth, {"kb_ids": "kb_id"}) - assert res["code"] == expected_code, res - assert expected_fragment in res["message"], res - - @pytest.mark.p2 - @pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) - def test_rm_tags_auth_invalid(self, invalid_auth, expected_code, expected_fragment): - res = rm_tags(invalid_auth, "kb_id", {"tags": ["tag"]}) - assert res["code"] == expected_code, res - assert expected_fragment in res["message"], res - - @pytest.mark.p2 - @pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) - def test_rename_tag_auth_invalid(self, invalid_auth, expected_code, expected_fragment): - res = rename_tags(invalid_auth, "kb_id", {"from_tag": "old", "to_tag": "new"}) - assert res["code"] == expected_code, res - assert expected_fragment in res["message"], res - - @pytest.mark.p2 - @pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) - def test_get_meta_auth_invalid(self, invalid_auth, expected_code, expected_fragment): - res = kb_get_meta(invalid_auth, {"kb_ids": "kb_id"}) - assert res["code"] == expected_code, res - assert expected_fragment in res["message"], res - - @pytest.mark.p2 - @pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) - def test_basic_info_auth_invalid(self, invalid_auth, expected_code, expected_fragment): - res = kb_basic_info(invalid_auth, {"kb_id": "kb_id"}) - assert res["code"] == expected_code, res - assert expected_fragment in res["message"], res - - @pytest.mark.p2 - @pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) - def test_update_metadata_setting_auth_invalid(self, invalid_auth, expected_code, expected_fragment): - res = kb_update_metadata_setting(invalid_auth, {"kb_id": "kb_id", "metadata": {}}) - assert res["code"] == expected_code, res - assert expected_fragment in res["message"], res - - @pytest.mark.p2 - @pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) - def test_knowledge_graph_auth_invalid(self, invalid_auth, expected_code, expected_fragment): - res = knowledge_graph(invalid_auth, "kb_id") - assert res["code"] == expected_code, res - assert expected_fragment in res["message"], res - - @pytest.mark.p2 - @pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) - def test_delete_knowledge_graph_auth_invalid(self, invalid_auth, expected_code, expected_fragment): - res = delete_knowledge_graph(invalid_auth, "kb_id") - assert res["code"] == expected_code, res - assert expected_fragment in res["message"], res - - -class TestKbTagsMeta: - @pytest.mark.p2 - def test_list_tags(self, WebApiAuth, add_dataset): - kb_id = add_dataset - res = list_tags(WebApiAuth, kb_id) - assert res["code"] == 0, res - assert isinstance(res["data"], list), res - - @pytest.mark.p2 - def test_list_tags_from_kbs(self, WebApiAuth, add_dataset): - kb_id = add_dataset - res = list_tags_from_kbs(WebApiAuth, {"kb_ids": kb_id}) - assert res["code"] == 0, res - assert isinstance(res["data"], list), res - - @pytest.mark.p3 - def test_rm_tags(self, WebApiAuth, add_chunks): - kb_id, document_id, chunk_ids = add_chunks - tag_to_remove = _seed_tag(WebApiAuth, kb_id, document_id, chunk_ids[0]) - if not tag_to_remove: - # Tag aggregation is index-backed; skip if it never surfaces. - pytest.skip("Seeded tag did not appear in list_tags.") - - res = rm_tags(WebApiAuth, kb_id, {"tags": [tag_to_remove]}) - assert res["code"] == 0, res - assert res["data"] is True, res - - @wait_for(TAG_SEED_TIMEOUT, 1, "Tag removal timeout") - def _condition(): - after_res = list_tags(WebApiAuth, kb_id) - if after_res["code"] != 0: - return False - return tag_to_remove not in after_res["data"] - - _condition() - - @pytest.mark.p3 - def test_rename_tag(self, WebApiAuth, add_chunks): - kb_id, document_id, chunk_ids = add_chunks - from_tag = _seed_tag(WebApiAuth, kb_id, document_id, chunk_ids[0]) - if not from_tag: - # Tag aggregation is index-backed; skip if it never surfaces. - pytest.skip("Seeded tag did not appear in list_tags.") - - to_tag = f"{from_tag}_renamed" - res = rename_tags(WebApiAuth, kb_id, {"from_tag": from_tag, "to_tag": to_tag}) - assert res["code"] == 0, res - assert res["data"] is True, res - - @wait_for(TAG_SEED_TIMEOUT, 1, "Tag rename timeout") - def _condition(): - after_res = list_tags(WebApiAuth, kb_id) - if after_res["code"] != 0: - return False - tags = after_res["data"] - return to_tag in tags and from_tag not in tags - - _condition() - - @pytest.mark.p2 - def test_get_meta(self, WebApiAuth, add_dataset): - kb_id = add_dataset - res = kb_get_meta(WebApiAuth, {"kb_ids": kb_id}) - assert res["code"] == 0, res - assert isinstance(res["data"], dict), res - - @pytest.mark.p2 - def test_basic_info(self, WebApiAuth, add_dataset): - kb_id = add_dataset - res = kb_basic_info(WebApiAuth, {"kb_id": kb_id}) - assert res["code"] == 0, res - for key in ["processing", "finished", "failed", "cancelled", "downloaded"]: - assert key in res["data"], res - - @pytest.mark.p2 - def test_update_metadata_setting(self, WebApiAuth, add_dataset): - kb_id = add_dataset - metadata = {"source": "test"} - res = kb_update_metadata_setting(WebApiAuth, {"kb_id": kb_id, "metadata": metadata, "enable_metadata": True}) - assert res["code"] == 0, res - assert res["data"]["id"] == kb_id, res - assert res["data"]["parser_config"]["metadata"] == metadata, res - - @pytest.mark.p2 - def test_knowledge_graph(self, WebApiAuth, add_dataset): - kb_id = add_dataset - res = knowledge_graph(WebApiAuth, kb_id) - assert res["code"] == 0, res - assert isinstance(res["data"], dict), res - assert "graph" in res["data"], res - assert "mind_map" in res["data"], res - - @pytest.mark.p2 - def test_delete_knowledge_graph(self, WebApiAuth, add_dataset): - kb_id = add_dataset - res = delete_knowledge_graph(WebApiAuth, kb_id) - assert res["code"] == 0, res - assert res["data"] is True, res - - -class TestKbTagsMetaNegative: - @pytest.mark.p3 - def test_list_tags_invalid_kb(self, WebApiAuth): - res = list_tags(WebApiAuth, "invalid_kb_id") - assert res["code"] == 109, res - assert "No authorization" in res["message"], res - - @pytest.mark.p3 - def test_list_tags_from_kbs_invalid_kb(self, WebApiAuth): - res = list_tags_from_kbs(WebApiAuth, {"kb_ids": "invalid_kb_id"}) - assert res["code"] == 109, res - assert "No authorization" in res["message"], res - - @pytest.mark.p3 - def test_rm_tags_invalid_kb(self, WebApiAuth): - res = rm_tags(WebApiAuth, "invalid_kb_id", {"tags": ["tag"]}) - assert res["code"] == 109, res - assert "No authorization" in res["message"], res - - @pytest.mark.p3 - def test_rename_tag_invalid_kb(self, WebApiAuth): - res = rename_tags(WebApiAuth, "invalid_kb_id", {"from_tag": "old", "to_tag": "new"}) - assert res["code"] == 109, res - assert "No authorization" in res["message"], res - - @pytest.mark.p3 - def test_get_meta_invalid_kb(self, WebApiAuth): - res = kb_get_meta(WebApiAuth, {"kb_ids": "invalid_kb_id"}) - assert res["code"] == 109, res - assert "No authorization" in res["message"], res - - @pytest.mark.p3 - def test_basic_info_invalid_kb(self, WebApiAuth): - res = kb_basic_info(WebApiAuth, {"kb_id": "invalid_kb_id"}) - assert res["code"] == 109, res - assert "No authorization" in res["message"], res - - @pytest.mark.p3 - def test_update_metadata_setting_missing_metadata(self, WebApiAuth, add_dataset): - res = kb_update_metadata_setting(WebApiAuth, {"kb_id": add_dataset}) - assert res["code"] == 101, res - assert "required argument are missing" in res["message"], res - assert "metadata" in res["message"], res - - @pytest.mark.p3 - def test_knowledge_graph_invalid_kb(self, WebApiAuth): - res = knowledge_graph(WebApiAuth, "invalid_kb_id") - assert res["code"] == 109, res - assert "No authorization" in res["message"], res - - @pytest.mark.p3 - def test_delete_knowledge_graph_invalid_kb(self, WebApiAuth): - res = delete_knowledge_graph(WebApiAuth, "invalid_kb_id") - assert res["code"] == 109, res - assert "No authorization" in res["message"], res diff --git a/test/testcases/test_web_api/test_kb_app/test_list_kbs.py b/test/testcases/test_web_api/test_kb_app/test_list_kbs.py deleted file mode 100644 index 0aeebf0c8c..0000000000 --- a/test/testcases/test_web_api/test_kb_app/test_list_kbs.py +++ /dev/null @@ -1,201 +0,0 @@ -# -# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import json -from concurrent.futures import ThreadPoolExecutor, as_completed - -import pytest -from test_common import list_datasets -from configs import INVALID_API_TOKEN -from libs.auth import RAGFlowWebApiAuth -from utils import is_sorted - - -class TestAuthorization: - @pytest.mark.p2 - @pytest.mark.parametrize( - "invalid_auth, expected_code, expected_message", - [ - (None, 401, ""), - (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), - ], - ) - def test_auth_invalid(self, invalid_auth, expected_code, expected_message): - res = list_datasets(invalid_auth) - assert res["code"] == expected_code, res - assert res["message"] == expected_message, res - - -class TestCapability: - @pytest.mark.p3 - def test_concurrent_list(self, WebApiAuth): - count = 100 - with ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(list_datasets, WebApiAuth) for i in range(count)] - responses = list(as_completed(futures)) - assert len(responses) == count, responses - assert all(future.result()["code"] == 0 for future in futures) - - -@pytest.mark.usefixtures("add_datasets") -class TestDatasetsList: - @pytest.mark.p2 - def test_params_unset(self, WebApiAuth): - res = list_datasets(WebApiAuth, None) - assert res["code"] == 0, res - assert len(res["data"]) == 5, res - - @pytest.mark.p2 - def test_params_empty(self, WebApiAuth): - res = list_datasets(WebApiAuth, {}) - assert res["code"] == 0, res - assert len(res["data"]) == 5, res - - @pytest.mark.p1 - @pytest.mark.parametrize( - "params, expected_page_size", - [ - ({"page": 2, "page_size": 2}, 2), - ({"page": 3, "page_size": 2}, 1), - ({"page": 4, "page_size": 2}, 0), - ({"page": "2", "page_size": 2}, 2), - ({"page": 1, "page_size": 10}, 5), - ], - ids=["normal_middle_page", "normal_last_partial_page", "beyond_max_page", "string_page_number", "full_data_single_page"], - ) - def test_page(self, WebApiAuth, params, expected_page_size): - res = list_datasets(WebApiAuth, params) - assert res["code"] == 0, res - assert len(res["data"]) == expected_page_size, res - - @pytest.mark.skip - @pytest.mark.p2 - @pytest.mark.parametrize( - "params, expected_code, expected_message", - [ - ({"page": 0}, 101, "Input should be greater than or equal to 1"), - ({"page": "a"}, 101, "Input should be a valid integer, unable to parse string as an integer"), - ], - ids=["page_0", "page_a"], - ) - def test_page_invalid(self, WebApiAuth, params, expected_code, expected_message): - res = list_datasets(WebApiAuth, params=params) - assert res["code"] == expected_code, res - assert expected_message in res["message"], res - - @pytest.mark.p2 - def test_page_none(self, WebApiAuth): - params = {"page": None} - res = list_datasets(WebApiAuth, params) - assert res["code"] == 0, res - assert len(res["data"]) == 5, res - - @pytest.mark.p1 - @pytest.mark.parametrize( - "params, expected_page_size", - [ - ({"page": 1, "page_size": 1}, 1), - ({"page": 1, "page_size": 3}, 3), - ({"page": 1, "page_size": 5}, 5), - ({"page": 1, "page_size": 6}, 5), - ({"page": 1, "page_size": "1"}, 1), - ], - ids=["min_valid_page_size", "medium_page_size", "page_size_equals_total", "page_size_exceeds_total", "string_type_page_size"], - ) - def test_page_size(self, WebApiAuth, params, expected_page_size): - res = list_datasets(WebApiAuth, params) - assert res["code"] == 0, res - assert len(res["data"]) == expected_page_size, res - - @pytest.mark.skip - @pytest.mark.p2 - @pytest.mark.parametrize( - "params, expected_code, expected_message", - [ - ({"page_size": 0}, 101, "Input should be greater than or equal to 1"), - ({"page_size": "a"}, 101, "Input should be a valid integer, unable to parse string as an integer"), - ], - ) - def test_page_size_invalid(self, WebApiAuth, params, expected_code, expected_message): - res = list_datasets(WebApiAuth, params) - assert res["code"] == expected_code, res - assert expected_message in res["message"], res - - @pytest.mark.p2 - def test_page_size_none(self, WebApiAuth): - params = {"page_size": None} - res = list_datasets(WebApiAuth, params) - assert res["code"] == 0, res - assert len(res["data"]) == 5, res - - @pytest.mark.p3 - @pytest.mark.parametrize( - "params, assertions", - [ - ({"orderby": "update_time"}, lambda r: (is_sorted(r["data"], "update_time", True))), - ], - ids=["orderby_update_time"], - ) - def test_orderby(self, WebApiAuth, params, assertions): - res = list_datasets(WebApiAuth, params) - assert res["code"] == 0, res - if callable(assertions): - assert assertions(res), res - - @pytest.mark.p3 - @pytest.mark.parametrize( - "params, assertions", - [ - ({"desc": "True"}, lambda r: (is_sorted(r["data"], "update_time", True))), - ({"desc": "False"}, lambda r: (is_sorted(r["data"], "update_time", False))), - ], - ids=["desc=True", "desc=False"], - ) - def test_desc(self, WebApiAuth, params, assertions): - res = list_datasets(WebApiAuth, params) - - assert res["code"] == 0, res - if callable(assertions): - assert assertions(res), res - - @pytest.mark.p2 - @pytest.mark.parametrize( - "params, expected_page_size", - [ - ({"ext": json.dumps({"parser_id": "naive"})}, 5), - ({"ext": json.dumps({"parser_id": "qa"})}, 0), - ], - ids=["naive", "dqa"], - ) - def test_parser_id(self, WebApiAuth, params, expected_page_size): - res = list_datasets(WebApiAuth, params) - assert res["code"] == 0, res - assert len(res["data"]) == expected_page_size, res - - @pytest.mark.p2 - def test_owner_ids_payload_mode(self, WebApiAuth): - base_res = list_datasets(WebApiAuth, {"page_size": 10}) - assert base_res["code"] == 0, base_res - assert base_res["data"], base_res - owner_id = base_res["data"][0]["tenant_id"] - - res = list_datasets( - WebApiAuth, - params={"page": 1, "page_size": 2, "desc": "false", "ext": json.dumps({"owner_ids": [owner_id]})}, - ) - assert res["code"] == 0, res - assert res["total_datasets"] >= len(res["data"]), res - assert len(res["data"]) <= 2, res - assert all(kb["tenant_id"] == owner_id for kb in res["data"]), res diff --git a/test/testcases/test_web_api/test_kb_app/test_rm_kb.py b/test/testcases/test_web_api/test_kb_app/test_rm_kb.py deleted file mode 100644 index eba2663f45..0000000000 --- a/test/testcases/test_web_api/test_kb_app/test_rm_kb.py +++ /dev/null @@ -1,61 +0,0 @@ -# -# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import pytest -from test_common import ( - list_datasets, - delete_datasets, -) -from configs import INVALID_API_TOKEN -from libs.auth import RAGFlowWebApiAuth - - -class TestAuthorization: - @pytest.mark.p2 - @pytest.mark.parametrize( - "invalid_auth, expected_code, expected_message", - [ - (None, 401, ""), - (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), - ], - ) - def test_auth_invalid(self, invalid_auth, expected_code, expected_message): - res = delete_datasets(invalid_auth) - assert res["code"] == expected_code, res - assert res["message"] == expected_message, res - - -class TestDatasetsDelete: - @pytest.mark.p1 - def test_kb_id(self, WebApiAuth, add_datasets_func): - kb_ids = add_datasets_func - payload = {"ids": [kb_ids[0]]} - res = delete_datasets(WebApiAuth, payload) - assert res["code"] == 0, res - - res = list_datasets(WebApiAuth) - assert len(res["data"]) == 2, res - - @pytest.mark.p2 - @pytest.mark.usefixtures("add_dataset_func") - def test_id_wrong_uuid(self, WebApiAuth): - payload = {"ids": ["d94a8dc02c9711f0930f7fbc369eab6d"]} - res = delete_datasets(WebApiAuth, payload) - assert res["code"] == 102, res - assert "lacks permission" in res["message"], res - - res = list_datasets(WebApiAuth) - assert len(res["data"]) == 1, res diff --git a/test/testcases/test_web_api/test_kb_app/test_update_kb.py b/test/testcases/test_web_api/test_kb_app/test_update_kb.py deleted file mode 100644 index 8dac7ab802..0000000000 --- a/test/testcases/test_web_api/test_kb_app/test_update_kb.py +++ /dev/null @@ -1,382 +0,0 @@ -# -# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import os -from concurrent.futures import ThreadPoolExecutor, as_completed - -import pytest -from test_common import update_dataset -from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN -from hypothesis import HealthCheck, example, given, settings -from libs.auth import RAGFlowWebApiAuth -from utils import encode_avatar -from utils.file_utils import create_image_file -from utils.hypothesis_utils import valid_names - - -class TestAuthorization: - @pytest.mark.p2 - @pytest.mark.parametrize( - "invalid_auth, expected_code, expected_message", - [ - (None, 401, ""), - (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), - ], - ids=["empty_auth", "invalid_api_token"], - ) - def test_auth_invalid(self, invalid_auth, expected_code, expected_message): - res = update_dataset(invalid_auth, "dataset_id") - assert res["code"] == expected_code, res - assert res["message"] == expected_message, res - - -class TestCapability: - @pytest.mark.p3 - def test_update_dateset_concurrent(self, WebApiAuth, add_dataset_func): - dataset_id = add_dataset_func - count = 100 - with ThreadPoolExecutor(max_workers=5) as executor: - futures = [ - executor.submit( - update_dataset, - WebApiAuth, - dataset_id, - { - "name": f"dataset_{i}", - "description": "", - "chunk_method": "naive", - }, - ) - for i in range(count) - ] - responses = list(as_completed(futures)) - assert len(responses) == count, responses - assert all(future.result()["code"] == 0 for future in futures) - - -class TestDatasetUpdate: - @pytest.mark.p3 - def test_dataset_id_not_uuid(self, WebApiAuth): - payload = {"name": "not uuid", "description": "", "chunk_method": "naive"} - res = update_dataset(WebApiAuth, "not_uuid", payload) - assert res["code"] == 101, res - assert "Invalid UUID1 format" in res["message"], res - - @pytest.mark.p1 - @given(name=valid_names()) - @example("a" * 128) - # Network-bound API call; disable Hypothesis deadline to avoid flaky timeouts. - @settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture], deadline=None) - def test_name(self, WebApiAuth, add_dataset_func, name): - dataset_id = add_dataset_func - payload = {"name": name, "description": "", "chunk_method": "naive"} - res = update_dataset(WebApiAuth, dataset_id, payload) - assert res["code"] == 0, res - assert res["data"]["name"] == name, res - - @pytest.mark.p2 - @pytest.mark.parametrize( - "name, expected_message", - [ - ("", "Field: - Message: "), - (" ", "Field: - Message: "), - ("a" * (DATASET_NAME_LIMIT + 1), "Field: - Message: "), - (0, "Field: - Message: "), - (None, "Field: - Message: "), - ], - ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"], - ) - def test_name_invalid(self, WebApiAuth, add_dataset_func, name, expected_message): - kb_id = add_dataset_func - payload = {"name": name, "description": "", "chunk_method": "naive"} - res = update_dataset(WebApiAuth, kb_id, payload) - assert res["code"] == 101, res - assert expected_message in res["message"], res - - @pytest.mark.p3 - def test_name_duplicated(self, WebApiAuth, add_datasets_func): - kb_id = add_datasets_func[0] - name = "kb_1" - payload = {"name": name, "description": "", "chunk_method": "naive"} - res = update_dataset(WebApiAuth, kb_id, payload) - assert res["code"] == 102, res - assert res["message"] == "Dataset name 'kb_1' already exists", res - - @pytest.mark.p3 - def test_name_case_insensitive(self, WebApiAuth, add_datasets_func): - kb_id = add_datasets_func[0] - name = "KB_1" - payload = {"name": name, "description": "", "chunk_method": "naive"} - res = update_dataset(WebApiAuth, kb_id, payload) - assert res["code"] == 102, res - assert res["message"] == "Dataset name 'KB_1' already exists", res - - @pytest.mark.p2 - def test_avatar(self, WebApiAuth, add_dataset_func, tmp_path): - kb_id = add_dataset_func - fn = create_image_file(tmp_path / "ragflow_test.png") - payload = { - "name": "avatar", - "description": "", - "chunk_method": "naive", - "avatar": f"data:image/png;base64,{encode_avatar(fn)}", - } - res = update_dataset(WebApiAuth, kb_id, payload) - assert res["code"] == 0, res - assert res["data"]["avatar"] == f"data:image/png;base64,{encode_avatar(fn)}", res - - @pytest.mark.p2 - def test_description(self, WebApiAuth, add_dataset_func): - kb_id = add_dataset_func - payload = {"name": "description", "description": "description", "chunk_method": "naive"} - res = update_dataset(WebApiAuth, kb_id, payload) - assert res["code"] == 0, res - assert res["data"]["description"] == "description", res - - @pytest.mark.p1 - @pytest.mark.parametrize( - "embedding_model", - [ - "BAAI/bge-small-en-v1.5@Builtin", - "embedding-3@ZHIPU-AI", - ], - ids=["builtin_baai", "tenant_zhipu"], - ) - def test_embedding_model(self, WebApiAuth, add_dataset_func, embedding_model): - kb_id = add_dataset_func - payload = {"name": "embedding_model", "description": "", "chunk_method": "naive", "embedding_model": embedding_model} - res = update_dataset(WebApiAuth, kb_id, payload) - assert res["code"] == 0, res - assert res["data"]["embedding_model"] == embedding_model, res - - @pytest.mark.p2 - @pytest.mark.parametrize( - "permission", - [ - "me", - "team", - ], - ids=["me", "team"], - ) - def test_permission(self, WebApiAuth, add_dataset_func, permission): - kb_id = add_dataset_func - payload = {"name": "permission", "description": "", "chunk_method": "naive", "permission": permission} - res = update_dataset(WebApiAuth, kb_id, payload) - assert res["code"] == 0, res - assert res["data"]["permission"] == permission.lower().strip(), res - - @pytest.mark.p1 - @pytest.mark.parametrize( - "chunk_method", - [ - "naive", - "book", - "email", - "laws", - "manual", - "one", - "paper", - "picture", - "presentation", - "qa", - "table", - pytest.param("tag", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity does not support parser_id=tag")), - ], - ids=["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"], - ) - def test_chunk_method(self, WebApiAuth, add_dataset_func, chunk_method): - kb_id = add_dataset_func - payload = {"name": "chunk_method", "description": "", "chunk_method": chunk_method} - res = update_dataset(WebApiAuth, kb_id, payload) - assert res["code"] == 0, res - assert res["data"]["chunk_method"] == chunk_method, res - - @pytest.mark.p1 - @pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="Infinity does not support parser_id=tag") - def test_chunk_method_tag_with_infinity(self, WebApiAuth, add_dataset_func): - kb_id = add_dataset_func - payload = {"name": "chunk_method", "description": "", "chunk_method": "tag"} - res = update_dataset(WebApiAuth, kb_id, payload) - assert res["code"] == 103, res - assert res["message"] == "The chunking method Tag has not been supported by Infinity yet.", res - - @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") - @pytest.mark.p2 - @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) - def test_pagerank(self, WebApiAuth, add_dataset_func, pagerank): - kb_id = add_dataset_func - payload = {"name": "pagerank", "description": "", "chunk_method": "naive", "pagerank": pagerank} - res = update_dataset(WebApiAuth, kb_id, payload) - assert res["code"] == 0, res - assert res["data"]["pagerank"] == pagerank, res - - @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") - @pytest.mark.p2 - def test_pagerank_set_to_0(self, WebApiAuth, add_dataset_func): - kb_id = add_dataset_func - payload = {"name": "pagerank", "description": "", "chunk_method": "naive", "pagerank": 50} - res = update_dataset(WebApiAuth, kb_id, payload) - assert res["code"] == 0, res - assert res["data"]["pagerank"] == 50, res - - payload = {"name": "pagerank", "description": "", "chunk_method": "naive", "pagerank": 0} - res = update_dataset(WebApiAuth, kb_id, payload) - assert res["code"] == 0, res - assert res["data"]["pagerank"] == 0, res - - @pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208") - @pytest.mark.p2 - def test_pagerank_infinity(self, WebApiAuth, add_dataset_func): - kb_id = add_dataset_func - payload = {"name": "pagerank", "description": "", "chunk_method": "naive", "pagerank": 50} - res = update_dataset(WebApiAuth, kb_id, payload) - assert res["code"] == 102, res - assert res["message"] == "'pagerank' can only be set when doc_engine is elasticsearch", res - - @pytest.mark.p1 - @pytest.mark.parametrize( - "parser_config", - [ - {"auto_keywords": 0}, - {"auto_keywords": 16}, - {"auto_keywords": 32}, - {"auto_questions": 0}, - {"auto_questions": 5}, - {"auto_questions": 10}, - {"chunk_token_num": 1}, - {"chunk_token_num": 1024}, - {"chunk_token_num": 2048}, - {"delimiter": "\n"}, - {"delimiter": " "}, - {"html4excel": True}, - {"html4excel": False}, - {"layout_recognize": "DeepDOC"}, - {"layout_recognize": "Plain Text"}, - {"tag_kb_ids": ["1", "2"]}, - {"topn_tags": 1}, - {"topn_tags": 5}, - {"topn_tags": 10}, - {"filename_embd_weight": 0.1}, - {"filename_embd_weight": 0.5}, - {"filename_embd_weight": 1.0}, - {"task_page_size": 1}, - {"task_page_size": None}, - {"pages": [[1, 100]]}, - {"pages": None}, - {"graphrag": {"use_graphrag": True}}, - {"graphrag": {"use_graphrag": False}}, - {"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}}, - {"graphrag": {"method": "general"}}, - {"graphrag": {"method": "light"}}, - {"graphrag": {"community": True}}, - {"graphrag": {"community": False}}, - {"graphrag": {"resolution": True}}, - {"graphrag": {"resolution": False}}, - {"raptor": {"use_raptor": True}}, - {"raptor": {"use_raptor": False}}, - {"raptor": {"prompt": "Who are you?"}}, - {"raptor": {"max_token": 1}}, - {"raptor": {"max_token": 1024}}, - {"raptor": {"max_token": 2048}}, - {"raptor": {"threshold": 0.0}}, - {"raptor": {"threshold": 0.5}}, - {"raptor": {"threshold": 1.0}}, - {"raptor": {"max_cluster": 1}}, - {"raptor": {"max_cluster": 512}}, - {"raptor": {"max_cluster": 1024}}, - {"raptor": {"random_seed": 0}}, - ], - ids=[ - "auto_keywords_min", - "auto_keywords_mid", - "auto_keywords_max", - "auto_questions_min", - "auto_questions_mid", - "auto_questions_max", - "chunk_token_num_min", - "chunk_token_num_mid", - "chunk_token_num_max", - "delimiter", - "delimiter_space", - "html4excel_true", - "html4excel_false", - "layout_recognize_DeepDOC", - "layout_recognize_navie", - "tag_kb_ids", - "topn_tags_min", - "topn_tags_mid", - "topn_tags_max", - "filename_embd_weight_min", - "filename_embd_weight_mid", - "filename_embd_weight_max", - "task_page_size_min", - "task_page_size_None", - "pages", - "pages_none", - "graphrag_true", - "graphrag_false", - "graphrag_entity_types", - "graphrag_method_general", - "graphrag_method_light", - "graphrag_community_true", - "graphrag_community_false", - "graphrag_resolution_true", - "graphrag_resolution_false", - "raptor_true", - "raptor_false", - "raptor_prompt", - "raptor_max_token_min", - "raptor_max_token_mid", - "raptor_max_token_max", - "raptor_threshold_min", - "raptor_threshold_mid", - "raptor_threshold_max", - "raptor_max_cluster_min", - "raptor_max_cluster_mid", - "raptor_max_cluster_max", - "raptor_random_seed_min", - ], - ) - def test_parser_config(self, WebApiAuth, add_dataset_func, parser_config): - kb_id = add_dataset_func - payload = {"name": "parser_config", "description": "", "chunk_method": "naive", "parser_config": parser_config} - res = update_dataset(WebApiAuth, kb_id, payload) - assert res["code"] == 0, res - for key, value in parser_config.items(): - if not isinstance(value, dict): - assert res["data"]["parser_config"].get(key) == value, res - else: - for sub_key, sub_value in value.items(): - assert res["data"]["parser_config"].get(key, {}).get(sub_key) == sub_value, res - - @pytest.mark.p2 - @pytest.mark.parametrize( - "payload", - [ - {"id": "id"}, - {"tenant_id": "e57c1966f99211efb41e9e45646e0111"}, - {"created_by": "created_by"}, - {"create_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, - {"create_time": 1741671443322}, - {"update_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, - {"update_time": 1741671443339}, - ], - ) - def test_field_unsupported(self, WebApiAuth, add_dataset_func, payload): - kb_id = add_dataset_func - full_payload = {"name": "field_unsupported", "description": "", "chunk_method": "naive", **payload} - res = update_dataset(WebApiAuth, kb_id, full_payload) - assert res["code"] == 101, res - assert "are not permitted" in res["message"], res diff --git a/web/src/hooks/use-knowledge-request.ts b/web/src/hooks/use-knowledge-request.ts index fc77f40f1a..853f3750a5 100644 --- a/web/src/hooks/use-knowledge-request.ts +++ b/web/src/hooks/use-knowledge-request.ts @@ -14,6 +14,7 @@ import { ITestRetrievalRequestBody } from '@/interfaces/request/knowledge'; import i18n from '@/locales/config'; import kbService, { deleteKnowledgeGraph, + getKbDetail, getKnowledgeGraph, listDataset, listTag, @@ -407,9 +408,7 @@ export const useFetchKnowledgeBaseConfiguration = (props?: { gcTime: 0, enabled: !!knowledgeBaseId && isEdit, queryFn: async () => { - const { data } = await kbService.getKbDetail({ - kb_id: knowledgeBaseId, - }); + const { data } = await getKbDetail(knowledgeBaseId || ''); return data?.data ?? {}; }, }); @@ -443,7 +442,9 @@ export function useFetchKnowledgeMetadata(kbIds: string[] = []) { enabled: kbIds.length > 0, gcTime: 0, queryFn: async () => { - const { data } = await kbService.getMeta({ kb_ids: kbIds.join(',') }); + const { data } = await kbService.getMeta({ + dataset_ids: kbIds.join(','), + }); return data?.data ?? {}; }, }); @@ -549,7 +550,7 @@ export const useFetchTagListByKnowledgeIds = () => { gcTime: 0, // https://tanstack.com/query/latest/docs/framework/react/guides/caching?from=reactQueryV3 queryFn: async () => { const { data } = await kbService.listTagByKnowledgeIds({ - kb_ids: knowledgeIds.join(','), + dataset_ids: knowledgeIds.join(','), }); const list = data?.data || []; return list; diff --git a/web/src/interfaces/database/dataset.ts b/web/src/interfaces/database/dataset.ts index 2a028a77d7..e49cca5140 100644 --- a/web/src/interfaces/database/dataset.ts +++ b/web/src/interfaces/database/dataset.ts @@ -1,5 +1,5 @@ // for the dataset list -// The data structures returned by the `datasets` interface and `kb/detail` are inconsistent. +// The data structures returned by the `datasets` interface and `/api/v1/datasets/{id}` are inconsistent. export interface IDataset { avatar?: string; diff --git a/web/src/pages/dataset/components/metedata/hooks/use-manage-modal.ts b/web/src/pages/dataset/components/metedata/hooks/use-manage-modal.ts index b2778eb69c..8f7311723a 100644 --- a/web/src/pages/dataset/components/metedata/hooks/use-manage-modal.ts +++ b/web/src/pages/dataset/components/metedata/hooks/use-manage-modal.ts @@ -2,8 +2,9 @@ import message from '@/components/ui/message'; import { useSetModalState } from '@/hooks/common-hooks'; import { useSelectedIds } from '@/hooks/logic-hooks/use-row-selection'; import { DocumentApiAction } from '@/hooks/use-document-request'; -import kbService, { +import { getMetaDataService, + kbUpdateMetaData, updateDocumentMetaDataConfig, updateDocumentsMetadata, } from '@/services/knowledge-service'; @@ -413,8 +414,7 @@ export const useManageMetaDataModal = ( const handleSaveSettings = useCallback( async (callback: () => void, builtInMetadata?: IBuiltInMetadataItem[]) => { const data = util.tableDataToMetaDataSettingJSON(tableData); - const { data: res } = await kbService.kbUpdateMetaData({ - kb_id: id, + const { data: res } = await kbUpdateMetaData(id || '', { metadata: data, builtInMetadata: builtInMetadata || [], }); @@ -434,14 +434,11 @@ export const useManageMetaDataModal = ( const handleSaveSingleFileSettings = useCallback( async (callback: () => void) => { const data = util.tableDataToMetaDataSettingJSON(tableData); - // otherData contains: documentId - if (otherData?.documentId && id) { + if (otherData?.documentId) { const { data: res } = await updateDocumentMetaDataConfig({ - kb_id: id, + kb_id: id || '', doc_id: otherData.documentId, - data: { - metadata: data, - }, + data: { metadata: data }, }); if (res.code === 0) { message.success(t('message.operated')); diff --git a/web/src/pages/dataset/dataset-overview/hook.ts b/web/src/pages/dataset/dataset-overview/hook.ts index 679d90be04..201b2a5069 100644 --- a/web/src/pages/dataset/dataset-overview/hook.ts +++ b/web/src/pages/dataset/dataset-overview/hook.ts @@ -3,7 +3,8 @@ import { useGetPaginationWithRouter, useHandleSearchChange, } from '@/hooks/logic-hooks'; -import kbService, { +import { + getKnowledgeBasicInfo, listDataPipelineLogDocument, listPipelineDatasetLogs, } from '@/services/knowledge-service'; @@ -20,9 +21,9 @@ const useFetchOverviewTotal = () => { const { data } = useQuery({ queryKey: ['overviewTotal'], queryFn: async () => { - const { data: res = {} } = await kbService.getKnowledgeBasicInfo({ - kb_id: knowledgeBaseId, - }); + const { data: res = {} } = await getKnowledgeBasicInfo( + knowledgeBaseId || '', + ); return res.data || []; }, }); @@ -61,16 +62,12 @@ const useFetchFileLogList = () => { }, enabled: true, queryFn: async () => { - const { data: res = {} } = await fetchFunc( - { - kb_id: knowledgeBaseId, - page: pagination.current, - page_size: pagination.pageSize, - keywords: searchString, - // order_by: '', - }, - { ...filterValue }, - ); + const { data: res = {} } = await fetchFunc(knowledgeBaseId || '', { + page: pagination.current, + page_size: pagination.pageSize, + keywords: searchString, + ...filterValue, + }); return res.data || []; }, }); diff --git a/web/src/pages/dataset/dataset-setting/hooks.ts b/web/src/pages/dataset/dataset-setting/hooks.ts index c42be72ffe..1ac6b4cd91 100644 --- a/web/src/pages/dataset/dataset-setting/hooks.ts +++ b/web/src/pages/dataset/dataset-setting/hooks.ts @@ -4,7 +4,7 @@ import { useSetModalState } from '@/hooks/common-hooks'; import { useFetchKnowledgeBaseConfiguration } from '@/hooks/use-knowledge-request'; import { useSelectLlmOptionsByModelType } from '@/hooks/use-llm-request'; import { useSelectParserList } from '@/hooks/use-user-setting-request'; -import kbService from '@/services/knowledge-service'; +import { checkEmbedding } from '@/services/knowledge-service'; import { useIsFetching } from '@tanstack/react-query'; import { pick } from 'lodash'; import { useCallback, useEffect, useState } from 'react'; @@ -108,8 +108,7 @@ export const useHandleKbEmbedding = () => { const knowledgeBaseId = searchParams.get('id') || id; const handleChange = useCallback( async ({ embed_id }: { embed_id: string }) => { - const res = await kbService.checkEmbedding({ - kb_id: knowledgeBaseId, + const res = await checkEmbedding(knowledgeBaseId || '', { embd_id: embed_id, }); return res.data; diff --git a/web/src/pages/dataset/dataset/generate-button/hook.ts b/web/src/pages/dataset/dataset/generate-button/hook.ts index cad9e3e9ad..833c37f6af 100644 --- a/web/src/pages/dataset/dataset/generate-button/hook.ts +++ b/web/src/pages/dataset/dataset/generate-button/hook.ts @@ -2,10 +2,8 @@ import message from '@/components/ui/message'; import agentService from '@/services/agent-service'; import { deletePipelineTask, - runGraphRag, - runRaptor, - traceGraphRag, - traceRaptor, + runIndex, + traceIndex, } from '@/services/knowledge-service'; import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; import { t } from 'i18next'; @@ -59,7 +57,7 @@ export const useTraceGenerate = ({ open }: { open: boolean }) => { retryDelay: 1000, enabled: open, queryFn: async () => { - const { data } = await traceGraphRag(id); + const { data } = await traceIndex(id, 'graph'); return data?.data || {}; }, }); @@ -74,7 +72,7 @@ export const useTraceGenerate = ({ open }: { open: boolean }) => { retryDelay: 1000, enabled: open, queryFn: async () => { - const { data } = await traceRaptor(id); + const { data } = await traceIndex(id, 'raptor'); return data?.data || {}; }, }); @@ -134,9 +132,9 @@ export const useDatasetGenerate = () => { } = useMutation({ mutationKey: [DatasetKey.generate], mutationFn: async ({ type }: { type: GenerateType }) => { - const func = - type === GenerateType.KnowledgeGraph ? runGraphRag : runRaptor; - const { data } = await func(id); + const indexType = + type === GenerateType.KnowledgeGraph ? 'graph' : 'raptor'; + const { data } = await runIndex(id, indexType); if (data.code === 0) { message.success(t('message.operated')); queryClient.invalidateQueries({ diff --git a/web/src/services/knowledge-service.ts b/web/src/services/knowledge-service.ts index f1df2e0b2f..b947311830 100644 --- a/web/src/services/knowledge-service.ts +++ b/web/src/services/knowledge-service.ts @@ -8,33 +8,25 @@ import { ProcessingType } from '@/pages/dataset/dataset-overview/dataset-common' import api from '@/utils/api'; import { getAuthorization } from '@/utils/authorization-util'; import registerServer from '@/utils/register-server'; -import request, { post } from '@/utils/request'; +import request from '@/utils/request'; import axios from 'axios'; const { createKb, rmKb, - getKbDetail, kbList, - getDocumentList, documentChangeStatus, documentCreate, documentChangeParser, documentThumbnails, retrievalTest, documentRun, - documentUpload, webCrawl, knowledgeGraph, listTagByKnowledgeIds, setMeta, getMeta, retrievalTestShare, - getKnowledgeBasicInfo, - fetchDataPipelineLog, - fetchPipelineDatasetLogs, - checkEmbedding, - kbUpdateMetaData, } = api; const methods = { @@ -46,19 +38,11 @@ const methods = { url: rmKb, method: 'delete', }, - getKbDetail: { - url: getKbDetail, - method: 'get', - }, getList: { url: kbList, method: 'get', }, // document manager - getDocumentList: { - url: getDocumentList, - method: 'get', - }, documentChangeStatus: { url: documentChangeStatus, method: 'post', @@ -79,10 +63,6 @@ const methods = { url: documentThumbnails, method: 'get', }, - documentUpload: { - url: documentUpload, - method: 'post', - }, webCrawl: { url: webCrawl, method: 'post', @@ -115,36 +95,10 @@ const methods = { url: retrievalTestShare, method: 'post', }, - getKnowledgeBasicInfo: { - url: getKnowledgeBasicInfo, - method: 'get', - }, - fetchDataPipelineLog: { - url: fetchDataPipelineLog, - method: 'post', - }, - fetchPipelineDatasetLogs: { - url: fetchPipelineDatasetLogs, - method: 'post', - }, - getPipelineDetail: { - url: api.getPipelineDetail, - method: 'get', - }, - pipelineRerun: { url: api.pipelineRerun, method: 'post', }, - - checkEmbedding: { - url: checkEmbedding, - method: 'post', - }, - kbUpdateMetaData: { - url: kbUpdateMetaData, - method: 'post', - }, }; const baseKbService = registerServer(methods, request); @@ -281,16 +235,19 @@ const kbService = { ...chunkService, }; +export const getKbDetail = (datasetId: string) => + request.get(api.getKbDetail(datasetId)); + export const listTag = (knowledgeId: string) => request.get(api.listTag(knowledgeId)); export const removeTag = (knowledgeId: string, tags: string[]) => - post(api.removeTag(knowledgeId), { tags }); + request.delete(api.removeTag(knowledgeId), { data: { tags } }); export const renameTag = ( knowledgeId: string, { fromTag, toTag }: IRenameTag, -) => post(api.renameTag(knowledgeId), { fromTag, toTag }); +) => request.put(api.renameTag(knowledgeId), { data: { fromTag, toTag } }); export function getKnowledgeGraph(knowledgeId: string) { return request.get(api.getKnowledgeGraph(knowledgeId)); @@ -306,17 +263,11 @@ export const listDataset = (params?: IFetchKnowledgeListRequestParams) => export const updateKb = (datasetId: string, data: Record) => request.put(api.updateKb(datasetId), { data }); -export const runGraphRag = (datasetId: string) => - request.post(api.runGraphRag(datasetId)); +export const runIndex = (datasetId: string, indexType: string) => + request.post(api.runIndex(datasetId, indexType)); -export const traceGraphRag = (datasetId: string) => - request.get(api.traceGraphRag(datasetId)); - -export const runRaptor = (datasetId: string) => - request.post(api.runRaptor(datasetId)); - -export const traceRaptor = (datasetId: string) => - request.get(api.traceRaptor(datasetId)); +export const traceIndex = (datasetId: string, indexType: string) => + request.get(api.traceIndex(datasetId, indexType)); // Using RESTful API: GET /api/v1/datasets/{dataset_id}/documents export const listDocument = ( @@ -403,16 +354,28 @@ export const updateDocumentMetaDataConfig = ({ }); export const listDataPipelineLogDocument = ( - params?: IFetchKnowledgeListRequestParams, - body?: IFetchDocumentListRequestBody, -) => request.post(api.fetchDataPipelineLog, { data: body || {}, params }); + datasetId: string, + params?: Record, +) => request.get(api.fetchDataPipelineLog(datasetId), { params }); + export const listPipelineDatasetLogs = ( - params?: IFetchKnowledgeListRequestParams & { - kb_id?: string; - keywords?: string; - }, - body?: IFetchDocumentListRequestBody, -) => request.post(api.fetchPipelineDatasetLogs, { data: body || {}, params }); + datasetId: string, + params?: Record, +) => request.get(api.fetchPipelineDatasetLogs(datasetId), { params }); + +export const getPipelineDetail = (datasetId: string, logId: string) => + request.get(api.getPipelineDetail(datasetId, logId)); + +export const getKnowledgeBasicInfo = (datasetId: string) => + request.get(api.getKnowledgeBasicInfo(datasetId)); + +export const checkEmbedding = (datasetId: string, data: Record) => + request.post(api.checkEmbedding(datasetId), { data }); + +export const kbUpdateMetaData = ( + datasetId: string, + data: Record, +) => request.put(api.kbUpdateMetaData(datasetId), { data }); export function deletePipelineTask({ kb_id, @@ -421,7 +384,7 @@ export function deletePipelineTask({ kb_id: string; type: ProcessingType; }) { - return request.delete(api.unbindPipelineTask({ kb_id, type })); + return request.delete(api.unbindPipelineTask(kb_id, type)); } export default kbService; diff --git a/web/src/utils/api.ts b/web/src/utils/api.ts index 6b3d893a83..df797937b9 100644 --- a/web/src/utils/api.ts +++ b/web/src/utils/api.ts @@ -57,46 +57,50 @@ export default { // knowledge base - checkEmbedding: `${webAPI}/kb/check_embedding`, + checkEmbedding: (datasetId: string) => + `${restAPIv1}/datasets/${datasetId}/embedding`, kbList: `${restAPIv1}/datasets`, createKb: `${restAPIv1}/datasets`, updateKb: (datasetId: string) => `${restAPIv1}/datasets/${datasetId}`, rmKb: `${restAPIv1}/datasets`, - getKbDetail: `${webAPI}/kb/detail`, + getKbDetail: (datasetId: string) => `${restAPIv1}/datasets/${datasetId}`, getKnowledgeGraph: (knowledgeId: string) => - `${restAPIv1}/datasets/${knowledgeId}/knowledge_graph`, + `${restAPIv1}/datasets/${knowledgeId}/graph/search`, deleteKnowledgeGraph: (knowledgeId: string) => - `${restAPIv1}/datasets/${knowledgeId}/knowledge_graph`, - getMeta: `${webAPI}/kb/get_meta`, - getKnowledgeBasicInfo: `${webAPI}/kb/basic_info`, + `${restAPIv1}/datasets/${knowledgeId}/graph`, + getMeta: `${restAPIv1}/datasets/metadata/flattened`, + getKnowledgeBasicInfo: (datasetId: string) => + `${restAPIv1}/datasets/${datasetId}/ingestions/summary`, // data pipeline log - fetchDataPipelineLog: `${webAPI}/kb/list_pipeline_logs`, - getPipelineDetail: `${webAPI}/kb/pipeline_log_detail`, - fetchPipelineDatasetLogs: `${webAPI}/kb/list_pipeline_dataset_logs`, - runGraphRag: (datasetId: string) => - `${restAPIv1}/datasets/${datasetId}/run_graphrag`, - traceGraphRag: (datasetId: string) => - `${restAPIv1}/datasets/${datasetId}/trace_graphrag`, - runRaptor: (datasetId: string) => - `${restAPIv1}/datasets/${datasetId}/run_raptor`, - traceRaptor: (datasetId: string) => - `${restAPIv1}/datasets/${datasetId}/trace_raptor`, - unbindPipelineTask: ({ kb_id, type }: { kb_id: string; type: string }) => - `${webAPI}/kb/unbind_task?kb_id=${kb_id}&pipeline_task_type=${type}`, - pipelineRerun: `${restAPIv1}/agents/rerun`, + fetchDataPipelineLog: (datasetId: string) => + `${restAPIv1}/datasets/${datasetId}/ingestions`, + getPipelineDetail: (datasetId: string, logId: string) => + `${restAPIv1}/datasets/${datasetId}/ingestions/${logId}`, + fetchPipelineDatasetLogs: (datasetId: string) => + `${restAPIv1}/datasets/${datasetId}/ingestions`, + runIndex: (datasetId: string, indexType: string) => + `${restAPIv1}/datasets/${datasetId}/index?type=${indexType}`, + traceIndex: (datasetId: string, indexType: string) => + `${restAPIv1}/datasets/${datasetId}/index?type=${indexType}`, + unbindPipelineTask: (datasetId: string, indexType: string) => + `${restAPIv1}/datasets/${datasetId}/${indexType}`, + pipelineRerun: `${webAPI}/canvas/rerun`, getMetaData: (datasetId: string) => `${restAPIv1}/datasets/${datasetId}/metadata/summary`, updateDocumentsMetadata: (datasetId: string) => `${restAPIv1}/datasets/${datasetId}/documents/metadatas`, - kbUpdateMetaData: `${webAPI}/kb/update_metadata_setting`, + kbUpdateMetaData: (datasetId: string) => + `${restAPIv1}/datasets/${datasetId}/metadata/config`, documentUpdateMetaDataConfig: (datasetId: string, documentId: string) => `${restAPIv1}/datasets/${datasetId}/documents/${documentId}/metadata/config`, // tags - listTag: (knowledgeId: string) => `${webAPI}/kb/${knowledgeId}/tags`, - listTagByKnowledgeIds: `${webAPI}/kb/tags`, - removeTag: (knowledgeId: string) => `${webAPI}/kb/${knowledgeId}/rm_tags`, - renameTag: (knowledgeId: string) => `${webAPI}/kb/${knowledgeId}/rename_tag`, + listTag: (knowledgeId: string) => `${restAPIv1}/datasets/${knowledgeId}/tags`, + listTagByKnowledgeIds: `${restAPIv1}/datasets/tags/aggregation`, + removeTag: (knowledgeId: string) => + `${restAPIv1}/datasets/${knowledgeId}/tags`, + renameTag: (knowledgeId: string) => + `${restAPIv1}/datasets/${knowledgeId}/tags`, // chunk chunkList: (datasetId: string, documentId: string) => diff --git a/web/src/utils/llm-util.ts b/web/src/utils/llm-util.ts index b8a843db3a..daf9c0d586 100644 --- a/web/src/utils/llm-util.ts +++ b/web/src/utils/llm-util.ts @@ -84,8 +84,7 @@ const API_WHITELIST = [ '/v1/canvas/setting', '/api/v1/searches/', '/api/v1/memories', - '/v1/kb/create', - '/v1/kb/update', + '/api/v1/datasets', '/v1/dataflow/set', ];