mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-05 10:58:34 +08:00
feat[go]: datasets/<dataset_id>/chunks DELETE (#16185)
### What problem does this PR solve?
As title:
`documents.POST("/ingest", r.documentHandler.Ingest)`:
---
<img width="3750" height="2039" alt="image"
src="https://github.com/user-attachments/assets/533c1c3d-af3e-47e6-9f51-a278539b7066"
/>
`datasets.DELETE("/:dataset_id/chunks", r.chunkHandler.StopParsing)`
---
<img width="3621" height="2040" alt="image"
src="https://github.com/user-attachments/assets/022adcdb-1e47-4883-9611-1a695c34007d"
/>
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -38,6 +38,7 @@ type chunkService interface {
|
||||
UpdateChunk(req *service.UpdateChunkRequest, userID string) error
|
||||
RemoveChunks(req *service.RemoveChunksRequest, userID string) (int64, error)
|
||||
Parse(userID, datasetID string, req *service.ParseFileRequest) (map[string]interface{}, common.ErrorCode, error)
|
||||
StopParsing(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error)
|
||||
}
|
||||
|
||||
// ChunkHandler chunk handler
|
||||
@@ -224,8 +225,8 @@ func (h *ChunkHandler) Parse(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
datasetID := strings.TrimSpace(c.Param("dataset_id"))
|
||||
if datasetID == "" {
|
||||
datasetId := strings.TrimSpace(c.Param("dataset_id"))
|
||||
if datasetId == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeBadRequest,
|
||||
"message": "dataset_id is required",
|
||||
@@ -243,7 +244,7 @@ func (h *ChunkHandler) Parse(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
data, code, err := h.chunkService.Parse(userID, datasetID, &req)
|
||||
data, code, err := h.chunkService.Parse(userID, datasetId, &req)
|
||||
if code != common.CodeSuccess {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
@@ -353,6 +354,59 @@ func parseAvailableQuery(raw string) (int, bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ChunkHandler) StopParsing(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
datasetID := c.Param("dataset_id")
|
||||
if datasetID == "" {
|
||||
jsonError(c, common.CodeDataError, "dataset_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req service.StopParsingRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
jsonError(c, common.CodeDataError, err.Error())
|
||||
return
|
||||
}
|
||||
if len(req.DocumentIDs) == 0 {
|
||||
jsonError(c, common.CodeDataError, "`document_ids` is required")
|
||||
return
|
||||
}
|
||||
|
||||
resp, code, err := h.chunkService.StopParsing(user.ID, datasetID, req)
|
||||
if err != nil {
|
||||
var data interface{}
|
||||
if resp != nil {
|
||||
data = resp.Data
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"data": data,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
message := "success"
|
||||
var data interface{}
|
||||
if resp != nil {
|
||||
if resp.Message != "" {
|
||||
message = resp.Message
|
||||
}
|
||||
data = resp.Data
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"data": data,
|
||||
"message": message,
|
||||
})
|
||||
}
|
||||
|
||||
// List retrieves chunks for a document.
|
||||
// @Summary List Chunks
|
||||
// @Description Retrieve paginated chunks for a document with optional filtering.
|
||||
|
||||
@@ -23,6 +23,7 @@ type mockChunkSvc struct {
|
||||
listFn func(req *service.ListChunksRequest, userID string) (*service.ListChunksResponse, error)
|
||||
switchChunksFn func(userID, datasetID, documentID string, availableInt int, chunkIDs []string) error
|
||||
updateChunkFn func(req *service.UpdateChunkRequest, userID string) error
|
||||
stopParsingFn func(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error)
|
||||
}
|
||||
|
||||
func (m *mockChunkSvc) RetrievalTest(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
|
||||
@@ -58,6 +59,12 @@ func (m *mockChunkSvc) UpdateChunk(req *service.UpdateChunkRequest, userID strin
|
||||
func (m *mockChunkSvc) RemoveChunks(*service.RemoveChunksRequest, string) (int64, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
func (m *mockChunkSvc) StopParsing(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error) {
|
||||
if m.stopParsingFn != nil {
|
||||
return m.stopParsingFn(userID, datasetID, req)
|
||||
}
|
||||
panic("not implemented")
|
||||
}
|
||||
func (m *mockChunkSvc) Parse(string, string, *service.ParseFileRequest) (map[string]interface{}, common.ErrorCode, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -84,6 +91,18 @@ func setupChunkRetrievalTestNoAuth() *gin.Engine {
|
||||
return r
|
||||
}
|
||||
|
||||
func setupChunkStopParsingTest(userID string) (*gin.Engine, *mockChunkSvc) {
|
||||
mock := &mockChunkSvc{}
|
||||
h := &ChunkHandler{chunkService: mock}
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: userID})
|
||||
})
|
||||
r.DELETE("/api/v1/datasets/:dataset_id/chunks", h.StopParsing)
|
||||
return r, mock
|
||||
}
|
||||
|
||||
func setupChunkHandlerWithUser(userID string, mock *mockChunkSvc) (*gin.Engine, *ChunkHandler) {
|
||||
h := &ChunkHandler{chunkService: mock}
|
||||
gin.SetMode(gin.TestMode)
|
||||
@@ -254,6 +273,103 @@ func TestChunkRetrieval_EmptyQuestion(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkStopParsing_Success(t *testing.T) {
|
||||
r, mock := setupChunkStopParsingTest("user1")
|
||||
mock.stopParsingFn = func(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error) {
|
||||
if userID != "user1" {
|
||||
t.Fatalf("expected user1, got %q", userID)
|
||||
}
|
||||
if datasetID != "kb1" {
|
||||
t.Fatalf("expected kb1, got %q", datasetID)
|
||||
}
|
||||
if len(req.DocumentIDs) != 2 || req.DocumentIDs[0] != "doc1" || req.DocumentIDs[1] != "doc2" {
|
||||
t.Fatalf("unexpected document IDs: %#v", req.DocumentIDs)
|
||||
}
|
||||
return nil, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("DELETE", "/api/v1/datasets/kb1/chunks", strings.NewReader(`{"document_ids":["doc1","doc2"]}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected code 0, got %v: %s", resp["code"], w.Body.String())
|
||||
}
|
||||
if resp["message"] != "success" {
|
||||
t.Fatalf("expected success message, got %v", resp["message"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkStopParsingRouteRequiresDocumentIDs(t *testing.T) {
|
||||
r, mock := setupChunkStopParsingTest("user1")
|
||||
mock.stopParsingFn = func(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error) {
|
||||
t.Fatal("service should not be called when document_ids is missing")
|
||||
return nil, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("DELETE", "/api/v1/datasets/kb1/chunks", strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeDataError) {
|
||||
t.Fatalf("expected data error, got %v: %s", resp["code"], w.Body.String())
|
||||
}
|
||||
if resp["message"] != "`document_ids` is required" {
|
||||
t.Fatalf("unexpected message: %v", resp["message"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkStopParsing_InvalidStateIncludesPythonErrorCode(t *testing.T) {
|
||||
r, mock := setupChunkStopParsingTest("user1")
|
||||
mock.stopParsingFn = func(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error) {
|
||||
return &service.StopParsingResponse{
|
||||
Data: map[string]interface{}{"error_code": "DOC_STOP_PARSING_INVALID_STATE"},
|
||||
}, common.CodeDataError, errors.New("Can't stop parsing document that has not started or already completed")
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("DELETE", "/api/v1/datasets/kb1/chunks", strings.NewReader(`{"document_ids":["doc1"]}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeDataError) {
|
||||
t.Fatalf("expected data error, got %v: %s", resp["code"], w.Body.String())
|
||||
}
|
||||
data, ok := resp["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected data object, got %T", resp["data"])
|
||||
}
|
||||
if data["error_code"] != "DOC_STOP_PARSING_INVALID_STATE" {
|
||||
t.Fatalf("unexpected error_code: %v", data["error_code"])
|
||||
}
|
||||
if resp["message"] != "Can't stop parsing document that has not started or already completed" {
|
||||
t.Fatalf("unexpected message: %v", resp["message"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkRetrieval_WhitespaceQuestion(t *testing.T) {
|
||||
r, _ := setupChunkRetrievalTest("user1")
|
||||
|
||||
|
||||
@@ -67,6 +67,7 @@ type documentServiceIface interface {
|
||||
ListIngestionTasks(userID string, datasetID *string, page, pageSize int) ([]*entity.IngestionTask, error)
|
||||
IngestDocuments(datasetID, userID string, docIDs []string) ([]*service.ParseDocumentResponse, error)
|
||||
StopIngestionTasks(tasks []string, userID string) ([]*entity.IngestionTask, error)
|
||||
Ingest(userID string, req *service.IngestDocumentRequest) (common.ErrorCode, error)
|
||||
RemoveIngestionTasks(tasks []string, userID string) ([]map[string]string, error)
|
||||
BatchUpdateDocumentStatus(userID, datasetID, status string, DocumentIDs []string) (map[string]interface{}, common.ErrorCode, error)
|
||||
}
|
||||
@@ -874,6 +875,37 @@ func (h *DocumentHandler) SetMeta(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func (h *DocumentHandler) Ingest(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
userID := strings.TrimSpace(user.ID)
|
||||
if userID == "" {
|
||||
jsonError(c, common.CodeAuthenticationError, "No Authentication")
|
||||
return
|
||||
}
|
||||
|
||||
var req service.IngestDocumentRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
jsonError(c, common.CodeBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if code, err := h.documentService.Ingest(userID, &req); err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"message": "success",
|
||||
"data": true,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteMetaRequest represents the request for deleting document metadata
|
||||
type DeleteMetaRequest struct {
|
||||
DocID string `json:"doc_id" binding:"required"`
|
||||
|
||||
@@ -45,6 +45,19 @@ type fakeDocumentService struct {
|
||||
metadataErr error
|
||||
metadataKBID string
|
||||
metadataDocIDs []string
|
||||
ingestCode common.ErrorCode
|
||||
ingestErr error
|
||||
ingestUserID string
|
||||
ingestReq *service.IngestDocumentRequest
|
||||
}
|
||||
|
||||
func (f *fakeDocumentService) Ingest(userID string, req *service.IngestDocumentRequest) (common.ErrorCode, error) {
|
||||
f.ingestUserID = userID
|
||||
f.ingestReq = req
|
||||
if f.ingestCode != 0 || f.ingestErr != nil {
|
||||
return f.ingestCode, f.ingestErr
|
||||
}
|
||||
return common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentService) UpdateDatasetDocument(userID, datasetID, documentID string, req *service.UpdateDatasetDocumentRequest, present map[string]bool) (*service.UpdateDatasetDocumentResponse, common.ErrorCode, error) {
|
||||
@@ -176,6 +189,21 @@ func setupGinContextWithUser(method, path, body string) (*gin.Context, *httptest
|
||||
return c, w
|
||||
}
|
||||
|
||||
func setupDocumentIngestRoute(userID string, svc *fakeDocumentService) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &DocumentHandler{
|
||||
documentService: svc,
|
||||
datasetService: service.NewDatasetService(),
|
||||
}
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: userID})
|
||||
c.Set("user_id", userID)
|
||||
})
|
||||
r.POST("/api/v1/documents/ingest", h.Ingest)
|
||||
return r
|
||||
}
|
||||
|
||||
func TestDeleteDocumentsHandler_Success(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -323,6 +351,116 @@ func TestDeleteDocumentsHandler_MissingDatasetID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDocumentHandlerIngestMatchesPythonResponseShape(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
fake := &fakeDocumentService{}
|
||||
h := &DocumentHandler{
|
||||
documentService: fake,
|
||||
datasetService: service.NewDatasetService(),
|
||||
}
|
||||
|
||||
c, w := setupGinContextWithUser("POST", "/api/v1/documents/ingest", `{"doc_ids":["doc-1"],"run":"1"}`)
|
||||
h.Ingest(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected top-level code 0, got %v", resp["code"])
|
||||
}
|
||||
if resp["data"] != true {
|
||||
t.Fatalf("expected top-level data=true, got %#v", resp["data"])
|
||||
}
|
||||
if _, ok := resp["data"].(map[string]interface{}); ok {
|
||||
t.Fatalf("response must not nest code/message under data: %#v", resp["data"])
|
||||
}
|
||||
if fake.ingestUserID != "user-1" {
|
||||
t.Fatalf("expected user-1, got %q", fake.ingestUserID)
|
||||
}
|
||||
if fake.ingestReq == nil || len(fake.ingestReq.DocIDs) != 1 || fake.ingestReq.DocIDs[0] != "doc-1" {
|
||||
t.Fatalf("unexpected ingest request: %#v", fake.ingestReq)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDocumentIngestRoutePassesPythonBodyToService(t *testing.T) {
|
||||
fake := &fakeDocumentService{}
|
||||
r := setupDocumentIngestRoute("user-1", fake)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/v1/documents/ingest", strings.NewReader(`{"doc_ids":["doc-1","doc-2"],"run":1,"delete":true,"apply_kb":true}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeSuccess) || resp["data"] != true {
|
||||
t.Fatalf("unexpected response: %s", w.Body.String())
|
||||
}
|
||||
if fake.ingestUserID != "user-1" {
|
||||
t.Fatalf("userID = %q, want user-1", fake.ingestUserID)
|
||||
}
|
||||
if fake.ingestReq == nil {
|
||||
t.Fatal("service did not receive ingest request")
|
||||
}
|
||||
if len(fake.ingestReq.DocIDs) != 2 || fake.ingestReq.DocIDs[0] != "doc-1" || fake.ingestReq.DocIDs[1] != "doc-2" {
|
||||
t.Fatalf("doc_ids = %#v, want [doc-1 doc-2]", fake.ingestReq.DocIDs)
|
||||
}
|
||||
if fmt.Sprint(fake.ingestReq.Run) != "1" {
|
||||
t.Fatalf("run = %#v, want 1", fake.ingestReq.Run)
|
||||
}
|
||||
if !fake.ingestReq.Delete {
|
||||
t.Fatal("delete = false, want true")
|
||||
}
|
||||
if !fake.ingestReq.ApplyKB {
|
||||
t.Fatal("apply_kb = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDocumentHandlerIngestPropagatesServiceErrorCode(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
fake := &fakeDocumentService{
|
||||
ingestCode: common.CodeAuthenticationError,
|
||||
ingestErr: fmt.Errorf("No authorization."),
|
||||
}
|
||||
h := &DocumentHandler{
|
||||
documentService: fake,
|
||||
datasetService: service.NewDatasetService(),
|
||||
}
|
||||
|
||||
c, w := setupGinContextWithUser("POST", "/api/v1/documents/ingest", `{"doc_ids":["doc-1"],"run":"1"}`)
|
||||
h.Ingest(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeAuthenticationError) {
|
||||
t.Fatalf("expected auth error code, got %v", resp["code"])
|
||||
}
|
||||
if resp["message"] != "No authorization." {
|
||||
t.Fatalf("unexpected message: %v", resp["message"])
|
||||
}
|
||||
if resp["data"] != nil {
|
||||
t.Fatalf("expected nil data, got %#v", resp["data"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopParseDocumentsHandler_EmptyDocIDs(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -243,6 +243,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
documents.GET("/:id", r.documentHandler.GetDocumentByID)
|
||||
documents.PUT("/:id", r.documentHandler.UpdateDocument)
|
||||
documents.DELETE("/:id", r.documentHandler.DeleteDocument)
|
||||
documents.POST("/ingest", r.documentHandler.Ingest)
|
||||
}
|
||||
|
||||
// Chat routes
|
||||
@@ -320,6 +321,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
datasets.DELETE("/ingestion/tasks", r.documentHandler.RemoveIngestionTasks)
|
||||
//datasets.POST("/:dataset_id/documents/parse", r.documentHandler.ParseDocuments)
|
||||
//datasets.POST("/:dataset_id/documents/stop", r.documentHandler.StopParseDocuments)
|
||||
datasets.DELETE("/:dataset_id/chunks", r.chunkHandler.StopParsing)
|
||||
datasets.DELETE("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.RemoveChunks)
|
||||
datasets.PUT("/:dataset_id/documents/:document_id/metadata/config", r.datasetsHandler.UpdateDocumentMetadataConfig)
|
||||
datasets.POST("/:dataset_id/metadata/update", r.documentHandler.MetadataBatchUpdate)
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"math/rand"
|
||||
"path/filepath"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/engine/redis"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/entity/models"
|
||||
"ragflow/internal/server"
|
||||
@@ -42,7 +43,6 @@ import (
|
||||
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/engine/redis"
|
||||
"ragflow/internal/engine/types"
|
||||
"ragflow/internal/service"
|
||||
"ragflow/internal/service/nlp"
|
||||
@@ -64,6 +64,7 @@ type ChunkService struct {
|
||||
kbDAO *dao.KnowledgebaseDAO
|
||||
userTenantDAO *dao.UserTenantDAO
|
||||
documentDAO *dao.DocumentDAO
|
||||
taskDAO *dao.TaskDAO
|
||||
searchService *service.SearchService
|
||||
|
||||
accessibleFunc func(string, string) bool
|
||||
@@ -85,6 +86,7 @@ func NewChunkService() *ChunkService {
|
||||
kbDAO: dao.NewKnowledgebaseDAO(),
|
||||
userTenantDAO: dao.NewUserTenantDAO(),
|
||||
documentDAO: dao.NewDocumentDAO(),
|
||||
taskDAO: dao.NewTaskDAO(),
|
||||
searchService: service.NewSearchService(),
|
||||
}
|
||||
}
|
||||
@@ -571,6 +573,112 @@ func (s *ChunkService) Get(req *service.GetChunkRequest, userID string) (*servic
|
||||
return &service.GetChunkResponse{Chunk: chunk}, nil
|
||||
}
|
||||
|
||||
const (
|
||||
docStopParsingInvalidStateMessage = "Can't stop parsing document that has not started or already completed"
|
||||
docStopParsingInvalidStateErrorCode = "DOC_STOP_PARSING_INVALID_STATE"
|
||||
)
|
||||
|
||||
func (s *ChunkService) cancelAllTasksOfDoc(docID string) error {
|
||||
tasks, err := s.taskDAO.GetByDocID(docID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get tasks for document %s: %w", docID, err)
|
||||
}
|
||||
|
||||
redisClient := redis.Get()
|
||||
if redisClient == nil {
|
||||
common.Warn(fmt.Sprintf("Redis unavailable; cannot cancel tasks for document %s", docID))
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, task := range tasks {
|
||||
if task == nil {
|
||||
continue
|
||||
}
|
||||
redisClient.Set(fmt.Sprintf("%s-cancel", task.ID), "x", 0)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ChunkService) StopParsing(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error) {
|
||||
if !s.kbDAO.Accessible(datasetID, userID) {
|
||||
return nil, common.CodeDataError, fmt.Errorf("You don't own the dataset %s", datasetID)
|
||||
}
|
||||
|
||||
if len(req.DocumentIDs) == 0 {
|
||||
return nil, common.CodeDataError, fmt.Errorf("`document_ids` is required")
|
||||
}
|
||||
|
||||
kb, err := s.kbDAO.GetByID(datasetID)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, fmt.Errorf("You don't own the dataset %s", datasetID)
|
||||
}
|
||||
|
||||
docIDs, duplicateMessages := service.CheckDuplicateIDs(req.DocumentIDs, "document")
|
||||
successCount := 0
|
||||
ctx := context.Background()
|
||||
indexName := service.IndexName(kb.TenantID)
|
||||
|
||||
for _, docID := range docIDs {
|
||||
doc, err := s.documentDAO.GetByDocumentIDAndDatasetID(docID, datasetID)
|
||||
if err != nil || doc == nil {
|
||||
return nil, common.CodeDataError, fmt.Errorf("You don't own the document %s", docID)
|
||||
}
|
||||
|
||||
if doc.Run == nil || *doc.Run != string(entity.TaskStatusRunning) {
|
||||
return &service.StopParsingResponse{
|
||||
Data: map[string]interface{}{"error_code": docStopParsingInvalidStateErrorCode},
|
||||
}, common.CodeDataError, fmt.Errorf("%s", docStopParsingInvalidStateMessage)
|
||||
}
|
||||
|
||||
if err := s.cancelAllTasksOfDoc(docID); err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"run": string(entity.TaskStatusCancel),
|
||||
"progress": 0,
|
||||
"chunk_num": 0,
|
||||
}
|
||||
if err := s.documentDAO.UpdateByID(doc.ID, updates); err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to update document %s: %w", doc.ID, err)
|
||||
}
|
||||
|
||||
if s.docEngine != nil {
|
||||
exists, err := s.docEngine.ChunkStoreExists(ctx, indexName, datasetID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to check chunk store %s/%s: %w", indexName, datasetID, err)
|
||||
}
|
||||
if exists {
|
||||
if _, err := s.docEngine.DeleteChunks(ctx, map[string]interface{}{"doc_id": doc.ID}, indexName, datasetID); err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to delete chunks for document %s: %w", doc.ID, err)
|
||||
}
|
||||
} else {
|
||||
common.Info(fmt.Sprintf("Skipping chunk delete during stop_parsing for doc %s: index %s/%s does not exist", doc.ID, indexName, datasetID))
|
||||
}
|
||||
} else {
|
||||
common.Info(fmt.Sprintf("Skipping chunk delete during stop_parsing for doc %s: index %s/%s does not exist", doc.ID, indexName, datasetID))
|
||||
}
|
||||
|
||||
successCount++
|
||||
}
|
||||
|
||||
if len(duplicateMessages) > 0 {
|
||||
if successCount > 0 {
|
||||
return &service.StopParsingResponse{
|
||||
Message: fmt.Sprintf("Partially stopped %d documents with %d errors", successCount, len(duplicateMessages)),
|
||||
Data: map[string]interface{}{
|
||||
"success_count": successCount,
|
||||
"errors": duplicateMessages,
|
||||
},
|
||||
}, common.CodeSuccess, nil
|
||||
}
|
||||
return nil, common.CodeDataError, fmt.Errorf("%s", strings.Join(duplicateMessages, ";"))
|
||||
}
|
||||
|
||||
return nil, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func checkDuplicateIDs(documentIDs []string, idTypes string) ([]string, []string) {
|
||||
idCount := make(map[string]int, len(documentIDs))
|
||||
duplicateMessages := make([]string, 0)
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/engine/redis"
|
||||
"ragflow/internal/server"
|
||||
"strings"
|
||||
|
||||
@@ -31,12 +32,22 @@ import (
|
||||
"ragflow/internal/utility"
|
||||
)
|
||||
|
||||
var (
|
||||
UNSTART = "0"
|
||||
RUNNING = "1"
|
||||
CANCEL = "2"
|
||||
DONE = "3"
|
||||
FAIL = "4"
|
||||
SCHEDULE = "5"
|
||||
)
|
||||
|
||||
// ChunkService chunk service
|
||||
type ChunkService struct {
|
||||
docEngine engine.DocEngine
|
||||
engineType server.EngineType
|
||||
embeddingCache *utility.EmbeddingLRU
|
||||
kbDAO *dao.KnowledgebaseDAO
|
||||
taskDAO *dao.TaskDAO
|
||||
userTenantDAO *dao.UserTenantDAO
|
||||
documentDAO *dao.DocumentDAO
|
||||
searchService *SearchService
|
||||
@@ -162,6 +173,131 @@ func (s *ChunkService) Get(req *GetChunkRequest, userID string) (*GetChunkRespon
|
||||
return &GetChunkResponse{Chunk: chunk}, nil
|
||||
}
|
||||
|
||||
type StopParsingRequest struct {
|
||||
DocumentIDs []string `json:"document_ids"`
|
||||
}
|
||||
|
||||
type StopParsingResponse struct {
|
||||
Data map[string]interface{}
|
||||
Message string
|
||||
}
|
||||
|
||||
func (s *ChunkService) cancelAllTasksOfDoc(docID string) error {
|
||||
tasks, err := s.taskDAO.GetByDocID(docID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get tasks for document %s: %w", docID, err)
|
||||
}
|
||||
|
||||
redisClient := redis.Get()
|
||||
if redisClient == nil {
|
||||
common.Logger.Warn(fmt.Sprintf("Redis unavailable; cannot cancel tasks for document %s", docID))
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, task := range tasks {
|
||||
if task == nil {
|
||||
continue
|
||||
}
|
||||
redisClient.Set(fmt.Sprintf("%s-cancel", task.ID), "x", 0)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func CheckDuplicateIDs(docList []string, idType string) ([]string, []string) {
|
||||
uniqueDocIDs := make([]string, 0)
|
||||
duplicateMessages := make([]string, 0)
|
||||
idCount := make(map[string]int)
|
||||
|
||||
for _, docID := range docList {
|
||||
idCount[docID] += 1
|
||||
}
|
||||
|
||||
for id, count := range idCount {
|
||||
if count > 1 {
|
||||
duplicateMessages = append(duplicateMessages, fmt.Sprintf("Duplicate %s ids: %s", idType, id))
|
||||
}
|
||||
uniqueDocIDs = append(uniqueDocIDs, id)
|
||||
}
|
||||
return uniqueDocIDs, duplicateMessages
|
||||
}
|
||||
|
||||
func IndexName(uid string) string {
|
||||
return fmt.Sprintf("ragflow_%s", uid)
|
||||
}
|
||||
|
||||
func (s *ChunkService) StopParsing(userID, datasetID string, req StopParsingRequest) (map[string]interface{}, common.ErrorCode, error) {
|
||||
if !s.kbDAO.Accessible(datasetID, userID) {
|
||||
return nil, common.CodeAuthenticationError, fmt.Errorf("You don't own the dataset %s", datasetID)
|
||||
}
|
||||
|
||||
if req.DocumentIDs == nil || len(req.DocumentIDs) == 0 {
|
||||
return nil, common.CodeDataError, fmt.Errorf("document_ids is required")
|
||||
}
|
||||
|
||||
docList, duplicateMessages := CheckDuplicateIDs(req.DocumentIDs, "document")
|
||||
|
||||
kb, err := s.kbDAO.GetByID(datasetID)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, fmt.Errorf("You don't own the dataset %s", datasetID)
|
||||
}
|
||||
|
||||
successCount := 0
|
||||
for _, id := range docList {
|
||||
doc, err := s.documentDAO.GetByDocumentIDAndDatasetID(id, datasetID)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, fmt.Errorf("You don't own the document %s", id)
|
||||
}
|
||||
if doc == nil {
|
||||
return nil, common.CodeDataError, fmt.Errorf("You don't own the document %s", id)
|
||||
}
|
||||
|
||||
if doc.Run == nil || *doc.Run != RUNNING {
|
||||
return nil, common.CodeDataError, fmt.Errorf("Can't stop parsing document that has not started or already completed")
|
||||
}
|
||||
|
||||
err = s.cancelAllTasksOfDoc(id)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
info := map[string]interface{}{
|
||||
"run": "2",
|
||||
"progress": 0,
|
||||
"chunk_num": 0,
|
||||
}
|
||||
if err := s.documentDAO.UpdateByID(doc.ID, info); err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to update document %s: %w", doc.ID, err)
|
||||
}
|
||||
|
||||
indexName := IndexName(kb.TenantID)
|
||||
|
||||
exists, err := s.docEngine.ChunkStoreExists(context.Background(), indexName, datasetID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to check chunk store %s/%s: %w", indexName, datasetID, err)
|
||||
}
|
||||
if exists {
|
||||
if _, err := s.docEngine.DeleteChunks(context.Background(), map[string]interface{}{"doc_id": doc.ID}, indexName, datasetID); err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to delete chunks for document %s: %w", doc.ID, err)
|
||||
}
|
||||
} else {
|
||||
common.Logger.Info(fmt.Sprintf("Skipping chunk delete during stop_parsing for doc %s: index %s/%s does not exist", doc.ID, indexName, datasetID))
|
||||
}
|
||||
successCount++
|
||||
}
|
||||
if len(duplicateMessages) > 0 {
|
||||
if successCount > 0 {
|
||||
return map[string]interface{}{
|
||||
"success_count": successCount,
|
||||
"errors": duplicateMessages,
|
||||
"message": fmt.Sprintf("Partially stopped %d documents with %d errors", successCount, len(duplicateMessages)),
|
||||
}, common.CodeSuccess, nil
|
||||
}
|
||||
return nil, common.CodeDataError, fmt.Errorf("%s", strings.Join(duplicateMessages, ";"))
|
||||
}
|
||||
return nil, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// ListChunksRequest request for listing chunks
|
||||
type ListChunksRequest struct {
|
||||
DatasetID string `json:"dataset_id,omitempty"`
|
||||
|
||||
@@ -17,11 +17,17 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"math/rand"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -43,6 +49,7 @@ import (
|
||||
"ragflow/internal/tokenizer"
|
||||
"ragflow/internal/utility"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -987,6 +994,584 @@ func (s *DocumentService) RemoveIngestionTasks(tasks []string, userID string) ([
|
||||
return deletedTasks, nil
|
||||
}
|
||||
|
||||
type IngestDocumentRequest struct {
|
||||
DocIDs []string `json:"doc_ids" binding:"required"`
|
||||
Run interface{} `json:"run" binding:"required"`
|
||||
Delete bool `json:"delete"`
|
||||
ApplyKB bool `json:"apply_kb"`
|
||||
}
|
||||
|
||||
type documentParsePageRange struct {
|
||||
from int64
|
||||
to int64
|
||||
}
|
||||
|
||||
func (s *DocumentService) Ingest(userID string, req *IngestDocumentRequest) (common.ErrorCode, error) {
|
||||
run := fmt.Sprint(req.Run)
|
||||
|
||||
docs, err := s.documentDAO.GetByIDs(req.DocIDs)
|
||||
if err != nil {
|
||||
return common.CodeExceptionError, fmt.Errorf("fail to get documents: %s", err.Error())
|
||||
}
|
||||
|
||||
docsByID := make(map[string]*entity.Document, len(docs))
|
||||
for _, doc := range docs {
|
||||
if doc != nil {
|
||||
docsByID[doc.ID] = doc
|
||||
}
|
||||
}
|
||||
|
||||
tableDoneCountByKB := make(map[string]int64)
|
||||
|
||||
for _, docID := range req.DocIDs {
|
||||
doc := docsByID[docID]
|
||||
if doc == nil {
|
||||
return common.CodeDataError, fmt.Errorf("Document not found!")
|
||||
}
|
||||
|
||||
kb, err := s.kbDAO.GetByID(doc.KbID)
|
||||
if err != nil {
|
||||
return common.CodeDataError, fmt.Errorf("Tenant not found!")
|
||||
}
|
||||
|
||||
if !s.kbDAO.Accessible(kb.ID, userID) {
|
||||
return common.CodeAuthenticationError, fmt.Errorf("No authorization.")
|
||||
}
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"run": run,
|
||||
"progress": 0,
|
||||
}
|
||||
|
||||
rerunWithDelete := run == string(entity.TaskStatusRunning) && req.Delete
|
||||
if rerunWithDelete {
|
||||
updates["progress_msg"] = ""
|
||||
updates["chunk_num"] = 0
|
||||
updates["token_num"] = 0
|
||||
}
|
||||
|
||||
if run == string(entity.TaskStatusCancel) {
|
||||
if err := s.cancelDocParse(doc); err != nil {
|
||||
return common.CodeDataError, err
|
||||
}
|
||||
}
|
||||
|
||||
if rerunWithDelete && doc.Run != nil && *doc.Run == string(entity.TaskStatusDone) {
|
||||
if err := s.clearKBChunkNumWhenRerun(doc); err != nil {
|
||||
return common.CodeExceptionError, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.documentDAO.UpdateByID(doc.ID, updates); err != nil {
|
||||
return common.CodeExceptionError, err
|
||||
}
|
||||
|
||||
if req.Delete {
|
||||
_, _ = s.taskDAO.DeleteByDocIDs([]string{doc.ID})
|
||||
indexName := fmt.Sprintf("ragflow_%s", kb.TenantID)
|
||||
if s.docEngine != nil {
|
||||
exists, err := s.docEngine.ChunkStoreExists(context.Background(), indexName, doc.KbID)
|
||||
if err != nil {
|
||||
return common.CodeExceptionError, err
|
||||
}
|
||||
if exists {
|
||||
if _, err := s.docEngine.DeleteChunks(context.Background(), map[string]interface{}{"doc_id": doc.ID}, indexName, doc.KbID); err != nil {
|
||||
return common.CodeExceptionError, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if run == string(entity.TaskStatusRunning) {
|
||||
if req.ApplyKB {
|
||||
if doc.ParserConfig == nil {
|
||||
doc.ParserConfig = entity.JSONMap{}
|
||||
}
|
||||
config := map[string]interface{}{
|
||||
"llm_id": kb.ParserConfig["llm_id"],
|
||||
"enable_metadata": false,
|
||||
"metadata": map[string]interface{}{},
|
||||
}
|
||||
if value, ok := kb.ParserConfig["enable_metadata"]; ok {
|
||||
config["enable_metadata"] = value
|
||||
}
|
||||
if value, ok := kb.ParserConfig["metadata"]; ok {
|
||||
config["metadata"] = value
|
||||
}
|
||||
if err := s.updateDocumentParserConfig(doc.ID, config); err != nil {
|
||||
return common.CodeExceptionError, err
|
||||
}
|
||||
for key, value := range config {
|
||||
doc.ParserConfig[key] = value
|
||||
}
|
||||
}
|
||||
if doc.PipelineID != nil && strings.TrimSpace(*doc.PipelineID) != "" {
|
||||
if err := s.queueDocumentDataflowTask(kb, doc, strings.TrimSpace(*doc.PipelineID), 0); err != nil {
|
||||
return common.CodeExceptionError, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if doc.ParserID == string(entity.ParserTypeTable) {
|
||||
doneCount, ok := tableDoneCountByKB[doc.KbID]
|
||||
if !ok {
|
||||
count, err := s.countDoneDocuments(doc.KbID)
|
||||
if err != nil {
|
||||
return common.CodeExceptionError, err
|
||||
}
|
||||
doneCount = count
|
||||
tableDoneCountByKB[doc.KbID] = doneCount
|
||||
if doneCount <= 0 {
|
||||
if err := s.kbDAO.DeleteFieldMap(doc.KbID); err != nil && !dao.IsNotFoundErr(err) {
|
||||
return common.CodeExceptionError, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if _, err := s.taskDAO.DeleteByDocIDs([]string{doc.ID}); err != nil {
|
||||
return common.CodeExceptionError, err
|
||||
}
|
||||
bucket, objectName, err := s.GetDocumentStorageAddress(doc)
|
||||
if err != nil {
|
||||
return common.CodeExceptionError, err
|
||||
}
|
||||
if err := s.queueDocumentParseTasks(doc, bucket, objectName, 0); err != nil {
|
||||
return common.CodeExceptionError, err
|
||||
}
|
||||
if err := s.beginDocumentParse(doc.ID); err != nil {
|
||||
return common.CodeExceptionError, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (s *DocumentService) countDoneDocuments(datasetID string) (int64, error) {
|
||||
var count int64
|
||||
err := dao.GetDB().Model(&entity.Document{}).
|
||||
Where("kb_id = ? AND run = ?", datasetID, string(entity.TaskStatusDone)).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *DocumentService) queueDocumentParseTasks(doc *entity.Document, bucket, objectName string, priority int64) error {
|
||||
if _, err := s.taskDAO.DeleteByDocIDs([]string{doc.ID}); err != nil {
|
||||
return err
|
||||
}
|
||||
tasks, err := s.newDocumentParseTasks(doc, bucket, objectName, priority)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.taskDAO.CreateMany(tasks); err != nil {
|
||||
return err
|
||||
}
|
||||
queueName := documentParseQueueName(doc, priority)
|
||||
for _, task := range tasks {
|
||||
if task.Progress >= 1 {
|
||||
continue
|
||||
}
|
||||
if redisClient := redis.Get(); redisClient == nil || !redisClient.QueueProduct(queueName, documentTaskMessage(task)) {
|
||||
return fmt.Errorf("Can't access Redis. Please check the Redis' status.")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DocumentService) queueDocumentDataflowTask(kb *entity.Knowledgebase, doc *entity.Document, flowID string, priority int64) error {
|
||||
if _, err := s.taskDAO.DeleteByDocIDs([]string{doc.ID}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.beginDocumentParse(doc.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
task := s.newDocumentParseTask(doc, 0, maximumTaskPageNumber, priority)
|
||||
task.TaskType = "dataflow"
|
||||
if err := s.taskDAO.CreateMany([]*entity.Task{task}); err != nil {
|
||||
return err
|
||||
}
|
||||
message := documentTaskMessage(task)
|
||||
message["task_type"] = task.TaskType
|
||||
message["kb_id"] = doc.KbID
|
||||
message["tenant_id"] = kb.TenantID
|
||||
message["dataflow_id"] = flowID
|
||||
message["file"] = nil
|
||||
if redisClient := redis.Get(); redisClient == nil || !redisClient.QueueProduct(documentParseQueueName(doc, priority), message) {
|
||||
return fmt.Errorf("Can't access Redis. Please check the Redis' status.")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DocumentService) newDocumentParseTasks(doc *entity.Document, bucket, objectName string, priority int64) ([]*entity.Task, error) {
|
||||
ranges, err := documentParseTaskRanges(doc, bucket, objectName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tasks := make([]*entity.Task, 0, len(ranges))
|
||||
for _, pageRange := range ranges {
|
||||
tasks = append(tasks, s.newDocumentParseTask(doc, pageRange.from, pageRange.to, priority))
|
||||
}
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
func (s *DocumentService) newDocumentParseTask(doc *entity.Document, fromPage, toPage, priority int64) *entity.Task {
|
||||
now := time.Now()
|
||||
progressMsg := ""
|
||||
digest := documentParseTaskDigest(doc, fromPage, toPage)
|
||||
chunkIDs := ""
|
||||
return &entity.Task{
|
||||
ID: common.GenerateUUID(),
|
||||
DocID: doc.ID,
|
||||
FromPage: fromPage,
|
||||
ToPage: toPage,
|
||||
TaskType: "",
|
||||
Priority: priority,
|
||||
BeginAt: &now,
|
||||
Progress: 0,
|
||||
ProgressMsg: &progressMsg,
|
||||
Digest: &digest,
|
||||
ChunkIDs: &chunkIDs,
|
||||
}
|
||||
}
|
||||
|
||||
func documentParseTaskRanges(doc *entity.Document, bucket, objectName string) ([]documentParsePageRange, error) {
|
||||
if doc.Type == "pdf" {
|
||||
binary, err := documentStorageBinary(bucket, objectName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pages := documentEstimatePDFPageCount(binary)
|
||||
pageSize := int64(documentParserConfigInt(doc.ParserConfig, "task_page_size", 12))
|
||||
if doc.ParserID == string(entity.ParserTypePaper) {
|
||||
pageSize = int64(documentParserConfigInt(doc.ParserConfig, "task_page_size", 22))
|
||||
}
|
||||
if doc.ParserID == string(entity.ParserTypeOne) ||
|
||||
doc.ParserID == string(entity.ParserTypeKG) ||
|
||||
documentParserConfigBool(doc.ParserConfig, "toc_extraction", false) {
|
||||
pageSize = maximumTaskPageNumber
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 12
|
||||
}
|
||||
ranges := make([]documentParsePageRange, 0)
|
||||
for _, configuredRange := range documentParserConfigPageRanges(doc.ParserConfig) {
|
||||
start := configuredRange.from - 1
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
end := configuredRange.to - 1
|
||||
if pages >= 0 && end > pages {
|
||||
end = pages
|
||||
}
|
||||
for page := start; page < end; page += pageSize {
|
||||
to := page + pageSize
|
||||
if to > end {
|
||||
to = end
|
||||
}
|
||||
ranges = append(ranges, documentParsePageRange{from: page, to: to})
|
||||
}
|
||||
}
|
||||
if len(ranges) == 0 {
|
||||
ranges = append(ranges, documentParsePageRange{from: 0, to: maximumTaskPageNumber})
|
||||
}
|
||||
return ranges, nil
|
||||
}
|
||||
if doc.ParserID == string(entity.ParserTypeTable) {
|
||||
binary, err := documentStorageBinary(bucket, objectName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows := documentEstimateTableRowCount(documentName(doc), binary)
|
||||
if rows <= 0 {
|
||||
return []documentParsePageRange{{from: 0, to: maximumTaskPageNumber}}, nil
|
||||
}
|
||||
ranges := make([]documentParsePageRange, 0, (rows+2999)/3000)
|
||||
for row := int64(0); row < int64(rows); row += 3000 {
|
||||
to := row + 3000
|
||||
if to > int64(rows) {
|
||||
to = int64(rows)
|
||||
}
|
||||
ranges = append(ranges, documentParsePageRange{from: row, to: to})
|
||||
}
|
||||
return ranges, nil
|
||||
}
|
||||
return []documentParsePageRange{{from: 0, to: maximumTaskPageNumber}}, nil
|
||||
}
|
||||
|
||||
func documentStorageBinary(bucket, objectName string) ([]byte, error) {
|
||||
storageImpl := storage.GetStorageFactory().GetStorage()
|
||||
if storageImpl == nil {
|
||||
return nil, fmt.Errorf("storage not initialized")
|
||||
}
|
||||
return storageImpl.Get(bucket, objectName)
|
||||
}
|
||||
|
||||
func documentName(doc *entity.Document) string {
|
||||
if doc == nil || doc.Name == nil {
|
||||
return ""
|
||||
}
|
||||
return *doc.Name
|
||||
}
|
||||
|
||||
func documentParserConfigInt(config map[string]interface{}, key string, fallback int) int {
|
||||
value, ok := config[key]
|
||||
if !ok || value == nil {
|
||||
return fallback
|
||||
}
|
||||
switch typedValue := value.(type) {
|
||||
case int:
|
||||
return typedValue
|
||||
case int64:
|
||||
return int(typedValue)
|
||||
case float64:
|
||||
return int(typedValue)
|
||||
case json.Number:
|
||||
if intValue, err := typedValue.Int64(); err == nil {
|
||||
return int(intValue)
|
||||
}
|
||||
case string:
|
||||
if intValue, err := strconv.Atoi(strings.TrimSpace(typedValue)); err == nil {
|
||||
return intValue
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func documentParserConfigBool(config map[string]interface{}, key string, fallback bool) bool {
|
||||
value, ok := config[key]
|
||||
if !ok || value == nil {
|
||||
return fallback
|
||||
}
|
||||
switch typedValue := value.(type) {
|
||||
case bool:
|
||||
return typedValue
|
||||
case string:
|
||||
switch strings.ToLower(strings.TrimSpace(typedValue)) {
|
||||
case "true", "1", "yes", "on":
|
||||
return true
|
||||
case "false", "0", "no", "off":
|
||||
return false
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func documentParserConfigPageRanges(config map[string]interface{}) []documentParsePageRange {
|
||||
defaultRanges := []documentParsePageRange{{from: 1, to: 100000}}
|
||||
raw, ok := config["pages"]
|
||||
if !ok || raw == nil {
|
||||
return defaultRanges
|
||||
}
|
||||
rawRanges, ok := raw.([]interface{})
|
||||
if !ok || len(rawRanges) == 0 {
|
||||
return defaultRanges
|
||||
}
|
||||
ranges := make([]documentParsePageRange, 0, len(rawRanges))
|
||||
for _, rawRange := range rawRanges {
|
||||
rangeValues, ok := rawRange.([]interface{})
|
||||
if !ok || len(rangeValues) < 2 {
|
||||
continue
|
||||
}
|
||||
from, okFrom := documentToInt64(rangeValues[0])
|
||||
to, okTo := documentToInt64(rangeValues[1])
|
||||
if okFrom && okTo && to > from {
|
||||
ranges = append(ranges, documentParsePageRange{from: from, to: to})
|
||||
}
|
||||
}
|
||||
if len(ranges) == 0 {
|
||||
return defaultRanges
|
||||
}
|
||||
return ranges
|
||||
}
|
||||
|
||||
func documentToInt64(value interface{}) (int64, bool) {
|
||||
switch typedValue := value.(type) {
|
||||
case int:
|
||||
return int64(typedValue), true
|
||||
case int64:
|
||||
return typedValue, true
|
||||
case float64:
|
||||
return int64(typedValue), true
|
||||
case json.Number:
|
||||
intValue, err := typedValue.Int64()
|
||||
return intValue, err == nil
|
||||
case string:
|
||||
intValue, err := strconv.ParseInt(strings.TrimSpace(typedValue), 10, 64)
|
||||
return intValue, err == nil
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
var documentPDFPagePattern = regexp.MustCompile(`/Type\s*/Page\b`)
|
||||
|
||||
func documentEstimatePDFPageCount(binary []byte) int64 {
|
||||
if len(binary) == 0 {
|
||||
return 0
|
||||
}
|
||||
return int64(len(documentPDFPagePattern.FindAll(binary, -1)))
|
||||
}
|
||||
|
||||
func documentEstimateTableRowCount(name string, binary []byte) int {
|
||||
switch strings.ToLower(filepath.Ext(name)) {
|
||||
case ".xlsx":
|
||||
if rows, err := documentCountXLSXRows(binary); err == nil {
|
||||
return rows
|
||||
}
|
||||
case ".csv", ".tsv", ".txt":
|
||||
return documentCountDelimitedRows(name, binary)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func documentCountDelimitedRows(name string, binary []byte) int {
|
||||
reader := csv.NewReader(bytes.NewReader(binary))
|
||||
reader.FieldsPerRecord = -1
|
||||
reader.ReuseRecord = true
|
||||
if strings.EqualFold(filepath.Ext(name), ".tsv") {
|
||||
reader.Comma = '\t'
|
||||
}
|
||||
rows := 0
|
||||
for {
|
||||
_, err := reader.Read()
|
||||
if err == nil {
|
||||
rows++
|
||||
continue
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
rows += bytes.Count(binary, []byte{'\n'})
|
||||
if len(binary) > 0 && binary[len(binary)-1] != '\n' {
|
||||
rows++
|
||||
}
|
||||
break
|
||||
}
|
||||
return rows
|
||||
}
|
||||
|
||||
func documentCountXLSXRows(binary []byte) (int, error) {
|
||||
zipReader, err := zip.NewReader(bytes.NewReader(binary), int64(len(binary)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
maxRows := 0
|
||||
for _, file := range zipReader.File {
|
||||
if !strings.HasPrefix(file.Name, "xl/worksheets/") || !strings.HasSuffix(file.Name, ".xml") {
|
||||
continue
|
||||
}
|
||||
rows, err := documentCountWorksheetRows(file)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if rows > maxRows {
|
||||
maxRows = rows
|
||||
}
|
||||
}
|
||||
return maxRows, nil
|
||||
}
|
||||
|
||||
func documentCountWorksheetRows(file *zip.File) (int, error) {
|
||||
reader, err := file.Open()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer reader.Close()
|
||||
decoder := xml.NewDecoder(reader)
|
||||
rows := 0
|
||||
for {
|
||||
token, err := decoder.Token()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
start, ok := token.(xml.StartElement)
|
||||
if ok && start.Name.Local == "row" {
|
||||
rows++
|
||||
}
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func (s *DocumentService) beginDocumentParse(docID string) error {
|
||||
now := time.Now()
|
||||
return dao.GetDB().Model(&entity.Document{}).Where("id = ?", docID).Updates(map[string]interface{}{
|
||||
"progress_msg": "Task is queued...",
|
||||
"process_begin_at": now,
|
||||
"progress": rand.Float64() * 0.01,
|
||||
"run": string(entity.TaskStatusRunning),
|
||||
"chunk_num": 0,
|
||||
"token_num": 0,
|
||||
}).Error
|
||||
}
|
||||
|
||||
func documentParseQueueName(doc *entity.Document, priority int64) string {
|
||||
suffix := "common"
|
||||
if doc.ParserID == string(entity.ParserTypeResume) {
|
||||
suffix = "resume"
|
||||
}
|
||||
return fmt.Sprintf("te.%d.%s", priority, suffix)
|
||||
}
|
||||
|
||||
func documentTaskMessage(task *entity.Task) map[string]interface{} {
|
||||
beginAt := ""
|
||||
if task.BeginAt != nil {
|
||||
beginAt = task.BeginAt.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
digest := ""
|
||||
if task.Digest != nil {
|
||||
digest = *task.Digest
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"id": task.ID,
|
||||
"doc_id": task.DocID,
|
||||
"from_page": task.FromPage,
|
||||
"to_page": task.ToPage,
|
||||
"progress": task.Progress,
|
||||
"priority": task.Priority,
|
||||
"begin_at": beginAt,
|
||||
"digest": digest,
|
||||
}
|
||||
}
|
||||
|
||||
func documentParseTaskDigest(doc *entity.Document, fromPage, toPage int64) string {
|
||||
hasher := xxhash.New()
|
||||
config := map[string]interface{}{
|
||||
"doc_id": doc.ID,
|
||||
"kb_id": doc.KbID,
|
||||
"parser_id": doc.ParserID,
|
||||
"parser_config": doc.ParserConfig,
|
||||
}
|
||||
keys := make([]string, 0, len(config))
|
||||
for key := range config {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, key := range keys {
|
||||
b, err := json.Marshal(config[key])
|
||||
if err != nil {
|
||||
hasher.WriteString(fmt.Sprint(config[key]))
|
||||
} else {
|
||||
hasher.Write(b)
|
||||
}
|
||||
}
|
||||
hasher.WriteString(doc.ID)
|
||||
hasher.WriteString(strconv.FormatInt(fromPage, 10))
|
||||
hasher.WriteString(strconv.FormatInt(toPage, 10))
|
||||
return fmt.Sprintf("%x", hasher.Sum64())
|
||||
}
|
||||
|
||||
func (s *DocumentService) clearKBChunkNumWhenRerun(doc *entity.Document) error {
|
||||
if doc == nil {
|
||||
return fmt.Errorf("document is nil")
|
||||
}
|
||||
return dao.GetDB().Model(&entity.Knowledgebase{}).Where("id = ?", doc.KbID).Updates(map[string]interface{}{
|
||||
"token_num": gorm.Expr("token_num - ?", doc.TokenNum),
|
||||
"chunk_num": gorm.Expr("chunk_num - ?", doc.ChunkNum),
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (s *DocumentService) ParseDocuments(datasetID, userID string, docIDs []string) ([]*ParseDocumentResponse, error) {
|
||||
// create document parse id
|
||||
// save to task table
|
||||
|
||||
Reference in New Issue
Block a user