From 217c2a94c21337d4e4c6ae82eb9e84e582c2d0c1 Mon Sep 17 00:00:00 2001 From: Haruko386 Date: Thu, 18 Jun 2026 17:57:24 +0800 Subject: [PATCH] feat[Go]: implement datasets//index P/G (#16153) ### What problem does this PR solve? ``` POST: http://localhost:9384/api/v1/datasets/433b390c630411f1a13eab5f89540b2a/index?type=graph Output: { "code": 0, "data": { "task_id": "ff5a3546bafa49d794a9a050d99c4a52" }, "message": "success" } ``` --- ``` GET: http://localhost:9384/api/v1/datasets/433b390c630411f1a13eab5f89540b2a/index?type=graph Output: { "code": 0, "data": { "id": "ff5a3546bafa49d794a9a050d99c4a52", "doc_id": "graph_raptor_x", "from_page": 100000000, "to_page": 100000000, "task_type": "graphrag", "priority": 0, "begin_at": "2026-06-17T18:07:45+08:00", "process_duration": 4.108135, "progress": -1, "progress_msg": "18:07:45 created task graphrag\n18:07:47 Task has been received.\n18:07:49 [ERROR][Exception]: Model config not found: Qwen/Qwen3-235B-A22B@test@SILICONFLOW", "retry_count": 1, "digest": "f16fd067d5c92cec", "create_time": 1781690865552, "create_date": "2026-06-17T18:07:45+08:00", "update_time": 1781690869108, "update_date": "2026-06-17T18:07:49+08:00" }, "message": "success" } ``` ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- internal/dao/document.go | 80 ++++ internal/dao/document_test.go | 26 ++ internal/handler/datasets.go | 68 ++++ internal/router/router.go | 2 + internal/service/dataset.go | 385 ++++++++++++++++++ internal/service/dataset_task_cleanup_test.go | 110 +++++ 6 files changed, 671 insertions(+) create mode 100644 internal/service/dataset_task_cleanup_test.go diff --git a/internal/dao/document.go b/internal/dao/document.go index 72f197452b..16c9671ad0 100644 --- a/internal/dao/document.go +++ b/internal/dao/document.go @@ -109,6 +109,86 @@ func (dao *DocumentDAO) ListByKBID(kbID string, offset, limit int) ([]*entity.Do return documents, total, err } +// GetByKBID retrieves all documents in a knowledge base ordered by create time. +func (dao *DocumentDAO) GetByKBID(kbID string) ([]*entity.Document, int64, error) { + var documents []*entity.Document + var total int64 + + query := DB.Model(&entity.Document{}).Where("kb_id = ?", kbID) + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + err := query.Order("create_time ASC").Find(&documents).Error + return documents, total, err +} + +// GetChunkingConfig returns the document, dataset, and tenant fields used to +// build a parsing task digest, mirroring DocumentService.get_chunking_config. +func (dao *DocumentDAO) GetChunkingConfig(docID string) (map[string]interface{}, error) { + var row struct { + ID string `gorm:"column:id"` + KbID string `gorm:"column:kb_id"` + ParserID string `gorm:"column:parser_id"` + ParserConfig entity.JSONMap `gorm:"column:parser_config"` + Size int64 `gorm:"column:size"` + ContentHash *string `gorm:"column:content_hash"` + Language *string `gorm:"column:language"` + EmbdID string `gorm:"column:embd_id"` + TenantID string `gorm:"column:tenant_id"` + Img2TxtID string `gorm:"column:img2txt_id"` + ASRID string `gorm:"column:asr_id"` + LLMID string `gorm:"column:llm_id"` + } + + err := DB.Table("document"). + Select(` + document.id, + document.kb_id, + document.parser_id, + document.parser_config, + document.size, + document.content_hash, + knowledgebase.language, + knowledgebase.embd_id, + tenant.id AS tenant_id, + tenant.img2txt_id, + tenant.asr_id, + tenant.llm_id + `). + Joins("JOIN knowledgebase ON document.kb_id = knowledgebase.id"). + Joins("JOIN tenant ON knowledgebase.tenant_id = tenant.id"). + Where("document.id = ?", docID). + Take(&row).Error + if err != nil { + return nil, err + } + + config := map[string]interface{}{ + "id": row.ID, + "kb_id": row.KbID, + "parser_id": row.ParserID, + "parser_config": row.ParserConfig, + "size": row.Size, + "embd_id": row.EmbdID, + "tenant_id": row.TenantID, + "img2txt_id": row.Img2TxtID, + "asr_id": row.ASRID, + "llm_id": row.LLMID, + } + if row.ContentHash != nil { + config["content_hash"] = *row.ContentHash + } else { + config["content_hash"] = nil + } + if row.Language != nil { + config["language"] = *row.Language + } else { + config["language"] = nil + } + return config, nil +} + // DeleteByTenantID deletes all documents by tenant ID (hard delete) func (dao *DocumentDAO) DeleteByTenantID(tenantID string) (int64, error) { result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&entity.Document{}) diff --git a/internal/dao/document_test.go b/internal/dao/document_test.go index 5cef89baac..31e8f36a99 100644 --- a/internal/dao/document_test.go +++ b/internal/dao/document_test.go @@ -118,6 +118,32 @@ func TestDocumentGetByIDs_NoMatch(t *testing.T) { } } +func TestDocumentGetByKBIDOrdersByCreateTime(t *testing.T) { + db := setupDocumentTestDB(t) + pushDocDB(t, db) + + createTime10 := int64(10) + createTime20 := int64(20) + createTime30 := int64(30) + db.Create(&entity.Document{ID: "doc-later", KbID: "kb1", Name: sp("Doc Later"), CreatedBy: "user1", ParserConfig: entity.JSONMap{}, BaseModel: entity.BaseModel{CreateTime: &createTime30}}) + db.Create(&entity.Document{ID: "doc-other", KbID: "kb2", Name: sp("Doc Other"), CreatedBy: "user1", ParserConfig: entity.JSONMap{}, BaseModel: entity.BaseModel{CreateTime: &createTime10}}) + db.Create(&entity.Document{ID: "doc-earlier", KbID: "kb1", Name: sp("Doc Earlier"), CreatedBy: "user1", ParserConfig: entity.JSONMap{}, BaseModel: entity.BaseModel{CreateTime: &createTime20}}) + + docs, total, err := NewDocumentDAO().GetByKBID("kb1") + if err != nil { + t.Fatalf("GetByKBID failed: %v", err) + } + if total != 2 { + t.Fatalf("expected total=2, got %d", total) + } + if len(docs) != 2 { + t.Fatalf("expected 2 docs, got %d", len(docs)) + } + if docs[0].ID != "doc-earlier" || docs[1].ID != "doc-later" { + t.Fatalf("unexpected order: %s, %s", docs[0].ID, docs[1].ID) + } +} + func TestDocumentGetByDocumentIDAndDatasetIDUsesKBID(t *testing.T) { db := setupDocumentTestDB(t) pushDocDB(t, db) diff --git a/internal/handler/datasets.go b/internal/handler/datasets.go index d430741a83..2fa78b9d4b 100644 --- a/internal/handler/datasets.go +++ b/internal/handler/datasets.go @@ -592,6 +592,74 @@ func (h *DatasetsHandler) RemoveTags(c *gin.Context) { jsonResponse(c, common.CodeSuccess, true, "success") } +// RunIndex Run an indexing task (graph/raptor/mindmap) for a dataset. +func (h *DatasetsHandler) RunIndex(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + datasetID := strings.TrimSpace(c.Param("dataset_id")) + if datasetID == "" { + jsonError(c, common.CodeDataError, "dataset_id is required") + return + } + + userID := strings.TrimSpace(user.ID) + if userID == "" { + jsonError(c, common.CodeDataError, "user_id is required") + return + } + + indexType := strings.ToLower(strings.TrimSpace(c.Query("type"))) + data, code, err := h.datasetsService.RunIndex(userID, datasetID, indexType) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, data, "success") +} + +// TraceIndex Trace an indexing task (graph/raptor/mindmap) for a dataset. +func (h *DatasetsHandler) TraceIndex(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + datasetID := strings.TrimSpace(c.Param("dataset_id")) + if datasetID == "" { + jsonError(c, common.CodeDataError, "dataset_id is required") + return + } + + userID := strings.TrimSpace(user.ID) + if userID == "" { + jsonError(c, common.CodeDataError, "user_id is required") + return + } + + indexType := strings.ToLower(strings.TrimSpace(c.Query("type"))) + result, code, err := h.datasetsService.TraceIndex(datasetID, userID, indexType) + if err != nil { + jsonError(c, code, err.Error()) + return + } + if result == nil { + jsonResponse(c, common.CodeSuccess, map[string]interface{}{}, "success") + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "data": result, + "message": "success", + }) +} + // ListMetadataFlattened handles GET /api/v1/datasets/metadata/flattened. // @Summary List flattened metadata for datasets // @Description Get flattened metadata for multiple datasets diff --git a/internal/router/router.go b/internal/router/router.go index 29f6898cac..f1e9f6ec5a 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -255,6 +255,8 @@ func (r *Router) Setup(engine *gin.Engine) { datasets.PUT("/:dataset_id", r.datasetsHandler.UpdateDataset) datasets.GET("/:dataset_id/graph", r.datasetsHandler.GetKnowledgeGraph) datasets.DELETE("/:dataset_id/tags", r.datasetsHandler.RemoveTags) + datasets.GET("/:dataset_id/index", r.datasetsHandler.TraceIndex) + datasets.POST("/:dataset_id/index", r.datasetsHandler.RunIndex) datasets.DELETE("/:dataset_id/graph", r.datasetsHandler.DeleteKnowledgeGraph) datasets.POST("", r.datasetsHandler.CreateDataset) datasets.DELETE("", r.datasetsHandler.DeleteDatasets) diff --git a/internal/service/dataset.go b/internal/service/dataset.go index d7c6c870ff..17b759d5cb 100644 --- a/internal/service/dataset.go +++ b/internal/service/dataset.go @@ -24,17 +24,21 @@ import ( "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/engine" + redisengine "ragflow/internal/engine/redis" "ragflow/internal/entity" "ragflow/internal/entity/models" "ragflow/internal/server" "ragflow/internal/service/nlp" "ragflow/internal/utility" + "sort" "strings" "time" + "github.com/cespare/xxhash/v2" "github.com/google/uuid" "go.uber.org/zap" "gorm.io/gorm" + "gorm.io/gorm/clause" ) var ( @@ -68,6 +72,16 @@ var ( "number": {}, } datasetChunkMethodErrorMessage = "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'resume', 'table' or 'tag'" + validIndexTypes = []string{"graph", "raptor", "mindmap"} + indexTypeToTaskType = map[string]string{"graph": "graphrag", "raptor": "raptor", "mindmap": "mindmap"} + indexTypeToDisplayName = map[string]string{"graph": "Graph", "raptor": "RAPTOR", "mindmap": "Mindmap"} +) + +const ( + // Keep the legacy worker marker in queue payloads; persisted tasks use a real document ID. + graphRaptorQueueDocID = "graph_raptor_x" + maximumTaskPageNumber = int64(100000000) + serverQueueNamePrefix = "te" ) // DatasetService implements the RESTful dataset APIs from dataset_api.py. @@ -79,6 +93,7 @@ type DatasetService struct { tenantLLMDAO *dao.TenantLLMDAO pipelineLogDAO *dao.PipelineOperationLogDAO userTenantDAO *dao.UserTenantDAO + taskDAO *dao.TaskDAO searchService *SearchService docEngine engine.DocEngine embeddingCache *utility.EmbeddingLRU @@ -100,6 +115,7 @@ func NewDatasetService() *DatasetService { tenantLLMDAO: dao.NewTenantLLMDAO(), pipelineLogDAO: dao.NewPipelineOperationLogDAO(), userTenantDAO: dao.NewUserTenantDAO(), + taskDAO: dao.NewTaskDAO(), searchService: NewSearchService(), docEngine: engine.Get(), embeddingCache: utility.NewEmbeddingLRU(1000), @@ -149,6 +165,369 @@ func (s *DatasetService) UpdateDocumentMetadataConfig(userID, datasetID, documen return updatedDoc, common.CodeSuccess, nil } +func checkType(indexType string) bool { + haveType := false + for _, t := range validIndexTypes { + if indexType == t { + haveType = true + } + } + return haveType +} + +func (s *DatasetService) newRaptorOrGraphRagTask(sampleDoc *entity.Document, taskType string, taskDocID string, queueDocID string, docIDs []string) (*entity.Task, map[string]interface{}, error) { + if docIDs == nil || len(docIDs) == 0 { + docIDs = make([]string, 0) + } + if !checkIndexTaskType(taskType) { + return nil, nil, errors.New("type should be graphrag, raptor or mindmap") + } + + chunkingConfig, err := s.documentDAO.GetChunkingConfig(sampleDoc.ID) + if err != nil { + return nil, nil, err + } + + hasher := xxhash.New() + keys := make([]string, 0, len(chunkingConfig)) + for key := range chunkingConfig { + keys = append(keys, key) + } + sort.Strings(keys) + for _, key := range keys { + _, _ = hasher.Write([]byte(key)) + _, _ = hasher.Write([]byte{0}) + v, mErr := json.Marshal(chunkingConfig[key]) + if mErr != nil { + return nil, nil, mErr + } + _, _ = hasher.Write(v) + _, _ = hasher.Write([]byte{0}) + } + + taskID := strings.ReplaceAll(uuid.New().String(), "-", "")[:32] + beginAt := time.Now().Truncate(time.Second) + progressMsg := beginAt.Format("15:04:05") + " created task " + taskType + + for _, field := range []interface{}{taskDocID, maximumTaskPageNumber, maximumTaskPageNumber, taskType} { + _, _ = hasher.Write([]byte(fmt.Sprint(field))) + } + digest := fmt.Sprintf("%016x", hasher.Sum64()) + task := &entity.Task{ + ID: taskID, + DocID: taskDocID, + FromPage: maximumTaskPageNumber, + ToPage: maximumTaskPageNumber, + TaskType: taskType, + ProgressMsg: &progressMsg, + BeginAt: &beginAt, + Digest: &digest, + } + + queueMessage := map[string]interface{}{ + "id": taskID, + "doc_id": queueDocID, + "from_page": maximumTaskPageNumber, + "to_page": maximumTaskPageNumber, + "task_type": taskType, + "progress_msg": progressMsg, + "begin_at": beginAt.Format("2006-01-02 15:04:05"), + "digest": digest, + "doc_ids": docIDs, + } + + return task, queueMessage, nil +} + +func createDatasetIndexTaskInTx(tx *gorm.DB, task *entity.Task, queueDocID string) (*entity.Document, error) { + if task == nil { + return nil, errors.New("task is required") + } + if err := tx.Create(task).Error; err != nil { + return nil, err + } + + if queueDocID == "" { + return nil, nil + } + + var document entity.Document + err := tx.Select("id", "progress_msg", "process_begin_at").Where("id = ?", queueDocID).First(&document).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, err + } + + beginAt := time.Now().Truncate(time.Second) + if task.BeginAt != nil { + beginAt = *task.BeginAt + } + if err := tx.Model(&entity.Document{}).Where("id = ?", queueDocID).Updates(map[string]interface{}{ + "progress_msg": "Task is queued...", + "process_begin_at": beginAt, + }).Error; err != nil { + return nil, err + } + + return &document, nil +} + +func enqueueDatasetIndexTask(priority int, queueMessage map[string]interface{}) error { + redisClient := redisengine.Get() + if redisClient == nil || !redisClient.QueueProduct(datasetIndexQueueName(priority), queueMessage) { + return errors.New("Can't access Redis. Please check the Redis' status") + } + return nil +} + +func cleanupFailedDatasetIndexTask(taskID string, updatedDocument *entity.Document, kbID string, indexType string) error { + return dao.DB.Transaction(func(tx *gorm.DB) error { + if err := tx.Unscoped().Where("id = ?", taskID).Delete(&entity.Task{}).Error; err != nil { + return fmt.Errorf("delete task %s: %w", taskID, err) + } + + if column := datasetIndexTaskIDColumn(indexType); kbID != "" && column != "" { + if err := tx.Model(&entity.Knowledgebase{}).Where("id = ? AND "+column+" = ?", kbID, taskID).Update(column, nil).Error; err != nil { + return fmt.Errorf("clear dataset task id %s: %w", taskID, err) + } + } + + if updatedDocument == nil { + return nil + } + + return tx.Model(&entity.Document{}).Where("id = ?", updatedDocument.ID).Updates(map[string]interface{}{ + "progress_msg": updatedDocument.ProgressMsg, + "process_begin_at": updatedDocument.ProcessBeginAt, + }).Error + }) +} + +func datasetIndexTaskIDColumn(indexType string) string { + switch indexType { + case "graph": + return "graphrag_task_id" + case "raptor": + return "raptor_task_id" + case "mindmap": + return "mindmap_task_id" + default: + return "" + } +} + +func checkIndexTaskType(taskType string) bool { + switch taskType { + case "graphrag", "raptor", "mindmap": + return true + default: + return false + } +} + +func datasetIndexTaskID(kb *entity.Knowledgebase, indexType string) string { + if kb == nil { + return "" + } + switch indexType { + case "graph": + if kb.GraphragTaskID != nil { + return *kb.GraphragTaskID + } + case "raptor": + if kb.RaptorTaskID != nil { + return *kb.RaptorTaskID + } + case "mindmap": + if kb.MindmapTaskID != nil { + return *kb.MindmapTaskID + } + } + return "" +} + +func datasetIndexTaskIDUpdate(indexType, taskID string) map[string]interface{} { + switch indexType { + case "graph": + return map[string]interface{}{"graphrag_task_id": taskID} + case "raptor": + return map[string]interface{}{"raptor_task_id": taskID} + case "mindmap": + return map[string]interface{}{"mindmap_task_id": taskID} + default: + return map[string]interface{}{} + } +} + +func datasetIndexTaskIDs(kb *entity.Knowledgebase) []string { + if kb == nil { + return nil + } + + taskIDs := make([]string, 0, 3) + for _, taskID := range []*string{kb.GraphragTaskID, kb.RaptorTaskID, kb.MindmapTaskID} { + if taskID != nil && *taskID != "" { + taskIDs = append(taskIDs, *taskID) + } + } + return common.Deduplicate(taskIDs) +} + +func datasetIndexQueueName(priority int) string { + return fmt.Sprintf("%s.%d.common", serverQueueNamePrefix, priority) +} + +// RunIndex Run an indexing task (graph/raptor/mindmap) for a dataset. +func (s *DatasetService) RunIndex(userID, datasetID, indexType string) (map[string]interface{}, common.ErrorCode, error) { + if !checkType(indexType) { + return nil, common.CodeDataError, fmt.Errorf("Invalid index type '%s'. Must be one of %v", indexType, validIndexTypes) + } + + if datasetID == "" { + return nil, common.CodeDataError, errors.New(`Lack of "Dataset ID"`) + } + if !s.kbDAO.Accessible(datasetID, userID) { + return nil, common.CodeDataError, errors.New("No authorization.") + } + + kb, err := s.kbDAO.GetByID(datasetID) + if err != nil { + if dao.IsNotFoundErr(err) { + return nil, common.CodeDataError, errors.New("Invalid Dataset ID") + } + return nil, common.CodeDataError, errors.New("Internal server error") + } + + taskType := indexTypeToTaskType[indexType] + displayName := indexTypeToDisplayName[indexType] + + documents, code, err := s.getDocumentsByDatasetForIndex(datasetID) + if err != nil { + return nil, code, err + } + _ = documents + + sampleDocument := documents[0] + documentIDs := make([]string, len(documents)) + + for i, doc := range documents { + documentIDs[i] = doc.ID + } + + task, queueMessage, err := s.newRaptorOrGraphRagTask(sampleDocument, taskType, sampleDocument.ID, graphRaptorQueueDocID, documentIDs) + if err != nil { + common.Warn("Failed to build dataset index task", zap.String("dataset_id", datasetID), zap.String("task_type", taskType), zap.Error(err)) + return nil, common.CodeDataError, errors.New("Internal server error") + } + + var updatedDocument *entity.Document + var dataErr error + err = dao.DB.Transaction(func(tx *gorm.DB) error { + var lockedKB entity.Knowledgebase + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). + Where("id = ? AND status = ?", kb.ID, string(entity.StatusValid)). + First(&lockedKB).Error; err != nil { + return err + } + + existingTaskID := datasetIndexTaskID(&lockedKB, indexType) + if existingTaskID != "" { + var existingTask entity.Task + taskErr := tx.Where("id = ?", existingTaskID).First(&existingTask).Error + if taskErr != nil { + if errors.Is(taskErr, gorm.ErrRecordNotFound) { + } else { + return taskErr + } + } else if existingTask.Progress != 1 && existingTask.Progress != -1 { + dataErr = fmt.Errorf("Task %s in progress with status %v. A %s Task is already running.", existingTaskID, existingTask.Progress, displayName) + return dataErr + } + } + + updatedDocument, err = createDatasetIndexTaskInTx(tx, task, graphRaptorQueueDocID) + if err != nil { + return err + } + return tx.Model(&entity.Knowledgebase{}).Where("id = ?", lockedKB.ID).Updates(datasetIndexTaskIDUpdate(indexType, task.ID)).Error + }) + if err != nil { + if dataErr != nil { + return nil, common.CodeDataError, dataErr + } + common.Warn("Failed to create dataset index task", zap.String("dataset_id", datasetID), zap.String("task_type", taskType), zap.Error(err)) + return nil, common.CodeDataError, errors.New("Internal server error") + } + + if err := enqueueDatasetIndexTask(0, queueMessage); err != nil { + if cleanupErr := cleanupFailedDatasetIndexTask(task.ID, updatedDocument, kb.ID, indexType); cleanupErr != nil { + err = errors.Join(err, cleanupErr) + } + common.Warn("Failed to queue dataset index task", zap.String("dataset_id", datasetID), zap.String("task_type", taskType), zap.Error(err)) + return nil, common.CodeDataError, errors.New("Internal server error") + } + + return map[string]interface{}{"task_id": task.ID}, common.CodeSuccess, nil +} + +func (s *DatasetService) getDocumentsByDatasetForIndex(datasetID string) ([]*entity.Document, common.ErrorCode, error) { + documents, _, err := s.documentDAO.GetByKBID(datasetID) + if err != nil { + common.Warn("Failed to load dataset documents for index", zap.String("dataset_id", datasetID), zap.Error(err)) + return nil, common.CodeDataError, errors.New("Internal server error") + } + if len(documents) == 0 { + return nil, common.CodeDataError, fmt.Errorf("No documents in Dataset %s", datasetID) + } + return documents, common.CodeSuccess, nil +} + +type TraceIndexRequest struct { + Type string `json:"type" binding:"required"` +} + +// TraceIndex Trace an indexing task (graph/raptor/mindmap) for a dataset. +func (s *DatasetService) TraceIndex(datasetID, userID, indexType string) (*entity.Task, common.ErrorCode, error) { + if !checkType(indexType) { + return nil, common.CodeDataError, fmt.Errorf("Invalid index type '%s'. Must be one of %v", indexType, validIndexTypes) + } + + if datasetID == "" { + return nil, common.CodeDataError, errors.New(`Lack of "Dataset ID"`) + } + if !s.kbDAO.Accessible(datasetID, userID) { + return nil, common.CodeDataError, errors.New("No authorization.") + } + + kb, err := s.kbDAO.GetByID(datasetID) + if err != nil { + if dao.IsNotFoundErr(err) { + return nil, common.CodeDataError, errors.New("Invalid Dataset ID") + } + return nil, common.CodeDataError, errors.New("Internal server error") + } + + taskID := datasetIndexTaskID(kb, indexType) + + var task *entity.Task + if taskID != "" { + task, err = s.taskDAO.GetByID(taskID) + if err != nil { + if dao.IsNotFoundErr(err) { + return nil, common.CodeSuccess, nil + } + return nil, common.CodeServerError, errors.New("Internal server error") + } + if task == nil { + return nil, common.CodeSuccess, nil + } + } + + return task, common.CodeSuccess, nil +} + // SearchDatasetsRequest is the request structure for searching chunks across datasets. type SearchDatasetsRequest struct { DatasetIDs []string `json:"dataset_ids" binding:"required"` @@ -1638,6 +2017,12 @@ func jsonMapValue(m entity.JSONMap) interface{} { func (s *DatasetService) deleteDataset(tenantID string, kb *entity.Knowledgebase) error { return dao.DB.Transaction(func(tx *gorm.DB) error { + if taskIDs := datasetIndexTaskIDs(kb); len(taskIDs) > 0 { + if err := tx.Where("id IN ?", taskIDs).Delete(&entity.Task{}).Error; err != nil { + return fmt.Errorf("Delete dataset error for %s", kb.ID) + } + } + var documents []entity.Document if err := tx.Where("kb_id = ?", kb.ID).Find(&documents).Error; err != nil { return fmt.Errorf("Delete dataset error for %s", kb.ID) diff --git a/internal/service/dataset_task_cleanup_test.go b/internal/service/dataset_task_cleanup_test.go new file mode 100644 index 0000000000..4f547b72a2 --- /dev/null +++ b/internal/service/dataset_task_cleanup_test.go @@ -0,0 +1,110 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "errors" + "testing" + "time" + + "gorm.io/gorm" + + "ragflow/internal/dao" + "ragflow/internal/entity" +) + +func TestCleanupFailedDatasetIndexTaskDeletesTaskAndRestoresDocument(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + + previousMsg := "previous progress" + previousBeginAt := time.Date(2026, 6, 18, 10, 0, 0, 0, time.UTC) + queuedMsg := "Task is queued..." + queuedBeginAt := previousBeginAt.Add(time.Hour) + taskID := "task-1" + + kb := &entity.Knowledgebase{ + ID: "kb-1", + TenantID: "user-1", + Name: "test-kb", + EmbdID: "embedding@OpenAI", + CreatedBy: "user-1", + Permission: string(entity.TenantPermissionMe), + ParserID: "naive", + ParserConfig: entity.JSONMap{}, + GraphragTaskID: &taskID, + Status: sptr("1"), + } + if err := dao.DB.Create(kb).Error; err != nil { + t.Fatalf("insert kb: %v", err) + } + + doc := &entity.Document{ + ID: "doc-1", + KbID: "kb-1", + ParserID: "naive", + ParserConfig: entity.JSONMap{}, + SourceType: "local", + Type: "pdf", + CreatedBy: "user-1", + Suffix: ".pdf", + ProgressMsg: &queuedMsg, + ProcessBeginAt: &queuedBeginAt, + } + if err := dao.DB.Create(doc).Error; err != nil { + t.Fatalf("insert document: %v", err) + } + + task := &entity.Task{ID: taskID, DocID: doc.ID, TaskType: "graphrag"} + if err := dao.DB.Create(task).Error; err != nil { + t.Fatalf("insert task: %v", err) + } + + snapshot := &entity.Document{ + ID: doc.ID, + ProgressMsg: &previousMsg, + ProcessBeginAt: &previousBeginAt, + } + if err := cleanupFailedDatasetIndexTask(task.ID, snapshot, kb.ID, "graph"); err != nil { + t.Fatalf("cleanup failed: %v", err) + } + + var persistedTask entity.Task + err := dao.DB.Where("id = ?", task.ID).First(&persistedTask).Error + if !errors.Is(err, gorm.ErrRecordNotFound) { + t.Fatalf("expected task to be deleted, got err=%v task=%#v", err, persistedTask) + } + + persistedDoc, err := dao.NewDocumentDAO().GetByID(doc.ID) + if err != nil { + t.Fatalf("fetch document: %v", err) + } + if persistedDoc.ProgressMsg == nil || *persistedDoc.ProgressMsg != previousMsg { + t.Fatalf("expected progress_msg %q, got %#v", previousMsg, persistedDoc.ProgressMsg) + } + if persistedDoc.ProcessBeginAt == nil || !persistedDoc.ProcessBeginAt.Equal(previousBeginAt) { + t.Fatalf("expected process_begin_at %v, got %#v", previousBeginAt, persistedDoc.ProcessBeginAt) + } + + var persistedKB entity.Knowledgebase + if err := dao.DB.Where("id = ?", kb.ID).First(&persistedKB).Error; err != nil { + t.Fatalf("fetch kb: %v", err) + } + if persistedKB.GraphragTaskID != nil { + t.Fatalf("expected graphrag_task_id to be cleared, got %#v", *persistedKB.GraphragTaskID) + } +}