From c2665d4ab146f55e745efddf895f4af2da2dfff8 Mon Sep 17 00:00:00 2001 From: Haruko386 Date: Wed, 24 Jun 2026 19:09:43 +0800 Subject: [PATCH] implement: /embedding/check POST (#16266) --- internal/handler/datasets.go | 72 +++ internal/router/router.go | 2 + internal/service/dataset.go | 1040 +++++++++++++++++++++++++++++++++- 3 files changed, 1111 insertions(+), 3 deletions(-) diff --git a/internal/handler/datasets.go b/internal/handler/datasets.go index 3982de068a..95a0447a49 100644 --- a/internal/handler/datasets.go +++ b/internal/handler/datasets.go @@ -664,6 +664,78 @@ func (h *DatasetsHandler) RemoveTags(c *gin.Context) { jsonResponse(c, common.CodeSuccess, true, "success") } +// RunEmbedding Run embedding for all documents in a dataset. +func (h *DatasetsHandler) RunEmbedding(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + userID := strings.TrimSpace(user.ID) + if userID == "" { + jsonError(c, common.CodeAuthenticationError, "user_id is required") + return + } + + datasetID := strings.TrimSpace(c.Param("dataset_id")) + if datasetID == "" { + jsonError(c, common.CodeDataError, "dataset_id is required") + return + } + + result, errorCode, err := h.datasetsService.RunEmbedding(userID, datasetID) + if err != nil { + jsonError(c, errorCode, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, result, "success") +} + +// CheckEmbedding Check embedding model compatibility by sampling random chunks, +// re-embedding them with the new model, and computing cosine similarity. +func (h *DatasetsHandler) CheckEmbedding(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + datasetID := strings.TrimSpace(c.Param("dataset_id")) + if datasetID == "" { + jsonError(c, common.CodeDataError, "dataset_id is required") + return + } + + userID := strings.TrimSpace(user.ID) + if userID == "" { + jsonError(c, common.CodeDataError, "user_id is required") + return + } + + var req service.CheckEmbeddingRequest + if err := c.ShouldBindJSON(&req); err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + if strings.TrimSpace(req.EmbeddingID) == "" { + jsonError(c, common.CodeDataError, "`embd_id` is required.") + return + } + + data, code, err := h.datasetsService.CheckEmbedding(userID, datasetID, &req) + if err != nil { + if code == common.CodeNotEffective { + jsonResponse(c, code, data, err.Error()) + return + } + jsonError(c, code, err.Error()) + return + } + jsonResponse(c, common.CodeSuccess, data, "success") +} + // AggregateTags handles GET /api/v1/datasets/tags/aggregation. // @Summary Aggregate dataset tags // @Description Aggregate tags across multiple datasets diff --git a/internal/router/router.go b/internal/router/router.go index ed51bc090c..730f08e370 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -279,6 +279,8 @@ func (r *Router) Setup(engine *gin.Engine) { datasets.GET("/:dataset_id/tags", r.datasetsHandler.ListTags) datasets.PUT("/:dataset_id/tags", r.datasetsHandler.RenameTag) datasets.DELETE("/:dataset_id/tags", r.datasetsHandler.RemoveTags) + datasets.POST("/:dataset_id/embedding", r.datasetsHandler.RunEmbedding) + datasets.POST("/:dataset_id/embedding/check", r.datasetsHandler.CheckEmbedding) datasets.POST("/:dataset_id/documents/batch-update-status", r.documentHandler.BatchUpdateDocumentStatus) datasets.GET("/:dataset_id/index", r.datasetsHandler.TraceIndex) datasets.POST("/:dataset_id/index", r.datasetsHandler.RunIndex) diff --git a/internal/service/dataset.go b/internal/service/dataset.go index 20536048e0..736f77ca24 100644 --- a/internal/service/dataset.go +++ b/internal/service/dataset.go @@ -17,21 +17,33 @@ package service import ( + "archive/zip" + "bytes" "context" + "encoding/csv" "encoding/json" + "encoding/xml" "errors" "fmt" + "io" + "math" + "math/rand" + "path/filepath" "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/engine" redisengine "ragflow/internal/engine/redis" "ragflow/internal/engine/types" + enginetypes "ragflow/internal/engine/types" "ragflow/internal/entity" "ragflow/internal/entity/models" "ragflow/internal/server" "ragflow/internal/service/nlp" + "ragflow/internal/storage" "ragflow/internal/utility" + "regexp" "sort" + "strconv" "strings" "time" @@ -80,9 +92,11 @@ var ( const ( // Keep the legacy worker marker in queue payloads; persisted tasks use a real document ID. - graphRaptorQueueDocID = "graph_raptor_x" - maximumTaskPageNumber = int64(100000000) - serverQueueNamePrefix = "te" + graphRaptorQueueDocID = "graph_raptor_x" + maximumPageNumber = int64(100000) + maximumTaskPageNumber = int64(100000000) + serverQueueNamePrefix = "te" + defaultEmbeddingCheckNum = 5 graphPhaseResolutionDone = "resolution_done" graphPhaseCommunityDone = "community_done" @@ -565,6 +579,1026 @@ func (s *DatasetService) TraceIndex(datasetID, userID, indexType string) (*entit return task, common.CodeSuccess, nil } +type CheckEmbeddingRequest struct { + EmbeddingID string `json:"embd_id" binding:"required"` + CheckNum *int `json:"check_num,omitempty"` +} + +type EmbeddingCheckSummary struct { + KbID string `json:"kb_id"` + Model string `json:"model"` + Sampled int `json:"sampled"` + Valid int `json:"valid"` + AvgCosSim float64 `json:"avg_cos_sim"` + MinCosSim float64 `json:"min_cos_sim"` + MaxCosSim float64 `json:"max_cos_sim"` + MatchMode string `json:"match_mode"` +} + +type EmbeddingCheckResult struct { + ChunkID string `json:"chunk_id"` + DocID string `json:"doc_id,omitempty"` + DocName string `json:"doc_name,omitempty"` + VectorField string `json:"vector_field,omitempty"` + VectorDim int `json:"vector_dim,omitempty"` + CosSim float64 `json:"cos_sim,omitempty"` + Reason string `json:"reason,omitempty"` +} + +type EmbeddingCheckResponse struct { + Summary EmbeddingCheckSummary `json:"summary"` + Results []EmbeddingCheckResult `json:"results"` +} + +type embeddingCheckSample struct { + ChunkID string + KbID string + DocID string + DocName string + VectorField string + Vector []float64 + PageNum interface{} + Position interface{} + Top interface{} + ContentWithWeight string + QuestionKeywords []string +} + +type datasetParsePageRange struct { + from int64 + to int64 +} + +// RunEmbedding runs embedding for all documents in a dataset. +func (s *DatasetService) RunEmbedding(userID, datasetID string) (map[string]interface{}, common.ErrorCode, error) { + if datasetID == "" { + return nil, common.CodeDataError, errors.New(`Lack of "Dataset ID"`) + } + if !s.kbDAO.Accessible(datasetID, userID) { + return nil, common.CodeDataError, errors.New("No authorization.") + } + + kb, err := s.kbDAO.GetByID(datasetID) + if err != nil { + if dao.IsNotFoundErr(err) { + return nil, common.CodeDataError, errors.New("Invalid Dataset ID") + } + return nil, common.CodeServerError, errors.New("Internal server error") + } + + documents, _, err := s.documentDAO.GetByKBID(datasetID) + if err != nil { + return nil, common.CodeServerError, errors.New("Internal server error") + } + if len(documents) == 0 { + return nil, common.CodeDataError, fmt.Errorf("No documents in Dataset %s", datasetID) + } + + tableDoneCountByKB := make(map[string]int64) + scheduledCount := 0 + for _, doc := range documents { + if doc == nil { + continue + } + if err := s.runEmbeddingDocument(kb, doc, tableDoneCountByKB); err != nil { + common.Warn("Failed to schedule dataset embedding document", + zap.String("datasetID", datasetID), + zap.String("docID", doc.ID), + zap.Error(err)) + return nil, common.CodeServerError, errors.New("Internal server error") + } + scheduledCount++ + } + + return map[string]interface{}{ + "scheduled_count": scheduledCount, + }, common.CodeSuccess, nil +} + +func (s *DatasetService) runEmbeddingDocument(kb *entity.Knowledgebase, doc *entity.Document, tableDoneCountByKB map[string]int64) error { + if doc.PipelineID != nil && strings.TrimSpace(*doc.PipelineID) != "" { + return s.queueDatasetDataflowTask(kb, doc, strings.TrimSpace(*doc.PipelineID), 0) + } + + if doc.ParserID == string(entity.ParserTypeTable) { + doneCount, ok := tableDoneCountByKB[doc.KbID] + if !ok { + count, err := s.countDoneDocuments(doc.KbID) + if err != nil { + return err + } + doneCount = count + tableDoneCountByKB[doc.KbID] = doneCount + if doneCount <= 0 { + if err := s.kbDAO.DeleteFieldMap(doc.KbID); err != nil && !dao.IsNotFoundErr(err) { + return err + } + } + } + } + + indexName := fmt.Sprintf("ragflow_%s", kb.TenantID) + if s.docEngine != nil { + if _, err := s.docEngine.DeleteChunks(context.Background(), map[string]interface{}{"doc_id": doc.ID}, indexName, doc.KbID); err != nil { + return err + } + } + if _, err := s.taskDAO.DeleteByDocIDs([]string{doc.ID}); err != nil { + return err + } + + bucket, objectName, err := NewDocumentService().GetDocumentStorageAddress(doc) + if err != nil { + return err + } + if err := s.queueDatasetParseTasks(doc, bucket, objectName, 0); err != nil { + return err + } + if err := s.beginDatasetParseDocument(doc.ID); err != nil { + if _, delErr := s.taskDAO.DeleteByDocIDs([]string{doc.ID}); delErr != nil { + common.Warn("Failed to clean parse tasks after document state update failure", + zap.String("docID", doc.ID), + zap.Error(delErr)) + } + return err + } + return nil +} + +func (s *DatasetService) queueDatasetDataflowTask(kb *entity.Knowledgebase, doc *entity.Document, flowID string, priority int64) error { + if _, err := s.taskDAO.DeleteByDocIDs([]string{doc.ID}); err != nil { + return err + } + if err := s.beginDatasetParseDocument(doc.ID); err != nil { + return err + } + + now := time.Now() + task := &entity.Task{ + ID: common.GenerateUUID(), + DocID: doc.ID, + FromPage: 0, + ToPage: maximumTaskPageNumber, + TaskType: "dataflow", + Priority: priority, + BeginAt: &now, + Progress: 0, + } + if err := s.taskDAO.CreateMany([]*entity.Task{task}); err != nil { + return err + } + + message := datasetParseTaskMessage(task) + message["task_type"] = task.TaskType + message["kb_id"] = doc.KbID + message["tenant_id"] = kb.TenantID + message["dataflow_id"] = flowID + message["file"] = nil + if redisClient := redisengine.Get(); redisClient == nil || !redisClient.QueueProduct(datasetParseQueueName(doc, priority), message) { + return fmt.Errorf("Can't access Redis. Please check the Redis' status.") + } + return nil +} + +func (s *DatasetService) countDoneDocuments(datasetID string) (int64, error) { + var count int64 + err := dao.GetDB().Model(&entity.Document{}). + Where("kb_id = ? AND run = ?", datasetID, string(entity.TaskStatusDone)). + Count(&count).Error + return count, err +} + +func (s *DatasetService) queueDatasetParseTasks(doc *entity.Document, bucket, objectName string, priority int64) error { + tasks, err := s.buildDatasetParseTasks(doc, bucket, objectName, priority) + if err != nil { + return err + } + if len(tasks) == 0 { + return nil + } + if err := s.taskDAO.CreateMany(tasks); err != nil { + return err + } + queueName := datasetParseQueueName(doc, priority) + for _, task := range tasks { + if task.Progress >= 1 { + continue + } + if redisClient := redisengine.Get(); redisClient == nil || !redisClient.QueueProduct(queueName, datasetParseTaskMessage(task)) { + if _, delErr := s.taskDAO.DeleteByDocIDs([]string{doc.ID}); delErr != nil { + common.Warn("Failed to clean parse tasks after Redis enqueue failure", + zap.String("docID", doc.ID), + zap.Error(delErr)) + } + return fmt.Errorf("Can't access Redis. Please check the Redis' status.") + } + } + return nil +} + +func (s *DatasetService) buildDatasetParseTasks(doc *entity.Document, bucket, objectName string, priority int64) ([]*entity.Task, error) { + ranges, err := datasetParseTaskRanges(doc, bucket, objectName) + if err != nil { + return nil, err + } + now := time.Now() + tasks := make([]*entity.Task, 0, len(ranges)) + for _, pageRange := range ranges { + progressMsg := "" + digest := datasetParseTaskDigest(doc, pageRange.from, pageRange.to) + chunkIDs := "" + tasks = append(tasks, &entity.Task{ + ID: common.GenerateUUID(), + DocID: doc.ID, + FromPage: pageRange.from, + ToPage: pageRange.to, + TaskType: "", + Priority: priority, + BeginAt: &now, + Progress: 0, + ProgressMsg: &progressMsg, + Digest: &digest, + ChunkIDs: &chunkIDs, + }) + } + return tasks, nil +} + +func (s *DatasetService) beginDatasetParseDocument(docID string) error { + now := time.Now() + return dao.GetDB().Model(&entity.Document{}).Where("id = ?", docID).Updates(map[string]interface{}{ + "progress_msg": "Task is queued...", + "process_begin_at": now, + "progress": rand.Float64() * 0.01, + "run": string(entity.TaskStatusRunning), + "chunk_num": 0, + "token_num": 0, + }).Error +} + +// CheckEmbedding checks whether a new embedding model is compatible with stored vectors. +func (s *DatasetService) CheckEmbedding(userID, datasetID string, req *CheckEmbeddingRequest) (*EmbeddingCheckResponse, common.ErrorCode, error) { + if datasetID == "" { + return nil, common.CodeDataError, errors.New(`Lack of "Dataset ID"`) + } + if !s.kbDAO.Accessible(datasetID, userID) { + return nil, common.CodeDataError, errors.New("No authorization.") + } + + kb, err := s.kbDAO.GetByID(datasetID) + if err != nil { + if dao.IsNotFoundErr(err) { + return nil, common.CodeDataError, errors.New("Invalid Dataset ID") + } + return nil, common.CodeServerError, errors.New("Internal server error") + } + + if req == nil || strings.TrimSpace(req.EmbeddingID) == "" { + return nil, common.CodeDataError, errors.New("`embd_id` is required.") + } + embeddingID := strings.TrimSpace(req.EmbeddingID) + if ok, message := s.verifyEmbeddingAvailability(embeddingID, userID); !ok { + return nil, common.CodeDataError, errors.New(message) + } + if s.docEngine == nil { + return nil, common.CodeServerError, errors.New("doc engine not initialized") + } + + driver, modelName, apiConfig, maxTokens, err := NewModelProviderService().GetModelConfigFromProviderInstance(kb.TenantID, entity.ModelTypeEmbedding, embeddingID) + if err != nil { + return nil, common.CodeDataError, err + } + embeddingModel := models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens) + + checkNum := defaultEmbeddingCheckNum + if req.CheckNum != nil { + checkNum = *req.CheckNum + } + if checkNum <= 0 { + checkNum = defaultEmbeddingCheckNum + } + + samples, err := s.sampleRandomChunksWithVectors(context.Background(), kb.TenantID, datasetID, checkNum) + if err != nil { + return nil, common.CodeServerError, err + } + + results := make([]EmbeddingCheckResult, 0, len(samples)) + effectiveSimilarities := make([]float64, 0, len(samples)) + matchMode := "content_only" + for _, sample := range samples { + title := sample.DocName + if strings.TrimSpace(title) == "" { + title = "Title" + } + + textInput := strings.Join(sample.QuestionKeywords, "\n") + if strings.TrimSpace(textInput) == "" { + textInput = sample.ContentWithWeight + } + textInput = datasetCleanEmbeddingText(textInput) + if textInput == "" { + results = append(results, EmbeddingCheckResult{ChunkID: sample.ChunkID, Reason: "no_text"}) + continue + } + if len(sample.Vector) == 0 { + results = append(results, EmbeddingCheckResult{ChunkID: sample.ChunkID, Reason: "no_stored_vector"}) + continue + } + + vectors, err := datasetEncodeEmbedding(embeddingModel, []string{title, textInput}) + if err != nil { + return nil, common.CodeDataError, fmt.Errorf("Embedding failure. %w", err) + } + if len(vectors) < 2 { + return nil, common.CodeDataError, errors.New("Embedding failure. embedding response is incomplete") + } + if len(vectors[1]) != len(sample.Vector) { + return nil, common.CodeDataError, fmt.Errorf("Embedding failure. The dimension (%d) of given embedding model is different from the original (%d)", len(vectors[1]), len(sample.Vector)) + } + + simContent := datasetCosSim(vectors[1], sample.Vector) + simMix := datasetCosSim(datasetMixVectors(vectors[0], vectors[1], 0.1), sample.Vector) + sim := simContent + matchMode = "content_only" + if simMix > sim { + sim = simMix + matchMode = "title+content" + } + sim = datasetRoundFloat(sim, 6) + + effectiveSimilarities = append(effectiveSimilarities, sim) + results = append(results, EmbeddingCheckResult{ + ChunkID: sample.ChunkID, + DocID: sample.DocID, + DocName: sample.DocName, + VectorField: sample.VectorField, + VectorDim: len(sample.Vector), + CosSim: sim, + }) + } + + summary := datasetEmbeddingCheckSummary(datasetID, embeddingID, len(samples), effectiveSimilarities, matchMode) + response := &EmbeddingCheckResponse{Summary: summary, Results: results} + if len(effectiveSimilarities) == 0 { + return nil, common.CodeDataError, errors.New("No embedded chunks are available to compare.") + } + if summary.AvgCosSim >= 0.9 { + return response, common.CodeSuccess, nil + } + return response, common.CodeNotEffective, errors.New("Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.") +} + +func (s *DatasetService) sampleRandomChunksWithVectors(ctx context.Context, tenantID, datasetID string, n int) ([]embeddingCheckSample, error) { + indexName := fmt.Sprintf("ragflow_%s", tenantID) + totalResult, err := s.docEngine.Search(ctx, &enginetypes.SearchRequest{ + IndexNames: []string{indexName}, + KbIDs: []string{datasetID}, + Offset: 0, + Limit: 1, + Filter: map[string]interface{}{ + "kb_id": datasetID, + "available_int": 1, + }, + }) + if err != nil { + return nil, err + } + if totalResult == nil || totalResult.Total <= 0 { + return []embeddingCheckSample{}, nil + } + + total := int(totalResult.Total) + if n > total { + n = total + } + limit := total + if limit > 1000 { + limit = 1000 + } + if n > limit { + n = limit + } + offsets := rand.Perm(limit) + offsets = offsets[:n] + sort.Ints(offsets) + + baseFields := []string{"docnm_kwd", "doc_id", "content_with_weight", "page_num_int", "position_int", "top_int"} + samples := make([]embeddingCheckSample, 0, n) + for _, offset := range offsets { + searchResult, err := s.docEngine.Search(ctx, &enginetypes.SearchRequest{ + IndexNames: []string{indexName}, + KbIDs: []string{datasetID}, + Offset: offset, + Limit: 1, + SelectFields: baseFields, + Filter: map[string]interface{}{ + "kb_id": datasetID, + "available_int": 1, + }, + }) + if err != nil { + return nil, err + } + if searchResult == nil || len(searchResult.Chunks) == 0 { + continue + } + chunkID := datasetChunkID(searchResult.Chunks[0]) + if chunkID == "" { + continue + } + fullChunk, err := s.docEngine.GetChunk(ctx, indexName, chunkID, []string{datasetID}) + if err != nil { + return nil, err + } + chunkMap := datasetMap(fullChunk) + if len(chunkMap) == 0 { + continue + } + vectorField := datasetGuessVecField(chunkMap) + vector := datasetAsFloatVec(chunkMap[vectorField]) + samples = append(samples, embeddingCheckSample{ + ChunkID: chunkID, + KbID: datasetID, + DocID: datasetString(chunkMap["doc_id"]), + DocName: datasetString(chunkMap["docnm_kwd"]), + VectorField: vectorField, + Vector: vector, + PageNum: chunkMap["page_num_int"], + Position: chunkMap["position_int"], + Top: chunkMap["top_int"], + ContentWithWeight: datasetString(chunkMap["content_with_weight"]), + QuestionKeywords: datasetStringSlice(chunkMap["question_kwd"]), + }) + } + return samples, nil +} + +func datasetGuessVecField(src map[string]interface{}) string { + for k := range src { + if strings.HasSuffix(k, "_vec") { + return k + } + } + return "" +} + +func datasetAsFloatVec(v interface{}) []float64 { + if v == nil { + return []float64{} + } + switch val := v.(type) { + case string: + parts := strings.Split(val, "\t") + res := make([]float64, 0, len(parts)) + for _, p := range parts { + if p == "" { + continue + } + f, err := strconv.ParseFloat(p, 64) + if err != nil { + continue + } + res = append(res, f) + } + return res + case []float64: + return val + case []float32: + res := make([]float64, len(val)) + for i, x := range val { + res[i] = float64(x) + } + return res + case []int: + res := make([]float64, len(val)) + for i, x := range val { + res[i] = float64(x) + } + return res + case []interface{}: + res := make([]float64, 0, len(val)) + for _, x := range val { + switch n := x.(type) { + case float64: + res = append(res, n) + case float32: + res = append(res, float64(n)) + case int: + res = append(res, float64(n)) + case string: + f, err := strconv.ParseFloat(n, 64) + if err == nil { + res = append(res, f) + } + } + } + return res + } + return []float64{} +} + +func datasetCosSim(a, b []float64) float64 { + if len(a) == 0 || len(b) == 0 { + return 0 + } + var dot, na, nb float64 + n := len(a) + if len(b) < n { + n = len(b) + } + for i := 0; i < n; i++ { + dot += a[i] * b[i] + } + for _, x := range a { + na += x * x + } + for _, x := range b { + nb += x * x + } + + if na == 0 || nb == 0 { + return 0 + } + return dot / (math.Sqrt(na) * math.Sqrt(nb)) +} + +func datasetCleanEmbeddingText(s string) string { + re := regexp.MustCompile(`]{0,12})?>`) + return strings.TrimSpace(re.ReplaceAllString(s, " ")) +} + +func datasetEncodeEmbedding(embeddingModel *models.EmbeddingModel, texts []string) ([][]float64, error) { + embeddingConfig := &models.EmbeddingConfig{Dimension: 0} + embeddings, err := embeddingModel.ModelDriver.Embed(embeddingModel.ModelName, texts, embeddingModel.APIConfig, embeddingConfig) + if err != nil { + return nil, err + } + vectors := make([][]float64, len(embeddings)) + for i, embedding := range embeddings { + vectors[i] = embedding.Embedding + } + return vectors, nil +} + +func datasetMixVectors(titleVector, contentVector []float64, titleWeight float64) []float64 { + if len(titleVector) != len(contentVector) { + return contentVector + } + mixed := make([]float64, len(contentVector)) + contentWeight := 1 - titleWeight + for i := range contentVector { + mixed[i] = titleWeight*titleVector[i] + contentWeight*contentVector[i] + } + return mixed +} + +func datasetEmbeddingCheckSummary(datasetID, embeddingID string, sampled int, similarities []float64, matchMode string) EmbeddingCheckSummary { + summary := EmbeddingCheckSummary{ + KbID: datasetID, + Model: embeddingID, + Sampled: sampled, + Valid: len(similarities), + MatchMode: matchMode, + } + if len(similarities) == 0 { + return summary + } + minValue := similarities[0] + maxValue := similarities[0] + total := 0.0 + for _, value := range similarities { + total += value + if value < minValue { + minValue = value + } + if value > maxValue { + maxValue = value + } + } + summary.AvgCosSim = datasetRoundFloat(total/float64(len(similarities)), 6) + summary.MinCosSim = datasetRoundFloat(minValue, 6) + summary.MaxCosSim = datasetRoundFloat(maxValue, 6) + return summary +} + +func datasetRoundFloat(value float64, places int) float64 { + factor := math.Pow10(places) + return math.Round(value*factor) / factor +} + +func datasetChunkID(chunk map[string]interface{}) string { + for _, key := range []string{"id", "_id"} { + if value := datasetString(chunk[key]); value != "" { + return value + } + } + return "" +} + +func datasetMap(value interface{}) map[string]interface{} { + switch typedValue := value.(type) { + case map[string]interface{}: + return typedValue + default: + return map[string]interface{}{} + } +} + +func datasetString(value interface{}) string { + switch typedValue := value.(type) { + case string: + return typedValue + case fmt.Stringer: + return typedValue.String() + case nil: + return "" + default: + return fmt.Sprint(typedValue) + } +} + +func datasetStringSlice(value interface{}) []string { + switch typedValue := value.(type) { + case []string: + return typedValue + case []interface{}: + values := make([]string, 0, len(typedValue)) + for _, item := range typedValue { + if s := strings.TrimSpace(datasetString(item)); s != "" { + values = append(values, s) + } + } + return values + case string: + if typedValue == "" { + return nil + } + return []string{typedValue} + default: + return nil + } +} + +func datasetParseQueueName(doc *entity.Document, priority int64) string { + suffix := "common" + if doc.ParserID == string(entity.ParserTypeResume) { + suffix = "resume" + } + return fmt.Sprintf("%s.%d.%s", serverQueueNamePrefix, priority, suffix) +} + +func datasetParseTaskMessage(task *entity.Task) map[string]interface{} { + beginAt := "" + if task.BeginAt != nil { + beginAt = task.BeginAt.Format("2006-01-02 15:04:05") + } + digest := "" + if task.Digest != nil { + digest = *task.Digest + } + return map[string]interface{}{ + "id": task.ID, + "doc_id": task.DocID, + "from_page": task.FromPage, + "to_page": task.ToPage, + "progress": task.Progress, + "priority": task.Priority, + "begin_at": beginAt, + "digest": digest, + } +} + +func datasetParseTaskDigest(doc *entity.Document, fromPage, toPage int64) string { + hasher := xxhash.New() + config := datasetChunkingConfigForDigest(doc) + keys := make([]string, 0, len(config)) + for key := range config { + keys = append(keys, key) + } + sort.Strings(keys) + for _, key := range keys { + hasher.WriteString(datasetStableString(config[key])) + } + hasher.WriteString(doc.ID) + hasher.WriteString(strconv.FormatInt(fromPage, 10)) + hasher.WriteString(strconv.FormatInt(toPage, 10)) + return fmt.Sprintf("%x", hasher.Sum64()) +} + +func datasetChunkingConfigForDigest(doc *entity.Document) map[string]interface{} { + return map[string]interface{}{ + "doc_id": doc.ID, + "kb_id": doc.KbID, + "parser_id": doc.ParserID, + "parser_config": datasetCopyParserConfigForDigest(doc.ParserConfig), + } +} + +func datasetCopyParserConfigForDigest(config map[string]interface{}) map[string]interface{} { + copied := make(map[string]interface{}, len(config)) + for key, value := range config { + if key == "raptor" || key == "graphrag" { + continue + } + copied[key] = value + } + return copied +} + +func datasetStableString(value interface{}) string { + binary, err := json.Marshal(value) + if err != nil { + return fmt.Sprint(value) + } + return string(binary) +} + +func datasetParseTaskRanges(doc *entity.Document, bucket, objectName string) ([]datasetParsePageRange, error) { + if doc.Type == "pdf" { + return datasetPDFParseTaskRanges(doc, bucket, objectName) + } + if doc.ParserID == string(entity.ParserTypeTable) { + return datasetTableParseTaskRanges(doc, bucket, objectName) + } + return []datasetParsePageRange{{from: 0, to: maximumTaskPageNumber}}, nil +} + +func datasetPDFParseTaskRanges(doc *entity.Document, bucket, objectName string) ([]datasetParsePageRange, error) { + binary, err := datasetStorageBinary(bucket, objectName) + if err != nil { + return nil, err + } + pages := datasetEstimatePDFPageCount(binary) + pageSize := int64(datasetParserConfigInt(doc.ParserConfig, "task_page_size", 12)) + if doc.ParserID == string(entity.ParserTypePaper) { + pageSize = int64(datasetParserConfigInt(doc.ParserConfig, "task_page_size", 22)) + } + if doc.ParserID == string(entity.ParserTypeOne) || + doc.ParserID == string(entity.ParserTypeKG) || + datasetParserConfigString(doc.ParserConfig, "layout_recognize", "DeepDOC") != "DeepDOC" || + datasetParserConfigBool(doc.ParserConfig, "toc_extraction", false) { + pageSize = maximumTaskPageNumber + } + if pageSize <= 0 { + pageSize = 12 + } + + pageRanges := datasetParserConfigPageRanges(doc.ParserConfig) + ranges := make([]datasetParsePageRange, 0) + for _, configuredRange := range pageRanges { + start := configuredRange.from - 1 + if start < 0 { + start = 0 + } + end := configuredRange.to - 1 + if pages >= 0 && end > pages { + end = pages + } + for page := start; page < end; page += pageSize { + to := page + pageSize + if to > end { + to = end + } + ranges = append(ranges, datasetParsePageRange{from: page, to: to}) + } + } + if len(ranges) == 0 { + ranges = append(ranges, datasetParsePageRange{from: 0, to: maximumTaskPageNumber}) + } + return ranges, nil +} + +func datasetTableParseTaskRanges(doc *entity.Document, bucket, objectName string) ([]datasetParsePageRange, error) { + binary, err := datasetStorageBinary(bucket, objectName) + if err != nil { + return nil, err + } + rows := datasetEstimateTableRowCount(datasetDocName(doc), binary) + if rows <= 0 { + return []datasetParsePageRange{{from: 0, to: maximumTaskPageNumber}}, nil + } + ranges := make([]datasetParsePageRange, 0, (rows+2999)/3000) + for row := int64(0); row < int64(rows); row += 3000 { + to := row + 3000 + if to > int64(rows) { + to = int64(rows) + } + ranges = append(ranges, datasetParsePageRange{from: row, to: to}) + } + return ranges, nil +} + +func datasetStorageBinary(bucket, objectName string) ([]byte, error) { + storageImpl := storage.GetStorageFactory().GetStorage() + if storageImpl == nil { + return nil, fmt.Errorf("storage not initialized") + } + return storageImpl.Get(bucket, objectName) +} + +func datasetDocName(doc *entity.Document) string { + if doc == nil || doc.Name == nil { + return "" + } + return *doc.Name +} + +func datasetParserConfigInt(config map[string]interface{}, key string, fallback int) int { + value, ok := config[key] + if !ok || value == nil { + return fallback + } + switch typedValue := value.(type) { + case int: + return typedValue + case int64: + return int(typedValue) + case float64: + return int(typedValue) + case json.Number: + if intValue, err := typedValue.Int64(); err == nil { + return int(intValue) + } + case string: + if intValue, err := strconv.Atoi(strings.TrimSpace(typedValue)); err == nil { + return intValue + } + } + return fallback +} + +func datasetParserConfigString(config map[string]interface{}, key, fallback string) string { + value, ok := config[key] + if !ok || value == nil { + return fallback + } + if stringValue, ok := value.(string); ok { + return stringValue + } + return fmt.Sprint(value) +} + +func datasetParserConfigBool(config map[string]interface{}, key string, fallback bool) bool { + value, ok := config[key] + if !ok || value == nil { + return fallback + } + switch typedValue := value.(type) { + case bool: + return typedValue + case string: + switch strings.ToLower(strings.TrimSpace(typedValue)) { + case "true", "1", "yes", "on": + return true + case "false", "0", "no", "off": + return false + } + } + return fallback +} + +func datasetParserConfigPageRanges(config map[string]interface{}) []datasetParsePageRange { + defaultRanges := []datasetParsePageRange{{from: 1, to: maximumPageNumber}} + raw, ok := config["pages"] + if !ok || raw == nil { + return defaultRanges + } + rawRanges, ok := raw.([]interface{}) + if !ok || len(rawRanges) == 0 { + return defaultRanges + } + + ranges := make([]datasetParsePageRange, 0, len(rawRanges)) + for _, rawRange := range rawRanges { + rangeValues, ok := rawRange.([]interface{}) + if !ok || len(rangeValues) < 2 { + continue + } + from, okFrom := datasetToInt64(rangeValues[0]) + to, okTo := datasetToInt64(rangeValues[1]) + if okFrom && okTo && to > from { + ranges = append(ranges, datasetParsePageRange{from: from, to: to}) + } + } + if len(ranges) == 0 { + return defaultRanges + } + return ranges +} + +func datasetToInt64(value interface{}) (int64, bool) { + switch typedValue := value.(type) { + case int: + return int64(typedValue), true + case int64: + return typedValue, true + case float64: + return int64(typedValue), true + case json.Number: + intValue, err := typedValue.Int64() + return intValue, err == nil + case string: + intValue, err := strconv.ParseInt(strings.TrimSpace(typedValue), 10, 64) + return intValue, err == nil + default: + return 0, false + } +} + +var datasetPDFPagePattern = regexp.MustCompile(`/Type\s*/Page\b`) + +func datasetEstimatePDFPageCount(binary []byte) int64 { + if len(binary) == 0 { + return 0 + } + return int64(len(datasetPDFPagePattern.FindAll(binary, -1))) +} + +func datasetEstimateTableRowCount(name string, binary []byte) int { + switch strings.ToLower(filepath.Ext(name)) { + case ".xlsx": + if rows, err := datasetCountXLSXRows(binary); err == nil { + return rows + } + case ".csv", ".tsv", ".txt": + return datasetCountDelimitedRows(name, binary) + } + return 0 +} + +func datasetCountDelimitedRows(name string, binary []byte) int { + reader := csv.NewReader(bytes.NewReader(binary)) + reader.FieldsPerRecord = -1 + reader.ReuseRecord = true + if strings.EqualFold(filepath.Ext(name), ".tsv") { + reader.Comma = '\t' + } + rows := 0 + for { + _, err := reader.Read() + if err == nil { + rows++ + continue + } + if err == io.EOF { + break + } + rows += bytes.Count(binary, []byte{'\n'}) + if len(binary) > 0 && binary[len(binary)-1] != '\n' { + rows++ + } + break + } + return rows +} + +func datasetCountXLSXRows(binary []byte) (int, error) { + zipReader, err := zip.NewReader(bytes.NewReader(binary), int64(len(binary))) + if err != nil { + return 0, err + } + maxRows := 0 + for _, file := range zipReader.File { + if !strings.HasPrefix(file.Name, "xl/worksheets/") || !strings.HasSuffix(file.Name, ".xml") { + continue + } + rows, err := datasetCountWorksheetRows(file) + if err != nil { + return 0, err + } + if rows > maxRows { + maxRows = rows + } + } + return maxRows, nil +} + +func datasetCountWorksheetRows(file *zip.File) (int, error) { + reader, err := file.Open() + if err != nil { + return 0, err + } + defer reader.Close() + + decoder := xml.NewDecoder(reader) + rows := 0 + for { + token, err := decoder.Token() + if err == io.EOF { + break + } + if err != nil { + return 0, err + } + start, ok := token.(xml.StartElement) + if ok && start.Name.Local == "row" { + rows++ + } + } + return rows, nil +} + func (s *DatasetService) DeleteIndex(userID, datasetID, indexType string, wipe bool) (common.ErrorCode, error) { if !checkType(indexType) { return common.CodeArgumentError, fmt.Errorf("Invalid index type '%s'", indexType)