diff --git a/internal/dao/document.go b/internal/dao/document.go index 5ac325f112..72f197452b 100644 --- a/internal/dao/document.go +++ b/internal/dao/document.go @@ -211,3 +211,9 @@ func (dao *DocumentDAO) GetParsingStatusByKBID(kbID string) (map[string]int64, e } return result, nil } + +func (dao *DocumentDAO) GetByNameAndKBID(name, kbID string) ([]*entity.Document, error) { + var docs []*entity.Document + err := DB.Where("name = ? AND kb_id = ?", name, kbID).Find(&docs).Error + return docs, err +} diff --git a/internal/handler/document.go b/internal/handler/document.go index 874542ec4d..e637ceac4f 100644 --- a/internal/handler/document.go +++ b/internal/handler/document.go @@ -59,6 +59,7 @@ type documentServiceIface interface { GetDocumentArtifact(filename string) (*service.ArtifactResponse, error) GetDocumentPreview(docID string) (*service.DocumentPreview, error) DownloadDocument(datasetID, docID string) (*service.DownloadDocumentResp, error) + UpdateDatasetDocument(userID, datasetID, documentID string, req *service.UpdateDatasetDocumentRequest, present map[string]bool) (*service.UpdateDatasetDocumentResponse, common.ErrorCode, error) ListIngestionTasks(userID string, datasetID *string, page, pageSize int) ([]*entity.IngestionTask, error) IngestDocuments(datasetID, userID string, docIDs []string) ([]*service.ParseDocumentResponse, error) StopIngestionTasks(tasks []string, userID string) ([]*entity.IngestionTask, error) @@ -1171,3 +1172,53 @@ func (h *DocumentHandler) MetadataSummaryByDataset(c *gin.Context) { "data": gin.H{"summary": summary}, }) } + +func (h *DocumentHandler) UpdateDatasetDocument(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + datasetID := strings.TrimSpace(c.Param("dataset_id")) + if datasetID == "" { + jsonError(c, common.CodeArgumentError, "dataset_id is required") + return + } + documentID := strings.TrimSpace(c.Param("document_id")) + if documentID == "" { + jsonError(c, common.CodeArgumentError, "document_id is required") + return + } + + body, err := c.GetRawData() + if err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + var raw map[string]json.RawMessage + if err := json.Unmarshal(body, &raw); err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + present := make(map[string]bool, len(raw)) + for key := range raw { + present[key] = true + } + var req service.UpdateDatasetDocumentRequest + if err := json.Unmarshal(body, &req); err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + + data, code, err := h.documentService.UpdateDatasetDocument(user.ID, datasetID, documentID, &req, present) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "data": data, + }) +} diff --git a/internal/handler/document_test.go b/internal/handler/document_test.go index 6c2c5d314f..472ccb7ac4 100644 --- a/internal/handler/document_test.go +++ b/internal/handler/document_test.go @@ -46,6 +46,10 @@ type fakeDocumentService struct { metadataDocIDs []string } +func (f *fakeDocumentService) UpdateDatasetDocument(userID, datasetID, documentID string, req *service.UpdateDatasetDocumentRequest, present map[string]bool) (*service.UpdateDatasetDocumentResponse, common.ErrorCode, error) { + return nil, common.CodeSuccess, nil +} + func (f *fakeDocumentService) GetDocumentArtifact(filename string) (*service.ArtifactResponse, error) { if filename == "error.txt" { return nil, service.ErrArtifactNotFound diff --git a/internal/router/router.go b/internal/router/router.go index 70f35dd65c..2dffcce29f 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -273,6 +273,7 @@ func (r *Router) Setup(engine *gin.Engine) { // Dataset documents datasets.GET("/:dataset_id/documents", r.documentHandler.ListDocuments) datasets.GET("/:dataset_id/documents/:document_id", r.documentHandler.DownloadDocument) + datasets.PATCH("/:dataset_id/documents/:document_id", r.documentHandler.UpdateDatasetDocument) datasets.DELETE("/:dataset_id/documents", r.documentHandler.DeleteDocuments) // Dataset document chunk diff --git a/internal/service/document.go b/internal/service/document.go index d661df75b5..1ba84d09a5 100644 --- a/internal/service/document.go +++ b/internal/service/document.go @@ -21,22 +21,25 @@ import ( "encoding/json" "errors" "fmt" + "math" "os" "path/filepath" - "ragflow/internal/common" - "ragflow/internal/engine/redis" - "ragflow/internal/entity" - "ragflow/internal/storage" - "ragflow/internal/utility" "regexp" "sort" + "strconv" "strings" "time" + "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/engine" - + "ragflow/internal/engine/redis" + enginetypes "ragflow/internal/engine/types" + "ragflow/internal/entity" "ragflow/internal/server" + "ragflow/internal/storage" + "ragflow/internal/tokenizer" + "ragflow/internal/utility" "gorm.io/gorm" ) @@ -52,6 +55,7 @@ type DocumentService struct { metadataSvc *MetadataService taskDAO *dao.TaskDAO file2DocumentDAO *dao.File2DocumentDAO + fileDAO *dao.FileDAO } // NewDocumentService create document service @@ -67,6 +71,7 @@ func NewDocumentService() *DocumentService { metadataSvc: NewMetadataService(), taskDAO: dao.NewTaskDAO(), file2DocumentDAO: dao.NewFile2DocumentDAO(), + fileDAO: dao.NewFileDAO(), } } @@ -127,6 +132,50 @@ type ArtifactResponse struct { ForceAttachment bool } +type UpdateDatasetDocumentRequest struct { + Name *string `json:"name"` + ChunkMethod *string `json:"chunk_method"` + ParserID *string `json:"parser_id"` + ChunkCount *int64 `json:"chunk_count"` + TokenCount *int64 `json:"token_count"` + PipelineID *string `json:"pipeline_id"` + Enabled *int `json:"enabled"` + Progress *float64 `json:"progress"` + ParserConfig map[string]any `json:"parser_config"` + MetaFields map[string]any `json:"meta_fields"` +} + +// PATCH /api/v1/datasets/:dataset_id/documents/:document_id. +type UpdateDatasetDocumentResponse struct { + ID string `json:"id"` + Thumbnail *string `json:"thumbnail,omitempty"` + DatasetID string `json:"dataset_id"` + ChunkMethod string `json:"chunk_method"` + PipelineID *string `json:"pipeline_id,omitempty"` + ParserConfig map[string]interface{} `json:"parser_config"` + SourceType string `json:"source_type"` + Type string `json:"type"` + CreatedBy string `json:"created_by"` + Name *string `json:"name,omitempty"` + Location *string `json:"location,omitempty"` + Size int64 `json:"size"` + TokenCount int64 `json:"token_count"` + ChunkCount int64 `json:"chunk_count"` + Progress float64 `json:"progress"` + ProgressMsg *string `json:"progress_msg,omitempty"` + ProcessBeginAt *time.Time `json:"process_begin_at,omitempty"` + ProcessDuration float64 `json:"process_duration"` + ContentHash *string `json:"content_hash,omitempty"` + MetaFields map[string]interface{} `json:"meta_fields,omitempty"` + Suffix string `json:"suffix"` + Run string `json:"run"` + Status *string `json:"status,omitempty"` + CreateTime *int64 `json:"create_time,omitempty"` + CreateDate *time.Time `json:"create_date,omitempty"` + UpdateTime *int64 `json:"update_time,omitempty"` + UpdateDate *time.Time `json:"update_date,omitempty"` +} + var ( ErrArtifactInvalidFilename = errors.New("Invalid filename.") ErrArtifactInvalidFileType = errors.New("Invalid file type.") @@ -1543,3 +1592,498 @@ func isTimeString(s string) bool { matched, _ := regexp.MatchString(`^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$`, s) return matched } + +func (s *DocumentService) UpdateDatasetDocument(userID, datasetID, documentID string, req *UpdateDatasetDocumentRequest, present map[string]bool) (*UpdateDatasetDocumentResponse, common.ErrorCode, error) { + tenantID := userID + kb, err := s.kbDAO.GetByIDAndTenantID(datasetID, tenantID) + if err != nil { + if dao.IsNotFoundErr(err) { + return nil, common.CodeDataError, errors.New("You don't own the dataset.") + } + return nil, common.CodeDataError, errors.New("Can't find this dataset!") + } + + doc, err := s.documentDAO.GetByDocumentIDAndDatasetID(documentID, datasetID) + if err != nil { + if dao.IsNotFoundErr(err) { + return nil, common.CodeDataError, errors.New("The dataset doesn't own the document.") + } + return nil, common.CodeServerError, err + } + + if code, err := s.validateDatasetDocumentUpdate(doc, req, present); err != nil { + return nil, code, err + } + + if present["meta_fields"] { + if err := s.replaceDocumentMetadata(documentID, req.MetaFields); err != nil { + return nil, common.CodeDataError, err + } + } + + if present["name"] && req.Name != nil && (doc.Name == nil || *req.Name != *doc.Name) { + if err := s.updateDocumentNameOnly(doc, kb.TenantID, *req.Name); err != nil { + return nil, common.CodeDataError, err + } + } + + if present["parser_config"] && req.ParserConfig != nil { + if err := s.updateDocumentParserConfig(doc.ID, req.ParserConfig); err != nil { + return nil, common.CodeDataError, err + } + } + + if req.PipelineID != nil && *req.PipelineID != "" { + if err := s.resetDocumentForReparse(doc, kb.TenantID, nil, req.PipelineID); err != nil { + return nil, common.CodeDataError, err + } + } else if present["parser_id"] && req.ParserID != nil && strings.TrimSpace(*req.ParserID) != "" { + parserID := strings.TrimSpace(*req.ParserID) + if err := s.resetDocumentForReparse(doc, kb.TenantID, &parserID, nil); err != nil { + return nil, common.CodeDataError, err + } + } else if req.ChunkMethod != nil && *req.ChunkMethod != "" { + if err := s.updateChunkMethod(doc, kb.TenantID, *req.ChunkMethod, req.ParserConfig, present["parser_config"]); err != nil { + return nil, common.CodeDataError, err + } + } + + if present["enabled"] && req.Enabled != nil { + if err := s.updateDocumentStatusOnly(doc, kb, *req.Enabled); err != nil { + return nil, common.CodeServerError, err + } + } + + updatedDoc, err := s.documentDAO.GetByID(doc.ID) + if err != nil { + if dao.IsNotFoundErr(err) { + return nil, common.CodeDataError, fmt.Errorf("Can not get document by id:%s", doc.ID) + } + return nil, common.CodeDataError, errors.New("Database operation failed") + } + + metaFields := map[string]interface{}{} + if s.docEngine != nil && s.metadataSvc != nil { + metaFields, _ = s.GetDocumentMetadataByID(updatedDoc.ID) + } + + return s.toUpdateDatasetDocumentResponse(updatedDoc, metaFields), common.CodeSuccess, nil +} + +var allowedDocumentChunkMethods = map[string]struct{}{ + "naive": {}, + "manual": {}, + "qa": {}, + "table": {}, + "paper": {}, + "book": {}, + "laws": {}, + "presentation": {}, + "picture": {}, + "one": {}, + "knowledge_graph": {}, + "email": {}, + "tag": {}, +} + +func (s *DocumentService) validateDatasetDocumentUpdate(doc *entity.Document, req *UpdateDatasetDocumentRequest, present map[string]bool) (common.ErrorCode, error) { + if req == nil { + return common.CodeDataError, errors.New("Invalid request payload") + } + if present["chunk_count"] && req.ChunkCount != nil && *req.ChunkCount != 0 && *req.ChunkCount != doc.ChunkNum { + return common.CodeDataError, errors.New("Can't change `chunk_count`.") + } + if present["token_count"] && req.TokenCount != nil && *req.TokenCount != 0 && *req.TokenCount != doc.TokenNum { + return common.CodeDataError, errors.New("Can't change `token_count`.") + } + if present["progress"] && req.Progress != nil && *req.Progress != 0 && math.Abs(*req.Progress-doc.Progress) > 1e-9 { + return common.CodeDataError, errors.New("Can't change `progress`.") + } + + if present["enabled"] { + if req.Enabled == nil || (*req.Enabled != 0 && *req.Enabled != 1) { + return common.CodeDataError, errors.New("`enabled` value invalid, only accept 0 or 1") + } + } + + if present["chunk_method"] { + if req.ChunkMethod == nil || strings.TrimSpace(*req.ChunkMethod) == "" { + return common.CodeDataError, errors.New("`chunk_method` (empty string) is not valid") + } + chunkMethod := strings.TrimSpace(*req.ChunkMethod) + if _, ok := allowedDocumentChunkMethods[chunkMethod]; !ok { + return common.CodeDataError, fmt.Errorf("`chunk_method` %s doesn't exist", chunkMethod) + } + if doc.Type == "visual" || isPresentationFile(doc.Name) { + return common.CodeDataError, errors.New("Not supported yet!") + } + } + if present["parser_id"] && req.ParserID != nil { + parserID := strings.TrimSpace(*req.ParserID) + if (doc.Type == "visual" && parserID != "picture") || (isPresentationFile(doc.Name) && parserID != "presentation") { + return common.CodeDataError, errors.New("Not supported yet!") + } + } + if present["name"] && req.Name != nil { + if err := s.validateDocumentName(doc, *req.Name); err != nil { + return common.CodeDataError, err + } + } + + if present["meta_fields"] { + if err := validateMetaFields(req.MetaFields); err != nil { + return common.CodeDataError, err + } + } + + return common.CodeSuccess, nil +} + +func (s *DocumentService) validateDocumentName(doc *entity.Document, newName string) error { + if strings.TrimSpace(newName) == "" { + return errors.New("File name can't be empty.") + } + if len([]byte(newName)) > 255 { + return errors.New("File name must be 255 bytes or less.") + } + + oldName := "" + if doc.Name != nil { + oldName = *doc.Name + } + + if strings.ToLower(filepath.Ext(newName)) != strings.ToLower(filepath.Ext(oldName)) { + return errors.New("The extension of file can't be changed") + } + + docs, err := s.documentDAO.GetByNameAndKBID(newName, doc.KbID) + if err != nil { + return err + } + for _, d := range docs { + if d.ID != doc.ID && d.Name != nil && *d.Name == newName { + return errors.New("Duplicated document name in the same dataset.") + } + } + + return nil +} + +func isPresentationFile(name *string) bool { + if name == nil { + return false + } + ext := strings.ToLower(filepath.Ext(*name)) + return ext == ".ppt" || ext == ".pptx" || ext == ".pages" +} + +func validateMetaFields(meta map[string]any) error { + if meta == nil { + return nil + } + + for _, v := range meta { + switch typed := v.(type) { + case string, float64, int, int64, float32: + continue + case []any: + for _, item := range typed { + switch item.(type) { + case string, float64, int, int64, float32: + continue + default: + return fmt.Errorf("The type is not supported in list: %v", typed) + } + } + default: + return fmt.Errorf("The type is not supported: %v", v) + } + } + + return nil +} + +func (s *DocumentService) replaceDocumentMetadata(docID string, meta map[string]any) error { + if s.docEngine == nil || s.metadataSvc == nil { + return nil + } + if err := s.DeleteDocumentAllMetadata(docID); err != nil { + return err + } + return s.SetDocumentMetadata(docID, map[string]interface{}(meta)) +} + +func (s *DocumentService) updateDocumentNameOnly(doc *entity.Document, tenantID, newName string) error { + if err := s.documentDAO.UpdateByID(doc.ID, map[string]interface{}{"name": newName}); err != nil { + return errors.New("Database error (Document rename)!") + } + + mappings, err := s.file2DocumentDAO.GetByDocumentID(doc.ID) + if err == nil && len(mappings) > 0 && mappings[0].FileID != nil && s.fileDAO != nil { + _ = s.fileDAO.UpdateByID(*mappings[0].FileID, map[string]interface{}{"name": newName}) + } + + if s.docEngine == nil { + return nil + } + + titleTks, _ := tokenizer.Tokenize(newName) + titleSmTks, _ := tokenizer.FineGrainedTokenize(titleTks) + indexName := fmt.Sprintf("ragflow_%s", tenantID) + return s.docEngine.UpdateChunks( + context.Background(), + map[string]interface{}{"doc_id": doc.ID}, + map[string]interface{}{ + "docnm_kwd": newName, + "title_tks": titleTks, + "title_sm_tks": titleSmTks, + }, + indexName, + doc.KbID, + ) +} + +func (s *DocumentService) updateDocumentParserConfig(documentID string, config map[string]any) error { + if len(config) == 0 { + return nil + } + + doc, err := s.documentDAO.GetByID(documentID) + if err != nil { + return fmt.Errorf("Document(%s) not found.", documentID) + } + + merged := common.DeepMergeMaps(map[string]interface{}(doc.ParserConfig), map[string]interface{}(config)) + if _, ok := config["raptor"]; !ok { + delete(merged, "raptor") + } + + return s.documentDAO.UpdateByID(documentID, map[string]interface{}{ + "parser_config": entity.JSONMap(merged), + }) +} + +func (s *DocumentService) resetDocumentForReparse(doc *entity.Document, tenantID string, parserID *string, pipelineID *string) error { + progressMsg := "" + run := string(entity.TaskStatusUnstart) + updates := map[string]interface{}{ + "progress": 0, + "progress_msg": progressMsg, + "run": run, + } + if parserID != nil { + updates["parser_id"] = *parserID + } + if pipelineID != nil { + updates["pipeline_id"] = *pipelineID + } + + if err := s.documentDAO.UpdateByID(doc.ID, updates); err != nil { + return errors.New("Document not found!") + } + + if doc.TokenNum > 0 { + decremented, err := s.decrementDocumentAndKBCountersForReparse(doc) + if err != nil { + return errors.New("Document not found!") + } + if !decremented { + return nil + } + if s.docEngine != nil { + indexName := fmt.Sprintf("ragflow_%s", tenantID) + s.deleteChunkImages(doc, indexName) + if _, err := s.docEngine.DeleteChunks(context.Background(), map[string]interface{}{"doc_id": doc.ID}, indexName, doc.KbID); err != nil { + return err + } + } + } + + return nil +} + +func (s *DocumentService) deleteChunkImages(doc *entity.Document, indexName string) { + if s.docEngine == nil { + return + } + storageImpl := storage.GetStorageFactory().GetStorage() + if storageImpl == nil { + return + } + + const pageSize = 1000 + for offset := 0; ; offset += pageSize { + result, err := s.docEngine.Search(context.Background(), &enginetypes.SearchRequest{ + IndexNames: []string{indexName}, + KbIDs: []string{doc.KbID}, + Offset: offset, + Limit: pageSize, + SelectFields: []string{"id", "img_id"}, + Filter: map[string]interface{}{"doc_id": doc.ID}, + MatchExprs: nil, + OrderBy: nil, + RankFeature: nil, + }) + if err != nil || result == nil || len(result.Chunks) == 0 { + return + } + for _, chunk := range result.Chunks { + imageKey, ok := chunkImageStorageKey(doc.KbID, chunk) + if !ok { + continue + } + if storageImpl.ObjExist(doc.KbID, imageKey) { + _ = storageImpl.Remove(doc.KbID, imageKey) + } + } + } +} + +func chunkImageStorageKey(defaultBucket string, chunk map[string]interface{}) (string, bool) { + imgID := firstStringField(chunk, "img_id") + if imgID != "" { + prefix := defaultBucket + "-" + if strings.HasPrefix(imgID, prefix) && len(imgID) > len(prefix) { + return strings.TrimPrefix(imgID, prefix), true + } + return imgID, true + } + + chunkID := firstStringField(chunk, "id", "_id") + if chunkID == "" { + return "", false + } + return chunkID, true +} + +func firstStringField(m map[string]interface{}, keys ...string) string { + for _, key := range keys { + if value, ok := m[key]; ok { + if s, ok := value.(string); ok { + return s + } + } + } + return "" +} + +func (s *DocumentService) decrementDocumentAndKBCountersForReparse(doc *entity.Document) (bool, error) { + decremented := false + err := dao.DB.Transaction(func(tx *gorm.DB) error { + result := tx.Model(&entity.Document{}). + Where("id = ? AND kb_id = ? AND token_num = ? AND chunk_num = ?", doc.ID, doc.KbID, doc.TokenNum, doc.ChunkNum). + Updates(map[string]interface{}{ + "token_num": gorm.Expr("token_num - ?", doc.TokenNum), + "chunk_num": gorm.Expr("chunk_num - ?", doc.ChunkNum), + "process_duration": gorm.Expr("process_duration - ?", doc.ProcessDuration), + }) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return nil + } + decremented = true + + return tx.Model(&entity.Knowledgebase{}). + Where("id = ?", doc.KbID). + Updates(map[string]interface{}{ + "token_num": gorm.Expr("token_num - ?", doc.TokenNum), + "chunk_num": gorm.Expr("chunk_num - ?", doc.ChunkNum), + }).Error + }) + return decremented, err +} + +func (s *DocumentService) updateChunkMethod(doc *entity.Document, tenantID string, chunkMethod string, parserConfig map[string]any, hasParserConfig bool) error { + chunkMethod = strings.TrimSpace(chunkMethod) + if !strings.EqualFold(doc.ParserID, chunkMethod) { + if err := s.resetDocumentForReparse(doc, tenantID, &chunkMethod, nil); err != nil { + return err + } + } + if !hasParserConfig { + defaultConfig := common.GetParserConfig(chunkMethod, nil) + if err := s.updateDocumentParserConfig(doc.ID, defaultConfig); err != nil { + return err + } + } + return nil +} + +func (s *DocumentService) updateDocumentStatusOnly(doc *entity.Document, kb *entity.Knowledgebase, status int) error { + statusStr := strconv.Itoa(status) + if doc.Status != nil && *doc.Status == statusStr { + return nil + } + + if err := s.documentDAO.UpdateByID(doc.ID, map[string]interface{}{"status": statusStr}); err != nil { + return errors.New("Database error (Document update)!") + } + + if s.docEngine == nil { + return nil + } + + indexName := fmt.Sprintf("ragflow_%s", kb.TenantID) + return s.docEngine.UpdateChunks( + context.Background(), + map[string]interface{}{"doc_id": doc.ID}, + map[string]interface{}{"available_int": status}, + indexName, + doc.KbID, + ) +} + +func (s *DocumentService) toUpdateDatasetDocumentResponse(doc *entity.Document, metaFields map[string]interface{}) *UpdateDatasetDocumentResponse { + if metaFields == nil { + metaFields = map[string]interface{}{} + } + return &UpdateDatasetDocumentResponse{ + ID: doc.ID, + Thumbnail: doc.Thumbnail, + DatasetID: doc.KbID, + ChunkMethod: doc.ParserID, + PipelineID: doc.PipelineID, + ParserConfig: map[string]interface{}(doc.ParserConfig), + SourceType: doc.SourceType, + Type: doc.Type, + CreatedBy: doc.CreatedBy, + Name: doc.Name, + Location: doc.Location, + Size: doc.Size, + TokenCount: doc.TokenNum, + ChunkCount: doc.ChunkNum, + Progress: doc.Progress, + ProgressMsg: doc.ProgressMsg, + ProcessBeginAt: doc.ProcessBeginAt, + ProcessDuration: doc.ProcessDuration, + ContentHash: doc.ContentHash, + MetaFields: metaFields, + Suffix: doc.Suffix, + Run: mapDocumentRunStatus(doc.Run), + Status: doc.Status, + CreateTime: doc.CreateTime, + CreateDate: doc.CreateDate, + UpdateTime: doc.UpdateTime, + UpdateDate: doc.UpdateDate, + } +} + +func mapDocumentRunStatus(run *string) string { + if run == nil { + return "UNSTART" + } + switch *run { + case string(entity.TaskStatusRunning): + return "RUNNING" + case string(entity.TaskStatusCancel): + return "CANCEL" + case string(entity.TaskStatusDone): + return "DONE" + case string(entity.TaskStatusFail): + return "FAIL" + default: + return "UNSTART" + } +} diff --git a/internal/service/document_test.go b/internal/service/document_test.go index 59de93cf8a..93d75d38ee 100644 --- a/internal/service/document_test.go +++ b/internal/service/document_test.go @@ -17,15 +17,34 @@ package service import ( + "context" + "errors" + "path/filepath" "testing" "github.com/glebarez/sqlite" "gorm.io/gorm" + "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/entity" ) +type failingDeleteMetadataEngine struct { + fakeChatDocEngine + deleteErr error + updateCalled bool +} + +func (f *failingDeleteMetadataEngine) DeleteMetadata(ctx context.Context, condition map[string]interface{}, tenantID string) (int64, error) { + return 0, f.deleteErr +} + +func (f *failingDeleteMetadataEngine) UpdateMetadata(ctx context.Context, docID string, datasetID string, metaFields map[string]interface{}, tenantID string) error { + f.updateCalled = true + return nil +} + // setupServiceTestDB initializes an in-memory SQLite database for service tests. func setupServiceTestDB(t *testing.T) *gorm.DB { t.Helper() @@ -72,6 +91,7 @@ func testDocumentService(t *testing.T) *DocumentService { kbDAO: dao.NewKnowledgebaseDAO(), taskDAO: dao.NewTaskDAO(), file2DocumentDAO: dao.NewFile2DocumentDAO(), + fileDAO: dao.NewFileDAO(), docEngine: nil, metadataSvc: nil, // nil engine → metadata ops skipped } @@ -83,16 +103,16 @@ func sptr(s string) *string { return &s } func insertTestKB(t *testing.T, id, tenantID string, docNum, tokenNum, chunkNum int64) { t.Helper() kb := &entity.Knowledgebase{ - ID: id, - TenantID: tenantID, - Name: "test-kb", - EmbdID: "embd-1", - CreatedBy: "user-1", + ID: id, + TenantID: tenantID, + Name: "test-kb", + EmbdID: "embd-1", + CreatedBy: "user-1", Permission: string(entity.TenantPermissionTeam), - DocNum: docNum, - TokenNum: tokenNum, - ChunkNum: chunkNum, - Status: sptr(string(entity.StatusValid)), + DocNum: docNum, + TokenNum: tokenNum, + ChunkNum: chunkNum, + Status: sptr(string(entity.StatusValid)), } if err := dao.DB.Create(kb).Error; err != nil { t.Fatalf("insert test kb: %v", err) @@ -919,3 +939,339 @@ func TestDownloadDocument_WrongDataset(t *testing.T) { t.Error("expected error for wrong dataset") } } + +func TestUpdateDatasetDocumentRejectsNonOwner(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + insertTestKB(t, "kb-1", "tenant-1", 1, 0, 0) + insertTestDoc(t, "doc-1", "kb-1", 0, 0) + + svc := testDocumentService(t) + _, code, err := svc.UpdateDatasetDocument("tenant-2", "kb-1", "doc-1", &UpdateDatasetDocumentRequest{}, map[string]bool{}) + if err == nil { + t.Fatal("expected ownership error") + } + if code != common.CodeDataError { + t.Fatalf("code = %v, want %v", code, common.CodeDataError) + } + if err.Error() != "You don't own the dataset." { + t.Fatalf("err = %q", err.Error()) + } +} + +func TestUpdateDatasetDocumentRejectsCounterMutation(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + insertTestKB(t, "kb-1", "tenant-1", 1, 10, 5) + insertTestDoc(t, "doc-1", "kb-1", 10, 5) + + chunkCount := int64(6) + svc := testDocumentService(t) + _, code, err := svc.UpdateDatasetDocument("tenant-1", "kb-1", "doc-1", &UpdateDatasetDocumentRequest{ + ChunkCount: &chunkCount, + }, map[string]bool{"chunk_count": true}) + if err == nil { + t.Fatal("expected chunk_count mutation error") + } + if code != common.CodeDataError { + t.Fatalf("code = %v, want %v", code, common.CodeDataError) + } + if err.Error() != "Can't change `chunk_count`." { + t.Fatalf("err = %q", err.Error()) + } +} + +func TestUpdateDatasetDocumentAllowsZeroCounterLikePythonTruthyCheck(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + insertTestKB(t, "kb-1", "tenant-1", 1, 10, 5) + insertTestDoc(t, "doc-1", "kb-1", 10, 5) + + chunkCount := int64(0) + svc := testDocumentService(t) + _, code, err := svc.UpdateDatasetDocument("tenant-1", "kb-1", "doc-1", &UpdateDatasetDocumentRequest{ + ChunkCount: &chunkCount, + }, map[string]bool{"chunk_count": true}) + if err != nil { + t.Fatalf("UpdateDatasetDocument failed: code=%v err=%v", code, err) + } +} + +func TestUpdateDatasetDocumentRejectsUnsupportedParserIDForVisualDoc(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + insertTestKB(t, "kb-1", "tenant-1", 1, 0, 0) + insertNamedTestDoc(t, "doc-1", "kb-1", "image.png", 0, 0) + if err := dao.DB.Model(&entity.Document{}).Where("id = ?", "doc-1").Update("type", "visual").Error; err != nil { + t.Fatalf("update doc type: %v", err) + } + + parserID := "naive" + svc := testDocumentService(t) + _, code, err := svc.UpdateDatasetDocument("tenant-1", "kb-1", "doc-1", &UpdateDatasetDocumentRequest{ + ParserID: &parserID, + }, map[string]bool{"parser_id": true}) + if err == nil { + t.Fatal("expected parser_id visual error") + } + if code != common.CodeDataError { + t.Fatalf("code = %v, want %v", code, common.CodeDataError) + } + if err.Error() != "Not supported yet!" { + t.Fatalf("err = %q", err.Error()) + } +} + +func TestUpdateDatasetDocumentRenameUpdatesDocumentAndFile(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + insertTestKB(t, "kb-1", "tenant-1", 1, 0, 0) + insertNamedTestDoc(t, "doc-1", "kb-1", "old.pdf", 0, 0) + insertTestFile(t, "file-1", "folder-1", "old.pdf", sptr("old.pdf")) + insertTestFile2Document(t, "f2d-1", "file-1", "doc-1") + + newName := "new.pdf" + svc := testDocumentService(t) + resp, code, err := svc.UpdateDatasetDocument("tenant-1", "kb-1", "doc-1", &UpdateDatasetDocumentRequest{ + Name: &newName, + }, map[string]bool{"name": true}) + if err != nil { + t.Fatalf("UpdateDatasetDocument failed: code=%v err=%v", code, err) + } + if resp == nil || resp.Name == nil || *resp.Name != newName { + t.Fatalf("response name = %#v, want %q", resp, newName) + } + + doc, _ := dao.NewDocumentDAO().GetByID("doc-1") + if doc.Name == nil || *doc.Name != newName { + t.Fatalf("document name = %v, want %q", doc.Name, newName) + } + file, _ := dao.NewFileDAO().GetByID("file-1") + if file.Name != newName { + t.Fatalf("file name = %q, want %q", file.Name, newName) + } +} + +func TestUpdateDatasetDocumentChunkMethodResetsForReparse(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + insertTestKB(t, "kb-1", "tenant-1", 1, 10, 5) + insertNamedTestDoc(t, "doc-1", "kb-1", "doc.txt", 10, 5) + + chunkMethod := "manual" + svc := testDocumentService(t) + resp, code, err := svc.UpdateDatasetDocument("tenant-1", "kb-1", "doc-1", &UpdateDatasetDocumentRequest{ + ChunkMethod: &chunkMethod, + }, map[string]bool{"chunk_method": true}) + if err != nil { + t.Fatalf("UpdateDatasetDocument failed: code=%v err=%v", code, err) + } + if resp.ChunkMethod != chunkMethod || resp.Run != "UNSTART" || resp.TokenCount != 0 || resp.ChunkCount != 0 { + t.Fatalf("response = %+v, want method=%s run=UNSTART counts=0", resp, chunkMethod) + } + + doc, _ := dao.NewDocumentDAO().GetByID("doc-1") + if doc.ParserID != chunkMethod { + t.Fatalf("parser_id = %q, want %q", doc.ParserID, chunkMethod) + } + if doc.TokenNum != 0 || doc.ChunkNum != 0 || doc.Progress != 0 { + t.Fatalf("doc counters/progress = token:%d chunk:%d progress:%f, want zero", doc.TokenNum, doc.ChunkNum, doc.Progress) + } + kb, _ := dao.NewKnowledgebaseDAO().GetByID("kb-1") + if kb.TokenNum != 0 || kb.ChunkNum != 0 { + t.Fatalf("kb counters = token:%d chunk:%d, want zero", kb.TokenNum, kb.ChunkNum) + } +} + +func TestUpdateDatasetDocumentParserIDResetsForReparse(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + insertTestKB(t, "kb-1", "tenant-1", 1, 10, 5) + insertNamedTestDoc(t, "doc-1", "kb-1", "doc.txt", 10, 5) + + parserID := "manual" + svc := testDocumentService(t) + resp, code, err := svc.UpdateDatasetDocument("tenant-1", "kb-1", "doc-1", &UpdateDatasetDocumentRequest{ + ParserID: &parserID, + }, map[string]bool{"parser_id": true}) + if err != nil { + t.Fatalf("UpdateDatasetDocument failed: code=%v err=%v", code, err) + } + if resp.ChunkMethod != parserID || resp.Run != "UNSTART" || resp.TokenCount != 0 || resp.ChunkCount != 0 { + t.Fatalf("response = %+v, want parser_id=%s run=UNSTART counts=0", resp, parserID) + } + + doc, _ := dao.NewDocumentDAO().GetByID("doc-1") + if doc.ParserID != parserID { + t.Fatalf("parser_id = %q, want %q", doc.ParserID, parserID) + } + if doc.TokenNum != 0 || doc.ChunkNum != 0 { + t.Fatalf("doc counters = token:%d chunk:%d, want zero", doc.TokenNum, doc.ChunkNum) + } + kb, _ := dao.NewKnowledgebaseDAO().GetByID("kb-1") + if kb.TokenNum != 0 || kb.ChunkNum != 0 { + t.Fatalf("kb counters = token:%d chunk:%d, want zero", kb.TokenNum, kb.ChunkNum) + } +} + +func TestResetDocumentForReparseSkipsSecondCounterDecrement(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + insertTestKB(t, "kb-1", "tenant-1", 1, 10, 5) + insertNamedTestDoc(t, "doc-1", "kb-1", "doc.txt", 10, 5) + + staleDoc, err := dao.NewDocumentDAO().GetByID("doc-1") + if err != nil { + t.Fatalf("get doc: %v", err) + } + + svc := testDocumentService(t) + parserID := "manual" + if err := svc.resetDocumentForReparse(staleDoc, "tenant-1", &parserID, nil); err != nil { + t.Fatalf("first resetDocumentForReparse failed: %v", err) + } + if err := svc.resetDocumentForReparse(staleDoc, "tenant-1", &parserID, nil); err != nil { + t.Fatalf("second resetDocumentForReparse failed: %v", err) + } + + doc, _ := dao.NewDocumentDAO().GetByID("doc-1") + if doc.TokenNum != 0 || doc.ChunkNum != 0 { + t.Fatalf("doc counters = token:%d chunk:%d, want zero", doc.TokenNum, doc.ChunkNum) + } + kb, _ := dao.NewKnowledgebaseDAO().GetByID("kb-1") + if kb.TokenNum != 0 || kb.ChunkNum != 0 { + t.Fatalf("kb counters = token:%d chunk:%d, want zero after duplicate reset", kb.TokenNum, kb.ChunkNum) + } +} + +func TestUpdateDatasetDocumentPropagatesMetadataDeleteFailure(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + insertTestKB(t, "kb-1", "tenant-1", 1, 0, 0) + insertNamedTestDoc(t, "doc-1", "kb-1", "doc.txt", 0, 0) + + engine := &failingDeleteMetadataEngine{deleteErr: errors.New("delete failed")} + svc := testDocumentService(t) + svc.docEngine = engine + svc.metadataSvc = &MetadataService{} + + _, code, err := svc.UpdateDatasetDocument("tenant-1", "kb-1", "doc-1", &UpdateDatasetDocumentRequest{ + MetaFields: map[string]any{"new": "value"}, + }, map[string]bool{"meta_fields": true}) + if err == nil { + t.Fatal("expected metadata delete error") + } + if code != common.CodeDataError { + t.Fatalf("code = %v, want %v", code, common.CodeDataError) + } + if err.Error() != "failed to delete document metadata: delete failed" { + t.Fatalf("err = %q", err.Error()) + } + if engine.updateCalled { + t.Fatal("metadata update should not run after delete failure") + } +} + +func TestChunkImageStorageKeyUsesImgIDWithDatasetPrefix(t *testing.T) { + key, ok := chunkImageStorageKey("kb-1", map[string]interface{}{ + "id": "chunk-1", + "img_id": "kb-1-image-001", + }) + if !ok { + t.Fatal("expected image storage key") + } + if key != "image-001" { + t.Fatalf("key = %q, want %q", key, "image-001") + } +} + +func TestChunkImageStorageKeyHandlesHyphenatedDatasetID(t *testing.T) { + key, ok := chunkImageStorageKey("dataset-abc-123", map[string]interface{}{ + "id": "chunk-1", + "img_id": "dataset-abc-123-page-1-image", + }) + if !ok { + t.Fatal("expected image storage key") + } + if key != "page-1-image" { + t.Fatalf("key = %q, want %q", key, "page-1-image") + } +} + +func TestChunkImageStorageKeyFallsBackToChunkID(t *testing.T) { + key, ok := chunkImageStorageKey("kb-1", map[string]interface{}{ + "_id": "chunk-fallback", + }) + if !ok { + t.Fatal("expected fallback storage key") + } + if key != "chunk-fallback" { + t.Fatalf("key = %q, want %q", key, "chunk-fallback") + } +} + +func TestUpdateDatasetDocumentPipelineIDTakesPrecedenceOverChunkMethod(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + insertTestKB(t, "kb-1", "tenant-1", 1, 10, 5) + insertNamedTestDoc(t, "doc-1", "kb-1", "doc.txt", 10, 5) + + pipelineID := "1234567890abcdef1234567890abcdef" + chunkMethod := "manual" + svc := testDocumentService(t) + resp, code, err := svc.UpdateDatasetDocument("tenant-1", "kb-1", "doc-1", &UpdateDatasetDocumentRequest{ + PipelineID: &pipelineID, + ChunkMethod: &chunkMethod, + }, map[string]bool{"pipeline_id": true, "chunk_method": true}) + if err != nil { + t.Fatalf("UpdateDatasetDocument failed: code=%v err=%v", code, err) + } + if resp.PipelineID == nil || *resp.PipelineID != pipelineID { + t.Fatalf("pipeline_id = %v, want %q", resp.PipelineID, pipelineID) + } + if resp.ChunkMethod != "naive" { + t.Fatalf("chunk_method = %q, want original naive", resp.ChunkMethod) + } +} + +func TestUpdateDatasetDocumentEnabledUpdatesStatus(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + insertTestKB(t, "kb-1", "tenant-1", 1, 0, 0) + insertTestDoc(t, "doc-1", "kb-1", 0, 0) + + enabled := 0 + svc := testDocumentService(t) + resp, code, err := svc.UpdateDatasetDocument("tenant-1", "kb-1", "doc-1", &UpdateDatasetDocumentRequest{ + Enabled: &enabled, + }, map[string]bool{"enabled": true}) + if err != nil { + t.Fatalf("UpdateDatasetDocument failed: code=%v err=%v", code, err) + } + if resp.Status == nil || *resp.Status != "0" { + t.Fatalf("status = %v, want 0", resp.Status) + } +} + +func insertNamedTestDoc(t *testing.T, id, kbID, name string, tokenNum, chunkNum int64) { + t.Helper() + doc := &entity.Document{ + ID: id, + KbID: kbID, + ParserID: "naive", + ParserConfig: entity.JSONMap{}, + TokenNum: tokenNum, + ChunkNum: chunkNum, + Progress: 0.75, + Name: sptr(name), + Type: "doc", + SourceType: "local", + CreatedBy: "tenant-1", + Suffix: filepath.Ext(name), + Status: sptr("1"), + Run: sptr(string(entity.TaskStatusDone)), + } + if err := dao.DB.Create(doc).Error; err != nil { + t.Fatalf("insert named test doc: %v", err) + } +}