diff --git a/cmd/admin_server.go b/cmd/admin_server.go index 3775d038b7..46c695fa60 100644 --- a/cmd/admin_server.go +++ b/cmd/admin_server.go @@ -27,6 +27,7 @@ import ( "ragflow/internal/cache" "ragflow/internal/common" "ragflow/internal/engine" + "ragflow/internal/utility" "syscall" "time" @@ -36,7 +37,6 @@ import ( "ragflow/internal/admin" "ragflow/internal/dao" "ragflow/internal/server" - "ragflow/internal/utility" ) func main() { @@ -135,9 +135,6 @@ func main() { Handler: ginEngine, } - // Print RAGFlow version - common.Info("RAGFlow version", zap.String("version", utility.GetRAGFlowVersion())) - // Print all configuration settings server.PrintAll() @@ -149,10 +146,22 @@ func main() { " / _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / ___ / /_/ / / / / / / / / / /\n" + " /_/ |_/_/ |_\\____/_/ /_/\\____/|__/|__/ /_/ |_\\__,_/_/ /_/ /_/_/_/ /_/ \n") - // Start server in a goroutine + // Print RAGFlow version + common.Info(fmt.Sprintf("RAGFlow admin version: %s", utility.GetRAGFlowVersion())) + + // Start ingestion manager (gRPC) in a goroutine + ingestionMgr := admin.NewAdminServer() go func() { - common.Info(fmt.Sprintf("Admin Go Version: %s", utility.GetRAGFlowVersion())) - common.Info(fmt.Sprintf("Starting RAGFlow admin server on port: %d", cfg.Admin.Port)) + addr = fmt.Sprintf(":%d", cfg.Admin.IngestionManagerPort) + common.Info(fmt.Sprintf("Starting RAGFlow ingestion manager on port: %d", cfg.Admin.IngestionManagerPort)) + if err := ingestionMgr.Start(addr); err != nil { + common.Fatal("Failed to start RAGFlow ingestion manager", zap.Error(err)) + } + }() + + // Start HTTP server in a goroutine + go func() { + common.Info(fmt.Sprintf("Starting RAGFlow admin HTTP server on port: %d", cfg.Admin.Port)) if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { common.Fatal("Failed to start server", zap.Error(err)) } @@ -164,16 +173,19 @@ func main() { sig := <-quit common.Info("Received signal", zap.String("signal", sig.String())) - common.Info("Shutting down server...") + common.Info("Shutting down RAGFlow HTTP server...") // Create context with timeout for graceful shutdown ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - // Shutdown server + // Shutdown HTTP server if err := srv.Shutdown(ctx); err != nil { common.Fatal("Server forced to shutdown", zap.Error(err)) } - common.Info("Server exited") + common.Info("Admin HTTP server exited") + + // Stop ingestion manager + ingestionMgr.Stop() } diff --git a/cmd/ingestion_server.go b/cmd/ingestion_server.go new file mode 100644 index 0000000000..0e0e136a03 --- /dev/null +++ b/cmd/ingestion_server.go @@ -0,0 +1,195 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package main + +import ( + "context" + "flag" + "fmt" + "os" + "os/signal" + "ragflow/internal/ingestion" + "ragflow/internal/service/nlp" + "ragflow/internal/tokenizer" + "ragflow/internal/utility" + "syscall" + "time" + + "ragflow/internal/cache" + "ragflow/internal/common" + "ragflow/internal/dao" + "ragflow/internal/engine" + "ragflow/internal/server" + "ragflow/internal/storage" + + "go.uber.org/zap" +) + +func printIngestionServerHelp() { + fmt.Fprintf(os.Stderr, "Usage: %s [OPTIONS]\n\n", os.Args[0]) + fmt.Fprintf(os.Stderr, "RAGFlow Ingestion Worker - Document ingestion processing\n\n") + fmt.Fprintf(os.Stderr, "Options:\n") + fmt.Fprintf(os.Stderr, " -f string\t\tPath to config file (default: auto-detect)\n") + fmt.Fprintf(os.Stderr, " --name string\t\tIngestion server name (default: \"default_ingestion\")\n") + fmt.Fprintf(os.Stderr, " --admin-host string\tAdmin server host (overrides config file)\n") + fmt.Fprintf(os.Stderr, " --admin-port int\tAdmin server port (overrides config file)\n") + fmt.Fprintf(os.Stderr, " -h, --help\t\tShow this help message and exit\n") + fmt.Fprintf(os.Stderr, "\nExamples:\n") + fmt.Fprintf(os.Stderr, " %s # Start with default config\n", os.Args[0]) + fmt.Fprintf(os.Stderr, " %s -f /path/to/config.yaml # Start with custom config file\n", os.Args[0]) + fmt.Fprintf(os.Stderr, " %s --admin-host 10.0.0.1 --admin-port 9383\n", os.Args[0]) +} + +func main() { + // Parse command line flags + var configPath string + var name string + var adminHost string + var adminPort int + + flag.StringVar(&configPath, "f", "", "Path to config file (overrides auto-detect)") + flag.StringVar(&name, "name", "default_ingestion", "Ingestion server name") + flag.StringVar(&adminHost, "admin-host", "", "Admin server host (overrides config file)") + flag.IntVar(&adminPort, "admin-port", 0, "Admin server port (overrides config file)") + + // Custom help message + flag.Usage = printIngestionServerHelp + + flag.Parse() + + // Initialize logger with default level + if err := common.Init("info"); err != nil { + panic(fmt.Sprintf("Failed to initialize logger: %v", err)) + } + + // Initialize configuration + if err := server.Init(configPath); err != nil { + common.Fatal("Failed to initialize config", zap.Error(err)) + } + + config := server.GetConfig() + + // Override admin server host with command line argument if provided + if adminHost != "" { + config.Admin.Host = adminHost + common.Info("Admin host overridden by command line argument", zap.String("admin_host", adminHost)) + } + + // Override admin server port with command line argument if provided + if adminPort > 0 { + config.Admin.Port = adminPort + common.Info("Admin port overridden by command line argument", zap.Int("admin_port", adminPort)) + } + + // Reinitialize logger with configured level if different + if config.Log.Level != "" && config.Log.Level != "info" { + if err := common.Init(config.Log.Level); err != nil { + common.Error("Failed to reinitialize logger with configured level", err) + } + } + server.SetLogger(common.Logger) + + common.Info("Starting RAGFlow Ingestion Worker") + + // Initialize database + if err := dao.InitDB(); err != nil { + common.Fatal("Failed to initialize database", zap.Error(err)) + } + + // Initialize LLM factory data models from configuration file + if err := dao.InitLLMFactory(); err != nil { + common.Error("Failed to initialize LLM factory", err) + } else { + common.Info("LLM factory initialized successfully") + } + + // Initialize doc engine + if err := engine.Init(&config.DocEngine); err != nil { + common.Fatal("Failed to initialize doc engine", zap.Error(err)) + } + defer engine.Close() + + // Initialize Redis cache + if err := cache.Init(&config.Redis); err != nil { + common.Fatal("Failed to initialize Redis", zap.Error(err)) + } + defer cache.Close() + + // Initialize storage factory + if err := storage.InitStorageFactory(); err != nil { + common.Fatal("Failed to initialize storage factory", zap.Error(err)) + } + + // Initialize server variables (runtime variables from Redis) + if err := server.InitVariables(cache.Get()); err != nil { + common.Warn("Failed to initialize server variables from Redis, using defaults", zap.String("error", err.Error())) + } + + // Initialize tokenizer (rag_analyzer) + tokenizerCfg := &tokenizer.PoolConfig{ + DictPath: "/usr/share/infinity/resource", + } + if err := tokenizer.Init(tokenizerCfg); err != nil { + common.Fatal("Failed to initialize tokenizer", zap.Error(err)) + } + defer tokenizer.Close() + + // Initialize global QueryBuilder using tokenizer's DictPath + if err := nlp.InitQueryBuilderFromTokenizer(tokenizerCfg.DictPath); err != nil { + common.Fatal("Failed to initialize query builder", zap.Error(err)) + } + + ingestor := ingestion.NewIngestor(name, 2, []string{"pdf", "docx", "txt"}) + + // Connect to the admin server + serverAddress := fmt.Sprintf("%s:%d", config.Admin.Host, config.Admin.IngestionManagerPort) + if err := ingestor.Connect(serverAddress); err != nil { + common.Fatal(fmt.Sprintf("Error: %s", err.Error())) + } + + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGUSR2) + + // Print all configuration settings + server.PrintAll() + common.Info("\n ____ __ _\n" + + " / _/___ ____ ____ _____/ /_(_)___ ____ ________ ______ _____ _____\n" + + " / // __ \\/ __ `/ _ \\/ ___/ __/ / __ \\/ __ \\ / ___/ _ \\/ ___/ | / / _ \\/ ___/\n" + + " _/ // / / / /_/ / __(__ ) /_/ / /_/ / / / / (__ ) __/ / | |/ / __/ /\n" + + "/___/_/ /_/\\__, /\\___/____/\\__/_/\\____/_/ /_/ /____/\\___/_/ |___/\\___/_/\n" + + " /____/\n") + + // Print RAGFlow version + common.Info(fmt.Sprintf("RAGFlow admin version: %s", utility.GetRAGFlowVersion())) + + // Wait for either an OS signal or a shutdown command from the admin + select { + case sig := <-quit: + common.Info("Received signal", zap.String("signal", sig.String())) + common.Info(fmt.Sprintf("Shutting down RAGFlow ingestor %s ...", name)) + case <-ingestor.ShutdownCh: + common.Info(fmt.Sprintf("Received shutdown command from admin, stopping ingestor %s ...", name)) + } + + // Create context with timeout for graceful shutdown + _, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + ingestor.Stop() + + common.Info(fmt.Sprintf("Ingestor %s shutdown complete", name)) +} diff --git a/cmd/ragflow_cli.go b/cmd/ragflow_cli.go index cc2043687c..d8303ca9b0 100644 --- a/cmd/ragflow_cli.go +++ b/cmd/ragflow_cli.go @@ -1,3 +1,19 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + package main import ( diff --git a/cmd/server_main.go b/cmd/server_main.go index 8c5e53b8f4..01713473de 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -1,3 +1,19 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + package main import ( diff --git a/go.mod b/go.mod index 1c1eca976e..1d417111f0 100644 --- a/go.mod +++ b/go.mod @@ -17,12 +17,16 @@ require ( github.com/minio/minio-go/v7 v7.0.99 github.com/peterh/liner v1.2.2 github.com/redis/go-redis/v9 v9.18.0 + github.com/shirou/gopsutil/v3 v3.24.5 github.com/siongui/gojianfan v0.0.0-20210926212422-2f175ac615de github.com/spf13/viper v1.18.2 go.uber.org/zap v1.27.1 golang.org/x/crypto v0.47.0 + golang.org/x/net v0.49.0 golang.org/x/term v0.41.0 google.golang.org/genai v1.54.0 + google.golang.org/grpc v1.79.3 + google.golang.org/protobuf v1.36.10 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.2 gorm.io/gorm v1.25.5 @@ -58,6 +62,7 @@ require ( github.com/go-ini/ini v1.67.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.16.0 // indirect @@ -76,6 +81,7 @@ require ( github.com/klauspost/cpuid/v2 v2.2.11 // indirect github.com/klauspost/crc32 v1.3.0 // indirect github.com/leodido/go-urn v1.2.4 // indirect + github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.3 // indirect @@ -86,17 +92,22 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.1.1 // indirect github.com/philhofer/fwd v1.2.0 // indirect + github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/rs/xid v1.6.0 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/tinylib/msgp v1.6.1 // indirect + github.com/tklauser/go-sysconf v0.3.12 // indirect + github.com/tklauser/numcpus v0.6.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect + github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel v1.41.0 // indirect @@ -107,12 +118,9 @@ require ( go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/arch v0.6.0 // indirect golang.org/x/exp v0.0.0-20231226003508-02704c960a9b // indirect - golang.org/x/net v0.49.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.33.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect - google.golang.org/grpc v1.79.3 // indirect - google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect ) diff --git a/go.sum b/go.sum index ca19d27134..70794b5f47 100644 --- a/go.sum +++ b/go.sum @@ -94,6 +94,8 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -128,6 +130,7 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -166,6 +169,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= +github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= @@ -194,6 +199,8 @@ github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= +github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs= github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= @@ -205,6 +212,12 @@ github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6ke github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= +github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI= +github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk= +github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= +github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= +github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU= +github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= github.com/siongui/gojianfan v0.0.0-20210926212422-2f175ac615de h1:1/P9CcR8iENN9ybbSRWohRd3rsPp9tEWlTS/7ygvjHE= github.com/siongui/gojianfan v0.0.0-20210926212422-2f175ac615de/go.mod h1:TRwEEJlrSIv+jc66k48huOZ2aKVBPL8V29ZcsjUIH70= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= @@ -233,10 +246,16 @@ github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8 github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/tinylib/msgp v1.6.1 h1:ESRv8eL3u+DNHUoSAAQRE50Hm162zqAnBoGv9PzScPY= github.com/tinylib/msgp v1.6.1/go.mod h1:RSp0LW9oSxFut3KzESt5Voq4GVWyS+PSulT77roAqEA= +github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= +github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= +github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= +github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= +github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= @@ -293,9 +312,13 @@ golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20211117180635-dee7805ff2e1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= diff --git a/internal/admin/handler.go b/internal/admin/handler.go index b267baf5be..1ff2c7ff6c 100644 --- a/internal/admin/handler.go +++ b/internal/admin/handler.go @@ -33,11 +33,6 @@ import ( "github.com/gin-gonic/gin" ) -// Common errors -var ( - ErrUserNotFound = errors.New("user not found") -) - // Handler admin handler type Handler struct { service *Service @@ -1261,6 +1256,80 @@ func (h *Handler) SetLogLevel(c *gin.Context) { success(c, gin.H{"level": req.Level}, "Log level updated successfully") } +type StartIngestionTaskRequest struct { + FileURI string `json:"uri" binding:"required"` + From string `json:"from" binding:"required"` +} + +func (h *Handler) StartIngestionTask(c *gin.Context) { + var req StartIngestionTaskRequest + if err := c.ShouldBindJSON(&req); err != nil { + errorResponse(c, "file uri and from is required", 400) + return + } + + taskID := common.GenerateUUID() + ingestionManager.SubmitTask(&common.TaskAssignment{ + TaskId: taskID, + TaskType: "start_ingestion_task", + Config: req.FileURI, + ComeFrom: req.From, + }) + + success(c, gin.H{"task_id": taskID}, "Send task for ingestion successfully") +} + +type StopIngestionTaskRequest struct { + TaskID string `json:"task_id" binding:"required"` + From string `json:"from" binding:"required"` +} + +func (h *Handler) StopIngestionTask(c *gin.Context) { + var req StopIngestionTaskRequest + if err := c.ShouldBindJSON(&req); err != nil { + errorResponse(c, "task id and from is required", 400) + return + } + + ingestionManager.SubmitTask(&common.TaskAssignment{ + TaskId: req.TaskID, + TaskType: "cancel_ingestion_task", + ComeFrom: req.From, + }) + + success(c, gin.H{"task_id": req.TaskID}, "Cancel task successfully") +} + +func (h *Handler) ListIngestors(c *gin.Context) { + ingestionMgr := GetIngestionManager() + ingestors, err := ingestionMgr.ListIngestors() + if err != nil { + errorResponse(c, err.Error(), 500) + } + success(c, ingestors, "Get all tasks") +} + +type ShutdownIngestorRequest struct { + IngestorID string `json:"ingestor_name" binding:"required"` +} + +func (h *Handler) ShutdownIngestor(c *gin.Context) { + var req ShutdownIngestorRequest + if err := c.ShouldBindJSON(&req); err != nil { + errorResponse(c, "file uri is required", 400) + return + } + + taskID := common.GenerateUUID() + ingestionManager.SubmitTask(&common.TaskAssignment{ + TaskId: taskID, + TaskType: "shutdown_ingestor", + AssignedTo: req.IngestorID, + }) + + success(c, gin.H{"task_id": taskID, "ingestor_id": req.IngestorID}, "Shutdown ingestor") +} + // Reports handle heartbeat reports from servers func (h *Handler) Reports(c *gin.Context) { var req common.BaseMessage @@ -1295,3 +1364,14 @@ func (h *Handler) Reports(c *gin.Context) { responseWithCode(c, message, http.StatusOK, errCode) } + +// ListIngestionTasks +func (h *Handler) ListIngestionTasks(c *gin.Context) { + tasks, err := h.service.ListIngestionTasks() + if err != nil { + errorResponse(c, err.Error(), 400) + return + } + + success(c, tasks, "") +} diff --git a/internal/admin/heartbeat.go b/internal/admin/heartbeat.go index fc8901f440..d78da5da34 100644 --- a/internal/admin/heartbeat.go +++ b/internal/admin/heartbeat.go @@ -1,76 +1 @@ package admin - -import ( - "ragflow/internal/common" - "sync" - "time" -) - -// ServerStatusStore is a thread-safe global server status storage -type ServerStatusStore struct { - mu sync.RWMutex - servers map[string]*common.BaseMessage // key: server_id -} - -// GlobalServerStatusStore is the global instance -var GlobalServerStatusStore = &ServerStatusStore{ - servers: make(map[string]*common.BaseMessage), -} - -// UpdateStatus updates or adds a server status -func (s *ServerStatusStore) UpdateStatus(serverName string, status *common.BaseMessage) { - s.mu.Lock() - defer s.mu.Unlock() - s.servers[serverName] = status -} - -// GetStatus gets a single server status -func (s *ServerStatusStore) GetStatus(serverName string) (*common.BaseMessage, bool) { - s.mu.RLock() - defer s.mu.RUnlock() - status, ok := s.servers[serverName] - return status, ok -} - -// GetAllStatuses gets all server statuses -func (s *ServerStatusStore) GetAllStatuses() []*common.BaseMessage { - s.mu.RLock() - defer s.mu.RUnlock() - result := make([]*common.BaseMessage, 0, len(s.servers)) - for _, status := range s.servers { - result = append(result, status) - } - return result -} - -// GetStatusesByType gets server statuses by type -func (s *ServerStatusStore) GetStatusesByType(serverType common.ServerType) []*common.BaseMessage { - s.mu.RLock() - defer s.mu.RUnlock() - result := make([]*common.BaseMessage, 0) - for _, status := range s.servers { - if status.ServerType == serverType { - result = append(result, status) - } - } - return result -} - -// RemoveStatus removes a server status -func (s *ServerStatusStore) RemoveStatus(serverID string) { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.servers, serverID) -} - -// CleanupStaleStatuses cleans up servers that haven't reported for a specified duration -func (s *ServerStatusStore) CleanupStaleStatuses(maxAge time.Duration) { - s.mu.Lock() - defer s.mu.Unlock() - now := time.Now() - for id, status := range s.servers { - if now.Sub(status.Timestamp) > maxAge { - delete(s.servers, id) - } - } -} diff --git a/internal/admin/ingestion_manager.go b/internal/admin/ingestion_manager.go new file mode 100644 index 0000000000..ceffdf843e --- /dev/null +++ b/internal/admin/ingestion_manager.go @@ -0,0 +1,587 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package admin + +import ( + "context" + "fmt" + "net" + "sync" + "time" + + "ragflow/internal/common" + + "google.golang.org/grpc" + "google.golang.org/grpc/peer" +) + +const heartbeatTimeout = 30 * time.Second + +type IngestionManager struct { + common.UnimplementedIngestionManagerServer + mu sync.RWMutex + + // Registered ingestion servers + ingestionServers map[string]*IngestorState // ingestor id -> ingestor id + + taskStates map[string]*TaskState // task_id -> task state + + // In-memory task queue + taskQueue chan *pendingTask + + // Notifies that an ingestor slot may have freed up + slotFreed chan struct{} + + grpcServer *grpc.Server // gRPC server instance for graceful shutdown via Stop() + + ctx context.Context + cancel context.CancelFunc +} + +type TaskState struct { + taskID string // same as task_id in database + status string // created, assigned, processing, completed, failed + comeFrom string // api server id + assignTo string // ingestor id + lastUpdate time.Time + startTime *time.Time + estimatedRemainingTime time.Duration // estimated cost in seconds to complete the task + errorMessage string +} + +type IngestorState struct { + ID string + Info *common.RegisterInfo + LastHeartbeat time.Time + Stream common.IngestionManager_ActionServer + Status string // active, draining + Address string + ProcessID int64 + cpuUsage float64 + vmsUsage float64 + rssUsage float64 +} + +type pendingTask struct { + Task *common.TaskAssignment + CreatedAt time.Time +} + +var ingestionManager *IngestionManager + +func GetIngestionManager() *IngestionManager { + return ingestionManager +} + +func NewAdminServer() *IngestionManager { + ctx, cancel := context.WithCancel(context.Background()) + ingestionManager = &IngestionManager{ + taskStates: make(map[string]*TaskState), + ingestionServers: make(map[string]*IngestorState), + taskQueue: make(chan *pendingTask, 10000), + slotFreed: make(chan struct{}, 100), + ctx: ctx, + cancel: cancel, + } + go ingestionManager.dispatchLoop() + //go ingestionManager.heartbeatCheckLoop() no need to check heartbeat timeout + return ingestionManager +} + +// Action handles the bidirectional streaming RPC from ingestion servers +func (s *IngestionManager) Action(stream common.IngestionManager_ActionServer) error { + var ingestionServerID string + var state *IngestorState + + common.Info("New ingestion_server connection") + + // Start receive goroutine + receiveErrorCH := make(chan error, 1) + go func() { + for { + msg, err := stream.Recv() + if err != nil { + receiveErrorCH <- err + return + } + s.handleMessage(stream, msg, &ingestionServerID, &state) + } + }() + + // Start send goroutine: send tasks immediately when assigned to this ingestion_server + sendDone := make(chan struct{}) + go func() { + defer close(sendDone) + for { + select { + case <-stream.Context().Done(): + return + case <-s.ctx.Done(): + return + } + } + }() + + select { + case err := <-receiveErrorCH: + // Connection dropped, clean up + s.cleanupIngestionServer(ingestionServerID) + return err + case <-sendDone: + // Stream context canceled (client disconnect or server shutdown) + s.cleanupIngestionServer(ingestionServerID) + return nil + } +} + +func (s *IngestionManager) handleMessage( + stream common.IngestionManager_ActionServer, + msg *common.IngestionMessage, + ingestionServerID *string, + state **IngestorState, +) { + switch msg.MessageType { + case "REGISTER": + s.handleRegister(stream, msg, ingestionServerID, state) + + case "HEARTBEAT": + s.handleHeartbeat(msg, *ingestionServerID, *state) + + case "TASK_RESULT": + s.handleTaskResult(msg, *ingestionServerID, *state) + + case "TASK_PROGRESS": + s.handleTaskProgress(msg, *ingestionServerID, *state) + + default: + common.Info(fmt.Sprintf("Unknown message type: %s", msg.MessageType)) + err := stream.Send(&common.AdminMessage{ + MessageType: "ERROR", + ErrorMessage: "unknown message type", + }) + if err != nil { + common.Error("Fail to send unknown message", err) + return + } + } +} + +func (s *IngestionManager) handleRegister( + stream common.IngestionManager_ActionServer, + msg *common.IngestionMessage, + ingestionServerID *string, + state **IngestorState, +) { + if msg.RegisterInfo == nil { + err := stream.Send(&common.AdminMessage{ + MessageType: "ERROR", + ErrorMessage: "missing register info", + }) + if err != nil { + common.Error("Fail to send missing register info", err) + return + } + return + } + + peerHost, ok := peer.FromContext(stream.Context()) + if !ok { + err := stream.Send(&common.AdminMessage{ + MessageType: "ERROR", + ErrorMessage: "peer not found in context", + }) + if err != nil { + common.Error("Fail to send 'peer not found' message", err) + return + } + return + } + clientAddr := peerHost.Addr.String() + + *ingestionServerID = msg.IngestorId + *state = &IngestorState{ + ID: msg.IngestorId, + Info: msg.RegisterInfo, + LastHeartbeat: time.Now().Truncate(time.Second), + Stream: stream, + Status: "active", + Address: clientAddr, + } + + s.mu.Lock() + s.ingestionServers[*ingestionServerID] = *state + s.mu.Unlock() + + err := stream.Send(&common.AdminMessage{ + MessageType: "ACK", + AckInfo: &common.AckInfo{ + TaskId: "", + Success: true, + Message: "registered successfully", + }, + }) + if err != nil { + common.Error("Fail to send ACK message", err) + return + } + + common.Info(fmt.Sprintf("Ingestor %s registered, max_concurrency=%d, supported_types=%v", + *ingestionServerID, msg.RegisterInfo.MaxConcurrency, msg.RegisterInfo.SupportedDocTypes)) +} + +func (s *IngestionManager) handleHeartbeat(msg *common.IngestionMessage, ingestorID string, state *IngestorState) { + if state == nil { + return + } + + state.LastHeartbeat = time.Now().Truncate(time.Second) + + if msg.HeartbeatInfo != nil { + + lastUpdateTime := time.Now().Truncate(time.Second) + s.mu.Lock() + ingestorState := s.ingestionServers[msg.IngestorId] + ingestorState.LastHeartbeat = lastUpdateTime + if ingestorState.Status == "timeout" { + ingestorState.Status = "active" + common.Info(fmt.Sprintf("Ingestor %s recovered from timeout, status set to active", msg.IngestorId)) + } + ingestorState.ProcessID = msg.HeartbeatInfo.ProcessId + ingestorState.cpuUsage = float64(msg.HeartbeatInfo.CpuUsage) + ingestorState.vmsUsage = float64(msg.HeartbeatInfo.VmsUsage) / 1024 / 1024 // in MB + ingestorState.rssUsage = float64(msg.HeartbeatInfo.RssUsage) / 1024 / 1024 // in MB + + // Delete expired terminal tasks from currentTasks + for _, taskID := range msg.HeartbeatInfo.DeleteTaskIds { + delete(s.taskStates, taskID) + } + + for _, ingestorTaskState := range msg.HeartbeatInfo.TaskStates { + localTaskState := s.taskStates[ingestorTaskState.TaskId] + if localTaskState == nil { + startTime := time.Unix(0, ingestorTaskState.StartTime) + localTaskState = &TaskState{ + taskID: ingestorTaskState.TaskId, + comeFrom: ingestorTaskState.ComeFrom, + startTime: &startTime, + } + } + localTaskState.estimatedRemainingTime = time.Duration(ingestorTaskState.EstimatedRemainingTimeSeconds) + localTaskState.lastUpdate = lastUpdateTime + localTaskState.status = ingestorTaskState.Status + localTaskState.errorMessage = ingestorTaskState.ErrorMessage + localTaskState.assignTo = msg.IngestorId + } + s.mu.Unlock() + + common.Debug(fmt.Sprintf("Heartbeat from %s", ingestorID)) + } +} + +func (s *IngestionManager) handleTaskResult(msg *common.IngestionMessage, ingestorID string, state *IngestorState) { + if msg.TaskResult == nil { + return + } + + result := msg.TaskResult + common.Info(fmt.Sprintf("Task result from %s: task=%s, status=%s, message=%s", ingestorID, result.TaskId, result.Status, result.ErrorMessage)) + + // Signal that a slot may have freed up for pending tasks + select { + case s.slotFreed <- struct{}{}: + default: + } +} + +func (s *IngestionManager) handleTaskProgress(msg *common.IngestionMessage, ingestorID string, state *IngestorState) { + if msg.TaskProgress == nil { + return + } + + progress := msg.TaskProgress + common.Info(fmt.Sprintf("Task progress from %s: task=%s, progress=%d%%, detail=%s", + ingestorID, progress.TaskId, progress.Progress, progress.Info)) +} + +// SubmitTask is for API Server to call (non-gRPC, for testing only) +func (s *IngestionManager) SubmitTask(task *common.TaskAssignment) { + s.taskQueue <- &pendingTask{ + Task: task, + CreatedAt: time.Now().Truncate(time.Second), + } + common.Info(fmt.Sprintf("Task %s submitted to queue", task.TaskId)) + + // Wake up dispatchLoop if it's blocked waiting for a slot + select { + case s.slotFreed <- struct{}{}: + default: + } +} + +// dispatchLoop pulls tasks from the queue and assigns them to available ingestors. +// Runs in a background goroutine. +func (s *IngestionManager) dispatchLoop() { + for { + select { + case <-s.ctx.Done(): + return + case pending := <-s.taskQueue: + go s.tryAssign(pending.Task) + } + } +} + +// heartbeatCheckLoop periodically checks all registered ingestors for heartbeat timeout. +// If an ingestor's LastHeartbeat is older than heartbeatTimeout, its status is set to "timeout". +func (s *IngestionManager) heartbeatCheckLoop() { + ticker := time.NewTicker(heartbeatTimeout / 3) + defer ticker.Stop() + + for { + select { + case <-s.ctx.Done(): + return + case <-ticker.C: + s.checkHeartbeats() + } + } +} + +func (s *IngestionManager) checkHeartbeats() { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now().Truncate(time.Second) + for id, state := range s.ingestionServers { + if now.Sub(state.LastHeartbeat) > heartbeatTimeout { + if state.Status != "timeout" { + state.Status = "timeout" + common.Info(fmt.Sprintf("Ingestor %s heartbeat timeout, marked as timeout", id)) + } + } + } +} + +func (s *IngestionManager) SelectIngestorForTask(task *common.TaskAssignment) *IngestorState { + s.mu.Lock() + defer s.mu.Unlock() + + switch task.TaskType { + case "start_ingestion_task": + for _, ingestor := range s.ingestionServers { + if ingestor.Status == "active" { + s.taskStates[task.TaskId] = &TaskState{ + taskID: task.TaskId, + status: "DISPATCHED", + comeFrom: "CLI", + startTime: nil, + lastUpdate: time.Now().Truncate(time.Second), + assignTo: ingestor.ID, + } + return ingestor + } + } + case "cancel_ingestion_task": + taskState := s.taskStates[task.TaskId] + if taskState != nil { + switch taskState.status { + case "COMPLETED": + return nil + case "DISPATCHED": + { + taskState.status = "CANCELING" + return s.ingestionServers[taskState.assignTo] + } + default: + return s.ingestionServers[taskState.assignTo] + } + } + + case "shutdown_ingestor": + return s.ingestionServers[task.AssignedTo] + } + + return nil +} + +// tryAssign repeatedly tries to find an available ingestor and assign the task. +// Blocks until either the task is assigned or the context is canceled. +func (s *IngestionManager) tryAssign(task *common.TaskAssignment) { + for { + + target := s.SelectIngestorForTask(task) + if target != nil { + task.AssignedTo = target.ID + s.assignToIngestor(task, target) + return + } + + if task.TaskType == "start_ingestion_task" { + // Receives a start ingestion task, save and change the states + s.mu.Lock() + s.taskStates[task.TaskId] = &TaskState{ + taskID: task.TaskId, + status: "pending", + comeFrom: task.ComeFrom, + lastUpdate: time.Now().Truncate(time.Second), + startTime: nil, + } + s.mu.Unlock() + } else { + // shutdown ingestor or cancel task + common.Info("Task is completed, canceled, or ingestor is shutdown") + return + } + + // No ingestor available, wait for a slot to free up + select { + case <-s.ctx.Done(): + return + case <-s.slotFreed: + // A slot might be free, retry + case <-time.After(2 * time.Second): + // Periodic retry as fallback + } + } +} + +func (s *IngestionManager) assignToIngestor(task *common.TaskAssignment, state *IngestorState) { + err := state.Stream.Send(&common.AdminMessage{ + MessageType: "TASK_ASSIGNMENT", + TaskAssignment: task, + }) + if err != nil { + common.Info(fmt.Sprintf("Failed to assign task %s to ingestor %s: %v", task.TaskId, state.ID, err)) + // Re-queue the task + s.taskQueue <- &pendingTask{Task: task, CreatedAt: time.Now().Truncate(time.Second)} + return + } + common.Info(fmt.Sprintf("Assigned task %s to ingestion_server %s", task.TaskId, state.ID)) +} + +func (s *IngestionManager) cleanupIngestionServer(ingestorID string) { + s.mu.Lock() + defer s.mu.Unlock() + + if ingestorID == "" { + // Client disconnected before REGISTER completed — nothing to clean up + common.Info("Unregistered ingestion server disconnected") + return + } + + if _, exists := s.ingestionServers[ingestorID]; exists { + delete(s.ingestionServers, ingestorID) + common.Info(fmt.Sprintf("Ingestor %s cleaned up", ingestorID)) + + // Clean the tasks handled by this ingestor + var tasksToDelete []string + for _, taskState := range s.taskStates { + if taskState.assignTo == ingestorID { + tasksToDelete = append(tasksToDelete, taskState.taskID) + } + } + for _, taskID := range tasksToDelete { + delete(s.taskStates, taskID) + } + } +} + +func (s *IngestionManager) ListIngestors() ([]map[string]interface{}, error) { + s.mu.Lock() + defer s.mu.Unlock() + + var result []map[string]interface{} + for ingestorID, state := range s.ingestionServers { + + var taskCount int64 + for _, task := range s.taskStates { + if task.assignTo == ingestorID { + taskCount++ + } + } + + result = append(result, map[string]interface{}{ + "id": ingestorID, + "name": state.Info.Name, + "address": state.Address, + "last_heartbeat": state.LastHeartbeat, + "task_count": taskCount, + "status": state.Status, + "cpu_usage": state.cpuUsage, + "rss_usage": state.rssUsage, + "vms_usage": state.vmsUsage, + "process_id": state.ProcessID, + }) + } + return result, nil +} + +func (s *IngestionManager) ListIngestionTasks() ([]map[string]interface{}, error) { + + var result []map[string]interface{} + + s.mu.Lock() + defer s.mu.Unlock() + for index, taskState := range s.taskStates { + common.Info(fmt.Sprintf("Task %s: %s", index, taskState.taskID)) + result = append(result, map[string]interface{}{ + "id": taskState.taskID, + "status": taskState.status, + "from": taskState.comeFrom, + "assign_to": taskState.assignTo, + "last_update": taskState.lastUpdate, + "start_time": taskState.startTime, + "ETA": taskState.estimatedRemainingTime, + "error": taskState.errorMessage, + }) + } + + return result, nil +} + +// Start starts the admin service +func (s *IngestionManager) Start(port string) error { + lis, err := net.Listen("tcp", port) + if err != nil { + return err + } + + s.grpcServer = grpc.NewServer() + common.RegisterIngestionManagerServer(s.grpcServer, s) + + return s.grpcServer.Serve(lis) +} + +// Stop gracefully shuts down the admin service +func (s *IngestionManager) Stop() { + common.Info("Stopping RAGFlow ingestion manager...") + + // Notify all goroutines to exit + s.cancel() + + // Gracefully stop gRPC server (stop accepting new connections, wait for in-flight requests) + if s.grpcServer != nil { + s.grpcServer.GracefulStop() + } + + // Close the task queue + s.mu.Lock() + close(s.taskQueue) + s.mu.Unlock() + + common.Info("RAGFlow ingestion manager stopped") +} diff --git a/internal/admin/router.go b/internal/admin/router.go index 03aa3300b6..dde8d93844 100644 --- a/internal/admin/router.go +++ b/internal/admin/router.go @@ -136,6 +136,13 @@ func (r *Router) Setup(engine *gin.Engine) { provider.GET("/:provider_name/models", r.handler.ListModels) provider.GET("/:provider_name/models/:model_name", r.handler.ShowModel) } + + protected.GET("/ingestors", r.handler.ListIngestors) + protected.DELETE("/ingestors", r.handler.ShutdownIngestor) + protected.POST("/ingestion", r.handler.StartIngestionTask) // start ingestion + protected.DELETE("/ingestion", r.handler.StopIngestionTask) // stop ingestion + protected.GET("/ingestion/tasks", r.handler.ListIngestionTasks) + } } diff --git a/internal/admin/service.go b/internal/admin/service.go index 4b30b4f26c..d496ffefc2 100644 --- a/internal/admin/service.go +++ b/internal/admin/service.go @@ -41,13 +41,6 @@ import ( "go.uber.org/zap" ) -// Service errors -var ( - ErrInvalidToken = errors.New("invalid token") - ErrNotAdmin = errors.New("user is not admin") - ErrUserInactive = errors.New("user is inactive") -) - // Service admin service layer type Service struct { userDAO *dao.UserDAO @@ -105,32 +98,38 @@ func (s *Service) Logout(user interface{}) error { // ListTasks func (s *Service) ListTasks() ([]map[string]interface{}, error) { - tasks, err := s.taskDAO.GetAllTasks() + //tasks, err := s.taskDAO.GetAllTasks() + //if err != nil { + // return nil, err + //} + // + //var result []map[string]interface{} + //for _, task := range tasks { + // // task.ChunkIDs is a string, delimiter is space, count the word count + // ChunkCount := strings.Count(*task.ChunkIDs, " ") + // result = append(result, map[string]interface{}{ + // "id": task.ID, + // "task_type": task.TaskType, + // "document_id": task.DocID, + // "chunk_count": ChunkCount, + // "from_page": task.FromPage, + // "to_page": task.ToPage, + // "priority": task.Priority, + // "duration": task.ProcessDuration, + // "progress": task.Progress, + // //"message": *task.ProgressMsg, + // "retry_count": task.RetryCount, + // "digest": task.Digest, + // }) + //} + + ingestionMgr := GetIngestionManager() + ingestionTasks, err := ingestionMgr.ListIngestionTasks() if err != nil { - return nil, err + return nil, fmt.Errorf("fail to list ingestion tasks") } - var result []map[string]interface{} - for _, task := range tasks { - // task.ChunkIDs is a string, delimiter is space, count the word count - ChunkCount := strings.Count(*task.ChunkIDs, " ") - result = append(result, map[string]interface{}{ - "id": task.ID, - "task_type": task.TaskType, - "document_id": task.DocID, - "chunk_count": ChunkCount, - "from_page": task.FromPage, - "to_page": task.ToPage, - "priority": task.Priority, - "duration": task.ProcessDuration, - "progress": task.Progress, - //"message": *task.ProgressMsg, - "retry_count": task.RetryCount, - "digest": task.Digest, - }) - } - - return result, nil + return ingestionTasks, nil } // GetUserByToken get user by access token @@ -1055,7 +1054,7 @@ func (s *Service) ListServices() ([]map[string]interface{}, error) { } id := len(result) - serverList := GlobalServerStatusStore.GetAllStatuses() + serverList := GlobalServerStore.ListInfos() for _, serverStatus := range serverList { serverItem := make(map[string]interface{}) serverItem["name"] = serverStatus.ServerName @@ -1698,10 +1697,15 @@ func (s *Service) HandleHeartbeat(message *common.BaseMessage) (common.ErrorCode Timestamp: message.Timestamp, Ext: message.Ext, } - GlobalServerStatusStore.UpdateStatus(message.ServerName, status) + GlobalServerStore.UpdateServerInfo(message.ServerName, status) return common.CodeLicenseValid, "" } +func (s *Service) ListIngestionTasks() ([]map[string]interface{}, error) { + // TODO: Implement with sandbox manager + return []map[string]interface{}{}, nil +} + // InitDefaultAdmin initialize default admin user // This matches Python's init_default_admin behavior func (s *Service) InitDefaultAdmin() error { diff --git a/internal/admin/state.go b/internal/admin/state.go new file mode 100644 index 0000000000..5475bf6467 --- /dev/null +++ b/internal/admin/state.go @@ -0,0 +1,114 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package admin + +import ( + "errors" + "ragflow/internal/common" + "sync" + "time" +) + +// Service errors +var ( + ErrInvalidToken = errors.New("invalid token") + ErrNotAdmin = errors.New("user is not admin") + ErrUserInactive = errors.New("user is inactive") + ErrUserNotFound = errors.New("user not found") +) + +// API server state + +// ServerStore is a thread-safe global server status storage +type ServerStore struct { + mu sync.RWMutex + servers map[string]*common.BaseMessage // key: server_id +} + +// GlobalServerStore is the global instance +var GlobalServerStore = &ServerStore{ + servers: make(map[string]*common.BaseMessage), +} + +// UpdateServerInfo updates or adds a server status +func (s *ServerStore) UpdateServerInfo(serverName string, status *common.BaseMessage) { + + //switch serviceType { + //case "meta_data": + // return s.getMySQLStatus(name) + + switch status.ServerType { + case common.ServerTypeAPI: + s.mu.Lock() + defer s.mu.Unlock() + s.servers[serverName] = status + return + case common.ServerTypeIngestion: + return + } +} + +// GetServerInfo gets a single server status +func (s *ServerStore) GetServerInfo(serverName string) (*common.BaseMessage, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + status, ok := s.servers[serverName] + return status, ok +} + +// ListInfos gets all server infos +func (s *ServerStore) ListInfos() []*common.BaseMessage { + s.mu.RLock() + defer s.mu.RUnlock() + result := make([]*common.BaseMessage, 0, len(s.servers)) + for _, status := range s.servers { + result = append(result, status) + } + return result +} + +// ListInfosByType gets server infos by type +func (s *ServerStore) ListInfosByType(serverType common.ServerType) []*common.BaseMessage { + s.mu.RLock() + defer s.mu.RUnlock() + result := make([]*common.BaseMessage, 0) + for _, status := range s.servers { + if status.ServerType == serverType { + result = append(result, status) + } + } + return result +} + +// RemoveStatus removes a server status +func (s *ServerStore) RemoveStatus(serverID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.servers, serverID) +} + +// CleanupStaleStatuses cleans up servers that haven't reported for a specified duration +func (s *ServerStore) CleanupStaleStatuses(maxAge time.Duration) { + s.mu.Lock() + defer s.mu.Unlock() + now := time.Now() + for id, status := range s.servers { + if now.Sub(status.Timestamp) > maxAge { + delete(s.servers, id) + } + } +} diff --git a/internal/cli/admin_command.go b/internal/cli/admin_command.go index d1c3763601..cf36daf129 100644 --- a/internal/cli/admin_command.go +++ b/internal/cli/admin_command.go @@ -1281,3 +1281,162 @@ func (c *RAGFlowClient) ListAdminTasks(cmd *Command) (ResponseIf, error) { result.Duration = resp.Duration return &result, nil } + +func (c *RAGFlowClient) ListAdminIngestors(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + resp, err := c.HTTPClient.Request("GET", "/admin/ingestors", "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list ingestors: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list ingestors: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("list ingestors failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} +func (c *RAGFlowClient) ListAdminIngestionTasks(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + resp, err := c.HTTPClient.Request("GET", "/admin/ingestion/tasks", "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list admin tasks: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list admin tasks: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("list admin tasks failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} + +func (c *RAGFlowClient) AdminStartIngestionCommand(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + fileURI, ok := cmd.Params["uri"].(string) + if !ok { + return nil, fmt.Errorf("uri not provided") + } + payload := map[string]interface{}{ + "uri": fileURI, + "from": "CLI", + } + + resp, err := c.HTTPClient.Request("POST", "/admin/ingestion", "admin", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to ingest file: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to ingest file: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonDataResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("ingest file failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} + +func (c *RAGFlowClient) AdminStopIngestionCommand(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + taskID, ok := cmd.Params["task_id"].(string) + if !ok { + return nil, fmt.Errorf("uri not provided") + } + payload := map[string]interface{}{ + "task_id": taskID, + "from": "CLI", + } + + resp, err := c.HTTPClient.Request("DELETE", "/admin/ingestion", "admin", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to ingest file: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to ingest file: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonDataResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("ingest file failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} + +func (c *RAGFlowClient) AdminShutdownIngestor(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + ingestorName, ok := cmd.Params["ingestor_name"].(string) + if !ok { + return nil, fmt.Errorf("ingestor_name not provided") + } + payload := map[string]interface{}{ + "ingestor_name": ingestorName, + } + + resp, err := c.HTTPClient.Request("DELETE", "/admin/ingestors", "admin", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to shutdown ingestor: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to shutdown ingestor: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonDataResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("shutdown ingestor failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} diff --git a/internal/cli/admin_parser.go b/internal/cli/admin_parser.go index 9f4c6228e8..ada5843b9d 100644 --- a/internal/cli/admin_parser.go +++ b/internal/cli/admin_parser.go @@ -192,6 +192,10 @@ func (p *Parser) parseAdminListCommand() (*Command, error) { return p.parseAdminListFiles() case TokenTasks: return p.parseAdminListTasks() + case TokenIngestors: + return p.parseAdminListIngestors() + case TokenIngestion: + return p.parseAdminListIngestionTasks() default: return nil, fmt.Errorf("unknown LIST target: %s", p.curToken.Value) } @@ -376,6 +380,24 @@ func (p *Parser) parseAdminListTasks() (*Command, error) { return cmd, nil } +func (p *Parser) parseAdminListIngestors() (*Command, error) { + p.nextToken() // consume TASKS + cmd := NewCommand("admin_list_ingestors") + + return cmd, nil +} + +func (p *Parser) parseAdminListIngestionTasks() (*Command, error) { + p.nextToken() // consume Ingestion + + if p.curToken.Type != TokenTasks { + return nil, fmt.Errorf("expected TASKS") + } + + cmd := NewCommand("list_admin_ingestion_tasks") + return cmd, nil +} + func (p *Parser) parseAdminShowCommand() (*Command, error) { p.nextToken() // consume SHOW @@ -1543,10 +1565,19 @@ func (p *Parser) parseAdminStartupCommand() (*Command, error) { func (p *Parser) parseAdminShutdownCommand() (*Command, error) { p.nextToken() // consume SHUTDOWN - if p.curToken.Type != TokenService { - return nil, fmt.Errorf("expected SERVICE") + + switch p.curToken.Type { + case TokenService: + return p.parseAdminShutdownServiceCommand() + case TokenIngestor: + return p.parseAdminShutdownIngestorCommand() + default: + return nil, fmt.Errorf("expected SERVICE or INGESTOR") } - p.nextToken() +} + +func (p *Parser) parseAdminShutdownServiceCommand() (*Command, error) { + p.nextToken() // consume SERVICE serviceNum, err := p.parseNumber() if err != nil { @@ -1564,6 +1595,25 @@ func (p *Parser) parseAdminShutdownCommand() (*Command, error) { return cmd, nil } +func (p *Parser) parseAdminShutdownIngestorCommand() (*Command, error) { + p.nextToken() // consume INGESTOR + + ingestorName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("admin_shutdown_ingestor_command") + cmd.Params["ingestor_name"] = ingestorName + + p.nextToken() + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return cmd, nil +} + func (p *Parser) parseAdminRestartCommand() (*Command, error) { p.nextToken() // consume RESTART if p.curToken.Type != TokenService { @@ -1587,6 +1637,73 @@ func (p *Parser) parseAdminRestartCommand() (*Command, error) { return cmd, nil } +func (p *Parser) parseStartIngestion() (*Command, error) { + p.nextToken() // consume Start + + if p.curToken.Type != TokenIngestion { + return nil, fmt.Errorf("expect INGESTION") + } + p.nextToken() // consume Ingest + + uri, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("admin_start_ingestion_command") + cmd.Params["uri"] = uri + p.nextToken() + + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return cmd, nil +} + +func (p *Parser) parseStopIngestion() (*Command, error) { + p.nextToken() // consume Stop + + if p.curToken.Type != TokenIngestion { + return nil, fmt.Errorf("expect INGESTION") + } + p.nextToken() // consume Ingest + + taskID, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("admin_stop_ingestion_command") + cmd.Params["task_id"] = taskID + p.nextToken() + + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return cmd, nil +} + +func (p *Parser) parseAdminIngestCommand() (*Command, error) { + p.nextToken() // consume Ingest + + uri, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("admin_ingest_command") + cmd.Params["uri"] = uri + p.nextToken() + + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return cmd, nil +} + func (p *Parser) parseAdminUnsetCommand() (*Command, error) { p.nextToken() // consume UNSET diff --git a/internal/cli/client.go b/internal/cli/client.go index 53e53bceea..20efd2e65c 100644 --- a/internal/cli/client.go +++ b/internal/cli/client.go @@ -185,6 +185,16 @@ func (c *RAGFlowClient) ExecuteAdminCommand(cmd *Command) (ResponseIf, error) { return c.ShowModel(cmd) case "list_admin_tasks": return c.ListAdminTasks(cmd) + case "admin_list_ingestors": + return c.ListAdminIngestors(cmd) + case "admin_start_ingestion_command": + return c.AdminStartIngestionCommand(cmd) + case "admin_stop_ingestion_command": + return c.AdminStopIngestionCommand(cmd) + case "admin_shutdown_ingestor_command": + return c.AdminShutdownIngestor(cmd) + case "list_admin_ingestion_tasks": + return c.ListAdminIngestionTasks(cmd) // TODO: Implement other commands default: return nil, fmt.Errorf("command '%s' would be executed with API", cmd.Type) @@ -324,6 +334,8 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) { return c.RemoveChunks(cmd) case "list_metadata": return c.ListMetadata(cmd) + case "parse_documents_user_command": + return c.ParseDocumentsUserCommand(cmd) // ContextEngine commands case "ce_ls": return c.CEList(cmd) diff --git a/internal/cli/lexer.go b/internal/cli/lexer.go index 48e283a315..8676d8dbf3 100644 --- a/internal/cli/lexer.go +++ b/internal/cli/lexer.go @@ -447,6 +447,16 @@ func (l *Lexer) lookupIdent(ident string) Token { return Token{Type: TokenTask, Value: ident} case "TASKS": return Token{Type: TokenTasks, Value: ident} + case "START": + return Token{Type: TokenStart, Value: ident} + case "STOP": + return Token{Type: TokenStop, Value: ident} + case "INGESTOR": + return Token{Type: TokenIngestor, Value: ident} + case "INGESTORS": + return Token{Type: TokenIngestors, Value: ident} + case "INGESTION": + return Token{Type: TokenIngestion, Value: ident} case "LOG": return Token{Type: TokenLog, Value: ident} case "LEVEL": diff --git a/internal/cli/parser.go b/internal/cli/parser.go index a23a9f3589..214a1fa963 100644 --- a/internal/cli/parser.go +++ b/internal/cli/parser.go @@ -129,6 +129,10 @@ func (p *Parser) parseAdminCommand() (*Command, error) { return p.parseAdminShutdownCommand() case TokenRestart: return p.parseAdminRestartCommand() + case TokenStart: + return p.parseStartIngestion() + case TokenStop: + return p.parseStopIngestion() default: return nil, fmt.Errorf("unknown command: %s", p.curToken.Value) } @@ -259,7 +263,7 @@ func (p *Parser) expectSemicolon() error { } func isKeyword(tokenType int) bool { - return tokenType >= TokenLogin && tokenType <= TokenTag + return tokenType >= TokenLogin && tokenType <= TokenPanic } // isCECommand checks if the given string is a Filesystem command diff --git a/internal/cli/types.go b/internal/cli/types.go index df52d30d22..9e9f123ab9 100644 --- a/internal/cli/types.go +++ b/internal/cli/types.go @@ -157,6 +157,11 @@ const ( TokenURL TokenTask TokenTasks + TokenIngestor + TokenIngestors + TokenStart + TokenStop + TokenIngestion TokenLog TokenLevel TokenDebug diff --git a/internal/cli/user_command.go b/internal/cli/user_command.go index a20d99b93a..67618a5437 100644 --- a/internal/cli/user_command.go +++ b/internal/cli/user_command.go @@ -3232,6 +3232,54 @@ func (c *RAGFlowClient) RemoveChunks(cmd *Command) (ResponseIf, error) { return &result, nil } +func (c *RAGFlowClient) ParseDocumentsUserCommand(cmd *Command) (ResponseIf, error) { + if c.HTTPClient.APIToken == "" && c.HTTPClient.LoginToken == "" { + return nil, fmt.Errorf("API token not set. Please login first") + } + + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + datasetID, ok := cmd.Params["dataset_id"].(string) + if !ok { + return nil, fmt.Errorf("dataset_id not provided") + } + + documents, ok := cmd.Params["documents"].([]string) + if !ok { + return nil, fmt.Errorf("documents not provided") + } + + url := fmt.Sprintf("/datasets/%s/documents/parse", datasetID) + + payload := map[string]interface{}{ + "documents": documents, + } + + // Normal mode + resp, err := c.HTTPClient.Request("POST", url, "web", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to list documents: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list documents: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result SimpleResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("list documents failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + + return &result, nil +} + // formatRequestError Uniformly handle and format network errors in HTTP requests func formatRequestError(action string, err error) error { if err == nil { diff --git a/internal/cli/user_parser.go b/internal/cli/user_parser.go index 307628ddf1..6b481af660 100644 --- a/internal/cli/user_parser.go +++ b/internal/cli/user_parser.go @@ -3159,8 +3159,10 @@ func (p *Parser) parseParseCommand() (*Command, error) { return p.parseParseDataset() case TokenWith: return p.parseModelParseCommand() - default: + case TokenDocument: return p.parseParseDocs() + default: + return nil, fmt.Errorf("expected DATASET, WITH, or DOCUMENT") } } @@ -3194,31 +3196,32 @@ func (p *Parser) parseParseDataset() (*Command, error) { } func (p *Parser) parseParseDocs() (*Command, error) { - documentNames, err := p.parseQuotedString() + p.nextToken() // consume document + + documentsStr, err := p.parseQuotedString() if err != nil { return nil, err } p.nextToken() - if p.curToken.Type != TokenOf { - return nil, fmt.Errorf("expected OF") - } - p.nextToken() - if p.curToken.Type != TokenDataset { - return nil, fmt.Errorf("expected DATASET") + if p.curToken.Type != TokenFrom { + return nil, fmt.Errorf("expected FROM") } p.nextToken() - datasetName, err := p.parseQuotedString() + datasetID, err := p.parseQuotedString() if err != nil { return nil, err } - - cmd := NewCommand("parse_dataset_docs") - cmd.Params["document_names"] = documentNames - cmd.Params["dataset_name"] = datasetName - p.nextToken() + + cmd := NewCommand("parse_documents_user_command") + + documents := strings.Split(documentsStr, " ") + + cmd.Params["documents"] = documents + cmd.Params["dataset_id"] = datasetID + // Semicolon is optional for UNSET TOKEN if p.curToken.Type == TokenSemicolon { p.nextToken() diff --git a/internal/common/idenfication.go b/internal/common/idenfication.go new file mode 100644 index 0000000000..7306bbccb2 --- /dev/null +++ b/internal/common/idenfication.go @@ -0,0 +1,35 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package common + +func Deduplicate(names []string) []string { + if names == nil { + return nil + } + + seen := make(map[string]bool) + result := make([]string, 0, len(names)) + + for _, name := range names { + if !seen[name] { + seen[name] = true + result = append(result, name) + } + } + + return result +} diff --git a/internal/common/ingestion.pb.go b/internal/common/ingestion.pb.go new file mode 100644 index 0000000000..a2211b2a64 --- /dev/null +++ b/internal/common/ingestion.pb.go @@ -0,0 +1,795 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v3.21.12 +// source: internal/proto/ingestion.proto + +package common + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type IngestionMessage struct { + state protoimpl.MessageState `protogen:"open.v1"` + IngestorId string `protobuf:"bytes,1,opt,name=ingestor_id,json=ingestorId,proto3" json:"ingestor_id,omitempty"` + MessageType string `protobuf:"bytes,2,opt,name=message_type,json=messageType,proto3" json:"message_type,omitempty"` // REGISTER, HEARTBEAT, TASK_RESULT, TASK_PROGRESS, PULL_REQUEST + RegisterInfo *RegisterInfo `protobuf:"bytes,3,opt,name=register_info,json=registerInfo,proto3" json:"register_info,omitempty"` + HeartbeatInfo *HeartbeatInfo `protobuf:"bytes,4,opt,name=heartbeat_info,json=heartbeatInfo,proto3" json:"heartbeat_info,omitempty"` + TaskResult *TaskResult `protobuf:"bytes,5,opt,name=task_result,json=taskResult,proto3" json:"task_result,omitempty"` + TaskProgress *TaskProgress `protobuf:"bytes,6,opt,name=task_progress,json=taskProgress,proto3" json:"task_progress,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *IngestionMessage) Reset() { + *x = IngestionMessage{} + mi := &file_internal_proto_ingestion_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *IngestionMessage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IngestionMessage) ProtoMessage() {} + +func (x *IngestionMessage) ProtoReflect() protoreflect.Message { + mi := &file_internal_proto_ingestion_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IngestionMessage.ProtoReflect.Descriptor instead. +func (*IngestionMessage) Descriptor() ([]byte, []int) { + return file_internal_proto_ingestion_proto_rawDescGZIP(), []int{0} +} + +func (x *IngestionMessage) GetIngestorId() string { + if x != nil { + return x.IngestorId + } + return "" +} + +func (x *IngestionMessage) GetMessageType() string { + if x != nil { + return x.MessageType + } + return "" +} + +func (x *IngestionMessage) GetRegisterInfo() *RegisterInfo { + if x != nil { + return x.RegisterInfo + } + return nil +} + +func (x *IngestionMessage) GetHeartbeatInfo() *HeartbeatInfo { + if x != nil { + return x.HeartbeatInfo + } + return nil +} + +func (x *IngestionMessage) GetTaskResult() *TaskResult { + if x != nil { + return x.TaskResult + } + return nil +} + +func (x *IngestionMessage) GetTaskProgress() *TaskProgress { + if x != nil { + return x.TaskProgress + } + return nil +} + +type AdminMessage struct { + state protoimpl.MessageState `protogen:"open.v1"` + MessageType string `protobuf:"bytes,1,opt,name=message_type,json=messageType,proto3" json:"message_type,omitempty"` // TASK_ASSIGNMENT, ACK, PONG, RECONNECT + TaskAssignment *TaskAssignment `protobuf:"bytes,2,opt,name=task_assignment,json=taskAssignment,proto3" json:"task_assignment,omitempty"` + AckInfo *AckInfo `protobuf:"bytes,3,opt,name=ack_info,json=ackInfo,proto3" json:"ack_info,omitempty"` + ErrorMessage string `protobuf:"bytes,4,opt,name=error_message,json=errorMessage,proto3" json:"error_message,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AdminMessage) Reset() { + *x = AdminMessage{} + mi := &file_internal_proto_ingestion_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AdminMessage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AdminMessage) ProtoMessage() {} + +func (x *AdminMessage) ProtoReflect() protoreflect.Message { + mi := &file_internal_proto_ingestion_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AdminMessage.ProtoReflect.Descriptor instead. +func (*AdminMessage) Descriptor() ([]byte, []int) { + return file_internal_proto_ingestion_proto_rawDescGZIP(), []int{1} +} + +func (x *AdminMessage) GetMessageType() string { + if x != nil { + return x.MessageType + } + return "" +} + +func (x *AdminMessage) GetTaskAssignment() *TaskAssignment { + if x != nil { + return x.TaskAssignment + } + return nil +} + +func (x *AdminMessage) GetAckInfo() *AckInfo { + if x != nil { + return x.AckInfo + } + return nil +} + +func (x *AdminMessage) GetErrorMessage() string { + if x != nil { + return x.ErrorMessage + } + return "" +} + +type RegisterInfo struct { + state protoimpl.MessageState `protogen:"open.v1"` + MaxConcurrency int32 `protobuf:"varint,1,opt,name=max_concurrency,json=maxConcurrency,proto3" json:"max_concurrency,omitempty"` + SupportedDocTypes []string `protobuf:"bytes,2,rep,name=supported_doc_types,json=supportedDocTypes,proto3" json:"supported_doc_types,omitempty"` + Version string `protobuf:"bytes,3,opt,name=version,proto3" json:"version,omitempty"` + Name string `protobuf:"bytes,4,opt,name=name,proto3" json:"name,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RegisterInfo) Reset() { + *x = RegisterInfo{} + mi := &file_internal_proto_ingestion_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RegisterInfo) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RegisterInfo) ProtoMessage() {} + +func (x *RegisterInfo) ProtoReflect() protoreflect.Message { + mi := &file_internal_proto_ingestion_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RegisterInfo.ProtoReflect.Descriptor instead. +func (*RegisterInfo) Descriptor() ([]byte, []int) { + return file_internal_proto_ingestion_proto_rawDescGZIP(), []int{2} +} + +func (x *RegisterInfo) GetMaxConcurrency() int32 { + if x != nil { + return x.MaxConcurrency + } + return 0 +} + +func (x *RegisterInfo) GetSupportedDocTypes() []string { + if x != nil { + return x.SupportedDocTypes + } + return nil +} + +func (x *RegisterInfo) GetVersion() string { + if x != nil { + return x.Version + } + return "" +} + +func (x *RegisterInfo) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +type HeartbeatInfo struct { + state protoimpl.MessageState `protogen:"open.v1"` + TaskStates []*TaskState `protobuf:"bytes,1,rep,name=task_states,json=taskStates,proto3" json:"task_states,omitempty"` + DeleteTaskIds []string `protobuf:"bytes,2,rep,name=delete_task_ids,json=deleteTaskIds,proto3" json:"delete_task_ids,omitempty"` + CpuUsage float32 `protobuf:"fixed32,3,opt,name=cpu_usage,json=cpuUsage,proto3" json:"cpu_usage,omitempty"` // percentage + VmsUsage float32 `protobuf:"fixed32,4,opt,name=vms_usage,json=vmsUsage,proto3" json:"vms_usage,omitempty"` // absolute value + RssUsage float32 `protobuf:"fixed32,5,opt,name=rss_usage,json=rssUsage,proto3" json:"rss_usage,omitempty"` // absolute value + ProcessId int64 `protobuf:"varint,6,opt,name=process_id,json=processId,proto3" json:"process_id,omitempty"` // pid + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *HeartbeatInfo) Reset() { + *x = HeartbeatInfo{} + mi := &file_internal_proto_ingestion_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *HeartbeatInfo) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HeartbeatInfo) ProtoMessage() {} + +func (x *HeartbeatInfo) ProtoReflect() protoreflect.Message { + mi := &file_internal_proto_ingestion_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HeartbeatInfo.ProtoReflect.Descriptor instead. +func (*HeartbeatInfo) Descriptor() ([]byte, []int) { + return file_internal_proto_ingestion_proto_rawDescGZIP(), []int{3} +} + +func (x *HeartbeatInfo) GetTaskStates() []*TaskState { + if x != nil { + return x.TaskStates + } + return nil +} + +func (x *HeartbeatInfo) GetDeleteTaskIds() []string { + if x != nil { + return x.DeleteTaskIds + } + return nil +} + +func (x *HeartbeatInfo) GetCpuUsage() float32 { + if x != nil { + return x.CpuUsage + } + return 0 +} + +func (x *HeartbeatInfo) GetVmsUsage() float32 { + if x != nil { + return x.VmsUsage + } + return 0 +} + +func (x *HeartbeatInfo) GetRssUsage() float32 { + if x != nil { + return x.RssUsage + } + return 0 +} + +func (x *HeartbeatInfo) GetProcessId() int64 { + if x != nil { + return x.ProcessId + } + return 0 +} + +type TaskState struct { + state protoimpl.MessageState `protogen:"open.v1"` + TaskId string `protobuf:"bytes,1,opt,name=task_id,json=taskId,proto3" json:"task_id,omitempty"` + Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` // PENDING, RUNNING, COMPLETED, FAILED, CANCELLED + ErrorMessage string `protobuf:"bytes,3,opt,name=error_message,json=errorMessage,proto3" json:"error_message,omitempty"` + EstimatedRemainingTimeSeconds int64 `protobuf:"varint,4,opt,name=estimated_remaining_time_seconds,json=estimatedRemainingTimeSeconds,proto3" json:"estimated_remaining_time_seconds,omitempty"` + ComeFrom string `protobuf:"bytes,5,opt,name=come_from,json=comeFrom,proto3" json:"come_from,omitempty"` + StartTime int64 `protobuf:"varint,6,opt,name=start_time,json=startTime,proto3" json:"start_time,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TaskState) Reset() { + *x = TaskState{} + mi := &file_internal_proto_ingestion_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TaskState) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TaskState) ProtoMessage() {} + +func (x *TaskState) ProtoReflect() protoreflect.Message { + mi := &file_internal_proto_ingestion_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TaskState.ProtoReflect.Descriptor instead. +func (*TaskState) Descriptor() ([]byte, []int) { + return file_internal_proto_ingestion_proto_rawDescGZIP(), []int{4} +} + +func (x *TaskState) GetTaskId() string { + if x != nil { + return x.TaskId + } + return "" +} + +func (x *TaskState) GetStatus() string { + if x != nil { + return x.Status + } + return "" +} + +func (x *TaskState) GetErrorMessage() string { + if x != nil { + return x.ErrorMessage + } + return "" +} + +func (x *TaskState) GetEstimatedRemainingTimeSeconds() int64 { + if x != nil { + return x.EstimatedRemainingTimeSeconds + } + return 0 +} + +func (x *TaskState) GetComeFrom() string { + if x != nil { + return x.ComeFrom + } + return "" +} + +func (x *TaskState) GetStartTime() int64 { + if x != nil { + return x.StartTime + } + return 0 +} + +type TaskAssignment struct { + state protoimpl.MessageState `protogen:"open.v1"` + TaskId string `protobuf:"bytes,1,opt,name=task_id,json=taskId,proto3" json:"task_id,omitempty"` + TaskType string `protobuf:"bytes,2,opt,name=task_type,json=taskType,proto3" json:"task_type,omitempty"` + Config string `protobuf:"bytes,3,opt,name=config,proto3" json:"config,omitempty"` + ComeFrom string `protobuf:"bytes,4,opt,name=come_from,json=comeFrom,proto3" json:"come_from,omitempty"` + AssignedTo string `protobuf:"bytes,5,opt,name=assigned_to,json=assignedTo,proto3" json:"assigned_to,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TaskAssignment) Reset() { + *x = TaskAssignment{} + mi := &file_internal_proto_ingestion_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TaskAssignment) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TaskAssignment) ProtoMessage() {} + +func (x *TaskAssignment) ProtoReflect() protoreflect.Message { + mi := &file_internal_proto_ingestion_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TaskAssignment.ProtoReflect.Descriptor instead. +func (*TaskAssignment) Descriptor() ([]byte, []int) { + return file_internal_proto_ingestion_proto_rawDescGZIP(), []int{5} +} + +func (x *TaskAssignment) GetTaskId() string { + if x != nil { + return x.TaskId + } + return "" +} + +func (x *TaskAssignment) GetTaskType() string { + if x != nil { + return x.TaskType + } + return "" +} + +func (x *TaskAssignment) GetConfig() string { + if x != nil { + return x.Config + } + return "" +} + +func (x *TaskAssignment) GetComeFrom() string { + if x != nil { + return x.ComeFrom + } + return "" +} + +func (x *TaskAssignment) GetAssignedTo() string { + if x != nil { + return x.AssignedTo + } + return "" +} + +type TaskResult struct { + state protoimpl.MessageState `protogen:"open.v1"` + TaskId string `protobuf:"bytes,1,opt,name=task_id,json=taskId,proto3" json:"task_id,omitempty"` + Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` // COMPLETED, FAILED, CANCELLED + ErrorMessage string `protobuf:"bytes,3,opt,name=error_message,json=errorMessage,proto3" json:"error_message,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TaskResult) Reset() { + *x = TaskResult{} + mi := &file_internal_proto_ingestion_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TaskResult) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TaskResult) ProtoMessage() {} + +func (x *TaskResult) ProtoReflect() protoreflect.Message { + mi := &file_internal_proto_ingestion_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TaskResult.ProtoReflect.Descriptor instead. +func (*TaskResult) Descriptor() ([]byte, []int) { + return file_internal_proto_ingestion_proto_rawDescGZIP(), []int{6} +} + +func (x *TaskResult) GetTaskId() string { + if x != nil { + return x.TaskId + } + return "" +} + +func (x *TaskResult) GetStatus() string { + if x != nil { + return x.Status + } + return "" +} + +func (x *TaskResult) GetErrorMessage() string { + if x != nil { + return x.ErrorMessage + } + return "" +} + +type TaskProgress struct { + state protoimpl.MessageState `protogen:"open.v1"` + TaskId string `protobuf:"bytes,1,opt,name=task_id,json=taskId,proto3" json:"task_id,omitempty"` + Progress int32 `protobuf:"varint,2,opt,name=progress,proto3" json:"progress,omitempty"` + Info string `protobuf:"bytes,3,opt,name=info,proto3" json:"info,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TaskProgress) Reset() { + *x = TaskProgress{} + mi := &file_internal_proto_ingestion_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TaskProgress) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TaskProgress) ProtoMessage() {} + +func (x *TaskProgress) ProtoReflect() protoreflect.Message { + mi := &file_internal_proto_ingestion_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TaskProgress.ProtoReflect.Descriptor instead. +func (*TaskProgress) Descriptor() ([]byte, []int) { + return file_internal_proto_ingestion_proto_rawDescGZIP(), []int{7} +} + +func (x *TaskProgress) GetTaskId() string { + if x != nil { + return x.TaskId + } + return "" +} + +func (x *TaskProgress) GetProgress() int32 { + if x != nil { + return x.Progress + } + return 0 +} + +func (x *TaskProgress) GetInfo() string { + if x != nil { + return x.Info + } + return "" +} + +type AckInfo struct { + state protoimpl.MessageState `protogen:"open.v1"` + TaskId string `protobuf:"bytes,1,opt,name=task_id,json=taskId,proto3" json:"task_id,omitempty"` + Success bool `protobuf:"varint,2,opt,name=success,proto3" json:"success,omitempty"` + Message string `protobuf:"bytes,3,opt,name=message,proto3" json:"message,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AckInfo) Reset() { + *x = AckInfo{} + mi := &file_internal_proto_ingestion_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AckInfo) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AckInfo) ProtoMessage() {} + +func (x *AckInfo) ProtoReflect() protoreflect.Message { + mi := &file_internal_proto_ingestion_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AckInfo.ProtoReflect.Descriptor instead. +func (*AckInfo) Descriptor() ([]byte, []int) { + return file_internal_proto_ingestion_proto_rawDescGZIP(), []int{8} +} + +func (x *AckInfo) GetTaskId() string { + if x != nil { + return x.TaskId + } + return "" +} + +func (x *AckInfo) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *AckInfo) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +var File_internal_proto_ingestion_proto protoreflect.FileDescriptor + +const file_internal_proto_ingestion_proto_rawDesc = "" + + "\n" + + "\x1einternal/proto/ingestion.proto\x12\x06common\"\xbf\x02\n" + + "\x10IngestionMessage\x12\x1f\n" + + "\vingestor_id\x18\x01 \x01(\tR\n" + + "ingestorId\x12!\n" + + "\fmessage_type\x18\x02 \x01(\tR\vmessageType\x129\n" + + "\rregister_info\x18\x03 \x01(\v2\x14.common.RegisterInfoR\fregisterInfo\x12<\n" + + "\x0eheartbeat_info\x18\x04 \x01(\v2\x15.common.HeartbeatInfoR\rheartbeatInfo\x123\n" + + "\vtask_result\x18\x05 \x01(\v2\x12.common.TaskResultR\n" + + "taskResult\x129\n" + + "\rtask_progress\x18\x06 \x01(\v2\x14.common.TaskProgressR\ftaskProgress\"\xc3\x01\n" + + "\fAdminMessage\x12!\n" + + "\fmessage_type\x18\x01 \x01(\tR\vmessageType\x12?\n" + + "\x0ftask_assignment\x18\x02 \x01(\v2\x16.common.TaskAssignmentR\x0etaskAssignment\x12*\n" + + "\back_info\x18\x03 \x01(\v2\x0f.common.AckInfoR\aackInfo\x12#\n" + + "\rerror_message\x18\x04 \x01(\tR\ferrorMessage\"\x95\x01\n" + + "\fRegisterInfo\x12'\n" + + "\x0fmax_concurrency\x18\x01 \x01(\x05R\x0emaxConcurrency\x12.\n" + + "\x13supported_doc_types\x18\x02 \x03(\tR\x11supportedDocTypes\x12\x18\n" + + "\aversion\x18\x03 \x01(\tR\aversion\x12\x12\n" + + "\x04name\x18\x04 \x01(\tR\x04name\"\xe1\x01\n" + + "\rHeartbeatInfo\x122\n" + + "\vtask_states\x18\x01 \x03(\v2\x11.common.TaskStateR\n" + + "taskStates\x12&\n" + + "\x0fdelete_task_ids\x18\x02 \x03(\tR\rdeleteTaskIds\x12\x1b\n" + + "\tcpu_usage\x18\x03 \x01(\x02R\bcpuUsage\x12\x1b\n" + + "\tvms_usage\x18\x04 \x01(\x02R\bvmsUsage\x12\x1b\n" + + "\trss_usage\x18\x05 \x01(\x02R\brssUsage\x12\x1d\n" + + "\n" + + "process_id\x18\x06 \x01(\x03R\tprocessId\"\xe6\x01\n" + + "\tTaskState\x12\x17\n" + + "\atask_id\x18\x01 \x01(\tR\x06taskId\x12\x16\n" + + "\x06status\x18\x02 \x01(\tR\x06status\x12#\n" + + "\rerror_message\x18\x03 \x01(\tR\ferrorMessage\x12G\n" + + " estimated_remaining_time_seconds\x18\x04 \x01(\x03R\x1destimatedRemainingTimeSeconds\x12\x1b\n" + + "\tcome_from\x18\x05 \x01(\tR\bcomeFrom\x12\x1d\n" + + "\n" + + "start_time\x18\x06 \x01(\x03R\tstartTime\"\x9c\x01\n" + + "\x0eTaskAssignment\x12\x17\n" + + "\atask_id\x18\x01 \x01(\tR\x06taskId\x12\x1b\n" + + "\ttask_type\x18\x02 \x01(\tR\btaskType\x12\x16\n" + + "\x06config\x18\x03 \x01(\tR\x06config\x12\x1b\n" + + "\tcome_from\x18\x04 \x01(\tR\bcomeFrom\x12\x1f\n" + + "\vassigned_to\x18\x05 \x01(\tR\n" + + "assignedTo\"b\n" + + "\n" + + "TaskResult\x12\x17\n" + + "\atask_id\x18\x01 \x01(\tR\x06taskId\x12\x16\n" + + "\x06status\x18\x02 \x01(\tR\x06status\x12#\n" + + "\rerror_message\x18\x03 \x01(\tR\ferrorMessage\"W\n" + + "\fTaskProgress\x12\x17\n" + + "\atask_id\x18\x01 \x01(\tR\x06taskId\x12\x1a\n" + + "\bprogress\x18\x02 \x01(\x05R\bprogress\x12\x12\n" + + "\x04info\x18\x03 \x01(\tR\x04info\"V\n" + + "\aAckInfo\x12\x17\n" + + "\atask_id\x18\x01 \x01(\tR\x06taskId\x12\x18\n" + + "\asuccess\x18\x02 \x01(\bR\asuccess\x12\x18\n" + + "\amessage\x18\x03 \x01(\tR\amessage2P\n" + + "\x10IngestionManager\x12<\n" + + "\x06Action\x12\x18.common.IngestionMessage\x1a\x14.common.AdminMessage(\x010\x01B\vZ\t./;commonb\x06proto3" + +var ( + file_internal_proto_ingestion_proto_rawDescOnce sync.Once + file_internal_proto_ingestion_proto_rawDescData []byte +) + +func file_internal_proto_ingestion_proto_rawDescGZIP() []byte { + file_internal_proto_ingestion_proto_rawDescOnce.Do(func() { + file_internal_proto_ingestion_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_internal_proto_ingestion_proto_rawDesc), len(file_internal_proto_ingestion_proto_rawDesc))) + }) + return file_internal_proto_ingestion_proto_rawDescData +} + +var file_internal_proto_ingestion_proto_msgTypes = make([]protoimpl.MessageInfo, 9) +var file_internal_proto_ingestion_proto_goTypes = []any{ + (*IngestionMessage)(nil), // 0: common.IngestionMessage + (*AdminMessage)(nil), // 1: common.AdminMessage + (*RegisterInfo)(nil), // 2: common.RegisterInfo + (*HeartbeatInfo)(nil), // 3: common.HeartbeatInfo + (*TaskState)(nil), // 4: common.TaskState + (*TaskAssignment)(nil), // 5: common.TaskAssignment + (*TaskResult)(nil), // 6: common.TaskResult + (*TaskProgress)(nil), // 7: common.TaskProgress + (*AckInfo)(nil), // 8: common.AckInfo +} +var file_internal_proto_ingestion_proto_depIdxs = []int32{ + 2, // 0: common.IngestionMessage.register_info:type_name -> common.RegisterInfo + 3, // 1: common.IngestionMessage.heartbeat_info:type_name -> common.HeartbeatInfo + 6, // 2: common.IngestionMessage.task_result:type_name -> common.TaskResult + 7, // 3: common.IngestionMessage.task_progress:type_name -> common.TaskProgress + 5, // 4: common.AdminMessage.task_assignment:type_name -> common.TaskAssignment + 8, // 5: common.AdminMessage.ack_info:type_name -> common.AckInfo + 4, // 6: common.HeartbeatInfo.task_states:type_name -> common.TaskState + 0, // 7: common.IngestionManager.Action:input_type -> common.IngestionMessage + 1, // 8: common.IngestionManager.Action:output_type -> common.AdminMessage + 8, // [8:9] is the sub-list for method output_type + 7, // [7:8] is the sub-list for method input_type + 7, // [7:7] is the sub-list for extension type_name + 7, // [7:7] is the sub-list for extension extendee + 0, // [0:7] is the sub-list for field type_name +} + +func init() { file_internal_proto_ingestion_proto_init() } +func file_internal_proto_ingestion_proto_init() { + if File_internal_proto_ingestion_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_internal_proto_ingestion_proto_rawDesc), len(file_internal_proto_ingestion_proto_rawDesc)), + NumEnums: 0, + NumMessages: 9, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_internal_proto_ingestion_proto_goTypes, + DependencyIndexes: file_internal_proto_ingestion_proto_depIdxs, + MessageInfos: file_internal_proto_ingestion_proto_msgTypes, + }.Build() + File_internal_proto_ingestion_proto = out.File + file_internal_proto_ingestion_proto_goTypes = nil + file_internal_proto_ingestion_proto_depIdxs = nil +} diff --git a/internal/common/ingestion_grpc.pb.go b/internal/common/ingestion_grpc.pb.go new file mode 100644 index 0000000000..5d43bce800 --- /dev/null +++ b/internal/common/ingestion_grpc.pb.go @@ -0,0 +1,115 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.6.1 +// - protoc v3.21.12 +// source: internal/proto/ingestion.proto + +package common + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + IngestionManager_Action_FullMethodName = "/common.IngestionManager/Action" +) + +// IngestionManagerClient is the client API for IngestionManager service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type IngestionManagerClient interface { + Action(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[IngestionMessage, AdminMessage], error) +} + +type ingestionManagerClient struct { + cc grpc.ClientConnInterface +} + +func NewIngestionManagerClient(cc grpc.ClientConnInterface) IngestionManagerClient { + return &ingestionManagerClient{cc} +} + +func (c *ingestionManagerClient) Action(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[IngestionMessage, AdminMessage], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &IngestionManager_ServiceDesc.Streams[0], IngestionManager_Action_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[IngestionMessage, AdminMessage]{ClientStream: stream} + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type IngestionManager_ActionClient = grpc.BidiStreamingClient[IngestionMessage, AdminMessage] + +// IngestionManagerServer is the server API for IngestionManager service. +// All implementations must embed UnimplementedIngestionManagerServer +// for forward compatibility. +type IngestionManagerServer interface { + Action(grpc.BidiStreamingServer[IngestionMessage, AdminMessage]) error + mustEmbedUnimplementedIngestionManagerServer() +} + +// UnimplementedIngestionManagerServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedIngestionManagerServer struct{} + +func (UnimplementedIngestionManagerServer) Action(grpc.BidiStreamingServer[IngestionMessage, AdminMessage]) error { + return status.Error(codes.Unimplemented, "method Action not implemented") +} +func (UnimplementedIngestionManagerServer) mustEmbedUnimplementedIngestionManagerServer() {} +func (UnimplementedIngestionManagerServer) testEmbeddedByValue() {} + +// UnsafeIngestionManagerServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to IngestionManagerServer will +// result in compilation errors. +type UnsafeIngestionManagerServer interface { + mustEmbedUnimplementedIngestionManagerServer() +} + +func RegisterIngestionManagerServer(s grpc.ServiceRegistrar, srv IngestionManagerServer) { + // If the following call panics, it indicates UnimplementedIngestionManagerServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&IngestionManager_ServiceDesc, srv) +} + +func _IngestionManager_Action_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(IngestionManagerServer).Action(&grpc.GenericServerStream[IngestionMessage, AdminMessage]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type IngestionManager_ActionServer = grpc.BidiStreamingServer[IngestionMessage, AdminMessage] + +// IngestionManager_ServiceDesc is the grpc.ServiceDesc for IngestionManager service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var IngestionManager_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "common.IngestionManager", + HandlerType: (*IngestionManagerServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "Action", + Handler: _IngestionManager_Action_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "internal/proto/ingestion.proto", +} diff --git a/internal/common/status_message.go b/internal/common/status_message.go index d538848a9e..1412a5e05a 100644 --- a/internal/common/status_message.go +++ b/internal/common/status_message.go @@ -15,9 +15,9 @@ const ( type ServerType string const ( - ServerTypeAPI ServerType = "api_server" // API server - ServerTypeWorker ServerType = "ingestor" // Ingestion server - ServerTypeScheduler ServerType = "data_collector" // Data collection server + ServerTypeAPI ServerType = "api_server" // API server + ServerTypeIngestion ServerType = "ingestor" // Ingestion server + ServerTypeDataCollector ServerType = "data_collector" // Data collection server ) type BaseMessage struct { diff --git a/internal/dao/file.go b/internal/dao/file.go index e09a75fa56..28f1fe15f7 100644 --- a/internal/dao/file.go +++ b/internal/dao/file.go @@ -102,7 +102,7 @@ func (dao *FileDAO) GetRootFolder(tenantID string) (*entity.File, error) { } file.SourceType = "" - if err := DB.Create(&file).Error; err != nil { + if err = DB.Create(&file).Error; err != nil { return nil, err } return &file, nil @@ -427,7 +427,7 @@ func (dao *FileDAO) newAFileFromDataset(tenantID, name, parentID string) (*entit SourceType: "knowledgebase", } - if err := DB.Create(file).Error; err != nil { + if err = DB.Create(file).Error; err != nil { return nil, err } return file, nil @@ -470,7 +470,7 @@ func (dao *FileDAO) addFileFromKB(doc *entity.Document, datasetFolderID, tenantI SourceType: "knowledgebase", } - if err := DB.Create(file).Error; err != nil { + if err = DB.Create(file).Error; err != nil { return err } @@ -481,7 +481,7 @@ func (dao *FileDAO) addFileFromKB(doc *entity.Document, datasetFolderID, tenantI DocumentID: &doc.ID, } - if err := DB.Create(f2d).Error; err != nil { + if err = DB.Create(f2d).Error; err != nil { return err } diff --git a/internal/dao/ingestion.go b/internal/dao/ingestion.go new file mode 100644 index 0000000000..6aad681d66 --- /dev/null +++ b/internal/dao/ingestion.go @@ -0,0 +1,110 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dao + +import ( + "ragflow/internal/entity" +) + +type IngestionDAO struct{} + +func NewIngestionDAO() *IngestionDAO { + return &IngestionDAO{} +} + +func (dao *IngestionDAO) Create(ingestionTask *entity.IngestionTask) error { + + tx := DB.Begin() + if tx.Error != nil { + return tx.Error + } + + defer func() { + if r := recover(); r != nil { + tx.Rollback() + panic(r) + } + }() + + // create ingestion task + if err := DB.Create(ingestionTask).Error; err != nil { + tx.Rollback() + return err + } + + taskLog := &entity.IngestionTaskLog{ + TaskID: ingestionTask.ID, + Stage: 0, + } + + // create task log + if err := DB.Create(taskLog).Error; err != nil { + tx.Rollback() + return err + } + + if err := tx.Commit().Error; err != nil { + return err + } + + return nil +} + +func (dao *IngestionDAO) GetAllTasks() ([]*entity.IngestionTask, error) { + var tasks []*entity.IngestionTask + err := DB.Find(&tasks).Error + return tasks, err +} + +func (dao *IngestionDAO) ListByUserID(userID string) ([]*entity.IngestionTask, error) { + var tasks []*entity.IngestionTask + err := DB.Where("user_id = ?", userID).Find(&tasks).Error + return tasks, err +} + +func (dao *IngestionDAO) GetByID(id string) (*entity.IngestionTask, error) { + var task *entity.IngestionTask + err := DB.Where("id = ?", id).First(&task).Error + return task, err +} + +type IngestionLogDAO struct{} + +func NewIngestionLogDAO() *IngestionLogDAO { + return &IngestionLogDAO{} +} + +func (dao *IngestionLogDAO) Create(ingestionLog *entity.IngestionTaskLog) error { + return DB.Create(ingestionLog).Error +} + +func (dao *IngestionDAO) ListLogsByTaskID(taskID string) ([]*entity.IngestionTaskLog, error) { + var tasks []*entity.IngestionTaskLog + err := DB.Where("task_id = ?", taskID).Find(&tasks).Error + return tasks, err +} + +func (dao *IngestionDAO) GetLogByLogID(logID string) (*entity.IngestionTaskLog, error) { + var task *entity.IngestionTaskLog + err := DB.Where("id = ?", logID).First(&task).Error + return task, err +} + +func (dao *IngestionDAO) DeleteByTaskID(taskID string) (int64, error) { + result := DB.Unscoped().Where("task_id = ?", taskID).Delete(&entity.IngestionTaskLog{}) + return result.RowsAffected, result.Error +} diff --git a/internal/entity/ingestion_task.go b/internal/entity/ingestion_task.go new file mode 100644 index 0000000000..d061ac65c2 --- /dev/null +++ b/internal/entity/ingestion_task.go @@ -0,0 +1,70 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package entity + +type IngestionTask struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + DocumentID string `gorm:"column:document_id;size:32;not null;index" json:"document_id"` + UserID string `gorm:"column:user_id;size:32;not null;" json:"user_id"` + Config JSONMap `gorm:"column:config;type:longtext;not null" json:"config"` + TryCount int `gorm:"column:try_count;type:int;default:0" json:"try_count"` + BaseModel +} + +// TableName specify table name +func (IngestionTask) TableName() string { + return "ingestion_task" +} + +type IngestionTasklet struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + TaskID string `gorm:"column:task_id;size:32;not null;index" json:"task_id"` + Config JSONMap `gorm:"column:config;type:longtext;not null" json:"config"` + TryCount int `gorm:"column:try_count;type:int;default:0" json:"try_count"` + BaseModel +} + +// TableName specify table name +func (IngestionTasklet) TableName() string { + return "ingestion_tasklet" +} + +type IngestionTaskLog struct { + ID int `gorm:"column:id;primaryKey;autoIncrement" json:"id"` + TaskID string `gorm:"column:task_id;size:32;not null;index" json:"task_id"` + Stage int `gorm:"column:stage;type:int;default:0;not null;" json:"stage"` + DataSchema JSONMap `gorm:"column:config;type:longtext;not null" json:"data_schema"` + BaseModel +} + +// TableName specify table name +func (IngestionTaskLog) TableName() string { + return "ingestion_task_log" +} + +type IngestionTaskletLog struct { + ID int `gorm:"column:id;primaryKey;autoIncrement" json:"id"` + TaskletID string `gorm:"column:tasklet_id;size:32;not null;index" json:"task_id"` + Stage int `gorm:"column:stage;type:int;default:0;not null;" json:"stage"` + DataSchema JSONMap `gorm:"column:config;type:longtext;not null" json:"data_schema"` + BaseModel +} + +// TableName specify table name +func (IngestionTaskletLog) TableName() string { + return "ingestion_tasklet_log" +} diff --git a/internal/handler/document.go b/internal/handler/document.go index 9b3307a724..b2a5be438c 100644 --- a/internal/handler/document.go +++ b/internal/handler/document.go @@ -567,10 +567,11 @@ func (h *DocumentHandler) SetMeta(c *gin.Context) { type ParseDocumentRequest struct { Documents []string `json:"documents" binding:"required"` - DatasetID string `json:"dataset_id" binding:"required"` } func (h *DocumentHandler) ParseDocuments(c *gin.Context) { + datasetID := c.Param("dataset_id") + var req ParseDocumentRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusOK, gin.H{ @@ -582,12 +583,12 @@ func (h *DocumentHandler) ParseDocuments(c *gin.Context) { userID := c.GetString("user_id") - if !h.datasetService.Accessible(req.DatasetID, userID) { + if !h.datasetService.Accessible(datasetID, userID) { jsonError(c, common.CodeAuthenticationError, "No authorization to access the dataset.") return } - err := h.documentService.ParseDocuments(req.DatasetID, userID, req.Documents) + parseResult, err := h.documentService.ParseDocuments(datasetID, userID, req.Documents) if err != nil { jsonError(c, common.CodeExceptionError, err.Error()) return @@ -595,5 +596,6 @@ func (h *DocumentHandler) ParseDocuments(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "code": 0, "message": "success", + "data": parseResult, }) } diff --git a/internal/ingestion/ingestion_service.go b/internal/ingestion/ingestion_service.go new file mode 100644 index 0000000000..e29f13aa43 --- /dev/null +++ b/internal/ingestion/ingestion_service.go @@ -0,0 +1,590 @@ +package ingestion + +import ( + "context" + "fmt" + "log" + "math" + "os" + "sync" + "time" + + "github.com/shirou/gopsutil/v3/process" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "ragflow/internal/common" +) + +type taskWrapper struct { + task *common.TaskAssignment +} + +type Ingestor struct { + id string + name string + serverAddr string + conn *grpc.ClientConn + client common.IngestionManagerClient + stream common.IngestionManager_ActionClient + ctx context.Context + cancel context.CancelFunc + reconnectMu sync.Mutex + + // Configuration + maxConcurrency int32 + supportedDocTypes []string + version string + + // Runtime state + currentTasks map[string]*TaskContext + tasksMu sync.RWMutex + + // Shutdown channel - receive on this to trigger graceful shutdown + ShutdownCh chan struct{} + + // Worker pool + taskChan chan *TaskContext + workerWg sync.WaitGroup + startOnce sync.Once +} + +type TaskContext struct { + Ctx context.Context + CancelFunc context.CancelFunc + Task *common.TaskAssignment + Status string // PENDING, RUNNING, COMPLETED, FAILED, CANCELLING, CANCELLED + StartTime time.Time + EndTime time.Time + estimatedRemainingTime time.Duration // estimated cost in seconds to complete the task + Progress int32 + ErrorMessage string +} + +func NewIngestor(name string, maxConcurrency int32, supportedTypes []string) *Ingestor { + ctx, cancel := context.WithCancel(context.Background()) + id := common.GenerateUUID() + return &Ingestor{ + id: id, + name: name, + ctx: ctx, + cancel: cancel, + maxConcurrency: maxConcurrency, + supportedDocTypes: supportedTypes, + version: "1.0.0", + currentTasks: make(map[string]*TaskContext), + taskChan: make(chan *TaskContext, maxConcurrency*2), + ShutdownCh: make(chan struct{}, 1), + } +} + +// Connect connects to the admin and establishes a bidirectional stream +func (e *Ingestor) Connect(serverAddr string) error { + e.serverAddr = serverAddr + conn, err := grpc.Dial(serverAddr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + grpc.WithTimeout(5*time.Second), + ) + if err != nil { + return fmt.Errorf("fail to connect admin server: %s", err.Error()) + } + e.conn = conn + + e.client = common.NewIngestionManagerClient(conn) + + stream, err := e.client.Action(e.ctx) + if err != nil { + conn.Close() + return err + } + e.stream = stream + + common.Info(fmt.Sprintf("Ingestor %s connected to admin", e.id)) + + // 1. Send registration message + if err = e.sendRegister(); err != nil { + conn.Close() + return err + } + + // Ensure worker pool is started on first task + e.startWorkerPool() + + // 2. Start receive loop + go e.receiveLoop() + + // 3. Start heartbeat loop + go e.heartbeatLoop() + + return nil +} + +func (e *Ingestor) sendRegister() error { + msg := &common.IngestionMessage{ + IngestorId: e.id, + MessageType: "REGISTER", + RegisterInfo: &common.RegisterInfo{ + MaxConcurrency: e.maxConcurrency, + SupportedDocTypes: e.supportedDocTypes, + Version: e.version, + Name: e.name, + }, + } + return e.stream.Send(msg) +} + +func (e *Ingestor) sendHeartbeat() error { + e.tasksMu.RLock() + + cutoff := time.Now().Add(-10 * time.Minute) + var toDeleteTask []string + taskStates := make([]*common.TaskState, 0, len(e.currentTasks)) + + for tid, tc := range e.currentTasks { + // Check if task is in a terminal state and expired beyond 10 minutes + if (tc.Status == "CANCELED" || tc.Status == "COMPLETED" || tc.Status == "REJECTED") && + !tc.EndTime.IsZero() && tc.EndTime.Before(cutoff) { + toDeleteTask = append(toDeleteTask, tid) + } else { + taskStates = append(taskStates, &common.TaskState{ + TaskId: tid, + Status: tc.Status, + EstimatedRemainingTimeSeconds: int64(tc.estimatedRemainingTime), + ErrorMessage: tc.ErrorMessage, + StartTime: tc.StartTime.UnixNano(), + ComeFrom: tc.Task.ComeFrom, + }) + } + } + e.tasksMu.RUnlock() + + // Delete expired terminal tasks from currentTasks + if len(toDeleteTask) > 0 { + e.tasksMu.Lock() + for _, id := range toDeleteTask { + delete(e.currentTasks, id) + } + e.tasksMu.Unlock() + } + + var pid = int64(os.Getpid()) + p, err := process.NewProcess(int32(pid)) + if err != nil { + log.Fatal(err) + } + + var cpuPercent float64 + cpuPercent, err = p.Percent(100 * time.Millisecond) + if err != nil { + cpuPercent = math.NaN() + common.Info(fmt.Sprintf("Fail to read CPU usage: %v", err)) + } + + RssUsage := math.NaN() + VmsUsage := math.NaN() + memInfo, err := p.MemoryInfo() + if err == nil { + RssUsage = float64(memInfo.RSS) + VmsUsage = float64(memInfo.VMS) + } else { + common.Info(fmt.Sprintf("Fail to read memory usage: %v", err)) + } + msg := &common.IngestionMessage{ + IngestorId: e.id, + MessageType: "HEARTBEAT", + HeartbeatInfo: &common.HeartbeatInfo{ + TaskStates: taskStates, + DeleteTaskIds: toDeleteTask, + CpuUsage: float32(cpuPercent), + VmsUsage: float32(VmsUsage), + RssUsage: float32(RssUsage), + ProcessId: pid, + }, + } + return e.stream.Send(msg) +} + +func (e *Ingestor) sendTaskResult(taskID, status, errorMsg string) error { + msg := &common.IngestionMessage{ + IngestorId: e.id, + MessageType: "TASK_RESULT", + TaskResult: &common.TaskResult{ + TaskId: taskID, + Status: status, + ErrorMessage: errorMsg, + }, + } + return e.stream.Send(msg) +} + +func (e *Ingestor) sendTaskProgress(taskID string, progress int32, info string) error { + msg := &common.IngestionMessage{ + IngestorId: e.id, + MessageType: "TASK_PROGRESS", + TaskProgress: &common.TaskProgress{ + TaskId: taskID, + Progress: progress, + Info: info, + }, + } + return e.stream.Send(msg) +} + +func (e *Ingestor) receiveLoop() { + for { + msg, err := e.stream.Recv() + if err != nil { + if e.ctx.Err() != nil { + common.Info(fmt.Sprintf("Ingestor %s context cancelled, receive loop exiting", e.id)) + return + } + common.Info(fmt.Sprintf("Receive error: %v", err)) + common.Info("Admin connection lost, attempting to reconnect") + e.reconnect() + return + } + + switch msg.MessageType { + case "TASK_ASSIGNMENT": + e.handleTaskAssignment(msg.TaskAssignment) + + case "ACK": + common.Info(fmt.Sprintf("Received ACK: task=%s, success=%v, msg=%s", + msg.AckInfo.TaskId, msg.AckInfo.Success, msg.AckInfo.Message)) + + case "ERROR": + common.Info(fmt.Sprintf("Received error from admin: %s", msg.ErrorMessage)) + + default: + common.Info(fmt.Sprintf("Unknown admin message type: %s", msg.MessageType)) + } + } +} + +func (e *Ingestor) handleTaskAssignment(task *common.TaskAssignment) { + if task == nil { + return + } + + common.Info(fmt.Sprintf("Received task: %s, task_type=%s", task.TaskId, task.TaskType)) + + switch task.TaskType { + case "shutdown_ingestor": + if e.id == task.AssignedTo { + e.handleShutdownIngestor() + return + } + + common.Error("unmatched ingestor id", fmt.Errorf("attempt to shutdown ingestor: %s, current ingestor: %s, mismatched", task.AssignedTo, e.id)) + return + case "cancel_ingestion_task": + e.handleCancelTask(task.TaskId) + return + } + + // Construct TaskContext with a cancellable context + ctx, cancel := context.WithCancel(e.ctx) + taskCtx := &TaskContext{ + Ctx: ctx, + CancelFunc: cancel, + Task: task, + Status: "QUEUED", + } + + // Register in currentTasks immediately so heartbeat sees PENDING state + e.tasksMu.Lock() + e.currentTasks[task.TaskId] = taskCtx + e.tasksMu.Unlock() + + common.Info("wait for 10 seconds") + time.Sleep(time.Second * 10) + // Push to task channel; if full, reject the task (backpressure) + select { + case e.taskChan <- taskCtx: + common.Info(fmt.Sprintf("Task %s queued (channel: %d/%d)", task.TaskId, len(e.taskChan), cap(e.taskChan))) + default: + common.Info(fmt.Sprintf("No available slot for task %s, rejecting", task.TaskId)) + //e.tasksMu.Lock() + //delete(e.currentTasks, task.TaskId) + //e.tasksMu.Unlock() + taskCtx.Status = "REJECTED" + taskCtx.EndTime = time.Now() + e.sendTaskResult(taskCtx.Task.TaskId, "REJECTED", "task rejected before execution") + } +} + +func (e *Ingestor) handleCancelTask(taskID string) { + e.tasksMu.Lock() + taskCtx, exists := e.currentTasks[taskID] + e.tasksMu.Unlock() + + if !exists { + common.Info(fmt.Sprintf("Cancel request for unknown task %s, ignoring", taskID)) + return + } + + common.Info(fmt.Sprintf("Cancelling task %s (current status: %s)", taskID, taskCtx.Status)) + taskCtx.CancelFunc() +} + +func (e *Ingestor) handleShutdownIngestor() { + common.Info(fmt.Sprintf("Shutdown task received, initiating graceful shutdown of ingestor %s", e.id)) + select { + case e.ShutdownCh <- struct{}{}: + default: + } + return +} + +func (e *Ingestor) startWorkerPool() { + e.startOnce.Do(func() { + for i := int32(0); i < e.maxConcurrency; i++ { + e.workerWg.Add(1) + go e.workerLoop(i) + } + common.Info(fmt.Sprintf("Worker pool started with %d workers", e.maxConcurrency)) + }) +} + +func (e *Ingestor) workerLoop(id int32) { + defer e.workerWg.Done() + common.Info(fmt.Sprintf("Worker %d started", id)) + for { + select { + case <-e.ctx.Done(): + return + case taskCtx := <-e.taskChan: + // Skip tasks that were canceled while queued + select { + case <-taskCtx.Ctx.Done(): + common.Info(fmt.Sprintf("Task %s was cancelled while queued, skipping", taskCtx.Task.TaskId)) + taskCtx.Status = "CANCELED" + taskCtx.EndTime = time.Now() + //e.tasksMu.Lock() + //delete(e.currentTasks, taskCtx.Task.TaskId) + //e.tasksMu.Unlock() + e.sendTaskResult(taskCtx.Task.TaskId, "CANCELED", "task cancelled before execution") + continue + default: + } + + // Mark as RUNNING + taskCtx.Status = "RUNNING" + taskCtx.StartTime = time.Now() + + // Execute the task (synchronously within worker) + e.executeTask(taskCtx) + } + } +} + +func (e *Ingestor) executeTask(taskCtx *TaskContext) { + defer func() { + //e.tasksMu.Lock() + //delete(e.currentTasks, taskCtx.Task.TaskId) + //e.tasksMu.Unlock() + }() + + ctx := taskCtx.Ctx + task := taskCtx.Task + common.Info(fmt.Sprintf("Starting task %s", task.TaskId)) + + // Simulate task execution progress + // In production, this would split into subtasks and execute in parallel + for progress := int32(0); progress <= 100; progress += 10 { + select { + case <-ctx.Done(): + // Task canceled + common.Info(fmt.Sprintf("Task %s cancelled", task.TaskId)) + taskCtx.Status = "CANCELED" + taskCtx.EndTime = time.Now() + e.sendTaskResult(task.TaskId, "CANCELED", "task cancelled") + return + case <-time.After(5000 * time.Millisecond): + // Simulate progress update + taskCtx.Progress = progress + e.sendTaskProgress(task.TaskId, progress, "processing...") + common.Info(fmt.Sprintf("Task %s progress: %d%%", task.TaskId, progress)) + } + } + + // Simulate subtask splitting and execution (demonstration) + e.executeWithSubTasks(task) + + taskCtx.Status = "COMPLETED" + taskCtx.EndTime = time.Now() + + time.Sleep(time.Second * 10) + + // Task completed + e.sendTaskResult(task.TaskId, "COMPLETED", "") + common.Info(fmt.Sprintf("Task %s completed", task.TaskId)) +} + +// executeWithSubTasks demonstrates subtask splitting and parallel execution +func (e *Ingestor) executeWithSubTasks(task *common.TaskAssignment) { + common.Info(fmt.Sprintf("Task %s: splitting into subtasks", task.TaskId)) + + // Simulate splitting into 4 subtasks + subTasks := []struct { + id string + index int + }{ + {task.TaskId + "-sub1", 1}, + {task.TaskId + "-sub2", 2}, + {task.TaskId + "-sub3", 3}, + {task.TaskId + "-sub4", 4}, + } + + // Wait for all subtasks to complete + var wg sync.WaitGroup + results := make(chan error, len(subTasks)) + + // Execute subtasks in parallel + for _, st := range subTasks { + wg.Add(1) + go func(subID string, idx int) { + defer wg.Done() + + common.Info(fmt.Sprintf("Subtask %s started", subID)) + // Simulate subtask execution + time.Sleep(1 * time.Second) + common.Info(fmt.Sprintf("Subtask %s completed", subID)) + results <- nil + }(st.id, st.index) + } + + // Wait for all subtasks to complete + wg.Wait() + close(results) + + // Check if any subtasks failed + failedCount := 0 + for err := range results { + if err != nil { + failedCount++ + } + } + + if failedCount > 0 { + common.Info(fmt.Sprintf("Task %s: %d subtasks failed", task.TaskId, failedCount)) + } else { + common.Info(fmt.Sprintf("Task %s: all subtasks completed successfully", task.TaskId)) + } +} + +func (e *Ingestor) heartbeatLoop() { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for { + select { + case <-e.ctx.Done(): + return + case <-ticker.C: + if err := e.sendHeartbeat(); err != nil { + common.Info(fmt.Sprintf("Failed to send heartbeat: %v", err)) + if e.ctx.Err() != nil { + common.Info(fmt.Sprintf("Ingestor %s context cancelled, heartbeat loop exiting", e.id)) + return + } + common.Info(fmt.Sprintf("Admin connection lost, attempting to reconnect")) + e.reconnect() + return + } + } + } +} + +// reconnect closes the old connection and establishes a new one with exponential backoff. +// Only one reconnection attempt runs at a time; concurrent callers return immediately. +func (e *Ingestor) reconnect() { + if e.ctx.Err() != nil { + common.Info(fmt.Sprintf("Ingestor %s is shutting down, skipping reconnection", e.id)) + return + } + + if !e.reconnectMu.TryLock() { + return + } + defer e.reconnectMu.Unlock() + + common.Info(fmt.Sprintf("Ingestor %s attempting to reconnect to admin at %s", e.id, e.serverAddr)) + + // Close old stream and connection + if e.stream != nil { + e.stream.CloseSend() + } + if e.conn != nil { + e.conn.Close() + } + + backoff := 1 * time.Second + maxBackoff := 30 * time.Second + + for { + conn, err := grpc.Dial(e.serverAddr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + grpc.WithTimeout(5*time.Second), + ) + if err != nil { + common.Info(fmt.Sprintf("Reconnect dial failed: %v, retrying in %v", err, backoff)) + time.Sleep(backoff) + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + continue + } + e.conn = conn + e.client = common.NewIngestionManagerClient(conn) + + stream, err := e.client.Action(e.ctx) + if err != nil { + conn.Close() + common.Info(fmt.Sprintf("Reconnect create stream failed: %v, retrying in %v", err, backoff)) + time.Sleep(backoff) + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + continue + } + e.stream = stream + + if err = e.sendRegister(); err != nil { + stream.CloseSend() + conn.Close() + common.Info(fmt.Sprintf("Reconnect register failed: %v, retrying in %v", err, backoff)) + time.Sleep(backoff) + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + continue + } + + common.Info(fmt.Sprintf("Ingestor %s reconnected to admin", e.id)) + break + } + + // Restart the loops on the new stream + go e.receiveLoop() + go e.heartbeatLoop() +} + +// Stop gracefully shuts down the ingestor +func (e *Ingestor) Stop() { + common.Info(fmt.Sprintf("Stopping ingestor %s", e.id)) + e.cancel() + + // Wait for all workers to finish (they exit on ctx.Done()) + e.workerWg.Wait() + common.Info("All tasks completed") + + if e.stream != nil { + e.stream.CloseSend() + } +} diff --git a/internal/proto/ingestion.proto b/internal/proto/ingestion.proto new file mode 100644 index 0000000000..4877a3b2d1 --- /dev/null +++ b/internal/proto/ingestion.proto @@ -0,0 +1,80 @@ +syntax = "proto3"; + +package common; + +option go_package = "./;common"; + +service IngestionManager { + rpc Action(stream IngestionMessage) returns (stream AdminMessage); +} + +message IngestionMessage { + string ingestor_id = 1; + string message_type = 2; // REGISTER, HEARTBEAT, TASK_RESULT, TASK_PROGRESS, PULL_REQUEST + + RegisterInfo register_info = 3; + + HeartbeatInfo heartbeat_info = 4; + + TaskResult task_result = 5; + + TaskProgress task_progress = 6; +} + +message AdminMessage { + string message_type = 1; // TASK_ASSIGNMENT, ACK, PONG, RECONNECT + TaskAssignment task_assignment = 2; + AckInfo ack_info = 3; + string error_message = 4; +} + +message RegisterInfo { + int32 max_concurrency = 1; + repeated string supported_doc_types = 2; + string version = 3; + string name = 4; +} + +message HeartbeatInfo { + repeated TaskState task_states = 1; + repeated string delete_task_ids = 2; + float cpu_usage = 3; // percentage + float vms_usage = 4; // absolute value + float rss_usage = 5; // absolute value + int64 process_id = 6; // pid +} + +message TaskState { + string task_id = 1; + string status = 2; // PENDING, RUNNING, COMPLETED, FAILED, CANCELLED + string error_message = 3; + int64 estimated_remaining_time_seconds = 4; + string come_from = 5; + int64 start_time = 6; +} + +message TaskAssignment { + string task_id = 1; + string task_type = 2; + string config = 3; + string come_from = 4; + string assigned_to = 5; +} + +message TaskResult { + string task_id = 1; + string status = 2; // COMPLETED, FAILED, CANCELLED + string error_message = 3; +} + +message TaskProgress { + string task_id = 1; + int32 progress = 2; + string info = 3; +} + +message AckInfo { + string task_id = 1; + bool success = 2; + string message = 3; +} \ No newline at end of file diff --git a/internal/router/router.go b/internal/router/router.go index 143877babc..045f72634f 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -165,7 +165,6 @@ func (r *Router) Setup(engine *gin.Engine) { documents.GET("/:id", r.documentHandler.GetDocumentByID) documents.PUT("/:id", r.documentHandler.UpdateDocument) documents.DELETE("/:id", r.documentHandler.DeleteDocument) - documents.POST("/parse", r.documentHandler.ParseDocuments) } // Chat routes @@ -194,6 +193,7 @@ func (r *Router) Setup(engine *gin.Engine) { // Dataset document chunk datasets.GET("/:dataset_id/documents/:document_id/chunks/:chunk_id", r.chunkHandler.Get) + datasets.POST("/:dataset_id/documents/parse", r.documentHandler.ParseDocuments) } // Search routes diff --git a/internal/server/config.go b/internal/server/config.go index 27e97b2472..72e1325400 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -52,8 +52,9 @@ type Config struct { // AdminConfig admin server configuration type AdminConfig struct { - Host string `mapstructure:"host"` - Port int `mapstructure:"http_port"` + Host string `mapstructure:"host"` + Port int `mapstructure:"http_port"` + IngestionManagerPort int `mapstructure:"ingestion_manager_port"` } type AuthenticationConfig struct { @@ -566,6 +567,9 @@ func FromConfigFile(configPath string) error { } else { globalConfig.Admin.Port += 2 } + if globalConfig.Admin.IngestionManagerPort == 0 { + globalConfig.Admin.IngestionManagerPort = 9385 + } // authentication section if globalConfig != nil { diff --git a/internal/service/chat.go b/internal/service/chat.go index 060bd3cc56..122b0c227b 100644 --- a/internal/service/chat.go +++ b/internal/service/chat.go @@ -19,12 +19,11 @@ package service import ( "errors" "fmt" + "ragflow/internal/common" "ragflow/internal/entity" "strings" "unicode/utf8" - "github.com/google/uuid" - "ragflow/internal/dao" ) @@ -446,11 +445,7 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo if isCreate { // Generate UUID for new chat - newID := uuid.New().String() - newID = strings.ReplaceAll(newID, "-", "") - if len(newID) > 32 { - newID = newID[:32] - } + newID := common.GenerateUUID() // Set default language language := "English" diff --git a/internal/service/chat_session.go b/internal/service/chat_session.go index 50402f9d7c..659838a72f 100644 --- a/internal/service/chat_session.go +++ b/internal/service/chat_session.go @@ -24,7 +24,6 @@ import ( "strings" "time" - "github.com/google/uuid" "go.uber.org/zap" "ragflow/internal/dao" @@ -77,8 +76,8 @@ func (s *ChatSessionService) SetChatSession(userID string, req *SetChatSessionRe if !req.IsNew { // Update existing chat session updates := map[string]interface{}{ - "name": name, - "user_id": userID, + "name": name, + "user_id": userID, } if err := s.chatSessionDAO.UpdateByID(req.SessionID, updates); err != nil { @@ -102,11 +101,7 @@ func (s *ChatSessionService) SetChatSession(userID string, req *SetChatSessionRe } // Generate UUID for new chat session - newID := uuid.New().String() - newID = strings.ReplaceAll(newID, "-", "") - if len(newID) > 32 { - newID = newID[:32] - } + newID := common.GenerateUUID() // Get prologue from dialog's prompt_config prologue := "Hi! I'm your assistant. What can I do for you?" @@ -448,8 +443,8 @@ func (s *ChatSessionService) updateSessionMessages(session *entity.ChatSession, referenceJSON, _ := json.Marshal(reference) updates := map[string]interface{}{ - "message": messagesJSON, - "reference": referenceJSON, + "message": messagesJSON, + "reference": referenceJSON, } s.chatSessionDAO.UpdateByID(session.ID, updates) } diff --git a/internal/service/document.go b/internal/service/document.go index aeef7fb004..072b1d946b 100644 --- a/internal/service/document.go +++ b/internal/service/document.go @@ -19,6 +19,7 @@ package service import ( "encoding/json" "fmt" + "ragflow/internal/common" "ragflow/internal/entity" "regexp" "sort" @@ -28,26 +29,32 @@ import ( "ragflow/internal/engine" "ragflow/internal/server" + + "github.com/google/uuid" ) // DocumentService document service type DocumentService struct { - documentDAO *dao.DocumentDAO - kbDAO *dao.KnowledgebaseDAO - docEngine engine.DocEngine - engineType server.EngineType - metadataSvc *MetadataService + documentDAO *dao.DocumentDAO + kbDAO *dao.KnowledgebaseDAO + ingestionTaskDAO *dao.IngestionDAO + ingestionLogDAO *dao.IngestionLogDAO + docEngine engine.DocEngine + engineType server.EngineType + metadataSvc *MetadataService } // NewDocumentService create document service func NewDocumentService() *DocumentService { cfg := server.GetConfig() return &DocumentService{ - documentDAO: dao.NewDocumentDAO(), - kbDAO: dao.NewKnowledgebaseDAO(), - docEngine: engine.Get(), - engineType: cfg.DocEngine.Type, - metadataSvc: NewMetadataService(), + documentDAO: dao.NewDocumentDAO(), + ingestionTaskDAO: dao.NewIngestionDAO(), + ingestionLogDAO: dao.NewIngestionLogDAO(), + kbDAO: dao.NewKnowledgebaseDAO(), + docEngine: engine.Get(), + engineType: cfg.DocEngine.Type, + metadataSvc: NewMetadataService(), } } @@ -207,11 +214,79 @@ func (s *DocumentService) GetDocumentsByAuthorID(authorID, page, pageSize int) ( return responses, total, nil } -func (s *DocumentService) ParseDocuments(datasetID, userID string, docIDs []string) error { +type ParseDocumentResponse struct { + DocumentID string `json:"document_id"` + Result *string `json:"result"` +} + +func (s *DocumentService) ParseDocuments(datasetID, userID string, docIDs []string) ([]*ParseDocumentResponse, error) { // create document parse id // save to task table // send to message queue - return nil + + // deduplicate the document id + uniqueDocIDs := common.Deduplicate(docIDs) + if uniqueDocIDs == nil || len(uniqueDocIDs) == 0 { + return nil, fmt.Errorf("no documents to parse") + } + + var responses []*ParseDocumentResponse + + // query database, if the document ids are valid + for _, docID := range uniqueDocIDs { + doc, err := s.documentDAO.GetByID(docID) + if err != nil { + errorMessage := err.Error() + responses = append(responses, &ParseDocumentResponse{ + DocumentID: docID, + Result: &errorMessage, + }) + continue + } + if doc == nil { + errorMessage := "no such document" + responses = append(responses, &ParseDocumentResponse{ + DocumentID: docID, + Result: &errorMessage, + }) + continue + } + + if doc.Status != nil && *doc.Status != "0" { + errorMessage := fmt.Sprintf("document %s is already parsed", docID) + responses = append(responses, &ParseDocumentResponse{ + DocumentID: docID, + Result: &errorMessage, + }) + continue + } + + // create task for each document + task := &entity.IngestionTask{ + ID: uuid.New().String(), + DocumentID: docID, + UserID: userID, + Config: nil, + TryCount: 1, + } + + // save the task to database + err = s.ingestionTaskDAO.Create(task) + if err != nil { + errorMessage := err.Error() + responses = append(responses, &ParseDocumentResponse{ + DocumentID: docID, + Result: &errorMessage, + }) + continue + } + + // Send task to message queue + + } + + common.Info(fmt.Sprintf("parse documents, dataset: %s, documents: %v", datasetID, docIDs)) + return responses, nil } // toResponse convert model.Document to DocumentResponse diff --git a/tools/scripts/gen-proto.sh b/tools/scripts/gen-proto.sh new file mode 100755 index 0000000000..f793335b1d --- /dev/null +++ b/tools/scripts/gen-proto.sh @@ -0,0 +1,51 @@ +echo "Generating protobuf and gRPC code..." + +if ! command -v protoc &> /dev/null; then + echo "❌ protoc not found!" + echo "Please install protoc first:" + echo " - macOS: brew install protobuf" + echo " - Ubuntu: apt install protobuf-compiler" + echo " - Download: https://github.com/protocolbuffers/protobuf/releases" + exit 1 +fi +echo "✅ protoc: $(which protoc)" + +if ! command -v protoc-gen-go &> /dev/null; then + echo "" + echo "❌ protoc-gen-go not found!" + echo "Please install it:" + echo " go install google.golang.org/protobuf/cmd/protoc-gen-go@latest" + echo "" + echo "Or add Go bin to PATH:" + echo " export PATH=\$PATH:$(go env GOPATH)/bin" + exit 1 +fi +echo "✅ protoc-gen-go: $(which protoc-gen-go)" + +if ! command -v protoc-gen-go-grpc &> /dev/null; then + echo "" + echo "❌ protoc-gen-go-grpc not found!" + echo "Please install it:" + echo " go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest" + echo "" + echo "Or add Go bin to PATH:" + echo " export PATH=\$PATH:$(go env GOPATH)/bin" + exit 1 +fi +echo "✅ protoc-gen-go-grpc: $(which protoc-gen-go-grpc)" + +mkdir -p internal/common + +protoc --go_out=internal/common \ + --go-grpc_out=internal/common \ + internal/proto/ingestion.proto + +if [ $? -eq 0 ]; then + echo "✅ Generation successful!" + echo "Generated files:" + echo " - internal/common/ingestion.pb.go" + echo " - internal/common/ingestion_grpc.pb.go" +else + echo "❌ Generation failed!" + exit 1 +fi