diff --git a/internal/dao/mcp.go b/internal/dao/mcp.go index 8ac75c4be5..65ecd8827c 100644 --- a/internal/dao/mcp.go +++ b/internal/dao/mcp.go @@ -18,6 +18,8 @@ package dao import ( "errors" + "fmt" + "strings" "ragflow/internal/entity" @@ -27,6 +29,16 @@ import ( // MCPServerDAO MCP server data access object. type MCPServerDAO struct{} +// InvalidMCPServerOrderByError matches the Python list endpoint's error shape +// for unknown MCPServer ordering fields. +type InvalidMCPServerOrderByError struct { + Field string +} + +func (e *InvalidMCPServerOrderByError) Error() string { + return fmt.Sprintf("AttributeError(%q)", fmt.Sprintf("type object 'MCPServer' has no attribute '%s'", e.Field)) +} + // NewMCPServerDAO creates an MCP server DAO. func NewMCPServerDAO() *MCPServerDAO { return &MCPServerDAO{} @@ -60,6 +72,44 @@ func (dao *MCPServerDAO) CreateMCPServer(server *entity.MCPServer) error { return DB.Create(server).Error } +// ListMCPServers returns MCP servers for a tenant with optional filtering. +func (dao *MCPServerDAO) ListMCPServers(tenantID string, ids []string, keywords string, orderby string, desc bool) ([]*entity.MCPServer, int64, error) { + var servers []*entity.MCPServer + var total int64 + + query := DB.Model(&entity.MCPServer{}).Where("tenant_id = ?", tenantID) + + if len(ids) > 0 { + query = query.Where("id IN ?", ids) + } + + if keywords != "" { + query = query.Where("LOWER(name) LIKE ?", "%"+strings.ToLower(keywords)+"%") + } + + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + orderColumn, err := mcpServerOrderColumn(orderby) + if err != nil { + return nil, 0, err + } + orderDirection := "ASC" + if desc { + orderDirection = "DESC" + } + query = query.Order(orderColumn + " " + orderDirection) + + if err := query. + Select("id", "name", "server_type", "url", "description", "variables", "create_date", "update_date"). + Find(&servers).Error; err != nil { + return nil, 0, err + } + + return servers, total, nil +} + // DeleteMCPServer deletes an MCP server owned by a tenant. func (dao *MCPServerDAO) DeleteMCPServer(id, tenantID string) (bool, error) { result := DB.Where("id = ? AND tenant_id = ?", id, tenantID).Delete(&entity.MCPServer{}) @@ -68,3 +118,22 @@ func (dao *MCPServerDAO) DeleteMCPServer(id, tenantID string) (bool, error) { } return result.RowsAffected > 0, nil } + +func mcpServerOrderColumn(orderby string) (string, error) { + switch orderby { + case "id": + return "id", nil + case "name": + return "name", nil + case "server_type": + return "server_type", nil + case "url": + return "url", nil + case "update_time", "update_date": + return "update_date", nil + case "create_time", "create_date": + return "create_date", nil + default: + return "", &InvalidMCPServerOrderByError{Field: orderby} + } +} diff --git a/internal/dao/mcp_test.go b/internal/dao/mcp_test.go new file mode 100644 index 0000000000..54b4dd5e5a --- /dev/null +++ b/internal/dao/mcp_test.go @@ -0,0 +1,39 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "errors" + "testing" +) + +func TestMCPServerOrderColumnInvalidField(t *testing.T) { + _, err := mcpServerOrderColumn("bad_field") + if err == nil { + t.Fatal("expected invalid orderby error") + } + + var orderbyErr *InvalidMCPServerOrderByError + if !errors.As(err, &orderbyErr) { + t.Fatalf("expected InvalidMCPServerOrderByError, got %T", err) + } + + want := `AttributeError("type object 'MCPServer' has no attribute 'bad_field'")` + if err.Error() != want { + t.Fatalf("expected %q, got %q", want, err.Error()) + } +} diff --git a/internal/handler/mcp.go b/internal/handler/mcp.go index 23612b6fb6..4a67e68aef 100644 --- a/internal/handler/mcp.go +++ b/internal/handler/mcp.go @@ -17,7 +17,10 @@ package handler import ( + "fmt" "net/http" + "strconv" + "strings" "github.com/gin-gonic/gin" @@ -25,6 +28,12 @@ import ( "ragflow/internal/service" ) +const ( + defaultMCPServerPage = 0 + defaultMCPServerPageSize = 0 + maxMCPServerPageSize = 100 +) + // MCPHandler handles MCP server requests. type MCPHandler struct { mcpService *service.MCPService @@ -64,6 +73,51 @@ func (h *MCPHandler) CreateMCPServer(c *gin.Context) { }) } +// ListMCPServers lists MCP servers for the current user. +func (h *MCPHandler) ListMCPServers(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + page, err := parseMCPServerPage(c.Query("page")) + if err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + pageSize, err := parseMCPServerPageSize(c.Query("page_size")) + if err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + + orderby := c.DefaultQuery("orderby", "create_time") + desc := strings.ToLower(c.DefaultQuery("desc", "true")) != "false" + keywords := c.Query("keywords") + mcpIDs := getMCPIDsFromQuery(c) + + result, code, err := h.mcpService.ListMCPServers(user.ID, mcpIDs, keywords, page, pageSize, orderby, desc) + if err != nil { + if code == common.CodeServerError { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": code, + "message": err.Error(), + "data": nil, + }) + return + } + jsonError(c, code, err.Error()) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": "success", + "data": result, + }) +} + // DeleteMCPServer deletes an MCP server for the current user. func (h *MCPHandler) DeleteMCPServer(c *gin.Context) { user, errorCode, errorMessage := GetUser(c) @@ -84,3 +138,46 @@ func (h *MCPHandler) DeleteMCPServer(c *gin.Context) { "data": result, }) } + +func parseMCPServerPage(value string) (int, error) { + if value == "" { + return defaultMCPServerPage, nil + } + page, err := strconv.Atoi(value) + if err != nil { + return 0, fmt.Errorf("page must be an integer") + } + return page, nil +} + +func parseMCPServerPageSize(value string) (int, error) { + if value == "" { + return defaultMCPServerPageSize, nil + } + pageSize, err := strconv.Atoi(value) + if err != nil { + return 0, fmt.Errorf("page_size must be an integer") + } + if pageSize > maxMCPServerPageSize { + return 0, fmt.Errorf("page_size must be less than or equal to %d", maxMCPServerPageSize) + } + return pageSize, nil +} + +func getMCPIDsFromQuery(c *gin.Context) []string { + rawValues := c.QueryArray("mcp_ids") + if len(rawValues) == 0 { + rawValues = []string{c.Query("mcp_id")} + } + + ids := make([]string, 0) + for _, rawValue := range rawValues { + for _, item := range strings.Split(rawValue, ",") { + id := strings.TrimSpace(item) + if id != "" { + ids = append(ids, id) + } + } + } + return ids +} diff --git a/internal/router/router.go b/internal/router/router.go index 93615d501e..28f6614eac 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -268,6 +268,7 @@ func (r *Router) Setup(engine *gin.Engine) { mcp := v1.Group("/mcp") { mcp.POST("/servers", r.mcpHandler.CreateMCPServer) + mcp.GET("/servers", r.mcpHandler.ListMCPServers) mcp.DELETE("/servers/:mcp_id", r.mcpHandler.DeleteMCPServer) } diff --git a/internal/service/mcp.go b/internal/service/mcp.go index c1bc86fe69..41f0bc5bd0 100644 --- a/internal/service/mcp.go +++ b/internal/service/mcp.go @@ -21,6 +21,7 @@ import ( "encoding/json" "errors" "fmt" + "time" "ragflow/internal/common" "ragflow/internal/dao" @@ -31,6 +32,7 @@ const ( mcpServerTypeSSE = "sse" mcpServerTypeStreamableHTTP = "streamable-http" mcpServerNameLimit = 255 + mcpServerDateFormat = "2006-01-02T15:04:05" ) // MCPService handles MCP server operations. @@ -69,6 +71,24 @@ type CreateMCPServerResponse struct { Headers entity.JSONMap `json:"headers"` } +// MCPServerListItem is an MCP server item in the list response. +type MCPServerListItem struct { + ID string `json:"id"` + Name string `json:"name"` + ServerType string `json:"server_type"` + URL string `json:"url"` + Description *string `json:"description"` + Variables entity.JSONMap `json:"variables"` + CreateDate *string `json:"create_date"` + UpdateDate *string `json:"update_date"` +} + +// ListMCPServersResponse is the response payload for listing MCP servers. +type ListMCPServersResponse struct { + MCPServers []*MCPServerListItem `json:"mcp_servers"` + Total int64 `json:"total"` +} + // CreateMCPServer creates an MCP server owned by a tenant. func (s *MCPService) CreateMCPServer(tenantID string, req CreateMCPServerRequest) (*CreateMCPServerResponse, common.ErrorCode, error) { if !isValidMCPServerType(req.ServerType) { @@ -127,6 +147,45 @@ func (s *MCPService) CreateMCPServer(tenantID string, req CreateMCPServerRequest }, common.CodeSuccess, nil } +// ListMCPServers lists MCP servers owned by a tenant. +func (s *MCPService) ListMCPServers(tenantID string, ids []string, keywords string, page, pageSize int, orderby string, desc bool) (*ListMCPServersResponse, common.ErrorCode, error) { + servers, total, err := s.mcpServerDAO.ListMCPServers(tenantID, ids, keywords, orderby, desc) + if err != nil { + var orderbyErr *dao.InvalidMCPServerOrderByError + if errors.As(err, &orderbyErr) { + return nil, common.CodeExceptionError, err + } + return nil, common.CodeServerError, err + } + if servers == nil { + servers = []*entity.MCPServer{} + } + servers = paginateMCPServers(servers, page, pageSize) + + items := make([]*MCPServerListItem, 0, len(servers)) + for _, server := range servers { + variables := server.Variables + if variables == nil { + variables = entity.JSONMap{} + } + items = append(items, &MCPServerListItem{ + ID: server.ID, + Name: server.Name, + ServerType: server.ServerType, + URL: server.URL, + Description: server.Description, + Variables: variables, + CreateDate: formatMCPServerDate(server.CreateDate), + UpdateDate: formatMCPServerDate(server.UpdateDate), + }) + } + + return &ListMCPServersResponse{ + MCPServers: items, + Total: total, + }, common.CodeSuccess, nil +} + // DeleteMCPServer deletes an MCP server owned by a tenant. func (s *MCPService) DeleteMCPServer(tenantID, mcpID string) (bool, common.ErrorCode, error) { server, err := s.mcpServerDAO.GetByID(mcpID) @@ -177,3 +236,44 @@ func safeJSONMap(raw json.RawMessage) entity.JSONMap { } return entity.JSONMap(value) } + +func formatMCPServerDate(date *time.Time) *string { + if date == nil { + return nil + } + formatted := date.Format(mcpServerDateFormat) + return &formatted +} + +func paginateMCPServers(servers []*entity.MCPServer, page, pageSize int) []*entity.MCPServer { + if page == 0 || pageSize == 0 { + return servers + } + + start := (page - 1) * pageSize + stop := page * pageSize + return sliceMCPServers(servers, start, stop) +} + +func sliceMCPServers(servers []*entity.MCPServer, start, stop int) []*entity.MCPServer { + length := len(servers) + start = normalizeMCPServerSliceIndex(start, length) + stop = normalizeMCPServerSliceIndex(stop, length) + if stop < start { + return []*entity.MCPServer{} + } + return servers[start:stop] +} + +func normalizeMCPServerSliceIndex(index, length int) int { + if index < 0 { + index += length + } + if index < 0 { + return 0 + } + if index > length { + return length + } + return index +} diff --git a/internal/service/mcp_test.go b/internal/service/mcp_test.go new file mode 100644 index 0000000000..c2de799477 --- /dev/null +++ b/internal/service/mcp_test.go @@ -0,0 +1,65 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "fmt" + "testing" + + "ragflow/internal/entity" +) + +func TestPaginateMCPServersNegativeValuesMatchPythonSlice(t *testing.T) { + servers := makeMCPServers(13) + + got := paginateMCPServers(servers, -1, -2) + + if len(got) != 0 { + t.Fatalf("expected empty page for negative pagination, got %d servers", len(got)) + } +} + +func TestPaginateMCPServersKeepsUnpagedList(t *testing.T) { + servers := makeMCPServers(3) + + got := paginateMCPServers(servers, 0, 0) + + if len(got) != len(servers) { + t.Fatalf("expected unpaged list length %d, got %d", len(servers), len(got)) + } +} + +func TestPaginateMCPServersPositiveValues(t *testing.T) { + servers := makeMCPServers(5) + + got := paginateMCPServers(servers, 2, 2) + + if len(got) != 2 { + t.Fatalf("expected 2 servers, got %d", len(got)) + } + if got[0].ID != "server-3" || got[1].ID != "server-4" { + t.Fatalf("expected second page servers, got %q and %q", got[0].ID, got[1].ID) + } +} + +func makeMCPServers(count int) []*entity.MCPServer { + servers := make([]*entity.MCPServer, 0, count) + for i := 1; i <= count; i++ { + servers = append(servers, &entity.MCPServer{ID: fmt.Sprintf("server-%d", i)}) + } + return servers +}