mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
### What problem does this PR solve? 1. Add CREATE and DROP DATASET / MEMORY / AGENT / SEARCH / CHAT. 2. Add option to build.sh to strip RAGFlow binary. ### Type of change - [x] Refactoring --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
3830 lines
119 KiB
Go
3830 lines
119 KiB
Go
//
|
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
|
|
package service
|
|
|
|
import (
|
|
"archive/zip"
|
|
"bytes"
|
|
"context"
|
|
"encoding/csv"
|
|
"encoding/json"
|
|
"encoding/xml"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"math/rand"
|
|
"path/filepath"
|
|
"ragflow/internal/common"
|
|
"ragflow/internal/dao"
|
|
"ragflow/internal/engine"
|
|
redisengine "ragflow/internal/engine/redis"
|
|
"ragflow/internal/engine/types"
|
|
enginetypes "ragflow/internal/engine/types"
|
|
"ragflow/internal/entity"
|
|
"ragflow/internal/entity/models"
|
|
"ragflow/internal/server"
|
|
"ragflow/internal/service/nlp"
|
|
"ragflow/internal/storage"
|
|
"ragflow/internal/utility"
|
|
"regexp"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/cespare/xxhash/v2"
|
|
"github.com/google/uuid"
|
|
"go.uber.org/zap"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
)
|
|
|
|
var (
|
|
datasetAllowedChunkMethods = map[string]struct{}{
|
|
"naive": {},
|
|
"book": {},
|
|
"email": {},
|
|
"laws": {},
|
|
"manual": {},
|
|
"one": {},
|
|
"paper": {},
|
|
"picture": {},
|
|
"presentation": {},
|
|
"qa": {},
|
|
"resume": {},
|
|
"table": {},
|
|
"tag": {},
|
|
}
|
|
datasetSupportedAvatarMIMETypes = map[string]struct{}{
|
|
"image/jpeg": {},
|
|
"image/png": {},
|
|
}
|
|
datasetAllowedOrderByFields = map[string]struct{}{
|
|
"create_time": {},
|
|
"update_time": {},
|
|
}
|
|
datasetAllowedMetadataTypes = map[string]struct{}{
|
|
"string": {},
|
|
"list": {},
|
|
"time": {},
|
|
"number": {},
|
|
}
|
|
datasetChunkMethodErrorMessage = "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'resume', 'table' or 'tag'"
|
|
validIndexTypes = []string{"graph", "raptor", "mindmap"}
|
|
indexTypeToTaskType = map[string]string{"graph": "graphrag", "raptor": "raptor", "mindmap": "mindmap"}
|
|
indexTypeToDisplayName = map[string]string{"graph": "Graph", "raptor": "RAPTOR", "mindmap": "Mindmap"}
|
|
)
|
|
|
|
const (
|
|
// Keep the legacy worker marker in queue payloads; persisted tasks use a real document ID.
|
|
graphRaptorQueueDocID = "graph_raptor_x"
|
|
maximumPageNumber = int64(100000)
|
|
maximumTaskPageNumber = int64(100000000)
|
|
serverQueueNamePrefix = "te"
|
|
defaultEmbeddingCheckNum = 5
|
|
|
|
graphPhaseResolutionDone = "resolution_done"
|
|
graphPhaseCommunityDone = "community_done"
|
|
)
|
|
|
|
// DatasetService implements the RESTful dataset APIs from dataset_api.py.
|
|
type DatasetService struct {
|
|
kbDAO *dao.KnowledgebaseDAO
|
|
documentDAO *dao.DocumentDAO
|
|
connectorDAO *dao.ConnectorDAO
|
|
tenantDAO *dao.TenantDAO
|
|
tenantLLMDAO *dao.TenantLLMDAO
|
|
pipelineLogDAO *dao.PipelineOperationLogDAO
|
|
userTenantDAO *dao.UserTenantDAO
|
|
taskDAO *dao.TaskDAO
|
|
searchService *SearchService
|
|
docEngine engine.DocEngine
|
|
embeddingCache *utility.EmbeddingLRU
|
|
engineType server.EngineType
|
|
}
|
|
|
|
// NewDatasetService creates a new datasets service.
|
|
func NewDatasetService() *DatasetService {
|
|
cfg := server.GetConfig()
|
|
engineType := server.EngineType("")
|
|
if cfg != nil {
|
|
engineType = cfg.DocEngine.Type
|
|
}
|
|
return &DatasetService{
|
|
kbDAO: dao.NewKnowledgebaseDAO(),
|
|
documentDAO: dao.NewDocumentDAO(),
|
|
connectorDAO: dao.NewConnectorDAO(),
|
|
tenantDAO: dao.NewTenantDAO(),
|
|
tenantLLMDAO: dao.NewTenantLLMDAO(),
|
|
pipelineLogDAO: dao.NewPipelineOperationLogDAO(),
|
|
userTenantDAO: dao.NewUserTenantDAO(),
|
|
taskDAO: dao.NewTaskDAO(),
|
|
searchService: NewSearchService(),
|
|
docEngine: engine.Get(),
|
|
embeddingCache: utility.NewEmbeddingLRU(1000),
|
|
engineType: engineType,
|
|
}
|
|
}
|
|
|
|
func (s *DatasetService) UpdateDocumentMetadataConfig(userID, datasetID, documentID string, req map[string]interface{}) (*entity.Document, common.ErrorCode, error) {
|
|
if _, err := s.kbDAO.GetByIDAndTenantID(datasetID, userID); err != nil {
|
|
if dao.IsNotFoundErr(err) {
|
|
return nil, common.CodeDataError, errors.New("You don't own the dataset.")
|
|
}
|
|
return nil, common.CodeServerError, errors.New("Database operation failed")
|
|
}
|
|
|
|
doc, err := s.documentDAO.GetByDocumentIDAndDatasetID(documentID, datasetID)
|
|
if err != nil {
|
|
if dao.IsNotFoundErr(err) {
|
|
return nil, common.CodeDataError, fmt.Errorf("Document %s not found in dataset %s", documentID, datasetID)
|
|
}
|
|
return nil, common.CodeServerError, err
|
|
}
|
|
|
|
metadata, ok := req["metadata"]
|
|
if !ok {
|
|
return nil, common.CodeArgumentError, errors.New("metadata is required")
|
|
}
|
|
|
|
parserConfig := doc.ParserConfig
|
|
if parserConfig == nil {
|
|
parserConfig = entity.JSONMap{}
|
|
}
|
|
parserConfig["metadata"] = metadata
|
|
|
|
if err := s.documentDAO.UpdateByID(doc.ID, map[string]interface{}{"parser_config": parserConfig}); err != nil {
|
|
return nil, common.CodeExceptionError, err
|
|
}
|
|
|
|
updatedDoc, err := s.documentDAO.GetByID(doc.ID)
|
|
if err != nil {
|
|
if dao.IsNotFoundErr(err) {
|
|
return nil, common.CodeDataError, errors.New("Document not found!")
|
|
}
|
|
return nil, common.CodeExceptionError, err
|
|
}
|
|
|
|
return updatedDoc, common.CodeSuccess, nil
|
|
}
|
|
|
|
// checkType reports whether indexType is supported by dataset index APIs.
|
|
func checkType(indexType string) bool {
|
|
haveType := false
|
|
for _, t := range validIndexTypes {
|
|
if indexType == t {
|
|
haveType = true
|
|
}
|
|
}
|
|
return haveType
|
|
}
|
|
|
|
func (s *DatasetService) newRaptorOrGraphRagTask(sampleDoc *entity.Document, taskType string, taskDocID string, queueDocID string, docIDs []string) (*entity.Task, map[string]interface{}, error) {
|
|
if docIDs == nil || len(docIDs) == 0 {
|
|
docIDs = make([]string, 0)
|
|
}
|
|
if !checkIndexTaskType(taskType) {
|
|
return nil, nil, errors.New("type should be graphrag, raptor or mindmap")
|
|
}
|
|
|
|
chunkingConfig, err := s.documentDAO.GetChunkingConfig(sampleDoc.ID)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
hasher := xxhash.New()
|
|
keys := make([]string, 0, len(chunkingConfig))
|
|
for key := range chunkingConfig {
|
|
keys = append(keys, key)
|
|
}
|
|
sort.Strings(keys)
|
|
for _, key := range keys {
|
|
_, _ = hasher.Write([]byte(key))
|
|
_, _ = hasher.Write([]byte{0})
|
|
v, mErr := json.Marshal(chunkingConfig[key])
|
|
if mErr != nil {
|
|
return nil, nil, mErr
|
|
}
|
|
_, _ = hasher.Write(v)
|
|
_, _ = hasher.Write([]byte{0})
|
|
}
|
|
|
|
taskID := strings.ReplaceAll(uuid.New().String(), "-", "")[:32]
|
|
beginAt := time.Now().Truncate(time.Second)
|
|
progressMsg := beginAt.Format("15:04:05") + " created task " + taskType
|
|
|
|
for _, field := range []interface{}{taskDocID, maximumTaskPageNumber, maximumTaskPageNumber, taskType} {
|
|
_, _ = hasher.Write([]byte(fmt.Sprint(field)))
|
|
}
|
|
digest := fmt.Sprintf("%016x", hasher.Sum64())
|
|
task := &entity.Task{
|
|
ID: taskID,
|
|
DocID: taskDocID,
|
|
FromPage: maximumTaskPageNumber,
|
|
ToPage: maximumTaskPageNumber,
|
|
TaskType: taskType,
|
|
ProgressMsg: &progressMsg,
|
|
BeginAt: &beginAt,
|
|
Digest: &digest,
|
|
}
|
|
|
|
queueMessage := map[string]interface{}{
|
|
"id": taskID,
|
|
"doc_id": queueDocID,
|
|
"from_page": maximumTaskPageNumber,
|
|
"to_page": maximumTaskPageNumber,
|
|
"task_type": taskType,
|
|
"progress_msg": progressMsg,
|
|
"begin_at": beginAt.Format("2006-01-02 15:04:05"),
|
|
"digest": digest,
|
|
"doc_ids": docIDs,
|
|
}
|
|
|
|
return task, queueMessage, nil
|
|
}
|
|
|
|
func createDatasetIndexTaskInTx(tx *gorm.DB, task *entity.Task, queueDocID string) (*entity.Document, error) {
|
|
if task == nil {
|
|
return nil, errors.New("task is required")
|
|
}
|
|
if err := tx.Create(task).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if queueDocID == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
var document entity.Document
|
|
err := tx.Select("id", "progress_msg", "process_begin_at").Where("id = ?", queueDocID).First(&document).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
beginAt := time.Now().Truncate(time.Second)
|
|
if task.BeginAt != nil {
|
|
beginAt = *task.BeginAt
|
|
}
|
|
if err := tx.Model(&entity.Document{}).Where("id = ?", queueDocID).Updates(map[string]interface{}{
|
|
"progress_msg": "Task is queued...",
|
|
"process_begin_at": beginAt,
|
|
}).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &document, nil
|
|
}
|
|
|
|
func enqueueDatasetIndexTask(priority int, queueMessage map[string]interface{}) error {
|
|
redisClient := redisengine.Get()
|
|
if redisClient == nil || !redisClient.QueueProduct(datasetIndexQueueName(priority), queueMessage) {
|
|
return errors.New("Can't access Redis. Please check the Redis' status")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func cleanupFailedDatasetIndexTask(taskID string, updatedDocument *entity.Document, kbID string, indexType string) error {
|
|
return dao.DB.Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Unscoped().Where("id = ?", taskID).Delete(&entity.Task{}).Error; err != nil {
|
|
return fmt.Errorf("delete task %s: %w", taskID, err)
|
|
}
|
|
|
|
if column := datasetIndexTaskIDColumn(indexType); kbID != "" && column != "" {
|
|
if err := tx.Model(&entity.Knowledgebase{}).Where("id = ? AND "+column+" = ?", kbID, taskID).Update(column, nil).Error; err != nil {
|
|
return fmt.Errorf("clear dataset task id %s: %w", taskID, err)
|
|
}
|
|
}
|
|
|
|
if updatedDocument == nil {
|
|
return nil
|
|
}
|
|
|
|
return tx.Model(&entity.Document{}).Where("id = ?", updatedDocument.ID).Updates(map[string]interface{}{
|
|
"progress_msg": updatedDocument.ProgressMsg,
|
|
"process_begin_at": updatedDocument.ProcessBeginAt,
|
|
}).Error
|
|
})
|
|
}
|
|
|
|
func datasetIndexTaskIDColumn(indexType string) string {
|
|
switch indexType {
|
|
case "graph":
|
|
return "graphrag_task_id"
|
|
case "raptor":
|
|
return "raptor_task_id"
|
|
case "mindmap":
|
|
return "mindmap_task_id"
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func datasetIndexTaskFinishAtColumn(indexType string) string {
|
|
switch indexType {
|
|
case "graph":
|
|
return "graphrag_task_finish_at"
|
|
case "raptor":
|
|
return "raptor_task_finish_at"
|
|
case "mindmap":
|
|
return "mindmap_task_finish_at"
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func checkIndexTaskType(taskType string) bool {
|
|
switch taskType {
|
|
case "graphrag", "raptor", "mindmap":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func datasetIndexTaskID(kb *entity.Knowledgebase, indexType string) string {
|
|
if kb == nil {
|
|
return ""
|
|
}
|
|
switch indexType {
|
|
case "graph":
|
|
if kb.GraphragTaskID != nil {
|
|
return *kb.GraphragTaskID
|
|
}
|
|
case "raptor":
|
|
if kb.RaptorTaskID != nil {
|
|
return *kb.RaptorTaskID
|
|
}
|
|
case "mindmap":
|
|
if kb.MindmapTaskID != nil {
|
|
return *kb.MindmapTaskID
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func datasetIndexTaskIDUpdate(indexType, taskID string) map[string]interface{} {
|
|
switch indexType {
|
|
case "graph":
|
|
return map[string]interface{}{"graphrag_task_id": taskID}
|
|
case "raptor":
|
|
return map[string]interface{}{"raptor_task_id": taskID}
|
|
case "mindmap":
|
|
return map[string]interface{}{"mindmap_task_id": taskID}
|
|
default:
|
|
return map[string]interface{}{}
|
|
}
|
|
}
|
|
|
|
func datasetIndexTaskIDs(kb *entity.Knowledgebase) []string {
|
|
if kb == nil {
|
|
return nil
|
|
}
|
|
|
|
taskIDs := make([]string, 0, 3)
|
|
for _, taskID := range []*string{kb.GraphragTaskID, kb.RaptorTaskID, kb.MindmapTaskID} {
|
|
if taskID != nil && *taskID != "" {
|
|
taskIDs = append(taskIDs, *taskID)
|
|
}
|
|
}
|
|
return common.Deduplicate(taskIDs)
|
|
}
|
|
|
|
func datasetIndexQueueName(priority int) string {
|
|
return fmt.Sprintf("%s.%d.common", serverQueueNamePrefix, priority)
|
|
}
|
|
|
|
func interfaceSlice(items ...string) []interface{} {
|
|
result := make([]interface{}, len(items))
|
|
for i, item := range items {
|
|
result[i] = item
|
|
}
|
|
return result
|
|
}
|
|
|
|
func clearGraphPhaseMarkers(redisClient *redisengine.RedisClient, datasetID string) {
|
|
if redisClient == nil || datasetID == "" {
|
|
return
|
|
}
|
|
for _, phase := range []string{graphPhaseResolutionDone, graphPhaseCommunityDone} {
|
|
if !redisClient.Delete(fmt.Sprintf("graphrag:phase:%s:%s", datasetID, phase)) {
|
|
common.Warn("Failed to clear GraphRAG phase marker", zap.String("dataset_id", datasetID), zap.String("phase", phase))
|
|
}
|
|
}
|
|
}
|
|
|
|
// RunIndex Run an indexing task (graph/raptor/mindmap) for a dataset.
|
|
func (s *DatasetService) RunIndex(userID, datasetID, indexType string) (map[string]interface{}, common.ErrorCode, error) {
|
|
if !checkType(indexType) {
|
|
return nil, common.CodeDataError, fmt.Errorf("Invalid index type '%s'. Must be one of %v", indexType, validIndexTypes)
|
|
}
|
|
|
|
if datasetID == "" {
|
|
return nil, common.CodeDataError, errors.New(`Lack of "Dataset ID"`)
|
|
}
|
|
if !s.kbDAO.Accessible(datasetID, userID) {
|
|
return nil, common.CodeDataError, errors.New("No authorization.")
|
|
}
|
|
|
|
kb, err := s.kbDAO.GetByID(datasetID)
|
|
if err != nil {
|
|
if dao.IsNotFoundErr(err) {
|
|
return nil, common.CodeDataError, errors.New("Invalid Dataset ID")
|
|
}
|
|
return nil, common.CodeDataError, errors.New("Internal server error")
|
|
}
|
|
|
|
taskType := indexTypeToTaskType[indexType]
|
|
displayName := indexTypeToDisplayName[indexType]
|
|
|
|
documents, code, err := s.getDocumentsByDatasetForIndex(datasetID)
|
|
if err != nil {
|
|
return nil, code, err
|
|
}
|
|
_ = documents
|
|
|
|
sampleDocument := documents[0]
|
|
documentIDs := make([]string, len(documents))
|
|
|
|
for i, doc := range documents {
|
|
documentIDs[i] = doc.ID
|
|
}
|
|
|
|
task, queueMessage, err := s.newRaptorOrGraphRagTask(sampleDocument, taskType, sampleDocument.ID, graphRaptorQueueDocID, documentIDs)
|
|
if err != nil {
|
|
common.Warn("Failed to build dataset index task", zap.String("dataset_id", datasetID), zap.String("task_type", taskType), zap.Error(err))
|
|
return nil, common.CodeDataError, errors.New("Internal server error")
|
|
}
|
|
|
|
var updatedDocument *entity.Document
|
|
var dataErr error
|
|
err = dao.DB.Transaction(func(tx *gorm.DB) error {
|
|
var lockedKB entity.Knowledgebase
|
|
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
Where("id = ? AND status = ?", kb.ID, string(entity.StatusValid)).
|
|
First(&lockedKB).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
existingTaskID := datasetIndexTaskID(&lockedKB, indexType)
|
|
if existingTaskID != "" {
|
|
var existingTask entity.Task
|
|
taskErr := tx.Where("id = ?", existingTaskID).First(&existingTask).Error
|
|
if taskErr != nil {
|
|
if errors.Is(taskErr, gorm.ErrRecordNotFound) {
|
|
} else {
|
|
return taskErr
|
|
}
|
|
} else if existingTask.Progress != 1 && existingTask.Progress != -1 {
|
|
dataErr = fmt.Errorf("Task %s in progress with status %v. A %s Task is already running.", existingTaskID, existingTask.Progress, displayName)
|
|
return dataErr
|
|
}
|
|
}
|
|
|
|
updatedDocument, err = createDatasetIndexTaskInTx(tx, task, graphRaptorQueueDocID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return tx.Model(&entity.Knowledgebase{}).Where("id = ?", lockedKB.ID).Updates(datasetIndexTaskIDUpdate(indexType, task.ID)).Error
|
|
})
|
|
if err != nil {
|
|
if dataErr != nil {
|
|
return nil, common.CodeDataError, dataErr
|
|
}
|
|
common.Warn("Failed to create dataset index task", zap.String("dataset_id", datasetID), zap.String("task_type", taskType), zap.Error(err))
|
|
return nil, common.CodeDataError, errors.New("Internal server error")
|
|
}
|
|
|
|
if err := enqueueDatasetIndexTask(0, queueMessage); err != nil {
|
|
if cleanupErr := cleanupFailedDatasetIndexTask(task.ID, updatedDocument, kb.ID, indexType); cleanupErr != nil {
|
|
err = errors.Join(err, cleanupErr)
|
|
}
|
|
common.Warn("Failed to queue dataset index task", zap.String("dataset_id", datasetID), zap.String("task_type", taskType), zap.Error(err))
|
|
return nil, common.CodeDataError, errors.New("Internal server error")
|
|
}
|
|
|
|
return map[string]interface{}{"task_id": task.ID}, common.CodeSuccess, nil
|
|
}
|
|
|
|
func (s *DatasetService) getDocumentsByDatasetForIndex(datasetID string) ([]*entity.Document, common.ErrorCode, error) {
|
|
documents, _, err := s.documentDAO.GetByKBID(datasetID)
|
|
if err != nil {
|
|
common.Warn("Failed to load dataset documents for index", zap.String("dataset_id", datasetID), zap.Error(err))
|
|
return nil, common.CodeDataError, errors.New("Internal server error")
|
|
}
|
|
if len(documents) == 0 {
|
|
return nil, common.CodeDataError, fmt.Errorf("No documents in Dataset %s", datasetID)
|
|
}
|
|
return documents, common.CodeSuccess, nil
|
|
}
|
|
|
|
type TraceIndexRequest struct {
|
|
Type string `json:"type" binding:"required"`
|
|
}
|
|
|
|
// TraceIndex Trace an indexing task (graph/raptor/mindmap) for a dataset.
|
|
func (s *DatasetService) TraceIndex(datasetID, userID, indexType string) (*entity.Task, common.ErrorCode, error) {
|
|
if !checkType(indexType) {
|
|
return nil, common.CodeDataError, fmt.Errorf("Invalid index type '%s'. Must be one of %v", indexType, validIndexTypes)
|
|
}
|
|
|
|
if datasetID == "" {
|
|
return nil, common.CodeDataError, errors.New(`Lack of "Dataset ID"`)
|
|
}
|
|
if !s.kbDAO.Accessible(datasetID, userID) {
|
|
return nil, common.CodeDataError, errors.New("No authorization.")
|
|
}
|
|
|
|
kb, err := s.kbDAO.GetByID(datasetID)
|
|
if err != nil {
|
|
if dao.IsNotFoundErr(err) {
|
|
return nil, common.CodeDataError, errors.New("Invalid Dataset ID")
|
|
}
|
|
return nil, common.CodeDataError, errors.New("Internal server error")
|
|
}
|
|
|
|
taskID := datasetIndexTaskID(kb, indexType)
|
|
|
|
var task *entity.Task
|
|
if taskID != "" {
|
|
task, err = s.taskDAO.GetByID(taskID)
|
|
if err != nil {
|
|
if dao.IsNotFoundErr(err) {
|
|
return nil, common.CodeSuccess, nil
|
|
}
|
|
return nil, common.CodeServerError, errors.New("Internal server error")
|
|
}
|
|
if task == nil {
|
|
return nil, common.CodeSuccess, nil
|
|
}
|
|
}
|
|
|
|
return task, common.CodeSuccess, nil
|
|
}
|
|
|
|
type CheckEmbeddingRequest struct {
|
|
EmbeddingID string `json:"embd_id" binding:"required"`
|
|
CheckNum *int `json:"check_num,omitempty"`
|
|
}
|
|
|
|
type EmbeddingCheckSummary struct {
|
|
KbID string `json:"kb_id"`
|
|
Model string `json:"model"`
|
|
Sampled int `json:"sampled"`
|
|
Valid int `json:"valid"`
|
|
AvgCosSim float64 `json:"avg_cos_sim"`
|
|
MinCosSim float64 `json:"min_cos_sim"`
|
|
MaxCosSim float64 `json:"max_cos_sim"`
|
|
MatchMode string `json:"match_mode"`
|
|
}
|
|
|
|
type EmbeddingCheckResult struct {
|
|
ChunkID string `json:"chunk_id"`
|
|
DocID string `json:"doc_id,omitempty"`
|
|
DocName string `json:"doc_name,omitempty"`
|
|
VectorField string `json:"vector_field,omitempty"`
|
|
VectorDim int `json:"vector_dim,omitempty"`
|
|
CosSim float64 `json:"cos_sim,omitempty"`
|
|
Reason string `json:"reason,omitempty"`
|
|
}
|
|
|
|
type EmbeddingCheckResponse struct {
|
|
Summary EmbeddingCheckSummary `json:"summary"`
|
|
Results []EmbeddingCheckResult `json:"results"`
|
|
}
|
|
|
|
type embeddingCheckSample struct {
|
|
ChunkID string
|
|
KbID string
|
|
DocID string
|
|
DocName string
|
|
VectorField string
|
|
Vector []float64
|
|
PageNum interface{}
|
|
Position interface{}
|
|
Top interface{}
|
|
ContentWithWeight string
|
|
QuestionKeywords []string
|
|
}
|
|
|
|
type datasetParsePageRange struct {
|
|
from int64
|
|
to int64
|
|
}
|
|
|
|
// RunEmbedding runs embedding for all documents in a dataset.
|
|
func (s *DatasetService) RunEmbedding(userID, datasetID string) (map[string]interface{}, common.ErrorCode, error) {
|
|
if datasetID == "" {
|
|
return nil, common.CodeDataError, errors.New(`Lack of "Dataset ID"`)
|
|
}
|
|
if !s.kbDAO.Accessible(datasetID, userID) {
|
|
return nil, common.CodeDataError, errors.New("No authorization.")
|
|
}
|
|
|
|
kb, err := s.kbDAO.GetByID(datasetID)
|
|
if err != nil {
|
|
if dao.IsNotFoundErr(err) {
|
|
return nil, common.CodeDataError, errors.New("Invalid Dataset ID")
|
|
}
|
|
return nil, common.CodeServerError, errors.New("Internal server error")
|
|
}
|
|
|
|
documents, _, err := s.documentDAO.GetByKBID(datasetID)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, errors.New("Internal server error")
|
|
}
|
|
if len(documents) == 0 {
|
|
return nil, common.CodeDataError, fmt.Errorf("No documents in Dataset %s", datasetID)
|
|
}
|
|
|
|
tableDoneCountByKB := make(map[string]int64)
|
|
scheduledCount := 0
|
|
for _, doc := range documents {
|
|
if doc == nil {
|
|
continue
|
|
}
|
|
if err := s.runEmbeddingDocument(kb, doc, tableDoneCountByKB); err != nil {
|
|
common.Warn("Failed to schedule dataset embedding document",
|
|
zap.String("datasetID", datasetID),
|
|
zap.String("docID", doc.ID),
|
|
zap.Error(err))
|
|
return nil, common.CodeServerError, errors.New("Internal server error")
|
|
}
|
|
scheduledCount++
|
|
}
|
|
|
|
return map[string]interface{}{
|
|
"scheduled_count": scheduledCount,
|
|
}, common.CodeSuccess, nil
|
|
}
|
|
|
|
func (s *DatasetService) runEmbeddingDocument(kb *entity.Knowledgebase, doc *entity.Document, tableDoneCountByKB map[string]int64) error {
|
|
if doc.PipelineID != nil && strings.TrimSpace(*doc.PipelineID) != "" {
|
|
return s.queueDatasetDataflowTask(kb, doc, strings.TrimSpace(*doc.PipelineID), 0)
|
|
}
|
|
|
|
if doc.ParserID == string(entity.ParserTypeTable) {
|
|
doneCount, ok := tableDoneCountByKB[doc.KbID]
|
|
if !ok {
|
|
count, err := s.countDoneDocuments(doc.KbID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
doneCount = count
|
|
tableDoneCountByKB[doc.KbID] = doneCount
|
|
if doneCount <= 0 {
|
|
if err := s.kbDAO.DeleteFieldMap(doc.KbID); err != nil && !dao.IsNotFoundErr(err) {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
indexName := fmt.Sprintf("ragflow_%s", kb.TenantID)
|
|
if s.docEngine != nil {
|
|
if _, err := s.docEngine.DeleteChunks(context.Background(), map[string]interface{}{"doc_id": doc.ID}, indexName, doc.KbID); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if _, err := s.taskDAO.DeleteByDocIDs([]string{doc.ID}); err != nil {
|
|
return err
|
|
}
|
|
|
|
bucket, objectName, err := NewDocumentService().GetDocumentStorageAddress(doc)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := s.queueDatasetParseTasks(doc, bucket, objectName, 0); err != nil {
|
|
return err
|
|
}
|
|
if err := s.beginDatasetParseDocument(doc.ID); err != nil {
|
|
if _, delErr := s.taskDAO.DeleteByDocIDs([]string{doc.ID}); delErr != nil {
|
|
common.Warn("Failed to clean parse tasks after document state update failure",
|
|
zap.String("docID", doc.ID),
|
|
zap.Error(delErr))
|
|
}
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *DatasetService) queueDatasetDataflowTask(kb *entity.Knowledgebase, doc *entity.Document, flowID string, priority int64) error {
|
|
if _, err := s.taskDAO.DeleteByDocIDs([]string{doc.ID}); err != nil {
|
|
return err
|
|
}
|
|
if err := s.beginDatasetParseDocument(doc.ID); err != nil {
|
|
return err
|
|
}
|
|
|
|
now := time.Now()
|
|
task := &entity.Task{
|
|
ID: common.GenerateUUID(),
|
|
DocID: doc.ID,
|
|
FromPage: 0,
|
|
ToPage: maximumTaskPageNumber,
|
|
TaskType: "dataflow",
|
|
Priority: priority,
|
|
BeginAt: &now,
|
|
Progress: 0,
|
|
}
|
|
if err := s.taskDAO.CreateMany([]*entity.Task{task}); err != nil {
|
|
return err
|
|
}
|
|
|
|
message := datasetParseTaskMessage(task)
|
|
message["task_type"] = task.TaskType
|
|
message["kb_id"] = doc.KbID
|
|
message["tenant_id"] = kb.TenantID
|
|
message["dataflow_id"] = flowID
|
|
message["file"] = nil
|
|
if redisClient := redisengine.Get(); redisClient == nil || !redisClient.QueueProduct(datasetParseQueueName(doc, priority), message) {
|
|
return fmt.Errorf("Can't access Redis. Please check the Redis' status.")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *DatasetService) countDoneDocuments(datasetID string) (int64, error) {
|
|
var count int64
|
|
err := dao.GetDB().Model(&entity.Document{}).
|
|
Where("kb_id = ? AND run = ?", datasetID, string(entity.TaskStatusDone)).
|
|
Count(&count).Error
|
|
return count, err
|
|
}
|
|
|
|
func (s *DatasetService) queueDatasetParseTasks(doc *entity.Document, bucket, objectName string, priority int64) error {
|
|
tasks, err := s.buildDatasetParseTasks(doc, bucket, objectName, priority)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(tasks) == 0 {
|
|
return nil
|
|
}
|
|
if err := s.taskDAO.CreateMany(tasks); err != nil {
|
|
return err
|
|
}
|
|
queueName := datasetParseQueueName(doc, priority)
|
|
for _, task := range tasks {
|
|
if task.Progress >= 1 {
|
|
continue
|
|
}
|
|
if redisClient := redisengine.Get(); redisClient == nil || !redisClient.QueueProduct(queueName, datasetParseTaskMessage(task)) {
|
|
if _, delErr := s.taskDAO.DeleteByDocIDs([]string{doc.ID}); delErr != nil {
|
|
common.Warn("Failed to clean parse tasks after Redis enqueue failure",
|
|
zap.String("docID", doc.ID),
|
|
zap.Error(delErr))
|
|
}
|
|
return fmt.Errorf("Can't access Redis. Please check the Redis' status.")
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *DatasetService) buildDatasetParseTasks(doc *entity.Document, bucket, objectName string, priority int64) ([]*entity.Task, error) {
|
|
ranges, err := datasetParseTaskRanges(doc, bucket, objectName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
now := time.Now()
|
|
tasks := make([]*entity.Task, 0, len(ranges))
|
|
for _, pageRange := range ranges {
|
|
progressMsg := ""
|
|
digest := datasetParseTaskDigest(doc, pageRange.from, pageRange.to)
|
|
chunkIDs := ""
|
|
tasks = append(tasks, &entity.Task{
|
|
ID: common.GenerateUUID(),
|
|
DocID: doc.ID,
|
|
FromPage: pageRange.from,
|
|
ToPage: pageRange.to,
|
|
TaskType: "",
|
|
Priority: priority,
|
|
BeginAt: &now,
|
|
Progress: 0,
|
|
ProgressMsg: &progressMsg,
|
|
Digest: &digest,
|
|
ChunkIDs: &chunkIDs,
|
|
})
|
|
}
|
|
return tasks, nil
|
|
}
|
|
|
|
func (s *DatasetService) beginDatasetParseDocument(docID string) error {
|
|
now := time.Now()
|
|
return dao.GetDB().Model(&entity.Document{}).Where("id = ?", docID).Updates(map[string]interface{}{
|
|
"progress_msg": "Task is queued...",
|
|
"process_begin_at": now,
|
|
"progress": rand.Float64() * 0.01,
|
|
"run": string(entity.TaskStatusRunning),
|
|
"chunk_num": 0,
|
|
"token_num": 0,
|
|
}).Error
|
|
}
|
|
|
|
// CheckEmbedding checks whether a new embedding model is compatible with stored vectors.
|
|
func (s *DatasetService) CheckEmbedding(userID, datasetID string, req *CheckEmbeddingRequest) (*EmbeddingCheckResponse, common.ErrorCode, error) {
|
|
if datasetID == "" {
|
|
return nil, common.CodeDataError, errors.New(`Lack of "Dataset ID"`)
|
|
}
|
|
if !s.kbDAO.Accessible(datasetID, userID) {
|
|
return nil, common.CodeDataError, errors.New("No authorization.")
|
|
}
|
|
|
|
kb, err := s.kbDAO.GetByID(datasetID)
|
|
if err != nil {
|
|
if dao.IsNotFoundErr(err) {
|
|
return nil, common.CodeDataError, errors.New("Invalid Dataset ID")
|
|
}
|
|
return nil, common.CodeServerError, errors.New("Internal server error")
|
|
}
|
|
|
|
if req == nil || strings.TrimSpace(req.EmbeddingID) == "" {
|
|
return nil, common.CodeDataError, errors.New("`embd_id` is required.")
|
|
}
|
|
embeddingID := strings.TrimSpace(req.EmbeddingID)
|
|
if ok, message := s.verifyEmbeddingAvailability(embeddingID, userID); !ok {
|
|
return nil, common.CodeDataError, errors.New(message)
|
|
}
|
|
if s.docEngine == nil {
|
|
return nil, common.CodeServerError, errors.New("doc engine not initialized")
|
|
}
|
|
|
|
driver, modelName, apiConfig, maxTokens, err := NewModelProviderService().GetModelConfigFromProviderInstance(kb.TenantID, entity.ModelTypeEmbedding, embeddingID)
|
|
if err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
embeddingModel := models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens)
|
|
|
|
checkNum := defaultEmbeddingCheckNum
|
|
if req.CheckNum != nil {
|
|
checkNum = *req.CheckNum
|
|
}
|
|
if checkNum <= 0 {
|
|
checkNum = defaultEmbeddingCheckNum
|
|
}
|
|
|
|
samples, err := s.sampleRandomChunksWithVectors(context.Background(), kb.TenantID, datasetID, checkNum)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, err
|
|
}
|
|
|
|
results := make([]EmbeddingCheckResult, 0, len(samples))
|
|
effectiveSimilarities := make([]float64, 0, len(samples))
|
|
matchMode := "content_only"
|
|
for _, sample := range samples {
|
|
title := sample.DocName
|
|
if strings.TrimSpace(title) == "" {
|
|
title = "Title"
|
|
}
|
|
|
|
textInput := strings.Join(sample.QuestionKeywords, "\n")
|
|
if strings.TrimSpace(textInput) == "" {
|
|
textInput = sample.ContentWithWeight
|
|
}
|
|
textInput = datasetCleanEmbeddingText(textInput)
|
|
if textInput == "" {
|
|
results = append(results, EmbeddingCheckResult{ChunkID: sample.ChunkID, Reason: "no_text"})
|
|
continue
|
|
}
|
|
if len(sample.Vector) == 0 {
|
|
results = append(results, EmbeddingCheckResult{ChunkID: sample.ChunkID, Reason: "no_stored_vector"})
|
|
continue
|
|
}
|
|
|
|
vectors, err := datasetEncodeEmbedding(embeddingModel, []string{title, textInput})
|
|
if err != nil {
|
|
return nil, common.CodeDataError, fmt.Errorf("Embedding failure. %w", err)
|
|
}
|
|
if len(vectors) < 2 {
|
|
return nil, common.CodeDataError, errors.New("Embedding failure. embedding response is incomplete")
|
|
}
|
|
if len(vectors[1]) != len(sample.Vector) {
|
|
return nil, common.CodeDataError, fmt.Errorf("Embedding failure. The dimension (%d) of given embedding model is different from the original (%d)", len(vectors[1]), len(sample.Vector))
|
|
}
|
|
|
|
simContent := datasetCosSim(vectors[1], sample.Vector)
|
|
simMix := datasetCosSim(datasetMixVectors(vectors[0], vectors[1], 0.1), sample.Vector)
|
|
sim := simContent
|
|
matchMode = "content_only"
|
|
if simMix > sim {
|
|
sim = simMix
|
|
matchMode = "title+content"
|
|
}
|
|
sim = datasetRoundFloat(sim, 6)
|
|
|
|
effectiveSimilarities = append(effectiveSimilarities, sim)
|
|
results = append(results, EmbeddingCheckResult{
|
|
ChunkID: sample.ChunkID,
|
|
DocID: sample.DocID,
|
|
DocName: sample.DocName,
|
|
VectorField: sample.VectorField,
|
|
VectorDim: len(sample.Vector),
|
|
CosSim: sim,
|
|
})
|
|
}
|
|
|
|
summary := datasetEmbeddingCheckSummary(datasetID, embeddingID, len(samples), effectiveSimilarities, matchMode)
|
|
response := &EmbeddingCheckResponse{Summary: summary, Results: results}
|
|
if len(effectiveSimilarities) == 0 {
|
|
return nil, common.CodeDataError, errors.New("No embedded chunks are available to compare.")
|
|
}
|
|
if summary.AvgCosSim >= 0.9 {
|
|
return response, common.CodeSuccess, nil
|
|
}
|
|
return response, common.CodeNotEffective, errors.New("Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.")
|
|
}
|
|
|
|
func (s *DatasetService) sampleRandomChunksWithVectors(ctx context.Context, tenantID, datasetID string, n int) ([]embeddingCheckSample, error) {
|
|
indexName := fmt.Sprintf("ragflow_%s", tenantID)
|
|
totalResult, err := s.docEngine.Search(ctx, &enginetypes.SearchRequest{
|
|
IndexNames: []string{indexName},
|
|
KbIDs: []string{datasetID},
|
|
Offset: 0,
|
|
Limit: 1,
|
|
Filter: map[string]interface{}{
|
|
"kb_id": datasetID,
|
|
"available_int": 1,
|
|
},
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if totalResult == nil || totalResult.Total <= 0 {
|
|
return []embeddingCheckSample{}, nil
|
|
}
|
|
|
|
total := int(totalResult.Total)
|
|
// Cap n to a sane upper bound so a hostile caller can't force a
|
|
// huge preallocation. The downstream `samples` slice is sized
|
|
// directly from n.
|
|
const maxEmbeddingSamples = 1024
|
|
if n < 0 {
|
|
return nil, fmt.Errorf("invalid sample size: %d", n)
|
|
}
|
|
if n > maxEmbeddingSamples {
|
|
n = maxEmbeddingSamples
|
|
}
|
|
if n > total {
|
|
n = total
|
|
}
|
|
limit := total
|
|
if limit > 1000 {
|
|
limit = 1000
|
|
}
|
|
if n > limit {
|
|
n = limit
|
|
}
|
|
offsets := rand.Perm(limit)
|
|
offsets = offsets[:n]
|
|
sort.Ints(offsets)
|
|
|
|
baseFields := []string{"docnm_kwd", "doc_id", "content_with_weight", "page_num_int", "position_int", "top_int"}
|
|
// codeql[go/uncontrolled-allocation-size] False positive: n is
|
|
// bounded to maxEmbeddingSamples (1024) at the top of this
|
|
// function, so the samples slice cannot exceed ~1 MiB
|
|
// (embeddingCheckSample is a small struct).
|
|
samples := make([]embeddingCheckSample, 0, n)
|
|
for _, offset := range offsets {
|
|
searchResult, err := s.docEngine.Search(ctx, &enginetypes.SearchRequest{
|
|
IndexNames: []string{indexName},
|
|
KbIDs: []string{datasetID},
|
|
Offset: offset,
|
|
Limit: 1,
|
|
SelectFields: baseFields,
|
|
Filter: map[string]interface{}{
|
|
"kb_id": datasetID,
|
|
"available_int": 1,
|
|
},
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if searchResult == nil || len(searchResult.Chunks) == 0 {
|
|
continue
|
|
}
|
|
chunkID := datasetChunkID(searchResult.Chunks[0])
|
|
if chunkID == "" {
|
|
continue
|
|
}
|
|
fullChunk, err := s.docEngine.GetChunk(ctx, indexName, chunkID, []string{datasetID})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
chunkMap := datasetMap(fullChunk)
|
|
if len(chunkMap) == 0 {
|
|
continue
|
|
}
|
|
vectorField := datasetGuessVecField(chunkMap)
|
|
vector := datasetAsFloatVec(chunkMap[vectorField])
|
|
samples = append(samples, embeddingCheckSample{
|
|
ChunkID: chunkID,
|
|
KbID: datasetID,
|
|
DocID: datasetString(chunkMap["doc_id"]),
|
|
DocName: datasetString(chunkMap["docnm_kwd"]),
|
|
VectorField: vectorField,
|
|
Vector: vector,
|
|
PageNum: chunkMap["page_num_int"],
|
|
Position: chunkMap["position_int"],
|
|
Top: chunkMap["top_int"],
|
|
ContentWithWeight: datasetString(chunkMap["content_with_weight"]),
|
|
QuestionKeywords: datasetStringSlice(chunkMap["question_kwd"]),
|
|
})
|
|
}
|
|
return samples, nil
|
|
}
|
|
|
|
func datasetGuessVecField(src map[string]interface{}) string {
|
|
for k := range src {
|
|
if strings.HasSuffix(k, "_vec") {
|
|
return k
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func datasetAsFloatVec(v interface{}) []float64 {
|
|
if v == nil {
|
|
return []float64{}
|
|
}
|
|
switch val := v.(type) {
|
|
case string:
|
|
parts := strings.Split(val, "\t")
|
|
res := make([]float64, 0, len(parts))
|
|
for _, p := range parts {
|
|
if p == "" {
|
|
continue
|
|
}
|
|
f, err := strconv.ParseFloat(p, 64)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
res = append(res, f)
|
|
}
|
|
return res
|
|
case []float64:
|
|
return val
|
|
case []float32:
|
|
res := make([]float64, len(val))
|
|
for i, x := range val {
|
|
res[i] = float64(x)
|
|
}
|
|
return res
|
|
case []int:
|
|
res := make([]float64, len(val))
|
|
for i, x := range val {
|
|
res[i] = float64(x)
|
|
}
|
|
return res
|
|
case []interface{}:
|
|
res := make([]float64, 0, len(val))
|
|
for _, x := range val {
|
|
switch n := x.(type) {
|
|
case float64:
|
|
res = append(res, n)
|
|
case float32:
|
|
res = append(res, float64(n))
|
|
case int:
|
|
res = append(res, float64(n))
|
|
case string:
|
|
f, err := strconv.ParseFloat(n, 64)
|
|
if err == nil {
|
|
res = append(res, f)
|
|
}
|
|
}
|
|
}
|
|
return res
|
|
}
|
|
return []float64{}
|
|
}
|
|
|
|
func datasetCosSim(a, b []float64) float64 {
|
|
if len(a) == 0 || len(b) == 0 {
|
|
return 0
|
|
}
|
|
var dot, na, nb float64
|
|
n := len(a)
|
|
if len(b) < n {
|
|
n = len(b)
|
|
}
|
|
for i := 0; i < n; i++ {
|
|
dot += a[i] * b[i]
|
|
}
|
|
for _, x := range a {
|
|
na += x * x
|
|
}
|
|
for _, x := range b {
|
|
nb += x * x
|
|
}
|
|
|
|
if na == 0 || nb == 0 {
|
|
return 0
|
|
}
|
|
return dot / (math.Sqrt(na) * math.Sqrt(nb))
|
|
}
|
|
|
|
func datasetCleanEmbeddingText(s string) string {
|
|
re := regexp.MustCompile(`</?(table|td|caption|tr|th)( [^<>]{0,12})?>`)
|
|
return strings.TrimSpace(re.ReplaceAllString(s, " "))
|
|
}
|
|
|
|
func datasetEncodeEmbedding(embeddingModel *models.EmbeddingModel, texts []string) ([][]float64, error) {
|
|
embeddingConfig := &models.EmbeddingConfig{Dimension: 0}
|
|
embeddings, err := embeddingModel.ModelDriver.Embed(embeddingModel.ModelName, texts, embeddingModel.APIConfig, embeddingConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
vectors := make([][]float64, len(embeddings))
|
|
for i, embedding := range embeddings {
|
|
vectors[i] = embedding.Embedding
|
|
}
|
|
return vectors, nil
|
|
}
|
|
|
|
func datasetMixVectors(titleVector, contentVector []float64, titleWeight float64) []float64 {
|
|
if len(titleVector) != len(contentVector) {
|
|
return contentVector
|
|
}
|
|
mixed := make([]float64, len(contentVector))
|
|
contentWeight := 1 - titleWeight
|
|
for i := range contentVector {
|
|
mixed[i] = titleWeight*titleVector[i] + contentWeight*contentVector[i]
|
|
}
|
|
return mixed
|
|
}
|
|
|
|
func datasetEmbeddingCheckSummary(datasetID, embeddingID string, sampled int, similarities []float64, matchMode string) EmbeddingCheckSummary {
|
|
summary := EmbeddingCheckSummary{
|
|
KbID: datasetID,
|
|
Model: embeddingID,
|
|
Sampled: sampled,
|
|
Valid: len(similarities),
|
|
MatchMode: matchMode,
|
|
}
|
|
if len(similarities) == 0 {
|
|
return summary
|
|
}
|
|
minValue := similarities[0]
|
|
maxValue := similarities[0]
|
|
total := 0.0
|
|
for _, value := range similarities {
|
|
total += value
|
|
if value < minValue {
|
|
minValue = value
|
|
}
|
|
if value > maxValue {
|
|
maxValue = value
|
|
}
|
|
}
|
|
summary.AvgCosSim = datasetRoundFloat(total/float64(len(similarities)), 6)
|
|
summary.MinCosSim = datasetRoundFloat(minValue, 6)
|
|
summary.MaxCosSim = datasetRoundFloat(maxValue, 6)
|
|
return summary
|
|
}
|
|
|
|
func datasetRoundFloat(value float64, places int) float64 {
|
|
factor := math.Pow10(places)
|
|
return math.Round(value*factor) / factor
|
|
}
|
|
|
|
func datasetChunkID(chunk map[string]interface{}) string {
|
|
for _, key := range []string{"id", "_id"} {
|
|
if value := datasetString(chunk[key]); value != "" {
|
|
return value
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func datasetMap(value interface{}) map[string]interface{} {
|
|
switch typedValue := value.(type) {
|
|
case map[string]interface{}:
|
|
return typedValue
|
|
default:
|
|
return map[string]interface{}{}
|
|
}
|
|
}
|
|
|
|
func datasetString(value interface{}) string {
|
|
switch typedValue := value.(type) {
|
|
case string:
|
|
return typedValue
|
|
case fmt.Stringer:
|
|
return typedValue.String()
|
|
case nil:
|
|
return ""
|
|
default:
|
|
return fmt.Sprint(typedValue)
|
|
}
|
|
}
|
|
|
|
func datasetStringSlice(value interface{}) []string {
|
|
switch typedValue := value.(type) {
|
|
case []string:
|
|
return typedValue
|
|
case []interface{}:
|
|
values := make([]string, 0, len(typedValue))
|
|
for _, item := range typedValue {
|
|
if s := strings.TrimSpace(datasetString(item)); s != "" {
|
|
values = append(values, s)
|
|
}
|
|
}
|
|
return values
|
|
case string:
|
|
if typedValue == "" {
|
|
return nil
|
|
}
|
|
return []string{typedValue}
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func datasetParseQueueName(doc *entity.Document, priority int64) string {
|
|
suffix := "common"
|
|
if doc.ParserID == string(entity.ParserTypeResume) {
|
|
suffix = "resume"
|
|
}
|
|
return fmt.Sprintf("%s.%d.%s", serverQueueNamePrefix, priority, suffix)
|
|
}
|
|
|
|
func datasetParseTaskMessage(task *entity.Task) map[string]interface{} {
|
|
beginAt := ""
|
|
if task.BeginAt != nil {
|
|
beginAt = task.BeginAt.Format("2006-01-02 15:04:05")
|
|
}
|
|
digest := ""
|
|
if task.Digest != nil {
|
|
digest = *task.Digest
|
|
}
|
|
return map[string]interface{}{
|
|
"id": task.ID,
|
|
"doc_id": task.DocID,
|
|
"from_page": task.FromPage,
|
|
"to_page": task.ToPage,
|
|
"progress": task.Progress,
|
|
"priority": task.Priority,
|
|
"begin_at": beginAt,
|
|
"digest": digest,
|
|
}
|
|
}
|
|
|
|
func datasetParseTaskDigest(doc *entity.Document, fromPage, toPage int64) string {
|
|
hasher := xxhash.New()
|
|
config := datasetChunkingConfigForDigest(doc)
|
|
keys := make([]string, 0, len(config))
|
|
for key := range config {
|
|
keys = append(keys, key)
|
|
}
|
|
sort.Strings(keys)
|
|
for _, key := range keys {
|
|
hasher.WriteString(datasetStableString(config[key]))
|
|
}
|
|
hasher.WriteString(doc.ID)
|
|
hasher.WriteString(strconv.FormatInt(fromPage, 10))
|
|
hasher.WriteString(strconv.FormatInt(toPage, 10))
|
|
return fmt.Sprintf("%x", hasher.Sum64())
|
|
}
|
|
|
|
func datasetChunkingConfigForDigest(doc *entity.Document) map[string]interface{} {
|
|
return map[string]interface{}{
|
|
"doc_id": doc.ID,
|
|
"kb_id": doc.KbID,
|
|
"parser_id": doc.ParserID,
|
|
"parser_config": datasetCopyParserConfigForDigest(doc.ParserConfig),
|
|
}
|
|
}
|
|
|
|
func datasetCopyParserConfigForDigest(config map[string]interface{}) map[string]interface{} {
|
|
copied := make(map[string]interface{}, len(config))
|
|
for key, value := range config {
|
|
if key == "raptor" || key == "graphrag" {
|
|
continue
|
|
}
|
|
copied[key] = value
|
|
}
|
|
return copied
|
|
}
|
|
|
|
func datasetStableString(value interface{}) string {
|
|
binary, err := json.Marshal(value)
|
|
if err != nil {
|
|
return fmt.Sprint(value)
|
|
}
|
|
return string(binary)
|
|
}
|
|
|
|
func datasetParseTaskRanges(doc *entity.Document, bucket, objectName string) ([]datasetParsePageRange, error) {
|
|
if doc.Type == "pdf" {
|
|
return datasetPDFParseTaskRanges(doc, bucket, objectName)
|
|
}
|
|
if doc.ParserID == string(entity.ParserTypeTable) {
|
|
return datasetTableParseTaskRanges(doc, bucket, objectName)
|
|
}
|
|
return []datasetParsePageRange{{from: 0, to: maximumTaskPageNumber}}, nil
|
|
}
|
|
|
|
func datasetPDFParseTaskRanges(doc *entity.Document, bucket, objectName string) ([]datasetParsePageRange, error) {
|
|
binary, err := datasetStorageBinary(bucket, objectName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pages := datasetEstimatePDFPageCount(binary)
|
|
pageSize := int64(datasetParserConfigInt(doc.ParserConfig, "task_page_size", 12))
|
|
if doc.ParserID == string(entity.ParserTypePaper) {
|
|
pageSize = int64(datasetParserConfigInt(doc.ParserConfig, "task_page_size", 22))
|
|
}
|
|
if doc.ParserID == string(entity.ParserTypeOne) ||
|
|
doc.ParserID == string(entity.ParserTypeKG) ||
|
|
datasetParserConfigString(doc.ParserConfig, "layout_recognize", "DeepDOC") != "DeepDOC" ||
|
|
datasetParserConfigBool(doc.ParserConfig, "toc_extraction", false) {
|
|
pageSize = maximumTaskPageNumber
|
|
}
|
|
if pageSize <= 0 {
|
|
pageSize = 12
|
|
}
|
|
|
|
pageRanges := datasetParserConfigPageRanges(doc.ParserConfig)
|
|
ranges := make([]datasetParsePageRange, 0)
|
|
for _, configuredRange := range pageRanges {
|
|
start := configuredRange.from - 1
|
|
if start < 0 {
|
|
start = 0
|
|
}
|
|
end := configuredRange.to - 1
|
|
if pages >= 0 && end > pages {
|
|
end = pages
|
|
}
|
|
for page := start; page < end; page += pageSize {
|
|
to := page + pageSize
|
|
if to > end {
|
|
to = end
|
|
}
|
|
ranges = append(ranges, datasetParsePageRange{from: page, to: to})
|
|
}
|
|
}
|
|
if len(ranges) == 0 {
|
|
ranges = append(ranges, datasetParsePageRange{from: 0, to: maximumTaskPageNumber})
|
|
}
|
|
return ranges, nil
|
|
}
|
|
|
|
func datasetTableParseTaskRanges(doc *entity.Document, bucket, objectName string) ([]datasetParsePageRange, error) {
|
|
binary, err := datasetStorageBinary(bucket, objectName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
rows := datasetEstimateTableRowCount(datasetDocName(doc), binary)
|
|
if rows <= 0 {
|
|
return []datasetParsePageRange{{from: 0, to: maximumTaskPageNumber}}, nil
|
|
}
|
|
ranges := make([]datasetParsePageRange, 0, (rows+2999)/3000)
|
|
for row := int64(0); row < int64(rows); row += 3000 {
|
|
to := row + 3000
|
|
if to > int64(rows) {
|
|
to = int64(rows)
|
|
}
|
|
ranges = append(ranges, datasetParsePageRange{from: row, to: to})
|
|
}
|
|
return ranges, nil
|
|
}
|
|
|
|
func datasetStorageBinary(bucket, objectName string) ([]byte, error) {
|
|
storageImpl := storage.GetStorageFactory().GetStorage()
|
|
if storageImpl == nil {
|
|
return nil, fmt.Errorf("storage not initialized")
|
|
}
|
|
return storageImpl.Get(bucket, objectName)
|
|
}
|
|
|
|
func datasetDocName(doc *entity.Document) string {
|
|
if doc == nil || doc.Name == nil {
|
|
return ""
|
|
}
|
|
return *doc.Name
|
|
}
|
|
|
|
func datasetParserConfigInt(config map[string]interface{}, key string, fallback int) int {
|
|
value, ok := config[key]
|
|
if !ok || value == nil {
|
|
return fallback
|
|
}
|
|
switch typedValue := value.(type) {
|
|
case int:
|
|
return typedValue
|
|
case int64:
|
|
return int(typedValue)
|
|
case float64:
|
|
return int(typedValue)
|
|
case json.Number:
|
|
if intValue, err := typedValue.Int64(); err == nil {
|
|
return int(intValue)
|
|
}
|
|
case string:
|
|
if intValue, err := strconv.Atoi(strings.TrimSpace(typedValue)); err == nil {
|
|
return intValue
|
|
}
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func datasetParserConfigString(config map[string]interface{}, key, fallback string) string {
|
|
value, ok := config[key]
|
|
if !ok || value == nil {
|
|
return fallback
|
|
}
|
|
if stringValue, ok := value.(string); ok {
|
|
return stringValue
|
|
}
|
|
return fmt.Sprint(value)
|
|
}
|
|
|
|
func datasetParserConfigBool(config map[string]interface{}, key string, fallback bool) bool {
|
|
value, ok := config[key]
|
|
if !ok || value == nil {
|
|
return fallback
|
|
}
|
|
switch typedValue := value.(type) {
|
|
case bool:
|
|
return typedValue
|
|
case string:
|
|
switch strings.ToLower(strings.TrimSpace(typedValue)) {
|
|
case "true", "1", "yes", "on":
|
|
return true
|
|
case "false", "0", "no", "off":
|
|
return false
|
|
}
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func datasetParserConfigPageRanges(config map[string]interface{}) []datasetParsePageRange {
|
|
defaultRanges := []datasetParsePageRange{{from: 1, to: maximumPageNumber}}
|
|
raw, ok := config["pages"]
|
|
if !ok || raw == nil {
|
|
return defaultRanges
|
|
}
|
|
rawRanges, ok := raw.([]interface{})
|
|
if !ok || len(rawRanges) == 0 {
|
|
return defaultRanges
|
|
}
|
|
|
|
ranges := make([]datasetParsePageRange, 0, len(rawRanges))
|
|
for _, rawRange := range rawRanges {
|
|
rangeValues, ok := rawRange.([]interface{})
|
|
if !ok || len(rangeValues) < 2 {
|
|
continue
|
|
}
|
|
from, okFrom := datasetToInt64(rangeValues[0])
|
|
to, okTo := datasetToInt64(rangeValues[1])
|
|
if okFrom && okTo && to > from {
|
|
ranges = append(ranges, datasetParsePageRange{from: from, to: to})
|
|
}
|
|
}
|
|
if len(ranges) == 0 {
|
|
return defaultRanges
|
|
}
|
|
return ranges
|
|
}
|
|
|
|
func datasetToInt64(value interface{}) (int64, bool) {
|
|
switch typedValue := value.(type) {
|
|
case int:
|
|
return int64(typedValue), true
|
|
case int64:
|
|
return typedValue, true
|
|
case float64:
|
|
return int64(typedValue), true
|
|
case json.Number:
|
|
intValue, err := typedValue.Int64()
|
|
return intValue, err == nil
|
|
case string:
|
|
intValue, err := strconv.ParseInt(strings.TrimSpace(typedValue), 10, 64)
|
|
return intValue, err == nil
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|
|
|
|
var datasetPDFPagePattern = regexp.MustCompile(`/Type\s*/Page\b`)
|
|
|
|
func datasetEstimatePDFPageCount(binary []byte) int64 {
|
|
if len(binary) == 0 {
|
|
return 0
|
|
}
|
|
return int64(len(datasetPDFPagePattern.FindAll(binary, -1)))
|
|
}
|
|
|
|
func datasetEstimateTableRowCount(name string, binary []byte) int {
|
|
switch strings.ToLower(filepath.Ext(name)) {
|
|
case ".xlsx":
|
|
if rows, err := datasetCountXLSXRows(binary); err == nil {
|
|
return rows
|
|
}
|
|
case ".csv", ".tsv", ".txt":
|
|
return datasetCountDelimitedRows(name, binary)
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func datasetCountDelimitedRows(name string, binary []byte) int {
|
|
reader := csv.NewReader(bytes.NewReader(binary))
|
|
reader.FieldsPerRecord = -1
|
|
reader.ReuseRecord = true
|
|
if strings.EqualFold(filepath.Ext(name), ".tsv") {
|
|
reader.Comma = '\t'
|
|
}
|
|
rows := 0
|
|
for {
|
|
_, err := reader.Read()
|
|
if err == nil {
|
|
rows++
|
|
continue
|
|
}
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
rows += bytes.Count(binary, []byte{'\n'})
|
|
if len(binary) > 0 && binary[len(binary)-1] != '\n' {
|
|
rows++
|
|
}
|
|
break
|
|
}
|
|
return rows
|
|
}
|
|
|
|
func datasetCountXLSXRows(binary []byte) (int, error) {
|
|
zipReader, err := zip.NewReader(bytes.NewReader(binary), int64(len(binary)))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
maxRows := 0
|
|
for _, file := range zipReader.File {
|
|
if !strings.HasPrefix(file.Name, "xl/worksheets/") || !strings.HasSuffix(file.Name, ".xml") {
|
|
continue
|
|
}
|
|
rows, err := datasetCountWorksheetRows(file)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if rows > maxRows {
|
|
maxRows = rows
|
|
}
|
|
}
|
|
return maxRows, nil
|
|
}
|
|
|
|
func datasetCountWorksheetRows(file *zip.File) (int, error) {
|
|
reader, err := file.Open()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer reader.Close()
|
|
|
|
decoder := xml.NewDecoder(reader)
|
|
rows := 0
|
|
for {
|
|
token, err := decoder.Token()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
start, ok := token.(xml.StartElement)
|
|
if ok && start.Name.Local == "row" {
|
|
rows++
|
|
}
|
|
}
|
|
return rows, nil
|
|
}
|
|
|
|
func (s *DatasetService) DeleteIndex(userID, datasetID, indexType string, wipe bool) (common.ErrorCode, error) {
|
|
if !checkType(indexType) {
|
|
return common.CodeArgumentError, fmt.Errorf("Invalid index type '%s'", indexType)
|
|
}
|
|
|
|
if datasetID == "" {
|
|
return common.CodeDataError, errors.New(`Lack of "Dataset ID"`)
|
|
}
|
|
|
|
if !s.kbDAO.Accessible(datasetID, userID) {
|
|
return common.CodeDataError, errors.New("No authorization.")
|
|
}
|
|
|
|
kb, err := s.kbDAO.GetByID(datasetID)
|
|
if err != nil {
|
|
if dao.IsNotFoundErr(err) {
|
|
return common.CodeDataError, errors.New("Invalid Dataset ID")
|
|
}
|
|
return common.CodeDataError, errors.New("Internal server error")
|
|
}
|
|
|
|
taskIDField := datasetIndexTaskIDColumn(indexType)
|
|
taskFinishAtField := datasetIndexTaskFinishAtColumn(indexType)
|
|
taskID := datasetIndexTaskID(kb, indexType)
|
|
|
|
common.Info("delete_index", zap.String("dataset_id", datasetID), zap.String("index_type", indexType), zap.Bool("wipe", wipe))
|
|
|
|
if taskID != "" {
|
|
redisClient := redisengine.Get()
|
|
if redisClient == nil || !redisClient.Set(fmt.Sprintf("%s-cancel", taskID), "x", 0) {
|
|
common.Warn("Failed to set dataset index cancellation marker", zap.String("dataset_id", datasetID), zap.String("task_id", taskID))
|
|
}
|
|
if err := dao.DB.Unscoped().Where("id = ?", taskID).Delete(&entity.Task{}).Error; err != nil {
|
|
common.Warn("Failed to delete dataset index task", zap.String("dataset_id", datasetID), zap.String("task_id", taskID), zap.Error(err))
|
|
return common.CodeDataError, errors.New("Internal server error")
|
|
}
|
|
}
|
|
|
|
if wipe && indexType == "graph" {
|
|
if s.docEngine == nil {
|
|
return common.CodeServerError, errors.New("Document engine is not initialized")
|
|
}
|
|
indexName := fmt.Sprintf("ragflow_%s", kb.TenantID)
|
|
_, err = s.docEngine.DeleteChunks(context.Background(), map[string]interface{}{
|
|
"knowledge_graph_kwd": interfaceSlice("graph", "subgraph", "entity", "relation", "community_report"),
|
|
"kb_id": datasetID,
|
|
}, indexName, datasetID)
|
|
if err != nil {
|
|
common.Warn("Failed to delete GraphRAG artefacts", zap.String("dataset_id", datasetID), zap.Error(err))
|
|
return common.CodeDataError, errors.New("Internal server error")
|
|
}
|
|
clearGraphPhaseMarkers(redisengine.Get(), datasetID)
|
|
common.Info("delete_index: cleared GraphRAG artefacts and phase markers", zap.String("dataset_id", datasetID))
|
|
} else if wipe && indexType == "raptor" {
|
|
if s.docEngine == nil {
|
|
return common.CodeServerError, errors.New("Document engine is not initialized")
|
|
}
|
|
indexName := fmt.Sprintf("ragflow_%s", kb.TenantID)
|
|
_, err = s.docEngine.DeleteChunks(context.Background(), map[string]interface{}{
|
|
"raptor_kwd": interfaceSlice("raptor"),
|
|
"kb_id": datasetID,
|
|
}, indexName, datasetID)
|
|
if err != nil {
|
|
common.Warn("Failed to delete RAPTOR artefacts", zap.String("dataset_id", datasetID), zap.Error(err))
|
|
return common.CodeDataError, errors.New("Internal server error")
|
|
}
|
|
}
|
|
|
|
if err := dao.DB.Model(&entity.Knowledgebase{}).Where("id = ?", kb.ID).Updates(map[string]interface{}{
|
|
taskIDField: "",
|
|
taskFinishAtField: nil,
|
|
}).Error; err != nil {
|
|
common.Warn("Failed to clear dataset index task fields", zap.String("dataset_id", datasetID), zap.String("index_type", indexType), zap.Error(err))
|
|
return common.CodeDataError, errors.New("Internal server error")
|
|
}
|
|
|
|
return common.CodeSuccess, nil
|
|
}
|
|
|
|
// SearchDatasetsRequest is the request structure for searching chunks across datasets.
|
|
type SearchDatasetsRequest struct {
|
|
DatasetIDs []string `json:"dataset_ids" binding:"required"`
|
|
Question string `json:"question" binding:"required"`
|
|
Page *int `json:"page,omitempty"`
|
|
Size *int `json:"size,omitempty"`
|
|
DocIDs []string `json:"doc_ids,omitempty"`
|
|
UseKG *bool `json:"use_kg,omitempty"`
|
|
TopK *int `json:"top_k,omitempty"`
|
|
CrossLanguages []string `json:"cross_languages,omitempty"`
|
|
SearchID *string `json:"search_id,omitempty"`
|
|
MetadataFilter map[string]interface{} `json:"meta_data_filter,omitempty"`
|
|
RerankID *string `json:"rerank_id,omitempty"`
|
|
Keyword *bool `json:"keyword,omitempty"`
|
|
SimilarityThreshold *float64 `json:"similarity_threshold,omitempty"`
|
|
VectorSimilarityWeight *float64 `json:"vector_similarity_weight,omitempty"`
|
|
}
|
|
|
|
// SearchDatasetsResponse is the response structure for dataset search results.
|
|
type SearchDatasetsResponse struct {
|
|
Chunks []map[string]interface{} `json:"chunks"`
|
|
DocAggs []map[string]interface{} `json:"doc_aggs"`
|
|
Labels *map[string]float64 `json:"labels"`
|
|
Total int64 `json:"total"`
|
|
}
|
|
|
|
// SearchDatasetRequest is the request structure for searching chunks within one dataset.
|
|
type SearchDatasetRequest struct {
|
|
Question string `json:"question"`
|
|
Page *int `json:"page,omitempty"`
|
|
Size *int `json:"size,omitempty"`
|
|
DocIDs []string `json:"doc_ids,omitempty"`
|
|
UseKG *bool `json:"use_kg,omitempty"`
|
|
TopK *int `json:"top_k,omitempty"`
|
|
CrossLanguages []string `json:"cross_languages,omitempty"`
|
|
SearchID *string `json:"search_id,omitempty"`
|
|
MetadataFilter map[string]interface{} `json:"meta_data_filter,omitempty"`
|
|
RerankID *string `json:"rerank_id,omitempty"`
|
|
Keyword *bool `json:"keyword,omitempty"`
|
|
SimilarityThreshold *float64 `json:"similarity_threshold,omitempty"`
|
|
VectorSimilarityWeight *float64 `json:"vector_similarity_weight,omitempty"`
|
|
}
|
|
|
|
// ToSearchDatasetsRequest converts a single-dataset search request into the multi-dataset form.
|
|
func (req *SearchDatasetRequest) ToSearchDatasetsRequest(datasetID string) *SearchDatasetsRequest {
|
|
if req == nil {
|
|
return &SearchDatasetsRequest{DatasetIDs: []string{datasetID}}
|
|
}
|
|
return &SearchDatasetsRequest{
|
|
DatasetIDs: []string{datasetID},
|
|
Question: req.Question,
|
|
Page: req.Page,
|
|
Size: req.Size,
|
|
DocIDs: req.DocIDs,
|
|
UseKG: req.UseKG,
|
|
TopK: req.TopK,
|
|
CrossLanguages: req.CrossLanguages,
|
|
SearchID: req.SearchID,
|
|
MetadataFilter: req.MetadataFilter,
|
|
RerankID: req.RerankID,
|
|
Keyword: req.Keyword,
|
|
SimilarityThreshold: req.SimilarityThreshold,
|
|
VectorSimilarityWeight: req.VectorSimilarityWeight,
|
|
}
|
|
}
|
|
|
|
// SearchDataset searches chunks within one knowledge base based on a question.
|
|
func (s *DatasetService) SearchDataset(datasetID, userID string, req *SearchDatasetRequest) (*SearchDatasetsResponse, error) {
|
|
if datasetID == "" {
|
|
return nil, fmt.Errorf("dataset_id is required")
|
|
}
|
|
return s.SearchDatasets(req.ToSearchDatasetsRequest(datasetID), userID)
|
|
}
|
|
|
|
// SearchDatasets searches chunks across one or more knowledge bases based on a question.
|
|
// It retrieves relevant chunks using embedding and optional reranking, applying filters,
|
|
// cross-language translation, and keyword extraction as configured.
|
|
func (s *DatasetService) SearchDatasets(req *SearchDatasetsRequest, userID string) (*SearchDatasetsResponse, error) {
|
|
if req.Question == "" {
|
|
return nil, fmt.Errorf("question is required")
|
|
}
|
|
if len(req.DatasetIDs) == 0 {
|
|
return nil, fmt.Errorf("dataset_ids is required")
|
|
}
|
|
common.Info("SearchDatasets started", zap.String("userID", userID), zap.Any("datasets", req.DatasetIDs), zap.String("question", req.Question))
|
|
|
|
page := 1
|
|
if req.Page != nil {
|
|
page = *req.Page
|
|
}
|
|
pageSize := 30
|
|
if req.Size != nil {
|
|
pageSize = *req.Size
|
|
}
|
|
useKG := false
|
|
if req.UseKG != nil {
|
|
useKG = *req.UseKG
|
|
}
|
|
similarityThreshold := 0.0
|
|
if req.SimilarityThreshold != nil {
|
|
similarityThreshold = *req.SimilarityThreshold
|
|
}
|
|
vectorSimilarityWeight := 0.3
|
|
if req.VectorSimilarityWeight != nil {
|
|
vectorSimilarityWeight = *req.VectorSimilarityWeight
|
|
}
|
|
topK := 1024
|
|
if req.TopK != nil {
|
|
topK = *req.TopK
|
|
}
|
|
if topK < 1 {
|
|
topK = 1
|
|
} else if topK > 2048 {
|
|
topK = 2048
|
|
}
|
|
keyword := false
|
|
if req.Keyword != nil {
|
|
keyword = *req.Keyword
|
|
}
|
|
searchID := ""
|
|
if req.SearchID != nil {
|
|
searchID = *req.SearchID
|
|
}
|
|
rerankID := ""
|
|
if req.RerankID != nil {
|
|
rerankID = *req.RerankID
|
|
}
|
|
|
|
question := req.Question
|
|
datasetIDs := req.DatasetIDs
|
|
metadataFilter := req.MetadataFilter
|
|
crossLanguages := req.CrossLanguages
|
|
|
|
common.Debug(fmt.Sprintf("SearchDatasets request:\n"+
|
|
" datasetIDs=%v\n"+
|
|
" question=%s\n"+
|
|
" page=%v, pageSize=%v\n"+
|
|
" docIDs=%v\n"+
|
|
" useKG=%v, topK=%v\n"+
|
|
" crossLanguages=%v\n"+
|
|
" searchID=%v\n"+
|
|
" metadataFilter=%v\n"+
|
|
" rerankID=%v\n"+
|
|
" keyword=%v\n"+
|
|
" similarityThreshold=%v, vectorSimilarityWeight=%v",
|
|
datasetIDs, req.Question,
|
|
common.PtrString(req.Page), common.PtrString(req.Size), req.DocIDs,
|
|
useKG, topK, crossLanguages, searchID,
|
|
metadataFilter,
|
|
rerankID,
|
|
keyword,
|
|
similarityThreshold, vectorSimilarityWeight))
|
|
|
|
ctx := context.Background()
|
|
modelProviderSvc := NewModelProviderService()
|
|
|
|
// Access check for all datasets
|
|
var tenantIDs []string
|
|
var kbRecords []*entity.Knowledgebase
|
|
seenTenants := make(map[string]bool)
|
|
for _, datasetID := range datasetIDs {
|
|
if !s.kbDAO.Accessible(datasetID, userID) {
|
|
common.Warn("SearchDatasets access denied", zap.String("datasetID", datasetID), zap.String("userID", userID))
|
|
return nil, fmt.Errorf("only owner of dataset %s is authorized for this operation", datasetID)
|
|
}
|
|
|
|
kb, err := s.kbDAO.GetByID(datasetID)
|
|
if err != nil || kb == nil {
|
|
common.Warn("SearchDatasets dataset not found", zap.String("datasetID", datasetID))
|
|
return nil, fmt.Errorf("dataset %s not found", datasetID)
|
|
}
|
|
if !seenTenants[kb.TenantID] {
|
|
seenTenants[kb.TenantID] = true
|
|
tenantIDs = append(tenantIDs, kb.TenantID)
|
|
}
|
|
kbRecords = append(kbRecords, kb)
|
|
}
|
|
|
|
// Check if all kbs have the same embedding model
|
|
if len(kbRecords) > 1 {
|
|
firstEmbdID := kbRecords[0].EmbdID
|
|
for i := 1; i < len(kbRecords); i++ {
|
|
if kbRecords[i].EmbdID != firstEmbdID {
|
|
return nil, fmt.Errorf("Datasets use different embedding models.")
|
|
}
|
|
}
|
|
}
|
|
|
|
// Override request fields with values from saved search config (if search_id is provided)
|
|
var chatID string
|
|
if searchID != "" {
|
|
if s.searchService == nil {
|
|
common.Warn("Search service is not initialized for search_id", zap.String("searchID", searchID))
|
|
return nil, fmt.Errorf("Invalid search_id")
|
|
}
|
|
searchDetail, err := s.searchService.GetDetail(searchID)
|
|
if err != nil || searchDetail == nil || len(searchDetail) == 0 {
|
|
common.Warn("Invalid search_id", zap.String("searchID", searchID), zap.Error(err))
|
|
return nil, fmt.Errorf("Invalid search_id")
|
|
} else if searchConfig, ok := searchDetail["search_config"].(map[string]interface{}); ok && searchConfig != nil {
|
|
if scMetadataFilter, ok := searchConfig["meta_data_filter"].(map[string]interface{}); ok {
|
|
metadataFilter = scMetadataFilter
|
|
}
|
|
if scST, ok := searchConfig["similarity_threshold"].(float64); ok {
|
|
similarityThreshold = scST
|
|
}
|
|
if scVSW, ok := searchConfig["vector_similarity_weight"].(float64); ok {
|
|
vectorSimilarityWeight = scVSW
|
|
}
|
|
if scTopK, ok := searchConfig["top_k"].(float64); ok {
|
|
topK = int(scTopK)
|
|
if topK < 1 {
|
|
topK = 1
|
|
} else if topK > 2048 {
|
|
topK = 2048
|
|
}
|
|
}
|
|
if scUseKG, ok := searchConfig["use_kg"].(bool); ok {
|
|
useKG = scUseKG
|
|
}
|
|
if scLangs, ok := searchConfig["cross_languages"].([]interface{}); ok {
|
|
crossLanguages = make([]string, len(scLangs))
|
|
for i, l := range scLangs {
|
|
if s, ok := l.(string); ok {
|
|
crossLanguages[i] = s
|
|
}
|
|
}
|
|
}
|
|
if scKeyword, ok := searchConfig["keyword"].(bool); ok {
|
|
keyword = scKeyword
|
|
}
|
|
if scRerankID, ok := searchConfig["rerank_id"].(string); ok {
|
|
rerankID = scRerankID
|
|
}
|
|
chatID, _ = searchConfig["chat_id"].(string)
|
|
|
|
common.Debug("SearchDatasets loaded Search config",
|
|
zap.String("searchID", searchID),
|
|
zap.Strings("datasetIDs", datasetIDs),
|
|
zap.Float64("vectorSimilarityWeight", vectorSimilarityWeight),
|
|
zap.Float64("fullTextWeight", 1-vectorSimilarityWeight),
|
|
zap.Float64("similarityThreshold", similarityThreshold),
|
|
zap.Int("topK", topK),
|
|
zap.Strings("crossLanguages", crossLanguages),
|
|
zap.Bool("keyword", keyword),
|
|
zap.String("rerankID", rerankID),
|
|
zap.String("chatID", chatID),
|
|
zap.Bool("useKG", useKG))
|
|
} else {
|
|
common.Warn("Invalid search_id: search_config missing or invalid", zap.String("searchID", searchID))
|
|
return nil, fmt.Errorf("Invalid search_id")
|
|
}
|
|
}
|
|
|
|
// If meta_data_filter method is auto/semi_auto, get chat model
|
|
var err error
|
|
var chatModelForFilter *models.ChatModel
|
|
if metadataFilter != nil {
|
|
method, _ := metadataFilter["method"].(string)
|
|
if method == "auto" || method == "semi_auto" {
|
|
if chatID != "" {
|
|
driver, modelName, apiConfig, _, err := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeChat, chatID)
|
|
if err != nil {
|
|
common.Warn("Failed to get chat model config from search_config chat_id, using tenant default", zap.String("chatID", chatID), zap.Error(err))
|
|
} else {
|
|
chatModelForFilter = models.NewChatModel(driver, &modelName, apiConfig)
|
|
common.Info("Fetched chat model (from search_config) for metadata filter",
|
|
zap.String("chatID", chatID),
|
|
zap.String("tenantID", tenantIDs[0]))
|
|
}
|
|
}
|
|
|
|
if chatModelForFilter == nil {
|
|
driver, modelName, apiConfig, _, err := modelProviderSvc.GetTenantDefaultModelByType(tenantIDs[0], entity.ModelTypeChat)
|
|
if err != nil {
|
|
common.Warn("Failed to get tenant default chat model for meta_data_filter", zap.Error(err))
|
|
} else {
|
|
chatModelForFilter = models.NewChatModel(driver, &modelName, apiConfig)
|
|
common.Info("Fetched chat model (tenant default) for metadata filter",
|
|
zap.String("tenantID", tenantIDs[0]))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Apply meta_data_filter to get filtered doc_ids
|
|
docIDs := make([]string, len(req.DocIDs))
|
|
copy(docIDs, req.DocIDs)
|
|
if metadataFilter != nil {
|
|
metadataSvc := NewMetadataService()
|
|
flattedMeta, err := metadataSvc.GetFlattedMetaByKBs(datasetIDs)
|
|
if err != nil {
|
|
common.Warn("Failed to get flatted metadata, using empty metadata for filter", zap.Error(err))
|
|
flattedMeta = make(common.MetaData)
|
|
}
|
|
common.Info("Metadata filter conditions", zap.Any("filter", metadataFilter))
|
|
filteredDocIDs, _ := ApplyMetaDataFilter(ctx, metadataFilter, flattedMeta, question, chatModelForFilter, req.DocIDs, datasetIDs)
|
|
docIDs = filteredDocIDs
|
|
common.Info("ApplyMetaDataFilter result", zap.Strings("docIDs", docIDs))
|
|
}
|
|
|
|
// Apply cross_languages and keyword extraction
|
|
modifiedQuestion := question
|
|
if len(crossLanguages) > 0 {
|
|
// Pass tenantID and empty llmID so CrossLanguages can fetch default if needed
|
|
// This matches Python's cross_languages(tenant_id, llm_id, query, languages)
|
|
common.Info("CrossLanguages: dispatching translation",
|
|
zap.String("tenantID", tenantIDs[0]),
|
|
zap.String("llmID", ""),
|
|
zap.Strings("crossLanguages", crossLanguages))
|
|
translated, err := CrossLanguages(ctx, tenantIDs[0], "", question, crossLanguages)
|
|
if err != nil {
|
|
common.Warn("Failed to translate question", zap.String("llmID", ""), zap.Error(err))
|
|
} else {
|
|
modifiedQuestion = translated
|
|
}
|
|
}
|
|
if keyword {
|
|
driver, modelName, apiConfig, _, err := modelProviderSvc.GetTenantDefaultModelByType(tenantIDs[0], entity.ModelTypeChat)
|
|
if err != nil {
|
|
common.Warn("Failed to get default chat model for LLM transformations", zap.Error(err))
|
|
} else {
|
|
chatModel := models.NewChatModel(driver, &modelName, apiConfig)
|
|
common.Info("Fetched chat model (tenant default) for keyword_extraction",
|
|
zap.String("tenantID", tenantIDs[0]))
|
|
|
|
extractedKeywords, err := KeywordExtraction(ctx, chatModel, modifiedQuestion, 3)
|
|
if err != nil {
|
|
common.Warn("Failed to extract keywords from question", zap.Error(err))
|
|
} else if extractedKeywords != "" {
|
|
modifiedQuestion = modifiedQuestion + extractedKeywords
|
|
}
|
|
}
|
|
}
|
|
if modifiedQuestion != question {
|
|
common.Info("Modified question after transformations",
|
|
zap.String("originalQuestion", question),
|
|
zap.String("modifiedQuestion", modifiedQuestion),
|
|
zap.Strings("crossLanguages", crossLanguages),
|
|
zap.Bool("keywordExtraction", keyword))
|
|
}
|
|
|
|
// Get tag-based rank features via LabelQuestion
|
|
metadataSvc := NewMetadataService()
|
|
labels := metadataSvc.LabelQuestion(modifiedQuestion, kbRecords)
|
|
if len(labels) > 0 {
|
|
common.Debug("LabelQuestion result", zap.Any("labels", labels))
|
|
}
|
|
|
|
// Determine embedding model
|
|
var embeddingModel *models.EmbeddingModel
|
|
if kbRecords[0].EmbdID != "" {
|
|
driver, modelName, apiConfig, maxTokens, embErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeEmbedding, kbRecords[0].EmbdID)
|
|
if embErr != nil {
|
|
return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", embErr)
|
|
}
|
|
embeddingModel = models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens)
|
|
} else {
|
|
driver, modelName, apiConfig, maxTokens, err := modelProviderSvc.GetTenantDefaultModelByType(tenantIDs[0], entity.ModelTypeEmbedding)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get tenant default embedding model: %w", err)
|
|
}
|
|
embeddingModel = models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens)
|
|
}
|
|
modelNameStr := ""
|
|
if embeddingModel.ModelName != nil {
|
|
modelNameStr = *embeddingModel.ModelName
|
|
}
|
|
common.Info("Fetched embedding model for retrieval",
|
|
zap.String("tenantID", tenantIDs[0]),
|
|
zap.String("modelName", modelNameStr))
|
|
|
|
// Get rerank model if rerankID is specified
|
|
var rerankModel *models.RerankModel
|
|
|
|
if rerankID != "" {
|
|
driver, modelName, apiConfig, _, rErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeRerank, rerankID)
|
|
if rErr != nil {
|
|
return nil, fmt.Errorf("failed to get rerank model by rerank_id: %w", rErr)
|
|
}
|
|
rerankModel = models.NewRerankModel(driver, &modelName, apiConfig)
|
|
}
|
|
|
|
if rerankModel != nil {
|
|
common.Info("Fetched rerank model",
|
|
zap.String("tenantID", tenantIDs[0]),
|
|
zap.String("modelName", *rerankModel.ModelName))
|
|
}
|
|
|
|
retrievalReq := &nlp.RetrievalRequest{
|
|
TenantIDs: tenantIDs,
|
|
Question: modifiedQuestion,
|
|
KbIDs: datasetIDs,
|
|
DocIDs: docIDs,
|
|
Page: page,
|
|
PageSize: pageSize,
|
|
Top: &topK,
|
|
SimilarityThreshold: &similarityThreshold,
|
|
VectorSimilarityWeight: &vectorSimilarityWeight,
|
|
RerankModel: rerankModel,
|
|
RankFeature: &labels,
|
|
EmbeddingModel: embeddingModel,
|
|
}
|
|
|
|
retrievalResult, err := nlp.NewRetrievalService(s.docEngine, s.documentDAO).Retrieval(ctx, retrievalReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("retrieval search failed: %w", err)
|
|
}
|
|
|
|
filteredChunks := retrievalResult.Chunks
|
|
|
|
if useKG {
|
|
common.Warn("use_kg is not yet implemented in Go - skipping KG retrieval")
|
|
}
|
|
|
|
filteredChunks = nlp.RetrievalByChildren(filteredChunks, tenantIDs, s.docEngine, ctx)
|
|
|
|
for i := range filteredChunks {
|
|
delete(filteredChunks[i], "vector")
|
|
}
|
|
|
|
common.Info("SearchDatasets completed", zap.String("userID", userID), zap.Any("kbID", datasetIDs), zap.String("question", question), zap.Int64("chunkCount", int64(len(filteredChunks))))
|
|
|
|
// Convert all float64 values to PyFloat64 for Python-compatible JSON serialization
|
|
pyChunks := common.ConvertFloatsToPyFormat(filteredChunks).([]map[string]interface{})
|
|
|
|
return &SearchDatasetsResponse{
|
|
Chunks: pyChunks,
|
|
DocAggs: retrievalResult.DocAggs,
|
|
Labels: &labels,
|
|
Total: retrievalResult.Total,
|
|
}, nil
|
|
}
|
|
|
|
// AutoMetadataField mirrors the REST dataset auto metadata field schema.
|
|
type AutoMetadataField struct {
|
|
Name string `json:"name"`
|
|
Type string `json:"type"`
|
|
Description *string `json:"description,omitempty"`
|
|
Examples interface{} `json:"examples,omitempty"`
|
|
RestrictValues bool `json:"restrict_values,omitempty"`
|
|
}
|
|
|
|
// AutoMetadataConfig mirrors the REST dataset auto metadata schema.
|
|
type AutoMetadataConfig struct {
|
|
Enabled *bool `json:"enabled,omitempty"`
|
|
Fields []AutoMetadataField `json:"fields,omitempty"`
|
|
}
|
|
|
|
// MetadataConfigField mirrors one field in the dataset metadata config API.
|
|
type MetadataConfigField struct {
|
|
Key string `json:"key"`
|
|
Type string `json:"type"`
|
|
Description *string `json:"description"`
|
|
Enum []string `json:"enum"`
|
|
}
|
|
|
|
// MetadataConfigRequest mirrors PUT /datasets/:dataset_id/metadata/config.
|
|
type MetadataConfigRequest struct {
|
|
Metadata []MetadataConfigField `json:"metadata"`
|
|
BuiltInMetadata []MetadataConfigField `json:"built_in_metadata"`
|
|
}
|
|
|
|
// CreateDatasetRequest represents the request for creating a dataset.
|
|
type CreateDatasetRequest struct {
|
|
Name string `json:"name" binding:"required"`
|
|
Avatar *string `json:"avatar,omitempty"`
|
|
Description *string `json:"description,omitempty"`
|
|
EmbeddingModel *string `json:"embedding_model,omitempty"`
|
|
Permission *string `json:"permission,omitempty"`
|
|
ChunkMethod *string `json:"chunk_method,omitempty"`
|
|
ParseType *int `json:"parse_type,omitempty"`
|
|
PipelineID *string `json:"pipeline_id,omitempty"`
|
|
ParserConfig map[string]interface{} `json:"parser_config,omitempty"`
|
|
AutoMetadataConfig *AutoMetadataConfig `json:"auto_metadata_config,omitempty"`
|
|
Ext map[string]interface{} `json:"ext,omitempty"`
|
|
}
|
|
|
|
// ListDatasets lists datasets with pagination and filtering.
|
|
func (s *DatasetService) ListDatasets(id, name string, page, pageSize int, orderby string, desc bool, keywords string, ownerIDs []string, parserID, userID string) ([]map[string]interface{}, int64, common.ErrorCode, error) {
|
|
id = strings.TrimSpace(id)
|
|
if id != "" {
|
|
normalizedID, err := normalizeDatasetID(id)
|
|
if err != nil {
|
|
return nil, 0, common.CodeDataError, err
|
|
}
|
|
id = normalizedID
|
|
|
|
kbs, err := s.kbDAO.GetKBByIDAndUserID(id, userID)
|
|
if err != nil {
|
|
return nil, 0, common.CodeServerError, errors.New("Database operation failed")
|
|
}
|
|
if len(kbs) == 0 {
|
|
return nil, 0, common.CodeDataError, fmt.Errorf("User '%s' lacks permission for dataset '%s'", userID, id)
|
|
}
|
|
}
|
|
|
|
name = strings.TrimSpace(name)
|
|
if name != "" {
|
|
kbs, err := s.kbDAO.GetKBByNameAndUserID(name, userID)
|
|
if err != nil {
|
|
return nil, 0, common.CodeServerError, errors.New("Database operation failed")
|
|
}
|
|
if len(kbs) == 0 {
|
|
return nil, 0, common.CodeDataError, fmt.Errorf("User '%s' lacks permission for dataset '%s'", userID, name)
|
|
}
|
|
}
|
|
|
|
if page <= 0 {
|
|
page = 1
|
|
}
|
|
if pageSize <= 0 {
|
|
pageSize = 30
|
|
}
|
|
|
|
orderby = strings.TrimSpace(orderby)
|
|
if _, ok := datasetAllowedOrderByFields[orderby]; !ok {
|
|
orderby = "create_time"
|
|
}
|
|
|
|
keywords = strings.TrimSpace(keywords)
|
|
parserID = strings.TrimSpace(parserID)
|
|
|
|
// Empty owner ids do not change the query, so only keep the meaningful ones.
|
|
tenantIDs := make([]string, 0, len(ownerIDs))
|
|
for _, ownerID := range ownerIDs {
|
|
ownerID = strings.TrimSpace(ownerID)
|
|
if ownerID != "" {
|
|
tenantIDs = append(tenantIDs, ownerID)
|
|
}
|
|
}
|
|
if len(tenantIDs) == 0 {
|
|
joinedTenants, err := s.tenantDAO.GetJoinedTenantsByUserID(userID)
|
|
if err != nil {
|
|
return nil, 0, common.CodeServerError, errors.New("Database operation failed")
|
|
}
|
|
for _, joinedTenant := range joinedTenants {
|
|
if joinedTenant == nil || joinedTenant.TenantID == "" {
|
|
continue
|
|
}
|
|
tenantIDs = append(tenantIDs, joinedTenant.TenantID)
|
|
}
|
|
}
|
|
|
|
kbs, total, err := s.kbDAO.GetByTenantIDs(tenantIDs, userID, page, pageSize, orderby, desc, keywords, parserID)
|
|
if err != nil {
|
|
return nil, 0, common.CodeServerError, errors.New("Database operation failed")
|
|
}
|
|
|
|
data := make([]map[string]interface{}, 0, len(kbs))
|
|
for _, kb := range kbs {
|
|
if kb == nil {
|
|
continue
|
|
}
|
|
data = append(data, datasetListItemToMap(kb))
|
|
}
|
|
|
|
return data, total, common.CodeSuccess, nil
|
|
}
|
|
|
|
// CreateDataset creates a new dataset.
|
|
func (s *DatasetService) CreateDataset(req *CreateDatasetRequest, tenantID string) (map[string]interface{}, common.ErrorCode, error) {
|
|
if !isValidString(req.Name) {
|
|
return nil, common.CodeDataError, errors.New("Dataset name must be string.")
|
|
}
|
|
|
|
name := strings.TrimSpace(req.Name)
|
|
if name == "" {
|
|
return nil, common.CodeDataError, errors.New("Dataset name can't be empty.")
|
|
}
|
|
if len(name) > entity.DatasetNameLimit {
|
|
return nil, common.CodeDataError, fmt.Errorf("Dataset name length is %d which is large than %d", len(name), entity.DatasetNameLimit)
|
|
}
|
|
|
|
tenant, err := s.tenantDAO.GetByID(tenantID)
|
|
if err != nil || tenant == nil {
|
|
return nil, common.CodeDataError, errors.New("Tenant not found.")
|
|
}
|
|
|
|
parserID := ""
|
|
permission := "me"
|
|
embeddingModel := ""
|
|
parserConfig := req.ParserConfig
|
|
pipelineID := req.PipelineID
|
|
description := req.Description
|
|
avatar := req.Avatar
|
|
var language *string
|
|
|
|
if req.Description != nil && len(*req.Description) > 65535 {
|
|
return nil, common.CodeDataError, errors.New("String should have at most 65535 characters")
|
|
}
|
|
if req.Avatar != nil {
|
|
if len(*req.Avatar) > 65535 {
|
|
return nil, common.CodeDataError, errors.New("String should have at most 65535 characters")
|
|
}
|
|
if err := validateDatasetAvatar(*req.Avatar); err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
}
|
|
if req.Permission != nil {
|
|
permission = strings.TrimSpace(*req.Permission)
|
|
if permission != "me" && permission != "team" {
|
|
return nil, common.CodeDataError, errors.New("Input should be 'me' or 'team'")
|
|
}
|
|
}
|
|
if req.ChunkMethod != nil {
|
|
parserID = strings.TrimSpace(*req.ChunkMethod)
|
|
if err := validateDatasetChunkMethod(parserID); err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
pipelineID = nil
|
|
}
|
|
if req.ParseType != nil && (*req.ParseType < 0 || *req.ParseType > 64) {
|
|
return nil, common.CodeDataError, fmt.Errorf("Input should be between 0 and 64")
|
|
}
|
|
if req.PipelineID != nil {
|
|
normalizedPipelineID, err := normalizeDatasetPipelineID(*req.PipelineID)
|
|
if err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
pipelineID = normalizedPipelineID
|
|
}
|
|
if req.EmbeddingModel != nil {
|
|
embeddingModel = strings.TrimSpace(*req.EmbeddingModel)
|
|
if err := validateDatasetEmbeddingModel(embeddingModel); err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
}
|
|
if err := validateDatasetParserConfigSize(parserConfig); err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
|
|
// ext mirrors the Python REST implementation and overrides known top-level fields.
|
|
for key, value := range req.Ext {
|
|
switch key {
|
|
case "name":
|
|
nameValue, ok := value.(string)
|
|
if !ok {
|
|
return nil, common.CodeDataError, errors.New("Dataset name must be string.")
|
|
}
|
|
nameValue = strings.TrimSpace(nameValue)
|
|
if nameValue == "" {
|
|
return nil, common.CodeDataError, errors.New("Dataset name can't be empty.")
|
|
}
|
|
if len(nameValue) > entity.DatasetNameLimit {
|
|
return nil, common.CodeDataError, fmt.Errorf("Dataset name length is %d which is large than %d", len(nameValue), entity.DatasetNameLimit)
|
|
}
|
|
name = nameValue
|
|
case "description":
|
|
descriptionValue, ok := value.(string)
|
|
if !ok {
|
|
return nil, common.CodeDataError, errors.New("Description must be string.")
|
|
}
|
|
if len(descriptionValue) > 65535 {
|
|
return nil, common.CodeDataError, errors.New("String should have at most 65535 characters")
|
|
}
|
|
description = &descriptionValue
|
|
case "avatar":
|
|
avatarValue, ok := value.(string)
|
|
if !ok {
|
|
return nil, common.CodeDataError, errors.New("Avatar must be string.")
|
|
}
|
|
if len(avatarValue) > 65535 {
|
|
return nil, common.CodeDataError, errors.New("String should have at most 65535 characters")
|
|
}
|
|
if err := validateDatasetAvatar(avatarValue); err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
avatar = &avatarValue
|
|
case "language":
|
|
languageValue, ok := value.(string)
|
|
if !ok {
|
|
return nil, common.CodeDataError, errors.New("Language must be string.")
|
|
}
|
|
languageValue = strings.TrimSpace(languageValue)
|
|
language = &languageValue
|
|
case "permission":
|
|
permissionValue, ok := value.(string)
|
|
if !ok {
|
|
return nil, common.CodeDataError, errors.New("Permission must be string.")
|
|
}
|
|
permissionValue = strings.TrimSpace(permissionValue)
|
|
if permissionValue != "me" && permissionValue != "team" {
|
|
return nil, common.CodeDataError, errors.New("Input should be 'me' or 'team'")
|
|
}
|
|
permission = permissionValue
|
|
case "embedding_model", "embd_id":
|
|
embeddingModelValue, ok := value.(string)
|
|
if !ok {
|
|
return nil, common.CodeDataError, errors.New("Embedding model identifier must follow <model_name>@<provider> format")
|
|
}
|
|
embeddingModelValue = strings.TrimSpace(embeddingModelValue)
|
|
if err := validateDatasetEmbeddingModel(embeddingModelValue); err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
embeddingModel = embeddingModelValue
|
|
case "chunk_method", "parser_id":
|
|
parserIDValue, ok := value.(string)
|
|
if !ok {
|
|
return nil, common.CodeDataError, errors.New(datasetChunkMethodErrorMessage)
|
|
}
|
|
parserIDValue = strings.TrimSpace(parserIDValue)
|
|
if err := validateDatasetChunkMethod(parserIDValue); err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
parserID = parserIDValue
|
|
pipelineID = nil
|
|
case "pipeline_id":
|
|
pipelineIDValue, ok := value.(string)
|
|
if !ok {
|
|
return nil, common.CodeDataError, errors.New("pipeline_id must be 32 hex characters")
|
|
}
|
|
normalizedPipelineID, err := normalizeDatasetPipelineID(pipelineIDValue)
|
|
if err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
pipelineID = normalizedPipelineID
|
|
case "parser_config":
|
|
parserConfigValue, ok := value.(map[string]interface{})
|
|
if !ok {
|
|
return nil, common.CodeDataError, errors.New("parser_config must be valid JSON")
|
|
}
|
|
if err := validateDatasetParserConfigSize(parserConfigValue); err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
parserConfig = parserConfigValue
|
|
}
|
|
}
|
|
|
|
// parser_id wins when it is present; otherwise parse_type and pipeline_id must arrive together.
|
|
if parserID == "" {
|
|
if req.ParseType == nil && pipelineID == nil {
|
|
parserID = "naive"
|
|
} else if req.ParseType == nil || pipelineID == nil {
|
|
missingFields := make([]string, 0, 2)
|
|
if req.ParseType == nil {
|
|
missingFields = append(missingFields, "parse_type")
|
|
}
|
|
if pipelineID == nil {
|
|
missingFields = append(missingFields, "pipeline_id")
|
|
}
|
|
return nil, common.CodeDataError, fmt.Errorf("parser_id omitted -> required fields missing: %s", strings.Join(missingFields, ", "))
|
|
}
|
|
}
|
|
|
|
if req.AutoMetadataConfig != nil {
|
|
parserConfig = applyAutoMetadataConfig(parserConfig, req.AutoMetadataConfig)
|
|
}
|
|
|
|
parserConfig = common.GetParserConfig(parserID, parserConfig)
|
|
parserConfig["llm_id"] = tenant.LLMID
|
|
|
|
embdID := tenant.EmbdID
|
|
if embeddingModel != "" {
|
|
ok, message := s.verifyEmbeddingAvailability(embeddingModel, tenantID)
|
|
if !ok {
|
|
return nil, common.CodeDataError, errors.New(message)
|
|
}
|
|
embdID = embeddingModel
|
|
}
|
|
|
|
kbID := utility.GenerateToken()
|
|
|
|
status := string(entity.StatusValid)
|
|
// Deduplicate name within tenant
|
|
duplicateName, err := common.DuplicateName(func(n, tid string) bool {
|
|
existing, err := s.kbDAO.GetByName(n, tid)
|
|
return err == nil && existing != nil
|
|
}, name, tenantID)
|
|
if err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
|
|
kb := &entity.Knowledgebase{
|
|
ID: kbID,
|
|
Name: duplicateName,
|
|
TenantID: tenantID,
|
|
CreatedBy: tenantID,
|
|
ParserID: parserID,
|
|
PipelineID: pipelineID,
|
|
ParserConfig: parserConfig,
|
|
Permission: permission,
|
|
EmbdID: embdID,
|
|
Status: &status,
|
|
}
|
|
|
|
if description != nil {
|
|
kb.Description = description
|
|
}
|
|
if avatar != nil {
|
|
kb.Avatar = avatar
|
|
}
|
|
if language != nil {
|
|
kb.Language = language
|
|
}
|
|
|
|
if err = s.kbDAO.Create(kb); err != nil {
|
|
return nil, common.CodeServerError, errors.New("Failed to save dataset")
|
|
}
|
|
|
|
createdKB, err := s.kbDAO.GetByID(kbID)
|
|
if err != nil || createdKB == nil {
|
|
return nil, common.CodeServerError, errors.New("Dataset created failed")
|
|
}
|
|
|
|
return datasetToMap(createdKB), common.CodeSuccess, nil
|
|
}
|
|
|
|
// DeleteDatasets deletes multiple datasets.
|
|
func (s *DatasetService) DeleteDatasets(ids []string, deleteAll bool, tenantID string) (map[string]interface{}, common.ErrorCode, error) {
|
|
normalizedIDs := make([]string, 0, len(ids))
|
|
seenIDs := make(map[string]struct{}, len(ids))
|
|
|
|
// Canonicalize ids once so every downstream DAO call sees the same 32-char hex format.
|
|
for _, id := range ids {
|
|
normalizedID, err := normalizeDatasetID(strings.TrimSpace(id))
|
|
if err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
if _, exists := seenIDs[normalizedID]; exists {
|
|
return nil, common.CodeDataError, fmt.Errorf("Duplicate ids: '%s'", normalizedID)
|
|
}
|
|
seenIDs[normalizedID] = struct{}{}
|
|
normalizedIDs = append(normalizedIDs, normalizedID)
|
|
}
|
|
|
|
if len(normalizedIDs) == 0 {
|
|
if !deleteAll {
|
|
return map[string]interface{}{"success_count": 0}, common.CodeSuccess, nil
|
|
}
|
|
|
|
kbs, err := s.kbDAO.Query(map[string]interface{}{"tenant_id": tenantID})
|
|
if err != nil {
|
|
return nil, common.CodeServerError, errors.New("Database operation failed")
|
|
}
|
|
for _, kb := range kbs {
|
|
normalizedIDs = append(normalizedIDs, kb.ID)
|
|
}
|
|
}
|
|
|
|
kbs := make([]*entity.Knowledgebase, 0, len(normalizedIDs))
|
|
unauthorizedIDs := make([]string, 0)
|
|
for _, id := range normalizedIDs {
|
|
kb, err := s.kbDAO.GetByIDAndTenantID(id, tenantID)
|
|
if err != nil || kb == nil {
|
|
unauthorizedIDs = append(unauthorizedIDs, id)
|
|
continue
|
|
}
|
|
kbs = append(kbs, kb)
|
|
}
|
|
if len(unauthorizedIDs) > 0 {
|
|
return nil, common.CodeDataError, fmt.Errorf("User '%s' lacks permission for datasets: '%s'", tenantID, strings.Join(unauthorizedIDs, ", "))
|
|
}
|
|
|
|
errorsList := make([]string, 0)
|
|
successCount := 0
|
|
for _, kb := range kbs {
|
|
if err := s.deleteDataset(tenantID, kb); err != nil {
|
|
errorsList = append(errorsList, err.Error())
|
|
continue
|
|
}
|
|
successCount++
|
|
}
|
|
|
|
if len(errorsList) == 0 {
|
|
return map[string]interface{}{"success_count": successCount}, common.CodeSuccess, nil
|
|
}
|
|
|
|
details := strings.Join(errorsList, "; ")
|
|
if len(details) > 128 {
|
|
details = details[:128]
|
|
}
|
|
errorMessage := fmt.Sprintf(
|
|
"Successfully deleted %d datasets, %d failed. Details: %s...",
|
|
successCount,
|
|
len(errorsList),
|
|
details,
|
|
)
|
|
if successCount == 0 {
|
|
return nil, common.CodeDataError, errors.New(errorMessage)
|
|
}
|
|
|
|
return map[string]interface{}{
|
|
"success_count": successCount,
|
|
"errors": limitStrings(errorsList, 5),
|
|
}, common.CodeSuccess, nil
|
|
}
|
|
|
|
// GetDataset gets a single dataset with its size and linked connectors.
|
|
func (s *DatasetService) GetDataset(datasetID, userID string) (map[string]interface{}, common.ErrorCode, error) {
|
|
datasetID = strings.TrimSpace(datasetID)
|
|
if datasetID == "" {
|
|
return nil, common.CodeDataError, errors.New("Lack of \"Dataset ID\"")
|
|
}
|
|
|
|
normalizedID, err := normalizeDatasetID(datasetID)
|
|
if err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
datasetID = normalizedID
|
|
|
|
if !s.kbDAO.Accessible(datasetID, userID) {
|
|
return nil, common.CodeDataError, fmt.Errorf("User '%s' lacks permission for dataset '%s'", userID, datasetID)
|
|
}
|
|
|
|
kb, err := s.kbDAO.GetByID(datasetID)
|
|
if err != nil || kb == nil {
|
|
return nil, common.CodeDataError, errors.New("Invalid Dataset ID")
|
|
}
|
|
|
|
data := datasetToMap(kb)
|
|
|
|
size, err := s.documentDAO.SumSizeByDatasetID(datasetID)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, errors.New("Database operation failed")
|
|
}
|
|
data["size"] = size
|
|
|
|
connectors, err := s.connectorDAO.ListByDatasetID(datasetID)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, errors.New("Database operation failed")
|
|
}
|
|
data["connectors"] = datasetConnectorsOrEmpty(connectors)
|
|
|
|
return data, common.CodeSuccess, nil
|
|
}
|
|
|
|
type DatasetConnectorRequest struct {
|
|
ID string `json:"id"`
|
|
AutoParse string `json:"auto_parse,omitempty"`
|
|
}
|
|
|
|
type UpdateDatasetRequest struct {
|
|
Name *string `json:"name,omitempty"`
|
|
Avatar *string `json:"avatar,omitempty"`
|
|
Description *string `json:"description,omitempty"`
|
|
Language *string `json:"language,omitempty"`
|
|
Connectors *[]DatasetConnectorRequest `json:"connectors,omitempty"`
|
|
EmbdID *string `json:"embd_id,omitempty"`
|
|
EmbeddingModel *string `json:"embedding_model,omitempty"`
|
|
Permission *string `json:"permission,omitempty"`
|
|
ParserID *string `json:"parser_id,omitempty"`
|
|
ChunkMethod *string `json:"chunk_method,omitempty"`
|
|
Pagerank *int64 `json:"pagerank,omitempty"`
|
|
ParserConfig map[string]interface{} `json:"parser_config,omitempty"`
|
|
PipelineID *string `json:"pipeline_id,omitempty"`
|
|
AutoMetadataConfig *AutoMetadataConfig `json:"auto_metadata_config,omitempty"`
|
|
Ext map[string]interface{} `json:"ext,omitempty"`
|
|
}
|
|
|
|
// UpdateDataset Update a dataset
|
|
func (s *DatasetService) UpdateDataset(datasetID, tenantID string, req UpdateDatasetRequest) (map[string]interface{}, common.ErrorCode, error) {
|
|
kb, err := s.kbDAO.GetByID(datasetID)
|
|
if err != nil {
|
|
if dao.IsNotFoundErr(err) {
|
|
return nil, common.CodeDataError, errors.New("Dataset not found")
|
|
}
|
|
return nil, common.CodeServerError, errors.New("Database operation failed")
|
|
}
|
|
|
|
if kb == nil || kb.TenantID != tenantID {
|
|
return nil, common.CodeDataError, fmt.Errorf("User '%s' lacks permission for dataset '%s'", tenantID, datasetID)
|
|
}
|
|
|
|
connectorsProvided := req.Connectors != nil
|
|
connectors := make([]DatasetConnectorRequest, 0)
|
|
if req.Connectors != nil {
|
|
connectors = *req.Connectors
|
|
}
|
|
|
|
updates := make(map[string]interface{})
|
|
extUpdates := normalizeDatasetUpdateExt(req.Ext)
|
|
|
|
if req.Name != nil {
|
|
name := strings.TrimSpace(*req.Name)
|
|
if name == "" {
|
|
return nil, common.CodeDataError, errors.New("String should have at least 1 character")
|
|
}
|
|
if len(name) > 128 {
|
|
return nil, common.CodeDataError, errors.New("String should have at most 128 characters")
|
|
}
|
|
updates["name"] = name
|
|
}
|
|
if req.Avatar != nil {
|
|
if len(*req.Avatar) > 65535 {
|
|
return nil, common.CodeDataError, errors.New("String should have at most 65535 characters")
|
|
}
|
|
if err := validateDatasetAvatar(*req.Avatar); err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
updates["avatar"] = *req.Avatar
|
|
}
|
|
if req.Description != nil {
|
|
if len(*req.Description) > 65535 {
|
|
return nil, common.CodeDataError, errors.New("String should have at most 65535 characters")
|
|
}
|
|
updates["description"] = *req.Description
|
|
}
|
|
if req.Language != nil {
|
|
language := strings.TrimSpace(*req.Language)
|
|
if len(language) > 32 {
|
|
return nil, common.CodeDataError, errors.New("String should have at most 32 characters")
|
|
}
|
|
updates["language"] = language
|
|
}
|
|
if req.Permission != nil {
|
|
permission := strings.TrimSpace(*req.Permission)
|
|
if permission != "me" && permission != "team" {
|
|
return nil, common.CodeDataError, errors.New("Input should be 'me' or 'team'")
|
|
}
|
|
updates["permission"] = permission
|
|
}
|
|
if req.PipelineID != nil {
|
|
pipelineID, err := normalizeDatasetPipelineID(*req.PipelineID)
|
|
if err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
if pipelineID != nil {
|
|
updates["pipeline_id"] = *pipelineID
|
|
}
|
|
}
|
|
|
|
for key, value := range extUpdates {
|
|
if _, exists := updates[key]; !exists {
|
|
updates[key] = value
|
|
}
|
|
}
|
|
|
|
parserID, parserIDProvided, err := datasetUpdateParserID(req)
|
|
if err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
if !parserIDProvided {
|
|
if extParserID, ok := updates["parser_id"]; ok {
|
|
parserIDValue, ok := extParserID.(string)
|
|
if !ok {
|
|
return nil, common.CodeDataError, errors.New(datasetChunkMethodErrorMessage)
|
|
}
|
|
parserID = strings.TrimSpace(parserIDValue)
|
|
if err := validateDatasetChunkMethod(parserID); err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
parserIDProvided = true
|
|
}
|
|
}
|
|
if parserIDProvided {
|
|
updates["parser_id"] = parserID
|
|
}
|
|
|
|
embdID, embdIDProvided, err := datasetUpdateEmbeddingID(req)
|
|
if err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
if !embdIDProvided {
|
|
if extEmbdID, ok := updates["embd_id"]; ok {
|
|
embdIDValue, ok := extEmbdID.(string)
|
|
if !ok {
|
|
return nil, common.CodeDataError, errors.New("Embedding model identifier must follow <model_name>@<provider> format")
|
|
}
|
|
embdID = strings.TrimSpace(embdIDValue)
|
|
if embdID != "" {
|
|
if err := validateDatasetEmbeddingModel(embdID); err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
}
|
|
embdIDProvided = true
|
|
}
|
|
}
|
|
if embdIDProvided {
|
|
if embdID == "" {
|
|
embdID = kb.EmbdID
|
|
}
|
|
ok, message := s.verifyEmbeddingAvailability(embdID, tenantID)
|
|
if !ok {
|
|
return nil, common.CodeDataError, errors.New(message)
|
|
}
|
|
updates["embd_id"] = embdID
|
|
}
|
|
|
|
if req.AutoMetadataConfig != nil {
|
|
req.ParserConfig = applyAutoMetadataConfig(req.ParserConfig, req.AutoMetadataConfig)
|
|
}
|
|
if req.ParserConfig != nil {
|
|
if err := validateDatasetParserConfigSize(req.ParserConfig); err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
if len(req.ParserConfig) > 0 {
|
|
parserConfig := normalizeDatasetUpdateParserConfig(req.ParserConfig)
|
|
updates["parser_config"] = entity.JSONMap(common.DeepMergeMaps(kb.ParserConfig, parserConfig))
|
|
}
|
|
}
|
|
|
|
if req.Pagerank != nil && *req.Pagerank != kb.Pagerank {
|
|
if *req.Pagerank < 0 || *req.Pagerank > 100 {
|
|
return nil, common.CodeDataError, errors.New("Input should be less than or equal to 100")
|
|
}
|
|
if s.engineType == server.EngineInfinity {
|
|
return nil, common.CodeDataError, errors.New("'pagerank' can only be set when doc_engine is elasticsearch")
|
|
}
|
|
indexName := fmt.Sprintf("ragflow_%s", kb.TenantID)
|
|
if *req.Pagerank > 0 {
|
|
err = s.docEngine.UpdateChunks(context.Background(), map[string]interface{}{"kb_id": kb.ID}, map[string]interface{}{common.PAGERANK_FLD: *req.Pagerank}, indexName, kb.ID)
|
|
} else {
|
|
err = s.docEngine.UpdateChunks(context.Background(), map[string]interface{}{"exists": common.PAGERANK_FLD}, map[string]interface{}{"remove": common.PAGERANK_FLD}, indexName, kb.ID)
|
|
}
|
|
if err != nil {
|
|
return nil, common.CodeServerError, err
|
|
}
|
|
updates["pagerank"] = *req.Pagerank
|
|
}
|
|
|
|
if parserIDProvided && parserID != kb.ParserID {
|
|
if _, ok := updates["parser_config"]; !ok {
|
|
updates["parser_config"] = entity.JSONMap(common.GetParserConfig(parserID, nil))
|
|
}
|
|
}
|
|
if kb.PipelineID != nil && parserIDProvided {
|
|
if _, ok := updates["pipeline_id"]; !ok {
|
|
updates["pipeline_id"] = ""
|
|
}
|
|
}
|
|
|
|
if nameValue, ok := updates["name"].(string); ok && strings.ToLower(nameValue) != strings.ToLower(kb.Name) {
|
|
existing, lookupErr := s.kbDAO.GetByName(nameValue, tenantID)
|
|
if lookupErr != nil && !dao.IsNotFoundErr(lookupErr) {
|
|
return nil, common.CodeServerError, errors.New("Database operation failed")
|
|
}
|
|
if existing != nil {
|
|
return nil, common.CodeDataError, fmt.Errorf("Dataset name '%s' already exists", nameValue)
|
|
}
|
|
}
|
|
|
|
if len(updates) == 0 && !connectorsProvided {
|
|
return nil, common.CodeDataError, errors.New("No properties were modified")
|
|
}
|
|
|
|
if len(updates) > 0 {
|
|
if err = s.kbDAO.UpdateByID(kb.ID, updates); err != nil {
|
|
return nil, common.CodeServerError, errors.New("Update dataset error.(Database error)")
|
|
}
|
|
}
|
|
|
|
if connectorsProvided {
|
|
connectorLinks := make([]dao.DatasetConnectorLink, 0, len(connectors))
|
|
for _, connector := range connectors {
|
|
connectorID := strings.TrimSpace(connector.ID)
|
|
if connectorID == "" {
|
|
return nil, common.CodeDataError, errors.New("connector id is required")
|
|
}
|
|
connectorLinks = append(connectorLinks, dao.DatasetConnectorLink{
|
|
ID: connectorID,
|
|
AutoParse: connector.AutoParse,
|
|
})
|
|
}
|
|
if err = s.connectorDAO.LinkDatasetConnectors(kb.ID, connectorLinks); err != nil {
|
|
return nil, common.CodeServerError, errors.New("Database operation failed")
|
|
}
|
|
}
|
|
|
|
updatedKB, err := s.kbDAO.GetByID(kb.ID)
|
|
if err != nil {
|
|
return nil, common.CodeDataError, errors.New("Dataset updated failed")
|
|
}
|
|
|
|
data := datasetToMap(updatedKB)
|
|
linkedConnectors, err := s.connectorDAO.ListByDatasetID(kb.ID)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, errors.New("Database operation failed")
|
|
}
|
|
data["connectors"] = datasetConnectorsOrEmpty(linkedConnectors)
|
|
return data, common.CodeSuccess, nil
|
|
}
|
|
|
|
func datasetConnectorsOrEmpty(connectors []*dao.ConnectorDatasetListItem) []*dao.ConnectorDatasetListItem {
|
|
if connectors == nil {
|
|
return make([]*dao.ConnectorDatasetListItem, 0)
|
|
}
|
|
return connectors
|
|
}
|
|
|
|
func datasetUpdateParserID(req UpdateDatasetRequest) (string, bool, error) {
|
|
parserID := ""
|
|
provided := false
|
|
if req.ParserID != nil {
|
|
parserID = strings.TrimSpace(*req.ParserID)
|
|
provided = true
|
|
}
|
|
if req.ChunkMethod != nil {
|
|
parserID = strings.TrimSpace(*req.ChunkMethod)
|
|
provided = true
|
|
}
|
|
if !provided {
|
|
return "", false, nil
|
|
}
|
|
if err := validateDatasetChunkMethod(parserID); err != nil {
|
|
return "", true, err
|
|
}
|
|
return parserID, true, nil
|
|
}
|
|
|
|
func datasetUpdateEmbeddingID(req UpdateDatasetRequest) (string, bool, error) {
|
|
embdID := ""
|
|
provided := false
|
|
if req.EmbdID != nil {
|
|
embdID = strings.TrimSpace(*req.EmbdID)
|
|
provided = true
|
|
}
|
|
if req.EmbeddingModel != nil {
|
|
embdID = strings.TrimSpace(*req.EmbeddingModel)
|
|
provided = true
|
|
}
|
|
if !provided {
|
|
return "", false, nil
|
|
}
|
|
if embdID != "" {
|
|
if err := validateDatasetEmbeddingModel(embdID); err != nil {
|
|
return "", true, err
|
|
}
|
|
}
|
|
return embdID, true, nil
|
|
}
|
|
|
|
func normalizeDatasetUpdateExt(ext map[string]interface{}) map[string]interface{} {
|
|
if ext == nil {
|
|
return nil
|
|
}
|
|
|
|
updates := make(map[string]interface{}, len(ext))
|
|
for key, value := range ext {
|
|
switch key {
|
|
case "embedding_model":
|
|
updates["embd_id"] = value
|
|
case "chunk_method":
|
|
updates["parser_id"] = value
|
|
case "connectors", "auto_metadata_config", "ext", "parse_type":
|
|
continue
|
|
default:
|
|
updates[key] = value
|
|
}
|
|
}
|
|
return updates
|
|
}
|
|
|
|
func normalizeDatasetUpdateParserConfig(parserConfig map[string]interface{}) map[string]interface{} {
|
|
normalized := common.DeepMergeMaps(nil, parserConfig)
|
|
parentChild, _ := normalized["parent_child"].(map[string]interface{})
|
|
if parentChild == nil {
|
|
parentChild = map[string]interface{}{}
|
|
}
|
|
|
|
if datasetBoolValue(parentChild["use_parent_child"]) {
|
|
childrenDelimiter, ok := parentChild["children_delimiter"]
|
|
if !ok {
|
|
childrenDelimiter = "\n"
|
|
}
|
|
normalized["children_delimiter"] = childrenDelimiter
|
|
enableChildren, ok := parentChild["use_parent_child"]
|
|
if !ok {
|
|
enableChildren = true
|
|
}
|
|
normalized["enable_children"] = enableChildren
|
|
} else {
|
|
normalized["children_delimiter"] = ""
|
|
normalized["enable_children"] = false
|
|
normalized["parent_child"] = map[string]interface{}{}
|
|
}
|
|
|
|
if extFields, ok := normalized["ext"].(map[string]interface{}); ok {
|
|
delete(normalized, "ext")
|
|
for key, value := range extFields {
|
|
normalized[key] = value
|
|
}
|
|
}
|
|
|
|
return normalized
|
|
}
|
|
|
|
func datasetBoolValue(value interface{}) bool {
|
|
switch typedValue := value.(type) {
|
|
case bool:
|
|
return typedValue
|
|
case string:
|
|
return typedValue == "1" || strings.EqualFold(typedValue, "true")
|
|
case int:
|
|
return typedValue != 0
|
|
case int64:
|
|
return typedValue != 0
|
|
case float64:
|
|
return typedValue != 0
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// GetMetadataConfig gets the auto-metadata configuration for a dataset.
|
|
func (s *DatasetService) GetMetadataConfig(datasetID, tenantID string) (map[string]interface{}, common.ErrorCode, error) {
|
|
kb, err := s.kbDAO.GetByIDAndTenantID(datasetID, tenantID)
|
|
if err != nil {
|
|
if dao.IsNotFoundErr(err) {
|
|
return nil, common.CodeDataError, fmt.Errorf("User '%s' lacks permission for dataset '%s'", tenantID, datasetID)
|
|
}
|
|
return nil, common.CodeServerError, errors.New("Database operation failed")
|
|
}
|
|
if kb == nil {
|
|
return nil, common.CodeDataError, fmt.Errorf("User '%s' lacks permission for dataset '%s'", tenantID, datasetID)
|
|
}
|
|
|
|
return map[string]interface{}{
|
|
"metadata": parserConfigValueOrEmptyList(kb.ParserConfig, "metadata"),
|
|
"built_in_metadata": parserConfigValueOrEmptyList(kb.ParserConfig, "built_in_metadata"),
|
|
}, common.CodeSuccess, nil
|
|
}
|
|
|
|
// UpdateMetadataConfig updates the auto-metadata configuration for a dataset.
|
|
func (s *DatasetService) UpdateMetadataConfig(datasetID, tenantID string, req *MetadataConfigRequest) (map[string]interface{}, common.ErrorCode, error) {
|
|
datasetID = strings.TrimSpace(datasetID)
|
|
tenantID = strings.TrimSpace(tenantID)
|
|
|
|
kb, err := s.kbDAO.GetByIDAndTenantID(datasetID, tenantID)
|
|
if err != nil {
|
|
if dao.IsNotFoundErr(err) {
|
|
return nil, common.CodeDataError, fmt.Errorf("User '%s' lacks permission for dataset '%s'", tenantID, datasetID)
|
|
}
|
|
return nil, common.CodeServerError, errors.New("Database operation failed")
|
|
}
|
|
if kb == nil {
|
|
return nil, common.CodeDataError, fmt.Errorf("User '%s' lacks permission for dataset '%s'", tenantID, datasetID)
|
|
}
|
|
|
|
if req == nil {
|
|
req = &MetadataConfigRequest{}
|
|
}
|
|
|
|
metadata, err := normalizeMetadataConfigFields(req.Metadata, "metadata")
|
|
if err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
builtInMetadata, err := normalizeMetadataConfigFields(req.BuiltInMetadata, "built_in_metadata")
|
|
if err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
|
|
parserConfig := kb.ParserConfig
|
|
if parserConfig == nil {
|
|
parserConfig = entity.JSONMap{}
|
|
}
|
|
parserConfig["metadata"] = metadata
|
|
parserConfig["built_in_metadata"] = builtInMetadata
|
|
|
|
if err := s.kbDAO.UpdateByID(kb.ID, map[string]interface{}{"parser_config": parserConfig}); err != nil {
|
|
return nil, common.CodeServerError, errors.New("Update auto-metadata error.(Database error)")
|
|
}
|
|
|
|
return map[string]interface{}{
|
|
"metadata": metadata,
|
|
"built_in_metadata": builtInMetadata,
|
|
}, common.CodeSuccess, nil
|
|
}
|
|
|
|
// Accessible checks if a knowledge base is accessible by a user
|
|
func (s *DatasetService) Accessible(kbID, userID string) bool {
|
|
return s.kbDAO.Accessible(kbID, userID)
|
|
}
|
|
|
|
// GetKnowledgebaseByID resolves a dataset entity without applying permission
|
|
// checks. Upload needs the same existence-then-auth ordering as Python.
|
|
func (s *DatasetService) GetKnowledgebaseByID(datasetID string) (*entity.Knowledgebase, error) {
|
|
datasetID = strings.TrimSpace(datasetID)
|
|
if datasetID == "" {
|
|
return nil, errors.New("Lack of \"Dataset ID\"")
|
|
}
|
|
normalizedID, err := normalizeDatasetID(datasetID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return s.kbDAO.GetByID(normalizedID)
|
|
}
|
|
|
|
// CheckKBTeamPermission mirrors Python check_kb_team_permission.
|
|
func (s *DatasetService) CheckKBTeamPermission(kb *entity.Knowledgebase, userID string) bool {
|
|
return hasKBTeamPermission(kb, userID, s.tenantDAO)
|
|
}
|
|
|
|
func (s *DatasetService) AggregateTags(datasetIDs []string, userID string) ([]map[string]interface{}, common.ErrorCode, error) {
|
|
if len(datasetIDs) == 0 {
|
|
return nil, common.CodeDataError, errors.New("Lack of dataset_ids in query parameters")
|
|
}
|
|
if s.docEngine == nil {
|
|
return nil, common.CodeServerError, errors.New("Document engine is not initialized")
|
|
}
|
|
|
|
datasetIDsByTenant := make(map[string][]string)
|
|
for _, rawID := range datasetIDs {
|
|
rawID = strings.TrimSpace(rawID)
|
|
if rawID == "" {
|
|
continue
|
|
}
|
|
datasetID, err := normalizeDatasetID(rawID)
|
|
if err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
if !s.kbDAO.Accessible(datasetID, userID) {
|
|
return nil, common.CodeDataError, fmt.Errorf("No authorization for dataset '%s'", datasetID)
|
|
}
|
|
kb, err := s.kbDAO.GetByID(datasetID)
|
|
if err != nil {
|
|
if dao.IsNotFoundErr(err) {
|
|
return nil, common.CodeDataError, fmt.Errorf("Invalid Dataset ID '%s'", datasetID)
|
|
}
|
|
return nil, common.CodeServerError, errors.New("Database operation failed")
|
|
}
|
|
if kb.DocNum <= 0 {
|
|
continue
|
|
}
|
|
datasetIDsByTenant[kb.TenantID] = append(datasetIDsByTenant[kb.TenantID], datasetID)
|
|
}
|
|
|
|
const pageSize = 10000
|
|
merged := make(map[string]int)
|
|
for tenantID, kbIDs := range datasetIDsByTenant {
|
|
for offset := 0; ; offset += pageSize {
|
|
searchResp, err := s.docEngine.Search(context.Background(), &types.SearchRequest{
|
|
IndexNames: []string{fmt.Sprintf("ragflow_%s", tenantID)},
|
|
KbIDs: kbIDs,
|
|
Offset: offset,
|
|
Limit: pageSize,
|
|
SelectFields: []string{"tag_kwd"},
|
|
})
|
|
if err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to aggregate tags: %w", err)
|
|
}
|
|
for _, agg := range s.docEngine.GetAggregation(searchResp.Chunks, "tag_kwd") {
|
|
tag, _ := agg["key"].(string)
|
|
if tag == "" {
|
|
continue
|
|
}
|
|
switch count := agg["count"].(type) {
|
|
case int:
|
|
merged[tag] += count
|
|
case int32:
|
|
merged[tag] += int(count)
|
|
case int64:
|
|
merged[tag] += int(count)
|
|
case float64:
|
|
merged[tag] += int(count)
|
|
}
|
|
}
|
|
|
|
chunkCount := len(searchResp.Chunks)
|
|
if chunkCount == 0 || chunkCount < pageSize {
|
|
break
|
|
}
|
|
if searchResp.Total > 0 && int64(offset+chunkCount) >= searchResp.Total {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
result := make([]map[string]interface{}, 0, len(merged))
|
|
for tag, count := range merged {
|
|
result = append(result, map[string]interface{}{
|
|
"value": tag,
|
|
"count": count,
|
|
})
|
|
}
|
|
return result, common.CodeSuccess, nil
|
|
}
|
|
|
|
func (s *DatasetService) ListTags(datasetID, userID string) ([]map[string]interface{}, common.ErrorCode, error) {
|
|
datasetID = strings.TrimSpace(datasetID)
|
|
if datasetID == "" {
|
|
return nil, common.CodeDataError, errors.New("Lack of \"Dataset ID\"")
|
|
}
|
|
|
|
normalizedID, err := normalizeDatasetID(datasetID)
|
|
if err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
datasetID = normalizedID
|
|
|
|
if !s.kbDAO.Accessible(datasetID, userID) {
|
|
return nil, common.CodeDataError, errors.New("No authorization.")
|
|
}
|
|
if s.docEngine == nil {
|
|
return nil, common.CodeServerError, errors.New("Document engine is not initialized")
|
|
}
|
|
|
|
kb, err := s.kbDAO.GetByID(datasetID)
|
|
if err != nil || kb == nil {
|
|
return nil, common.CodeDataError, errors.New("Invalid Dataset ID")
|
|
}
|
|
|
|
indexName := fmt.Sprintf("ragflow_%s", kb.TenantID)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
exists, err := s.docEngine.ChunkStoreExists(ctx, indexName, datasetID)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to inspect chunk store: %w", err)
|
|
}
|
|
if !exists {
|
|
return []map[string]interface{}{}, common.CodeSuccess, nil
|
|
}
|
|
|
|
const pageSize = 10000
|
|
counts := make(map[string]int)
|
|
for offset := 0; ; offset += pageSize {
|
|
if err := ctx.Err(); err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("list tags timeout or canceled: %w", err)
|
|
}
|
|
|
|
searchResp, err := s.docEngine.Search(ctx, &types.SearchRequest{
|
|
IndexNames: []string{indexName},
|
|
KbIDs: []string{datasetID},
|
|
Offset: offset,
|
|
Limit: pageSize,
|
|
SelectFields: []string{"tag_kwd"},
|
|
})
|
|
if err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to list tags: %w", err)
|
|
}
|
|
|
|
for _, agg := range s.docEngine.GetAggregation(searchResp.Chunks, "tag_kwd") {
|
|
tag, _ := agg["key"].(string)
|
|
if tag == "" {
|
|
continue
|
|
}
|
|
switch count := agg["count"].(type) {
|
|
case int:
|
|
counts[tag] += count
|
|
case int32:
|
|
counts[tag] += int(count)
|
|
case int64:
|
|
counts[tag] += int(count)
|
|
case float64:
|
|
counts[tag] += int(count)
|
|
}
|
|
}
|
|
|
|
chunkCount := len(searchResp.Chunks)
|
|
if chunkCount == 0 || chunkCount < pageSize {
|
|
break
|
|
}
|
|
if searchResp.Total > 0 && int64(offset+chunkCount) >= searchResp.Total {
|
|
break
|
|
}
|
|
}
|
|
|
|
if len(counts) == 0 {
|
|
return []map[string]interface{}{}, common.CodeSuccess, nil
|
|
}
|
|
|
|
tags := make([]string, 0, len(counts))
|
|
for tag := range counts {
|
|
tags = append(tags, tag)
|
|
}
|
|
sort.Slice(tags, func(i, j int) bool {
|
|
if counts[tags[i]] != counts[tags[j]] {
|
|
return counts[tags[i]] > counts[tags[j]]
|
|
}
|
|
return tags[i] < tags[j]
|
|
})
|
|
|
|
result := make([]map[string]interface{}, 0, len(tags))
|
|
for _, tag := range tags {
|
|
result = append(result, map[string]interface{}{
|
|
"key": tag,
|
|
"count": counts[tag],
|
|
})
|
|
}
|
|
|
|
return result, common.CodeSuccess, nil
|
|
}
|
|
|
|
// GetIngestionSummary returns dataset-level ingestion counters together with
|
|
// the aggregated document parsing status, mirroring
|
|
// dataset_api_service.get_ingestion_summary.
|
|
func (s *DatasetService) GetIngestionSummary(datasetID, userID string) (map[string]interface{}, common.ErrorCode, error) {
|
|
datasetID = strings.TrimSpace(datasetID)
|
|
if datasetID == "" {
|
|
return nil, common.CodeDataError, errors.New("Lack of \"Dataset ID\"")
|
|
}
|
|
|
|
if !s.kbDAO.Accessible(datasetID, userID) {
|
|
return nil, common.CodeDataError, fmt.Errorf("User '%s' lacks permission for dataset '%s'", userID, datasetID)
|
|
}
|
|
|
|
kb, err := s.kbDAO.GetByID(datasetID)
|
|
if err != nil || kb == nil {
|
|
return nil, common.CodeDataError, errors.New("Invalid Dataset ID")
|
|
}
|
|
|
|
status, err := s.documentDAO.GetParsingStatusByKBID(datasetID)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, errors.New("Database operation failed")
|
|
}
|
|
|
|
return map[string]interface{}{
|
|
"doc_num": kb.DocNum,
|
|
"chunk_num": kb.ChunkNum,
|
|
"token_num": kb.TokenNum,
|
|
"status": status,
|
|
}, common.CodeSuccess, nil
|
|
}
|
|
|
|
// ListIngestionLogs lists ingestion logs for a dataset, mirroring
|
|
// dataset_api_service.list_ingestion_logs. log_type selects between
|
|
// dataset-level logs ("dataset") and per-file logs ("file").
|
|
func (s *DatasetService) ListIngestionLogs(datasetID, userID string, page, pageSize int, orderby string, desc bool, operationStatus []string, createDateFrom, createDateTo, logType, keywords string) (map[string]interface{}, common.ErrorCode, error) {
|
|
datasetID = strings.TrimSpace(datasetID)
|
|
if datasetID == "" {
|
|
return nil, common.CodeDataError, errors.New("Lack of \"Dataset ID\"")
|
|
}
|
|
|
|
if !s.kbDAO.Accessible(datasetID, userID) {
|
|
return nil, common.CodeDataError, errors.New("No authorization.")
|
|
}
|
|
|
|
if logType != "dataset" && logType != "file" {
|
|
return nil, common.CodeDataError, errors.New("Invalid \"log_type\", expected \"dataset\" or \"file\"")
|
|
}
|
|
|
|
var (
|
|
logs []*entity.PipelineOperationLog
|
|
total int64
|
|
err error
|
|
)
|
|
if logType == "file" {
|
|
logs, total, err = s.pipelineLogDAO.GetFileLogsByKBID(datasetID, page, pageSize, orderby, desc, keywords, operationStatus, createDateFrom, createDateTo)
|
|
} else {
|
|
logs, total, err = s.pipelineLogDAO.GetDatasetLogsByKBID(datasetID, page, pageSize, orderby, desc, operationStatus, createDateFrom, createDateTo, keywords)
|
|
}
|
|
if err != nil {
|
|
return nil, common.CodeServerError, errors.New("Database operation failed")
|
|
}
|
|
|
|
items := make([]map[string]interface{}, 0, len(logs))
|
|
for _, log := range logs {
|
|
if log == nil {
|
|
continue
|
|
}
|
|
if logType == "file" {
|
|
items = append(items, fileIngestionLogToMap(log))
|
|
} else {
|
|
items = append(items, datasetIngestionLogToMap(log))
|
|
}
|
|
}
|
|
|
|
return map[string]interface{}{
|
|
"total": total,
|
|
"logs": items,
|
|
}, common.CodeSuccess, nil
|
|
}
|
|
|
|
// GetIngestionLog returns a single ingestion log, mirroring
|
|
// dataset_api_service.get_ingestion_log. It returns the full record (including
|
|
// the `dsl`, `document_id`, `parser_id`, etc.) so that the front-end
|
|
// dataflow-result page can render the pipeline timeline and chunks. The
|
|
// file-level converter is a superset of the dataset-level fields, so it is
|
|
// correct for both dataset-level (graph/raptor/mindmap) and per-file logs.
|
|
func (s *DatasetService) GetIngestionLog(datasetID, userID, logID string) (map[string]interface{}, common.ErrorCode, error) {
|
|
datasetID = strings.TrimSpace(datasetID)
|
|
if datasetID == "" {
|
|
return nil, common.CodeDataError, errors.New("Lack of \"Dataset ID\"")
|
|
}
|
|
|
|
if !s.kbDAO.Accessible(datasetID, userID) {
|
|
return nil, common.CodeDataError, errors.New("No authorization.")
|
|
}
|
|
|
|
log, err := s.pipelineLogDAO.GetByIDAndKBID(logID, datasetID)
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, common.CodeDataError, errors.New("Log not found")
|
|
}
|
|
return nil, common.CodeServerError, errors.New("Database operation failed")
|
|
}
|
|
|
|
return fileIngestionLogToMap(log), common.CodeSuccess, nil
|
|
}
|
|
|
|
func datasetIngestionLogToMap(log *entity.PipelineOperationLog) map[string]interface{} {
|
|
return map[string]interface{}{
|
|
"id": log.ID,
|
|
"tenant_id": log.TenantID,
|
|
"kb_id": log.KbID,
|
|
"progress": log.Progress,
|
|
"progress_msg": stringPointerValue(log.ProgressMsg),
|
|
"process_begin_at": timePointerValue(log.ProcessBeginAt),
|
|
"process_duration": log.ProcessDuration,
|
|
"task_type": log.TaskType,
|
|
"operation_status": log.OperationStatus,
|
|
"avatar": stringPointerValue(log.Avatar),
|
|
"status": stringPointerValue(log.Status),
|
|
"create_time": int64PointerValue(log.CreateTime),
|
|
"create_date": timePointerValue(log.CreateDate),
|
|
"update_time": int64PointerValue(log.UpdateTime),
|
|
"update_date": timePointerValue(log.UpdateDate),
|
|
}
|
|
}
|
|
|
|
func fileIngestionLogToMap(log *entity.PipelineOperationLog) map[string]interface{} {
|
|
return map[string]interface{}{
|
|
"id": log.ID,
|
|
"document_id": log.DocumentID,
|
|
"tenant_id": log.TenantID,
|
|
"kb_id": log.KbID,
|
|
"pipeline_id": stringPointerValue(log.PipelineID),
|
|
"pipeline_title": stringPointerValue(log.PipelineTitle),
|
|
"parser_id": log.ParserID,
|
|
"document_name": log.DocumentName,
|
|
"document_suffix": log.DocumentSuffix,
|
|
"document_type": log.DocumentType,
|
|
"source_from": log.SourceFrom,
|
|
"progress": log.Progress,
|
|
"progress_msg": stringPointerValue(log.ProgressMsg),
|
|
"process_begin_at": timePointerValue(log.ProcessBeginAt),
|
|
"process_duration": log.ProcessDuration,
|
|
"dsl": jsonMapValue(log.DSL),
|
|
"task_type": log.TaskType,
|
|
"operation_status": log.OperationStatus,
|
|
"avatar": stringPointerValue(log.Avatar),
|
|
"status": stringPointerValue(log.Status),
|
|
"create_time": int64PointerValue(log.CreateTime),
|
|
"create_date": timePointerValue(log.CreateDate),
|
|
"update_time": int64PointerValue(log.UpdateTime),
|
|
"update_date": timePointerValue(log.UpdateDate),
|
|
}
|
|
}
|
|
|
|
func stringPointerValue(s *string) interface{} {
|
|
if s == nil {
|
|
return nil
|
|
}
|
|
return *s
|
|
}
|
|
|
|
func int64PointerValue(i *int64) interface{} {
|
|
if i == nil {
|
|
return nil
|
|
}
|
|
return *i
|
|
}
|
|
|
|
func timePointerValue(t *time.Time) interface{} {
|
|
if t == nil {
|
|
return nil
|
|
}
|
|
return t.Format("2006-01-02 15:04:05")
|
|
}
|
|
|
|
func jsonMapValue(m entity.JSONMap) interface{} {
|
|
if m == nil {
|
|
return nil
|
|
}
|
|
return m
|
|
}
|
|
|
|
func (s *DatasetService) deleteDataset(tenantID string, kb *entity.Knowledgebase) error {
|
|
return dao.DB.Transaction(func(tx *gorm.DB) error {
|
|
if taskIDs := datasetIndexTaskIDs(kb); len(taskIDs) > 0 {
|
|
if err := tx.Where("id IN ?", taskIDs).Delete(&entity.Task{}).Error; err != nil {
|
|
return fmt.Errorf("Delete dataset error for %s", kb.ID)
|
|
}
|
|
}
|
|
|
|
var documents []entity.Document
|
|
if err := tx.Where("kb_id = ?", kb.ID).Find(&documents).Error; err != nil {
|
|
return fmt.Errorf("Delete dataset error for %s", kb.ID)
|
|
}
|
|
|
|
docIDs := make([]string, 0, len(documents))
|
|
for _, document := range documents {
|
|
docIDs = append(docIDs, document.ID)
|
|
}
|
|
|
|
if len(docIDs) > 0 {
|
|
var mappings []entity.File2Document
|
|
if err := tx.Where("document_id IN ?", docIDs).Find(&mappings).Error; err != nil {
|
|
return fmt.Errorf("Delete dataset error for %s", kb.ID)
|
|
}
|
|
|
|
fileIDs := make([]string, 0, len(mappings))
|
|
seenFileIDs := make(map[string]struct{}, len(mappings))
|
|
for _, mapping := range mappings {
|
|
if mapping.FileID == nil || *mapping.FileID == "" {
|
|
continue
|
|
}
|
|
if _, exists := seenFileIDs[*mapping.FileID]; exists {
|
|
continue
|
|
}
|
|
seenFileIDs[*mapping.FileID] = struct{}{}
|
|
fileIDs = append(fileIDs, *mapping.FileID)
|
|
}
|
|
|
|
if err := tx.Where("doc_id IN ?", docIDs).Delete(&entity.Task{}).Error; err != nil {
|
|
return fmt.Errorf("Delete dataset error for %s", kb.ID)
|
|
}
|
|
if err := tx.Where("document_id IN ?", docIDs).Delete(&entity.File2Document{}).Error; err != nil {
|
|
return fmt.Errorf("Delete dataset error for %s", kb.ID)
|
|
}
|
|
if len(fileIDs) > 0 {
|
|
if err := tx.Unscoped().Where("id IN ?", fileIDs).Delete(&entity.File{}).Error; err != nil {
|
|
return fmt.Errorf("Delete dataset error for %s", kb.ID)
|
|
}
|
|
}
|
|
if err := tx.Where("id IN ?", docIDs).Delete(&entity.Document{}).Error; err != nil {
|
|
return fmt.Errorf("Delete dataset error for %s", kb.ID)
|
|
}
|
|
}
|
|
|
|
if err := tx.Unscoped().
|
|
Where("source_type = ? AND type = ? AND name = ? AND tenant_id = ?", string(entity.FileSourceKnowledgebase), "folder", kb.Name, tenantID).
|
|
Delete(&entity.File{}).Error; err != nil {
|
|
return fmt.Errorf("Delete dataset error for %s", kb.ID)
|
|
}
|
|
|
|
if err := tx.Where("id = ?", kb.ID).Delete(&entity.Knowledgebase{}).Error; err != nil {
|
|
return fmt.Errorf("Delete dataset error for %s", kb.ID)
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func validateDatasetChunkMethod(chunkMethod string) error {
|
|
if _, ok := datasetAllowedChunkMethods[chunkMethod]; !ok {
|
|
return errors.New(datasetChunkMethodErrorMessage)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func validateDatasetAvatar(avatar string) error {
|
|
if !strings.Contains(avatar, ",") {
|
|
return errors.New("Missing MIME prefix. Expected format: data:<mime>;base64,<data>")
|
|
}
|
|
|
|
prefix, _, _ := strings.Cut(avatar, ",")
|
|
if !strings.HasPrefix(prefix, "data:") {
|
|
return errors.New("Invalid MIME prefix format. Must start with 'data:'")
|
|
}
|
|
|
|
mimeType, _, _ := strings.Cut(strings.TrimPrefix(prefix, "data:"), ";")
|
|
if _, ok := datasetSupportedAvatarMIMETypes[mimeType]; !ok {
|
|
return errors.New("Unsupported MIME type. Allowed: [image/jpeg image/png]")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateDatasetEmbeddingModel(embeddingModel string) error {
|
|
if embeddingModel == "" {
|
|
return errors.New("Embedding model identifier must follow <model_name>@<provider> format")
|
|
}
|
|
|
|
modelName, provider, ok := strings.Cut(embeddingModel, "@")
|
|
if !ok {
|
|
return errors.New("Embedding model identifier must follow <model_name>@<provider> format")
|
|
}
|
|
if strings.TrimSpace(modelName) == "" || strings.TrimSpace(provider) == "" {
|
|
return errors.New("Both model_name and provider must be non-empty strings")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func normalizeDatasetPipelineID(pipelineID string) (*string, error) {
|
|
pipelineID = strings.TrimSpace(pipelineID)
|
|
if pipelineID == "" {
|
|
return nil, nil
|
|
}
|
|
if len(pipelineID) != 32 {
|
|
return nil, errors.New("pipeline_id must be 32 hex characters")
|
|
}
|
|
for _, char := range pipelineID {
|
|
if !strings.ContainsRune("0123456789abcdefABCDEF", char) {
|
|
return nil, errors.New("pipeline_id must be hexadecimal")
|
|
}
|
|
}
|
|
|
|
normalized := strings.ToLower(pipelineID)
|
|
return &normalized, nil
|
|
}
|
|
|
|
func validateDatasetParserConfigSize(parserConfig map[string]interface{}) error {
|
|
if len(parserConfig) == 0 {
|
|
return nil
|
|
}
|
|
|
|
data, err := json.Marshal(parserConfig)
|
|
if err != nil {
|
|
return errors.New("parser_config must be valid JSON")
|
|
}
|
|
if len(data) > 65535 {
|
|
return fmt.Errorf("Parser config exceeds size limit (max 65,535 characters). Current size: %d", len(data))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// normalizeDatasetID canonicalizes an id into the 32-char hex form used by
|
|
// the storage layer. The "UUID1" name was a legacy term from when the
|
|
// Python service generated ids with `uuid.uuid1().hex`; the Go port uses
|
|
// `uuid.New()` (v4), so we accept any RFC 4122 version. We only reject the
|
|
// Nil UUID, which is the reserved "no id" sentinel.
|
|
func normalizeDatasetID(id string) (string, error) {
|
|
parsedUUID, err := uuid.Parse(id)
|
|
if err != nil {
|
|
return "", errors.New("Invalid UUID format")
|
|
}
|
|
if parsedUUID == (uuid.UUID{}) {
|
|
return "", errors.New("Invalid UUID format")
|
|
}
|
|
return strings.ReplaceAll(parsedUUID.String(), "-", ""), nil
|
|
}
|
|
|
|
func (s *DatasetService) verifyEmbeddingAvailability(embdID string, tenantID string) (bool, string) {
|
|
_, _, _, _, err := NewModelProviderService().GetModelConfigFromProviderInstance(tenantID, entity.ModelTypeEmbedding, embdID)
|
|
if err != nil {
|
|
return false, err.Error()
|
|
}
|
|
return true, ""
|
|
}
|
|
|
|
func applyAutoMetadataConfig(parserConfig map[string]interface{}, config *AutoMetadataConfig) map[string]interface{} {
|
|
if parserConfig == nil {
|
|
parserConfig = make(map[string]interface{})
|
|
}
|
|
|
|
fields := make([]map[string]interface{}, 0, len(config.Fields))
|
|
for _, field := range config.Fields {
|
|
fields = append(fields, map[string]interface{}{
|
|
"name": field.Name,
|
|
"type": field.Type,
|
|
"description": field.Description,
|
|
"examples": field.Examples,
|
|
"restrict_values": field.RestrictValues,
|
|
})
|
|
}
|
|
parserConfig["metadata"] = fields
|
|
enableMetadata := true
|
|
if config.Enabled != nil {
|
|
enableMetadata = *config.Enabled
|
|
}
|
|
parserConfig["enable_metadata"] = enableMetadata
|
|
return parserConfig
|
|
}
|
|
|
|
func parserConfigValueOrEmptyList(parserConfig map[string]interface{}, key string) interface{} {
|
|
if parserConfig == nil {
|
|
return []interface{}{}
|
|
}
|
|
|
|
value, ok := parserConfig[key]
|
|
if !ok || value == nil {
|
|
return []interface{}{}
|
|
}
|
|
|
|
return value
|
|
}
|
|
|
|
func normalizeMetadataConfigFields(fields []MetadataConfigField, fieldName string) ([]map[string]interface{}, error) {
|
|
normalizedFields := make([]map[string]interface{}, 0, len(fields))
|
|
for i, field := range fields {
|
|
key := strings.TrimSpace(field.Key)
|
|
if key == "" {
|
|
return nil, fmt.Errorf("%s[%d].key is required", fieldName, i)
|
|
}
|
|
if len(key) > 255 {
|
|
return nil, fmt.Errorf("%s[%d].key should have at most 255 characters", fieldName, i)
|
|
}
|
|
|
|
fieldType := strings.TrimSpace(field.Type)
|
|
if _, ok := datasetAllowedMetadataTypes[fieldType]; !ok {
|
|
return nil, fmt.Errorf("%s[%d].type should be one of 'string', 'list', 'time' or 'number'", fieldName, i)
|
|
}
|
|
|
|
if field.Description != nil && len(*field.Description) > 65535 {
|
|
return nil, fmt.Errorf("%s[%d].description should have at most 65535 characters", fieldName, i)
|
|
}
|
|
|
|
normalizedFields = append(normalizedFields, map[string]interface{}{
|
|
"key": key,
|
|
"type": fieldType,
|
|
"description": field.Description,
|
|
"enum": field.Enum,
|
|
})
|
|
}
|
|
|
|
return normalizedFields, nil
|
|
}
|
|
|
|
func datasetListItemToMap(kb *entity.KnowledgebaseListItem) map[string]interface{} {
|
|
item := map[string]interface{}{
|
|
"id": kb.ID,
|
|
"name": kb.Name,
|
|
"tenant_id": kb.TenantID,
|
|
"permission": kb.Permission,
|
|
"document_count": kb.DocNum,
|
|
"token_num": kb.TokenNum,
|
|
"chunk_count": kb.ChunkNum,
|
|
"chunk_method": kb.ParserID,
|
|
"embedding_model": kb.EmbdID,
|
|
"nickname": kb.Nickname,
|
|
}
|
|
|
|
if kb.Avatar != nil {
|
|
item["avatar"] = *kb.Avatar
|
|
}
|
|
if kb.Language != nil {
|
|
item["language"] = *kb.Language
|
|
}
|
|
if kb.Description != nil {
|
|
item["description"] = *kb.Description
|
|
}
|
|
if kb.TenantAvatar != nil {
|
|
item["tenant_avatar"] = *kb.TenantAvatar
|
|
}
|
|
if kb.UpdateTime != nil {
|
|
item["update_time"] = *kb.UpdateTime
|
|
}
|
|
|
|
return item
|
|
}
|
|
|
|
func datasetToMap(kb *entity.Knowledgebase) map[string]interface{} {
|
|
item := map[string]interface{}{
|
|
"id": kb.ID,
|
|
"tenant_id": kb.TenantID,
|
|
"name": kb.Name,
|
|
"embedding_model": kb.EmbdID,
|
|
"permission": kb.Permission,
|
|
"created_by": kb.CreatedBy,
|
|
"document_count": kb.DocNum,
|
|
"token_num": kb.TokenNum,
|
|
"chunk_count": kb.ChunkNum,
|
|
"similarity_threshold": kb.SimilarityThreshold,
|
|
"vector_similarity_weight": kb.VectorSimilarityWeight,
|
|
"chunk_method": kb.ParserID,
|
|
"parser_config": kb.ParserConfig,
|
|
"pagerank": kb.Pagerank,
|
|
"create_time": kb.CreateTime,
|
|
}
|
|
|
|
if kb.Avatar != nil {
|
|
item["avatar"] = *kb.Avatar
|
|
}
|
|
if kb.Language != nil {
|
|
item["language"] = *kb.Language
|
|
}
|
|
if kb.Description != nil {
|
|
item["description"] = *kb.Description
|
|
}
|
|
if kb.PipelineID != nil {
|
|
item["pipeline_id"] = *kb.PipelineID
|
|
}
|
|
if kb.GraphragTaskID != nil {
|
|
item["graphrag_task_id"] = *kb.GraphragTaskID
|
|
}
|
|
if kb.GraphragTaskFinishAt != nil {
|
|
item["graphrag_task_finish_at"] = kb.GraphragTaskFinishAt.Format("2006-01-02 15:04:05")
|
|
}
|
|
if kb.RaptorTaskID != nil {
|
|
item["raptor_task_id"] = *kb.RaptorTaskID
|
|
}
|
|
if kb.RaptorTaskFinishAt != nil {
|
|
item["raptor_task_finish_at"] = kb.RaptorTaskFinishAt.Format("2006-01-02 15:04:05")
|
|
}
|
|
if kb.MindmapTaskID != nil {
|
|
item["mindmap_task_id"] = *kb.MindmapTaskID
|
|
}
|
|
if kb.MindmapTaskFinishAt != nil {
|
|
item["mindmap_task_finish_at"] = kb.MindmapTaskFinishAt.Format("2006-01-02 15:04:05")
|
|
}
|
|
if kb.UpdateTime != nil {
|
|
item["update_time"] = *kb.UpdateTime
|
|
}
|
|
|
|
return item
|
|
}
|
|
|
|
func limitStrings(values []string, limit int) []string {
|
|
if len(values) <= limit {
|
|
return values
|
|
}
|
|
return values[:limit]
|
|
}
|
|
|
|
func (s *DatasetService) RenameTag(datasetID, userID, fromTag, toTag string) (map[string]interface{}, common.ErrorCode, error) {
|
|
fromTag = strings.TrimSpace(fromTag)
|
|
toTag = strings.TrimSpace(toTag)
|
|
|
|
datasetID, err := normalizeDatasetID(datasetID)
|
|
if err != nil {
|
|
return nil, common.CodeDataError, err
|
|
}
|
|
if strings.TrimSpace(datasetID) == "" {
|
|
return nil, common.CodeDataError, errors.New("Lack of \"Dataset ID\"")
|
|
}
|
|
if !s.kbDAO.Accessible(datasetID, userID) {
|
|
return nil, common.CodeDataError, errors.New("No authorization.")
|
|
}
|
|
if s.docEngine == nil {
|
|
return nil, common.CodeServerError, errors.New("Document engine is not initialized")
|
|
}
|
|
|
|
kb, err := s.kbDAO.GetByID(datasetID)
|
|
if err != nil || kb == nil {
|
|
return nil, common.CodeDataError, errors.New("Invalid Dataset ID")
|
|
}
|
|
indexName := fmt.Sprintf("ragflow_%s", kb.TenantID)
|
|
|
|
condition := map[string]interface{}{
|
|
"tag_kwd": fromTag,
|
|
"kb_id": datasetID,
|
|
}
|
|
newValue := map[string]interface{}{
|
|
"remove": map[string]interface{}{
|
|
"tag_kwd": fromTag,
|
|
},
|
|
"add": map[string]interface{}{
|
|
"tag_kwd": toTag,
|
|
},
|
|
}
|
|
|
|
err = s.docEngine.UpdateChunks(context.Background(), condition, newValue, indexName, datasetID)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, fmt.Errorf("failed to rename tag: %w", err)
|
|
}
|
|
|
|
return map[string]interface{}{
|
|
"from": fromTag,
|
|
"to": toTag,
|
|
}, common.CodeSuccess, nil
|
|
}
|