diff --git a/.gitignore b/.gitignore index 097a885152..3d23e6f1bb 100644 --- a/.gitignore +++ b/.gitignore @@ -232,4 +232,6 @@ internal/cpp/cmake-build-debug/ # Go server build output bin/* !bin/.gitkeep -.claude/settings.local.json \ No newline at end of file +.claude/settings.local.json + +.run/ \ No newline at end of file diff --git a/cmd/admin_server.go b/cmd/admin_server.go index 940606ab93..e8d1e1b96b 100644 --- a/cmd/admin_server.go +++ b/cmd/admin_server.go @@ -94,6 +94,10 @@ func main() { } defer cache.Close() + if err := engine.InitMessageQueueEngine(cfg.TaskExecutor.MessageQueueType); err != nil { + common.Error("Failed to initialize message queue engine", err) + } + // Initialize server variables (runtime variables that can change during operation) // This must be done after Cache is initialized if err := server.InitVariables(cache.Get()); err != nil { @@ -149,16 +153,6 @@ func main() { // Print RAGFlow version common.Info(fmt.Sprintf("RAGFlow admin version: %s", utility.GetRAGFlowVersion())) - // Start ingestion manager (gRPC) in a goroutine - ingestionMgr := admin.NewAdminServer() - go func() { - addr = fmt.Sprintf(":%d", cfg.Admin.IngestionManagerPort) - common.Info(fmt.Sprintf("Starting RAGFlow ingestion manager on port: %d", cfg.Admin.IngestionManagerPort)) - if err := ingestionMgr.Start(addr); err != nil { - common.Fatal("Failed to start RAGFlow ingestion manager", zap.Error(err)) - } - }() - // Start HTTP server in a goroutine go func() { common.Info(fmt.Sprintf("Starting RAGFlow admin HTTP server on port: %d", cfg.Admin.Port)) @@ -185,7 +179,4 @@ func main() { } common.Info("Admin HTTP server exited") - - // Stop ingestion manager - ingestionMgr.Stop() } diff --git a/cmd/ingestion_server.go b/cmd/ingestor.go similarity index 81% rename from cmd/ingestion_server.go rename to cmd/ingestor.go index a46a70e40f..0e020a9170 100644 --- a/cmd/ingestion_server.go +++ b/cmd/ingestor.go @@ -23,6 +23,8 @@ import ( "os" "os/signal" "ragflow/internal/ingestion" + "ragflow/internal/server/local" + "ragflow/internal/service" "ragflow/internal/service/nlp" "ragflow/internal/tokenizer" "ragflow/internal/utility" @@ -129,6 +131,10 @@ func main() { common.Fatal("Failed to initialize storage factory", zap.Error(err)) } + if err := engine.InitMessageQueueEngine(config.TaskExecutor.MessageQueueType); err != nil { + common.Fatal(fmt.Sprintf("Failed to initialize message queue engine: %w", err)) + } + // Initialize server variables (runtime variables from Redis) if err := server.InitVariables(cache.Get()); err != nil { common.Warn("Failed to initialize server variables from Redis, using defaults", zap.String("error", err.Error())) @@ -150,11 +156,13 @@ func main() { ingestor := ingestion.NewIngestor(name, 2, []string{"pdf", "docx", "txt"}) - // Connect to the admin server - serverAddress := fmt.Sprintf("%s:%d", config.Admin.Host, config.Admin.IngestionManagerPort) - if err := ingestor.Connect(serverAddress); err != nil { - common.Fatal(fmt.Sprintf("Error: %s", err.Error())) - } + go func() { + err := ingestor.Start() + if err != nil { + common.Error("Failed to initialize ingestor", err) + return + } + }() quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGUSR2) @@ -169,7 +177,37 @@ func main() { " /____/\n") // Print RAGFlow version - common.Info(fmt.Sprintf("RAGFlow admin version: %s", utility.GetRAGFlowVersion())) + common.Info(fmt.Sprintf("RAGFlow ingestion service version: %s", utility.GetRAGFlowVersion())) + + // Get local IP address for heartbeat reporting + localIP, err := utility.GetLocalIP() + if err != nil { + common.Fatal("fail to get local ip address") + } + + // Initialize and start heartbeat reporter to admin server + service.AdminServiceClient = service.NewAdminClient( + common.Logger, + common.ServerTypeIngestion, + fmt.Sprintf("ingestor-%s", ingestor.ID()), + localIP, + -1, + ) + if err = service.AdminServiceClient.InitHTTPClient(); err != nil { + common.Warn("Failed to initialize heartbeat service", zap.Error(err)) + } else { + // Start heartbeat reporter with 30 seconds interval + heartbeatReporter := utility.NewScheduledTask("Heartbeat reporter", 3*time.Second, func() { + if err = service.AdminServiceClient.SendHeartbeat(); err == nil { + local.SetAdminStatus(0, "") + } else { + local.SetAdminStatus(1, err.Error()) + //logger.Warn(fmt.Sprintf(err.Error())) + } + }) + heartbeatReporter.Start() + defer heartbeatReporter.Stop() + } // Wait for either an OS signal or a shutdown command from the admin select { diff --git a/cmd/server_main.go b/cmd/server_main.go index a90504d369..22bfc39fa8 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -40,9 +40,9 @@ import ( "ragflow/internal/dao" "ragflow/internal/engine" "ragflow/internal/handler" - "ragflow/internal/service/chunk" "ragflow/internal/router" "ragflow/internal/service" + "ragflow/internal/service/chunk" "ragflow/internal/service/nlp" "ragflow/internal/tokenizer" ) @@ -131,6 +131,10 @@ func main() { common.Fatal("Failed to initialize storage factory", zap.Error(err)) } + if err := engine.InitMessageQueueEngine(config.TaskExecutor.MessageQueueType); err != nil { + common.Error("Failed to initialize message queue engine", err) + } + // Initialize server variables (runtime variables that can change during operation) // This must be done after Cache is initialized if err := server.InitVariables(cache.Get()); err != nil { @@ -288,19 +292,19 @@ func startServer(config *server.Config) { } // Initialize and start heartbeat reporter to admin server - heartbeatService := service.NewHeartbeatSender( + service.AdminServiceClient = service.NewAdminClient( common.Logger, common.ServerTypeAPI, fmt.Sprintf("ragflow-server-%d", config.Server.Port), localIP, config.Server.Port, ) - if err = heartbeatService.InitHTTPClient(); err != nil { + if err = service.AdminServiceClient.InitHTTPClient(); err != nil { common.Warn("Failed to initialize heartbeat service", zap.Error(err)) } else { // Start heartbeat reporter with 30 seconds interval heartbeatReporter := utility.NewScheduledTask("Heartbeat reporter", 3*time.Second, func() { - if err = heartbeatService.SendHeartbeat(); err == nil { + if err = service.AdminServiceClient.SendHeartbeat(); err == nil { local.SetAdminStatus(0, "") } else { local.SetAdminStatus(1, err.Error()) diff --git a/conf/service_conf.yaml b/conf/service_conf.yaml index 5e48a5ecd0..b535595e52 100644 --- a/conf/service_conf.yaml +++ b/conf/service_conf.yaml @@ -48,8 +48,11 @@ redis: username: '' password: 'infini_rag_flow' host: 'localhost:6379' +nats: + host: "0.0.0.0" + port: 4222 task_executor: - message_queue_type: 'redis' + message_queue_type: 'nats' user_default_llm: default_models: embedding_model: diff --git a/docker/.env b/docker/.env index 719857f720..5448ac5054 100644 --- a/docker/.env +++ b/docker/.env @@ -145,6 +145,9 @@ REDIS_PORT=6379 # The password for Redis. REDIS_PASSWORD=infini_rag_flow +NATS_HOST=nats +NATS_PORT=4222 + # The port used to expose RAGFlow's HTTP API service to the host machine, # allowing EXTERNAL access to the service running inside the Docker container. SVR_WEB_HTTP_PORT=80 diff --git a/docker/docker-compose-base.yml b/docker/docker-compose-base.yml index 22c7a7b482..fbc13b5d05 100644 --- a/docker/docker-compose-base.yml +++ b/docker/docker-compose-base.yml @@ -250,6 +250,25 @@ services: timeout: 10s retries: 120 + nats: + profiles: + - ragflow-go + image: nats:2.14.1 + ports: + - ${NATS_PORT}:4222 + - "8222:8222" + volumes: + - nats_data:/data + command: -js -sd /data + env_file: .env + networks: + - ragflow + restart: unless-stopped + healthcheck: + test: ["CMD", "nc", "-z", "localhost", "${NATS_PORT}"] + interval: 10s + timeout: 10s + retries: 120 tei-cpu: profiles: @@ -329,6 +348,8 @@ volumes: driver: local kibana_data: driver: local + nats_data: + driver: local networks: ragflow: diff --git a/docker/service_conf.yaml.template b/docker/service_conf.yaml.template index 032ccaa051..21bbfaca27 100644 --- a/docker/service_conf.yaml.template +++ b/docker/service_conf.yaml.template @@ -60,6 +60,9 @@ redis: username: '${REDIS_USERNAME:-}' password: '${REDIS_PASSWORD:-infini_rag_flow}' host: '${REDIS_HOST:-redis}:6379' +nats: + host: ${NATS_HOST:-0.0.0.0} + port: ${NATS_PORT:-4222} user_default_llm: default_models: embedding_model: diff --git a/go.mod b/go.mod index 54ce2bee7a..82949ed981 100644 --- a/go.mod +++ b/go.mod @@ -4,30 +4,34 @@ go 1.25.0 require ( github.com/aws/aws-sdk-go-v2 v1.41.3 + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.6 github.com/aws/aws-sdk-go-v2/config v1.32.11 github.com/aws/aws-sdk-go-v2/credentials v1.19.11 github.com/aws/aws-sdk-go-v2/service/s3 v1.96.4 + github.com/aws/aws-sdk-go-v2/service/sts v1.41.8 github.com/aws/smithy-go v1.24.2 github.com/cespare/xxhash/v2 v2.3.0 github.com/elastic/go-elasticsearch/v8 v8.19.1 github.com/gin-gonic/gin v1.9.1 + github.com/glebarez/sqlite v1.11.0 + github.com/go-sql-driver/mysql v1.7.0 + github.com/goccy/go-json v0.10.2 github.com/google/uuid v1.6.0 github.com/infiniflow/infinity-go-sdk v0.0.0-00010101000000-000000000000 github.com/iromli/go-itsdangerous v0.0.0-20220223194502-9c8bef8dac6a github.com/json-iterator/go v1.1.12 github.com/minio/minio-go/v7 v7.0.99 + github.com/nats-io/nats.go v1.52.0 github.com/peterh/liner v1.2.2 github.com/redis/go-redis/v9 v9.18.0 - github.com/shirou/gopsutil/v3 v3.24.5 github.com/siongui/gojianfan v0.0.0-20210926212422-2f175ac615de github.com/spf13/viper v1.18.2 go.uber.org/zap v1.27.1 - golang.org/x/crypto v0.47.0 - golang.org/x/net v0.49.0 + golang.org/x/crypto v0.49.0 + golang.org/x/net v0.51.0 golang.org/x/term v0.41.0 google.golang.org/genai v1.54.0 google.golang.org/grpc v1.79.3 - google.golang.org/protobuf v1.36.10 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.2 gorm.io/gorm v1.25.7 @@ -38,7 +42,6 @@ require ( cloud.google.com/go/auth v0.9.3 // indirect cloud.google.com/go/compute/metadata v0.9.0 // indirect github.com/apache/thrift v0.22.0 // indirect - github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.6 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19 // indirect @@ -51,7 +54,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/signin v1.0.7 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.30.12 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.41.8 // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect @@ -61,16 +63,12 @@ require ( github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/glebarez/go-sqlite v1.21.2 // indirect - github.com/glebarez/sqlite v1.11.0 // indirect github.com/go-ini/ini v1.67.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect - github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.16.0 // indirect - github.com/go-sql-driver/mysql v1.7.0 // indirect - github.com/goccy/go-json v0.10.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/s2a-go v0.1.8 // indirect @@ -79,11 +77,10 @@ require ( github.com/hashicorp/hcl v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect - github.com/klauspost/compress v1.18.2 // indirect + github.com/klauspost/compress v1.18.5 // indirect github.com/klauspost/cpuid/v2 v2.2.11 // indirect github.com/klauspost/crc32 v1.3.0 // indirect github.com/leodido/go-urn v1.2.4 // indirect - github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.3 // indirect @@ -92,25 +89,22 @@ require ( github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/nats-io/nkeys v0.4.15 // indirect + github.com/nats-io/nuid v1.0.1 // indirect github.com/pelletier/go-toml/v2 v2.1.1 // indirect github.com/philhofer/fwd v1.2.0 // indirect - github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rs/xid v1.6.0 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect - github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/tinylib/msgp v1.6.1 // indirect - github.com/tklauser/go-sysconf v0.3.12 // indirect - github.com/tklauser/numcpus v0.6.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect - github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel v1.41.0 // indirect @@ -122,8 +116,9 @@ require ( golang.org/x/arch v0.6.0 // indirect golang.org/x/exp v0.0.0-20231226003508-02704c960a9b // indirect golang.org/x/sys v0.42.0 // indirect - golang.org/x/text v0.33.0 // indirect + golang.org/x/text v0.35.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect + google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect modernc.org/libc v1.22.5 // indirect modernc.org/mathutil v1.5.0 // indirect diff --git a/go.sum b/go.sum index dc90061519..7eb3d719dc 100644 --- a/go.sum +++ b/go.sum @@ -98,8 +98,6 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= -github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -134,10 +132,11 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM= github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -159,8 +158,8 @@ github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= -github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= -github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= +github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.11 h1:0OwqZRYI2rFrjS4kvkDnqJkKHdHaRnCm68/DY4OxRzU= @@ -173,8 +172,6 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= -github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= -github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= @@ -194,6 +191,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/nats-io/nats.go v1.52.0 h1:n3avV4VBsCgsdwh71TppsTwtv+QdPs7ntSKM8qJLGsc= +github.com/nats-io/nats.go v1.52.0/go.mod h1:26HypzazeOkyO3/mqd1zZd53STJN0EjCYF9Uy2ZOBno= +github.com/nats-io/nkeys v0.4.15 h1:JACV5jRVO9V856KOapQ7x+EY8Jo3qw1vJt/9Jpwzkk4= +github.com/nats-io/nkeys v0.4.15/go.mod h1:CpMchTXC9fxA5zrMo4KpySxNjiDVvr8ANOSZdiNfUrs= +github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= +github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/pelletier/go-toml/v2 v2.1.1 h1:LWAJwfNvjQZCFIDKWYQaM62NcYeYViCmWIwmOStowAI= github.com/pelletier/go-toml/v2 v2.1.1/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= github.com/peterh/liner v1.2.2 h1:aJ4AOodmL+JxOZZEL2u9iJf8omNRpqHc/EbrK+3mAXw= @@ -203,8 +206,6 @@ github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= -github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs= github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= @@ -219,12 +220,6 @@ github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6ke github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= -github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI= -github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk= -github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= -github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= -github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU= -github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= github.com/siongui/gojianfan v0.0.0-20210926212422-2f175ac615de h1:1/P9CcR8iENN9ybbSRWohRd3rsPp9tEWlTS/7ygvjHE= github.com/siongui/gojianfan v0.0.0-20210926212422-2f175ac615de/go.mod h1:TRwEEJlrSIv+jc66k48huOZ2aKVBPL8V29ZcsjUIH70= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= @@ -253,16 +248,10 @@ github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8 github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/tinylib/msgp v1.6.1 h1:ESRv8eL3u+DNHUoSAAQRE50Hm162zqAnBoGv9PzScPY= github.com/tinylib/msgp v1.6.1/go.mod h1:RSp0LW9oSxFut3KzESt5Voq4GVWyS+PSulT77roAqEA= -github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= -github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= -github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= -github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= -github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= -github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= @@ -294,8 +283,8 @@ golang.org/x/arch v0.6.0 h1:S0JTfE48HbRj80+4tbvZDYsJ3tGv6BUU3XxyZ7CirAc= golang.org/x/arch v0.6.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= -golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20231226003508-02704c960a9b h1:kLiC65FbiHWFAOu+lxwNPujcsl8VYyTYYEZnsOO1WK4= golang.org/x/exp v0.0.0-20231226003508-02704c960a9b/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= @@ -308,32 +297,28 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= -golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= +golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20211117180635-dee7805ff2e1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= -golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= @@ -380,8 +365,6 @@ gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/mysql v1.5.2 h1:QC2HRskSE75wBuOxe0+iCkyJZ+RqpudsQtqkp+IMuXs= gorm.io/driver/mysql v1.5.2/go.mod h1:pQLhh1Ut/WUAySdTHwBpBv6+JKcj+ua4ZFx1QQTBzb8= gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= -gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= -gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.7 h1:VsD6acwRjz2zFxGO50gPO6AkNs7KKnvfzUjHQhZDz/A= gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/internal/admin/handler.go b/internal/admin/handler.go index 1ff2c7ff6c..fc1aa6847e 100644 --- a/internal/admin/handler.go +++ b/internal/admin/handler.go @@ -17,12 +17,14 @@ package admin import ( + "encoding/json" "errors" "fmt" "net/http" "ragflow/internal/cache" "ragflow/internal/common" "ragflow/internal/dao" + "ragflow/internal/engine" "ragflow/internal/server" "ragflow/internal/service" "ragflow/internal/utility" @@ -203,15 +205,6 @@ func (h *Handler) AuthCheck(c *gin.Context) { successNoData(c, "Admin is authorized") } -// ListTasks handle list tasks -func (h *Handler) ListTasks(c *gin.Context) { - tasks, err := h.service.ListTasks() - if err != nil { - errorResponse(c, err.Error(), 500) - } - success(c, tasks, "Get all tasks") -} - // ListUsers handle list users func (h *Handler) ListUsers(c *gin.Context) { users, err := h.service.ListUsers() @@ -261,7 +254,7 @@ func (h *Handler) GetUser(c *gin.Context) { userDetails, err := h.service.GetUserDetails(username) if err != nil { - if errors.Is(err, ErrUserNotFound) { + if errors.Is(err, common.ErrUserNotFound) { errorResponse(c, "User not found", 404) return } @@ -1256,57 +1249,206 @@ func (h *Handler) SetLogLevel(c *gin.Context) { success(c, gin.H{"level": req.Level}, "Log level updated successfully") } -type StartIngestionTaskRequest struct { - FileURI string `json:"uri" binding:"required"` - From string `json:"from" binding:"required"` +func (h *Handler) ListMessagesFromQueue(c *gin.Context) { + + msgQueueEngine := engine.GetMessageQueueEngine() + messages, err := msgQueueEngine.ListMessages("ingestion", false) + if err != nil { + errorResponse(c, err.Error(), 400) + return + } + var result []map[string]string + for _, message := range messages { + var taskMessage common.TaskMessage + err = json.Unmarshal([]byte(message["message"]), &taskMessage) + if err != nil { + return + } + result = append(result, map[string]string{ + "subject": message["subject"], + "id": taskMessage.TaskID, + "type": taskMessage.TaskType, + }) + } + + success(c, result, "List messages from queue successfully") } -func (h *Handler) StartIngestionTask(c *gin.Context) { - var req StartIngestionTaskRequest +type PublishMessageToQueueRequest struct { + Message string `json:"message" binding:"required"` +} + +func (h *Handler) PublishMessageToQueue(c *gin.Context) { + var req PublishMessageToQueueRequest if err := c.ShouldBindJSON(&req); err != nil { - errorResponse(c, "file uri and from is required", 400) + errorResponse(c, "message is required", 400) return } - taskID := common.GenerateUUID() - ingestionManager.SubmitTask(&common.TaskAssignment{ - TaskId: taskID, - TaskType: "start_ingestion_task", - Config: req.FileURI, - ComeFrom: req.From, - }) + taskMessage := common.TaskMessage{ + TaskID: req.Message, + TaskType: common.TaskTypeIngestionTest, + } - success(c, gin.H{"task_id": taskID}, "Send task for ingestion successfully") + // convert task + taskMessageStr, err := json.Marshal(taskMessage) + if err != nil { + errorResponse(c, err.Error(), 400) + return + } + + msgQueueEngine := engine.GetMessageQueueEngine() + err = msgQueueEngine.PublishTask("tasks.RAGFLOW", taskMessageStr) + if err != nil { + errorResponse(c, err.Error(), 400) + return + } + + success(c, nil, "Publish message successfully") +} + +type PullMessageFromQueueRequest struct { + MessageCount int `json:"message_count" binding:"required"` + AckPolicy string `json:"ack_policy" binding:"required"` +} + +func (h *Handler) PullMessageFromQueue(c *gin.Context) { + var req PullMessageFromQueueRequest + if err := c.ShouldBindJSON(&req); err != nil { + errorResponse(c, fmt.Sprintf("message count and ack_policy are required, error: %s", err.Error()), 400) + return + } + + msgQueueEngine := engine.GetMessageQueueEngine() + err := msgQueueEngine.InitConsumer("tasks.RAGFLOW") + if err != nil { + errorResponse(c, err.Error(), 400) + return + } + messages, err := msgQueueEngine.GetMessages(req.MessageCount) + var result []map[string]string + if req.AckPolicy == "ACK" { + for _, message := range messages { + taskMessage := message.GetMessage() + resultMessage := map[string]string{ + "id": taskMessage.TaskID, + "type": taskMessage.TaskType, + } + err = message.Ack() + if err == nil { + resultMessage["ack"] = "true" + } else { + resultMessage["ack"] = "false" + } + result = append(result, resultMessage) + } + } else { + for _, message := range messages { + taskMessage := message.GetMessage() + resultMessage := map[string]string{ + "id": taskMessage.TaskID, + "type": taskMessage.TaskType, + } + if err == nil { + resultMessage["nack"] = "true" + } else { + resultMessage["nack"] = "false" + } + result = append(result, resultMessage) + } + } + + success(c, result, "Pull messages from queue successfully") +} + +func (h *Handler) ShowMessageQueue(c *gin.Context) { + + msgQueueEngine := engine.GetMessageQueueEngine() + result, err := msgQueueEngine.ShowMessageQueue() + if err != nil { + errorResponse(c, err.Error(), 400) + return + } + + success(c, result, "show message queue successfully") +} + +type RemoveIngestionTaskRequest struct { + Tasks []string `json:"tasks" binding:"required"` +} + +func (h *Handler) RemoveIngestionTasks(c *gin.Context) { + var req RemoveIngestionTaskRequest + if err := c.ShouldBindJSON(&req); err != nil { + errorResponse(c, "task id is required", 400) + return + } + + tasks, err := h.service.RemoveIngestionTasks(req.Tasks) + if err != nil { + errorResponse(c, err.Error(), 400) + return + } + + success(c, tasks, "Remove tasks successfully") } type StopIngestionTaskRequest struct { - TaskID string `json:"task_id" binding:"required"` - From string `json:"from" binding:"required"` + Tasks []string `json:"tasks" binding:"required"` } -func (h *Handler) StopIngestionTask(c *gin.Context) { +func (h *Handler) StopIngestionTasks(c *gin.Context) { var req StopIngestionTaskRequest if err := c.ShouldBindJSON(&req); err != nil { errorResponse(c, "task id and from is required", 400) return } - ingestionManager.SubmitTask(&common.TaskAssignment{ - TaskId: req.TaskID, - TaskType: "cancel_ingestion_task", - ComeFrom: req.From, - }) + tasks, err := h.service.StopIngestionTasks(req.Tasks) + if err != nil { + errorResponse(c, err.Error(), 400) + return + } - success(c, gin.H{"task_id": req.TaskID}, "Cancel task successfully") + var result []map[string]string + for _, task := range tasks { + result = append(result, map[string]string{ + "task_id": task.ID, + "status": task.Status, + }) + } + + success(c, result, "Stop tasks successfully") } -func (h *Handler) ListIngestors(c *gin.Context) { - ingestionMgr := GetIngestionManager() - ingestors, err := ingestionMgr.ListIngestors() +// ListIngestionTasks +func (h *Handler) ListIngestionTasks(c *gin.Context) { + tasks, err := h.service.ListIngestionTasks() if err != nil { errorResponse(c, err.Error(), 500) } - success(c, ingestors, "Get all tasks") + success(c, tasks, "Get all tasks") +} + +func (h *Handler) ListIngestors(c *gin.Context) { + serverList := GlobalServerStore.ListInfos() + var ingestorResults []map[string]string + now := time.Now() + for _, ingestorServer := range serverList { + if ingestorServer.ServerType == common.ServerTypeIngestion { + ingestorResult := map[string]string{} + ingestorResult["name"] = ingestorServer.ServerName + ingestorResult["host"] = ingestorServer.Host + ingestorResult["status"] = ingestorServer.Version + if now.Sub(ingestorServer.Timestamp) < 30*time.Second { + ingestorResult["status"] = "alive" + } else { + ingestorResult["status"] = "timeout" + } + ingestorResults = append(ingestorResults, ingestorResult) + } + } + success(c, ingestorResults, "Get all tasks") } type ShutdownIngestorRequest struct { @@ -1321,11 +1463,11 @@ func (h *Handler) ShutdownIngestor(c *gin.Context) { } taskID := common.GenerateUUID() - ingestionManager.SubmitTask(&common.TaskAssignment{ - TaskId: taskID, - TaskType: "shutdown_ingestor", - AssignedTo: req.IngestorID, - }) + //ingestionManager.SubmitTask(&common.TaskAssignment{ + // TaskId: taskID, + // TaskType: "SHUTDOWN", + // AssignedTo: req.IngestorID, + //}) success(c, gin.H{"task_id": taskID, "ingestor_id": req.IngestorID}, "Shutdown ingestor") } @@ -1364,14 +1506,3 @@ func (h *Handler) Reports(c *gin.Context) { responseWithCode(c, message, http.StatusOK, errCode) } - -// ListIngestionTasks -func (h *Handler) ListIngestionTasks(c *gin.Context) { - tasks, err := h.service.ListIngestionTasks() - if err != nil { - errorResponse(c, err.Error(), 400) - return - } - - success(c, tasks, "") -} diff --git a/internal/admin/ingestion_manager.go b/internal/admin/ingestion_manager.go deleted file mode 100644 index ceffdf843e..0000000000 --- a/internal/admin/ingestion_manager.go +++ /dev/null @@ -1,587 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package admin - -import ( - "context" - "fmt" - "net" - "sync" - "time" - - "ragflow/internal/common" - - "google.golang.org/grpc" - "google.golang.org/grpc/peer" -) - -const heartbeatTimeout = 30 * time.Second - -type IngestionManager struct { - common.UnimplementedIngestionManagerServer - mu sync.RWMutex - - // Registered ingestion servers - ingestionServers map[string]*IngestorState // ingestor id -> ingestor id - - taskStates map[string]*TaskState // task_id -> task state - - // In-memory task queue - taskQueue chan *pendingTask - - // Notifies that an ingestor slot may have freed up - slotFreed chan struct{} - - grpcServer *grpc.Server // gRPC server instance for graceful shutdown via Stop() - - ctx context.Context - cancel context.CancelFunc -} - -type TaskState struct { - taskID string // same as task_id in database - status string // created, assigned, processing, completed, failed - comeFrom string // api server id - assignTo string // ingestor id - lastUpdate time.Time - startTime *time.Time - estimatedRemainingTime time.Duration // estimated cost in seconds to complete the task - errorMessage string -} - -type IngestorState struct { - ID string - Info *common.RegisterInfo - LastHeartbeat time.Time - Stream common.IngestionManager_ActionServer - Status string // active, draining - Address string - ProcessID int64 - cpuUsage float64 - vmsUsage float64 - rssUsage float64 -} - -type pendingTask struct { - Task *common.TaskAssignment - CreatedAt time.Time -} - -var ingestionManager *IngestionManager - -func GetIngestionManager() *IngestionManager { - return ingestionManager -} - -func NewAdminServer() *IngestionManager { - ctx, cancel := context.WithCancel(context.Background()) - ingestionManager = &IngestionManager{ - taskStates: make(map[string]*TaskState), - ingestionServers: make(map[string]*IngestorState), - taskQueue: make(chan *pendingTask, 10000), - slotFreed: make(chan struct{}, 100), - ctx: ctx, - cancel: cancel, - } - go ingestionManager.dispatchLoop() - //go ingestionManager.heartbeatCheckLoop() no need to check heartbeat timeout - return ingestionManager -} - -// Action handles the bidirectional streaming RPC from ingestion servers -func (s *IngestionManager) Action(stream common.IngestionManager_ActionServer) error { - var ingestionServerID string - var state *IngestorState - - common.Info("New ingestion_server connection") - - // Start receive goroutine - receiveErrorCH := make(chan error, 1) - go func() { - for { - msg, err := stream.Recv() - if err != nil { - receiveErrorCH <- err - return - } - s.handleMessage(stream, msg, &ingestionServerID, &state) - } - }() - - // Start send goroutine: send tasks immediately when assigned to this ingestion_server - sendDone := make(chan struct{}) - go func() { - defer close(sendDone) - for { - select { - case <-stream.Context().Done(): - return - case <-s.ctx.Done(): - return - } - } - }() - - select { - case err := <-receiveErrorCH: - // Connection dropped, clean up - s.cleanupIngestionServer(ingestionServerID) - return err - case <-sendDone: - // Stream context canceled (client disconnect or server shutdown) - s.cleanupIngestionServer(ingestionServerID) - return nil - } -} - -func (s *IngestionManager) handleMessage( - stream common.IngestionManager_ActionServer, - msg *common.IngestionMessage, - ingestionServerID *string, - state **IngestorState, -) { - switch msg.MessageType { - case "REGISTER": - s.handleRegister(stream, msg, ingestionServerID, state) - - case "HEARTBEAT": - s.handleHeartbeat(msg, *ingestionServerID, *state) - - case "TASK_RESULT": - s.handleTaskResult(msg, *ingestionServerID, *state) - - case "TASK_PROGRESS": - s.handleTaskProgress(msg, *ingestionServerID, *state) - - default: - common.Info(fmt.Sprintf("Unknown message type: %s", msg.MessageType)) - err := stream.Send(&common.AdminMessage{ - MessageType: "ERROR", - ErrorMessage: "unknown message type", - }) - if err != nil { - common.Error("Fail to send unknown message", err) - return - } - } -} - -func (s *IngestionManager) handleRegister( - stream common.IngestionManager_ActionServer, - msg *common.IngestionMessage, - ingestionServerID *string, - state **IngestorState, -) { - if msg.RegisterInfo == nil { - err := stream.Send(&common.AdminMessage{ - MessageType: "ERROR", - ErrorMessage: "missing register info", - }) - if err != nil { - common.Error("Fail to send missing register info", err) - return - } - return - } - - peerHost, ok := peer.FromContext(stream.Context()) - if !ok { - err := stream.Send(&common.AdminMessage{ - MessageType: "ERROR", - ErrorMessage: "peer not found in context", - }) - if err != nil { - common.Error("Fail to send 'peer not found' message", err) - return - } - return - } - clientAddr := peerHost.Addr.String() - - *ingestionServerID = msg.IngestorId - *state = &IngestorState{ - ID: msg.IngestorId, - Info: msg.RegisterInfo, - LastHeartbeat: time.Now().Truncate(time.Second), - Stream: stream, - Status: "active", - Address: clientAddr, - } - - s.mu.Lock() - s.ingestionServers[*ingestionServerID] = *state - s.mu.Unlock() - - err := stream.Send(&common.AdminMessage{ - MessageType: "ACK", - AckInfo: &common.AckInfo{ - TaskId: "", - Success: true, - Message: "registered successfully", - }, - }) - if err != nil { - common.Error("Fail to send ACK message", err) - return - } - - common.Info(fmt.Sprintf("Ingestor %s registered, max_concurrency=%d, supported_types=%v", - *ingestionServerID, msg.RegisterInfo.MaxConcurrency, msg.RegisterInfo.SupportedDocTypes)) -} - -func (s *IngestionManager) handleHeartbeat(msg *common.IngestionMessage, ingestorID string, state *IngestorState) { - if state == nil { - return - } - - state.LastHeartbeat = time.Now().Truncate(time.Second) - - if msg.HeartbeatInfo != nil { - - lastUpdateTime := time.Now().Truncate(time.Second) - s.mu.Lock() - ingestorState := s.ingestionServers[msg.IngestorId] - ingestorState.LastHeartbeat = lastUpdateTime - if ingestorState.Status == "timeout" { - ingestorState.Status = "active" - common.Info(fmt.Sprintf("Ingestor %s recovered from timeout, status set to active", msg.IngestorId)) - } - ingestorState.ProcessID = msg.HeartbeatInfo.ProcessId - ingestorState.cpuUsage = float64(msg.HeartbeatInfo.CpuUsage) - ingestorState.vmsUsage = float64(msg.HeartbeatInfo.VmsUsage) / 1024 / 1024 // in MB - ingestorState.rssUsage = float64(msg.HeartbeatInfo.RssUsage) / 1024 / 1024 // in MB - - // Delete expired terminal tasks from currentTasks - for _, taskID := range msg.HeartbeatInfo.DeleteTaskIds { - delete(s.taskStates, taskID) - } - - for _, ingestorTaskState := range msg.HeartbeatInfo.TaskStates { - localTaskState := s.taskStates[ingestorTaskState.TaskId] - if localTaskState == nil { - startTime := time.Unix(0, ingestorTaskState.StartTime) - localTaskState = &TaskState{ - taskID: ingestorTaskState.TaskId, - comeFrom: ingestorTaskState.ComeFrom, - startTime: &startTime, - } - } - localTaskState.estimatedRemainingTime = time.Duration(ingestorTaskState.EstimatedRemainingTimeSeconds) - localTaskState.lastUpdate = lastUpdateTime - localTaskState.status = ingestorTaskState.Status - localTaskState.errorMessage = ingestorTaskState.ErrorMessage - localTaskState.assignTo = msg.IngestorId - } - s.mu.Unlock() - - common.Debug(fmt.Sprintf("Heartbeat from %s", ingestorID)) - } -} - -func (s *IngestionManager) handleTaskResult(msg *common.IngestionMessage, ingestorID string, state *IngestorState) { - if msg.TaskResult == nil { - return - } - - result := msg.TaskResult - common.Info(fmt.Sprintf("Task result from %s: task=%s, status=%s, message=%s", ingestorID, result.TaskId, result.Status, result.ErrorMessage)) - - // Signal that a slot may have freed up for pending tasks - select { - case s.slotFreed <- struct{}{}: - default: - } -} - -func (s *IngestionManager) handleTaskProgress(msg *common.IngestionMessage, ingestorID string, state *IngestorState) { - if msg.TaskProgress == nil { - return - } - - progress := msg.TaskProgress - common.Info(fmt.Sprintf("Task progress from %s: task=%s, progress=%d%%, detail=%s", - ingestorID, progress.TaskId, progress.Progress, progress.Info)) -} - -// SubmitTask is for API Server to call (non-gRPC, for testing only) -func (s *IngestionManager) SubmitTask(task *common.TaskAssignment) { - s.taskQueue <- &pendingTask{ - Task: task, - CreatedAt: time.Now().Truncate(time.Second), - } - common.Info(fmt.Sprintf("Task %s submitted to queue", task.TaskId)) - - // Wake up dispatchLoop if it's blocked waiting for a slot - select { - case s.slotFreed <- struct{}{}: - default: - } -} - -// dispatchLoop pulls tasks from the queue and assigns them to available ingestors. -// Runs in a background goroutine. -func (s *IngestionManager) dispatchLoop() { - for { - select { - case <-s.ctx.Done(): - return - case pending := <-s.taskQueue: - go s.tryAssign(pending.Task) - } - } -} - -// heartbeatCheckLoop periodically checks all registered ingestors for heartbeat timeout. -// If an ingestor's LastHeartbeat is older than heartbeatTimeout, its status is set to "timeout". -func (s *IngestionManager) heartbeatCheckLoop() { - ticker := time.NewTicker(heartbeatTimeout / 3) - defer ticker.Stop() - - for { - select { - case <-s.ctx.Done(): - return - case <-ticker.C: - s.checkHeartbeats() - } - } -} - -func (s *IngestionManager) checkHeartbeats() { - s.mu.Lock() - defer s.mu.Unlock() - - now := time.Now().Truncate(time.Second) - for id, state := range s.ingestionServers { - if now.Sub(state.LastHeartbeat) > heartbeatTimeout { - if state.Status != "timeout" { - state.Status = "timeout" - common.Info(fmt.Sprintf("Ingestor %s heartbeat timeout, marked as timeout", id)) - } - } - } -} - -func (s *IngestionManager) SelectIngestorForTask(task *common.TaskAssignment) *IngestorState { - s.mu.Lock() - defer s.mu.Unlock() - - switch task.TaskType { - case "start_ingestion_task": - for _, ingestor := range s.ingestionServers { - if ingestor.Status == "active" { - s.taskStates[task.TaskId] = &TaskState{ - taskID: task.TaskId, - status: "DISPATCHED", - comeFrom: "CLI", - startTime: nil, - lastUpdate: time.Now().Truncate(time.Second), - assignTo: ingestor.ID, - } - return ingestor - } - } - case "cancel_ingestion_task": - taskState := s.taskStates[task.TaskId] - if taskState != nil { - switch taskState.status { - case "COMPLETED": - return nil - case "DISPATCHED": - { - taskState.status = "CANCELING" - return s.ingestionServers[taskState.assignTo] - } - default: - return s.ingestionServers[taskState.assignTo] - } - } - - case "shutdown_ingestor": - return s.ingestionServers[task.AssignedTo] - } - - return nil -} - -// tryAssign repeatedly tries to find an available ingestor and assign the task. -// Blocks until either the task is assigned or the context is canceled. -func (s *IngestionManager) tryAssign(task *common.TaskAssignment) { - for { - - target := s.SelectIngestorForTask(task) - if target != nil { - task.AssignedTo = target.ID - s.assignToIngestor(task, target) - return - } - - if task.TaskType == "start_ingestion_task" { - // Receives a start ingestion task, save and change the states - s.mu.Lock() - s.taskStates[task.TaskId] = &TaskState{ - taskID: task.TaskId, - status: "pending", - comeFrom: task.ComeFrom, - lastUpdate: time.Now().Truncate(time.Second), - startTime: nil, - } - s.mu.Unlock() - } else { - // shutdown ingestor or cancel task - common.Info("Task is completed, canceled, or ingestor is shutdown") - return - } - - // No ingestor available, wait for a slot to free up - select { - case <-s.ctx.Done(): - return - case <-s.slotFreed: - // A slot might be free, retry - case <-time.After(2 * time.Second): - // Periodic retry as fallback - } - } -} - -func (s *IngestionManager) assignToIngestor(task *common.TaskAssignment, state *IngestorState) { - err := state.Stream.Send(&common.AdminMessage{ - MessageType: "TASK_ASSIGNMENT", - TaskAssignment: task, - }) - if err != nil { - common.Info(fmt.Sprintf("Failed to assign task %s to ingestor %s: %v", task.TaskId, state.ID, err)) - // Re-queue the task - s.taskQueue <- &pendingTask{Task: task, CreatedAt: time.Now().Truncate(time.Second)} - return - } - common.Info(fmt.Sprintf("Assigned task %s to ingestion_server %s", task.TaskId, state.ID)) -} - -func (s *IngestionManager) cleanupIngestionServer(ingestorID string) { - s.mu.Lock() - defer s.mu.Unlock() - - if ingestorID == "" { - // Client disconnected before REGISTER completed — nothing to clean up - common.Info("Unregistered ingestion server disconnected") - return - } - - if _, exists := s.ingestionServers[ingestorID]; exists { - delete(s.ingestionServers, ingestorID) - common.Info(fmt.Sprintf("Ingestor %s cleaned up", ingestorID)) - - // Clean the tasks handled by this ingestor - var tasksToDelete []string - for _, taskState := range s.taskStates { - if taskState.assignTo == ingestorID { - tasksToDelete = append(tasksToDelete, taskState.taskID) - } - } - for _, taskID := range tasksToDelete { - delete(s.taskStates, taskID) - } - } -} - -func (s *IngestionManager) ListIngestors() ([]map[string]interface{}, error) { - s.mu.Lock() - defer s.mu.Unlock() - - var result []map[string]interface{} - for ingestorID, state := range s.ingestionServers { - - var taskCount int64 - for _, task := range s.taskStates { - if task.assignTo == ingestorID { - taskCount++ - } - } - - result = append(result, map[string]interface{}{ - "id": ingestorID, - "name": state.Info.Name, - "address": state.Address, - "last_heartbeat": state.LastHeartbeat, - "task_count": taskCount, - "status": state.Status, - "cpu_usage": state.cpuUsage, - "rss_usage": state.rssUsage, - "vms_usage": state.vmsUsage, - "process_id": state.ProcessID, - }) - } - return result, nil -} - -func (s *IngestionManager) ListIngestionTasks() ([]map[string]interface{}, error) { - - var result []map[string]interface{} - - s.mu.Lock() - defer s.mu.Unlock() - for index, taskState := range s.taskStates { - common.Info(fmt.Sprintf("Task %s: %s", index, taskState.taskID)) - result = append(result, map[string]interface{}{ - "id": taskState.taskID, - "status": taskState.status, - "from": taskState.comeFrom, - "assign_to": taskState.assignTo, - "last_update": taskState.lastUpdate, - "start_time": taskState.startTime, - "ETA": taskState.estimatedRemainingTime, - "error": taskState.errorMessage, - }) - } - - return result, nil -} - -// Start starts the admin service -func (s *IngestionManager) Start(port string) error { - lis, err := net.Listen("tcp", port) - if err != nil { - return err - } - - s.grpcServer = grpc.NewServer() - common.RegisterIngestionManagerServer(s.grpcServer, s) - - return s.grpcServer.Serve(lis) -} - -// Stop gracefully shuts down the admin service -func (s *IngestionManager) Stop() { - common.Info("Stopping RAGFlow ingestion manager...") - - // Notify all goroutines to exit - s.cancel() - - // Gracefully stop gRPC server (stop accepting new connections, wait for in-flight requests) - if s.grpcServer != nil { - s.grpcServer.GracefulStop() - } - - // Close the task queue - s.mu.Lock() - close(s.taskQueue) - s.mu.Unlock() - - common.Info("RAGFlow ingestion manager stopped") -} diff --git a/internal/admin/router.go b/internal/admin/router.go index 500b03ad86..a42f764d5a 100644 --- a/internal/admin/router.go +++ b/internal/admin/router.go @@ -46,6 +46,10 @@ func (r *Router) Setup(engine *gin.Engine) { admin.POST("/reports", r.handler.Reports) + //admin.POST("/ingestion/tasks", r.handler.StartIngestionTask) + //admin.DELETE("/ingestion", r.handler.CancelIngestionTask) // cancel ingestion + //admin.GET("/ingestion/tasks", r.handler.ListIngestionTasks) + // Protected routes protected := admin.Group("") protected.Use(r.handler.AuthMiddleware()) @@ -55,9 +59,6 @@ func (r *Router) Setup(engine *gin.Engine) { // Auth protected.GET("/auth", r.handler.AuthCheck) - // Tasks - protected.GET("/tasks", r.handler.ListTasks) - // User management protected.GET("/users", r.handler.ListUsers) protected.POST("/users", r.handler.CreateUser) @@ -137,12 +138,19 @@ func (r *Router) Setup(engine *gin.Engine) { provider.GET("/:provider_name/models/:model_name", r.handler.ShowModel) } + queue := protected.Group("/queue") + { + queue.GET("/", r.handler.ShowMessageQueue) + queue.POST("/messages", r.handler.PublishMessageToQueue) + queue.GET("/messages", r.handler.ListMessagesFromQueue) + queue.PUT("/messages", r.handler.PullMessageFromQueue) + } + protected.GET("/ingestors", r.handler.ListIngestors) protected.DELETE("/ingestors", r.handler.ShutdownIngestor) - protected.POST("/ingestion", r.handler.StartIngestionTask) // start ingestion - protected.DELETE("/ingestion", r.handler.StopIngestionTask) // stop ingestion + protected.DELETE("/ingestion/tasks", r.handler.RemoveIngestionTasks) + protected.PUT("/ingestion/tasks", r.handler.StopIngestionTasks) protected.GET("/ingestion/tasks", r.handler.ListIngestionTasks) - } } diff --git a/internal/admin/service.go b/internal/admin/service.go index d496ffefc2..d43fcfb31c 100644 --- a/internal/admin/service.go +++ b/internal/admin/service.go @@ -29,6 +29,7 @@ import ( "ragflow/internal/cache" "ragflow/internal/common" "ragflow/internal/dao" + "ragflow/internal/engine" "ragflow/internal/engine/elasticsearch" "ragflow/internal/entity" "ragflow/internal/server" @@ -43,45 +44,49 @@ import ( // Service admin service layer type Service struct { - userDAO *dao.UserDAO - licenseDAO *dao.LicenseDAO - timeRecordDAO *dao.TimeRecordDAO - systemSettingsDAO *dao.SystemSettingsDAO - tenantDAO *dao.TenantDAO - userTenantDAO *dao.UserTenantDAO - tenantLLMDAO *dao.TenantLLMDAO - fileDAO *dao.FileDAO - documentDAO *dao.DocumentDAO - taskDAO *dao.TaskDAO - kbDAO *dao.KnowledgebaseDAO - canvasDAO *dao.UserCanvasDAO - chatDAO *dao.ChatDAO - chatSessionDAO *dao.ChatSessionDAO - apiTokenDAO *dao.APITokenDAO - api4ConvDAO *dao.API4ConversationDAO - llmDAO *dao.LLMDAO + userDAO *dao.UserDAO + licenseDAO *dao.LicenseDAO + timeRecordDAO *dao.TimeRecordDAO + systemSettingsDAO *dao.SystemSettingsDAO + tenantDAO *dao.TenantDAO + userTenantDAO *dao.UserTenantDAO + tenantLLMDAO *dao.TenantLLMDAO + fileDAO *dao.FileDAO + documentDAO *dao.DocumentDAO + taskDAO *dao.TaskDAO + kbDAO *dao.KnowledgebaseDAO + canvasDAO *dao.UserCanvasDAO + chatDAO *dao.ChatDAO + chatSessionDAO *dao.ChatSessionDAO + apiTokenDAO *dao.APITokenDAO + api4ConvDAO *dao.API4ConversationDAO + llmDAO *dao.LLMDAO + ingestionTaskDAO *dao.IngestionTaskDAO + ingestionTaskLogDao *dao.IngestionTaskLogDAO } // NewService create admin service func NewService() *Service { return &Service{ - userDAO: dao.NewUserDAO(), - licenseDAO: dao.NewLicenseDAO(), - timeRecordDAO: dao.NewTimeRecordDAO(), - systemSettingsDAO: dao.NewSystemSettingsDAO(), - tenantDAO: dao.NewTenantDAO(), - userTenantDAO: dao.NewUserTenantDAO(), - tenantLLMDAO: dao.NewTenantLLMDAO(), - fileDAO: dao.NewFileDAO(), - documentDAO: dao.NewDocumentDAO(), - taskDAO: dao.NewTaskDAO(), - kbDAO: dao.NewKnowledgebaseDAO(), - canvasDAO: dao.NewUserCanvasDAO(), - chatDAO: dao.NewChatDAO(), - chatSessionDAO: dao.NewChatSessionDAO(), - apiTokenDAO: dao.NewAPITokenDAO(), - api4ConvDAO: dao.NewAPI4ConversationDAO(), - llmDAO: dao.NewLLMDAO(), + userDAO: dao.NewUserDAO(), + licenseDAO: dao.NewLicenseDAO(), + timeRecordDAO: dao.NewTimeRecordDAO(), + systemSettingsDAO: dao.NewSystemSettingsDAO(), + tenantDAO: dao.NewTenantDAO(), + userTenantDAO: dao.NewUserTenantDAO(), + tenantLLMDAO: dao.NewTenantLLMDAO(), + fileDAO: dao.NewFileDAO(), + documentDAO: dao.NewDocumentDAO(), + taskDAO: dao.NewTaskDAO(), + kbDAO: dao.NewKnowledgebaseDAO(), + canvasDAO: dao.NewUserCanvasDAO(), + chatDAO: dao.NewChatDAO(), + chatSessionDAO: dao.NewChatSessionDAO(), + apiTokenDAO: dao.NewAPITokenDAO(), + api4ConvDAO: dao.NewAPI4ConversationDAO(), + llmDAO: dao.NewLLMDAO(), + ingestionTaskDAO: dao.NewIngestionTaskDAO(), + ingestionTaskLogDao: dao.NewIngestionTaskLogDAO(), } } @@ -96,51 +101,99 @@ func (s *Service) Logout(user interface{}) error { } // ListTasks -func (s *Service) ListTasks() ([]map[string]interface{}, error) { +func (s *Service) ListIngestionTasks() ([]map[string]interface{}, error) { - //tasks, err := s.taskDAO.GetAllTasks() - //if err != nil { - // return nil, err - //} - // - //var result []map[string]interface{} - //for _, task := range tasks { - // // task.ChunkIDs is a string, delimiter is space, count the word count - // ChunkCount := strings.Count(*task.ChunkIDs, " ") - // result = append(result, map[string]interface{}{ - // "id": task.ID, - // "task_type": task.TaskType, - // "document_id": task.DocID, - // "chunk_count": ChunkCount, - // "from_page": task.FromPage, - // "to_page": task.ToPage, - // "priority": task.Priority, - // "duration": task.ProcessDuration, - // "progress": task.Progress, - // //"message": *task.ProgressMsg, - // "retry_count": task.RetryCount, - // "digest": task.Digest, - // }) - //} - - ingestionMgr := GetIngestionManager() - ingestionTasks, err := ingestionMgr.ListIngestionTasks() + ingestionTasks, err := s.ingestionTaskDAO.GetAllTasks(0, 0) if err != nil { - return nil, fmt.Errorf("fail to list ingestion tasks") + return nil, err } - return ingestionTasks, nil + showTasks := []map[string]interface{}{} + for _, task := range ingestionTasks { + var user *entity.User + user, err = s.userDAO.GetByTenantID(task.UserID) + if err != nil { + return nil, err + } + //var document *entity.Document + //document, err = s.documentDAO.GetByID(task.DocumentID) + //if err != nil { + // return nil, err + //} + + var showTask map[string]interface{} + var latestLog *entity.IngestionTaskLog + latestLog, err = s.ingestionTaskLogDao.LatestLogByTaskID(task.ID) + showTask = map[string]interface{}{ + "id": task.ID, + "user_id": task.UserID, + "user": user.Email, + "document_id": task.DocumentID, + "status": task.Status, + } + if err == nil { + showTask = map[string]interface{}{ + "id": task.ID, + "user_id": task.UserID, + "user": user.Email, + "document_id": task.DocumentID, + "status": task.Status, + "step": int(latestLog.Checkpoint["current_step"].(float64)), + } + } + + showTasks = append(showTasks, showTask) + } + return showTasks, nil +} + +func (s *Service) RemoveIngestionTasks(tasks []string) ([]map[string]string, error) { + var deletedTasks []map[string]string + for _, taskID := range tasks { + taskRecord := map[string]string{ + "task_id": taskID, + } + _, err := s.ingestionTaskDAO.RemoveByAPIServerOrAdminServer(taskID, nil) + if err != nil { + taskRecord["remove"] = fmt.Sprintf("fail: %s", err.Error()) + } else { + taskRecord["remove"] = "success" + } + deletedTasks = append(deletedTasks, taskRecord) + } + return deletedTasks, nil +} + +func (s *Service) StopIngestionTasks(tasks []string) ([]*entity.IngestionTask, error) { + var taskResponses []*entity.IngestionTask + for _, taskID := range tasks { + task, err := s.ingestionTaskDAO.SetStoppingByAPIServer(taskID) + if err != nil { + return nil, err + } + + if task.Status == common.STOPPING { + msgQueueEngine := engine.GetMessageQueueEngine() + err = msgQueueEngine.PublishTask("tasks.RAGFLOW", []byte(task.ID)) + if err != nil { + return nil, err + } + } + + taskResponses = append(taskResponses, task) + } + return taskResponses, nil } // GetUserByToken get user by access token func (s *Service) GetUserByToken(token string) (*entity.User, error) { user, err := s.userDAO.GetByAccessToken(token) if err != nil { - return nil, ErrInvalidToken + return nil, common.ErrInvalidToken } if user.IsSuperuser == nil || !*user.IsSuperuser { - return nil, ErrNotAdmin + return nil, common.ErrNotAdmin } if user.IsActive != "1" { @@ -477,7 +530,7 @@ func (s *Service) GetUserDetails(username string) (map[string]interface{}, error var user entity.User err := dao.DB.Where("email = ?", username).First(&user).Error if err != nil { - return nil, ErrUserNotFound + return nil, common.ErrUserNotFound } return map[string]interface{}{ @@ -1050,11 +1103,11 @@ func (s *Service) ListServices() ([]map[string]interface{}, error) { } result = append(result, configDict) } - } id := len(result) serverList := GlobalServerStore.ListInfos() + now := time.Now() for _, serverStatus := range serverList { serverItem := make(map[string]interface{}) serverItem["name"] = serverStatus.ServerName @@ -1063,7 +1116,12 @@ func (s *Service) ListServices() ([]map[string]interface{}, error) { id++ serverItem["host"] = serverStatus.Host serverItem["port"] = serverStatus.Port - serverItem["status"] = "alive" + // the difference between now and serverStatus.Timestamp is less than 5 seconds, then the server is alive + if now.Sub(serverStatus.Timestamp) < 30*time.Second { + serverItem["status"] = "alive" + } else { + serverItem["status"] = "timeout" + } result = append(result, serverItem) } return result, nil @@ -1701,11 +1759,6 @@ func (s *Service) HandleHeartbeat(message *common.BaseMessage) (common.ErrorCode return common.CodeLicenseValid, "" } -func (s *Service) ListIngestionTasks() ([]map[string]interface{}, error) { - // TODO: Implement with sandbox manager - return []map[string]interface{}{}, nil -} - // InitDefaultAdmin initialize default admin user // This matches Python's init_default_admin behavior func (s *Service) InitDefaultAdmin() error { diff --git a/internal/admin/state.go b/internal/admin/state.go index 5475bf6467..0bf94ef15f 100644 --- a/internal/admin/state.go +++ b/internal/admin/state.go @@ -17,20 +17,11 @@ package admin import ( - "errors" "ragflow/internal/common" "sync" "time" ) -// Service errors -var ( - ErrInvalidToken = errors.New("invalid token") - ErrNotAdmin = errors.New("user is not admin") - ErrUserInactive = errors.New("user is inactive") - ErrUserNotFound = errors.New("user not found") -) - // API server state // ServerStore is a thread-safe global server status storage @@ -58,6 +49,9 @@ func (s *ServerStore) UpdateServerInfo(serverName string, status *common.BaseMes s.servers[serverName] = status return case common.ServerTypeIngestion: + s.mu.Lock() + defer s.mu.Unlock() + s.servers[serverName] = status return } } diff --git a/internal/cli/admin_command.go b/internal/cli/admin_command.go index 7ebb3a0a7a..00f2827c99 100644 --- a/internal/cli/admin_command.go +++ b/internal/cli/admin_command.go @@ -1307,6 +1307,7 @@ func (c *CLI) ListAdminIngestors(cmd *Command) (ResponseIf, error) { result.Duration = resp.Duration return &result, nil } + func (c *CLI) ListAdminIngestionTasks(cmd *Command) (ResponseIf, error) { if c.Config.CLIMode != AdminMode || c.AdminServerClient.LoginToken == nil { return nil, fmt.Errorf("this command is only allowed in ADMIN mode or already login") @@ -1334,21 +1335,20 @@ func (c *CLI) ListAdminIngestionTasks(cmd *Command) (ResponseIf, error) { return &result, nil } -func (c *CLI) AdminStartIngestionCommand(cmd *Command) (ResponseIf, error) { +func (c *CLI) AdminStopIngestionCommand(cmd *Command) (ResponseIf, error) { if c.Config.CLIMode != AdminMode || c.AdminServerClient.LoginToken == nil { return nil, fmt.Errorf("this command is only allowed in ADMIN mode or already login") } - fileURI, ok := cmd.Params["uri"].(string) + tasks, ok := cmd.Params["tasks"].([]string) if !ok { return nil, fmt.Errorf("uri not provided") } payload := map[string]interface{}{ - "uri": fileURI, - "from": "CLI", + "tasks": tasks, } - resp, err := c.AdminServerClient.Request("POST", "/admin/ingestion", "admin", nil, payload) + resp, err := c.AdminServerClient.Request("PUT", "/admin/ingestion/tasks", "admin", nil, payload) if err != nil { return nil, fmt.Errorf("failed to ingest file: %w", err) } @@ -1357,7 +1357,7 @@ func (c *CLI) AdminStartIngestionCommand(cmd *Command) (ResponseIf, error) { return nil, fmt.Errorf("failed to ingest file: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) } - var result CommonDataResponse + var result CommonResponse if err = json.Unmarshal(resp.Body, &result); err != nil { return nil, fmt.Errorf("ingest file failed: invalid JSON (%w)", err) } @@ -1370,21 +1370,20 @@ func (c *CLI) AdminStartIngestionCommand(cmd *Command) (ResponseIf, error) { return &result, nil } -func (c *CLI) AdminStopIngestionCommand(cmd *Command) (ResponseIf, error) { +func (c *CLI) AdminRemoveIngestionCommand(cmd *Command) (ResponseIf, error) { if c.Config.CLIMode != AdminMode || c.AdminServerClient.LoginToken == nil { return nil, fmt.Errorf("this command is only allowed in ADMIN mode or already login") } - taskID, ok := cmd.Params["task_id"].(string) + tasks, ok := cmd.Params["tasks"].([]string) if !ok { return nil, fmt.Errorf("uri not provided") } payload := map[string]interface{}{ - "task_id": taskID, - "from": "CLI", + "tasks": tasks, } - resp, err := c.AdminServerClient.Request("DELETE", "/admin/ingestion", "admin", nil, payload) + resp, err := c.AdminServerClient.Request("DELETE", "/admin/ingestion/tasks", "admin", nil, payload) if err != nil { return nil, fmt.Errorf("failed to ingest file: %w", err) } @@ -1393,7 +1392,7 @@ func (c *CLI) AdminStopIngestionCommand(cmd *Command) (ResponseIf, error) { return nil, fmt.Errorf("failed to ingest file: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) } - var result CommonDataResponse + var result CommonResponse if err = json.Unmarshal(resp.Body, &result); err != nil { return nil, fmt.Errorf("ingest file failed: invalid JSON (%w)", err) } @@ -1440,3 +1439,176 @@ func (c *CLI) AdminShutdownIngestor(cmd *Command) (ResponseIf, error) { result.Duration = resp.Duration return &result, nil } + +func (c *CLI) UserListMessageQueueCommand(cmd *Command) (ResponseIf, error) { + if c.Config.CLIMode != AdminMode || c.AdminServerClient.LoginToken == nil { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode or already login") + } + + pending, ok := cmd.Params["pending"].(bool) + if !ok { + return nil, fmt.Errorf("pending not provided") + } + payload := map[string]interface{}{ + "pending": pending, + } + + resp, err := c.AdminServerClient.Request("GET", "/admin/queue/messages", "admin", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to list tasks in message queue: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list tasks in message queue: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("list tasks in message queue failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} + +func (c *CLI) UserPublishMessageCommand(cmd *Command) (ResponseIf, error) { + if c.Config.CLIMode != AdminMode || c.AdminServerClient.LoginToken == nil { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode or already login") + } + + message, ok := cmd.Params["message"].(string) + if !ok { + return nil, fmt.Errorf("message not provided") + } + payload := map[string]interface{}{ + "message": message, + } + + resp, err := c.AdminServerClient.Request("POST", "/admin/queue/messages", "admin", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to publish message: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to publish message: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result SimpleResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("publish message failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} + +func (c *CLI) UserPullMessageCommand(cmd *Command) (ResponseIf, error) { + if c.Config.CLIMode != AdminMode || c.AdminServerClient.LoginToken == nil { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode or already login") + } + + messageCount, ok := cmd.Params["message_count"].(int) + if !ok { + return nil, fmt.Errorf("message_count not provided") + } + ackPolicy, ok := cmd.Params["ack_policy"].(string) + if !ok { + return nil, fmt.Errorf("ack_policy not provided") + } + + payload := map[string]interface{}{ + "message_count": messageCount, + "ack_policy": ackPolicy, + } + + resp, err := c.AdminServerClient.Request("PUT", "/admin/queue/messages", "admin", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to pull message: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to pull message: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("pull message failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} + +func (c *CLI) UserShowMessageQueueCommand(cmd *Command) (ResponseIf, error) { + if c.Config.CLIMode != AdminMode || c.AdminServerClient.LoginToken == nil { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode or already login") + } + + resp, err := c.AdminServerClient.Request("GET", "/admin/queue", "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to show message queue: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to show message queue: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonDataResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("show message queue failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} + +func (c *CLI) AdminRemoveServiceCommand(cmd *Command) (ResponseIf, error) { + if c.Config.CLIMode != AdminMode || c.AdminServerClient.LoginToken == nil { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode or already login") + } + serviceNumber, ok := cmd.Params["service_number"].(int) + if !ok { + return nil, fmt.Errorf("service_number not provided") + } + + payload := map[string]interface{}{ + "service_number": serviceNumber, + } + + resp, err := c.AdminServerClient.Request("DELETE", "/admin/services", "admin", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to remove unavailable service: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to remove unavailable service: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result SimpleResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("remove unavailable service failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} diff --git a/internal/cli/admin_parser.go b/internal/cli/admin_parser.go index b4bf964cec..08654a898c 100644 --- a/internal/cli/admin_parser.go +++ b/internal/cli/admin_parser.go @@ -16,7 +16,10 @@ package cli -import "fmt" +import ( + "fmt" + "strings" +) // Command parsers func (p *Parser) parseAdminLoginUser() (*Command, error) { @@ -190,8 +193,6 @@ func (p *Parser) parseAdminListCommand() (*Command, error) { return NewCommand("list_user_chats"), nil case TokenFiles: return p.parseAdminListFiles() - case TokenTasks: - return p.parseAdminListTasks() case TokenIngestors: return p.parseAdminListIngestors() case TokenIngestion: @@ -376,12 +377,6 @@ func (p *Parser) parseAdminListFiles() (*Command, error) { return cmd, nil } -func (p *Parser) parseAdminListTasks() (*Command, error) { - p.nextToken() // consume TASKS - cmd := NewCommand("list_admin_tasks") - return cmd, nil -} - func (p *Parser) parseAdminListIngestors() (*Command, error) { p.nextToken() // consume TASKS cmd := NewCommand("admin_list_ingestors") @@ -389,14 +384,80 @@ func (p *Parser) parseAdminListIngestors() (*Command, error) { return cmd, nil } +func (p *Parser) parseAdminStopIngestionTasks() (*Command, error) { + p.nextToken() // consume STOP + + if p.curToken.Type != TokenIngestion { + return nil, fmt.Errorf("expected INGESTION") + } + p.nextToken() + + if p.curToken.Type != TokenTasks { + return nil, fmt.Errorf("expected TASKS") + } + p.nextToken() // consume TASK + + taskString, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + tasks := strings.Split(taskString, " ") + p.nextToken() // consume TASK + + cmd := NewCommand("admin_stop_ingestion_tasks") + cmd.Params["tasks"] = tasks + + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + + return cmd, nil +} + +func (p *Parser) parseAdminRemoveIngestionTasks() (*Command, error) { + p.nextToken() // consume Ingestion + + if p.curToken.Type != TokenTasks { + return nil, fmt.Errorf("expected TASKS") + } + p.nextToken() // consume TASKS + + taskString, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + tasks := strings.Split(taskString, " ") + p.nextToken() // consume TASKS + + cmd := NewCommand("admin_remove_ingestion_tasks") + cmd.Params["tasks"] = tasks + + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + + return cmd, nil +} + func (p *Parser) parseAdminListIngestionTasks() (*Command, error) { p.nextToken() // consume Ingestion if p.curToken.Type != TokenTasks { return nil, fmt.Errorf("expected TASKS") } + p.nextToken() // consume TASKS cmd := NewCommand("list_admin_ingestion_tasks") + + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return cmd, nil } @@ -1752,7 +1813,6 @@ func (p *Parser) parseAdminIngestCommand() (*Command, error) { } return cmd, nil } - func (p *Parser) parseAdminUnsetCommand() (*Command, error) { p.nextToken() // consume UNSET @@ -1767,3 +1827,94 @@ func (p *Parser) parseAdminUnsetCommand() (*Command, error) { } return NewCommand("unset_token"), nil } + +func (p *Parser) parseMessageQueueCommand() (*Command, error) { + p.nextToken() // consume MESSAGE_QUEUE + + var cmd *Command + switch p.curToken.Type { + case TokenShow: + p.nextToken() + cmd = NewCommand("user_show_message_queue_command") + + case TokenList: + p.nextToken() // consume LIST + + cmd = NewCommand("user_list_message_queue_command") + if p.curToken.Type == TokenPending { + cmd.Params["pending"] = true + p.nextToken() // consume PENDING + } else { + cmd.Params["pending"] = false + } + case TokenPublish: + p.nextToken() // consume PUBLISH + + message, err := p.parseQuotedString() + if err != nil { + return nil, fmt.Errorf("expected message after PUBLISH") + } + p.nextToken() // consume message + + cmd = NewCommand("user_publish_message_command") + cmd.Params["message"] = message + case TokenPull: + p.nextToken() // consume PULL + + messageCount, err := p.parseNumber() + if err != nil { + messageCount = 1 + } else { + p.nextToken() // consume NUMBER + } + + if messageCount <= 0 || messageCount > 100 { + return nil, fmt.Errorf("message count cannot be less than 0 or greater than 100") + } + + cmd = NewCommand("user_pull_message_command") + cmd.Params["message_count"] = messageCount + + if p.curToken.Type == TokenNoACK { + cmd.Params["ack_policy"] = "NOACK" + p.nextToken() // consume NOACK + } else { + cmd.Params["ack_policy"] = "ACK" + } + + default: + return nil, fmt.Errorf("expected WITH") + } + + // Semicolon is optional + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return cmd, nil +} + +func (p *Parser) parseAdminRemoveCommand() (*Command, error) { + p.nextToken() // consume MESSAGE_QUEUE + + var cmd *Command + switch p.curToken.Type { + case TokenService: + p.nextToken() // consume SERVICE + serviceNum, err := p.parseNumber() + if err != nil { + return nil, fmt.Errorf("expected service number after SERVICE") + } + p.nextToken() // consume service number + cmd = NewCommand("admin_remove_service_command") + cmd.Params["service_number"] = serviceNum + case TokenIngestion: + return p.parseAdminRemoveIngestionTasks() + default: + return nil, fmt.Errorf("expected SERVICE") + } + // Semicolon is optional + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return cmd, nil +} diff --git a/internal/cli/cli_http.go b/internal/cli/cli_http.go index 54f489e873..6c94a0816f 100644 --- a/internal/cli/cli_http.go +++ b/internal/cli/cli_http.go @@ -16,7 +16,9 @@ package cli -import "fmt" +import ( + "fmt" +) // ExecuteCommand executes a parsed command // Returns benchmark result map for commands that support it (e.g., ping_server with iterations > 1) @@ -103,14 +105,25 @@ func (c *CLI) ExecuteAdminCommand(cmd *Command) (ResponseIf, error) { return c.ListAdminTasks(cmd) case "admin_list_ingestors": return c.ListAdminIngestors(cmd) - case "admin_start_ingestion_command": - return c.AdminStartIngestionCommand(cmd) - case "admin_stop_ingestion_command": + case "admin_stop_ingestion_tasks": return c.AdminStopIngestionCommand(cmd) + case "admin_remove_ingestion_tasks": + return c.AdminRemoveIngestionCommand(cmd) case "admin_shutdown_ingestor_command": return c.AdminShutdownIngestor(cmd) case "list_admin_ingestion_tasks": return c.ListAdminIngestionTasks(cmd) + case "user_list_message_queue_command": + return c.UserListMessageQueueCommand(cmd) + case "user_publish_message_command": + return c.UserPublishMessageCommand(cmd) + case "user_pull_message_command": + return c.UserPullMessageCommand(cmd) + case "user_show_message_queue_command": + return c.UserShowMessageQueueCommand(cmd) + case "admin_remove_service_command": + return c.AdminRemoveServiceCommand(cmd) + // TODO: Implement other commands case "show_admin_server": return c.ShowAdminServer(cmd) case "show_api_server": @@ -285,6 +298,15 @@ func (c *CLI) ExecuteUserCommand(cmd *Command) (ResponseIf, error) { return c.GetMetadata(cmd) case "parse_documents_user_command": return c.ParseDocumentsUserCommand(cmd) + case "user_start_ingestion_command": + return c.UserStartIngestionCommand(cmd) + case "user_stop_ingestion_command": + return c.UserStopIngestionCommand(cmd) + case "user_list_ingestion_tasks": + return c.ListUserIngestionTasks(cmd) + case "user_remove_task_command": + return c.UserRemoveTaskCommand(cmd) + // TODO: Implement other commands case "user_parse_local_file_command": return c.UserParseLocalFile(cmd) case "show_admin_server": diff --git a/internal/cli/lexer.go b/internal/cli/lexer.go index ae1159b461..ceefad9992 100644 --- a/internal/cli/lexer.go +++ b/internal/cli/lexer.go @@ -477,6 +477,16 @@ func (l *Lexer) lookupIdent(ident string) Token { return Token{Type: TokenIngestors, Value: ident} case "INGESTION": return Token{Type: TokenIngestion, Value: ident} + case "MQ": + return Token{Type: TokenMQ, Value: ident} + case "PUBLISH": + return Token{Type: TokenPublish, Value: ident} + case "PULL": + return Token{Type: TokenPull, Value: ident} + case "PENDING": + return Token{Type: TokenPending, Value: ident} + case "NOACK": + return Token{Type: TokenNoACK, Value: ident} case "LOG": return Token{Type: TokenLog, Value: ident} case "LEVEL": diff --git a/internal/cli/parser.go b/internal/cli/parser.go index 3ff50d1ee8..e3ceffb31e 100644 --- a/internal/cli/parser.go +++ b/internal/cli/parser.go @@ -124,10 +124,12 @@ func (p *Parser) parseAdminCommand() (*Command, error) { return p.parseAdminShutdownCommand() case TokenRestart: return p.parseAdminRestartCommand() - case TokenStart: - return p.parseStartIngestion() + case TokenMQ: + return p.parseMessageQueueCommand() + case TokenRemove: + return p.parseAdminRemoveCommand() case TokenStop: - return p.parseStopIngestion() + return p.parseAdminStopIngestionTasks() case TokenAdd: return p.parseAdminAddCommand() case TokenDelete: @@ -216,6 +218,11 @@ func (p *Parser) parseUserCommand() (*Command, error) { return p.parseOCRCommand() case TokenCheck: return p.parseCheckCommand() + case TokenStart: + return p.parseUserStartIngestion() + case TokenStop: + return p.parseUserStopIngestion() + case TokenSave: return p.parseUserSaveCommand() case TokenUse: diff --git a/internal/cli/types.go b/internal/cli/types.go index 329ff70d71..ebff29bbcf 100644 --- a/internal/cli/types.go +++ b/internal/cli/types.go @@ -170,6 +170,11 @@ const ( TokenStart TokenStop TokenIngestion + TokenMQ + TokenPublish + TokenPull + TokenPending + TokenNoACK TokenLog TokenLevel TokenDebug diff --git a/internal/cli/user_command.go b/internal/cli/user_command.go index 1d1c85b06c..858e72ebef 100644 --- a/internal/cli/user_command.go +++ b/internal/cli/user_command.go @@ -3273,6 +3273,173 @@ func formatRequestError(action string, err error) error { } } +func (c *CLI) ListUserIngestionTasks(cmd *Command) (ResponseIf, error) { + if c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer].APIToken == nil && c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer].LoginToken == nil { + return nil, fmt.Errorf("API token not set. Please login first") + } + + if c.Config.CLIMode != APIMode { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + datasetID, ok := cmd.Params["dataset_id"].(*string) + if !ok { + datasetID = nil + } + + payload := map[string]interface{}{ + "dataset_id": datasetID, + } + + resp, err := c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer].Request("GET", "/datasets/ingestion/tasks", "web", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to list ingestion tasks: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list ingestion tasks:: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("list ingestion tasks: failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} + +func (c *CLI) UserStartIngestionCommand(cmd *Command) (ResponseIf, error) { + if c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer].APIToken == nil && c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer].LoginToken == nil { + return nil, fmt.Errorf("API token not set. Please login first") + } + + if c.Config.CLIMode != APIMode { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + documentID, ok := cmd.Params["document_id"].(string) + if !ok { + return nil, fmt.Errorf("document_id not provided") + } + + datasetID, ok := cmd.Params["dataset_id"].(string) + if !ok { + return nil, fmt.Errorf("dataset_id not provided") + } + + payload := map[string]interface{}{ + "documents": []string{documentID}, + "dataset_id": datasetID, + } + + url := fmt.Sprintf("/datasets/%s/documents/parse", datasetID) + + resp, err := c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer].Request("POST", url, "web", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to ingest file: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to ingest file: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("ingest file failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} + +func (c *CLI) UserStopIngestionCommand(cmd *Command) (ResponseIf, error) { + if c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer].APIToken == nil && c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer].LoginToken == nil { + return nil, fmt.Errorf("API token not set. Please login first") + } + + if c.Config.CLIMode != APIMode { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + tasks, ok := cmd.Params["tasks"].([]string) + if !ok { + return nil, fmt.Errorf("uri not provided") + } + payload := map[string]interface{}{ + "tasks": tasks, + } + + resp, err := c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer].Request("PUT", "/datasets/ingestion/tasks", "web", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to ingest file: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to ingest file: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("ingest file failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} + +func (c *CLI) UserRemoveTaskCommand(cmd *Command) (ResponseIf, error) { + if c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer].APIToken == nil && c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer].LoginToken == nil { + return nil, fmt.Errorf("API token not set. Please login first") + } + + if c.Config.CLIMode != APIMode { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + tasks, ok := cmd.Params["tasks"].([]string) + if !ok { + return nil, fmt.Errorf("tasks not provided") + } + + payload := map[string]interface{}{ + "tasks": tasks, + } + + resp, err := c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer].Request("DELETE", "/datasets/ingestion/tasks", "web", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to remove tasks: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to remove tasks: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("remove tasks failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} + func (c *CLI) ChunkCommand(cmd *Command) (ResponseIf, error) { if c.Config.CLIMode != APIMode { return nil, fmt.Errorf("this command is only allowed in USER mode") @@ -3294,12 +3461,12 @@ func (c *CLI) ChunkCommand(cmd *Command) (ResponseIf, error) { if explain { fmt.Printf("Explain chunk file: %s, DSL: %s\n", filename, dsl) } else { + // TODO: not implemented fmt.Printf("Chunk file: %s, DSL: %s\n", filename, dsl) } var result SimpleResponse result.Code = 0 result.Message = fmt.Sprintf("Success to chunk %s", filename) - return &result, nil } diff --git a/internal/cli/user_parser.go b/internal/cli/user_parser.go index 693dce6b47..8efdd19fbb 100644 --- a/internal/cli/user_parser.go +++ b/internal/cli/user_parser.go @@ -159,6 +159,8 @@ func (p *Parser) parseListCommand() (*Command, error) { return p.parseListProviders() case TokenInstances: return p.parseListInstances() + case TokenIngestion: + return p.parseUserListIngestionTasks() case TokenDefault: return p.parseListDefaultModels() case TokenAvailable: @@ -1387,6 +1389,8 @@ func (p *Parser) parseRemoveCommand() (*Command, error) { return p.parseRemoveTags() case TokenChunks, TokenAll: return p.parseRemoveChunk() + case TokenIngestion: + return p.parseUserRemoveTask() case TokenModel: return p.parseRemoveInstanceModel() default: @@ -3892,7 +3896,7 @@ optionsLoop: if err != nil { return nil, err } - cmd.Params["embed_model"] = embedModel + cmd.Params["embedding_model"] = embedModel p.nextToken() case TokenDocParse: p.nextToken() @@ -4474,6 +4478,42 @@ func (p *Parser) parseRemoveChunk() (*Command, error) { return cmd, nil } +func (p *Parser) parseUserStartIngestion() (*Command, error) { + p.nextToken() // consume Start + + if p.curToken.Type != TokenIngestion { + return nil, fmt.Errorf("expect INGESTION") + } + p.nextToken() // consume Ingestion + + documentID, err := p.parseQuotedString() + if err != nil { + return nil, err + } + p.nextToken() + + if p.curToken.Type != TokenFrom { + return nil, fmt.Errorf("expect FROM") + } + p.nextToken() // consume FROM + + datasetID, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("user_start_ingestion_command") + cmd.Params["document_id"] = documentID + cmd.Params["dataset_id"] = datasetID + p.nextToken() + + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return cmd, nil +} + // parseShowTask parses SHOW ADMIN SERVER func (p *Parser) parseUserShowAdmin() (*Command, error) { p.nextToken() // consume ADMIN @@ -4494,6 +4534,90 @@ func (p *Parser) parseUserShowAdmin() (*Command, error) { return cmd, nil } +func (p *Parser) parseUserStopIngestion() (*Command, error) { + p.nextToken() // consume Stop + + if p.curToken.Type != TokenIngestion { + return nil, fmt.Errorf("expect INGESTION") + } + p.nextToken() // consume Ingestion + + if p.curToken.Type != TokenTasks { + return nil, fmt.Errorf("expect TASKS") + } + p.nextToken() + + taskStr, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + tasks := strings.Split(taskStr, " ") + + cmd := NewCommand("user_stop_ingestion_command") + cmd.Params["tasks"] = tasks + p.nextToken() + + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return cmd, nil +} + +func (p *Parser) parseUserListIngestionTasks() (*Command, error) { + p.nextToken() // consume Ingestion + + if p.curToken.Type != TokenTasks { + return nil, fmt.Errorf("expected TASKS") + } + p.nextToken() // consume TASKS + + cmd := NewCommand("user_list_ingestion_tasks") + + if p.curToken.Type == TokenFrom { + p.nextToken() + datasetID, err := p.parseQuotedString() + if err != nil { + return nil, err + } + cmd.Params["dataset_id"] = datasetID + } + + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return cmd, nil +} + +func (p *Parser) parseUserRemoveTask() (*Command, error) { + p.nextToken() // consume Ingestion + + if p.curToken.Type != TokenTasks { + return nil, fmt.Errorf("expected TASKS") + } + p.nextToken() // consume TASKS + + taskStr, err := p.parseQuotedString() + if err != nil { + return nil, err + } + + cmd := NewCommand("user_remove_task_command") + + tasks := strings.Split(taskStr, " ") + + cmd.Params["tasks"] = tasks + + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + + return cmd, nil +} + // parseShowTask parses SHOW API SERVER func (p *Parser) parseUserShowAPI() (*Command, error) { p.nextToken() // consume API diff --git a/internal/common/constants.go b/internal/common/constants.go index 4c3283b3a7..64f4a9f0bf 100644 --- a/internal/common/constants.go +++ b/internal/common/constants.go @@ -1,3 +1,19 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + package common const ( @@ -11,3 +27,13 @@ const ( // request can return per search_after iteration. SearchAfterBatchSize = 1000 ) + +// task status +const ( + CREATED = "CREATED" + RUNNING = "RUNNING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + STOPPED = "STOPPED" + STOPPING = "STOPPING" +) diff --git a/internal/common/error_code.go b/internal/common/error_code.go index 912d7bb6d7..d19e461d5e 100644 --- a/internal/common/error_code.go +++ b/internal/common/error_code.go @@ -16,6 +16,8 @@ package common +import "errors" + type ErrorCode int const ( @@ -82,3 +84,15 @@ func (e ErrorCode) Message() string { } return "Unknown error" } + +var ( + ErrInvalidToken = errors.New("invalid token") + ErrNotAdmin = errors.New("user is not admin") + ErrUserInactive = errors.New("user is inactive") + ErrUserNotFound = errors.New("user not found") + // ErrNotFound is returned when an object is not found + ErrNotFound = errors.New("object not found") + // ErrBucketNotFound is returned when a bucket is not found + ErrBucketNotFound = errors.New("bucket not found") + ErrTaskNotFound = errors.New("task id not found") +) diff --git a/internal/common/ingestion.pb.go b/internal/common/ingestion.pb.go deleted file mode 100644 index a2211b2a64..0000000000 --- a/internal/common/ingestion.pb.go +++ /dev/null @@ -1,795 +0,0 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. -// versions: -// protoc-gen-go v1.36.11 -// protoc v3.21.12 -// source: internal/proto/ingestion.proto - -package common - -import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" - reflect "reflect" - sync "sync" - unsafe "unsafe" -) - -const ( - // Verify that this generated code is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) - // Verify that runtime/protoimpl is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) -) - -type IngestionMessage struct { - state protoimpl.MessageState `protogen:"open.v1"` - IngestorId string `protobuf:"bytes,1,opt,name=ingestor_id,json=ingestorId,proto3" json:"ingestor_id,omitempty"` - MessageType string `protobuf:"bytes,2,opt,name=message_type,json=messageType,proto3" json:"message_type,omitempty"` // REGISTER, HEARTBEAT, TASK_RESULT, TASK_PROGRESS, PULL_REQUEST - RegisterInfo *RegisterInfo `protobuf:"bytes,3,opt,name=register_info,json=registerInfo,proto3" json:"register_info,omitempty"` - HeartbeatInfo *HeartbeatInfo `protobuf:"bytes,4,opt,name=heartbeat_info,json=heartbeatInfo,proto3" json:"heartbeat_info,omitempty"` - TaskResult *TaskResult `protobuf:"bytes,5,opt,name=task_result,json=taskResult,proto3" json:"task_result,omitempty"` - TaskProgress *TaskProgress `protobuf:"bytes,6,opt,name=task_progress,json=taskProgress,proto3" json:"task_progress,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *IngestionMessage) Reset() { - *x = IngestionMessage{} - mi := &file_internal_proto_ingestion_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *IngestionMessage) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*IngestionMessage) ProtoMessage() {} - -func (x *IngestionMessage) ProtoReflect() protoreflect.Message { - mi := &file_internal_proto_ingestion_proto_msgTypes[0] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use IngestionMessage.ProtoReflect.Descriptor instead. -func (*IngestionMessage) Descriptor() ([]byte, []int) { - return file_internal_proto_ingestion_proto_rawDescGZIP(), []int{0} -} - -func (x *IngestionMessage) GetIngestorId() string { - if x != nil { - return x.IngestorId - } - return "" -} - -func (x *IngestionMessage) GetMessageType() string { - if x != nil { - return x.MessageType - } - return "" -} - -func (x *IngestionMessage) GetRegisterInfo() *RegisterInfo { - if x != nil { - return x.RegisterInfo - } - return nil -} - -func (x *IngestionMessage) GetHeartbeatInfo() *HeartbeatInfo { - if x != nil { - return x.HeartbeatInfo - } - return nil -} - -func (x *IngestionMessage) GetTaskResult() *TaskResult { - if x != nil { - return x.TaskResult - } - return nil -} - -func (x *IngestionMessage) GetTaskProgress() *TaskProgress { - if x != nil { - return x.TaskProgress - } - return nil -} - -type AdminMessage struct { - state protoimpl.MessageState `protogen:"open.v1"` - MessageType string `protobuf:"bytes,1,opt,name=message_type,json=messageType,proto3" json:"message_type,omitempty"` // TASK_ASSIGNMENT, ACK, PONG, RECONNECT - TaskAssignment *TaskAssignment `protobuf:"bytes,2,opt,name=task_assignment,json=taskAssignment,proto3" json:"task_assignment,omitempty"` - AckInfo *AckInfo `protobuf:"bytes,3,opt,name=ack_info,json=ackInfo,proto3" json:"ack_info,omitempty"` - ErrorMessage string `protobuf:"bytes,4,opt,name=error_message,json=errorMessage,proto3" json:"error_message,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *AdminMessage) Reset() { - *x = AdminMessage{} - mi := &file_internal_proto_ingestion_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *AdminMessage) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*AdminMessage) ProtoMessage() {} - -func (x *AdminMessage) ProtoReflect() protoreflect.Message { - mi := &file_internal_proto_ingestion_proto_msgTypes[1] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use AdminMessage.ProtoReflect.Descriptor instead. -func (*AdminMessage) Descriptor() ([]byte, []int) { - return file_internal_proto_ingestion_proto_rawDescGZIP(), []int{1} -} - -func (x *AdminMessage) GetMessageType() string { - if x != nil { - return x.MessageType - } - return "" -} - -func (x *AdminMessage) GetTaskAssignment() *TaskAssignment { - if x != nil { - return x.TaskAssignment - } - return nil -} - -func (x *AdminMessage) GetAckInfo() *AckInfo { - if x != nil { - return x.AckInfo - } - return nil -} - -func (x *AdminMessage) GetErrorMessage() string { - if x != nil { - return x.ErrorMessage - } - return "" -} - -type RegisterInfo struct { - state protoimpl.MessageState `protogen:"open.v1"` - MaxConcurrency int32 `protobuf:"varint,1,opt,name=max_concurrency,json=maxConcurrency,proto3" json:"max_concurrency,omitempty"` - SupportedDocTypes []string `protobuf:"bytes,2,rep,name=supported_doc_types,json=supportedDocTypes,proto3" json:"supported_doc_types,omitempty"` - Version string `protobuf:"bytes,3,opt,name=version,proto3" json:"version,omitempty"` - Name string `protobuf:"bytes,4,opt,name=name,proto3" json:"name,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *RegisterInfo) Reset() { - *x = RegisterInfo{} - mi := &file_internal_proto_ingestion_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *RegisterInfo) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*RegisterInfo) ProtoMessage() {} - -func (x *RegisterInfo) ProtoReflect() protoreflect.Message { - mi := &file_internal_proto_ingestion_proto_msgTypes[2] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use RegisterInfo.ProtoReflect.Descriptor instead. -func (*RegisterInfo) Descriptor() ([]byte, []int) { - return file_internal_proto_ingestion_proto_rawDescGZIP(), []int{2} -} - -func (x *RegisterInfo) GetMaxConcurrency() int32 { - if x != nil { - return x.MaxConcurrency - } - return 0 -} - -func (x *RegisterInfo) GetSupportedDocTypes() []string { - if x != nil { - return x.SupportedDocTypes - } - return nil -} - -func (x *RegisterInfo) GetVersion() string { - if x != nil { - return x.Version - } - return "" -} - -func (x *RegisterInfo) GetName() string { - if x != nil { - return x.Name - } - return "" -} - -type HeartbeatInfo struct { - state protoimpl.MessageState `protogen:"open.v1"` - TaskStates []*TaskState `protobuf:"bytes,1,rep,name=task_states,json=taskStates,proto3" json:"task_states,omitempty"` - DeleteTaskIds []string `protobuf:"bytes,2,rep,name=delete_task_ids,json=deleteTaskIds,proto3" json:"delete_task_ids,omitempty"` - CpuUsage float32 `protobuf:"fixed32,3,opt,name=cpu_usage,json=cpuUsage,proto3" json:"cpu_usage,omitempty"` // percentage - VmsUsage float32 `protobuf:"fixed32,4,opt,name=vms_usage,json=vmsUsage,proto3" json:"vms_usage,omitempty"` // absolute value - RssUsage float32 `protobuf:"fixed32,5,opt,name=rss_usage,json=rssUsage,proto3" json:"rss_usage,omitempty"` // absolute value - ProcessId int64 `protobuf:"varint,6,opt,name=process_id,json=processId,proto3" json:"process_id,omitempty"` // pid - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *HeartbeatInfo) Reset() { - *x = HeartbeatInfo{} - mi := &file_internal_proto_ingestion_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *HeartbeatInfo) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*HeartbeatInfo) ProtoMessage() {} - -func (x *HeartbeatInfo) ProtoReflect() protoreflect.Message { - mi := &file_internal_proto_ingestion_proto_msgTypes[3] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use HeartbeatInfo.ProtoReflect.Descriptor instead. -func (*HeartbeatInfo) Descriptor() ([]byte, []int) { - return file_internal_proto_ingestion_proto_rawDescGZIP(), []int{3} -} - -func (x *HeartbeatInfo) GetTaskStates() []*TaskState { - if x != nil { - return x.TaskStates - } - return nil -} - -func (x *HeartbeatInfo) GetDeleteTaskIds() []string { - if x != nil { - return x.DeleteTaskIds - } - return nil -} - -func (x *HeartbeatInfo) GetCpuUsage() float32 { - if x != nil { - return x.CpuUsage - } - return 0 -} - -func (x *HeartbeatInfo) GetVmsUsage() float32 { - if x != nil { - return x.VmsUsage - } - return 0 -} - -func (x *HeartbeatInfo) GetRssUsage() float32 { - if x != nil { - return x.RssUsage - } - return 0 -} - -func (x *HeartbeatInfo) GetProcessId() int64 { - if x != nil { - return x.ProcessId - } - return 0 -} - -type TaskState struct { - state protoimpl.MessageState `protogen:"open.v1"` - TaskId string `protobuf:"bytes,1,opt,name=task_id,json=taskId,proto3" json:"task_id,omitempty"` - Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` // PENDING, RUNNING, COMPLETED, FAILED, CANCELLED - ErrorMessage string `protobuf:"bytes,3,opt,name=error_message,json=errorMessage,proto3" json:"error_message,omitempty"` - EstimatedRemainingTimeSeconds int64 `protobuf:"varint,4,opt,name=estimated_remaining_time_seconds,json=estimatedRemainingTimeSeconds,proto3" json:"estimated_remaining_time_seconds,omitempty"` - ComeFrom string `protobuf:"bytes,5,opt,name=come_from,json=comeFrom,proto3" json:"come_from,omitempty"` - StartTime int64 `protobuf:"varint,6,opt,name=start_time,json=startTime,proto3" json:"start_time,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *TaskState) Reset() { - *x = TaskState{} - mi := &file_internal_proto_ingestion_proto_msgTypes[4] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *TaskState) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*TaskState) ProtoMessage() {} - -func (x *TaskState) ProtoReflect() protoreflect.Message { - mi := &file_internal_proto_ingestion_proto_msgTypes[4] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use TaskState.ProtoReflect.Descriptor instead. -func (*TaskState) Descriptor() ([]byte, []int) { - return file_internal_proto_ingestion_proto_rawDescGZIP(), []int{4} -} - -func (x *TaskState) GetTaskId() string { - if x != nil { - return x.TaskId - } - return "" -} - -func (x *TaskState) GetStatus() string { - if x != nil { - return x.Status - } - return "" -} - -func (x *TaskState) GetErrorMessage() string { - if x != nil { - return x.ErrorMessage - } - return "" -} - -func (x *TaskState) GetEstimatedRemainingTimeSeconds() int64 { - if x != nil { - return x.EstimatedRemainingTimeSeconds - } - return 0 -} - -func (x *TaskState) GetComeFrom() string { - if x != nil { - return x.ComeFrom - } - return "" -} - -func (x *TaskState) GetStartTime() int64 { - if x != nil { - return x.StartTime - } - return 0 -} - -type TaskAssignment struct { - state protoimpl.MessageState `protogen:"open.v1"` - TaskId string `protobuf:"bytes,1,opt,name=task_id,json=taskId,proto3" json:"task_id,omitempty"` - TaskType string `protobuf:"bytes,2,opt,name=task_type,json=taskType,proto3" json:"task_type,omitempty"` - Config string `protobuf:"bytes,3,opt,name=config,proto3" json:"config,omitempty"` - ComeFrom string `protobuf:"bytes,4,opt,name=come_from,json=comeFrom,proto3" json:"come_from,omitempty"` - AssignedTo string `protobuf:"bytes,5,opt,name=assigned_to,json=assignedTo,proto3" json:"assigned_to,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *TaskAssignment) Reset() { - *x = TaskAssignment{} - mi := &file_internal_proto_ingestion_proto_msgTypes[5] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *TaskAssignment) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*TaskAssignment) ProtoMessage() {} - -func (x *TaskAssignment) ProtoReflect() protoreflect.Message { - mi := &file_internal_proto_ingestion_proto_msgTypes[5] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use TaskAssignment.ProtoReflect.Descriptor instead. -func (*TaskAssignment) Descriptor() ([]byte, []int) { - return file_internal_proto_ingestion_proto_rawDescGZIP(), []int{5} -} - -func (x *TaskAssignment) GetTaskId() string { - if x != nil { - return x.TaskId - } - return "" -} - -func (x *TaskAssignment) GetTaskType() string { - if x != nil { - return x.TaskType - } - return "" -} - -func (x *TaskAssignment) GetConfig() string { - if x != nil { - return x.Config - } - return "" -} - -func (x *TaskAssignment) GetComeFrom() string { - if x != nil { - return x.ComeFrom - } - return "" -} - -func (x *TaskAssignment) GetAssignedTo() string { - if x != nil { - return x.AssignedTo - } - return "" -} - -type TaskResult struct { - state protoimpl.MessageState `protogen:"open.v1"` - TaskId string `protobuf:"bytes,1,opt,name=task_id,json=taskId,proto3" json:"task_id,omitempty"` - Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` // COMPLETED, FAILED, CANCELLED - ErrorMessage string `protobuf:"bytes,3,opt,name=error_message,json=errorMessage,proto3" json:"error_message,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *TaskResult) Reset() { - *x = TaskResult{} - mi := &file_internal_proto_ingestion_proto_msgTypes[6] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *TaskResult) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*TaskResult) ProtoMessage() {} - -func (x *TaskResult) ProtoReflect() protoreflect.Message { - mi := &file_internal_proto_ingestion_proto_msgTypes[6] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use TaskResult.ProtoReflect.Descriptor instead. -func (*TaskResult) Descriptor() ([]byte, []int) { - return file_internal_proto_ingestion_proto_rawDescGZIP(), []int{6} -} - -func (x *TaskResult) GetTaskId() string { - if x != nil { - return x.TaskId - } - return "" -} - -func (x *TaskResult) GetStatus() string { - if x != nil { - return x.Status - } - return "" -} - -func (x *TaskResult) GetErrorMessage() string { - if x != nil { - return x.ErrorMessage - } - return "" -} - -type TaskProgress struct { - state protoimpl.MessageState `protogen:"open.v1"` - TaskId string `protobuf:"bytes,1,opt,name=task_id,json=taskId,proto3" json:"task_id,omitempty"` - Progress int32 `protobuf:"varint,2,opt,name=progress,proto3" json:"progress,omitempty"` - Info string `protobuf:"bytes,3,opt,name=info,proto3" json:"info,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *TaskProgress) Reset() { - *x = TaskProgress{} - mi := &file_internal_proto_ingestion_proto_msgTypes[7] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *TaskProgress) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*TaskProgress) ProtoMessage() {} - -func (x *TaskProgress) ProtoReflect() protoreflect.Message { - mi := &file_internal_proto_ingestion_proto_msgTypes[7] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use TaskProgress.ProtoReflect.Descriptor instead. -func (*TaskProgress) Descriptor() ([]byte, []int) { - return file_internal_proto_ingestion_proto_rawDescGZIP(), []int{7} -} - -func (x *TaskProgress) GetTaskId() string { - if x != nil { - return x.TaskId - } - return "" -} - -func (x *TaskProgress) GetProgress() int32 { - if x != nil { - return x.Progress - } - return 0 -} - -func (x *TaskProgress) GetInfo() string { - if x != nil { - return x.Info - } - return "" -} - -type AckInfo struct { - state protoimpl.MessageState `protogen:"open.v1"` - TaskId string `protobuf:"bytes,1,opt,name=task_id,json=taskId,proto3" json:"task_id,omitempty"` - Success bool `protobuf:"varint,2,opt,name=success,proto3" json:"success,omitempty"` - Message string `protobuf:"bytes,3,opt,name=message,proto3" json:"message,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *AckInfo) Reset() { - *x = AckInfo{} - mi := &file_internal_proto_ingestion_proto_msgTypes[8] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *AckInfo) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*AckInfo) ProtoMessage() {} - -func (x *AckInfo) ProtoReflect() protoreflect.Message { - mi := &file_internal_proto_ingestion_proto_msgTypes[8] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use AckInfo.ProtoReflect.Descriptor instead. -func (*AckInfo) Descriptor() ([]byte, []int) { - return file_internal_proto_ingestion_proto_rawDescGZIP(), []int{8} -} - -func (x *AckInfo) GetTaskId() string { - if x != nil { - return x.TaskId - } - return "" -} - -func (x *AckInfo) GetSuccess() bool { - if x != nil { - return x.Success - } - return false -} - -func (x *AckInfo) GetMessage() string { - if x != nil { - return x.Message - } - return "" -} - -var File_internal_proto_ingestion_proto protoreflect.FileDescriptor - -const file_internal_proto_ingestion_proto_rawDesc = "" + - "\n" + - "\x1einternal/proto/ingestion.proto\x12\x06common\"\xbf\x02\n" + - "\x10IngestionMessage\x12\x1f\n" + - "\vingestor_id\x18\x01 \x01(\tR\n" + - "ingestorId\x12!\n" + - "\fmessage_type\x18\x02 \x01(\tR\vmessageType\x129\n" + - "\rregister_info\x18\x03 \x01(\v2\x14.common.RegisterInfoR\fregisterInfo\x12<\n" + - "\x0eheartbeat_info\x18\x04 \x01(\v2\x15.common.HeartbeatInfoR\rheartbeatInfo\x123\n" + - "\vtask_result\x18\x05 \x01(\v2\x12.common.TaskResultR\n" + - "taskResult\x129\n" + - "\rtask_progress\x18\x06 \x01(\v2\x14.common.TaskProgressR\ftaskProgress\"\xc3\x01\n" + - "\fAdminMessage\x12!\n" + - "\fmessage_type\x18\x01 \x01(\tR\vmessageType\x12?\n" + - "\x0ftask_assignment\x18\x02 \x01(\v2\x16.common.TaskAssignmentR\x0etaskAssignment\x12*\n" + - "\back_info\x18\x03 \x01(\v2\x0f.common.AckInfoR\aackInfo\x12#\n" + - "\rerror_message\x18\x04 \x01(\tR\ferrorMessage\"\x95\x01\n" + - "\fRegisterInfo\x12'\n" + - "\x0fmax_concurrency\x18\x01 \x01(\x05R\x0emaxConcurrency\x12.\n" + - "\x13supported_doc_types\x18\x02 \x03(\tR\x11supportedDocTypes\x12\x18\n" + - "\aversion\x18\x03 \x01(\tR\aversion\x12\x12\n" + - "\x04name\x18\x04 \x01(\tR\x04name\"\xe1\x01\n" + - "\rHeartbeatInfo\x122\n" + - "\vtask_states\x18\x01 \x03(\v2\x11.common.TaskStateR\n" + - "taskStates\x12&\n" + - "\x0fdelete_task_ids\x18\x02 \x03(\tR\rdeleteTaskIds\x12\x1b\n" + - "\tcpu_usage\x18\x03 \x01(\x02R\bcpuUsage\x12\x1b\n" + - "\tvms_usage\x18\x04 \x01(\x02R\bvmsUsage\x12\x1b\n" + - "\trss_usage\x18\x05 \x01(\x02R\brssUsage\x12\x1d\n" + - "\n" + - "process_id\x18\x06 \x01(\x03R\tprocessId\"\xe6\x01\n" + - "\tTaskState\x12\x17\n" + - "\atask_id\x18\x01 \x01(\tR\x06taskId\x12\x16\n" + - "\x06status\x18\x02 \x01(\tR\x06status\x12#\n" + - "\rerror_message\x18\x03 \x01(\tR\ferrorMessage\x12G\n" + - " estimated_remaining_time_seconds\x18\x04 \x01(\x03R\x1destimatedRemainingTimeSeconds\x12\x1b\n" + - "\tcome_from\x18\x05 \x01(\tR\bcomeFrom\x12\x1d\n" + - "\n" + - "start_time\x18\x06 \x01(\x03R\tstartTime\"\x9c\x01\n" + - "\x0eTaskAssignment\x12\x17\n" + - "\atask_id\x18\x01 \x01(\tR\x06taskId\x12\x1b\n" + - "\ttask_type\x18\x02 \x01(\tR\btaskType\x12\x16\n" + - "\x06config\x18\x03 \x01(\tR\x06config\x12\x1b\n" + - "\tcome_from\x18\x04 \x01(\tR\bcomeFrom\x12\x1f\n" + - "\vassigned_to\x18\x05 \x01(\tR\n" + - "assignedTo\"b\n" + - "\n" + - "TaskResult\x12\x17\n" + - "\atask_id\x18\x01 \x01(\tR\x06taskId\x12\x16\n" + - "\x06status\x18\x02 \x01(\tR\x06status\x12#\n" + - "\rerror_message\x18\x03 \x01(\tR\ferrorMessage\"W\n" + - "\fTaskProgress\x12\x17\n" + - "\atask_id\x18\x01 \x01(\tR\x06taskId\x12\x1a\n" + - "\bprogress\x18\x02 \x01(\x05R\bprogress\x12\x12\n" + - "\x04info\x18\x03 \x01(\tR\x04info\"V\n" + - "\aAckInfo\x12\x17\n" + - "\atask_id\x18\x01 \x01(\tR\x06taskId\x12\x18\n" + - "\asuccess\x18\x02 \x01(\bR\asuccess\x12\x18\n" + - "\amessage\x18\x03 \x01(\tR\amessage2P\n" + - "\x10IngestionManager\x12<\n" + - "\x06Action\x12\x18.common.IngestionMessage\x1a\x14.common.AdminMessage(\x010\x01B\vZ\t./;commonb\x06proto3" - -var ( - file_internal_proto_ingestion_proto_rawDescOnce sync.Once - file_internal_proto_ingestion_proto_rawDescData []byte -) - -func file_internal_proto_ingestion_proto_rawDescGZIP() []byte { - file_internal_proto_ingestion_proto_rawDescOnce.Do(func() { - file_internal_proto_ingestion_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_internal_proto_ingestion_proto_rawDesc), len(file_internal_proto_ingestion_proto_rawDesc))) - }) - return file_internal_proto_ingestion_proto_rawDescData -} - -var file_internal_proto_ingestion_proto_msgTypes = make([]protoimpl.MessageInfo, 9) -var file_internal_proto_ingestion_proto_goTypes = []any{ - (*IngestionMessage)(nil), // 0: common.IngestionMessage - (*AdminMessage)(nil), // 1: common.AdminMessage - (*RegisterInfo)(nil), // 2: common.RegisterInfo - (*HeartbeatInfo)(nil), // 3: common.HeartbeatInfo - (*TaskState)(nil), // 4: common.TaskState - (*TaskAssignment)(nil), // 5: common.TaskAssignment - (*TaskResult)(nil), // 6: common.TaskResult - (*TaskProgress)(nil), // 7: common.TaskProgress - (*AckInfo)(nil), // 8: common.AckInfo -} -var file_internal_proto_ingestion_proto_depIdxs = []int32{ - 2, // 0: common.IngestionMessage.register_info:type_name -> common.RegisterInfo - 3, // 1: common.IngestionMessage.heartbeat_info:type_name -> common.HeartbeatInfo - 6, // 2: common.IngestionMessage.task_result:type_name -> common.TaskResult - 7, // 3: common.IngestionMessage.task_progress:type_name -> common.TaskProgress - 5, // 4: common.AdminMessage.task_assignment:type_name -> common.TaskAssignment - 8, // 5: common.AdminMessage.ack_info:type_name -> common.AckInfo - 4, // 6: common.HeartbeatInfo.task_states:type_name -> common.TaskState - 0, // 7: common.IngestionManager.Action:input_type -> common.IngestionMessage - 1, // 8: common.IngestionManager.Action:output_type -> common.AdminMessage - 8, // [8:9] is the sub-list for method output_type - 7, // [7:8] is the sub-list for method input_type - 7, // [7:7] is the sub-list for extension type_name - 7, // [7:7] is the sub-list for extension extendee - 0, // [0:7] is the sub-list for field type_name -} - -func init() { file_internal_proto_ingestion_proto_init() } -func file_internal_proto_ingestion_proto_init() { - if File_internal_proto_ingestion_proto != nil { - return - } - type x struct{} - out := protoimpl.TypeBuilder{ - File: protoimpl.DescBuilder{ - GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: unsafe.Slice(unsafe.StringData(file_internal_proto_ingestion_proto_rawDesc), len(file_internal_proto_ingestion_proto_rawDesc)), - NumEnums: 0, - NumMessages: 9, - NumExtensions: 0, - NumServices: 1, - }, - GoTypes: file_internal_proto_ingestion_proto_goTypes, - DependencyIndexes: file_internal_proto_ingestion_proto_depIdxs, - MessageInfos: file_internal_proto_ingestion_proto_msgTypes, - }.Build() - File_internal_proto_ingestion_proto = out.File - file_internal_proto_ingestion_proto_goTypes = nil - file_internal_proto_ingestion_proto_depIdxs = nil -} diff --git a/internal/common/ingestion_grpc.pb.go b/internal/common/ingestion_grpc.pb.go deleted file mode 100644 index 5d43bce800..0000000000 --- a/internal/common/ingestion_grpc.pb.go +++ /dev/null @@ -1,115 +0,0 @@ -// Code generated by protoc-gen-go-grpc. DO NOT EDIT. -// versions: -// - protoc-gen-go-grpc v1.6.1 -// - protoc v3.21.12 -// source: internal/proto/ingestion.proto - -package common - -import ( - context "context" - grpc "google.golang.org/grpc" - codes "google.golang.org/grpc/codes" - status "google.golang.org/grpc/status" -) - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.64.0 or later. -const _ = grpc.SupportPackageIsVersion9 - -const ( - IngestionManager_Action_FullMethodName = "/common.IngestionManager/Action" -) - -// IngestionManagerClient is the client API for IngestionManager service. -// -// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. -type IngestionManagerClient interface { - Action(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[IngestionMessage, AdminMessage], error) -} - -type ingestionManagerClient struct { - cc grpc.ClientConnInterface -} - -func NewIngestionManagerClient(cc grpc.ClientConnInterface) IngestionManagerClient { - return &ingestionManagerClient{cc} -} - -func (c *ingestionManagerClient) Action(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[IngestionMessage, AdminMessage], error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - stream, err := c.cc.NewStream(ctx, &IngestionManager_ServiceDesc.Streams[0], IngestionManager_Action_FullMethodName, cOpts...) - if err != nil { - return nil, err - } - x := &grpc.GenericClientStream[IngestionMessage, AdminMessage]{ClientStream: stream} - return x, nil -} - -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type IngestionManager_ActionClient = grpc.BidiStreamingClient[IngestionMessage, AdminMessage] - -// IngestionManagerServer is the server API for IngestionManager service. -// All implementations must embed UnimplementedIngestionManagerServer -// for forward compatibility. -type IngestionManagerServer interface { - Action(grpc.BidiStreamingServer[IngestionMessage, AdminMessage]) error - mustEmbedUnimplementedIngestionManagerServer() -} - -// UnimplementedIngestionManagerServer must be embedded to have -// forward compatible implementations. -// -// NOTE: this should be embedded by value instead of pointer to avoid a nil -// pointer dereference when methods are called. -type UnimplementedIngestionManagerServer struct{} - -func (UnimplementedIngestionManagerServer) Action(grpc.BidiStreamingServer[IngestionMessage, AdminMessage]) error { - return status.Error(codes.Unimplemented, "method Action not implemented") -} -func (UnimplementedIngestionManagerServer) mustEmbedUnimplementedIngestionManagerServer() {} -func (UnimplementedIngestionManagerServer) testEmbeddedByValue() {} - -// UnsafeIngestionManagerServer may be embedded to opt out of forward compatibility for this service. -// Use of this interface is not recommended, as added methods to IngestionManagerServer will -// result in compilation errors. -type UnsafeIngestionManagerServer interface { - mustEmbedUnimplementedIngestionManagerServer() -} - -func RegisterIngestionManagerServer(s grpc.ServiceRegistrar, srv IngestionManagerServer) { - // If the following call panics, it indicates UnimplementedIngestionManagerServer was - // embedded by pointer and is nil. This will cause panics if an - // unimplemented method is ever invoked, so we test this at initialization - // time to prevent it from happening at runtime later due to I/O. - if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { - t.testEmbeddedByValue() - } - s.RegisterService(&IngestionManager_ServiceDesc, srv) -} - -func _IngestionManager_Action_Handler(srv interface{}, stream grpc.ServerStream) error { - return srv.(IngestionManagerServer).Action(&grpc.GenericServerStream[IngestionMessage, AdminMessage]{ServerStream: stream}) -} - -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type IngestionManager_ActionServer = grpc.BidiStreamingServer[IngestionMessage, AdminMessage] - -// IngestionManager_ServiceDesc is the grpc.ServiceDesc for IngestionManager service. -// It's only intended for direct use with grpc.RegisterService, -// and not to be introspected or modified (even as a copy) -var IngestionManager_ServiceDesc = grpc.ServiceDesc{ - ServiceName: "common.IngestionManager", - HandlerType: (*IngestionManagerServer)(nil), - Methods: []grpc.MethodDesc{}, - Streams: []grpc.StreamDesc{ - { - StreamName: "Action", - Handler: _IngestionManager_Action_Handler, - ServerStreams: true, - ClientStreams: true, - }, - }, - Metadata: "internal/proto/ingestion.proto", -} diff --git a/internal/common/float.go b/internal/common/number.go similarity index 96% rename from internal/common/float.go rename to internal/common/number.go index fd989f1000..f78d541844 100644 --- a/internal/common/float.go +++ b/internal/common/number.go @@ -139,3 +139,14 @@ func PairwiseSum(xs []float64) float64 { } return xs[0] } + +func GetInt(value interface{}) (int, bool) { + switch v := value.(type) { + case int: + return v, true + case float64: + return int(v), true + default: + return 0, false + } +} diff --git a/internal/common/status_message.go b/internal/common/status_message.go index 1412a5e05a..8be706274e 100644 --- a/internal/common/status_message.go +++ b/internal/common/status_message.go @@ -31,3 +31,10 @@ type BaseMessage struct { Timestamp time.Time `json:"timestamp"` Ext interface{} `json:"ext,omitempty"` } + +type StartIngestionRequest struct { + TaskID string `json:"task_id" binding:"required"` + TaskType string `json:"task_type" binding:"required"` + From string `json:"from" binding:"required"` + UserID string `json:"user_id" binding:"required"` +} diff --git a/internal/common/task.go b/internal/common/task.go new file mode 100644 index 0000000000..24132264e4 --- /dev/null +++ b/internal/common/task.go @@ -0,0 +1,34 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package common + +const ( + TaskTypeIngestionTask = "ingestion_task" + TaskTypeIngestionTasklet = "ingestion_tasklet" + TaskTypeIngestionTest = "ingestion_test" +) + +type TaskMessage struct { + TaskID string `json:"task_id" binding:"required"` + TaskType string `json:"task_type" binding:"required"` +} + +type TaskHandle interface { + GetMessage() TaskMessage + Ack() error + Nack() error +} diff --git a/internal/dao/database.go b/internal/dao/database.go index a692bfd782..6466a7743b 100644 --- a/internal/dao/database.go +++ b/internal/dao/database.go @@ -149,6 +149,10 @@ func InitDB() error { &entity.TenantModelGroupMapping{}, &entity.TenantModelProvider{}, &entity.TenantModelGroup{}, + &entity.IngestionTask{}, + &entity.IngestionTaskLog{}, + &entity.IngestionTasklet{}, + &entity.IngestionTaskletLog{}, } for _, m := range dataModels { diff --git a/internal/dao/ingestion.go b/internal/dao/ingestion.go index 6aad681d66..02e7d085c1 100644 --- a/internal/dao/ingestion.go +++ b/internal/dao/ingestion.go @@ -17,20 +17,33 @@ package dao import ( + "errors" + "fmt" + "ragflow/internal/common" "ragflow/internal/entity" ) -type IngestionDAO struct{} +type IngestionTaskDAO struct{} -func NewIngestionDAO() *IngestionDAO { - return &IngestionDAO{} +func NewIngestionTaskDAO() *IngestionTaskDAO { + return &IngestionTaskDAO{} } -func (dao *IngestionDAO) Create(ingestionTask *entity.IngestionTask) error { +// Use by api server to create task +// created → running : After the ingestor component assigns the task, it changes the status to running +// running → completed : Task executes successfully +// running → failed : Error occurs during execution +// created → canceling : User cancels before the task is picked up by the ingestor +// running → canceling : User cancels during execution +// completed → canceling : User cancels a completed task (e.g., for cleanup/rollback) +// canceling → canceled : Cancellation completes +// failed → created : Retry (back to start) +// canceled → created : Retry/re-execute (back to start) +func (dao *IngestionTaskDAO) CheckAndCreate(ingestionTask *entity.IngestionTask) (*entity.IngestionTask, error) { tx := DB.Begin() if tx.Error != nil { - return tx.Error + return nil, tx.Error } defer func() { @@ -40,71 +53,425 @@ func (dao *IngestionDAO) Create(ingestionTask *entity.IngestionTask) error { } }() - // create ingestion task - if err := DB.Create(ingestionTask).Error; err != nil { - tx.Rollback() - return err + // Check if the task is created + var taskRecord *entity.IngestionTask + err := tx.Where("document_id = ?", ingestionTask.DocumentID).First(&taskRecord).Error + if err == nil { + // found + if taskRecord.Status == common.FAILED || taskRecord.Status == common.STOPPED { + // restart the task + err = tx.Model(&entity.IngestionTask{}).Where("id = ?", taskRecord.ID).Update("status", common.CREATED).Error + if err != nil { + tx.Rollback() + return nil, err + } + } else { + return nil, fmt.Errorf("document id %s already exists, status: %s, task id: %s", ingestionTask.DocumentID, taskRecord.Status, taskRecord.ID) + } + } else { + // create ingestion task + ingestionTask.ID = common.GenerateUUID() + if err = tx.Create(ingestionTask).Error; err != nil { + tx.Rollback() + return nil, err + } + taskRecord = ingestionTask } - taskLog := &entity.IngestionTaskLog{ - TaskID: ingestionTask.ID, - Stage: 0, + if err = tx.Commit().Error; err != nil { + return nil, err } - // create task log - if err := DB.Create(taskLog).Error; err != nil { - tx.Rollback() - return err - } - - if err := tx.Commit().Error; err != nil { - return err - } - - return nil + return taskRecord, nil } -func (dao *IngestionDAO) GetAllTasks() ([]*entity.IngestionTask, error) { +// UpdateStatus Update ingestion task status +func (dao *IngestionTaskDAO) UpdateStatus(taskID, status string) error { + return DB.Model(&entity.IngestionTask{}).Where("id = ?", taskID).Update("status", status).Error +} + +// CheckAnd called by ingestor +// if task status is RUNNING, COMPLETED, STOPPED, FAILED, just return without error +// if task status is CREATE, update to RUNNING +// if task status is STOPPING, update to STOPPED +func (dao *IngestionTaskDAO) SetRunningByIngestor(taskID string) (*entity.IngestionTask, error) { + + tx := DB.Begin() + if tx.Error != nil { + return nil, tx.Error + } + var committed bool + + defer func() { + if committed { + tx.Commit() + } else { + tx.Rollback() + if r := recover(); r != nil { + panic(r) + } + } + }() + var tasks []*entity.IngestionTask - err := DB.Find(&tasks).Error + err := tx.Where("id = ?", taskID).Find(&tasks).Error + if err != nil { + return nil, err + } + + if len(tasks) == 0 { + return nil, common.ErrTaskNotFound + } + + if len(tasks) != 1 { + return nil, fmt.Errorf("task %s has multiple records", taskID) + } + + taskStatus := tasks[0].Status + switch taskStatus { + case common.CREATED: + tasks[0].Status = common.RUNNING + err = tx.Model(&entity.IngestionTask{}).Where("id = ?", taskID).Update("status", common.RUNNING).Error + if err != nil { + return nil, err + } + committed = true + return tasks[0], nil + case common.STOPPING: + tasks[0].Status = common.STOPPED + err = tx.Model(&entity.IngestionTask{}).Where("id = ?", taskID).Update("status", common.STOPPED).Error + if err != nil { + return nil, err + } + committed = true + return tasks[0], nil + case common.RUNNING: + // this task was executing before, just return without error + committed = true + return tasks[0], nil + default: + return tasks[0], nil + } +} + +func (dao *IngestionTaskDAO) SetStoppingByAPIServer(taskID string) (*entity.IngestionTask, error) { + + tx := DB.Begin() + if tx.Error != nil { + return nil, tx.Error + } + var committed bool + + defer func() { + if committed { + tx.Commit() + } else { + tx.Rollback() + if r := recover(); r != nil { + panic(r) + } + } + }() + + var tasks []*entity.IngestionTask + err := tx.Where("id = ?", taskID).Find(&tasks).Error + if err != nil { + return nil, err + } + + if len(tasks) == 0 { + return nil, fmt.Errorf("task %s not found", taskID) + } + + if len(tasks) != 1 { + return nil, fmt.Errorf("task %s has multiple records", taskID) + } + + taskStatus := tasks[0].Status + switch taskStatus { + case common.CREATED: + tasks[0].Status = common.STOPPED + err = tx.Model(&entity.IngestionTask{}).Where("id = ?", taskID).Update("status", common.STOPPED).Error + if err != nil { + return nil, err + } + committed = true + return tasks[0], nil + case common.RUNNING: + tasks[0].Status = common.STOPPING + err = tx.Model(&entity.IngestionTask{}).Where("id = ?", taskID).Update("status", common.STOPPING).Error + if err != nil { + return nil, err + } + committed = true + return tasks[0], nil + default: + return tasks[0], nil + } +} + +type TaskletInfo struct { + TaskletID string `json:"tasklet_id"` + FilesToDelete []string `json:"files_to_delete"` +} + +type TaskInfo struct { + TaskID string `json:"task_id"` + FilesToDelete []string `json:"files_to_delete"` + Tasklets []TaskletInfo `json:"tasklets"` +} + +func (dao *IngestionTaskDAO) RemoveByAPIServerOrAdminServer(taskID string, userID *string) (*TaskInfo, error) { + + tx := DB.Begin() + if tx.Error != nil { + return nil, tx.Error + } + var committed bool + + defer func() { + if committed { + tx.Commit() + } else { + tx.Rollback() + if r := recover(); r != nil { + panic(r) + } + } + }() + + var tasks []*entity.IngestionTask + err := tx.Where("id = ?", taskID).Find(&tasks).Error + if err != nil { + return nil, err + } + + if len(tasks) == 0 { + return nil, fmt.Errorf("task %s not found", taskID) + } + + if len(tasks) != 1 { + return nil, fmt.Errorf("task %s has multiple records", taskID) + } + + if userID != nil { + if tasks[0].UserID != *userID { + return nil, errors.New("task does not belong to the user") + } + } + + taskStatus := tasks[0].Status + switch taskStatus { + case common.CREATED, common.STOPPED, common.COMPLETED, common.FAILED: + // get all ingestion tasklets + var tasklets []*entity.IngestionTasklet + err = tx.Where("task_id = ?", taskID).Find(&tasklets).Error + if err != nil { + return nil, err + } + var TaskletInfos []TaskletInfo + for _, tasklet := range tasklets { + // get all ingestion tasklet log + var taskletLogs []*entity.IngestionTaskletLog + err = tx.Where("tasklet_id = ?", tasklet.ID).Find(&taskletLogs).Error + + fileMap := make(map[string]bool) + for _, taskletLog := range taskletLogs { + files, ok := taskletLog.Checkpoint["files"].([]string) + if ok { + for _, file := range files { + fileMap[file] = true + } + } + } + var filesToDelete []string + for file := range fileMap { + filesToDelete = append(filesToDelete, file) + } + TaskletInfos = append(TaskletInfos, TaskletInfo{ + TaskletID: tasklet.ID, + FilesToDelete: filesToDelete, + }) + } + + // get all ingestion task log + var taskLogs []*entity.IngestionTaskLog + err = tx.Where("task_id = ?", taskID).Find(&taskLogs).Error + if err != nil { + return nil, err + } + + fileMap := make(map[string]bool) + for _, taskLog := range taskLogs { + files, ok := taskLog.Checkpoint["files"].([]string) + if ok { + for _, file := range files { + fileMap[file] = true + } + } + } + var filesToDelete []string + for file := range fileMap { + filesToDelete = append(filesToDelete, file) + } + + err = tx.Model(&entity.IngestionTask{}).Where("id = ?", taskID).Delete(&entity.IngestionTask{}).Error + if err != nil { + return nil, err + } + + taskInfo := &TaskInfo{ + TaskID: taskID, + FilesToDelete: filesToDelete, + Tasklets: TaskletInfos, + } + committed = true + return taskInfo, nil + default: + return nil, fmt.Errorf("task %s is executing, cannot be removed", taskID) + } +} + +func (dao *IngestionTaskDAO) GetAllTasks(page, pageSize int) ([]*entity.IngestionTask, error) { + var tasks []*entity.IngestionTask + var err error + if pageSize == 0 { + err = DB.Find(&tasks).Error + } else { + err = DB.Order("create_time DESC").Offset((page - 1) * pageSize).Limit(pageSize).Find(&tasks).Error + } return tasks, err } -func (dao *IngestionDAO) ListByUserID(userID string) ([]*entity.IngestionTask, error) { +func (dao *IngestionTaskDAO) ListByUserID(userID string, page, pageSize int) ([]*entity.IngestionTask, error) { var tasks []*entity.IngestionTask - err := DB.Where("user_id = ?", userID).Find(&tasks).Error + var err error + if pageSize == 0 { + err = DB.Where("user_id = ?", userID).Order("create_time DESC").Find(&tasks).Error + } else { + err = DB.Where("user_id = ?", userID).Order("create_time DESC").Offset((page - 1) * pageSize).Limit(pageSize).Find(&tasks).Error + } + return tasks, err } -func (dao *IngestionDAO) GetByID(id string) (*entity.IngestionTask, error) { +func (dao *IngestionTaskDAO) ListByUserIDAndDatasetID(userID, datasetID string, page, pageSize int) ([]*entity.IngestionTask, error) { + var tasks []*entity.IngestionTask + var err error + if pageSize == 0 { + err = DB.Where("user_id = ? AND dataset_id = ?", userID, datasetID).Order("create_time DESC").Find(&tasks).Error + } else { + err = DB.Where("user_id = ? AND dataset_id = ?", userID, datasetID).Order("create_time DESC").Offset((page - 1) * pageSize).Limit(pageSize).Find(&tasks).Error + } + + return tasks, err +} + +func (dao *IngestionTaskDAO) GetByID(id string) (*entity.IngestionTask, error) { var task *entity.IngestionTask err := DB.Where("id = ?", id).First(&task).Error return task, err } -type IngestionLogDAO struct{} - -func NewIngestionLogDAO() *IngestionLogDAO { - return &IngestionLogDAO{} +func (dao *IngestionTaskDAO) GetByDocumentID(documentId string) (*entity.IngestionTask, error) { + var task *entity.IngestionTask + err := DB.Where("document_id = ?", documentId).First(&task).Error + return task, err } -func (dao *IngestionLogDAO) Create(ingestionLog *entity.IngestionTaskLog) error { +type IngestionTaskLogDAO struct{} + +func NewIngestionTaskLogDAO() *IngestionTaskLogDAO { + return &IngestionTaskLogDAO{} +} + +func (dao *IngestionTaskLogDAO) Create(ingestionLog *entity.IngestionTaskLog) error { return DB.Create(ingestionLog).Error } -func (dao *IngestionDAO) ListLogsByTaskID(taskID string) ([]*entity.IngestionTaskLog, error) { +func (dao *IngestionTaskLogDAO) ListLogsByTaskID(taskID string) ([]*entity.IngestionTaskLog, error) { var tasks []*entity.IngestionTaskLog - err := DB.Where("task_id = ?", taskID).Find(&tasks).Error + err := DB.Where("task_id = ?", taskID).Order("create_time DESC").Find(&tasks).Error return tasks, err } -func (dao *IngestionDAO) GetLogByLogID(logID string) (*entity.IngestionTaskLog, error) { +func (dao *IngestionTaskLogDAO) LatestLogByTaskID(taskID string) (*entity.IngestionTaskLog, error) { + var task *entity.IngestionTaskLog + err := DB.Where("task_id = ?", taskID).Order("create_time DESC").First(&task).Error + return task, err +} + +func (dao *IngestionTaskLogDAO) GetLogByLogID(logID string) (*entity.IngestionTaskLog, error) { var task *entity.IngestionTaskLog err := DB.Where("id = ?", logID).First(&task).Error return task, err } -func (dao *IngestionDAO) DeleteByTaskID(taskID string) (int64, error) { +func (dao *IngestionTaskLogDAO) DeleteByTaskID(taskID string) (int64, error) { result := DB.Unscoped().Where("task_id = ?", taskID).Delete(&entity.IngestionTaskLog{}) return result.RowsAffected, result.Error } + +type IngestionTaskletDAO struct{} + +func NewIngestionTaskletDAO() *IngestionTaskletDAO { + return &IngestionTaskletDAO{} +} + +func (dao *IngestionTaskletDAO) Create(ingestionTasklet *entity.IngestionTasklet) error { + return DB.Create(ingestionTasklet).Error +} + +func (dao *IngestionTaskletDAO) UpdateStatus(taskletID, status string) error { + return DB.Model(&entity.IngestionTasklet{}).Where("id = ?", taskletID).Update("status", status).Error +} +func (dao *IngestionTaskletDAO) GetAllTasklets() ([]*entity.IngestionTasklet, error) { + var tasks []*entity.IngestionTasklet + err := DB.Find(&tasks).Error + return tasks, err +} + +func (dao *IngestionTaskletDAO) ListByUserID(userID string) ([]*entity.IngestionTasklet, error) { + var tasks []*entity.IngestionTasklet + err := DB.Where("user_id = ?", userID).Find(&tasks).Error + return tasks, err +} + +func (dao *IngestionTaskletDAO) GetByID(id string) (*entity.IngestionTasklet, error) { + var task *entity.IngestionTasklet + err := DB.Where("id = ?", id).First(&task).Error + return task, err +} + +type IngestionTaskletLogDAO struct{} + +func NewIngestionTaskletLogDAO() *IngestionTaskletLogDAO { + return &IngestionTaskletLogDAO{} +} + +func (dao *IngestionTaskletLogDAO) Create(ingestionLog *entity.IngestionTaskletLog) error { + return DB.Create(ingestionLog).Error +} + +func (dao *IngestionTaskletLogDAO) ListLogsByTaskletID(taskID string) ([]*entity.IngestionTaskletLog, error) { + var tasks []*entity.IngestionTaskletLog + err := DB.Where("task_id = ?", taskID).Find(&tasks).Error + return tasks, err +} + +func (dao *IngestionTaskletLogDAO) GetLogByLogID(logID string) (*entity.IngestionTaskletLog, error) { + var task *entity.IngestionTaskletLog + err := DB.Where("id = ?", logID).First(&task).Error + return task, err +} + +func (dao *IngestionTaskletLogDAO) LatestLogByTaskletID(taskletID string) (*entity.IngestionTaskletLog, error) { + var tasklet *entity.IngestionTaskletLog + err := DB.Where("tasklet_id = ?", taskletID).Order("create_time DESC").First(&tasklet).Error + return tasklet, err +} + +func (dao *IngestionTaskletLogDAO) DeleteByTaskletID(taskID string) (int64, error) { + result := DB.Unscoped().Where("task_id = ?", taskID).Delete(&entity.IngestionTaskletLog{}) + return result.RowsAffected, result.Error +} diff --git a/internal/engine/engine.go b/internal/engine/engine.go index 6cd02d4cd3..789fc7c29b 100644 --- a/internal/engine/engine.go +++ b/internal/engine/engine.go @@ -18,7 +18,7 @@ package engine import ( "context" - + "ragflow/internal/common" "ragflow/internal/engine/types" ) @@ -87,3 +87,12 @@ func Type(docEngine DocEngine) EngineType { // or rely on configuration to know the type return EngineType("unknown") } + +type MessageQueue interface { + Init() error + InitConsumer(subject string) error + PublishTask(subject string, payload []byte) error + GetMessages(messageCount int) ([]common.TaskHandle, error) + ListMessages(messageType string, pending bool) ([]map[string]string, error) + ShowMessageQueue() (map[string]string, error) +} diff --git a/internal/engine/global.go b/internal/engine/global.go index baf178e61f..151de6fef4 100644 --- a/internal/engine/global.go +++ b/internal/engine/global.go @@ -19,19 +19,21 @@ package engine import ( "fmt" "ragflow/internal/common" + "ragflow/internal/engine/nats" "ragflow/internal/server" "sync" - "go.uber.org/zap" - "ragflow/internal/engine/elasticsearch" "ragflow/internal/engine/infinity" + + "go.uber.org/zap" ) var ( - globalEngine DocEngine - engineType EngineType - once sync.Once + globalEngine DocEngine + engineType EngineType + messageQueueEngine MessageQueue + once sync.Once ) // Init initializes document engine @@ -75,3 +77,22 @@ func Close() error { } return nil } + +func GetMessageQueueEngine() MessageQueue { + return messageQueueEngine +} + +func InitMessageQueueEngine(messageQueueType string) error { + config := server.GetConfig() + switch messageQueueType { + case "nats": + messageQueueEngine = nats.NewNatsEngine(config.Nats.Host, config.Nats.Port) + err := messageQueueEngine.Init() + if err != nil { + return err + } + default: + return fmt.Errorf("unsupported message queue type: %s", messageQueueType) + } + return nil +} diff --git a/internal/engine/nats/nats.go b/internal/engine/nats/nats.go new file mode 100644 index 0000000000..8e0ebf27de --- /dev/null +++ b/internal/engine/nats/nats.go @@ -0,0 +1,246 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package nats + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "ragflow/internal/common" + "strconv" + "strings" + "time" + + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" +) + +type NatsEngine struct { + host string + port int + nc *nats.Conn + jetStream jetstream.JetStream + stream jetstream.Stream + consumer jetstream.Consumer +} + +func NewNatsEngine(host string, port int) *NatsEngine { + return &NatsEngine{ + host: host, + port: port, + } +} + +func (n *NatsEngine) Init() error { + var err error + n.nc, err = nats.Connect(nats.DefaultURL) + if err != nil { + return fmt.Errorf("failed to connect to NATS: %w", err) + } + + n.jetStream, err = jetstream.New(n.nc) + if err != nil { + n.nc.Close() + return fmt.Errorf("failed to create JetStream context: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + streamCfg := jetstream.StreamConfig{ + Name: "RAGFLOW_TASKS", + Subjects: []string{"tasks.>"}, + Retention: jetstream.WorkQueuePolicy, + Storage: jetstream.FileStorage, + MaxMsgs: 1024 * 128, + MaxBytes: 1024 * 1024, + } + + n.stream, err = n.jetStream.CreateStream(ctx, streamCfg) + if err != nil { + if err.Error() != "stream already exists" { + n.nc.Close() + return fmt.Errorf("fail to create stream: %w", err) + } + + common.Info("NATS stream already exists, use existing stream") + n.stream, err = n.jetStream.Stream(ctx, "RAGFLOW_TASKS") + if err != nil { + n.nc.Close() + return fmt.Errorf("fail to get existing stream: %w", err) + } + } else { + common.Info("NATS stream create successfully") + } + + return nil +} + +func (n *NatsEngine) PublishTask(subject string, payload []byte) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + ack, err := n.jetStream.Publish(ctx, subject, payload) + if err != nil { + return err + } + common.Info(fmt.Sprintf("Task published, stream seq: %d", ack.Sequence)) + return nil +} + +func (n *NatsEngine) ShowMessageQueue() (map[string]string, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + accountInfo, err := n.jetStream.AccountInfo(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get account info: %w", err) + } + result := make(map[string]string) + result["consumer_count"] = strconv.Itoa(accountInfo.Consumers) + result["memory"] = strconv.FormatUint(accountInfo.Memory, 10) + + subjectFilter := "tasks.>" + info, err := n.stream.Info(ctx, jetstream.WithSubjectFilter(subjectFilter)) + if err != nil { + return nil, fmt.Errorf("failed to get stream info: %w", err) + } + result["message_count"] = strconv.FormatUint(info.State.Msgs, 10) + + consumer, err := n.stream.Consumer(ctx, "RAGFLOW_CONSUMER") + if err != nil { + return nil, fmt.Errorf("failed to get existing consumer: %w", err) + } + + consumerInfo, err := consumer.Info(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get consumer info: %w", err) + } + result["pending_count"] = strconv.FormatUint(consumerInfo.NumPending, 10) + result["waiting_count"] = strconv.Itoa(consumerInfo.NumWaiting) + result["ack_pending_count"] = strconv.Itoa(consumerInfo.NumAckPending) + result["redelivered_count"] = strconv.Itoa(consumerInfo.NumRedelivered) + return result, nil +} + +func (n *NatsEngine) ListMessages(messageType string, pending bool) ([]map[string]string, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if n.stream == nil { + return nil, fmt.Errorf("NATS stream not initialized") + } + + subjectFilter := "tasks.>" + + info, err := n.stream.Info(ctx, jetstream.WithSubjectFilter(subjectFilter)) + if err != nil { + return nil, fmt.Errorf("failed to get stream info: %w", err) + } + + if info.State.Msgs == 0 { + return nil, nil + } + + var messages []map[string]string + seq := info.State.FirstSeq + lastSeq := info.State.LastSeq + + for seq <= lastSeq { + var msg *jetstream.RawStreamMsg + msg, err = n.stream.GetMsg(ctx, seq, jetstream.WithGetMsgSubject(subjectFilter)) + if err != nil { + if errors.Is(err, jetstream.ErrMsgNotFound) { + break + } + return nil, fmt.Errorf("failed to get message at seq %d: %w", seq, err) + } + messageMap := make(map[string]string) + messageMap["subject"] = msg.Subject + messageMap["message"] = string(msg.Data) + messages = append(messages, messageMap) + seq = msg.Sequence + 1 + } + + common.Info(fmt.Sprintf("Listed %d messages for subject: %s", len(messages), subjectFilter)) + return messages, nil +} + +func (n *NatsEngine) InitConsumer(subject string) error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + var err error + n.consumer, err = n.stream.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{ + Name: "RAGFLOW_CONSUMER", + AckPolicy: jetstream.AckExplicitPolicy, + MaxDeliver: 16, + MaxAckPending: 1024 * 128, + FilterSubject: "tasks.>", + }) + if err != nil { + // MaxAckPending is immutable after consumer creation. + // If the consumer already exists, fall back to fetching it. + if strings.Contains(err.Error(), "max waiting can not be updated") { + n.consumer, err = n.stream.Consumer(ctx, "RAGFLOW_CONSUMER") + if err != nil { + return fmt.Errorf("failed to get existing consumer: %w", err) + } + } else { + return fmt.Errorf("failed to create Consumer: %w", err) + } + } + return nil +} +func (n *NatsEngine) GetMessages(messageCount int) ([]common.TaskHandle, error) { + resultMessages := make([]common.TaskHandle, 0) + messages, err := n.consumer.Fetch(messageCount, jetstream.FetchMaxWait(1*time.Second)) + if err != nil { + return nil, fmt.Errorf("failed to fetch messages: %w", err) + } + for msg := range messages.Messages() { + resultMessages = append(resultMessages, NewNatsMessageHandle(msg)) + } + return resultMessages, nil +} + +type NatsMessageHandle struct { + message jetstream.Msg +} + +func NewNatsMessageHandle(message jetstream.Msg) *NatsMessageHandle { + return &NatsMessageHandle{ + message: message, + } +} + +func (m *NatsMessageHandle) GetMessage() common.TaskMessage { + // convert to task message + var taskMessage common.TaskMessage + if err := json.Unmarshal(m.message.Data(), &taskMessage); err != nil { + common.Error("failed to unmarshal message", err) + } + return taskMessage +} + +func (m *NatsMessageHandle) Ack() error { + return m.message.Ack() +} + +func (m *NatsMessageHandle) Nack() error { + return m.message.Nak() +} diff --git a/internal/entity/base.go b/internal/entity/base.go index 774be466f3..d2aabb5e7d 100644 --- a/internal/entity/base.go +++ b/internal/entity/base.go @@ -119,6 +119,20 @@ func (m *BaseModel) BeforeUpdate(tx *gorm.DB) error { return nil } +func (m *BaseModel) UpdateCreateDateAndTime() error { + timestamp, dateTime := autoModelTime() + m.CreateTime = ×tamp + m.UpdateDate = &dateTime + return nil +} + +func (m *BaseModel) UpdateUpdateDateAndTime() error { + timestamp, dateTime := autoModelTime() + m.UpdateTime = ×tamp + m.UpdateDate = &dateTime + return nil +} + // JSONMap is a map type that can store JSON data type JSONMap map[string]interface{} diff --git a/internal/entity/ingestion_task.go b/internal/entity/ingestion_task.go index d061ac65c2..0ee5f19ca3 100644 --- a/internal/entity/ingestion_task.go +++ b/internal/entity/ingestion_task.go @@ -18,10 +18,11 @@ package entity type IngestionTask struct { ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + UserID string `gorm:"column:user_id;size:32;not null" json:"user_id"` DocumentID string `gorm:"column:document_id;size:32;not null;index" json:"document_id"` - UserID string `gorm:"column:user_id;size:32;not null;" json:"user_id"` - Config JSONMap `gorm:"column:config;type:longtext;not null" json:"config"` - TryCount int `gorm:"column:try_count;type:int;default:0" json:"try_count"` + DatasetID string `gorm:"column:dataset_id;size:32;not null" json:"dataset_id"` + Schema JSONMap `gorm:"column:schema;type:longtext" json:"schema"` + Status string `gorm:"column:status;size:32;not null;" json:"status"` BaseModel } @@ -30,24 +31,10 @@ func (IngestionTask) TableName() string { return "ingestion_task" } -type IngestionTasklet struct { - ID string `gorm:"column:id;primaryKey;size:32" json:"id"` - TaskID string `gorm:"column:task_id;size:32;not null;index" json:"task_id"` - Config JSONMap `gorm:"column:config;type:longtext;not null" json:"config"` - TryCount int `gorm:"column:try_count;type:int;default:0" json:"try_count"` - BaseModel -} - -// TableName specify table name -func (IngestionTasklet) TableName() string { - return "ingestion_tasklet" -} - type IngestionTaskLog struct { ID int `gorm:"column:id;primaryKey;autoIncrement" json:"id"` TaskID string `gorm:"column:task_id;size:32;not null;index" json:"task_id"` - Stage int `gorm:"column:stage;type:int;default:0;not null;" json:"stage"` - DataSchema JSONMap `gorm:"column:config;type:longtext;not null" json:"data_schema"` + Checkpoint JSONMap `gorm:"column:checkpoint;type:longtext;not null" json:"checkpoint"` BaseModel } @@ -56,11 +43,23 @@ func (IngestionTaskLog) TableName() string { return "ingestion_task_log" } +type IngestionTasklet struct { + ID string `gorm:"column:id;primaryKey;size:32" json:"id"` + TaskID string `gorm:"column:task_id;size:32;not null;index" json:"task_id"` + Schema JSONMap `gorm:"column:schema;type:longtext" json:"schema"` + Status string `gorm:"column:status;size:32;not null;" json:"status"` + BaseModel +} + +// TableName specify table name +func (IngestionTasklet) TableName() string { + return "ingestion_tasklet" +} + type IngestionTaskletLog struct { ID int `gorm:"column:id;primaryKey;autoIncrement" json:"id"` - TaskletID string `gorm:"column:tasklet_id;size:32;not null;index" json:"task_id"` - Stage int `gorm:"column:stage;type:int;default:0;not null;" json:"stage"` - DataSchema JSONMap `gorm:"column:config;type:longtext;not null" json:"data_schema"` + TaskletID string `gorm:"column:tasklet_id;size:32;not null;index" json:"tasklet_id"` + Checkpoint JSONMap `gorm:"column:checkpoint;type:longtext;not null" json:"checkpoint"` BaseModel } diff --git a/internal/handler/document.go b/internal/handler/document.go index dcbcbf5670..874542ec4d 100644 --- a/internal/handler/document.go +++ b/internal/handler/document.go @@ -59,6 +59,10 @@ type documentServiceIface interface { GetDocumentArtifact(filename string) (*service.ArtifactResponse, error) GetDocumentPreview(docID string) (*service.DocumentPreview, error) DownloadDocument(datasetID, docID string) (*service.DownloadDocumentResp, error) + ListIngestionTasks(userID string, datasetID *string, page, pageSize int) ([]*entity.IngestionTask, error) + IngestDocuments(datasetID, userID string, docIDs []string) ([]*service.ParseDocumentResponse, error) + StopIngestionTasks(tasks []string, userID string) ([]*entity.IngestionTask, error) + RemoveIngestionTasks(tasks []string, userID string) ([]map[string]string, error) } // DocumentHandler document handler @@ -68,7 +72,7 @@ type DocumentHandler struct { } // NewDocumentHandler create document handler -func NewDocumentHandler(documentService *service.DocumentService, datasetService *service.DatasetService) *DocumentHandler { +func NewDocumentHandler(documentService documentServiceIface, datasetService *service.DatasetService) *DocumentHandler { return &DocumentHandler{ documentService: documentService, datasetService: datasetService, @@ -909,6 +913,143 @@ func (h *DocumentHandler) DeleteMeta(c *gin.Context) { }) } +type ListIngestionsRequest struct { + DatasetID *string `json:"dataset_id"` +} + +func (h *DocumentHandler) ListIngestionTasks(c *gin.Context) { + var req ListIngestionsRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + }) + return + } + + userID := c.GetString("user_id") + + var parseResult []*entity.IngestionTask + var err error + if req.DatasetID != nil { + if !h.datasetService.Accessible(*req.DatasetID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization to access the dataset.") + return + } + } + + parseResult, err = h.documentService.ListIngestionTasks(userID, req.DatasetID, 0, 0) + if err != nil { + jsonError(c, common.CodeExceptionError, err.Error()) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + "data": parseResult, + }) +} + +type StartParseDocumentsRequest struct { + DatasetID string `json:"dataset_id" binding:"required"` + Documents []string `json:"documents" binding:"required"` +} + +func (h *DocumentHandler) StartIngestionTask(c *gin.Context) { + var req StartParseDocumentsRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + }) + return + } + + userID := c.GetString("user_id") + + if !h.datasetService.Accessible(req.DatasetID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization to access the dataset.") + return + } + + parseResult, err := h.documentService.IngestDocuments(req.DatasetID, userID, req.Documents) + if err != nil { + jsonError(c, common.CodeExceptionError, err.Error()) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + "data": parseResult, + }) +} + +type StopIngestionsRequest struct { + Tasks []string `json:"tasks" binding:"required"` +} + +func (h *DocumentHandler) StopIngestionTasks(c *gin.Context) { + var req StopIngestionsRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + }) + return + } + + userID := c.GetString("user_id") + + parseResult, err := h.documentService.StopIngestionTasks(req.Tasks, userID) + if err != nil { + jsonError(c, common.CodeExceptionError, err.Error()) + return + } + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + "data": parseResult, + }) +} + +type RemoveIngestionsRequest struct { + Tasks []string `json:"tasks" binding:"required"` +} + +func (h *DocumentHandler) RemoveIngestionTasks(c *gin.Context) { + var req RemoveIngestionsRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + }) + return + } + + if req.Tasks == nil || len(req.Tasks) == 0 { + c.JSON(http.StatusOK, gin.H{ + "code": 1, + "message": "task_ids is required", + }) + return + } + + userID := c.GetString("user_id") + + deletedTasks, err := h.documentService.RemoveIngestionTasks(req.Tasks, userID) + if err != nil { + jsonError(c, common.CodeExceptionError, err.Error()) + return + } + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + "data": deletedTasks, + }) +} + type ParseDocumentRequest struct { Documents []string `json:"documents" binding:"required"` } diff --git a/internal/handler/document_test.go b/internal/handler/document_test.go index 5ac370f6f5..6c2c5d314f 100644 --- a/internal/handler/document_test.go +++ b/internal/handler/document_test.go @@ -134,6 +134,19 @@ func (f *fakeDocumentService) GetDocumentMetadataByID(docID string) (map[string] return nil, nil } +func (f *fakeDocumentService) ListIngestionTasks(userID string, datasetID *string, page, pageSize int) ([]*entity.IngestionTask, error) { + return nil, nil +} +func (f *fakeDocumentService) IngestDocuments(datasetID, userID string, docIDs []string) ([]*service.ParseDocumentResponse, error) { + return nil, nil +} +func (f *fakeDocumentService) StopIngestionTasks(tasks []string, userID string) ([]*entity.IngestionTask, error) { + return nil, nil +} +func (f *fakeDocumentService) RemoveIngestionTasks(tasks []string, userID string) ([]map[string]string, error) { + return nil, nil +} + func setupGinContextWithUser(method, path, body string) (*gin.Context, *httptest.ResponseRecorder) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() diff --git a/internal/ingestion/ingestion_service.go b/internal/ingestion/ingestion_service.go index e29f13aa43..1947332cd0 100644 --- a/internal/ingestion/ingestion_service.go +++ b/internal/ingestion/ingestion_service.go @@ -2,31 +2,24 @@ package ingestion import ( "context" + "errors" "fmt" - "log" - "math" - "os" + "ragflow/internal/dao" + "ragflow/internal/engine" + "ragflow/internal/entity" "sync" "time" - "github.com/shirou/gopsutil/v3/process" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" - "ragflow/internal/common" -) -type taskWrapper struct { - task *common.TaskAssignment -} + "google.golang.org/grpc" +) type Ingestor struct { id string name string serverAddr string conn *grpc.ClientConn - client common.IngestionManagerClient - stream common.IngestionManager_ActionClient ctx context.Context cancel context.CancelFunc reconnectMu sync.Mutex @@ -47,295 +40,432 @@ type Ingestor struct { taskChan chan *TaskContext workerWg sync.WaitGroup startOnce sync.Once + + ingestionTaskDAO *dao.IngestionTaskDAO + ingestionTaskLogDAO *dao.IngestionTaskLogDAO + ingestionTaskletDAO *dao.IngestionTaskletDAO + ingestionTaskletLogDAO *dao.IngestionTaskletLogDAO +} + +type TaskLog struct { + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` + Description string `json:"description"` + Details map[string]interface{} `json:"details"` } type TaskContext struct { - Ctx context.Context - CancelFunc context.CancelFunc - Task *common.TaskAssignment - Status string // PENDING, RUNNING, COMPLETED, FAILED, CANCELLING, CANCELLED - StartTime time.Time - EndTime time.Time + Ctx context.Context + CancelFunc context.CancelFunc + // if tasklet is nil, this context is belonged to a task + // if task and tasklet are both not nil, this context is belonged to a tasklet, the task is the parent task of the tasklet + Task *entity.IngestionTask + Tasklet *entity.IngestionTasklet + Logs []*TaskLog estimatedRemainingTime time.Duration // estimated cost in seconds to complete the task Progress int32 ErrorMessage string + TaskHandle common.TaskHandle } func NewIngestor(name string, maxConcurrency int32, supportedTypes []string) *Ingestor { ctx, cancel := context.WithCancel(context.Background()) id := common.GenerateUUID() return &Ingestor{ - id: id, - name: name, - ctx: ctx, - cancel: cancel, - maxConcurrency: maxConcurrency, - supportedDocTypes: supportedTypes, - version: "1.0.0", - currentTasks: make(map[string]*TaskContext), - taskChan: make(chan *TaskContext, maxConcurrency*2), - ShutdownCh: make(chan struct{}, 1), + id: id, + name: name, + ctx: ctx, + cancel: cancel, + maxConcurrency: maxConcurrency, + supportedDocTypes: supportedTypes, + version: "1.0.0", + currentTasks: make(map[string]*TaskContext), + taskChan: make(chan *TaskContext, maxConcurrency*2), + ShutdownCh: make(chan struct{}, 1), + ingestionTaskDAO: dao.NewIngestionTaskDAO(), + ingestionTaskLogDAO: dao.NewIngestionTaskLogDAO(), + ingestionTaskletDAO: dao.NewIngestionTaskletDAO(), + ingestionTaskletLogDAO: dao.NewIngestionTaskletLogDAO(), } } -// Connect connects to the admin and establishes a bidirectional stream -func (e *Ingestor) Connect(serverAddr string) error { - e.serverAddr = serverAddr - conn, err := grpc.Dial(serverAddr, - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithBlock(), - grpc.WithTimeout(5*time.Second), - ) +func (e *Ingestor) ID() string { + return e.id +} + +func (e *Ingestor) Start() error { + common.Info(fmt.Sprintf("Ingestor %s initialized", e.id)) + msgQueueEngine := engine.GetMessageQueueEngine() + err := msgQueueEngine.InitConsumer("tasks.RAGFLOW") if err != nil { - return fmt.Errorf("fail to connect admin server: %s", err.Error()) - } - e.conn = conn - - e.client = common.NewIngestionManagerClient(conn) - - stream, err := e.client.Action(e.ctx) - if err != nil { - conn.Close() - return err - } - e.stream = stream - - common.Info(fmt.Sprintf("Ingestor %s connected to admin", e.id)) - - // 1. Send registration message - if err = e.sendRegister(); err != nil { - conn.Close() return err } // Ensure worker pool is started on first task - e.startWorkerPool() + go e.startWorkerPool() - // 2. Start receive loop - go e.receiveLoop() - - // 3. Start heartbeat loop - go e.heartbeatLoop() - - return nil -} - -func (e *Ingestor) sendRegister() error { - msg := &common.IngestionMessage{ - IngestorId: e.id, - MessageType: "REGISTER", - RegisterInfo: &common.RegisterInfo{ - MaxConcurrency: e.maxConcurrency, - SupportedDocTypes: e.supportedDocTypes, - Version: e.version, - Name: e.name, - }, - } - return e.stream.Send(msg) -} - -func (e *Ingestor) sendHeartbeat() error { - e.tasksMu.RLock() - - cutoff := time.Now().Add(-10 * time.Minute) - var toDeleteTask []string - taskStates := make([]*common.TaskState, 0, len(e.currentTasks)) - - for tid, tc := range e.currentTasks { - // Check if task is in a terminal state and expired beyond 10 minutes - if (tc.Status == "CANCELED" || tc.Status == "COMPLETED" || tc.Status == "REJECTED") && - !tc.EndTime.IsZero() && tc.EndTime.Before(cutoff) { - toDeleteTask = append(toDeleteTask, tid) - } else { - taskStates = append(taskStates, &common.TaskState{ - TaskId: tid, - Status: tc.Status, - EstimatedRemainingTimeSeconds: int64(tc.estimatedRemainingTime), - ErrorMessage: tc.ErrorMessage, - StartTime: tc.StartTime.UnixNano(), - ComeFrom: tc.Task.ComeFrom, - }) - } - } - e.tasksMu.RUnlock() - - // Delete expired terminal tasks from currentTasks - if len(toDeleteTask) > 0 { - e.tasksMu.Lock() - for _, id := range toDeleteTask { - delete(e.currentTasks, id) - } - e.tasksMu.Unlock() - } - - var pid = int64(os.Getpid()) - p, err := process.NewProcess(int32(pid)) - if err != nil { - log.Fatal(err) - } - - var cpuPercent float64 - cpuPercent, err = p.Percent(100 * time.Millisecond) - if err != nil { - cpuPercent = math.NaN() - common.Info(fmt.Sprintf("Fail to read CPU usage: %v", err)) - } - - RssUsage := math.NaN() - VmsUsage := math.NaN() - memInfo, err := p.MemoryInfo() - if err == nil { - RssUsage = float64(memInfo.RSS) - VmsUsage = float64(memInfo.VMS) - } else { - common.Info(fmt.Sprintf("Fail to read memory usage: %v", err)) - } - msg := &common.IngestionMessage{ - IngestorId: e.id, - MessageType: "HEARTBEAT", - HeartbeatInfo: &common.HeartbeatInfo{ - TaskStates: taskStates, - DeleteTaskIds: toDeleteTask, - CpuUsage: float32(cpuPercent), - VmsUsage: float32(VmsUsage), - RssUsage: float32(RssUsage), - ProcessId: pid, - }, - } - return e.stream.Send(msg) -} - -func (e *Ingestor) sendTaskResult(taskID, status, errorMsg string) error { - msg := &common.IngestionMessage{ - IngestorId: e.id, - MessageType: "TASK_RESULT", - TaskResult: &common.TaskResult{ - TaskId: taskID, - Status: status, - ErrorMessage: errorMsg, - }, - } - return e.stream.Send(msg) -} - -func (e *Ingestor) sendTaskProgress(taskID string, progress int32, info string) error { - msg := &common.IngestionMessage{ - IngestorId: e.id, - MessageType: "TASK_PROGRESS", - TaskProgress: &common.TaskProgress{ - TaskId: taskID, - Progress: progress, - Info: info, - }, - } - return e.stream.Send(msg) -} - -func (e *Ingestor) receiveLoop() { for { - msg, err := e.stream.Recv() + var taskHandles []common.TaskHandle + taskHandles, err = msgQueueEngine.GetMessages(4) if err != nil { - if e.ctx.Err() != nil { - common.Info(fmt.Sprintf("Ingestor %s context cancelled, receive loop exiting", e.id)) - return + common.Error("error consuming message", err) + continue + } + for _, taskHandle := range taskHandles { + taskMessage := taskHandle.GetMessage() + common.Info(fmt.Sprintf("Received task id: %s, type: %s", taskMessage.TaskID, taskMessage.TaskType)) + if taskMessage.TaskType != common.TaskTypeIngestionTask { + common.Info(fmt.Sprintf("task %s is not an ingestion task", taskMessage.TaskID)) + err = taskHandle.Ack() + if err != nil { + common.Error(fmt.Sprintf("error ack task %s", taskMessage.TaskID), err) + return err + } + continue + } + var task *entity.IngestionTask + task, err = e.ingestionTaskDAO.SetRunningByIngestor(taskMessage.TaskID) + if err != nil { + if errors.Is(err, common.ErrTaskNotFound) { + common.Warn(fmt.Sprintf("task %s not found, skipping", taskMessage.TaskID)) + err = taskHandle.Ack() + if err != nil { + common.Error(fmt.Sprintf("error ack task %s", taskMessage.TaskID), err) + return err + } + continue + } else { + common.Error(fmt.Sprintf("error setting task %s to running", taskMessage.TaskID), err) + return err + } + } + if task == nil { + common.Info(fmt.Sprintf("task %s is already removed", taskMessage.TaskID)) + err = taskHandle.Ack() + if err != nil { + return err + } + continue } - common.Info(fmt.Sprintf("Receive error: %v", err)) - common.Info("Admin connection lost, attempting to reconnect") - e.reconnect() - return - } - switch msg.MessageType { - case "TASK_ASSIGNMENT": - e.handleTaskAssignment(msg.TaskAssignment) + switch task.Status { + case common.COMPLETED, common.STOPPED, common.FAILED: + common.Info(fmt.Sprintf("task %s is already %s", taskMessage.TaskID, task.Status)) + err = taskHandle.Ack() + if err != nil { + common.Error(fmt.Sprintf("error nack task %s", taskMessage.TaskID), err) + return err + } + continue + case common.STOPPING, common.CREATED: + err = fmt.Errorf("task %s is in unexpected status %s", taskMessage.TaskID, task.Status) + return err + case common.RUNNING: + } - case "ACK": - common.Info(fmt.Sprintf("Received ACK: task=%s, success=%v, msg=%s", - msg.AckInfo.TaskId, msg.AckInfo.Success, msg.AckInfo.Message)) + // Construct TaskContext with a cancellable context + ctx, cancel := context.WithCancel(e.ctx) + taskCtx := &TaskContext{ + Ctx: ctx, + CancelFunc: cancel, + Task: task, + TaskHandle: taskHandle, + } - case "ERROR": - common.Info(fmt.Sprintf("Received error from admin: %s", msg.ErrorMessage)) + // Register in currentTasks immediately so heartbeat sees PENDING state + //e.tasksMu.Lock() + //e.currentTasks[task.ID] = taskCtx + //e.tasksMu.Unlock() - default: - common.Info(fmt.Sprintf("Unknown admin message type: %s", msg.MessageType)) + // Push to task channel; if full, reject the task (backpressure) + select { + case e.taskChan <- taskCtx: + common.Info(fmt.Sprintf("Task %s queued (channel: %d/%d)", task.ID, len(e.taskChan), cap(e.taskChan))) + default: + common.Info(fmt.Sprintf("No available slot for task %s, failed", task.ID)) + + //e.tasksMu.Lock() + //delete(e.currentTasks, task.ID) + //e.tasksMu.Unlock() + + err = taskHandle.Nack() + if err != nil { + common.Error(fmt.Sprintf("error nack task %s", taskMessage.TaskID), err) + return err + } + } } } } -func (e *Ingestor) handleTaskAssignment(task *common.TaskAssignment) { - if task == nil { - return - } +//// Connect connects to the admin and establishes a bidirectional stream +//func (e *Ingestor) Connect(serverAddr string) error { +// e.serverAddr = serverAddr +// conn, err := grpc.Dial(serverAddr, +// grpc.WithTransportCredentials(insecure.NewCredentials()), +// grpc.WithBlock(), +// grpc.WithTimeout(5*time.Second), +// ) +// if err != nil { +// return fmt.Errorf("fail to connect admin server: %s", err.Error()) +// } +// e.conn = conn +// +// e.client = common.NewIngestionManagerClient(conn) +// +// stream, err := e.client.Action(e.ctx) +// if err != nil { +// conn.Close() +// return err +// } +// e.stream = stream +// +// common.Info(fmt.Sprintf("Ingestor %s connected to admin", e.id)) +// +// // 1. Send registration message +// if err = e.sendRegister(); err != nil { +// conn.Close() +// return err +// } +// +// // Ensure worker pool is started on first task +// e.startWorkerPool() +// +// // 2. Start receive loop +// go e.receiveLoop() +// +// // 3. Start heartbeat loop +// go e.heartbeatLoop() +// +// return nil +//} - common.Info(fmt.Sprintf("Received task: %s, task_type=%s", task.TaskId, task.TaskType)) - - switch task.TaskType { - case "shutdown_ingestor": - if e.id == task.AssignedTo { - e.handleShutdownIngestor() - return - } - - common.Error("unmatched ingestor id", fmt.Errorf("attempt to shutdown ingestor: %s, current ingestor: %s, mismatched", task.AssignedTo, e.id)) - return - case "cancel_ingestion_task": - e.handleCancelTask(task.TaskId) - return - } - - // Construct TaskContext with a cancellable context - ctx, cancel := context.WithCancel(e.ctx) - taskCtx := &TaskContext{ - Ctx: ctx, - CancelFunc: cancel, - Task: task, - Status: "QUEUED", - } - - // Register in currentTasks immediately so heartbeat sees PENDING state - e.tasksMu.Lock() - e.currentTasks[task.TaskId] = taskCtx - e.tasksMu.Unlock() - - common.Info("wait for 10 seconds") - time.Sleep(time.Second * 10) - // Push to task channel; if full, reject the task (backpressure) - select { - case e.taskChan <- taskCtx: - common.Info(fmt.Sprintf("Task %s queued (channel: %d/%d)", task.TaskId, len(e.taskChan), cap(e.taskChan))) - default: - common.Info(fmt.Sprintf("No available slot for task %s, rejecting", task.TaskId)) - //e.tasksMu.Lock() - //delete(e.currentTasks, task.TaskId) - //e.tasksMu.Unlock() - taskCtx.Status = "REJECTED" - taskCtx.EndTime = time.Now() - e.sendTaskResult(taskCtx.Task.TaskId, "REJECTED", "task rejected before execution") - } -} - -func (e *Ingestor) handleCancelTask(taskID string) { - e.tasksMu.Lock() - taskCtx, exists := e.currentTasks[taskID] - e.tasksMu.Unlock() - - if !exists { - common.Info(fmt.Sprintf("Cancel request for unknown task %s, ignoring", taskID)) - return - } - - common.Info(fmt.Sprintf("Cancelling task %s (current status: %s)", taskID, taskCtx.Status)) - taskCtx.CancelFunc() -} - -func (e *Ingestor) handleShutdownIngestor() { - common.Info(fmt.Sprintf("Shutdown task received, initiating graceful shutdown of ingestor %s", e.id)) - select { - case e.ShutdownCh <- struct{}{}: - default: - } - return -} +//func (e *Ingestor) sendRegister() error { +// msg := &common.IngestionMessage{ +// IngestorId: e.id, +// MessageType: "REGISTER", +// RegisterInfo: &common.RegisterInfo{ +// MaxConcurrency: e.maxConcurrency, +// SupportedDocTypes: e.supportedDocTypes, +// Version: e.version, +// Name: e.name, +// }, +// } +// return e.stream.Send(msg) +//} +// +//func (e *Ingestor) sendHeartbeat() error { +// e.tasksMu.RLock() +// +// cutoff := time.Now().Add(-10 * time.Minute) +// var toDeleteTask []string +// taskStates := make([]*common.TaskState, 0, len(e.currentTasks)) +// +// for tid, tc := range e.currentTasks { +// // Check if task is in a terminal state and expired beyond 10 minutes +// if (tc.Status == "CANCELED" || tc.Status == "COMPLETED" || tc.Status == "REJECTED") && +// !tc.EndTime.IsZero() && tc.EndTime.Before(cutoff) { +// toDeleteTask = append(toDeleteTask, tid) +// } else { +// taskStates = append(taskStates, &common.TaskState{ +// TaskId: tid, +// Status: tc.Status, +// EstimatedRemainingTimeSeconds: int64(tc.estimatedRemainingTime), +// ErrorMessage: tc.ErrorMessage, +// StartTime: tc.StartTime.UnixNano(), +// ComeFrom: tc.Task.ComeFrom, +// }) +// } +// } +// e.tasksMu.RUnlock() +// +// // Delete expired terminal tasks from currentTasks +// if len(toDeleteTask) > 0 { +// e.tasksMu.Lock() +// for _, id := range toDeleteTask { +// delete(e.currentTasks, id) +// } +// e.tasksMu.Unlock() +// } +// +// var pid = int64(os.Getpid()) +// p, err := process.NewProcess(int32(pid)) +// if err != nil { +// log.Fatal(err) +// } +// +// var cpuPercent float64 +// cpuPercent, err = p.Percent(100 * time.Millisecond) +// if err != nil { +// cpuPercent = math.NaN() +// common.Info(fmt.Sprintf("Fail to read CPU usage: %v", err)) +// } +// +// RssUsage := math.NaN() +// VmsUsage := math.NaN() +// memInfo, err := p.MemoryInfo() +// if err == nil { +// RssUsage = float64(memInfo.RSS) +// VmsUsage = float64(memInfo.VMS) +// } else { +// common.Info(fmt.Sprintf("Fail to read memory usage: %v", err)) +// } +// msg := &common.IngestionMessage{ +// IngestorId: e.id, +// MessageType: "HEARTBEAT", +// HeartbeatInfo: &common.HeartbeatInfo{ +// TaskStates: taskStates, +// DeleteTaskIds: toDeleteTask, +// CpuUsage: float32(cpuPercent), +// VmsUsage: float32(VmsUsage), +// RssUsage: float32(RssUsage), +// ProcessId: pid, +// }, +// } +// return e.stream.Send(msg) +//} +// +//func (e *Ingestor) sendTaskResult(taskID, status, errorMsg string) error { +// msg := &common.IngestionMessage{ +// IngestorId: e.id, +// MessageType: "TASK_RESULT", +// TaskResult: &common.TaskResult{ +// TaskId: taskID, +// Status: status, +// ErrorMessage: errorMsg, +// }, +// } +// return e.stream.Send(msg) +//} +// +//func (e *Ingestor) sendTaskProgress(taskID string, progress int32, info string) error { +// msg := &common.IngestionMessage{ +// IngestorId: e.id, +// MessageType: "TASK_PROGRESS", +// TaskProgress: &common.TaskProgress{ +// TaskId: taskID, +// Progress: progress, +// Info: info, +// }, +// } +// return e.stream.Send(msg) +//} +// +//func (e *Ingestor) receiveLoop() { +// for { +// msg, err := e.stream.Recv() +// if err != nil { +// if e.ctx.Err() != nil { +// common.Info(fmt.Sprintf("Ingestor %s context cancelled, receive loop exiting", e.id)) +// return +// } +// common.Info(fmt.Sprintf("Receive error: %v", err)) +// common.Info("Admin connection lost, attempting to reconnect") +// e.reconnect() +// return +// } +// +// switch msg.MessageType { +// case "TASK_ASSIGNMENT": +// e.handleTaskAssignment(msg.TaskAssignment) +// +// case "ACK": +// common.Info(fmt.Sprintf("Received ACK: task=%s, success=%v, msg=%s", +// msg.AckInfo.TaskId, msg.AckInfo.Success, msg.AckInfo.Message)) +// +// case "ERROR": +// common.Info(fmt.Sprintf("Received error from admin: %s", msg.ErrorMessage)) +// +// default: +// common.Info(fmt.Sprintf("Unknown admin message type: %s", msg.MessageType)) +// } +// } +//} +// +//func (e *Ingestor) handleTaskAssignment(task *common.TaskAssignment) { +// if task == nil { +// return +// } +// +// common.Info(fmt.Sprintf("Received task: %s, task_type=%s", task.TaskId, task.TaskType)) +// +// switch task.TaskType { +// case "shutdown_ingestor": +// if e.id == task.AssignedTo { +// e.handleShutdownIngestor() +// return +// } +// +// common.Error("unmatched ingestor id", fmt.Errorf("attempt to shutdown ingestor: %s, current ingestor: %s, mismatched", task.AssignedTo, e.id)) +// return +// case "cancel_ingestion_task": +// e.handleCancelTask(task.TaskId) +// return +// case "start_ingestion_task": +// // create ingestion task log +// err := e.ingestionTaskLogDAO.Create(&entity.IngestionTaskLog{ +// TaskID: task.TaskId, +// Action: "CREATED", +// }) +// if err != nil { +// common.Fatal(fmt.Sprintf("Failed to create ingestion task log for task %s: %v", task.TaskId, err)) +// return +// } +// } +// +// // Construct TaskContext with a cancellable context +// ctx, cancel := context.WithCancel(e.ctx) +// taskCtx := &TaskContext{ +// Ctx: ctx, +// CancelFunc: cancel, +// Task: task, +// Status: "QUEUED", +// } +// +// // Register in currentTasks immediately so heartbeat sees PENDING state +// e.tasksMu.Lock() +// e.currentTasks[task.TaskId] = taskCtx +// e.tasksMu.Unlock() +// +// common.Info("wait for 10 seconds") +// time.Sleep(time.Second * 10) +// // Push to task channel; if full, reject the task (backpressure) +// select { +// case e.taskChan <- taskCtx: +// common.Info(fmt.Sprintf("Task %s queued (channel: %d/%d)", task.TaskId, len(e.taskChan), cap(e.taskChan))) +// default: +// common.Info(fmt.Sprintf("No available slot for task %s, rejecting", task.TaskId)) +// //e.tasksMu.Lock() +// //delete(e.currentTasks, task.TaskId) +// //e.tasksMu.Unlock() +// taskCtx.Status = "REJECTED" +// taskCtx.EndTime = time.Now() +// e.sendTaskResult(taskCtx.Task.TaskId, "REJECTED", "task rejected before execution") +// } +//} +// +//func (e *Ingestor) handleCancelTask(taskID string) { +// e.tasksMu.Lock() +// taskCtx, exists := e.currentTasks[taskID] +// e.tasksMu.Unlock() +// +// if !exists { +// common.Info(fmt.Sprintf("Cancel request for unknown task %s, ignoring", taskID)) +// return +// } +// +// common.Info(fmt.Sprintf("Cancelling task %s (current status: %s)", taskID, taskCtx.Status)) +// taskCtx.CancelFunc() +//} +// +//func (e *Ingestor) handleShutdownIngestor() { +// common.Info(fmt.Sprintf("Shutdown task received, initiating graceful shutdown of ingestor %s", e.id)) +// select { +// case e.ShutdownCh <- struct{}{}: +// default: +// } +// return +//} func (e *Ingestor) startWorkerPool() { e.startOnce.Do(func() { @@ -355,26 +485,11 @@ func (e *Ingestor) workerLoop(id int32) { case <-e.ctx.Done(): return case taskCtx := <-e.taskChan: - // Skip tasks that were canceled while queued - select { - case <-taskCtx.Ctx.Done(): - common.Info(fmt.Sprintf("Task %s was cancelled while queued, skipping", taskCtx.Task.TaskId)) - taskCtx.Status = "CANCELED" - taskCtx.EndTime = time.Now() - //e.tasksMu.Lock() - //delete(e.currentTasks, taskCtx.Task.TaskId) - //e.tasksMu.Unlock() - e.sendTaskResult(taskCtx.Task.TaskId, "CANCELED", "task cancelled before execution") - continue - default: + if taskCtx.Tasklet != nil { + e.executeTasklet(taskCtx) + } else { + e.executeTask(taskCtx) } - - // Mark as RUNNING - taskCtx.Status = "RUNNING" - taskCtx.StartTime = time.Now() - - // Execute the task (synchronously within worker) - e.executeTask(taskCtx) } } } @@ -388,193 +503,225 @@ func (e *Ingestor) executeTask(taskCtx *TaskContext) { ctx := taskCtx.Ctx task := taskCtx.Task - common.Info(fmt.Sprintf("Starting task %s", task.TaskId)) + common.Info(fmt.Sprintf("Starting task %s", task.ID)) - // Simulate task execution progress - // In production, this would split into subtasks and execute in parallel - for progress := int32(0); progress <= 100; progress += 10 { + latestLog, err := e.ingestionTaskLogDAO.LatestLogByTaskID(task.ID) + if err != nil { + latestLog = &entity.IngestionTaskLog{ + ID: 0, + TaskID: task.ID, + Checkpoint: entity.JSONMap{ + "current_step": 1, + "total_step": 5, + }, + } + err = e.ingestionTaskLogDAO.Create(latestLog) + if err != nil { + common.Error(fmt.Sprintf("Failed to create task log for task %s", task.ID), err) + return + } + } + + var checkpointMap map[string]interface{} + checkpointMap = latestLog.Checkpoint + currentStep, ok := common.GetInt(checkpointMap["current_step"]) + if !ok { + common.Fatal(fmt.Sprintf("Failed to get current step from task log for task %s", task.ID)) + return + } + totalStep, ok := common.GetInt(checkpointMap["total_step"]) + if !ok { + common.Fatal(fmt.Sprintf("Failed to get current step from task log for task %s", task.ID)) + return + } + for i := currentStep; i < totalStep; i++ { select { case <-ctx.Done(): // Task canceled - common.Info(fmt.Sprintf("Task %s cancelled", task.TaskId)) - taskCtx.Status = "CANCELED" - taskCtx.EndTime = time.Now() - e.sendTaskResult(task.TaskId, "CANCELED", "task cancelled") + common.Info(fmt.Sprintf("Task %s stopped", task.ID)) return case <-time.After(5000 * time.Millisecond): - // Simulate progress update - taskCtx.Progress = progress - e.sendTaskProgress(task.TaskId, progress, "processing...") - common.Info(fmt.Sprintf("Task %s progress: %d%%", task.TaskId, progress)) - } - } + common.Info(fmt.Sprintf("Task %s is running step %d", task.ID, i)) + checkpointMap["current_step"] = i + 1 + latestLog.Checkpoint = checkpointMap + latestLog.ID++ + err = latestLog.UpdateCreateDateAndTime() + if err != nil { + common.Error(fmt.Sprintf("Failed to update date and time of task log for task %s", task.ID), err) + return + } - // Simulate subtask splitting and execution (demonstration) - e.executeWithSubTasks(task) - - taskCtx.Status = "COMPLETED" - taskCtx.EndTime = time.Now() - - time.Sleep(time.Second * 10) - - // Task completed - e.sendTaskResult(task.TaskId, "COMPLETED", "") - common.Info(fmt.Sprintf("Task %s completed", task.TaskId)) -} - -// executeWithSubTasks demonstrates subtask splitting and parallel execution -func (e *Ingestor) executeWithSubTasks(task *common.TaskAssignment) { - common.Info(fmt.Sprintf("Task %s: splitting into subtasks", task.TaskId)) - - // Simulate splitting into 4 subtasks - subTasks := []struct { - id string - index int - }{ - {task.TaskId + "-sub1", 1}, - {task.TaskId + "-sub2", 2}, - {task.TaskId + "-sub3", 3}, - {task.TaskId + "-sub4", 4}, - } - - // Wait for all subtasks to complete - var wg sync.WaitGroup - results := make(chan error, len(subTasks)) - - // Execute subtasks in parallel - for _, st := range subTasks { - wg.Add(1) - go func(subID string, idx int) { - defer wg.Done() - - common.Info(fmt.Sprintf("Subtask %s started", subID)) - // Simulate subtask execution - time.Sleep(1 * time.Second) - common.Info(fmt.Sprintf("Subtask %s completed", subID)) - results <- nil - }(st.id, st.index) - } - - // Wait for all subtasks to complete - wg.Wait() - close(results) - - // Check if any subtasks failed - failedCount := 0 - for err := range results { - if err != nil { - failedCount++ - } - } - - if failedCount > 0 { - common.Info(fmt.Sprintf("Task %s: %d subtasks failed", task.TaskId, failedCount)) - } else { - common.Info(fmt.Sprintf("Task %s: all subtasks completed successfully", task.TaskId)) - } -} - -func (e *Ingestor) heartbeatLoop() { - ticker := time.NewTicker(5 * time.Second) - defer ticker.Stop() - - for { - select { - case <-e.ctx.Done(): - return - case <-ticker.C: - if err := e.sendHeartbeat(); err != nil { - common.Info(fmt.Sprintf("Failed to send heartbeat: %v", err)) - if e.ctx.Err() != nil { - common.Info(fmt.Sprintf("Ingestor %s context cancelled, heartbeat loop exiting", e.id)) - return - } - common.Info(fmt.Sprintf("Admin connection lost, attempting to reconnect")) - e.reconnect() + err = e.ingestionTaskLogDAO.Create(latestLog) + if err != nil { + common.Error(fmt.Sprintf("Failed to create task log for task %s", task.ID), err) return } } } -} -// reconnect closes the old connection and establishes a new one with exponential backoff. -// Only one reconnection attempt runs at a time; concurrent callers return immediately. -func (e *Ingestor) reconnect() { - if e.ctx.Err() != nil { - common.Info(fmt.Sprintf("Ingestor %s is shutting down, skipping reconnection", e.id)) + err = e.ingestionTaskDAO.UpdateStatus(task.ID, common.COMPLETED) + if err != nil { + common.Error(fmt.Sprintf("Task %s update status failed", task.ID), err) return } - if !e.reconnectMu.TryLock() { + common.Info(fmt.Sprintf("Task %s completed", task.ID)) +} + +func (e *Ingestor) executeTasklet(taskCtx *TaskContext) { + ctx := taskCtx.Ctx + tasklet := taskCtx.Tasklet + common.Info(fmt.Sprintf("Starting tasklet %s", tasklet.ID)) + + latestLog, err := e.ingestionTaskletLogDAO.LatestLogByTaskletID(tasklet.ID) + if err != nil { + latestLog = &entity.IngestionTaskletLog{ + TaskletID: tasklet.ID, + Checkpoint: entity.JSONMap{ + "current_step": 0, + "total_step": 3, + }, + } + err = e.ingestionTaskletLogDAO.Create(latestLog) + if err != nil { + common.Error(fmt.Sprintf("Failed to create task log for tasklet %s", tasklet.ID), err) + return + } + } + + var checkpointMap map[string]interface{} + checkpointMap = latestLog.Checkpoint + currentStep := checkpointMap["current_step"].(int) + totalStep := checkpointMap["total_step"].(int) + for i := currentStep; i < totalStep; i++ { + select { + case <-ctx.Done(): + // Task canceled + common.Info(fmt.Sprintf("Tasklet %s stopped", tasklet.ID)) + return + case <-time.After(3000 * time.Millisecond): + common.Info(fmt.Sprintf("Tasklet %s is running step %d", tasklet.ID, i)) + checkpointMap["current_step"] = i + 1 + latestLog.Checkpoint = checkpointMap + err = e.ingestionTaskletLogDAO.Create(latestLog) + if err != nil { + common.Error(fmt.Sprintf("Failed to update task log for tasklet %s", tasklet.ID), err) + return + } + } + } + + err = e.ingestionTaskletDAO.UpdateStatus(tasklet.ID, common.STOPPED) + if err != nil { + common.Error(fmt.Sprintf("Tasklet %s update status failed", tasklet.ID), err) return } - defer e.reconnectMu.Unlock() - common.Info(fmt.Sprintf("Ingestor %s attempting to reconnect to admin at %s", e.id, e.serverAddr)) - - // Close old stream and connection - if e.stream != nil { - e.stream.CloseSend() - } - if e.conn != nil { - e.conn.Close() - } - - backoff := 1 * time.Second - maxBackoff := 30 * time.Second - - for { - conn, err := grpc.Dial(e.serverAddr, - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithBlock(), - grpc.WithTimeout(5*time.Second), - ) - if err != nil { - common.Info(fmt.Sprintf("Reconnect dial failed: %v, retrying in %v", err, backoff)) - time.Sleep(backoff) - backoff *= 2 - if backoff > maxBackoff { - backoff = maxBackoff - } - continue - } - e.conn = conn - e.client = common.NewIngestionManagerClient(conn) - - stream, err := e.client.Action(e.ctx) - if err != nil { - conn.Close() - common.Info(fmt.Sprintf("Reconnect create stream failed: %v, retrying in %v", err, backoff)) - time.Sleep(backoff) - backoff *= 2 - if backoff > maxBackoff { - backoff = maxBackoff - } - continue - } - e.stream = stream - - if err = e.sendRegister(); err != nil { - stream.CloseSend() - conn.Close() - common.Info(fmt.Sprintf("Reconnect register failed: %v, retrying in %v", err, backoff)) - time.Sleep(backoff) - backoff *= 2 - if backoff > maxBackoff { - backoff = maxBackoff - } - continue - } - - common.Info(fmt.Sprintf("Ingestor %s reconnected to admin", e.id)) - break - } - - // Restart the loops on the new stream - go e.receiveLoop() - go e.heartbeatLoop() + common.Info(fmt.Sprintf("Tasklet %s completed", tasklet.ID)) } +// +//func (e *Ingestor) heartbeatLoop() { +// ticker := time.NewTicker(5 * time.Second) +// defer ticker.Stop() +// +// for { +// select { +// case <-e.ctx.Done(): +// return +// case <-ticker.C: +// if err := e.sendHeartbeat(); err != nil { +// common.Info(fmt.Sprintf("Failed to send heartbeat: %v", err)) +// if e.ctx.Err() != nil { +// common.Info(fmt.Sprintf("Ingestor %s context cancelled, heartbeat loop exiting", e.id)) +// return +// } +// common.Info(fmt.Sprintf("Admin connection lost, attempting to reconnect")) +// e.reconnect() +// return +// } +// } +// } +//} +// +//// reconnect closes the old connection and establishes a new one with exponential backoff. +//// Only one reconnection attempt runs at a time; concurrent callers return immediately. +//func (e *Ingestor) reconnect() { +// if e.ctx.Err() != nil { +// common.Info(fmt.Sprintf("Ingestor %s is shutting down, skipping reconnection", e.id)) +// return +// } +// +// if !e.reconnectMu.TryLock() { +// return +// } +// defer e.reconnectMu.Unlock() +// +// common.Info(fmt.Sprintf("Ingestor %s attempting to reconnect to admin at %s", e.id, e.serverAddr)) +// +// // Close old stream and connection +// if e.stream != nil { +// e.stream.CloseSend() +// } +// if e.conn != nil { +// e.conn.Close() +// } +// +// backoff := 1 * time.Second +// maxBackoff := 30 * time.Second +// +// for { +// conn, err := grpc.Dial(e.serverAddr, +// grpc.WithTransportCredentials(insecure.NewCredentials()), +// grpc.WithBlock(), +// grpc.WithTimeout(5*time.Second), +// ) +// if err != nil { +// common.Info(fmt.Sprintf("Reconnect dial failed: %v, retrying in %v", err, backoff)) +// time.Sleep(backoff) +// backoff *= 2 +// if backoff > maxBackoff { +// backoff = maxBackoff +// } +// continue +// } +// e.conn = conn +// e.client = common.NewIngestionManagerClient(conn) +// +// stream, err := e.client.Action(e.ctx) +// if err != nil { +// conn.Close() +// common.Info(fmt.Sprintf("Reconnect create stream failed: %v, retrying in %v", err, backoff)) +// time.Sleep(backoff) +// backoff *= 2 +// if backoff > maxBackoff { +// backoff = maxBackoff +// } +// continue +// } +// e.stream = stream +// +// if err = e.sendRegister(); err != nil { +// stream.CloseSend() +// conn.Close() +// common.Info(fmt.Sprintf("Reconnect register failed: %v, retrying in %v", err, backoff)) +// time.Sleep(backoff) +// backoff *= 2 +// if backoff > maxBackoff { +// backoff = maxBackoff +// } +// continue +// } +// +// common.Info(fmt.Sprintf("Ingestor %s reconnected to admin", e.id)) +// break +// } +// +// // Restart the loops on the new stream +// go e.receiveLoop() +// go e.heartbeatLoop() +//} + // Stop gracefully shuts down the ingestor func (e *Ingestor) Stop() { common.Info(fmt.Sprintf("Stopping ingestor %s", e.id)) @@ -583,8 +730,4 @@ func (e *Ingestor) Stop() { // Wait for all workers to finish (they exit on ctx.Done()) e.workerWg.Wait() common.Info("All tasks completed") - - if e.stream != nil { - e.stream.CloseSend() - } } diff --git a/internal/router/router.go b/internal/router/router.go index e798654460..acc355f5b4 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -271,8 +271,12 @@ func (r *Router) Setup(engine *gin.Engine) { // Dataset document chunk datasets.GET("/:dataset_id/documents/:document_id/chunks/:chunk_id", r.chunkHandler.Get) - datasets.POST("/:dataset_id/documents/parse", r.documentHandler.ParseDocuments) - datasets.POST("/:dataset_id/documents/stop", r.documentHandler.StopParseDocuments) + datasets.POST("/:dataset_id/documents/parse", r.documentHandler.StartIngestionTask) + datasets.GET("/ingestion/tasks", r.documentHandler.ListIngestionTasks) + datasets.PUT("/ingestion/tasks", r.documentHandler.StopIngestionTasks) + datasets.DELETE("/ingestion/tasks", r.documentHandler.RemoveIngestionTasks) + //datasets.POST("/:dataset_id/documents/parse", r.documentHandler.ParseDocuments) + //datasets.POST("/:dataset_id/documents/stop", r.documentHandler.StopParseDocuments) datasets.DELETE("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.RemoveChunks) datasets.PUT("/:dataset_id/documents/:document_id/metadata/config", r.datasetsHandler.UpdateDocumentMetadataConfig) } diff --git a/internal/server/config.go b/internal/server/config.go index dc11d617af..620d341199 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -41,6 +41,7 @@ type Config struct { Authentication AuthenticationConfig `mapstructure:"authentication"` Database DatabaseConfig `mapstructure:"database"` Redis RedisConfig `mapstructure:"redis"` + Nats NatsConfig `mapstructure:"nats"` Log LogConfig `mapstructure:"log"` DocEngine DocEngineConfig `mapstructure:"doc_engine"` StorageEngine StorageConfig `mapstructure:"storage_engine"` @@ -51,13 +52,13 @@ type Config struct { UserDefaultLLM UserDefaultLLMConfig `mapstructure:"user_default_llm"` DefaultSuperUser DefaultSuperUser `mapstructure:"default_super_user"` Language string `mapstructure:"language"` + TaskExecutor TaskExecutorConfig `mapstructure:"task_executor"` } // AdminConfig admin server configuration type AdminConfig struct { - Host string `mapstructure:"host"` - Port int `mapstructure:"http_port"` - IngestionManagerPort int `mapstructure:"ingestion_manager_port"` + Host string `mapstructure:"host"` + Port int `mapstructure:"http_port"` } type AuthenticationConfig struct { @@ -71,6 +72,10 @@ type DefaultSuperUser struct { Nickname string `mapstructure:"nickname"` } +type TaskExecutorConfig struct { + MessageQueueType string `mapstructure:"message_queue_type"` +} + // UserDefaultLLMConfig user default LLM configuration type UserDefaultLLMConfig struct { DefaultModels DefaultModelsConfig `mapstructure:"default_models"` @@ -230,6 +235,11 @@ type RedisConfig struct { DB int `mapstructure:"db"` } +type NatsConfig struct { + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` +} + var ( globalConfig *Config globalViper *viper.Viper @@ -371,6 +381,13 @@ func Init(configPath string) error { "message_queue_type": mqType, } delete(configDict, "message_queue_type") + case "nats": + host := getString(configDict, "host") + port := getInt(configDict, "port") + configDict["id"] = id + configDict["name"] = "nats" + configDict["host"] = host + configDict["port"] = port case "admin": // Skip admin section continue @@ -584,9 +601,6 @@ func FromConfigFile(configPath string) error { } else { globalConfig.Admin.Port += 2 } - if globalConfig.Admin.IngestionManagerPort == 0 { - globalConfig.Admin.IngestionManagerPort = 9385 - } // authentication section if globalConfig != nil { diff --git a/internal/service/heartbeat_sender.go b/internal/service/admin_client.go similarity index 88% rename from internal/service/heartbeat_sender.go rename to internal/service/admin_client.go index 63709559fb..05c3be0057 100644 --- a/internal/service/heartbeat_sender.go +++ b/internal/service/admin_client.go @@ -28,8 +28,10 @@ import ( "go.uber.org/zap" ) -// HeartbeatSender is responsible for sending heartbeat reports to the admin server -type HeartbeatSender struct { +var AdminServiceClient *AdminClient + +// AdminClient is responsible for sending heartbeat reports to the admin server +type AdminClient struct { client *utility.HTTPClient logger *zap.Logger serverType common.ServerType @@ -41,9 +43,9 @@ type HeartbeatSender struct { attemptCount int } -// NewHeartbeatSender creates a new heartbeat service instance -func NewHeartbeatSender(logger *zap.Logger, serverType common.ServerType, serverName, host string, port int) *HeartbeatSender { - return &HeartbeatSender{ +// NewAdminClient creates a new heartbeat service instance +func NewAdminClient(logger *zap.Logger, serverType common.ServerType, serverName, host string, port int) *AdminClient { + return &AdminClient{ logger: logger, serverType: serverType, serverName: serverName, @@ -56,7 +58,7 @@ func NewHeartbeatSender(logger *zap.Logger, serverType common.ServerType, server } // InitHTTPClient initializes the HTTP client with admin server configuration -func (h *HeartbeatSender) InitHTTPClient() error { +func (h *AdminClient) InitHTTPClient() error { adminConfig := server.GetAdminConfig() if adminConfig == nil { return fmt.Errorf("admin configuration not found") @@ -77,7 +79,7 @@ func (h *HeartbeatSender) InitHTTPClient() error { } // SendHeartbeat sends a heartbeat message to the admin server -func (h *HeartbeatSender) SendHeartbeat() error { +func (h *AdminClient) SendHeartbeat() error { if h.attemptCount < 10 { if h.lastSuccess { diff --git a/internal/service/document.go b/internal/service/document.go index 891a4dd35f..b9ec341527 100644 --- a/internal/service/document.go +++ b/internal/service/document.go @@ -36,38 +36,36 @@ import ( "ragflow/internal/dao" "ragflow/internal/engine" - "ragflow/internal/server" - - "github.com/google/uuid" "gorm.io/gorm" + "ragflow/internal/server" ) // DocumentService document service type DocumentService struct { - documentDAO *dao.DocumentDAO - kbDAO *dao.KnowledgebaseDAO - ingestionTaskDAO *dao.IngestionDAO - ingestionLogDAO *dao.IngestionLogDAO - docEngine engine.DocEngine - engineType server.EngineType - metadataSvc *MetadataService - taskDAO *dao.TaskDAO - file2DocumentDAO *dao.File2DocumentDAO + documentDAO *dao.DocumentDAO + kbDAO *dao.KnowledgebaseDAO + ingestionTaskDAO *dao.IngestionTaskDAO + ingestionTaskLogDAO *dao.IngestionTaskLogDAO + docEngine engine.DocEngine + engineType server.EngineType + metadataSvc *MetadataService + taskDAO *dao.TaskDAO + file2DocumentDAO *dao.File2DocumentDAO } // NewDocumentService create document service func NewDocumentService() *DocumentService { cfg := server.GetConfig() return &DocumentService{ - documentDAO: dao.NewDocumentDAO(), - ingestionTaskDAO: dao.NewIngestionDAO(), - ingestionLogDAO: dao.NewIngestionLogDAO(), - kbDAO: dao.NewKnowledgebaseDAO(), - docEngine: engine.Get(), - engineType: cfg.DocEngine.Type, - metadataSvc: NewMetadataService(), - taskDAO: dao.NewTaskDAO(), - file2DocumentDAO: dao.NewFile2DocumentDAO(), + documentDAO: dao.NewDocumentDAO(), + ingestionTaskDAO: dao.NewIngestionTaskDAO(), + ingestionTaskLogDAO: dao.NewIngestionTaskLogDAO(), + kbDAO: dao.NewKnowledgebaseDAO(), + docEngine: engine.Get(), + engineType: cfg.DocEngine.Type, + metadataSvc: NewMetadataService(), + taskDAO: dao.NewTaskDAO(), + file2DocumentDAO: dao.NewFile2DocumentDAO(), } } @@ -715,9 +713,136 @@ func (s *DocumentService) GetDocumentsByAuthorID(authorID, page, pageSize int) ( return responses, total, nil } +func (s *DocumentService) ListIngestionTasks(userID string, datasetID *string, page, pageSize int) ([]*entity.IngestionTask, error) { + offset := (page - 1) * pageSize + + var tasks []*entity.IngestionTask + var err error + if datasetID == nil { + tasks, err = s.ingestionTaskDAO.ListByUserID(userID, offset, pageSize) + } else { + tasks, err = s.ingestionTaskDAO.ListByUserIDAndDatasetID(userID, *datasetID, offset, pageSize) + } + + if err != nil { + return nil, err + } + + return tasks, nil +} + type ParseDocumentResponse struct { - DocumentID string `json:"document_id"` - Result *string `json:"result"` + DocumentID string `json:"document_id"` + Result string `json:"result"` +} + +func (s *DocumentService) IngestDocuments(datasetID, userID string, docIDs []string) ([]*ParseDocumentResponse, error) { + // deduplicate the document id + uniqueDocIDs := common.Deduplicate(docIDs) + if uniqueDocIDs == nil || len(uniqueDocIDs) == 0 { + return nil, fmt.Errorf("no documents to parse") + } + + var responses []*ParseDocumentResponse + + // query database, if the document ids are valid + for _, docID := range uniqueDocIDs { + doc, err := s.documentDAO.GetByID(docID) + + if err != nil { + errorMessage := err.Error() + responses = append(responses, &ParseDocumentResponse{ + DocumentID: docID, + Result: errorMessage, + }) + continue + } + + if doc == nil { + errorMessage := "no such document" + responses = append(responses, &ParseDocumentResponse{ + DocumentID: docID, + Result: errorMessage, + }) + continue + } + + task := &entity.IngestionTask{ + DocumentID: docID, + UserID: userID, + DatasetID: datasetID, + Schema: nil, + Status: common.CREATED, + } + + // save the task to database + task, err = s.ingestionTaskDAO.CheckAndCreate(task) + if err != nil { + errorMessage := err.Error() + responses = append(responses, &ParseDocumentResponse{ + DocumentID: docID, + Result: errorMessage, + }) + continue + } + + msgQueueEngine := engine.GetMessageQueueEngine() + + taskMessage := common.TaskMessage{ + TaskID: task.ID, + TaskType: common.TaskTypeIngestionTask, + } + + // convert task + taskMessageStr, err := json.Marshal(taskMessage) + if err != nil { + return nil, err + } + + err = msgQueueEngine.PublishTask("tasks.RAGFLOW", taskMessageStr) + if err != nil { + return nil, err + } + + responses = append(responses, &ParseDocumentResponse{ + DocumentID: docID, + Result: fmt.Sprintf("task_id: %s", task.ID), + }) + } + + common.Info(fmt.Sprintf("parse documents, dataset: %s, documents: %v", datasetID, docIDs)) + return responses, nil +} + +func (s *DocumentService) StopIngestionTasks(tasks []string, userID string) ([]*entity.IngestionTask, error) { + + var taskResponses []*entity.IngestionTask + for _, taskID := range tasks { + task, err := s.ingestionTaskDAO.SetStoppingByAPIServer(taskID) + if err != nil { + return nil, err + } + taskResponses = append(taskResponses, task) + } + return taskResponses, nil +} + +func (s *DocumentService) RemoveIngestionTasks(tasks []string, userID string) ([]map[string]string, error) { + + var deletedTasks []map[string]string + for _, taskID := range tasks { + taskRecord := map[string]string{ + "task_id": taskID, + } + _, err := s.ingestionTaskDAO.RemoveByAPIServerOrAdminServer(taskID, &userID) + if err != nil { + taskRecord["remove"] = fmt.Sprintf("fail: %s", err.Error()) + } else { + taskRecord["remove"] = "success" + } + deletedTasks = append(deletedTasks, taskRecord) + } + return deletedTasks, nil } func (s *DocumentService) ParseDocuments(datasetID, userID string, docIDs []string) ([]*ParseDocumentResponse, error) { @@ -740,7 +865,7 @@ func (s *DocumentService) ParseDocuments(datasetID, userID string, docIDs []stri errorMessage := err.Error() responses = append(responses, &ParseDocumentResponse{ DocumentID: docID, - Result: &errorMessage, + Result: errorMessage, }) continue } @@ -748,7 +873,7 @@ func (s *DocumentService) ParseDocuments(datasetID, userID string, docIDs []stri errorMessage := "no such document" responses = append(responses, &ParseDocumentResponse{ DocumentID: docID, - Result: &errorMessage, + Result: errorMessage, }) continue } @@ -757,30 +882,28 @@ func (s *DocumentService) ParseDocuments(datasetID, userID string, docIDs []stri errorMessage := fmt.Sprintf("document %s is already parsed", docID) responses = append(responses, &ParseDocumentResponse{ DocumentID: docID, - Result: &errorMessage, + Result: errorMessage, }) continue } // create task for each document - task := &entity.IngestionTask{ - ID: uuid.New().String(), - DocumentID: docID, - UserID: userID, - Config: nil, - TryCount: 1, - } + //task := &entity.IngestionTask{ + // ID: utility.GenerateToken(), + // DocumentID: docID, + // UserID: userID, + //} // save the task to database - err = s.ingestionTaskDAO.Create(task) - if err != nil { - errorMessage := err.Error() - responses = append(responses, &ParseDocumentResponse{ - DocumentID: docID, - Result: &errorMessage, - }) - continue - } + //err = s.ingestionTaskDAO.Create(task) + //if err != nil { + // errorMessage := err.Error() + // responses = append(responses, &ParseDocumentResponse{ + // DocumentID: docID, + // Result: &errorMessage, + // }) + // continue + //} // Send task to message queue diff --git a/internal/storage/types.go b/internal/storage/types.go index 0d15ba5556..4d48ea8ac1 100644 --- a/internal/storage/types.go +++ b/internal/storage/types.go @@ -17,17 +17,9 @@ package storage import ( - "errors" "time" ) -var ( - // ErrNotFound is returned when an object is not found - ErrNotFound = errors.New("object not found") - // ErrBucketNotFound is returned when a bucket is not found - ErrBucketNotFound = errors.New("bucket not found") -) - // StorageType represents the type of storage backend type StorageType int