feat[Go]: implement datasets/<dataset_id>/index P/G (#16153)

### What problem does this PR solve?

```
POST: http://localhost:9384/api/v1/datasets/433b390c630411f1a13eab5f89540b2a/index?type=graph

Output: {
    "code": 0,
    "data": {
        "task_id": "ff5a3546bafa49d794a9a050d99c4a52"
    },
    "message": "success"
}
```

---

```
GET: http://localhost:9384/api/v1/datasets/433b390c630411f1a13eab5f89540b2a/index?type=graph

Output: {
    "code": 0,
    "data": {
        "id": "ff5a3546bafa49d794a9a050d99c4a52",
        "doc_id": "graph_raptor_x",
        "from_page": 100000000,
        "to_page": 100000000,
        "task_type": "graphrag",
        "priority": 0,
        "begin_at": "2026-06-17T18:07:45+08:00",
        "process_duration": 4.108135,
        "progress": -1,
        "progress_msg": "18:07:45 created task graphrag\n18:07:47 Task has been received.\n18:07:49 [ERROR][Exception]: Model config not found: Qwen/Qwen3-235B-A22B@test@SILICONFLOW",
        "retry_count": 1,
        "digest": "f16fd067d5c92cec",
        "create_time": 1781690865552,
        "create_date": "2026-06-17T18:07:45+08:00",
        "update_time": 1781690869108,
        "update_date": "2026-06-17T18:07:49+08:00"
    },
    "message": "success"
}

```

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Haruko386
2026-06-18 17:57:24 +08:00
committed by GitHub
parent 5f6ebc97c6
commit 217c2a94c2
6 changed files with 671 additions and 0 deletions

View File

@@ -109,6 +109,86 @@ func (dao *DocumentDAO) ListByKBID(kbID string, offset, limit int) ([]*entity.Do
return documents, total, err
}
// GetByKBID retrieves all documents in a knowledge base ordered by create time.
func (dao *DocumentDAO) GetByKBID(kbID string) ([]*entity.Document, int64, error) {
var documents []*entity.Document
var total int64
query := DB.Model(&entity.Document{}).Where("kb_id = ?", kbID)
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
err := query.Order("create_time ASC").Find(&documents).Error
return documents, total, err
}
// GetChunkingConfig returns the document, dataset, and tenant fields used to
// build a parsing task digest, mirroring DocumentService.get_chunking_config.
func (dao *DocumentDAO) GetChunkingConfig(docID string) (map[string]interface{}, error) {
var row struct {
ID string `gorm:"column:id"`
KbID string `gorm:"column:kb_id"`
ParserID string `gorm:"column:parser_id"`
ParserConfig entity.JSONMap `gorm:"column:parser_config"`
Size int64 `gorm:"column:size"`
ContentHash *string `gorm:"column:content_hash"`
Language *string `gorm:"column:language"`
EmbdID string `gorm:"column:embd_id"`
TenantID string `gorm:"column:tenant_id"`
Img2TxtID string `gorm:"column:img2txt_id"`
ASRID string `gorm:"column:asr_id"`
LLMID string `gorm:"column:llm_id"`
}
err := DB.Table("document").
Select(`
document.id,
document.kb_id,
document.parser_id,
document.parser_config,
document.size,
document.content_hash,
knowledgebase.language,
knowledgebase.embd_id,
tenant.id AS tenant_id,
tenant.img2txt_id,
tenant.asr_id,
tenant.llm_id
`).
Joins("JOIN knowledgebase ON document.kb_id = knowledgebase.id").
Joins("JOIN tenant ON knowledgebase.tenant_id = tenant.id").
Where("document.id = ?", docID).
Take(&row).Error
if err != nil {
return nil, err
}
config := map[string]interface{}{
"id": row.ID,
"kb_id": row.KbID,
"parser_id": row.ParserID,
"parser_config": row.ParserConfig,
"size": row.Size,
"embd_id": row.EmbdID,
"tenant_id": row.TenantID,
"img2txt_id": row.Img2TxtID,
"asr_id": row.ASRID,
"llm_id": row.LLMID,
}
if row.ContentHash != nil {
config["content_hash"] = *row.ContentHash
} else {
config["content_hash"] = nil
}
if row.Language != nil {
config["language"] = *row.Language
} else {
config["language"] = nil
}
return config, nil
}
// DeleteByTenantID deletes all documents by tenant ID (hard delete)
func (dao *DocumentDAO) DeleteByTenantID(tenantID string) (int64, error) {
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&entity.Document{})

View File

@@ -118,6 +118,32 @@ func TestDocumentGetByIDs_NoMatch(t *testing.T) {
}
}
func TestDocumentGetByKBIDOrdersByCreateTime(t *testing.T) {
db := setupDocumentTestDB(t)
pushDocDB(t, db)
createTime10 := int64(10)
createTime20 := int64(20)
createTime30 := int64(30)
db.Create(&entity.Document{ID: "doc-later", KbID: "kb1", Name: sp("Doc Later"), CreatedBy: "user1", ParserConfig: entity.JSONMap{}, BaseModel: entity.BaseModel{CreateTime: &createTime30}})
db.Create(&entity.Document{ID: "doc-other", KbID: "kb2", Name: sp("Doc Other"), CreatedBy: "user1", ParserConfig: entity.JSONMap{}, BaseModel: entity.BaseModel{CreateTime: &createTime10}})
db.Create(&entity.Document{ID: "doc-earlier", KbID: "kb1", Name: sp("Doc Earlier"), CreatedBy: "user1", ParserConfig: entity.JSONMap{}, BaseModel: entity.BaseModel{CreateTime: &createTime20}})
docs, total, err := NewDocumentDAO().GetByKBID("kb1")
if err != nil {
t.Fatalf("GetByKBID failed: %v", err)
}
if total != 2 {
t.Fatalf("expected total=2, got %d", total)
}
if len(docs) != 2 {
t.Fatalf("expected 2 docs, got %d", len(docs))
}
if docs[0].ID != "doc-earlier" || docs[1].ID != "doc-later" {
t.Fatalf("unexpected order: %s, %s", docs[0].ID, docs[1].ID)
}
}
func TestDocumentGetByDocumentIDAndDatasetIDUsesKBID(t *testing.T) {
db := setupDocumentTestDB(t)
pushDocDB(t, db)

View File

@@ -592,6 +592,74 @@ func (h *DatasetsHandler) RemoveTags(c *gin.Context) {
jsonResponse(c, common.CodeSuccess, true, "success")
}
// RunIndex Run an indexing task (graph/raptor/mindmap) for a dataset.
func (h *DatasetsHandler) RunIndex(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
datasetID := strings.TrimSpace(c.Param("dataset_id"))
if datasetID == "" {
jsonError(c, common.CodeDataError, "dataset_id is required")
return
}
userID := strings.TrimSpace(user.ID)
if userID == "" {
jsonError(c, common.CodeDataError, "user_id is required")
return
}
indexType := strings.ToLower(strings.TrimSpace(c.Query("type")))
data, code, err := h.datasetsService.RunIndex(userID, datasetID, indexType)
if err != nil {
jsonError(c, code, err.Error())
return
}
jsonResponse(c, common.CodeSuccess, data, "success")
}
// TraceIndex Trace an indexing task (graph/raptor/mindmap) for a dataset.
func (h *DatasetsHandler) TraceIndex(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
datasetID := strings.TrimSpace(c.Param("dataset_id"))
if datasetID == "" {
jsonError(c, common.CodeDataError, "dataset_id is required")
return
}
userID := strings.TrimSpace(user.ID)
if userID == "" {
jsonError(c, common.CodeDataError, "user_id is required")
return
}
indexType := strings.ToLower(strings.TrimSpace(c.Query("type")))
result, code, err := h.datasetsService.TraceIndex(datasetID, userID, indexType)
if err != nil {
jsonError(c, code, err.Error())
return
}
if result == nil {
jsonResponse(c, common.CodeSuccess, map[string]interface{}{}, "success")
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"data": result,
"message": "success",
})
}
// ListMetadataFlattened handles GET /api/v1/datasets/metadata/flattened.
// @Summary List flattened metadata for datasets
// @Description Get flattened metadata for multiple datasets

View File

@@ -255,6 +255,8 @@ func (r *Router) Setup(engine *gin.Engine) {
datasets.PUT("/:dataset_id", r.datasetsHandler.UpdateDataset)
datasets.GET("/:dataset_id/graph", r.datasetsHandler.GetKnowledgeGraph)
datasets.DELETE("/:dataset_id/tags", r.datasetsHandler.RemoveTags)
datasets.GET("/:dataset_id/index", r.datasetsHandler.TraceIndex)
datasets.POST("/:dataset_id/index", r.datasetsHandler.RunIndex)
datasets.DELETE("/:dataset_id/graph", r.datasetsHandler.DeleteKnowledgeGraph)
datasets.POST("", r.datasetsHandler.CreateDataset)
datasets.DELETE("", r.datasetsHandler.DeleteDatasets)

View File

@@ -24,17 +24,21 @@ import (
"ragflow/internal/common"
"ragflow/internal/dao"
"ragflow/internal/engine"
redisengine "ragflow/internal/engine/redis"
"ragflow/internal/entity"
"ragflow/internal/entity/models"
"ragflow/internal/server"
"ragflow/internal/service/nlp"
"ragflow/internal/utility"
"sort"
"strings"
"time"
"github.com/cespare/xxhash/v2"
"github.com/google/uuid"
"go.uber.org/zap"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
var (
@@ -68,6 +72,16 @@ var (
"number": {},
}
datasetChunkMethodErrorMessage = "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'resume', 'table' or 'tag'"
validIndexTypes = []string{"graph", "raptor", "mindmap"}
indexTypeToTaskType = map[string]string{"graph": "graphrag", "raptor": "raptor", "mindmap": "mindmap"}
indexTypeToDisplayName = map[string]string{"graph": "Graph", "raptor": "RAPTOR", "mindmap": "Mindmap"}
)
const (
// Keep the legacy worker marker in queue payloads; persisted tasks use a real document ID.
graphRaptorQueueDocID = "graph_raptor_x"
maximumTaskPageNumber = int64(100000000)
serverQueueNamePrefix = "te"
)
// DatasetService implements the RESTful dataset APIs from dataset_api.py.
@@ -79,6 +93,7 @@ type DatasetService struct {
tenantLLMDAO *dao.TenantLLMDAO
pipelineLogDAO *dao.PipelineOperationLogDAO
userTenantDAO *dao.UserTenantDAO
taskDAO *dao.TaskDAO
searchService *SearchService
docEngine engine.DocEngine
embeddingCache *utility.EmbeddingLRU
@@ -100,6 +115,7 @@ func NewDatasetService() *DatasetService {
tenantLLMDAO: dao.NewTenantLLMDAO(),
pipelineLogDAO: dao.NewPipelineOperationLogDAO(),
userTenantDAO: dao.NewUserTenantDAO(),
taskDAO: dao.NewTaskDAO(),
searchService: NewSearchService(),
docEngine: engine.Get(),
embeddingCache: utility.NewEmbeddingLRU(1000),
@@ -149,6 +165,369 @@ func (s *DatasetService) UpdateDocumentMetadataConfig(userID, datasetID, documen
return updatedDoc, common.CodeSuccess, nil
}
func checkType(indexType string) bool {
haveType := false
for _, t := range validIndexTypes {
if indexType == t {
haveType = true
}
}
return haveType
}
func (s *DatasetService) newRaptorOrGraphRagTask(sampleDoc *entity.Document, taskType string, taskDocID string, queueDocID string, docIDs []string) (*entity.Task, map[string]interface{}, error) {
if docIDs == nil || len(docIDs) == 0 {
docIDs = make([]string, 0)
}
if !checkIndexTaskType(taskType) {
return nil, nil, errors.New("type should be graphrag, raptor or mindmap")
}
chunkingConfig, err := s.documentDAO.GetChunkingConfig(sampleDoc.ID)
if err != nil {
return nil, nil, err
}
hasher := xxhash.New()
keys := make([]string, 0, len(chunkingConfig))
for key := range chunkingConfig {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
_, _ = hasher.Write([]byte(key))
_, _ = hasher.Write([]byte{0})
v, mErr := json.Marshal(chunkingConfig[key])
if mErr != nil {
return nil, nil, mErr
}
_, _ = hasher.Write(v)
_, _ = hasher.Write([]byte{0})
}
taskID := strings.ReplaceAll(uuid.New().String(), "-", "")[:32]
beginAt := time.Now().Truncate(time.Second)
progressMsg := beginAt.Format("15:04:05") + " created task " + taskType
for _, field := range []interface{}{taskDocID, maximumTaskPageNumber, maximumTaskPageNumber, taskType} {
_, _ = hasher.Write([]byte(fmt.Sprint(field)))
}
digest := fmt.Sprintf("%016x", hasher.Sum64())
task := &entity.Task{
ID: taskID,
DocID: taskDocID,
FromPage: maximumTaskPageNumber,
ToPage: maximumTaskPageNumber,
TaskType: taskType,
ProgressMsg: &progressMsg,
BeginAt: &beginAt,
Digest: &digest,
}
queueMessage := map[string]interface{}{
"id": taskID,
"doc_id": queueDocID,
"from_page": maximumTaskPageNumber,
"to_page": maximumTaskPageNumber,
"task_type": taskType,
"progress_msg": progressMsg,
"begin_at": beginAt.Format("2006-01-02 15:04:05"),
"digest": digest,
"doc_ids": docIDs,
}
return task, queueMessage, nil
}
func createDatasetIndexTaskInTx(tx *gorm.DB, task *entity.Task, queueDocID string) (*entity.Document, error) {
if task == nil {
return nil, errors.New("task is required")
}
if err := tx.Create(task).Error; err != nil {
return nil, err
}
if queueDocID == "" {
return nil, nil
}
var document entity.Document
err := tx.Select("id", "progress_msg", "process_begin_at").Where("id = ?", queueDocID).First(&document).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
beginAt := time.Now().Truncate(time.Second)
if task.BeginAt != nil {
beginAt = *task.BeginAt
}
if err := tx.Model(&entity.Document{}).Where("id = ?", queueDocID).Updates(map[string]interface{}{
"progress_msg": "Task is queued...",
"process_begin_at": beginAt,
}).Error; err != nil {
return nil, err
}
return &document, nil
}
func enqueueDatasetIndexTask(priority int, queueMessage map[string]interface{}) error {
redisClient := redisengine.Get()
if redisClient == nil || !redisClient.QueueProduct(datasetIndexQueueName(priority), queueMessage) {
return errors.New("Can't access Redis. Please check the Redis' status")
}
return nil
}
func cleanupFailedDatasetIndexTask(taskID string, updatedDocument *entity.Document, kbID string, indexType string) error {
return dao.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Unscoped().Where("id = ?", taskID).Delete(&entity.Task{}).Error; err != nil {
return fmt.Errorf("delete task %s: %w", taskID, err)
}
if column := datasetIndexTaskIDColumn(indexType); kbID != "" && column != "" {
if err := tx.Model(&entity.Knowledgebase{}).Where("id = ? AND "+column+" = ?", kbID, taskID).Update(column, nil).Error; err != nil {
return fmt.Errorf("clear dataset task id %s: %w", taskID, err)
}
}
if updatedDocument == nil {
return nil
}
return tx.Model(&entity.Document{}).Where("id = ?", updatedDocument.ID).Updates(map[string]interface{}{
"progress_msg": updatedDocument.ProgressMsg,
"process_begin_at": updatedDocument.ProcessBeginAt,
}).Error
})
}
func datasetIndexTaskIDColumn(indexType string) string {
switch indexType {
case "graph":
return "graphrag_task_id"
case "raptor":
return "raptor_task_id"
case "mindmap":
return "mindmap_task_id"
default:
return ""
}
}
func checkIndexTaskType(taskType string) bool {
switch taskType {
case "graphrag", "raptor", "mindmap":
return true
default:
return false
}
}
func datasetIndexTaskID(kb *entity.Knowledgebase, indexType string) string {
if kb == nil {
return ""
}
switch indexType {
case "graph":
if kb.GraphragTaskID != nil {
return *kb.GraphragTaskID
}
case "raptor":
if kb.RaptorTaskID != nil {
return *kb.RaptorTaskID
}
case "mindmap":
if kb.MindmapTaskID != nil {
return *kb.MindmapTaskID
}
}
return ""
}
func datasetIndexTaskIDUpdate(indexType, taskID string) map[string]interface{} {
switch indexType {
case "graph":
return map[string]interface{}{"graphrag_task_id": taskID}
case "raptor":
return map[string]interface{}{"raptor_task_id": taskID}
case "mindmap":
return map[string]interface{}{"mindmap_task_id": taskID}
default:
return map[string]interface{}{}
}
}
func datasetIndexTaskIDs(kb *entity.Knowledgebase) []string {
if kb == nil {
return nil
}
taskIDs := make([]string, 0, 3)
for _, taskID := range []*string{kb.GraphragTaskID, kb.RaptorTaskID, kb.MindmapTaskID} {
if taskID != nil && *taskID != "" {
taskIDs = append(taskIDs, *taskID)
}
}
return common.Deduplicate(taskIDs)
}
func datasetIndexQueueName(priority int) string {
return fmt.Sprintf("%s.%d.common", serverQueueNamePrefix, priority)
}
// RunIndex Run an indexing task (graph/raptor/mindmap) for a dataset.
func (s *DatasetService) RunIndex(userID, datasetID, indexType string) (map[string]interface{}, common.ErrorCode, error) {
if !checkType(indexType) {
return nil, common.CodeDataError, fmt.Errorf("Invalid index type '%s'. Must be one of %v", indexType, validIndexTypes)
}
if datasetID == "" {
return nil, common.CodeDataError, errors.New(`Lack of "Dataset ID"`)
}
if !s.kbDAO.Accessible(datasetID, userID) {
return nil, common.CodeDataError, errors.New("No authorization.")
}
kb, err := s.kbDAO.GetByID(datasetID)
if err != nil {
if dao.IsNotFoundErr(err) {
return nil, common.CodeDataError, errors.New("Invalid Dataset ID")
}
return nil, common.CodeDataError, errors.New("Internal server error")
}
taskType := indexTypeToTaskType[indexType]
displayName := indexTypeToDisplayName[indexType]
documents, code, err := s.getDocumentsByDatasetForIndex(datasetID)
if err != nil {
return nil, code, err
}
_ = documents
sampleDocument := documents[0]
documentIDs := make([]string, len(documents))
for i, doc := range documents {
documentIDs[i] = doc.ID
}
task, queueMessage, err := s.newRaptorOrGraphRagTask(sampleDocument, taskType, sampleDocument.ID, graphRaptorQueueDocID, documentIDs)
if err != nil {
common.Warn("Failed to build dataset index task", zap.String("dataset_id", datasetID), zap.String("task_type", taskType), zap.Error(err))
return nil, common.CodeDataError, errors.New("Internal server error")
}
var updatedDocument *entity.Document
var dataErr error
err = dao.DB.Transaction(func(tx *gorm.DB) error {
var lockedKB entity.Knowledgebase
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Where("id = ? AND status = ?", kb.ID, string(entity.StatusValid)).
First(&lockedKB).Error; err != nil {
return err
}
existingTaskID := datasetIndexTaskID(&lockedKB, indexType)
if existingTaskID != "" {
var existingTask entity.Task
taskErr := tx.Where("id = ?", existingTaskID).First(&existingTask).Error
if taskErr != nil {
if errors.Is(taskErr, gorm.ErrRecordNotFound) {
} else {
return taskErr
}
} else if existingTask.Progress != 1 && existingTask.Progress != -1 {
dataErr = fmt.Errorf("Task %s in progress with status %v. A %s Task is already running.", existingTaskID, existingTask.Progress, displayName)
return dataErr
}
}
updatedDocument, err = createDatasetIndexTaskInTx(tx, task, graphRaptorQueueDocID)
if err != nil {
return err
}
return tx.Model(&entity.Knowledgebase{}).Where("id = ?", lockedKB.ID).Updates(datasetIndexTaskIDUpdate(indexType, task.ID)).Error
})
if err != nil {
if dataErr != nil {
return nil, common.CodeDataError, dataErr
}
common.Warn("Failed to create dataset index task", zap.String("dataset_id", datasetID), zap.String("task_type", taskType), zap.Error(err))
return nil, common.CodeDataError, errors.New("Internal server error")
}
if err := enqueueDatasetIndexTask(0, queueMessage); err != nil {
if cleanupErr := cleanupFailedDatasetIndexTask(task.ID, updatedDocument, kb.ID, indexType); cleanupErr != nil {
err = errors.Join(err, cleanupErr)
}
common.Warn("Failed to queue dataset index task", zap.String("dataset_id", datasetID), zap.String("task_type", taskType), zap.Error(err))
return nil, common.CodeDataError, errors.New("Internal server error")
}
return map[string]interface{}{"task_id": task.ID}, common.CodeSuccess, nil
}
func (s *DatasetService) getDocumentsByDatasetForIndex(datasetID string) ([]*entity.Document, common.ErrorCode, error) {
documents, _, err := s.documentDAO.GetByKBID(datasetID)
if err != nil {
common.Warn("Failed to load dataset documents for index", zap.String("dataset_id", datasetID), zap.Error(err))
return nil, common.CodeDataError, errors.New("Internal server error")
}
if len(documents) == 0 {
return nil, common.CodeDataError, fmt.Errorf("No documents in Dataset %s", datasetID)
}
return documents, common.CodeSuccess, nil
}
type TraceIndexRequest struct {
Type string `json:"type" binding:"required"`
}
// TraceIndex Trace an indexing task (graph/raptor/mindmap) for a dataset.
func (s *DatasetService) TraceIndex(datasetID, userID, indexType string) (*entity.Task, common.ErrorCode, error) {
if !checkType(indexType) {
return nil, common.CodeDataError, fmt.Errorf("Invalid index type '%s'. Must be one of %v", indexType, validIndexTypes)
}
if datasetID == "" {
return nil, common.CodeDataError, errors.New(`Lack of "Dataset ID"`)
}
if !s.kbDAO.Accessible(datasetID, userID) {
return nil, common.CodeDataError, errors.New("No authorization.")
}
kb, err := s.kbDAO.GetByID(datasetID)
if err != nil {
if dao.IsNotFoundErr(err) {
return nil, common.CodeDataError, errors.New("Invalid Dataset ID")
}
return nil, common.CodeDataError, errors.New("Internal server error")
}
taskID := datasetIndexTaskID(kb, indexType)
var task *entity.Task
if taskID != "" {
task, err = s.taskDAO.GetByID(taskID)
if err != nil {
if dao.IsNotFoundErr(err) {
return nil, common.CodeSuccess, nil
}
return nil, common.CodeServerError, errors.New("Internal server error")
}
if task == nil {
return nil, common.CodeSuccess, nil
}
}
return task, common.CodeSuccess, nil
}
// SearchDatasetsRequest is the request structure for searching chunks across datasets.
type SearchDatasetsRequest struct {
DatasetIDs []string `json:"dataset_ids" binding:"required"`
@@ -1638,6 +2017,12 @@ func jsonMapValue(m entity.JSONMap) interface{} {
func (s *DatasetService) deleteDataset(tenantID string, kb *entity.Knowledgebase) error {
return dao.DB.Transaction(func(tx *gorm.DB) error {
if taskIDs := datasetIndexTaskIDs(kb); len(taskIDs) > 0 {
if err := tx.Where("id IN ?", taskIDs).Delete(&entity.Task{}).Error; err != nil {
return fmt.Errorf("Delete dataset error for %s", kb.ID)
}
}
var documents []entity.Document
if err := tx.Where("kb_id = ?", kb.ID).Find(&documents).Error; err != nil {
return fmt.Errorf("Delete dataset error for %s", kb.ID)

View File

@@ -0,0 +1,110 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package service
import (
"errors"
"testing"
"time"
"gorm.io/gorm"
"ragflow/internal/dao"
"ragflow/internal/entity"
)
func TestCleanupFailedDatasetIndexTaskDeletesTaskAndRestoresDocument(t *testing.T) {
db := setupServiceTestDB(t)
pushServiceDB(t, db)
previousMsg := "previous progress"
previousBeginAt := time.Date(2026, 6, 18, 10, 0, 0, 0, time.UTC)
queuedMsg := "Task is queued..."
queuedBeginAt := previousBeginAt.Add(time.Hour)
taskID := "task-1"
kb := &entity.Knowledgebase{
ID: "kb-1",
TenantID: "user-1",
Name: "test-kb",
EmbdID: "embedding@OpenAI",
CreatedBy: "user-1",
Permission: string(entity.TenantPermissionMe),
ParserID: "naive",
ParserConfig: entity.JSONMap{},
GraphragTaskID: &taskID,
Status: sptr("1"),
}
if err := dao.DB.Create(kb).Error; err != nil {
t.Fatalf("insert kb: %v", err)
}
doc := &entity.Document{
ID: "doc-1",
KbID: "kb-1",
ParserID: "naive",
ParserConfig: entity.JSONMap{},
SourceType: "local",
Type: "pdf",
CreatedBy: "user-1",
Suffix: ".pdf",
ProgressMsg: &queuedMsg,
ProcessBeginAt: &queuedBeginAt,
}
if err := dao.DB.Create(doc).Error; err != nil {
t.Fatalf("insert document: %v", err)
}
task := &entity.Task{ID: taskID, DocID: doc.ID, TaskType: "graphrag"}
if err := dao.DB.Create(task).Error; err != nil {
t.Fatalf("insert task: %v", err)
}
snapshot := &entity.Document{
ID: doc.ID,
ProgressMsg: &previousMsg,
ProcessBeginAt: &previousBeginAt,
}
if err := cleanupFailedDatasetIndexTask(task.ID, snapshot, kb.ID, "graph"); err != nil {
t.Fatalf("cleanup failed: %v", err)
}
var persistedTask entity.Task
err := dao.DB.Where("id = ?", task.ID).First(&persistedTask).Error
if !errors.Is(err, gorm.ErrRecordNotFound) {
t.Fatalf("expected task to be deleted, got err=%v task=%#v", err, persistedTask)
}
persistedDoc, err := dao.NewDocumentDAO().GetByID(doc.ID)
if err != nil {
t.Fatalf("fetch document: %v", err)
}
if persistedDoc.ProgressMsg == nil || *persistedDoc.ProgressMsg != previousMsg {
t.Fatalf("expected progress_msg %q, got %#v", previousMsg, persistedDoc.ProgressMsg)
}
if persistedDoc.ProcessBeginAt == nil || !persistedDoc.ProcessBeginAt.Equal(previousBeginAt) {
t.Fatalf("expected process_begin_at %v, got %#v", previousBeginAt, persistedDoc.ProcessBeginAt)
}
var persistedKB entity.Knowledgebase
if err := dao.DB.Where("id = ?", kb.ID).First(&persistedKB).Error; err != nil {
t.Fatalf("fetch kb: %v", err)
}
if persistedKB.GraphragTaskID != nil {
t.Fatalf("expected graphrag_task_id to be cleared, got %#v", *persistedKB.GraphragTaskID)
}
}