mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Fix: UserFillUp interactive forms not working in agent explore mode (#14589)
## Summary - **Backend**: `_iter_session_completion_events` in `agent_api.py` was filtering out `user_inputs` and `workflow_finished` SSE events, causing agents with UserFillUp components to silently fail in explore mode — the interactive form never appeared, while the same agent worked correctly in run (editor) mode. - **Frontend**: `SessionChat` component in explore mode was missing `DebugContent` children rendering inside `MessageItem`, so even if the backend forwarded the events, the form UI would not render. Added `DebugContent`, `MarkdownContent`, `useAwaitCompentData` hook, and input-disabling logic to match the run mode's `chat/box.tsx` behavior. ## What was changed ### Backend (`api/apps/restful_apis/agent_api.py`) - Line 266: Added `"user_inputs"` and `"workflow_finished"` to the allowed event filter in `_iter_session_completion_events` ### Frontend (`web/src/pages/agent/explore/components/session-chat.tsx`) - Added imports: `DebugContent`, `MarkdownContent`, `useAwaitCompentData`, `useParams` - Added `sendFormMessage` from `useSendSessionMessage()` hook - Added `useAwaitCompentData` hook for form state management - Added `DebugContent` as `MessageItem` children for the latest assistant message (renders UserFillUp form) - Added `MarkdownContent` + submitted values display for previous assistant messages - Updated `NextMessageInput` disabled states to respect `isWaitting` (form submission in progress) ## Test plan - [x] Agent with UserFillUp component (e.g., email draft with send/edit/cancel options) shows interactive form in **explore mode** - [x] Same agent continues to work correctly in **run (editor) mode** - [x] Form submission sends data back to the agent and workflow continues - [x] Input field is disabled while waiting for form submission - [ ] Agents without UserFillUp components are unaffected in explore mode 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Zhichang Yu <yuzhichang@gmail.com>
This commit is contained in:
@@ -575,7 +575,14 @@ async def _iter_session_completion_events(tenant_id, agent_id, req, return_trace
|
||||
yield ans
|
||||
continue
|
||||
|
||||
if event in ["message", "message_end"]:
|
||||
if event in ["message", "message_end", "user_inputs", "workflow_finished"]:
|
||||
if event in ["user_inputs", "workflow_finished"]:
|
||||
logging.debug(
|
||||
"Forwarding session completion event: tenant_id=%s agent_id=%s event=%s",
|
||||
tenant_id,
|
||||
agent_id,
|
||||
event,
|
||||
)
|
||||
yield ans
|
||||
|
||||
|
||||
@@ -1564,7 +1571,10 @@ async def agent_chat_completion(tenant_id, agent_id=None):
|
||||
"trace": [copy.deepcopy(data)],
|
||||
}
|
||||
)
|
||||
final_ans = ans
|
||||
if ans.get("event") == "message_end":
|
||||
final_ans = ans
|
||||
elif ans.get("event") == "user_inputs" and not final_ans:
|
||||
final_ans = ans
|
||||
except Exception as exc:
|
||||
return get_result(data=f"**ERROR**: {str(exc)}")
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from quart import Response, request
|
||||
|
||||
from agent.canvas import Canvas
|
||||
from api.apps import AUTH_BETA, login_required
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import API4ConversationService
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.canvas_service import completion as agent_completion
|
||||
@@ -53,6 +54,13 @@ from api.utils.reference_metadata_utils import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_sdk_authorization_token():
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return ""
|
||||
return auth_header[len("Bearer "):].strip()
|
||||
|
||||
|
||||
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
|
||||
@login_required(auth_types=AUTH_BETA)
|
||||
@add_tenant_id_to_kwargs
|
||||
|
||||
@@ -29,8 +29,8 @@ from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
|
||||
from api.apps.services.document_api_service import validate_document_update_fields, map_doc_keys, \
|
||||
map_doc_keys_with_run_status, update_document_name_only, update_chunk_method, update_document_status_only, \
|
||||
reset_document_for_reparse
|
||||
from api.db import VALID_FILE_TYPES, FileType, DB
|
||||
from api.db.db_models import API4Conversation
|
||||
from api.db import VALID_FILE_TYPES, FileType
|
||||
from api.db.db_models import API4Conversation, DB
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.doc_metadata_service import DocMetadataService
|
||||
from api.db.db_models import Task
|
||||
|
||||
@@ -16,21 +16,30 @@
|
||||
|
||||
import os
|
||||
|
||||
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
|
||||
PROJECT_BASE = None
|
||||
|
||||
|
||||
def _default_project_base_directory():
|
||||
return os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
os.pardir,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_project_base_directory(*args):
|
||||
global PROJECT_BASE
|
||||
if PROJECT_BASE is None:
|
||||
PROJECT_BASE = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
os.pardir,
|
||||
)
|
||||
)
|
||||
|
||||
project_base = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
|
||||
if not project_base:
|
||||
if PROJECT_BASE is None:
|
||||
PROJECT_BASE = _default_project_base_directory()
|
||||
project_base = PROJECT_BASE
|
||||
|
||||
if args:
|
||||
return os.path.join(PROJECT_BASE, *args)
|
||||
return PROJECT_BASE
|
||||
return os.path.join(project_base, *args)
|
||||
return project_base
|
||||
|
||||
def traversal_files(base):
|
||||
for root, ds, fs in os.walk(base):
|
||||
|
||||
@@ -14,13 +14,27 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tiktoken
|
||||
|
||||
from common.file_utils import get_project_base_directory
|
||||
|
||||
tiktoken_cache_dir = get_project_base_directory()
|
||||
|
||||
def _ensure_tiktoken_cache() -> str:
|
||||
cache_dir = get_project_base_directory()
|
||||
os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir
|
||||
|
||||
bundled_encoding_path = get_project_base_directory("ragflow_deps", "cl100k_base.tiktoken")
|
||||
cached_encoding_path = os.path.join(cache_dir, "9b5ad71b2ce5302211f9c61530b329a4922fc6a4")
|
||||
|
||||
if os.path.exists(bundled_encoding_path) and not os.path.exists(cached_encoding_path):
|
||||
shutil.copyfile(bundled_encoding_path, cached_encoding_path)
|
||||
|
||||
return cache_dir
|
||||
|
||||
|
||||
tiktoken_cache_dir = _ensure_tiktoken_cache()
|
||||
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
|
||||
# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
encoder = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
@@ -98,7 +98,7 @@ class RAGFlowPdfParser:
|
||||
model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc")
|
||||
self.updown_cnt_mdl.load_model(os.path.join(model_dir, "updown_concat_xgb.model"))
|
||||
except Exception:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/text_concat_xgb_v1.0", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), local_dir_use_symlinks=False)
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/text_concat_xgb_v1.0", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"))
|
||||
self.updown_cnt_mdl.load_model(os.path.join(model_dir, "updown_concat_xgb.model"))
|
||||
|
||||
self.page_from = 0
|
||||
|
||||
@@ -62,7 +62,7 @@ class LayoutRecognizer(Recognizer):
|
||||
model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc")
|
||||
super().__init__(self.labels, domain, model_dir)
|
||||
except Exception:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), local_dir_use_symlinks=False)
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"))
|
||||
super().__init__(self.labels, domain, model_dir)
|
||||
|
||||
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True):
|
||||
|
||||
@@ -570,9 +570,10 @@ class OCR:
|
||||
self.text_recognizer = [TextRecognizer(model_dir)]
|
||||
|
||||
except Exception:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||
local_dir_use_symlinks=False)
|
||||
model_dir = snapshot_download(
|
||||
repo_id="InfiniFlow/deepdoc",
|
||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||
)
|
||||
|
||||
if settings.PARALLEL_DEVICES > 0:
|
||||
self.text_detector = []
|
||||
|
||||
@@ -47,7 +47,6 @@ class TableStructureRecognizer(Recognizer):
|
||||
snapshot_download(
|
||||
repo_id="InfiniFlow/deepdoc",
|
||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||
local_dir_use_symlinks=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ end-to-end behavior is intentionally left to higher-level integration tests.
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from importlib import import_module, reload
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -38,102 +39,90 @@ pytestmark = pytest.mark.p2
|
||||
# runtime environment. We track every key we install and restore the prior
|
||||
# sys.modules state in teardown_module so the stubs don't leak into other
|
||||
# test files.
|
||||
_SENTINEL_ABSENT = object()
|
||||
_INSTALLED_STUBS: dict[str, object] = {}
|
||||
@pytest.fixture(scope="module")
|
||||
def pipeline_chunker_module():
|
||||
"""Import pipeline_chunker with rag.app parser modules stubbed locally."""
|
||||
stubbed_names = [
|
||||
"api.db.services.file_service",
|
||||
"deepdoc.vision.ocr",
|
||||
"deepdoc.parser.figure_parser",
|
||||
"rag.app.picture",
|
||||
"rag.app.audio",
|
||||
"rag.app.resume",
|
||||
"rag.app.naive",
|
||||
"rag.app.paper",
|
||||
"rag.app.book",
|
||||
"rag.app.presentation",
|
||||
"rag.app.manual",
|
||||
"rag.app.laws",
|
||||
"rag.app.qa",
|
||||
"rag.app.table",
|
||||
"rag.app.one",
|
||||
"rag.app.email",
|
||||
"rag.app.tag",
|
||||
]
|
||||
original_modules = {name: sys.modules.get(name) for name in stubbed_names}
|
||||
|
||||
file_service_stub = MagicMock()
|
||||
file_service_stub.FileService = MagicMock()
|
||||
|
||||
def _install_stub(name: str, stub: object) -> None:
|
||||
"""Insert ``stub`` into sys.modules and remember the prior entry."""
|
||||
if name in _INSTALLED_STUBS:
|
||||
return
|
||||
_INSTALLED_STUBS[name] = sys.modules.get(name, _SENTINEL_ABSENT)
|
||||
sys.modules[name] = stub
|
||||
try:
|
||||
sys.modules["api.db.services.file_service"] = file_service_stub
|
||||
for name in stubbed_names[1:]:
|
||||
stub = MagicMock()
|
||||
stub.chunk = MagicMock(return_value=[{"content_with_weight": "stub"}])
|
||||
sys.modules[name] = stub
|
||||
|
||||
|
||||
_file_service_stub = MagicMock()
|
||||
_file_service_stub.FileService = MagicMock()
|
||||
if "api.db.services.file_service" not in sys.modules:
|
||||
_install_stub("api.db.services.file_service", _file_service_stub)
|
||||
|
||||
for mod in [
|
||||
"deepdoc.vision.ocr",
|
||||
"deepdoc.parser.figure_parser",
|
||||
"rag.app.picture",
|
||||
"rag.app.audio",
|
||||
"rag.app.resume",
|
||||
"rag.app.naive",
|
||||
"rag.app.paper",
|
||||
"rag.app.book",
|
||||
"rag.app.presentation",
|
||||
"rag.app.manual",
|
||||
"rag.app.laws",
|
||||
"rag.app.qa",
|
||||
"rag.app.table",
|
||||
"rag.app.one",
|
||||
"rag.app.email",
|
||||
"rag.app.tag",
|
||||
]:
|
||||
if mod not in sys.modules:
|
||||
stub = MagicMock()
|
||||
stub.chunk = MagicMock(return_value=[{"content_with_weight": "stub"}])
|
||||
_install_stub(mod, stub)
|
||||
|
||||
|
||||
def teardown_module(module):
|
||||
"""Restore sys.modules to its pre-stub state when this file's tests finish."""
|
||||
for name, original in _INSTALLED_STUBS.items():
|
||||
if original is _SENTINEL_ABSENT:
|
||||
sys.modules.pop(name, None)
|
||||
else:
|
||||
sys.modules[name] = original
|
||||
_INSTALLED_STUBS.clear()
|
||||
|
||||
from agent.component.pipeline_chunker import ( # noqa: E402
|
||||
_PARSER_MODULES,
|
||||
PipelineChunkerParam,
|
||||
_load_chunker,
|
||||
)
|
||||
module = import_module("agent.component.pipeline_chunker")
|
||||
module = reload(module)
|
||||
yield module
|
||||
finally:
|
||||
for name, original in original_modules.items():
|
||||
if original is None:
|
||||
sys.modules.pop(name, None)
|
||||
else:
|
||||
sys.modules[name] = original
|
||||
|
||||
|
||||
class TestPipelineChunkerParam:
|
||||
"""Validate parameter parsing and the strategy whitelist."""
|
||||
|
||||
def test_default_param_validates(self):
|
||||
def test_default_param_validates(self, pipeline_chunker_module):
|
||||
"""A freshly constructed param object should pass ``check()``."""
|
||||
p = PipelineChunkerParam()
|
||||
p = pipeline_chunker_module.PipelineChunkerParam()
|
||||
assert p.check() is True
|
||||
|
||||
def test_accepts_each_known_parser(self):
|
||||
def test_accepts_each_known_parser(self, pipeline_chunker_module):
|
||||
"""Every parser id in the lookup table must validate."""
|
||||
for parser_id in _PARSER_MODULES:
|
||||
p = PipelineChunkerParam()
|
||||
for parser_id in pipeline_chunker_module._PARSER_MODULES:
|
||||
p = pipeline_chunker_module.PipelineChunkerParam()
|
||||
p.parser_id = parser_id
|
||||
assert p.check() is True
|
||||
|
||||
def test_rejects_unknown_parser(self):
|
||||
def test_rejects_unknown_parser(self, pipeline_chunker_module):
|
||||
"""Unknown parser ids must raise ``ValueError`` at validation time."""
|
||||
p = PipelineChunkerParam()
|
||||
p = pipeline_chunker_module.PipelineChunkerParam()
|
||||
p.parser_id = "nonsense-parser"
|
||||
with pytest.raises(ValueError):
|
||||
p.check()
|
||||
|
||||
def test_rejects_non_dict_parser_config(self):
|
||||
def test_rejects_non_dict_parser_config(self, pipeline_chunker_module):
|
||||
"""``parser_config`` must be a dict; anything else must raise."""
|
||||
p = PipelineChunkerParam()
|
||||
p = pipeline_chunker_module.PipelineChunkerParam()
|
||||
p.parser_config = "not a dict"
|
||||
with pytest.raises(ValueError):
|
||||
p.check()
|
||||
|
||||
def test_rejects_negative_pages(self):
|
||||
def test_rejects_negative_pages(self, pipeline_chunker_module):
|
||||
"""Negative page indices must raise ``ValueError``."""
|
||||
p = PipelineChunkerParam()
|
||||
p = pipeline_chunker_module.PipelineChunkerParam()
|
||||
p.from_page = -1
|
||||
with pytest.raises(ValueError):
|
||||
p.check()
|
||||
|
||||
def test_rejects_inverted_page_range(self):
|
||||
def test_rejects_inverted_page_range(self, pipeline_chunker_module):
|
||||
"""``from_page`` greater than ``to_page`` must raise ``ValueError``."""
|
||||
p = PipelineChunkerParam()
|
||||
p = pipeline_chunker_module.PipelineChunkerParam()
|
||||
p.from_page = 10
|
||||
p.to_page = 5
|
||||
with pytest.raises(ValueError, match="from_page must be <= to_page"):
|
||||
@@ -143,13 +132,13 @@ class TestPipelineChunkerParam:
|
||||
class TestLoadChunker:
|
||||
"""Verify the lazy parser-id -> chunker callable resolver."""
|
||||
|
||||
def test_load_chunker_returns_callable_for_each_known_parser(self):
|
||||
def test_load_chunker_returns_callable_for_each_known_parser(self, pipeline_chunker_module):
|
||||
"""Every known parser id should resolve to a callable ``chunk`` function."""
|
||||
for parser_id in _PARSER_MODULES:
|
||||
chunker = _load_chunker(parser_id)
|
||||
for parser_id in pipeline_chunker_module._PARSER_MODULES:
|
||||
chunker = pipeline_chunker_module._load_chunker(parser_id)
|
||||
assert callable(chunker)
|
||||
|
||||
def test_load_chunker_raises_for_unknown_parser(self):
|
||||
def test_load_chunker_raises_for_unknown_parser(self, pipeline_chunker_module):
|
||||
"""Unknown parser ids should raise ``KeyError`` from the lookup."""
|
||||
with pytest.raises(KeyError):
|
||||
_load_chunker("not-a-real-parser")
|
||||
pipeline_chunker_module._load_chunker("not-a-real-parser")
|
||||
|
||||
@@ -65,8 +65,8 @@ def _load_bot_api(monkeypatch, *, accessible, calls):
|
||||
calls["completion"] = True
|
||||
|
||||
async def _gen():
|
||||
if False:
|
||||
yield ""
|
||||
yield 'data: {"event":"message","data":{"content":"ok"}}\n\n'
|
||||
yield 'data: {"event":"message_end","data":{"content":"ok"}}\n\n'
|
||||
return _gen()
|
||||
|
||||
_stub(monkeypatch, "quart", Response=lambda *a, **k: SimpleNamespace(headers=SimpleNamespace(add_header=lambda *aa, **kk: None)), request=SimpleNamespace())
|
||||
@@ -83,7 +83,12 @@ def _load_bot_api(monkeypatch, *, accessible, calls):
|
||||
_stub(monkeypatch, "api.db.services.llm_service", LLMBundle=SimpleNamespace())
|
||||
_stub(monkeypatch, "common.metadata_utils", apply_meta_data_filter=lambda *_a, **_k: None)
|
||||
_stub(monkeypatch, "api.db.services.search_service", SearchService=SimpleNamespace())
|
||||
_stub(monkeypatch, "api.db.services.user_service", UserTenantService=SimpleNamespace())
|
||||
_stub(
|
||||
monkeypatch,
|
||||
"api.db.services.user_service",
|
||||
TenantService=SimpleNamespace(),
|
||||
UserTenantService=SimpleNamespace(),
|
||||
)
|
||||
_stub(monkeypatch, "api.db.joint_services.tenant_model_service", get_tenant_default_model_by_type=lambda *_a, **_k: None, get_model_config_from_provider_instance=lambda *_a, **_k: None)
|
||||
_stub(monkeypatch, "common.misc_utils", get_uuid=lambda: "uuid", thread_pool_exec=_passthrough_thread_pool_exec)
|
||||
_stub(
|
||||
@@ -118,7 +123,7 @@ def _load_bot_api(monkeypatch, *, accessible, calls):
|
||||
|
||||
|
||||
async def _async_empty_json():
|
||||
return {}
|
||||
return {"stream": False}
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
|
||||
@@ -53,23 +53,24 @@ class TestGetProjectBaseDirectory:
|
||||
def test_uses_environment_variable_when_available(self):
|
||||
"""Test that function uses RAG_PROJECT_BASE environment variable when set"""
|
||||
test_path = "/custom/project/path"
|
||||
|
||||
file_utils.PROJECT_BASE = test_path
|
||||
|
||||
result = get_project_base_directory()
|
||||
with patch.dict(os.environ, {"RAG_PROJECT_BASE": test_path}, clear=False):
|
||||
result = get_project_base_directory()
|
||||
assert result == test_path
|
||||
|
||||
def test_calculates_default_path_when_no_env_vars(self):
|
||||
"""Test that function calculates default path when no environment variables are set"""
|
||||
with patch.dict(os.environ, {}, clear=True): # Clear all environment variables
|
||||
# Reset the global variable to force re-initialization
|
||||
original_base = file_utils.PROJECT_BASE
|
||||
try:
|
||||
file_utils.PROJECT_BASE = None
|
||||
result = get_project_base_directory()
|
||||
|
||||
result = get_project_base_directory()
|
||||
|
||||
# Should return a valid absolute path
|
||||
assert result is not None
|
||||
assert os.path.isabs(result)
|
||||
assert os.path.basename(result) != "" # Should not be root directory
|
||||
# Should return a valid absolute path
|
||||
assert result is not None
|
||||
assert os.path.isabs(result)
|
||||
assert os.path.basename(result) != "" # Should not be root directory
|
||||
finally:
|
||||
file_utils.PROJECT_BASE = original_base
|
||||
|
||||
def test_caches_project_base_value(self):
|
||||
"""Test that PROJECT_BASE is cached after first calculation"""
|
||||
|
||||
@@ -28,25 +28,44 @@ from __future__ import annotations
|
||||
|
||||
import io
|
||||
import sys
|
||||
from importlib import import_module, reload
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Mock heavy modules that trigger ONNX/OCR model loading or optional parser
|
||||
# backends at import time so importing rag.app.naive stays lightweight.
|
||||
for _mod in [
|
||||
"deepdoc.vision.ocr",
|
||||
"deepdoc.parser.figure_parser",
|
||||
"deepdoc.parser.docling_parser",
|
||||
"deepdoc.parser.tcadp_parser",
|
||||
"rag.app.picture",
|
||||
]:
|
||||
if _mod not in sys.modules:
|
||||
sys.modules[_mod] = MagicMock()
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from common import ssrf_guard
|
||||
from rag.app.naive import MAX_IMAGE_REDIRECTS, Markdown
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def naive_module():
|
||||
"""Load rag.app.naive with heavy optional dependencies stubbed locally."""
|
||||
stub_names = [
|
||||
"deepdoc.vision.ocr",
|
||||
"deepdoc.parser.figure_parser",
|
||||
"deepdoc.parser.docling_parser",
|
||||
"deepdoc.parser.tcadp_parser",
|
||||
"rag.app.picture",
|
||||
]
|
||||
original_modules = {name: sys.modules.get(name) for name in stub_names}
|
||||
|
||||
try:
|
||||
for name in stub_names:
|
||||
sys.modules[name] = MagicMock()
|
||||
module = import_module("rag.app.naive")
|
||||
module = reload(module)
|
||||
yield module
|
||||
finally:
|
||||
for name, original in original_modules.items():
|
||||
if original is None:
|
||||
sys.modules.pop(name, None)
|
||||
else:
|
||||
sys.modules[name] = original
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def max_image_redirects(naive_module):
|
||||
return naive_module.MAX_IMAGE_REDIRECTS
|
||||
|
||||
|
||||
def _png_bytes() -> bytes:
|
||||
@@ -68,8 +87,8 @@ class _Resp:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser():
|
||||
return Markdown(128)
|
||||
def parser(naive_module):
|
||||
return naive_module.Markdown(128)
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
@@ -127,7 +146,7 @@ def test_fetches_legitimate_public_image(parser):
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_redirect_chain_is_bounded(parser):
|
||||
def test_redirect_chain_is_bounded(parser, max_image_redirects):
|
||||
"""An endless redirect loop is abandoned instead of being followed forever."""
|
||||
loop = _Resp(302, headers={"Location": "http://public.example/next"})
|
||||
with (
|
||||
@@ -136,5 +155,5 @@ def test_redirect_chain_is_bounded(parser):
|
||||
):
|
||||
images, _ = parser.load_images_from_urls(["http://public.example/start"])
|
||||
|
||||
assert get.call_count == MAX_IMAGE_REDIRECTS + 1
|
||||
assert get.call_count == max_image_redirects + 1
|
||||
assert images == []
|
||||
|
||||
@@ -20,18 +20,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from importlib import import_module, reload
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Mock heavy modules that trigger ONNX model loading at import time
|
||||
# table.py -> deepdoc.parser.figure_parser -> rag.app.picture -> OCR()
|
||||
for mod in [
|
||||
"deepdoc.vision.ocr",
|
||||
"deepdoc.parser.figure_parser",
|
||||
"rag.app.picture",
|
||||
]:
|
||||
if mod not in sys.modules:
|
||||
sys.modules[mod] = MagicMock()
|
||||
|
||||
import warnings
|
||||
|
||||
# Importing rag.app.table pulls api -> rag.llm -> deepdoc -> xgboost; xgboost may warn on
|
||||
@@ -42,7 +33,6 @@ import pkg_resources # noqa: F401 — stabilize xgboost import during collectio
|
||||
import pytest
|
||||
|
||||
import common.settings as settings
|
||||
from rag.app.table import chunk
|
||||
|
||||
# chunk() removes columns named id, _id, index, idx — use row_id instead of id.
|
||||
TEST_CSV = b"""row_id,title,content,country,category
|
||||
@@ -76,14 +66,38 @@ def _stub_rag_tokenizer(monkeypatch):
|
||||
monkeypatch.setattr("rag.nlp.rag_tokenizer.fine_grained_tokenize", fake_tokenize)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def table_module():
|
||||
"""Load rag.app.table with heavy optional dependencies stubbed locally."""
|
||||
stub_names = [
|
||||
"deepdoc.vision.ocr",
|
||||
"deepdoc.parser.figure_parser",
|
||||
"rag.app.picture",
|
||||
]
|
||||
original_modules = {name: sys.modules.get(name) for name in stub_names}
|
||||
|
||||
try:
|
||||
for name in stub_names:
|
||||
sys.modules[name] = MagicMock()
|
||||
module = import_module("rag.app.table")
|
||||
module = reload(module)
|
||||
yield module
|
||||
finally:
|
||||
for name, original in original_modules.items():
|
||||
if original is None:
|
||||
sys.modules.pop(name, None)
|
||||
else:
|
||||
sys.modules[name] = original
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_update_kb():
|
||||
with patch("rag.app.table.KnowledgebaseService.update_parser_config") as m:
|
||||
yield m
|
||||
|
||||
|
||||
def _run_chunk(parser_config: dict, mock_update_kb: MagicMock):
|
||||
return chunk(
|
||||
def _run_chunk(table_module, parser_config: dict, mock_update_kb: MagicMock):
|
||||
return table_module.chunk(
|
||||
FILENAME,
|
||||
binary=TEST_CSV,
|
||||
callback=_noop_callback,
|
||||
@@ -93,9 +107,9 @@ def _run_chunk(parser_config: dict, mock_update_kb: MagicMock):
|
||||
)
|
||||
|
||||
|
||||
def test_chunk_auto_mode_all_columns_in_text_and_stored(mock_update_kb: MagicMock):
|
||||
def test_chunk_auto_mode_all_columns_in_text_and_stored(table_module, mock_update_kb: MagicMock):
|
||||
parser_config: dict = {}
|
||||
chunks = _run_chunk(parser_config, mock_update_kb)
|
||||
chunks = _run_chunk(table_module, parser_config, mock_update_kb)
|
||||
assert len(chunks) == 3
|
||||
first = chunks[0]
|
||||
cww = first["content_with_weight"]
|
||||
@@ -109,7 +123,7 @@ def test_chunk_auto_mode_all_columns_in_text_and_stored(mock_update_kb: MagicMoc
|
||||
assert "title_raw" in first and "country_raw" in first
|
||||
|
||||
|
||||
def test_chunk_manual_mode_indexing_only(mock_update_kb: MagicMock):
|
||||
def test_chunk_manual_mode_indexing_only(table_module, mock_update_kb: MagicMock):
|
||||
parser_config = {
|
||||
"table_column_mode": "manual",
|
||||
"table_column_roles": {
|
||||
@@ -120,7 +134,7 @@ def test_chunk_manual_mode_indexing_only(mock_update_kb: MagicMock):
|
||||
"category": "metadata",
|
||||
},
|
||||
}
|
||||
chunks = _run_chunk(parser_config, mock_update_kb)
|
||||
chunks = _run_chunk(table_module, parser_config, mock_update_kb)
|
||||
first = chunks[0]
|
||||
cww = first["content_with_weight"]
|
||||
assert "- title:" in cww and "Earthquake" in cww
|
||||
@@ -135,7 +149,7 @@ def test_chunk_manual_mode_indexing_only(mock_update_kb: MagicMock):
|
||||
assert "row_id_long" in first
|
||||
|
||||
|
||||
def test_chunk_manual_mode_legacy_vectorize_role(mock_update_kb: MagicMock):
|
||||
def test_chunk_manual_mode_legacy_vectorize_role(table_module, mock_update_kb: MagicMock):
|
||||
"""Stored configs may still use role *vectorize*; chunking treats it like *indexing*."""
|
||||
parser_config = {
|
||||
"table_column_mode": "manual",
|
||||
@@ -147,7 +161,7 @@ def test_chunk_manual_mode_legacy_vectorize_role(mock_update_kb: MagicMock):
|
||||
"category": "metadata",
|
||||
},
|
||||
}
|
||||
chunks = _run_chunk(parser_config, mock_update_kb)
|
||||
chunks = _run_chunk(table_module, parser_config, mock_update_kb)
|
||||
first = chunks[0]
|
||||
cww = first["content_with_weight"]
|
||||
assert "- title:" in cww and "Earthquake" in cww
|
||||
@@ -155,7 +169,7 @@ def test_chunk_manual_mode_legacy_vectorize_role(mock_update_kb: MagicMock):
|
||||
assert "- country:" not in cww
|
||||
|
||||
|
||||
def test_chunk_manual_mode_metadata_only(mock_update_kb: MagicMock):
|
||||
def test_chunk_manual_mode_metadata_only(table_module, mock_update_kb: MagicMock):
|
||||
parser_config = {
|
||||
"table_column_mode": "manual",
|
||||
"table_column_roles": {
|
||||
@@ -166,18 +180,18 @@ def test_chunk_manual_mode_metadata_only(mock_update_kb: MagicMock):
|
||||
"category": "metadata",
|
||||
},
|
||||
}
|
||||
chunks = _run_chunk(parser_config, mock_update_kb)
|
||||
chunks = _run_chunk(table_module, parser_config, mock_update_kb)
|
||||
first = chunks[0]
|
||||
assert (first.get("content_with_weight") or "").strip() == ""
|
||||
assert "country_raw" in first and "title_raw" in first
|
||||
|
||||
|
||||
def test_chunk_manual_mode_both(mock_update_kb: MagicMock):
|
||||
def test_chunk_manual_mode_both(table_module, mock_update_kb: MagicMock):
|
||||
parser_config = {
|
||||
"table_column_mode": "manual",
|
||||
"table_column_roles": {c: "both" for c in ["title", "content", "country", "category", "row_id"]},
|
||||
}
|
||||
chunks = _run_chunk(parser_config, mock_update_kb)
|
||||
chunks = _run_chunk(table_module, parser_config, mock_update_kb)
|
||||
first = chunks[0]
|
||||
cww = first["content_with_weight"]
|
||||
assert "Earthquake hits Turkey" in cww
|
||||
@@ -187,7 +201,7 @@ def test_chunk_manual_mode_both(mock_update_kb: MagicMock):
|
||||
assert "title_raw" in first and "country_raw" in first
|
||||
|
||||
|
||||
def test_chunk_manual_mode_partial_roles_default_to_both(mock_update_kb: MagicMock):
|
||||
def test_chunk_manual_mode_partial_roles_default_to_both(table_module, mock_update_kb: MagicMock):
|
||||
parser_config = {
|
||||
"table_column_mode": "manual",
|
||||
"table_column_roles": {
|
||||
@@ -195,7 +209,7 @@ def test_chunk_manual_mode_partial_roles_default_to_both(mock_update_kb: MagicMo
|
||||
"country": "metadata",
|
||||
},
|
||||
}
|
||||
chunks = _run_chunk(parser_config, mock_update_kb)
|
||||
chunks = _run_chunk(table_module, parser_config, mock_update_kb)
|
||||
first = chunks[0]
|
||||
cww = first["content_with_weight"]
|
||||
assert "- title:" in cww and "Earthquake" in cww
|
||||
@@ -208,20 +222,20 @@ def test_chunk_manual_mode_partial_roles_default_to_both(mock_update_kb: MagicMo
|
||||
assert "content_raw" in first and "category_raw" in first
|
||||
|
||||
|
||||
def test_chunk_manual_mode_raw_fields_for_es(mock_update_kb: MagicMock):
|
||||
def test_chunk_manual_mode_raw_fields_for_es(table_module, mock_update_kb: MagicMock):
|
||||
parser_config = {
|
||||
"table_column_mode": "manual",
|
||||
"table_column_roles": {c: "both" for c in ["title", "content", "country", "category", "row_id"]},
|
||||
}
|
||||
chunks = _run_chunk(parser_config, mock_update_kb)
|
||||
chunks = _run_chunk(table_module, parser_config, mock_update_kb)
|
||||
first = chunks[0]
|
||||
for col in ("title", "content", "country", "category"):
|
||||
assert f"{col}_raw" in first
|
||||
assert f"{col}_tks" in first
|
||||
|
||||
|
||||
def test_chunk_updates_table_column_names(mock_update_kb: MagicMock):
|
||||
_run_chunk({}, mock_update_kb)
|
||||
def test_chunk_updates_table_column_names(table_module, mock_update_kb: MagicMock):
|
||||
_run_chunk(table_module, {}, mock_update_kb)
|
||||
mock_update_kb.assert_called_once()
|
||||
args, kwargs = mock_update_kb.call_args
|
||||
assert args[0] == KB_ID
|
||||
@@ -230,6 +244,6 @@ def test_chunk_updates_table_column_names(mock_update_kb: MagicMock):
|
||||
assert names == ["row_id", "title", "content", "country", "category"]
|
||||
|
||||
|
||||
def test_chunk_count_matches_row_count(mock_update_kb: MagicMock):
|
||||
chunks = _run_chunk({}, mock_update_kb)
|
||||
def test_chunk_count_matches_row_count(table_module, mock_update_kb: MagicMock):
|
||||
chunks = _run_chunk(table_module, {}, mock_update_kb)
|
||||
assert len(chunks) == 3
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import sys
|
||||
import types
|
||||
from importlib import import_module, reload
|
||||
from io import BytesIO
|
||||
|
||||
import pytest
|
||||
@@ -37,12 +38,32 @@ class _DummyBase:
|
||||
pass
|
||||
|
||||
|
||||
_stub("deepdoc.parser", PdfParser=_DummyBase, DocxParser=_DummyBase, HtmlParser=_DummyBase)
|
||||
_stub("deepdoc.parser.utils", get_text=lambda *a, **k: "")
|
||||
_stub("rag.app.naive", by_plaintext=lambda *a, **k: ([], [], None), PARSERS={})
|
||||
_stub("common.parser_config_utils", normalize_layout_recognizer=lambda x: (x, None))
|
||||
@pytest.fixture(scope="module")
|
||||
def docx_chunker():
|
||||
original_modules = {
|
||||
name: sys.modules.get(name)
|
||||
for name in (
|
||||
"deepdoc.parser",
|
||||
"deepdoc.parser.utils",
|
||||
"rag.app.naive",
|
||||
"common.parser_config_utils",
|
||||
)
|
||||
}
|
||||
|
||||
from rag.app.laws import Docx # noqa: E402
|
||||
try:
|
||||
_stub("deepdoc.parser", PdfParser=_DummyBase, DocxParser=_DummyBase, HtmlParser=_DummyBase)
|
||||
_stub("deepdoc.parser.utils", get_text=lambda *a, **k: "")
|
||||
_stub("rag.app.naive", by_plaintext=lambda *a, **k: ([], [], None), PARSERS={})
|
||||
_stub("common.parser_config_utils", normalize_layout_recognizer=lambda x: (x, None))
|
||||
module = import_module("rag.app.laws")
|
||||
module = reload(module)
|
||||
yield module.Docx
|
||||
finally:
|
||||
for name, original in original_modules.items():
|
||||
if original is None:
|
||||
sys.modules.pop(name, None)
|
||||
else:
|
||||
sys.modules[name] = original
|
||||
|
||||
|
||||
def _build_docx(builder):
|
||||
@@ -54,7 +75,7 @@ def _build_docx(builder):
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_laws_docx_preserves_table():
|
||||
def test_laws_docx_preserves_table(docx_chunker):
|
||||
"""Regression for #16008: the laws DOCX parser dropped tables entirely."""
|
||||
|
||||
def builder(d):
|
||||
@@ -67,7 +88,7 @@ def test_laws_docx_preserves_table():
|
||||
t.cell(1, 0).text = "Registration"
|
||||
t.cell(1, 1).text = "100"
|
||||
|
||||
chunks = Docx()("law.docx", _build_docx(builder))
|
||||
chunks = docx_chunker()("law.docx", _build_docx(builder))
|
||||
|
||||
assert any("<table>" in c for c in chunks)
|
||||
table_chunk = next(c for c in chunks if "<table>" in c)
|
||||
@@ -78,7 +99,7 @@ def test_laws_docx_preserves_table():
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_laws_docx_merged_cells_use_colspan():
|
||||
def test_laws_docx_merged_cells_use_colspan(docx_chunker):
|
||||
def builder(d):
|
||||
d.add_heading("Heading", level=1)
|
||||
t = d.add_table(rows=1, cols=3)
|
||||
@@ -87,20 +108,20 @@ def test_laws_docx_merged_cells_use_colspan():
|
||||
t.cell(0, 1).text = "Merged"
|
||||
t.cell(0, 2).text = "Other"
|
||||
|
||||
chunks = Docx()("law.docx", _build_docx(builder))
|
||||
chunks = docx_chunker()("law.docx", _build_docx(builder))
|
||||
table_chunk = next(c for c in chunks if "<table>" in c)
|
||||
assert "colspan='2'" in table_chunk
|
||||
assert "<td>Other</td>" in table_chunk
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_laws_docx_escapes_cell_html():
|
||||
def test_laws_docx_escapes_cell_html(docx_chunker):
|
||||
def builder(d):
|
||||
d.add_heading("Heading", level=1)
|
||||
t = d.add_table(rows=1, cols=1)
|
||||
t.cell(0, 0).text = "a < b & c > d"
|
||||
|
||||
chunks = Docx()("law.docx", _build_docx(builder))
|
||||
chunks = docx_chunker()("law.docx", _build_docx(builder))
|
||||
table_chunk = next(c for c in chunks if "<table>" in c)
|
||||
# Special characters are HTML-escaped so the table markup stays well-formed.
|
||||
assert "a < b & c > d" in table_chunk
|
||||
@@ -108,17 +129,17 @@ def test_laws_docx_escapes_cell_html():
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_laws_docx_tables_only_does_not_crash():
|
||||
def test_laws_docx_tables_only_does_not_crash(docx_chunker):
|
||||
def builder(d):
|
||||
t = d.add_table(rows=1, cols=2)
|
||||
t.cell(0, 0).text = "a"
|
||||
t.cell(0, 1).text = "b"
|
||||
|
||||
chunks = Docx()("law.docx", _build_docx(builder))
|
||||
chunks = docx_chunker()("law.docx", _build_docx(builder))
|
||||
assert any("<table>" in c for c in chunks)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_laws_docx_empty_doc_returns_empty():
|
||||
chunks = Docx()("law.docx", _build_docx(lambda d: None))
|
||||
def test_laws_docx_empty_doc_returns_empty(docx_chunker):
|
||||
chunks = docx_chunker()("law.docx", _build_docx(lambda d: None))
|
||||
assert chunks == []
|
||||
|
||||
@@ -132,6 +132,8 @@ function MessageItem({
|
||||
return null;
|
||||
}
|
||||
|
||||
const hasCustomChildren = item.data && !!children;
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn({
|
||||
@@ -142,7 +144,7 @@ function MessageItem({
|
||||
})}
|
||||
dir={getDirAttribute(messageContent.replace(citationMarkerReg, ''))}
|
||||
>
|
||||
{item.data ? (
|
||||
{hasCustomChildren ? (
|
||||
children
|
||||
) : sendLoading && isEmpty(messageContent) ? (
|
||||
<>{!isShare && t('common.running')}</>
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { FileUploadProps } from '@/components/file-upload';
|
||||
import { NextMessageInput } from '@/components/message-input/next';
|
||||
import MarkdownContent from '@/components/next-markdown-content';
|
||||
import MessageItem from '@/components/next-message-item';
|
||||
import PdfSheet from '@/components/pdf-drawer';
|
||||
import { useClickDrawer } from '@/components/pdf-drawer/hooks';
|
||||
@@ -8,6 +9,8 @@ import { useUploadAgentFileWithProgress } from '@/hooks/use-agent-request';
|
||||
import { useFetchUserInfo } from '@/hooks/use-user-setting-request';
|
||||
import { IAgentLogResponse } from '@/interfaces/database/agent';
|
||||
import { IMessage } from '@/interfaces/database/chat';
|
||||
import DebugContent from '@/pages/agent/debug-content';
|
||||
import { useAwaitComponentData } from '@/pages/agent/hooks/use-chat-logic';
|
||||
import { BeginQuery } from '@/pages/agent/interface';
|
||||
import { ParameterDialog } from '@/pages/agent/share/parameter-dialog';
|
||||
import { buildMessageUuidWithRole } from '@/utils/chat';
|
||||
@@ -37,6 +40,7 @@ export function SessionChat({ session }: SessionChatProps) {
|
||||
handleInputChange,
|
||||
handlePressEnter,
|
||||
stopOutputMessage,
|
||||
sendFormMessage,
|
||||
canvasInfo,
|
||||
findReferenceByMessageId,
|
||||
appendUploadResponseList,
|
||||
@@ -47,6 +51,11 @@ export function SessionChat({ session }: SessionChatProps) {
|
||||
shouldShowParameterDialog,
|
||||
setDerivedMessages,
|
||||
} = useSendSessionMessage();
|
||||
|
||||
const { buildInputList, handleOk, isWaiting } = useAwaitComponentData({
|
||||
derivedMessages,
|
||||
sendFormMessage,
|
||||
});
|
||||
const hasActiveSession = Boolean(
|
||||
sessionId || isNew || hasLocalMessageRef.current,
|
||||
);
|
||||
@@ -122,26 +131,58 @@ export function SessionChat({ session }: SessionChatProps) {
|
||||
</div>
|
||||
) : (
|
||||
<div className="w-full pr-5">
|
||||
{derivedMessages.map((message, i) => (
|
||||
<MessageItem
|
||||
loading={
|
||||
message.role === MessageType.Assistant &&
|
||||
sendLoading &&
|
||||
derivedMessages.length - 1 === i
|
||||
}
|
||||
key={buildMessageUuidWithRole(message)}
|
||||
item={message}
|
||||
nickname={userInfo.nickname}
|
||||
avatar={userInfo.avatar}
|
||||
avatarDialog={canvasInfo?.avatar || ''}
|
||||
reference={findReferenceByMessageId(message.id)}
|
||||
clickDocumentButton={clickDocumentButton}
|
||||
index={i}
|
||||
showLikeButton={false}
|
||||
sendLoading={sendLoading}
|
||||
showLog={false}
|
||||
/>
|
||||
))}
|
||||
{derivedMessages.map((message, i) => {
|
||||
const inputList = buildInputList(message);
|
||||
const hasUserFillUpInputs =
|
||||
message.role === MessageType.Assistant &&
|
||||
inputList.length > 0;
|
||||
|
||||
return (
|
||||
<MessageItem
|
||||
loading={
|
||||
message.role === MessageType.Assistant &&
|
||||
sendLoading &&
|
||||
derivedMessages.length - 1 === i
|
||||
}
|
||||
key={buildMessageUuidWithRole(message)}
|
||||
item={message}
|
||||
nickname={userInfo.nickname}
|
||||
avatar={userInfo.avatar}
|
||||
avatarDialog={canvasInfo?.avatar || ''}
|
||||
reference={findReferenceByMessageId(message.id)}
|
||||
clickDocumentButton={clickDocumentButton}
|
||||
index={i}
|
||||
showLikeButton={false}
|
||||
sendLoading={sendLoading}
|
||||
showLog={false}
|
||||
>
|
||||
{hasUserFillUpInputs &&
|
||||
derivedMessages.length - 1 === i && (
|
||||
<DebugContent
|
||||
parameters={inputList}
|
||||
message={message}
|
||||
ok={handleOk(message)}
|
||||
isNext={false}
|
||||
btnText={t('common.submit')}
|
||||
></DebugContent>
|
||||
)}
|
||||
{hasUserFillUpInputs &&
|
||||
derivedMessages.length - 1 !== i && (
|
||||
<div>
|
||||
<MarkdownContent
|
||||
content={message?.data?.tips}
|
||||
loading={false}
|
||||
></MarkdownContent>
|
||||
<div>
|
||||
{inputList.map((item) => (
|
||||
<div key={item.key}>{item.value}</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</MessageItem>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
<div ref={scrollRef} />
|
||||
@@ -151,9 +192,9 @@ export function SessionChat({ session }: SessionChatProps) {
|
||||
<NextMessageInput
|
||||
value={value}
|
||||
sendLoading={sendLoading}
|
||||
disabled={false}
|
||||
sendDisabled={sendLoading}
|
||||
isUploading={isUploading}
|
||||
disabled={isWaiting}
|
||||
sendDisabled={sendLoading || isWaiting}
|
||||
isUploading={isUploading || isWaiting}
|
||||
onPressEnter={handleSessionPressEnter}
|
||||
onInputChange={handleInputChange}
|
||||
stopOutputMessage={stopOutputMessage}
|
||||
|
||||
@@ -40,14 +40,15 @@ const useAwaitComponentData = (props: IAwaitCompentData) => {
|
||||
|
||||
const isWaiting = useMemo(() => {
|
||||
const temp = derivedMessages?.some((message, i) => {
|
||||
const hasInputs = Object.keys(getInputs(message)).length > 0;
|
||||
const flag =
|
||||
message.role === MessageType.Assistant &&
|
||||
derivedMessages.length - 1 === i &&
|
||||
message.data;
|
||||
hasInputs;
|
||||
return flag;
|
||||
});
|
||||
return temp;
|
||||
}, [derivedMessages]);
|
||||
}, [derivedMessages, getInputs]);
|
||||
return { getInputs, buildInputList, handleOk, isWaiting };
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user