mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
feat[Go]: implement datasets/<dataset_id>/index P/G (#16153)
### What problem does this PR solve? ``` POST: http://localhost:9384/api/v1/datasets/433b390c630411f1a13eab5f89540b2a/index?type=graph Output: { "code": 0, "data": { "task_id": "ff5a3546bafa49d794a9a050d99c4a52" }, "message": "success" } ``` --- ``` GET: http://localhost:9384/api/v1/datasets/433b390c630411f1a13eab5f89540b2a/index?type=graph Output: { "code": 0, "data": { "id": "ff5a3546bafa49d794a9a050d99c4a52", "doc_id": "graph_raptor_x", "from_page": 100000000, "to_page": 100000000, "task_type": "graphrag", "priority": 0, "begin_at": "2026-06-17T18:07:45+08:00", "process_duration": 4.108135, "progress": -1, "progress_msg": "18:07:45 created task graphrag\n18:07:47 Task has been received.\n18:07:49 [ERROR][Exception]: Model config not found: Qwen/Qwen3-235B-A22B@test@SILICONFLOW", "retry_count": 1, "digest": "f16fd067d5c92cec", "create_time": 1781690865552, "create_date": "2026-06-17T18:07:45+08:00", "update_time": 1781690869108, "update_date": "2026-06-17T18:07:49+08:00" }, "message": "success" } ``` ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -109,6 +109,86 @@ func (dao *DocumentDAO) ListByKBID(kbID string, offset, limit int) ([]*entity.Do
|
||||
return documents, total, err
|
||||
}
|
||||
|
||||
// GetByKBID retrieves all documents in a knowledge base ordered by create time.
|
||||
func (dao *DocumentDAO) GetByKBID(kbID string) ([]*entity.Document, int64, error) {
|
||||
var documents []*entity.Document
|
||||
var total int64
|
||||
|
||||
query := DB.Model(&entity.Document{}).Where("kb_id = ?", kbID)
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
err := query.Order("create_time ASC").Find(&documents).Error
|
||||
return documents, total, err
|
||||
}
|
||||
|
||||
// GetChunkingConfig returns the document, dataset, and tenant fields used to
|
||||
// build a parsing task digest, mirroring DocumentService.get_chunking_config.
|
||||
func (dao *DocumentDAO) GetChunkingConfig(docID string) (map[string]interface{}, error) {
|
||||
var row struct {
|
||||
ID string `gorm:"column:id"`
|
||||
KbID string `gorm:"column:kb_id"`
|
||||
ParserID string `gorm:"column:parser_id"`
|
||||
ParserConfig entity.JSONMap `gorm:"column:parser_config"`
|
||||
Size int64 `gorm:"column:size"`
|
||||
ContentHash *string `gorm:"column:content_hash"`
|
||||
Language *string `gorm:"column:language"`
|
||||
EmbdID string `gorm:"column:embd_id"`
|
||||
TenantID string `gorm:"column:tenant_id"`
|
||||
Img2TxtID string `gorm:"column:img2txt_id"`
|
||||
ASRID string `gorm:"column:asr_id"`
|
||||
LLMID string `gorm:"column:llm_id"`
|
||||
}
|
||||
|
||||
err := DB.Table("document").
|
||||
Select(`
|
||||
document.id,
|
||||
document.kb_id,
|
||||
document.parser_id,
|
||||
document.parser_config,
|
||||
document.size,
|
||||
document.content_hash,
|
||||
knowledgebase.language,
|
||||
knowledgebase.embd_id,
|
||||
tenant.id AS tenant_id,
|
||||
tenant.img2txt_id,
|
||||
tenant.asr_id,
|
||||
tenant.llm_id
|
||||
`).
|
||||
Joins("JOIN knowledgebase ON document.kb_id = knowledgebase.id").
|
||||
Joins("JOIN tenant ON knowledgebase.tenant_id = tenant.id").
|
||||
Where("document.id = ?", docID).
|
||||
Take(&row).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config := map[string]interface{}{
|
||||
"id": row.ID,
|
||||
"kb_id": row.KbID,
|
||||
"parser_id": row.ParserID,
|
||||
"parser_config": row.ParserConfig,
|
||||
"size": row.Size,
|
||||
"embd_id": row.EmbdID,
|
||||
"tenant_id": row.TenantID,
|
||||
"img2txt_id": row.Img2TxtID,
|
||||
"asr_id": row.ASRID,
|
||||
"llm_id": row.LLMID,
|
||||
}
|
||||
if row.ContentHash != nil {
|
||||
config["content_hash"] = *row.ContentHash
|
||||
} else {
|
||||
config["content_hash"] = nil
|
||||
}
|
||||
if row.Language != nil {
|
||||
config["language"] = *row.Language
|
||||
} else {
|
||||
config["language"] = nil
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// DeleteByTenantID deletes all documents by tenant ID (hard delete)
|
||||
func (dao *DocumentDAO) DeleteByTenantID(tenantID string) (int64, error) {
|
||||
result := DB.Unscoped().Where("tenant_id = ?", tenantID).Delete(&entity.Document{})
|
||||
|
||||
@@ -118,6 +118,32 @@ func TestDocumentGetByIDs_NoMatch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDocumentGetByKBIDOrdersByCreateTime(t *testing.T) {
|
||||
db := setupDocumentTestDB(t)
|
||||
pushDocDB(t, db)
|
||||
|
||||
createTime10 := int64(10)
|
||||
createTime20 := int64(20)
|
||||
createTime30 := int64(30)
|
||||
db.Create(&entity.Document{ID: "doc-later", KbID: "kb1", Name: sp("Doc Later"), CreatedBy: "user1", ParserConfig: entity.JSONMap{}, BaseModel: entity.BaseModel{CreateTime: &createTime30}})
|
||||
db.Create(&entity.Document{ID: "doc-other", KbID: "kb2", Name: sp("Doc Other"), CreatedBy: "user1", ParserConfig: entity.JSONMap{}, BaseModel: entity.BaseModel{CreateTime: &createTime10}})
|
||||
db.Create(&entity.Document{ID: "doc-earlier", KbID: "kb1", Name: sp("Doc Earlier"), CreatedBy: "user1", ParserConfig: entity.JSONMap{}, BaseModel: entity.BaseModel{CreateTime: &createTime20}})
|
||||
|
||||
docs, total, err := NewDocumentDAO().GetByKBID("kb1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByKBID failed: %v", err)
|
||||
}
|
||||
if total != 2 {
|
||||
t.Fatalf("expected total=2, got %d", total)
|
||||
}
|
||||
if len(docs) != 2 {
|
||||
t.Fatalf("expected 2 docs, got %d", len(docs))
|
||||
}
|
||||
if docs[0].ID != "doc-earlier" || docs[1].ID != "doc-later" {
|
||||
t.Fatalf("unexpected order: %s, %s", docs[0].ID, docs[1].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDocumentGetByDocumentIDAndDatasetIDUsesKBID(t *testing.T) {
|
||||
db := setupDocumentTestDB(t)
|
||||
pushDocDB(t, db)
|
||||
|
||||
@@ -592,6 +592,74 @@ func (h *DatasetsHandler) RemoveTags(c *gin.Context) {
|
||||
jsonResponse(c, common.CodeSuccess, true, "success")
|
||||
}
|
||||
|
||||
// RunIndex Run an indexing task (graph/raptor/mindmap) for a dataset.
|
||||
func (h *DatasetsHandler) RunIndex(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
datasetID := strings.TrimSpace(c.Param("dataset_id"))
|
||||
if datasetID == "" {
|
||||
jsonError(c, common.CodeDataError, "dataset_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
userID := strings.TrimSpace(user.ID)
|
||||
if userID == "" {
|
||||
jsonError(c, common.CodeDataError, "user_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
indexType := strings.ToLower(strings.TrimSpace(c.Query("type")))
|
||||
data, code, err := h.datasetsService.RunIndex(userID, datasetID, indexType)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, data, "success")
|
||||
}
|
||||
|
||||
// TraceIndex Trace an indexing task (graph/raptor/mindmap) for a dataset.
|
||||
func (h *DatasetsHandler) TraceIndex(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
datasetID := strings.TrimSpace(c.Param("dataset_id"))
|
||||
if datasetID == "" {
|
||||
jsonError(c, common.CodeDataError, "dataset_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
userID := strings.TrimSpace(user.ID)
|
||||
if userID == "" {
|
||||
jsonError(c, common.CodeDataError, "user_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
indexType := strings.ToLower(strings.TrimSpace(c.Query("type")))
|
||||
result, code, err := h.datasetsService.TraceIndex(datasetID, userID, indexType)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
if result == nil {
|
||||
jsonResponse(c, common.CodeSuccess, map[string]interface{}{}, "success")
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"data": result,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// ListMetadataFlattened handles GET /api/v1/datasets/metadata/flattened.
|
||||
// @Summary List flattened metadata for datasets
|
||||
// @Description Get flattened metadata for multiple datasets
|
||||
|
||||
@@ -255,6 +255,8 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
datasets.PUT("/:dataset_id", r.datasetsHandler.UpdateDataset)
|
||||
datasets.GET("/:dataset_id/graph", r.datasetsHandler.GetKnowledgeGraph)
|
||||
datasets.DELETE("/:dataset_id/tags", r.datasetsHandler.RemoveTags)
|
||||
datasets.GET("/:dataset_id/index", r.datasetsHandler.TraceIndex)
|
||||
datasets.POST("/:dataset_id/index", r.datasetsHandler.RunIndex)
|
||||
datasets.DELETE("/:dataset_id/graph", r.datasetsHandler.DeleteKnowledgeGraph)
|
||||
datasets.POST("", r.datasetsHandler.CreateDataset)
|
||||
datasets.DELETE("", r.datasetsHandler.DeleteDatasets)
|
||||
|
||||
@@ -24,17 +24,21 @@ import (
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/engine"
|
||||
redisengine "ragflow/internal/engine/redis"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/entity/models"
|
||||
"ragflow/internal/server"
|
||||
"ragflow/internal/service/nlp"
|
||||
"ragflow/internal/utility"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -68,6 +72,16 @@ var (
|
||||
"number": {},
|
||||
}
|
||||
datasetChunkMethodErrorMessage = "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'resume', 'table' or 'tag'"
|
||||
validIndexTypes = []string{"graph", "raptor", "mindmap"}
|
||||
indexTypeToTaskType = map[string]string{"graph": "graphrag", "raptor": "raptor", "mindmap": "mindmap"}
|
||||
indexTypeToDisplayName = map[string]string{"graph": "Graph", "raptor": "RAPTOR", "mindmap": "Mindmap"}
|
||||
)
|
||||
|
||||
const (
|
||||
// Keep the legacy worker marker in queue payloads; persisted tasks use a real document ID.
|
||||
graphRaptorQueueDocID = "graph_raptor_x"
|
||||
maximumTaskPageNumber = int64(100000000)
|
||||
serverQueueNamePrefix = "te"
|
||||
)
|
||||
|
||||
// DatasetService implements the RESTful dataset APIs from dataset_api.py.
|
||||
@@ -79,6 +93,7 @@ type DatasetService struct {
|
||||
tenantLLMDAO *dao.TenantLLMDAO
|
||||
pipelineLogDAO *dao.PipelineOperationLogDAO
|
||||
userTenantDAO *dao.UserTenantDAO
|
||||
taskDAO *dao.TaskDAO
|
||||
searchService *SearchService
|
||||
docEngine engine.DocEngine
|
||||
embeddingCache *utility.EmbeddingLRU
|
||||
@@ -100,6 +115,7 @@ func NewDatasetService() *DatasetService {
|
||||
tenantLLMDAO: dao.NewTenantLLMDAO(),
|
||||
pipelineLogDAO: dao.NewPipelineOperationLogDAO(),
|
||||
userTenantDAO: dao.NewUserTenantDAO(),
|
||||
taskDAO: dao.NewTaskDAO(),
|
||||
searchService: NewSearchService(),
|
||||
docEngine: engine.Get(),
|
||||
embeddingCache: utility.NewEmbeddingLRU(1000),
|
||||
@@ -149,6 +165,369 @@ func (s *DatasetService) UpdateDocumentMetadataConfig(userID, datasetID, documen
|
||||
return updatedDoc, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func checkType(indexType string) bool {
|
||||
haveType := false
|
||||
for _, t := range validIndexTypes {
|
||||
if indexType == t {
|
||||
haveType = true
|
||||
}
|
||||
}
|
||||
return haveType
|
||||
}
|
||||
|
||||
func (s *DatasetService) newRaptorOrGraphRagTask(sampleDoc *entity.Document, taskType string, taskDocID string, queueDocID string, docIDs []string) (*entity.Task, map[string]interface{}, error) {
|
||||
if docIDs == nil || len(docIDs) == 0 {
|
||||
docIDs = make([]string, 0)
|
||||
}
|
||||
if !checkIndexTaskType(taskType) {
|
||||
return nil, nil, errors.New("type should be graphrag, raptor or mindmap")
|
||||
}
|
||||
|
||||
chunkingConfig, err := s.documentDAO.GetChunkingConfig(sampleDoc.ID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
hasher := xxhash.New()
|
||||
keys := make([]string, 0, len(chunkingConfig))
|
||||
for key := range chunkingConfig {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, key := range keys {
|
||||
_, _ = hasher.Write([]byte(key))
|
||||
_, _ = hasher.Write([]byte{0})
|
||||
v, mErr := json.Marshal(chunkingConfig[key])
|
||||
if mErr != nil {
|
||||
return nil, nil, mErr
|
||||
}
|
||||
_, _ = hasher.Write(v)
|
||||
_, _ = hasher.Write([]byte{0})
|
||||
}
|
||||
|
||||
taskID := strings.ReplaceAll(uuid.New().String(), "-", "")[:32]
|
||||
beginAt := time.Now().Truncate(time.Second)
|
||||
progressMsg := beginAt.Format("15:04:05") + " created task " + taskType
|
||||
|
||||
for _, field := range []interface{}{taskDocID, maximumTaskPageNumber, maximumTaskPageNumber, taskType} {
|
||||
_, _ = hasher.Write([]byte(fmt.Sprint(field)))
|
||||
}
|
||||
digest := fmt.Sprintf("%016x", hasher.Sum64())
|
||||
task := &entity.Task{
|
||||
ID: taskID,
|
||||
DocID: taskDocID,
|
||||
FromPage: maximumTaskPageNumber,
|
||||
ToPage: maximumTaskPageNumber,
|
||||
TaskType: taskType,
|
||||
ProgressMsg: &progressMsg,
|
||||
BeginAt: &beginAt,
|
||||
Digest: &digest,
|
||||
}
|
||||
|
||||
queueMessage := map[string]interface{}{
|
||||
"id": taskID,
|
||||
"doc_id": queueDocID,
|
||||
"from_page": maximumTaskPageNumber,
|
||||
"to_page": maximumTaskPageNumber,
|
||||
"task_type": taskType,
|
||||
"progress_msg": progressMsg,
|
||||
"begin_at": beginAt.Format("2006-01-02 15:04:05"),
|
||||
"digest": digest,
|
||||
"doc_ids": docIDs,
|
||||
}
|
||||
|
||||
return task, queueMessage, nil
|
||||
}
|
||||
|
||||
func createDatasetIndexTaskInTx(tx *gorm.DB, task *entity.Task, queueDocID string) (*entity.Document, error) {
|
||||
if task == nil {
|
||||
return nil, errors.New("task is required")
|
||||
}
|
||||
if err := tx.Create(task).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if queueDocID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var document entity.Document
|
||||
err := tx.Select("id", "progress_msg", "process_begin_at").Where("id = ?", queueDocID).First(&document).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
beginAt := time.Now().Truncate(time.Second)
|
||||
if task.BeginAt != nil {
|
||||
beginAt = *task.BeginAt
|
||||
}
|
||||
if err := tx.Model(&entity.Document{}).Where("id = ?", queueDocID).Updates(map[string]interface{}{
|
||||
"progress_msg": "Task is queued...",
|
||||
"process_begin_at": beginAt,
|
||||
}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &document, nil
|
||||
}
|
||||
|
||||
func enqueueDatasetIndexTask(priority int, queueMessage map[string]interface{}) error {
|
||||
redisClient := redisengine.Get()
|
||||
if redisClient == nil || !redisClient.QueueProduct(datasetIndexQueueName(priority), queueMessage) {
|
||||
return errors.New("Can't access Redis. Please check the Redis' status")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanupFailedDatasetIndexTask(taskID string, updatedDocument *entity.Document, kbID string, indexType string) error {
|
||||
return dao.DB.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Unscoped().Where("id = ?", taskID).Delete(&entity.Task{}).Error; err != nil {
|
||||
return fmt.Errorf("delete task %s: %w", taskID, err)
|
||||
}
|
||||
|
||||
if column := datasetIndexTaskIDColumn(indexType); kbID != "" && column != "" {
|
||||
if err := tx.Model(&entity.Knowledgebase{}).Where("id = ? AND "+column+" = ?", kbID, taskID).Update(column, nil).Error; err != nil {
|
||||
return fmt.Errorf("clear dataset task id %s: %w", taskID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if updatedDocument == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return tx.Model(&entity.Document{}).Where("id = ?", updatedDocument.ID).Updates(map[string]interface{}{
|
||||
"progress_msg": updatedDocument.ProgressMsg,
|
||||
"process_begin_at": updatedDocument.ProcessBeginAt,
|
||||
}).Error
|
||||
})
|
||||
}
|
||||
|
||||
func datasetIndexTaskIDColumn(indexType string) string {
|
||||
switch indexType {
|
||||
case "graph":
|
||||
return "graphrag_task_id"
|
||||
case "raptor":
|
||||
return "raptor_task_id"
|
||||
case "mindmap":
|
||||
return "mindmap_task_id"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func checkIndexTaskType(taskType string) bool {
|
||||
switch taskType {
|
||||
case "graphrag", "raptor", "mindmap":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func datasetIndexTaskID(kb *entity.Knowledgebase, indexType string) string {
|
||||
if kb == nil {
|
||||
return ""
|
||||
}
|
||||
switch indexType {
|
||||
case "graph":
|
||||
if kb.GraphragTaskID != nil {
|
||||
return *kb.GraphragTaskID
|
||||
}
|
||||
case "raptor":
|
||||
if kb.RaptorTaskID != nil {
|
||||
return *kb.RaptorTaskID
|
||||
}
|
||||
case "mindmap":
|
||||
if kb.MindmapTaskID != nil {
|
||||
return *kb.MindmapTaskID
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func datasetIndexTaskIDUpdate(indexType, taskID string) map[string]interface{} {
|
||||
switch indexType {
|
||||
case "graph":
|
||||
return map[string]interface{}{"graphrag_task_id": taskID}
|
||||
case "raptor":
|
||||
return map[string]interface{}{"raptor_task_id": taskID}
|
||||
case "mindmap":
|
||||
return map[string]interface{}{"mindmap_task_id": taskID}
|
||||
default:
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func datasetIndexTaskIDs(kb *entity.Knowledgebase) []string {
|
||||
if kb == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
taskIDs := make([]string, 0, 3)
|
||||
for _, taskID := range []*string{kb.GraphragTaskID, kb.RaptorTaskID, kb.MindmapTaskID} {
|
||||
if taskID != nil && *taskID != "" {
|
||||
taskIDs = append(taskIDs, *taskID)
|
||||
}
|
||||
}
|
||||
return common.Deduplicate(taskIDs)
|
||||
}
|
||||
|
||||
func datasetIndexQueueName(priority int) string {
|
||||
return fmt.Sprintf("%s.%d.common", serverQueueNamePrefix, priority)
|
||||
}
|
||||
|
||||
// RunIndex Run an indexing task (graph/raptor/mindmap) for a dataset.
|
||||
func (s *DatasetService) RunIndex(userID, datasetID, indexType string) (map[string]interface{}, common.ErrorCode, error) {
|
||||
if !checkType(indexType) {
|
||||
return nil, common.CodeDataError, fmt.Errorf("Invalid index type '%s'. Must be one of %v", indexType, validIndexTypes)
|
||||
}
|
||||
|
||||
if datasetID == "" {
|
||||
return nil, common.CodeDataError, errors.New(`Lack of "Dataset ID"`)
|
||||
}
|
||||
if !s.kbDAO.Accessible(datasetID, userID) {
|
||||
return nil, common.CodeDataError, errors.New("No authorization.")
|
||||
}
|
||||
|
||||
kb, err := s.kbDAO.GetByID(datasetID)
|
||||
if err != nil {
|
||||
if dao.IsNotFoundErr(err) {
|
||||
return nil, common.CodeDataError, errors.New("Invalid Dataset ID")
|
||||
}
|
||||
return nil, common.CodeDataError, errors.New("Internal server error")
|
||||
}
|
||||
|
||||
taskType := indexTypeToTaskType[indexType]
|
||||
displayName := indexTypeToDisplayName[indexType]
|
||||
|
||||
documents, code, err := s.getDocumentsByDatasetForIndex(datasetID)
|
||||
if err != nil {
|
||||
return nil, code, err
|
||||
}
|
||||
_ = documents
|
||||
|
||||
sampleDocument := documents[0]
|
||||
documentIDs := make([]string, len(documents))
|
||||
|
||||
for i, doc := range documents {
|
||||
documentIDs[i] = doc.ID
|
||||
}
|
||||
|
||||
task, queueMessage, err := s.newRaptorOrGraphRagTask(sampleDocument, taskType, sampleDocument.ID, graphRaptorQueueDocID, documentIDs)
|
||||
if err != nil {
|
||||
common.Warn("Failed to build dataset index task", zap.String("dataset_id", datasetID), zap.String("task_type", taskType), zap.Error(err))
|
||||
return nil, common.CodeDataError, errors.New("Internal server error")
|
||||
}
|
||||
|
||||
var updatedDocument *entity.Document
|
||||
var dataErr error
|
||||
err = dao.DB.Transaction(func(tx *gorm.DB) error {
|
||||
var lockedKB entity.Knowledgebase
|
||||
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||
Where("id = ? AND status = ?", kb.ID, string(entity.StatusValid)).
|
||||
First(&lockedKB).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
existingTaskID := datasetIndexTaskID(&lockedKB, indexType)
|
||||
if existingTaskID != "" {
|
||||
var existingTask entity.Task
|
||||
taskErr := tx.Where("id = ?", existingTaskID).First(&existingTask).Error
|
||||
if taskErr != nil {
|
||||
if errors.Is(taskErr, gorm.ErrRecordNotFound) {
|
||||
} else {
|
||||
return taskErr
|
||||
}
|
||||
} else if existingTask.Progress != 1 && existingTask.Progress != -1 {
|
||||
dataErr = fmt.Errorf("Task %s in progress with status %v. A %s Task is already running.", existingTaskID, existingTask.Progress, displayName)
|
||||
return dataErr
|
||||
}
|
||||
}
|
||||
|
||||
updatedDocument, err = createDatasetIndexTaskInTx(tx, task, graphRaptorQueueDocID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Model(&entity.Knowledgebase{}).Where("id = ?", lockedKB.ID).Updates(datasetIndexTaskIDUpdate(indexType, task.ID)).Error
|
||||
})
|
||||
if err != nil {
|
||||
if dataErr != nil {
|
||||
return nil, common.CodeDataError, dataErr
|
||||
}
|
||||
common.Warn("Failed to create dataset index task", zap.String("dataset_id", datasetID), zap.String("task_type", taskType), zap.Error(err))
|
||||
return nil, common.CodeDataError, errors.New("Internal server error")
|
||||
}
|
||||
|
||||
if err := enqueueDatasetIndexTask(0, queueMessage); err != nil {
|
||||
if cleanupErr := cleanupFailedDatasetIndexTask(task.ID, updatedDocument, kb.ID, indexType); cleanupErr != nil {
|
||||
err = errors.Join(err, cleanupErr)
|
||||
}
|
||||
common.Warn("Failed to queue dataset index task", zap.String("dataset_id", datasetID), zap.String("task_type", taskType), zap.Error(err))
|
||||
return nil, common.CodeDataError, errors.New("Internal server error")
|
||||
}
|
||||
|
||||
return map[string]interface{}{"task_id": task.ID}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (s *DatasetService) getDocumentsByDatasetForIndex(datasetID string) ([]*entity.Document, common.ErrorCode, error) {
|
||||
documents, _, err := s.documentDAO.GetByKBID(datasetID)
|
||||
if err != nil {
|
||||
common.Warn("Failed to load dataset documents for index", zap.String("dataset_id", datasetID), zap.Error(err))
|
||||
return nil, common.CodeDataError, errors.New("Internal server error")
|
||||
}
|
||||
if len(documents) == 0 {
|
||||
return nil, common.CodeDataError, fmt.Errorf("No documents in Dataset %s", datasetID)
|
||||
}
|
||||
return documents, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
type TraceIndexRequest struct {
|
||||
Type string `json:"type" binding:"required"`
|
||||
}
|
||||
|
||||
// TraceIndex Trace an indexing task (graph/raptor/mindmap) for a dataset.
|
||||
func (s *DatasetService) TraceIndex(datasetID, userID, indexType string) (*entity.Task, common.ErrorCode, error) {
|
||||
if !checkType(indexType) {
|
||||
return nil, common.CodeDataError, fmt.Errorf("Invalid index type '%s'. Must be one of %v", indexType, validIndexTypes)
|
||||
}
|
||||
|
||||
if datasetID == "" {
|
||||
return nil, common.CodeDataError, errors.New(`Lack of "Dataset ID"`)
|
||||
}
|
||||
if !s.kbDAO.Accessible(datasetID, userID) {
|
||||
return nil, common.CodeDataError, errors.New("No authorization.")
|
||||
}
|
||||
|
||||
kb, err := s.kbDAO.GetByID(datasetID)
|
||||
if err != nil {
|
||||
if dao.IsNotFoundErr(err) {
|
||||
return nil, common.CodeDataError, errors.New("Invalid Dataset ID")
|
||||
}
|
||||
return nil, common.CodeDataError, errors.New("Internal server error")
|
||||
}
|
||||
|
||||
taskID := datasetIndexTaskID(kb, indexType)
|
||||
|
||||
var task *entity.Task
|
||||
if taskID != "" {
|
||||
task, err = s.taskDAO.GetByID(taskID)
|
||||
if err != nil {
|
||||
if dao.IsNotFoundErr(err) {
|
||||
return nil, common.CodeSuccess, nil
|
||||
}
|
||||
return nil, common.CodeServerError, errors.New("Internal server error")
|
||||
}
|
||||
if task == nil {
|
||||
return nil, common.CodeSuccess, nil
|
||||
}
|
||||
}
|
||||
|
||||
return task, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// SearchDatasetsRequest is the request structure for searching chunks across datasets.
|
||||
type SearchDatasetsRequest struct {
|
||||
DatasetIDs []string `json:"dataset_ids" binding:"required"`
|
||||
@@ -1638,6 +2017,12 @@ func jsonMapValue(m entity.JSONMap) interface{} {
|
||||
|
||||
func (s *DatasetService) deleteDataset(tenantID string, kb *entity.Knowledgebase) error {
|
||||
return dao.DB.Transaction(func(tx *gorm.DB) error {
|
||||
if taskIDs := datasetIndexTaskIDs(kb); len(taskIDs) > 0 {
|
||||
if err := tx.Where("id IN ?", taskIDs).Delete(&entity.Task{}).Error; err != nil {
|
||||
return fmt.Errorf("Delete dataset error for %s", kb.ID)
|
||||
}
|
||||
}
|
||||
|
||||
var documents []entity.Document
|
||||
if err := tx.Where("kb_id = ?", kb.ID).Find(&documents).Error; err != nil {
|
||||
return fmt.Errorf("Delete dataset error for %s", kb.ID)
|
||||
|
||||
110
internal/service/dataset_task_cleanup_test.go
Normal file
110
internal/service/dataset_task_cleanup_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
func TestCleanupFailedDatasetIndexTaskDeletesTaskAndRestoresDocument(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
pushServiceDB(t, db)
|
||||
|
||||
previousMsg := "previous progress"
|
||||
previousBeginAt := time.Date(2026, 6, 18, 10, 0, 0, 0, time.UTC)
|
||||
queuedMsg := "Task is queued..."
|
||||
queuedBeginAt := previousBeginAt.Add(time.Hour)
|
||||
taskID := "task-1"
|
||||
|
||||
kb := &entity.Knowledgebase{
|
||||
ID: "kb-1",
|
||||
TenantID: "user-1",
|
||||
Name: "test-kb",
|
||||
EmbdID: "embedding@OpenAI",
|
||||
CreatedBy: "user-1",
|
||||
Permission: string(entity.TenantPermissionMe),
|
||||
ParserID: "naive",
|
||||
ParserConfig: entity.JSONMap{},
|
||||
GraphragTaskID: &taskID,
|
||||
Status: sptr("1"),
|
||||
}
|
||||
if err := dao.DB.Create(kb).Error; err != nil {
|
||||
t.Fatalf("insert kb: %v", err)
|
||||
}
|
||||
|
||||
doc := &entity.Document{
|
||||
ID: "doc-1",
|
||||
KbID: "kb-1",
|
||||
ParserID: "naive",
|
||||
ParserConfig: entity.JSONMap{},
|
||||
SourceType: "local",
|
||||
Type: "pdf",
|
||||
CreatedBy: "user-1",
|
||||
Suffix: ".pdf",
|
||||
ProgressMsg: &queuedMsg,
|
||||
ProcessBeginAt: &queuedBeginAt,
|
||||
}
|
||||
if err := dao.DB.Create(doc).Error; err != nil {
|
||||
t.Fatalf("insert document: %v", err)
|
||||
}
|
||||
|
||||
task := &entity.Task{ID: taskID, DocID: doc.ID, TaskType: "graphrag"}
|
||||
if err := dao.DB.Create(task).Error; err != nil {
|
||||
t.Fatalf("insert task: %v", err)
|
||||
}
|
||||
|
||||
snapshot := &entity.Document{
|
||||
ID: doc.ID,
|
||||
ProgressMsg: &previousMsg,
|
||||
ProcessBeginAt: &previousBeginAt,
|
||||
}
|
||||
if err := cleanupFailedDatasetIndexTask(task.ID, snapshot, kb.ID, "graph"); err != nil {
|
||||
t.Fatalf("cleanup failed: %v", err)
|
||||
}
|
||||
|
||||
var persistedTask entity.Task
|
||||
err := dao.DB.Where("id = ?", task.ID).First(&persistedTask).Error
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
t.Fatalf("expected task to be deleted, got err=%v task=%#v", err, persistedTask)
|
||||
}
|
||||
|
||||
persistedDoc, err := dao.NewDocumentDAO().GetByID(doc.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("fetch document: %v", err)
|
||||
}
|
||||
if persistedDoc.ProgressMsg == nil || *persistedDoc.ProgressMsg != previousMsg {
|
||||
t.Fatalf("expected progress_msg %q, got %#v", previousMsg, persistedDoc.ProgressMsg)
|
||||
}
|
||||
if persistedDoc.ProcessBeginAt == nil || !persistedDoc.ProcessBeginAt.Equal(previousBeginAt) {
|
||||
t.Fatalf("expected process_begin_at %v, got %#v", previousBeginAt, persistedDoc.ProcessBeginAt)
|
||||
}
|
||||
|
||||
var persistedKB entity.Knowledgebase
|
||||
if err := dao.DB.Where("id = ?", kb.ID).First(&persistedKB).Error; err != nil {
|
||||
t.Fatalf("fetch kb: %v", err)
|
||||
}
|
||||
if persistedKB.GraphragTaskID != nil {
|
||||
t.Fatalf("expected graphrag_task_id to be cleared, got %#v", *persistedKB.GraphragTaskID)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user