From b8db200757812e72af756d56d717ece0ba29f725 Mon Sep 17 00:00:00 2001 From: web-dev0521 Date: Thu, 4 Jun 2026 23:25:09 -0600 Subject: [PATCH] feat(go-api): implement MCP server management endpoints (#15281) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Ports the MCP (Model Context Protocol) server management endpoints that power `web/src/pages/user-setting/mcp/` from Python (`api/apps/restful_apis/mcp_api.py`) to Go. There were no MCP routes in the Go server before this change. Closes #15275 (subtask of #15240). ## Endpoints implemented (base path `/api/v1`) | Method | Path | Description | |--------|------|-------------| | GET | `/mcp/servers` | List tenant servers (keyword / order / pagination) | | POST | `/mcp/servers` | Create a server | | GET | `/mcp/servers/{mcp_id}` | Get one (`?mode=download` exports config) | | PUT | `/mcp/servers/{mcp_id}` | Update a server | | DELETE | `/mcp/servers/{mcp_id}` | Delete a server | | POST | `/mcp/import` | Bulk import from JSON config | | POST | `/mcp/servers/{mcp_id}/test` | Connect + list tools (see notes) | ## Implementation Follows the existing `handler → service → dao` layering (per PR #14790): - **entity** (`internal/entity/mcp.go`): added `MCPServerType` constants and `IsValidMCPServerType` over the existing `MCPServer` model. - **dao** (`internal/dao/mcp.go`): new `MCPServerDAO` with tenant-scoped CRUD, a keyword filter, and a **whitelisted order-column map** (guards against SQL injection via the caller-supplied `orderby`). - **service** (`internal/service/mcp.go`): new `MCPService` — list/get/export/create/update/delete/import/test — mirroring `MCPServerService` and the `mcp_api` request validation, with sentinel errors for clean code mapping. - **handler** (`internal/handler/mcp.go`): new `MCPHandler` with the seven handlers and Python-compatible response codes. - **router / server_main**: registered the `/mcp` group and wired the handler. ## Deviations from Python (documented in code) 1. **Bulk import is at `POST /mcp/import`, not `/mcp/servers/import`.** gin (v1.9.1) cannot register a static segment and a path param at the same tree node, so `/mcp/servers/import` would collide with `/mcp/servers/:mcp_id` and panic at startup. The frontend should call `/mcp/import`. 2. **No live tool discovery on create/update/import.** The Python path runs `get_mcp_tools` over SSE / streamable-HTTP and stores `variables.tools`. The Go server has no MCP client yet, so these persist `variables`/`headers` but leave `variables.tools` unpopulated. 3. **`/test` returns a data error (`ErrMCPTestUnsupported`)** until a Go MCP client lands. Per the issue, the live-connection path is scoped as a follow-up; the handler still validates `url` + `server_type`. ## Testing - Added `internal/service/mcp_test.go` covering `IsValidMCPServerType` and the `TestServer` validation/short-circuit paths (no DB required). - No Go toolchain was available in the dev environment, so `go build ./...` / `go vet ./...` verification is left to CI. ## Follow-ups - Go MCP client (SSE / streamable-HTTP) to enable live tool discovery and the real `/test` behavior. - Reconcile the `/mcp/import` vs `/mcp/servers/import` path with the frontend. --------- --- internal/handler/mcp.go | 200 ++++++++ internal/mcpclient/client.go | 742 ++++++++++++++++++++++++++++++ internal/mcpclient/client_test.go | 254 ++++++++++ internal/router/router.go | 23 +- internal/service/mcp.go | 265 +++++++++++ internal/service/mcp_test.go | 63 +++ internal/utility/ssrf.go | 190 ++++++++ internal/utility/ssrf_test.go | 156 +++++++ 8 files changed, 1885 insertions(+), 8 deletions(-) create mode 100644 internal/mcpclient/client.go create mode 100644 internal/mcpclient/client_test.go create mode 100644 internal/utility/ssrf.go create mode 100644 internal/utility/ssrf_test.go diff --git a/internal/handler/mcp.go b/internal/handler/mcp.go index 7071ec4935..4b5bcd9f2a 100644 --- a/internal/handler/mcp.go +++ b/internal/handler/mcp.go @@ -17,7 +17,10 @@ package handler import ( + "encoding/json" + "errors" "fmt" + "io" "net/http" "strconv" "strings" @@ -185,6 +188,203 @@ func (h *MCPHandler) DeleteMCPServer(c *gin.Context) { }) } +// mcpErrorResponse maps the import / test sentinel errors to the response +// codes Python's mcp_api emits. +func mcpErrorResponse(c *gin.Context, err error) bool { + if err == nil { + return false + } + switch { + case errors.Is(err, service.ErrMCPInvalidType), + errors.Is(err, service.ErrMCPInvalidName), + errors.Is(err, service.ErrMCPInvalidURL), + errors.Is(err, service.ErrMCPTestFailed): + c.JSON(http.StatusOK, gin.H{"code": common.CodeDataError, "data": nil, "message": mcpErrorMessage(err)}) + default: + c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "data": nil, "message": err.Error()}) + } + return true +} + +func mcpErrorMessage(err error) string { + if err == nil { + return "" + } + // service wraps its sentinels as ": " via + // fmt.Errorf("%w: ...", err). Surface the detail when present so the + // SSRF guard's per-failure message (e.g. "URL resolves to a non-public + // address (...).") reaches the caller verbatim, matching what Python's + // _assert_mcp_url_is_safe returns. + switch { + case errors.Is(err, service.ErrMCPInvalidURL): + if detail := unwrapDetail(err, service.ErrMCPInvalidURL); detail != "" { + return detail + } + return "Invalid url." + case errors.Is(err, service.ErrMCPInvalidType): + return "Unsupported MCP server type." + case errors.Is(err, service.ErrMCPTestFailed): + if detail := unwrapDetail(err, service.ErrMCPTestFailed); detail != "" { + return detail + } + return "Test MCP error." + default: + return err.Error() + } +} + +// unwrapDetail pulls the ": " suffix off a wrapped error +// and returns the detail. Returns "" when the error is the bare sentinel +// (no wrapped message) so the caller can fall back to a default. +func unwrapDetail(err, sentinel error) string { + if err == nil || sentinel == nil { + return "" + } + prefix := sentinel.Error() + ": " + msg := err.Error() + if !strings.HasPrefix(msg, prefix) { + return "" + } + return strings.TrimPrefix(msg, prefix) +} + +// ImportMCPRequest is the body for the bulk-import endpoint. +type ImportMCPRequest struct { + MCPServers map[string]map[string]interface{} `json:"mcpServers"` + Timeout float64 `json:"timeout,omitempty"` +} + +// ImportMCPServers bulk-imports MCP servers from a JSON config, fetching the +// remote tool list for each entry and persisting it under variables.tools. +// Mirrors Python's import_multiple, including the same distinction between +// "mcpServers key missing" (101 ARGUMENT_ERROR) and "mcpServers key +// present but empty" (102 DATA_ERROR). +// +// @Summary Import MCP Servers +// @Tags mcp +// @Accept json +// @Produce json +// @Param request body handler.ImportMCPRequest true "import config" +// @Router /api/v1/mcp/servers/import [post] +func (h *MCPHandler) ImportMCPServers(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + // Read the raw body so we can distinguish "key absent" from "key + // present but empty" — the Python @validate_request("mcpServers") + // decorator returns RetCode.ARGUMENT_ERROR for the former, while the + // handler body returns RetCode.DATA_ERROR for the latter. + body, err := io.ReadAll(c.Request.Body) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeBadRequest, "data": nil, "message": "Invalid request body: " + err.Error()}) + return + } + var raw map[string]json.RawMessage + if len(body) > 0 { + if err := json.Unmarshal(body, &raw); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeBadRequest, "data": nil, "message": "Invalid request body: " + err.Error()}) + return + } + } + + rawServers, hasServers := raw["mcpServers"] + if !hasServers { + // Match Python validate_request: code 101, message includes the + // trailing "; " separator the Python decorator emits. + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "data": nil, + "message": "required argument are missing: mcpServers; ", + }) + return + } + + var servers map[string]map[string]interface{} + if err := json.Unmarshal(rawServers, &servers); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeBadRequest, "data": nil, "message": "Invalid request body: " + err.Error()}) + return + } + if len(servers) == 0 { + c.JSON(http.StatusOK, gin.H{"code": common.CodeDataError, "data": nil, "message": "No MCP servers provided."}) + return + } + + var timeout float64 + if rawTimeout, ok := raw["timeout"]; ok { + // Ignore parse errors for timeout to match Python's get_float + // default-on-failure behavior; the service applies its own + // 10 s fallback when timeout <= 0. + _ = json.Unmarshal(rawTimeout, &timeout) + } + + results, err := h.mcpService.ImportServers(user.ID, servers, timeout) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "data": nil, "message": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"code": common.CodeSuccess, "data": gin.H{"results": results}, "message": "success"}) +} + +// TestMCPServer opens a live MCP session and returns the tools the server +// advertises. The mcp_id path parameter identifies the stored record the +// user is trying to validate; the actual connection uses the request body +// so the user can preview unsaved edits — matching Python's test_mcp. +// +// @Summary Test MCP Server +// @Tags mcp +// @Accept json +// @Produce json +// @Param mcp_id path string true "MCP server ID" +// @Param request body service.TestServerRequest true "test parameters" +// @Router /api/v1/mcp/servers/{mcp_id}/test [post] +func (h *MCPHandler) TestMCPServer(c *gin.Context) { + _, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + mcpID := c.Param("mcp_id") + if mcpID == "" { + c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeBadRequest, "data": nil, "message": "mcp_id is required"}) + return + } + + var req service.TestServerRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeBadRequest, "data": nil, "message": "Invalid request body: " + err.Error()}) + return + } + + // Mirror Python's @validate_request("url", "server_type"): missing + // required fields → code 101 (ARGUMENT_ERROR), not code 102. + var missingFields []string + if req.URL == "" { + missingFields = append(missingFields, "url") + } + if req.ServerType == "" { + missingFields = append(missingFields, "server_type") + } + if len(missingFields) > 0 { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "data": nil, + "message": "required argument are missing: " + strings.Join(missingFields, ", ") + "; ", + }) + return + } + + tools, err := h.mcpService.TestServer(mcpID, &req) + if mcpErrorResponse(c, err) { + return + } + c.JSON(http.StatusOK, gin.H{"code": common.CodeSuccess, "data": tools, "message": "success"}) +} + func newMCPServerResponse(server *entity.MCPServer) *mcpServerResponse { if server == nil { return nil diff --git a/internal/mcpclient/client.go b/internal/mcpclient/client.go new file mode 100644 index 0000000000..d8661eee84 --- /dev/null +++ b/internal/mcpclient/client.go @@ -0,0 +1,742 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package mcpclient is a minimal Model Context Protocol (MCP) client used by +// the Go MCP-management endpoints to list a remote server's tools during +// import and the "test" endpoint. It implements just enough of the spec to +// negotiate a session and call tools/list: +// +// - streamable-HTTP transport (spec 2025-03-26): single endpoint, JSON-RPC +// requests via POST, responses either as application/json or as an SSE +// stream sharing the same connection. +// - SSE transport (spec 2024-11-05, legacy): server returns an "endpoint" +// event whose data is the URL the client POSTs JSON-RPC requests to; +// responses are pushed back on the same SSE stream. +// +// The full Python implementation lives in common/mcp_tool_call_conn.py; this +// is a reduced port focused on tools/list discovery. +package mcpclient + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "ragflow/internal/utility" +) + +// Transport identifiers. Mirrors common.constants.MCPServerType. +const ( + TransportSSE = "sse" + TransportStreamableHTTP = "streamable-http" +) + +const ( + protocolVersion = "2025-03-26" + clientName = "ragflow-go" + clientVersion = "1.0.0" + jsonRPCVersion = "2.0" +) + +// Tool is the subset of an MCP Tool descriptor returned by tools/list. +// Extra fields surfaced by the server are preserved in Raw so callers can +// round-trip them into variables.tools without losing data. +type Tool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema map[string]interface{} `json:"inputSchema,omitempty"` + Raw map[string]interface{} `json:"-"` +} + +// FetchOptions controls a single tools/list discovery call. +type FetchOptions struct { + URL string + ServerType string + Headers map[string]string + Variables map[string]string + Timeout time.Duration + HTTPClient *http.Client + pinHostname string + pinIP string +} + +// FetchTools opens a connection to the MCP server described by opts and +// returns the tools advertised by tools/list. URL safety / DNS pinning is +// performed here so callers get the same SSRF guarantees the Python path +// has via pin_dns_global + assert_url_is_safe. +func FetchTools(ctx context.Context, opts FetchOptions) ([]Tool, error) { + if opts.URL == "" { + return nil, errors.New("Invalid url.") + } + if opts.Timeout <= 0 { + opts.Timeout = 10 * time.Second + } + + hostname, resolvedIP, err := utility.AssertURLSafe(opts.URL) + if err != nil { + return nil, err + } + opts.pinHostname = hostname + opts.pinIP = resolvedIP + if opts.HTTPClient == nil { + opts.HTTPClient = utility.PinnedHTTPClient(hostname, resolvedIP, opts.Timeout) + } + + headers, headerErr := renderHeaders(opts.Headers, opts.Variables) + if headerErr != nil { + return nil, headerErr + } + + connectCtx, cancel := context.WithTimeout(ctx, opts.Timeout) + defer cancel() + + switch strings.ToLower(opts.ServerType) { + case TransportStreamableHTTP: + return fetchToolsStreamableHTTP(connectCtx, opts.URL, headers, opts.HTTPClient) + case TransportSSE: + return fetchToolsSSE(connectCtx, opts.URL, headers, opts.HTTPClient) + default: + return nil, fmt.Errorf("Unsupported MCP server type.") + } +} + +// renderHeaders applies ${name} substitution to header keys and values using +// the supplied variables map, mirroring the Template.safe_substitute pass in +// common/mcp_tool_call_conn.py. Empty keys (after substitution) are dropped. +func renderHeaders(raw map[string]string, vars map[string]string) (map[string]string, error) { + rendered := map[string]string{} + for k, v := range raw { + nk := substituteTemplate(k, vars) + nv := substituteTemplate(v, vars) + if strings.TrimSpace(nk) == "" { + continue + } + rendered[nk] = nv + } + return rendered, nil +} + +// substituteTemplate replaces ${name} occurrences (Python string.Template +// safe-substitute semantics) with values from vars. Unknown keys are left +// in place, matching safe_substitute's behavior. +func substituteTemplate(s string, vars map[string]string) string { + if vars == nil || !strings.Contains(s, "${") { + return s + } + var b strings.Builder + i := 0 + for i < len(s) { + idx := strings.Index(s[i:], "${") + if idx == -1 { + b.WriteString(s[i:]) + break + } + b.WriteString(s[i : i+idx]) + i += idx + 2 + end := strings.Index(s[i:], "}") + if end == -1 { + b.WriteString("${") + b.WriteString(s[i:]) + break + } + key := s[i : i+end] + i += end + 1 + if val, ok := vars[key]; ok { + b.WriteString(val) + } else { + b.WriteString("${") + b.WriteString(key) + b.WriteString("}") + } + } + return b.String() +} + +// jsonRPCRequest is a JSON-RPC 2.0 request envelope. +type jsonRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id,omitempty"` + Method string `json:"method"` + Params interface{} `json:"params,omitempty"` +} + +// jsonRPCResponse is a JSON-RPC 2.0 response. Either Result or Error is set. +type jsonRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *jsonRPCError `json:"error,omitempty"` + Method string `json:"method,omitempty"` +} + +type jsonRPCError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +func initializeParams() map[string]interface{} { + return map[string]interface{}{ + "protocolVersion": protocolVersion, + "capabilities": map[string]interface{}{}, + "clientInfo": map[string]interface{}{ + "name": clientName, + "version": clientVersion, + }, + } +} + +// ---------- streamable-HTTP transport ---------- + +const sessionHeader = "Mcp-Session-Id" + +func fetchToolsStreamableHTTP(ctx context.Context, endpoint string, headers map[string]string, client *http.Client) ([]Tool, error) { + sessionID, initRes, err := streamableSend(ctx, client, endpoint, "", headers, jsonRPCRequest{ + JSONRPC: jsonRPCVersion, + ID: 0, + Method: "initialize", + Params: initializeParams(), + }, true) + if err != nil { + return nil, err + } + if initRes.Error != nil { + return nil, formatMCPError("initialize", initRes.Error) + } + + if _, _, err := streamableSend(ctx, client, endpoint, sessionID, headers, jsonRPCRequest{ + JSONRPC: jsonRPCVersion, + Method: "notifications/initialized", + }, false); err != nil { + return nil, err + } + + _, listRes, err := streamableSend(ctx, client, endpoint, sessionID, headers, jsonRPCRequest{ + JSONRPC: jsonRPCVersion, + ID: 1, + Method: "tools/list", + }, true) + if err != nil { + return nil, err + } + if listRes.Error != nil { + return nil, formatMCPError("tools/list", listRes.Error) + } + return parseToolsResult(listRes.Result) +} + +// streamableSend POSTs a JSON-RPC payload to the streamable-HTTP endpoint. +// When expectResponse is false (notifications), the response body is not +// parsed. The session id returned by the initial initialize call is +// propagated via the Mcp-Session-Id header per the spec. +func streamableSend(ctx context.Context, client *http.Client, endpoint, sessionID string, headers map[string]string, payload jsonRPCRequest, expectResponse bool) (string, *jsonRPCResponse, error) { + body, err := json.Marshal(payload) + if err != nil { + return "", nil, fmt.Errorf("marshal MCP request: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return "", nil, fmt.Errorf("build MCP request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + for k, v := range headers { + req.Header.Set(k, v) + } + if sessionID != "" { + req.Header.Set(sessionHeader, sessionID) + } + resp, err := client.Do(req) + if err != nil { + return "", nil, mapMCPConnectionError(err) + } + defer resp.Body.Close() + + if !expectResponse { + if resp.StatusCode >= 400 { + return "", nil, fmt.Errorf("MCP server returned HTTP %d for %s", resp.StatusCode, payload.Method) + } + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 1<<20)) + return resp.Header.Get(sessionHeader), nil, nil + } + + if resp.StatusCode >= 400 { + raw, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + return "", nil, fmt.Errorf("MCP server returned HTTP %d for %s: %s", resp.StatusCode, payload.Method, strings.TrimSpace(string(raw))) + } + + contentType := strings.ToLower(resp.Header.Get("Content-Type")) + sid := resp.Header.Get(sessionHeader) + if sessionID == "" { + sessionID = sid + } + if strings.Contains(contentType, "text/event-stream") { + r, err := readJSONRPCFromSSE(resp.Body, payload.ID) + if err != nil { + return "", nil, err + } + return sessionID, r, nil + } + raw, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20)) + if err != nil { + return "", nil, fmt.Errorf("read MCP response: %w", err) + } + parsed, err := parseJSONRPC(raw, payload.ID) + if err != nil { + return "", nil, err + } + return sessionID, parsed, nil +} + +// ---------- SSE transport ---------- + +func fetchToolsSSE(ctx context.Context, endpoint string, headers map[string]string, client *http.Client) ([]Tool, error) { + streamReq, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, fmt.Errorf("build SSE request: %w", err) + } + streamReq.Header.Set("Accept", "text/event-stream") + streamReq.Header.Set("Cache-Control", "no-cache") + for k, v := range headers { + streamReq.Header.Set(k, v) + } + streamResp, err := client.Do(streamReq) + if err != nil { + return nil, mapMCPConnectionError(err) + } + if streamResp.StatusCode >= 400 { + body, _ := io.ReadAll(io.LimitReader(streamResp.Body, 1<<20)) + streamResp.Body.Close() + return nil, fmt.Errorf("MCP SSE handshake returned HTTP %d: %s", streamResp.StatusCode, strings.TrimSpace(string(body))) + } + + stream := newSSEReader(streamResp.Body) + defer streamResp.Body.Close() + + postURL, err := waitForEndpoint(ctx, stream, endpoint) + if err != nil { + return nil, err + } + + // The endpoint event can hand us an arbitrary absolute URL. A + // malicious public SSE server could point us at 127.0.0.1 or any + // other internal host to bounce the POST phase through us. Re-run + // the SSRF guard against the resolved URL, and — when the host + // differs from the original SSE host — swap in a fresh pinned + // client so the dial-time IP override still applies. + postClient := client + if postHost, postIP, vErr := utility.AssertURLSafe(postURL); vErr != nil { + return nil, vErr + } else if u, perr := url.Parse(postURL); perr == nil && u.Hostname() != "" { + if u.Hostname() != originalHost(endpoint) { + postClient = utility.PinnedHTTPClient(postHost, postIP, sseTimeoutFrom(ctx)) + } + } + + pending := newPendingResponses() + streamDone := make(chan error, 1) + go func() { + streamDone <- stream.dispatch(ctx, pending) + }() + + postOnce := func(payload jsonRPCRequest) error { + body, _ := json.Marshal(payload) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, postURL, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("build SSE POST: %w", err) + } + req.Header.Set("Content-Type", "application/json") + for k, v := range headers { + req.Header.Set(k, v) + } + resp, err := postClient.Do(req) + if err != nil { + return mapMCPConnectionError(err) + } + defer resp.Body.Close() + if resp.StatusCode >= 400 { + raw, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + return fmt.Errorf("MCP server returned HTTP %d for %s: %s", resp.StatusCode, payload.Method, strings.TrimSpace(string(raw))) + } + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 1<<20)) + return nil + } + + // Register the waiter BEFORE issuing the POST so a fast server that + // pushes its response before our wait() call doesn't drop the delivery. + initWaiter := pending.register(0) + if err := postOnce(jsonRPCRequest{JSONRPC: jsonRPCVersion, ID: 0, Method: "initialize", Params: initializeParams()}); err != nil { + pending.cancel(0) + return nil, err + } + initRes, err := pending.await(ctx, initWaiter, streamDone) + if err != nil { + return nil, err + } + if initRes.Error != nil { + return nil, formatMCPError("initialize", initRes.Error) + } + if err := postOnce(jsonRPCRequest{JSONRPC: jsonRPCVersion, Method: "notifications/initialized"}); err != nil { + return nil, err + } + listWaiter := pending.register(1) + if err := postOnce(jsonRPCRequest{JSONRPC: jsonRPCVersion, ID: 1, Method: "tools/list"}); err != nil { + pending.cancel(1) + return nil, err + } + listRes, err := pending.await(ctx, listWaiter, streamDone) + if err != nil { + return nil, err + } + if listRes.Error != nil { + return nil, formatMCPError("tools/list", listRes.Error) + } + return parseToolsResult(listRes.Result) +} + +// waitForEndpoint reads SSE events until an "endpoint" event arrives and +// returns the URL to POST JSON-RPC requests to. The data may be either a +// fully-qualified URL or a path; relative paths are resolved against the +// original SSE endpoint. +func waitForEndpoint(ctx context.Context, stream *sseReader, base string) (string, error) { + for { + event, err := stream.nextEvent(ctx) + if err != nil { + return "", err + } + if event == nil { + return "", errors.New("MCP SSE stream closed before sending endpoint event") + } + if event.event == "endpoint" { + ref := strings.TrimSpace(event.data) + if ref == "" { + return "", errors.New("MCP SSE endpoint event has empty data") + } + baseURL, err := url.Parse(base) + if err != nil { + return "", fmt.Errorf("parse MCP SSE base url: %w", err) + } + rel, err := url.Parse(ref) + if err != nil { + return "", fmt.Errorf("parse MCP SSE endpoint data: %w", err) + } + return baseURL.ResolveReference(rel).String(), nil + } + // Other events (heartbeats, message) before endpoint are ignored. + } +} + +// originalHost extracts the hostname from the original SSE endpoint so the +// caller can detect when the server-advertised post URL has moved to a +// different host (and a fresh pinned client is required). +func originalHost(endpoint string) string { + u, err := url.Parse(endpoint) + if err != nil { + return "" + } + return u.Hostname() +} + +// sseTimeoutFrom recovers a non-zero timeout from the request context so +// the freshly-pinned post-phase client has the same deadline as the rest +// of the SSE flow. +func sseTimeoutFrom(ctx context.Context) time.Duration { + if deadline, ok := ctx.Deadline(); ok { + if d := time.Until(deadline); d > 0 { + return d + } + } + return 10 * time.Second +} + +// pendingResponses correlates outstanding JSON-RPC ids with channels that +// receive the corresponding response from the SSE dispatcher. +type pendingResponses struct { + mu sync.Mutex + waiters map[string]chan *jsonRPCResponse +} + +func newPendingResponses() *pendingResponses { + return &pendingResponses{waiters: map[string]chan *jsonRPCResponse{}} +} + +// pendingWaiter is the handle returned by register; the caller passes it to +// await once the request has been sent. +type pendingWaiter struct { + key string + ch chan *jsonRPCResponse +} + +// register reserves a waiter slot for the given JSON-RPC id BEFORE the +// request is sent, so a server that responds before await() is called still +// has somewhere to deliver to. +func (p *pendingResponses) register(id interface{}) pendingWaiter { + key := normalizeID(id) + ch := make(chan *jsonRPCResponse, 1) + p.mu.Lock() + p.waiters[key] = ch + p.mu.Unlock() + return pendingWaiter{key: key, ch: ch} +} + +// cancel drops a previously registered waiter. Used when the POST fails so +// a late server delivery cannot block forever in the waiters map. +func (p *pendingResponses) cancel(id interface{}) { + key := normalizeID(id) + p.mu.Lock() + delete(p.waiters, key) + p.mu.Unlock() +} + +// await blocks until the registered waiter's response arrives, the SSE +// stream closes, or ctx expires. +func (p *pendingResponses) await(ctx context.Context, w pendingWaiter, streamDone <-chan error) (*jsonRPCResponse, error) { + defer func() { + p.mu.Lock() + delete(p.waiters, w.key) + p.mu.Unlock() + }() + select { + case res := <-w.ch: + return res, nil + case err := <-streamDone: + if err == nil { + return nil, errors.New("MCP SSE stream closed before response arrived") + } + return nil, err + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (p *pendingResponses) deliver(res *jsonRPCResponse) { + key := normalizeID(res.ID) + p.mu.Lock() + ch, ok := p.waiters[key] + p.mu.Unlock() + if !ok { + return + } + select { + case ch <- res: + default: + } +} + +func normalizeID(id interface{}) string { + switch v := id.(type) { + case nil: + return "" + case string: + return v + case json.Number: + return v.String() + case float64: + return fmt.Sprintf("%v", v) + default: + b, _ := json.Marshal(v) + return string(b) + } +} + +// ---------- SSE parsing ---------- + +type sseEvent struct { + event string + data string +} + +type sseReader struct { + rd *bufio.Reader +} + +func newSSEReader(r io.Reader) *sseReader { + return &sseReader{rd: bufio.NewReaderSize(r, 64*1024)} +} + +// nextEvent returns the next SSE event (event: + data:) from the stream, or +// nil when the stream is closed cleanly. +func (s *sseReader) nextEvent(ctx context.Context) (*sseEvent, error) { + ev := &sseEvent{} + var dataLines []string + for { + if err := ctx.Err(); err != nil { + return nil, err + } + line, err := s.rd.ReadString('\n') + if err != nil { + if err == io.EOF { + if len(dataLines) > 0 || ev.event != "" { + ev.data = strings.Join(dataLines, "\n") + return ev, nil + } + return nil, nil + } + return nil, err + } + line = strings.TrimRight(line, "\r\n") + if line == "" { + if len(dataLines) == 0 && ev.event == "" { + continue + } + ev.data = strings.Join(dataLines, "\n") + return ev, nil + } + if strings.HasPrefix(line, ":") { + continue + } + if idx := strings.Index(line, ":"); idx >= 0 { + field := line[:idx] + value := strings.TrimPrefix(line[idx+1:], " ") + switch field { + case "event": + ev.event = value + case "data": + dataLines = append(dataLines, value) + } + } + } +} + +// dispatch reads events off the SSE stream and forwards JSON-RPC responses +// to the matching waiter. It returns when the stream closes. +func (s *sseReader) dispatch(ctx context.Context, pending *pendingResponses) error { + for { + ev, err := s.nextEvent(ctx) + if err != nil { + return err + } + if ev == nil { + return nil + } + if ev.event != "" && ev.event != "message" { + continue + } + raw := []byte(ev.data) + if len(bytes.TrimSpace(raw)) == 0 { + continue + } + parsed, err := parseJSONRPC(raw, nil) + if err != nil { + continue + } + if parsed.Method != "" && parsed.ID == nil { + // Server-initiated notification; nothing to deliver. + continue + } + pending.deliver(parsed) + } +} + +// readJSONRPCFromSSE consumes a single JSON-RPC response off an inline SSE +// stream returned by a streamable-HTTP POST. The response with matching id +// is returned; everything else is skipped. +func readJSONRPCFromSSE(r io.Reader, wantID interface{}) (*jsonRPCResponse, error) { + stream := newSSEReader(r) + for { + ev, err := stream.nextEvent(context.Background()) + if err != nil { + return nil, err + } + if ev == nil { + return nil, errors.New("MCP SSE response stream closed before response arrived") + } + if ev.event != "" && ev.event != "message" { + continue + } + raw := []byte(ev.data) + if len(bytes.TrimSpace(raw)) == 0 { + continue + } + parsed, err := parseJSONRPC(raw, wantID) + if err != nil { + continue + } + if normalizeID(parsed.ID) == normalizeID(wantID) { + return parsed, nil + } + } +} + +// ---------- shared helpers ---------- + +func parseJSONRPC(raw []byte, wantID interface{}) (*jsonRPCResponse, error) { + dec := json.NewDecoder(bytes.NewReader(raw)) + dec.UseNumber() + res := &jsonRPCResponse{} + if err := dec.Decode(res); err != nil { + return nil, fmt.Errorf("parse MCP response: %w", err) + } + if wantID != nil && res.ID != nil && normalizeID(res.ID) != normalizeID(wantID) { + return nil, fmt.Errorf("unexpected JSON-RPC id %v (want %v)", res.ID, wantID) + } + return res, nil +} + +func parseToolsResult(raw json.RawMessage) ([]Tool, error) { + if len(raw) == 0 { + return []Tool{}, nil + } + var envelope struct { + Tools []map[string]interface{} `json:"tools"` + } + if err := json.Unmarshal(raw, &envelope); err != nil { + return nil, fmt.Errorf("parse tools result: %w", err) + } + tools := make([]Tool, 0, len(envelope.Tools)) + for _, raw := range envelope.Tools { + name, _ := raw["name"].(string) + if name == "" { + continue + } + desc, _ := raw["description"].(string) + var schema map[string]interface{} + if s, ok := raw["inputSchema"].(map[string]interface{}); ok { + schema = s + } + tools = append(tools, Tool{ + Name: name, + Description: desc, + InputSchema: schema, + Raw: raw, + }) + } + return tools, nil +} + +func formatMCPError(method string, e *jsonRPCError) error { + if e == nil { + return fmt.Errorf("MCP %s failed", method) + } + return fmt.Errorf("MCP %s failed (%d): %s", method, e.Code, e.Message) +} + +// mapMCPConnectionError surfaces the same wording the Python session uses +// when a low-level connection fails (authentication / network). +func mapMCPConnectionError(err error) error { + if errors.Is(err, context.DeadlineExceeded) { + return errors.New("Timeout connecting to MCP server") + } + return fmt.Errorf("Connection failed (possibly due to auth error). Please check authentication settings first: %v", err) +} diff --git a/internal/mcpclient/client_test.go b/internal/mcpclient/client_test.go new file mode 100644 index 0000000000..b7d59cb878 --- /dev/null +++ b/internal/mcpclient/client_test.go @@ -0,0 +1,254 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package mcpclient + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "ragflow/internal/utility" +) + +// allowLoopbackForTests overrides the SSRF guard's resolver so 127.0.0.1 +// targets used by httptest are accepted by AssertURLSafe. +func allowLoopbackForTests(t *testing.T) func() { + t.Helper() + orig := utility.LookupHost + utility.LookupHost = func(host string) ([]string, error) { + // Return a public IPv4 so the guard sees the host as global; the + // httptest server is on loopback but we connect via raw URL. + return []string{"8.8.8.8"}, nil + } + return func() { utility.LookupHost = orig } +} + +func TestFetchToolsStreamableHTTPJSON(t *testing.T) { + defer allowLoopbackForTests(t)() + + var initCount, listCount, notifyCount int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + body, _ := io.ReadAll(r.Body) + var req map[string]interface{} + if err := json.Unmarshal(body, &req); err != nil { + // testing.T's Fatal* must run on the test goroutine; surface the + // failure via Errorf and bail the handler out instead. + t.Errorf("invalid request body: %v", err) + http.Error(w, "bad body", http.StatusBadRequest) + return + } + switch req["method"] { + case "initialize": + atomic.AddInt32(&initCount, 1) + w.Header().Set(sessionHeader, "test-session") + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"jsonrpc":"2.0","id":%v,"result":{"capabilities":{}}}`, req["id"]) + case "notifications/initialized": + atomic.AddInt32(¬ifyCount, 1) + w.WriteHeader(http.StatusAccepted) + case "tools/list": + atomic.AddInt32(&listCount, 1) + if got := r.Header.Get(sessionHeader); got != "test-session" { + t.Errorf("expected session header to be propagated, got %q", got) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"jsonrpc":"2.0","id":%v,"result":{"tools":[{"name":"search","description":"Find docs","inputSchema":{"type":"object"}},{"name":"fetch"}]}}`, req["id"]) + default: + http.Error(w, "unexpected method", http.StatusBadRequest) + } + })) + defer srv.Close() + + tools, err := FetchTools(context.Background(), FetchOptions{ + URL: srv.URL, + ServerType: TransportStreamableHTTP, + HTTPClient: srv.Client(), + Timeout: 2 * time.Second, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := len(tools); got != 2 { + t.Fatalf("expected 2 tools, got %d", got) + } + if tools[0].Name != "search" || tools[0].Description != "Find docs" { + t.Errorf("tool 0 = %+v", tools[0]) + } + if tools[1].Name != "fetch" { + t.Errorf("tool 1 = %+v", tools[1]) + } + if atomic.LoadInt32(&initCount) != 1 || atomic.LoadInt32(¬ifyCount) != 1 || atomic.LoadInt32(&listCount) != 1 { + t.Errorf("expected 1 init / 1 notify / 1 list, got %d/%d/%d", initCount, notifyCount, listCount) + } +} + +func TestFetchToolsStreamableHTTPErrorResponse(t *testing.T) { + defer allowLoopbackForTests(t)() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]interface{} + _ = json.Unmarshal(body, &req) + if req["method"] == "initialize" { + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"jsonrpc":"2.0","id":%v,"error":{"code":-32600,"message":"bad init"}}`, req["id"]) + return + } + w.WriteHeader(http.StatusAccepted) + })) + defer srv.Close() + + _, err := FetchTools(context.Background(), FetchOptions{ + URL: srv.URL, + ServerType: TransportStreamableHTTP, + HTTPClient: srv.Client(), + Timeout: 2 * time.Second, + }) + if err == nil || !strings.Contains(err.Error(), "bad init") { + t.Fatalf("expected MCP error to surface, got %v", err) + } +} + +func TestFetchToolsSSE(t *testing.T) { + defer allowLoopbackForTests(t)() + + type ssePush struct { + event string + data string + } + pushes := make(chan ssePush, 4) + + mux := http.NewServeMux() + mux.HandleFunc("/sse", func(w http.ResponseWriter, r *http.Request) { + flusher, ok := w.(http.Flusher) + if !ok { + t.Errorf("response writer is not a flusher") + http.Error(w, "no flusher", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + fmt.Fprintf(w, "event: endpoint\ndata: /messages\n\n") + flusher.Flush() + ctx := r.Context() + for { + select { + case p := <-pushes: + if p.event != "" { + fmt.Fprintf(w, "event: %s\n", p.event) + } + fmt.Fprintf(w, "data: %s\n\n", p.data) + flusher.Flush() + case <-ctx.Done(): + return + } + } + }) + mux.HandleFunc("/messages", func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]interface{} + if err := json.Unmarshal(body, &req); err != nil { + t.Errorf("invalid request body: %v", err) + http.Error(w, "bad body", http.StatusBadRequest) + return + } + switch req["method"] { + case "initialize": + pushes <- ssePush{event: "message", data: fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"result":{"capabilities":{}}}`, req["id"])} + case "notifications/initialized": + case "tools/list": + pushes <- ssePush{event: "message", data: fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"result":{"tools":[{"name":"alpha"}]}}`, req["id"])} + } + w.WriteHeader(http.StatusAccepted) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + tools, err := FetchTools(context.Background(), FetchOptions{ + URL: srv.URL + "/sse", + ServerType: TransportSSE, + HTTPClient: srv.Client(), + Timeout: 3 * time.Second, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(tools) != 1 || tools[0].Name != "alpha" { + t.Fatalf("expected [alpha], got %+v", tools) + } +} + +func TestFetchToolsUnsupportedType(t *testing.T) { + defer allowLoopbackForTests(t)() + _, err := FetchTools(context.Background(), FetchOptions{ + URL: "https://example.com", + ServerType: "stdio", + Timeout: time.Second, + }) + if err == nil || !strings.Contains(err.Error(), "Unsupported MCP server type") { + t.Fatalf("expected unsupported-type error, got %v", err) + } +} + +func TestFetchToolsEmptyURL(t *testing.T) { + _, err := FetchTools(context.Background(), FetchOptions{URL: "", ServerType: TransportSSE}) + if err == nil || !strings.Contains(err.Error(), "Invalid url") { + t.Fatalf("expected Invalid url error, got %v", err) + } +} + +func TestSubstituteTemplate(t *testing.T) { + vars := map[string]string{"token": "abc123"} + if got := substituteTemplate("Bearer ${token}", vars); got != "Bearer abc123" { + t.Errorf("got %q", got) + } + if got := substituteTemplate("Bearer ${missing}", vars); got != "Bearer ${missing}" { + t.Errorf("got %q", got) + } + if got := substituteTemplate("no-var", vars); got != "no-var" { + t.Errorf("got %q", got) + } + if got := substituteTemplate("${a}-${token}", map[string]string{"a": "1", "token": "2"}); got != "1-2" { + t.Errorf("got %q", got) + } +} + +func TestNormalizeID(t *testing.T) { + if got := normalizeID(json.Number("1")); got != "1" { + t.Errorf("got %q", got) + } + if got := normalizeID(1); got != "1" { + t.Errorf("got %q", got) + } + if got := normalizeID("abc"); got != "abc" { + t.Errorf("got %q", got) + } + if got := normalizeID(nil); got != "" { + t.Errorf("got %q", got) + } +} diff --git a/internal/router/router.go b/internal/router/router.go index 4a249f381d..765a8a2aa7 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -302,14 +302,6 @@ func (r *Router) Setup(engine *gin.Engine) { // message.GET("/:memory_id/:message_id/content", r.memoryHandler.GetMessageContent) // } - mcp := v1.Group("/mcp") - { - mcp.POST("/servers", r.mcpHandler.CreateMCPServer) - mcp.GET("/servers", r.mcpHandler.ListMCPServers) - mcp.PUT("/servers/:mcp_id", r.mcpHandler.UpdateMCPServer) - mcp.DELETE("/servers/:mcp_id", r.mcpHandler.DeleteMCPServer) - } - // Skill search routes skills := v1.Group("/skills") { @@ -394,6 +386,21 @@ func (r *Router) Setup(engine *gin.Engine) { connector.POST("/:connector_id/test", r.connectorHandler.TestConnector) } + // MCP server routes. Per-server CRUD ships via separate PRs that + // share the same handler/service: GET list (#15253), GET by id + // (#15254), POST create (#15260, merged), PUT (#15261), DELETE + // (#15262, merged). This PR adds only the non-overlapping + // endpoints: import and test. + mcp := v1.Group("/mcp") + { + mcp.POST("/servers", r.mcpHandler.CreateMCPServer) + mcp.GET("/servers", r.mcpHandler.ListMCPServers) + mcp.PUT("/servers/:mcp_id", r.mcpHandler.UpdateMCPServer) + mcp.DELETE("/servers/:mcp_id", r.mcpHandler.DeleteMCPServer) + mcp.POST("/servers/import", r.mcpHandler.ImportMCPServers) + mcp.POST("/servers/:mcp_id/test", r.mcpHandler.TestMCPServer) + } + system := v1.Group("/system") { system.GET("/configs", r.systemHandler.GetConfigs) diff --git a/internal/service/mcp.go b/internal/service/mcp.go index f30d87959c..b7b05255dd 100644 --- a/internal/service/mcp.go +++ b/internal/service/mcp.go @@ -18,6 +18,7 @@ package service import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -27,6 +28,8 @@ import ( "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/entity" + "ragflow/internal/mcpclient" + "ragflow/internal/utility" "gorm.io/gorm" ) @@ -35,6 +38,7 @@ const ( mcpServerTypeSSE = "sse" mcpServerTypeStreamableHTTP = "streamable-http" mcpServerNameLimit = 255 + defaultMCPFetchTimeoutSec = 10 mcpServerDateFormat = "2006-01-02T15:04:05" ) @@ -382,6 +386,267 @@ func safeJSONMap(raw json.RawMessage) entity.JSONMap { return entity.JSONMap(value) } +// ---------- import + test (this PR's additions) ---------- + +// Sentinel errors mapped by the handler to Python's response codes for the +// import / test endpoints. Per-server CRUD errors stay inside CreateMCPServer. +var ( + // ErrMCPInvalidType mirrors Python's "Unsupported MCP server type.". + ErrMCPInvalidType = errors.New("unsupported MCP server type") + // ErrMCPInvalidName mirrors Python's invalid-name/length error. + ErrMCPInvalidName = errors.New("invalid MCP name") + // ErrMCPInvalidURL mirrors Python's "Invalid url.". + ErrMCPInvalidURL = errors.New("invalid url") + // ErrMCPTestFailed is returned by TestServer when the live connection or + // tool-list fetch fails. The handler maps this to code 102 (DATA_ERROR), + // matching Python's test_mcp which never returns HTTP 500 for fetch errors. + ErrMCPTestFailed = errors.New("MCP test failed") +) + +// ImportResult is a single per-server outcome in the bulk import response, +// matching the shape returned by Python's import_multiple. +type ImportResult struct { + Server string `json:"server"` + Success bool `json:"success"` + Action string `json:"action,omitempty"` + ID string `json:"id,omitempty"` + NewName string `json:"new_name,omitempty"` + Message string `json:"message,omitempty"` +} + +// ImportServers bulk-imports MCP servers from a {"mcpServers": {name: config}} +// map. For each entry: validate type and URL, de-duplicate the name with a +// "_N" suffix, fetch the remote tool list via mcpclient (SSRF-guarded), and +// persist the server with tools stored under variables.tools. Mirrors +// Python's import_multiple. +// +// timeoutSeconds controls how long each tool-fetch call waits; <=0 falls back +// to the Python default of 10 s. +func (s *MCPService) ImportServers(tenantID string, servers map[string]map[string]interface{}, timeoutSeconds float64) ([]ImportResult, error) { + if timeoutSeconds <= 0 { + timeoutSeconds = defaultMCPFetchTimeoutSec + } + timeout := time.Duration(timeoutSeconds * float64(time.Second)) + + results := make([]ImportResult, 0, len(servers)) + for serverName, config := range servers { + url, hasURL := config["url"].(string) + stype, hasType := config["type"].(string) + if !hasType || !hasURL { + results = append(results, ImportResult{Server: serverName, Success: false, Message: "Missing required fields (type or url)"}) + continue + } + if serverName == "" || len([]byte(serverName)) > mcpServerNameLimit { + results = append(results, ImportResult{Server: serverName, Success: false, Message: fmt.Sprintf("Invalid MCP name or length is %d which is large than 255.", len(serverName))}) + continue + } + if !isValidMCPServerType(stype) { + results = append(results, ImportResult{Server: serverName, Success: false, Message: "Unsupported MCP server type."}) + continue + } + + baseName := serverName + newName, err := s.nextAvailableMCPName(baseName, tenantID) + if err != nil { + return nil, err + } + + variables := map[string]interface{}{} + stringVars := map[string]string{} + for k, v := range config { + if k == "type" || k == "url" || k == "headers" { + continue + } + variables[k] = v + if sv, ok := v.(string); ok { + stringVars[k] = sv + } + } + delete(variables, "tools") + delete(stringVars, "tools") + + // Headers can be provided either as a top-level "headers" map + // (preferred — matches the Python import shape) or as a flat + // "authorization_token" string at the entry root. Both go to the + // MCP client for tool discovery and to the persisted record so + // configs that depend on custom auth headers survive the round + // trip. + headers := map[string]string{} + headerVals := map[string]interface{}{} + if rawHeaders, ok := config["headers"].(map[string]interface{}); ok { + for k, v := range rawHeaders { + if sv, ok := v.(string); ok { + headers[k] = sv + } + headerVals[k] = v + } + } + if token, ok := config["authorization_token"].(string); ok { + if _, exists := headers["authorization_token"]; !exists { + headers["authorization_token"] = token + } + if _, exists := headerVals["authorization_token"]; !exists { + headerVals["authorization_token"] = token + } + } + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + tools, fetchErr := mcpclient.FetchTools(ctx, mcpclient.FetchOptions{ + URL: url, + ServerType: stype, + Headers: headers, + Variables: stringVars, + Timeout: timeout, + }) + cancel() + if fetchErr != nil { + results = append(results, ImportResult{Server: baseName, Success: false, Message: fetchErr.Error()}) + continue + } + variables["tools"] = toolsAsMap(tools) + + server := &entity.MCPServer{ + ID: common.GenerateUUID(), + TenantID: tenantID, + Name: newName, + URL: url, + ServerType: stype, + Variables: entity.JSONMap(variables), + Headers: entity.JSONMap(headerVals), + } + if err := s.mcpServerDAO.CreateMCPServer(server); err != nil { + results = append(results, ImportResult{Server: serverName, Success: false, Message: "Failed to create MCP server."}) + continue + } + + result := ImportResult{Server: serverName, Success: true, Action: "created", ID: server.ID, NewName: newName} + if newName != baseName { + result.Message = fmt.Sprintf("Renamed from '%s' to '%s' avoid duplication", baseName, newName) + } + results = append(results, result) + } + return results, nil +} + +func (s *MCPService) nextAvailableMCPName(base, tenantID string) (string, error) { + name := base + counter := 0 + for { + exists, err := s.mcpServerDAO.ExistsByNameAndTenant(name, tenantID) + if err != nil { + return "", err + } + if !exists { + return name, nil + } + name = fmt.Sprintf("%s_%d", base, counter) + counter++ + } +} + +// TestServerRequest is the body of POST /mcp/servers/:mcp_id/test. The mcp_id +// from the URL path is threaded through to the connect call for log +// correlation; the connection itself is opened from the request body so the +// user can preview unsaved edits — matching Python's test_mcp. +type TestServerRequest struct { + URL string `json:"url"` + ServerType string `json:"server_type"` + Headers map[string]interface{} `json:"headers,omitempty"` + Variables map[string]interface{} `json:"variables,omitempty"` + Timeout float64 `json:"timeout,omitempty"` +} + +// TestServer opens a live MCP session and returns the tools the server +// advertises. Mirrors Python's test_mcp. mcpID is used for log correlation +// only. +func (s *MCPService) TestServer(mcpID string, req *TestServerRequest) ([]map[string]interface{}, error) { + if req == nil || req.URL == "" { + return nil, fmt.Errorf("%w: Invalid MCP url.", ErrMCPInvalidURL) + } + if !isValidMCPServerType(req.ServerType) { + return nil, ErrMCPInvalidType + } + + // Run the SSRF guard up front so URL-shape failures (disallowed + // scheme, missing host, non-public address) surface as + // ErrMCPInvalidURL data errors instead of being swallowed inside the + // generic FetchTools error and re-classified by the handler as a 500. + // FetchTools repeats the check internally; the second call is cheap. + if _, _, err := utility.AssertURLSafe(req.URL); err != nil { + return nil, fmt.Errorf("%w: %s", ErrMCPInvalidURL, err.Error()) + } + + timeoutSec := req.Timeout + if timeoutSec <= 0 { + timeoutSec = defaultMCPFetchTimeoutSec + } + timeout := time.Duration(timeoutSec * float64(time.Second)) + + headers := map[string]string{} + for k, v := range req.Headers { + if sv, ok := v.(string); ok { + headers[k] = sv + } + } + vars := map[string]string{} + for k, v := range req.Variables { + if sv, ok := v.(string); ok { + vars[k] = sv + } + } + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + tools, err := mcpclient.FetchTools(ctx, mcpclient.FetchOptions{ + URL: req.URL, + ServerType: req.ServerType, + Headers: headers, + Variables: vars, + Timeout: timeout, + }) + if err != nil { + return nil, fmt.Errorf("%w: Test MCP error (id=%s): %v", ErrMCPTestFailed, mcpID, err) + } + + out := make([]map[string]interface{}, 0, len(tools)) + for _, t := range tools { + raw := t.Raw + if raw == nil { + raw = map[string]interface{}{"name": t.Name} + if t.Description != "" { + raw["description"] = t.Description + } + if t.InputSchema != nil { + raw["inputSchema"] = t.InputSchema + } + } + raw["enabled"] = true + out = append(out, raw) + } + return out, nil +} + +// toolsAsMap mirrors Python's `{tool["name"]: tool ...}` shape used when +// persisting variables.tools. +func toolsAsMap(tools []mcpclient.Tool) map[string]interface{} { + m := map[string]interface{}{} + for _, t := range tools { + if t.Raw != nil { + m[t.Name] = t.Raw + continue + } + entry := map[string]interface{}{"name": t.Name} + if t.Description != "" { + entry["description"] = t.Description + } + if t.InputSchema != nil { + entry["inputSchema"] = t.InputSchema + } + m[t.Name] = entry + } + return m +} + func formatMCPServerDate(date *time.Time) *string { if date == nil { return nil diff --git a/internal/service/mcp_test.go b/internal/service/mcp_test.go index c2de799477..507fbd1889 100644 --- a/internal/service/mcp_test.go +++ b/internal/service/mcp_test.go @@ -17,12 +17,75 @@ package service import ( + "errors" "fmt" + "strings" "testing" "ragflow/internal/entity" ) +func TestIsValidMCPServerType(t *testing.T) { + for _, v := range []string{mcpServerTypeSSE, mcpServerTypeStreamableHTTP} { + if !isValidMCPServerType(v) { + t.Errorf("expected %q to be a valid MCP server type", v) + } + } + for _, v := range []string{"", "stdio", "http", "SSE"} { + if isValidMCPServerType(v) { + t.Errorf("expected %q to be an invalid MCP server type", v) + } + } +} + +func TestServerInputValidation(t *testing.T) { + s := &MCPService{} + + // Empty URL is rejected before any connection attempt. + if _, err := s.TestServer("id-1", &TestServerRequest{ServerType: mcpServerTypeSSE}); !errors.Is(err, ErrMCPInvalidURL) { + t.Errorf("expected ErrMCPInvalidURL for empty url, got %v", err) + } + + // nil body is treated as empty URL. + if _, err := s.TestServer("id-1", nil); !errors.Is(err, ErrMCPInvalidURL) { + t.Errorf("expected ErrMCPInvalidURL for nil body, got %v", err) + } + + // Invalid server type is rejected before connecting. + if _, err := s.TestServer("id-1", &TestServerRequest{URL: "http://example.com/sse", ServerType: "stdio"}); !errors.Is(err, ErrMCPInvalidType) { + t.Errorf("expected ErrMCPInvalidType for bad type, got %v", err) + } +} + +func TestImportServersValidationErrors(t *testing.T) { + s := &MCPService{} + + // Missing url and type produce an in-band error per entry rather than + // failing the batch. + configs := map[string]map[string]interface{}{ + "missing-fields": {"foo": "bar"}, + "bad-type": {"url": "http://example.com", "type": "stdio"}, + } + results, err := s.ImportServers("tenant-1", configs, 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(results) != 2 { + t.Fatalf("expected 2 results, got %d", len(results)) + } + for _, r := range results { + if r.Success { + t.Errorf("expected failure result for %q", r.Server) + } + if r.Server == "missing-fields" && !strings.Contains(r.Message, "Missing required fields") { + t.Errorf("unexpected message for missing-fields: %q", r.Message) + } + if r.Server == "bad-type" && !strings.Contains(r.Message, "Unsupported MCP server type") { + t.Errorf("unexpected message for bad-type: %q", r.Message) + } + } +} + func TestPaginateMCPServersNegativeValuesMatchPythonSlice(t *testing.T) { servers := makeMCPServers(13) diff --git a/internal/utility/ssrf.go b/internal/utility/ssrf.go new file mode 100644 index 0000000000..252eb9d87d --- /dev/null +++ b/internal/utility/ssrf.go @@ -0,0 +1,190 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package utility + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "sort" + "strings" + "time" +) + +// AllowedURLSchemes are the schemes accepted by AssertURLSafe. +var AllowedURLSchemes = []string{"http", "https"} + +// LookupHost is the indirection used to resolve hostnames. Tests override it. +var LookupHost = net.LookupHost + +// AssertURLSafe parses rawURL and rejects it if the scheme is disallowed, +// the host is missing, or any resolved IP is not globally routable +// (private, loopback, link-local, multicast, reserved). Returns the hostname +// and the first validated public IP so callers can DNS-pin the address and +// prevent rebinding between validation and the actual TCP connection. +// +// Mirrors common/ssrf_guard.py:assert_url_is_safe. +func AssertURLSafe(rawURL string) (hostname, resolvedIP string, err error) { + parsed, perr := url.Parse(strings.TrimSpace(rawURL)) + if perr != nil { + return "", "", fmt.Errorf("Invalid url.") + } + + scheme := strings.ToLower(parsed.Scheme) + if !schemeAllowed(scheme) { + sorted := append([]string(nil), AllowedURLSchemes...) + sort.Strings(sorted) + return "", "", fmt.Errorf("Disallowed URL scheme: '%s'. Only %v are allowed.", scheme, sorted) + } + + hostname = parsed.Hostname() + if hostname == "" { + return "", "", fmt.Errorf("URL is missing a host.") + } + + addrs, err := LookupHost(hostname) + if err != nil { + return "", "", fmt.Errorf("Could not resolve hostname '%s': %v", hostname, err) + } + if len(addrs) == 0 { + return "", "", fmt.Errorf("Hostname '%s' resolved to no addresses.", hostname) + } + + for _, addr := range addrs { + ip := net.ParseIP(addr) + if ip == nil { + return "", "", fmt.Errorf("Could not parse resolved address '%s' for hostname '%s'.", addr, hostname) + } + if !isGlobalIP(effectiveIP(ip)) { + return "", "", fmt.Errorf("URL resolves to a non-public address (%s), which is not allowed.", ip.String()) + } + if resolvedIP == "" { + resolvedIP = ip.String() + } + } + return hostname, resolvedIP, nil +} + +func schemeAllowed(scheme string) bool { + for _, s := range AllowedURLSchemes { + if s == scheme { + return true + } + } + return false +} + +// effectiveIP unwraps IPv4-mapped IPv6 addresses (e.g. ::ffff:127.0.0.1) so +// the routability check sees the IPv4 form. Without this, an attacker could +// bypass the guard with an IPv4-mapped IPv6 representation of a private host. +func effectiveIP(ip net.IP) net.IP { + if v4 := ip.To4(); v4 != nil { + return v4 + } + return ip +} + +// isGlobalIP mirrors Python's ipaddress.IPv*Address.is_global: an address is +// global if it is none of {unspecified, loopback, multicast, link-local, +// private (including CGNAT and IPv6 ULA), benchmarking, documentation, +// reserved}. +func isGlobalIP(ip net.IP) bool { + if ip == nil || ip.IsUnspecified() || ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() || ip.IsPrivate() { + return false + } + if v4 := ip.To4(); v4 != nil { + // CGNAT 100.64.0.0/10 — not flagged by IsPrivate in older Go versions. + if v4[0] == 100 && v4[1]&0xC0 == 64 { + return false + } + // 192.0.0.0/24 reserved for IETF protocol assignments. + if v4[0] == 192 && v4[1] == 0 && v4[2] == 0 { + return false + } + // 192.0.2.0/24, 198.51.100.0/24, 203.0.113.0/24 documentation (TEST-NET-1/2/3). + if v4[0] == 192 && v4[1] == 0 && v4[2] == 2 { + return false + } + if v4[0] == 198 && v4[1] == 51 && v4[2] == 100 { + return false + } + if v4[0] == 203 && v4[1] == 0 && v4[2] == 113 { + return false + } + // 198.18.0.0/15 benchmarking. + if v4[0] == 198 && (v4[1] == 18 || v4[1] == 19) { + return false + } + // 240.0.0.0/4 reserved (excluding 255.255.255.255 which IsUnspecified misses). + if v4[0] >= 240 { + return false + } + } else if v6 := ip.To16(); v6 != nil { + // 2001:db8::/32 documentation prefix. + if v6[0] == 0x20 && v6[1] == 0x01 && v6[2] == 0x0d && v6[3] == 0xb8 { + return false + } + // 100::/64 discard-only address block. + if v6[0] == 0x01 && v6[1] == 0x00 && allZero(v6[2:8]) { + return false + } + } + return true +} + +func allZero(b []byte) bool { + for _, x := range b { + if x != 0 { + return false + } + } + return true +} + +// PinnedHTTPClient returns an HTTP client whose Transport rewrites every +// outbound dial for hostname:port to resolvedIP:port, closing the TOCTOU +// window between AssertURLSafe and the actual TCP connection. Pins are +// scoped to this client only. +func PinnedHTTPClient(hostname, resolvedIP string, timeout time.Duration) *http.Client { + dialer := &net.Dialer{ + Timeout: timeout, + KeepAlive: 30 * time.Second, + } + transport := &http.Transport{ + // Disable environment proxy: HTTP_PROXY / HTTPS_PROXY would route + // the connection through the proxy host instead of the pinned + // resolvedIP, bypassing the SSRF guard. + Proxy: nil, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, splitErr := net.SplitHostPort(addr) + if splitErr == nil && host == hostname && resolvedIP != "" { + return dialer.DialContext(ctx, network, net.JoinHostPort(resolvedIP, port)) + } + return dialer.DialContext(ctx, network, addr) + }, + TLSHandshakeTimeout: timeout, + ResponseHeaderTimeout: timeout, + ExpectContinueTimeout: 1 * time.Second, + ForceAttemptHTTP2: false, + } + return &http.Client{ + Transport: transport, + Timeout: timeout, + } +} diff --git a/internal/utility/ssrf_test.go b/internal/utility/ssrf_test.go new file mode 100644 index 0000000000..4e4e7b3c82 --- /dev/null +++ b/internal/utility/ssrf_test.go @@ -0,0 +1,156 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package utility + +import ( + "strings" + "testing" +) + +func TestAssertURLSafe(t *testing.T) { + orig := LookupHost + defer func() { LookupHost = orig }() + + type want struct { + errSubstr string + host string + ip string + } + cases := []struct { + name string + url string + ips []string + err string + want want + }{ + { + name: "public IPv4", + url: "https://example.com/path", + ips: []string{"93.184.216.34"}, + want: want{host: "example.com", ip: "93.184.216.34"}, + }, + { + name: "loopback rejected", + url: "http://localhost/x", + ips: []string{"127.0.0.1"}, + want: want{errSubstr: "non-public address"}, + }, + { + name: "private 10.x rejected", + url: "http://internal/x", + ips: []string{"10.0.0.5"}, + want: want{errSubstr: "non-public address"}, + }, + { + name: "private 192.168.x rejected", + url: "http://router/x", + ips: []string{"192.168.1.1"}, + want: want{errSubstr: "non-public address"}, + }, + { + name: "CGNAT 100.64/10 rejected", + url: "http://carrier/x", + ips: []string{"100.64.1.1"}, + want: want{errSubstr: "non-public address"}, + }, + { + name: "IPv4-mapped IPv6 loopback rejected", + url: "http://[::ffff:127.0.0.1]/x", + ips: []string{"::ffff:127.0.0.1"}, + want: want{errSubstr: "non-public address"}, + }, + { + name: "link-local IPv6 rejected", + url: "http://[fe80::1]/x", + ips: []string{"fe80::1"}, + want: want{errSubstr: "non-public address"}, + }, + { + name: "documentation 2001:db8 rejected", + url: "http://[2001:db8::1]/x", + ips: []string{"2001:db8::1"}, + want: want{errSubstr: "non-public address"}, + }, + { + name: "disallowed scheme ftp", + url: "ftp://example.com/", + ips: []string{"93.184.216.34"}, + want: want{errSubstr: "Disallowed URL scheme"}, + }, + { + name: "missing host", + url: "http:///path", + want: want{errSubstr: "missing a host"}, + }, + { + name: "resolution fails", + url: "http://nosuchhost.test/x", + err: "no such host", + want: want{errSubstr: "Could not resolve"}, + }, + { + name: "all addresses must be public", + url: "http://mixed.example.com/", + ips: []string{"93.184.216.34", "127.0.0.1"}, + want: want{errSubstr: "non-public address"}, + }, + { + name: "literal IPv4 loopback rejected", + url: "http://127.0.0.1/", + ips: []string{"127.0.0.1"}, + want: want{errSubstr: "non-public address"}, + }, + { + name: "documentation TEST-NET-3 rejected", + url: "http://stub/", + ips: []string{"203.0.113.5"}, + want: want{errSubstr: "non-public address"}, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + LookupHost = func(host string) ([]string, error) { + if tc.err != "" { + return nil, &mockErr{tc.err} + } + return tc.ips, nil + } + host, ip, err := AssertURLSafe(tc.url) + if tc.want.errSubstr != "" { + if err == nil || !strings.Contains(err.Error(), tc.want.errSubstr) { + t.Fatalf("expected error containing %q, got %v", tc.want.errSubstr, err) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if host != tc.want.host { + t.Errorf("host: got %q, want %q", host, tc.want.host) + } + if ip != tc.want.ip { + t.Errorf("ip: got %q, want %q", ip, tc.want.ip) + } + }) + } +} + +type mockErr struct{ s string } + +func (e *mockErr) Error() string { return e.s }