diff --git a/internal/handler/chunk.go b/internal/handler/chunk.go index fd8e63d7b7..fd667adb81 100644 --- a/internal/handler/chunk.go +++ b/internal/handler/chunk.go @@ -38,6 +38,7 @@ type chunkService interface { UpdateChunk(req *service.UpdateChunkRequest, userID string) error RemoveChunks(req *service.RemoveChunksRequest, userID string) (int64, error) Parse(userID, datasetID string, req *service.ParseFileRequest) (map[string]interface{}, common.ErrorCode, error) + StopParsing(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error) } // ChunkHandler chunk handler @@ -224,8 +225,8 @@ func (h *ChunkHandler) Parse(c *gin.Context) { }) return } - datasetID := strings.TrimSpace(c.Param("dataset_id")) - if datasetID == "" { + datasetId := strings.TrimSpace(c.Param("dataset_id")) + if datasetId == "" { c.JSON(http.StatusBadRequest, gin.H{ "code": common.CodeBadRequest, "message": "dataset_id is required", @@ -243,7 +244,7 @@ func (h *ChunkHandler) Parse(c *gin.Context) { return } - data, code, err := h.chunkService.Parse(userID, datasetID, &req) + data, code, err := h.chunkService.Parse(userID, datasetId, &req) if code != common.CodeSuccess { c.JSON(http.StatusOK, gin.H{ "code": code, @@ -353,6 +354,59 @@ func parseAvailableQuery(raw string) (int, bool, error) { } } +func (h *ChunkHandler) StopParsing(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + datasetID := c.Param("dataset_id") + if datasetID == "" { + jsonError(c, common.CodeDataError, "dataset_id is required") + return + } + + var req service.StopParsingRequest + if err := c.ShouldBindJSON(&req); err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + if len(req.DocumentIDs) == 0 { + jsonError(c, common.CodeDataError, "`document_ids` is required") + return + } + + resp, code, err := h.chunkService.StopParsing(user.ID, datasetID, req) + if err != nil { + var data interface{} + if resp != nil { + data = resp.Data + } + c.JSON(http.StatusOK, gin.H{ + "code": code, + "data": data, + "message": err.Error(), + }) + return + } + + message := "success" + var data interface{} + if resp != nil { + if resp.Message != "" { + message = resp.Message + } + data = resp.Data + } + + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "data": data, + "message": message, + }) +} + // List retrieves chunks for a document. // @Summary List Chunks // @Description Retrieve paginated chunks for a document with optional filtering. diff --git a/internal/handler/chunk_test.go b/internal/handler/chunk_test.go index 355a875b20..a5ced9f270 100644 --- a/internal/handler/chunk_test.go +++ b/internal/handler/chunk_test.go @@ -23,6 +23,7 @@ type mockChunkSvc struct { listFn func(req *service.ListChunksRequest, userID string) (*service.ListChunksResponse, error) switchChunksFn func(userID, datasetID, documentID string, availableInt int, chunkIDs []string) error updateChunkFn func(req *service.UpdateChunkRequest, userID string) error + stopParsingFn func(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error) } func (m *mockChunkSvc) RetrievalTest(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) { @@ -58,6 +59,12 @@ func (m *mockChunkSvc) UpdateChunk(req *service.UpdateChunkRequest, userID strin func (m *mockChunkSvc) RemoveChunks(*service.RemoveChunksRequest, string) (int64, error) { panic("not implemented") } +func (m *mockChunkSvc) StopParsing(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error) { + if m.stopParsingFn != nil { + return m.stopParsingFn(userID, datasetID, req) + } + panic("not implemented") +} func (m *mockChunkSvc) Parse(string, string, *service.ParseFileRequest) (map[string]interface{}, common.ErrorCode, error) { panic("not implemented") } @@ -84,6 +91,18 @@ func setupChunkRetrievalTestNoAuth() *gin.Engine { return r } +func setupChunkStopParsingTest(userID string) (*gin.Engine, *mockChunkSvc) { + mock := &mockChunkSvc{} + h := &ChunkHandler{chunkService: mock} + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("user", &entity.User{ID: userID}) + }) + r.DELETE("/api/v1/datasets/:dataset_id/chunks", h.StopParsing) + return r, mock +} + func setupChunkHandlerWithUser(userID string, mock *mockChunkSvc) (*gin.Engine, *ChunkHandler) { h := &ChunkHandler{chunkService: mock} gin.SetMode(gin.TestMode) @@ -254,6 +273,103 @@ func TestChunkRetrieval_EmptyQuestion(t *testing.T) { } } +func TestChunkStopParsing_Success(t *testing.T) { + r, mock := setupChunkStopParsingTest("user1") + mock.stopParsingFn = func(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error) { + if userID != "user1" { + t.Fatalf("expected user1, got %q", userID) + } + if datasetID != "kb1" { + t.Fatalf("expected kb1, got %q", datasetID) + } + if len(req.DocumentIDs) != 2 || req.DocumentIDs[0] != "doc1" || req.DocumentIDs[1] != "doc2" { + t.Fatalf("unexpected document IDs: %#v", req.DocumentIDs) + } + return nil, common.CodeSuccess, nil + } + + w := httptest.NewRecorder() + req, _ := http.NewRequest("DELETE", "/api/v1/datasets/kb1/chunks", strings.NewReader(`{"document_ids":["doc1","doc2"]}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if resp["code"] != float64(common.CodeSuccess) { + t.Fatalf("expected code 0, got %v: %s", resp["code"], w.Body.String()) + } + if resp["message"] != "success" { + t.Fatalf("expected success message, got %v", resp["message"]) + } +} + +func TestChunkStopParsingRouteRequiresDocumentIDs(t *testing.T) { + r, mock := setupChunkStopParsingTest("user1") + mock.stopParsingFn = func(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error) { + t.Fatal("service should not be called when document_ids is missing") + return nil, common.CodeSuccess, nil + } + + w := httptest.NewRecorder() + req, _ := http.NewRequest("DELETE", "/api/v1/datasets/kb1/chunks", strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if resp["code"] != float64(common.CodeDataError) { + t.Fatalf("expected data error, got %v: %s", resp["code"], w.Body.String()) + } + if resp["message"] != "`document_ids` is required" { + t.Fatalf("unexpected message: %v", resp["message"]) + } +} + +func TestChunkStopParsing_InvalidStateIncludesPythonErrorCode(t *testing.T) { + r, mock := setupChunkStopParsingTest("user1") + mock.stopParsingFn = func(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error) { + return &service.StopParsingResponse{ + Data: map[string]interface{}{"error_code": "DOC_STOP_PARSING_INVALID_STATE"}, + }, common.CodeDataError, errors.New("Can't stop parsing document that has not started or already completed") + } + + w := httptest.NewRecorder() + req, _ := http.NewRequest("DELETE", "/api/v1/datasets/kb1/chunks", strings.NewReader(`{"document_ids":["doc1"]}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if resp["code"] != float64(common.CodeDataError) { + t.Fatalf("expected data error, got %v: %s", resp["code"], w.Body.String()) + } + data, ok := resp["data"].(map[string]interface{}) + if !ok { + t.Fatalf("expected data object, got %T", resp["data"]) + } + if data["error_code"] != "DOC_STOP_PARSING_INVALID_STATE" { + t.Fatalf("unexpected error_code: %v", data["error_code"]) + } + if resp["message"] != "Can't stop parsing document that has not started or already completed" { + t.Fatalf("unexpected message: %v", resp["message"]) + } +} + func TestChunkRetrieval_WhitespaceQuestion(t *testing.T) { r, _ := setupChunkRetrievalTest("user1") diff --git a/internal/handler/document.go b/internal/handler/document.go index 384b75268b..11e260d6ab 100644 --- a/internal/handler/document.go +++ b/internal/handler/document.go @@ -67,6 +67,7 @@ type documentServiceIface interface { ListIngestionTasks(userID string, datasetID *string, page, pageSize int) ([]*entity.IngestionTask, error) IngestDocuments(datasetID, userID string, docIDs []string) ([]*service.ParseDocumentResponse, error) StopIngestionTasks(tasks []string, userID string) ([]*entity.IngestionTask, error) + Ingest(userID string, req *service.IngestDocumentRequest) (common.ErrorCode, error) RemoveIngestionTasks(tasks []string, userID string) ([]map[string]string, error) BatchUpdateDocumentStatus(userID, datasetID, status string, DocumentIDs []string) (map[string]interface{}, common.ErrorCode, error) } @@ -874,6 +875,37 @@ func (h *DocumentHandler) SetMeta(c *gin.Context) { }) } +func (h *DocumentHandler) Ingest(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + userID := strings.TrimSpace(user.ID) + if userID == "" { + jsonError(c, common.CodeAuthenticationError, "No Authentication") + return + } + + var req service.IngestDocumentRequest + if err := c.ShouldBindJSON(&req); err != nil { + jsonError(c, common.CodeBadRequest, err.Error()) + return + } + + if code, err := h.documentService.Ingest(userID, &req); err != nil { + jsonError(c, code, err.Error()) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": "success", + "data": true, + }) +} + // DeleteMetaRequest represents the request for deleting document metadata type DeleteMetaRequest struct { DocID string `json:"doc_id" binding:"required"` diff --git a/internal/handler/document_test.go b/internal/handler/document_test.go index aeedc2d7f1..07a844d138 100644 --- a/internal/handler/document_test.go +++ b/internal/handler/document_test.go @@ -45,6 +45,19 @@ type fakeDocumentService struct { metadataErr error metadataKBID string metadataDocIDs []string + ingestCode common.ErrorCode + ingestErr error + ingestUserID string + ingestReq *service.IngestDocumentRequest +} + +func (f *fakeDocumentService) Ingest(userID string, req *service.IngestDocumentRequest) (common.ErrorCode, error) { + f.ingestUserID = userID + f.ingestReq = req + if f.ingestCode != 0 || f.ingestErr != nil { + return f.ingestCode, f.ingestErr + } + return common.CodeSuccess, nil } func (f *fakeDocumentService) UpdateDatasetDocument(userID, datasetID, documentID string, req *service.UpdateDatasetDocumentRequest, present map[string]bool) (*service.UpdateDatasetDocumentResponse, common.ErrorCode, error) { @@ -176,6 +189,21 @@ func setupGinContextWithUser(method, path, body string) (*gin.Context, *httptest return c, w } +func setupDocumentIngestRoute(userID string, svc *fakeDocumentService) *gin.Engine { + gin.SetMode(gin.TestMode) + h := &DocumentHandler{ + documentService: svc, + datasetService: service.NewDatasetService(), + } + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("user", &entity.User{ID: userID}) + c.Set("user_id", userID) + }) + r.POST("/api/v1/documents/ingest", h.Ingest) + return r +} + func TestDeleteDocumentsHandler_Success(t *testing.T) { gin.SetMode(gin.TestMode) @@ -323,6 +351,116 @@ func TestDeleteDocumentsHandler_MissingDatasetID(t *testing.T) { } } +func TestDocumentHandlerIngestMatchesPythonResponseShape(t *testing.T) { + gin.SetMode(gin.TestMode) + + fake := &fakeDocumentService{} + h := &DocumentHandler{ + documentService: fake, + datasetService: service.NewDatasetService(), + } + + c, w := setupGinContextWithUser("POST", "/api/v1/documents/ingest", `{"doc_ids":["doc-1"],"run":"1"}`) + h.Ingest(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if resp["code"] != float64(common.CodeSuccess) { + t.Fatalf("expected top-level code 0, got %v", resp["code"]) + } + if resp["data"] != true { + t.Fatalf("expected top-level data=true, got %#v", resp["data"]) + } + if _, ok := resp["data"].(map[string]interface{}); ok { + t.Fatalf("response must not nest code/message under data: %#v", resp["data"]) + } + if fake.ingestUserID != "user-1" { + t.Fatalf("expected user-1, got %q", fake.ingestUserID) + } + if fake.ingestReq == nil || len(fake.ingestReq.DocIDs) != 1 || fake.ingestReq.DocIDs[0] != "doc-1" { + t.Fatalf("unexpected ingest request: %#v", fake.ingestReq) + } +} + +func TestDocumentIngestRoutePassesPythonBodyToService(t *testing.T) { + fake := &fakeDocumentService{} + r := setupDocumentIngestRoute("user-1", fake) + + w := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/api/v1/documents/ingest", strings.NewReader(`{"doc_ids":["doc-1","doc-2"],"run":1,"delete":true,"apply_kb":true}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if resp["code"] != float64(common.CodeSuccess) || resp["data"] != true { + t.Fatalf("unexpected response: %s", w.Body.String()) + } + if fake.ingestUserID != "user-1" { + t.Fatalf("userID = %q, want user-1", fake.ingestUserID) + } + if fake.ingestReq == nil { + t.Fatal("service did not receive ingest request") + } + if len(fake.ingestReq.DocIDs) != 2 || fake.ingestReq.DocIDs[0] != "doc-1" || fake.ingestReq.DocIDs[1] != "doc-2" { + t.Fatalf("doc_ids = %#v, want [doc-1 doc-2]", fake.ingestReq.DocIDs) + } + if fmt.Sprint(fake.ingestReq.Run) != "1" { + t.Fatalf("run = %#v, want 1", fake.ingestReq.Run) + } + if !fake.ingestReq.Delete { + t.Fatal("delete = false, want true") + } + if !fake.ingestReq.ApplyKB { + t.Fatal("apply_kb = false, want true") + } +} + +func TestDocumentHandlerIngestPropagatesServiceErrorCode(t *testing.T) { + gin.SetMode(gin.TestMode) + + fake := &fakeDocumentService{ + ingestCode: common.CodeAuthenticationError, + ingestErr: fmt.Errorf("No authorization."), + } + h := &DocumentHandler{ + documentService: fake, + datasetService: service.NewDatasetService(), + } + + c, w := setupGinContextWithUser("POST", "/api/v1/documents/ingest", `{"doc_ids":["doc-1"],"run":"1"}`) + h.Ingest(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if resp["code"] != float64(common.CodeAuthenticationError) { + t.Fatalf("expected auth error code, got %v", resp["code"]) + } + if resp["message"] != "No authorization." { + t.Fatalf("unexpected message: %v", resp["message"]) + } + if resp["data"] != nil { + t.Fatalf("expected nil data, got %#v", resp["data"]) + } +} + func TestStopParseDocumentsHandler_EmptyDocIDs(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/internal/router/router.go b/internal/router/router.go index 730f08e370..d9fd3974f8 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -243,6 +243,7 @@ func (r *Router) Setup(engine *gin.Engine) { documents.GET("/:id", r.documentHandler.GetDocumentByID) documents.PUT("/:id", r.documentHandler.UpdateDocument) documents.DELETE("/:id", r.documentHandler.DeleteDocument) + documents.POST("/ingest", r.documentHandler.Ingest) } // Chat routes @@ -320,6 +321,7 @@ func (r *Router) Setup(engine *gin.Engine) { datasets.DELETE("/ingestion/tasks", r.documentHandler.RemoveIngestionTasks) //datasets.POST("/:dataset_id/documents/parse", r.documentHandler.ParseDocuments) //datasets.POST("/:dataset_id/documents/stop", r.documentHandler.StopParseDocuments) + datasets.DELETE("/:dataset_id/chunks", r.chunkHandler.StopParsing) datasets.DELETE("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.RemoveChunks) datasets.PUT("/:dataset_id/documents/:document_id/metadata/config", r.datasetsHandler.UpdateDocumentMetadataConfig) datasets.POST("/:dataset_id/metadata/update", r.documentHandler.MetadataBatchUpdate) diff --git a/internal/service/chunk/chunk.go b/internal/service/chunk/chunk.go index 1193b110b1..83e58cd917 100644 --- a/internal/service/chunk/chunk.go +++ b/internal/service/chunk/chunk.go @@ -28,6 +28,7 @@ import ( "math/rand" "path/filepath" "ragflow/internal/common" + "ragflow/internal/engine/redis" "ragflow/internal/entity" "ragflow/internal/entity/models" "ragflow/internal/server" @@ -42,7 +43,6 @@ import ( "ragflow/internal/dao" "ragflow/internal/engine" - "ragflow/internal/engine/redis" "ragflow/internal/engine/types" "ragflow/internal/service" "ragflow/internal/service/nlp" @@ -64,6 +64,7 @@ type ChunkService struct { kbDAO *dao.KnowledgebaseDAO userTenantDAO *dao.UserTenantDAO documentDAO *dao.DocumentDAO + taskDAO *dao.TaskDAO searchService *service.SearchService accessibleFunc func(string, string) bool @@ -85,6 +86,7 @@ func NewChunkService() *ChunkService { kbDAO: dao.NewKnowledgebaseDAO(), userTenantDAO: dao.NewUserTenantDAO(), documentDAO: dao.NewDocumentDAO(), + taskDAO: dao.NewTaskDAO(), searchService: service.NewSearchService(), } } @@ -571,6 +573,112 @@ func (s *ChunkService) Get(req *service.GetChunkRequest, userID string) (*servic return &service.GetChunkResponse{Chunk: chunk}, nil } +const ( + docStopParsingInvalidStateMessage = "Can't stop parsing document that has not started or already completed" + docStopParsingInvalidStateErrorCode = "DOC_STOP_PARSING_INVALID_STATE" +) + +func (s *ChunkService) cancelAllTasksOfDoc(docID string) error { + tasks, err := s.taskDAO.GetByDocID(docID) + if err != nil { + return fmt.Errorf("failed to get tasks for document %s: %w", docID, err) + } + + redisClient := redis.Get() + if redisClient == nil { + common.Warn(fmt.Sprintf("Redis unavailable; cannot cancel tasks for document %s", docID)) + return nil + } + + for _, task := range tasks { + if task == nil { + continue + } + redisClient.Set(fmt.Sprintf("%s-cancel", task.ID), "x", 0) + } + + return nil +} + +func (s *ChunkService) StopParsing(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error) { + if !s.kbDAO.Accessible(datasetID, userID) { + return nil, common.CodeDataError, fmt.Errorf("You don't own the dataset %s", datasetID) + } + + if len(req.DocumentIDs) == 0 { + return nil, common.CodeDataError, fmt.Errorf("`document_ids` is required") + } + + kb, err := s.kbDAO.GetByID(datasetID) + if err != nil { + return nil, common.CodeDataError, fmt.Errorf("You don't own the dataset %s", datasetID) + } + + docIDs, duplicateMessages := service.CheckDuplicateIDs(req.DocumentIDs, "document") + successCount := 0 + ctx := context.Background() + indexName := service.IndexName(kb.TenantID) + + for _, docID := range docIDs { + doc, err := s.documentDAO.GetByDocumentIDAndDatasetID(docID, datasetID) + if err != nil || doc == nil { + return nil, common.CodeDataError, fmt.Errorf("You don't own the document %s", docID) + } + + if doc.Run == nil || *doc.Run != string(entity.TaskStatusRunning) { + return &service.StopParsingResponse{ + Data: map[string]interface{}{"error_code": docStopParsingInvalidStateErrorCode}, + }, common.CodeDataError, fmt.Errorf("%s", docStopParsingInvalidStateMessage) + } + + if err := s.cancelAllTasksOfDoc(docID); err != nil { + return nil, common.CodeServerError, err + } + + updates := map[string]interface{}{ + "run": string(entity.TaskStatusCancel), + "progress": 0, + "chunk_num": 0, + } + if err := s.documentDAO.UpdateByID(doc.ID, updates); err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to update document %s: %w", doc.ID, err) + } + + if s.docEngine != nil { + exists, err := s.docEngine.ChunkStoreExists(ctx, indexName, datasetID) + if err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to check chunk store %s/%s: %w", indexName, datasetID, err) + } + if exists { + if _, err := s.docEngine.DeleteChunks(ctx, map[string]interface{}{"doc_id": doc.ID}, indexName, datasetID); err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to delete chunks for document %s: %w", doc.ID, err) + } + } else { + common.Info(fmt.Sprintf("Skipping chunk delete during stop_parsing for doc %s: index %s/%s does not exist", doc.ID, indexName, datasetID)) + } + } else { + common.Info(fmt.Sprintf("Skipping chunk delete during stop_parsing for doc %s: index %s/%s does not exist", doc.ID, indexName, datasetID)) + } + + successCount++ + } + + if len(duplicateMessages) > 0 { + if successCount > 0 { + return &service.StopParsingResponse{ + Message: fmt.Sprintf("Partially stopped %d documents with %d errors", successCount, len(duplicateMessages)), + Data: map[string]interface{}{ + "success_count": successCount, + "errors": duplicateMessages, + }, + }, common.CodeSuccess, nil + } + return nil, common.CodeDataError, fmt.Errorf("%s", strings.Join(duplicateMessages, ";")) + } + + return nil, common.CodeSuccess, nil +} + func checkDuplicateIDs(documentIDs []string, idTypes string) ([]string, []string) { idCount := make(map[string]int, len(documentIDs)) duplicateMessages := make([]string, 0) diff --git a/internal/service/chunk_types.go b/internal/service/chunk_types.go index 3628ea42c1..c04cfb3d4b 100644 --- a/internal/service/chunk_types.go +++ b/internal/service/chunk_types.go @@ -21,6 +21,7 @@ import ( "encoding/json" "fmt" "ragflow/internal/common" + "ragflow/internal/engine/redis" "ragflow/internal/server" "strings" @@ -31,12 +32,22 @@ import ( "ragflow/internal/utility" ) +var ( + UNSTART = "0" + RUNNING = "1" + CANCEL = "2" + DONE = "3" + FAIL = "4" + SCHEDULE = "5" +) + // ChunkService chunk service type ChunkService struct { docEngine engine.DocEngine engineType server.EngineType embeddingCache *utility.EmbeddingLRU kbDAO *dao.KnowledgebaseDAO + taskDAO *dao.TaskDAO userTenantDAO *dao.UserTenantDAO documentDAO *dao.DocumentDAO searchService *SearchService @@ -162,6 +173,131 @@ func (s *ChunkService) Get(req *GetChunkRequest, userID string) (*GetChunkRespon return &GetChunkResponse{Chunk: chunk}, nil } +type StopParsingRequest struct { + DocumentIDs []string `json:"document_ids"` +} + +type StopParsingResponse struct { + Data map[string]interface{} + Message string +} + +func (s *ChunkService) cancelAllTasksOfDoc(docID string) error { + tasks, err := s.taskDAO.GetByDocID(docID) + if err != nil { + return fmt.Errorf("failed to get tasks for document %s: %w", docID, err) + } + + redisClient := redis.Get() + if redisClient == nil { + common.Logger.Warn(fmt.Sprintf("Redis unavailable; cannot cancel tasks for document %s", docID)) + return nil + } + + for _, task := range tasks { + if task == nil { + continue + } + redisClient.Set(fmt.Sprintf("%s-cancel", task.ID), "x", 0) + } + + return nil +} + +func CheckDuplicateIDs(docList []string, idType string) ([]string, []string) { + uniqueDocIDs := make([]string, 0) + duplicateMessages := make([]string, 0) + idCount := make(map[string]int) + + for _, docID := range docList { + idCount[docID] += 1 + } + + for id, count := range idCount { + if count > 1 { + duplicateMessages = append(duplicateMessages, fmt.Sprintf("Duplicate %s ids: %s", idType, id)) + } + uniqueDocIDs = append(uniqueDocIDs, id) + } + return uniqueDocIDs, duplicateMessages +} + +func IndexName(uid string) string { + return fmt.Sprintf("ragflow_%s", uid) +} + +func (s *ChunkService) StopParsing(userID, datasetID string, req StopParsingRequest) (map[string]interface{}, common.ErrorCode, error) { + if !s.kbDAO.Accessible(datasetID, userID) { + return nil, common.CodeAuthenticationError, fmt.Errorf("You don't own the dataset %s", datasetID) + } + + if req.DocumentIDs == nil || len(req.DocumentIDs) == 0 { + return nil, common.CodeDataError, fmt.Errorf("document_ids is required") + } + + docList, duplicateMessages := CheckDuplicateIDs(req.DocumentIDs, "document") + + kb, err := s.kbDAO.GetByID(datasetID) + if err != nil { + return nil, common.CodeDataError, fmt.Errorf("You don't own the dataset %s", datasetID) + } + + successCount := 0 + for _, id := range docList { + doc, err := s.documentDAO.GetByDocumentIDAndDatasetID(id, datasetID) + if err != nil { + return nil, common.CodeDataError, fmt.Errorf("You don't own the document %s", id) + } + if doc == nil { + return nil, common.CodeDataError, fmt.Errorf("You don't own the document %s", id) + } + + if doc.Run == nil || *doc.Run != RUNNING { + return nil, common.CodeDataError, fmt.Errorf("Can't stop parsing document that has not started or already completed") + } + + err = s.cancelAllTasksOfDoc(id) + if err != nil { + return nil, common.CodeServerError, err + } + + info := map[string]interface{}{ + "run": "2", + "progress": 0, + "chunk_num": 0, + } + if err := s.documentDAO.UpdateByID(doc.ID, info); err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to update document %s: %w", doc.ID, err) + } + + indexName := IndexName(kb.TenantID) + + exists, err := s.docEngine.ChunkStoreExists(context.Background(), indexName, datasetID) + if err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to check chunk store %s/%s: %w", indexName, datasetID, err) + } + if exists { + if _, err := s.docEngine.DeleteChunks(context.Background(), map[string]interface{}{"doc_id": doc.ID}, indexName, datasetID); err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to delete chunks for document %s: %w", doc.ID, err) + } + } else { + common.Logger.Info(fmt.Sprintf("Skipping chunk delete during stop_parsing for doc %s: index %s/%s does not exist", doc.ID, indexName, datasetID)) + } + successCount++ + } + if len(duplicateMessages) > 0 { + if successCount > 0 { + return map[string]interface{}{ + "success_count": successCount, + "errors": duplicateMessages, + "message": fmt.Sprintf("Partially stopped %d documents with %d errors", successCount, len(duplicateMessages)), + }, common.CodeSuccess, nil + } + return nil, common.CodeDataError, fmt.Errorf("%s", strings.Join(duplicateMessages, ";")) + } + return nil, common.CodeSuccess, nil +} + // ListChunksRequest request for listing chunks type ListChunksRequest struct { DatasetID string `json:"dataset_id,omitempty"` diff --git a/internal/service/document.go b/internal/service/document.go index fe5b0c615f..6eed85692c 100644 --- a/internal/service/document.go +++ b/internal/service/document.go @@ -17,11 +17,17 @@ package service import ( + "archive/zip" + "bytes" "context" + "encoding/csv" "encoding/json" + "encoding/xml" "errors" "fmt" + "io" "math" + "math/rand" "mime/multipart" "os" "path/filepath" @@ -43,6 +49,7 @@ import ( "ragflow/internal/tokenizer" "ragflow/internal/utility" + "github.com/cespare/xxhash/v2" "go.uber.org/zap" "gorm.io/gorm" ) @@ -987,6 +994,584 @@ func (s *DocumentService) RemoveIngestionTasks(tasks []string, userID string) ([ return deletedTasks, nil } +type IngestDocumentRequest struct { + DocIDs []string `json:"doc_ids" binding:"required"` + Run interface{} `json:"run" binding:"required"` + Delete bool `json:"delete"` + ApplyKB bool `json:"apply_kb"` +} + +type documentParsePageRange struct { + from int64 + to int64 +} + +func (s *DocumentService) Ingest(userID string, req *IngestDocumentRequest) (common.ErrorCode, error) { + run := fmt.Sprint(req.Run) + + docs, err := s.documentDAO.GetByIDs(req.DocIDs) + if err != nil { + return common.CodeExceptionError, fmt.Errorf("fail to get documents: %s", err.Error()) + } + + docsByID := make(map[string]*entity.Document, len(docs)) + for _, doc := range docs { + if doc != nil { + docsByID[doc.ID] = doc + } + } + + tableDoneCountByKB := make(map[string]int64) + + for _, docID := range req.DocIDs { + doc := docsByID[docID] + if doc == nil { + return common.CodeDataError, fmt.Errorf("Document not found!") + } + + kb, err := s.kbDAO.GetByID(doc.KbID) + if err != nil { + return common.CodeDataError, fmt.Errorf("Tenant not found!") + } + + if !s.kbDAO.Accessible(kb.ID, userID) { + return common.CodeAuthenticationError, fmt.Errorf("No authorization.") + } + + updates := map[string]interface{}{ + "run": run, + "progress": 0, + } + + rerunWithDelete := run == string(entity.TaskStatusRunning) && req.Delete + if rerunWithDelete { + updates["progress_msg"] = "" + updates["chunk_num"] = 0 + updates["token_num"] = 0 + } + + if run == string(entity.TaskStatusCancel) { + if err := s.cancelDocParse(doc); err != nil { + return common.CodeDataError, err + } + } + + if rerunWithDelete && doc.Run != nil && *doc.Run == string(entity.TaskStatusDone) { + if err := s.clearKBChunkNumWhenRerun(doc); err != nil { + return common.CodeExceptionError, err + } + } + + if err := s.documentDAO.UpdateByID(doc.ID, updates); err != nil { + return common.CodeExceptionError, err + } + + if req.Delete { + _, _ = s.taskDAO.DeleteByDocIDs([]string{doc.ID}) + indexName := fmt.Sprintf("ragflow_%s", kb.TenantID) + if s.docEngine != nil { + exists, err := s.docEngine.ChunkStoreExists(context.Background(), indexName, doc.KbID) + if err != nil { + return common.CodeExceptionError, err + } + if exists { + if _, err := s.docEngine.DeleteChunks(context.Background(), map[string]interface{}{"doc_id": doc.ID}, indexName, doc.KbID); err != nil { + return common.CodeExceptionError, err + } + } + } + } + + if run == string(entity.TaskStatusRunning) { + if req.ApplyKB { + if doc.ParserConfig == nil { + doc.ParserConfig = entity.JSONMap{} + } + config := map[string]interface{}{ + "llm_id": kb.ParserConfig["llm_id"], + "enable_metadata": false, + "metadata": map[string]interface{}{}, + } + if value, ok := kb.ParserConfig["enable_metadata"]; ok { + config["enable_metadata"] = value + } + if value, ok := kb.ParserConfig["metadata"]; ok { + config["metadata"] = value + } + if err := s.updateDocumentParserConfig(doc.ID, config); err != nil { + return common.CodeExceptionError, err + } + for key, value := range config { + doc.ParserConfig[key] = value + } + } + if doc.PipelineID != nil && strings.TrimSpace(*doc.PipelineID) != "" { + if err := s.queueDocumentDataflowTask(kb, doc, strings.TrimSpace(*doc.PipelineID), 0); err != nil { + return common.CodeExceptionError, err + } + continue + } + if doc.ParserID == string(entity.ParserTypeTable) { + doneCount, ok := tableDoneCountByKB[doc.KbID] + if !ok { + count, err := s.countDoneDocuments(doc.KbID) + if err != nil { + return common.CodeExceptionError, err + } + doneCount = count + tableDoneCountByKB[doc.KbID] = doneCount + if doneCount <= 0 { + if err := s.kbDAO.DeleteFieldMap(doc.KbID); err != nil && !dao.IsNotFoundErr(err) { + return common.CodeExceptionError, err + } + } + } + } + if _, err := s.taskDAO.DeleteByDocIDs([]string{doc.ID}); err != nil { + return common.CodeExceptionError, err + } + bucket, objectName, err := s.GetDocumentStorageAddress(doc) + if err != nil { + return common.CodeExceptionError, err + } + if err := s.queueDocumentParseTasks(doc, bucket, objectName, 0); err != nil { + return common.CodeExceptionError, err + } + if err := s.beginDocumentParse(doc.ID); err != nil { + return common.CodeExceptionError, err + } + } + } + + return common.CodeSuccess, nil +} + +func (s *DocumentService) countDoneDocuments(datasetID string) (int64, error) { + var count int64 + err := dao.GetDB().Model(&entity.Document{}). + Where("kb_id = ? AND run = ?", datasetID, string(entity.TaskStatusDone)). + Count(&count).Error + return count, err +} + +func (s *DocumentService) queueDocumentParseTasks(doc *entity.Document, bucket, objectName string, priority int64) error { + if _, err := s.taskDAO.DeleteByDocIDs([]string{doc.ID}); err != nil { + return err + } + tasks, err := s.newDocumentParseTasks(doc, bucket, objectName, priority) + if err != nil { + return err + } + if err := s.taskDAO.CreateMany(tasks); err != nil { + return err + } + queueName := documentParseQueueName(doc, priority) + for _, task := range tasks { + if task.Progress >= 1 { + continue + } + if redisClient := redis.Get(); redisClient == nil || !redisClient.QueueProduct(queueName, documentTaskMessage(task)) { + return fmt.Errorf("Can't access Redis. Please check the Redis' status.") + } + } + return nil +} + +func (s *DocumentService) queueDocumentDataflowTask(kb *entity.Knowledgebase, doc *entity.Document, flowID string, priority int64) error { + if _, err := s.taskDAO.DeleteByDocIDs([]string{doc.ID}); err != nil { + return err + } + if err := s.beginDocumentParse(doc.ID); err != nil { + return err + } + task := s.newDocumentParseTask(doc, 0, maximumTaskPageNumber, priority) + task.TaskType = "dataflow" + if err := s.taskDAO.CreateMany([]*entity.Task{task}); err != nil { + return err + } + message := documentTaskMessage(task) + message["task_type"] = task.TaskType + message["kb_id"] = doc.KbID + message["tenant_id"] = kb.TenantID + message["dataflow_id"] = flowID + message["file"] = nil + if redisClient := redis.Get(); redisClient == nil || !redisClient.QueueProduct(documentParseQueueName(doc, priority), message) { + return fmt.Errorf("Can't access Redis. Please check the Redis' status.") + } + return nil +} + +func (s *DocumentService) newDocumentParseTasks(doc *entity.Document, bucket, objectName string, priority int64) ([]*entity.Task, error) { + ranges, err := documentParseTaskRanges(doc, bucket, objectName) + if err != nil { + return nil, err + } + tasks := make([]*entity.Task, 0, len(ranges)) + for _, pageRange := range ranges { + tasks = append(tasks, s.newDocumentParseTask(doc, pageRange.from, pageRange.to, priority)) + } + return tasks, nil +} + +func (s *DocumentService) newDocumentParseTask(doc *entity.Document, fromPage, toPage, priority int64) *entity.Task { + now := time.Now() + progressMsg := "" + digest := documentParseTaskDigest(doc, fromPage, toPage) + chunkIDs := "" + return &entity.Task{ + ID: common.GenerateUUID(), + DocID: doc.ID, + FromPage: fromPage, + ToPage: toPage, + TaskType: "", + Priority: priority, + BeginAt: &now, + Progress: 0, + ProgressMsg: &progressMsg, + Digest: &digest, + ChunkIDs: &chunkIDs, + } +} + +func documentParseTaskRanges(doc *entity.Document, bucket, objectName string) ([]documentParsePageRange, error) { + if doc.Type == "pdf" { + binary, err := documentStorageBinary(bucket, objectName) + if err != nil { + return nil, err + } + pages := documentEstimatePDFPageCount(binary) + pageSize := int64(documentParserConfigInt(doc.ParserConfig, "task_page_size", 12)) + if doc.ParserID == string(entity.ParserTypePaper) { + pageSize = int64(documentParserConfigInt(doc.ParserConfig, "task_page_size", 22)) + } + if doc.ParserID == string(entity.ParserTypeOne) || + doc.ParserID == string(entity.ParserTypeKG) || + documentParserConfigBool(doc.ParserConfig, "toc_extraction", false) { + pageSize = maximumTaskPageNumber + } + if pageSize <= 0 { + pageSize = 12 + } + ranges := make([]documentParsePageRange, 0) + for _, configuredRange := range documentParserConfigPageRanges(doc.ParserConfig) { + start := configuredRange.from - 1 + if start < 0 { + start = 0 + } + end := configuredRange.to - 1 + if pages >= 0 && end > pages { + end = pages + } + for page := start; page < end; page += pageSize { + to := page + pageSize + if to > end { + to = end + } + ranges = append(ranges, documentParsePageRange{from: page, to: to}) + } + } + if len(ranges) == 0 { + ranges = append(ranges, documentParsePageRange{from: 0, to: maximumTaskPageNumber}) + } + return ranges, nil + } + if doc.ParserID == string(entity.ParserTypeTable) { + binary, err := documentStorageBinary(bucket, objectName) + if err != nil { + return nil, err + } + rows := documentEstimateTableRowCount(documentName(doc), binary) + if rows <= 0 { + return []documentParsePageRange{{from: 0, to: maximumTaskPageNumber}}, nil + } + ranges := make([]documentParsePageRange, 0, (rows+2999)/3000) + for row := int64(0); row < int64(rows); row += 3000 { + to := row + 3000 + if to > int64(rows) { + to = int64(rows) + } + ranges = append(ranges, documentParsePageRange{from: row, to: to}) + } + return ranges, nil + } + return []documentParsePageRange{{from: 0, to: maximumTaskPageNumber}}, nil +} + +func documentStorageBinary(bucket, objectName string) ([]byte, error) { + storageImpl := storage.GetStorageFactory().GetStorage() + if storageImpl == nil { + return nil, fmt.Errorf("storage not initialized") + } + return storageImpl.Get(bucket, objectName) +} + +func documentName(doc *entity.Document) string { + if doc == nil || doc.Name == nil { + return "" + } + return *doc.Name +} + +func documentParserConfigInt(config map[string]interface{}, key string, fallback int) int { + value, ok := config[key] + if !ok || value == nil { + return fallback + } + switch typedValue := value.(type) { + case int: + return typedValue + case int64: + return int(typedValue) + case float64: + return int(typedValue) + case json.Number: + if intValue, err := typedValue.Int64(); err == nil { + return int(intValue) + } + case string: + if intValue, err := strconv.Atoi(strings.TrimSpace(typedValue)); err == nil { + return intValue + } + } + return fallback +} + +func documentParserConfigBool(config map[string]interface{}, key string, fallback bool) bool { + value, ok := config[key] + if !ok || value == nil { + return fallback + } + switch typedValue := value.(type) { + case bool: + return typedValue + case string: + switch strings.ToLower(strings.TrimSpace(typedValue)) { + case "true", "1", "yes", "on": + return true + case "false", "0", "no", "off": + return false + } + } + return fallback +} + +func documentParserConfigPageRanges(config map[string]interface{}) []documentParsePageRange { + defaultRanges := []documentParsePageRange{{from: 1, to: 100000}} + raw, ok := config["pages"] + if !ok || raw == nil { + return defaultRanges + } + rawRanges, ok := raw.([]interface{}) + if !ok || len(rawRanges) == 0 { + return defaultRanges + } + ranges := make([]documentParsePageRange, 0, len(rawRanges)) + for _, rawRange := range rawRanges { + rangeValues, ok := rawRange.([]interface{}) + if !ok || len(rangeValues) < 2 { + continue + } + from, okFrom := documentToInt64(rangeValues[0]) + to, okTo := documentToInt64(rangeValues[1]) + if okFrom && okTo && to > from { + ranges = append(ranges, documentParsePageRange{from: from, to: to}) + } + } + if len(ranges) == 0 { + return defaultRanges + } + return ranges +} + +func documentToInt64(value interface{}) (int64, bool) { + switch typedValue := value.(type) { + case int: + return int64(typedValue), true + case int64: + return typedValue, true + case float64: + return int64(typedValue), true + case json.Number: + intValue, err := typedValue.Int64() + return intValue, err == nil + case string: + intValue, err := strconv.ParseInt(strings.TrimSpace(typedValue), 10, 64) + return intValue, err == nil + default: + return 0, false + } +} + +var documentPDFPagePattern = regexp.MustCompile(`/Type\s*/Page\b`) + +func documentEstimatePDFPageCount(binary []byte) int64 { + if len(binary) == 0 { + return 0 + } + return int64(len(documentPDFPagePattern.FindAll(binary, -1))) +} + +func documentEstimateTableRowCount(name string, binary []byte) int { + switch strings.ToLower(filepath.Ext(name)) { + case ".xlsx": + if rows, err := documentCountXLSXRows(binary); err == nil { + return rows + } + case ".csv", ".tsv", ".txt": + return documentCountDelimitedRows(name, binary) + } + return 0 +} + +func documentCountDelimitedRows(name string, binary []byte) int { + reader := csv.NewReader(bytes.NewReader(binary)) + reader.FieldsPerRecord = -1 + reader.ReuseRecord = true + if strings.EqualFold(filepath.Ext(name), ".tsv") { + reader.Comma = '\t' + } + rows := 0 + for { + _, err := reader.Read() + if err == nil { + rows++ + continue + } + if err == io.EOF { + break + } + rows += bytes.Count(binary, []byte{'\n'}) + if len(binary) > 0 && binary[len(binary)-1] != '\n' { + rows++ + } + break + } + return rows +} + +func documentCountXLSXRows(binary []byte) (int, error) { + zipReader, err := zip.NewReader(bytes.NewReader(binary), int64(len(binary))) + if err != nil { + return 0, err + } + maxRows := 0 + for _, file := range zipReader.File { + if !strings.HasPrefix(file.Name, "xl/worksheets/") || !strings.HasSuffix(file.Name, ".xml") { + continue + } + rows, err := documentCountWorksheetRows(file) + if err != nil { + return 0, err + } + if rows > maxRows { + maxRows = rows + } + } + return maxRows, nil +} + +func documentCountWorksheetRows(file *zip.File) (int, error) { + reader, err := file.Open() + if err != nil { + return 0, err + } + defer reader.Close() + decoder := xml.NewDecoder(reader) + rows := 0 + for { + token, err := decoder.Token() + if err == io.EOF { + break + } + if err != nil { + return 0, err + } + start, ok := token.(xml.StartElement) + if ok && start.Name.Local == "row" { + rows++ + } + } + return rows, nil +} + +func (s *DocumentService) beginDocumentParse(docID string) error { + now := time.Now() + return dao.GetDB().Model(&entity.Document{}).Where("id = ?", docID).Updates(map[string]interface{}{ + "progress_msg": "Task is queued...", + "process_begin_at": now, + "progress": rand.Float64() * 0.01, + "run": string(entity.TaskStatusRunning), + "chunk_num": 0, + "token_num": 0, + }).Error +} + +func documentParseQueueName(doc *entity.Document, priority int64) string { + suffix := "common" + if doc.ParserID == string(entity.ParserTypeResume) { + suffix = "resume" + } + return fmt.Sprintf("te.%d.%s", priority, suffix) +} + +func documentTaskMessage(task *entity.Task) map[string]interface{} { + beginAt := "" + if task.BeginAt != nil { + beginAt = task.BeginAt.Format("2006-01-02 15:04:05") + } + digest := "" + if task.Digest != nil { + digest = *task.Digest + } + return map[string]interface{}{ + "id": task.ID, + "doc_id": task.DocID, + "from_page": task.FromPage, + "to_page": task.ToPage, + "progress": task.Progress, + "priority": task.Priority, + "begin_at": beginAt, + "digest": digest, + } +} + +func documentParseTaskDigest(doc *entity.Document, fromPage, toPage int64) string { + hasher := xxhash.New() + config := map[string]interface{}{ + "doc_id": doc.ID, + "kb_id": doc.KbID, + "parser_id": doc.ParserID, + "parser_config": doc.ParserConfig, + } + keys := make([]string, 0, len(config)) + for key := range config { + keys = append(keys, key) + } + sort.Strings(keys) + for _, key := range keys { + b, err := json.Marshal(config[key]) + if err != nil { + hasher.WriteString(fmt.Sprint(config[key])) + } else { + hasher.Write(b) + } + } + hasher.WriteString(doc.ID) + hasher.WriteString(strconv.FormatInt(fromPage, 10)) + hasher.WriteString(strconv.FormatInt(toPage, 10)) + return fmt.Sprintf("%x", hasher.Sum64()) +} + +func (s *DocumentService) clearKBChunkNumWhenRerun(doc *entity.Document) error { + if doc == nil { + return fmt.Errorf("document is nil") + } + return dao.GetDB().Model(&entity.Knowledgebase{}).Where("id = ?", doc.KbID).Updates(map[string]interface{}{ + "token_num": gorm.Expr("token_num - ?", doc.TokenNum), + "chunk_num": gorm.Expr("chunk_num - ?", doc.ChunkNum), + }).Error +} + func (s *DocumentService) ParseDocuments(datasetID, userID string, docIDs []string) ([]*ParseDocumentResponse, error) { // create document parse id // save to task table