mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Add scheduled tasks (#13470)
### What problem does this PR solve? 1. RAGFlow server will send heartbeat periodically. 2. This PR will including: - Scheduled task - API server message sending - Admin server API to receive the message. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/init_data"
|
||||
"ragflow/internal/server"
|
||||
"ragflow/internal/utility"
|
||||
@@ -30,7 +31,7 @@ import (
|
||||
func main() {
|
||||
// Initialize logger with default level
|
||||
// logger.Init("info"); // set debug log level
|
||||
if err := logger.Init("debug"); err != nil {
|
||||
if err := logger.Init("info"); err != nil {
|
||||
panic(fmt.Sprintf("Failed to initialize logger: %v", err))
|
||||
}
|
||||
|
||||
@@ -45,28 +46,21 @@ func main() {
|
||||
}
|
||||
logger.Info("Model providers loaded", zap.Int("count", len(server.GetModelProviders())))
|
||||
|
||||
cfg := server.GetConfig()
|
||||
config := server.GetConfig()
|
||||
|
||||
// Reinitialize logger with configured level if different
|
||||
if cfg.Log.Level != "" && cfg.Log.Level != "info" {
|
||||
if err := logger.Init(cfg.Log.Level); err != nil {
|
||||
if config.Log.Level != "" && config.Log.Level != "info" {
|
||||
if err := logger.Init(config.Log.Level); err != nil {
|
||||
logger.Error("Failed to reinitialize logger with configured level", err)
|
||||
}
|
||||
}
|
||||
server.SetLogger(logger.Logger)
|
||||
|
||||
logger.Info("Server mode", zap.String("mode", cfg.Server.Mode))
|
||||
logger.Info("Server mode", zap.String("mode", config.Server.Mode))
|
||||
|
||||
// Print all configuration settings
|
||||
server.PrintAll()
|
||||
|
||||
// Set Gin mode
|
||||
if cfg.Server.Mode == "release" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
} else {
|
||||
gin.SetMode(gin.DebugMode)
|
||||
}
|
||||
|
||||
// Initialize database
|
||||
if err := dao.InitDB(); err != nil {
|
||||
logger.Fatal("Failed to initialize database", zap.Error(err))
|
||||
@@ -80,13 +74,13 @@ func main() {
|
||||
}
|
||||
|
||||
// Initialize doc engine
|
||||
if err := engine.Init(&cfg.DocEngine); err != nil {
|
||||
if err := engine.Init(&config.DocEngine); err != nil {
|
||||
logger.Fatal("Failed to initialize doc engine", zap.Error(err))
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
// Initialize Redis cache
|
||||
if err := cache.Init(&cfg.Redis); err != nil {
|
||||
if err := cache.Init(&config.Redis); err != nil {
|
||||
logger.Fatal("Failed to initialize Redis", zap.Error(err))
|
||||
}
|
||||
defer cache.Close()
|
||||
@@ -112,6 +106,20 @@ func main() {
|
||||
logger.Fatal("Failed to initialize query builder", zap.Error(err))
|
||||
}
|
||||
|
||||
startServer(config)
|
||||
|
||||
logger.Info("Server exited")
|
||||
}
|
||||
|
||||
func startServer(config *server.Config) {
|
||||
|
||||
// Set Gin mode
|
||||
if config.Server.Mode == "release" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
} else {
|
||||
gin.SetMode(gin.DebugMode)
|
||||
}
|
||||
|
||||
// Initialize service layer
|
||||
userService := service.NewUserService()
|
||||
documentService := service.NewDocumentService()
|
||||
@@ -147,7 +155,7 @@ func main() {
|
||||
ginEngine := gin.New()
|
||||
|
||||
// Middleware
|
||||
if cfg.Server.Mode == "debug" {
|
||||
if config.Server.Mode == "debug" {
|
||||
ginEngine.Use(gin.Logger())
|
||||
}
|
||||
ginEngine.Use(gin.Recovery())
|
||||
@@ -156,7 +164,7 @@ func main() {
|
||||
r.Setup(ginEngine)
|
||||
|
||||
// Create HTTP server
|
||||
addr := fmt.Sprintf(":%d", cfg.Server.Port)
|
||||
addr := fmt.Sprintf(":%d", config.Server.Port)
|
||||
srv := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: ginEngine,
|
||||
@@ -172,12 +180,39 @@ func main() {
|
||||
" /_/ |_|/_/ |_|\\____//_/ /_/ \\____/ |__/|__/\n",
|
||||
)
|
||||
logger.Info(fmt.Sprintf("Version: %s", utility.GetRAGFlowVersion()))
|
||||
logger.Info(fmt.Sprintf("Server starting on port: %d", cfg.Server.Port))
|
||||
logger.Info(fmt.Sprintf("Server starting on port: %d", config.Server.Port))
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.Fatal("Failed to start server", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
// Get local IP address for heartbeat reporting
|
||||
localIP := utility.GetLocalIP()
|
||||
if localIP == "" {
|
||||
localIP = "127.0.0.1"
|
||||
}
|
||||
|
||||
// Initialize and start heartbeat reporter to admin server
|
||||
heartbeatService := service.NewHeartbeatSender(
|
||||
logger.Logger,
|
||||
common.ServerTypeAPI,
|
||||
fmt.Sprintf("ragflow-server-%d", config.Server.Port),
|
||||
localIP,
|
||||
config.Server.Port,
|
||||
)
|
||||
if err := heartbeatService.InitHTTPClient(); err != nil {
|
||||
logger.Warn("Failed to initialize heartbeat service", zap.Error(err))
|
||||
} else {
|
||||
// Start heartbeat reporter with 30 seconds interval
|
||||
heartbeatReporter := utility.NewScheduledTask("Heartbeat reporter", 3*time.Second, func() {
|
||||
if err := heartbeatService.SendHeartbeat(); err != nil {
|
||||
logger.Warn("Failed to send heartbeat", zap.Error(err))
|
||||
}
|
||||
})
|
||||
heartbeatReporter.Start()
|
||||
defer heartbeatReporter.Stop()
|
||||
}
|
||||
|
||||
// Wait for interrupt signal to gracefully shutdown
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGUSR2)
|
||||
@@ -194,6 +229,4 @@ func main() {
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
logger.Fatal("Server forced to shutdown", zap.Error(err))
|
||||
}
|
||||
|
||||
logger.Info("Server exited")
|
||||
}
|
||||
|
||||
@@ -19,10 +19,12 @@ package admin
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/server"
|
||||
"ragflow/internal/service"
|
||||
"ragflow/internal/utility"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -111,7 +113,7 @@ func (h *Handler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, code, err := h.userService.LoginByEmail(&req)
|
||||
user, code, err := h.userService.LoginByEmail(&req, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": code,
|
||||
@@ -135,8 +137,9 @@ func (h *Handler) Login(c *gin.Context) {
|
||||
c.Header("Access-Control-Expose-Headers", "Authorization")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "Login successful",
|
||||
"code": common.CodeSuccess,
|
||||
"message": "Welcome back!",
|
||||
"data": user,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -943,3 +946,31 @@ func (h *Handler) HandleNoRoute(c *gin.Context) {
|
||||
Message: "The requested resource was not found",
|
||||
})
|
||||
}
|
||||
|
||||
// Reports handle heartbeat reports from servers
|
||||
func (h *Handler) Reports(c *gin.Context) {
|
||||
var req common.BaseMessage
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
errorResponse(c, "Invalid request body: "+err.Error(), 400)
|
||||
return
|
||||
}
|
||||
|
||||
// Set default timestamp if not provided
|
||||
if req.Timestamp.IsZero() {
|
||||
req.Timestamp = time.Now()
|
||||
}
|
||||
|
||||
// Only process heartbeat messages for now
|
||||
if req.MessageType != common.MessageHeartbeat {
|
||||
errorResponse(c, "Unsupported report type: "+string(req.MessageType), 400)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle the heartbeat
|
||||
if err := h.service.HandleHeartbeat(&req); err != nil {
|
||||
errorResponse(c, "Failed to process heartbeat: "+err.Error(), 500)
|
||||
return
|
||||
}
|
||||
|
||||
successNoData(c, "Heartbeat received successfully")
|
||||
}
|
||||
|
||||
76
internal/admin/heartbeat.go
Normal file
76
internal/admin/heartbeat.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"ragflow/internal/common"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ServerStatusStore is a thread-safe global server status storage
|
||||
type ServerStatusStore struct {
|
||||
mu sync.RWMutex
|
||||
servers map[string]*common.BaseMessage // key: server_id
|
||||
}
|
||||
|
||||
// GlobalServerStatusStore is the global instance
|
||||
var GlobalServerStatusStore = &ServerStatusStore{
|
||||
servers: make(map[string]*common.BaseMessage),
|
||||
}
|
||||
|
||||
// UpdateStatus updates or adds a server status
|
||||
func (s *ServerStatusStore) UpdateStatus(serverID string, status *common.BaseMessage) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.servers[serverID] = status
|
||||
}
|
||||
|
||||
// GetStatus gets a single server status
|
||||
func (s *ServerStatusStore) GetStatus(serverID string) (*common.BaseMessage, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
status, ok := s.servers[serverID]
|
||||
return status, ok
|
||||
}
|
||||
|
||||
// GetAllStatuses gets all server statuses
|
||||
func (s *ServerStatusStore) GetAllStatuses() []*common.BaseMessage {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
result := make([]*common.BaseMessage, 0, len(s.servers))
|
||||
for _, status := range s.servers {
|
||||
result = append(result, status)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetStatusesByType gets server statuses by type
|
||||
func (s *ServerStatusStore) GetStatusesByType(serverType common.ServerType) []*common.BaseMessage {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
result := make([]*common.BaseMessage, 0)
|
||||
for _, status := range s.servers {
|
||||
if status.ServerType == serverType {
|
||||
result = append(result, status)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// RemoveStatus removes a server status
|
||||
func (s *ServerStatusStore) RemoveStatus(serverID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.servers, serverID)
|
||||
}
|
||||
|
||||
// CleanupStaleStatuses cleans up servers that haven't reported for a specified duration
|
||||
func (s *ServerStatusStore) CleanupStaleStatuses(maxAge time.Duration) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
now := time.Now()
|
||||
for id, status := range s.servers {
|
||||
if now.Sub(status.Timestamp) > maxAge {
|
||||
delete(s.servers, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -47,6 +47,8 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
admin.GET("/ping", r.handler.Ping)
|
||||
admin.POST("/login", r.handler.Login)
|
||||
|
||||
admin.POST("/reports", r.handler.Reports)
|
||||
|
||||
// Protected routes
|
||||
protected := admin.Group("")
|
||||
protected.Use(r.handler.AuthMiddleware())
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"ragflow/internal/cache"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/engine/elasticsearch"
|
||||
"ragflow/internal/model"
|
||||
@@ -732,3 +733,18 @@ func (s *Service) TestSandboxConnection(providerType string, config map[string]i
|
||||
"connected": true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// HandleHeartbeat handle heartbeat
|
||||
func (s *Service) HandleHeartbeat(msg *common.BaseMessage) error {
|
||||
status := &common.BaseMessage{
|
||||
ServerName: msg.ServerName,
|
||||
ServerType: msg.ServerType,
|
||||
Host: msg.Host,
|
||||
Port: msg.Port,
|
||||
Version: msg.Version,
|
||||
Timestamp: msg.Timestamp,
|
||||
Ext: msg.Ext,
|
||||
}
|
||||
GlobalServerStatusStore.UpdateStatus(msg.ServerName, status)
|
||||
return nil
|
||||
}
|
||||
|
||||
33
internal/common/status_message.go
Normal file
33
internal/common/status_message.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type MessageType string
|
||||
|
||||
const (
|
||||
MessageHeartbeat MessageType = "heartbeat"
|
||||
MessageMetric MessageType = "metric"
|
||||
MessageEvent MessageType = "event"
|
||||
)
|
||||
|
||||
type ServerType string
|
||||
|
||||
const (
|
||||
ServerTypeAPI ServerType = "api_server" // API server
|
||||
ServerTypeWorker ServerType = "ingestor" // Ingestion server
|
||||
ServerTypeScheduler ServerType = "data_collector" // Data collection server
|
||||
)
|
||||
|
||||
type BaseMessage struct {
|
||||
MessageID int64 `json:"report_id"`
|
||||
MessageType MessageType `json:"report_type"`
|
||||
ServerName string `json:"server_id"`
|
||||
ServerType ServerType `json:"server_type"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Version string `json:"version"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Ext map[string]interface{} `json:"ext,omitempty"`
|
||||
}
|
||||
@@ -51,7 +51,7 @@ c_api_debug: $(BUILD_DIR)
|
||||
|
||||
# Test the Go bindings
|
||||
test_go: c_api
|
||||
cd bindings/example && go run main.go ../../$(BUILD_DIR) "这是一个测试文本。This is a test."
|
||||
cd bindings/example && go run main.go ../../$(BUILD_DIR) "This is a test."
|
||||
|
||||
# Run memory test
|
||||
test_memory: c_api
|
||||
|
||||
@@ -164,7 +164,7 @@ func (h *UserHandler) LoginByEmail(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, code, err := h.userService.LoginByEmail(&req)
|
||||
user, code, err := h.userService.LoginByEmail(&req, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
|
||||
@@ -40,9 +40,16 @@ type Config struct {
|
||||
DocEngine DocEngineConfig `mapstructure:"doc_engine"`
|
||||
RegisterEnabled int `mapstructure:"register_enabled"`
|
||||
OAuth map[string]OAuthConfig `mapstructure:"oauth"`
|
||||
Admin AdminConfig `mapstructure:"admin"`
|
||||
UserDefaultLLM UserDefaultLLMConfig `mapstructure:"user_default_llm"`
|
||||
}
|
||||
|
||||
// AdminConfig admin server configuration
|
||||
type AdminConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"http_port"`
|
||||
}
|
||||
|
||||
// UserDefaultLLMConfig user default LLM configuration
|
||||
type UserDefaultLLMConfig struct {
|
||||
DefaultModels DefaultModelsConfig `mapstructure:"default_models"`
|
||||
@@ -50,19 +57,19 @@ type UserDefaultLLMConfig struct {
|
||||
|
||||
// DefaultModelsConfig default models configuration
|
||||
type DefaultModelsConfig struct {
|
||||
ChatModel ModelConfig `mapstructure:"chat_model"`
|
||||
EmbeddingModel ModelConfig `mapstructure:"embedding_model"`
|
||||
RerankModel ModelConfig `mapstructure:"rerank_model"`
|
||||
ASRModel ModelConfig `mapstructure:"asr_model"`
|
||||
ChatModel ModelConfig `mapstructure:"chat_model"`
|
||||
EmbeddingModel ModelConfig `mapstructure:"embedding_model"`
|
||||
RerankModel ModelConfig `mapstructure:"rerank_model"`
|
||||
ASRModel ModelConfig `mapstructure:"asr_model"`
|
||||
Image2TextModel ModelConfig `mapstructure:"image2text_model"`
|
||||
}
|
||||
|
||||
// ModelConfig model configuration
|
||||
type ModelConfig struct {
|
||||
Name string `mapstructure:"name"`
|
||||
APIKey string `mapstructure:"api_key"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
Factory string `mapstructure:"factory"`
|
||||
Name string `mapstructure:"name"`
|
||||
APIKey string `mapstructure:"api_key"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
Factory string `mapstructure:"factory"`
|
||||
}
|
||||
|
||||
// OAuthConfig OAuth configuration for a channel
|
||||
@@ -325,6 +332,20 @@ func Init(configPath string) error {
|
||||
return fmt.Errorf("unmarshal config error: %w", err)
|
||||
}
|
||||
|
||||
// Set default values for admin configuration if not configured
|
||||
if globalConfig.Admin.Host == "" {
|
||||
globalConfig.Admin.Host = v.GetString("admin.host")
|
||||
}
|
||||
if globalConfig.Admin.Host == "" {
|
||||
globalConfig.Admin.Host = "127.0.0.1"
|
||||
}
|
||||
if globalConfig.Admin.Port == 0 {
|
||||
globalConfig.Admin.Port = v.GetInt("admin.http_port")
|
||||
}
|
||||
if globalConfig.Admin.Port == 0 {
|
||||
globalConfig.Admin.Port = 9381
|
||||
}
|
||||
|
||||
// Load REGISTER_ENABLED from environment variable (default: 1)
|
||||
registerEnabled := 1
|
||||
if envVal := os.Getenv("REGISTER_ENABLED"); envVal != "" {
|
||||
@@ -357,8 +378,8 @@ func Init(configPath string) error {
|
||||
if v.IsSet("ragflow") {
|
||||
ragflowConfig := v.Sub("ragflow")
|
||||
if ragflowConfig != nil {
|
||||
globalConfig.Server.Port = ragflowConfig.GetInt("http_port") + 2 // 9382, by default
|
||||
// globalConfig.Server.Port = ragflowConfig.GetInt("http_port") // Correct
|
||||
//globalConfig.Server.Port = ragflowConfig.GetInt("http_port") + 2 // 9382, by default
|
||||
globalConfig.Server.Port = ragflowConfig.GetInt("http_port") // Correct
|
||||
// If mode is not set, default to debug
|
||||
if globalConfig.Server.Mode == "" {
|
||||
globalConfig.Server.Mode = "release"
|
||||
@@ -484,6 +505,14 @@ func GetConfig() *Config {
|
||||
return globalConfig
|
||||
}
|
||||
|
||||
// GetAdminConfig gets the admin server configuration
|
||||
func GetAdminConfig() *AdminConfig {
|
||||
if globalConfig == nil {
|
||||
return nil
|
||||
}
|
||||
return &globalConfig.Admin
|
||||
}
|
||||
|
||||
// SetLogger sets the logger instance
|
||||
func SetLogger(l *zap.Logger) {
|
||||
zapLogger = l
|
||||
|
||||
136
internal/service/heartbeat_sender.go
Normal file
136
internal/service/heartbeat_sender.go
Normal file
@@ -0,0 +1,136 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/server"
|
||||
"ragflow/internal/utility"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// HeartbeatSender is responsible for sending heartbeat reports to the admin server
|
||||
type HeartbeatSender struct {
|
||||
client *utility.HTTPClient
|
||||
logger *zap.Logger
|
||||
serverType common.ServerType
|
||||
serverName string
|
||||
host string
|
||||
port int
|
||||
version string
|
||||
lastSuccess bool
|
||||
attemptCount int
|
||||
}
|
||||
|
||||
// NewHeartbeatSender creates a new heartbeat service instance
|
||||
func NewHeartbeatSender(logger *zap.Logger, serverType common.ServerType, serverName, host string, port int) *HeartbeatSender {
|
||||
return &HeartbeatSender{
|
||||
logger: logger,
|
||||
serverType: serverType,
|
||||
serverName: serverName,
|
||||
host: host,
|
||||
port: port,
|
||||
version: utility.GetRAGFlowVersion(),
|
||||
lastSuccess: false,
|
||||
attemptCount: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// InitHTTPClient initializes the HTTP client with admin server configuration
|
||||
func (h *HeartbeatSender) InitHTTPClient() error {
|
||||
adminConfig := server.GetAdminConfig()
|
||||
if adminConfig == nil {
|
||||
return fmt.Errorf("admin configuration not found")
|
||||
}
|
||||
|
||||
h.client = utility.NewHTTPClientBuilder().
|
||||
WithHost(adminConfig.Host).
|
||||
WithPort(adminConfig.Port).
|
||||
WithTimeout(10 * time.Second).
|
||||
Build()
|
||||
|
||||
h.logger.Info("Heartbeat HTTP client initialized",
|
||||
zap.String("admin_host", adminConfig.Host),
|
||||
zap.Int("admin_port", adminConfig.Port+2),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendHeartbeat sends a heartbeat message to the admin server
|
||||
func (h *HeartbeatSender) SendHeartbeat() error {
|
||||
|
||||
if h.attemptCount < 10 {
|
||||
if h.lastSuccess {
|
||||
h.attemptCount++
|
||||
return nil
|
||||
}
|
||||
}
|
||||
h.attemptCount = 0
|
||||
h.lastSuccess = false
|
||||
|
||||
if h.client == nil {
|
||||
if err := h.InitHTTPClient(); err != nil {
|
||||
h.logger.Error("Failed to initialize HTTP client", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
message := &common.BaseMessage{
|
||||
MessageID: time.Now().UnixNano(),
|
||||
MessageType: common.MessageHeartbeat,
|
||||
ServerName: h.serverName,
|
||||
ServerType: h.serverType,
|
||||
Host: h.host,
|
||||
Port: h.port,
|
||||
Version: h.version,
|
||||
Timestamp: time.Now(),
|
||||
Ext: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to marshal heartbeat message", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := h.client.PostJSON("/api/v1/admin/reports", jsonData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
h.logger.Error("Heartbeat request failed",
|
||||
zap.Int("status_code", resp.StatusCode),
|
||||
)
|
||||
return fmt.Errorf("heartbeat request failed with status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
h.logger.Debug("Heartbeat sent successfully",
|
||||
zap.String("server_id", h.serverName),
|
||||
zap.String("server_type", string(h.serverType)),
|
||||
)
|
||||
|
||||
h.lastSuccess = true
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -309,8 +309,8 @@ func (s *UserService) Login(req *LoginRequest) (*model.User, common.ErrorCode, e
|
||||
// - CodeAuthenticationError (109): Email not registered or password mismatch
|
||||
// - CodeServerError (500): Password decryption failure
|
||||
// - CodeForbidden (403): Account disabled
|
||||
func (s *UserService) LoginByEmail(req *EmailLoginRequest) (*model.User, common.ErrorCode, error) {
|
||||
if req.Email == "admin@ragflow.io" {
|
||||
func (s *UserService) LoginByEmail(req *EmailLoginRequest, adminLogin bool) (*model.User, common.ErrorCode, error) {
|
||||
if !adminLogin && req.Email == "admin@ragflow.io" {
|
||||
return nil, common.CodeAuthenticationError, fmt.Errorf("default admin account cannot be used to login normal services")
|
||||
}
|
||||
|
||||
|
||||
274
internal/utility/http_client.go
Normal file
274
internal/utility/http_client.go
Normal file
@@ -0,0 +1,274 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package utility
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HTTPClient is a configurable HTTP client
|
||||
type HTTPClient struct {
|
||||
host string
|
||||
port int
|
||||
useSSL bool
|
||||
timeout time.Duration
|
||||
headers map[string]string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// HTTPClientBuilder is a builder for HTTPClient
|
||||
type HTTPClientBuilder struct {
|
||||
client *HTTPClient
|
||||
}
|
||||
|
||||
// NewHTTPClientBuilder creates a new HTTPClientBuilder with default values
|
||||
func NewHTTPClientBuilder() *HTTPClientBuilder {
|
||||
return &HTTPClientBuilder{
|
||||
client: &HTTPClient{
|
||||
host: "localhost",
|
||||
port: 80,
|
||||
useSSL: false,
|
||||
timeout: 30 * time.Second,
|
||||
headers: make(map[string]string),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// WithHost sets the host
|
||||
func (b *HTTPClientBuilder) WithHost(host string) *HTTPClientBuilder {
|
||||
b.client.host = host
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPort sets the port
|
||||
func (b *HTTPClientBuilder) WithPort(port int) *HTTPClientBuilder {
|
||||
b.client.port = port
|
||||
return b
|
||||
}
|
||||
|
||||
// WithSSL enables or disables SSL
|
||||
func (b *HTTPClientBuilder) WithSSL(useSSL bool) *HTTPClientBuilder {
|
||||
b.client.useSSL = useSSL
|
||||
return b
|
||||
}
|
||||
|
||||
// WithTimeout sets the timeout duration
|
||||
func (b *HTTPClientBuilder) WithTimeout(timeout time.Duration) *HTTPClientBuilder {
|
||||
b.client.timeout = timeout
|
||||
return b
|
||||
}
|
||||
|
||||
// WithHeader adds a single header
|
||||
func (b *HTTPClientBuilder) WithHeader(key, value string) *HTTPClientBuilder {
|
||||
b.client.headers[key] = value
|
||||
return b
|
||||
}
|
||||
|
||||
// WithHeaders sets multiple headers
|
||||
func (b *HTTPClientBuilder) WithHeaders(headers map[string]string) *HTTPClientBuilder {
|
||||
for key, value := range headers {
|
||||
b.client.headers[key] = value
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Build creates the HTTPClient
|
||||
func (b *HTTPClientBuilder) Build() *HTTPClient {
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: false,
|
||||
},
|
||||
}
|
||||
|
||||
// If SSL is disabled, allow insecure connections
|
||||
if !b.client.useSSL {
|
||||
transport.TLSClientConfig.InsecureSkipVerify = true
|
||||
}
|
||||
|
||||
b.client.httpClient = &http.Client{
|
||||
Timeout: b.client.timeout,
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
return b.client
|
||||
}
|
||||
|
||||
// SetHost sets the host
|
||||
func (c *HTTPClient) SetHost(host string) {
|
||||
c.host = host
|
||||
}
|
||||
|
||||
// SetPort sets the port
|
||||
func (c *HTTPClient) SetPort(port int) {
|
||||
c.port = port
|
||||
}
|
||||
|
||||
// SetSSL enables or disables SSL
|
||||
func (c *HTTPClient) SetSSL(useSSL bool) {
|
||||
c.useSSL = useSSL
|
||||
}
|
||||
|
||||
// SetTimeout sets the timeout duration
|
||||
func (c *HTTPClient) SetTimeout(timeout time.Duration) {
|
||||
c.timeout = timeout
|
||||
c.httpClient.Timeout = timeout
|
||||
}
|
||||
|
||||
// SetHeader sets a single header
|
||||
func (c *HTTPClient) SetHeader(key, value string) {
|
||||
c.headers[key] = value
|
||||
}
|
||||
|
||||
// SetHeaders sets multiple headers
|
||||
func (c *HTTPClient) SetHeaders(headers map[string]string) {
|
||||
c.headers = headers
|
||||
}
|
||||
|
||||
// AddHeader adds a header without removing existing ones
|
||||
func (c *HTTPClient) AddHeader(key, value string) {
|
||||
c.headers[key] = value
|
||||
}
|
||||
|
||||
// GetHeaders returns a copy of all headers
|
||||
func (c *HTTPClient) GetHeaders() map[string]string {
|
||||
headersCopy := make(map[string]string)
|
||||
for k, v := range c.headers {
|
||||
headersCopy[k] = v
|
||||
}
|
||||
return headersCopy
|
||||
}
|
||||
|
||||
// GetBaseURL returns the base URL
|
||||
func (c *HTTPClient) GetBaseURL() string {
|
||||
scheme := "http"
|
||||
if c.useSSL {
|
||||
scheme = "https"
|
||||
}
|
||||
return fmt.Sprintf("%s://%s:%d", scheme, c.host, c.port)
|
||||
}
|
||||
|
||||
// GetFullURL returns the full URL for a given path
|
||||
func (c *HTTPClient) GetFullURL(path string) string {
|
||||
baseURL := c.GetBaseURL()
|
||||
// Ensure path starts with /
|
||||
if path != "" && path[0] != '/' {
|
||||
path = "/" + path
|
||||
}
|
||||
return baseURL + path
|
||||
}
|
||||
|
||||
// prepareRequest creates an HTTP request with configured headers
|
||||
func (c *HTTPClient) prepareRequest(method, urlStr string, body io.Reader) (*http.Request, error) {
|
||||
req, err := http.NewRequest(method, urlStr, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add configured headers
|
||||
for key, value := range c.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// Get performs a GET request
|
||||
func (c *HTTPClient) Get(path string) (*http.Response, error) {
|
||||
urlStr := c.GetFullURL(path)
|
||||
req, err := c.prepareRequest(http.MethodGet, urlStr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.httpClient.Do(req)
|
||||
}
|
||||
|
||||
// GetWithParams performs a GET request with query parameters
|
||||
func (c *HTTPClient) GetWithParams(path string, params map[string]string) (*http.Response, error) {
|
||||
urlStr := c.GetFullURL(path)
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query := u.Query()
|
||||
for key, value := range params {
|
||||
query.Set(key, value)
|
||||
}
|
||||
u.RawQuery = query.Encode()
|
||||
|
||||
req, err := c.prepareRequest(http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.httpClient.Do(req)
|
||||
}
|
||||
|
||||
// Post performs a POST request
|
||||
func (c *HTTPClient) Post(path string, body []byte) (*http.Response, error) {
|
||||
urlStr := c.GetFullURL(path)
|
||||
req, err := c.prepareRequest(http.MethodPost, urlStr, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.httpClient.Do(req)
|
||||
}
|
||||
|
||||
// PostJSON performs a POST request with JSON content type
|
||||
func (c *HTTPClient) PostJSON(path string, body []byte) (*http.Response, error) {
|
||||
c.SetHeader("Content-Type", "application/json")
|
||||
return c.Post(path, body)
|
||||
}
|
||||
|
||||
// Put performs a PUT request
|
||||
func (c *HTTPClient) Put(path string, body []byte) (*http.Response, error) {
|
||||
urlStr := c.GetFullURL(path)
|
||||
req, err := c.prepareRequest(http.MethodPut, urlStr, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.httpClient.Do(req)
|
||||
}
|
||||
|
||||
// Delete performs a DELETE request
|
||||
func (c *HTTPClient) Delete(path string) (*http.Response, error) {
|
||||
urlStr := c.GetFullURL(path)
|
||||
req, err := c.prepareRequest(http.MethodDelete, urlStr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.httpClient.Do(req)
|
||||
}
|
||||
|
||||
// Do performs a request with the given method
|
||||
func (c *HTTPClient) Do(method, path string, body []byte) (*http.Response, error) {
|
||||
urlStr := c.GetFullURL(path)
|
||||
var bodyReader io.Reader
|
||||
if body != nil {
|
||||
bodyReader = bytes.NewReader(body)
|
||||
}
|
||||
req, err := c.prepareRequest(method, urlStr, bodyReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.httpClient.Do(req)
|
||||
}
|
||||
49
internal/utility/network.go
Normal file
49
internal/utility/network.go
Normal file
@@ -0,0 +1,49 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package utility
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// GetLocalIP returns the first non-loopback local IP address of the host
|
||||
func GetLocalIP() string {
|
||||
addrs, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
// Check the address type and skip loopback addresses
|
||||
if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
|
||||
if ipnet.IP.To4() != nil {
|
||||
return ipnet.IP.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetLocalIPWithFallback returns the local IP address with a fallback value
|
||||
func GetLocalIPWithFallback(fallback string) string {
|
||||
ip := GetLocalIP()
|
||||
if ip == "" {
|
||||
return fallback
|
||||
}
|
||||
return ip
|
||||
}
|
||||
156
internal/utility/scheduled_task.go
Normal file
156
internal/utility/scheduled_task.go
Normal file
@@ -0,0 +1,156 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package utility
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"ragflow/internal/logger"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type StatusMessage struct {
|
||||
ID int `json:"id"`
|
||||
Version string `json:"version"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
NodeName string `json:"node_name"`
|
||||
ExtInfo string `json:"ext_info"`
|
||||
}
|
||||
|
||||
func NewStatusMessage(id int, version string, nodeName string, extInfo string) *StatusMessage {
|
||||
return &StatusMessage{
|
||||
ID: id,
|
||||
Version: version,
|
||||
Timestamp: time.Now(),
|
||||
NodeName: nodeName,
|
||||
ExtInfo: extInfo,
|
||||
}
|
||||
}
|
||||
|
||||
func StatusMessageSending() {
|
||||
// Construct status message
|
||||
statusMessage := NewStatusMessage(0, "v1", "ragflow", "")
|
||||
|
||||
// Serialize to JSON
|
||||
jsonData, err := json.Marshal(statusMessage)
|
||||
if err != nil {
|
||||
logger.Error("Failed to marshal status message", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create HTTP client
|
||||
client := NewHTTPClientBuilder().
|
||||
WithHost("127.0.0.1").
|
||||
WithPort(9381).
|
||||
WithSSL(false).
|
||||
WithTimeout(10 * time.Second).
|
||||
Build()
|
||||
|
||||
// Send POST request
|
||||
resp, err := client.PostJSON("/v1/admin/status", jsonData)
|
||||
if err != nil {
|
||||
logger.Error("Error sending status message", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
logger.Error("Failed to send status message", fmt.Errorf("status: %d", resp.StatusCode))
|
||||
}
|
||||
}
|
||||
|
||||
// ScheduledTask represents a periodic task
|
||||
type ScheduledTask struct {
|
||||
Name string
|
||||
Interval time.Duration
|
||||
Job func()
|
||||
stop chan struct{}
|
||||
running bool
|
||||
executing int32 // atomic flag: 0 - not executed, 1 running
|
||||
}
|
||||
|
||||
// NewScheduledTask creates a new simple task
|
||||
func NewScheduledTask(name string, interval time.Duration, job func()) *ScheduledTask {
|
||||
return &ScheduledTask{
|
||||
Name: name,
|
||||
Interval: interval,
|
||||
Job: job,
|
||||
stop: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the periodic task
|
||||
func (t *ScheduledTask) Start() {
|
||||
if t.running {
|
||||
return
|
||||
}
|
||||
t.running = true
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(t.Interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
logger.Info("Task started", zap.String("name", t.Name))
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
t.runSafely()
|
||||
case <-t.stop:
|
||||
logger.Info("Task stopped", zap.String("name", t.Name))
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// runSafely executes the job with panic recovery and prevents overlap
|
||||
func (t *ScheduledTask) runSafely() {
|
||||
// Attempt to set the flag
|
||||
if !atomic.CompareAndSwapInt32(&t.executing, 0, 1) {
|
||||
logger.Warn("Task skipped - previous execution still running", zap.String("name", t.Name))
|
||||
return
|
||||
}
|
||||
|
||||
// Clear atomic flag after execution
|
||||
defer atomic.StoreInt32(&t.executing, 0)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Fatal("Task panicked", zap.String("name", t.Name), zap.Any("recover", r))
|
||||
}
|
||||
}()
|
||||
|
||||
t.Job()
|
||||
}
|
||||
|
||||
// Stop stops the periodic task
|
||||
func (t *ScheduledTask) Stop() {
|
||||
if !t.running {
|
||||
return
|
||||
}
|
||||
t.running = false
|
||||
close(t.stop)
|
||||
}
|
||||
|
||||
// IsExecuting returns whether the task is currently executing
|
||||
func (t *ScheduledTask) IsExecuting() bool {
|
||||
return atomic.LoadInt32(&t.executing) == 1
|
||||
}
|
||||
Reference in New Issue
Block a user