diff --git a/api/apps/restful_apis/agent_api.py b/api/apps/restful_apis/agent_api.py index 0837b23ae1..a4eac92d0e 100644 --- a/api/apps/restful_apis/agent_api.py +++ b/api/apps/restful_apis/agent_api.py @@ -575,7 +575,14 @@ async def _iter_session_completion_events(tenant_id, agent_id, req, return_trace yield ans continue - if event in ["message", "message_end"]: + if event in ["message", "message_end", "user_inputs", "workflow_finished"]: + if event in ["user_inputs", "workflow_finished"]: + logging.debug( + "Forwarding session completion event: tenant_id=%s agent_id=%s event=%s", + tenant_id, + agent_id, + event, + ) yield ans @@ -1564,7 +1571,10 @@ async def agent_chat_completion(tenant_id, agent_id=None): "trace": [copy.deepcopy(data)], } ) - final_ans = ans + if ans.get("event") == "message_end": + final_ans = ans + elif ans.get("event") == "user_inputs" and not final_ans: + final_ans = ans except Exception as exc: return get_result(data=f"**ERROR**: {str(exc)}") diff --git a/api/apps/restful_apis/bot_api.py b/api/apps/restful_apis/bot_api.py index efc1836ab5..a1878be692 100644 --- a/api/apps/restful_apis/bot_api.py +++ b/api/apps/restful_apis/bot_api.py @@ -24,6 +24,7 @@ from quart import Response, request from agent.canvas import Canvas from api.apps import AUTH_BETA, login_required +from api.db.db_models import APIToken from api.db.services.api_service import API4ConversationService from api.db.services.canvas_service import UserCanvasService from api.db.services.canvas_service import completion as agent_completion @@ -53,6 +54,13 @@ from api.utils.reference_metadata_utils import ( logger = logging.getLogger(__name__) +def _get_sdk_authorization_token(): + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + return "" + return auth_header[len("Bearer "):].strip() + + @manager.route("/chatbots//completions", methods=["POST"]) # noqa: F821 @login_required(auth_types=AUTH_BETA) @add_tenant_id_to_kwargs diff --git a/api/apps/restful_apis/document_api.py b/api/apps/restful_apis/document_api.py index 86de8b4582..07edadfe5c 100644 --- a/api/apps/restful_apis/document_api.py +++ b/api/apps/restful_apis/document_api.py @@ -29,8 +29,8 @@ from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX from api.apps.services.document_api_service import validate_document_update_fields, map_doc_keys, \ map_doc_keys_with_run_status, update_document_name_only, update_chunk_method, update_document_status_only, \ reset_document_for_reparse -from api.db import VALID_FILE_TYPES, FileType, DB -from api.db.db_models import API4Conversation +from api.db import VALID_FILE_TYPES, FileType +from api.db.db_models import API4Conversation, DB from api.db.services import duplicate_name from api.db.services.doc_metadata_service import DocMetadataService from api.db.db_models import Task diff --git a/common/file_utils.py b/common/file_utils.py index 3d7455b6b4..af691f9fee 100644 --- a/common/file_utils.py +++ b/common/file_utils.py @@ -16,21 +16,30 @@ import os -PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") +PROJECT_BASE = None + + +def _default_project_base_directory(): + return os.path.abspath( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + os.pardir, + ) + ) + def get_project_base_directory(*args): global PROJECT_BASE - if PROJECT_BASE is None: - PROJECT_BASE = os.path.abspath( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - os.pardir, - ) - ) + + project_base = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") + if not project_base: + if PROJECT_BASE is None: + PROJECT_BASE = _default_project_base_directory() + project_base = PROJECT_BASE if args: - return os.path.join(PROJECT_BASE, *args) - return PROJECT_BASE + return os.path.join(project_base, *args) + return project_base def traversal_files(base): for root, ds, fs in os.walk(base): diff --git a/common/token_utils.py b/common/token_utils.py index 981e98a1b5..67e421dd1f 100644 --- a/common/token_utils.py +++ b/common/token_utils.py @@ -14,13 +14,27 @@ # limitations under the License. # - import os +import shutil import tiktoken from common.file_utils import get_project_base_directory -tiktoken_cache_dir = get_project_base_directory() + +def _ensure_tiktoken_cache() -> str: + cache_dir = get_project_base_directory() + os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir + + bundled_encoding_path = get_project_base_directory("ragflow_deps", "cl100k_base.tiktoken") + cached_encoding_path = os.path.join(cache_dir, "9b5ad71b2ce5302211f9c61530b329a4922fc6a4") + + if os.path.exists(bundled_encoding_path) and not os.path.exists(cached_encoding_path): + shutil.copyfile(bundled_encoding_path, cached_encoding_path) + + return cache_dir + + +tiktoken_cache_dir = _ensure_tiktoken_cache() os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir # encoder = tiktoken.encoding_for_model("gpt-3.5-turbo") encoder = tiktoken.get_encoding("cl100k_base") diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index cfe2c3e203..47d02513c7 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -98,7 +98,7 @@ class RAGFlowPdfParser: model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc") self.updown_cnt_mdl.load_model(os.path.join(model_dir, "updown_concat_xgb.model")) except Exception: - model_dir = snapshot_download(repo_id="InfiniFlow/text_concat_xgb_v1.0", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), local_dir_use_symlinks=False) + model_dir = snapshot_download(repo_id="InfiniFlow/text_concat_xgb_v1.0", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc")) self.updown_cnt_mdl.load_model(os.path.join(model_dir, "updown_concat_xgb.model")) self.page_from = 0 diff --git a/deepdoc/vision/layout_recognizer.py b/deepdoc/vision/layout_recognizer.py index 9befbe2936..21bc1f4bad 100644 --- a/deepdoc/vision/layout_recognizer.py +++ b/deepdoc/vision/layout_recognizer.py @@ -62,7 +62,7 @@ class LayoutRecognizer(Recognizer): model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc") super().__init__(self.labels, domain, model_dir) except Exception: - model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), local_dir_use_symlinks=False) + model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc")) super().__init__(self.labels, domain, model_dir) def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True): diff --git a/deepdoc/vision/ocr.py b/deepdoc/vision/ocr.py index d5e546a3c5..f5508dc4f1 100644 --- a/deepdoc/vision/ocr.py +++ b/deepdoc/vision/ocr.py @@ -570,9 +570,10 @@ class OCR: self.text_recognizer = [TextRecognizer(model_dir)] except Exception: - model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", - local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), - local_dir_use_symlinks=False) + model_dir = snapshot_download( + repo_id="InfiniFlow/deepdoc", + local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), + ) if settings.PARALLEL_DEVICES > 0: self.text_detector = [] diff --git a/deepdoc/vision/table_structure_recognizer.py b/deepdoc/vision/table_structure_recognizer.py index 997bc84b62..6d2e0c906d 100644 --- a/deepdoc/vision/table_structure_recognizer.py +++ b/deepdoc/vision/table_structure_recognizer.py @@ -47,7 +47,6 @@ class TableStructureRecognizer(Recognizer): snapshot_download( repo_id="InfiniFlow/deepdoc", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), - local_dir_use_symlinks=False, ), ) diff --git a/test/unit_test/agent/test_pipeline_chunker.py b/test/unit_test/agent/test_pipeline_chunker.py index d19981a73d..b03c15ddc9 100644 --- a/test/unit_test/agent/test_pipeline_chunker.py +++ b/test/unit_test/agent/test_pipeline_chunker.py @@ -24,6 +24,7 @@ end-to-end behavior is intentionally left to higher-level integration tests. from __future__ import annotations import sys +from importlib import import_module, reload from unittest.mock import MagicMock import pytest @@ -38,102 +39,90 @@ pytestmark = pytest.mark.p2 # runtime environment. We track every key we install and restore the prior # sys.modules state in teardown_module so the stubs don't leak into other # test files. -_SENTINEL_ABSENT = object() -_INSTALLED_STUBS: dict[str, object] = {} +@pytest.fixture(scope="module") +def pipeline_chunker_module(): + """Import pipeline_chunker with rag.app parser modules stubbed locally.""" + stubbed_names = [ + "api.db.services.file_service", + "deepdoc.vision.ocr", + "deepdoc.parser.figure_parser", + "rag.app.picture", + "rag.app.audio", + "rag.app.resume", + "rag.app.naive", + "rag.app.paper", + "rag.app.book", + "rag.app.presentation", + "rag.app.manual", + "rag.app.laws", + "rag.app.qa", + "rag.app.table", + "rag.app.one", + "rag.app.email", + "rag.app.tag", + ] + original_modules = {name: sys.modules.get(name) for name in stubbed_names} + file_service_stub = MagicMock() + file_service_stub.FileService = MagicMock() -def _install_stub(name: str, stub: object) -> None: - """Insert ``stub`` into sys.modules and remember the prior entry.""" - if name in _INSTALLED_STUBS: - return - _INSTALLED_STUBS[name] = sys.modules.get(name, _SENTINEL_ABSENT) - sys.modules[name] = stub + try: + sys.modules["api.db.services.file_service"] = file_service_stub + for name in stubbed_names[1:]: + stub = MagicMock() + stub.chunk = MagicMock(return_value=[{"content_with_weight": "stub"}]) + sys.modules[name] = stub - -_file_service_stub = MagicMock() -_file_service_stub.FileService = MagicMock() -if "api.db.services.file_service" not in sys.modules: - _install_stub("api.db.services.file_service", _file_service_stub) - -for mod in [ - "deepdoc.vision.ocr", - "deepdoc.parser.figure_parser", - "rag.app.picture", - "rag.app.audio", - "rag.app.resume", - "rag.app.naive", - "rag.app.paper", - "rag.app.book", - "rag.app.presentation", - "rag.app.manual", - "rag.app.laws", - "rag.app.qa", - "rag.app.table", - "rag.app.one", - "rag.app.email", - "rag.app.tag", -]: - if mod not in sys.modules: - stub = MagicMock() - stub.chunk = MagicMock(return_value=[{"content_with_weight": "stub"}]) - _install_stub(mod, stub) - - -def teardown_module(module): - """Restore sys.modules to its pre-stub state when this file's tests finish.""" - for name, original in _INSTALLED_STUBS.items(): - if original is _SENTINEL_ABSENT: - sys.modules.pop(name, None) - else: - sys.modules[name] = original - _INSTALLED_STUBS.clear() - -from agent.component.pipeline_chunker import ( # noqa: E402 - _PARSER_MODULES, - PipelineChunkerParam, - _load_chunker, -) + module = import_module("agent.component.pipeline_chunker") + module = reload(module) + yield module + finally: + for name, original in original_modules.items(): + if original is None: + sys.modules.pop(name, None) + else: + sys.modules[name] = original class TestPipelineChunkerParam: """Validate parameter parsing and the strategy whitelist.""" - def test_default_param_validates(self): + def test_default_param_validates(self, pipeline_chunker_module): """A freshly constructed param object should pass ``check()``.""" - p = PipelineChunkerParam() + p = pipeline_chunker_module.PipelineChunkerParam() assert p.check() is True - def test_accepts_each_known_parser(self): + def test_accepts_each_known_parser(self, pipeline_chunker_module): """Every parser id in the lookup table must validate.""" - for parser_id in _PARSER_MODULES: - p = PipelineChunkerParam() + for parser_id in pipeline_chunker_module._PARSER_MODULES: + p = pipeline_chunker_module.PipelineChunkerParam() p.parser_id = parser_id assert p.check() is True - def test_rejects_unknown_parser(self): + def test_rejects_unknown_parser(self, pipeline_chunker_module): """Unknown parser ids must raise ``ValueError`` at validation time.""" - p = PipelineChunkerParam() + p = pipeline_chunker_module.PipelineChunkerParam() p.parser_id = "nonsense-parser" with pytest.raises(ValueError): p.check() - def test_rejects_non_dict_parser_config(self): + def test_rejects_non_dict_parser_config(self, pipeline_chunker_module): """``parser_config`` must be a dict; anything else must raise.""" - p = PipelineChunkerParam() + p = pipeline_chunker_module.PipelineChunkerParam() p.parser_config = "not a dict" with pytest.raises(ValueError): p.check() - def test_rejects_negative_pages(self): + def test_rejects_negative_pages(self, pipeline_chunker_module): """Negative page indices must raise ``ValueError``.""" - p = PipelineChunkerParam() + p = pipeline_chunker_module.PipelineChunkerParam() p.from_page = -1 with pytest.raises(ValueError): p.check() - def test_rejects_inverted_page_range(self): + def test_rejects_inverted_page_range(self, pipeline_chunker_module): """``from_page`` greater than ``to_page`` must raise ``ValueError``.""" - p = PipelineChunkerParam() + p = pipeline_chunker_module.PipelineChunkerParam() p.from_page = 10 p.to_page = 5 with pytest.raises(ValueError, match="from_page must be <= to_page"): @@ -143,13 +132,13 @@ class TestPipelineChunkerParam: class TestLoadChunker: """Verify the lazy parser-id -> chunker callable resolver.""" - def test_load_chunker_returns_callable_for_each_known_parser(self): + def test_load_chunker_returns_callable_for_each_known_parser(self, pipeline_chunker_module): """Every known parser id should resolve to a callable ``chunk`` function.""" - for parser_id in _PARSER_MODULES: - chunker = _load_chunker(parser_id) + for parser_id in pipeline_chunker_module._PARSER_MODULES: + chunker = pipeline_chunker_module._load_chunker(parser_id) assert callable(chunker) - def test_load_chunker_raises_for_unknown_parser(self): + def test_load_chunker_raises_for_unknown_parser(self, pipeline_chunker_module): """Unknown parser ids should raise ``KeyError`` from the lookup.""" with pytest.raises(KeyError): - _load_chunker("not-a-real-parser") + pipeline_chunker_module._load_chunker("not-a-real-parser") diff --git a/test/unit_test/api/apps/restful_apis/test_agentbots_access_control.py b/test/unit_test/api/apps/restful_apis/test_agentbots_access_control.py index a778147cb9..beff9cba9a 100644 --- a/test/unit_test/api/apps/restful_apis/test_agentbots_access_control.py +++ b/test/unit_test/api/apps/restful_apis/test_agentbots_access_control.py @@ -65,8 +65,8 @@ def _load_bot_api(monkeypatch, *, accessible, calls): calls["completion"] = True async def _gen(): - if False: - yield "" + yield 'data: {"event":"message","data":{"content":"ok"}}\n\n' + yield 'data: {"event":"message_end","data":{"content":"ok"}}\n\n' return _gen() _stub(monkeypatch, "quart", Response=lambda *a, **k: SimpleNamespace(headers=SimpleNamespace(add_header=lambda *aa, **kk: None)), request=SimpleNamespace()) @@ -83,7 +83,12 @@ def _load_bot_api(monkeypatch, *, accessible, calls): _stub(monkeypatch, "api.db.services.llm_service", LLMBundle=SimpleNamespace()) _stub(monkeypatch, "common.metadata_utils", apply_meta_data_filter=lambda *_a, **_k: None) _stub(monkeypatch, "api.db.services.search_service", SearchService=SimpleNamespace()) - _stub(monkeypatch, "api.db.services.user_service", UserTenantService=SimpleNamespace()) + _stub( + monkeypatch, + "api.db.services.user_service", + TenantService=SimpleNamespace(), + UserTenantService=SimpleNamespace(), + ) _stub(monkeypatch, "api.db.joint_services.tenant_model_service", get_tenant_default_model_by_type=lambda *_a, **_k: None, get_model_config_from_provider_instance=lambda *_a, **_k: None) _stub(monkeypatch, "common.misc_utils", get_uuid=lambda: "uuid", thread_pool_exec=_passthrough_thread_pool_exec) _stub( @@ -118,7 +123,7 @@ def _load_bot_api(monkeypatch, *, accessible, calls): async def _async_empty_json(): - return {} + return {"stream": False} @pytest.mark.p1 diff --git a/test/unit_test/common/test_file_utils.py b/test/unit_test/common/test_file_utils.py index 616312cde2..6a38f51ad0 100644 --- a/test/unit_test/common/test_file_utils.py +++ b/test/unit_test/common/test_file_utils.py @@ -53,23 +53,24 @@ class TestGetProjectBaseDirectory: def test_uses_environment_variable_when_available(self): """Test that function uses RAG_PROJECT_BASE environment variable when set""" test_path = "/custom/project/path" - - file_utils.PROJECT_BASE = test_path - - result = get_project_base_directory() + with patch.dict(os.environ, {"RAG_PROJECT_BASE": test_path}, clear=False): + result = get_project_base_directory() assert result == test_path def test_calculates_default_path_when_no_env_vars(self): """Test that function calculates default path when no environment variables are set""" with patch.dict(os.environ, {}, clear=True): # Clear all environment variables - # Reset the global variable to force re-initialization + original_base = file_utils.PROJECT_BASE + try: + file_utils.PROJECT_BASE = None + result = get_project_base_directory() - result = get_project_base_directory() - - # Should return a valid absolute path - assert result is not None - assert os.path.isabs(result) - assert os.path.basename(result) != "" # Should not be root directory + # Should return a valid absolute path + assert result is not None + assert os.path.isabs(result) + assert os.path.basename(result) != "" # Should not be root directory + finally: + file_utils.PROJECT_BASE = original_base def test_caches_project_base_value(self): """Test that PROJECT_BASE is cached after first calculation""" diff --git a/test/unit_test/rag/app/test_markdown_image_ssrf.py b/test/unit_test/rag/app/test_markdown_image_ssrf.py index 6ef2487bab..2c3d2d119f 100644 --- a/test/unit_test/rag/app/test_markdown_image_ssrf.py +++ b/test/unit_test/rag/app/test_markdown_image_ssrf.py @@ -28,25 +28,44 @@ from __future__ import annotations import io import sys +from importlib import import_module, reload from unittest.mock import MagicMock, patch -# Mock heavy modules that trigger ONNX/OCR model loading or optional parser -# backends at import time so importing rag.app.naive stays lightweight. -for _mod in [ - "deepdoc.vision.ocr", - "deepdoc.parser.figure_parser", - "deepdoc.parser.docling_parser", - "deepdoc.parser.tcadp_parser", - "rag.app.picture", -]: - if _mod not in sys.modules: - sys.modules[_mod] = MagicMock() - import pytest from PIL import Image from common import ssrf_guard -from rag.app.naive import MAX_IMAGE_REDIRECTS, Markdown + + +@pytest.fixture(scope="module") +def naive_module(): + """Load rag.app.naive with heavy optional dependencies stubbed locally.""" + stub_names = [ + "deepdoc.vision.ocr", + "deepdoc.parser.figure_parser", + "deepdoc.parser.docling_parser", + "deepdoc.parser.tcadp_parser", + "rag.app.picture", + ] + original_modules = {name: sys.modules.get(name) for name in stub_names} + + try: + for name in stub_names: + sys.modules[name] = MagicMock() + module = import_module("rag.app.naive") + module = reload(module) + yield module + finally: + for name, original in original_modules.items(): + if original is None: + sys.modules.pop(name, None) + else: + sys.modules[name] = original + + +@pytest.fixture(scope="module") +def max_image_redirects(naive_module): + return naive_module.MAX_IMAGE_REDIRECTS def _png_bytes() -> bytes: @@ -68,8 +87,8 @@ class _Resp: @pytest.fixture -def parser(): - return Markdown(128) +def parser(naive_module): + return naive_module.Markdown(128) @pytest.mark.p1 @@ -127,7 +146,7 @@ def test_fetches_legitimate_public_image(parser): @pytest.mark.p1 -def test_redirect_chain_is_bounded(parser): +def test_redirect_chain_is_bounded(parser, max_image_redirects): """An endless redirect loop is abandoned instead of being followed forever.""" loop = _Resp(302, headers={"Location": "http://public.example/next"}) with ( @@ -136,5 +155,5 @@ def test_redirect_chain_is_bounded(parser): ): images, _ = parser.load_images_from_urls(["http://public.example/start"]) - assert get.call_count == MAX_IMAGE_REDIRECTS + 1 + assert get.call_count == max_image_redirects + 1 assert images == [] diff --git a/test/unit_test/rag/app/test_table_chunk_column_roles.py b/test/unit_test/rag/app/test_table_chunk_column_roles.py index 40eed2ae5b..20025c96c7 100644 --- a/test/unit_test/rag/app/test_table_chunk_column_roles.py +++ b/test/unit_test/rag/app/test_table_chunk_column_roles.py @@ -20,18 +20,9 @@ from __future__ import annotations import sys +from importlib import import_module, reload from unittest.mock import MagicMock, patch -# Mock heavy modules that trigger ONNX model loading at import time -# table.py -> deepdoc.parser.figure_parser -> rag.app.picture -> OCR() -for mod in [ - "deepdoc.vision.ocr", - "deepdoc.parser.figure_parser", - "rag.app.picture", -]: - if mod not in sys.modules: - sys.modules[mod] = MagicMock() - import warnings # Importing rag.app.table pulls api -> rag.llm -> deepdoc -> xgboost; xgboost may warn on @@ -42,7 +33,6 @@ import pkg_resources # noqa: F401 — stabilize xgboost import during collectio import pytest import common.settings as settings -from rag.app.table import chunk # chunk() removes columns named id, _id, index, idx — use row_id instead of id. TEST_CSV = b"""row_id,title,content,country,category @@ -76,14 +66,38 @@ def _stub_rag_tokenizer(monkeypatch): monkeypatch.setattr("rag.nlp.rag_tokenizer.fine_grained_tokenize", fake_tokenize) +@pytest.fixture(scope="module") +def table_module(): + """Load rag.app.table with heavy optional dependencies stubbed locally.""" + stub_names = [ + "deepdoc.vision.ocr", + "deepdoc.parser.figure_parser", + "rag.app.picture", + ] + original_modules = {name: sys.modules.get(name) for name in stub_names} + + try: + for name in stub_names: + sys.modules[name] = MagicMock() + module = import_module("rag.app.table") + module = reload(module) + yield module + finally: + for name, original in original_modules.items(): + if original is None: + sys.modules.pop(name, None) + else: + sys.modules[name] = original + + @pytest.fixture def mock_update_kb(): with patch("rag.app.table.KnowledgebaseService.update_parser_config") as m: yield m -def _run_chunk(parser_config: dict, mock_update_kb: MagicMock): - return chunk( +def _run_chunk(table_module, parser_config: dict, mock_update_kb: MagicMock): + return table_module.chunk( FILENAME, binary=TEST_CSV, callback=_noop_callback, @@ -93,9 +107,9 @@ def _run_chunk(parser_config: dict, mock_update_kb: MagicMock): ) -def test_chunk_auto_mode_all_columns_in_text_and_stored(mock_update_kb: MagicMock): +def test_chunk_auto_mode_all_columns_in_text_and_stored(table_module, mock_update_kb: MagicMock): parser_config: dict = {} - chunks = _run_chunk(parser_config, mock_update_kb) + chunks = _run_chunk(table_module, parser_config, mock_update_kb) assert len(chunks) == 3 first = chunks[0] cww = first["content_with_weight"] @@ -109,7 +123,7 @@ def test_chunk_auto_mode_all_columns_in_text_and_stored(mock_update_kb: MagicMoc assert "title_raw" in first and "country_raw" in first -def test_chunk_manual_mode_indexing_only(mock_update_kb: MagicMock): +def test_chunk_manual_mode_indexing_only(table_module, mock_update_kb: MagicMock): parser_config = { "table_column_mode": "manual", "table_column_roles": { @@ -120,7 +134,7 @@ def test_chunk_manual_mode_indexing_only(mock_update_kb: MagicMock): "category": "metadata", }, } - chunks = _run_chunk(parser_config, mock_update_kb) + chunks = _run_chunk(table_module, parser_config, mock_update_kb) first = chunks[0] cww = first["content_with_weight"] assert "- title:" in cww and "Earthquake" in cww @@ -135,7 +149,7 @@ def test_chunk_manual_mode_indexing_only(mock_update_kb: MagicMock): assert "row_id_long" in first -def test_chunk_manual_mode_legacy_vectorize_role(mock_update_kb: MagicMock): +def test_chunk_manual_mode_legacy_vectorize_role(table_module, mock_update_kb: MagicMock): """Stored configs may still use role *vectorize*; chunking treats it like *indexing*.""" parser_config = { "table_column_mode": "manual", @@ -147,7 +161,7 @@ def test_chunk_manual_mode_legacy_vectorize_role(mock_update_kb: MagicMock): "category": "metadata", }, } - chunks = _run_chunk(parser_config, mock_update_kb) + chunks = _run_chunk(table_module, parser_config, mock_update_kb) first = chunks[0] cww = first["content_with_weight"] assert "- title:" in cww and "Earthquake" in cww @@ -155,7 +169,7 @@ def test_chunk_manual_mode_legacy_vectorize_role(mock_update_kb: MagicMock): assert "- country:" not in cww -def test_chunk_manual_mode_metadata_only(mock_update_kb: MagicMock): +def test_chunk_manual_mode_metadata_only(table_module, mock_update_kb: MagicMock): parser_config = { "table_column_mode": "manual", "table_column_roles": { @@ -166,18 +180,18 @@ def test_chunk_manual_mode_metadata_only(mock_update_kb: MagicMock): "category": "metadata", }, } - chunks = _run_chunk(parser_config, mock_update_kb) + chunks = _run_chunk(table_module, parser_config, mock_update_kb) first = chunks[0] assert (first.get("content_with_weight") or "").strip() == "" assert "country_raw" in first and "title_raw" in first -def test_chunk_manual_mode_both(mock_update_kb: MagicMock): +def test_chunk_manual_mode_both(table_module, mock_update_kb: MagicMock): parser_config = { "table_column_mode": "manual", "table_column_roles": {c: "both" for c in ["title", "content", "country", "category", "row_id"]}, } - chunks = _run_chunk(parser_config, mock_update_kb) + chunks = _run_chunk(table_module, parser_config, mock_update_kb) first = chunks[0] cww = first["content_with_weight"] assert "Earthquake hits Turkey" in cww @@ -187,7 +201,7 @@ def test_chunk_manual_mode_both(mock_update_kb: MagicMock): assert "title_raw" in first and "country_raw" in first -def test_chunk_manual_mode_partial_roles_default_to_both(mock_update_kb: MagicMock): +def test_chunk_manual_mode_partial_roles_default_to_both(table_module, mock_update_kb: MagicMock): parser_config = { "table_column_mode": "manual", "table_column_roles": { @@ -195,7 +209,7 @@ def test_chunk_manual_mode_partial_roles_default_to_both(mock_update_kb: MagicMo "country": "metadata", }, } - chunks = _run_chunk(parser_config, mock_update_kb) + chunks = _run_chunk(table_module, parser_config, mock_update_kb) first = chunks[0] cww = first["content_with_weight"] assert "- title:" in cww and "Earthquake" in cww @@ -208,20 +222,20 @@ def test_chunk_manual_mode_partial_roles_default_to_both(mock_update_kb: MagicMo assert "content_raw" in first and "category_raw" in first -def test_chunk_manual_mode_raw_fields_for_es(mock_update_kb: MagicMock): +def test_chunk_manual_mode_raw_fields_for_es(table_module, mock_update_kb: MagicMock): parser_config = { "table_column_mode": "manual", "table_column_roles": {c: "both" for c in ["title", "content", "country", "category", "row_id"]}, } - chunks = _run_chunk(parser_config, mock_update_kb) + chunks = _run_chunk(table_module, parser_config, mock_update_kb) first = chunks[0] for col in ("title", "content", "country", "category"): assert f"{col}_raw" in first assert f"{col}_tks" in first -def test_chunk_updates_table_column_names(mock_update_kb: MagicMock): - _run_chunk({}, mock_update_kb) +def test_chunk_updates_table_column_names(table_module, mock_update_kb: MagicMock): + _run_chunk(table_module, {}, mock_update_kb) mock_update_kb.assert_called_once() args, kwargs = mock_update_kb.call_args assert args[0] == KB_ID @@ -230,6 +244,6 @@ def test_chunk_updates_table_column_names(mock_update_kb: MagicMock): assert names == ["row_id", "title", "content", "country", "category"] -def test_chunk_count_matches_row_count(mock_update_kb: MagicMock): - chunks = _run_chunk({}, mock_update_kb) +def test_chunk_count_matches_row_count(table_module, mock_update_kb: MagicMock): + chunks = _run_chunk(table_module, {}, mock_update_kb) assert len(chunks) == 3 diff --git a/test/unit_test/rag/test_laws_docx_tables.py b/test/unit_test/rag/test_laws_docx_tables.py index 0a49a791a9..341e7f4847 100644 --- a/test/unit_test/rag/test_laws_docx_tables.py +++ b/test/unit_test/rag/test_laws_docx_tables.py @@ -16,6 +16,7 @@ import sys import types +from importlib import import_module, reload from io import BytesIO import pytest @@ -37,12 +38,32 @@ class _DummyBase: pass -_stub("deepdoc.parser", PdfParser=_DummyBase, DocxParser=_DummyBase, HtmlParser=_DummyBase) -_stub("deepdoc.parser.utils", get_text=lambda *a, **k: "") -_stub("rag.app.naive", by_plaintext=lambda *a, **k: ([], [], None), PARSERS={}) -_stub("common.parser_config_utils", normalize_layout_recognizer=lambda x: (x, None)) +@pytest.fixture(scope="module") +def docx_chunker(): + original_modules = { + name: sys.modules.get(name) + for name in ( + "deepdoc.parser", + "deepdoc.parser.utils", + "rag.app.naive", + "common.parser_config_utils", + ) + } -from rag.app.laws import Docx # noqa: E402 + try: + _stub("deepdoc.parser", PdfParser=_DummyBase, DocxParser=_DummyBase, HtmlParser=_DummyBase) + _stub("deepdoc.parser.utils", get_text=lambda *a, **k: "") + _stub("rag.app.naive", by_plaintext=lambda *a, **k: ([], [], None), PARSERS={}) + _stub("common.parser_config_utils", normalize_layout_recognizer=lambda x: (x, None)) + module = import_module("rag.app.laws") + module = reload(module) + yield module.Docx + finally: + for name, original in original_modules.items(): + if original is None: + sys.modules.pop(name, None) + else: + sys.modules[name] = original def _build_docx(builder): @@ -54,7 +75,7 @@ def _build_docx(builder): @pytest.mark.p2 -def test_laws_docx_preserves_table(): +def test_laws_docx_preserves_table(docx_chunker): """Regression for #16008: the laws DOCX parser dropped tables entirely.""" def builder(d): @@ -67,7 +88,7 @@ def test_laws_docx_preserves_table(): t.cell(1, 0).text = "Registration" t.cell(1, 1).text = "100" - chunks = Docx()("law.docx", _build_docx(builder)) + chunks = docx_chunker()("law.docx", _build_docx(builder)) assert any("" in c for c in chunks) table_chunk = next(c for c in chunks if "
" in c) @@ -78,7 +99,7 @@ def test_laws_docx_preserves_table(): @pytest.mark.p2 -def test_laws_docx_merged_cells_use_colspan(): +def test_laws_docx_merged_cells_use_colspan(docx_chunker): def builder(d): d.add_heading("Heading", level=1) t = d.add_table(rows=1, cols=3) @@ -87,20 +108,20 @@ def test_laws_docx_merged_cells_use_colspan(): t.cell(0, 1).text = "Merged" t.cell(0, 2).text = "Other" - chunks = Docx()("law.docx", _build_docx(builder)) + chunks = docx_chunker()("law.docx", _build_docx(builder)) table_chunk = next(c for c in chunks if "
" in c) assert "colspan='2'" in table_chunk assert "" in table_chunk @pytest.mark.p2 -def test_laws_docx_escapes_cell_html(): +def test_laws_docx_escapes_cell_html(docx_chunker): def builder(d): d.add_heading("Heading", level=1) t = d.add_table(rows=1, cols=1) t.cell(0, 0).text = "a < b & c > d" - chunks = Docx()("law.docx", _build_docx(builder)) + chunks = docx_chunker()("law.docx", _build_docx(builder)) table_chunk = next(c for c in chunks if "
Other
" in c) # Special characters are HTML-escaped so the table markup stays well-formed. assert "a < b & c > d" in table_chunk @@ -108,17 +129,17 @@ def test_laws_docx_escapes_cell_html(): @pytest.mark.p2 -def test_laws_docx_tables_only_does_not_crash(): +def test_laws_docx_tables_only_does_not_crash(docx_chunker): def builder(d): t = d.add_table(rows=1, cols=2) t.cell(0, 0).text = "a" t.cell(0, 1).text = "b" - chunks = Docx()("law.docx", _build_docx(builder)) + chunks = docx_chunker()("law.docx", _build_docx(builder)) assert any("
" in c for c in chunks) @pytest.mark.p2 -def test_laws_docx_empty_doc_returns_empty(): - chunks = Docx()("law.docx", _build_docx(lambda d: None)) +def test_laws_docx_empty_doc_returns_empty(docx_chunker): + chunks = docx_chunker()("law.docx", _build_docx(lambda d: None)) assert chunks == [] diff --git a/web/src/components/next-message-item/index.tsx b/web/src/components/next-message-item/index.tsx index 53e4ef4a45..48db836ed9 100644 --- a/web/src/components/next-message-item/index.tsx +++ b/web/src/components/next-message-item/index.tsx @@ -132,6 +132,8 @@ function MessageItem({ return null; } + const hasCustomChildren = item.data && !!children; + return (
- {item.data ? ( + {hasCustomChildren ? ( children ) : sendLoading && isEmpty(messageContent) ? ( <>{!isShare && t('common.running')} diff --git a/web/src/pages/agent/explore/components/session-chat.tsx b/web/src/pages/agent/explore/components/session-chat.tsx index 4353325135..e5bd0bb0d9 100644 --- a/web/src/pages/agent/explore/components/session-chat.tsx +++ b/web/src/pages/agent/explore/components/session-chat.tsx @@ -1,5 +1,6 @@ import { FileUploadProps } from '@/components/file-upload'; import { NextMessageInput } from '@/components/message-input/next'; +import MarkdownContent from '@/components/next-markdown-content'; import MessageItem from '@/components/next-message-item'; import PdfSheet from '@/components/pdf-drawer'; import { useClickDrawer } from '@/components/pdf-drawer/hooks'; @@ -8,6 +9,8 @@ import { useUploadAgentFileWithProgress } from '@/hooks/use-agent-request'; import { useFetchUserInfo } from '@/hooks/use-user-setting-request'; import { IAgentLogResponse } from '@/interfaces/database/agent'; import { IMessage } from '@/interfaces/database/chat'; +import DebugContent from '@/pages/agent/debug-content'; +import { useAwaitComponentData } from '@/pages/agent/hooks/use-chat-logic'; import { BeginQuery } from '@/pages/agent/interface'; import { ParameterDialog } from '@/pages/agent/share/parameter-dialog'; import { buildMessageUuidWithRole } from '@/utils/chat'; @@ -37,6 +40,7 @@ export function SessionChat({ session }: SessionChatProps) { handleInputChange, handlePressEnter, stopOutputMessage, + sendFormMessage, canvasInfo, findReferenceByMessageId, appendUploadResponseList, @@ -47,6 +51,11 @@ export function SessionChat({ session }: SessionChatProps) { shouldShowParameterDialog, setDerivedMessages, } = useSendSessionMessage(); + + const { buildInputList, handleOk, isWaiting } = useAwaitComponentData({ + derivedMessages, + sendFormMessage, + }); const hasActiveSession = Boolean( sessionId || isNew || hasLocalMessageRef.current, ); @@ -122,26 +131,58 @@ export function SessionChat({ session }: SessionChatProps) {
) : (
- {derivedMessages.map((message, i) => ( - - ))} + {derivedMessages.map((message, i) => { + const inputList = buildInputList(message); + const hasUserFillUpInputs = + message.role === MessageType.Assistant && + inputList.length > 0; + + return ( + + {hasUserFillUpInputs && + derivedMessages.length - 1 === i && ( + + )} + {hasUserFillUpInputs && + derivedMessages.length - 1 !== i && ( +
+ +
+ {inputList.map((item) => ( +
{item.value}
+ ))} +
+
+ )} +
+ ); + })}
)}
@@ -151,9 +192,9 @@ export function SessionChat({ session }: SessionChatProps) { { const isWaiting = useMemo(() => { const temp = derivedMessages?.some((message, i) => { + const hasInputs = Object.keys(getInputs(message)).length > 0; const flag = message.role === MessageType.Assistant && derivedMessages.length - 1 === i && - message.data; + hasInputs; return flag; }); return temp; - }, [derivedMessages]); + }, [derivedMessages, getInputs]); return { getInputs, buildInputList, handleOk, isWaiting }; };