mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-04 09:39:32 +08:00
fix: unable to fetch tools for MCP (#16583)
This commit is contained in:
@@ -133,6 +133,34 @@ func TestRouterSetupRegistersSearchbotMindMapRoute(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterSetupRegistersChatbotInfoOnce(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
engine := gin.New()
|
||||
r := &Router{
|
||||
authHandler: handler.NewAuthHandler(),
|
||||
botHandler: handler.NewBotHandler(nil),
|
||||
}
|
||||
r.Setup(engine)
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/chatbots/dialog-1/info", nil)
|
||||
engine.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code == http.StatusNotFound {
|
||||
t.Fatalf("GET /api/v1/chatbots/:dialog_id/info returned 404; ChatbotInfo route is not registered")
|
||||
}
|
||||
var body struct {
|
||||
Code common.ErrorCode `json:"code"`
|
||||
}
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("failed to decode response body: %v", err)
|
||||
}
|
||||
if body.Code != common.CodeUnauthorized {
|
||||
t.Fatalf("status=%d body=%s; want beta auth middleware to handle registered ChatbotInfo route", resp.Code, resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterSetupRegistersChatMindMapRoute(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -63,6 +64,7 @@ type CreateMCPServerRequest struct {
|
||||
Description *string `json:"description,omitempty"`
|
||||
Variables json.RawMessage `json:"variables,omitempty"`
|
||||
Headers json.RawMessage `json:"headers,omitempty"`
|
||||
Timeout float64 `json:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
// CreateMCPServerResponse is the response payload for creating an MCP server.
|
||||
@@ -110,8 +112,13 @@ type ListMCPServersResponse struct {
|
||||
Total int64 `json:"total"`
|
||||
}
|
||||
|
||||
const maxMCPFetchTimeoutSec = 60
|
||||
|
||||
// CreateMCPServer creates an MCP server owned by a tenant.
|
||||
func (s *MCPService) CreateMCPServer(tenantID string, req CreateMCPServerRequest) (*CreateMCPServerResponse, common.ErrorCode, error) {
|
||||
if req.Timeout < 0 || req.Timeout > maxMCPFetchTimeoutSec {
|
||||
return nil, common.CodeDataError, errors.New("Invalid timeout.")
|
||||
}
|
||||
if !isValidMCPServerType(req.ServerType) {
|
||||
return nil, common.CodeDataError, errors.New("Unsupported MCP server type.")
|
||||
}
|
||||
@@ -139,7 +146,12 @@ func (s *MCPService) CreateMCPServer(tenantID string, req CreateMCPServerRequest
|
||||
headers := safeJSONMap(req.Headers)
|
||||
variables := safeJSONMap(req.Variables)
|
||||
delete(variables, "tools")
|
||||
variables["tools"] = map[string]interface{}{}
|
||||
|
||||
tools, err := fetchMCPTools(context.Background(), req.URL, req.ServerType, headers, variables, req.Timeout)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, err
|
||||
}
|
||||
variables["tools"] = tools
|
||||
|
||||
server := &entity.MCPServer{
|
||||
ID: common.GenerateUUID(),
|
||||
@@ -168,6 +180,57 @@ func (s *MCPService) CreateMCPServer(tenantID string, req CreateMCPServerRequest
|
||||
}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func optionalFloat64(req UpdateMCPServerRequest, key string, defaultValue float64) (float64, error) {
|
||||
raw, ok := req[key]
|
||||
if !ok || bytes.Equal(bytes.TrimSpace(raw), []byte("null")) {
|
||||
return defaultValue, nil
|
||||
}
|
||||
|
||||
var value float64
|
||||
if err := json.Unmarshal(raw, &value); err == nil {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
var text string
|
||||
if err := json.Unmarshal(raw, &text); err != nil {
|
||||
return defaultValue, fmt.Errorf("%s must be a number", key)
|
||||
}
|
||||
value, err := strconv.ParseFloat(strings.TrimSpace(text), 64)
|
||||
if err != nil {
|
||||
return defaultValue, fmt.Errorf("%s must be a number", key)
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func fetchMCPTools(ctx context.Context, url, serverType string, headers, variables entity.JSONMap, timeoutSeconds float64) (map[string]interface{}, error) {
|
||||
if timeoutSeconds <= 0 {
|
||||
timeoutSeconds = defaultMCPFetchTimeoutSec
|
||||
}
|
||||
timeout := time.Duration(timeoutSeconds * float64(time.Second))
|
||||
|
||||
tools, err := utility.FetchTools(ctx, utility.FetchOptions{
|
||||
URL: url,
|
||||
ServerType: serverType,
|
||||
Headers: jsonMapStringValues(headers),
|
||||
Variables: jsonMapStringValues(variables),
|
||||
Timeout: timeout,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toolsAsMap(tools), nil
|
||||
}
|
||||
|
||||
func jsonMapStringValues(values entity.JSONMap) map[string]string {
|
||||
out := map[string]string{}
|
||||
for key, value := range values {
|
||||
if text, ok := value.(string); ok {
|
||||
out[key] = text
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *MCPService) GetMCPServer(tenantID, mcpID string) (*entity.MCPServer, common.ErrorCode, error) {
|
||||
server, err := s.mcpServerDAO.GetByIDAndTenant(mcpID, tenantID)
|
||||
if err != nil {
|
||||
@@ -284,8 +347,22 @@ func (s *MCPService) UpdateMCPServer(tenantID, mcpID string, req UpdateMCPServer
|
||||
if variables == nil {
|
||||
variables = entity.JSONMap{}
|
||||
}
|
||||
existingTools := server.Variables["tools"]
|
||||
delete(variables, "tools")
|
||||
variables["tools"] = map[string]interface{}{}
|
||||
needsRefresh := serverURLProvided || serverTypeProvided || headerOrVariablesChanged(req)
|
||||
if needsRefresh {
|
||||
timeoutSeconds, err := optionalFloat64(req, "timeout", defaultMCPFetchTimeoutSec)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, err
|
||||
}
|
||||
tools, err := fetchMCPTools(context.Background(), serverURL, serverType, headers, variables, timeoutSeconds)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, err
|
||||
}
|
||||
variables["tools"] = tools
|
||||
} else if existingTools != nil {
|
||||
variables["tools"] = existingTools
|
||||
}
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"id": mcpID,
|
||||
@@ -327,6 +404,16 @@ func (s *MCPService) UpdateMCPServer(tenantID, mcpID string, req UpdateMCPServer
|
||||
return updatedServer, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func headerOrVariablesChanged(req UpdateMCPServerRequest) bool {
|
||||
if _, ok := req["headers"]; ok {
|
||||
return true
|
||||
}
|
||||
if _, ok := req["variables"]; ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isMCPServerNotFound(err error) bool {
|
||||
return errors.Is(err, gorm.ErrRecordNotFound)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user