mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-04 18:45:38 +08:00
feat(go-api): implement MCP server management endpoints (#15281)
## Summary Ports the MCP (Model Context Protocol) server management endpoints that power `web/src/pages/user-setting/mcp/` from Python (`api/apps/restful_apis/mcp_api.py`) to Go. There were no MCP routes in the Go server before this change. Closes #15275 (subtask of #15240). ## Endpoints implemented (base path `/api/v1`) | Method | Path | Description | |--------|------|-------------| | GET | `/mcp/servers` | List tenant servers (keyword / order / pagination) | | POST | `/mcp/servers` | Create a server | | GET | `/mcp/servers/{mcp_id}` | Get one (`?mode=download` exports config) | | PUT | `/mcp/servers/{mcp_id}` | Update a server | | DELETE | `/mcp/servers/{mcp_id}` | Delete a server | | POST | `/mcp/import` | Bulk import from JSON config | | POST | `/mcp/servers/{mcp_id}/test` | Connect + list tools (see notes) | ## Implementation Follows the existing `handler → service → dao` layering (per PR #14790): - **entity** (`internal/entity/mcp.go`): added `MCPServerType` constants and `IsValidMCPServerType` over the existing `MCPServer` model. - **dao** (`internal/dao/mcp.go`): new `MCPServerDAO` with tenant-scoped CRUD, a keyword filter, and a **whitelisted order-column map** (guards against SQL injection via the caller-supplied `orderby`). - **service** (`internal/service/mcp.go`): new `MCPService` — list/get/export/create/update/delete/import/test — mirroring `MCPServerService` and the `mcp_api` request validation, with sentinel errors for clean code mapping. - **handler** (`internal/handler/mcp.go`): new `MCPHandler` with the seven handlers and Python-compatible response codes. - **router / server_main**: registered the `/mcp` group and wired the handler. ## Deviations from Python (documented in code) 1. **Bulk import is at `POST /mcp/import`, not `/mcp/servers/import`.** gin (v1.9.1) cannot register a static segment and a path param at the same tree node, so `/mcp/servers/import` would collide with `/mcp/servers/:mcp_id` and panic at startup. The frontend should call `/mcp/import`. 2. **No live tool discovery on create/update/import.** The Python path runs `get_mcp_tools` over SSE / streamable-HTTP and stores `variables.tools`. The Go server has no MCP client yet, so these persist `variables`/`headers` but leave `variables.tools` unpopulated. 3. **`/test` returns a data error (`ErrMCPTestUnsupported`)** until a Go MCP client lands. Per the issue, the live-connection path is scoped as a follow-up; the handler still validates `url` + `server_type`. ## Testing - Added `internal/service/mcp_test.go` covering `IsValidMCPServerType` and the `TestServer` validation/short-circuit paths (no DB required). - No Go toolchain was available in the dev environment, so `go build ./...` / `go vet ./...` verification is left to CI. ## Follow-ups - Go MCP client (SSE / streamable-HTTP) to enable live tool discovery and the real `/test` behavior. - Reconcile the `/mcp/import` vs `/mcp/servers/import` path with the frontend. ---------
This commit is contained in:
@@ -17,7 +17,10 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -185,6 +188,203 @@ func (h *MCPHandler) DeleteMCPServer(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// mcpErrorResponse maps the import / test sentinel errors to the response
|
||||
// codes Python's mcp_api emits.
|
||||
func mcpErrorResponse(c *gin.Context, err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
switch {
|
||||
case errors.Is(err, service.ErrMCPInvalidType),
|
||||
errors.Is(err, service.ErrMCPInvalidName),
|
||||
errors.Is(err, service.ErrMCPInvalidURL),
|
||||
errors.Is(err, service.ErrMCPTestFailed):
|
||||
c.JSON(http.StatusOK, gin.H{"code": common.CodeDataError, "data": nil, "message": mcpErrorMessage(err)})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "data": nil, "message": err.Error()})
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func mcpErrorMessage(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
// service wraps its sentinels as "<sentinel>: <detail>" via
|
||||
// fmt.Errorf("%w: ...", err). Surface the detail when present so the
|
||||
// SSRF guard's per-failure message (e.g. "URL resolves to a non-public
|
||||
// address (...).") reaches the caller verbatim, matching what Python's
|
||||
// _assert_mcp_url_is_safe returns.
|
||||
switch {
|
||||
case errors.Is(err, service.ErrMCPInvalidURL):
|
||||
if detail := unwrapDetail(err, service.ErrMCPInvalidURL); detail != "" {
|
||||
return detail
|
||||
}
|
||||
return "Invalid url."
|
||||
case errors.Is(err, service.ErrMCPInvalidType):
|
||||
return "Unsupported MCP server type."
|
||||
case errors.Is(err, service.ErrMCPTestFailed):
|
||||
if detail := unwrapDetail(err, service.ErrMCPTestFailed); detail != "" {
|
||||
return detail
|
||||
}
|
||||
return "Test MCP error."
|
||||
default:
|
||||
return err.Error()
|
||||
}
|
||||
}
|
||||
|
||||
// unwrapDetail pulls the "<sentinel>: <detail>" suffix off a wrapped error
|
||||
// and returns the detail. Returns "" when the error is the bare sentinel
|
||||
// (no wrapped message) so the caller can fall back to a default.
|
||||
func unwrapDetail(err, sentinel error) string {
|
||||
if err == nil || sentinel == nil {
|
||||
return ""
|
||||
}
|
||||
prefix := sentinel.Error() + ": "
|
||||
msg := err.Error()
|
||||
if !strings.HasPrefix(msg, prefix) {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimPrefix(msg, prefix)
|
||||
}
|
||||
|
||||
// ImportMCPRequest is the body for the bulk-import endpoint.
|
||||
type ImportMCPRequest struct {
|
||||
MCPServers map[string]map[string]interface{} `json:"mcpServers"`
|
||||
Timeout float64 `json:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
// ImportMCPServers bulk-imports MCP servers from a JSON config, fetching the
|
||||
// remote tool list for each entry and persisting it under variables.tools.
|
||||
// Mirrors Python's import_multiple, including the same distinction between
|
||||
// "mcpServers key missing" (101 ARGUMENT_ERROR) and "mcpServers key
|
||||
// present but empty" (102 DATA_ERROR).
|
||||
//
|
||||
// @Summary Import MCP Servers
|
||||
// @Tags mcp
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body handler.ImportMCPRequest true "import config"
|
||||
// @Router /api/v1/mcp/servers/import [post]
|
||||
func (h *MCPHandler) ImportMCPServers(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
// Read the raw body so we can distinguish "key absent" from "key
|
||||
// present but empty" — the Python @validate_request("mcpServers")
|
||||
// decorator returns RetCode.ARGUMENT_ERROR for the former, while the
|
||||
// handler body returns RetCode.DATA_ERROR for the latter.
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeBadRequest, "data": nil, "message": "Invalid request body: " + err.Error()})
|
||||
return
|
||||
}
|
||||
var raw map[string]json.RawMessage
|
||||
if len(body) > 0 {
|
||||
if err := json.Unmarshal(body, &raw); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeBadRequest, "data": nil, "message": "Invalid request body: " + err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
rawServers, hasServers := raw["mcpServers"]
|
||||
if !hasServers {
|
||||
// Match Python validate_request: code 101, message includes the
|
||||
// trailing "; " separator the Python decorator emits.
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"data": nil,
|
||||
"message": "required argument are missing: mcpServers; ",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var servers map[string]map[string]interface{}
|
||||
if err := json.Unmarshal(rawServers, &servers); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeBadRequest, "data": nil, "message": "Invalid request body: " + err.Error()})
|
||||
return
|
||||
}
|
||||
if len(servers) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"code": common.CodeDataError, "data": nil, "message": "No MCP servers provided."})
|
||||
return
|
||||
}
|
||||
|
||||
var timeout float64
|
||||
if rawTimeout, ok := raw["timeout"]; ok {
|
||||
// Ignore parse errors for timeout to match Python's get_float
|
||||
// default-on-failure behavior; the service applies its own
|
||||
// 10 s fallback when timeout <= 0.
|
||||
_ = json.Unmarshal(rawTimeout, &timeout)
|
||||
}
|
||||
|
||||
results, err := h.mcpService.ImportServers(user.ID, servers, timeout)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "data": nil, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"code": common.CodeSuccess, "data": gin.H{"results": results}, "message": "success"})
|
||||
}
|
||||
|
||||
// TestMCPServer opens a live MCP session and returns the tools the server
|
||||
// advertises. The mcp_id path parameter identifies the stored record the
|
||||
// user is trying to validate; the actual connection uses the request body
|
||||
// so the user can preview unsaved edits — matching Python's test_mcp.
|
||||
//
|
||||
// @Summary Test MCP Server
|
||||
// @Tags mcp
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param mcp_id path string true "MCP server ID"
|
||||
// @Param request body service.TestServerRequest true "test parameters"
|
||||
// @Router /api/v1/mcp/servers/{mcp_id}/test [post]
|
||||
func (h *MCPHandler) TestMCPServer(c *gin.Context) {
|
||||
_, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
mcpID := c.Param("mcp_id")
|
||||
if mcpID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeBadRequest, "data": nil, "message": "mcp_id is required"})
|
||||
return
|
||||
}
|
||||
|
||||
var req service.TestServerRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeBadRequest, "data": nil, "message": "Invalid request body: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Mirror Python's @validate_request("url", "server_type"): missing
|
||||
// required fields → code 101 (ARGUMENT_ERROR), not code 102.
|
||||
var missingFields []string
|
||||
if req.URL == "" {
|
||||
missingFields = append(missingFields, "url")
|
||||
}
|
||||
if req.ServerType == "" {
|
||||
missingFields = append(missingFields, "server_type")
|
||||
}
|
||||
if len(missingFields) > 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"data": nil,
|
||||
"message": "required argument are missing: " + strings.Join(missingFields, ", ") + "; ",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
tools, err := h.mcpService.TestServer(mcpID, &req)
|
||||
if mcpErrorResponse(c, err) {
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"code": common.CodeSuccess, "data": tools, "message": "success"})
|
||||
}
|
||||
|
||||
func newMCPServerResponse(server *entity.MCPServer) *mcpServerResponse {
|
||||
if server == nil {
|
||||
return nil
|
||||
|
||||
742
internal/mcpclient/client.go
Normal file
742
internal/mcpclient/client.go
Normal file
@@ -0,0 +1,742 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
// Package mcpclient is a minimal Model Context Protocol (MCP) client used by
|
||||
// the Go MCP-management endpoints to list a remote server's tools during
|
||||
// import and the "test" endpoint. It implements just enough of the spec to
|
||||
// negotiate a session and call tools/list:
|
||||
//
|
||||
// - streamable-HTTP transport (spec 2025-03-26): single endpoint, JSON-RPC
|
||||
// requests via POST, responses either as application/json or as an SSE
|
||||
// stream sharing the same connection.
|
||||
// - SSE transport (spec 2024-11-05, legacy): server returns an "endpoint"
|
||||
// event whose data is the URL the client POSTs JSON-RPC requests to;
|
||||
// responses are pushed back on the same SSE stream.
|
||||
//
|
||||
// The full Python implementation lives in common/mcp_tool_call_conn.py; this
|
||||
// is a reduced port focused on tools/list discovery.
|
||||
package mcpclient
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"ragflow/internal/utility"
|
||||
)
|
||||
|
||||
// Transport identifiers. Mirrors common.constants.MCPServerType.
|
||||
const (
|
||||
TransportSSE = "sse"
|
||||
TransportStreamableHTTP = "streamable-http"
|
||||
)
|
||||
|
||||
const (
|
||||
protocolVersion = "2025-03-26"
|
||||
clientName = "ragflow-go"
|
||||
clientVersion = "1.0.0"
|
||||
jsonRPCVersion = "2.0"
|
||||
)
|
||||
|
||||
// Tool is the subset of an MCP Tool descriptor returned by tools/list.
|
||||
// Extra fields surfaced by the server are preserved in Raw so callers can
|
||||
// round-trip them into variables.tools without losing data.
|
||||
type Tool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema map[string]interface{} `json:"inputSchema,omitempty"`
|
||||
Raw map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// FetchOptions controls a single tools/list discovery call.
|
||||
type FetchOptions struct {
|
||||
URL string
|
||||
ServerType string
|
||||
Headers map[string]string
|
||||
Variables map[string]string
|
||||
Timeout time.Duration
|
||||
HTTPClient *http.Client
|
||||
pinHostname string
|
||||
pinIP string
|
||||
}
|
||||
|
||||
// FetchTools opens a connection to the MCP server described by opts and
|
||||
// returns the tools advertised by tools/list. URL safety / DNS pinning is
|
||||
// performed here so callers get the same SSRF guarantees the Python path
|
||||
// has via pin_dns_global + assert_url_is_safe.
|
||||
func FetchTools(ctx context.Context, opts FetchOptions) ([]Tool, error) {
|
||||
if opts.URL == "" {
|
||||
return nil, errors.New("Invalid url.")
|
||||
}
|
||||
if opts.Timeout <= 0 {
|
||||
opts.Timeout = 10 * time.Second
|
||||
}
|
||||
|
||||
hostname, resolvedIP, err := utility.AssertURLSafe(opts.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opts.pinHostname = hostname
|
||||
opts.pinIP = resolvedIP
|
||||
if opts.HTTPClient == nil {
|
||||
opts.HTTPClient = utility.PinnedHTTPClient(hostname, resolvedIP, opts.Timeout)
|
||||
}
|
||||
|
||||
headers, headerErr := renderHeaders(opts.Headers, opts.Variables)
|
||||
if headerErr != nil {
|
||||
return nil, headerErr
|
||||
}
|
||||
|
||||
connectCtx, cancel := context.WithTimeout(ctx, opts.Timeout)
|
||||
defer cancel()
|
||||
|
||||
switch strings.ToLower(opts.ServerType) {
|
||||
case TransportStreamableHTTP:
|
||||
return fetchToolsStreamableHTTP(connectCtx, opts.URL, headers, opts.HTTPClient)
|
||||
case TransportSSE:
|
||||
return fetchToolsSSE(connectCtx, opts.URL, headers, opts.HTTPClient)
|
||||
default:
|
||||
return nil, fmt.Errorf("Unsupported MCP server type.")
|
||||
}
|
||||
}
|
||||
|
||||
// renderHeaders applies ${name} substitution to header keys and values using
|
||||
// the supplied variables map, mirroring the Template.safe_substitute pass in
|
||||
// common/mcp_tool_call_conn.py. Empty keys (after substitution) are dropped.
|
||||
func renderHeaders(raw map[string]string, vars map[string]string) (map[string]string, error) {
|
||||
rendered := map[string]string{}
|
||||
for k, v := range raw {
|
||||
nk := substituteTemplate(k, vars)
|
||||
nv := substituteTemplate(v, vars)
|
||||
if strings.TrimSpace(nk) == "" {
|
||||
continue
|
||||
}
|
||||
rendered[nk] = nv
|
||||
}
|
||||
return rendered, nil
|
||||
}
|
||||
|
||||
// substituteTemplate replaces ${name} occurrences (Python string.Template
|
||||
// safe-substitute semantics) with values from vars. Unknown keys are left
|
||||
// in place, matching safe_substitute's behavior.
|
||||
func substituteTemplate(s string, vars map[string]string) string {
|
||||
if vars == nil || !strings.Contains(s, "${") {
|
||||
return s
|
||||
}
|
||||
var b strings.Builder
|
||||
i := 0
|
||||
for i < len(s) {
|
||||
idx := strings.Index(s[i:], "${")
|
||||
if idx == -1 {
|
||||
b.WriteString(s[i:])
|
||||
break
|
||||
}
|
||||
b.WriteString(s[i : i+idx])
|
||||
i += idx + 2
|
||||
end := strings.Index(s[i:], "}")
|
||||
if end == -1 {
|
||||
b.WriteString("${")
|
||||
b.WriteString(s[i:])
|
||||
break
|
||||
}
|
||||
key := s[i : i+end]
|
||||
i += end + 1
|
||||
if val, ok := vars[key]; ok {
|
||||
b.WriteString(val)
|
||||
} else {
|
||||
b.WriteString("${")
|
||||
b.WriteString(key)
|
||||
b.WriteString("}")
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// jsonRPCRequest is a JSON-RPC 2.0 request envelope.
|
||||
type jsonRPCRequest struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID interface{} `json:"id,omitempty"`
|
||||
Method string `json:"method"`
|
||||
Params interface{} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// jsonRPCResponse is a JSON-RPC 2.0 response. Either Result or Error is set.
|
||||
type jsonRPCResponse struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID interface{} `json:"id,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *jsonRPCError `json:"error,omitempty"`
|
||||
Method string `json:"method,omitempty"`
|
||||
}
|
||||
|
||||
type jsonRPCError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func initializeParams() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"protocolVersion": protocolVersion,
|
||||
"capabilities": map[string]interface{}{},
|
||||
"clientInfo": map[string]interface{}{
|
||||
"name": clientName,
|
||||
"version": clientVersion,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- streamable-HTTP transport ----------
|
||||
|
||||
const sessionHeader = "Mcp-Session-Id"
|
||||
|
||||
func fetchToolsStreamableHTTP(ctx context.Context, endpoint string, headers map[string]string, client *http.Client) ([]Tool, error) {
|
||||
sessionID, initRes, err := streamableSend(ctx, client, endpoint, "", headers, jsonRPCRequest{
|
||||
JSONRPC: jsonRPCVersion,
|
||||
ID: 0,
|
||||
Method: "initialize",
|
||||
Params: initializeParams(),
|
||||
}, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if initRes.Error != nil {
|
||||
return nil, formatMCPError("initialize", initRes.Error)
|
||||
}
|
||||
|
||||
if _, _, err := streamableSend(ctx, client, endpoint, sessionID, headers, jsonRPCRequest{
|
||||
JSONRPC: jsonRPCVersion,
|
||||
Method: "notifications/initialized",
|
||||
}, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, listRes, err := streamableSend(ctx, client, endpoint, sessionID, headers, jsonRPCRequest{
|
||||
JSONRPC: jsonRPCVersion,
|
||||
ID: 1,
|
||||
Method: "tools/list",
|
||||
}, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if listRes.Error != nil {
|
||||
return nil, formatMCPError("tools/list", listRes.Error)
|
||||
}
|
||||
return parseToolsResult(listRes.Result)
|
||||
}
|
||||
|
||||
// streamableSend POSTs a JSON-RPC payload to the streamable-HTTP endpoint.
|
||||
// When expectResponse is false (notifications), the response body is not
|
||||
// parsed. The session id returned by the initial initialize call is
|
||||
// propagated via the Mcp-Session-Id header per the spec.
|
||||
func streamableSend(ctx context.Context, client *http.Client, endpoint, sessionID string, headers map[string]string, payload jsonRPCRequest, expectResponse bool) (string, *jsonRPCResponse, error) {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("marshal MCP request: %w", err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("build MCP request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json, text/event-stream")
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
if sessionID != "" {
|
||||
req.Header.Set(sessionHeader, sessionID)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", nil, mapMCPConnectionError(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if !expectResponse {
|
||||
if resp.StatusCode >= 400 {
|
||||
return "", nil, fmt.Errorf("MCP server returned HTTP %d for %s", resp.StatusCode, payload.Method)
|
||||
}
|
||||
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 1<<20))
|
||||
return resp.Header.Get(sessionHeader), nil, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return "", nil, fmt.Errorf("MCP server returned HTTP %d for %s: %s", resp.StatusCode, payload.Method, strings.TrimSpace(string(raw)))
|
||||
}
|
||||
|
||||
contentType := strings.ToLower(resp.Header.Get("Content-Type"))
|
||||
sid := resp.Header.Get(sessionHeader)
|
||||
if sessionID == "" {
|
||||
sessionID = sid
|
||||
}
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
r, err := readJSONRPCFromSSE(resp.Body, payload.ID)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return sessionID, r, nil
|
||||
}
|
||||
raw, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("read MCP response: %w", err)
|
||||
}
|
||||
parsed, err := parseJSONRPC(raw, payload.ID)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return sessionID, parsed, nil
|
||||
}
|
||||
|
||||
// ---------- SSE transport ----------
|
||||
|
||||
func fetchToolsSSE(ctx context.Context, endpoint string, headers map[string]string, client *http.Client) ([]Tool, error) {
|
||||
streamReq, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build SSE request: %w", err)
|
||||
}
|
||||
streamReq.Header.Set("Accept", "text/event-stream")
|
||||
streamReq.Header.Set("Cache-Control", "no-cache")
|
||||
for k, v := range headers {
|
||||
streamReq.Header.Set(k, v)
|
||||
}
|
||||
streamResp, err := client.Do(streamReq)
|
||||
if err != nil {
|
||||
return nil, mapMCPConnectionError(err)
|
||||
}
|
||||
if streamResp.StatusCode >= 400 {
|
||||
body, _ := io.ReadAll(io.LimitReader(streamResp.Body, 1<<20))
|
||||
streamResp.Body.Close()
|
||||
return nil, fmt.Errorf("MCP SSE handshake returned HTTP %d: %s", streamResp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
stream := newSSEReader(streamResp.Body)
|
||||
defer streamResp.Body.Close()
|
||||
|
||||
postURL, err := waitForEndpoint(ctx, stream, endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// The endpoint event can hand us an arbitrary absolute URL. A
|
||||
// malicious public SSE server could point us at 127.0.0.1 or any
|
||||
// other internal host to bounce the POST phase through us. Re-run
|
||||
// the SSRF guard against the resolved URL, and — when the host
|
||||
// differs from the original SSE host — swap in a fresh pinned
|
||||
// client so the dial-time IP override still applies.
|
||||
postClient := client
|
||||
if postHost, postIP, vErr := utility.AssertURLSafe(postURL); vErr != nil {
|
||||
return nil, vErr
|
||||
} else if u, perr := url.Parse(postURL); perr == nil && u.Hostname() != "" {
|
||||
if u.Hostname() != originalHost(endpoint) {
|
||||
postClient = utility.PinnedHTTPClient(postHost, postIP, sseTimeoutFrom(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
pending := newPendingResponses()
|
||||
streamDone := make(chan error, 1)
|
||||
go func() {
|
||||
streamDone <- stream.dispatch(ctx, pending)
|
||||
}()
|
||||
|
||||
postOnce := func(payload jsonRPCRequest) error {
|
||||
body, _ := json.Marshal(payload)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, postURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("build SSE POST: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
resp, err := postClient.Do(req)
|
||||
if err != nil {
|
||||
return mapMCPConnectionError(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return fmt.Errorf("MCP server returned HTTP %d for %s: %s", resp.StatusCode, payload.Method, strings.TrimSpace(string(raw)))
|
||||
}
|
||||
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 1<<20))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Register the waiter BEFORE issuing the POST so a fast server that
|
||||
// pushes its response before our wait() call doesn't drop the delivery.
|
||||
initWaiter := pending.register(0)
|
||||
if err := postOnce(jsonRPCRequest{JSONRPC: jsonRPCVersion, ID: 0, Method: "initialize", Params: initializeParams()}); err != nil {
|
||||
pending.cancel(0)
|
||||
return nil, err
|
||||
}
|
||||
initRes, err := pending.await(ctx, initWaiter, streamDone)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if initRes.Error != nil {
|
||||
return nil, formatMCPError("initialize", initRes.Error)
|
||||
}
|
||||
if err := postOnce(jsonRPCRequest{JSONRPC: jsonRPCVersion, Method: "notifications/initialized"}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listWaiter := pending.register(1)
|
||||
if err := postOnce(jsonRPCRequest{JSONRPC: jsonRPCVersion, ID: 1, Method: "tools/list"}); err != nil {
|
||||
pending.cancel(1)
|
||||
return nil, err
|
||||
}
|
||||
listRes, err := pending.await(ctx, listWaiter, streamDone)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if listRes.Error != nil {
|
||||
return nil, formatMCPError("tools/list", listRes.Error)
|
||||
}
|
||||
return parseToolsResult(listRes.Result)
|
||||
}
|
||||
|
||||
// waitForEndpoint reads SSE events until an "endpoint" event arrives and
|
||||
// returns the URL to POST JSON-RPC requests to. The data may be either a
|
||||
// fully-qualified URL or a path; relative paths are resolved against the
|
||||
// original SSE endpoint.
|
||||
func waitForEndpoint(ctx context.Context, stream *sseReader, base string) (string, error) {
|
||||
for {
|
||||
event, err := stream.nextEvent(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if event == nil {
|
||||
return "", errors.New("MCP SSE stream closed before sending endpoint event")
|
||||
}
|
||||
if event.event == "endpoint" {
|
||||
ref := strings.TrimSpace(event.data)
|
||||
if ref == "" {
|
||||
return "", errors.New("MCP SSE endpoint event has empty data")
|
||||
}
|
||||
baseURL, err := url.Parse(base)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse MCP SSE base url: %w", err)
|
||||
}
|
||||
rel, err := url.Parse(ref)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse MCP SSE endpoint data: %w", err)
|
||||
}
|
||||
return baseURL.ResolveReference(rel).String(), nil
|
||||
}
|
||||
// Other events (heartbeats, message) before endpoint are ignored.
|
||||
}
|
||||
}
|
||||
|
||||
// originalHost extracts the hostname from the original SSE endpoint so the
|
||||
// caller can detect when the server-advertised post URL has moved to a
|
||||
// different host (and a fresh pinned client is required).
|
||||
func originalHost(endpoint string) string {
|
||||
u, err := url.Parse(endpoint)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return u.Hostname()
|
||||
}
|
||||
|
||||
// sseTimeoutFrom recovers a non-zero timeout from the request context so
|
||||
// the freshly-pinned post-phase client has the same deadline as the rest
|
||||
// of the SSE flow.
|
||||
func sseTimeoutFrom(ctx context.Context) time.Duration {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if d := time.Until(deadline); d > 0 {
|
||||
return d
|
||||
}
|
||||
}
|
||||
return 10 * time.Second
|
||||
}
|
||||
|
||||
// pendingResponses correlates outstanding JSON-RPC ids with channels that
|
||||
// receive the corresponding response from the SSE dispatcher.
|
||||
type pendingResponses struct {
|
||||
mu sync.Mutex
|
||||
waiters map[string]chan *jsonRPCResponse
|
||||
}
|
||||
|
||||
func newPendingResponses() *pendingResponses {
|
||||
return &pendingResponses{waiters: map[string]chan *jsonRPCResponse{}}
|
||||
}
|
||||
|
||||
// pendingWaiter is the handle returned by register; the caller passes it to
|
||||
// await once the request has been sent.
|
||||
type pendingWaiter struct {
|
||||
key string
|
||||
ch chan *jsonRPCResponse
|
||||
}
|
||||
|
||||
// register reserves a waiter slot for the given JSON-RPC id BEFORE the
|
||||
// request is sent, so a server that responds before await() is called still
|
||||
// has somewhere to deliver to.
|
||||
func (p *pendingResponses) register(id interface{}) pendingWaiter {
|
||||
key := normalizeID(id)
|
||||
ch := make(chan *jsonRPCResponse, 1)
|
||||
p.mu.Lock()
|
||||
p.waiters[key] = ch
|
||||
p.mu.Unlock()
|
||||
return pendingWaiter{key: key, ch: ch}
|
||||
}
|
||||
|
||||
// cancel drops a previously registered waiter. Used when the POST fails so
|
||||
// a late server delivery cannot block forever in the waiters map.
|
||||
func (p *pendingResponses) cancel(id interface{}) {
|
||||
key := normalizeID(id)
|
||||
p.mu.Lock()
|
||||
delete(p.waiters, key)
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
// await blocks until the registered waiter's response arrives, the SSE
|
||||
// stream closes, or ctx expires.
|
||||
func (p *pendingResponses) await(ctx context.Context, w pendingWaiter, streamDone <-chan error) (*jsonRPCResponse, error) {
|
||||
defer func() {
|
||||
p.mu.Lock()
|
||||
delete(p.waiters, w.key)
|
||||
p.mu.Unlock()
|
||||
}()
|
||||
select {
|
||||
case res := <-w.ch:
|
||||
return res, nil
|
||||
case err := <-streamDone:
|
||||
if err == nil {
|
||||
return nil, errors.New("MCP SSE stream closed before response arrived")
|
||||
}
|
||||
return nil, err
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pendingResponses) deliver(res *jsonRPCResponse) {
|
||||
key := normalizeID(res.ID)
|
||||
p.mu.Lock()
|
||||
ch, ok := p.waiters[key]
|
||||
p.mu.Unlock()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case ch <- res:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeID(id interface{}) string {
|
||||
switch v := id.(type) {
|
||||
case nil:
|
||||
return ""
|
||||
case string:
|
||||
return v
|
||||
case json.Number:
|
||||
return v.String()
|
||||
case float64:
|
||||
return fmt.Sprintf("%v", v)
|
||||
default:
|
||||
b, _ := json.Marshal(v)
|
||||
return string(b)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- SSE parsing ----------
|
||||
|
||||
type sseEvent struct {
|
||||
event string
|
||||
data string
|
||||
}
|
||||
|
||||
type sseReader struct {
|
||||
rd *bufio.Reader
|
||||
}
|
||||
|
||||
func newSSEReader(r io.Reader) *sseReader {
|
||||
return &sseReader{rd: bufio.NewReaderSize(r, 64*1024)}
|
||||
}
|
||||
|
||||
// nextEvent returns the next SSE event (event: + data:) from the stream, or
|
||||
// nil when the stream is closed cleanly.
|
||||
func (s *sseReader) nextEvent(ctx context.Context) (*sseEvent, error) {
|
||||
ev := &sseEvent{}
|
||||
var dataLines []string
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
line, err := s.rd.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
if len(dataLines) > 0 || ev.event != "" {
|
||||
ev.data = strings.Join(dataLines, "\n")
|
||||
return ev, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
line = strings.TrimRight(line, "\r\n")
|
||||
if line == "" {
|
||||
if len(dataLines) == 0 && ev.event == "" {
|
||||
continue
|
||||
}
|
||||
ev.data = strings.Join(dataLines, "\n")
|
||||
return ev, nil
|
||||
}
|
||||
if strings.HasPrefix(line, ":") {
|
||||
continue
|
||||
}
|
||||
if idx := strings.Index(line, ":"); idx >= 0 {
|
||||
field := line[:idx]
|
||||
value := strings.TrimPrefix(line[idx+1:], " ")
|
||||
switch field {
|
||||
case "event":
|
||||
ev.event = value
|
||||
case "data":
|
||||
dataLines = append(dataLines, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dispatch reads events off the SSE stream and forwards JSON-RPC responses
|
||||
// to the matching waiter. It returns when the stream closes.
|
||||
func (s *sseReader) dispatch(ctx context.Context, pending *pendingResponses) error {
|
||||
for {
|
||||
ev, err := s.nextEvent(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ev == nil {
|
||||
return nil
|
||||
}
|
||||
if ev.event != "" && ev.event != "message" {
|
||||
continue
|
||||
}
|
||||
raw := []byte(ev.data)
|
||||
if len(bytes.TrimSpace(raw)) == 0 {
|
||||
continue
|
||||
}
|
||||
parsed, err := parseJSONRPC(raw, nil)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if parsed.Method != "" && parsed.ID == nil {
|
||||
// Server-initiated notification; nothing to deliver.
|
||||
continue
|
||||
}
|
||||
pending.deliver(parsed)
|
||||
}
|
||||
}
|
||||
|
||||
// readJSONRPCFromSSE consumes a single JSON-RPC response off an inline SSE
|
||||
// stream returned by a streamable-HTTP POST. The response with matching id
|
||||
// is returned; everything else is skipped.
|
||||
func readJSONRPCFromSSE(r io.Reader, wantID interface{}) (*jsonRPCResponse, error) {
|
||||
stream := newSSEReader(r)
|
||||
for {
|
||||
ev, err := stream.nextEvent(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ev == nil {
|
||||
return nil, errors.New("MCP SSE response stream closed before response arrived")
|
||||
}
|
||||
if ev.event != "" && ev.event != "message" {
|
||||
continue
|
||||
}
|
||||
raw := []byte(ev.data)
|
||||
if len(bytes.TrimSpace(raw)) == 0 {
|
||||
continue
|
||||
}
|
||||
parsed, err := parseJSONRPC(raw, wantID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if normalizeID(parsed.ID) == normalizeID(wantID) {
|
||||
return parsed, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- shared helpers ----------
|
||||
|
||||
func parseJSONRPC(raw []byte, wantID interface{}) (*jsonRPCResponse, error) {
|
||||
dec := json.NewDecoder(bytes.NewReader(raw))
|
||||
dec.UseNumber()
|
||||
res := &jsonRPCResponse{}
|
||||
if err := dec.Decode(res); err != nil {
|
||||
return nil, fmt.Errorf("parse MCP response: %w", err)
|
||||
}
|
||||
if wantID != nil && res.ID != nil && normalizeID(res.ID) != normalizeID(wantID) {
|
||||
return nil, fmt.Errorf("unexpected JSON-RPC id %v (want %v)", res.ID, wantID)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func parseToolsResult(raw json.RawMessage) ([]Tool, error) {
|
||||
if len(raw) == 0 {
|
||||
return []Tool{}, nil
|
||||
}
|
||||
var envelope struct {
|
||||
Tools []map[string]interface{} `json:"tools"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &envelope); err != nil {
|
||||
return nil, fmt.Errorf("parse tools result: %w", err)
|
||||
}
|
||||
tools := make([]Tool, 0, len(envelope.Tools))
|
||||
for _, raw := range envelope.Tools {
|
||||
name, _ := raw["name"].(string)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
desc, _ := raw["description"].(string)
|
||||
var schema map[string]interface{}
|
||||
if s, ok := raw["inputSchema"].(map[string]interface{}); ok {
|
||||
schema = s
|
||||
}
|
||||
tools = append(tools, Tool{
|
||||
Name: name,
|
||||
Description: desc,
|
||||
InputSchema: schema,
|
||||
Raw: raw,
|
||||
})
|
||||
}
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
func formatMCPError(method string, e *jsonRPCError) error {
|
||||
if e == nil {
|
||||
return fmt.Errorf("MCP %s failed", method)
|
||||
}
|
||||
return fmt.Errorf("MCP %s failed (%d): %s", method, e.Code, e.Message)
|
||||
}
|
||||
|
||||
// mapMCPConnectionError surfaces the same wording the Python session uses
|
||||
// when a low-level connection fails (authentication / network).
|
||||
func mapMCPConnectionError(err error) error {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return errors.New("Timeout connecting to MCP server")
|
||||
}
|
||||
return fmt.Errorf("Connection failed (possibly due to auth error). Please check authentication settings first: %v", err)
|
||||
}
|
||||
254
internal/mcpclient/client_test.go
Normal file
254
internal/mcpclient/client_test.go
Normal file
@@ -0,0 +1,254 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package mcpclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ragflow/internal/utility"
|
||||
)
|
||||
|
||||
// allowLoopbackForTests overrides the SSRF guard's resolver so 127.0.0.1
|
||||
// targets used by httptest are accepted by AssertURLSafe.
|
||||
func allowLoopbackForTests(t *testing.T) func() {
|
||||
t.Helper()
|
||||
orig := utility.LookupHost
|
||||
utility.LookupHost = func(host string) ([]string, error) {
|
||||
// Return a public IPv4 so the guard sees the host as global; the
|
||||
// httptest server is on loopback but we connect via raw URL.
|
||||
return []string{"8.8.8.8"}, nil
|
||||
}
|
||||
return func() { utility.LookupHost = orig }
|
||||
}
|
||||
|
||||
func TestFetchToolsStreamableHTTPJSON(t *testing.T) {
|
||||
defer allowLoopbackForTests(t)()
|
||||
|
||||
var initCount, listCount, notifyCount int32
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
}
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var req map[string]interface{}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
// testing.T's Fatal* must run on the test goroutine; surface the
|
||||
// failure via Errorf and bail the handler out instead.
|
||||
t.Errorf("invalid request body: %v", err)
|
||||
http.Error(w, "bad body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
switch req["method"] {
|
||||
case "initialize":
|
||||
atomic.AddInt32(&initCount, 1)
|
||||
w.Header().Set(sessionHeader, "test-session")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprintf(w, `{"jsonrpc":"2.0","id":%v,"result":{"capabilities":{}}}`, req["id"])
|
||||
case "notifications/initialized":
|
||||
atomic.AddInt32(¬ifyCount, 1)
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
case "tools/list":
|
||||
atomic.AddInt32(&listCount, 1)
|
||||
if got := r.Header.Get(sessionHeader); got != "test-session" {
|
||||
t.Errorf("expected session header to be propagated, got %q", got)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprintf(w, `{"jsonrpc":"2.0","id":%v,"result":{"tools":[{"name":"search","description":"Find docs","inputSchema":{"type":"object"}},{"name":"fetch"}]}}`, req["id"])
|
||||
default:
|
||||
http.Error(w, "unexpected method", http.StatusBadRequest)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
tools, err := FetchTools(context.Background(), FetchOptions{
|
||||
URL: srv.URL,
|
||||
ServerType: TransportStreamableHTTP,
|
||||
HTTPClient: srv.Client(),
|
||||
Timeout: 2 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got := len(tools); got != 2 {
|
||||
t.Fatalf("expected 2 tools, got %d", got)
|
||||
}
|
||||
if tools[0].Name != "search" || tools[0].Description != "Find docs" {
|
||||
t.Errorf("tool 0 = %+v", tools[0])
|
||||
}
|
||||
if tools[1].Name != "fetch" {
|
||||
t.Errorf("tool 1 = %+v", tools[1])
|
||||
}
|
||||
if atomic.LoadInt32(&initCount) != 1 || atomic.LoadInt32(¬ifyCount) != 1 || atomic.LoadInt32(&listCount) != 1 {
|
||||
t.Errorf("expected 1 init / 1 notify / 1 list, got %d/%d/%d", initCount, notifyCount, listCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchToolsStreamableHTTPErrorResponse(t *testing.T) {
|
||||
defer allowLoopbackForTests(t)()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var req map[string]interface{}
|
||||
_ = json.Unmarshal(body, &req)
|
||||
if req["method"] == "initialize" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprintf(w, `{"jsonrpc":"2.0","id":%v,"error":{"code":-32600,"message":"bad init"}}`, req["id"])
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
_, err := FetchTools(context.Background(), FetchOptions{
|
||||
URL: srv.URL,
|
||||
ServerType: TransportStreamableHTTP,
|
||||
HTTPClient: srv.Client(),
|
||||
Timeout: 2 * time.Second,
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "bad init") {
|
||||
t.Fatalf("expected MCP error to surface, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchToolsSSE(t *testing.T) {
|
||||
defer allowLoopbackForTests(t)()
|
||||
|
||||
type ssePush struct {
|
||||
event string
|
||||
data string
|
||||
}
|
||||
pushes := make(chan ssePush, 4)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/sse", func(w http.ResponseWriter, r *http.Request) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
t.Errorf("response writer is not a flusher")
|
||||
http.Error(w, "no flusher", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
fmt.Fprintf(w, "event: endpoint\ndata: /messages\n\n")
|
||||
flusher.Flush()
|
||||
ctx := r.Context()
|
||||
for {
|
||||
select {
|
||||
case p := <-pushes:
|
||||
if p.event != "" {
|
||||
fmt.Fprintf(w, "event: %s\n", p.event)
|
||||
}
|
||||
fmt.Fprintf(w, "data: %s\n\n", p.data)
|
||||
flusher.Flush()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
mux.HandleFunc("/messages", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var req map[string]interface{}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
t.Errorf("invalid request body: %v", err)
|
||||
http.Error(w, "bad body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
switch req["method"] {
|
||||
case "initialize":
|
||||
pushes <- ssePush{event: "message", data: fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"result":{"capabilities":{}}}`, req["id"])}
|
||||
case "notifications/initialized":
|
||||
case "tools/list":
|
||||
pushes <- ssePush{event: "message", data: fmt.Sprintf(`{"jsonrpc":"2.0","id":%v,"result":{"tools":[{"name":"alpha"}]}}`, req["id"])}
|
||||
}
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
tools, err := FetchTools(context.Background(), FetchOptions{
|
||||
URL: srv.URL + "/sse",
|
||||
ServerType: TransportSSE,
|
||||
HTTPClient: srv.Client(),
|
||||
Timeout: 3 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(tools) != 1 || tools[0].Name != "alpha" {
|
||||
t.Fatalf("expected [alpha], got %+v", tools)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchToolsUnsupportedType(t *testing.T) {
|
||||
defer allowLoopbackForTests(t)()
|
||||
_, err := FetchTools(context.Background(), FetchOptions{
|
||||
URL: "https://example.com",
|
||||
ServerType: "stdio",
|
||||
Timeout: time.Second,
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "Unsupported MCP server type") {
|
||||
t.Fatalf("expected unsupported-type error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchToolsEmptyURL(t *testing.T) {
|
||||
_, err := FetchTools(context.Background(), FetchOptions{URL: "", ServerType: TransportSSE})
|
||||
if err == nil || !strings.Contains(err.Error(), "Invalid url") {
|
||||
t.Fatalf("expected Invalid url error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubstituteTemplate(t *testing.T) {
|
||||
vars := map[string]string{"token": "abc123"}
|
||||
if got := substituteTemplate("Bearer ${token}", vars); got != "Bearer abc123" {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
if got := substituteTemplate("Bearer ${missing}", vars); got != "Bearer ${missing}" {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
if got := substituteTemplate("no-var", vars); got != "no-var" {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
if got := substituteTemplate("${a}-${token}", map[string]string{"a": "1", "token": "2"}); got != "1-2" {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeID(t *testing.T) {
|
||||
if got := normalizeID(json.Number("1")); got != "1" {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
if got := normalizeID(1); got != "1" {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
if got := normalizeID("abc"); got != "abc" {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
if got := normalizeID(nil); got != "" {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -302,14 +302,6 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
// message.GET("/:memory_id/:message_id/content", r.memoryHandler.GetMessageContent)
|
||||
// }
|
||||
|
||||
mcp := v1.Group("/mcp")
|
||||
{
|
||||
mcp.POST("/servers", r.mcpHandler.CreateMCPServer)
|
||||
mcp.GET("/servers", r.mcpHandler.ListMCPServers)
|
||||
mcp.PUT("/servers/:mcp_id", r.mcpHandler.UpdateMCPServer)
|
||||
mcp.DELETE("/servers/:mcp_id", r.mcpHandler.DeleteMCPServer)
|
||||
}
|
||||
|
||||
// Skill search routes
|
||||
skills := v1.Group("/skills")
|
||||
{
|
||||
@@ -394,6 +386,21 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
connector.POST("/:connector_id/test", r.connectorHandler.TestConnector)
|
||||
}
|
||||
|
||||
// MCP server routes. Per-server CRUD ships via separate PRs that
|
||||
// share the same handler/service: GET list (#15253), GET by id
|
||||
// (#15254), POST create (#15260, merged), PUT (#15261), DELETE
|
||||
// (#15262, merged). This PR adds only the non-overlapping
|
||||
// endpoints: import and test.
|
||||
mcp := v1.Group("/mcp")
|
||||
{
|
||||
mcp.POST("/servers", r.mcpHandler.CreateMCPServer)
|
||||
mcp.GET("/servers", r.mcpHandler.ListMCPServers)
|
||||
mcp.PUT("/servers/:mcp_id", r.mcpHandler.UpdateMCPServer)
|
||||
mcp.DELETE("/servers/:mcp_id", r.mcpHandler.DeleteMCPServer)
|
||||
mcp.POST("/servers/import", r.mcpHandler.ImportMCPServers)
|
||||
mcp.POST("/servers/:mcp_id/test", r.mcpHandler.TestMCPServer)
|
||||
}
|
||||
|
||||
system := v1.Group("/system")
|
||||
{
|
||||
system.GET("/configs", r.systemHandler.GetConfigs)
|
||||
|
||||
@@ -18,6 +18,7 @@ package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -27,6 +28,8 @@ import (
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/mcpclient"
|
||||
"ragflow/internal/utility"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -35,6 +38,7 @@ const (
|
||||
mcpServerTypeSSE = "sse"
|
||||
mcpServerTypeStreamableHTTP = "streamable-http"
|
||||
mcpServerNameLimit = 255
|
||||
defaultMCPFetchTimeoutSec = 10
|
||||
mcpServerDateFormat = "2006-01-02T15:04:05"
|
||||
)
|
||||
|
||||
@@ -382,6 +386,267 @@ func safeJSONMap(raw json.RawMessage) entity.JSONMap {
|
||||
return entity.JSONMap(value)
|
||||
}
|
||||
|
||||
// ---------- import + test (this PR's additions) ----------
|
||||
|
||||
// Sentinel errors mapped by the handler to Python's response codes for the
|
||||
// import / test endpoints. Per-server CRUD errors stay inside CreateMCPServer.
|
||||
var (
|
||||
// ErrMCPInvalidType mirrors Python's "Unsupported MCP server type.".
|
||||
ErrMCPInvalidType = errors.New("unsupported MCP server type")
|
||||
// ErrMCPInvalidName mirrors Python's invalid-name/length error.
|
||||
ErrMCPInvalidName = errors.New("invalid MCP name")
|
||||
// ErrMCPInvalidURL mirrors Python's "Invalid url.".
|
||||
ErrMCPInvalidURL = errors.New("invalid url")
|
||||
// ErrMCPTestFailed is returned by TestServer when the live connection or
|
||||
// tool-list fetch fails. The handler maps this to code 102 (DATA_ERROR),
|
||||
// matching Python's test_mcp which never returns HTTP 500 for fetch errors.
|
||||
ErrMCPTestFailed = errors.New("MCP test failed")
|
||||
)
|
||||
|
||||
// ImportResult is a single per-server outcome in the bulk import response,
|
||||
// matching the shape returned by Python's import_multiple.
|
||||
type ImportResult struct {
|
||||
Server string `json:"server"`
|
||||
Success bool `json:"success"`
|
||||
Action string `json:"action,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
NewName string `json:"new_name,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// ImportServers bulk-imports MCP servers from a {"mcpServers": {name: config}}
|
||||
// map. For each entry: validate type and URL, de-duplicate the name with a
|
||||
// "_N" suffix, fetch the remote tool list via mcpclient (SSRF-guarded), and
|
||||
// persist the server with tools stored under variables.tools. Mirrors
|
||||
// Python's import_multiple.
|
||||
//
|
||||
// timeoutSeconds controls how long each tool-fetch call waits; <=0 falls back
|
||||
// to the Python default of 10 s.
|
||||
func (s *MCPService) ImportServers(tenantID string, servers map[string]map[string]interface{}, timeoutSeconds float64) ([]ImportResult, error) {
|
||||
if timeoutSeconds <= 0 {
|
||||
timeoutSeconds = defaultMCPFetchTimeoutSec
|
||||
}
|
||||
timeout := time.Duration(timeoutSeconds * float64(time.Second))
|
||||
|
||||
results := make([]ImportResult, 0, len(servers))
|
||||
for serverName, config := range servers {
|
||||
url, hasURL := config["url"].(string)
|
||||
stype, hasType := config["type"].(string)
|
||||
if !hasType || !hasURL {
|
||||
results = append(results, ImportResult{Server: serverName, Success: false, Message: "Missing required fields (type or url)"})
|
||||
continue
|
||||
}
|
||||
if serverName == "" || len([]byte(serverName)) > mcpServerNameLimit {
|
||||
results = append(results, ImportResult{Server: serverName, Success: false, Message: fmt.Sprintf("Invalid MCP name or length is %d which is large than 255.", len(serverName))})
|
||||
continue
|
||||
}
|
||||
if !isValidMCPServerType(stype) {
|
||||
results = append(results, ImportResult{Server: serverName, Success: false, Message: "Unsupported MCP server type."})
|
||||
continue
|
||||
}
|
||||
|
||||
baseName := serverName
|
||||
newName, err := s.nextAvailableMCPName(baseName, tenantID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
variables := map[string]interface{}{}
|
||||
stringVars := map[string]string{}
|
||||
for k, v := range config {
|
||||
if k == "type" || k == "url" || k == "headers" {
|
||||
continue
|
||||
}
|
||||
variables[k] = v
|
||||
if sv, ok := v.(string); ok {
|
||||
stringVars[k] = sv
|
||||
}
|
||||
}
|
||||
delete(variables, "tools")
|
||||
delete(stringVars, "tools")
|
||||
|
||||
// Headers can be provided either as a top-level "headers" map
|
||||
// (preferred — matches the Python import shape) or as a flat
|
||||
// "authorization_token" string at the entry root. Both go to the
|
||||
// MCP client for tool discovery and to the persisted record so
|
||||
// configs that depend on custom auth headers survive the round
|
||||
// trip.
|
||||
headers := map[string]string{}
|
||||
headerVals := map[string]interface{}{}
|
||||
if rawHeaders, ok := config["headers"].(map[string]interface{}); ok {
|
||||
for k, v := range rawHeaders {
|
||||
if sv, ok := v.(string); ok {
|
||||
headers[k] = sv
|
||||
}
|
||||
headerVals[k] = v
|
||||
}
|
||||
}
|
||||
if token, ok := config["authorization_token"].(string); ok {
|
||||
if _, exists := headers["authorization_token"]; !exists {
|
||||
headers["authorization_token"] = token
|
||||
}
|
||||
if _, exists := headerVals["authorization_token"]; !exists {
|
||||
headerVals["authorization_token"] = token
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
tools, fetchErr := mcpclient.FetchTools(ctx, mcpclient.FetchOptions{
|
||||
URL: url,
|
||||
ServerType: stype,
|
||||
Headers: headers,
|
||||
Variables: stringVars,
|
||||
Timeout: timeout,
|
||||
})
|
||||
cancel()
|
||||
if fetchErr != nil {
|
||||
results = append(results, ImportResult{Server: baseName, Success: false, Message: fetchErr.Error()})
|
||||
continue
|
||||
}
|
||||
variables["tools"] = toolsAsMap(tools)
|
||||
|
||||
server := &entity.MCPServer{
|
||||
ID: common.GenerateUUID(),
|
||||
TenantID: tenantID,
|
||||
Name: newName,
|
||||
URL: url,
|
||||
ServerType: stype,
|
||||
Variables: entity.JSONMap(variables),
|
||||
Headers: entity.JSONMap(headerVals),
|
||||
}
|
||||
if err := s.mcpServerDAO.CreateMCPServer(server); err != nil {
|
||||
results = append(results, ImportResult{Server: serverName, Success: false, Message: "Failed to create MCP server."})
|
||||
continue
|
||||
}
|
||||
|
||||
result := ImportResult{Server: serverName, Success: true, Action: "created", ID: server.ID, NewName: newName}
|
||||
if newName != baseName {
|
||||
result.Message = fmt.Sprintf("Renamed from '%s' to '%s' avoid duplication", baseName, newName)
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *MCPService) nextAvailableMCPName(base, tenantID string) (string, error) {
|
||||
name := base
|
||||
counter := 0
|
||||
for {
|
||||
exists, err := s.mcpServerDAO.ExistsByNameAndTenant(name, tenantID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !exists {
|
||||
return name, nil
|
||||
}
|
||||
name = fmt.Sprintf("%s_%d", base, counter)
|
||||
counter++
|
||||
}
|
||||
}
|
||||
|
||||
// TestServerRequest is the body of POST /mcp/servers/:mcp_id/test. The mcp_id
|
||||
// from the URL path is threaded through to the connect call for log
|
||||
// correlation; the connection itself is opened from the request body so the
|
||||
// user can preview unsaved edits — matching Python's test_mcp.
|
||||
type TestServerRequest struct {
|
||||
URL string `json:"url"`
|
||||
ServerType string `json:"server_type"`
|
||||
Headers map[string]interface{} `json:"headers,omitempty"`
|
||||
Variables map[string]interface{} `json:"variables,omitempty"`
|
||||
Timeout float64 `json:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
// TestServer opens a live MCP session and returns the tools the server
|
||||
// advertises. Mirrors Python's test_mcp. mcpID is used for log correlation
|
||||
// only.
|
||||
func (s *MCPService) TestServer(mcpID string, req *TestServerRequest) ([]map[string]interface{}, error) {
|
||||
if req == nil || req.URL == "" {
|
||||
return nil, fmt.Errorf("%w: Invalid MCP url.", ErrMCPInvalidURL)
|
||||
}
|
||||
if !isValidMCPServerType(req.ServerType) {
|
||||
return nil, ErrMCPInvalidType
|
||||
}
|
||||
|
||||
// Run the SSRF guard up front so URL-shape failures (disallowed
|
||||
// scheme, missing host, non-public address) surface as
|
||||
// ErrMCPInvalidURL data errors instead of being swallowed inside the
|
||||
// generic FetchTools error and re-classified by the handler as a 500.
|
||||
// FetchTools repeats the check internally; the second call is cheap.
|
||||
if _, _, err := utility.AssertURLSafe(req.URL); err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrMCPInvalidURL, err.Error())
|
||||
}
|
||||
|
||||
timeoutSec := req.Timeout
|
||||
if timeoutSec <= 0 {
|
||||
timeoutSec = defaultMCPFetchTimeoutSec
|
||||
}
|
||||
timeout := time.Duration(timeoutSec * float64(time.Second))
|
||||
|
||||
headers := map[string]string{}
|
||||
for k, v := range req.Headers {
|
||||
if sv, ok := v.(string); ok {
|
||||
headers[k] = sv
|
||||
}
|
||||
}
|
||||
vars := map[string]string{}
|
||||
for k, v := range req.Variables {
|
||||
if sv, ok := v.(string); ok {
|
||||
vars[k] = sv
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
tools, err := mcpclient.FetchTools(ctx, mcpclient.FetchOptions{
|
||||
URL: req.URL,
|
||||
ServerType: req.ServerType,
|
||||
Headers: headers,
|
||||
Variables: vars,
|
||||
Timeout: timeout,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: Test MCP error (id=%s): %v", ErrMCPTestFailed, mcpID, err)
|
||||
}
|
||||
|
||||
out := make([]map[string]interface{}, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
raw := t.Raw
|
||||
if raw == nil {
|
||||
raw = map[string]interface{}{"name": t.Name}
|
||||
if t.Description != "" {
|
||||
raw["description"] = t.Description
|
||||
}
|
||||
if t.InputSchema != nil {
|
||||
raw["inputSchema"] = t.InputSchema
|
||||
}
|
||||
}
|
||||
raw["enabled"] = true
|
||||
out = append(out, raw)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// toolsAsMap mirrors Python's `{tool["name"]: tool ...}` shape used when
|
||||
// persisting variables.tools.
|
||||
func toolsAsMap(tools []mcpclient.Tool) map[string]interface{} {
|
||||
m := map[string]interface{}{}
|
||||
for _, t := range tools {
|
||||
if t.Raw != nil {
|
||||
m[t.Name] = t.Raw
|
||||
continue
|
||||
}
|
||||
entry := map[string]interface{}{"name": t.Name}
|
||||
if t.Description != "" {
|
||||
entry["description"] = t.Description
|
||||
}
|
||||
if t.InputSchema != nil {
|
||||
entry["inputSchema"] = t.InputSchema
|
||||
}
|
||||
m[t.Name] = entry
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func formatMCPServerDate(date *time.Time) *string {
|
||||
if date == nil {
|
||||
return nil
|
||||
|
||||
@@ -17,12 +17,75 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
func TestIsValidMCPServerType(t *testing.T) {
|
||||
for _, v := range []string{mcpServerTypeSSE, mcpServerTypeStreamableHTTP} {
|
||||
if !isValidMCPServerType(v) {
|
||||
t.Errorf("expected %q to be a valid MCP server type", v)
|
||||
}
|
||||
}
|
||||
for _, v := range []string{"", "stdio", "http", "SSE"} {
|
||||
if isValidMCPServerType(v) {
|
||||
t.Errorf("expected %q to be an invalid MCP server type", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerInputValidation(t *testing.T) {
|
||||
s := &MCPService{}
|
||||
|
||||
// Empty URL is rejected before any connection attempt.
|
||||
if _, err := s.TestServer("id-1", &TestServerRequest{ServerType: mcpServerTypeSSE}); !errors.Is(err, ErrMCPInvalidURL) {
|
||||
t.Errorf("expected ErrMCPInvalidURL for empty url, got %v", err)
|
||||
}
|
||||
|
||||
// nil body is treated as empty URL.
|
||||
if _, err := s.TestServer("id-1", nil); !errors.Is(err, ErrMCPInvalidURL) {
|
||||
t.Errorf("expected ErrMCPInvalidURL for nil body, got %v", err)
|
||||
}
|
||||
|
||||
// Invalid server type is rejected before connecting.
|
||||
if _, err := s.TestServer("id-1", &TestServerRequest{URL: "http://example.com/sse", ServerType: "stdio"}); !errors.Is(err, ErrMCPInvalidType) {
|
||||
t.Errorf("expected ErrMCPInvalidType for bad type, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportServersValidationErrors(t *testing.T) {
|
||||
s := &MCPService{}
|
||||
|
||||
// Missing url and type produce an in-band error per entry rather than
|
||||
// failing the batch.
|
||||
configs := map[string]map[string]interface{}{
|
||||
"missing-fields": {"foo": "bar"},
|
||||
"bad-type": {"url": "http://example.com", "type": "stdio"},
|
||||
}
|
||||
results, err := s.ImportServers("tenant-1", configs, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("expected 2 results, got %d", len(results))
|
||||
}
|
||||
for _, r := range results {
|
||||
if r.Success {
|
||||
t.Errorf("expected failure result for %q", r.Server)
|
||||
}
|
||||
if r.Server == "missing-fields" && !strings.Contains(r.Message, "Missing required fields") {
|
||||
t.Errorf("unexpected message for missing-fields: %q", r.Message)
|
||||
}
|
||||
if r.Server == "bad-type" && !strings.Contains(r.Message, "Unsupported MCP server type") {
|
||||
t.Errorf("unexpected message for bad-type: %q", r.Message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaginateMCPServersNegativeValuesMatchPythonSlice(t *testing.T) {
|
||||
servers := makeMCPServers(13)
|
||||
|
||||
|
||||
190
internal/utility/ssrf.go
Normal file
190
internal/utility/ssrf.go
Normal file
@@ -0,0 +1,190 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package utility
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AllowedURLSchemes are the schemes accepted by AssertURLSafe.
|
||||
var AllowedURLSchemes = []string{"http", "https"}
|
||||
|
||||
// LookupHost is the indirection used to resolve hostnames. Tests override it.
|
||||
var LookupHost = net.LookupHost
|
||||
|
||||
// AssertURLSafe parses rawURL and rejects it if the scheme is disallowed,
|
||||
// the host is missing, or any resolved IP is not globally routable
|
||||
// (private, loopback, link-local, multicast, reserved). Returns the hostname
|
||||
// and the first validated public IP so callers can DNS-pin the address and
|
||||
// prevent rebinding between validation and the actual TCP connection.
|
||||
//
|
||||
// Mirrors common/ssrf_guard.py:assert_url_is_safe.
|
||||
func AssertURLSafe(rawURL string) (hostname, resolvedIP string, err error) {
|
||||
parsed, perr := url.Parse(strings.TrimSpace(rawURL))
|
||||
if perr != nil {
|
||||
return "", "", fmt.Errorf("Invalid url.")
|
||||
}
|
||||
|
||||
scheme := strings.ToLower(parsed.Scheme)
|
||||
if !schemeAllowed(scheme) {
|
||||
sorted := append([]string(nil), AllowedURLSchemes...)
|
||||
sort.Strings(sorted)
|
||||
return "", "", fmt.Errorf("Disallowed URL scheme: '%s'. Only %v are allowed.", scheme, sorted)
|
||||
}
|
||||
|
||||
hostname = parsed.Hostname()
|
||||
if hostname == "" {
|
||||
return "", "", fmt.Errorf("URL is missing a host.")
|
||||
}
|
||||
|
||||
addrs, err := LookupHost(hostname)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("Could not resolve hostname '%s': %v", hostname, err)
|
||||
}
|
||||
if len(addrs) == 0 {
|
||||
return "", "", fmt.Errorf("Hostname '%s' resolved to no addresses.", hostname)
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
ip := net.ParseIP(addr)
|
||||
if ip == nil {
|
||||
return "", "", fmt.Errorf("Could not parse resolved address '%s' for hostname '%s'.", addr, hostname)
|
||||
}
|
||||
if !isGlobalIP(effectiveIP(ip)) {
|
||||
return "", "", fmt.Errorf("URL resolves to a non-public address (%s), which is not allowed.", ip.String())
|
||||
}
|
||||
if resolvedIP == "" {
|
||||
resolvedIP = ip.String()
|
||||
}
|
||||
}
|
||||
return hostname, resolvedIP, nil
|
||||
}
|
||||
|
||||
func schemeAllowed(scheme string) bool {
|
||||
for _, s := range AllowedURLSchemes {
|
||||
if s == scheme {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// effectiveIP unwraps IPv4-mapped IPv6 addresses (e.g. ::ffff:127.0.0.1) so
|
||||
// the routability check sees the IPv4 form. Without this, an attacker could
|
||||
// bypass the guard with an IPv4-mapped IPv6 representation of a private host.
|
||||
func effectiveIP(ip net.IP) net.IP {
|
||||
if v4 := ip.To4(); v4 != nil {
|
||||
return v4
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// isGlobalIP mirrors Python's ipaddress.IPv*Address.is_global: an address is
|
||||
// global if it is none of {unspecified, loopback, multicast, link-local,
|
||||
// private (including CGNAT and IPv6 ULA), benchmarking, documentation,
|
||||
// reserved}.
|
||||
func isGlobalIP(ip net.IP) bool {
|
||||
if ip == nil || ip.IsUnspecified() || ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() || ip.IsPrivate() {
|
||||
return false
|
||||
}
|
||||
if v4 := ip.To4(); v4 != nil {
|
||||
// CGNAT 100.64.0.0/10 — not flagged by IsPrivate in older Go versions.
|
||||
if v4[0] == 100 && v4[1]&0xC0 == 64 {
|
||||
return false
|
||||
}
|
||||
// 192.0.0.0/24 reserved for IETF protocol assignments.
|
||||
if v4[0] == 192 && v4[1] == 0 && v4[2] == 0 {
|
||||
return false
|
||||
}
|
||||
// 192.0.2.0/24, 198.51.100.0/24, 203.0.113.0/24 documentation (TEST-NET-1/2/3).
|
||||
if v4[0] == 192 && v4[1] == 0 && v4[2] == 2 {
|
||||
return false
|
||||
}
|
||||
if v4[0] == 198 && v4[1] == 51 && v4[2] == 100 {
|
||||
return false
|
||||
}
|
||||
if v4[0] == 203 && v4[1] == 0 && v4[2] == 113 {
|
||||
return false
|
||||
}
|
||||
// 198.18.0.0/15 benchmarking.
|
||||
if v4[0] == 198 && (v4[1] == 18 || v4[1] == 19) {
|
||||
return false
|
||||
}
|
||||
// 240.0.0.0/4 reserved (excluding 255.255.255.255 which IsUnspecified misses).
|
||||
if v4[0] >= 240 {
|
||||
return false
|
||||
}
|
||||
} else if v6 := ip.To16(); v6 != nil {
|
||||
// 2001:db8::/32 documentation prefix.
|
||||
if v6[0] == 0x20 && v6[1] == 0x01 && v6[2] == 0x0d && v6[3] == 0xb8 {
|
||||
return false
|
||||
}
|
||||
// 100::/64 discard-only address block.
|
||||
if v6[0] == 0x01 && v6[1] == 0x00 && allZero(v6[2:8]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func allZero(b []byte) bool {
|
||||
for _, x := range b {
|
||||
if x != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// PinnedHTTPClient returns an HTTP client whose Transport rewrites every
|
||||
// outbound dial for hostname:port to resolvedIP:port, closing the TOCTOU
|
||||
// window between AssertURLSafe and the actual TCP connection. Pins are
|
||||
// scoped to this client only.
|
||||
func PinnedHTTPClient(hostname, resolvedIP string, timeout time.Duration) *http.Client {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: timeout,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
transport := &http.Transport{
|
||||
// Disable environment proxy: HTTP_PROXY / HTTPS_PROXY would route
|
||||
// the connection through the proxy host instead of the pinned
|
||||
// resolvedIP, bypassing the SSRF guard.
|
||||
Proxy: nil,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, splitErr := net.SplitHostPort(addr)
|
||||
if splitErr == nil && host == hostname && resolvedIP != "" {
|
||||
return dialer.DialContext(ctx, network, net.JoinHostPort(resolvedIP, port))
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
},
|
||||
TLSHandshakeTimeout: timeout,
|
||||
ResponseHeaderTimeout: timeout,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ForceAttemptHTTP2: false,
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: timeout,
|
||||
}
|
||||
}
|
||||
156
internal/utility/ssrf_test.go
Normal file
156
internal/utility/ssrf_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package utility
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAssertURLSafe(t *testing.T) {
|
||||
orig := LookupHost
|
||||
defer func() { LookupHost = orig }()
|
||||
|
||||
type want struct {
|
||||
errSubstr string
|
||||
host string
|
||||
ip string
|
||||
}
|
||||
cases := []struct {
|
||||
name string
|
||||
url string
|
||||
ips []string
|
||||
err string
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "public IPv4",
|
||||
url: "https://example.com/path",
|
||||
ips: []string{"93.184.216.34"},
|
||||
want: want{host: "example.com", ip: "93.184.216.34"},
|
||||
},
|
||||
{
|
||||
name: "loopback rejected",
|
||||
url: "http://localhost/x",
|
||||
ips: []string{"127.0.0.1"},
|
||||
want: want{errSubstr: "non-public address"},
|
||||
},
|
||||
{
|
||||
name: "private 10.x rejected",
|
||||
url: "http://internal/x",
|
||||
ips: []string{"10.0.0.5"},
|
||||
want: want{errSubstr: "non-public address"},
|
||||
},
|
||||
{
|
||||
name: "private 192.168.x rejected",
|
||||
url: "http://router/x",
|
||||
ips: []string{"192.168.1.1"},
|
||||
want: want{errSubstr: "non-public address"},
|
||||
},
|
||||
{
|
||||
name: "CGNAT 100.64/10 rejected",
|
||||
url: "http://carrier/x",
|
||||
ips: []string{"100.64.1.1"},
|
||||
want: want{errSubstr: "non-public address"},
|
||||
},
|
||||
{
|
||||
name: "IPv4-mapped IPv6 loopback rejected",
|
||||
url: "http://[::ffff:127.0.0.1]/x",
|
||||
ips: []string{"::ffff:127.0.0.1"},
|
||||
want: want{errSubstr: "non-public address"},
|
||||
},
|
||||
{
|
||||
name: "link-local IPv6 rejected",
|
||||
url: "http://[fe80::1]/x",
|
||||
ips: []string{"fe80::1"},
|
||||
want: want{errSubstr: "non-public address"},
|
||||
},
|
||||
{
|
||||
name: "documentation 2001:db8 rejected",
|
||||
url: "http://[2001:db8::1]/x",
|
||||
ips: []string{"2001:db8::1"},
|
||||
want: want{errSubstr: "non-public address"},
|
||||
},
|
||||
{
|
||||
name: "disallowed scheme ftp",
|
||||
url: "ftp://example.com/",
|
||||
ips: []string{"93.184.216.34"},
|
||||
want: want{errSubstr: "Disallowed URL scheme"},
|
||||
},
|
||||
{
|
||||
name: "missing host",
|
||||
url: "http:///path",
|
||||
want: want{errSubstr: "missing a host"},
|
||||
},
|
||||
{
|
||||
name: "resolution fails",
|
||||
url: "http://nosuchhost.test/x",
|
||||
err: "no such host",
|
||||
want: want{errSubstr: "Could not resolve"},
|
||||
},
|
||||
{
|
||||
name: "all addresses must be public",
|
||||
url: "http://mixed.example.com/",
|
||||
ips: []string{"93.184.216.34", "127.0.0.1"},
|
||||
want: want{errSubstr: "non-public address"},
|
||||
},
|
||||
{
|
||||
name: "literal IPv4 loopback rejected",
|
||||
url: "http://127.0.0.1/",
|
||||
ips: []string{"127.0.0.1"},
|
||||
want: want{errSubstr: "non-public address"},
|
||||
},
|
||||
{
|
||||
name: "documentation TEST-NET-3 rejected",
|
||||
url: "http://stub/",
|
||||
ips: []string{"203.0.113.5"},
|
||||
want: want{errSubstr: "non-public address"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
LookupHost = func(host string) ([]string, error) {
|
||||
if tc.err != "" {
|
||||
return nil, &mockErr{tc.err}
|
||||
}
|
||||
return tc.ips, nil
|
||||
}
|
||||
host, ip, err := AssertURLSafe(tc.url)
|
||||
if tc.want.errSubstr != "" {
|
||||
if err == nil || !strings.Contains(err.Error(), tc.want.errSubstr) {
|
||||
t.Fatalf("expected error containing %q, got %v", tc.want.errSubstr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if host != tc.want.host {
|
||||
t.Errorf("host: got %q, want %q", host, tc.want.host)
|
||||
}
|
||||
if ip != tc.want.ip {
|
||||
t.Errorf("ip: got %q, want %q", ip, tc.want.ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockErr struct{ s string }
|
||||
|
||||
func (e *mockErr) Error() string { return e.s }
|
||||
Reference in New Issue
Block a user