feat[go]: datasets/<dataset_id>/chunks DELETE (#16185)

### What problem does this PR solve?

As title:

`documents.POST("/ingest", r.documentHandler.Ingest)`:

---

<img width="3750" height="2039" alt="image"
src="https://github.com/user-attachments/assets/533c1c3d-af3e-47e6-9f51-a278539b7066"
/>

`datasets.DELETE("/:dataset_id/chunks", r.chunkHandler.StopParsing)`

---

<img width="3621" height="2040" alt="image"
src="https://github.com/user-attachments/assets/022adcdb-1e47-4883-9611-1a695c34007d"
/>


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Haruko386
2026-06-24 19:43:18 +08:00
committed by GitHub
parent c2665d4ab1
commit dd46ece3bc
8 changed files with 1175 additions and 4 deletions

View File

@@ -38,6 +38,7 @@ type chunkService interface {
UpdateChunk(req *service.UpdateChunkRequest, userID string) error
RemoveChunks(req *service.RemoveChunksRequest, userID string) (int64, error)
Parse(userID, datasetID string, req *service.ParseFileRequest) (map[string]interface{}, common.ErrorCode, error)
StopParsing(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error)
}
// ChunkHandler chunk handler
@@ -224,8 +225,8 @@ func (h *ChunkHandler) Parse(c *gin.Context) {
})
return
}
datasetID := strings.TrimSpace(c.Param("dataset_id"))
if datasetID == "" {
datasetId := strings.TrimSpace(c.Param("dataset_id"))
if datasetId == "" {
c.JSON(http.StatusBadRequest, gin.H{
"code": common.CodeBadRequest,
"message": "dataset_id is required",
@@ -243,7 +244,7 @@ func (h *ChunkHandler) Parse(c *gin.Context) {
return
}
data, code, err := h.chunkService.Parse(userID, datasetID, &req)
data, code, err := h.chunkService.Parse(userID, datasetId, &req)
if code != common.CodeSuccess {
c.JSON(http.StatusOK, gin.H{
"code": code,
@@ -353,6 +354,59 @@ func parseAvailableQuery(raw string) (int, bool, error) {
}
}
func (h *ChunkHandler) StopParsing(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
datasetID := c.Param("dataset_id")
if datasetID == "" {
jsonError(c, common.CodeDataError, "dataset_id is required")
return
}
var req service.StopParsingRequest
if err := c.ShouldBindJSON(&req); err != nil {
jsonError(c, common.CodeDataError, err.Error())
return
}
if len(req.DocumentIDs) == 0 {
jsonError(c, common.CodeDataError, "`document_ids` is required")
return
}
resp, code, err := h.chunkService.StopParsing(user.ID, datasetID, req)
if err != nil {
var data interface{}
if resp != nil {
data = resp.Data
}
c.JSON(http.StatusOK, gin.H{
"code": code,
"data": data,
"message": err.Error(),
})
return
}
message := "success"
var data interface{}
if resp != nil {
if resp.Message != "" {
message = resp.Message
}
data = resp.Data
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"data": data,
"message": message,
})
}
// List retrieves chunks for a document.
// @Summary List Chunks
// @Description Retrieve paginated chunks for a document with optional filtering.

View File

@@ -23,6 +23,7 @@ type mockChunkSvc struct {
listFn func(req *service.ListChunksRequest, userID string) (*service.ListChunksResponse, error)
switchChunksFn func(userID, datasetID, documentID string, availableInt int, chunkIDs []string) error
updateChunkFn func(req *service.UpdateChunkRequest, userID string) error
stopParsingFn func(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error)
}
func (m *mockChunkSvc) RetrievalTest(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
@@ -58,6 +59,12 @@ func (m *mockChunkSvc) UpdateChunk(req *service.UpdateChunkRequest, userID strin
func (m *mockChunkSvc) RemoveChunks(*service.RemoveChunksRequest, string) (int64, error) {
panic("not implemented")
}
func (m *mockChunkSvc) StopParsing(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error) {
if m.stopParsingFn != nil {
return m.stopParsingFn(userID, datasetID, req)
}
panic("not implemented")
}
func (m *mockChunkSvc) Parse(string, string, *service.ParseFileRequest) (map[string]interface{}, common.ErrorCode, error) {
panic("not implemented")
}
@@ -84,6 +91,18 @@ func setupChunkRetrievalTestNoAuth() *gin.Engine {
return r
}
func setupChunkStopParsingTest(userID string) (*gin.Engine, *mockChunkSvc) {
mock := &mockChunkSvc{}
h := &ChunkHandler{chunkService: mock}
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("user", &entity.User{ID: userID})
})
r.DELETE("/api/v1/datasets/:dataset_id/chunks", h.StopParsing)
return r, mock
}
func setupChunkHandlerWithUser(userID string, mock *mockChunkSvc) (*gin.Engine, *ChunkHandler) {
h := &ChunkHandler{chunkService: mock}
gin.SetMode(gin.TestMode)
@@ -254,6 +273,103 @@ func TestChunkRetrieval_EmptyQuestion(t *testing.T) {
}
}
func TestChunkStopParsing_Success(t *testing.T) {
r, mock := setupChunkStopParsingTest("user1")
mock.stopParsingFn = func(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error) {
if userID != "user1" {
t.Fatalf("expected user1, got %q", userID)
}
if datasetID != "kb1" {
t.Fatalf("expected kb1, got %q", datasetID)
}
if len(req.DocumentIDs) != 2 || req.DocumentIDs[0] != "doc1" || req.DocumentIDs[1] != "doc2" {
t.Fatalf("unexpected document IDs: %#v", req.DocumentIDs)
}
return nil, common.CodeSuccess, nil
}
w := httptest.NewRecorder()
req, _ := http.NewRequest("DELETE", "/api/v1/datasets/kb1/chunks", strings.NewReader(`{"document_ids":["doc1","doc2"]}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp["code"] != float64(common.CodeSuccess) {
t.Fatalf("expected code 0, got %v: %s", resp["code"], w.Body.String())
}
if resp["message"] != "success" {
t.Fatalf("expected success message, got %v", resp["message"])
}
}
func TestChunkStopParsingRouteRequiresDocumentIDs(t *testing.T) {
r, mock := setupChunkStopParsingTest("user1")
mock.stopParsingFn = func(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error) {
t.Fatal("service should not be called when document_ids is missing")
return nil, common.CodeSuccess, nil
}
w := httptest.NewRecorder()
req, _ := http.NewRequest("DELETE", "/api/v1/datasets/kb1/chunks", strings.NewReader(`{}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp["code"] != float64(common.CodeDataError) {
t.Fatalf("expected data error, got %v: %s", resp["code"], w.Body.String())
}
if resp["message"] != "`document_ids` is required" {
t.Fatalf("unexpected message: %v", resp["message"])
}
}
func TestChunkStopParsing_InvalidStateIncludesPythonErrorCode(t *testing.T) {
r, mock := setupChunkStopParsingTest("user1")
mock.stopParsingFn = func(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error) {
return &service.StopParsingResponse{
Data: map[string]interface{}{"error_code": "DOC_STOP_PARSING_INVALID_STATE"},
}, common.CodeDataError, errors.New("Can't stop parsing document that has not started or already completed")
}
w := httptest.NewRecorder()
req, _ := http.NewRequest("DELETE", "/api/v1/datasets/kb1/chunks", strings.NewReader(`{"document_ids":["doc1"]}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp["code"] != float64(common.CodeDataError) {
t.Fatalf("expected data error, got %v: %s", resp["code"], w.Body.String())
}
data, ok := resp["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected data object, got %T", resp["data"])
}
if data["error_code"] != "DOC_STOP_PARSING_INVALID_STATE" {
t.Fatalf("unexpected error_code: %v", data["error_code"])
}
if resp["message"] != "Can't stop parsing document that has not started or already completed" {
t.Fatalf("unexpected message: %v", resp["message"])
}
}
func TestChunkRetrieval_WhitespaceQuestion(t *testing.T) {
r, _ := setupChunkRetrievalTest("user1")

View File

@@ -67,6 +67,7 @@ type documentServiceIface interface {
ListIngestionTasks(userID string, datasetID *string, page, pageSize int) ([]*entity.IngestionTask, error)
IngestDocuments(datasetID, userID string, docIDs []string) ([]*service.ParseDocumentResponse, error)
StopIngestionTasks(tasks []string, userID string) ([]*entity.IngestionTask, error)
Ingest(userID string, req *service.IngestDocumentRequest) (common.ErrorCode, error)
RemoveIngestionTasks(tasks []string, userID string) ([]map[string]string, error)
BatchUpdateDocumentStatus(userID, datasetID, status string, DocumentIDs []string) (map[string]interface{}, common.ErrorCode, error)
}
@@ -874,6 +875,37 @@ func (h *DocumentHandler) SetMeta(c *gin.Context) {
})
}
func (h *DocumentHandler) Ingest(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := strings.TrimSpace(user.ID)
if userID == "" {
jsonError(c, common.CodeAuthenticationError, "No Authentication")
return
}
var req service.IngestDocumentRequest
if err := c.ShouldBindJSON(&req); err != nil {
jsonError(c, common.CodeBadRequest, err.Error())
return
}
if code, err := h.documentService.Ingest(userID, &req); err != nil {
jsonError(c, code, err.Error())
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"message": "success",
"data": true,
})
}
// DeleteMetaRequest represents the request for deleting document metadata
type DeleteMetaRequest struct {
DocID string `json:"doc_id" binding:"required"`

View File

@@ -45,6 +45,19 @@ type fakeDocumentService struct {
metadataErr error
metadataKBID string
metadataDocIDs []string
ingestCode common.ErrorCode
ingestErr error
ingestUserID string
ingestReq *service.IngestDocumentRequest
}
func (f *fakeDocumentService) Ingest(userID string, req *service.IngestDocumentRequest) (common.ErrorCode, error) {
f.ingestUserID = userID
f.ingestReq = req
if f.ingestCode != 0 || f.ingestErr != nil {
return f.ingestCode, f.ingestErr
}
return common.CodeSuccess, nil
}
func (f *fakeDocumentService) UpdateDatasetDocument(userID, datasetID, documentID string, req *service.UpdateDatasetDocumentRequest, present map[string]bool) (*service.UpdateDatasetDocumentResponse, common.ErrorCode, error) {
@@ -176,6 +189,21 @@ func setupGinContextWithUser(method, path, body string) (*gin.Context, *httptest
return c, w
}
func setupDocumentIngestRoute(userID string, svc *fakeDocumentService) *gin.Engine {
gin.SetMode(gin.TestMode)
h := &DocumentHandler{
documentService: svc,
datasetService: service.NewDatasetService(),
}
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("user", &entity.User{ID: userID})
c.Set("user_id", userID)
})
r.POST("/api/v1/documents/ingest", h.Ingest)
return r
}
func TestDeleteDocumentsHandler_Success(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -323,6 +351,116 @@ func TestDeleteDocumentsHandler_MissingDatasetID(t *testing.T) {
}
}
func TestDocumentHandlerIngestMatchesPythonResponseShape(t *testing.T) {
gin.SetMode(gin.TestMode)
fake := &fakeDocumentService{}
h := &DocumentHandler{
documentService: fake,
datasetService: service.NewDatasetService(),
}
c, w := setupGinContextWithUser("POST", "/api/v1/documents/ingest", `{"doc_ids":["doc-1"],"run":"1"}`)
h.Ingest(c)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp["code"] != float64(common.CodeSuccess) {
t.Fatalf("expected top-level code 0, got %v", resp["code"])
}
if resp["data"] != true {
t.Fatalf("expected top-level data=true, got %#v", resp["data"])
}
if _, ok := resp["data"].(map[string]interface{}); ok {
t.Fatalf("response must not nest code/message under data: %#v", resp["data"])
}
if fake.ingestUserID != "user-1" {
t.Fatalf("expected user-1, got %q", fake.ingestUserID)
}
if fake.ingestReq == nil || len(fake.ingestReq.DocIDs) != 1 || fake.ingestReq.DocIDs[0] != "doc-1" {
t.Fatalf("unexpected ingest request: %#v", fake.ingestReq)
}
}
func TestDocumentIngestRoutePassesPythonBodyToService(t *testing.T) {
fake := &fakeDocumentService{}
r := setupDocumentIngestRoute("user-1", fake)
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/v1/documents/ingest", strings.NewReader(`{"doc_ids":["doc-1","doc-2"],"run":1,"delete":true,"apply_kb":true}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp["code"] != float64(common.CodeSuccess) || resp["data"] != true {
t.Fatalf("unexpected response: %s", w.Body.String())
}
if fake.ingestUserID != "user-1" {
t.Fatalf("userID = %q, want user-1", fake.ingestUserID)
}
if fake.ingestReq == nil {
t.Fatal("service did not receive ingest request")
}
if len(fake.ingestReq.DocIDs) != 2 || fake.ingestReq.DocIDs[0] != "doc-1" || fake.ingestReq.DocIDs[1] != "doc-2" {
t.Fatalf("doc_ids = %#v, want [doc-1 doc-2]", fake.ingestReq.DocIDs)
}
if fmt.Sprint(fake.ingestReq.Run) != "1" {
t.Fatalf("run = %#v, want 1", fake.ingestReq.Run)
}
if !fake.ingestReq.Delete {
t.Fatal("delete = false, want true")
}
if !fake.ingestReq.ApplyKB {
t.Fatal("apply_kb = false, want true")
}
}
func TestDocumentHandlerIngestPropagatesServiceErrorCode(t *testing.T) {
gin.SetMode(gin.TestMode)
fake := &fakeDocumentService{
ingestCode: common.CodeAuthenticationError,
ingestErr: fmt.Errorf("No authorization."),
}
h := &DocumentHandler{
documentService: fake,
datasetService: service.NewDatasetService(),
}
c, w := setupGinContextWithUser("POST", "/api/v1/documents/ingest", `{"doc_ids":["doc-1"],"run":"1"}`)
h.Ingest(c)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp["code"] != float64(common.CodeAuthenticationError) {
t.Fatalf("expected auth error code, got %v", resp["code"])
}
if resp["message"] != "No authorization." {
t.Fatalf("unexpected message: %v", resp["message"])
}
if resp["data"] != nil {
t.Fatalf("expected nil data, got %#v", resp["data"])
}
}
func TestStopParseDocumentsHandler_EmptyDocIDs(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -243,6 +243,7 @@ func (r *Router) Setup(engine *gin.Engine) {
documents.GET("/:id", r.documentHandler.GetDocumentByID)
documents.PUT("/:id", r.documentHandler.UpdateDocument)
documents.DELETE("/:id", r.documentHandler.DeleteDocument)
documents.POST("/ingest", r.documentHandler.Ingest)
}
// Chat routes
@@ -320,6 +321,7 @@ func (r *Router) Setup(engine *gin.Engine) {
datasets.DELETE("/ingestion/tasks", r.documentHandler.RemoveIngestionTasks)
//datasets.POST("/:dataset_id/documents/parse", r.documentHandler.ParseDocuments)
//datasets.POST("/:dataset_id/documents/stop", r.documentHandler.StopParseDocuments)
datasets.DELETE("/:dataset_id/chunks", r.chunkHandler.StopParsing)
datasets.DELETE("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.RemoveChunks)
datasets.PUT("/:dataset_id/documents/:document_id/metadata/config", r.datasetsHandler.UpdateDocumentMetadataConfig)
datasets.POST("/:dataset_id/metadata/update", r.documentHandler.MetadataBatchUpdate)

View File

@@ -28,6 +28,7 @@ import (
"math/rand"
"path/filepath"
"ragflow/internal/common"
"ragflow/internal/engine/redis"
"ragflow/internal/entity"
"ragflow/internal/entity/models"
"ragflow/internal/server"
@@ -42,7 +43,6 @@ import (
"ragflow/internal/dao"
"ragflow/internal/engine"
"ragflow/internal/engine/redis"
"ragflow/internal/engine/types"
"ragflow/internal/service"
"ragflow/internal/service/nlp"
@@ -64,6 +64,7 @@ type ChunkService struct {
kbDAO *dao.KnowledgebaseDAO
userTenantDAO *dao.UserTenantDAO
documentDAO *dao.DocumentDAO
taskDAO *dao.TaskDAO
searchService *service.SearchService
accessibleFunc func(string, string) bool
@@ -85,6 +86,7 @@ func NewChunkService() *ChunkService {
kbDAO: dao.NewKnowledgebaseDAO(),
userTenantDAO: dao.NewUserTenantDAO(),
documentDAO: dao.NewDocumentDAO(),
taskDAO: dao.NewTaskDAO(),
searchService: service.NewSearchService(),
}
}
@@ -571,6 +573,112 @@ func (s *ChunkService) Get(req *service.GetChunkRequest, userID string) (*servic
return &service.GetChunkResponse{Chunk: chunk}, nil
}
const (
docStopParsingInvalidStateMessage = "Can't stop parsing document that has not started or already completed"
docStopParsingInvalidStateErrorCode = "DOC_STOP_PARSING_INVALID_STATE"
)
func (s *ChunkService) cancelAllTasksOfDoc(docID string) error {
tasks, err := s.taskDAO.GetByDocID(docID)
if err != nil {
return fmt.Errorf("failed to get tasks for document %s: %w", docID, err)
}
redisClient := redis.Get()
if redisClient == nil {
common.Warn(fmt.Sprintf("Redis unavailable; cannot cancel tasks for document %s", docID))
return nil
}
for _, task := range tasks {
if task == nil {
continue
}
redisClient.Set(fmt.Sprintf("%s-cancel", task.ID), "x", 0)
}
return nil
}
func (s *ChunkService) StopParsing(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error) {
if !s.kbDAO.Accessible(datasetID, userID) {
return nil, common.CodeDataError, fmt.Errorf("You don't own the dataset %s", datasetID)
}
if len(req.DocumentIDs) == 0 {
return nil, common.CodeDataError, fmt.Errorf("`document_ids` is required")
}
kb, err := s.kbDAO.GetByID(datasetID)
if err != nil {
return nil, common.CodeDataError, fmt.Errorf("You don't own the dataset %s", datasetID)
}
docIDs, duplicateMessages := service.CheckDuplicateIDs(req.DocumentIDs, "document")
successCount := 0
ctx := context.Background()
indexName := service.IndexName(kb.TenantID)
for _, docID := range docIDs {
doc, err := s.documentDAO.GetByDocumentIDAndDatasetID(docID, datasetID)
if err != nil || doc == nil {
return nil, common.CodeDataError, fmt.Errorf("You don't own the document %s", docID)
}
if doc.Run == nil || *doc.Run != string(entity.TaskStatusRunning) {
return &service.StopParsingResponse{
Data: map[string]interface{}{"error_code": docStopParsingInvalidStateErrorCode},
}, common.CodeDataError, fmt.Errorf("%s", docStopParsingInvalidStateMessage)
}
if err := s.cancelAllTasksOfDoc(docID); err != nil {
return nil, common.CodeServerError, err
}
updates := map[string]interface{}{
"run": string(entity.TaskStatusCancel),
"progress": 0,
"chunk_num": 0,
}
if err := s.documentDAO.UpdateByID(doc.ID, updates); err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to update document %s: %w", doc.ID, err)
}
if s.docEngine != nil {
exists, err := s.docEngine.ChunkStoreExists(ctx, indexName, datasetID)
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to check chunk store %s/%s: %w", indexName, datasetID, err)
}
if exists {
if _, err := s.docEngine.DeleteChunks(ctx, map[string]interface{}{"doc_id": doc.ID}, indexName, datasetID); err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to delete chunks for document %s: %w", doc.ID, err)
}
} else {
common.Info(fmt.Sprintf("Skipping chunk delete during stop_parsing for doc %s: index %s/%s does not exist", doc.ID, indexName, datasetID))
}
} else {
common.Info(fmt.Sprintf("Skipping chunk delete during stop_parsing for doc %s: index %s/%s does not exist", doc.ID, indexName, datasetID))
}
successCount++
}
if len(duplicateMessages) > 0 {
if successCount > 0 {
return &service.StopParsingResponse{
Message: fmt.Sprintf("Partially stopped %d documents with %d errors", successCount, len(duplicateMessages)),
Data: map[string]interface{}{
"success_count": successCount,
"errors": duplicateMessages,
},
}, common.CodeSuccess, nil
}
return nil, common.CodeDataError, fmt.Errorf("%s", strings.Join(duplicateMessages, ";"))
}
return nil, common.CodeSuccess, nil
}
func checkDuplicateIDs(documentIDs []string, idTypes string) ([]string, []string) {
idCount := make(map[string]int, len(documentIDs))
duplicateMessages := make([]string, 0)

View File

@@ -21,6 +21,7 @@ import (
"encoding/json"
"fmt"
"ragflow/internal/common"
"ragflow/internal/engine/redis"
"ragflow/internal/server"
"strings"
@@ -31,12 +32,22 @@ import (
"ragflow/internal/utility"
)
var (
UNSTART = "0"
RUNNING = "1"
CANCEL = "2"
DONE = "3"
FAIL = "4"
SCHEDULE = "5"
)
// ChunkService chunk service
type ChunkService struct {
docEngine engine.DocEngine
engineType server.EngineType
embeddingCache *utility.EmbeddingLRU
kbDAO *dao.KnowledgebaseDAO
taskDAO *dao.TaskDAO
userTenantDAO *dao.UserTenantDAO
documentDAO *dao.DocumentDAO
searchService *SearchService
@@ -162,6 +173,131 @@ func (s *ChunkService) Get(req *GetChunkRequest, userID string) (*GetChunkRespon
return &GetChunkResponse{Chunk: chunk}, nil
}
type StopParsingRequest struct {
DocumentIDs []string `json:"document_ids"`
}
type StopParsingResponse struct {
Data map[string]interface{}
Message string
}
func (s *ChunkService) cancelAllTasksOfDoc(docID string) error {
tasks, err := s.taskDAO.GetByDocID(docID)
if err != nil {
return fmt.Errorf("failed to get tasks for document %s: %w", docID, err)
}
redisClient := redis.Get()
if redisClient == nil {
common.Logger.Warn(fmt.Sprintf("Redis unavailable; cannot cancel tasks for document %s", docID))
return nil
}
for _, task := range tasks {
if task == nil {
continue
}
redisClient.Set(fmt.Sprintf("%s-cancel", task.ID), "x", 0)
}
return nil
}
func CheckDuplicateIDs(docList []string, idType string) ([]string, []string) {
uniqueDocIDs := make([]string, 0)
duplicateMessages := make([]string, 0)
idCount := make(map[string]int)
for _, docID := range docList {
idCount[docID] += 1
}
for id, count := range idCount {
if count > 1 {
duplicateMessages = append(duplicateMessages, fmt.Sprintf("Duplicate %s ids: %s", idType, id))
}
uniqueDocIDs = append(uniqueDocIDs, id)
}
return uniqueDocIDs, duplicateMessages
}
func IndexName(uid string) string {
return fmt.Sprintf("ragflow_%s", uid)
}
func (s *ChunkService) StopParsing(userID, datasetID string, req StopParsingRequest) (map[string]interface{}, common.ErrorCode, error) {
if !s.kbDAO.Accessible(datasetID, userID) {
return nil, common.CodeAuthenticationError, fmt.Errorf("You don't own the dataset %s", datasetID)
}
if req.DocumentIDs == nil || len(req.DocumentIDs) == 0 {
return nil, common.CodeDataError, fmt.Errorf("document_ids is required")
}
docList, duplicateMessages := CheckDuplicateIDs(req.DocumentIDs, "document")
kb, err := s.kbDAO.GetByID(datasetID)
if err != nil {
return nil, common.CodeDataError, fmt.Errorf("You don't own the dataset %s", datasetID)
}
successCount := 0
for _, id := range docList {
doc, err := s.documentDAO.GetByDocumentIDAndDatasetID(id, datasetID)
if err != nil {
return nil, common.CodeDataError, fmt.Errorf("You don't own the document %s", id)
}
if doc == nil {
return nil, common.CodeDataError, fmt.Errorf("You don't own the document %s", id)
}
if doc.Run == nil || *doc.Run != RUNNING {
return nil, common.CodeDataError, fmt.Errorf("Can't stop parsing document that has not started or already completed")
}
err = s.cancelAllTasksOfDoc(id)
if err != nil {
return nil, common.CodeServerError, err
}
info := map[string]interface{}{
"run": "2",
"progress": 0,
"chunk_num": 0,
}
if err := s.documentDAO.UpdateByID(doc.ID, info); err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to update document %s: %w", doc.ID, err)
}
indexName := IndexName(kb.TenantID)
exists, err := s.docEngine.ChunkStoreExists(context.Background(), indexName, datasetID)
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to check chunk store %s/%s: %w", indexName, datasetID, err)
}
if exists {
if _, err := s.docEngine.DeleteChunks(context.Background(), map[string]interface{}{"doc_id": doc.ID}, indexName, datasetID); err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to delete chunks for document %s: %w", doc.ID, err)
}
} else {
common.Logger.Info(fmt.Sprintf("Skipping chunk delete during stop_parsing for doc %s: index %s/%s does not exist", doc.ID, indexName, datasetID))
}
successCount++
}
if len(duplicateMessages) > 0 {
if successCount > 0 {
return map[string]interface{}{
"success_count": successCount,
"errors": duplicateMessages,
"message": fmt.Sprintf("Partially stopped %d documents with %d errors", successCount, len(duplicateMessages)),
}, common.CodeSuccess, nil
}
return nil, common.CodeDataError, fmt.Errorf("%s", strings.Join(duplicateMessages, ";"))
}
return nil, common.CodeSuccess, nil
}
// ListChunksRequest request for listing chunks
type ListChunksRequest struct {
DatasetID string `json:"dataset_id,omitempty"`

View File

@@ -17,11 +17,17 @@
package service
import (
"archive/zip"
"bytes"
"context"
"encoding/csv"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io"
"math"
"math/rand"
"mime/multipart"
"os"
"path/filepath"
@@ -43,6 +49,7 @@ import (
"ragflow/internal/tokenizer"
"ragflow/internal/utility"
"github.com/cespare/xxhash/v2"
"go.uber.org/zap"
"gorm.io/gorm"
)
@@ -987,6 +994,584 @@ func (s *DocumentService) RemoveIngestionTasks(tasks []string, userID string) ([
return deletedTasks, nil
}
type IngestDocumentRequest struct {
DocIDs []string `json:"doc_ids" binding:"required"`
Run interface{} `json:"run" binding:"required"`
Delete bool `json:"delete"`
ApplyKB bool `json:"apply_kb"`
}
type documentParsePageRange struct {
from int64
to int64
}
func (s *DocumentService) Ingest(userID string, req *IngestDocumentRequest) (common.ErrorCode, error) {
run := fmt.Sprint(req.Run)
docs, err := s.documentDAO.GetByIDs(req.DocIDs)
if err != nil {
return common.CodeExceptionError, fmt.Errorf("fail to get documents: %s", err.Error())
}
docsByID := make(map[string]*entity.Document, len(docs))
for _, doc := range docs {
if doc != nil {
docsByID[doc.ID] = doc
}
}
tableDoneCountByKB := make(map[string]int64)
for _, docID := range req.DocIDs {
doc := docsByID[docID]
if doc == nil {
return common.CodeDataError, fmt.Errorf("Document not found!")
}
kb, err := s.kbDAO.GetByID(doc.KbID)
if err != nil {
return common.CodeDataError, fmt.Errorf("Tenant not found!")
}
if !s.kbDAO.Accessible(kb.ID, userID) {
return common.CodeAuthenticationError, fmt.Errorf("No authorization.")
}
updates := map[string]interface{}{
"run": run,
"progress": 0,
}
rerunWithDelete := run == string(entity.TaskStatusRunning) && req.Delete
if rerunWithDelete {
updates["progress_msg"] = ""
updates["chunk_num"] = 0
updates["token_num"] = 0
}
if run == string(entity.TaskStatusCancel) {
if err := s.cancelDocParse(doc); err != nil {
return common.CodeDataError, err
}
}
if rerunWithDelete && doc.Run != nil && *doc.Run == string(entity.TaskStatusDone) {
if err := s.clearKBChunkNumWhenRerun(doc); err != nil {
return common.CodeExceptionError, err
}
}
if err := s.documentDAO.UpdateByID(doc.ID, updates); err != nil {
return common.CodeExceptionError, err
}
if req.Delete {
_, _ = s.taskDAO.DeleteByDocIDs([]string{doc.ID})
indexName := fmt.Sprintf("ragflow_%s", kb.TenantID)
if s.docEngine != nil {
exists, err := s.docEngine.ChunkStoreExists(context.Background(), indexName, doc.KbID)
if err != nil {
return common.CodeExceptionError, err
}
if exists {
if _, err := s.docEngine.DeleteChunks(context.Background(), map[string]interface{}{"doc_id": doc.ID}, indexName, doc.KbID); err != nil {
return common.CodeExceptionError, err
}
}
}
}
if run == string(entity.TaskStatusRunning) {
if req.ApplyKB {
if doc.ParserConfig == nil {
doc.ParserConfig = entity.JSONMap{}
}
config := map[string]interface{}{
"llm_id": kb.ParserConfig["llm_id"],
"enable_metadata": false,
"metadata": map[string]interface{}{},
}
if value, ok := kb.ParserConfig["enable_metadata"]; ok {
config["enable_metadata"] = value
}
if value, ok := kb.ParserConfig["metadata"]; ok {
config["metadata"] = value
}
if err := s.updateDocumentParserConfig(doc.ID, config); err != nil {
return common.CodeExceptionError, err
}
for key, value := range config {
doc.ParserConfig[key] = value
}
}
if doc.PipelineID != nil && strings.TrimSpace(*doc.PipelineID) != "" {
if err := s.queueDocumentDataflowTask(kb, doc, strings.TrimSpace(*doc.PipelineID), 0); err != nil {
return common.CodeExceptionError, err
}
continue
}
if doc.ParserID == string(entity.ParserTypeTable) {
doneCount, ok := tableDoneCountByKB[doc.KbID]
if !ok {
count, err := s.countDoneDocuments(doc.KbID)
if err != nil {
return common.CodeExceptionError, err
}
doneCount = count
tableDoneCountByKB[doc.KbID] = doneCount
if doneCount <= 0 {
if err := s.kbDAO.DeleteFieldMap(doc.KbID); err != nil && !dao.IsNotFoundErr(err) {
return common.CodeExceptionError, err
}
}
}
}
if _, err := s.taskDAO.DeleteByDocIDs([]string{doc.ID}); err != nil {
return common.CodeExceptionError, err
}
bucket, objectName, err := s.GetDocumentStorageAddress(doc)
if err != nil {
return common.CodeExceptionError, err
}
if err := s.queueDocumentParseTasks(doc, bucket, objectName, 0); err != nil {
return common.CodeExceptionError, err
}
if err := s.beginDocumentParse(doc.ID); err != nil {
return common.CodeExceptionError, err
}
}
}
return common.CodeSuccess, nil
}
func (s *DocumentService) countDoneDocuments(datasetID string) (int64, error) {
var count int64
err := dao.GetDB().Model(&entity.Document{}).
Where("kb_id = ? AND run = ?", datasetID, string(entity.TaskStatusDone)).
Count(&count).Error
return count, err
}
func (s *DocumentService) queueDocumentParseTasks(doc *entity.Document, bucket, objectName string, priority int64) error {
if _, err := s.taskDAO.DeleteByDocIDs([]string{doc.ID}); err != nil {
return err
}
tasks, err := s.newDocumentParseTasks(doc, bucket, objectName, priority)
if err != nil {
return err
}
if err := s.taskDAO.CreateMany(tasks); err != nil {
return err
}
queueName := documentParseQueueName(doc, priority)
for _, task := range tasks {
if task.Progress >= 1 {
continue
}
if redisClient := redis.Get(); redisClient == nil || !redisClient.QueueProduct(queueName, documentTaskMessage(task)) {
return fmt.Errorf("Can't access Redis. Please check the Redis' status.")
}
}
return nil
}
func (s *DocumentService) queueDocumentDataflowTask(kb *entity.Knowledgebase, doc *entity.Document, flowID string, priority int64) error {
if _, err := s.taskDAO.DeleteByDocIDs([]string{doc.ID}); err != nil {
return err
}
if err := s.beginDocumentParse(doc.ID); err != nil {
return err
}
task := s.newDocumentParseTask(doc, 0, maximumTaskPageNumber, priority)
task.TaskType = "dataflow"
if err := s.taskDAO.CreateMany([]*entity.Task{task}); err != nil {
return err
}
message := documentTaskMessage(task)
message["task_type"] = task.TaskType
message["kb_id"] = doc.KbID
message["tenant_id"] = kb.TenantID
message["dataflow_id"] = flowID
message["file"] = nil
if redisClient := redis.Get(); redisClient == nil || !redisClient.QueueProduct(documentParseQueueName(doc, priority), message) {
return fmt.Errorf("Can't access Redis. Please check the Redis' status.")
}
return nil
}
func (s *DocumentService) newDocumentParseTasks(doc *entity.Document, bucket, objectName string, priority int64) ([]*entity.Task, error) {
ranges, err := documentParseTaskRanges(doc, bucket, objectName)
if err != nil {
return nil, err
}
tasks := make([]*entity.Task, 0, len(ranges))
for _, pageRange := range ranges {
tasks = append(tasks, s.newDocumentParseTask(doc, pageRange.from, pageRange.to, priority))
}
return tasks, nil
}
func (s *DocumentService) newDocumentParseTask(doc *entity.Document, fromPage, toPage, priority int64) *entity.Task {
now := time.Now()
progressMsg := ""
digest := documentParseTaskDigest(doc, fromPage, toPage)
chunkIDs := ""
return &entity.Task{
ID: common.GenerateUUID(),
DocID: doc.ID,
FromPage: fromPage,
ToPage: toPage,
TaskType: "",
Priority: priority,
BeginAt: &now,
Progress: 0,
ProgressMsg: &progressMsg,
Digest: &digest,
ChunkIDs: &chunkIDs,
}
}
func documentParseTaskRanges(doc *entity.Document, bucket, objectName string) ([]documentParsePageRange, error) {
if doc.Type == "pdf" {
binary, err := documentStorageBinary(bucket, objectName)
if err != nil {
return nil, err
}
pages := documentEstimatePDFPageCount(binary)
pageSize := int64(documentParserConfigInt(doc.ParserConfig, "task_page_size", 12))
if doc.ParserID == string(entity.ParserTypePaper) {
pageSize = int64(documentParserConfigInt(doc.ParserConfig, "task_page_size", 22))
}
if doc.ParserID == string(entity.ParserTypeOne) ||
doc.ParserID == string(entity.ParserTypeKG) ||
documentParserConfigBool(doc.ParserConfig, "toc_extraction", false) {
pageSize = maximumTaskPageNumber
}
if pageSize <= 0 {
pageSize = 12
}
ranges := make([]documentParsePageRange, 0)
for _, configuredRange := range documentParserConfigPageRanges(doc.ParserConfig) {
start := configuredRange.from - 1
if start < 0 {
start = 0
}
end := configuredRange.to - 1
if pages >= 0 && end > pages {
end = pages
}
for page := start; page < end; page += pageSize {
to := page + pageSize
if to > end {
to = end
}
ranges = append(ranges, documentParsePageRange{from: page, to: to})
}
}
if len(ranges) == 0 {
ranges = append(ranges, documentParsePageRange{from: 0, to: maximumTaskPageNumber})
}
return ranges, nil
}
if doc.ParserID == string(entity.ParserTypeTable) {
binary, err := documentStorageBinary(bucket, objectName)
if err != nil {
return nil, err
}
rows := documentEstimateTableRowCount(documentName(doc), binary)
if rows <= 0 {
return []documentParsePageRange{{from: 0, to: maximumTaskPageNumber}}, nil
}
ranges := make([]documentParsePageRange, 0, (rows+2999)/3000)
for row := int64(0); row < int64(rows); row += 3000 {
to := row + 3000
if to > int64(rows) {
to = int64(rows)
}
ranges = append(ranges, documentParsePageRange{from: row, to: to})
}
return ranges, nil
}
return []documentParsePageRange{{from: 0, to: maximumTaskPageNumber}}, nil
}
func documentStorageBinary(bucket, objectName string) ([]byte, error) {
storageImpl := storage.GetStorageFactory().GetStorage()
if storageImpl == nil {
return nil, fmt.Errorf("storage not initialized")
}
return storageImpl.Get(bucket, objectName)
}
func documentName(doc *entity.Document) string {
if doc == nil || doc.Name == nil {
return ""
}
return *doc.Name
}
func documentParserConfigInt(config map[string]interface{}, key string, fallback int) int {
value, ok := config[key]
if !ok || value == nil {
return fallback
}
switch typedValue := value.(type) {
case int:
return typedValue
case int64:
return int(typedValue)
case float64:
return int(typedValue)
case json.Number:
if intValue, err := typedValue.Int64(); err == nil {
return int(intValue)
}
case string:
if intValue, err := strconv.Atoi(strings.TrimSpace(typedValue)); err == nil {
return intValue
}
}
return fallback
}
func documentParserConfigBool(config map[string]interface{}, key string, fallback bool) bool {
value, ok := config[key]
if !ok || value == nil {
return fallback
}
switch typedValue := value.(type) {
case bool:
return typedValue
case string:
switch strings.ToLower(strings.TrimSpace(typedValue)) {
case "true", "1", "yes", "on":
return true
case "false", "0", "no", "off":
return false
}
}
return fallback
}
func documentParserConfigPageRanges(config map[string]interface{}) []documentParsePageRange {
defaultRanges := []documentParsePageRange{{from: 1, to: 100000}}
raw, ok := config["pages"]
if !ok || raw == nil {
return defaultRanges
}
rawRanges, ok := raw.([]interface{})
if !ok || len(rawRanges) == 0 {
return defaultRanges
}
ranges := make([]documentParsePageRange, 0, len(rawRanges))
for _, rawRange := range rawRanges {
rangeValues, ok := rawRange.([]interface{})
if !ok || len(rangeValues) < 2 {
continue
}
from, okFrom := documentToInt64(rangeValues[0])
to, okTo := documentToInt64(rangeValues[1])
if okFrom && okTo && to > from {
ranges = append(ranges, documentParsePageRange{from: from, to: to})
}
}
if len(ranges) == 0 {
return defaultRanges
}
return ranges
}
func documentToInt64(value interface{}) (int64, bool) {
switch typedValue := value.(type) {
case int:
return int64(typedValue), true
case int64:
return typedValue, true
case float64:
return int64(typedValue), true
case json.Number:
intValue, err := typedValue.Int64()
return intValue, err == nil
case string:
intValue, err := strconv.ParseInt(strings.TrimSpace(typedValue), 10, 64)
return intValue, err == nil
default:
return 0, false
}
}
var documentPDFPagePattern = regexp.MustCompile(`/Type\s*/Page\b`)
func documentEstimatePDFPageCount(binary []byte) int64 {
if len(binary) == 0 {
return 0
}
return int64(len(documentPDFPagePattern.FindAll(binary, -1)))
}
func documentEstimateTableRowCount(name string, binary []byte) int {
switch strings.ToLower(filepath.Ext(name)) {
case ".xlsx":
if rows, err := documentCountXLSXRows(binary); err == nil {
return rows
}
case ".csv", ".tsv", ".txt":
return documentCountDelimitedRows(name, binary)
}
return 0
}
func documentCountDelimitedRows(name string, binary []byte) int {
reader := csv.NewReader(bytes.NewReader(binary))
reader.FieldsPerRecord = -1
reader.ReuseRecord = true
if strings.EqualFold(filepath.Ext(name), ".tsv") {
reader.Comma = '\t'
}
rows := 0
for {
_, err := reader.Read()
if err == nil {
rows++
continue
}
if err == io.EOF {
break
}
rows += bytes.Count(binary, []byte{'\n'})
if len(binary) > 0 && binary[len(binary)-1] != '\n' {
rows++
}
break
}
return rows
}
func documentCountXLSXRows(binary []byte) (int, error) {
zipReader, err := zip.NewReader(bytes.NewReader(binary), int64(len(binary)))
if err != nil {
return 0, err
}
maxRows := 0
for _, file := range zipReader.File {
if !strings.HasPrefix(file.Name, "xl/worksheets/") || !strings.HasSuffix(file.Name, ".xml") {
continue
}
rows, err := documentCountWorksheetRows(file)
if err != nil {
return 0, err
}
if rows > maxRows {
maxRows = rows
}
}
return maxRows, nil
}
func documentCountWorksheetRows(file *zip.File) (int, error) {
reader, err := file.Open()
if err != nil {
return 0, err
}
defer reader.Close()
decoder := xml.NewDecoder(reader)
rows := 0
for {
token, err := decoder.Token()
if err == io.EOF {
break
}
if err != nil {
return 0, err
}
start, ok := token.(xml.StartElement)
if ok && start.Name.Local == "row" {
rows++
}
}
return rows, nil
}
func (s *DocumentService) beginDocumentParse(docID string) error {
now := time.Now()
return dao.GetDB().Model(&entity.Document{}).Where("id = ?", docID).Updates(map[string]interface{}{
"progress_msg": "Task is queued...",
"process_begin_at": now,
"progress": rand.Float64() * 0.01,
"run": string(entity.TaskStatusRunning),
"chunk_num": 0,
"token_num": 0,
}).Error
}
func documentParseQueueName(doc *entity.Document, priority int64) string {
suffix := "common"
if doc.ParserID == string(entity.ParserTypeResume) {
suffix = "resume"
}
return fmt.Sprintf("te.%d.%s", priority, suffix)
}
func documentTaskMessage(task *entity.Task) map[string]interface{} {
beginAt := ""
if task.BeginAt != nil {
beginAt = task.BeginAt.Format("2006-01-02 15:04:05")
}
digest := ""
if task.Digest != nil {
digest = *task.Digest
}
return map[string]interface{}{
"id": task.ID,
"doc_id": task.DocID,
"from_page": task.FromPage,
"to_page": task.ToPage,
"progress": task.Progress,
"priority": task.Priority,
"begin_at": beginAt,
"digest": digest,
}
}
func documentParseTaskDigest(doc *entity.Document, fromPage, toPage int64) string {
hasher := xxhash.New()
config := map[string]interface{}{
"doc_id": doc.ID,
"kb_id": doc.KbID,
"parser_id": doc.ParserID,
"parser_config": doc.ParserConfig,
}
keys := make([]string, 0, len(config))
for key := range config {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
b, err := json.Marshal(config[key])
if err != nil {
hasher.WriteString(fmt.Sprint(config[key]))
} else {
hasher.Write(b)
}
}
hasher.WriteString(doc.ID)
hasher.WriteString(strconv.FormatInt(fromPage, 10))
hasher.WriteString(strconv.FormatInt(toPage, 10))
return fmt.Sprintf("%x", hasher.Sum64())
}
func (s *DocumentService) clearKBChunkNumWhenRerun(doc *entity.Document) error {
if doc == nil {
return fmt.Errorf("document is nil")
}
return dao.GetDB().Model(&entity.Knowledgebase{}).Where("id = ?", doc.KbID).Updates(map[string]interface{}{
"token_num": gorm.Expr("token_num - ?", doc.TokenNum),
"chunk_num": gorm.Expr("chunk_num - ?", doc.ChunkNum),
}).Error
}
func (s *DocumentService) ParseDocuments(datasetID, userID string, docIDs []string) ([]*ParseDocumentResponse, error) {
// create document parse id
// save to task table