diff --git a/go.mod b/go.mod index dabf9ec028..6175fde918 100644 --- a/go.mod +++ b/go.mod @@ -44,6 +44,7 @@ require ( github.com/spf13/viper v1.18.2 github.com/xuri/excelize/v2 v2.10.1 github.com/yfedoseev/office_oxide/go v0.1.2 + github.com/zeebo/xxh3 v1.0.2 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.69.0 go.opentelemetry.io/otel v1.44.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.44.0 diff --git a/internal/dao/document.go b/internal/dao/document.go index 16c9671ad0..3d0adefc8b 100644 --- a/internal/dao/document.go +++ b/internal/dao/document.go @@ -297,3 +297,14 @@ func (dao *DocumentDAO) GetByNameAndKBID(name, kbID string) ([]*entity.Document, err := DB.Where("name = ? AND kb_id = ?", name, kbID).Find(&docs).Error return docs, err } + +// ListNamesByKbID returns every document name in a dataset, used to compute a +// non-colliding upload filename (mirrors Python duplicate_name). +func (dao *DocumentDAO) ListNamesByKbID(kbID string) ([]string, error) { + var names []string + err := DB.Model(&entity.Document{}).Where("kb_id = ?", kbID).Pluck("name", &names).Error + if err != nil { + return nil, err + } + return names, nil +} diff --git a/internal/handler/document.go b/internal/handler/document.go index 11e260d6ab..ce69207f8c 100644 --- a/internal/handler/document.go +++ b/internal/handler/document.go @@ -20,9 +20,11 @@ import ( "encoding/json" "errors" "fmt" + "io" "mime" "mime/multipart" "net/http" + "net/url" "path/filepath" "ragflow/internal/common" "ragflow/internal/entity" @@ -59,6 +61,9 @@ type documentServiceIface interface { GetDocumentMetadataByID(docID string) (map[string]interface{}, error) GetDocumentArtifact(filename string) (*service.ArtifactResponse, error) GetDocumentPreview(docID string) (*service.DocumentPreview, error) + UploadLocalDocuments(kb *entity.Knowledgebase, tenantID string, files []*multipart.FileHeader, parentPath string, parserConfigOverride map[string]interface{}) ([]map[string]interface{}, []string) + UploadWebDocument(kb *entity.Knowledgebase, tenantID, name, url string) (map[string]interface{}, common.ErrorCode, error) + UploadEmptyDocument(kb *entity.Knowledgebase, tenantID, name string) (map[string]interface{}, common.ErrorCode, error) DownloadDocument(datasetID, docID string) (*service.DownloadDocumentResp, error) UpdateDatasetDocument(userID, datasetID, documentID string, req *service.UpdateDatasetDocumentRequest, present map[string]bool) (*service.UpdateDatasetDocumentResponse, common.ErrorCode, error) BatchUpdateDocumentMetadatas(datasetID string, selector *service.DocumentMetadataSelector, updates []service.DocumentMetadataUpdate, deletes []service.DocumentMetadataDelete) (*service.BatchUpdateDocumentMetadatasResponse, common.ErrorCode, error) @@ -532,6 +537,197 @@ func (h *DocumentHandler) ListDocuments(c *gin.Context) { }) } +func (h *DocumentHandler) UploadDocuments(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + tenantID := user.ID + datasetID := c.Param("dataset_id") + uploadType := strings.ToLower(c.DefaultQuery("type", "local")) + + kb, err := h.datasetService.GetKnowledgebaseByID(datasetID) + if err != nil || kb == nil { + jsonError(c, common.CodeDataError, fmt.Sprintf("Can't find the dataset with ID %s!", datasetID)) + return + } + if !h.datasetService.CheckKBTeamPermission(kb, tenantID) { + jsonError(c, common.CodeAuthenticationError, "No authorization.") + return + } + + switch uploadType { + case "web": + h.uploadWebDocument(c, kb, tenantID) + case "empty": + h.uploadEmptyDocument(c, kb, tenantID) + case "local": + h.uploadLocalDocuments(c, kb, tenantID) + default: + jsonError(c, common.CodeArgumentError, `"type" must be one of "local", "web", or "empty".`) + } +} + +func (h *DocumentHandler) uploadLocalDocuments(c *gin.Context, kb *entity.Knowledgebase, tenantID string) { + form, err := c.MultipartForm() + if err != nil || form == nil || len(form.File["file"]) == 0 { + jsonError(c, common.CodeArgumentError, "No file part!") + return + } + files := form.File["file"] + for _, fh := range files { + if fh == nil || fh.Filename == "" { + jsonError(c, common.CodeArgumentError, "No file selected!") + return + } + if len([]byte(fh.Filename)) > 255 { + jsonError(c, common.CodeArgumentError, "File name must be 255 bytes or less.") + return + } + } + + // Optional parser_config override — only the allow-listed table column keys. + // Python ignores malformed or non-object input here instead of failing the + // whole upload request. + var override map[string]interface{} + if raw := strings.TrimSpace(c.PostForm("parser_config")); raw != "" { + var parsed map[string]interface{} + if err := json.Unmarshal([]byte(raw), &parsed); err == nil && parsed != nil { + override = map[string]interface{}{} + for _, k := range []string{"table_column_mode", "table_column_roles"} { + if v, ok := parsed[k]; ok { + override[k] = v + } + } + if len(override) == 0 { + override = nil + } + } + } + + data, errMsgs := h.documentService.UploadLocalDocuments(kb, tenantID, files, c.PostForm("parent_path"), override) + if len(data) == 0 && len(errMsgs) > 0 { + jsonError(c, common.CodeServerError, strings.Join(errMsgs, "\n")) + return + } + if len(data) == 0 { + jsonError(c, common.CodeDataError, "There seems to be an issue with your file format. please verify it is correct and not corrupted.") + return + } + + if strings.ToLower(c.DefaultQuery("return_raw_files", "false")) == "true" { + if len(errMsgs) > 0 { + jsonSuccess(c, gin.H{"documents": data, "errors": errMsgs}) + return + } + jsonSuccess(c, data) + return + } + mapped := make([]map[string]interface{}, len(data)) + for i, d := range data { + mapped[i] = mapDocKeysWithRunStatus(d) + } + if len(errMsgs) > 0 { + jsonSuccess(c, gin.H{"documents": mapped, "errors": errMsgs}) + return + } + jsonSuccess(c, mapped) +} + +func (h *DocumentHandler) uploadEmptyDocument(c *gin.Context, kb *entity.Knowledgebase, tenantID string) { + var req struct { + Name string `json:"name"` + } + // An empty body is valid (falls through to the name-required check below); + // a non-empty but malformed body should report the syntax error, not a + // misleading "File name can't be empty." + if err := c.ShouldBindJSON(&req); err != nil && !errors.Is(err, io.EOF) { + jsonError(c, common.CodeArgumentError, "Invalid JSON body: "+err.Error()) + return + } + name := strings.TrimSpace(req.Name) + if name == "" { + jsonError(c, common.CodeArgumentError, "File name can't be empty.") + return + } + if len([]byte(name)) > 255 { + jsonError(c, common.CodeArgumentError, "File name must be 255 bytes or less.") + return + } + data, code, err := h.documentService.UploadEmptyDocument(kb, tenantID, name) + if err != nil { + jsonError(c, code, err.Error()) + return + } + jsonSuccess(c, mapDocKeysWithRunStatus(data)) +} + +func (h *DocumentHandler) uploadWebDocument(c *gin.Context, kb *entity.Knowledgebase, tenantID string) { + name := strings.TrimSpace(c.PostForm("name")) + rawURL := c.PostForm("url") + if name == "" { + jsonError(c, common.CodeArgumentError, `Lack of "name"`) + return + } + if rawURL == "" { + jsonError(c, common.CodeArgumentError, `Lack of "url"`) + return + } + if len([]byte(name)) > 255 { + jsonError(c, common.CodeArgumentError, "File name must be 255 bytes or less.") + return + } + if !isValidHTTPURL(rawURL) { + jsonError(c, common.CodeArgumentError, "The URL format is invalid") + return + } + data, code, err := h.documentService.UploadWebDocument(kb, tenantID, name, rawURL) + if err != nil { + jsonError(c, code, err.Error()) + return + } + jsonSuccess(c, mapDocKeysWithRunStatus(data)) +} + +// jsonSuccess writes the standard {code:0,message:"success",data} envelope. +func jsonSuccess(c *gin.Context, data interface{}) { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": "success", + "data": data, + }) +} + +// mapDocKeysWithRunStatus renames a freshly-created document's raw keys to the +// public response shape (chunk_num→chunk_count, token_num→token_count, +// kb_id→dataset_id, parser_id→chunk_method) and reports run as a label. +// Mirrors Python map_doc_keys_with_run_status / map_doc_keys. +func mapDocKeysWithRunStatus(raw map[string]interface{}) map[string]interface{} { + out := map[string]interface{}{ + "chunk_count": raw["chunk_num"], + "token_count": raw["token_num"], + "dataset_id": raw["kb_id"], + "chunk_method": raw["parser_id"], + "run": "UNSTART", + } + for _, k := range []string{"id", "name", "type", "size", "suffix", "source_type", "created_by", "parser_config", "location", "pipeline_id", "content_hash"} { + if v, ok := raw[k]; ok { + out[k] = v + } + } + return out +} + +// isValidHTTPURL mirrors Python is_valid_url: requires an http/https scheme and a host. +func isValidHTTPURL(raw string) bool { + u, err := url.Parse(strings.TrimSpace(raw)) + if err != nil { + return false + } + return (u.Scheme == "http" || u.Scheme == "https") && u.Host != "" +} + func (h *DocumentHandler) DownloadDocument(c *gin.Context) { datasetID := c.Param("dataset_id") docID := c.Param("document_id") diff --git a/internal/handler/document_test.go b/internal/handler/document_test.go index 07a844d138..464bb9f4bc 100644 --- a/internal/handler/document_test.go +++ b/internal/handler/document_test.go @@ -17,6 +17,7 @@ package handler import ( + "bytes" "encoding/json" "fmt" "mime/multipart" @@ -45,6 +46,11 @@ type fakeDocumentService struct { metadataErr error metadataKBID string metadataDocIDs []string + uploadLocalData []map[string]interface{} + uploadLocalErrs []string + uploadLocalKB *entity.Knowledgebase + uploadLocalPath string + uploadOverride map[string]interface{} ingestCode common.ErrorCode ingestErr error ingestUserID string @@ -60,6 +66,8 @@ func (f *fakeDocumentService) Ingest(userID string, req *service.IngestDocumentR return common.CodeSuccess, nil } +const uploadTestDatasetID = "123e4567-e89b-12d3-a456-426614174000" + func (f *fakeDocumentService) UpdateDatasetDocument(userID, datasetID, documentID string, req *service.UpdateDatasetDocumentRequest, present map[string]bool) (*service.UpdateDatasetDocumentResponse, common.ErrorCode, error) { return nil, common.CodeSuccess, nil } @@ -163,6 +171,18 @@ func (f *fakeDocumentService) DeleteDocumentAllMetadata(docID string) error { func (f *fakeDocumentService) GetDocumentMetadataByID(docID string) (map[string]interface{}, error) { return nil, nil } +func (f *fakeDocumentService) UploadLocalDocuments(kb *entity.Knowledgebase, tenantID string, files []*multipart.FileHeader, parentPath string, parserConfigOverride map[string]interface{}) ([]map[string]interface{}, []string) { + f.uploadLocalKB = kb + f.uploadLocalPath = parentPath + f.uploadOverride = parserConfigOverride + return f.uploadLocalData, f.uploadLocalErrs +} +func (f *fakeDocumentService) UploadWebDocument(kb *entity.Knowledgebase, tenantID, name, url string) (map[string]interface{}, common.ErrorCode, error) { + return nil, common.CodeServerError, fmt.Errorf("not implemented") +} +func (f *fakeDocumentService) UploadEmptyDocument(kb *entity.Knowledgebase, tenantID, name string) (map[string]interface{}, common.ErrorCode, error) { + return nil, common.CodeServerError, fmt.Errorf("not implemented") +} func (f *fakeDocumentService) ListIngestionTasks(userID string, datasetID *string, page, pageSize int) ([]*entity.IngestionTask, error) { return nil, nil @@ -189,6 +209,81 @@ func setupGinContextWithUser(method, path, body string) (*gin.Context, *httptest return c, w } +func setupUploadHandlerDB(t *testing.T, role string) *gorm.DB { + t.Helper() + + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + TranslateError: true, + }) + if err != nil { + t.Fatalf("failed to open sqlite: %v", err) + } + if err := db.AutoMigrate( + &entity.User{}, + &entity.Tenant{}, + &entity.UserTenant{}, + &entity.Knowledgebase{}, + ); err != nil { + t.Fatalf("failed to migrate: %v", err) + } + if err := db.Create(&entity.User{ID: "user-1", Nickname: "test", Email: "test@test.com", Password: sptr("x")}).Error; err != nil { + t.Fatalf("insert user: %v", err) + } + if err := db.Create(&entity.Tenant{ID: "tenant-1", LLMID: "llm-1", EmbdID: "embd-1", ASRID: "asr-1", Status: sptr(string(entity.StatusValid))}).Error; err != nil { + t.Fatalf("insert tenant: %v", err) + } + if err := db.Create(&entity.UserTenant{ID: "ut-1", UserID: "user-1", TenantID: "tenant-1", Role: role, Status: sptr(string(entity.StatusValid))}).Error; err != nil { + t.Fatalf("insert user_tenant: %v", err) + } + pipelineID := "pipe-1" + if err := db.Create(&entity.Knowledgebase{ + ID: "123e4567e89b12d3a456426614174000", + TenantID: "tenant-1", + Name: "kb-upload", + EmbdID: "embd-1", + CreatedBy: "user-1", + Permission: string(entity.TenantPermissionTeam), + ParserID: "naive", + PipelineID: &pipelineID, + ParserConfig: entity.JSONMap{"base": "cfg"}, + Status: sptr(string(entity.StatusValid)), + }).Error; err != nil { + t.Fatalf("insert knowledgebase: %v", err) + } + return db +} + +func setupUploadContext(t *testing.T, path string, fields map[string]string, fileName string, fileContent []byte) (*gin.Context, *httptest.ResponseRecorder) { + t.Helper() + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + var body bytes.Buffer + writer := multipart.NewWriter(&body) + for k, v := range fields { + if err := writer.WriteField(k, v); err != nil { + t.Fatalf("write field %s: %v", k, err) + } + } + part, err := writer.CreateFormFile("file", fileName) + if err != nil { + t.Fatalf("create form file: %v", err) + } + if _, err := part.Write(fileContent); err != nil { + t.Fatalf("write form file: %v", err) + } + if err := writer.Close(); err != nil { + t.Fatalf("close writer: %v", err) + } + req := httptest.NewRequest(http.MethodPost, path, &body) + req.Header.Set("Content-Type", writer.FormDataContentType()) + c, _ := gin.CreateTestContext(w) + c.Request = req + c.Set("user", &entity.User{ID: "user-1"}) + c.Set("user_id", "user-1") + c.Params = gin.Params{{Key: "dataset_id", Value: uploadTestDatasetID}} + return c, w +} + func setupDocumentIngestRoute(userID string, svc *fakeDocumentService) *gin.Engine { gin.SetMode(gin.TestMode) h := &DocumentHandler{ @@ -233,6 +328,115 @@ func TestDeleteDocumentsHandler_Success(t *testing.T) { } } +func TestUploadDocumentsHandler_LocalUsesFullKBAndIgnoresBadParserConfig(t *testing.T) { + db := setupUploadHandlerDB(t, "normal") + orig := dao.DB + dao.DB = db + t.Cleanup(func() { dao.DB = orig }) + + fake := &fakeDocumentService{ + uploadLocalData: []map[string]interface{}{ + {"id": "doc-1", "kb_id": "ds-1", "parser_id": "naive", "chunk_num": int64(0), "token_num": int64(0), "name": "a.txt"}, + }, + } + h := &DocumentHandler{ + documentService: fake, + datasetService: service.NewDatasetService(), + } + + c, w := setupUploadContext(t, "/api/v1/datasets/ds-1/documents?type=local", map[string]string{ + "parent_path": "nested/path", + "parser_config": "{bad json", + }, "a.txt", []byte("abc")) + + h.UploadDocuments(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + if fake.uploadLocalKB == nil { + t.Fatalf("UploadLocalDocuments was not called, response=%s", w.Body.String()) + } + if fake.uploadLocalKB.TenantID != "tenant-1" || fake.uploadLocalKB.Name != "kb-upload" || fake.uploadLocalKB.ParserID != "naive" { + t.Fatalf("incomplete kb passed to service: %+v", fake.uploadLocalKB) + } + if fake.uploadLocalPath != "nested/path" { + t.Fatalf("parent path=%q, want nested/path", fake.uploadLocalPath) + } + if fake.uploadOverride != nil { + t.Fatalf("bad parser_config should be ignored, got %v", fake.uploadOverride) + } +} + +func TestUploadDocumentsHandler_LocalReturnsPartialSuccess(t *testing.T) { + db := setupUploadHandlerDB(t, "normal") + orig := dao.DB + dao.DB = db + t.Cleanup(func() { dao.DB = orig }) + + fake := &fakeDocumentService{ + uploadLocalData: []map[string]interface{}{ + {"id": "doc-1", "kb_id": "ds-1", "parser_id": "naive", "chunk_num": int64(0), "token_num": int64(0), "name": "ok.txt"}, + }, + uploadLocalErrs: []string{"bad.exe: This type of file has not been supported yet!"}, + } + h := &DocumentHandler{ + documentService: fake, + datasetService: service.NewDatasetService(), + } + + c, w := setupUploadContext(t, "/api/v1/datasets/ds-1/documents?type=local", nil, "ok.txt", []byte("abc")) + h.UploadDocuments(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp["code"] != float64(common.CodeSuccess) { + t.Fatalf("expected success for partial upload, got %v", resp) + } + data := resp["data"].(map[string]interface{}) + if len(data["documents"].([]interface{})) != 1 { + t.Fatalf("expected one successful document, got %v", data["documents"]) + } + if len(data["errors"].([]interface{})) != 1 { + t.Fatalf("expected one file error, got %v", data["errors"]) + } +} + +func TestUploadDocumentsHandler_DeniesNonNormalTeamRole(t *testing.T) { + db := setupUploadHandlerDB(t, "admin") + orig := dao.DB + dao.DB = db + t.Cleanup(func() { dao.DB = orig }) + + fake := &fakeDocumentService{} + h := &DocumentHandler{ + documentService: fake, + datasetService: service.NewDatasetService(), + } + + c, w := setupUploadContext(t, "/api/v1/datasets/ds-1/documents?type=local", nil, "a.txt", []byte("abc")) + h.UploadDocuments(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp["code"] == float64(common.CodeSuccess) { + t.Fatalf("expected authorization error, got %v", resp) + } + if fake.uploadLocalKB != nil { + t.Fatal("service should not be called on denied upload") + } +} + func TestDeleteDocumentsHandler_DeleteAll(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/internal/router/router.go b/internal/router/router.go index 7c4c9b6678..975666f26a 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -306,6 +306,7 @@ func (r *Router) Setup(engine *gin.Engine) { // Dataset documents datasets.GET("/:dataset_id/documents", r.documentHandler.ListDocuments) + datasets.POST("/:dataset_id/documents", r.documentHandler.UploadDocuments) datasets.GET("/:dataset_id/documents/:document_id", r.documentHandler.DownloadDocument) datasets.PATCH("/:dataset_id/documents/:document_id", r.documentHandler.UpdateDatasetDocument) datasets.DELETE("/:dataset_id/documents", r.documentHandler.DeleteDocuments) diff --git a/internal/service/dataset.go b/internal/service/dataset.go index 3304be203d..32f2b47e92 100644 --- a/internal/service/dataset.go +++ b/internal/service/dataset.go @@ -3042,6 +3042,25 @@ func (s *DatasetService) Accessible(kbID, userID string) bool { return s.kbDAO.Accessible(kbID, userID) } +// GetKnowledgebaseByID resolves a dataset entity without applying permission +// checks. Upload needs the same existence-then-auth ordering as Python. +func (s *DatasetService) GetKnowledgebaseByID(datasetID string) (*entity.Knowledgebase, error) { + datasetID = strings.TrimSpace(datasetID) + if datasetID == "" { + return nil, errors.New("Lack of \"Dataset ID\"") + } + normalizedID, err := normalizeDatasetID(datasetID) + if err != nil { + return nil, err + } + return s.kbDAO.GetByID(normalizedID) +} + +// CheckKBTeamPermission mirrors Python check_kb_team_permission. +func (s *DatasetService) CheckKBTeamPermission(kb *entity.Knowledgebase, userID string) bool { + return hasKBTeamPermission(kb, userID, s.tenantDAO) +} + func (s *DatasetService) AggregateTags(datasetIDs []string, userID string) ([]map[string]interface{}, common.ErrorCode, error) { if len(datasetIDs) == 0 { return nil, common.CodeDataError, errors.New("Lack of dataset_ids in query parameters") diff --git a/internal/service/document.go b/internal/service/document.go index 6eed85692c..43db9ccc82 100644 --- a/internal/service/document.go +++ b/internal/service/document.go @@ -29,6 +29,7 @@ import ( "math" "math/rand" "mime/multipart" + "net/http" "os" "path/filepath" "reflect" @@ -50,6 +51,7 @@ import ( "ragflow/internal/utility" "github.com/cespare/xxhash/v2" + "github.com/google/uuid" "go.uber.org/zap" "gorm.io/gorm" ) @@ -2781,6 +2783,403 @@ func mapDocumentRunStatus(run *string) string { } } +// UploadLocalDocuments stores each uploaded file in object storage and inserts a +// matching Document row into the dataset. It mirrors Python +// FileService.upload_document: it derives parser_id by filetype, merges the +// optional parser_config override into the dataset config, dedup-renames the +// filename, records size + xxhash content hash, and links each document into the +// file manager (a File row under the dataset folder + a file2document mapping) +// so it surfaces in the dataset's document list. Chunking/embedding happen later +// in the parse step, so nothing here touches the doc store index. +// +// Gaps vs Python (documented, not yet ported): thumbnail generation and +// read_potential_broken_pdf repair. +func (s *DocumentService) UploadLocalDocuments(kb *entity.Knowledgebase, tenantID string, files []*multipart.FileHeader, parentPath string, parserConfigOverride map[string]interface{}) ([]map[string]interface{}, []string) { + storageImpl := storage.GetStorageFactory().GetStorage() + if storageImpl == nil { + return nil, []string{"storage not initialized"} + } + + // Resolve (and create if needed) the dataset's file-manager folder up front. + // Without the File / file2document linkage the document list (which inner-joins + // file2document + file) would never surface the uploaded files. + kbFolder, err := s.ensureKBFolder(kb, tenantID) + if err != nil { + return nil, []string{err.Error()} + } + + // Merge parser_config override (allow-listed keys only) over the dataset config. + merged := entity.JSONMap{} + for k, v := range kb.ParserConfig { + merged[k] = v + } + for k, v := range parserConfigOverride { + merged[k] = v + } + + safeParent := utility.SanitizeFilename(parentPath) + + // Don't silently disable dedupe protection: a transient lookup failure means + // the existing-name set is unknown, so fail rather than risk duplicates. + names, err := s.documentDAO.ListNamesByKbID(kb.ID) + if err != nil { + return nil, []string{err.Error()} + } + taken := map[string]bool{} + for _, n := range names { + taken[n] = true + } + + var results []map[string]interface{} + var errMsgs []string + + for _, fh := range files { + blob, err := readFileHeaderBytes(fh) + if err != nil { + errMsgs = append(errMsgs, fh.Filename+": "+err.Error()) + continue + } + + filename := uniqueUploadName(fh.Filename, taken) + + filetype := utility.FilenameType(filename) + if filetype == utility.FileTypeOTHER { + errMsgs = append(errMsgs, fh.Filename+": This type of file has not been supported yet!") + continue + } + + location := filename + if safeParent != "" { + location = safeParent + "/" + filename + } + for storageImpl.ObjExist(kb.ID, location) { + location += "_" + } + if err := storageImpl.Put(kb.ID, location, blob); err != nil { + errMsgs = append(errMsgs, fh.Filename+": "+err.Error()) + continue + } + + doc := s.newDatasetDocument(kb, tenantID, filename, location, string(filetype), merged, "local", int64(len(blob)), blob) + if err := s.documentDAO.Create(doc); err != nil { + // Roll back the orphaned blob so a failed insert doesn't leak storage. + _ = storageImpl.Remove(kb.ID, location) + errMsgs = append(errMsgs, fh.Filename+": "+err.Error()) + continue + } + if err := s.addFileFromKB(doc, kbFolder.ID, kb.TenantID); err != nil { + // Linkage failed: roll back the document row and blob so the partial + // state doesn't leave an invisible (unlisted) document behind. + _, _ = s.documentDAO.Delete(doc.ID) + _ = storageImpl.Remove(kb.ID, location) + errMsgs = append(errMsgs, fh.Filename+": "+err.Error()) + continue + } + // Only reserve the name once the write fully succeeds. + taken[filename] = true + results = append(results, docToRawMap(doc)) + } + + return results, errMsgs +} + +// UploadEmptyDocument inserts a zero-byte "virtual" document into the dataset. +func (s *DocumentService) UploadEmptyDocument(kb *entity.Knowledgebase, tenantID, name string) (map[string]interface{}, common.ErrorCode, error) { + // A transient lookup failure means the existing-name set is unknown; fail + // rather than write blind and risk a duplicate. + names, err := s.documentDAO.ListNamesByKbID(kb.ID) + if err != nil { + return nil, common.CodeServerError, err + } + for _, n := range names { + if n == name { + return nil, common.CodeDataError, fmt.Errorf("Duplicated document name in the same dataset.") + } + } + + kbFolder, err := s.ensureKBFolder(kb, tenantID) + if err != nil { + return nil, common.CodeServerError, err + } + + doc := s.newDatasetDocument(kb, tenantID, name, "", "virtual", kb.ParserConfig, "local", 0, nil) + if err := s.documentDAO.Create(doc); err != nil { + return nil, common.CodeServerError, err + } + if err := s.addFileFromKB(doc, kbFolder.ID, kb.TenantID); err != nil { + _, _ = s.documentDAO.Delete(doc.ID) + return nil, common.CodeServerError, err + } + return docToRawMap(doc), common.CodeSuccess, nil +} + +// knowledgebaseFolderName is the file-manager folder under each tenant's root +// that holds per-dataset subfolders, mirroring Python KNOWLEDGEBASE_FOLDER_NAME. +const knowledgebaseFolderName = ".knowledgebase" + +// ensureKBFolder resolves (creating as needed) the per-dataset file-manager +// folder: root -> .knowledgebase -> . Mirrors Python +// get_root_folder + get_kb_folder + new_a_file_from_kb. +func (s *DocumentService) ensureKBFolder(kb *entity.Knowledgebase, tenantID string) (*entity.File, error) { + root, err := s.fileDAO.GetRootFolder(tenantID) + if err != nil { + return nil, err + } + kbRoot, err := s.newAFileFromKB(tenantID, knowledgebaseFolderName, root.ID) + if err != nil { + return nil, err + } + return s.newAFileFromKB(kb.TenantID, kb.Name, kbRoot.ID) +} + +// newAFileFromKB returns the existing folder named name under parentID, or +// creates it. Mirrors Python FileService.new_a_file_from_kb. +func (s *DocumentService) newAFileFromKB(tenantID, name, parentID string) (*entity.File, error) { + for _, f := range s.fileDAO.Query(name, parentID) { + if f.TenantID == tenantID { + return f, nil + } + } + loc := "" + folder := &entity.File{ + ID: strings.ReplaceAll(uuid.New().String(), "-", ""), + ParentID: parentID, + TenantID: tenantID, + CreatedBy: tenantID, + Name: name, + Type: "folder", + Size: 0, + Location: &loc, + SourceType: string(entity.FileSourceKnowledgebase), + } + if err := s.fileDAO.Create(folder); err != nil { + return nil, err + } + return folder, nil +} + +// addFileFromKB links a document into the file manager: a File row under the +// dataset folder plus a file2document mapping. Mirrors Python +// FileService.add_file_from_kb (idempotent on the document mapping). +func (s *DocumentService) addFileFromKB(doc *entity.Document, kbFolderID, tenantID string) error { + if existing, err := s.file2DocumentDAO.GetByDocumentID(doc.ID); err == nil && len(existing) > 0 { + return nil + } + name := "" + if doc.Name != nil { + name = *doc.Name + } + loc := "" + if doc.Location != nil { + loc = *doc.Location + } + fileID := strings.ReplaceAll(uuid.New().String(), "-", "") + file := &entity.File{ + ID: fileID, + ParentID: kbFolderID, + TenantID: tenantID, + CreatedBy: tenantID, + Name: name, + Type: doc.Type, + Size: doc.Size, + Location: &loc, + SourceType: string(entity.FileSourceKnowledgebase), + } + if err := s.fileDAO.Create(file); err != nil { + return err + } + docID := doc.ID + if err := s.file2DocumentDAO.Create(&entity.File2Document{ + ID: strings.ReplaceAll(uuid.New().String(), "-", ""), + FileID: &fileID, + DocumentID: &docID, + }); err != nil { + _ = s.fileDAO.Delete(fileID) + return err + } + return nil +} + +func (s *DocumentService) UploadWebDocument(kb *entity.Knowledgebase, tenantID, name, url string) (map[string]interface{}, common.ErrorCode, error) { + storageImpl := storage.GetStorageFactory().GetStorage() + if storageImpl == nil { + return nil, common.CodeServerError, fmt.Errorf("storage not initialized") + } + + kbFolder, err := s.ensureKBFolder(kb, tenantID) + if err != nil { + return nil, common.CodeServerError, err + } + + names, err := s.documentDAO.ListNamesByKbID(kb.ID) + if err != nil { + return nil, common.CodeServerError, err + } + taken := map[string]bool{} + for _, n := range names { + taken[n] = true + } + + blob, headers, _, err := fetchRemoteFileSafely(url, maxUploadDocSize) + if err != nil { + return nil, common.CodeDataError, err + } + contentType := "" + if headers != nil { + contentType = headers.Get("Content-Type") + } + filename := normalizeWebDocumentName(name, contentType, blob) + filename, _, blob = normalizeUploadInfoContent(filename, contentType, blob) + filename = uniqueUploadName(filename, taken) + + filetype := utility.FilenameType(filename) + if filetype == utility.FileTypeOTHER { + return nil, common.CodeDataError, fmt.Errorf("This type of file has not been supported yet!") + } + + location := filename + for storageImpl.ObjExist(kb.ID, location) { + location += "_" + } + if err := storageImpl.Put(kb.ID, location, blob); err != nil { + return nil, common.CodeServerError, err + } + + doc := s.newDatasetDocument(kb, tenantID, filename, location, string(filetype), kb.ParserConfig, "web", int64(len(blob)), blob) + if err := s.documentDAO.Create(doc); err != nil { + _ = storageImpl.Remove(kb.ID, location) + return nil, common.CodeServerError, err + } + if err := s.addFileFromKB(doc, kbFolder.ID, kb.TenantID); err != nil { + _, _ = s.documentDAO.Delete(doc.ID) + _ = storageImpl.Remove(kb.ID, location) + return nil, common.CodeServerError, err + } + return docToRawMap(doc), common.CodeSuccess, nil +} + +func normalizeWebDocumentName(name, contentType string, blob []byte) string { + filename := utility.SanitizeFilename(name) + if filepath.Ext(filename) != "" { + return filename + } + lowerCT := strings.ToLower(strings.TrimSpace(strings.Split(contentType, ";")[0])) + switch { + case lowerCT == "application/pdf" || http.DetectContentType(blob) == "application/pdf" || bytesLooksLikePDF(blob): + return filename + ".pdf" + case lowerCT == "text/html" || lowerCT == "application/xhtml+xml" || looksLikeHTML(blob): + return filename + ".html" + default: + return filename + } +} + +// newDatasetDocument builds a Document row for an upload, deriving parser_id, +// suffix and content hash. blob may be nil for the empty/virtual document. +func (s *DocumentService) newDatasetDocument(kb *entity.Knowledgebase, tenantID, filename, location, filetype string, parserConfig entity.JSONMap, src string, size int64, blob []byte) *entity.Document { + docID := strings.ReplaceAll(uuid.New().String(), "-", "") + zero := "0" + suffix := "" + if i := strings.LastIndex(filename, "."); i >= 0 { + suffix = filename[i+1:] + } + loc := location + doc := &entity.Document{ + ID: docID, + KbID: kb.ID, + ParserID: selectUploadParser(utility.FileType(filetype), filename, kb.ParserID), + PipelineID: kb.PipelineID, + ParserConfig: parserConfig, + CreatedBy: tenantID, + Type: filetype, + SourceType: src, + Name: &filename, + Location: &loc, + Size: size, + Suffix: suffix, + Run: &zero, + Status: &zero, + } + if blob != nil { + hash := contentHashHex(blob) + doc.ContentHash = &hash + } + return doc +} + +// docToRawMap serialises a freshly created Document into the raw key shape the +// handler remaps (chunk_num→chunk_count, kb_id→dataset_id, parser_id→chunk_method). +func docToRawMap(doc *entity.Document) map[string]interface{} { + m := map[string]interface{}{ + "id": doc.ID, + "kb_id": doc.KbID, + "parser_id": doc.ParserID, + "parser_config": map[string]interface{}(doc.ParserConfig), + "created_by": doc.CreatedBy, + "type": doc.Type, + "source_type": doc.SourceType, + "size": doc.Size, + "chunk_num": doc.ChunkNum, + "token_num": doc.TokenNum, + "suffix": doc.Suffix, + "run": "0", + } + if doc.Name != nil { + m["name"] = *doc.Name + } + if doc.Location != nil { + m["location"] = *doc.Location + } + if doc.PipelineID != nil { + m["pipeline_id"] = *doc.PipelineID + } + if doc.ContentHash != nil { + m["content_hash"] = *doc.ContentHash + } + return m +} + +// uniqueUploadName appends a numeric suffix until the name is free, mirroring +// Python duplicate_name. +func uniqueUploadName(name string, taken map[string]bool) string { + if !taken[name] { + return name + } + base, ext := name, "" + if i := strings.LastIndex(name, "."); i >= 0 { + base, ext = name[:i], name[i:] + } + for i := 1; ; i++ { + candidate := fmt.Sprintf("%s(%d)%s", base, i, ext) + if !taken[candidate] { + return candidate + } + } +} + +// maxUploadDocSize bounds a single uploaded file held in memory, mirroring the +// Python DOC_MAXIMUM_SIZE default (128 MiB; overridable there via MAX_CONTENT_LENGTH). +const maxUploadDocSize = 128 * 1024 * 1024 + +func readFileHeaderBytes(fh *multipart.FileHeader) ([]byte, error) { + if fh.Size > maxUploadDocSize { + return nil, fmt.Errorf("file exceeds the maximum allowed size of %d bytes", maxUploadDocSize) + } + src, err := fh.Open() + if err != nil { + return nil, err + } + defer src.Close() + blob, err := io.ReadAll(io.LimitReader(src, maxUploadDocSize+1)) + if err != nil { + return nil, err + } + if len(blob) > maxUploadDocSize { + return nil, fmt.Errorf("file exceeds the maximum allowed size of %d bytes", maxUploadDocSize) + } + return blob, nil +} + // MetadataUpdate is one update item: set key to value. type DocumentMetadataUpdate struct { Key string `json:"key"` diff --git a/internal/service/document_test.go b/internal/service/document_test.go index 929904c7a5..5d1e089287 100644 --- a/internal/service/document_test.go +++ b/internal/service/document_test.go @@ -17,12 +17,17 @@ package service import ( + "bytes" "context" "errors" "fmt" + "mime/multipart" + "net/http" + "net/http/httptest" "path/filepath" "strings" "testing" + "time" "github.com/glebarez/sqlite" "gorm.io/gorm" @@ -31,8 +36,60 @@ import ( "ragflow/internal/dao" "ragflow/internal/engine/types" "ragflow/internal/entity" + "ragflow/internal/storage" + "ragflow/internal/utility" ) +type fakeUploadStorage struct { + objects map[string][]byte +} + +func newFakeUploadStorage() *fakeUploadStorage { + return &fakeUploadStorage{objects: map[string][]byte{}} +} + +func (f *fakeUploadStorage) Health() bool { return true } +func (f *fakeUploadStorage) key(bucket, fnm string) string { return bucket + "/" + fnm } +func (f *fakeUploadStorage) Put(bucket, fnm string, binary []byte, tenantID ...string) error { + f.objects[f.key(bucket, fnm)] = append([]byte(nil), binary...) + return nil +} +func (f *fakeUploadStorage) Get(bucket, fnm string, tenantID ...string) ([]byte, error) { + v, ok := f.objects[f.key(bucket, fnm)] + if !ok { + return nil, errors.New("not found") + } + return append([]byte(nil), v...), nil +} +func (f *fakeUploadStorage) Remove(bucket, fnm string, tenantID ...string) error { + delete(f.objects, f.key(bucket, fnm)) + return nil +} +func (f *fakeUploadStorage) ObjExist(bucket, fnm string, tenantID ...string) bool { + _, ok := f.objects[f.key(bucket, fnm)] + return ok +} +func (f *fakeUploadStorage) GetPresignedURL(bucket, fnm string, expires time.Duration, tenantID ...string) (string, error) { + return "", nil +} +func (f *fakeUploadStorage) BucketExists(bucket string) bool { return true } +func (f *fakeUploadStorage) RemoveBucket(bucket string) error { return nil } +func (f *fakeUploadStorage) Copy(srcBucket, srcPath, destBucket, destPath string) bool { + v, ok := f.objects[f.key(srcBucket, srcPath)] + if !ok { + return false + } + f.objects[f.key(destBucket, destPath)] = append([]byte(nil), v...) + return true +} +func (f *fakeUploadStorage) Move(srcBucket, srcPath, destBucket, destPath string) bool { + if !f.Copy(srcBucket, srcPath, destBucket, destPath) { + return false + } + delete(f.objects, f.key(srcBucket, srcPath)) + return true +} + type fakeChatDocEngine struct{} func (fakeChatDocEngine) CreateChunkStore(context.Context, string, string, int, string) error { @@ -294,6 +351,32 @@ func testDocumentService(t *testing.T) *DocumentService { } } +func makeTestFileHeader(t *testing.T, field, filename string, content []byte) *multipart.FileHeader { + t.Helper() + var body bytes.Buffer + writer := multipart.NewWriter(&body) + part, err := writer.CreateFormFile(field, filename) + if err != nil { + t.Fatalf("create form file: %v", err) + } + if _, err := part.Write(content); err != nil { + t.Fatalf("write form file: %v", err) + } + if err := writer.Close(); err != nil { + t.Fatalf("close multipart writer: %v", err) + } + req := httptest.NewRequest(http.MethodPost, "/", &body) + req.Header.Set("Content-Type", writer.FormDataContentType()) + if err := req.ParseMultipartForm(int64(len(content) + 1024)); err != nil { + t.Fatalf("parse multipart form: %v", err) + } + fhs := req.MultipartForm.File[field] + if len(fhs) != 1 { + t.Fatalf("expected 1 file header, got %d", len(fhs)) + } + return fhs[0] +} + // sptr returns a pointer to the given string. func sptr(s string) *string { return &s } @@ -502,6 +585,164 @@ func TestDeleteDocumentFull_SharedFilePreserved(t *testing.T) { } } +func TestSelectUploadParser_MirrorsPython(t *testing.T) { + tests := []struct { + name string + docType utility.FileType + filename string + defaultValue string + want string + }{ + {name: "visual", docType: utility.FileTypeVISUAL, filename: "img.png", defaultValue: "naive", want: "picture"}, + {name: "aural", docType: utility.FileTypeAURAL, filename: "audio.mp3", defaultValue: "naive", want: "audio"}, + {name: "presentation by ext", docType: utility.FileTypeDOC, filename: "deck.pptx", defaultValue: "naive", want: "presentation"}, + {name: "email by ext", docType: utility.FileTypeDOC, filename: "mail.eml", defaultValue: "naive", want: "email"}, + {name: "fallback default", docType: utility.FileTypeDOC, filename: "notes.txt", defaultValue: "manual", want: "manual"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := selectUploadParser(tt.docType, tt.filename, tt.defaultValue); got != tt.want { + t.Fatalf("selectUploadParser(%q)=%q, want %q", tt.filename, got, tt.want) + } + }) + } +} + +func TestContentHashHex_MatchesPythonXXH128(t *testing.T) { + tests := []struct { + data []byte + want string + }{ + {data: []byte("abc"), want: "06b05ab6733a618578af5f94892f3950"}, + {data: []byte(""), want: "99aa06d3014798d86001c324468d497f"}, + } + for _, tt := range tests { + if got := contentHashHex(tt.data); got != tt.want { + t.Fatalf("contentHashHex(%q)=%s, want %s", tt.data, got, tt.want) + } + } +} + +func TestUploadLocalDocuments_MirrorsPythonCoreFields(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + + mockStorage := newFakeUploadStorage() + factory := storage.GetStorageFactory() + origStorage := factory.GetStorage() + factory.SetStorage(mockStorage) + t.Cleanup(func() { factory.SetStorage(origStorage) }) + + pipelineID := "pipe-1" + kb := &entity.Knowledgebase{ + ID: "kb-upload", + TenantID: "tenant-1", + Name: "kb-upload", + ParserID: "naive", + PipelineID: &pipelineID, + ParserConfig: entity.JSONMap{ + "existing": "value", + }, + } + if err := dao.DB.Create(kb).Error; err != nil { + t.Fatalf("insert kb: %v", err) + } + if err := dao.DB.Create(&entity.Document{ + ID: "doc-existing", + KbID: kb.ID, + ParserID: "naive", + ParserConfig: entity.JSONMap{}, + Name: sptr("deck.pptx"), + Status: sptr("1"), + }).Error; err != nil { + t.Fatalf("insert existing doc: %v", err) + } + + svc := testDocumentService(t) + fh := makeTestFileHeader(t, "file", "deck.pptx", []byte("abc")) + got, errs := svc.UploadLocalDocuments(kb, "user-1", []*multipart.FileHeader{fh}, "nested/path", map[string]interface{}{ + "table_column_mode": "assist", + }) + if len(errs) != 0 { + t.Fatalf("unexpected errs: %v", errs) + } + if len(got) != 1 { + t.Fatalf("expected 1 uploaded doc, got %d", len(got)) + } + doc := got[0] + if doc["name"] != "deck(1).pptx" { + t.Fatalf("name=%v, want deck(1).pptx", doc["name"]) + } + if doc["location"] != "nested/path/deck(1).pptx" { + t.Fatalf("location=%v, want nested/path/deck(1).pptx", doc["location"]) + } + if doc["parser_id"] != "presentation" { + t.Fatalf("parser_id=%v, want presentation", doc["parser_id"]) + } + if doc["content_hash"] != "06b05ab6733a618578af5f94892f3950" { + t.Fatalf("content_hash=%v", doc["content_hash"]) + } + cfg := doc["parser_config"].(map[string]interface{}) + if cfg["existing"] != "value" || cfg["table_column_mode"] != "assist" { + t.Fatalf("parser_config=%v", cfg) + } + + storedBlob, err := mockStorage.Get(kb.ID, "nested/path/deck(1).pptx") + if err != nil { + t.Fatalf("blob not stored: %v", err) + } + if string(storedBlob) != "abc" { + t.Fatalf("stored blob=%q, want abc", storedBlob) + } +} + +func TestUploadEmptyDocument_CreatesVirtualDocumentAndFileLink(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + + pipelineID := "pipe-2" + kb := &entity.Knowledgebase{ + ID: "kb-empty", + TenantID: "tenant-1", + Name: "kb-empty", + ParserID: "manual", + PipelineID: &pipelineID, + ParserConfig: entity.JSONMap{ + "foo": "bar", + }, + } + if err := dao.DB.Create(kb).Error; err != nil { + t.Fatalf("insert kb: %v", err) + } + + svc := testDocumentService(t) + got, code, err := svc.UploadEmptyDocument(kb, "user-1", "draft.md") + if err != nil { + t.Fatalf("UploadEmptyDocument error: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("code=%v, want success", code) + } + if got["type"] != "virtual" || got["parser_id"] != "manual" || got["size"] != int64(0) { + t.Fatalf("unexpected doc map: %v", got) + } + + var docCount int64 + if err := dao.DB.Model(&entity.Document{}).Where("kb_id = ?", kb.ID).Count(&docCount).Error; err != nil { + t.Fatalf("count docs: %v", err) + } + if docCount != 1 { + t.Fatalf("doc count=%d, want 1", docCount) + } + var linkCount int64 + if err := dao.DB.Model(&entity.File2Document{}).Count(&linkCount).Error; err != nil { + t.Fatalf("count links: %v", err) + } + if linkCount != 1 { + t.Fatalf("link count=%d, want 1", linkCount) + } +} + func insertUserTenantForAccessCheck(t *testing.T, userID, tenantID string) { t.Helper() // Insert user if not exists (email is NOT NULL, password is nullable pointer) diff --git a/internal/service/document_upload_helpers.go b/internal/service/document_upload_helpers.go new file mode 100644 index 0000000000..b8a6e6a8c4 --- /dev/null +++ b/internal/service/document_upload_helpers.go @@ -0,0 +1,42 @@ +package service + +import ( + "encoding/hex" + "path/filepath" + "regexp" + "strings" + + "ragflow/internal/utility" + + "github.com/zeebo/xxh3" +) + +var ( + presentationUploadPattern = regexp.MustCompile(`(?i)\.(ppt|pptx|pages)$`) + emailUploadPattern = regexp.MustCompile(`(?i)\.(msg|eml)$`) +) + +// selectUploadParser mirrors Python FileService.get_parser. +func selectUploadParser(docType utility.FileType, filename, defaultParser string) string { + switch docType { + case utility.FileTypeVISUAL: + return "picture" + case utility.FileTypeAURAL: + return "audio" + } + base := filepath.Base(strings.TrimSpace(filename)) + switch { + case presentationUploadPattern.MatchString(base): + return "presentation" + case emailUploadPattern.MatchString(base): + return "email" + default: + return defaultParser + } +} + +// contentHashHex mirrors Python xxhash.xxh128(blob).hexdigest(). +func contentHashHex(blob []byte) string { + sum := xxh3.Hash128(blob).Bytes() + return hex.EncodeToString(sum[:]) +} diff --git a/internal/service/file.go b/internal/service/file.go index 07a3579409..741413822d 100644 --- a/internal/service/file.go +++ b/internal/service/file.go @@ -600,8 +600,6 @@ func (s *FileService) checkFileTeamPermission(file *entity.File, uid string) boo } kbDAO := dao.NewKnowledgebaseDAO() - userTenantDAO := dao.NewUserTenantDAO() - for _, datasetID := range datasetIDs { ds, err := kbDAO.GetByID(datasetID) if err != nil || ds == nil { @@ -609,7 +607,7 @@ func (s *FileService) checkFileTeamPermission(file *entity.File, uid string) boo } // Check KB tenant permission - if s.checkDatasetTeamPermission(ds, uid, userTenantDAO) { + if s.checkDatasetTeamPermission(ds, uid) { return true } } @@ -619,31 +617,8 @@ func (s *FileService) checkFileTeamPermission(file *entity.File, uid string) boo // checkDatasetTeamPermission checks if user has permission to access the dataset // Matches Python's check_kb_team_permission function -func (s *FileService) checkDatasetTeamPermission(ds *entity.Knowledgebase, uid string, userTenantDAO *dao.UserTenantDAO) bool { - // KB's tenant directly authorized - if ds.TenantID == uid { - return true - } - - // Check permission type - permission := ds.Permission - if permission != string(entity.TenantPermissionTeam) { - return false - } - - // Check if user joined the tenant - joinedTenantIDs, err := userTenantDAO.GetTenantIDsByUserID(uid) - if err != nil || len(joinedTenantIDs) == 0 { - return false - } - - for _, tenantID := range joinedTenantIDs { - if tenantID == ds.TenantID { - return true - } - } - - return false +func (s *FileService) checkDatasetTeamPermission(ds *entity.Knowledgebase, uid string) bool { + return hasKBTeamPermission(ds, uid, dao.NewTenantDAO()) } // deleteSingleFile deletes a single file (not folder) diff --git a/internal/service/file2document.go b/internal/service/file2document.go index 2f490a4e56..c1e4022219 100644 --- a/internal/service/file2document.go +++ b/internal/service/file2document.go @@ -26,6 +26,7 @@ import ( "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/entity" + "ragflow/internal/utility" ) // Sentinel errors returned by File2DocumentService. Handlers map these to @@ -196,7 +197,7 @@ func (s *File2DocumentService) convertFiles(fileIDs, kbIDs []string, userID stri continue } - parserID := getParser(file.Type, file.Name, kb.ParserID) + parserID := selectUploadParser(utility.FileType(file.Type), file.Name, kb.ParserID) suffix := strings.TrimPrefix(filepath.Ext(file.Name), ".") doc := &entity.Document{ ID: common.GenerateUUID(), @@ -266,12 +267,18 @@ func (s *File2DocumentService) checkFileTeamPermission(file *entity.File, userID if file.TenantID == userID { return true } - tenants, err := s.userTenantDAO.GetByUserID(userID) - if err != nil { + + datasetIDs, err := s.fileDAO.GetDatasetIDByFileID(file.ID) + if err != nil || len(datasetIDs) == 0 { return false } - for _, t := range tenants { - if t.TenantID == file.TenantID { + + for _, datasetID := range datasetIDs { + kb, err := s.kbDAO.GetByID(datasetID) + if err != nil || kb == nil { + continue + } + if s.checkKBTeamPermission(kb, userID) { return true } } @@ -281,43 +288,7 @@ func (s *File2DocumentService) checkFileTeamPermission(file *entity.File, userID // checkKBTeamPermission mirrors Python check_kb_team_permission: // true when kb.TenantID == userID or user is in the KB tenant's team. func (s *File2DocumentService) checkKBTeamPermission(kb *entity.Knowledgebase, userID string) bool { - if kb.TenantID == userID { - return true - } - tenants, err := s.userTenantDAO.GetByUserID(userID) - if err != nil { - return false - } - for _, t := range tenants { - if t.TenantID == kb.TenantID { - return true - } - } - return false -} - -// getParser maps (fileType, fileName, kbParserID) → a parser ID. -// Mirrors Python FileService.get_parser — falls back to the KB's parser. -func getParser(fileType, fileName, kbParserID string) string { - ext := strings.ToLower(strings.TrimPrefix(filepath.Ext(fileName), ".")) - switch ext { - case "pdf": - return "pdf" - case "doc", "docx": - return "naive" - case "ppt", "pptx": - return "presentation" - case "xls", "xlsx": - return "table" - case "txt", "md": - return "naive" - case "png", "jpg", "jpeg", "gif", "bmp", "webp": - return "picture" - } - if kbParserID != "" { - return kbParserID - } - return "naive" + return hasKBTeamPermission(kb, userID, dao.NewTenantDAO()) } // dedupeStrings returns the input slice with duplicates removed, preserving the diff --git a/internal/service/team_permission.go b/internal/service/team_permission.go new file mode 100644 index 0000000000..73a0882138 --- /dev/null +++ b/internal/service/team_permission.go @@ -0,0 +1,31 @@ +package service + +import ( + "ragflow/internal/dao" + "ragflow/internal/entity" +) + +// hasKBTeamPermission mirrors Python check_kb_team_permission: +// direct owner access is always allowed; otherwise the KB must be team-shared +// and the caller must be a joined normal member of the owner tenant. +func hasKBTeamPermission(kb *entity.Knowledgebase, userID string, tenantDAO *dao.TenantDAO) bool { + if kb == nil { + return false + } + if kb.TenantID == userID { + return true + } + if kb.Permission != string(entity.TenantPermissionTeam) { + return false + } + joinedTenants, err := tenantDAO.GetJoinedTenantsByUserID(userID) + if err != nil { + return false + } + for _, tenant := range joinedTenants { + if tenant.TenantID == kb.TenantID { + return true + } + } + return false +}