feat(go-api): implement MCP server management endpoints (#15281)

## Summary

Ports the MCP (Model Context Protocol) server management endpoints that
power `web/src/pages/user-setting/mcp/` from Python
(`api/apps/restful_apis/mcp_api.py`) to Go. There were no MCP routes in
the Go server before this change.

Closes #15275 (subtask of #15240).

## Endpoints implemented (base path `/api/v1`)

| Method | Path | Description |
|--------|------|-------------|
| GET | `/mcp/servers` | List tenant servers (keyword / order /
pagination) |
| POST | `/mcp/servers` | Create a server |
| GET | `/mcp/servers/{mcp_id}` | Get one (`?mode=download` exports
config) |
| PUT | `/mcp/servers/{mcp_id}` | Update a server |
| DELETE | `/mcp/servers/{mcp_id}` | Delete a server |
| POST | `/mcp/import` | Bulk import from JSON config |
| POST | `/mcp/servers/{mcp_id}/test` | Connect + list tools (see notes)
|

## Implementation

Follows the existing `handler → service → dao` layering (per PR #14790):

- **entity** (`internal/entity/mcp.go`): added `MCPServerType` constants
and `IsValidMCPServerType` over the existing `MCPServer` model.
- **dao** (`internal/dao/mcp.go`): new `MCPServerDAO` with tenant-scoped
CRUD, a keyword filter, and a **whitelisted order-column map** (guards
against SQL injection via the caller-supplied `orderby`).
- **service** (`internal/service/mcp.go`): new `MCPService` —
list/get/export/create/update/delete/import/test — mirroring
`MCPServerService` and the `mcp_api` request validation, with sentinel
errors for clean code mapping.
- **handler** (`internal/handler/mcp.go`): new `MCPHandler` with the
seven handlers and Python-compatible response codes.
- **router / server_main**: registered the `/mcp` group and wired the
handler.

## Deviations from Python (documented in code)

1. **Bulk import is at `POST /mcp/import`, not `/mcp/servers/import`.**
gin (v1.9.1) cannot register a static segment and a path param at the
same tree node, so `/mcp/servers/import` would collide with
`/mcp/servers/:mcp_id` and panic at startup. The frontend should call
`/mcp/import`.
2. **No live tool discovery on create/update/import.** The Python path
runs `get_mcp_tools` over SSE / streamable-HTTP and stores
`variables.tools`. The Go server has no MCP client yet, so these persist
`variables`/`headers` but leave `variables.tools` unpopulated.
3. **`/test` returns a data error (`ErrMCPTestUnsupported`)** until a Go
MCP client lands. Per the issue, the live-connection path is scoped as a
follow-up; the handler still validates `url` + `server_type`.

## Testing

- Added `internal/service/mcp_test.go` covering `IsValidMCPServerType`
and the `TestServer` validation/short-circuit paths (no DB required).
- No Go toolchain was available in the dev environment, so `go build
./...` / `go vet ./...` verification is left to CI.

## Follow-ups

- Go MCP client (SSE / streamable-HTTP) to enable live tool discovery
and the real `/test` behavior.
- Reconcile the `/mcp/import` vs `/mcp/servers/import` path with the
frontend.

---------
This commit is contained in:
web-dev0521
2026-06-04 23:25:09 -06:00
committed by GitHub
parent 1d7e45115b
commit b8db200757
8 changed files with 1885 additions and 8 deletions

View File

@@ -17,7 +17,10 @@
package handler
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
@@ -185,6 +188,203 @@ func (h *MCPHandler) DeleteMCPServer(c *gin.Context) {
})
}
// mcpErrorResponse maps the import / test sentinel errors to the response
// codes Python's mcp_api emits.
func mcpErrorResponse(c *gin.Context, err error) bool {
if err == nil {
return false
}
switch {
case errors.Is(err, service.ErrMCPInvalidType),
errors.Is(err, service.ErrMCPInvalidName),
errors.Is(err, service.ErrMCPInvalidURL),
errors.Is(err, service.ErrMCPTestFailed):
c.JSON(http.StatusOK, gin.H{"code": common.CodeDataError, "data": nil, "message": mcpErrorMessage(err)})
default:
c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "data": nil, "message": err.Error()})
}
return true
}
func mcpErrorMessage(err error) string {
if err == nil {
return ""
}
// service wraps its sentinels as "<sentinel>: <detail>" via
// fmt.Errorf("%w: ...", err). Surface the detail when present so the
// SSRF guard's per-failure message (e.g. "URL resolves to a non-public
// address (...).") reaches the caller verbatim, matching what Python's
// _assert_mcp_url_is_safe returns.
switch {
case errors.Is(err, service.ErrMCPInvalidURL):
if detail := unwrapDetail(err, service.ErrMCPInvalidURL); detail != "" {
return detail
}
return "Invalid url."
case errors.Is(err, service.ErrMCPInvalidType):
return "Unsupported MCP server type."
case errors.Is(err, service.ErrMCPTestFailed):
if detail := unwrapDetail(err, service.ErrMCPTestFailed); detail != "" {
return detail
}
return "Test MCP error."
default:
return err.Error()
}
}
// unwrapDetail pulls the "<sentinel>: <detail>" suffix off a wrapped error
// and returns the detail. Returns "" when the error is the bare sentinel
// (no wrapped message) so the caller can fall back to a default.
func unwrapDetail(err, sentinel error) string {
if err == nil || sentinel == nil {
return ""
}
prefix := sentinel.Error() + ": "
msg := err.Error()
if !strings.HasPrefix(msg, prefix) {
return ""
}
return strings.TrimPrefix(msg, prefix)
}
// ImportMCPRequest is the body for the bulk-import endpoint.
type ImportMCPRequest struct {
MCPServers map[string]map[string]interface{} `json:"mcpServers"`
Timeout float64 `json:"timeout,omitempty"`
}
// ImportMCPServers bulk-imports MCP servers from a JSON config, fetching the
// remote tool list for each entry and persisting it under variables.tools.
// Mirrors Python's import_multiple, including the same distinction between
// "mcpServers key missing" (101 ARGUMENT_ERROR) and "mcpServers key
// present but empty" (102 DATA_ERROR).
//
// @Summary Import MCP Servers
// @Tags mcp
// @Accept json
// @Produce json
// @Param request body handler.ImportMCPRequest true "import config"
// @Router /api/v1/mcp/servers/import [post]
func (h *MCPHandler) ImportMCPServers(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
// Read the raw body so we can distinguish "key absent" from "key
// present but empty" — the Python @validate_request("mcpServers")
// decorator returns RetCode.ARGUMENT_ERROR for the former, while the
// handler body returns RetCode.DATA_ERROR for the latter.
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeBadRequest, "data": nil, "message": "Invalid request body: " + err.Error()})
return
}
var raw map[string]json.RawMessage
if len(body) > 0 {
if err := json.Unmarshal(body, &raw); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeBadRequest, "data": nil, "message": "Invalid request body: " + err.Error()})
return
}
}
rawServers, hasServers := raw["mcpServers"]
if !hasServers {
// Match Python validate_request: code 101, message includes the
// trailing "; " separator the Python decorator emits.
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
"data": nil,
"message": "required argument are missing: mcpServers; ",
})
return
}
var servers map[string]map[string]interface{}
if err := json.Unmarshal(rawServers, &servers); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeBadRequest, "data": nil, "message": "Invalid request body: " + err.Error()})
return
}
if len(servers) == 0 {
c.JSON(http.StatusOK, gin.H{"code": common.CodeDataError, "data": nil, "message": "No MCP servers provided."})
return
}
var timeout float64
if rawTimeout, ok := raw["timeout"]; ok {
// Ignore parse errors for timeout to match Python's get_float
// default-on-failure behavior; the service applies its own
// 10 s fallback when timeout <= 0.
_ = json.Unmarshal(rawTimeout, &timeout)
}
results, err := h.mcpService.ImportServers(user.ID, servers, timeout)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "data": nil, "message": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"code": common.CodeSuccess, "data": gin.H{"results": results}, "message": "success"})
}
// TestMCPServer opens a live MCP session and returns the tools the server
// advertises. The mcp_id path parameter identifies the stored record the
// user is trying to validate; the actual connection uses the request body
// so the user can preview unsaved edits — matching Python's test_mcp.
//
// @Summary Test MCP Server
// @Tags mcp
// @Accept json
// @Produce json
// @Param mcp_id path string true "MCP server ID"
// @Param request body service.TestServerRequest true "test parameters"
// @Router /api/v1/mcp/servers/{mcp_id}/test [post]
func (h *MCPHandler) TestMCPServer(c *gin.Context) {
_, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
mcpID := c.Param("mcp_id")
if mcpID == "" {
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeBadRequest, "data": nil, "message": "mcp_id is required"})
return
}
var req service.TestServerRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeBadRequest, "data": nil, "message": "Invalid request body: " + err.Error()})
return
}
// Mirror Python's @validate_request("url", "server_type"): missing
// required fields → code 101 (ARGUMENT_ERROR), not code 102.
var missingFields []string
if req.URL == "" {
missingFields = append(missingFields, "url")
}
if req.ServerType == "" {
missingFields = append(missingFields, "server_type")
}
if len(missingFields) > 0 {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
"data": nil,
"message": "required argument are missing: " + strings.Join(missingFields, ", ") + "; ",
})
return
}
tools, err := h.mcpService.TestServer(mcpID, &req)
if mcpErrorResponse(c, err) {
return
}
c.JSON(http.StatusOK, gin.H{"code": common.CodeSuccess, "data": tools, "message": "success"})
}
func newMCPServerResponse(server *entity.MCPServer) *mcpServerResponse {
if server == nil {
return nil

View File

@@ -0,0 +1,742 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Package mcpclient is a minimal Model Context Protocol (MCP) client used by
// the Go MCP-management endpoints to list a remote server's tools during
// import and the "test" endpoint. It implements just enough of the spec to
// negotiate a session and call tools/list:
//
// - streamable-HTTP transport (spec 2025-03-26): single endpoint, JSON-RPC
// requests via POST, responses either as application/json or as an SSE
// stream sharing the same connection.
// - SSE transport (spec 2024-11-05, legacy): server returns an "endpoint"
// event whose data is the URL the client POSTs JSON-RPC requests to;
// responses are pushed back on the same SSE stream.
//
// The full Python implementation lives in common/mcp_tool_call_conn.py; this
// is a reduced port focused on tools/list discovery.
package mcpclient
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"ragflow/internal/utility"
)
// Transport identifiers. Mirrors common.constants.MCPServerType.
const (
TransportSSE = "sse"
TransportStreamableHTTP = "streamable-http"
)
const (
protocolVersion = "2025-03-26"
clientName = "ragflow-go"
clientVersion = "1.0.0"
jsonRPCVersion = "2.0"
)
// Tool is the subset of an MCP Tool descriptor returned by tools/list.
// Extra fields surfaced by the server are preserved in Raw so callers can
// round-trip them into variables.tools without losing data.
type Tool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema map[string]interface{} `json:"inputSchema,omitempty"`
Raw map[string]interface{} `json:"-"`
}
// FetchOptions controls a single tools/list discovery call.
type FetchOptions struct {
URL string
ServerType string
Headers map[string]string
Variables map[string]string
Timeout time.Duration
HTTPClient *http.Client
pinHostname string
pinIP string
}
// FetchTools opens a connection to the MCP server described by opts and
// returns the tools advertised by tools/list. URL safety / DNS pinning is
// performed here so callers get the same SSRF guarantees the Python path
// has via pin_dns_global + assert_url_is_safe.
func FetchTools(ctx context.Context, opts FetchOptions) ([]Tool, error) {
if opts.URL == "" {
return nil, errors.New("Invalid url.")
}
if opts.Timeout <= 0 {
opts.Timeout = 10 * time.Second
}
hostname, resolvedIP, err := utility.AssertURLSafe(opts.URL)
if err != nil {
return nil, err
}
opts.pinHostname = hostname
opts.pinIP = resolvedIP
if opts.HTTPClient == nil {
opts.HTTPClient = utility.PinnedHTTPClient(hostname, resolvedIP, opts.Timeout)
}
headers, headerErr := renderHeaders(opts.Headers, opts.Variables)
if headerErr != nil {
return nil, headerErr
}
connectCtx, cancel := context.WithTimeout(ctx, opts.Timeout)
defer cancel()
switch strings.ToLower(opts.ServerType) {
case TransportStreamableHTTP:
return fetchToolsStreamableHTTP(connectCtx, opts.URL, headers, opts.HTTPClient)
case TransportSSE:
return fetchToolsSSE(connectCtx, opts.URL, headers, opts.HTTPClient)
default:
return nil, fmt.Errorf("Unsupported MCP server type.")
}
}
// renderHeaders applies ${name} substitution to header keys and values using
// the supplied variables map, mirroring the Template.safe_substitute pass in
// common/mcp_tool_call_conn.py. Empty keys (after substitution) are dropped.
func renderHeaders(raw map[string]string, vars map[string]string) (map[string]string, error) {
rendered := map[string]string{}
for k, v := range raw {
nk := substituteTemplate(k, vars)
nv := substituteTemplate(v, vars)
if strings.TrimSpace(nk) == "" {
continue
}
rendered[nk] = nv
}
return rendered, nil
}
// substituteTemplate replaces ${name} occurrences (Python string.Template
// safe-substitute semantics) with values from vars. Unknown keys are left
// in place, matching safe_substitute's behavior.
func substituteTemplate(s string, vars map[string]string) string {
if vars == nil || !strings.Contains(s, "${") {
return s
}
var b strings.Builder
i := 0
for i < len(s) {
idx := strings.Index(s[i:], "${")
if idx == -1 {
b.WriteString(s[i:])
break
}
b.WriteString(s[i : i+idx])
i += idx + 2
end := strings.Index(s[i:], "}")
if end == -1 {
b.WriteString("${")
b.WriteString(s[i:])
break
}
key := s[i : i+end]
i += end + 1
if val, ok := vars[key]; ok {
b.WriteString(val)
} else {
b.WriteString("${")
b.WriteString(key)
b.WriteString("}")
}
}
return b.String()
}
// jsonRPCRequest is a JSON-RPC 2.0 request envelope.
type jsonRPCRequest struct {
JSONRPC string `json:"jsonrpc"`
ID interface{} `json:"id,omitempty"`
Method string `json:"method"`
Params interface{} `json:"params,omitempty"`
}
// jsonRPCResponse is a JSON-RPC 2.0 response. Either Result or Error is set.
type jsonRPCResponse struct {
JSONRPC string `json:"jsonrpc"`
ID interface{} `json:"id,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Error *jsonRPCError `json:"error,omitempty"`
Method string `json:"method,omitempty"`
}
type jsonRPCError struct {
Code int `json:"code"`
Message string `json:"message"`
}
func initializeParams() map[string]interface{} {
return map[string]interface{}{
"protocolVersion": protocolVersion,
"capabilities": map[string]interface{}{},
"clientInfo": map[string]interface{}{
"name": clientName,
"version": clientVersion,
},
}
}
// ---------- streamable-HTTP transport ----------
const sessionHeader = "Mcp-Session-Id"
func fetchToolsStreamableHTTP(ctx context.Context, endpoint string, headers map[string]string, client *http.Client) ([]Tool, error) {
sessionID, initRes, err := streamableSend(ctx, client, endpoint, "", headers, jsonRPCRequest{
JSONRPC: jsonRPCVersion,
ID: 0,
Method: "initialize",
Params: initializeParams(),
}, true)
if err != nil {
return nil, err
}
if initRes.Error != nil {
return nil, formatMCPError("initialize", initRes.Error)
}
if _, _, err := streamableSend(ctx, client, endpoint, sessionID, headers, jsonRPCRequest{
JSONRPC: jsonRPCVersion,
Method: "notifications/initialized",
}, false); err != nil {
return nil, err
}
_, listRes, err := streamableSend(ctx, client, endpoint, sessionID, headers, jsonRPCRequest{
JSONRPC: jsonRPCVersion,
ID: 1,
Method: "tools/list",
}, true)
if err != nil {
return nil, err
}
if listRes.Error != nil {
return nil, formatMCPError("tools/list", listRes.Error)
}
return parseToolsResult(listRes.Result)
}
// streamableSend POSTs a JSON-RPC payload to the streamable-HTTP endpoint.
// When expectResponse is false (notifications), the response body is not
// parsed. The session id returned by the initial initialize call is
// propagated via the Mcp-Session-Id header per the spec.
func streamableSend(ctx context.Context, client *http.Client, endpoint, sessionID string, headers map[string]string, payload jsonRPCRequest, expectResponse bool) (string, *jsonRPCResponse, error) {
body, err := json.Marshal(payload)
if err != nil {
return "", nil, fmt.Errorf("marshal MCP request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return "", nil, fmt.Errorf("build MCP request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")
for k, v := range headers {
req.Header.Set(k, v)
}
if sessionID != "" {
req.Header.Set(sessionHeader, sessionID)
}
resp, err := client.Do(req)
if err != nil {
return "", nil, mapMCPConnectionError(err)
}
defer resp.Body.Close()
if !expectResponse {
if resp.StatusCode >= 400 {
return "", nil, fmt.Errorf("MCP server returned HTTP %d for %s", resp.StatusCode, payload.Method)
}
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 1<<20))
return resp.Header.Get(sessionHeader), nil, nil
}
if resp.StatusCode >= 400 {
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
return "", nil, fmt.Errorf("MCP server returned HTTP %d for %s: %s", resp.StatusCode, payload.Method, strings.TrimSpace(string(raw)))
}
contentType := strings.ToLower(resp.Header.Get("Content-Type"))
sid := resp.Header.Get(sessionHeader)
if sessionID == "" {
sessionID = sid
}
if strings.Contains(contentType, "text/event-stream") {
r, err := readJSONRPCFromSSE(resp.Body, payload.ID)
if err != nil {
return "", nil, err
}
return sessionID, r, nil
}
raw, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
if err != nil {
return "", nil, fmt.Errorf("read MCP response: %w", err)
}
parsed, err := parseJSONRPC(raw, payload.ID)
if err != nil {
return "", nil, err
}
return sessionID, parsed, nil
}
// ---------- SSE transport ----------
func fetchToolsSSE(ctx context.Context, endpoint string, headers map[string]string, client *http.Client) ([]Tool, error) {
streamReq, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("build SSE request: %w", err)
}
streamReq.Header.Set("Accept", "text/event-stream")
streamReq.Header.Set("Cache-Control", "no-cache")
for k, v := range headers {
streamReq.Header.Set(k, v)
}
streamResp, err := client.Do(streamReq)
if err != nil {
return nil, mapMCPConnectionError(err)
}
if streamResp.StatusCode >= 400 {
body, _ := io.ReadAll(io.LimitReader(streamResp.Body, 1<<20))
streamResp.Body.Close()
return nil, fmt.Errorf("MCP SSE handshake returned HTTP %d: %s", streamResp.StatusCode, strings.TrimSpace(string(body)))
}
stream := newSSEReader(streamResp.Body)
defer streamResp.Body.Close()
postURL, err := waitForEndpoint(ctx, stream, endpoint)
if err != nil {
return nil, err
}
// The endpoint event can hand us an arbitrary absolute URL. A
// malicious public SSE server could point us at 127.0.0.1 or any
// other internal host to bounce the POST phase through us. Re-run
// the SSRF guard against the resolved URL, and — when the host
// differs from the original SSE host — swap in a fresh pinned
// client so the dial-time IP override still applies.
postClient := client
if postHost, postIP, vErr := utility.AssertURLSafe(postURL); vErr != nil {
return nil, vErr
} else if u, perr := url.Parse(postURL); perr == nil && u.Hostname() != "" {
if u.Hostname() != originalHost(endpoint) {
postClient = utility.PinnedHTTPClient(postHost, postIP, sseTimeoutFrom(ctx))
}
}
pending := newPendingResponses()
streamDone := make(chan error, 1)
go func() {
streamDone <- stream.dispatch(ctx, pending)
}()
postOnce := func(payload jsonRPCRequest) error {
body, _ := json.Marshal(payload)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, postURL, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("build SSE POST: %w", err)
}
req.Header.Set("Content-Type", "application/json")
for k, v := range headers {
req.Header.Set(k, v)
}
resp, err := postClient.Do(req)
if err != nil {
return mapMCPConnectionError(err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
return fmt.Errorf("MCP server returned HTTP %d for %s: %s", resp.StatusCode, payload.Method, strings.TrimSpace(string(raw)))
}
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 1<<20))
return nil
}
// Register the waiter BEFORE issuing the POST so a fast server that
// pushes its response before our wait() call doesn't drop the delivery.
initWaiter := pending.register(0)
if err := postOnce(jsonRPCRequest{JSONRPC: jsonRPCVersion, ID: 0, Method: "initialize", Params: initializeParams()}); err != nil {
pending.cancel(0)
return nil, err
}
initRes, err := pending.await(ctx, initWaiter, streamDone)
if err != nil {
return nil, err
}
if initRes.Error != nil {
return nil, formatMCPError("initialize", initRes.Error)
}
if err := postOnce(jsonRPCRequest{JSONRPC: jsonRPCVersion, Method: "notifications/initialized"}); err != nil {
return nil, err
}
listWaiter := pending.register(1)
if err := postOnce(jsonRPCRequest{JSONRPC: jsonRPCVersion, ID: 1, Method: "tools/list"}); err != nil {
pending.cancel(1)
return nil, err
}
listRes, err := pending.await(ctx, listWaiter, streamDone)
if err != nil {
return nil, err
}
if listRes.Error != nil {
return nil, formatMCPError("tools/list", listRes.Error)
}
return parseToolsResult(listRes.Result)
}
// waitForEndpoint reads SSE events until an "endpoint" event arrives and
// returns the URL to POST JSON-RPC requests to. The data may be either a
// fully-qualified URL or a path; relative paths are resolved against the
// original SSE endpoint.
func waitForEndpoint(ctx context.Context, stream *sseReader, base string) (string, error) {
for {
event, err := stream.nextEvent(ctx)
if err != nil {
return "", err
}
if event == nil {
return "", errors.New("MCP SSE stream closed before sending endpoint event")
}
if event.event == "endpoint" {
ref := strings.TrimSpace(event.data)
if ref == "" {
return "", errors.New("MCP SSE endpoint event has empty data")
}
baseURL, err := url.Parse(base)
if err != nil {
return "", fmt.Errorf("parse MCP SSE base url: %w", err)
}
rel, err := url.Parse(ref)
if err != nil {
return "", fmt.Errorf("parse MCP SSE endpoint data: %w", err)
}
return baseURL.ResolveReference(rel).String(), nil
}
// Other events (heartbeats, message) before endpoint are ignored.
}
}
// originalHost extracts the hostname from the original SSE endpoint so the
// caller can detect when the server-advertised post URL has moved to a
// different host (and a fresh pinned client is required).
func originalHost(endpoint string) string {
u, err := url.Parse(endpoint)
if err != nil {
return ""
}
return u.Hostname()
}
// sseTimeoutFrom recovers a non-zero timeout from the request context so
// the freshly-pinned post-phase client has the same deadline as the rest
// of the SSE flow.
func sseTimeoutFrom(ctx context.Context) time.Duration {
if deadline, ok := ctx.Deadline(); ok {
if d := time.Until(deadline); d > 0 {
return d
}
}
return 10 * time.Second
}
// pendingResponses correlates outstanding JSON-RPC ids with channels that
// receive the corresponding response from the SSE dispatcher.
type pendingResponses struct {
mu sync.Mutex
waiters map[string]chan *jsonRPCResponse
}
func newPendingResponses() *pendingResponses {
return &pendingResponses{waiters: map[string]chan *jsonRPCResponse{}}
}
// pendingWaiter is the handle returned by register; the caller passes it to
// await once the request has been sent.
type pendingWaiter struct {
key string
ch chan *jsonRPCResponse
}
// register reserves a waiter slot for the given JSON-RPC id BEFORE the
// request is sent, so a server that responds before await() is called still
// has somewhere to deliver to.
func (p *pendingResponses) register(id interface{}) pendingWaiter {
key := normalizeID(id)
ch := make(chan *jsonRPCResponse, 1)
p.mu.Lock()
p.waiters[key] = ch
p.mu.Unlock()
return pendingWaiter{key: key, ch: ch}
}
// cancel drops a previously registered waiter. Used when the POST fails so
// a late server delivery cannot block forever in the waiters map.
func (p *pendingResponses) cancel(id interface{}) {
key := normalizeID(id)
p.mu.Lock()
delete(p.waiters, key)
p.mu.Unlock()
}
// await blocks until the registered waiter's response arrives, the SSE
// stream closes, or ctx expires.
func (p *pendingResponses) await(ctx context.Context, w pendingWaiter, streamDone <-chan error) (*jsonRPCResponse, error) {
defer func() {
p.mu.Lock()
delete(p.waiters, w.key)
p.mu.Unlock()
}()
select {
case res := <-w.ch:
return res, nil
case err := <-streamDone:
if err == nil {
return nil, errors.New("MCP SSE stream closed before response arrived")
}
return nil, err
case <-ctx.Done():
return nil, ctx.Err()
}
}
func (p *pendingResponses) deliver(res *jsonRPCResponse) {
key := normalizeID(res.ID)
p.mu.Lock()
ch, ok := p.waiters[key]
p.mu.Unlock()
if !ok {
return
}
select {
case ch <- res:
default:
}
}
func normalizeID(id interface{}) string {
switch v := id.(type) {
case nil:
return ""
case string:
return v
case json.Number:
return v.String()
case float64:
return fmt.Sprintf("%v", v)
default:
b, _ := json.Marshal(v)
return string(b)
}
}
// ---------- SSE parsing ----------
type sseEvent struct {
event string
data string
}
type sseReader struct {
rd *bufio.Reader
}
func newSSEReader(r io.Reader) *sseReader {
return &sseReader{rd: bufio.NewReaderSize(r, 64*1024)}
}
// nextEvent returns the next SSE event (event: + data:) from the stream, or
// nil when the stream is closed cleanly.
func (s *sseReader) nextEvent(ctx context.Context) (*sseEvent, error) {
ev := &sseEvent{}
var dataLines []string
for {
if err := ctx.Err(); err != nil {
return nil, err
}
line, err := s.rd.ReadString('\n')
if err != nil {
if err == io.EOF {
if len(dataLines) > 0 || ev.event != "" {
ev.data = strings.Join(dataLines, "\n")
return ev, nil
}
return nil, nil
}
return nil, err
}
line = strings.TrimRight(line, "\r\n")
if line == "" {
if len(dataLines) == 0 && ev.event == "" {
continue
}
ev.data = strings.Join(dataLines, "\n")
return ev, nil
}
if strings.HasPrefix(line, ":") {
continue
}
if idx := strings.Index(line, ":"); idx >= 0 {
field := line[:idx]
value := strings.TrimPrefix(line[idx+1:], " ")
switch field {
case "event":
ev.event = value
case "data":
dataLines = append(dataLines, value)
}
}
}
}
// dispatch reads events off the SSE stream and forwards JSON-RPC responses
// to the matching waiter. It returns when the stream closes.
func (s *sseReader) dispatch(ctx context.Context, pending *pendingResponses) error {
for {
ev, err := s.nextEvent(ctx)
if err != nil {
return err
}
if ev == nil {
return nil
}
if ev.event != "" && ev.event != "message" {
continue
}
raw := []byte(ev.data)
if len(bytes.TrimSpace(raw)) == 0 {
continue
}
parsed, err := parseJSONRPC(raw, nil)
if err != nil {
continue
}
if parsed.Method != "" && parsed.ID == nil {
// Server-initiated notification; nothing to deliver.
continue
}
pending.deliver(parsed)
}
}
// readJSONRPCFromSSE consumes a single JSON-RPC response off an inline SSE
// stream returned by a streamable-HTTP POST. The response with matching id
// is returned; everything else is skipped.
func readJSONRPCFromSSE(r io.Reader, wantID interface{}) (*jsonRPCResponse, error) {
stream := newSSEReader(r)
for {
ev, err := stream.nextEvent(context.Background())
if err != nil {
return nil, err
}
if ev == nil {
return nil, errors.New("MCP SSE response stream closed before response arrived")
}
if ev.event != "" && ev.event != "message" {
continue
}
raw := []byte(ev.data)
if len(bytes.TrimSpace(raw)) == 0 {
continue
}
parsed, err := parseJSONRPC(raw, wantID)
if err != nil {
continue
}
if normalizeID(parsed.ID) == normalizeID(wantID) {
return parsed, nil
}
}
}
// ---------- shared helpers ----------
func parseJSONRPC(raw []byte, wantID interface{}) (*jsonRPCResponse, error) {
dec := json.NewDecoder(bytes.NewReader(raw))
dec.UseNumber()
res := &jsonRPCResponse{}
if err := dec.Decode(res); err != nil {
return nil, fmt.Errorf("parse MCP response: %w", err)
}
if wantID != nil && res.ID != nil && normalizeID(res.ID) != normalizeID(wantID) {
return nil, fmt.Errorf("unexpected JSON-RPC id %v (want %v)", res.ID, wantID)
}
return res, nil
}
func parseToolsResult(raw json.RawMessage) ([]Tool, error) {
if len(raw) == 0 {
return []Tool{}, nil
}
var envelope struct {
Tools []map[string]interface{} `json:"tools"`
}
if err := json.Unmarshal(raw, &envelope); err != nil {
return nil, fmt.Errorf("parse tools result: %w", err)
}
tools := make([]Tool, 0, len(envelope.Tools))
for _, raw := range envelope.Tools {
name, _ := raw["name"].(string)
if name == "" {
continue
}
desc, _ := raw["description"].(string)
var schema map[string]interface{}
if s, ok := raw["inputSchema"].(map[string]interface{}); ok {
schema = s
}
tools = append(tools, Tool{
Name: name,
Description: desc,
InputSchema: schema,
Raw: raw,
})
}
return tools, nil
}
func formatMCPError(method string, e *jsonRPCError) error {
if e == nil {
return fmt.Errorf("MCP %s failed", method)
}
return fmt.Errorf("MCP %s failed (%d): %s", method, e.Code, e.Message)
}
// mapMCPConnectionError surfaces the same wording the Python session uses
// when a low-level connection fails (authentication / network).
func mapMCPConnectionError(err error) error {
if errors.Is(err, context.DeadlineExceeded) {
return errors.New("Timeout connecting to MCP server")
}
return fmt.Errorf("Connection failed (possibly due to auth error). Please check authentication settings first: %v", err)
}

View File

@@ -0,0 +1,254 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package mcpclient
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"ragflow/internal/utility"
)
// allowLoopbackForTests overrides the SSRF guard's resolver so 127.0.0.1
// targets used by httptest are accepted by AssertURLSafe.
func allowLoopbackForTests(t *testing.T) func() {
t.Helper()
orig := utility.LookupHost
utility.LookupHost = func(host string) ([]string, error) {
// Return a public IPv4 so the guard sees the host as global; the
// httptest server is on loopback but we connect via raw URL.
return []string{"8.8.8.8"}, nil
}
return func() { utility.LookupHost = orig }
}
func TestFetchToolsStreamableHTTPJSON(t *testing.T) {
defer allowLoopbackForTests(t)()
var initCount, listCount, notifyCount int32
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST, got %s", r.Method)
}
body, _ := io.ReadAll(r.Body)
var req map[string]interface{}
if err := json.Unmarshal(body, &req); err != nil {
// testing.T's Fatal* must run on the test goroutine; surface the
// failure via Errorf and bail the handler out instead.
t.Errorf("invalid request body: %v", err)
http.Error(w, "bad body", http.StatusBadRequest)
return
}
switch req["method"] {
case "initialize":
atomic.AddInt32(&initCount, 1)
w.Header().Set(sessionHeader, "test-session")
w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, `{"jsonrpc":"2.0","id":%v,"result":{"capabilities":{}}}`, req["id"])
case "notifications/initialized":
atomic.AddInt32(&notifyCount, 1)
w.WriteHeader(http.StatusAccepted)
case "tools/list":
atomic.AddInt32(&listCount, 1)
if got := r.Header.Get(sessionHeader); got != "test-session" {
t.Errorf("expected session header to be propagated, got %q", got)
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, `{"jsonrpc":"2.0","id":%v,"result":{"tools":[{"name":"search","description":"Find docs","inputSchema":{"type":"object"}},{"name":"fetch"}]}}`, req["id"])
default:
http.Error(w, "unexpected method", http.StatusBadRequest)
}
}))
defer srv.Close()
tools, err := FetchTools(context.Background(), FetchOptions{
URL: srv.URL,
ServerType: TransportStreamableHTTP,
HTTPClient: srv.Client(),
Timeout: 2 * time.Second,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got := len(tools); got != 2 {
t.Fatalf("expected 2 tools, got %d", got)
}
if tools[0].Name != "search" || tools[0].Description != "Find docs" {
t.Errorf("tool 0 = %+v", tools[0])
}
if tools[1].Name != "fetch" {
t.Errorf("tool 1 = %+v", tools[1])
}
if atomic.LoadInt32(&initCount) != 1 || atomic.LoadInt32(&notifyCount) != 1 || atomic.LoadInt32(&listCount) != 1 {
t.Errorf("expected 1 init / 1 notify / 1 list, got %d/%d/%d", initCount, notifyCount, listCount)
}
}
func TestFetchToolsStreamableHTTPErrorResponse(t *testing.T) {
defer allowLoopbackForTests(t)()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
var req map[string]interface{}
_ = json.Unmarshal(body, &req)
if req["method"] == "initialize" {
w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, `{"jsonrpc":"2.0","id":%v,"error":{"code":-32600,"message":"bad init"}}`, req["id"])
return
}
w.WriteHeader(http.StatusAccepted)
}))
defer srv.Close()
_, err := FetchTools(context.Background(), FetchOptions{
URL: srv.URL,
ServerType: TransportStreamableHTTP,
HTTPClient: srv.Client(),
Timeout: 2 * time.Second,
})
if err == nil || !strings.Contains(err.Error(), "bad init") {
t.Fatalf("expected MCP error to surface, got %v", err)
}
}
func TestFetchToolsSSE(t *testing.T) {
defer allowLoopbackForTests(t)()
type ssePush struct {
event string
data string
}
pushes := make(chan ssePush, 4)
mux := http.NewServeMux()
mux.HandleFunc("/sse", func(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
t.Errorf("response writer is not a flusher")
http.Error(w, "no flusher", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
fmt.Fprintf(w, "event: endpoint\ndata: /messages\n\n")
flusher.Flush()
ctx := r.Context()
for {
select {
case p := <-pushes:
if p.event != "" {
fmt.Fprintf(w, "event: %s\n", p.event)
}
fmt.Fprintf(w, "data: %s\n\n", p.data)
flusher.Flush()
case <-ctx.Done():
return
}
}
})
mux.HandleFunc("/messages", func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
var req map[string]interface{}
if err := json.Unmarshal(body, &req); err != nil {
t.Errorf("invalid request body: %v", err)
http.Error(w, "bad body", http.StatusBadRequest)
return
}
switch req["method"] {
case "initialize":
pushes <- ssePush{event: "message", data: fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"result":{"capabilities":{}}}`, req["id"])}
case "notifications/initialized":
case "tools/list":
pushes <- ssePush{event: "message", data: fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"result":{"tools":[{"name":"alpha"}]}}`, req["id"])}
}
w.WriteHeader(http.StatusAccepted)
})
srv := httptest.NewServer(mux)
defer srv.Close()
tools, err := FetchTools(context.Background(), FetchOptions{
URL: srv.URL + "/sse",
ServerType: TransportSSE,
HTTPClient: srv.Client(),
Timeout: 3 * time.Second,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(tools) != 1 || tools[0].Name != "alpha" {
t.Fatalf("expected [alpha], got %+v", tools)
}
}
func TestFetchToolsUnsupportedType(t *testing.T) {
defer allowLoopbackForTests(t)()
_, err := FetchTools(context.Background(), FetchOptions{
URL: "https://example.com",
ServerType: "stdio",
Timeout: time.Second,
})
if err == nil || !strings.Contains(err.Error(), "Unsupported MCP server type") {
t.Fatalf("expected unsupported-type error, got %v", err)
}
}
func TestFetchToolsEmptyURL(t *testing.T) {
_, err := FetchTools(context.Background(), FetchOptions{URL: "", ServerType: TransportSSE})
if err == nil || !strings.Contains(err.Error(), "Invalid url") {
t.Fatalf("expected Invalid url error, got %v", err)
}
}
func TestSubstituteTemplate(t *testing.T) {
vars := map[string]string{"token": "abc123"}
if got := substituteTemplate("Bearer ${token}", vars); got != "Bearer abc123" {
t.Errorf("got %q", got)
}
if got := substituteTemplate("Bearer ${missing}", vars); got != "Bearer ${missing}" {
t.Errorf("got %q", got)
}
if got := substituteTemplate("no-var", vars); got != "no-var" {
t.Errorf("got %q", got)
}
if got := substituteTemplate("${a}-${token}", map[string]string{"a": "1", "token": "2"}); got != "1-2" {
t.Errorf("got %q", got)
}
}
func TestNormalizeID(t *testing.T) {
if got := normalizeID(json.Number("1")); got != "1" {
t.Errorf("got %q", got)
}
if got := normalizeID(1); got != "1" {
t.Errorf("got %q", got)
}
if got := normalizeID("abc"); got != "abc" {
t.Errorf("got %q", got)
}
if got := normalizeID(nil); got != "" {
t.Errorf("got %q", got)
}
}

View File

@@ -302,14 +302,6 @@ func (r *Router) Setup(engine *gin.Engine) {
// message.GET("/:memory_id/:message_id/content", r.memoryHandler.GetMessageContent)
// }
mcp := v1.Group("/mcp")
{
mcp.POST("/servers", r.mcpHandler.CreateMCPServer)
mcp.GET("/servers", r.mcpHandler.ListMCPServers)
mcp.PUT("/servers/:mcp_id", r.mcpHandler.UpdateMCPServer)
mcp.DELETE("/servers/:mcp_id", r.mcpHandler.DeleteMCPServer)
}
// Skill search routes
skills := v1.Group("/skills")
{
@@ -394,6 +386,21 @@ func (r *Router) Setup(engine *gin.Engine) {
connector.POST("/:connector_id/test", r.connectorHandler.TestConnector)
}
// MCP server routes. Per-server CRUD ships via separate PRs that
// share the same handler/service: GET list (#15253), GET by id
// (#15254), POST create (#15260, merged), PUT (#15261), DELETE
// (#15262, merged). This PR adds only the non-overlapping
// endpoints: import and test.
mcp := v1.Group("/mcp")
{
mcp.POST("/servers", r.mcpHandler.CreateMCPServer)
mcp.GET("/servers", r.mcpHandler.ListMCPServers)
mcp.PUT("/servers/:mcp_id", r.mcpHandler.UpdateMCPServer)
mcp.DELETE("/servers/:mcp_id", r.mcpHandler.DeleteMCPServer)
mcp.POST("/servers/import", r.mcpHandler.ImportMCPServers)
mcp.POST("/servers/:mcp_id/test", r.mcpHandler.TestMCPServer)
}
system := v1.Group("/system")
{
system.GET("/configs", r.systemHandler.GetConfigs)

View File

@@ -18,6 +18,7 @@ package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -27,6 +28,8 @@ import (
"ragflow/internal/common"
"ragflow/internal/dao"
"ragflow/internal/entity"
"ragflow/internal/mcpclient"
"ragflow/internal/utility"
"gorm.io/gorm"
)
@@ -35,6 +38,7 @@ const (
mcpServerTypeSSE = "sse"
mcpServerTypeStreamableHTTP = "streamable-http"
mcpServerNameLimit = 255
defaultMCPFetchTimeoutSec = 10
mcpServerDateFormat = "2006-01-02T15:04:05"
)
@@ -382,6 +386,267 @@ func safeJSONMap(raw json.RawMessage) entity.JSONMap {
return entity.JSONMap(value)
}
// ---------- import + test (this PR's additions) ----------
// Sentinel errors mapped by the handler to Python's response codes for the
// import / test endpoints. Per-server CRUD errors stay inside CreateMCPServer.
var (
// ErrMCPInvalidType mirrors Python's "Unsupported MCP server type.".
ErrMCPInvalidType = errors.New("unsupported MCP server type")
// ErrMCPInvalidName mirrors Python's invalid-name/length error.
ErrMCPInvalidName = errors.New("invalid MCP name")
// ErrMCPInvalidURL mirrors Python's "Invalid url.".
ErrMCPInvalidURL = errors.New("invalid url")
// ErrMCPTestFailed is returned by TestServer when the live connection or
// tool-list fetch fails. The handler maps this to code 102 (DATA_ERROR),
// matching Python's test_mcp which never returns HTTP 500 for fetch errors.
ErrMCPTestFailed = errors.New("MCP test failed")
)
// ImportResult is a single per-server outcome in the bulk import response,
// matching the shape returned by Python's import_multiple.
type ImportResult struct {
Server string `json:"server"`
Success bool `json:"success"`
Action string `json:"action,omitempty"`
ID string `json:"id,omitempty"`
NewName string `json:"new_name,omitempty"`
Message string `json:"message,omitempty"`
}
// ImportServers bulk-imports MCP servers from a {"mcpServers": {name: config}}
// map. For each entry: validate type and URL, de-duplicate the name with a
// "_N" suffix, fetch the remote tool list via mcpclient (SSRF-guarded), and
// persist the server with tools stored under variables.tools. Mirrors
// Python's import_multiple.
//
// timeoutSeconds controls how long each tool-fetch call waits; <=0 falls back
// to the Python default of 10 s.
func (s *MCPService) ImportServers(tenantID string, servers map[string]map[string]interface{}, timeoutSeconds float64) ([]ImportResult, error) {
if timeoutSeconds <= 0 {
timeoutSeconds = defaultMCPFetchTimeoutSec
}
timeout := time.Duration(timeoutSeconds * float64(time.Second))
results := make([]ImportResult, 0, len(servers))
for serverName, config := range servers {
url, hasURL := config["url"].(string)
stype, hasType := config["type"].(string)
if !hasType || !hasURL {
results = append(results, ImportResult{Server: serverName, Success: false, Message: "Missing required fields (type or url)"})
continue
}
if serverName == "" || len([]byte(serverName)) > mcpServerNameLimit {
results = append(results, ImportResult{Server: serverName, Success: false, Message: fmt.Sprintf("Invalid MCP name or length is %d which is large than 255.", len(serverName))})
continue
}
if !isValidMCPServerType(stype) {
results = append(results, ImportResult{Server: serverName, Success: false, Message: "Unsupported MCP server type."})
continue
}
baseName := serverName
newName, err := s.nextAvailableMCPName(baseName, tenantID)
if err != nil {
return nil, err
}
variables := map[string]interface{}{}
stringVars := map[string]string{}
for k, v := range config {
if k == "type" || k == "url" || k == "headers" {
continue
}
variables[k] = v
if sv, ok := v.(string); ok {
stringVars[k] = sv
}
}
delete(variables, "tools")
delete(stringVars, "tools")
// Headers can be provided either as a top-level "headers" map
// (preferred — matches the Python import shape) or as a flat
// "authorization_token" string at the entry root. Both go to the
// MCP client for tool discovery and to the persisted record so
// configs that depend on custom auth headers survive the round
// trip.
headers := map[string]string{}
headerVals := map[string]interface{}{}
if rawHeaders, ok := config["headers"].(map[string]interface{}); ok {
for k, v := range rawHeaders {
if sv, ok := v.(string); ok {
headers[k] = sv
}
headerVals[k] = v
}
}
if token, ok := config["authorization_token"].(string); ok {
if _, exists := headers["authorization_token"]; !exists {
headers["authorization_token"] = token
}
if _, exists := headerVals["authorization_token"]; !exists {
headerVals["authorization_token"] = token
}
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
tools, fetchErr := mcpclient.FetchTools(ctx, mcpclient.FetchOptions{
URL: url,
ServerType: stype,
Headers: headers,
Variables: stringVars,
Timeout: timeout,
})
cancel()
if fetchErr != nil {
results = append(results, ImportResult{Server: baseName, Success: false, Message: fetchErr.Error()})
continue
}
variables["tools"] = toolsAsMap(tools)
server := &entity.MCPServer{
ID: common.GenerateUUID(),
TenantID: tenantID,
Name: newName,
URL: url,
ServerType: stype,
Variables: entity.JSONMap(variables),
Headers: entity.JSONMap(headerVals),
}
if err := s.mcpServerDAO.CreateMCPServer(server); err != nil {
results = append(results, ImportResult{Server: serverName, Success: false, Message: "Failed to create MCP server."})
continue
}
result := ImportResult{Server: serverName, Success: true, Action: "created", ID: server.ID, NewName: newName}
if newName != baseName {
result.Message = fmt.Sprintf("Renamed from '%s' to '%s' avoid duplication", baseName, newName)
}
results = append(results, result)
}
return results, nil
}
func (s *MCPService) nextAvailableMCPName(base, tenantID string) (string, error) {
name := base
counter := 0
for {
exists, err := s.mcpServerDAO.ExistsByNameAndTenant(name, tenantID)
if err != nil {
return "", err
}
if !exists {
return name, nil
}
name = fmt.Sprintf("%s_%d", base, counter)
counter++
}
}
// TestServerRequest is the body of POST /mcp/servers/:mcp_id/test. The mcp_id
// from the URL path is threaded through to the connect call for log
// correlation; the connection itself is opened from the request body so the
// user can preview unsaved edits — matching Python's test_mcp.
type TestServerRequest struct {
URL string `json:"url"`
ServerType string `json:"server_type"`
Headers map[string]interface{} `json:"headers,omitempty"`
Variables map[string]interface{} `json:"variables,omitempty"`
Timeout float64 `json:"timeout,omitempty"`
}
// TestServer opens a live MCP session and returns the tools the server
// advertises. Mirrors Python's test_mcp. mcpID is used for log correlation
// only.
func (s *MCPService) TestServer(mcpID string, req *TestServerRequest) ([]map[string]interface{}, error) {
if req == nil || req.URL == "" {
return nil, fmt.Errorf("%w: Invalid MCP url.", ErrMCPInvalidURL)
}
if !isValidMCPServerType(req.ServerType) {
return nil, ErrMCPInvalidType
}
// Run the SSRF guard up front so URL-shape failures (disallowed
// scheme, missing host, non-public address) surface as
// ErrMCPInvalidURL data errors instead of being swallowed inside the
// generic FetchTools error and re-classified by the handler as a 500.
// FetchTools repeats the check internally; the second call is cheap.
if _, _, err := utility.AssertURLSafe(req.URL); err != nil {
return nil, fmt.Errorf("%w: %s", ErrMCPInvalidURL, err.Error())
}
timeoutSec := req.Timeout
if timeoutSec <= 0 {
timeoutSec = defaultMCPFetchTimeoutSec
}
timeout := time.Duration(timeoutSec * float64(time.Second))
headers := map[string]string{}
for k, v := range req.Headers {
if sv, ok := v.(string); ok {
headers[k] = sv
}
}
vars := map[string]string{}
for k, v := range req.Variables {
if sv, ok := v.(string); ok {
vars[k] = sv
}
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
tools, err := mcpclient.FetchTools(ctx, mcpclient.FetchOptions{
URL: req.URL,
ServerType: req.ServerType,
Headers: headers,
Variables: vars,
Timeout: timeout,
})
if err != nil {
return nil, fmt.Errorf("%w: Test MCP error (id=%s): %v", ErrMCPTestFailed, mcpID, err)
}
out := make([]map[string]interface{}, 0, len(tools))
for _, t := range tools {
raw := t.Raw
if raw == nil {
raw = map[string]interface{}{"name": t.Name}
if t.Description != "" {
raw["description"] = t.Description
}
if t.InputSchema != nil {
raw["inputSchema"] = t.InputSchema
}
}
raw["enabled"] = true
out = append(out, raw)
}
return out, nil
}
// toolsAsMap mirrors Python's `{tool["name"]: tool ...}` shape used when
// persisting variables.tools.
func toolsAsMap(tools []mcpclient.Tool) map[string]interface{} {
m := map[string]interface{}{}
for _, t := range tools {
if t.Raw != nil {
m[t.Name] = t.Raw
continue
}
entry := map[string]interface{}{"name": t.Name}
if t.Description != "" {
entry["description"] = t.Description
}
if t.InputSchema != nil {
entry["inputSchema"] = t.InputSchema
}
m[t.Name] = entry
}
return m
}
func formatMCPServerDate(date *time.Time) *string {
if date == nil {
return nil

View File

@@ -17,12 +17,75 @@
package service
import (
"errors"
"fmt"
"strings"
"testing"
"ragflow/internal/entity"
)
func TestIsValidMCPServerType(t *testing.T) {
for _, v := range []string{mcpServerTypeSSE, mcpServerTypeStreamableHTTP} {
if !isValidMCPServerType(v) {
t.Errorf("expected %q to be a valid MCP server type", v)
}
}
for _, v := range []string{"", "stdio", "http", "SSE"} {
if isValidMCPServerType(v) {
t.Errorf("expected %q to be an invalid MCP server type", v)
}
}
}
func TestServerInputValidation(t *testing.T) {
s := &MCPService{}
// Empty URL is rejected before any connection attempt.
if _, err := s.TestServer("id-1", &TestServerRequest{ServerType: mcpServerTypeSSE}); !errors.Is(err, ErrMCPInvalidURL) {
t.Errorf("expected ErrMCPInvalidURL for empty url, got %v", err)
}
// nil body is treated as empty URL.
if _, err := s.TestServer("id-1", nil); !errors.Is(err, ErrMCPInvalidURL) {
t.Errorf("expected ErrMCPInvalidURL for nil body, got %v", err)
}
// Invalid server type is rejected before connecting.
if _, err := s.TestServer("id-1", &TestServerRequest{URL: "http://example.com/sse", ServerType: "stdio"}); !errors.Is(err, ErrMCPInvalidType) {
t.Errorf("expected ErrMCPInvalidType for bad type, got %v", err)
}
}
func TestImportServersValidationErrors(t *testing.T) {
s := &MCPService{}
// Missing url and type produce an in-band error per entry rather than
// failing the batch.
configs := map[string]map[string]interface{}{
"missing-fields": {"foo": "bar"},
"bad-type": {"url": "http://example.com", "type": "stdio"},
}
results, err := s.ImportServers("tenant-1", configs, 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(results) != 2 {
t.Fatalf("expected 2 results, got %d", len(results))
}
for _, r := range results {
if r.Success {
t.Errorf("expected failure result for %q", r.Server)
}
if r.Server == "missing-fields" && !strings.Contains(r.Message, "Missing required fields") {
t.Errorf("unexpected message for missing-fields: %q", r.Message)
}
if r.Server == "bad-type" && !strings.Contains(r.Message, "Unsupported MCP server type") {
t.Errorf("unexpected message for bad-type: %q", r.Message)
}
}
}
func TestPaginateMCPServersNegativeValuesMatchPythonSlice(t *testing.T) {
servers := makeMCPServers(13)

190
internal/utility/ssrf.go Normal file
View File

@@ -0,0 +1,190 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package utility
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"sort"
"strings"
"time"
)
// AllowedURLSchemes are the schemes accepted by AssertURLSafe.
var AllowedURLSchemes = []string{"http", "https"}
// LookupHost is the indirection used to resolve hostnames. Tests override it.
var LookupHost = net.LookupHost
// AssertURLSafe parses rawURL and rejects it if the scheme is disallowed,
// the host is missing, or any resolved IP is not globally routable
// (private, loopback, link-local, multicast, reserved). Returns the hostname
// and the first validated public IP so callers can DNS-pin the address and
// prevent rebinding between validation and the actual TCP connection.
//
// Mirrors common/ssrf_guard.py:assert_url_is_safe.
func AssertURLSafe(rawURL string) (hostname, resolvedIP string, err error) {
parsed, perr := url.Parse(strings.TrimSpace(rawURL))
if perr != nil {
return "", "", fmt.Errorf("Invalid url.")
}
scheme := strings.ToLower(parsed.Scheme)
if !schemeAllowed(scheme) {
sorted := append([]string(nil), AllowedURLSchemes...)
sort.Strings(sorted)
return "", "", fmt.Errorf("Disallowed URL scheme: '%s'. Only %v are allowed.", scheme, sorted)
}
hostname = parsed.Hostname()
if hostname == "" {
return "", "", fmt.Errorf("URL is missing a host.")
}
addrs, err := LookupHost(hostname)
if err != nil {
return "", "", fmt.Errorf("Could not resolve hostname '%s': %v", hostname, err)
}
if len(addrs) == 0 {
return "", "", fmt.Errorf("Hostname '%s' resolved to no addresses.", hostname)
}
for _, addr := range addrs {
ip := net.ParseIP(addr)
if ip == nil {
return "", "", fmt.Errorf("Could not parse resolved address '%s' for hostname '%s'.", addr, hostname)
}
if !isGlobalIP(effectiveIP(ip)) {
return "", "", fmt.Errorf("URL resolves to a non-public address (%s), which is not allowed.", ip.String())
}
if resolvedIP == "" {
resolvedIP = ip.String()
}
}
return hostname, resolvedIP, nil
}
func schemeAllowed(scheme string) bool {
for _, s := range AllowedURLSchemes {
if s == scheme {
return true
}
}
return false
}
// effectiveIP unwraps IPv4-mapped IPv6 addresses (e.g. ::ffff:127.0.0.1) so
// the routability check sees the IPv4 form. Without this, an attacker could
// bypass the guard with an IPv4-mapped IPv6 representation of a private host.
func effectiveIP(ip net.IP) net.IP {
if v4 := ip.To4(); v4 != nil {
return v4
}
return ip
}
// isGlobalIP mirrors Python's ipaddress.IPv*Address.is_global: an address is
// global if it is none of {unspecified, loopback, multicast, link-local,
// private (including CGNAT and IPv6 ULA), benchmarking, documentation,
// reserved}.
func isGlobalIP(ip net.IP) bool {
if ip == nil || ip.IsUnspecified() || ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() || ip.IsPrivate() {
return false
}
if v4 := ip.To4(); v4 != nil {
// CGNAT 100.64.0.0/10 — not flagged by IsPrivate in older Go versions.
if v4[0] == 100 && v4[1]&0xC0 == 64 {
return false
}
// 192.0.0.0/24 reserved for IETF protocol assignments.
if v4[0] == 192 && v4[1] == 0 && v4[2] == 0 {
return false
}
// 192.0.2.0/24, 198.51.100.0/24, 203.0.113.0/24 documentation (TEST-NET-1/2/3).
if v4[0] == 192 && v4[1] == 0 && v4[2] == 2 {
return false
}
if v4[0] == 198 && v4[1] == 51 && v4[2] == 100 {
return false
}
if v4[0] == 203 && v4[1] == 0 && v4[2] == 113 {
return false
}
// 198.18.0.0/15 benchmarking.
if v4[0] == 198 && (v4[1] == 18 || v4[1] == 19) {
return false
}
// 240.0.0.0/4 reserved (excluding 255.255.255.255 which IsUnspecified misses).
if v4[0] >= 240 {
return false
}
} else if v6 := ip.To16(); v6 != nil {
// 2001:db8::/32 documentation prefix.
if v6[0] == 0x20 && v6[1] == 0x01 && v6[2] == 0x0d && v6[3] == 0xb8 {
return false
}
// 100::/64 discard-only address block.
if v6[0] == 0x01 && v6[1] == 0x00 && allZero(v6[2:8]) {
return false
}
}
return true
}
func allZero(b []byte) bool {
for _, x := range b {
if x != 0 {
return false
}
}
return true
}
// PinnedHTTPClient returns an HTTP client whose Transport rewrites every
// outbound dial for hostname:port to resolvedIP:port, closing the TOCTOU
// window between AssertURLSafe and the actual TCP connection. Pins are
// scoped to this client only.
func PinnedHTTPClient(hostname, resolvedIP string, timeout time.Duration) *http.Client {
dialer := &net.Dialer{
Timeout: timeout,
KeepAlive: 30 * time.Second,
}
transport := &http.Transport{
// Disable environment proxy: HTTP_PROXY / HTTPS_PROXY would route
// the connection through the proxy host instead of the pinned
// resolvedIP, bypassing the SSRF guard.
Proxy: nil,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, splitErr := net.SplitHostPort(addr)
if splitErr == nil && host == hostname && resolvedIP != "" {
return dialer.DialContext(ctx, network, net.JoinHostPort(resolvedIP, port))
}
return dialer.DialContext(ctx, network, addr)
},
TLSHandshakeTimeout: timeout,
ResponseHeaderTimeout: timeout,
ExpectContinueTimeout: 1 * time.Second,
ForceAttemptHTTP2: false,
}
return &http.Client{
Transport: transport,
Timeout: timeout,
}
}

View File

@@ -0,0 +1,156 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package utility
import (
"strings"
"testing"
)
func TestAssertURLSafe(t *testing.T) {
orig := LookupHost
defer func() { LookupHost = orig }()
type want struct {
errSubstr string
host string
ip string
}
cases := []struct {
name string
url string
ips []string
err string
want want
}{
{
name: "public IPv4",
url: "https://example.com/path",
ips: []string{"93.184.216.34"},
want: want{host: "example.com", ip: "93.184.216.34"},
},
{
name: "loopback rejected",
url: "http://localhost/x",
ips: []string{"127.0.0.1"},
want: want{errSubstr: "non-public address"},
},
{
name: "private 10.x rejected",
url: "http://internal/x",
ips: []string{"10.0.0.5"},
want: want{errSubstr: "non-public address"},
},
{
name: "private 192.168.x rejected",
url: "http://router/x",
ips: []string{"192.168.1.1"},
want: want{errSubstr: "non-public address"},
},
{
name: "CGNAT 100.64/10 rejected",
url: "http://carrier/x",
ips: []string{"100.64.1.1"},
want: want{errSubstr: "non-public address"},
},
{
name: "IPv4-mapped IPv6 loopback rejected",
url: "http://[::ffff:127.0.0.1]/x",
ips: []string{"::ffff:127.0.0.1"},
want: want{errSubstr: "non-public address"},
},
{
name: "link-local IPv6 rejected",
url: "http://[fe80::1]/x",
ips: []string{"fe80::1"},
want: want{errSubstr: "non-public address"},
},
{
name: "documentation 2001:db8 rejected",
url: "http://[2001:db8::1]/x",
ips: []string{"2001:db8::1"},
want: want{errSubstr: "non-public address"},
},
{
name: "disallowed scheme ftp",
url: "ftp://example.com/",
ips: []string{"93.184.216.34"},
want: want{errSubstr: "Disallowed URL scheme"},
},
{
name: "missing host",
url: "http:///path",
want: want{errSubstr: "missing a host"},
},
{
name: "resolution fails",
url: "http://nosuchhost.test/x",
err: "no such host",
want: want{errSubstr: "Could not resolve"},
},
{
name: "all addresses must be public",
url: "http://mixed.example.com/",
ips: []string{"93.184.216.34", "127.0.0.1"},
want: want{errSubstr: "non-public address"},
},
{
name: "literal IPv4 loopback rejected",
url: "http://127.0.0.1/",
ips: []string{"127.0.0.1"},
want: want{errSubstr: "non-public address"},
},
{
name: "documentation TEST-NET-3 rejected",
url: "http://stub/",
ips: []string{"203.0.113.5"},
want: want{errSubstr: "non-public address"},
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
LookupHost = func(host string) ([]string, error) {
if tc.err != "" {
return nil, &mockErr{tc.err}
}
return tc.ips, nil
}
host, ip, err := AssertURLSafe(tc.url)
if tc.want.errSubstr != "" {
if err == nil || !strings.Contains(err.Error(), tc.want.errSubstr) {
t.Fatalf("expected error containing %q, got %v", tc.want.errSubstr, err)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if host != tc.want.host {
t.Errorf("host: got %q, want %q", host, tc.want.host)
}
if ip != tc.want.ip {
t.Errorf("ip: got %q, want %q", ip, tc.want.ip)
}
})
}
}
type mockErr struct{ s string }
func (e *mockErr) Error() string { return e.s }