mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-04 01:29:35 +08:00
feat: add Go MCP server update API (#15261)
## What #15240 implementation for PUT /api/v1/mcp/servers/:mcp_id ## Changes - Adds the Go implementation for `PUT /api/v1/mcp/servers/:mcp_id`. - Wires MCP service and handler into the Go server/router for the update route. - Preserves Python-style behavior for ownership checks, partial update fields, MCP type/name/URL validation, `headers`/`variables` normalization, and tool metadata scrubbing.
This commit is contained in:
committed by
GitHub
parent
2819d0ea24
commit
a98889cd76
@@ -14,7 +14,7 @@ load_env_file() {
|
||||
echo "Loading environment variables from: $env_file"
|
||||
# Source the .env file
|
||||
set -a
|
||||
source "$env_file"
|
||||
source "$env_file"
|
||||
set +a
|
||||
else
|
||||
echo "Warning: .env file not found at: $env_file"
|
||||
|
||||
@@ -12,7 +12,7 @@ Integrate an RSS feed as a data source.
|
||||
|
||||
---
|
||||
|
||||
This guide explains how to add an RSS feed as a data source to your dataset in RAGFlow.
|
||||
This guide explains how to add an RSS feed as a data source to your dataset in RAGFlow.
|
||||
|
||||
RSS (Really Simple Syndication) is a standardized web feed format used to publish frequently updated content—such as blog entries, news headlines, and podcasts. By connecting an RSS feed to RAGFlow, you can automatically ingest new content from a website as soon as it is published.
|
||||
|
||||
|
||||
@@ -110,6 +110,15 @@ func (dao *MCPServerDAO) ListMCPServers(tenantID string, ids []string, keywords
|
||||
return servers, total, nil
|
||||
}
|
||||
|
||||
// GetByIDAndTenant returns an MCP server owned by a tenant.
|
||||
func (dao *MCPServerDAO) GetByIDAndTenant(id, tenantID string) (*entity.MCPServer, error) {
|
||||
var server entity.MCPServer
|
||||
if err := DB.Where("id = ? AND tenant_id = ?", id, tenantID).First(&server).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &server, nil
|
||||
}
|
||||
|
||||
// DeleteMCPServer deletes an MCP server owned by a tenant.
|
||||
func (dao *MCPServerDAO) DeleteMCPServer(id, tenantID string) (bool, error) {
|
||||
result := DB.Where("id = ? AND tenant_id = ?", id, tenantID).Delete(&entity.MCPServer{})
|
||||
@@ -119,6 +128,17 @@ func (dao *MCPServerDAO) DeleteMCPServer(id, tenantID string) (bool, error) {
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
// UpdateMCPServer updates an MCP server owned by a tenant.
|
||||
func (dao *MCPServerDAO) UpdateMCPServer(id, tenantID string, updates map[string]interface{}) (bool, error) {
|
||||
result := DB.Model(&entity.MCPServer{}).
|
||||
Where("id = ? AND tenant_id = ?", id, tenantID).
|
||||
Updates(updates)
|
||||
if result.Error != nil {
|
||||
return false, result.Error
|
||||
}
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
func mcpServerOrderColumn(orderby string) (string, error) {
|
||||
switch orderby {
|
||||
case "id":
|
||||
|
||||
@@ -21,10 +21,12 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/service"
|
||||
)
|
||||
|
||||
@@ -32,6 +34,7 @@ const (
|
||||
defaultMCPServerPage = 0
|
||||
defaultMCPServerPageSize = 0
|
||||
maxMCPServerPageSize = 100
|
||||
mcpServerDateFormat = "2006-01-02T15:04:05"
|
||||
)
|
||||
|
||||
// MCPHandler handles MCP server requests.
|
||||
@@ -39,6 +42,21 @@ type MCPHandler struct {
|
||||
mcpService *service.MCPService
|
||||
}
|
||||
|
||||
type mcpServerResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
TenantID string `json:"tenant_id"`
|
||||
URL string `json:"url"`
|
||||
ServerType string `json:"server_type"`
|
||||
Description *string `json:"description"`
|
||||
Variables map[string]interface{} `json:"variables"`
|
||||
Headers map[string]interface{} `json:"headers"`
|
||||
CreateTime *int64 `json:"create_time"`
|
||||
CreateDate string `json:"create_date"`
|
||||
UpdateTime *int64 `json:"update_time"`
|
||||
UpdateDate string `json:"update_date"`
|
||||
}
|
||||
|
||||
// NewMCPHandler creates an MCP handler.
|
||||
func NewMCPHandler(mcpService *service.MCPService) *MCPHandler {
|
||||
return &MCPHandler{
|
||||
@@ -118,6 +136,34 @@ func (h *MCPHandler) ListMCPServers(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateMCPServer updates an MCP server for the current user.
|
||||
func (h *MCPHandler) UpdateMCPServer(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
mcpID := c.Param("mcp_id")
|
||||
var req service.UpdateMCPServerRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
jsonError(c, common.CodeDataError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result, code, err := h.mcpService.UpdateMCPServer(user.ID, mcpID, req)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"message": "success",
|
||||
"data": newMCPServerResponse(result),
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteMCPServer deletes an MCP server for the current user.
|
||||
func (h *MCPHandler) DeleteMCPServer(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
@@ -139,6 +185,34 @@ func (h *MCPHandler) DeleteMCPServer(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func newMCPServerResponse(server *entity.MCPServer) *mcpServerResponse {
|
||||
if server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &mcpServerResponse{
|
||||
ID: server.ID,
|
||||
Name: server.Name,
|
||||
TenantID: server.TenantID,
|
||||
URL: server.URL,
|
||||
ServerType: server.ServerType,
|
||||
Description: server.Description,
|
||||
Variables: map[string]interface{}(server.Variables),
|
||||
Headers: map[string]interface{}(server.Headers),
|
||||
CreateTime: server.CreateTime,
|
||||
CreateDate: formatMCPServerDate(server.CreateDate),
|
||||
UpdateTime: server.UpdateTime,
|
||||
UpdateDate: formatMCPServerDate(server.UpdateDate),
|
||||
}
|
||||
}
|
||||
|
||||
func formatMCPServerDate(date *time.Time) string {
|
||||
if date == nil {
|
||||
return ""
|
||||
}
|
||||
return date.Format(mcpServerDateFormat)
|
||||
}
|
||||
|
||||
func parseMCPServerPage(value string) (int, error) {
|
||||
if value == "" {
|
||||
return defaultMCPServerPage, nil
|
||||
|
||||
71
internal/handler/mcp_test.go
Normal file
71
internal/handler/mcp_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
func TestNewMCPServerResponsePreservesNullDescriptionAndFormatsDates(t *testing.T) {
|
||||
createTime := int64(1779949820000)
|
||||
updateTime := int64(1779953420000)
|
||||
location := time.FixedZone("UTC+8", 8*60*60)
|
||||
createDate := time.Date(2026, 5, 28, 10, 30, 20, 0, location)
|
||||
updateDate := time.Date(2026, 5, 28, 11, 30, 20, 0, location)
|
||||
|
||||
response := newMCPServerResponse(&entity.MCPServer{
|
||||
ID: "mcp-id",
|
||||
Name: "server",
|
||||
TenantID: "tenant-id",
|
||||
URL: "https://example.com/mcp",
|
||||
ServerType: "sse",
|
||||
Variables: entity.JSONMap{"tools": map[string]interface{}{}},
|
||||
Headers: entity.JSONMap{"Authorization": "Bearer token"},
|
||||
BaseModel: entity.BaseModel{
|
||||
CreateTime: &createTime,
|
||||
CreateDate: &createDate,
|
||||
UpdateTime: &updateTime,
|
||||
UpdateDate: &updateDate,
|
||||
},
|
||||
})
|
||||
|
||||
if response.Description != nil {
|
||||
t.Fatalf("description = %q, want nil", *response.Description)
|
||||
}
|
||||
if response.CreateDate != "2026-05-28T10:30:20" {
|
||||
t.Fatalf("create_date = %q, want date without timezone", response.CreateDate)
|
||||
}
|
||||
if response.UpdateDate != "2026-05-28T11:30:20" {
|
||||
t.Fatalf("update_date = %q, want date without timezone", response.UpdateDate)
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal response: %v", err)
|
||||
}
|
||||
if !bytes.Contains(payload, []byte(`"description":null`)) {
|
||||
t.Fatalf("payload %s does not preserve description:null", payload)
|
||||
}
|
||||
if bytes.Contains(payload, []byte(`+08:00`)) {
|
||||
t.Fatalf("payload %s includes timezone in date fields", payload)
|
||||
}
|
||||
}
|
||||
@@ -280,6 +280,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
{
|
||||
mcp.POST("/servers", r.mcpHandler.CreateMCPServer)
|
||||
mcp.GET("/servers", r.mcpHandler.ListMCPServers)
|
||||
mcp.PUT("/servers/:mcp_id", r.mcpHandler.UpdateMCPServer)
|
||||
mcp.DELETE("/servers/:mcp_id", r.mcpHandler.DeleteMCPServer)
|
||||
}
|
||||
|
||||
|
||||
@@ -21,11 +21,14 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -71,6 +74,9 @@ type CreateMCPServerResponse struct {
|
||||
Headers entity.JSONMap `json:"headers"`
|
||||
}
|
||||
|
||||
// UpdateMCPServerRequest is the raw request payload for updating an MCP server.
|
||||
type UpdateMCPServerRequest map[string]json.RawMessage
|
||||
|
||||
// MCPServerListItem is an MCP server item in the list response.
|
||||
type MCPServerListItem struct {
|
||||
ID string `json:"id"`
|
||||
@@ -147,6 +153,120 @@ func (s *MCPService) CreateMCPServer(tenantID string, req CreateMCPServerRequest
|
||||
}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// UpdateMCPServer updates an MCP server owned by a tenant.
|
||||
func (s *MCPService) UpdateMCPServer(tenantID, mcpID string, req UpdateMCPServerRequest) (*entity.MCPServer, common.ErrorCode, error) {
|
||||
server, err := s.mcpServerDAO.GetByIDAndTenant(mcpID, tenantID)
|
||||
if err != nil {
|
||||
if isMCPServerNotFound(err) {
|
||||
return nil, common.CodeDataError, mcpServerNotFoundError(mcpID, tenantID)
|
||||
}
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to get MCP server %s: %w", mcpID, err)
|
||||
}
|
||||
if server == nil {
|
||||
return nil, common.CodeDataError, mcpServerNotFoundError(mcpID, tenantID)
|
||||
}
|
||||
|
||||
serverType := server.ServerType
|
||||
serverTypeProvided := false
|
||||
if value, ok, err := optionalString(req, "server_type"); err != nil {
|
||||
return nil, common.CodeDataError, err
|
||||
} else if ok {
|
||||
serverType = value
|
||||
serverTypeProvided = true
|
||||
}
|
||||
if serverTypeProvided && !isValidMCPServerType(serverType) {
|
||||
return nil, common.CodeDataError, errors.New("Unsupported MCP server type.")
|
||||
}
|
||||
|
||||
serverName := server.Name
|
||||
serverNameProvided := false
|
||||
if value, ok, err := optionalString(req, "name"); err != nil {
|
||||
return nil, common.CodeDataError, err
|
||||
} else if ok {
|
||||
serverName = value
|
||||
serverNameProvided = true
|
||||
}
|
||||
if serverName != "" && len([]byte(serverName)) > mcpServerNameLimit {
|
||||
return nil, common.CodeDataError, fmt.Errorf("Invalid MCP name or length is %d which is large than 255.", len([]byte(serverName)))
|
||||
}
|
||||
|
||||
serverURL := server.URL
|
||||
serverURLProvided := false
|
||||
if value, ok, err := optionalString(req, "url"); err != nil {
|
||||
return nil, common.CodeDataError, err
|
||||
} else if ok {
|
||||
serverURL = strings.TrimSpace(value)
|
||||
if serverURL == "" {
|
||||
return nil, common.CodeDataError, errors.New("Invalid url.")
|
||||
}
|
||||
serverURLProvided = true
|
||||
}
|
||||
if serverURL == "" {
|
||||
return nil, common.CodeDataError, errors.New("Invalid url.")
|
||||
}
|
||||
|
||||
headers := server.Headers
|
||||
if raw, ok := req["headers"]; ok {
|
||||
headers = safeJSONMap(raw)
|
||||
}
|
||||
if headers == nil {
|
||||
headers = entity.JSONMap{}
|
||||
}
|
||||
|
||||
variables := server.Variables
|
||||
if raw, ok := req["variables"]; ok {
|
||||
variables = safeJSONMap(raw)
|
||||
}
|
||||
if variables == nil {
|
||||
variables = entity.JSONMap{}
|
||||
}
|
||||
delete(variables, "tools")
|
||||
variables["tools"] = map[string]interface{}{}
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"id": mcpID,
|
||||
"tenant_id": tenantID,
|
||||
"headers": headers,
|
||||
"variables": variables,
|
||||
}
|
||||
if serverNameProvided {
|
||||
updates["name"] = serverName
|
||||
}
|
||||
if serverURLProvided {
|
||||
updates["url"] = serverURL
|
||||
}
|
||||
if serverTypeProvided {
|
||||
updates["server_type"] = serverType
|
||||
}
|
||||
if raw, ok := req["description"]; ok {
|
||||
description, err := optionalNullableString(raw, "description")
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, err
|
||||
}
|
||||
updates["description"] = description
|
||||
}
|
||||
|
||||
if _, err := s.mcpServerDAO.UpdateMCPServer(mcpID, tenantID, updates); err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
updatedServer, err := s.mcpServerDAO.GetByIDAndTenant(mcpID, tenantID)
|
||||
if err != nil {
|
||||
if isMCPServerNotFound(err) {
|
||||
return nil, common.CodeDataError, mcpServerNotFoundError(mcpID, tenantID)
|
||||
}
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to fetch updated MCP server %s: %w", mcpID, err)
|
||||
}
|
||||
if updatedServer == nil {
|
||||
return nil, common.CodeDataError, mcpServerNotFoundError(mcpID, tenantID)
|
||||
}
|
||||
return updatedServer, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func isMCPServerNotFound(err error) bool {
|
||||
return errors.Is(err, gorm.ErrRecordNotFound)
|
||||
}
|
||||
|
||||
// ListMCPServers lists MCP servers owned by a tenant.
|
||||
func (s *MCPService) ListMCPServers(tenantID string, ids []string, keywords string, page, pageSize int, orderby string, desc bool) (*ListMCPServersResponse, common.ErrorCode, error) {
|
||||
servers, total, err := s.mcpServerDAO.ListMCPServers(tenantID, ids, keywords, orderby, desc)
|
||||
@@ -215,6 +335,31 @@ func isValidMCPServerType(serverType string) bool {
|
||||
return serverType == mcpServerTypeSSE || serverType == mcpServerTypeStreamableHTTP
|
||||
}
|
||||
|
||||
func optionalString(req UpdateMCPServerRequest, key string) (string, bool, error) {
|
||||
raw, ok := req[key]
|
||||
if !ok || bytes.Equal(bytes.TrimSpace(raw), []byte("null")) {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
var value string
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return "", true, fmt.Errorf("%s must be a string", key)
|
||||
}
|
||||
return value, true, nil
|
||||
}
|
||||
|
||||
func optionalNullableString(raw json.RawMessage, key string) (*string, error) {
|
||||
if bytes.Equal(bytes.TrimSpace(raw), []byte("null")) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var value string
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return nil, fmt.Errorf("%s must be a string", key)
|
||||
}
|
||||
return &value, nil
|
||||
}
|
||||
|
||||
func safeJSONMap(raw json.RawMessage) entity.JSONMap {
|
||||
raw = bytes.TrimSpace(raw)
|
||||
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
|
||||
|
||||
@@ -310,4 +310,3 @@ func TestNewOIDCMissingIssuer(t *testing.T) {
|
||||
t.Errorf("expected Missing issuer error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user