From ee942711c4bb8438e327cd6ab1e307e7aaead9c5 Mon Sep 17 00:00:00 2001 From: Haruko386 Date: Fri, 3 Jul 2026 14:05:42 +0800 Subject: [PATCH] fix: unable to fetch tools for MCP (#16583) --- internal/router/router_test.go | 28 +++++++++++ internal/service/mcp.go | 91 +++++++++++++++++++++++++++++++++- 2 files changed, 117 insertions(+), 2 deletions(-) diff --git a/internal/router/router_test.go b/internal/router/router_test.go index 3447b52298..23ea1a859e 100644 --- a/internal/router/router_test.go +++ b/internal/router/router_test.go @@ -133,6 +133,34 @@ func TestRouterSetupRegistersSearchbotMindMapRoute(t *testing.T) { } } +func TestRouterSetupRegistersChatbotInfoOnce(t *testing.T) { + gin.SetMode(gin.TestMode) + + engine := gin.New() + r := &Router{ + authHandler: handler.NewAuthHandler(), + botHandler: handler.NewBotHandler(nil), + } + r.Setup(engine) + + resp := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/chatbots/dialog-1/info", nil) + engine.ServeHTTP(resp, req) + + if resp.Code == http.StatusNotFound { + t.Fatalf("GET /api/v1/chatbots/:dialog_id/info returned 404; ChatbotInfo route is not registered") + } + var body struct { + Code common.ErrorCode `json:"code"` + } + if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil { + t.Fatalf("failed to decode response body: %v", err) + } + if body.Code != common.CodeUnauthorized { + t.Fatalf("status=%d body=%s; want beta auth middleware to handle registered ChatbotInfo route", resp.Code, resp.Body.String()) + } +} + func TestRouterSetupRegistersChatMindMapRoute(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/internal/service/mcp.go b/internal/service/mcp.go index 77a51690ab..fa5f6752f2 100644 --- a/internal/service/mcp.go +++ b/internal/service/mcp.go @@ -22,6 +22,7 @@ import ( "encoding/json" "errors" "fmt" + "strconv" "strings" "time" @@ -63,6 +64,7 @@ type CreateMCPServerRequest struct { Description *string `json:"description,omitempty"` Variables json.RawMessage `json:"variables,omitempty"` Headers json.RawMessage `json:"headers,omitempty"` + Timeout float64 `json:"timeout,omitempty"` } // CreateMCPServerResponse is the response payload for creating an MCP server. @@ -110,8 +112,13 @@ type ListMCPServersResponse struct { Total int64 `json:"total"` } +const maxMCPFetchTimeoutSec = 60 + // CreateMCPServer creates an MCP server owned by a tenant. func (s *MCPService) CreateMCPServer(tenantID string, req CreateMCPServerRequest) (*CreateMCPServerResponse, common.ErrorCode, error) { + if req.Timeout < 0 || req.Timeout > maxMCPFetchTimeoutSec { + return nil, common.CodeDataError, errors.New("Invalid timeout.") + } if !isValidMCPServerType(req.ServerType) { return nil, common.CodeDataError, errors.New("Unsupported MCP server type.") } @@ -139,7 +146,12 @@ func (s *MCPService) CreateMCPServer(tenantID string, req CreateMCPServerRequest headers := safeJSONMap(req.Headers) variables := safeJSONMap(req.Variables) delete(variables, "tools") - variables["tools"] = map[string]interface{}{} + + tools, err := fetchMCPTools(context.Background(), req.URL, req.ServerType, headers, variables, req.Timeout) + if err != nil { + return nil, common.CodeDataError, err + } + variables["tools"] = tools server := &entity.MCPServer{ ID: common.GenerateUUID(), @@ -168,6 +180,57 @@ func (s *MCPService) CreateMCPServer(tenantID string, req CreateMCPServerRequest }, common.CodeSuccess, nil } +func optionalFloat64(req UpdateMCPServerRequest, key string, defaultValue float64) (float64, error) { + raw, ok := req[key] + if !ok || bytes.Equal(bytes.TrimSpace(raw), []byte("null")) { + return defaultValue, nil + } + + var value float64 + if err := json.Unmarshal(raw, &value); err == nil { + return value, nil + } + + var text string + if err := json.Unmarshal(raw, &text); err != nil { + return defaultValue, fmt.Errorf("%s must be a number", key) + } + value, err := strconv.ParseFloat(strings.TrimSpace(text), 64) + if err != nil { + return defaultValue, fmt.Errorf("%s must be a number", key) + } + return value, nil +} + +func fetchMCPTools(ctx context.Context, url, serverType string, headers, variables entity.JSONMap, timeoutSeconds float64) (map[string]interface{}, error) { + if timeoutSeconds <= 0 { + timeoutSeconds = defaultMCPFetchTimeoutSec + } + timeout := time.Duration(timeoutSeconds * float64(time.Second)) + + tools, err := utility.FetchTools(ctx, utility.FetchOptions{ + URL: url, + ServerType: serverType, + Headers: jsonMapStringValues(headers), + Variables: jsonMapStringValues(variables), + Timeout: timeout, + }) + if err != nil { + return nil, err + } + return toolsAsMap(tools), nil +} + +func jsonMapStringValues(values entity.JSONMap) map[string]string { + out := map[string]string{} + for key, value := range values { + if text, ok := value.(string); ok { + out[key] = text + } + } + return out +} + func (s *MCPService) GetMCPServer(tenantID, mcpID string) (*entity.MCPServer, common.ErrorCode, error) { server, err := s.mcpServerDAO.GetByIDAndTenant(mcpID, tenantID) if err != nil { @@ -284,8 +347,22 @@ func (s *MCPService) UpdateMCPServer(tenantID, mcpID string, req UpdateMCPServer if variables == nil { variables = entity.JSONMap{} } + existingTools := server.Variables["tools"] delete(variables, "tools") - variables["tools"] = map[string]interface{}{} + needsRefresh := serverURLProvided || serverTypeProvided || headerOrVariablesChanged(req) + if needsRefresh { + timeoutSeconds, err := optionalFloat64(req, "timeout", defaultMCPFetchTimeoutSec) + if err != nil { + return nil, common.CodeDataError, err + } + tools, err := fetchMCPTools(context.Background(), serverURL, serverType, headers, variables, timeoutSeconds) + if err != nil { + return nil, common.CodeDataError, err + } + variables["tools"] = tools + } else if existingTools != nil { + variables["tools"] = existingTools + } updates := map[string]interface{}{ "id": mcpID, @@ -327,6 +404,16 @@ func (s *MCPService) UpdateMCPServer(tenantID, mcpID string, req UpdateMCPServer return updatedServer, common.CodeSuccess, nil } +func headerOrVariablesChanged(req UpdateMCPServerRequest) bool { + if _, ok := req["headers"]; ok { + return true + } + if _, ok := req["variables"]; ok { + return true + } + return false +} + func isMCPServerNotFound(err error) bool { return errors.Is(err, gorm.ErrRecordNotFound) }