diff --git a/internal/handler/document.go b/internal/handler/document.go index cdd1d348d6..7054789558 100644 --- a/internal/handler/document.go +++ b/internal/handler/document.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "mime" + "mime/multipart" "net/http" "path/filepath" "ragflow/internal/common" @@ -60,6 +61,9 @@ type documentServiceIface interface { GetDocumentPreview(docID string) (*service.DocumentPreview, error) DownloadDocument(datasetID, docID string) (*service.DownloadDocumentResp, error) UpdateDatasetDocument(userID, datasetID, documentID string, req *service.UpdateDatasetDocumentRequest, present map[string]bool) (*service.UpdateDatasetDocumentResponse, common.ErrorCode, error) + BatchUpdateDocumentMetadatas(datasetID string, selector *service.DocumentMetadataSelector, updates []service.DocumentMetadataUpdate, deletes []service.DocumentMetadataDelete) (*service.BatchUpdateDocumentMetadatasResponse, common.ErrorCode, error) + UploadDocumentInfos(userID string, files []*multipart.FileHeader) ([]map[string]interface{}, common.ErrorCode, error) + UploadDocumentInfoByURL(userID, rawURL string) (map[string]interface{}, common.ErrorCode, error) ListIngestionTasks(userID string, datasetID *string, page, pageSize int) ([]*entity.IngestionTask, error) IngestDocuments(datasetID, userID string, docIDs []string) ([]*service.ParseDocumentResponse, error) StopIngestionTasks(tasks []string, userID string) ([]*entity.IngestionTask, error) @@ -1296,3 +1300,123 @@ func (h *DocumentHandler) UpdateDatasetDocument(c *gin.Context) { "data": data, }) } + +func (h *DocumentHandler) UploadInfo(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + form, err := c.MultipartForm() + if err != nil && !strings.Contains(err.Error(), "request Content-Type isn't multipart/form-data") { + jsonError(c, common.CodeArgumentError, "Failed to parse multipart form: "+err.Error()) + return + } + + var fileHeaders []*multipart.FileHeader + if form != nil && form.File != nil { + fileHeaders = form.File["file"] + } + rawURL := strings.TrimSpace(c.Query("url")) + + if len(fileHeaders) > 0 && rawURL != "" { + jsonError(c, common.CodeArgumentError, "Provide either multipart file(s) or ?url=..., not both.") + return + } + if len(fileHeaders) == 0 && rawURL == "" { + jsonError(c, common.CodeArgumentError, "Missing input: provide multipart file(s) or url") + return + } + + if rawURL != "" { + data, code, err := h.documentService.UploadDocumentInfoByURL(user.ID, rawURL) + if err != nil { + jsonError(c, code, err.Error()) + return + } + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "data": data, + "message": "success", + }) + return + } + + data, code, err := h.documentService.UploadDocumentInfos(user.ID, fileHeaders) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + var payload interface{} + if len(data) == 1 { + payload = data[0] + } else { + payload = data + } + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "data": payload, + "message": "success", + }) +} + +type documentMetadataBatchRequest struct { + Selector *service.DocumentMetadataSelector `json:"selector"` + Updates []service.DocumentMetadataUpdate `json:"updates"` + Deletes []service.DocumentMetadataDelete `json:"deletes"` +} + +func (h *DocumentHandler) MetadataBatchUpdate(c *gin.Context) { + h.handleBatchUpdateDocumentMetadatas(c) +} + +func (h *DocumentHandler) UpdateDocumentMetadatas(c *gin.Context) { + h.handleBatchUpdateDocumentMetadatas(c) +} + +func (h *DocumentHandler) handleBatchUpdateDocumentMetadatas(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + datasetID := strings.TrimSpace(c.Param("dataset_id")) + if datasetID == "" { + jsonError(c, common.CodeArgumentError, "dataset_id is required") + return + } + if !h.datasetService.Accessible(datasetID, user.ID) { + jsonError(c, common.CodeDataError, "You don't own the dataset "+datasetID+".") + return + } + + var req documentMetadataBatchRequest + if err := c.ShouldBindJSON(&req); err != nil { + jsonError(c, common.CodeDataError, err.Error()) + return + } + if req.Selector == nil { + req.Selector = &service.DocumentMetadataSelector{} + } + if req.Updates == nil { + req.Updates = []service.DocumentMetadataUpdate{} + } + if req.Deletes == nil { + req.Deletes = []service.DocumentMetadataDelete{} + } + + resp, code, err := h.documentService.BatchUpdateDocumentMetadatas(datasetID, req.Selector, req.Updates, req.Deletes) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "data": resp, + "message": "success", + }) +} diff --git a/internal/handler/document_test.go b/internal/handler/document_test.go index 93d807539d..aeedc2d7f1 100644 --- a/internal/handler/document_test.go +++ b/internal/handler/document_test.go @@ -19,6 +19,7 @@ package handler import ( "encoding/json" "fmt" + "mime/multipart" "net/http" "net/http/httptest" "strings" @@ -49,6 +50,15 @@ type fakeDocumentService struct { func (f *fakeDocumentService) UpdateDatasetDocument(userID, datasetID, documentID string, req *service.UpdateDatasetDocumentRequest, present map[string]bool) (*service.UpdateDatasetDocumentResponse, common.ErrorCode, error) { return nil, common.CodeSuccess, nil } +func (f *fakeDocumentService) BatchUpdateDocumentMetadatas(datasetID string, selector *service.DocumentMetadataSelector, updates []service.DocumentMetadataUpdate, deletes []service.DocumentMetadataDelete) (*service.BatchUpdateDocumentMetadatasResponse, common.ErrorCode, error) { + return nil, common.CodeSuccess, nil +} +func (f *fakeDocumentService) UploadDocumentInfos(userID string, files []*multipart.FileHeader) ([]map[string]interface{}, common.ErrorCode, error) { + return nil, common.CodeSuccess, nil +} +func (f *fakeDocumentService) UploadDocumentInfoByURL(userID, rawURL string) (map[string]interface{}, common.ErrorCode, error) { + return nil, common.CodeSuccess, nil +} func (f *fakeDocumentService) GetDocumentArtifact(filename string) (*service.ArtifactResponse, error) { if filename == "error.txt" { diff --git a/internal/router/router.go b/internal/router/router.go index 349d0c329d..34ac580e77 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -236,6 +236,7 @@ func (r *Router) Setup(engine *gin.Engine) { documents := v1.Group("/documents") { documents.POST("", r.documentHandler.CreateDocument) + documents.POST("/upload", r.documentHandler.UploadInfo) documents.GET("", r.documentHandler.ListDocuments) documents.GET("/artifact/:filename", r.documentHandler.GetDocumentArtifact) documents.GET("/:id/preview", r.documentHandler.GetDocumentPreview) @@ -315,6 +316,8 @@ func (r *Router) Setup(engine *gin.Engine) { //datasets.POST("/:dataset_id/documents/stop", r.documentHandler.StopParseDocuments) datasets.DELETE("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.RemoveChunks) datasets.PUT("/:dataset_id/documents/:document_id/metadata/config", r.datasetsHandler.UpdateDocumentMetadataConfig) + datasets.POST("/:dataset_id/metadata/update", r.documentHandler.MetadataBatchUpdate) + datasets.PATCH("/:dataset_id/documents/metadatas", r.documentHandler.UpdateDocumentMetadatas) } // Search routes diff --git a/internal/service/document.go b/internal/service/document.go index b8f8042802..fe5b0c615f 100644 --- a/internal/service/document.go +++ b/internal/service/document.go @@ -22,8 +22,10 @@ import ( "errors" "fmt" "math" + "mime/multipart" "os" "path/filepath" + "reflect" "regexp" "sort" "strconv" @@ -41,6 +43,7 @@ import ( "ragflow/internal/tokenizer" "ragflow/internal/utility" + "go.uber.org/zap" "gorm.io/gorm" ) @@ -1571,15 +1574,10 @@ func aggregateMetadata(chunks []map[string]interface{}) map[string]interface{} { typeCounter[k][valueType] = typeCounter[k][valueType] + 1 } - // Aggregate value counts - values := v - if v, ok := v.([]interface{}); ok { - values = v - } else { - values = []interface{}{v} - } - - for _, vv := range values.([]interface{}) { + // Aggregate value counts. Flatten nested arrays so malformed values do + // not surface in the UI as the literal string "[]". + values := flattenMetadataSummaryValues(v) + for _, vv := range values { if vv == nil { continue } @@ -1676,6 +1674,27 @@ func getMetaValueType(value interface{}) string { return "string" } +func flattenMetadataSummaryValues(value interface{}) []interface{} { + switch typed := value.(type) { + case []interface{}: + result := make([]interface{}, 0, len(typed)) + for _, item := range typed { + result = append(result, flattenMetadataSummaryValues(item)...) + } + return result + case []string: + result := make([]interface{}, 0, len(typed)) + for _, item := range typed { + result = append(result, item) + } + return result + case nil: + return nil + default: + return []interface{}{typed} + } +} + // isTimeString checks if a string is an ISO 8601 datetime func isTimeString(s string) bool { matched, _ := regexp.MatchString(`^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$`, s) @@ -2176,3 +2195,471 @@ func mapDocumentRunStatus(run *string) string { return "UNSTART" } } + +// MetadataUpdate is one update item: set key to value. +type DocumentMetadataUpdate struct { + Key string `json:"key"` + Value interface{} `json:"value"` + Match interface{} `json:"match,omitempty"` + ValueType string `json:"valueType,omitempty"` +} + +// MetadataDelete removes a whole key, or a specific value from a list field. +type DocumentMetadataDelete struct { + Key string `json:"key"` + Value interface{} `json:"value,omitempty"` +} + +// MetadataSelector selects which documents to target. +type DocumentMetadataSelector struct { + DocumentIDs []string `json:"document_ids"` + MetadataCondition map[string]interface{} `json:"metadata_condition"` +} + +// BatchUpdateDocumentMetadatasResponse summarises the operation. +type BatchUpdateDocumentMetadatasResponse struct { + Updated int `json:"updated"` + MatchedDocs int `json:"matched_docs"` +} + +// BatchUpdateDocumentMetadatas implements the shared logic for +// PATCH /datasets/:dataset_id/documents/metadatas and +// POST /datasets/:dataset_id/metadata/update. +func (s *DocumentService) BatchUpdateDocumentMetadatas( + datasetID string, + selector *DocumentMetadataSelector, + updates []DocumentMetadataUpdate, + deletes []DocumentMetadataDelete, +) (*BatchUpdateDocumentMetadatasResponse, common.ErrorCode, error) { + if selector == nil { + selector = &DocumentMetadataSelector{} + } + if code, err := validateBatchUpdateDocumentMetadatasRequest(selector, updates, deletes); err != nil { + return nil, code, err + } + + // Resolve which document IDs to target. + targetDocIDs := make(map[string]struct{}) + + if len(selector.DocumentIDs) > 0 { + // Validate that supplied IDs actually belong to this dataset. + allRows, err := s.documentDAO.GetAllDocIDsByKBIDs([]string{datasetID}) + if err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to list dataset documents: %w", err) + } + kbDocIDSet := make(map[string]struct{}, len(allRows)) + for _, row := range allRows { + kbDocIDSet[row["id"]] = struct{}{} + } + var invalidIDs []string + for _, id := range selector.DocumentIDs { + if _, ok := kbDocIDSet[id]; !ok { + invalidIDs = append(invalidIDs, id) + } + } + if len(invalidIDs) > 0 { + return nil, common.CodeDataError, fmt.Errorf("these documents do not belong to dataset %s: %s", + datasetID, strings.Join(invalidIDs, ", ")) + } + for _, id := range selector.DocumentIDs { + targetDocIDs[id] = struct{}{} + } + } + + // Apply metadata_condition filter. + if len(selector.MetadataCondition) > 0 { + flattedMeta, err := s.metadataSvc.GetFlattedMetaByKBs([]string{datasetID}) + if err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to get flattened metadata: %w", err) + } + + // ParseAndConvert mirrors Python convert_conditions: conditions arrive as + // {name, comparison_operator, value}, the operator is normalised, and the + // (possibly non-string) value is preserved. MetaFilter then matches against + // the common.MetaData returned by GetFlattedMetaByKBs. + filterInput := common.ParseAndConvert(selector.MetadataCondition) + filteredIDs := common.MetaFilter(flattedMeta, filterInput) + + filteredSet := make(map[string]struct{}, len(filteredIDs)) + for _, id := range filteredIDs { + filteredSet[id] = struct{}{} + } + + if len(targetDocIDs) > 0 { + // Intersect with the document_ids restriction. + for id := range targetDocIDs { + if _, ok := filteredSet[id]; !ok { + delete(targetDocIDs, id) + } + } + } else { + targetDocIDs = filteredSet + } + + // Early-exit when conditions given but nothing matched. + rawConds, _ := selector.MetadataCondition["conditions"] + if rawConds != nil && len(targetDocIDs) == 0 { + return &BatchUpdateDocumentMetadatasResponse{Updated: 0, MatchedDocs: 0}, common.CodeSuccess, nil + } + } + + ids := make([]string, 0, len(targetDocIDs)) + for id := range targetDocIDs { + ids = append(ids, id) + } + + // Apply updates and deletes per document using Python's batch_update_metadata + // semantics instead of a simple merge-then-delete. + updated := 0 + for _, docID := range ids { + currentMeta, err := s.GetDocumentMetadataByID(docID) + if err != nil { + common.Warn("BatchUpdateDocumentMetadata: get metadata failed", + zap.String("docID", docID), zap.Error(err)) + continue + } + + meta := cloneDocumentMetadata(currentMeta) + originalMeta := cloneDocumentMetadata(meta) + + changed := applyDocumentMetadataUpdates(meta, updates) + if applyDocumentMetadataDeletes(meta, deletes) { + changed = true + } + + if !changed || reflect.DeepEqual(originalMeta, meta) { + continue + } + + if len(meta) == 0 { + if err := s.DeleteDocumentAllMetadata(docID); err != nil { + common.Warn("BatchUpdateDocumentMetadata: delete all metadata failed", + zap.String("docID", docID), zap.Error(err)) + continue + } + updated++ + continue + } + + if err := s.replaceDocumentMetadata(docID, meta); err != nil { + common.Warn("BatchUpdateDocumentMetadata: replace metadata failed", + zap.String("docID", docID), zap.Error(err)) + continue + } + updated++ + } + + return &BatchUpdateDocumentMetadatasResponse{Updated: updated, MatchedDocs: len(ids)}, common.CodeSuccess, nil +} + +func (s *DocumentService) UploadDocumentInfos(userID string, files []*multipart.FileHeader) ([]map[string]interface{}, common.ErrorCode, error) { + fileSvc := &FileService{ + fileDAO: s.fileDAO, + file2DocumentDAO: s.file2DocumentDAO, + } + data, err := fileSvc.UploadInfos(userID, files) + if err != nil { + return nil, common.CodeDataError, err + } + return data, common.CodeSuccess, nil +} + +func (s *DocumentService) UploadDocumentInfoByURL(userID, rawURL string) (map[string]interface{}, common.ErrorCode, error) { + fileSvc := &FileService{ + fileDAO: s.fileDAO, + file2DocumentDAO: s.file2DocumentDAO, + } + data, err := fileSvc.UploadFromURL(userID, rawURL) + if err != nil { + return nil, common.CodeDataError, err + } + return data, common.CodeSuccess, nil +} + +func validateBatchUpdateDocumentMetadatasRequest( + selector *DocumentMetadataSelector, + updates []DocumentMetadataUpdate, + deletes []DocumentMetadataDelete, +) (common.ErrorCode, error) { + for _, upd := range updates { + if strings.TrimSpace(upd.Key) == "" || upd.Value == nil { + return common.CodeDataError, errors.New("Each update requires key and value.") + } + } + for _, del := range deletes { + if strings.TrimSpace(del.Key) == "" { + return common.CodeDataError, errors.New("Each delete requires key.") + } + } + if selector != nil && selector.MetadataCondition != nil { + if _, ok := selector.MetadataCondition["conditions"]; !ok && len(selector.MetadataCondition) > 0 { + return common.CodeDataError, errors.New("metadata_condition must be an object.") + } + } + return common.CodeSuccess, nil +} + +func cloneDocumentMetadata(meta map[string]interface{}) map[string]interface{} { + if meta == nil { + return map[string]interface{}{} + } + cloned := make(map[string]interface{}, len(meta)) + for k, v := range meta { + cloned[k] = cloneDocumentMetadataValue(v) + } + return cloned +} + +func cloneDocumentMetadataValue(v interface{}) interface{} { + switch typed := v.(type) { + case []interface{}: + cp := make([]interface{}, len(typed)) + copy(cp, typed) + return cp + case []string: + cp := make([]interface{}, 0, len(typed)) + for _, item := range typed { + cp = append(cp, item) + } + return cp + default: + return typed + } +} + +func applyDocumentMetadataUpdates(meta map[string]interface{}, updates []DocumentMetadataUpdate) bool { + changed := false + for _, upd := range updates { + key := strings.TrimSpace(upd.Key) + if key == "" { + continue + } + normalizedValue := normalizeDocumentMetadataUpdateValue(upd.Value, upd.ValueType) + matchProvided := upd.Match != nil && !(fmt.Sprintf("%v", upd.Match) == "") + current, exists := meta[key] + if !exists { + if matchProvided { + continue + } + if listVal, ok := toMetadataInterfaceSlice(normalizedValue); ok { + meta[key] = dedupeDocumentMetadataList(listVal) + } else { + meta[key] = normalizedValue + } + changed = true + continue + } + + if curList, ok := toMetadataInterfaceSlice(current); ok { + if !matchProvided { + newList := append([]interface{}{}, curList...) + if appendList, ok := toMetadataInterfaceSlice(normalizedValue); ok { + newList = append(newList, appendList...) + } else { + newList = append(newList, normalizedValue) + } + newList = dedupeDocumentMetadataList(newList) + if !reflect.DeepEqual(curList, newList) { + meta[key] = newList + changed = true + } + continue + } + + replaced := false + newList := make([]interface{}, 0, len(curList)) + for _, item := range curList { + if documentMetadataValuesEqual(item, upd.Match) { + if replacementList, ok := toMetadataInterfaceSlice(normalizedValue); ok { + newList = append(newList, replacementList...) + } else { + newList = append(newList, normalizedValue) + } + replaced = true + } else { + newList = append(newList, item) + } + } + newList = dedupeDocumentMetadataList(newList) + if replaced && !reflect.DeepEqual(curList, newList) { + meta[key] = newList + changed = true + } + continue + } + + if !matchProvided { + if !reflect.DeepEqual(current, normalizedValue) { + meta[key] = normalizedValue + changed = true + } + continue + } + if documentMetadataValuesEqual(current, upd.Match) && !reflect.DeepEqual(current, normalizedValue) { + meta[key] = normalizedValue + changed = true + } + } + return changed +} + +func applyDocumentMetadataDeletes(meta map[string]interface{}, deletes []DocumentMetadataDelete) bool { + changed := false + for _, del := range deletes { + key := strings.TrimSpace(del.Key) + current, exists := meta[key] + if key == "" || !exists { + continue + } + + if curList, ok := toMetadataInterfaceSlice(current); ok { + if del.Value == nil { + delete(meta, key) + changed = true + continue + } + newList := make([]interface{}, 0, len(curList)) + for _, item := range curList { + if !documentMetadataValuesEqual(item, del.Value) { + newList = append(newList, item) + } + } + if len(newList) != len(curList) { + if len(newList) == 0 { + delete(meta, key) + } else { + meta[key] = newList + } + changed = true + } + continue + } + + if del.Value == nil || documentMetadataValuesEqual(current, del.Value) { + delete(meta, key) + changed = true + } + } + return changed +} + +func toMetadataInterfaceSlice(v interface{}) ([]interface{}, bool) { + switch typed := v.(type) { + case []interface{}: + cp := make([]interface{}, len(typed)) + copy(cp, typed) + return cp, true + case []string: + cp := make([]interface{}, 0, len(typed)) + for _, item := range typed { + cp = append(cp, item) + } + return cp, true + default: + return nil, false + } +} + +func dedupeDocumentMetadataList(items []interface{}) []interface{} { + result := make([]interface{}, 0, len(items)) + seen := make(map[string]struct{}, len(items)) + for _, item := range items { + key := fmt.Sprintf("%T:%v", item, item) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + result = append(result, item) + } + return result +} + +func documentMetadataValuesEqual(a, b interface{}) bool { + return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b) +} + +func normalizeDocumentMetadataUpdateValue(value interface{}, valueType string) interface{} { + switch strings.ToLower(strings.TrimSpace(valueType)) { + case "list": + if list, ok := normalizeMetadataListValue(value); ok { + return list + } + return []interface{}{} + case "number": + scalar, ok := firstScalarMetadataValue(value) + if !ok { + return value + } + switch typed := scalar.(type) { + case float64, float32, int, int8, int16, int32, int64: + return typed + case json.Number: + if i, err := typed.Int64(); err == nil { + return i + } + if f, err := typed.Float64(); err == nil { + return f + } + case string: + trimmed := strings.TrimSpace(typed) + if trimmed == "" { + return "" + } + if i, err := strconv.ParseInt(trimmed, 10, 64); err == nil { + return i + } + if f, err := strconv.ParseFloat(trimmed, 64); err == nil { + return f + } + return trimmed + } + return scalar + case "string", "time": + if scalar, ok := firstScalarMetadataValue(value); ok { + return fmt.Sprintf("%v", scalar) + } + return "" + default: + return value + } +} + +func normalizeMetadataListValue(value interface{}) ([]interface{}, bool) { + switch typed := value.(type) { + case []interface{}: + result := make([]interface{}, 0, len(typed)) + for _, item := range typed { + if nested, ok := normalizeMetadataListValue(item); ok { + result = append(result, nested...) + continue + } + if item != nil { + result = append(result, item) + } + } + return result, true + case []string: + result := make([]interface{}, 0, len(typed)) + for _, item := range typed { + result = append(result, item) + } + return result, true + default: + return nil, false + } +} + +func firstScalarMetadataValue(value interface{}) (interface{}, bool) { + if list, ok := normalizeMetadataListValue(value); ok { + for _, item := range list { + if item != nil { + return item, true + } + } + return nil, false + } + if value == nil { + return nil, false + } + return value, true +} diff --git a/internal/service/document_test.go b/internal/service/document_test.go index 621aa34c3d..929904c7a5 100644 --- a/internal/service/document_test.go +++ b/internal/service/document_test.go @@ -19,7 +19,9 @@ package service import ( "context" "errors" + "fmt" "path/filepath" + "strings" "testing" "github.com/glebarez/sqlite" @@ -130,6 +132,107 @@ type failingDeleteMetadataEngine struct { updateCalled bool } +type metadataDocEngine struct { + fakeChatDocEngine + records map[string]map[string]interface{} + docKBs map[string]string +} + +func newMetadataDocEngine(records map[string]map[string]interface{}, docKBs map[string]string) *metadataDocEngine { + cp := make(map[string]map[string]interface{}, len(records)) + for id, meta := range records { + dup := make(map[string]interface{}, len(meta)) + for k, v := range meta { + dup[k] = v + } + cp[id] = dup + } + return &metadataDocEngine{records: cp, docKBs: docKBs} +} + +func (m *metadataDocEngine) SearchMetadata(_ context.Context, req *types.SearchMetadataRequest) (*types.SearchMetadataResult, error) { + var ids map[string]struct{} + if rawIDs, ok := req.Filter["id"]; ok && rawIDs != nil { + ids = make(map[string]struct{}) + switch typed := rawIDs.(type) { + case []string: + for _, id := range typed { + ids[id] = struct{}{} + } + case []interface{}: + for _, id := range typed { + if s, ok := id.(string); ok { + ids[s] = struct{}{} + } + } + } + } + + var kbFilter map[string]struct{} + if rawKB, ok := req.Filter["kb_id"]; ok && rawKB != nil { + kbFilter = make(map[string]struct{}) + switch typed := rawKB.(type) { + case string: + kbFilter[typed] = struct{}{} + case []string: + for _, kb := range typed { + kbFilter[kb] = struct{}{} + } + case []interface{}: + for _, kb := range typed { + if s, ok := kb.(string); ok { + kbFilter[s] = struct{}{} + } + } + } + } + + result := &types.SearchMetadataResult{MetadataRecords: []map[string]interface{}{}} + for docID, meta := range m.records { + if ids != nil { + if _, ok := ids[docID]; !ok { + continue + } + } + kbID := m.docKBs[docID] + if kbFilter != nil { + if _, ok := kbFilter[kbID]; !ok { + continue + } + } + result.MetadataRecords = append(result.MetadataRecords, map[string]interface{}{ + "id": docID, + "kb_id": kbID, + "meta_fields": meta, + }) + } + return result, nil +} + +func (m *metadataDocEngine) UpdateMetadata(_ context.Context, docID string, datasetID string, metaFields map[string]interface{}, tenantID string) error { + dup := make(map[string]interface{}, len(metaFields)) + for k, v := range metaFields { + dup[k] = v + } + m.records[docID] = dup + if _, ok := m.docKBs[docID]; !ok { + m.docKBs[docID] = datasetID + } + return nil +} + +func (m *metadataDocEngine) DeleteMetadata(_ context.Context, condition map[string]interface{}, tenantID string) (int64, error) { + docID, _ := condition["id"].(string) + if docID == "" { + return 0, nil + } + if _, ok := m.records[docID]; ok { + delete(m.records, docID) + return 1, nil + } + return 0, nil +} + func (f *failingDeleteMetadataEngine) DeleteMetadata(ctx context.Context, condition map[string]interface{}, tenantID string) (int64, error) { return 0, f.deleteErr } @@ -1304,6 +1407,192 @@ func TestChunkImageStorageKeyFallsBackToChunkID(t *testing.T) { } } +func TestBatchUpdateDocumentMetadatasMatchesPythonSemantics(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + insertTestKB(t, "kb-1", "tenant-1", 3, 0, 0) + insertNamedTestDoc(t, "doc-1", "kb-1", "doc1.txt", 0, 0) + insertNamedTestDoc(t, "doc-2", "kb-1", "doc2.txt", 0, 0) + insertNamedTestDoc(t, "doc-3", "kb-1", "doc3.txt", 0, 0) + + engine := newMetadataDocEngine(map[string]map[string]interface{}{ + "doc-1": {"tags": []interface{}{"old", "keep"}, "author": "alice"}, + "doc-2": {"tags": []interface{}{"old"}, "author": "bob"}, + }, map[string]string{"doc-1": "kb-1", "doc-2": "kb-1", "doc-3": "kb-1"}) + + svc := testDocumentService(t) + svc.docEngine = engine + svc.metadataSvc = &MetadataService{kbDAO: dao.NewKnowledgebaseDAO(), docEngine: engine} + + resp, code, err := svc.BatchUpdateDocumentMetadatas("kb-1", &DocumentMetadataSelector{ + DocumentIDs: []string{"doc-1", "doc-2", "doc-3"}, + }, []DocumentMetadataUpdate{ + {Key: "tags", Value: "new", Match: "old"}, + {Key: "category", Value: "paper"}, + }, []DocumentMetadataDelete{ + {Key: "author", Value: "alice"}, + }) + if err != nil { + t.Fatalf("BatchUpdateDocumentMetadatas failed: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("code = %v, want success", code) + } + if resp.Updated != 3 || resp.MatchedDocs != 3 { + t.Fatalf("resp = %#v, want updated=3 matched=3", resp) + } + + got1 := engine.records["doc-1"] + if fmt.Sprintf("%v", got1["category"]) != "paper" { + t.Fatalf("doc-1 category = %#v", got1["category"]) + } + if _, ok := got1["author"]; ok { + t.Fatalf("doc-1 author should be deleted: %#v", got1) + } + if got := got1["tags"].([]interface{}); len(got) != 2 || got[0] != "new" || got[1] != "keep" { + t.Fatalf("doc-1 tags = %#v", got) + } + + got2 := engine.records["doc-2"] + if fmt.Sprintf("%v", got2["author"]) != "bob" { + t.Fatalf("doc-2 author should be kept: %#v", got2["author"]) + } + if got := got2["tags"].([]interface{}); len(got) != 1 || got[0] != "new" { + t.Fatalf("doc-2 tags = %#v", got) + } + + got3 := engine.records["doc-3"] + if fmt.Sprintf("%v", got3["category"]) != "paper" { + t.Fatalf("doc-3 category = %#v", got3) + } + if _, ok := got3["tags"]; ok { + t.Fatalf("doc-3 tags should not be created by match-only update: %#v", got3) + } +} + +func TestBatchUpdateDocumentMetadatasDeletesEmptyMetadataAndNoOps(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + insertTestKB(t, "kb-1", "tenant-1", 2, 0, 0) + insertNamedTestDoc(t, "doc-1", "kb-1", "doc1.txt", 0, 0) + insertNamedTestDoc(t, "doc-2", "kb-1", "doc2.txt", 0, 0) + + engine := newMetadataDocEngine(map[string]map[string]interface{}{ + "doc-1": {"status": "draft"}, + "doc-2": {"status": "done"}, + }, map[string]string{"doc-1": "kb-1", "doc-2": "kb-1"}) + + svc := testDocumentService(t) + svc.docEngine = engine + svc.metadataSvc = &MetadataService{kbDAO: dao.NewKnowledgebaseDAO(), docEngine: engine} + + resp, code, err := svc.BatchUpdateDocumentMetadatas("kb-1", &DocumentMetadataSelector{ + DocumentIDs: []string{"doc-1", "doc-2"}, + }, nil, []DocumentMetadataDelete{{Key: "status", Value: "draft"}}) + if err != nil || code != common.CodeSuccess { + t.Fatalf("delete batch failed: code=%v err=%v", code, err) + } + if resp.Updated != 1 || resp.MatchedDocs != 2 { + t.Fatalf("resp = %#v, want updated=1 matched=2", resp) + } + if _, ok := engine.records["doc-1"]; ok { + t.Fatalf("doc-1 metadata should be fully removed: %#v", engine.records["doc-1"]) + } + if fmt.Sprintf("%v", engine.records["doc-2"]["status"]) != "done" { + t.Fatalf("doc-2 metadata unexpectedly changed: %#v", engine.records["doc-2"]) + } +} + +func TestBatchUpdateDocumentMetadatasNormalizesNumberValues(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + insertTestKB(t, "kb-1", "tenant-1", 1, 0, 0) + insertNamedTestDoc(t, "doc-1", "kb-1", "doc1.txt", 0, 0) + + engine := newMetadataDocEngine(map[string]map[string]interface{}{}, map[string]string{"doc-1": "kb-1"}) + + svc := testDocumentService(t) + svc.docEngine = engine + svc.metadataSvc = &MetadataService{kbDAO: dao.NewKnowledgebaseDAO(), docEngine: engine} + + resp, code, err := svc.BatchUpdateDocumentMetadatas("kb-1", &DocumentMetadataSelector{ + DocumentIDs: []string{"doc-1"}, + }, []DocumentMetadataUpdate{ + {Key: "score", Value: "42", ValueType: "number"}, + }, nil) + if err != nil || code != common.CodeSuccess { + t.Fatalf("number batch failed: code=%v err=%v", code, err) + } + if resp.Updated != 1 || resp.MatchedDocs != 1 { + t.Fatalf("resp = %#v, want updated=1 matched=1", resp) + } + + got := engine.records["doc-1"]["score"] + switch v := got.(type) { + case int64: + if v != 42 { + t.Fatalf("score = %v, want 42", v) + } + case float64: + if v != 42 { + t.Fatalf("score = %v, want 42", v) + } + default: + t.Fatalf("score type = %T, want numeric value", got) + } +} + +func TestBatchUpdateDocumentMetadatasRejectsMissingValue(t *testing.T) { + svc := testDocumentService(t) + resp, code, err := svc.BatchUpdateDocumentMetadatas("kb-1", &DocumentMetadataSelector{}, []DocumentMetadataUpdate{ + {Key: "status"}, + }, nil) + if err == nil { + t.Fatal("expected validation error for missing value") + } + if resp != nil { + t.Fatalf("resp = %#v, want nil", resp) + } + if code != common.CodeDataError { + t.Fatalf("code = %v, want data error", code) + } + if !strings.Contains(err.Error(), "Each update requires key and value.") { + t.Fatalf("err = %v", err) + } +} + +func TestAggregateMetadataIgnoresNestedEmptyLists(t *testing.T) { + summary := aggregateMetadata([]map[string]interface{}{ + { + "id": "doc-1", + "kb_id": "kb-1", + "meta_fields": map[string]interface{}{ + "score": []interface{}{[]interface{}{}, 7.0}, + "name": "alice", + }, + }, + }) + + scoreField, ok := summary["score"].(map[string]interface{}) + if !ok { + t.Fatalf("score summary missing: %#v", summary) + } + values, ok := scoreField["values"].([][2]interface{}) + if !ok { + t.Fatalf("score values type = %T", scoreField["values"]) + } + if len(values) != 1 || values[0][0] != "7" || values[0][1] != 1 { + t.Fatalf("score values = %#v, want [[\"7\",1]]", values) + } +} + +func TestMergeFieldValuesKeepsNumericValues(t *testing.T) { + got := mergeFieldValues(1.0, 2.0) + if len(got) != 2 || got[0] != 1.0 || got[1] != 2.0 { + t.Fatalf("mergeFieldValues = %#v, want [1 2]", got) + } +} + func TestUpdateDatasetDocumentPipelineIDTakesPrecedenceOverChunkMethod(t *testing.T) { db := setupServiceTestDB(t) pushServiceDB(t, db) diff --git a/internal/service/file.go b/internal/service/file.go index f46e3ecb14..07a3579409 100644 --- a/internal/service/file.go +++ b/internal/service/file.go @@ -20,8 +20,12 @@ import ( "context" "encoding/base64" "fmt" + "html" "io" "mime/multipart" + "net" + "net/http" + "net/url" "os" "path/filepath" "ragflow/internal/common" @@ -31,6 +35,7 @@ import ( "ragflow/internal/ingestion/parser" "ragflow/internal/storage" "ragflow/internal/utility" + "regexp" "strings" "time" @@ -156,6 +161,11 @@ const DatasetFolderName = ".knowledgebase" // FileSourceDataset represents dataset as file source const FileSourceDataset = "knowledgebase" +var ( + assertURLSafe = utility.AssertURLSafe + pinnedHTTPClient = utility.PinnedHTTPClient +) + // toFileResponse converts file model to response format func (s *FileService) toFileResponse(file *entity.File) map[string]interface{} { result := map[string]interface{}{ @@ -391,6 +401,57 @@ func (s *FileService) UploadFile(tenantID, parentID string, files []*multipart.F return result, nil } +// UploadInfos mirrors Python's upload_info file branch: store raw bytes in the +// per-user downloads bucket and return lightweight upload descriptors instead +// of creating full File rows in the file-management tree. +func (s *FileService) UploadInfos(userID string, files []*multipart.FileHeader) ([]map[string]interface{}, error) { + storageImpl := storage.GetStorageFactory().GetStorage() + if storageImpl == nil { + return nil, fmt.Errorf("storage not initialized") + } + + results := make([]map[string]interface{}, 0, len(files)) + for _, fileHeader := range files { + filename := fileHeader.Filename + if err := s.checkUploadInfoHealth(userID, filename); err != nil { + return nil, err + } + src, err := fileHeader.Open() + if err != nil { + return nil, fmt.Errorf("failed to open uploaded file: %w", err) + } + data, readErr := readUploadInfoData(src) + src.Close() + if readErr != nil { + return nil, fmt.Errorf("failed to read file data: %w", readErr) + } + + contentType := fileHeader.Header.Get("Content-Type") + if contentType == "" { + contentType = http.DetectContentType(data) + } + filename, contentType, data = normalizeUploadInfoContent(filename, contentType, data) + resp, err := s.storeUploadInfoBlob(storageImpl, userID, filename, contentType, data) + if err != nil { + return nil, err + } + results = append(results, resp) + } + return results, nil +} + +func readUploadInfoData(r io.Reader) ([]byte, error) { + limited := &io.LimitedReader{R: r, N: maxRemoteFileSize + 1} + data, err := io.ReadAll(limited) + if err != nil { + return nil, err + } + if int64(len(data)) > maxRemoteFileSize { + return nil, fmt.Errorf("file size exceeds %d bytes", maxRemoteFileSize) + } + return data, nil +} + func (s *FileService) parseFilePath(filename string) []string { filename = strings.TrimPrefix(filename, "/") parts := strings.Split(filename, "/") @@ -1068,3 +1129,303 @@ func parseFileContent(filename string, data []byte) string { } return fp.String() } + +// toUploadInfoResponse converts a newly-uploaded file record to the shape +// Python's upload_info endpoint returns. +func (s *FileService) toUploadInfoResponse(file *entity.File, mimeType string) map[string]interface{} { + ext := "" + if idx := strings.LastIndex(file.Name, "."); idx >= 0 { + ext = strings.ToLower(file.Name[idx+1:]) + } + return map[string]interface{}{ + "id": file.ID, + "name": file.Name, + "size": file.Size, + "extension": ext, + "mime_type": mimeType, + "created_by": file.CreatedBy, + "created_at": float64(time.Now().UnixMilli()) / 1000.0, + "preview_url": nil, + } +} + +// maxRemoteFileSize bounds the body of a ?url= upload (100 MB). +const maxRemoteFileSize = 100 << 20 + +// UploadFromURL fetches a remote URL, saves the content to the tenant's root +// folder, and returns the file metadata map — mirroring Python +// FileService.upload_info(tenant_id, None, url). +// +// The remote fetch is SSRF-guarded (mirrors Python's assert_url_is_safe): the +// scheme must be http/https and every address the host resolves to must be +// globally routable; the validated IP is pinned for the actual connection — and +// re-validated on each redirect hop — to defeat DNS-rebinding. The HTTP client +// carries connect and overall timeouts, and the response body is bounded with +// truncation detection so an oversized file is rejected rather than silently +// clipped. +func (s *FileService) UploadFromURL(tenantID, rawURL string) (map[string]interface{}, error) { + if rawURL == "" { + return nil, fmt.Errorf("url is required") + } + parsed, err := url.Parse(rawURL) + if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") || parsed.Hostname() == "" { + return nil, fmt.Errorf("invalid or unsafe URL") + } + + data, headers, finalURL, err := fetchRemoteFileSafely(rawURL, maxRemoteFileSize) + if err != nil { + return nil, err + } + + storageImpl := storage.GetStorageFactory().GetStorage() + if storageImpl == nil { + return nil, fmt.Errorf("storage not initialized") + } + + contentType := headers.Get("Content-Type") + filename := normalizeRemoteUploadFilename(finalURL, contentType, data) + if err := s.checkUploadInfoHealth(tenantID, filename); err != nil { + return nil, err + } + filename, contentType, data = normalizeUploadInfoContent(filename, contentType, data) + return s.storeUploadInfoBlob(storageImpl, tenantID, filename, contentType, data) +} + +// fetchRemoteFileSafely downloads rawURL with SSRF protection, connect/overall +// timeouts, and a hard size cap that rejects (rather than truncates) oversized +// bodies. +func fetchRemoteFileSafely(rawURL string, maxSize int64) ([]byte, http.Header, string, error) { + currentURL := rawURL + for redirects := 0; redirects < 10; redirects++ { + hostname, resolvedIP, err := assertURLSafe(currentURL) + if err != nil { + return nil, nil, "", err + } + client := pinnedHTTPClient(hostname, resolvedIP, 10*time.Second) + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + + resp, err := client.Get(currentURL) // #nosec G107 + if err != nil { + return nil, nil, "", fmt.Errorf("failed to fetch URL: %w", err) + } + + if resp.StatusCode == http.StatusMovedPermanently || + resp.StatusCode == http.StatusFound || + resp.StatusCode == http.StatusSeeOther || + resp.StatusCode == http.StatusTemporaryRedirect || + resp.StatusCode == http.StatusPermanentRedirect { + location := resp.Header.Get("Location") + resp.Body.Close() + if location == "" { + return nil, nil, "", fmt.Errorf("redirect response missing Location header") + } + baseURL, parseErr := url.Parse(currentURL) + if parseErr != nil { + return nil, nil, "", parseErr + } + nextURL, resolveErr := baseURL.Parse(location) + if resolveErr != nil { + return nil, nil, "", resolveErr + } + currentURL = nextURL.String() + continue + } + + if resp.StatusCode >= 400 { + resp.Body.Close() + return nil, nil, "", fmt.Errorf("remote URL returned HTTP %d", resp.StatusCode) + } + + data, readErr := io.ReadAll(io.LimitReader(resp.Body, maxSize+1)) + resp.Body.Close() + if readErr != nil { + return nil, nil, "", fmt.Errorf("failed to read remote content: %w", readErr) + } + if int64(len(data)) > maxSize { + return nil, nil, "", fmt.Errorf("remote file exceeds the maximum allowed size of %d bytes", maxSize) + } + return data, resp.Header.Clone(), currentURL, nil + } + return nil, nil, "", fmt.Errorf("stopped after too many redirects") +} + +// isPublicIP reports whether ip is a globally routable address. It mirrors the +// allowlist intent of Python's assert_url_is_safe (which requires ip.is_global) +// by rejecting loopback, private, link-local, multicast, unspecified, and +// carrier-grade NAT ranges. IPv4-mapped IPv6 addresses are handled by the +// stdlib predicates. +func isPublicIP(ip net.IP) bool { + if ip == nil { + return false + } + if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() || + ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || + ip.IsMulticast() || ip.IsInterfaceLocalMulticast() { + return false + } + // Carrier-grade NAT 100.64.0.0/10 (RFC 6598) — not covered by IsPrivate. + if ip4 := ip.To4(); ip4 != nil && ip4[0] == 100 && ip4[1]&0xc0 == 0x40 { + return false + } + return true +} + +func (s *FileService) checkUploadInfoHealth(userID, filename string) error { + if filename == "" { + return fmt.Errorf("No file selected!") + } + maxFileNumPerUser := os.Getenv("MAX_FILE_NUM_PER_USER") + if maxFileNumPerUser != "" { + var maxNum int64 + if _, err := fmt.Sscanf(maxFileNumPerUser, "%d", &maxNum); err == nil && maxNum > 0 { + docCount, err := s.GetDocCount(userID) + if err != nil { + return fmt.Errorf("failed to get document count: %w", err) + } + if docCount >= maxNum { + return fmt.Errorf("Exceed the maximum file number of a free user!") + } + } + } + if len([]byte(filename)) > 255 { + return fmt.Errorf("Exceed the maximum length of file name!") + } + return nil +} + +func (s *FileService) storeUploadInfoBlob(storageImpl storage.Storage, userID, filename, contentType string, data []byte) (map[string]interface{}, error) { + location := common.GenerateUUID() + bucket := fmt.Sprintf("%s-downloads", userID) + if err := storageImpl.Put(bucket, location, data); err != nil { + return nil, fmt.Errorf("failed to store file: %w", err) + } + ext := "" + if idx := strings.LastIndex(filename, "."); idx >= 0 { + ext = strings.ToLower(filename[idx+1:]) + } + return map[string]interface{}{ + "id": location, + "name": filename, + "size": int64(len(data)), + "extension": ext, + "mime_type": contentType, + "created_by": userID, + "created_at": float64(time.Now().UnixMilli()) / 1000.0, + "preview_url": nil, + }, nil +} + +func normalizeRemoteUploadFilename(rawURL, contentType string, data []byte) string { + parsed, err := url.Parse(rawURL) + filename := "download" + if err == nil { + filename = sanitizeFilename(filepath.Base(parsed.Path)) + } + ct := strings.ToLower(strings.TrimSpace(strings.Split(contentType, ";")[0])) + if ct == "application/pdf" || bytesLooksLikePDF(data) { + if !strings.HasSuffix(strings.ToLower(filename), ".pdf") { + filename += ".pdf" + } + } + return filename +} + +func normalizeUploadInfoContent(filename, contentType string, data []byte) (string, string, []byte) { + lowerCT := strings.ToLower(strings.TrimSpace(strings.Split(contentType, ";")[0])) + if lowerCT == "" { + lowerCT = http.DetectContentType(data) + } + + if lowerCT == "application/pdf" || bytesLooksLikePDF(data) { + if !strings.HasSuffix(strings.ToLower(filename), ".pdf") { + filename += ".pdf" + } + lowerCT = "application/pdf" + } + if lowerCT == "text/html" || lowerCT == "application/xhtml+xml" || looksLikeHTML(data) { + data = htmlToReadableMarkdown(data) + if lowerCT == "" { + lowerCT = "text/html" + } + } + return filename, lowerCT, data +} + +func bytesLooksLikePDF(data []byte) bool { + return len(data) >= 4 && string(data[:4]) == "%PDF" +} + +func looksLikeHTML(data []byte) bool { + snippet := strings.ToLower(string(data)) + return strings.Contains(snippet, "]*>.*?`) + htmlTagRE = regexp.MustCompile(`(?s)<[^>]+>`) + multiSpaceRE = regexp.MustCompile(`[ \t]+`) + multiNewlineRE = regexp.MustCompile(`\n{3,}`) +) + +func htmlToReadableMarkdown(data []byte) []byte { + text := string(data) + text = htmlScriptStyleRE.ReplaceAllString(text, " ") + text = strings.ReplaceAll(text, "
", "\n") + text = strings.ReplaceAll(text, "
", "\n") + text = strings.ReplaceAll(text, "
", "\n") + text = strings.ReplaceAll(text, "

", "\n\n") + text = strings.ReplaceAll(text, "", "\n") + text = strings.ReplaceAll(text, "", "\n") + text = htmlTagRE.ReplaceAllString(text, " ") + text = html.UnescapeString(text) + text = strings.ReplaceAll(text, "\r", "\n") + text = multiSpaceRE.ReplaceAllString(text, " ") + text = multiNewlineRE.ReplaceAllString(text, "\n\n") + text = strings.TrimSpace(text) + return []byte(text) +} + +// reservedDeviceNames are Windows reserved filenames that must never be used. +var reservedDeviceNames = map[string]bool{ + "CON": true, "PRN": true, "AUX": true, "NUL": true, + "COM1": true, "COM2": true, "COM3": true, "COM4": true, "COM5": true, + "COM6": true, "COM7": true, "COM8": true, "COM9": true, + "LPT1": true, "LPT2": true, "LPT3": true, "LPT4": true, "LPT5": true, + "LPT6": true, "LPT7": true, "LPT8": true, "LPT9": true, +} + +// sanitizeFilename produces a safe, filesystem-friendly filename from an +// arbitrary URL path segment: it strips directory components, replaces unsafe / +// control characters, rejects reserved names, bounds the length, and falls back +// to "download". +func sanitizeFilename(name string) string { + name = filepath.Base(name) + name = strings.TrimSpace(name) + + name = strings.Map(func(r rune) rune { + switch r { + case '/', '\\', ':', '*', '?', '"', '<', '>', '|', 0: + return '_' + } + if r < 0x20 { // control characters + return '_' + } + return r + }, name) + + // Strip leading/trailing dots and spaces to avoid hidden or reserved forms. + name = strings.Trim(name, ". ") + + if name == "" || name == "." || name == ".." { + return "download" + } + if stem := strings.SplitN(strings.ToUpper(name), ".", 2)[0]; reservedDeviceNames[stem] { + return "download" + } + if len(name) > 255 { + name = name[:255] + } + return name +} diff --git a/internal/service/file_test.go b/internal/service/file_test.go index bb32232788..db1ec167ce 100644 --- a/internal/service/file_test.go +++ b/internal/service/file_test.go @@ -3,6 +3,10 @@ package service import ( "bytes" "errors" + "io" + "net/http" + "net/http/httptest" + "strings" "testing" "time" @@ -15,6 +19,7 @@ type fakeStorage struct { lastFnm string blob []byte err error + exists bool } func (f *fakeStorage) Health() bool { @@ -22,7 +27,11 @@ func (f *fakeStorage) Health() bool { } func (f *fakeStorage) Put(bucket, fnm string, binary []byte, tenantID ...string) error { - panic("not implemented in fakeStorage") + f.lastBucket = bucket + f.lastFnm = fnm + f.blob = binary + f.exists = true + return f.err } func (f *fakeStorage) Get(bucket, fnm string, tenantID ...string) ([]byte, error) { @@ -36,7 +45,7 @@ func (f *fakeStorage) Remove(bucket, fnm string, tenantID ...string) error { } func (f *fakeStorage) ObjExist(bucket, fnm string, tenantID ...string) bool { - panic("not implemented in fakeStorage") + return f.exists && f.lastBucket == bucket && f.lastFnm == fnm } func (f *fakeStorage) GetPresignedURL(bucket, fnm string, expires time.Duration, tenantID ...string) (string, error) { @@ -124,3 +133,130 @@ func TestFileService_DownloadAgentFile_Error(t *testing.T) { t.Errorf("expected nil blob, got %v", blob) } } + +func TestFileService_UploadFromURL_PDFAddsExtensionAndStoresToDownloads(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/pdf") + _, _ = w.Write([]byte("%PDF-1.7 fake pdf")) + })) + defer server.Close() + + origAssert := assertURLSafe + origPinned := pinnedHTTPClient + assertURLSafe = func(rawURL string) (string, string, error) { + return "127.0.0.1", "127.0.0.1", nil + } + pinnedHTTPClient = func(hostname, resolvedIP string, timeout time.Duration) *http.Client { + return server.Client() + } + t.Cleanup(func() { + assertURLSafe = origAssert + pinnedHTTPClient = origPinned + }) + + mockStorage := &fakeStorage{} + factory := storage.GetStorageFactory() + originalStorage := factory.GetStorage() + factory.SetStorage(mockStorage) + t.Cleanup(func() { factory.SetStorage(originalStorage) }) + + svc := NewFileService() + resp, err := svc.UploadFromURL("tenant123", server.URL+"/report") + if err != nil { + t.Fatalf("UploadFromURL failed: %v", err) + } + + if mockStorage.lastBucket != "tenant123-downloads" { + t.Fatalf("bucket = %q", mockStorage.lastBucket) + } + if resp["name"] != "report.pdf" { + t.Fatalf("name = %#v, want report.pdf", resp["name"]) + } + if resp["mime_type"] != "application/pdf" { + t.Fatalf("mime_type = %#v", resp["mime_type"]) + } + if resp["id"] != mockStorage.lastFnm { + t.Fatalf("id = %#v, stored key = %q", resp["id"], mockStorage.lastFnm) + } +} + +func TestFileService_UploadFromURL_HTMLNormalizesReadableContent(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + _, _ = w.Write([]byte(`
Hello

World

`)) + })) + defer server.Close() + + origAssert := assertURLSafe + origPinned := pinnedHTTPClient + assertURLSafe = func(rawURL string) (string, string, error) { + return "127.0.0.1", "127.0.0.1", nil + } + pinnedHTTPClient = func(hostname, resolvedIP string, timeout time.Duration) *http.Client { + return server.Client() + } + t.Cleanup(func() { + assertURLSafe = origAssert + pinnedHTTPClient = origPinned + }) + + mockStorage := &fakeStorage{} + factory := storage.GetStorageFactory() + originalStorage := factory.GetStorage() + factory.SetStorage(mockStorage) + t.Cleanup(func() { factory.SetStorage(originalStorage) }) + + svc := NewFileService() + resp, err := svc.UploadFromURL("tenant123", server.URL+"/page") + if err != nil { + t.Fatalf("UploadFromURL failed: %v", err) + } + + stored := string(mockStorage.blob) + if strings.Contains(strings.ToLower(stored), "