feat: add Go MCP server update API (#15261)

## What

#15240
implementation for PUT /api/v1/mcp/servers/:mcp_id

## Changes

- Adds the Go implementation for `PUT /api/v1/mcp/servers/:mcp_id`.
- Wires MCP service and handler into the Go server/router for the update
route.
- Preserves Python-style behavior for ownership checks, partial update
fields, MCP type/name/URL validation, `headers`/`variables`
normalization, and tool metadata scrubbing.
This commit is contained in:
Alexander Laurent
2026-06-01 21:58:44 -10:00
committed by GitHub
parent 2819d0ea24
commit a98889cd76
8 changed files with 313 additions and 3 deletions

View File

@@ -14,7 +14,7 @@ load_env_file() {
echo "Loading environment variables from: $env_file"
# Source the .env file
set -a
source "$env_file"
source "$env_file"
set +a
else
echo "Warning: .env file not found at: $env_file"

View File

@@ -12,7 +12,7 @@ Integrate an RSS feed as a data source.
---
This guide explains how to add an RSS feed as a data source to your dataset in RAGFlow.
This guide explains how to add an RSS feed as a data source to your dataset in RAGFlow.
RSS (Really Simple Syndication) is a standardized web feed format used to publish frequently updated content—such as blog entries, news headlines, and podcasts. By connecting an RSS feed to RAGFlow, you can automatically ingest new content from a website as soon as it is published.

View File

@@ -110,6 +110,15 @@ func (dao *MCPServerDAO) ListMCPServers(tenantID string, ids []string, keywords
return servers, total, nil
}
// GetByIDAndTenant returns an MCP server owned by a tenant.
func (dao *MCPServerDAO) GetByIDAndTenant(id, tenantID string) (*entity.MCPServer, error) {
var server entity.MCPServer
if err := DB.Where("id = ? AND tenant_id = ?", id, tenantID).First(&server).Error; err != nil {
return nil, err
}
return &server, nil
}
// DeleteMCPServer deletes an MCP server owned by a tenant.
func (dao *MCPServerDAO) DeleteMCPServer(id, tenantID string) (bool, error) {
result := DB.Where("id = ? AND tenant_id = ?", id, tenantID).Delete(&entity.MCPServer{})
@@ -119,6 +128,17 @@ func (dao *MCPServerDAO) DeleteMCPServer(id, tenantID string) (bool, error) {
return result.RowsAffected > 0, nil
}
// UpdateMCPServer updates an MCP server owned by a tenant.
func (dao *MCPServerDAO) UpdateMCPServer(id, tenantID string, updates map[string]interface{}) (bool, error) {
result := DB.Model(&entity.MCPServer{}).
Where("id = ? AND tenant_id = ?", id, tenantID).
Updates(updates)
if result.Error != nil {
return false, result.Error
}
return result.RowsAffected > 0, nil
}
func mcpServerOrderColumn(orderby string) (string, error) {
switch orderby {
case "id":

View File

@@ -21,10 +21,12 @@ import (
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"ragflow/internal/common"
"ragflow/internal/entity"
"ragflow/internal/service"
)
@@ -32,6 +34,7 @@ const (
defaultMCPServerPage = 0
defaultMCPServerPageSize = 0
maxMCPServerPageSize = 100
mcpServerDateFormat = "2006-01-02T15:04:05"
)
// MCPHandler handles MCP server requests.
@@ -39,6 +42,21 @@ type MCPHandler struct {
mcpService *service.MCPService
}
type mcpServerResponse struct {
ID string `json:"id"`
Name string `json:"name"`
TenantID string `json:"tenant_id"`
URL string `json:"url"`
ServerType string `json:"server_type"`
Description *string `json:"description"`
Variables map[string]interface{} `json:"variables"`
Headers map[string]interface{} `json:"headers"`
CreateTime *int64 `json:"create_time"`
CreateDate string `json:"create_date"`
UpdateTime *int64 `json:"update_time"`
UpdateDate string `json:"update_date"`
}
// NewMCPHandler creates an MCP handler.
func NewMCPHandler(mcpService *service.MCPService) *MCPHandler {
return &MCPHandler{
@@ -118,6 +136,34 @@ func (h *MCPHandler) ListMCPServers(c *gin.Context) {
})
}
// UpdateMCPServer updates an MCP server for the current user.
func (h *MCPHandler) UpdateMCPServer(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
mcpID := c.Param("mcp_id")
var req service.UpdateMCPServerRequest
if err := c.ShouldBindJSON(&req); err != nil {
jsonError(c, common.CodeDataError, err.Error())
return
}
result, code, err := h.mcpService.UpdateMCPServer(user.ID, mcpID, req)
if err != nil {
jsonError(c, code, err.Error())
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"message": "success",
"data": newMCPServerResponse(result),
})
}
// DeleteMCPServer deletes an MCP server for the current user.
func (h *MCPHandler) DeleteMCPServer(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
@@ -139,6 +185,34 @@ func (h *MCPHandler) DeleteMCPServer(c *gin.Context) {
})
}
func newMCPServerResponse(server *entity.MCPServer) *mcpServerResponse {
if server == nil {
return nil
}
return &mcpServerResponse{
ID: server.ID,
Name: server.Name,
TenantID: server.TenantID,
URL: server.URL,
ServerType: server.ServerType,
Description: server.Description,
Variables: map[string]interface{}(server.Variables),
Headers: map[string]interface{}(server.Headers),
CreateTime: server.CreateTime,
CreateDate: formatMCPServerDate(server.CreateDate),
UpdateTime: server.UpdateTime,
UpdateDate: formatMCPServerDate(server.UpdateDate),
}
}
func formatMCPServerDate(date *time.Time) string {
if date == nil {
return ""
}
return date.Format(mcpServerDateFormat)
}
func parseMCPServerPage(value string) (int, error) {
if value == "" {
return defaultMCPServerPage, nil

View File

@@ -0,0 +1,71 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package handler
import (
"bytes"
"encoding/json"
"testing"
"time"
"ragflow/internal/entity"
)
func TestNewMCPServerResponsePreservesNullDescriptionAndFormatsDates(t *testing.T) {
createTime := int64(1779949820000)
updateTime := int64(1779953420000)
location := time.FixedZone("UTC+8", 8*60*60)
createDate := time.Date(2026, 5, 28, 10, 30, 20, 0, location)
updateDate := time.Date(2026, 5, 28, 11, 30, 20, 0, location)
response := newMCPServerResponse(&entity.MCPServer{
ID: "mcp-id",
Name: "server",
TenantID: "tenant-id",
URL: "https://example.com/mcp",
ServerType: "sse",
Variables: entity.JSONMap{"tools": map[string]interface{}{}},
Headers: entity.JSONMap{"Authorization": "Bearer token"},
BaseModel: entity.BaseModel{
CreateTime: &createTime,
CreateDate: &createDate,
UpdateTime: &updateTime,
UpdateDate: &updateDate,
},
})
if response.Description != nil {
t.Fatalf("description = %q, want nil", *response.Description)
}
if response.CreateDate != "2026-05-28T10:30:20" {
t.Fatalf("create_date = %q, want date without timezone", response.CreateDate)
}
if response.UpdateDate != "2026-05-28T11:30:20" {
t.Fatalf("update_date = %q, want date without timezone", response.UpdateDate)
}
payload, err := json.Marshal(response)
if err != nil {
t.Fatalf("marshal response: %v", err)
}
if !bytes.Contains(payload, []byte(`"description":null`)) {
t.Fatalf("payload %s does not preserve description:null", payload)
}
if bytes.Contains(payload, []byte(`+08:00`)) {
t.Fatalf("payload %s includes timezone in date fields", payload)
}
}

View File

@@ -280,6 +280,7 @@ func (r *Router) Setup(engine *gin.Engine) {
{
mcp.POST("/servers", r.mcpHandler.CreateMCPServer)
mcp.GET("/servers", r.mcpHandler.ListMCPServers)
mcp.PUT("/servers/:mcp_id", r.mcpHandler.UpdateMCPServer)
mcp.DELETE("/servers/:mcp_id", r.mcpHandler.DeleteMCPServer)
}

View File

@@ -21,11 +21,14 @@ import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"ragflow/internal/common"
"ragflow/internal/dao"
"ragflow/internal/entity"
"gorm.io/gorm"
)
const (
@@ -71,6 +74,9 @@ type CreateMCPServerResponse struct {
Headers entity.JSONMap `json:"headers"`
}
// UpdateMCPServerRequest is the raw request payload for updating an MCP server.
type UpdateMCPServerRequest map[string]json.RawMessage
// MCPServerListItem is an MCP server item in the list response.
type MCPServerListItem struct {
ID string `json:"id"`
@@ -147,6 +153,120 @@ func (s *MCPService) CreateMCPServer(tenantID string, req CreateMCPServerRequest
}, common.CodeSuccess, nil
}
// UpdateMCPServer updates an MCP server owned by a tenant.
func (s *MCPService) UpdateMCPServer(tenantID, mcpID string, req UpdateMCPServerRequest) (*entity.MCPServer, common.ErrorCode, error) {
server, err := s.mcpServerDAO.GetByIDAndTenant(mcpID, tenantID)
if err != nil {
if isMCPServerNotFound(err) {
return nil, common.CodeDataError, mcpServerNotFoundError(mcpID, tenantID)
}
return nil, common.CodeServerError, fmt.Errorf("failed to get MCP server %s: %w", mcpID, err)
}
if server == nil {
return nil, common.CodeDataError, mcpServerNotFoundError(mcpID, tenantID)
}
serverType := server.ServerType
serverTypeProvided := false
if value, ok, err := optionalString(req, "server_type"); err != nil {
return nil, common.CodeDataError, err
} else if ok {
serverType = value
serverTypeProvided = true
}
if serverTypeProvided && !isValidMCPServerType(serverType) {
return nil, common.CodeDataError, errors.New("Unsupported MCP server type.")
}
serverName := server.Name
serverNameProvided := false
if value, ok, err := optionalString(req, "name"); err != nil {
return nil, common.CodeDataError, err
} else if ok {
serverName = value
serverNameProvided = true
}
if serverName != "" && len([]byte(serverName)) > mcpServerNameLimit {
return nil, common.CodeDataError, fmt.Errorf("Invalid MCP name or length is %d which is large than 255.", len([]byte(serverName)))
}
serverURL := server.URL
serverURLProvided := false
if value, ok, err := optionalString(req, "url"); err != nil {
return nil, common.CodeDataError, err
} else if ok {
serverURL = strings.TrimSpace(value)
if serverURL == "" {
return nil, common.CodeDataError, errors.New("Invalid url.")
}
serverURLProvided = true
}
if serverURL == "" {
return nil, common.CodeDataError, errors.New("Invalid url.")
}
headers := server.Headers
if raw, ok := req["headers"]; ok {
headers = safeJSONMap(raw)
}
if headers == nil {
headers = entity.JSONMap{}
}
variables := server.Variables
if raw, ok := req["variables"]; ok {
variables = safeJSONMap(raw)
}
if variables == nil {
variables = entity.JSONMap{}
}
delete(variables, "tools")
variables["tools"] = map[string]interface{}{}
updates := map[string]interface{}{
"id": mcpID,
"tenant_id": tenantID,
"headers": headers,
"variables": variables,
}
if serverNameProvided {
updates["name"] = serverName
}
if serverURLProvided {
updates["url"] = serverURL
}
if serverTypeProvided {
updates["server_type"] = serverType
}
if raw, ok := req["description"]; ok {
description, err := optionalNullableString(raw, "description")
if err != nil {
return nil, common.CodeDataError, err
}
updates["description"] = description
}
if _, err := s.mcpServerDAO.UpdateMCPServer(mcpID, tenantID, updates); err != nil {
return nil, common.CodeServerError, err
}
updatedServer, err := s.mcpServerDAO.GetByIDAndTenant(mcpID, tenantID)
if err != nil {
if isMCPServerNotFound(err) {
return nil, common.CodeDataError, mcpServerNotFoundError(mcpID, tenantID)
}
return nil, common.CodeServerError, fmt.Errorf("failed to fetch updated MCP server %s: %w", mcpID, err)
}
if updatedServer == nil {
return nil, common.CodeDataError, mcpServerNotFoundError(mcpID, tenantID)
}
return updatedServer, common.CodeSuccess, nil
}
func isMCPServerNotFound(err error) bool {
return errors.Is(err, gorm.ErrRecordNotFound)
}
// ListMCPServers lists MCP servers owned by a tenant.
func (s *MCPService) ListMCPServers(tenantID string, ids []string, keywords string, page, pageSize int, orderby string, desc bool) (*ListMCPServersResponse, common.ErrorCode, error) {
servers, total, err := s.mcpServerDAO.ListMCPServers(tenantID, ids, keywords, orderby, desc)
@@ -215,6 +335,31 @@ func isValidMCPServerType(serverType string) bool {
return serverType == mcpServerTypeSSE || serverType == mcpServerTypeStreamableHTTP
}
func optionalString(req UpdateMCPServerRequest, key string) (string, bool, error) {
raw, ok := req[key]
if !ok || bytes.Equal(bytes.TrimSpace(raw), []byte("null")) {
return "", false, nil
}
var value string
if err := json.Unmarshal(raw, &value); err != nil {
return "", true, fmt.Errorf("%s must be a string", key)
}
return value, true, nil
}
func optionalNullableString(raw json.RawMessage, key string) (*string, error) {
if bytes.Equal(bytes.TrimSpace(raw), []byte("null")) {
return nil, nil
}
var value string
if err := json.Unmarshal(raw, &value); err != nil {
return nil, fmt.Errorf("%s must be a string", key)
}
return &value, nil
}
func safeJSONMap(raw json.RawMessage) entity.JSONMap {
raw = bytes.TrimSpace(raw)
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {

View File

@@ -310,4 +310,3 @@ func TestNewOIDCMissingIssuer(t *testing.T) {
t.Errorf("expected Missing issuer error, got %v", err)
}
}