diff --git a/api/apps/restful_apis/mcp_api.py b/api/apps/restful_apis/mcp_api.py index ec384f6074..b3f39fa4bf 100644 --- a/api/apps/restful_apis/mcp_api.py +++ b/api/apps/restful_apis/mcp_api.py @@ -25,6 +25,7 @@ from api.utils.web_utils import get_float, safe_json_parse from common.constants import VALID_MCP_SERVER_TYPES from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions from common.misc_utils import get_uuid, thread_pool_exec +from common.ssrf_guard import assert_url_is_safe, pin_dns_global def _get_mcp_ids_from_args() -> list[str]: @@ -55,6 +56,16 @@ def _export_mcp_servers(mcp_ids: list[str]) -> dict | None: return {"mcpServers": exported_servers} +def _assert_mcp_url_is_safe(url, invalid_message: str = "Invalid url.") -> tuple[str, str, str | None]: + if not isinstance(url, str) or not url: + return "", "", invalid_message + try: + hostname, resolved_ip = assert_url_is_safe(url) + except ValueError as exc: + return "", "", str(exc) + return hostname, resolved_ip, None + + @manager.route("/mcp/servers", methods=["GET"]) # noqa: F821 @login_required async def list_mcp() -> Response: @@ -119,8 +130,9 @@ async def create() -> Response: return get_data_error_result(message="Duplicated MCP server name.") url = req.get("url", "") - if not url: - return get_data_error_result(message="Invalid url.") + hostname, resolved_ip, url_error = _assert_mcp_url_is_safe(url) + if url_error: + return get_data_error_result(message=url_error) headers = safe_json_parse(req.get("headers", {})) req["headers"] = headers @@ -138,7 +150,8 @@ async def create() -> Response: return get_data_error_result(message="Tenant not found.") mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers) - server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout) + with pin_dns_global(hostname, resolved_ip): + server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout) if err_message: return get_data_error_result(message=err_message) @@ -171,8 +184,9 @@ async def update(mcp_id: str) -> Response: if server_name and len(server_name.encode("utf-8")) > 255: return get_data_error_result(message=f"Invalid MCP name or length is {len(server_name)} which is large than 255.") url = req.get("url", mcp_server.url) - if not url: - return get_data_error_result(message="Invalid url.") + hostname, resolved_ip, url_error = _assert_mcp_url_is_safe(url) + if url_error: + return get_data_error_result(message=url_error) headers = safe_json_parse(req.get("headers", mcp_server.headers)) req["headers"] = headers @@ -187,7 +201,8 @@ async def update(mcp_id: str) -> Response: req["id"] = mcp_id mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers) - server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout) + with pin_dns_global(hostname, resolved_ip): + server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout) if err_message: return get_data_error_result(message=err_message) @@ -244,6 +259,13 @@ async def import_multiple() -> Response: if not server_name or len(server_name.encode("utf-8")) > 255: results.append({"server": server_name, "success": False, "message": f"Invalid MCP name or length is {len(server_name)} which is large than 255."}) continue + if config["type"] not in VALID_MCP_SERVER_TYPES: + results.append({"server": server_name, "success": False, "message": "Unsupported MCP server type."}) + continue + hostname, resolved_ip, url_error = _assert_mcp_url_is_safe(config["url"]) + if url_error: + results.append({"server": server_name, "success": False, "message": url_error}) + continue base_name = server_name new_name = base_name @@ -268,7 +290,8 @@ async def import_multiple() -> Response: headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {} variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}} mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers) - server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout) + with pin_dns_global(hostname, resolved_ip): + server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout) if err_message: results.append({"server": base_name, "success": False, "message": err_message}) continue @@ -297,13 +320,17 @@ async def test_mcp(mcp_id: str) -> Response: req = await get_request_json() url = req.get("url", "") - if not url: + if not isinstance(url, str) or not url: return get_data_error_result(message="Invalid MCP url.") server_type = req.get("server_type", "") if server_type not in VALID_MCP_SERVER_TYPES: return get_data_error_result(message="Unsupported MCP server type.") + hostname, resolved_ip, url_error = _assert_mcp_url_is_safe(url, "Invalid MCP url.") + if url_error: + return get_data_error_result(message=url_error) + timeout = get_float(req, "timeout", 10) headers = safe_json_parse(req.get("headers", {})) variables = safe_json_parse(req.get("variables", {})) @@ -312,14 +339,15 @@ async def test_mcp(mcp_id: str) -> Response: result = [] try: - tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) + with pin_dns_global(hostname, resolved_ip): + tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) - try: - tools = await thread_pool_exec(tool_call_session.get_tools, timeout) - except Exception as e: - return get_data_error_result(message=f"Test MCP error: {e}") - finally: - await thread_pool_exec(close_multiple_mcp_toolcall_sessions, [tool_call_session]) + try: + tools = await thread_pool_exec(tool_call_session.get_tools, timeout) + except Exception as e: + return get_data_error_result(message=f"Test MCP error: {e}") + finally: + await thread_pool_exec(close_multiple_mcp_toolcall_sessions, [tool_call_session]) for tool in tools: tool_dict = tool.model_dump() diff --git a/deepdoc/parser/mineru_parser.py b/deepdoc/parser/mineru_parser.py index 2c35ead98c..b0c0eafb6e 100644 --- a/deepdoc/parser/mineru_parser.py +++ b/deepdoc/parser/mineru_parser.py @@ -14,6 +14,7 @@ # limitations under the License. # import json +import html import logging import os import re @@ -205,6 +206,26 @@ class MinerUParser(RAGFlowPdfParser): except Exception: return False + @staticmethod + def _sanitize_section_text(section: str) -> str: + """Normalize MinerU text blocks before chunking. + + MinerU may return HTML fragments (e.g. table_body with //
). + Keep human-readable text while removing tag noise that hurts chunking. + """ + if not section: + return "" + section = html.unescape(section) + # Preserve rough structure before dropping tags. + section = re.sub(r"(?is)<\s*br\s*/?\s*>", "\n", section) + section = re.sub(r"(?is)", "\n", section) + section = re.sub(r"(?is)<[^>]+>", "", section) + # Collapse whitespace while preserving line boundaries. + section = re.sub(r"[ \t]+\n", "\n", section) + section = re.sub(r"\n{3,}", "\n\n", section) + section = re.sub(r"[ \t]{2,}", " ", section) + return section.strip() + def check_installation(self, backend: str = "pipeline", server_url: Optional[str] = None) -> tuple[bool, str]: reason = "" @@ -659,6 +680,11 @@ class MinerUParser(RAGFlowPdfParser): case MinerUContentType.DISCARDED: continue # Skip discarded blocks entirely + section = self._sanitize_section_text(section) + if not section: + self.logger.debug("[MinerU] Skip section after sanitization: type=%s", output.get("type")) + continue + if section and parse_method in {"manual", "pipeline"}: sections.append((section, output["type"], self._line_tag(output))) elif section and parse_method == "paper": diff --git a/test/testcases/restful_api/test_mcp_routes_unit.py b/test/testcases/restful_api/test_mcp_routes_unit.py index ccd628f0fd..2278149881 100644 --- a/test/testcases/restful_api/test_mcp_routes_unit.py +++ b/test/testcases/restful_api/test_mcp_routes_unit.py @@ -18,6 +18,7 @@ import importlib.util import inspect import json import sys +from contextlib import nullcontext from functools import wraps from pathlib import Path from types import ModuleType, SimpleNamespace @@ -140,6 +141,18 @@ def _set_request_json(monkeypatch, module, payload): monkeypatch.setattr(module, "get_request_json", _request_json) +def _stub_url_safety(monkeypatch, module, unsafe_urls=None): + unsafe_urls = set(unsafe_urls or []) + + def _assert_url_is_safe(url): + if url in unsafe_urls: + raise ValueError("blocked unsafe url") + return "safe.example", "93.184.216.34" + + monkeypatch.setattr(module, "assert_url_is_safe", _assert_url_is_safe) + monkeypatch.setattr(module, "pin_dns_global", lambda *_args, **_kwargs: nullcontext()) + + @pytest.fixture(scope="session") def auth(): return "unit-auth" @@ -338,10 +351,20 @@ def test_create_validation_guards(monkeypatch): res = _run(module.create.__wrapped__()) assert "Invalid url" in res["message"] + _set_request_json(monkeypatch, module, {"name": "srv", "url": 123, "server_type": "sse"}) + res = _run(module.create.__wrapped__()) + assert "Invalid url" in res["message"] + + _set_request_json(monkeypatch, module, {"name": "srv", "url": "http://unsafe", "server_type": "sse"}) + _stub_url_safety(monkeypatch, module, {"http://unsafe"}) + res = _run(module.create.__wrapped__()) + assert "blocked unsafe url" in res["message"] + @pytest.mark.p2 def test_create_service_paths(monkeypatch): module = _load_mcp_api(monkeypatch) + _stub_url_safety(monkeypatch, module) base_payload = { "name": "srv", @@ -434,10 +457,20 @@ def test_update_validation_guards(monkeypatch): res = _run(module.update("mcp-1")) assert "Invalid url" in res["message"] + _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "url": {"raw": "http://a"}}) + res = _run(module.update("mcp-1")) + assert "Invalid url" in res["message"] + + _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "url": "http://unsafe"}) + _stub_url_safety(monkeypatch, module, {"http://unsafe"}) + res = _run(module.update("mcp-1")) + assert "blocked unsafe url" in res["message"] + @pytest.mark.p2 def test_update_service_paths(monkeypatch): module = _load_mcp_api(monkeypatch) + _stub_url_safety(monkeypatch, module) existing = _DummyMCPServer( id="mcp-1", @@ -560,6 +593,7 @@ def test_rm_failure_success_and_exception(monkeypatch): @pytest.mark.p2 def test_import_multiple_missing_servers_and_exception(monkeypatch): module = _load_mcp_api(monkeypatch) + _stub_url_safety(monkeypatch, module) _set_request_json(monkeypatch, module, {"mcpServers": {}}) res = _run(module.import_multiple.__wrapped__()) @@ -579,11 +613,15 @@ def test_import_multiple_missing_servers_and_exception(monkeypatch): @pytest.mark.p2 def test_import_multiple_mixed_results(monkeypatch): module = _load_mcp_api(monkeypatch) + _stub_url_safety(monkeypatch, module, {"http://unsafe"}) payload = { "mcpServers": { "missing_fields": {"type": "sse"}, "": {"type": "sse", "url": "http://empty"}, + "invalid_type": {"type": "invalid", "url": "http://invalid"}, + "non_string_url": {"type": "sse", "url": True}, + "unsafe": {"type": "sse", "url": "http://unsafe"}, "dup": {"type": "sse", "url": "http://dup", "authorization_token": "dup-token"}, "tool_err": {"type": "sse", "url": "http://err"}, "insert_fail": {"type": "sse", "url": "http://fail"}, @@ -624,6 +662,12 @@ def test_import_multiple_mixed_results(monkeypatch): assert "Missing required fields" in results["missing_fields"]["message"] assert results[""]["success"] is False assert "Invalid MCP name" in results[""]["message"] + assert results["invalid_type"]["success"] is False + assert "Unsupported MCP server type" in results["invalid_type"]["message"] + assert results["non_string_url"]["success"] is False + assert "Invalid url" in results["non_string_url"]["message"] + assert results["unsafe"]["success"] is False + assert "blocked unsafe url" in results["unsafe"]["message"] assert results["tool_err"]["success"] is False assert "tool call failed" in results["tool_err"]["message"] assert results["insert_fail"]["success"] is False @@ -693,11 +737,16 @@ def test_detail_download_success_and_exception(monkeypatch): @pytest.mark.p2 def test_test_mcp_route_matrix_unit(monkeypatch): module = _load_mcp_api(monkeypatch) + _stub_url_safety(monkeypatch, module) _set_request_json(monkeypatch, module, {"url": "", "server_type": "sse"}) res = _run(module.test_mcp("mcp-1")) assert "Invalid MCP url" in res["message"] + _set_request_json(monkeypatch, module, {"url": ["http://a"], "server_type": "sse"}) + res = _run(module.test_mcp("mcp-1")) + assert "Invalid MCP url" in res["message"] + _set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "invalid"}) res = _run(module.test_mcp("mcp-1")) assert "Unsupported MCP server type" in res["message"] diff --git a/test/testcases/test_web_api/test_mcp_server_app/test_mcp_server_app_unit.py b/test/testcases/test_web_api/test_mcp_server_app/test_mcp_server_app_unit.py index ac8a580c38..6fcc1fc727 100644 --- a/test/testcases/test_web_api/test_mcp_server_app/test_mcp_server_app_unit.py +++ b/test/testcases/test_web_api/test_mcp_server_app/test_mcp_server_app_unit.py @@ -18,6 +18,7 @@ import importlib.util import inspect import json import sys +from contextlib import nullcontext from functools import wraps from pathlib import Path from types import ModuleType, SimpleNamespace @@ -140,6 +141,18 @@ def _set_request_json(monkeypatch, module, payload): monkeypatch.setattr(module, "get_request_json", _request_json) +def _stub_url_safety(monkeypatch, module, unsafe_urls=None): + unsafe_urls = set(unsafe_urls or []) + + def _assert_url_is_safe(url): + if url in unsafe_urls: + raise ValueError("blocked unsafe url") + return "safe.example", "93.184.216.34" + + monkeypatch.setattr(module, "assert_url_is_safe", _assert_url_is_safe) + monkeypatch.setattr(module, "pin_dns_global", lambda *_args, **_kwargs: nullcontext()) + + @pytest.fixture(scope="session") def auth(): return "unit-auth" @@ -338,10 +351,20 @@ def test_create_validation_guards(monkeypatch): res = _run(module.create.__wrapped__()) assert "Invalid url" in res["message"] + _set_request_json(monkeypatch, module, {"name": "srv", "url": 123, "server_type": "sse"}) + res = _run(module.create.__wrapped__()) + assert "Invalid url" in res["message"] + + _set_request_json(monkeypatch, module, {"name": "srv", "url": "http://unsafe", "server_type": "sse"}) + _stub_url_safety(monkeypatch, module, {"http://unsafe"}) + res = _run(module.create.__wrapped__()) + assert "blocked unsafe url" in res["message"] + @pytest.mark.p2 def test_create_service_paths(monkeypatch): module = _load_mcp_api(monkeypatch) + _stub_url_safety(monkeypatch, module) base_payload = { "name": "srv", @@ -434,10 +457,20 @@ def test_update_validation_guards(monkeypatch): res = _run(module.update("mcp-1")) assert "Invalid url" in res["message"] + _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "url": {"raw": "http://a"}}) + res = _run(module.update("mcp-1")) + assert "Invalid url" in res["message"] + + _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "url": "http://unsafe"}) + _stub_url_safety(monkeypatch, module, {"http://unsafe"}) + res = _run(module.update("mcp-1")) + assert "blocked unsafe url" in res["message"] + @pytest.mark.p2 def test_update_service_paths(monkeypatch): module = _load_mcp_api(monkeypatch) + _stub_url_safety(monkeypatch, module) existing = _DummyMCPServer( id="mcp-1", @@ -560,6 +593,7 @@ def test_rm_failure_success_and_exception(monkeypatch): @pytest.mark.p2 def test_import_multiple_missing_servers_and_exception(monkeypatch): module = _load_mcp_api(monkeypatch) + _stub_url_safety(monkeypatch, module) _set_request_json(monkeypatch, module, {"mcpServers": {}}) res = _run(module.import_multiple.__wrapped__()) @@ -579,11 +613,15 @@ def test_import_multiple_missing_servers_and_exception(monkeypatch): @pytest.mark.p2 def test_import_multiple_mixed_results(monkeypatch): module = _load_mcp_api(monkeypatch) + _stub_url_safety(monkeypatch, module, {"http://unsafe"}) payload = { "mcpServers": { "missing_fields": {"type": "sse"}, "": {"type": "sse", "url": "http://empty"}, + "invalid_type": {"type": "invalid", "url": "http://invalid"}, + "non_string_url": {"type": "sse", "url": True}, + "unsafe": {"type": "sse", "url": "http://unsafe"}, "dup": {"type": "sse", "url": "http://dup", "authorization_token": "dup-token"}, "tool_err": {"type": "sse", "url": "http://err"}, "insert_fail": {"type": "sse", "url": "http://fail"}, @@ -624,6 +662,12 @@ def test_import_multiple_mixed_results(monkeypatch): assert "Missing required fields" in results["missing_fields"]["message"] assert results[""]["success"] is False assert "Invalid MCP name" in results[""]["message"] + assert results["invalid_type"]["success"] is False + assert "Unsupported MCP server type" in results["invalid_type"]["message"] + assert results["non_string_url"]["success"] is False + assert "Invalid url" in results["non_string_url"]["message"] + assert results["unsafe"]["success"] is False + assert "blocked unsafe url" in results["unsafe"]["message"] assert results["tool_err"]["success"] is False assert "tool call failed" in results["tool_err"]["message"] assert results["insert_fail"]["success"] is False @@ -693,11 +737,16 @@ def test_detail_download_success_and_exception(monkeypatch): @pytest.mark.p2 def test_test_mcp_route_matrix_unit(monkeypatch): module = _load_mcp_api(monkeypatch) + _stub_url_safety(monkeypatch, module) _set_request_json(monkeypatch, module, {"url": "", "server_type": "sse"}) res = _run(module.test_mcp("mcp-1")) assert "Invalid MCP url" in res["message"] + _set_request_json(monkeypatch, module, {"url": ["http://a"], "server_type": "sse"}) + res = _run(module.test_mcp("mcp-1")) + assert "Invalid MCP url" in res["message"] + _set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "invalid"}) res = _run(module.test_mcp("mcp-1")) assert "Unsupported MCP server type" in res["message"] diff --git a/test/unit_test/deepdoc/parser/test_mineru_parser.py b/test/unit_test/deepdoc/parser/test_mineru_parser.py new file mode 100644 index 0000000000..9e624e51b1 --- /dev/null +++ b/test/unit_test/deepdoc/parser/test_mineru_parser.py @@ -0,0 +1,68 @@ +import importlib.util +import logging +import sys +from pathlib import Path +from types import ModuleType + + +def _load_mineru_parser(monkeypatch): + repo_root = Path(__file__).resolve().parents[4] + + deepdoc_mod = ModuleType("deepdoc") + deepdoc_mod.__path__ = [str(repo_root / "deepdoc")] + monkeypatch.setitem(sys.modules, "deepdoc", deepdoc_mod) + + parser_mod = ModuleType("deepdoc.parser") + parser_mod.__path__ = [str(repo_root / "deepdoc" / "parser")] + monkeypatch.setitem(sys.modules, "deepdoc.parser", parser_mod) + + pdf_parser_mod = ModuleType("deepdoc.parser.pdf_parser") + + class _RAGFlowPdfParser: + pass + + pdf_parser_mod.RAGFlowPdfParser = _RAGFlowPdfParser + monkeypatch.setitem(sys.modules, "deepdoc.parser.pdf_parser", pdf_parser_mod) + + utils_mod = ModuleType("deepdoc.parser.utils") + utils_mod.extract_pdf_outlines = lambda *_args, **_kwargs: [] + monkeypatch.setitem(sys.modules, "deepdoc.parser.utils", utils_mod) + + module_name = "test_mineru_parser_unit_module" + module_path = repo_root / "deepdoc" / "parser" / "mineru_parser.py" + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + monkeypatch.setitem(sys.modules, module_name, module) + spec.loader.exec_module(module) + return module + + +def test_sanitize_section_text_removes_escaped_html_tags(monkeypatch): + module = _load_mineru_parser(monkeypatch) + text = "<table><tr><td>Alpha</td><td>Beta</td></tr></table>" + + sanitized = module.MinerUParser._sanitize_section_text(text) + + assert sanitized == "AlphaBeta" + assert "" not in sanitized + assert "" not in sanitized + + +def test_transfer_to_sections_logs_sections_dropped_after_sanitization(monkeypatch, caplog): + module = _load_mineru_parser(monkeypatch) + parser = module.MinerUParser() + outputs = [ + { + "type": module.MinerUContentType.TEXT, + "text": "<td></td>", + "page_idx": 0, + "bbox": (0, 0, 1, 1), + } + ] + + with caplog.at_level(logging.DEBUG, logger=parser.logger.name): + sections = parser._transfer_to_sections(outputs, parse_method="pipeline") + + assert sections == [] + assert "Skip section after sanitization" in caplog.text + assert f"type={module.MinerUContentType.TEXT}" in caplog.text