mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-01 00:05:43 +08:00
feat: add Go MCP server list API (#15253)
## What #15240 Implements `GET /api/v1/mcp/servers` in the Go API server. ## Changes - Added MCP server DAO list query with tenant scoping. - Added MCP service response wrapper. - Added MCP handler for list request parsing and response formatting. - Wired `GET /api/v1/mcp/servers` under authenticated `/api/v1` routes. - Initialized MCP service and handler in the Go server startup. - update_time and update_date now both map to update_date - create_time and create_date now both map to create_date - default ordering now returns create_date ## API Behavior Matches the Python endpoint behavior: - Requires authenticated user. - Lists MCP servers for the current user tenant. - Supports `keywords`. - Supports `mcp_id` and repeated/comma-separated `mcp_ids`. - Supports `page`, `page_size`, `orderby`, and `desc`. - Returns: ```json { "code": 0, "message": "success", "data": { "mcp_servers": [], "total": 0 } } ```
This commit is contained in:
committed by
GitHub
parent
3aea80f5f5
commit
1748723971
@@ -18,6 +18,8 @@ package dao
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ragflow/internal/entity"
|
||||
|
||||
@@ -27,6 +29,16 @@ import (
|
||||
// MCPServerDAO MCP server data access object.
|
||||
type MCPServerDAO struct{}
|
||||
|
||||
// InvalidMCPServerOrderByError matches the Python list endpoint's error shape
|
||||
// for unknown MCPServer ordering fields.
|
||||
type InvalidMCPServerOrderByError struct {
|
||||
Field string
|
||||
}
|
||||
|
||||
func (e *InvalidMCPServerOrderByError) Error() string {
|
||||
return fmt.Sprintf("AttributeError(%q)", fmt.Sprintf("type object 'MCPServer' has no attribute '%s'", e.Field))
|
||||
}
|
||||
|
||||
// NewMCPServerDAO creates an MCP server DAO.
|
||||
func NewMCPServerDAO() *MCPServerDAO {
|
||||
return &MCPServerDAO{}
|
||||
@@ -60,6 +72,44 @@ func (dao *MCPServerDAO) CreateMCPServer(server *entity.MCPServer) error {
|
||||
return DB.Create(server).Error
|
||||
}
|
||||
|
||||
// ListMCPServers returns MCP servers for a tenant with optional filtering.
|
||||
func (dao *MCPServerDAO) ListMCPServers(tenantID string, ids []string, keywords string, orderby string, desc bool) ([]*entity.MCPServer, int64, error) {
|
||||
var servers []*entity.MCPServer
|
||||
var total int64
|
||||
|
||||
query := DB.Model(&entity.MCPServer{}).Where("tenant_id = ?", tenantID)
|
||||
|
||||
if len(ids) > 0 {
|
||||
query = query.Where("id IN ?", ids)
|
||||
}
|
||||
|
||||
if keywords != "" {
|
||||
query = query.Where("LOWER(name) LIKE ?", "%"+strings.ToLower(keywords)+"%")
|
||||
}
|
||||
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
orderColumn, err := mcpServerOrderColumn(orderby)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
orderDirection := "ASC"
|
||||
if desc {
|
||||
orderDirection = "DESC"
|
||||
}
|
||||
query = query.Order(orderColumn + " " + orderDirection)
|
||||
|
||||
if err := query.
|
||||
Select("id", "name", "server_type", "url", "description", "variables", "create_date", "update_date").
|
||||
Find(&servers).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return servers, total, nil
|
||||
}
|
||||
|
||||
// DeleteMCPServer deletes an MCP server owned by a tenant.
|
||||
func (dao *MCPServerDAO) DeleteMCPServer(id, tenantID string) (bool, error) {
|
||||
result := DB.Where("id = ? AND tenant_id = ?", id, tenantID).Delete(&entity.MCPServer{})
|
||||
@@ -68,3 +118,22 @@ func (dao *MCPServerDAO) DeleteMCPServer(id, tenantID string) (bool, error) {
|
||||
}
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
func mcpServerOrderColumn(orderby string) (string, error) {
|
||||
switch orderby {
|
||||
case "id":
|
||||
return "id", nil
|
||||
case "name":
|
||||
return "name", nil
|
||||
case "server_type":
|
||||
return "server_type", nil
|
||||
case "url":
|
||||
return "url", nil
|
||||
case "update_time", "update_date":
|
||||
return "update_date", nil
|
||||
case "create_time", "create_date":
|
||||
return "create_date", nil
|
||||
default:
|
||||
return "", &InvalidMCPServerOrderByError{Field: orderby}
|
||||
}
|
||||
}
|
||||
|
||||
39
internal/dao/mcp_test.go
Normal file
39
internal/dao/mcp_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package dao
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMCPServerOrderColumnInvalidField(t *testing.T) {
|
||||
_, err := mcpServerOrderColumn("bad_field")
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid orderby error")
|
||||
}
|
||||
|
||||
var orderbyErr *InvalidMCPServerOrderByError
|
||||
if !errors.As(err, &orderbyErr) {
|
||||
t.Fatalf("expected InvalidMCPServerOrderByError, got %T", err)
|
||||
}
|
||||
|
||||
want := `AttributeError("type object 'MCPServer' has no attribute 'bad_field'")`
|
||||
if err.Error() != want {
|
||||
t.Fatalf("expected %q, got %q", want, err.Error())
|
||||
}
|
||||
}
|
||||
@@ -17,7 +17,10 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -25,6 +28,12 @@ import (
|
||||
"ragflow/internal/service"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMCPServerPage = 0
|
||||
defaultMCPServerPageSize = 0
|
||||
maxMCPServerPageSize = 100
|
||||
)
|
||||
|
||||
// MCPHandler handles MCP server requests.
|
||||
type MCPHandler struct {
|
||||
mcpService *service.MCPService
|
||||
@@ -64,6 +73,51 @@ func (h *MCPHandler) CreateMCPServer(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// ListMCPServers lists MCP servers for the current user.
|
||||
func (h *MCPHandler) ListMCPServers(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
page, err := parseMCPServerPage(c.Query("page"))
|
||||
if err != nil {
|
||||
jsonError(c, common.CodeDataError, err.Error())
|
||||
return
|
||||
}
|
||||
pageSize, err := parseMCPServerPageSize(c.Query("page_size"))
|
||||
if err != nil {
|
||||
jsonError(c, common.CodeDataError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
orderby := c.DefaultQuery("orderby", "create_time")
|
||||
desc := strings.ToLower(c.DefaultQuery("desc", "true")) != "false"
|
||||
keywords := c.Query("keywords")
|
||||
mcpIDs := getMCPIDsFromQuery(c)
|
||||
|
||||
result, code, err := h.mcpService.ListMCPServers(user.ID, mcpIDs, keywords, page, pageSize, orderby, desc)
|
||||
if err != nil {
|
||||
if code == common.CodeServerError {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"message": "success",
|
||||
"data": result,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteMCPServer deletes an MCP server for the current user.
|
||||
func (h *MCPHandler) DeleteMCPServer(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
@@ -84,3 +138,46 @@ func (h *MCPHandler) DeleteMCPServer(c *gin.Context) {
|
||||
"data": result,
|
||||
})
|
||||
}
|
||||
|
||||
func parseMCPServerPage(value string) (int, error) {
|
||||
if value == "" {
|
||||
return defaultMCPServerPage, nil
|
||||
}
|
||||
page, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("page must be an integer")
|
||||
}
|
||||
return page, nil
|
||||
}
|
||||
|
||||
func parseMCPServerPageSize(value string) (int, error) {
|
||||
if value == "" {
|
||||
return defaultMCPServerPageSize, nil
|
||||
}
|
||||
pageSize, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("page_size must be an integer")
|
||||
}
|
||||
if pageSize > maxMCPServerPageSize {
|
||||
return 0, fmt.Errorf("page_size must be less than or equal to %d", maxMCPServerPageSize)
|
||||
}
|
||||
return pageSize, nil
|
||||
}
|
||||
|
||||
func getMCPIDsFromQuery(c *gin.Context) []string {
|
||||
rawValues := c.QueryArray("mcp_ids")
|
||||
if len(rawValues) == 0 {
|
||||
rawValues = []string{c.Query("mcp_id")}
|
||||
}
|
||||
|
||||
ids := make([]string, 0)
|
||||
for _, rawValue := range rawValues {
|
||||
for _, item := range strings.Split(rawValue, ",") {
|
||||
id := strings.TrimSpace(item)
|
||||
if id != "" {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
@@ -268,6 +268,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
mcp := v1.Group("/mcp")
|
||||
{
|
||||
mcp.POST("/servers", r.mcpHandler.CreateMCPServer)
|
||||
mcp.GET("/servers", r.mcpHandler.ListMCPServers)
|
||||
mcp.DELETE("/servers/:mcp_id", r.mcpHandler.DeleteMCPServer)
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
@@ -31,6 +32,7 @@ const (
|
||||
mcpServerTypeSSE = "sse"
|
||||
mcpServerTypeStreamableHTTP = "streamable-http"
|
||||
mcpServerNameLimit = 255
|
||||
mcpServerDateFormat = "2006-01-02T15:04:05"
|
||||
)
|
||||
|
||||
// MCPService handles MCP server operations.
|
||||
@@ -69,6 +71,24 @@ type CreateMCPServerResponse struct {
|
||||
Headers entity.JSONMap `json:"headers"`
|
||||
}
|
||||
|
||||
// MCPServerListItem is an MCP server item in the list response.
|
||||
type MCPServerListItem struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
ServerType string `json:"server_type"`
|
||||
URL string `json:"url"`
|
||||
Description *string `json:"description"`
|
||||
Variables entity.JSONMap `json:"variables"`
|
||||
CreateDate *string `json:"create_date"`
|
||||
UpdateDate *string `json:"update_date"`
|
||||
}
|
||||
|
||||
// ListMCPServersResponse is the response payload for listing MCP servers.
|
||||
type ListMCPServersResponse struct {
|
||||
MCPServers []*MCPServerListItem `json:"mcp_servers"`
|
||||
Total int64 `json:"total"`
|
||||
}
|
||||
|
||||
// CreateMCPServer creates an MCP server owned by a tenant.
|
||||
func (s *MCPService) CreateMCPServer(tenantID string, req CreateMCPServerRequest) (*CreateMCPServerResponse, common.ErrorCode, error) {
|
||||
if !isValidMCPServerType(req.ServerType) {
|
||||
@@ -127,6 +147,45 @@ func (s *MCPService) CreateMCPServer(tenantID string, req CreateMCPServerRequest
|
||||
}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// ListMCPServers lists MCP servers owned by a tenant.
|
||||
func (s *MCPService) ListMCPServers(tenantID string, ids []string, keywords string, page, pageSize int, orderby string, desc bool) (*ListMCPServersResponse, common.ErrorCode, error) {
|
||||
servers, total, err := s.mcpServerDAO.ListMCPServers(tenantID, ids, keywords, orderby, desc)
|
||||
if err != nil {
|
||||
var orderbyErr *dao.InvalidMCPServerOrderByError
|
||||
if errors.As(err, &orderbyErr) {
|
||||
return nil, common.CodeExceptionError, err
|
||||
}
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if servers == nil {
|
||||
servers = []*entity.MCPServer{}
|
||||
}
|
||||
servers = paginateMCPServers(servers, page, pageSize)
|
||||
|
||||
items := make([]*MCPServerListItem, 0, len(servers))
|
||||
for _, server := range servers {
|
||||
variables := server.Variables
|
||||
if variables == nil {
|
||||
variables = entity.JSONMap{}
|
||||
}
|
||||
items = append(items, &MCPServerListItem{
|
||||
ID: server.ID,
|
||||
Name: server.Name,
|
||||
ServerType: server.ServerType,
|
||||
URL: server.URL,
|
||||
Description: server.Description,
|
||||
Variables: variables,
|
||||
CreateDate: formatMCPServerDate(server.CreateDate),
|
||||
UpdateDate: formatMCPServerDate(server.UpdateDate),
|
||||
})
|
||||
}
|
||||
|
||||
return &ListMCPServersResponse{
|
||||
MCPServers: items,
|
||||
Total: total,
|
||||
}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// DeleteMCPServer deletes an MCP server owned by a tenant.
|
||||
func (s *MCPService) DeleteMCPServer(tenantID, mcpID string) (bool, common.ErrorCode, error) {
|
||||
server, err := s.mcpServerDAO.GetByID(mcpID)
|
||||
@@ -177,3 +236,44 @@ func safeJSONMap(raw json.RawMessage) entity.JSONMap {
|
||||
}
|
||||
return entity.JSONMap(value)
|
||||
}
|
||||
|
||||
func formatMCPServerDate(date *time.Time) *string {
|
||||
if date == nil {
|
||||
return nil
|
||||
}
|
||||
formatted := date.Format(mcpServerDateFormat)
|
||||
return &formatted
|
||||
}
|
||||
|
||||
func paginateMCPServers(servers []*entity.MCPServer, page, pageSize int) []*entity.MCPServer {
|
||||
if page == 0 || pageSize == 0 {
|
||||
return servers
|
||||
}
|
||||
|
||||
start := (page - 1) * pageSize
|
||||
stop := page * pageSize
|
||||
return sliceMCPServers(servers, start, stop)
|
||||
}
|
||||
|
||||
func sliceMCPServers(servers []*entity.MCPServer, start, stop int) []*entity.MCPServer {
|
||||
length := len(servers)
|
||||
start = normalizeMCPServerSliceIndex(start, length)
|
||||
stop = normalizeMCPServerSliceIndex(stop, length)
|
||||
if stop < start {
|
||||
return []*entity.MCPServer{}
|
||||
}
|
||||
return servers[start:stop]
|
||||
}
|
||||
|
||||
func normalizeMCPServerSliceIndex(index, length int) int {
|
||||
if index < 0 {
|
||||
index += length
|
||||
}
|
||||
if index < 0 {
|
||||
return 0
|
||||
}
|
||||
if index > length {
|
||||
return length
|
||||
}
|
||||
return index
|
||||
}
|
||||
|
||||
65
internal/service/mcp_test.go
Normal file
65
internal/service/mcp_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
func TestPaginateMCPServersNegativeValuesMatchPythonSlice(t *testing.T) {
|
||||
servers := makeMCPServers(13)
|
||||
|
||||
got := paginateMCPServers(servers, -1, -2)
|
||||
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected empty page for negative pagination, got %d servers", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaginateMCPServersKeepsUnpagedList(t *testing.T) {
|
||||
servers := makeMCPServers(3)
|
||||
|
||||
got := paginateMCPServers(servers, 0, 0)
|
||||
|
||||
if len(got) != len(servers) {
|
||||
t.Fatalf("expected unpaged list length %d, got %d", len(servers), len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaginateMCPServersPositiveValues(t *testing.T) {
|
||||
servers := makeMCPServers(5)
|
||||
|
||||
got := paginateMCPServers(servers, 2, 2)
|
||||
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2 servers, got %d", len(got))
|
||||
}
|
||||
if got[0].ID != "server-3" || got[1].ID != "server-4" {
|
||||
t.Fatalf("expected second page servers, got %q and %q", got[0].ID, got[1].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func makeMCPServers(count int) []*entity.MCPServer {
|
||||
servers := make([]*entity.MCPServer, 0, count)
|
||||
for i := 1; i <= count; i++ {
|
||||
servers = append(servers, &entity.MCPServer{ID: fmt.Sprintf("server-%d", i)})
|
||||
}
|
||||
return servers
|
||||
}
|
||||
Reference in New Issue
Block a user