Refa: migrate MCP APIs to RESTful api (#14317)

### What problem does this PR solve?

migrate MCP APIs to RESTful api

### Type of change

- [x] Refactoring
This commit is contained in:
buua436
2026-04-23 12:51:27 +08:00
committed by GitHub
parent dbf8c6ed90
commit aa4526266f
6 changed files with 481 additions and 317 deletions

View File

@@ -0,0 +1,331 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from quart import Response, request
from api.apps import current_user, login_required
from api.db.db_models import MCPServer
from api.db.services.mcp_server_service import MCPServerService
from api.db.services.user_service import TenantService
from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request
from api.utils.web_utils import get_float, safe_json_parse
from common.constants import VALID_MCP_SERVER_TYPES
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
from common.misc_utils import get_uuid, thread_pool_exec
def _get_mcp_ids_from_args() -> list[str]:
mcp_ids = request.args.getlist("mcp_ids")
if mcp_ids:
return [mcp_id for item in mcp_ids for mcp_id in item.split(",") if mcp_id]
mcp_ids = request.args.get("mcp_id", "")
return [mcp_id for mcp_id in mcp_ids.split(",") if mcp_id]
def _export_mcp_servers(mcp_ids: list[str]) -> dict | None:
exported_servers = {}
for mcp_id in mcp_ids:
e, mcp_server = MCPServerService.get_by_id(mcp_id)
if e and mcp_server.tenant_id == current_user.id:
server_key = mcp_server.name
exported_servers[server_key] = {
"type": mcp_server.server_type,
"url": mcp_server.url,
"name": mcp_server.name,
"authorization_token": mcp_server.variables.get("authorization_token", ""),
"tools": mcp_server.variables.get("tools", {}),
}
if not exported_servers:
return None
return {"mcpServers": exported_servers}
@manager.route("/mcp/servers", methods=["GET"]) # noqa: F821
@login_required
async def list_mcp() -> Response:
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 0))
items_per_page = int(request.args.get("page_size", 0))
orderby = request.args.get("orderby", "create_time")
if request.args.get("desc", "true").lower() == "false":
desc = False
else:
desc = True
mcp_ids = _get_mcp_ids_from_args()
try:
servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or []
total = len(servers)
if page_number and items_per_page:
servers = servers[(page_number - 1) * items_per_page : page_number * items_per_page]
return get_json_result(data={"mcp_servers": servers, "total": total})
except Exception as e:
return server_error_response(e)
@manager.route("/mcp/servers/<mcp_id>", methods=["GET"]) # noqa: F821
@login_required
def detail(mcp_id: str) -> Response:
try:
if request.args.get("mode") == "download":
exported_servers = _export_mcp_servers([mcp_id])
if exported_servers is None:
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
return get_json_result(data=exported_servers)
mcp_server = MCPServerService.get_or_none(id=mcp_id, tenant_id=current_user.id)
if mcp_server is None:
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
return get_json_result(data=mcp_server.to_dict())
except Exception as e:
return server_error_response(e)
@manager.route("/mcp/servers", methods=["POST"]) # noqa: F821
@login_required
@validate_request("name", "url", "server_type")
async def create() -> Response:
req = await get_request_json()
server_type = req.get("server_type", "")
if server_type not in VALID_MCP_SERVER_TYPES:
return get_data_error_result(message="Unsupported MCP server type.")
server_name = req.get("name", "")
if not server_name or len(server_name.encode("utf-8")) > 255:
return get_data_error_result(message=f"Invalid MCP name or length is {len(server_name)} which is large than 255.")
e, _ = MCPServerService.get_by_name_and_tenant(name=server_name, tenant_id=current_user.id)
if e:
return get_data_error_result(message="Duplicated MCP server name.")
url = req.get("url", "")
if not url:
return get_data_error_result(message="Invalid url.")
headers = safe_json_parse(req.get("headers", {}))
req["headers"] = headers
variables = safe_json_parse(req.get("variables", {}))
variables.pop("tools", None)
timeout = get_float(req, "timeout", 10)
try:
req["id"] = get_uuid()
req["tenant_id"] = current_user.id
e, _ = TenantService.get_by_id(current_user.id)
if not e:
return get_data_error_result(message="Tenant not found.")
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
if err_message:
return get_data_error_result(message=err_message)
tools = server_tools[server_name]
tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool}
variables["tools"] = tools
req["variables"] = variables
if not MCPServerService.insert(**req):
return get_data_error_result(message="Failed to create MCP server.")
return get_json_result(data=req)
except Exception as e:
return server_error_response(e)
@manager.route("/mcp/servers/<mcp_id>", methods=["PUT"]) # noqa: F821
@login_required
async def update(mcp_id: str) -> Response:
req = await get_request_json()
e, mcp_server = MCPServerService.get_by_id(mcp_id)
if not e or mcp_server.tenant_id != current_user.id:
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
server_type = req.get("server_type", mcp_server.server_type)
if server_type and server_type not in VALID_MCP_SERVER_TYPES:
return get_data_error_result(message="Unsupported MCP server type.")
server_name = req.get("name", mcp_server.name)
if server_name and len(server_name.encode("utf-8")) > 255:
return get_data_error_result(message=f"Invalid MCP name or length is {len(server_name)} which is large than 255.")
url = req.get("url", mcp_server.url)
if not url:
return get_data_error_result(message="Invalid url.")
headers = safe_json_parse(req.get("headers", mcp_server.headers))
req["headers"] = headers
variables = safe_json_parse(req.get("variables", mcp_server.variables))
variables.pop("tools", None)
timeout = get_float(req, "timeout", 10)
try:
req["tenant_id"] = current_user.id
req["id"] = mcp_id
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
if err_message:
return get_data_error_result(message=err_message)
tools = server_tools[server_name]
tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool}
variables["tools"] = tools
req["variables"] = variables
if not MCPServerService.filter_update([MCPServer.id == mcp_id, MCPServer.tenant_id == current_user.id], req):
return get_data_error_result(message="Failed to updated MCP server.")
e, updated_mcp = MCPServerService.get_by_id(req["id"])
if not e:
return get_data_error_result(message="Failed to fetch updated MCP server.")
return get_json_result(data=updated_mcp.to_dict())
except Exception as e:
return server_error_response(e)
@manager.route("/mcp/servers/<mcp_id>", methods=["DELETE"]) # noqa: F821
@login_required
async def rm(mcp_id: str) -> Response:
try:
e, mcp_server = MCPServerService.get_by_id(mcp_id)
if not e or mcp_server.tenant_id != current_user.id:
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
if not MCPServerService.delete_by_ids([mcp_id]):
return get_data_error_result(message=f"Failed to delete MCP servers {[mcp_id]}")
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route("/mcp/servers/import", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcpServers")
async def import_multiple() -> Response:
req = await get_request_json()
servers = req.get("mcpServers", {})
if not servers:
return get_data_error_result(message="No MCP servers provided.")
timeout = get_float(req, "timeout", 10)
results = []
try:
for server_name, config in servers.items():
if not all(key in config for key in {"type", "url"}):
results.append({"server": server_name, "success": False, "message": "Missing required fields (type or url)"})
continue
if not server_name or len(server_name.encode("utf-8")) > 255:
results.append({"server": server_name, "success": False, "message": f"Invalid MCP name or length is {len(server_name)} which is large than 255."})
continue
base_name = server_name
new_name = base_name
counter = 0
while True:
e, _ = MCPServerService.get_by_name_and_tenant(name=new_name, tenant_id=current_user.id)
if not e:
break
new_name = f"{base_name}_{counter}"
counter += 1
create_data = {
"id": get_uuid(),
"tenant_id": current_user.id,
"name": new_name,
"url": config["url"],
"server_type": config["type"],
"variables": {"authorization_token": config.get("authorization_token", "")},
}
headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {}
variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}}
mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers)
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
if err_message:
results.append({"server": base_name, "success": False, "message": err_message})
continue
tools = server_tools[new_name]
tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool}
create_data["variables"]["tools"] = tools
if MCPServerService.insert(**create_data):
result = {"server": server_name, "success": True, "action": "created", "id": create_data["id"], "new_name": new_name}
if new_name != base_name:
result["message"] = f"Renamed from '{base_name}' to '{new_name}' avoid duplication"
results.append(result)
else:
results.append({"server": server_name, "success": False, "message": "Failed to create MCP server."})
return get_json_result(data={"results": results})
except Exception as e:
return server_error_response(e)
@manager.route("/mcp/servers/<mcp_id>/test", methods=["POST"]) # noqa: F821
@login_required
@validate_request("url", "server_type")
async def test_mcp(mcp_id: str) -> Response:
req = await get_request_json()
url = req.get("url", "")
if not url:
return get_data_error_result(message="Invalid MCP url.")
server_type = req.get("server_type", "")
if server_type not in VALID_MCP_SERVER_TYPES:
return get_data_error_result(message="Unsupported MCP server type.")
timeout = get_float(req, "timeout", 10)
headers = safe_json_parse(req.get("headers", {}))
variables = safe_json_parse(req.get("variables", {}))
mcp_server = MCPServer(id=mcp_id, server_type=server_type, url=url, headers=headers, variables=variables)
result = []
try:
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
try:
tools = await thread_pool_exec(tool_call_session.get_tools, timeout)
except Exception as e:
return get_data_error_result(message=f"Test MCP error: {e}")
finally:
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, [tool_call_session])
for tool in tools:
tool_dict = tool.model_dump()
tool_dict["enabled"] = True
result.append(tool_dict)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)

View File

@@ -33,6 +33,14 @@ class _DummyManager:
return decorator
class _Args(dict):
def getlist(self, key):
value = self.get(key, [])
if isinstance(value, list):
return value
return [value]
class _Field:
def __init__(self, name):
self.name = name
@@ -142,13 +150,22 @@ def set_tenant_info():
return None
def _load_mcp_server_app(monkeypatch):
def _load_mcp_api(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
quart_mod = ModuleType("quart")
quart_mod.Response = object
quart_mod.request = SimpleNamespace(args=_Args({}))
monkeypatch.setitem(sys.modules, "quart", quart_mod)
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
constants_mod = ModuleType("common.constants")
constants_mod.VALID_MCP_SERVER_TYPES = {"sse", "streamable-http"}
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
apps_mod = ModuleType("api.apps")
apps_mod.current_user = SimpleNamespace(id="tenant_1")
apps_mod.login_required = lambda func: func
@@ -230,8 +247,8 @@ def _load_mcp_server_app(monkeypatch):
web_utils_mod.safe_json_parse = _safe_json_parse
monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod)
module_name = "test_mcp_server_app_unit_module"
module_path = repo_root / "api" / "apps" / "mcp_server_app.py"
module_name = "test_mcp_api_unit_module"
module_path = repo_root / "api" / "apps" / "restful_apis" / "mcp_api.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
@@ -242,12 +259,12 @@ def _load_mcp_server_app(monkeypatch):
@pytest.mark.p2
def test_list_mcp_desc_pagination_and_exception(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
module = _load_mcp_api(monkeypatch)
monkeypatch.setattr(
module,
"request",
SimpleNamespace(args={"keywords": "k", "page": "2", "page_size": "1", "orderby": "create_time", "desc": "false"}),
SimpleNamespace(args=_Args({"keywords": "k", "page": "2", "page_size": "1", "orderby": "create_time", "desc": "false"})),
)
_set_request_json(monkeypatch, module, {"mcp_ids": []})
monkeypatch.setattr(module.MCPServerService, "get_servers", lambda *_args, **_kwargs: [{"id": "a"}, {"id": "b"}])
@@ -257,7 +274,7 @@ def test_list_mcp_desc_pagination_and_exception(monkeypatch):
assert res["data"]["total"] == 2
assert res["data"]["mcp_servers"] == [{"id": "b"}]
monkeypatch.setattr(module, "request", SimpleNamespace(args={}))
monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({})))
_set_request_json(monkeypatch, module, {"mcp_ids": []})
def _raise_list(*_args, **_kwargs):
@@ -271,19 +288,20 @@ def test_list_mcp_desc_pagination_and_exception(monkeypatch):
@pytest.mark.p2
def test_detail_not_found_success_and_exception(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
monkeypatch.setattr(module, "request", SimpleNamespace(args={"mcp_id": "mcp-1"}))
module = _load_mcp_api(monkeypatch)
monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({})))
monkeypatch.setattr(module.MCPServerService, "get_or_none", lambda **_kwargs: None)
res = module.detail()
assert res["code"] == module.RetCode.NOT_FOUND
res = module.detail("mcp-1")
assert res["code"] == 102
assert "Cannot find MCP server mcp-1 for user tenant_1" in res["message"]
monkeypatch.setattr(
module.MCPServerService,
"get_or_none",
lambda **_kwargs: _DummyMCPServer(id="mcp-1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1"),
)
res = module.detail()
res = module.detail("mcp-1")
assert res["code"] == 0
assert res["data"]["id"] == "mcp-1"
@@ -291,14 +309,14 @@ def test_detail_not_found_success_and_exception(monkeypatch):
raise RuntimeError("detail explode")
monkeypatch.setattr(module.MCPServerService, "get_or_none", _raise_detail)
res = module.detail()
res = module.detail("mcp-1")
assert res["code"] == 100
assert "detail explode" in res["message"]
@pytest.mark.p2
def test_create_validation_guards(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
module = _load_mcp_api(monkeypatch)
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None))
@@ -323,7 +341,7 @@ def test_create_validation_guards(monkeypatch):
@pytest.mark.p2
def test_create_service_paths(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
module = _load_mcp_api(monkeypatch)
base_payload = {
"name": "srv",
@@ -350,8 +368,8 @@ def test_create_service_paths(monkeypatch):
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_tools_error)
res = _run(module.create.__wrapped__())
assert res["code"] == "tools error"
assert "Sorry! Data missing!" in res["message"]
assert res["code"] == 102
assert "tools error" in res["message"]
_set_request_json(monkeypatch, module, dict(base_payload))
@@ -361,8 +379,8 @@ def test_create_service_paths(monkeypatch):
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_ok)
monkeypatch.setattr(module.MCPServerService, "insert", lambda **_kwargs: False)
res = _run(module.create.__wrapped__())
assert res["code"] == "Failed to create MCP server."
assert "Sorry! Data missing!" in res["message"]
assert res["code"] == 102
assert "Failed to create MCP server" in res["message"]
_set_request_json(monkeypatch, module, dict(base_payload))
monkeypatch.setattr(module.MCPServerService, "insert", lambda **_kwargs: True)
@@ -385,13 +403,13 @@ def test_create_service_paths(monkeypatch):
@pytest.mark.p2
def test_update_validation_guards(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
module = _load_mcp_api(monkeypatch)
existing = _DummyMCPServer(id="mcp-1", name="srv", url="http://server", server_type="sse", tenant_id="tenant_1", variables={}, headers={})
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1"})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None))
res = _run(module.update.__wrapped__())
res = _run(module.update("mcp-1"))
assert "Cannot find MCP server" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1"})
@@ -400,26 +418,26 @@ def test_update_validation_guards(monkeypatch):
"get_by_id",
lambda _mcp_id: (True, _DummyMCPServer(id="mcp-1", name="srv", url="http://server", server_type="sse", tenant_id="other", variables={}, headers={})),
)
res = _run(module.update.__wrapped__())
res = _run(module.update("mcp-1"))
assert "Cannot find MCP server" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "server_type": "invalid"})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing))
res = _run(module.update.__wrapped__())
res = _run(module.update("mcp-1"))
assert "Unsupported MCP server type" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "name": "a" * 256})
res = _run(module.update.__wrapped__())
res = _run(module.update("mcp-1"))
assert "Invalid MCP name" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "url": ""})
res = _run(module.update.__wrapped__())
res = _run(module.update("mcp-1"))
assert "Invalid url" in res["message"]
@pytest.mark.p2
def test_update_service_paths(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
module = _load_mcp_api(monkeypatch)
existing = _DummyMCPServer(
id="mcp-1",
@@ -457,9 +475,9 @@ def test_update_service_paths(monkeypatch):
return None, "update tools error"
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_tools_error)
res = _run(module.update.__wrapped__())
assert res["code"] == "update tools error"
assert "Sorry! Data missing!" in res["message"]
res = _run(module.update("mcp-1"))
assert res["code"] == 102
assert "update tools error" in res["message"]
_set_request_json(monkeypatch, module, dict(base_payload))
@@ -468,7 +486,7 @@ def test_update_service_paths(monkeypatch):
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_ok)
monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: False)
res = _run(module.update.__wrapped__())
res = _run(module.update("mcp-1"))
assert "Failed to updated MCP server" in res["message"]
_set_request_json(monkeypatch, module, dict(base_payload))
@@ -482,7 +500,7 @@ def test_update_service_paths(monkeypatch):
_get_by_id_fetch_fail.calls = 0
monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id_fetch_fail)
res = _run(module.update.__wrapped__())
res = _run(module.update("mcp-1"))
assert "Failed to fetch updated MCP server" in res["message"]
_set_request_json(monkeypatch, module, dict(base_payload))
@@ -495,7 +513,7 @@ def test_update_service_paths(monkeypatch):
_get_by_id_success.calls = 0
monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id_success)
res = _run(module.update.__wrapped__())
res = _run(module.update("mcp-1"))
assert res["code"] == 0
assert res["data"]["id"] == "mcp-1"
@@ -506,23 +524,25 @@ def test_update_service_paths(monkeypatch):
raise RuntimeError("update explode")
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_raises)
res = _run(module.update.__wrapped__())
res = _run(module.update("mcp-1"))
assert res["code"] == 100
assert "update explode" in res["message"]
@pytest.mark.p2
def test_rm_failure_success_and_exception(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
module = _load_mcp_api(monkeypatch)
server = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1", variables={})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server))
_set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]})
monkeypatch.setattr(module.MCPServerService, "delete_by_ids", lambda _ids: False)
res = _run(module.rm.__wrapped__())
res = _run(module.rm("id1"))
assert "Failed to delete MCP servers" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]})
monkeypatch.setattr(module.MCPServerService, "delete_by_ids", lambda _ids: True)
res = _run(module.rm.__wrapped__())
res = _run(module.rm("id1"))
assert res["code"] == 0
assert res["data"] is True
@@ -532,14 +552,14 @@ def test_rm_failure_success_and_exception(monkeypatch):
raise RuntimeError("rm explode")
monkeypatch.setattr(module.MCPServerService, "delete_by_ids", _raise_rm)
res = _run(module.rm.__wrapped__())
res = _run(module.rm("id1"))
assert res["code"] == 100
assert "rm explode" in res["message"]
@pytest.mark.p2
def test_import_multiple_missing_servers_and_exception(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
module = _load_mcp_api(monkeypatch)
_set_request_json(monkeypatch, module, {"mcpServers": {}})
res = _run(module.import_multiple.__wrapped__())
@@ -558,7 +578,7 @@ def test_import_multiple_missing_servers_and_exception(monkeypatch):
@pytest.mark.p2
def test_import_multiple_mixed_results(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
module = _load_mcp_api(monkeypatch)
payload = {
"mcpServers": {
@@ -614,244 +634,72 @@ def test_import_multiple_mixed_results(monkeypatch):
@pytest.mark.p2
def test_export_multiple_missing_ids_success_and_exception(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
def test_detail_download_success_and_exception(monkeypatch):
module = _load_mcp_api(monkeypatch)
monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({"mode": "download"})))
_set_request_json(monkeypatch, module, {"mcp_ids": []})
res = _run(module.export_multiple.__wrapped__())
assert "No MCP server IDs provided" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_ids": ["id1", "id2", "id3"]})
def _get_by_id(mcp_id):
if mcp_id == "id1":
return True, _DummyMCPServer(
monkeypatch.setattr(
module.MCPServerService,
"get_by_id",
lambda _mcp_id: (
True,
_DummyMCPServer(
id="id1",
name="srv-one",
url="http://one",
server_type="sse",
tenant_id="tenant_1",
variables={"authorization_token": "tok", "tools": {"tool_a": {"enabled": True}}},
)
if mcp_id == "id2":
return True, _DummyMCPServer(
),
),
)
res = module.detail("id1")
assert res["code"] == 0
assert list(res["data"]["mcpServers"].keys()) == ["srv-one"]
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None))
res = module.detail("missing")
assert res["code"] == 102
assert "Cannot find MCP server missing for user tenant_1" in res["message"]
monkeypatch.setattr(
module.MCPServerService,
"get_by_id",
lambda _mcp_id: (
True,
_DummyMCPServer(
id="id2",
name="srv-two",
url="http://two",
server_type="sse",
tenant_id="other",
variables={},
)
return False, None
monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id)
res = _run(module.export_multiple.__wrapped__())
assert res["code"] == 0
assert list(res["data"]["mcpServers"].keys()) == ["srv-one"]
_set_request_json(monkeypatch, module, {"mcp_ids": ["id1"]})
),
),
)
res = module.detail("id2")
assert res["code"] == 102
assert "Cannot find MCP server id2 for user tenant_1" in res["message"]
def _raise_export(_mcp_id):
raise RuntimeError("export explode")
monkeypatch.setattr(module.MCPServerService, "get_by_id", _raise_export)
res = _run(module.export_multiple.__wrapped__())
res = module.detail("id1")
assert res["code"] == 100
assert "export explode" in res["message"]
@pytest.mark.p2
def test_list_tools_missing_ids_success_inner_error_outer_error_and_finally_cleanup(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
_set_request_json(monkeypatch, module, {"mcp_ids": []})
res = _run(module.list_tools.__wrapped__())
assert "No MCP server IDs provided" in res["message"]
server = _DummyMCPServer(
id="id1",
name="srv-tools",
url="http://tools",
server_type="sse",
tenant_id="tenant_1",
variables={"tools": {"tool_a": {"enabled": False}}},
)
_set_request_json(monkeypatch, module, {"mcp_ids": ["id1"], "timeout": "2.0"})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server))
close_calls = []
async def _thread_pool_exec_success(func, *args):
if func is module.close_multiple_mcp_toolcall_sessions:
close_calls.append(args[0])
return None
return func(*args)
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_success)
res = _run(module.list_tools.__wrapped__())
assert res["code"] == 0
assert res["data"]["id1"][0]["name"] == "tool_a"
assert res["data"]["id1"][0]["enabled"] is False
assert res["data"]["id1"][1]["enabled"] is True
assert close_calls and len(close_calls[-1]) == 1
_set_request_json(monkeypatch, module, {"mcp_ids": ["id1"], "timeout": "2.0"})
close_calls_inner = []
async def _thread_pool_exec_inner_error(func, *args):
if func is module.close_multiple_mcp_toolcall_sessions:
close_calls_inner.append(args[0])
return None
raise RuntimeError("inner tools explode")
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_inner_error)
res = _run(module.list_tools.__wrapped__())
assert res["code"] == 102
assert "MCP list tools error" in res["message"]
assert close_calls_inner and len(close_calls_inner[-1]) == 1
_set_request_json(monkeypatch, module, {"mcp_ids": ["id1"], "timeout": "2.0"})
close_calls_outer = []
def _raise_get_by_id(_mcp_id):
raise RuntimeError("outer explode")
monkeypatch.setattr(module.MCPServerService, "get_by_id", _raise_get_by_id)
async def _thread_pool_exec_outer(func, *args):
if func is module.close_multiple_mcp_toolcall_sessions:
close_calls_outer.append(args[0])
return None
return func(*args)
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_outer)
res = _run(module.list_tools.__wrapped__())
assert res["code"] == 100
assert "outer explode" in res["message"]
assert close_calls_outer
@pytest.mark.p2
def test_test_tool_missing_mcp_id(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
_set_request_json(monkeypatch, module, {"mcp_id": "", "tool_name": "tool_a", "arguments": {"x": 1}})
res = _run(module.test_tool.__wrapped__())
assert "No MCP server ID provided" in res["message"]
@pytest.mark.p2
def test_test_tool_route_matrix_unit(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
_set_request_json(monkeypatch, module, {"mcp_id": "", "tool_name": "tool_a", "arguments": {"x": 1}})
res = _run(module.test_tool.__wrapped__())
assert "No MCP server ID provided" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "id1", "tool_name": "", "arguments": {"x": 1}})
res = _run(module.test_tool.__wrapped__())
assert "Require provide tool name and arguments" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "id1", "tool_name": "tool_a", "arguments": {}})
res = _run(module.test_tool.__wrapped__())
assert "Require provide tool name and arguments" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "id1", "tool_name": "tool_a", "arguments": {"x": 1}})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None))
res = _run(module.test_tool.__wrapped__())
assert "Cannot find MCP server id1 for user tenant_1" in res["message"]
server_other = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="other", variables={})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_other))
res = _run(module.test_tool.__wrapped__())
assert "Cannot find MCP server id1 for user tenant_1" in res["message"]
server_ok = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1", variables={})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_ok))
close_calls = []
async def _thread_pool_exec_success(func, *args):
if func is module.close_multiple_mcp_toolcall_sessions:
close_calls.append(args[0])
return None
return func(*args)
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_success)
res = _run(module.test_tool.__wrapped__())
assert res["code"] == 0
assert res["data"] == "ok"
assert close_calls and len(close_calls[-1]) == 1
async def _thread_pool_exec_raise(func, *args):
if func is module.close_multiple_mcp_toolcall_sessions:
return None
raise RuntimeError("tool call explode")
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_raise)
res = _run(module.test_tool.__wrapped__())
assert res["code"] == 100
assert "tool call explode" in res["message"]
@pytest.mark.p2
def test_cache_tool_route_matrix_unit(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
_set_request_json(monkeypatch, module, {"mcp_id": "", "tools": [{"name": "tool_a"}]})
res = _run(module.cache_tool.__wrapped__())
assert "No MCP server ID provided" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "id1", "tools": [{"name": "tool_a"}]})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None))
res = _run(module.cache_tool.__wrapped__())
assert "Cannot find MCP server id1 for user tenant_1" in res["message"]
server_other = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="other", variables={})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_other))
res = _run(module.cache_tool.__wrapped__())
assert "Cannot find MCP server id1 for user tenant_1" in res["message"]
server_fail = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1", variables={})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_fail))
monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: False)
res = _run(module.cache_tool.__wrapped__())
assert "Failed to updated MCP server" in res["message"]
server_ok = _DummyMCPServer(
id="id1",
name="srv",
url="http://a",
server_type="sse",
tenant_id="tenant_1",
variables={"tools": {"old_tool": {"name": "old_tool"}}},
)
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_ok))
monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: True)
_set_request_json(
monkeypatch,
module,
{
"mcp_id": "id1",
"tools": [{"name": "tool_a", "enabled": True}, {"bad": 1}, "x", {"name": "tool_b", "enabled": False}],
},
)
res = _run(module.cache_tool.__wrapped__())
assert res["code"] == 0
assert sorted(res["data"].keys()) == ["tool_a", "tool_b"]
assert server_ok.variables["tools"]["tool_b"]["enabled"] is False
@pytest.mark.p2
def test_test_mcp_route_matrix_unit(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
module = _load_mcp_api(monkeypatch)
_set_request_json(monkeypatch, module, {"url": "", "server_type": "sse"})
res = _run(module.test_mcp.__wrapped__())
res = _run(module.test_mcp("mcp-1"))
assert "Invalid MCP url" in res["message"]
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "invalid"})
res = _run(module.test_mcp.__wrapped__())
res = _run(module.test_mcp("mcp-1"))
assert "Unsupported MCP server type" in res["message"]
close_calls = []
@@ -866,7 +714,7 @@ def test_test_mcp_route_matrix_unit(monkeypatch):
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_inner_error)
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"})
res = _run(module.test_mcp.__wrapped__())
res = _run(module.test_mcp("mcp-1"))
assert res["code"] == 102
assert "Test MCP error: get tools explode" in res["message"]
assert close_calls and len(close_calls[-1]) == 1
@@ -881,7 +729,7 @@ def test_test_mcp_route_matrix_unit(monkeypatch):
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_success)
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"})
res = _run(module.test_mcp.__wrapped__())
res = _run(module.test_mcp("mcp-1"))
assert res["code"] == 0
assert res["data"][0]["name"] == "tool_a"
assert all(tool["enabled"] is True for tool in res["data"])
@@ -892,6 +740,6 @@ def test_test_mcp_route_matrix_unit(monkeypatch):
monkeypatch.setattr(module, "MCPToolCallSession", _raise_session)
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"})
res = _run(module.test_mcp.__wrapped__())
res = _run(module.test_mcp("mcp-1"))
assert res["code"] == 100
assert "session explode" in res["message"]

View File

@@ -141,8 +141,12 @@ export const useDeleteMcpServer = () => {
} = useMutation({
mutationKey: [McpApiAction.DeleteMcpServer],
mutationFn: async (ids: string[]) => {
const { data = {} } = await mcpServerService.delete({ mcp_ids: ids });
if (data.code === 0) {
const results = await Promise.all(
ids.map((id) => mcpServerService.delete({ mcp_id: id })),
);
const failed = results.find(({ data = {} }) => data.code !== 0);
const data = failed?.data ?? { code: 0, data: true };
if (!failed) {
message.success(i18n.t(`message.deleted`));
queryClient.invalidateQueries({
@@ -188,8 +192,23 @@ export const useExportMcpServer = () => {
} = useMutation<ResponseType<IExportedMcpServers>, Error, string[]>({
mutationKey: [McpApiAction.ExportMcpServer],
mutationFn: async (ids) => {
const { data = {} } = await mcpServerService.export({ mcp_ids: ids });
if (data.code === 0) {
const results = await Promise.all(
ids.map((id) => mcpServerService.export({ mcp_id: id })),
);
const failed = results.find(({ data = {} }) => data.code !== 0);
const data = (failed?.data ?? {
code: 0,
data: results.reduce<IExportedMcpServers>(
(acc, result) => ({
mcpServers: {
...acc.mcpServers,
...(result.data?.data?.mcpServers ?? {}),
},
}),
{ mcpServers: {} },
),
}) as ResponseType<IExportedMcpServers>;
if (!failed) {
message.success(i18n.t(`message.operated`));
}
return data;

View File

@@ -43,12 +43,7 @@ interface ISymbol {
}
export interface IExportedMcpServers {
mcpServers: McpServers;
}
interface McpServers {
fetch_2: IExportedMcpServer;
github_1: IExportedMcpServer;
mcpServers: Record<string, IExportedMcpServer>;
}
export interface IExportedMcpServer {

View File

@@ -1,57 +1,27 @@
import { IPaginationRequestBody } from '@/interfaces/request/base';
import api from '@/utils/api';
import registerServer from '@/utils/register-server';
import request from '@/utils/request';
const {
listMcpServer,
createMcpServer,
updateMcpServer,
deleteMcpServer,
getMcpServer,
importMcpServer,
exportMcpServer,
testMcpServer,
} = api;
const methods = {
list: {
url: listMcpServer,
method: 'post',
},
get: {
url: getMcpServer,
method: 'get',
},
create: {
url: createMcpServer,
method: 'post',
},
update: {
url: updateMcpServer,
method: 'post',
},
delete: {
url: deleteMcpServer,
method: 'post',
},
import: {
url: importMcpServer,
method: 'post',
},
export: {
url: exportMcpServer,
method: 'post',
},
test: {
url: testMcpServer,
method: 'post',
},
} as const;
const mcpServerService = registerServer<keyof typeof methods>(methods, request);
const mcpServerService = {
get: (params: { mcp_id: string }) =>
request.get(api.getMcpServer(params.mcp_id), {
params: { mode: 'preview' },
}),
create: (params?: Record<string, any>) =>
request.post(api.createMcpServer, { data: params }),
update: ({ mcp_id, ...params }: Record<string, any>) =>
request.put(api.updateMcpServer(mcp_id), { data: params }),
delete: ({ mcp_id }: { mcp_id: string }) =>
request.delete(api.deleteMcpServer(mcp_id)),
import: (params?: Record<string, any>) =>
request.post(api.importMcpServer, { data: params }),
export: ({ mcp_id }: { mcp_id: string }) =>
request.get(api.exportMcpServer(mcp_id)),
test: (params: Record<string, any>) =>
request.post(api.testMcpServer(params.name || 'preview'), { data: params }),
};
export default mcpServerService;
export const listMcpServers = (params?: IPaginationRequestBody, body?: any) =>
request.post(api.listMcpServer, { data: body || {}, params });
request.get(api.listMcpServer, { params: { ...params, ...(body || {}) } });

View File

@@ -220,14 +220,15 @@ export default {
`${webAPI}/canvas/${canvasId}/completion`,
// mcp server
listMcpServer: `${webAPI}/mcp_server/list`,
getMcpServer: `${webAPI}/mcp_server/detail`,
createMcpServer: `${webAPI}/mcp_server/create`,
updateMcpServer: `${webAPI}/mcp_server/update`,
deleteMcpServer: `${webAPI}/mcp_server/rm`,
importMcpServer: `${webAPI}/mcp_server/import`,
exportMcpServer: `${webAPI}/mcp_server/export`,
testMcpServer: `${webAPI}/mcp_server/test_mcp`,
listMcpServer: `${restAPIv1}/mcp/servers`,
getMcpServer: (id: string) => `${restAPIv1}/mcp/servers/${id}`,
createMcpServer: `${restAPIv1}/mcp/servers`,
updateMcpServer: (id: string) => `${restAPIv1}/mcp/servers/${id}`,
deleteMcpServer: (id: string) => `${restAPIv1}/mcp/servers/${id}`,
importMcpServer: `${restAPIv1}/mcp/servers/import`,
exportMcpServer: (id: string) =>
`${restAPIv1}/mcp/servers/${id}?mode=download`,
testMcpServer: (id: string) => `${restAPIv1}/mcp/servers/${id}/test`,
// next-search
createSearch: `${restAPIv1}/searches`,