Fix: UserFillUp interactive forms not working in agent explore mode (#14589)

## Summary

- **Backend**: `_iter_session_completion_events` in `agent_api.py` was
filtering out `user_inputs` and `workflow_finished` SSE events, causing
agents with UserFillUp components to silently fail in explore mode — the
interactive form never appeared, while the same agent worked correctly
in run (editor) mode.
- **Frontend**: `SessionChat` component in explore mode was missing
`DebugContent` children rendering inside `MessageItem`, so even if the
backend forwarded the events, the form UI would not render. Added
`DebugContent`, `MarkdownContent`, `useAwaitCompentData` hook, and
input-disabling logic to match the run mode's `chat/box.tsx` behavior.

## What was changed

### Backend (`api/apps/restful_apis/agent_api.py`)
- Line 266: Added `"user_inputs"` and `"workflow_finished"` to the
allowed event filter in `_iter_session_completion_events`

### Frontend (`web/src/pages/agent/explore/components/session-chat.tsx`)
- Added imports: `DebugContent`, `MarkdownContent`,
`useAwaitCompentData`, `useParams`
- Added `sendFormMessage` from `useSendSessionMessage()` hook
- Added `useAwaitCompentData` hook for form state management
- Added `DebugContent` as `MessageItem` children for the latest
assistant message (renders UserFillUp form)
- Added `MarkdownContent` + submitted values display for previous
assistant messages
- Updated `NextMessageInput` disabled states to respect `isWaitting`
(form submission in progress)

## Test plan

- [x] Agent with UserFillUp component (e.g., email draft with
send/edit/cancel options) shows interactive form in **explore mode**
- [x] Same agent continues to work correctly in **run (editor) mode**
- [x] Form submission sends data back to the agent and workflow
continues
- [x] Input field is disabled while waiting for form submission
- [ ] Agents without UserFillUp components are unaffected in explore
mode

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Zhichang Yu <yuzhichang@gmail.com>
This commit is contained in:
Tim Wang
2026-06-28 21:57:57 +08:00
committed by yzc
parent 212429bf9d
commit f0f10b6092
18 changed files with 330 additions and 196 deletions

View File

@@ -575,7 +575,14 @@ async def _iter_session_completion_events(tenant_id, agent_id, req, return_trace
yield ans
continue
if event in ["message", "message_end"]:
if event in ["message", "message_end", "user_inputs", "workflow_finished"]:
if event in ["user_inputs", "workflow_finished"]:
logging.debug(
"Forwarding session completion event: tenant_id=%s agent_id=%s event=%s",
tenant_id,
agent_id,
event,
)
yield ans
@@ -1564,7 +1571,10 @@ async def agent_chat_completion(tenant_id, agent_id=None):
"trace": [copy.deepcopy(data)],
}
)
final_ans = ans
if ans.get("event") == "message_end":
final_ans = ans
elif ans.get("event") == "user_inputs" and not final_ans:
final_ans = ans
except Exception as exc:
return get_result(data=f"**ERROR**: {str(exc)}")

View File

@@ -24,6 +24,7 @@ from quart import Response, request
from agent.canvas import Canvas
from api.apps import AUTH_BETA, login_required
from api.db.db_models import APIToken
from api.db.services.api_service import API4ConversationService
from api.db.services.canvas_service import UserCanvasService
from api.db.services.canvas_service import completion as agent_completion
@@ -53,6 +54,13 @@ from api.utils.reference_metadata_utils import (
logger = logging.getLogger(__name__)
def _get_sdk_authorization_token():
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return ""
return auth_header[len("Bearer "):].strip()
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
@login_required(auth_types=AUTH_BETA)
@add_tenant_id_to_kwargs

View File

@@ -29,8 +29,8 @@ from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
from api.apps.services.document_api_service import validate_document_update_fields, map_doc_keys, \
map_doc_keys_with_run_status, update_document_name_only, update_chunk_method, update_document_status_only, \
reset_document_for_reparse
from api.db import VALID_FILE_TYPES, FileType, DB
from api.db.db_models import API4Conversation
from api.db import VALID_FILE_TYPES, FileType
from api.db.db_models import API4Conversation, DB
from api.db.services import duplicate_name
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.db_models import Task

View File

@@ -16,21 +16,30 @@
import os
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
PROJECT_BASE = None
def _default_project_base_directory():
return os.path.abspath(
os.path.join(
os.path.dirname(os.path.realpath(__file__)),
os.pardir,
)
)
def get_project_base_directory(*args):
global PROJECT_BASE
if PROJECT_BASE is None:
PROJECT_BASE = os.path.abspath(
os.path.join(
os.path.dirname(os.path.realpath(__file__)),
os.pardir,
)
)
project_base = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
if not project_base:
if PROJECT_BASE is None:
PROJECT_BASE = _default_project_base_directory()
project_base = PROJECT_BASE
if args:
return os.path.join(PROJECT_BASE, *args)
return PROJECT_BASE
return os.path.join(project_base, *args)
return project_base
def traversal_files(base):
for root, ds, fs in os.walk(base):

View File

@@ -14,13 +14,27 @@
# limitations under the License.
#
import os
import shutil
import tiktoken
from common.file_utils import get_project_base_directory
tiktoken_cache_dir = get_project_base_directory()
def _ensure_tiktoken_cache() -> str:
cache_dir = get_project_base_directory()
os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir
bundled_encoding_path = get_project_base_directory("ragflow_deps", "cl100k_base.tiktoken")
cached_encoding_path = os.path.join(cache_dir, "9b5ad71b2ce5302211f9c61530b329a4922fc6a4")
if os.path.exists(bundled_encoding_path) and not os.path.exists(cached_encoding_path):
shutil.copyfile(bundled_encoding_path, cached_encoding_path)
return cache_dir
tiktoken_cache_dir = _ensure_tiktoken_cache()
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
encoder = tiktoken.get_encoding("cl100k_base")

View File

@@ -98,7 +98,7 @@ class RAGFlowPdfParser:
model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc")
self.updown_cnt_mdl.load_model(os.path.join(model_dir, "updown_concat_xgb.model"))
except Exception:
model_dir = snapshot_download(repo_id="InfiniFlow/text_concat_xgb_v1.0", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), local_dir_use_symlinks=False)
model_dir = snapshot_download(repo_id="InfiniFlow/text_concat_xgb_v1.0", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"))
self.updown_cnt_mdl.load_model(os.path.join(model_dir, "updown_concat_xgb.model"))
self.page_from = 0

View File

@@ -62,7 +62,7 @@ class LayoutRecognizer(Recognizer):
model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc")
super().__init__(self.labels, domain, model_dir)
except Exception:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), local_dir_use_symlinks=False)
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"))
super().__init__(self.labels, domain, model_dir)
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True):

View File

@@ -570,9 +570,10 @@ class OCR:
self.text_recognizer = [TextRecognizer(model_dir)]
except Exception:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
local_dir_use_symlinks=False)
model_dir = snapshot_download(
repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
)
if settings.PARALLEL_DEVICES > 0:
self.text_detector = []

View File

@@ -47,7 +47,6 @@ class TableStructureRecognizer(Recognizer):
snapshot_download(
repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
local_dir_use_symlinks=False,
),
)

View File

@@ -24,6 +24,7 @@ end-to-end behavior is intentionally left to higher-level integration tests.
from __future__ import annotations
import sys
from importlib import import_module, reload
from unittest.mock import MagicMock
import pytest
@@ -38,102 +39,90 @@ pytestmark = pytest.mark.p2
# runtime environment. We track every key we install and restore the prior
# sys.modules state in teardown_module so the stubs don't leak into other
# test files.
_SENTINEL_ABSENT = object()
_INSTALLED_STUBS: dict[str, object] = {}
@pytest.fixture(scope="module")
def pipeline_chunker_module():
"""Import pipeline_chunker with rag.app parser modules stubbed locally."""
stubbed_names = [
"api.db.services.file_service",
"deepdoc.vision.ocr",
"deepdoc.parser.figure_parser",
"rag.app.picture",
"rag.app.audio",
"rag.app.resume",
"rag.app.naive",
"rag.app.paper",
"rag.app.book",
"rag.app.presentation",
"rag.app.manual",
"rag.app.laws",
"rag.app.qa",
"rag.app.table",
"rag.app.one",
"rag.app.email",
"rag.app.tag",
]
original_modules = {name: sys.modules.get(name) for name in stubbed_names}
file_service_stub = MagicMock()
file_service_stub.FileService = MagicMock()
def _install_stub(name: str, stub: object) -> None:
"""Insert ``stub`` into sys.modules and remember the prior entry."""
if name in _INSTALLED_STUBS:
return
_INSTALLED_STUBS[name] = sys.modules.get(name, _SENTINEL_ABSENT)
sys.modules[name] = stub
try:
sys.modules["api.db.services.file_service"] = file_service_stub
for name in stubbed_names[1:]:
stub = MagicMock()
stub.chunk = MagicMock(return_value=[{"content_with_weight": "stub"}])
sys.modules[name] = stub
_file_service_stub = MagicMock()
_file_service_stub.FileService = MagicMock()
if "api.db.services.file_service" not in sys.modules:
_install_stub("api.db.services.file_service", _file_service_stub)
for mod in [
"deepdoc.vision.ocr",
"deepdoc.parser.figure_parser",
"rag.app.picture",
"rag.app.audio",
"rag.app.resume",
"rag.app.naive",
"rag.app.paper",
"rag.app.book",
"rag.app.presentation",
"rag.app.manual",
"rag.app.laws",
"rag.app.qa",
"rag.app.table",
"rag.app.one",
"rag.app.email",
"rag.app.tag",
]:
if mod not in sys.modules:
stub = MagicMock()
stub.chunk = MagicMock(return_value=[{"content_with_weight": "stub"}])
_install_stub(mod, stub)
def teardown_module(module):
"""Restore sys.modules to its pre-stub state when this file's tests finish."""
for name, original in _INSTALLED_STUBS.items():
if original is _SENTINEL_ABSENT:
sys.modules.pop(name, None)
else:
sys.modules[name] = original
_INSTALLED_STUBS.clear()
from agent.component.pipeline_chunker import ( # noqa: E402
_PARSER_MODULES,
PipelineChunkerParam,
_load_chunker,
)
module = import_module("agent.component.pipeline_chunker")
module = reload(module)
yield module
finally:
for name, original in original_modules.items():
if original is None:
sys.modules.pop(name, None)
else:
sys.modules[name] = original
class TestPipelineChunkerParam:
"""Validate parameter parsing and the strategy whitelist."""
def test_default_param_validates(self):
def test_default_param_validates(self, pipeline_chunker_module):
"""A freshly constructed param object should pass ``check()``."""
p = PipelineChunkerParam()
p = pipeline_chunker_module.PipelineChunkerParam()
assert p.check() is True
def test_accepts_each_known_parser(self):
def test_accepts_each_known_parser(self, pipeline_chunker_module):
"""Every parser id in the lookup table must validate."""
for parser_id in _PARSER_MODULES:
p = PipelineChunkerParam()
for parser_id in pipeline_chunker_module._PARSER_MODULES:
p = pipeline_chunker_module.PipelineChunkerParam()
p.parser_id = parser_id
assert p.check() is True
def test_rejects_unknown_parser(self):
def test_rejects_unknown_parser(self, pipeline_chunker_module):
"""Unknown parser ids must raise ``ValueError`` at validation time."""
p = PipelineChunkerParam()
p = pipeline_chunker_module.PipelineChunkerParam()
p.parser_id = "nonsense-parser"
with pytest.raises(ValueError):
p.check()
def test_rejects_non_dict_parser_config(self):
def test_rejects_non_dict_parser_config(self, pipeline_chunker_module):
"""``parser_config`` must be a dict; anything else must raise."""
p = PipelineChunkerParam()
p = pipeline_chunker_module.PipelineChunkerParam()
p.parser_config = "not a dict"
with pytest.raises(ValueError):
p.check()
def test_rejects_negative_pages(self):
def test_rejects_negative_pages(self, pipeline_chunker_module):
"""Negative page indices must raise ``ValueError``."""
p = PipelineChunkerParam()
p = pipeline_chunker_module.PipelineChunkerParam()
p.from_page = -1
with pytest.raises(ValueError):
p.check()
def test_rejects_inverted_page_range(self):
def test_rejects_inverted_page_range(self, pipeline_chunker_module):
"""``from_page`` greater than ``to_page`` must raise ``ValueError``."""
p = PipelineChunkerParam()
p = pipeline_chunker_module.PipelineChunkerParam()
p.from_page = 10
p.to_page = 5
with pytest.raises(ValueError, match="from_page must be <= to_page"):
@@ -143,13 +132,13 @@ class TestPipelineChunkerParam:
class TestLoadChunker:
"""Verify the lazy parser-id -> chunker callable resolver."""
def test_load_chunker_returns_callable_for_each_known_parser(self):
def test_load_chunker_returns_callable_for_each_known_parser(self, pipeline_chunker_module):
"""Every known parser id should resolve to a callable ``chunk`` function."""
for parser_id in _PARSER_MODULES:
chunker = _load_chunker(parser_id)
for parser_id in pipeline_chunker_module._PARSER_MODULES:
chunker = pipeline_chunker_module._load_chunker(parser_id)
assert callable(chunker)
def test_load_chunker_raises_for_unknown_parser(self):
def test_load_chunker_raises_for_unknown_parser(self, pipeline_chunker_module):
"""Unknown parser ids should raise ``KeyError`` from the lookup."""
with pytest.raises(KeyError):
_load_chunker("not-a-real-parser")
pipeline_chunker_module._load_chunker("not-a-real-parser")

View File

@@ -65,8 +65,8 @@ def _load_bot_api(monkeypatch, *, accessible, calls):
calls["completion"] = True
async def _gen():
if False:
yield ""
yield 'data: {"event":"message","data":{"content":"ok"}}\n\n'
yield 'data: {"event":"message_end","data":{"content":"ok"}}\n\n'
return _gen()
_stub(monkeypatch, "quart", Response=lambda *a, **k: SimpleNamespace(headers=SimpleNamespace(add_header=lambda *aa, **kk: None)), request=SimpleNamespace())
@@ -83,7 +83,12 @@ def _load_bot_api(monkeypatch, *, accessible, calls):
_stub(monkeypatch, "api.db.services.llm_service", LLMBundle=SimpleNamespace())
_stub(monkeypatch, "common.metadata_utils", apply_meta_data_filter=lambda *_a, **_k: None)
_stub(monkeypatch, "api.db.services.search_service", SearchService=SimpleNamespace())
_stub(monkeypatch, "api.db.services.user_service", UserTenantService=SimpleNamespace())
_stub(
monkeypatch,
"api.db.services.user_service",
TenantService=SimpleNamespace(),
UserTenantService=SimpleNamespace(),
)
_stub(monkeypatch, "api.db.joint_services.tenant_model_service", get_tenant_default_model_by_type=lambda *_a, **_k: None, get_model_config_from_provider_instance=lambda *_a, **_k: None)
_stub(monkeypatch, "common.misc_utils", get_uuid=lambda: "uuid", thread_pool_exec=_passthrough_thread_pool_exec)
_stub(
@@ -118,7 +123,7 @@ def _load_bot_api(monkeypatch, *, accessible, calls):
async def _async_empty_json():
return {}
return {"stream": False}
@pytest.mark.p1

View File

@@ -53,23 +53,24 @@ class TestGetProjectBaseDirectory:
def test_uses_environment_variable_when_available(self):
"""Test that function uses RAG_PROJECT_BASE environment variable when set"""
test_path = "/custom/project/path"
file_utils.PROJECT_BASE = test_path
result = get_project_base_directory()
with patch.dict(os.environ, {"RAG_PROJECT_BASE": test_path}, clear=False):
result = get_project_base_directory()
assert result == test_path
def test_calculates_default_path_when_no_env_vars(self):
"""Test that function calculates default path when no environment variables are set"""
with patch.dict(os.environ, {}, clear=True): # Clear all environment variables
# Reset the global variable to force re-initialization
original_base = file_utils.PROJECT_BASE
try:
file_utils.PROJECT_BASE = None
result = get_project_base_directory()
result = get_project_base_directory()
# Should return a valid absolute path
assert result is not None
assert os.path.isabs(result)
assert os.path.basename(result) != "" # Should not be root directory
# Should return a valid absolute path
assert result is not None
assert os.path.isabs(result)
assert os.path.basename(result) != "" # Should not be root directory
finally:
file_utils.PROJECT_BASE = original_base
def test_caches_project_base_value(self):
"""Test that PROJECT_BASE is cached after first calculation"""

View File

@@ -28,25 +28,44 @@ from __future__ import annotations
import io
import sys
from importlib import import_module, reload
from unittest.mock import MagicMock, patch
# Mock heavy modules that trigger ONNX/OCR model loading or optional parser
# backends at import time so importing rag.app.naive stays lightweight.
for _mod in [
"deepdoc.vision.ocr",
"deepdoc.parser.figure_parser",
"deepdoc.parser.docling_parser",
"deepdoc.parser.tcadp_parser",
"rag.app.picture",
]:
if _mod not in sys.modules:
sys.modules[_mod] = MagicMock()
import pytest
from PIL import Image
from common import ssrf_guard
from rag.app.naive import MAX_IMAGE_REDIRECTS, Markdown
@pytest.fixture(scope="module")
def naive_module():
"""Load rag.app.naive with heavy optional dependencies stubbed locally."""
stub_names = [
"deepdoc.vision.ocr",
"deepdoc.parser.figure_parser",
"deepdoc.parser.docling_parser",
"deepdoc.parser.tcadp_parser",
"rag.app.picture",
]
original_modules = {name: sys.modules.get(name) for name in stub_names}
try:
for name in stub_names:
sys.modules[name] = MagicMock()
module = import_module("rag.app.naive")
module = reload(module)
yield module
finally:
for name, original in original_modules.items():
if original is None:
sys.modules.pop(name, None)
else:
sys.modules[name] = original
@pytest.fixture(scope="module")
def max_image_redirects(naive_module):
return naive_module.MAX_IMAGE_REDIRECTS
def _png_bytes() -> bytes:
@@ -68,8 +87,8 @@ class _Resp:
@pytest.fixture
def parser():
return Markdown(128)
def parser(naive_module):
return naive_module.Markdown(128)
@pytest.mark.p1
@@ -127,7 +146,7 @@ def test_fetches_legitimate_public_image(parser):
@pytest.mark.p1
def test_redirect_chain_is_bounded(parser):
def test_redirect_chain_is_bounded(parser, max_image_redirects):
"""An endless redirect loop is abandoned instead of being followed forever."""
loop = _Resp(302, headers={"Location": "http://public.example/next"})
with (
@@ -136,5 +155,5 @@ def test_redirect_chain_is_bounded(parser):
):
images, _ = parser.load_images_from_urls(["http://public.example/start"])
assert get.call_count == MAX_IMAGE_REDIRECTS + 1
assert get.call_count == max_image_redirects + 1
assert images == []

View File

@@ -20,18 +20,9 @@
from __future__ import annotations
import sys
from importlib import import_module, reload
from unittest.mock import MagicMock, patch
# Mock heavy modules that trigger ONNX model loading at import time
# table.py -> deepdoc.parser.figure_parser -> rag.app.picture -> OCR()
for mod in [
"deepdoc.vision.ocr",
"deepdoc.parser.figure_parser",
"rag.app.picture",
]:
if mod not in sys.modules:
sys.modules[mod] = MagicMock()
import warnings
# Importing rag.app.table pulls api -> rag.llm -> deepdoc -> xgboost; xgboost may warn on
@@ -42,7 +33,6 @@ import pkg_resources # noqa: F401 — stabilize xgboost import during collectio
import pytest
import common.settings as settings
from rag.app.table import chunk
# chunk() removes columns named id, _id, index, idx — use row_id instead of id.
TEST_CSV = b"""row_id,title,content,country,category
@@ -76,14 +66,38 @@ def _stub_rag_tokenizer(monkeypatch):
monkeypatch.setattr("rag.nlp.rag_tokenizer.fine_grained_tokenize", fake_tokenize)
@pytest.fixture(scope="module")
def table_module():
"""Load rag.app.table with heavy optional dependencies stubbed locally."""
stub_names = [
"deepdoc.vision.ocr",
"deepdoc.parser.figure_parser",
"rag.app.picture",
]
original_modules = {name: sys.modules.get(name) for name in stub_names}
try:
for name in stub_names:
sys.modules[name] = MagicMock()
module = import_module("rag.app.table")
module = reload(module)
yield module
finally:
for name, original in original_modules.items():
if original is None:
sys.modules.pop(name, None)
else:
sys.modules[name] = original
@pytest.fixture
def mock_update_kb():
with patch("rag.app.table.KnowledgebaseService.update_parser_config") as m:
yield m
def _run_chunk(parser_config: dict, mock_update_kb: MagicMock):
return chunk(
def _run_chunk(table_module, parser_config: dict, mock_update_kb: MagicMock):
return table_module.chunk(
FILENAME,
binary=TEST_CSV,
callback=_noop_callback,
@@ -93,9 +107,9 @@ def _run_chunk(parser_config: dict, mock_update_kb: MagicMock):
)
def test_chunk_auto_mode_all_columns_in_text_and_stored(mock_update_kb: MagicMock):
def test_chunk_auto_mode_all_columns_in_text_and_stored(table_module, mock_update_kb: MagicMock):
parser_config: dict = {}
chunks = _run_chunk(parser_config, mock_update_kb)
chunks = _run_chunk(table_module, parser_config, mock_update_kb)
assert len(chunks) == 3
first = chunks[0]
cww = first["content_with_weight"]
@@ -109,7 +123,7 @@ def test_chunk_auto_mode_all_columns_in_text_and_stored(mock_update_kb: MagicMoc
assert "title_raw" in first and "country_raw" in first
def test_chunk_manual_mode_indexing_only(mock_update_kb: MagicMock):
def test_chunk_manual_mode_indexing_only(table_module, mock_update_kb: MagicMock):
parser_config = {
"table_column_mode": "manual",
"table_column_roles": {
@@ -120,7 +134,7 @@ def test_chunk_manual_mode_indexing_only(mock_update_kb: MagicMock):
"category": "metadata",
},
}
chunks = _run_chunk(parser_config, mock_update_kb)
chunks = _run_chunk(table_module, parser_config, mock_update_kb)
first = chunks[0]
cww = first["content_with_weight"]
assert "- title:" in cww and "Earthquake" in cww
@@ -135,7 +149,7 @@ def test_chunk_manual_mode_indexing_only(mock_update_kb: MagicMock):
assert "row_id_long" in first
def test_chunk_manual_mode_legacy_vectorize_role(mock_update_kb: MagicMock):
def test_chunk_manual_mode_legacy_vectorize_role(table_module, mock_update_kb: MagicMock):
"""Stored configs may still use role *vectorize*; chunking treats it like *indexing*."""
parser_config = {
"table_column_mode": "manual",
@@ -147,7 +161,7 @@ def test_chunk_manual_mode_legacy_vectorize_role(mock_update_kb: MagicMock):
"category": "metadata",
},
}
chunks = _run_chunk(parser_config, mock_update_kb)
chunks = _run_chunk(table_module, parser_config, mock_update_kb)
first = chunks[0]
cww = first["content_with_weight"]
assert "- title:" in cww and "Earthquake" in cww
@@ -155,7 +169,7 @@ def test_chunk_manual_mode_legacy_vectorize_role(mock_update_kb: MagicMock):
assert "- country:" not in cww
def test_chunk_manual_mode_metadata_only(mock_update_kb: MagicMock):
def test_chunk_manual_mode_metadata_only(table_module, mock_update_kb: MagicMock):
parser_config = {
"table_column_mode": "manual",
"table_column_roles": {
@@ -166,18 +180,18 @@ def test_chunk_manual_mode_metadata_only(mock_update_kb: MagicMock):
"category": "metadata",
},
}
chunks = _run_chunk(parser_config, mock_update_kb)
chunks = _run_chunk(table_module, parser_config, mock_update_kb)
first = chunks[0]
assert (first.get("content_with_weight") or "").strip() == ""
assert "country_raw" in first and "title_raw" in first
def test_chunk_manual_mode_both(mock_update_kb: MagicMock):
def test_chunk_manual_mode_both(table_module, mock_update_kb: MagicMock):
parser_config = {
"table_column_mode": "manual",
"table_column_roles": {c: "both" for c in ["title", "content", "country", "category", "row_id"]},
}
chunks = _run_chunk(parser_config, mock_update_kb)
chunks = _run_chunk(table_module, parser_config, mock_update_kb)
first = chunks[0]
cww = first["content_with_weight"]
assert "Earthquake hits Turkey" in cww
@@ -187,7 +201,7 @@ def test_chunk_manual_mode_both(mock_update_kb: MagicMock):
assert "title_raw" in first and "country_raw" in first
def test_chunk_manual_mode_partial_roles_default_to_both(mock_update_kb: MagicMock):
def test_chunk_manual_mode_partial_roles_default_to_both(table_module, mock_update_kb: MagicMock):
parser_config = {
"table_column_mode": "manual",
"table_column_roles": {
@@ -195,7 +209,7 @@ def test_chunk_manual_mode_partial_roles_default_to_both(mock_update_kb: MagicMo
"country": "metadata",
},
}
chunks = _run_chunk(parser_config, mock_update_kb)
chunks = _run_chunk(table_module, parser_config, mock_update_kb)
first = chunks[0]
cww = first["content_with_weight"]
assert "- title:" in cww and "Earthquake" in cww
@@ -208,20 +222,20 @@ def test_chunk_manual_mode_partial_roles_default_to_both(mock_update_kb: MagicMo
assert "content_raw" in first and "category_raw" in first
def test_chunk_manual_mode_raw_fields_for_es(mock_update_kb: MagicMock):
def test_chunk_manual_mode_raw_fields_for_es(table_module, mock_update_kb: MagicMock):
parser_config = {
"table_column_mode": "manual",
"table_column_roles": {c: "both" for c in ["title", "content", "country", "category", "row_id"]},
}
chunks = _run_chunk(parser_config, mock_update_kb)
chunks = _run_chunk(table_module, parser_config, mock_update_kb)
first = chunks[0]
for col in ("title", "content", "country", "category"):
assert f"{col}_raw" in first
assert f"{col}_tks" in first
def test_chunk_updates_table_column_names(mock_update_kb: MagicMock):
_run_chunk({}, mock_update_kb)
def test_chunk_updates_table_column_names(table_module, mock_update_kb: MagicMock):
_run_chunk(table_module, {}, mock_update_kb)
mock_update_kb.assert_called_once()
args, kwargs = mock_update_kb.call_args
assert args[0] == KB_ID
@@ -230,6 +244,6 @@ def test_chunk_updates_table_column_names(mock_update_kb: MagicMock):
assert names == ["row_id", "title", "content", "country", "category"]
def test_chunk_count_matches_row_count(mock_update_kb: MagicMock):
chunks = _run_chunk({}, mock_update_kb)
def test_chunk_count_matches_row_count(table_module, mock_update_kb: MagicMock):
chunks = _run_chunk(table_module, {}, mock_update_kb)
assert len(chunks) == 3

View File

@@ -16,6 +16,7 @@
import sys
import types
from importlib import import_module, reload
from io import BytesIO
import pytest
@@ -37,12 +38,32 @@ class _DummyBase:
pass
_stub("deepdoc.parser", PdfParser=_DummyBase, DocxParser=_DummyBase, HtmlParser=_DummyBase)
_stub("deepdoc.parser.utils", get_text=lambda *a, **k: "")
_stub("rag.app.naive", by_plaintext=lambda *a, **k: ([], [], None), PARSERS={})
_stub("common.parser_config_utils", normalize_layout_recognizer=lambda x: (x, None))
@pytest.fixture(scope="module")
def docx_chunker():
original_modules = {
name: sys.modules.get(name)
for name in (
"deepdoc.parser",
"deepdoc.parser.utils",
"rag.app.naive",
"common.parser_config_utils",
)
}
from rag.app.laws import Docx # noqa: E402
try:
_stub("deepdoc.parser", PdfParser=_DummyBase, DocxParser=_DummyBase, HtmlParser=_DummyBase)
_stub("deepdoc.parser.utils", get_text=lambda *a, **k: "")
_stub("rag.app.naive", by_plaintext=lambda *a, **k: ([], [], None), PARSERS={})
_stub("common.parser_config_utils", normalize_layout_recognizer=lambda x: (x, None))
module = import_module("rag.app.laws")
module = reload(module)
yield module.Docx
finally:
for name, original in original_modules.items():
if original is None:
sys.modules.pop(name, None)
else:
sys.modules[name] = original
def _build_docx(builder):
@@ -54,7 +75,7 @@ def _build_docx(builder):
@pytest.mark.p2
def test_laws_docx_preserves_table():
def test_laws_docx_preserves_table(docx_chunker):
"""Regression for #16008: the laws DOCX parser dropped tables entirely."""
def builder(d):
@@ -67,7 +88,7 @@ def test_laws_docx_preserves_table():
t.cell(1, 0).text = "Registration"
t.cell(1, 1).text = "100"
chunks = Docx()("law.docx", _build_docx(builder))
chunks = docx_chunker()("law.docx", _build_docx(builder))
assert any("<table>" in c for c in chunks)
table_chunk = next(c for c in chunks if "<table>" in c)
@@ -78,7 +99,7 @@ def test_laws_docx_preserves_table():
@pytest.mark.p2
def test_laws_docx_merged_cells_use_colspan():
def test_laws_docx_merged_cells_use_colspan(docx_chunker):
def builder(d):
d.add_heading("Heading", level=1)
t = d.add_table(rows=1, cols=3)
@@ -87,20 +108,20 @@ def test_laws_docx_merged_cells_use_colspan():
t.cell(0, 1).text = "Merged"
t.cell(0, 2).text = "Other"
chunks = Docx()("law.docx", _build_docx(builder))
chunks = docx_chunker()("law.docx", _build_docx(builder))
table_chunk = next(c for c in chunks if "<table>" in c)
assert "colspan='2'" in table_chunk
assert "<td>Other</td>" in table_chunk
@pytest.mark.p2
def test_laws_docx_escapes_cell_html():
def test_laws_docx_escapes_cell_html(docx_chunker):
def builder(d):
d.add_heading("Heading", level=1)
t = d.add_table(rows=1, cols=1)
t.cell(0, 0).text = "a < b & c > d"
chunks = Docx()("law.docx", _build_docx(builder))
chunks = docx_chunker()("law.docx", _build_docx(builder))
table_chunk = next(c for c in chunks if "<table>" in c)
# Special characters are HTML-escaped so the table markup stays well-formed.
assert "a &lt; b &amp; c &gt; d" in table_chunk
@@ -108,17 +129,17 @@ def test_laws_docx_escapes_cell_html():
@pytest.mark.p2
def test_laws_docx_tables_only_does_not_crash():
def test_laws_docx_tables_only_does_not_crash(docx_chunker):
def builder(d):
t = d.add_table(rows=1, cols=2)
t.cell(0, 0).text = "a"
t.cell(0, 1).text = "b"
chunks = Docx()("law.docx", _build_docx(builder))
chunks = docx_chunker()("law.docx", _build_docx(builder))
assert any("<table>" in c for c in chunks)
@pytest.mark.p2
def test_laws_docx_empty_doc_returns_empty():
chunks = Docx()("law.docx", _build_docx(lambda d: None))
def test_laws_docx_empty_doc_returns_empty(docx_chunker):
chunks = docx_chunker()("law.docx", _build_docx(lambda d: None))
assert chunks == []

View File

@@ -132,6 +132,8 @@ function MessageItem({
return null;
}
const hasCustomChildren = item.data && !!children;
return (
<div
className={cn({
@@ -142,7 +144,7 @@ function MessageItem({
})}
dir={getDirAttribute(messageContent.replace(citationMarkerReg, ''))}
>
{item.data ? (
{hasCustomChildren ? (
children
) : sendLoading && isEmpty(messageContent) ? (
<>{!isShare && t('common.running')}</>

View File

@@ -1,5 +1,6 @@
import { FileUploadProps } from '@/components/file-upload';
import { NextMessageInput } from '@/components/message-input/next';
import MarkdownContent from '@/components/next-markdown-content';
import MessageItem from '@/components/next-message-item';
import PdfSheet from '@/components/pdf-drawer';
import { useClickDrawer } from '@/components/pdf-drawer/hooks';
@@ -8,6 +9,8 @@ import { useUploadAgentFileWithProgress } from '@/hooks/use-agent-request';
import { useFetchUserInfo } from '@/hooks/use-user-setting-request';
import { IAgentLogResponse } from '@/interfaces/database/agent';
import { IMessage } from '@/interfaces/database/chat';
import DebugContent from '@/pages/agent/debug-content';
import { useAwaitComponentData } from '@/pages/agent/hooks/use-chat-logic';
import { BeginQuery } from '@/pages/agent/interface';
import { ParameterDialog } from '@/pages/agent/share/parameter-dialog';
import { buildMessageUuidWithRole } from '@/utils/chat';
@@ -37,6 +40,7 @@ export function SessionChat({ session }: SessionChatProps) {
handleInputChange,
handlePressEnter,
stopOutputMessage,
sendFormMessage,
canvasInfo,
findReferenceByMessageId,
appendUploadResponseList,
@@ -47,6 +51,11 @@ export function SessionChat({ session }: SessionChatProps) {
shouldShowParameterDialog,
setDerivedMessages,
} = useSendSessionMessage();
const { buildInputList, handleOk, isWaiting } = useAwaitComponentData({
derivedMessages,
sendFormMessage,
});
const hasActiveSession = Boolean(
sessionId || isNew || hasLocalMessageRef.current,
);
@@ -122,26 +131,58 @@ export function SessionChat({ session }: SessionChatProps) {
</div>
) : (
<div className="w-full pr-5">
{derivedMessages.map((message, i) => (
<MessageItem
loading={
message.role === MessageType.Assistant &&
sendLoading &&
derivedMessages.length - 1 === i
}
key={buildMessageUuidWithRole(message)}
item={message}
nickname={userInfo.nickname}
avatar={userInfo.avatar}
avatarDialog={canvasInfo?.avatar || ''}
reference={findReferenceByMessageId(message.id)}
clickDocumentButton={clickDocumentButton}
index={i}
showLikeButton={false}
sendLoading={sendLoading}
showLog={false}
/>
))}
{derivedMessages.map((message, i) => {
const inputList = buildInputList(message);
const hasUserFillUpInputs =
message.role === MessageType.Assistant &&
inputList.length > 0;
return (
<MessageItem
loading={
message.role === MessageType.Assistant &&
sendLoading &&
derivedMessages.length - 1 === i
}
key={buildMessageUuidWithRole(message)}
item={message}
nickname={userInfo.nickname}
avatar={userInfo.avatar}
avatarDialog={canvasInfo?.avatar || ''}
reference={findReferenceByMessageId(message.id)}
clickDocumentButton={clickDocumentButton}
index={i}
showLikeButton={false}
sendLoading={sendLoading}
showLog={false}
>
{hasUserFillUpInputs &&
derivedMessages.length - 1 === i && (
<DebugContent
parameters={inputList}
message={message}
ok={handleOk(message)}
isNext={false}
btnText={t('common.submit')}
></DebugContent>
)}
{hasUserFillUpInputs &&
derivedMessages.length - 1 !== i && (
<div>
<MarkdownContent
content={message?.data?.tips}
loading={false}
></MarkdownContent>
<div>
{inputList.map((item) => (
<div key={item.key}>{item.value}</div>
))}
</div>
</div>
)}
</MessageItem>
);
})}
</div>
)}
<div ref={scrollRef} />
@@ -151,9 +192,9 @@ export function SessionChat({ session }: SessionChatProps) {
<NextMessageInput
value={value}
sendLoading={sendLoading}
disabled={false}
sendDisabled={sendLoading}
isUploading={isUploading}
disabled={isWaiting}
sendDisabled={sendLoading || isWaiting}
isUploading={isUploading || isWaiting}
onPressEnter={handleSessionPressEnter}
onInputChange={handleInputChange}
stopOutputMessage={stopOutputMessage}

View File

@@ -40,14 +40,15 @@ const useAwaitComponentData = (props: IAwaitCompentData) => {
const isWaiting = useMemo(() => {
const temp = derivedMessages?.some((message, i) => {
const hasInputs = Object.keys(getInputs(message)).length > 0;
const flag =
message.role === MessageType.Assistant &&
derivedMessages.length - 1 === i &&
message.data;
hasInputs;
return flag;
});
return temp;
}, [derivedMessages]);
}, [derivedMessages, getInputs]);
return { getInputs, buildInputList, handleOk, isWaiting };
};