diff --git a/api/apps/__init__.py b/api/apps/__init__.py index 6df12f47a8..041d06ecc2 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -308,9 +308,8 @@ def register_page(page_path): sys.modules[module_name] = page spec.loader.exec_module(page) page_name = getattr(page, "page_name", page_name) - sdk_path = "\\sdk\\" if sys.platform.startswith("win") else "/sdk/" restful_api_path = "\\restful_apis\\" if sys.platform.startswith("win") else "/restful_apis/" - url_prefix = f"/api/{API_VERSION}" if sdk_path in path or restful_api_path in path else f"/{API_VERSION}/{page_name}" + url_prefix = f"/api/{API_VERSION}" if restful_api_path in path else f"/{API_VERSION}/{page_name}" app.register_blueprint(page.manager, url_prefix=url_prefix) return url_prefix diff --git a/api/apps/sdk/session.py b/api/apps/restful_apis/bot_api.py similarity index 99% rename from api/apps/sdk/session.py rename to api/apps/restful_apis/bot_api.py index ba65db5f30..9e96a06931 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/restful_apis/bot_api.py @@ -555,7 +555,7 @@ async def retrieval_test_embedded(): try: return await _retrieval() except Exception as e: - if str(e).find("not_found") > 0: + if "not_found" in str(e): return get_json_result(data=False, message="No chunk found! Check the chunk status please!", code=RetCode.DATA_ERROR) return server_error_response(e) diff --git a/api/apps/restful_apis/chunk_api.py b/api/apps/restful_apis/chunk_api.py index fe45209dd0..3774a37461 100644 --- a/api/apps/restful_apis/chunk_api.py +++ b/api/apps/restful_apis/chunk_api.py @@ -15,6 +15,7 @@ # import base64 import datetime +import logging import re import xxhash @@ -25,24 +26,44 @@ from api.apps import login_required from api.db.joint_services.tenant_model_service import ( get_model_config_by_id, get_model_config_by_type_and_name, + get_tenant_default_model_by_type, ) +from api.db.db_models import Document, Task +from api.db.services.doc_metadata_service import DocMetadataService from api.db.services.document_service import DocumentService +from api.db.services.file2document_service import File2DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.llm_service import LLMBundle +from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks from api.db.services.tenant_llm_service import TenantLLMService from api.utils.api_utils import ( add_tenant_id_to_kwargs, check_duplicate_ids, + construct_json_result, get_error_data_result, get_request_json, get_result, server_error_response, + token_required, ) from api.utils.image_utils import store_chunk_image +from api.utils.reference_metadata_utils import ( + enrich_chunks_with_document_metadata, + resolve_reference_metadata_preferences, +) from common import settings -from common.constants import LLMType, ParserType, RetCode +from common.constants import LLMType, ParserType, RetCode, TaskStatus +from common.metadata_utils import convert_conditions, meta_filter from common.misc_utils import thread_pool_exec from common.string_utils import is_content_empty, remove_redundant_spaces from common.tag_feature_utils import validate_tag_features +from rag.app.tag import label_question +from rag.nlp import search +from rag.prompts.generator import cross_languages, keyword_extraction + + +DOC_STOP_PARSING_INVALID_STATE_MESSAGE = "Can't stop parsing document that has not started or already completed" +DOC_STOP_PARSING_INVALID_STATE_ERROR_CODE = "DOC_STOP_PARSING_INVALID_STATE" class Chunk(BaseModel): @@ -101,6 +122,232 @@ def _get_dataset_tenant_id(dataset_id): return kb.tenant_id +def _resolve_reference_metadata(req: dict, search_config: dict | None = None): + return resolve_reference_metadata_preferences(req, search_config) + + +def _enrich_chunks_with_document_metadata(chunks: list[dict], metadata_fields=None) -> None: + enrich_chunks_with_document_metadata(chunks, metadata_fields) + + +@manager.route("/datasets//chunks", methods=["POST"]) # noqa: F821 +@token_required +async def parse(tenant_id, dataset_id): + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + req = await get_request_json() + if not req.get("document_ids"): + return get_error_data_result("`document_ids` is required") + doc_list = req.get("document_ids") + unique_doc_ids, duplicate_messages = check_duplicate_ids(doc_list, "document") + doc_list = unique_doc_ids + + not_found = [] + success_count = 0 + for id in doc_list: + doc = DocumentService.query(id=id, kb_id=dataset_id) + if not doc: + not_found.append(id) + continue + if not doc: + return get_error_data_result(message=f"You don't own the document {id}.") + info = {"run": "1", "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0} + if ( + DocumentService.filter_update( + [ + Document.id == id, + ((Document.run.is_null(True)) | (Document.run != TaskStatus.RUNNING.value)), + ], + info, + ) + == 0 + ): + return get_error_data_result("Can't parse document that is currently being processed") + settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id) + TaskService.filter_delete([Task.doc_id == id]) + e, doc = DocumentService.get_by_id(id) + doc = doc.to_dict() + doc["tenant_id"] = tenant_id + bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) + queue_tasks(doc, bucket, name, 0) + success_count += 1 + if not_found: + return get_result(message=f"Documents not found: {not_found}", code=RetCode.DATA_ERROR) + if duplicate_messages: + if success_count > 0: + return get_result( + message=f"Partially parsed {success_count} documents with {len(duplicate_messages)} errors", + data={"success_count": success_count, "errors": duplicate_messages}, + ) + else: + return get_error_data_result(message=";".join(duplicate_messages)) + + return get_result() + + +@manager.route("/datasets//chunks", methods=["DELETE"]) # noqa: F821 +@token_required +async def stop_parsing(tenant_id, dataset_id): + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + req = await get_request_json() + + if not req.get("document_ids"): + return get_error_data_result("`document_ids` is required") + doc_list = req.get("document_ids") + unique_doc_ids, duplicate_messages = check_duplicate_ids(doc_list, "document") + doc_list = unique_doc_ids + + success_count = 0 + for id in doc_list: + doc = DocumentService.query(id=id, kb_id=dataset_id) + if not doc: + return get_error_data_result(message=f"You don't own the document {id}.") + if doc[0].run != TaskStatus.RUNNING.value: + return construct_json_result( + code=RetCode.DATA_ERROR, + message=DOC_STOP_PARSING_INVALID_STATE_MESSAGE, + data={"error_code": DOC_STOP_PARSING_INVALID_STATE_ERROR_CODE}, + ) + cancel_all_task_of(id) + info = {"run": "2", "progress": 0, "chunk_num": 0} + DocumentService.update_by_id(id, info) + settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id) + success_count += 1 + if duplicate_messages: + if success_count > 0: + return get_result( + message=f"Partially stopped {success_count} documents with {len(duplicate_messages)} errors", + data={"success_count": success_count, "errors": duplicate_messages}, + ) + else: + return get_error_data_result(message=";".join(duplicate_messages)) + return get_result() + + +@manager.route("/retrieval", methods=["POST"]) # noqa: F821 +@token_required +async def retrieval_test(tenant_id): + req = await get_request_json() + if not req.get("dataset_ids"): + return get_error_data_result("`dataset_ids` is required.") + kb_ids = req["dataset_ids"] + if not isinstance(kb_ids, list): + return get_error_data_result("`dataset_ids` should be a list") + for id in kb_ids: + if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id): + return get_error_data_result(f"You don't own the dataset {id}.") + kbs = KnowledgebaseService.get_by_ids(kb_ids) + embd_nms = list(set([TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs])) + if len(embd_nms) != 1: + return get_result(message="Datasets use different embedding models.", code=RetCode.DATA_ERROR) + if "question" not in req: + return get_error_data_result("`question` is required.") + page = int(req.get("page", 1)) + size = int(req.get("page_size", 30)) + question = req["question"].strip() if isinstance(req["question"], str) else req["question"] + if not question: + return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}}) + doc_ids = req.get("document_ids", []) + use_kg = req.get("use_kg", False) + toc_enhance = req.get("toc_enhance", False) + langs = req.get("cross_languages", []) + if not isinstance(doc_ids, list): + return get_error_data_result("`documents` should be a list") + if doc_ids: + doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids) + for doc_id in doc_ids: + if doc_id not in doc_ids_list: + return get_error_data_result(f"The datasets don't own the document {doc_id}") + if not doc_ids: + metadata_condition = req.get("metadata_condition") + if metadata_condition: + metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids) + doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")) + if not doc_ids and metadata_condition.get("conditions"): + return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}}) + if metadata_condition and not doc_ids: + doc_ids = ["-999"] + else: + doc_ids = None + similarity_threshold = float(req.get("similarity_threshold", 0.2)) + vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) + top = int(req.get("top_k", 1024)) + if top <= 0: + return get_error_data_result("`top_k` must be greater than 0") + highlight_val = req.get("highlight", None) + if highlight_val is None: + highlight = False + elif isinstance(highlight_val, bool): + highlight = highlight_val + elif isinstance(highlight_val, str) and highlight_val.lower() in ["true", "false"]: + highlight = highlight_val.lower() == "true" + else: + return get_error_data_result("`highlight` should be a boolean") + include_metadata, metadata_fields = _resolve_reference_metadata(req) + try: + tenant_ids = list(set([kb.tenant_id for kb in kbs])) + e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) + if not e: + return get_error_data_result(message="Dataset not found!") + embd_model_config = get_model_config_by_id(kb.tenant_embd_id) if kb.tenant_embd_id else get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) + embd_mdl = LLMBundle(kb.tenant_id, embd_model_config) + + rerank_mdl = None + if req.get("tenant_rerank_id"): + allowed_rerank_tenant_ids = {tenant_id, *[dataset.tenant_id for dataset in kbs]} + rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"], allowed_tenant_ids=allowed_rerank_tenant_ids, requester_tenant_id=tenant_id) + rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) + elif req.get("rerank_id"): + rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK, req["rerank_id"]) + rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) + + if langs: + question = await cross_languages(kb.tenant_id, None, question, langs) + if req.get("keyword", False): + chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) + question += await keyword_extraction(LLMBundle(kb.tenant_id, chat_model_config), question) + + ranks = await settings.retriever.retrieval( + question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, + vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl, + highlight=highlight, rank_feature=label_question(question, kbs), + ) + if toc_enhance: + chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) + cks = await settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, LLMBundle(kb.tenant_id, chat_model_config), size) + if cks: + ranks["chunks"] = cks + ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids) + if use_kg: + chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) + ck = await settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, chat_model_config)) + if ck["content_with_weight"]: + ranks["chunks"].insert(0, ck) + + for c in ranks["chunks"]: + c.pop("vector", None) + if include_metadata: + logging.info("sdk.retrieval reference_metadata enabled dataset_ids=%s fields=%s chunks=%s", kb_ids, sorted(metadata_fields) if metadata_fields else None, len(ranks["chunks"])) + enrich_chunks_with_document_metadata(ranks["chunks"], metadata_fields) + + key_mapping = { + "chunk_id": "id", + "content_with_weight": "content", + "doc_id": "document_id", + "important_kwd": "important_keywords", + "question_kwd": "questions", + "docnm_kwd": "document_keyword", + "kb_id": "dataset_id", + } + ranks["chunks"] = [{key_mapping.get(key, key): value for key, value in chunk.items()} for chunk in ranks["chunks"]] + return get_result(data=ranks) + except Exception as e: + if "not_found" in str(e): + return get_result(message="No chunk found! Check the chunk status please!", code=RetCode.DATA_ERROR) + return server_error_response(e) + + @manager.route("/datasets//documents//chunks", methods=["GET"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/restful_apis/dify_retrieval_api.py similarity index 99% rename from api/apps/sdk/dify_retrieval.py rename to api/apps/restful_apis/dify_retrieval_api.py index 59ea268453..ffe9f247f9 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/restful_apis/dify_retrieval_api.py @@ -317,7 +317,7 @@ async def retrieval(tenant_id): return jsonify({"records": records}) except Exception as e: - if str(e).find("not_found") > 0: + if "not_found" in str(e): return build_error_result( message='No chunk found! Check the chunk status please!', code=RetCode.NOT_FOUND diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py deleted file mode 100644 index f4959f2b1c..0000000000 --- a/api/apps/sdk/doc.py +++ /dev/null @@ -1,468 +0,0 @@ -# -# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import logging -from api.db.db_models import Document, Task -from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type -from api.db.services.doc_metadata_service import DocMetadataService -from api.db.services.document_service import DocumentService -from api.db.services.file2document_service import File2DocumentService -from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.llm_service import LLMBundle -from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks -from api.db.services.tenant_llm_service import TenantLLMService -from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_request_json, get_result, server_error_response, token_required -from common import settings -from common.constants import LLMType, RetCode, TaskStatus -from common.metadata_utils import convert_conditions, meta_filter -from rag.app.tag import label_question -from rag.nlp import search -from rag.prompts.generator import cross_languages, keyword_extraction - -MAXIMUM_OF_UPLOADING_FILES = 256 - - -from api.utils.reference_metadata_utils import ( - enrich_chunks_with_document_metadata, - resolve_reference_metadata_preferences, -) - -def _resolve_reference_metadata(req: dict, search_config: dict | None = None): - return resolve_reference_metadata_preferences(req, search_config) - -def _enrich_chunks_with_document_metadata(chunks: list[dict], metadata_fields=None) -> None: - enrich_chunks_with_document_metadata(chunks, metadata_fields) - -DOC_STOP_PARSING_INVALID_STATE_MESSAGE = "Can't stop parsing document that has not started or already completed" -DOC_STOP_PARSING_INVALID_STATE_ERROR_CODE = "DOC_STOP_PARSING_INVALID_STATE" - -@manager.route("/datasets//chunks", methods=["POST"]) # noqa: F821 -@token_required -async def parse(tenant_id, dataset_id): - """ - Start parsing documents into chunks. - --- - tags: - - Chunks - security: - - ApiKeyAuth: [] - parameters: - - in: path - name: dataset_id - type: string - required: true - description: ID of the dataset. - - in: body - name: body - description: Parsing parameters. - required: true - schema: - type: object - properties: - document_ids: - type: array - items: - type: string - description: List of document IDs to parse. - - in: header - name: Authorization - type: string - required: true - description: Bearer token for authentication. - responses: - 200: - description: Parsing started successfully. - schema: - type: object - """ - if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): - return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") - req = await get_request_json() - if not req.get("document_ids"): - return get_error_data_result("`document_ids` is required") - doc_list = req.get("document_ids") - unique_doc_ids, duplicate_messages = check_duplicate_ids(doc_list, "document") - doc_list = unique_doc_ids - - not_found = [] - success_count = 0 - for id in doc_list: - doc = DocumentService.query(id=id, kb_id=dataset_id) - if not doc: - not_found.append(id) - continue - if not doc: - return get_error_data_result(message=f"You don't own the document {id}.") - info = {"run": "1", "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0} - if ( - DocumentService.filter_update( - [ - Document.id == id, - ((Document.run.is_null(True)) | (Document.run != TaskStatus.RUNNING.value)), - ], - info, - ) - == 0 - ): - return get_error_data_result("Can't parse document that is currently being processed") - settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id) - TaskService.filter_delete([Task.doc_id == id]) - e, doc = DocumentService.get_by_id(id) - doc = doc.to_dict() - doc["tenant_id"] = tenant_id - bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) - queue_tasks(doc, bucket, name, 0) - success_count += 1 - if not_found: - return get_result(message=f"Documents not found: {not_found}", code=RetCode.DATA_ERROR) - if duplicate_messages: - if success_count > 0: - return get_result( - message=f"Partially parsed {success_count} documents with {len(duplicate_messages)} errors", - data={"success_count": success_count, "errors": duplicate_messages}, - ) - else: - return get_error_data_result(message=";".join(duplicate_messages)) - - return get_result() - - -@manager.route("/datasets//chunks", methods=["DELETE"]) # noqa: F821 -@token_required -async def stop_parsing(tenant_id, dataset_id): - """ - Stop parsing documents into chunks. - --- - tags: - - Chunks - security: - - ApiKeyAuth: [] - parameters: - - in: path - name: dataset_id - type: string - required: true - description: ID of the dataset. - - in: body - name: body - description: Stop parsing parameters. - required: true - schema: - type: object - properties: - document_ids: - type: array - items: - type: string - description: List of document IDs to stop parsing. - - in: header - name: Authorization - type: string - required: true - description: Bearer token for authentication. - responses: - 200: - description: Parsing stopped successfully. - schema: - type: object - """ - if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): - return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") - req = await get_request_json() - - if not req.get("document_ids"): - return get_error_data_result("`document_ids` is required") - doc_list = req.get("document_ids") - unique_doc_ids, duplicate_messages = check_duplicate_ids(doc_list, "document") - doc_list = unique_doc_ids - - success_count = 0 - for id in doc_list: - doc = DocumentService.query(id=id, kb_id=dataset_id) - if not doc: - return get_error_data_result(message=f"You don't own the document {id}.") - if doc[0].run != TaskStatus.RUNNING.value: - return construct_json_result( - code=RetCode.DATA_ERROR, - message=DOC_STOP_PARSING_INVALID_STATE_MESSAGE, - data={"error_code": DOC_STOP_PARSING_INVALID_STATE_ERROR_CODE}, - ) - # Send cancellation signal via Redis to stop background task - cancel_all_task_of(id) - info = {"run": "2", "progress": 0, "chunk_num": 0} - DocumentService.update_by_id(id, info) - settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id) - success_count += 1 - if duplicate_messages: - if success_count > 0: - return get_result( - message=f"Partially stopped {success_count} documents with {len(duplicate_messages)} errors", - data={"success_count": success_count, "errors": duplicate_messages}, - ) - else: - return get_error_data_result(message=";".join(duplicate_messages)) - return get_result() - - -@manager.route("/retrieval", methods=["POST"]) # noqa: F821 -@token_required -async def retrieval_test(tenant_id): - """ - Retrieve chunks based on a query. - --- - tags: - - Retrieval - security: - - ApiKeyAuth: [] - parameters: - - in: body - name: body - description: Retrieval parameters. - required: true - schema: - type: object - properties: - dataset_ids: - type: array - items: - type: string - required: true - description: List of dataset IDs to search in. - question: - type: string - required: true - description: Query string. - document_ids: - type: array - items: - type: string - description: List of document IDs to filter. - similarity_threshold: - type: number - format: float - description: Similarity threshold. - vector_similarity_weight: - type: number - format: float - description: Vector similarity weight. - top_k: - type: integer - description: Maximum number of chunks to return. - highlight: - type: boolean - description: Whether to highlight matched content. - metadata_condition: - type: object - description: metadata filter condition. - - in: header - name: Authorization - type: string - required: true - description: Bearer token for authentication. - responses: - 200: - description: Retrieval results. - schema: - type: object - properties: - chunks: - type: array - items: - type: object - properties: - id: - type: string - description: Chunk ID. - content: - type: string - description: Chunk content. - document_id: - type: string - description: ID of the document. - dataset_id: - type: string - description: ID of the dataset. - similarity: - type: number - format: float - description: Similarity score. - """ - req = await get_request_json() - if not req.get("dataset_ids"): - return get_error_data_result("`dataset_ids` is required.") - kb_ids = req["dataset_ids"] - if not isinstance(kb_ids, list): - return get_error_data_result("`dataset_ids` should be a list") - for id in kb_ids: - if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id): - return get_error_data_result(f"You don't own the dataset {id}.") - kbs = KnowledgebaseService.get_by_ids(kb_ids) - embd_nms = list(set([TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs])) # remove vendor suffix for comparison - if len(embd_nms) != 1: - return get_result( - message='Datasets use different embedding models."', - code=RetCode.DATA_ERROR, - ) - if "question" not in req: - return get_error_data_result("`question` is required.") - page = int(req.get("page", 1)) - size = int(req.get("page_size", 30)) - question = req["question"] - # Trim whitespace and validate question - if isinstance(question, str): - question = question.strip() - # Return empty result if question is empty or whitespace-only - if not question: - return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}}) - doc_ids = req.get("document_ids", []) - use_kg = req.get("use_kg", False) - toc_enhance = req.get("toc_enhance", False) - langs = req.get("cross_languages", []) - if not isinstance(doc_ids, list): - return get_error_data_result("`documents` should be a list") - if doc_ids: - doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids) - for doc_id in doc_ids: - if doc_id not in doc_ids_list: - return get_error_data_result(f"The datasets don't own the document {doc_id}") - if not doc_ids: - metadata_condition = req.get("metadata_condition") - if metadata_condition: - metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids) - doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")) - # If metadata_condition has conditions but no docs match, return empty result - if not doc_ids and metadata_condition.get("conditions"): - return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}}) - if metadata_condition and not doc_ids: - doc_ids = ["-999"] - else: - # If doc_ids is None all documents of the datasets are used - doc_ids = None - similarity_threshold = float(req.get("similarity_threshold", 0.2)) - vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) - top = int(req.get("top_k", 1024)) - if top <= 0: - return get_error_data_result("`top_k` must be greater than 0") - highlight_val = req.get("highlight", None) - if highlight_val is None: - highlight = False - elif isinstance(highlight_val, bool): - highlight = highlight_val - elif isinstance(highlight_val, str): - if highlight_val.lower() in ["true", "false"]: - highlight = highlight_val.lower() == "true" - else: - return get_error_data_result("`highlight` should be a boolean") - else: - return get_error_data_result("`highlight` should be a boolean") - include_metadata, metadata_fields = _resolve_reference_metadata(req) - try: - tenant_ids = list(set([kb.tenant_id for kb in kbs])) - e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) - if not e: - return get_error_data_result(message="Dataset not found!") - if kb.tenant_embd_id: - embd_model_config = get_model_config_by_id(kb.tenant_embd_id) - else: - embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) - embd_mdl = LLMBundle(kb.tenant_id, embd_model_config) - - rerank_mdl = None - if req.get("tenant_rerank_id"): - allowed_rerank_tenant_ids = {tenant_id, *[dataset.tenant_id for dataset in kbs]} - rerank_model_config = get_model_config_by_id( - req["tenant_rerank_id"], - allowed_tenant_ids=allowed_rerank_tenant_ids, - requester_tenant_id=tenant_id, - ) - rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) - elif req.get("rerank_id"): - rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK, req["rerank_id"]) - rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) - - if langs: - question = await cross_languages(kb.tenant_id, None, question, langs) - - if req.get("keyword", False): - chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) - chat_mdl = LLMBundle(kb.tenant_id, chat_model_config) - question += await keyword_extraction(chat_mdl, question) - - ranks = await settings.retriever.retrieval( - question, - embd_mdl, - tenant_ids, - kb_ids, - page, - size, - similarity_threshold, - vector_similarity_weight, - top, - doc_ids, - rerank_mdl=rerank_mdl, - highlight=highlight, - rank_feature=label_question(question, kbs), - ) - if toc_enhance: - chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) - chat_mdl = LLMBundle(kb.tenant_id, chat_model_config) - cks = await settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size) - if cks: - ranks["chunks"] = cks - ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids) - if use_kg: - chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) - ck = await settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, chat_model_config)) - if ck["content_with_weight"]: - ranks["chunks"].insert(0, ck) - - for c in ranks["chunks"]: - c.pop("vector", None) - - if include_metadata: - logging.info( - "sdk.retrieval reference_metadata enabled dataset_ids=%s fields=%s chunks=%s", - kb_ids, - sorted(metadata_fields) if metadata_fields else None, - len(ranks["chunks"]), - ) - enrich_chunks_with_document_metadata(ranks["chunks"], metadata_fields) - - ##rename keys - renamed_chunks = [] - for chunk in ranks["chunks"]: - key_mapping = { - "chunk_id": "id", - "content_with_weight": "content", - "doc_id": "document_id", - "important_kwd": "important_keywords", - "question_kwd": "questions", - "docnm_kwd": "document_keyword", - "kb_id": "dataset_id", - } - rename_chunk = {} - for key, value in chunk.items(): - new_key = key_mapping.get(key, key) - rename_chunk[new_key] = value - renamed_chunks.append(rename_chunk) - ranks["chunks"] = renamed_chunks - return get_result(data=ranks) - except Exception as e: - if str(e).find("not_found") > 0: - return get_result( - message="No chunk found! Check the chunk status please!", - code=RetCode.DATA_ERROR, - ) - return server_error_response(e) diff --git a/test/testcases/restful_api/test_dify_retrieval_routes_unit.py b/test/testcases/restful_api/test_dify_retrieval_routes_unit.py index b348503030..1fad7ebae5 100644 --- a/test/testcases/restful_api/test_dify_retrieval_routes_unit.py +++ b/test/testcases/restful_api/test_dify_retrieval_routes_unit.py @@ -242,7 +242,7 @@ def _load_dify_retrieval_module(monkeypatch): monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod) module_name = "test_dify_retrieval_routes_unit_module" - module_path = repo_root / "api" / "apps" / "sdk" / "dify_retrieval.py" + module_path = repo_root / "api" / "apps" / "restful_apis" / "dify_retrieval_api.py" spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) module.manager = _DummyManager() diff --git a/test/testcases/restful_api/test_retrieval.py b/test/testcases/restful_api/test_retrieval.py index 5f6531a8c3..6bfe40b28e 100644 --- a/test/testcases/restful_api/test_retrieval.py +++ b/test/testcases/restful_api/test_retrieval.py @@ -84,7 +84,7 @@ def test_multi_dataset_search_with_metadata_filter(rest_client, ensure_parsed_do @pytest.mark.p2 def test_retrieval_compatibility_endpoint(rest_client, ensure_parsed_document): dataset_id, _ = ensure_parsed_document() - # /api/v1/retrieval is SDK compatibility endpoint from api/apps/sdk/doc.py. + # /api/v1/retrieval is SDK compatibility endpoint registered from chunk_api.py. res = rest_client.post( "/retrieval", json={"dataset_ids": [dataset_id], "question": "test TXT file", "top_k": 5}, diff --git a/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py b/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py index 6f4927b8d0..e73f18959c 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py +++ b/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py @@ -253,7 +253,7 @@ def _load_dify_retrieval_module(monkeypatch): monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod) module_name = "test_dify_retrieval_routes_unit_module" - module_path = repo_root / "api" / "apps" / "sdk" / "dify_retrieval.py" + module_path = repo_root / "api" / "apps" / "restful_apis" / "dify_retrieval_api.py" spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) module.manager = _DummyManager() diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py index 08055a57e6..5b994ea525 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py @@ -137,12 +137,31 @@ def _load_doc_module(monkeypatch): common_pkg.__path__ = [str(repo_root / "common")] monkeypatch.setitem(sys.modules, "common", common_pkg) + apps_mod = ModuleType("api.apps") + apps_mod.login_required = lambda func: func + monkeypatch.setitem(sys.modules, "api.apps", apps_mod) + common_settings_mod = ModuleType("common.settings") common_settings_mod.retriever = SimpleNamespace() common_settings_mod.kg_retriever = SimpleNamespace() common_settings_mod.STORAGE_IMPL = SimpleNamespace(get=lambda *_args, **_kwargs: b"", rm=lambda *_args, **_kwargs: None) monkeypatch.setitem(sys.modules, "common.settings", common_settings_mod) + common_misc_utils_mod = ModuleType("common.misc_utils") + async def _thread_pool_exec(func, *args, **kwargs): + return func(*args, **kwargs) + common_misc_utils_mod.thread_pool_exec = _thread_pool_exec + monkeypatch.setitem(sys.modules, "common.misc_utils", common_misc_utils_mod) + + common_string_utils_mod = ModuleType("common.string_utils") + common_string_utils_mod.is_content_empty = lambda content: content is None or not str(content).strip() + common_string_utils_mod.remove_redundant_spaces = lambda text: " ".join(str(text).split()) + monkeypatch.setitem(sys.modules, "common.string_utils", common_string_utils_mod) + + tag_feature_utils_mod = ModuleType("common.tag_feature_utils") + tag_feature_utils_mod.validate_tag_features = lambda value: value + monkeypatch.setitem(sys.modules, "common.tag_feature_utils", tag_feature_utils_mod) + class _FakeExpr: def __or__(self, other): return self @@ -219,6 +238,7 @@ def _load_doc_module(monkeypatch): monkeypatch.setitem(sys.modules, "api.db.services.task_service", task_service_mod) api_utils_mod = ModuleType("api.utils.api_utils") + api_utils_mod.add_tenant_id_to_kwargs = lambda func: func api_utils_mod.check_duplicate_ids = lambda ids, _kind="item": (ids, []) api_utils_mod.construct_json_result = lambda code=0, message="success", data=None: {"code": code, "message": message, "data": data} api_utils_mod.get_error_data_result = lambda message="Sorry! Data missing!", code=102: {"code": code, "message": message} @@ -239,6 +259,32 @@ def _load_doc_module(monkeypatch): api_utils_mod.token_required = _token_required monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + image_utils_mod = ModuleType("api.utils.image_utils") + image_utils_mod.store_chunk_image = lambda *_args, **_kwargs: None + monkeypatch.setitem(sys.modules, "api.utils.image_utils", image_utils_mod) + + reference_metadata_utils_mod = ModuleType("api.utils.reference_metadata_utils") + reference_metadata_utils_mod.resolve_reference_metadata_preferences = ( + lambda req, *_args, **_kwargs: ( + bool((req.get("reference_metadata") or {}).get("include")), + set((req.get("reference_metadata") or {}).get("fields") or []), + ) + ) + def _enrich_chunks_with_document_metadata(chunks, metadata_fields=None): + for chunk in chunks: + doc_id = chunk.get("doc_id") or chunk.get("document_id") + if not doc_id: + continue + metadata = doc_metadata_service_mod.DocMetadataService.get_metadata_for_documents([doc_id], chunk.get("kb_id")) + document_metadata = dict(metadata.get(doc_id, {})) + if metadata_fields: + document_metadata = {key: value for key, value in document_metadata.items() if key in metadata_fields} + if document_metadata: + chunk["document_metadata"] = document_metadata + + reference_metadata_utils_mod.enrich_chunks_with_document_metadata = _enrich_chunks_with_document_metadata + monkeypatch.setitem(sys.modules, "api.utils.reference_metadata_utils", reference_metadata_utils_mod) + common_metadata_utils_mod = ModuleType("common.metadata_utils") common_metadata_utils_mod.convert_conditions = lambda conditions: conditions common_metadata_utils_mod.meta_filter = lambda *_args, **_kwargs: [] @@ -446,7 +492,7 @@ def _load_doc_module(monkeypatch): tenant_model_service_mod.get_tenant_default_model_by_type = _get_tenant_default_model_by_type monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod) - module_path = repo_root / "api" / "apps" / "sdk" / "doc.py" + module_path = repo_root / "api" / "apps" / "restful_apis" / "chunk_api.py" spec = importlib.util.spec_from_file_location("test_doc_sdk_routes_unit", module_path) module = importlib.util.module_from_spec(spec) module.manager = _DummyManager() diff --git a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py index 66e5ed7c07..aa28c8385e 100644 --- a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py @@ -690,7 +690,7 @@ def _load_session_module(monkeypatch): ) monkeypatch.setitem(sys.modules, "api.db.services.user_canvas_version", user_canvas_version_mod) - module_path = repo_root / "api" / "apps" / "sdk" / "session.py" + module_path = repo_root / "api" / "apps" / "restful_apis" / "bot_api.py" spec = importlib.util.spec_from_file_location("test_session_sdk_routes_unit_module", module_path) module = importlib.util.module_from_spec(spec) module.manager = _DummyManager() diff --git a/test/unit_test/api/apps/sdk/test_dify_retrieval.py b/test/unit_test/api/apps/sdk/test_dify_retrieval.py index ba830177af..72715d72f9 100644 --- a/test/unit_test/api/apps/sdk/test_dify_retrieval.py +++ b/test/unit_test/api/apps/sdk/test_dify_retrieval.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -"""Regression tests for retrieval in api/apps/sdk/dify_retrieval.py. +"""Regression tests for retrieval in api/apps/restful_apis/dify_retrieval_api.py. Issue #15027: cross-tenant knowledge-base access via POST /api/v1/dify/retrieval. The handler authenticated the caller via @apikey_required (resolving @@ -84,7 +84,7 @@ class _FakeKGRetriever: def _load_dify_retrieval(monkeypatch, *, kb, accessible, request_body, chunks=None): - """Load dify_retrieval.py with minimum stubs to exercise the retrieval handler.""" + """Load dify_retrieval_api.py with minimum stubs to exercise the retrieval handler.""" _stub( monkeypatch, "api.utils.api_utils", @@ -148,7 +148,7 @@ def _load_dify_retrieval(monkeypatch, *, kb, accessible, request_body, chunks=No monkeypatch.setitem(sys.modules, "quart", quart_stub) repo_root = Path(__file__).resolve().parents[5] - module_path = repo_root / "api" / "apps" / "sdk" / "dify_retrieval.py" + module_path = repo_root / "api" / "apps" / "restful_apis" / "dify_retrieval_api.py" spec = importlib.util.spec_from_file_location("test_dify_retrieval_module", module_path) module = importlib.util.module_from_spec(spec) module.manager = _PassthroughManager()