From 52bcd98d296f1a398d2e6a630a302d095e6be8a6 Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Mon, 9 Mar 2026 17:48:29 +0800 Subject: [PATCH] Add scheduled tasks (#13470) ### What problem does this PR solve? 1. RAGFlow server will send heartbeat periodically. 2. This PR will including: - Scheduled task - API server message sending - Admin server API to receive the message. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Signed-off-by: Jin Hai --- cmd/server_main.go | 71 +++++-- internal/admin/handler.go | 37 +++- internal/admin/heartbeat.go | 76 ++++++++ internal/admin/router.go | 2 + internal/admin/service.go | 16 ++ internal/common/status_message.go | 33 ++++ internal/cpp/Makefile | 2 +- internal/handler/user.go | 2 +- internal/server/config.go | 49 ++++- internal/service/heartbeat_sender.go | 136 +++++++++++++ internal/service/user.go | 4 +- internal/utility/http_client.go | 274 +++++++++++++++++++++++++++ internal/utility/network.go | 49 +++++ internal/utility/scheduled_task.go | 156 +++++++++++++++ 14 files changed, 871 insertions(+), 36 deletions(-) create mode 100644 internal/admin/heartbeat.go create mode 100644 internal/common/status_message.go create mode 100644 internal/service/heartbeat_sender.go create mode 100644 internal/utility/http_client.go create mode 100644 internal/utility/network.go create mode 100644 internal/utility/scheduled_task.go diff --git a/cmd/server_main.go b/cmd/server_main.go index 7d6ac21e8d..dbae3efbea 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -6,6 +6,7 @@ import ( "net/http" "os" "os/signal" + "ragflow/internal/common" "ragflow/internal/init_data" "ragflow/internal/server" "ragflow/internal/utility" @@ -30,7 +31,7 @@ import ( func main() { // Initialize logger with default level // logger.Init("info"); // set debug log level - if err := logger.Init("debug"); err != nil { + if err := logger.Init("info"); err != nil { panic(fmt.Sprintf("Failed to initialize logger: %v", err)) } @@ -45,28 +46,21 @@ func main() { } logger.Info("Model providers loaded", zap.Int("count", len(server.GetModelProviders()))) - cfg := server.GetConfig() + config := server.GetConfig() // Reinitialize logger with configured level if different - if cfg.Log.Level != "" && cfg.Log.Level != "info" { - if err := logger.Init(cfg.Log.Level); err != nil { + if config.Log.Level != "" && config.Log.Level != "info" { + if err := logger.Init(config.Log.Level); err != nil { logger.Error("Failed to reinitialize logger with configured level", err) } } server.SetLogger(logger.Logger) - logger.Info("Server mode", zap.String("mode", cfg.Server.Mode)) + logger.Info("Server mode", zap.String("mode", config.Server.Mode)) // Print all configuration settings server.PrintAll() - // Set Gin mode - if cfg.Server.Mode == "release" { - gin.SetMode(gin.ReleaseMode) - } else { - gin.SetMode(gin.DebugMode) - } - // Initialize database if err := dao.InitDB(); err != nil { logger.Fatal("Failed to initialize database", zap.Error(err)) @@ -80,13 +74,13 @@ func main() { } // Initialize doc engine - if err := engine.Init(&cfg.DocEngine); err != nil { + if err := engine.Init(&config.DocEngine); err != nil { logger.Fatal("Failed to initialize doc engine", zap.Error(err)) } defer engine.Close() // Initialize Redis cache - if err := cache.Init(&cfg.Redis); err != nil { + if err := cache.Init(&config.Redis); err != nil { logger.Fatal("Failed to initialize Redis", zap.Error(err)) } defer cache.Close() @@ -112,6 +106,20 @@ func main() { logger.Fatal("Failed to initialize query builder", zap.Error(err)) } + startServer(config) + + logger.Info("Server exited") +} + +func startServer(config *server.Config) { + + // Set Gin mode + if config.Server.Mode == "release" { + gin.SetMode(gin.ReleaseMode) + } else { + gin.SetMode(gin.DebugMode) + } + // Initialize service layer userService := service.NewUserService() documentService := service.NewDocumentService() @@ -147,7 +155,7 @@ func main() { ginEngine := gin.New() // Middleware - if cfg.Server.Mode == "debug" { + if config.Server.Mode == "debug" { ginEngine.Use(gin.Logger()) } ginEngine.Use(gin.Recovery()) @@ -156,7 +164,7 @@ func main() { r.Setup(ginEngine) // Create HTTP server - addr := fmt.Sprintf(":%d", cfg.Server.Port) + addr := fmt.Sprintf(":%d", config.Server.Port) srv := &http.Server{ Addr: addr, Handler: ginEngine, @@ -172,12 +180,39 @@ func main() { " /_/ |_|/_/ |_|\\____//_/ /_/ \\____/ |__/|__/\n", ) logger.Info(fmt.Sprintf("Version: %s", utility.GetRAGFlowVersion())) - logger.Info(fmt.Sprintf("Server starting on port: %d", cfg.Server.Port)) + logger.Info(fmt.Sprintf("Server starting on port: %d", config.Server.Port)) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { logger.Fatal("Failed to start server", zap.Error(err)) } }() + // Get local IP address for heartbeat reporting + localIP := utility.GetLocalIP() + if localIP == "" { + localIP = "127.0.0.1" + } + + // Initialize and start heartbeat reporter to admin server + heartbeatService := service.NewHeartbeatSender( + logger.Logger, + common.ServerTypeAPI, + fmt.Sprintf("ragflow-server-%d", config.Server.Port), + localIP, + config.Server.Port, + ) + if err := heartbeatService.InitHTTPClient(); err != nil { + logger.Warn("Failed to initialize heartbeat service", zap.Error(err)) + } else { + // Start heartbeat reporter with 30 seconds interval + heartbeatReporter := utility.NewScheduledTask("Heartbeat reporter", 3*time.Second, func() { + if err := heartbeatService.SendHeartbeat(); err != nil { + logger.Warn("Failed to send heartbeat", zap.Error(err)) + } + }) + heartbeatReporter.Start() + defer heartbeatReporter.Stop() + } + // Wait for interrupt signal to gracefully shutdown quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGUSR2) @@ -194,6 +229,4 @@ func main() { if err := srv.Shutdown(ctx); err != nil { logger.Fatal("Server forced to shutdown", zap.Error(err)) } - - logger.Info("Server exited") } diff --git a/internal/admin/handler.go b/internal/admin/handler.go index 526a22fa91..155ebe1685 100644 --- a/internal/admin/handler.go +++ b/internal/admin/handler.go @@ -19,10 +19,12 @@ package admin import ( "errors" "net/http" + "ragflow/internal/common" "ragflow/internal/server" "ragflow/internal/service" "ragflow/internal/utility" "strconv" + "time" "github.com/gin-gonic/gin" ) @@ -111,7 +113,7 @@ func (h *Handler) Login(c *gin.Context) { return } - user, code, err := h.userService.LoginByEmail(&req) + user, code, err := h.userService.LoginByEmail(&req, true) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ "code": code, @@ -135,8 +137,9 @@ func (h *Handler) Login(c *gin.Context) { c.Header("Access-Control-Expose-Headers", "Authorization") c.JSON(http.StatusOK, gin.H{ - "code": 0, - "message": "Login successful", + "code": common.CodeSuccess, + "message": "Welcome back!", + "data": user, }) } @@ -943,3 +946,31 @@ func (h *Handler) HandleNoRoute(c *gin.Context) { Message: "The requested resource was not found", }) } + +// Reports handle heartbeat reports from servers +func (h *Handler) Reports(c *gin.Context) { + var req common.BaseMessage + if err := c.ShouldBindJSON(&req); err != nil { + errorResponse(c, "Invalid request body: "+err.Error(), 400) + return + } + + // Set default timestamp if not provided + if req.Timestamp.IsZero() { + req.Timestamp = time.Now() + } + + // Only process heartbeat messages for now + if req.MessageType != common.MessageHeartbeat { + errorResponse(c, "Unsupported report type: "+string(req.MessageType), 400) + return + } + + // Handle the heartbeat + if err := h.service.HandleHeartbeat(&req); err != nil { + errorResponse(c, "Failed to process heartbeat: "+err.Error(), 500) + return + } + + successNoData(c, "Heartbeat received successfully") +} diff --git a/internal/admin/heartbeat.go b/internal/admin/heartbeat.go new file mode 100644 index 0000000000..b7e41e6114 --- /dev/null +++ b/internal/admin/heartbeat.go @@ -0,0 +1,76 @@ +package admin + +import ( + "ragflow/internal/common" + "sync" + "time" +) + +// ServerStatusStore is a thread-safe global server status storage +type ServerStatusStore struct { + mu sync.RWMutex + servers map[string]*common.BaseMessage // key: server_id +} + +// GlobalServerStatusStore is the global instance +var GlobalServerStatusStore = &ServerStatusStore{ + servers: make(map[string]*common.BaseMessage), +} + +// UpdateStatus updates or adds a server status +func (s *ServerStatusStore) UpdateStatus(serverID string, status *common.BaseMessage) { + s.mu.Lock() + defer s.mu.Unlock() + s.servers[serverID] = status +} + +// GetStatus gets a single server status +func (s *ServerStatusStore) GetStatus(serverID string) (*common.BaseMessage, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + status, ok := s.servers[serverID] + return status, ok +} + +// GetAllStatuses gets all server statuses +func (s *ServerStatusStore) GetAllStatuses() []*common.BaseMessage { + s.mu.RLock() + defer s.mu.RUnlock() + result := make([]*common.BaseMessage, 0, len(s.servers)) + for _, status := range s.servers { + result = append(result, status) + } + return result +} + +// GetStatusesByType gets server statuses by type +func (s *ServerStatusStore) GetStatusesByType(serverType common.ServerType) []*common.BaseMessage { + s.mu.RLock() + defer s.mu.RUnlock() + result := make([]*common.BaseMessage, 0) + for _, status := range s.servers { + if status.ServerType == serverType { + result = append(result, status) + } + } + return result +} + +// RemoveStatus removes a server status +func (s *ServerStatusStore) RemoveStatus(serverID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.servers, serverID) +} + +// CleanupStaleStatuses cleans up servers that haven't reported for a specified duration +func (s *ServerStatusStore) CleanupStaleStatuses(maxAge time.Duration) { + s.mu.Lock() + defer s.mu.Unlock() + now := time.Now() + for id, status := range s.servers { + if now.Sub(status.Timestamp) > maxAge { + delete(s.servers, id) + } + } +} diff --git a/internal/admin/router.go b/internal/admin/router.go index 3dc03c2c14..e2d9cedf1c 100644 --- a/internal/admin/router.go +++ b/internal/admin/router.go @@ -47,6 +47,8 @@ func (r *Router) Setup(engine *gin.Engine) { admin.GET("/ping", r.handler.Ping) admin.POST("/login", r.handler.Login) + admin.POST("/reports", r.handler.Reports) + // Protected routes protected := admin.Group("") protected.Use(r.handler.AuthMiddleware()) diff --git a/internal/admin/service.go b/internal/admin/service.go index 2438ef6c85..b45714396a 100644 --- a/internal/admin/service.go +++ b/internal/admin/service.go @@ -25,6 +25,7 @@ import ( "net/http" "os" "ragflow/internal/cache" + "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/engine/elasticsearch" "ragflow/internal/model" @@ -732,3 +733,18 @@ func (s *Service) TestSandboxConnection(providerType string, config map[string]i "connected": true, }, nil } + +// HandleHeartbeat handle heartbeat +func (s *Service) HandleHeartbeat(msg *common.BaseMessage) error { + status := &common.BaseMessage{ + ServerName: msg.ServerName, + ServerType: msg.ServerType, + Host: msg.Host, + Port: msg.Port, + Version: msg.Version, + Timestamp: msg.Timestamp, + Ext: msg.Ext, + } + GlobalServerStatusStore.UpdateStatus(msg.ServerName, status) + return nil +} diff --git a/internal/common/status_message.go b/internal/common/status_message.go new file mode 100644 index 0000000000..76d29ac3eb --- /dev/null +++ b/internal/common/status_message.go @@ -0,0 +1,33 @@ +package common + +import ( + "time" +) + +type MessageType string + +const ( + MessageHeartbeat MessageType = "heartbeat" + MessageMetric MessageType = "metric" + MessageEvent MessageType = "event" +) + +type ServerType string + +const ( + ServerTypeAPI ServerType = "api_server" // API server + ServerTypeWorker ServerType = "ingestor" // Ingestion server + ServerTypeScheduler ServerType = "data_collector" // Data collection server +) + +type BaseMessage struct { + MessageID int64 `json:"report_id"` + MessageType MessageType `json:"report_type"` + ServerName string `json:"server_id"` + ServerType ServerType `json:"server_type"` + Host string `json:"host"` + Port int `json:"port"` + Version string `json:"version"` + Timestamp time.Time `json:"timestamp"` + Ext map[string]interface{} `json:"ext,omitempty"` +} diff --git a/internal/cpp/Makefile b/internal/cpp/Makefile index 9ddf024405..e45843e85d 100644 --- a/internal/cpp/Makefile +++ b/internal/cpp/Makefile @@ -51,7 +51,7 @@ c_api_debug: $(BUILD_DIR) # Test the Go bindings test_go: c_api - cd bindings/example && go run main.go ../../$(BUILD_DIR) "这是一个测试文本。This is a test." + cd bindings/example && go run main.go ../../$(BUILD_DIR) "This is a test." # Run memory test test_memory: c_api diff --git a/internal/handler/user.go b/internal/handler/user.go index 3651c29c14..8ec2d314f3 100644 --- a/internal/handler/user.go +++ b/internal/handler/user.go @@ -164,7 +164,7 @@ func (h *UserHandler) LoginByEmail(c *gin.Context) { return } - user, code, err := h.userService.LoginByEmail(&req) + user, code, err := h.userService.LoginByEmail(&req, false) if err != nil { c.JSON(http.StatusOK, gin.H{ "code": code, diff --git a/internal/server/config.go b/internal/server/config.go index fe9cdea48e..b028ae76ce 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -40,9 +40,16 @@ type Config struct { DocEngine DocEngineConfig `mapstructure:"doc_engine"` RegisterEnabled int `mapstructure:"register_enabled"` OAuth map[string]OAuthConfig `mapstructure:"oauth"` + Admin AdminConfig `mapstructure:"admin"` UserDefaultLLM UserDefaultLLMConfig `mapstructure:"user_default_llm"` } +// AdminConfig admin server configuration +type AdminConfig struct { + Host string `mapstructure:"host"` + Port int `mapstructure:"http_port"` +} + // UserDefaultLLMConfig user default LLM configuration type UserDefaultLLMConfig struct { DefaultModels DefaultModelsConfig `mapstructure:"default_models"` @@ -50,19 +57,19 @@ type UserDefaultLLMConfig struct { // DefaultModelsConfig default models configuration type DefaultModelsConfig struct { - ChatModel ModelConfig `mapstructure:"chat_model"` - EmbeddingModel ModelConfig `mapstructure:"embedding_model"` - RerankModel ModelConfig `mapstructure:"rerank_model"` - ASRModel ModelConfig `mapstructure:"asr_model"` + ChatModel ModelConfig `mapstructure:"chat_model"` + EmbeddingModel ModelConfig `mapstructure:"embedding_model"` + RerankModel ModelConfig `mapstructure:"rerank_model"` + ASRModel ModelConfig `mapstructure:"asr_model"` Image2TextModel ModelConfig `mapstructure:"image2text_model"` } // ModelConfig model configuration type ModelConfig struct { - Name string `mapstructure:"name"` - APIKey string `mapstructure:"api_key"` - BaseURL string `mapstructure:"base_url"` - Factory string `mapstructure:"factory"` + Name string `mapstructure:"name"` + APIKey string `mapstructure:"api_key"` + BaseURL string `mapstructure:"base_url"` + Factory string `mapstructure:"factory"` } // OAuthConfig OAuth configuration for a channel @@ -325,6 +332,20 @@ func Init(configPath string) error { return fmt.Errorf("unmarshal config error: %w", err) } + // Set default values for admin configuration if not configured + if globalConfig.Admin.Host == "" { + globalConfig.Admin.Host = v.GetString("admin.host") + } + if globalConfig.Admin.Host == "" { + globalConfig.Admin.Host = "127.0.0.1" + } + if globalConfig.Admin.Port == 0 { + globalConfig.Admin.Port = v.GetInt("admin.http_port") + } + if globalConfig.Admin.Port == 0 { + globalConfig.Admin.Port = 9381 + } + // Load REGISTER_ENABLED from environment variable (default: 1) registerEnabled := 1 if envVal := os.Getenv("REGISTER_ENABLED"); envVal != "" { @@ -357,8 +378,8 @@ func Init(configPath string) error { if v.IsSet("ragflow") { ragflowConfig := v.Sub("ragflow") if ragflowConfig != nil { - globalConfig.Server.Port = ragflowConfig.GetInt("http_port") + 2 // 9382, by default - // globalConfig.Server.Port = ragflowConfig.GetInt("http_port") // Correct + //globalConfig.Server.Port = ragflowConfig.GetInt("http_port") + 2 // 9382, by default + globalConfig.Server.Port = ragflowConfig.GetInt("http_port") // Correct // If mode is not set, default to debug if globalConfig.Server.Mode == "" { globalConfig.Server.Mode = "release" @@ -484,6 +505,14 @@ func GetConfig() *Config { return globalConfig } +// GetAdminConfig gets the admin server configuration +func GetAdminConfig() *AdminConfig { + if globalConfig == nil { + return nil + } + return &globalConfig.Admin +} + // SetLogger sets the logger instance func SetLogger(l *zap.Logger) { zapLogger = l diff --git a/internal/service/heartbeat_sender.go b/internal/service/heartbeat_sender.go new file mode 100644 index 0000000000..ec2b198320 --- /dev/null +++ b/internal/service/heartbeat_sender.go @@ -0,0 +1,136 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "encoding/json" + "fmt" + "ragflow/internal/common" + "ragflow/internal/server" + "ragflow/internal/utility" + "time" + + "go.uber.org/zap" +) + +// HeartbeatSender is responsible for sending heartbeat reports to the admin server +type HeartbeatSender struct { + client *utility.HTTPClient + logger *zap.Logger + serverType common.ServerType + serverName string + host string + port int + version string + lastSuccess bool + attemptCount int +} + +// NewHeartbeatSender creates a new heartbeat service instance +func NewHeartbeatSender(logger *zap.Logger, serverType common.ServerType, serverName, host string, port int) *HeartbeatSender { + return &HeartbeatSender{ + logger: logger, + serverType: serverType, + serverName: serverName, + host: host, + port: port, + version: utility.GetRAGFlowVersion(), + lastSuccess: false, + attemptCount: 0, + } +} + +// InitHTTPClient initializes the HTTP client with admin server configuration +func (h *HeartbeatSender) InitHTTPClient() error { + adminConfig := server.GetAdminConfig() + if adminConfig == nil { + return fmt.Errorf("admin configuration not found") + } + + h.client = utility.NewHTTPClientBuilder(). + WithHost(adminConfig.Host). + WithPort(adminConfig.Port). + WithTimeout(10 * time.Second). + Build() + + h.logger.Info("Heartbeat HTTP client initialized", + zap.String("admin_host", adminConfig.Host), + zap.Int("admin_port", adminConfig.Port+2), + ) + + return nil +} + +// SendHeartbeat sends a heartbeat message to the admin server +func (h *HeartbeatSender) SendHeartbeat() error { + + if h.attemptCount < 10 { + if h.lastSuccess { + h.attemptCount++ + return nil + } + } + h.attemptCount = 0 + h.lastSuccess = false + + if h.client == nil { + if err := h.InitHTTPClient(); err != nil { + h.logger.Error("Failed to initialize HTTP client", zap.Error(err)) + return err + } + } + + message := &common.BaseMessage{ + MessageID: time.Now().UnixNano(), + MessageType: common.MessageHeartbeat, + ServerName: h.serverName, + ServerType: h.serverType, + Host: h.host, + Port: h.port, + Version: h.version, + Timestamp: time.Now(), + Ext: make(map[string]interface{}), + } + + jsonData, err := json.Marshal(message) + if err != nil { + h.logger.Error("Failed to marshal heartbeat message", zap.Error(err)) + return err + } + + resp, err := h.client.PostJSON("/api/v1/admin/reports", jsonData) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + h.logger.Error("Heartbeat request failed", + zap.Int("status_code", resp.StatusCode), + ) + return fmt.Errorf("heartbeat request failed with status code: %d", resp.StatusCode) + } + + h.logger.Debug("Heartbeat sent successfully", + zap.String("server_id", h.serverName), + zap.String("server_type", string(h.serverType)), + ) + + h.lastSuccess = true + + return nil +} diff --git a/internal/service/user.go b/internal/service/user.go index eb8b2e6f1e..a87260b680 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -309,8 +309,8 @@ func (s *UserService) Login(req *LoginRequest) (*model.User, common.ErrorCode, e // - CodeAuthenticationError (109): Email not registered or password mismatch // - CodeServerError (500): Password decryption failure // - CodeForbidden (403): Account disabled -func (s *UserService) LoginByEmail(req *EmailLoginRequest) (*model.User, common.ErrorCode, error) { - if req.Email == "admin@ragflow.io" { +func (s *UserService) LoginByEmail(req *EmailLoginRequest, adminLogin bool) (*model.User, common.ErrorCode, error) { + if !adminLogin && req.Email == "admin@ragflow.io" { return nil, common.CodeAuthenticationError, fmt.Errorf("default admin account cannot be used to login normal services") } diff --git a/internal/utility/http_client.go b/internal/utility/http_client.go new file mode 100644 index 0000000000..464b5530af --- /dev/null +++ b/internal/utility/http_client.go @@ -0,0 +1,274 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package utility + +import ( + "bytes" + "crypto/tls" + "fmt" + "io" + "net/http" + "net/url" + "time" +) + +// HTTPClient is a configurable HTTP client +type HTTPClient struct { + host string + port int + useSSL bool + timeout time.Duration + headers map[string]string + httpClient *http.Client +} + +// HTTPClientBuilder is a builder for HTTPClient +type HTTPClientBuilder struct { + client *HTTPClient +} + +// NewHTTPClientBuilder creates a new HTTPClientBuilder with default values +func NewHTTPClientBuilder() *HTTPClientBuilder { + return &HTTPClientBuilder{ + client: &HTTPClient{ + host: "localhost", + port: 80, + useSSL: false, + timeout: 30 * time.Second, + headers: make(map[string]string), + }, + } +} + +// WithHost sets the host +func (b *HTTPClientBuilder) WithHost(host string) *HTTPClientBuilder { + b.client.host = host + return b +} + +// WithPort sets the port +func (b *HTTPClientBuilder) WithPort(port int) *HTTPClientBuilder { + b.client.port = port + return b +} + +// WithSSL enables or disables SSL +func (b *HTTPClientBuilder) WithSSL(useSSL bool) *HTTPClientBuilder { + b.client.useSSL = useSSL + return b +} + +// WithTimeout sets the timeout duration +func (b *HTTPClientBuilder) WithTimeout(timeout time.Duration) *HTTPClientBuilder { + b.client.timeout = timeout + return b +} + +// WithHeader adds a single header +func (b *HTTPClientBuilder) WithHeader(key, value string) *HTTPClientBuilder { + b.client.headers[key] = value + return b +} + +// WithHeaders sets multiple headers +func (b *HTTPClientBuilder) WithHeaders(headers map[string]string) *HTTPClientBuilder { + for key, value := range headers { + b.client.headers[key] = value + } + return b +} + +// Build creates the HTTPClient +func (b *HTTPClientBuilder) Build() *HTTPClient { + transport := &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: false, + }, + } + + // If SSL is disabled, allow insecure connections + if !b.client.useSSL { + transport.TLSClientConfig.InsecureSkipVerify = true + } + + b.client.httpClient = &http.Client{ + Timeout: b.client.timeout, + Transport: transport, + } + + return b.client +} + +// SetHost sets the host +func (c *HTTPClient) SetHost(host string) { + c.host = host +} + +// SetPort sets the port +func (c *HTTPClient) SetPort(port int) { + c.port = port +} + +// SetSSL enables or disables SSL +func (c *HTTPClient) SetSSL(useSSL bool) { + c.useSSL = useSSL +} + +// SetTimeout sets the timeout duration +func (c *HTTPClient) SetTimeout(timeout time.Duration) { + c.timeout = timeout + c.httpClient.Timeout = timeout +} + +// SetHeader sets a single header +func (c *HTTPClient) SetHeader(key, value string) { + c.headers[key] = value +} + +// SetHeaders sets multiple headers +func (c *HTTPClient) SetHeaders(headers map[string]string) { + c.headers = headers +} + +// AddHeader adds a header without removing existing ones +func (c *HTTPClient) AddHeader(key, value string) { + c.headers[key] = value +} + +// GetHeaders returns a copy of all headers +func (c *HTTPClient) GetHeaders() map[string]string { + headersCopy := make(map[string]string) + for k, v := range c.headers { + headersCopy[k] = v + } + return headersCopy +} + +// GetBaseURL returns the base URL +func (c *HTTPClient) GetBaseURL() string { + scheme := "http" + if c.useSSL { + scheme = "https" + } + return fmt.Sprintf("%s://%s:%d", scheme, c.host, c.port) +} + +// GetFullURL returns the full URL for a given path +func (c *HTTPClient) GetFullURL(path string) string { + baseURL := c.GetBaseURL() + // Ensure path starts with / + if path != "" && path[0] != '/' { + path = "/" + path + } + return baseURL + path +} + +// prepareRequest creates an HTTP request with configured headers +func (c *HTTPClient) prepareRequest(method, urlStr string, body io.Reader) (*http.Request, error) { + req, err := http.NewRequest(method, urlStr, body) + if err != nil { + return nil, err + } + + // Add configured headers + for key, value := range c.headers { + req.Header.Set(key, value) + } + + return req, nil +} + +// Get performs a GET request +func (c *HTTPClient) Get(path string) (*http.Response, error) { + urlStr := c.GetFullURL(path) + req, err := c.prepareRequest(http.MethodGet, urlStr, nil) + if err != nil { + return nil, err + } + return c.httpClient.Do(req) +} + +// GetWithParams performs a GET request with query parameters +func (c *HTTPClient) GetWithParams(path string, params map[string]string) (*http.Response, error) { + urlStr := c.GetFullURL(path) + u, err := url.Parse(urlStr) + if err != nil { + return nil, err + } + + query := u.Query() + for key, value := range params { + query.Set(key, value) + } + u.RawQuery = query.Encode() + + req, err := c.prepareRequest(http.MethodGet, u.String(), nil) + if err != nil { + return nil, err + } + return c.httpClient.Do(req) +} + +// Post performs a POST request +func (c *HTTPClient) Post(path string, body []byte) (*http.Response, error) { + urlStr := c.GetFullURL(path) + req, err := c.prepareRequest(http.MethodPost, urlStr, bytes.NewReader(body)) + if err != nil { + return nil, err + } + return c.httpClient.Do(req) +} + +// PostJSON performs a POST request with JSON content type +func (c *HTTPClient) PostJSON(path string, body []byte) (*http.Response, error) { + c.SetHeader("Content-Type", "application/json") + return c.Post(path, body) +} + +// Put performs a PUT request +func (c *HTTPClient) Put(path string, body []byte) (*http.Response, error) { + urlStr := c.GetFullURL(path) + req, err := c.prepareRequest(http.MethodPut, urlStr, bytes.NewReader(body)) + if err != nil { + return nil, err + } + return c.httpClient.Do(req) +} + +// Delete performs a DELETE request +func (c *HTTPClient) Delete(path string) (*http.Response, error) { + urlStr := c.GetFullURL(path) + req, err := c.prepareRequest(http.MethodDelete, urlStr, nil) + if err != nil { + return nil, err + } + return c.httpClient.Do(req) +} + +// Do performs a request with the given method +func (c *HTTPClient) Do(method, path string, body []byte) (*http.Response, error) { + urlStr := c.GetFullURL(path) + var bodyReader io.Reader + if body != nil { + bodyReader = bytes.NewReader(body) + } + req, err := c.prepareRequest(method, urlStr, bodyReader) + if err != nil { + return nil, err + } + return c.httpClient.Do(req) +} diff --git a/internal/utility/network.go b/internal/utility/network.go new file mode 100644 index 0000000000..bf8ad98201 --- /dev/null +++ b/internal/utility/network.go @@ -0,0 +1,49 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package utility + +import ( + "net" +) + +// GetLocalIP returns the first non-loopback local IP address of the host +func GetLocalIP() string { + addrs, err := net.InterfaceAddrs() + if err != nil { + return "" + } + + for _, addr := range addrs { + // Check the address type and skip loopback addresses + if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { + if ipnet.IP.To4() != nil { + return ipnet.IP.String() + } + } + } + + return "" +} + +// GetLocalIPWithFallback returns the local IP address with a fallback value +func GetLocalIPWithFallback(fallback string) string { + ip := GetLocalIP() + if ip == "" { + return fallback + } + return ip +} diff --git a/internal/utility/scheduled_task.go b/internal/utility/scheduled_task.go new file mode 100644 index 0000000000..88c9886d17 --- /dev/null +++ b/internal/utility/scheduled_task.go @@ -0,0 +1,156 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package utility + +import ( + "encoding/json" + "fmt" + "ragflow/internal/logger" + "sync/atomic" + "time" + + "go.uber.org/zap" +) + +type StatusMessage struct { + ID int `json:"id"` + Version string `json:"version"` + Timestamp time.Time `json:"timestamp"` + NodeName string `json:"node_name"` + ExtInfo string `json:"ext_info"` +} + +func NewStatusMessage(id int, version string, nodeName string, extInfo string) *StatusMessage { + return &StatusMessage{ + ID: id, + Version: version, + Timestamp: time.Now(), + NodeName: nodeName, + ExtInfo: extInfo, + } +} + +func StatusMessageSending() { + // Construct status message + statusMessage := NewStatusMessage(0, "v1", "ragflow", "") + + // Serialize to JSON + jsonData, err := json.Marshal(statusMessage) + if err != nil { + logger.Error("Failed to marshal status message", err) + return + } + + // Create HTTP client + client := NewHTTPClientBuilder(). + WithHost("127.0.0.1"). + WithPort(9381). + WithSSL(false). + WithTimeout(10 * time.Second). + Build() + + // Send POST request + resp, err := client.PostJSON("/v1/admin/status", jsonData) + if err != nil { + logger.Error("Error sending status message", err) + return + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + logger.Error("Failed to send status message", fmt.Errorf("status: %d", resp.StatusCode)) + } +} + +// ScheduledTask represents a periodic task +type ScheduledTask struct { + Name string + Interval time.Duration + Job func() + stop chan struct{} + running bool + executing int32 // atomic flag: 0 - not executed, 1 running +} + +// NewScheduledTask creates a new simple task +func NewScheduledTask(name string, interval time.Duration, job func()) *ScheduledTask { + return &ScheduledTask{ + Name: name, + Interval: interval, + Job: job, + stop: make(chan struct{}), + } +} + +// Start begins the periodic task +func (t *ScheduledTask) Start() { + if t.running { + return + } + t.running = true + + go func() { + ticker := time.NewTicker(t.Interval) + defer ticker.Stop() + + logger.Info("Task started", zap.String("name", t.Name)) + + for { + select { + case <-ticker.C: + t.runSafely() + case <-t.stop: + logger.Info("Task stopped", zap.String("name", t.Name)) + return + } + } + }() +} + +// runSafely executes the job with panic recovery and prevents overlap +func (t *ScheduledTask) runSafely() { + // Attempt to set the flag + if !atomic.CompareAndSwapInt32(&t.executing, 0, 1) { + logger.Warn("Task skipped - previous execution still running", zap.String("name", t.Name)) + return + } + + // Clear atomic flag after execution + defer atomic.StoreInt32(&t.executing, 0) + + defer func() { + if r := recover(); r != nil { + logger.Fatal("Task panicked", zap.String("name", t.Name), zap.Any("recover", r)) + } + }() + + t.Job() +} + +// Stop stops the periodic task +func (t *ScheduledTask) Stop() { + if !t.running { + return + } + t.running = false + close(t.stop) +} + +// IsExecuting returns whether the task is currently executing +func (t *ScheduledTask) IsExecuting() bool { + return atomic.LoadInt32(&t.executing) == 1 +}