From a98889cd7651a1fe99b6087cd84c67fec69725a4 Mon Sep 17 00:00:00 2001 From: Alexander Laurent <57456290+Helios531@users.noreply.github.com> Date: Mon, 1 Jun 2026 21:58:44 -1000 Subject: [PATCH] feat: add Go MCP server update API (#15261) ## What #15240 implementation for PUT /api/v1/mcp/servers/:mcp_id ## Changes - Adds the Go implementation for `PUT /api/v1/mcp/servers/:mcp_id`. - Wires MCP service and handler into the Go server/router for the update route. - Preserves Python-style behavior for ownership checks, partial update fields, MCP type/name/URL validation, `headers`/`variables` normalization, and tool metadata scrubbing. --- docker/launch_admin_service.sh | 2 +- .../guides/dataset/add_data_source/add_rss.md | 2 +- internal/dao/mcp.go | 20 +++ internal/handler/mcp.go | 74 +++++++++ internal/handler/mcp_test.go | 71 +++++++++ internal/router/router.go | 1 + internal/service/mcp.go | 145 ++++++++++++++++++ internal/utility/oauth/client_test.go | 1 - 8 files changed, 313 insertions(+), 3 deletions(-) create mode 100644 internal/handler/mcp_test.go diff --git a/docker/launch_admin_service.sh b/docker/launch_admin_service.sh index 0afba88b2d..7d906cd259 100755 --- a/docker/launch_admin_service.sh +++ b/docker/launch_admin_service.sh @@ -14,7 +14,7 @@ load_env_file() { echo "Loading environment variables from: $env_file" # Source the .env file set -a - source "$env_file" + source "$env_file" set +a else echo "Warning: .env file not found at: $env_file" diff --git a/docs/guides/dataset/add_data_source/add_rss.md b/docs/guides/dataset/add_data_source/add_rss.md index 8d5d3fd6e0..be060bd7be 100644 --- a/docs/guides/dataset/add_data_source/add_rss.md +++ b/docs/guides/dataset/add_data_source/add_rss.md @@ -12,7 +12,7 @@ Integrate an RSS feed as a data source. --- -This guide explains how to add an RSS feed as a data source to your dataset in RAGFlow. +This guide explains how to add an RSS feed as a data source to your dataset in RAGFlow. RSS (Really Simple Syndication) is a standardized web feed format used to publish frequently updated content—such as blog entries, news headlines, and podcasts. By connecting an RSS feed to RAGFlow, you can automatically ingest new content from a website as soon as it is published. diff --git a/internal/dao/mcp.go b/internal/dao/mcp.go index 65ecd8827c..e95fb9b0f2 100644 --- a/internal/dao/mcp.go +++ b/internal/dao/mcp.go @@ -110,6 +110,15 @@ func (dao *MCPServerDAO) ListMCPServers(tenantID string, ids []string, keywords return servers, total, nil } +// GetByIDAndTenant returns an MCP server owned by a tenant. +func (dao *MCPServerDAO) GetByIDAndTenant(id, tenantID string) (*entity.MCPServer, error) { + var server entity.MCPServer + if err := DB.Where("id = ? AND tenant_id = ?", id, tenantID).First(&server).Error; err != nil { + return nil, err + } + return &server, nil +} + // DeleteMCPServer deletes an MCP server owned by a tenant. func (dao *MCPServerDAO) DeleteMCPServer(id, tenantID string) (bool, error) { result := DB.Where("id = ? AND tenant_id = ?", id, tenantID).Delete(&entity.MCPServer{}) @@ -119,6 +128,17 @@ func (dao *MCPServerDAO) DeleteMCPServer(id, tenantID string) (bool, error) { return result.RowsAffected > 0, nil } +// UpdateMCPServer updates an MCP server owned by a tenant. +func (dao *MCPServerDAO) UpdateMCPServer(id, tenantID string, updates map[string]interface{}) (bool, error) { + result := DB.Model(&entity.MCPServer{}). + Where("id = ? AND tenant_id = ?", id, tenantID). + Updates(updates) + if result.Error != nil { + return false, result.Error + } + return result.RowsAffected > 0, nil +} + func mcpServerOrderColumn(orderby string) (string, error) { switch orderby { case "id": diff --git a/internal/handler/mcp.go b/internal/handler/mcp.go index 4a67e68aef..7071ec4935 100644 --- a/internal/handler/mcp.go +++ b/internal/handler/mcp.go @@ -21,10 +21,12 @@ import ( "net/http" "strconv" "strings" + "time" "github.com/gin-gonic/gin" "ragflow/internal/common" + "ragflow/internal/entity" "ragflow/internal/service" ) @@ -32,6 +34,7 @@ const ( defaultMCPServerPage = 0 defaultMCPServerPageSize = 0 maxMCPServerPageSize = 100 + mcpServerDateFormat = "2006-01-02T15:04:05" ) // MCPHandler handles MCP server requests. @@ -39,6 +42,21 @@ type MCPHandler struct { mcpService *service.MCPService } +type mcpServerResponse struct { + ID string `json:"id"` + Name string `json:"name"` + TenantID string `json:"tenant_id"` + URL string `json:"url"` + ServerType string `json:"server_type"` + Description *string `json:"description"` + Variables map[string]interface{} `json:"variables"` + Headers map[string]interface{} `json:"headers"` + CreateTime *int64 `json:"create_time"` + CreateDate string `json:"create_date"` + UpdateTime *int64 `json:"update_time"` + UpdateDate string `json:"update_date"` +} + // NewMCPHandler creates an MCP handler. func NewMCPHandler(mcpService *service.MCPService) *MCPHandler { return &MCPHandler{ @@ -118,6 +136,34 @@ func (h *MCPHandler) ListMCPServers(c *gin.Context) { }) } +// UpdateMCPServer updates an MCP server for the current user. +func (h *MCPHandler) UpdateMCPServer(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + mcpID := c.Param("mcp_id") + var req service.UpdateMCPServerRequest + if err := c.ShouldBindJSON(&req); err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + + result, code, err := h.mcpService.UpdateMCPServer(user.ID, mcpID, req) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": "success", + "data": newMCPServerResponse(result), + }) +} + // DeleteMCPServer deletes an MCP server for the current user. func (h *MCPHandler) DeleteMCPServer(c *gin.Context) { user, errorCode, errorMessage := GetUser(c) @@ -139,6 +185,34 @@ func (h *MCPHandler) DeleteMCPServer(c *gin.Context) { }) } +func newMCPServerResponse(server *entity.MCPServer) *mcpServerResponse { + if server == nil { + return nil + } + + return &mcpServerResponse{ + ID: server.ID, + Name: server.Name, + TenantID: server.TenantID, + URL: server.URL, + ServerType: server.ServerType, + Description: server.Description, + Variables: map[string]interface{}(server.Variables), + Headers: map[string]interface{}(server.Headers), + CreateTime: server.CreateTime, + CreateDate: formatMCPServerDate(server.CreateDate), + UpdateTime: server.UpdateTime, + UpdateDate: formatMCPServerDate(server.UpdateDate), + } +} + +func formatMCPServerDate(date *time.Time) string { + if date == nil { + return "" + } + return date.Format(mcpServerDateFormat) +} + func parseMCPServerPage(value string) (int, error) { if value == "" { return defaultMCPServerPage, nil diff --git a/internal/handler/mcp_test.go b/internal/handler/mcp_test.go new file mode 100644 index 0000000000..7538bbfa7c --- /dev/null +++ b/internal/handler/mcp_test.go @@ -0,0 +1,71 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "bytes" + "encoding/json" + "testing" + "time" + + "ragflow/internal/entity" +) + +func TestNewMCPServerResponsePreservesNullDescriptionAndFormatsDates(t *testing.T) { + createTime := int64(1779949820000) + updateTime := int64(1779953420000) + location := time.FixedZone("UTC+8", 8*60*60) + createDate := time.Date(2026, 5, 28, 10, 30, 20, 0, location) + updateDate := time.Date(2026, 5, 28, 11, 30, 20, 0, location) + + response := newMCPServerResponse(&entity.MCPServer{ + ID: "mcp-id", + Name: "server", + TenantID: "tenant-id", + URL: "https://example.com/mcp", + ServerType: "sse", + Variables: entity.JSONMap{"tools": map[string]interface{}{}}, + Headers: entity.JSONMap{"Authorization": "Bearer token"}, + BaseModel: entity.BaseModel{ + CreateTime: &createTime, + CreateDate: &createDate, + UpdateTime: &updateTime, + UpdateDate: &updateDate, + }, + }) + + if response.Description != nil { + t.Fatalf("description = %q, want nil", *response.Description) + } + if response.CreateDate != "2026-05-28T10:30:20" { + t.Fatalf("create_date = %q, want date without timezone", response.CreateDate) + } + if response.UpdateDate != "2026-05-28T11:30:20" { + t.Fatalf("update_date = %q, want date without timezone", response.UpdateDate) + } + + payload, err := json.Marshal(response) + if err != nil { + t.Fatalf("marshal response: %v", err) + } + if !bytes.Contains(payload, []byte(`"description":null`)) { + t.Fatalf("payload %s does not preserve description:null", payload) + } + if bytes.Contains(payload, []byte(`+08:00`)) { + t.Fatalf("payload %s includes timezone in date fields", payload) + } +} diff --git a/internal/router/router.go b/internal/router/router.go index 498555f0ff..9ea1075b4d 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -280,6 +280,7 @@ func (r *Router) Setup(engine *gin.Engine) { { mcp.POST("/servers", r.mcpHandler.CreateMCPServer) mcp.GET("/servers", r.mcpHandler.ListMCPServers) + mcp.PUT("/servers/:mcp_id", r.mcpHandler.UpdateMCPServer) mcp.DELETE("/servers/:mcp_id", r.mcpHandler.DeleteMCPServer) } diff --git a/internal/service/mcp.go b/internal/service/mcp.go index 41f0bc5bd0..f30d87959c 100644 --- a/internal/service/mcp.go +++ b/internal/service/mcp.go @@ -21,11 +21,14 @@ import ( "encoding/json" "errors" "fmt" + "strings" "time" "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/entity" + + "gorm.io/gorm" ) const ( @@ -71,6 +74,9 @@ type CreateMCPServerResponse struct { Headers entity.JSONMap `json:"headers"` } +// UpdateMCPServerRequest is the raw request payload for updating an MCP server. +type UpdateMCPServerRequest map[string]json.RawMessage + // MCPServerListItem is an MCP server item in the list response. type MCPServerListItem struct { ID string `json:"id"` @@ -147,6 +153,120 @@ func (s *MCPService) CreateMCPServer(tenantID string, req CreateMCPServerRequest }, common.CodeSuccess, nil } +// UpdateMCPServer updates an MCP server owned by a tenant. +func (s *MCPService) UpdateMCPServer(tenantID, mcpID string, req UpdateMCPServerRequest) (*entity.MCPServer, common.ErrorCode, error) { + server, err := s.mcpServerDAO.GetByIDAndTenant(mcpID, tenantID) + if err != nil { + if isMCPServerNotFound(err) { + return nil, common.CodeDataError, mcpServerNotFoundError(mcpID, tenantID) + } + return nil, common.CodeServerError, fmt.Errorf("failed to get MCP server %s: %w", mcpID, err) + } + if server == nil { + return nil, common.CodeDataError, mcpServerNotFoundError(mcpID, tenantID) + } + + serverType := server.ServerType + serverTypeProvided := false + if value, ok, err := optionalString(req, "server_type"); err != nil { + return nil, common.CodeDataError, err + } else if ok { + serverType = value + serverTypeProvided = true + } + if serverTypeProvided && !isValidMCPServerType(serverType) { + return nil, common.CodeDataError, errors.New("Unsupported MCP server type.") + } + + serverName := server.Name + serverNameProvided := false + if value, ok, err := optionalString(req, "name"); err != nil { + return nil, common.CodeDataError, err + } else if ok { + serverName = value + serverNameProvided = true + } + if serverName != "" && len([]byte(serverName)) > mcpServerNameLimit { + return nil, common.CodeDataError, fmt.Errorf("Invalid MCP name or length is %d which is large than 255.", len([]byte(serverName))) + } + + serverURL := server.URL + serverURLProvided := false + if value, ok, err := optionalString(req, "url"); err != nil { + return nil, common.CodeDataError, err + } else if ok { + serverURL = strings.TrimSpace(value) + if serverURL == "" { + return nil, common.CodeDataError, errors.New("Invalid url.") + } + serverURLProvided = true + } + if serverURL == "" { + return nil, common.CodeDataError, errors.New("Invalid url.") + } + + headers := server.Headers + if raw, ok := req["headers"]; ok { + headers = safeJSONMap(raw) + } + if headers == nil { + headers = entity.JSONMap{} + } + + variables := server.Variables + if raw, ok := req["variables"]; ok { + variables = safeJSONMap(raw) + } + if variables == nil { + variables = entity.JSONMap{} + } + delete(variables, "tools") + variables["tools"] = map[string]interface{}{} + + updates := map[string]interface{}{ + "id": mcpID, + "tenant_id": tenantID, + "headers": headers, + "variables": variables, + } + if serverNameProvided { + updates["name"] = serverName + } + if serverURLProvided { + updates["url"] = serverURL + } + if serverTypeProvided { + updates["server_type"] = serverType + } + if raw, ok := req["description"]; ok { + description, err := optionalNullableString(raw, "description") + if err != nil { + return nil, common.CodeDataError, err + } + updates["description"] = description + } + + if _, err := s.mcpServerDAO.UpdateMCPServer(mcpID, tenantID, updates); err != nil { + return nil, common.CodeServerError, err + } + + updatedServer, err := s.mcpServerDAO.GetByIDAndTenant(mcpID, tenantID) + if err != nil { + if isMCPServerNotFound(err) { + return nil, common.CodeDataError, mcpServerNotFoundError(mcpID, tenantID) + } + return nil, common.CodeServerError, fmt.Errorf("failed to fetch updated MCP server %s: %w", mcpID, err) + } + if updatedServer == nil { + return nil, common.CodeDataError, mcpServerNotFoundError(mcpID, tenantID) + } + return updatedServer, common.CodeSuccess, nil +} + +func isMCPServerNotFound(err error) bool { + return errors.Is(err, gorm.ErrRecordNotFound) +} + // ListMCPServers lists MCP servers owned by a tenant. func (s *MCPService) ListMCPServers(tenantID string, ids []string, keywords string, page, pageSize int, orderby string, desc bool) (*ListMCPServersResponse, common.ErrorCode, error) { servers, total, err := s.mcpServerDAO.ListMCPServers(tenantID, ids, keywords, orderby, desc) @@ -215,6 +335,31 @@ func isValidMCPServerType(serverType string) bool { return serverType == mcpServerTypeSSE || serverType == mcpServerTypeStreamableHTTP } +func optionalString(req UpdateMCPServerRequest, key string) (string, bool, error) { + raw, ok := req[key] + if !ok || bytes.Equal(bytes.TrimSpace(raw), []byte("null")) { + return "", false, nil + } + + var value string + if err := json.Unmarshal(raw, &value); err != nil { + return "", true, fmt.Errorf("%s must be a string", key) + } + return value, true, nil +} + +func optionalNullableString(raw json.RawMessage, key string) (*string, error) { + if bytes.Equal(bytes.TrimSpace(raw), []byte("null")) { + return nil, nil + } + + var value string + if err := json.Unmarshal(raw, &value); err != nil { + return nil, fmt.Errorf("%s must be a string", key) + } + return &value, nil +} + func safeJSONMap(raw json.RawMessage) entity.JSONMap { raw = bytes.TrimSpace(raw) if len(raw) == 0 || bytes.Equal(raw, []byte("null")) { diff --git a/internal/utility/oauth/client_test.go b/internal/utility/oauth/client_test.go index 7bfe02f358..69b6b1543f 100644 --- a/internal/utility/oauth/client_test.go +++ b/internal/utility/oauth/client_test.go @@ -310,4 +310,3 @@ func TestNewOIDCMissingIssuer(t *testing.T) { t.Errorf("expected Missing issuer error, got %v", err) } } -