feat(go-api): Align document metadata batch APIs and upload_info with Python (#16269)

## Summary

  Align the Go implementations of these APIs with the Python behavior:

  - `POST /api/v1/datasets/:dataset_id/metadata/update`
  - `PATCH /api/v1/datasets/:dataset_id/documents/metadatas`
  - `POST /api/v1/documents/upload`

  ## What changed

  - Added the Go routes and handlers for the 3 APIs.
  - Aligned batch document metadata updates with Python semantics:
    - support `match` in update items
    - support list append / replace behavior
    - support deleting specific list values
    - remove metadata entirely when it becomes empty
- create metadata for documents that previously had none when updates
apply
    - count `updated` only when a document actually changes
- Aligned `documents/upload` file uploads with Python-style
`upload_info` behavior:
    - store upload-info blobs in the per-user downloads bucket
- return lightweight upload descriptors instead of normal
file-management responses
  - Improved URL upload behavior:
    - SSRF-guarded fetch with redirect validation
    - redirect limit aligned to Python behavior
    - normalize filename and MIME type
    - add `.pdf` when the fetched content is PDF
- normalize HTML content into readable text instead of storing raw HTML
shells

  ## Validation

  ### Unit tests

  Passed:

  - `go test ./internal/service`
  - `go test ./internal/handler`

  Also verified targeted cases for:

  - batch metadata update semantics
  - upload_info URL handling
  - upload_info download bucket behavior

  ### curl checks

Verified the new Go endpoints with `curl` and compared the response
shape and behavior with Python for:

  - `POST /api/v1/datasets/{dataset_id}/metadata/update`
  - `PATCH /api/v1/datasets/{dataset_id}/documents/metadatas`
  - `POST /api/v1/documents/upload`

  The Go responses were checked against Python for:
  - argument validation
  - success response shape
  - metadata update results
  - upload_info result structure
  - file vs URL input handling
This commit is contained in:
Hz_
2026-06-24 14:52:47 +08:00
committed by GitHub
parent 97718ec779
commit e35860ad74
8 changed files with 1427 additions and 11 deletions

View File

@@ -21,6 +21,7 @@ import (
"errors"
"fmt"
"mime"
"mime/multipart"
"net/http"
"path/filepath"
"ragflow/internal/common"
@@ -60,6 +61,9 @@ type documentServiceIface interface {
GetDocumentPreview(docID string) (*service.DocumentPreview, error)
DownloadDocument(datasetID, docID string) (*service.DownloadDocumentResp, error)
UpdateDatasetDocument(userID, datasetID, documentID string, req *service.UpdateDatasetDocumentRequest, present map[string]bool) (*service.UpdateDatasetDocumentResponse, common.ErrorCode, error)
BatchUpdateDocumentMetadatas(datasetID string, selector *service.DocumentMetadataSelector, updates []service.DocumentMetadataUpdate, deletes []service.DocumentMetadataDelete) (*service.BatchUpdateDocumentMetadatasResponse, common.ErrorCode, error)
UploadDocumentInfos(userID string, files []*multipart.FileHeader) ([]map[string]interface{}, common.ErrorCode, error)
UploadDocumentInfoByURL(userID, rawURL string) (map[string]interface{}, common.ErrorCode, error)
ListIngestionTasks(userID string, datasetID *string, page, pageSize int) ([]*entity.IngestionTask, error)
IngestDocuments(datasetID, userID string, docIDs []string) ([]*service.ParseDocumentResponse, error)
StopIngestionTasks(tasks []string, userID string) ([]*entity.IngestionTask, error)
@@ -1296,3 +1300,123 @@ func (h *DocumentHandler) UpdateDatasetDocument(c *gin.Context) {
"data": data,
})
}
func (h *DocumentHandler) UploadInfo(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
form, err := c.MultipartForm()
if err != nil && !strings.Contains(err.Error(), "request Content-Type isn't multipart/form-data") {
jsonError(c, common.CodeArgumentError, "Failed to parse multipart form: "+err.Error())
return
}
var fileHeaders []*multipart.FileHeader
if form != nil && form.File != nil {
fileHeaders = form.File["file"]
}
rawURL := strings.TrimSpace(c.Query("url"))
if len(fileHeaders) > 0 && rawURL != "" {
jsonError(c, common.CodeArgumentError, "Provide either multipart file(s) or ?url=..., not both.")
return
}
if len(fileHeaders) == 0 && rawURL == "" {
jsonError(c, common.CodeArgumentError, "Missing input: provide multipart file(s) or url")
return
}
if rawURL != "" {
data, code, err := h.documentService.UploadDocumentInfoByURL(user.ID, rawURL)
if err != nil {
jsonError(c, code, err.Error())
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"data": data,
"message": "success",
})
return
}
data, code, err := h.documentService.UploadDocumentInfos(user.ID, fileHeaders)
if err != nil {
jsonError(c, code, err.Error())
return
}
var payload interface{}
if len(data) == 1 {
payload = data[0]
} else {
payload = data
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"data": payload,
"message": "success",
})
}
type documentMetadataBatchRequest struct {
Selector *service.DocumentMetadataSelector `json:"selector"`
Updates []service.DocumentMetadataUpdate `json:"updates"`
Deletes []service.DocumentMetadataDelete `json:"deletes"`
}
func (h *DocumentHandler) MetadataBatchUpdate(c *gin.Context) {
h.handleBatchUpdateDocumentMetadatas(c)
}
func (h *DocumentHandler) UpdateDocumentMetadatas(c *gin.Context) {
h.handleBatchUpdateDocumentMetadatas(c)
}
func (h *DocumentHandler) handleBatchUpdateDocumentMetadatas(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
datasetID := strings.TrimSpace(c.Param("dataset_id"))
if datasetID == "" {
jsonError(c, common.CodeArgumentError, "dataset_id is required")
return
}
if !h.datasetService.Accessible(datasetID, user.ID) {
jsonError(c, common.CodeDataError, "You don't own the dataset "+datasetID+".")
return
}
var req documentMetadataBatchRequest
if err := c.ShouldBindJSON(&req); err != nil {
jsonError(c, common.CodeDataError, err.Error())
return
}
if req.Selector == nil {
req.Selector = &service.DocumentMetadataSelector{}
}
if req.Updates == nil {
req.Updates = []service.DocumentMetadataUpdate{}
}
if req.Deletes == nil {
req.Deletes = []service.DocumentMetadataDelete{}
}
resp, code, err := h.documentService.BatchUpdateDocumentMetadatas(datasetID, req.Selector, req.Updates, req.Deletes)
if err != nil {
jsonError(c, code, err.Error())
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"data": resp,
"message": "success",
})
}

View File

@@ -19,6 +19,7 @@ package handler
import (
"encoding/json"
"fmt"
"mime/multipart"
"net/http"
"net/http/httptest"
"strings"
@@ -49,6 +50,15 @@ type fakeDocumentService struct {
func (f *fakeDocumentService) UpdateDatasetDocument(userID, datasetID, documentID string, req *service.UpdateDatasetDocumentRequest, present map[string]bool) (*service.UpdateDatasetDocumentResponse, common.ErrorCode, error) {
return nil, common.CodeSuccess, nil
}
func (f *fakeDocumentService) BatchUpdateDocumentMetadatas(datasetID string, selector *service.DocumentMetadataSelector, updates []service.DocumentMetadataUpdate, deletes []service.DocumentMetadataDelete) (*service.BatchUpdateDocumentMetadatasResponse, common.ErrorCode, error) {
return nil, common.CodeSuccess, nil
}
func (f *fakeDocumentService) UploadDocumentInfos(userID string, files []*multipart.FileHeader) ([]map[string]interface{}, common.ErrorCode, error) {
return nil, common.CodeSuccess, nil
}
func (f *fakeDocumentService) UploadDocumentInfoByURL(userID, rawURL string) (map[string]interface{}, common.ErrorCode, error) {
return nil, common.CodeSuccess, nil
}
func (f *fakeDocumentService) GetDocumentArtifact(filename string) (*service.ArtifactResponse, error) {
if filename == "error.txt" {

View File

@@ -236,6 +236,7 @@ func (r *Router) Setup(engine *gin.Engine) {
documents := v1.Group("/documents")
{
documents.POST("", r.documentHandler.CreateDocument)
documents.POST("/upload", r.documentHandler.UploadInfo)
documents.GET("", r.documentHandler.ListDocuments)
documents.GET("/artifact/:filename", r.documentHandler.GetDocumentArtifact)
documents.GET("/:id/preview", r.documentHandler.GetDocumentPreview)
@@ -315,6 +316,8 @@ func (r *Router) Setup(engine *gin.Engine) {
//datasets.POST("/:dataset_id/documents/stop", r.documentHandler.StopParseDocuments)
datasets.DELETE("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.RemoveChunks)
datasets.PUT("/:dataset_id/documents/:document_id/metadata/config", r.datasetsHandler.UpdateDocumentMetadataConfig)
datasets.POST("/:dataset_id/metadata/update", r.documentHandler.MetadataBatchUpdate)
datasets.PATCH("/:dataset_id/documents/metadatas", r.documentHandler.UpdateDocumentMetadatas)
}
// Search routes

View File

@@ -22,8 +22,10 @@ import (
"errors"
"fmt"
"math"
"mime/multipart"
"os"
"path/filepath"
"reflect"
"regexp"
"sort"
"strconv"
@@ -41,6 +43,7 @@ import (
"ragflow/internal/tokenizer"
"ragflow/internal/utility"
"go.uber.org/zap"
"gorm.io/gorm"
)
@@ -1571,15 +1574,10 @@ func aggregateMetadata(chunks []map[string]interface{}) map[string]interface{} {
typeCounter[k][valueType] = typeCounter[k][valueType] + 1
}
// Aggregate value counts
values := v
if v, ok := v.([]interface{}); ok {
values = v
} else {
values = []interface{}{v}
}
for _, vv := range values.([]interface{}) {
// Aggregate value counts. Flatten nested arrays so malformed values do
// not surface in the UI as the literal string "[]".
values := flattenMetadataSummaryValues(v)
for _, vv := range values {
if vv == nil {
continue
}
@@ -1676,6 +1674,27 @@ func getMetaValueType(value interface{}) string {
return "string"
}
func flattenMetadataSummaryValues(value interface{}) []interface{} {
switch typed := value.(type) {
case []interface{}:
result := make([]interface{}, 0, len(typed))
for _, item := range typed {
result = append(result, flattenMetadataSummaryValues(item)...)
}
return result
case []string:
result := make([]interface{}, 0, len(typed))
for _, item := range typed {
result = append(result, item)
}
return result
case nil:
return nil
default:
return []interface{}{typed}
}
}
// isTimeString checks if a string is an ISO 8601 datetime
func isTimeString(s string) bool {
matched, _ := regexp.MatchString(`^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$`, s)
@@ -2176,3 +2195,471 @@ func mapDocumentRunStatus(run *string) string {
return "UNSTART"
}
}
// MetadataUpdate is one update item: set key to value.
type DocumentMetadataUpdate struct {
Key string `json:"key"`
Value interface{} `json:"value"`
Match interface{} `json:"match,omitempty"`
ValueType string `json:"valueType,omitempty"`
}
// MetadataDelete removes a whole key, or a specific value from a list field.
type DocumentMetadataDelete struct {
Key string `json:"key"`
Value interface{} `json:"value,omitempty"`
}
// MetadataSelector selects which documents to target.
type DocumentMetadataSelector struct {
DocumentIDs []string `json:"document_ids"`
MetadataCondition map[string]interface{} `json:"metadata_condition"`
}
// BatchUpdateDocumentMetadatasResponse summarises the operation.
type BatchUpdateDocumentMetadatasResponse struct {
Updated int `json:"updated"`
MatchedDocs int `json:"matched_docs"`
}
// BatchUpdateDocumentMetadatas implements the shared logic for
// PATCH /datasets/:dataset_id/documents/metadatas and
// POST /datasets/:dataset_id/metadata/update.
func (s *DocumentService) BatchUpdateDocumentMetadatas(
datasetID string,
selector *DocumentMetadataSelector,
updates []DocumentMetadataUpdate,
deletes []DocumentMetadataDelete,
) (*BatchUpdateDocumentMetadatasResponse, common.ErrorCode, error) {
if selector == nil {
selector = &DocumentMetadataSelector{}
}
if code, err := validateBatchUpdateDocumentMetadatasRequest(selector, updates, deletes); err != nil {
return nil, code, err
}
// Resolve which document IDs to target.
targetDocIDs := make(map[string]struct{})
if len(selector.DocumentIDs) > 0 {
// Validate that supplied IDs actually belong to this dataset.
allRows, err := s.documentDAO.GetAllDocIDsByKBIDs([]string{datasetID})
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to list dataset documents: %w", err)
}
kbDocIDSet := make(map[string]struct{}, len(allRows))
for _, row := range allRows {
kbDocIDSet[row["id"]] = struct{}{}
}
var invalidIDs []string
for _, id := range selector.DocumentIDs {
if _, ok := kbDocIDSet[id]; !ok {
invalidIDs = append(invalidIDs, id)
}
}
if len(invalidIDs) > 0 {
return nil, common.CodeDataError, fmt.Errorf("these documents do not belong to dataset %s: %s",
datasetID, strings.Join(invalidIDs, ", "))
}
for _, id := range selector.DocumentIDs {
targetDocIDs[id] = struct{}{}
}
}
// Apply metadata_condition filter.
if len(selector.MetadataCondition) > 0 {
flattedMeta, err := s.metadataSvc.GetFlattedMetaByKBs([]string{datasetID})
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to get flattened metadata: %w", err)
}
// ParseAndConvert mirrors Python convert_conditions: conditions arrive as
// {name, comparison_operator, value}, the operator is normalised, and the
// (possibly non-string) value is preserved. MetaFilter then matches against
// the common.MetaData returned by GetFlattedMetaByKBs.
filterInput := common.ParseAndConvert(selector.MetadataCondition)
filteredIDs := common.MetaFilter(flattedMeta, filterInput)
filteredSet := make(map[string]struct{}, len(filteredIDs))
for _, id := range filteredIDs {
filteredSet[id] = struct{}{}
}
if len(targetDocIDs) > 0 {
// Intersect with the document_ids restriction.
for id := range targetDocIDs {
if _, ok := filteredSet[id]; !ok {
delete(targetDocIDs, id)
}
}
} else {
targetDocIDs = filteredSet
}
// Early-exit when conditions given but nothing matched.
rawConds, _ := selector.MetadataCondition["conditions"]
if rawConds != nil && len(targetDocIDs) == 0 {
return &BatchUpdateDocumentMetadatasResponse{Updated: 0, MatchedDocs: 0}, common.CodeSuccess, nil
}
}
ids := make([]string, 0, len(targetDocIDs))
for id := range targetDocIDs {
ids = append(ids, id)
}
// Apply updates and deletes per document using Python's batch_update_metadata
// semantics instead of a simple merge-then-delete.
updated := 0
for _, docID := range ids {
currentMeta, err := s.GetDocumentMetadataByID(docID)
if err != nil {
common.Warn("BatchUpdateDocumentMetadata: get metadata failed",
zap.String("docID", docID), zap.Error(err))
continue
}
meta := cloneDocumentMetadata(currentMeta)
originalMeta := cloneDocumentMetadata(meta)
changed := applyDocumentMetadataUpdates(meta, updates)
if applyDocumentMetadataDeletes(meta, deletes) {
changed = true
}
if !changed || reflect.DeepEqual(originalMeta, meta) {
continue
}
if len(meta) == 0 {
if err := s.DeleteDocumentAllMetadata(docID); err != nil {
common.Warn("BatchUpdateDocumentMetadata: delete all metadata failed",
zap.String("docID", docID), zap.Error(err))
continue
}
updated++
continue
}
if err := s.replaceDocumentMetadata(docID, meta); err != nil {
common.Warn("BatchUpdateDocumentMetadata: replace metadata failed",
zap.String("docID", docID), zap.Error(err))
continue
}
updated++
}
return &BatchUpdateDocumentMetadatasResponse{Updated: updated, MatchedDocs: len(ids)}, common.CodeSuccess, nil
}
func (s *DocumentService) UploadDocumentInfos(userID string, files []*multipart.FileHeader) ([]map[string]interface{}, common.ErrorCode, error) {
fileSvc := &FileService{
fileDAO: s.fileDAO,
file2DocumentDAO: s.file2DocumentDAO,
}
data, err := fileSvc.UploadInfos(userID, files)
if err != nil {
return nil, common.CodeDataError, err
}
return data, common.CodeSuccess, nil
}
func (s *DocumentService) UploadDocumentInfoByURL(userID, rawURL string) (map[string]interface{}, common.ErrorCode, error) {
fileSvc := &FileService{
fileDAO: s.fileDAO,
file2DocumentDAO: s.file2DocumentDAO,
}
data, err := fileSvc.UploadFromURL(userID, rawURL)
if err != nil {
return nil, common.CodeDataError, err
}
return data, common.CodeSuccess, nil
}
func validateBatchUpdateDocumentMetadatasRequest(
selector *DocumentMetadataSelector,
updates []DocumentMetadataUpdate,
deletes []DocumentMetadataDelete,
) (common.ErrorCode, error) {
for _, upd := range updates {
if strings.TrimSpace(upd.Key) == "" || upd.Value == nil {
return common.CodeDataError, errors.New("Each update requires key and value.")
}
}
for _, del := range deletes {
if strings.TrimSpace(del.Key) == "" {
return common.CodeDataError, errors.New("Each delete requires key.")
}
}
if selector != nil && selector.MetadataCondition != nil {
if _, ok := selector.MetadataCondition["conditions"]; !ok && len(selector.MetadataCondition) > 0 {
return common.CodeDataError, errors.New("metadata_condition must be an object.")
}
}
return common.CodeSuccess, nil
}
func cloneDocumentMetadata(meta map[string]interface{}) map[string]interface{} {
if meta == nil {
return map[string]interface{}{}
}
cloned := make(map[string]interface{}, len(meta))
for k, v := range meta {
cloned[k] = cloneDocumentMetadataValue(v)
}
return cloned
}
func cloneDocumentMetadataValue(v interface{}) interface{} {
switch typed := v.(type) {
case []interface{}:
cp := make([]interface{}, len(typed))
copy(cp, typed)
return cp
case []string:
cp := make([]interface{}, 0, len(typed))
for _, item := range typed {
cp = append(cp, item)
}
return cp
default:
return typed
}
}
func applyDocumentMetadataUpdates(meta map[string]interface{}, updates []DocumentMetadataUpdate) bool {
changed := false
for _, upd := range updates {
key := strings.TrimSpace(upd.Key)
if key == "" {
continue
}
normalizedValue := normalizeDocumentMetadataUpdateValue(upd.Value, upd.ValueType)
matchProvided := upd.Match != nil && !(fmt.Sprintf("%v", upd.Match) == "")
current, exists := meta[key]
if !exists {
if matchProvided {
continue
}
if listVal, ok := toMetadataInterfaceSlice(normalizedValue); ok {
meta[key] = dedupeDocumentMetadataList(listVal)
} else {
meta[key] = normalizedValue
}
changed = true
continue
}
if curList, ok := toMetadataInterfaceSlice(current); ok {
if !matchProvided {
newList := append([]interface{}{}, curList...)
if appendList, ok := toMetadataInterfaceSlice(normalizedValue); ok {
newList = append(newList, appendList...)
} else {
newList = append(newList, normalizedValue)
}
newList = dedupeDocumentMetadataList(newList)
if !reflect.DeepEqual(curList, newList) {
meta[key] = newList
changed = true
}
continue
}
replaced := false
newList := make([]interface{}, 0, len(curList))
for _, item := range curList {
if documentMetadataValuesEqual(item, upd.Match) {
if replacementList, ok := toMetadataInterfaceSlice(normalizedValue); ok {
newList = append(newList, replacementList...)
} else {
newList = append(newList, normalizedValue)
}
replaced = true
} else {
newList = append(newList, item)
}
}
newList = dedupeDocumentMetadataList(newList)
if replaced && !reflect.DeepEqual(curList, newList) {
meta[key] = newList
changed = true
}
continue
}
if !matchProvided {
if !reflect.DeepEqual(current, normalizedValue) {
meta[key] = normalizedValue
changed = true
}
continue
}
if documentMetadataValuesEqual(current, upd.Match) && !reflect.DeepEqual(current, normalizedValue) {
meta[key] = normalizedValue
changed = true
}
}
return changed
}
func applyDocumentMetadataDeletes(meta map[string]interface{}, deletes []DocumentMetadataDelete) bool {
changed := false
for _, del := range deletes {
key := strings.TrimSpace(del.Key)
current, exists := meta[key]
if key == "" || !exists {
continue
}
if curList, ok := toMetadataInterfaceSlice(current); ok {
if del.Value == nil {
delete(meta, key)
changed = true
continue
}
newList := make([]interface{}, 0, len(curList))
for _, item := range curList {
if !documentMetadataValuesEqual(item, del.Value) {
newList = append(newList, item)
}
}
if len(newList) != len(curList) {
if len(newList) == 0 {
delete(meta, key)
} else {
meta[key] = newList
}
changed = true
}
continue
}
if del.Value == nil || documentMetadataValuesEqual(current, del.Value) {
delete(meta, key)
changed = true
}
}
return changed
}
func toMetadataInterfaceSlice(v interface{}) ([]interface{}, bool) {
switch typed := v.(type) {
case []interface{}:
cp := make([]interface{}, len(typed))
copy(cp, typed)
return cp, true
case []string:
cp := make([]interface{}, 0, len(typed))
for _, item := range typed {
cp = append(cp, item)
}
return cp, true
default:
return nil, false
}
}
func dedupeDocumentMetadataList(items []interface{}) []interface{} {
result := make([]interface{}, 0, len(items))
seen := make(map[string]struct{}, len(items))
for _, item := range items {
key := fmt.Sprintf("%T:%v", item, item)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
result = append(result, item)
}
return result
}
func documentMetadataValuesEqual(a, b interface{}) bool {
return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b)
}
func normalizeDocumentMetadataUpdateValue(value interface{}, valueType string) interface{} {
switch strings.ToLower(strings.TrimSpace(valueType)) {
case "list":
if list, ok := normalizeMetadataListValue(value); ok {
return list
}
return []interface{}{}
case "number":
scalar, ok := firstScalarMetadataValue(value)
if !ok {
return value
}
switch typed := scalar.(type) {
case float64, float32, int, int8, int16, int32, int64:
return typed
case json.Number:
if i, err := typed.Int64(); err == nil {
return i
}
if f, err := typed.Float64(); err == nil {
return f
}
case string:
trimmed := strings.TrimSpace(typed)
if trimmed == "" {
return ""
}
if i, err := strconv.ParseInt(trimmed, 10, 64); err == nil {
return i
}
if f, err := strconv.ParseFloat(trimmed, 64); err == nil {
return f
}
return trimmed
}
return scalar
case "string", "time":
if scalar, ok := firstScalarMetadataValue(value); ok {
return fmt.Sprintf("%v", scalar)
}
return ""
default:
return value
}
}
func normalizeMetadataListValue(value interface{}) ([]interface{}, bool) {
switch typed := value.(type) {
case []interface{}:
result := make([]interface{}, 0, len(typed))
for _, item := range typed {
if nested, ok := normalizeMetadataListValue(item); ok {
result = append(result, nested...)
continue
}
if item != nil {
result = append(result, item)
}
}
return result, true
case []string:
result := make([]interface{}, 0, len(typed))
for _, item := range typed {
result = append(result, item)
}
return result, true
default:
return nil, false
}
}
func firstScalarMetadataValue(value interface{}) (interface{}, bool) {
if list, ok := normalizeMetadataListValue(value); ok {
for _, item := range list {
if item != nil {
return item, true
}
}
return nil, false
}
if value == nil {
return nil, false
}
return value, true
}

View File

@@ -19,7 +19,9 @@ package service
import (
"context"
"errors"
"fmt"
"path/filepath"
"strings"
"testing"
"github.com/glebarez/sqlite"
@@ -130,6 +132,107 @@ type failingDeleteMetadataEngine struct {
updateCalled bool
}
type metadataDocEngine struct {
fakeChatDocEngine
records map[string]map[string]interface{}
docKBs map[string]string
}
func newMetadataDocEngine(records map[string]map[string]interface{}, docKBs map[string]string) *metadataDocEngine {
cp := make(map[string]map[string]interface{}, len(records))
for id, meta := range records {
dup := make(map[string]interface{}, len(meta))
for k, v := range meta {
dup[k] = v
}
cp[id] = dup
}
return &metadataDocEngine{records: cp, docKBs: docKBs}
}
func (m *metadataDocEngine) SearchMetadata(_ context.Context, req *types.SearchMetadataRequest) (*types.SearchMetadataResult, error) {
var ids map[string]struct{}
if rawIDs, ok := req.Filter["id"]; ok && rawIDs != nil {
ids = make(map[string]struct{})
switch typed := rawIDs.(type) {
case []string:
for _, id := range typed {
ids[id] = struct{}{}
}
case []interface{}:
for _, id := range typed {
if s, ok := id.(string); ok {
ids[s] = struct{}{}
}
}
}
}
var kbFilter map[string]struct{}
if rawKB, ok := req.Filter["kb_id"]; ok && rawKB != nil {
kbFilter = make(map[string]struct{})
switch typed := rawKB.(type) {
case string:
kbFilter[typed] = struct{}{}
case []string:
for _, kb := range typed {
kbFilter[kb] = struct{}{}
}
case []interface{}:
for _, kb := range typed {
if s, ok := kb.(string); ok {
kbFilter[s] = struct{}{}
}
}
}
}
result := &types.SearchMetadataResult{MetadataRecords: []map[string]interface{}{}}
for docID, meta := range m.records {
if ids != nil {
if _, ok := ids[docID]; !ok {
continue
}
}
kbID := m.docKBs[docID]
if kbFilter != nil {
if _, ok := kbFilter[kbID]; !ok {
continue
}
}
result.MetadataRecords = append(result.MetadataRecords, map[string]interface{}{
"id": docID,
"kb_id": kbID,
"meta_fields": meta,
})
}
return result, nil
}
func (m *metadataDocEngine) UpdateMetadata(_ context.Context, docID string, datasetID string, metaFields map[string]interface{}, tenantID string) error {
dup := make(map[string]interface{}, len(metaFields))
for k, v := range metaFields {
dup[k] = v
}
m.records[docID] = dup
if _, ok := m.docKBs[docID]; !ok {
m.docKBs[docID] = datasetID
}
return nil
}
func (m *metadataDocEngine) DeleteMetadata(_ context.Context, condition map[string]interface{}, tenantID string) (int64, error) {
docID, _ := condition["id"].(string)
if docID == "" {
return 0, nil
}
if _, ok := m.records[docID]; ok {
delete(m.records, docID)
return 1, nil
}
return 0, nil
}
func (f *failingDeleteMetadataEngine) DeleteMetadata(ctx context.Context, condition map[string]interface{}, tenantID string) (int64, error) {
return 0, f.deleteErr
}
@@ -1304,6 +1407,192 @@ func TestChunkImageStorageKeyFallsBackToChunkID(t *testing.T) {
}
}
func TestBatchUpdateDocumentMetadatasMatchesPythonSemantics(t *testing.T) {
db := setupServiceTestDB(t)
pushServiceDB(t, db)
insertTestKB(t, "kb-1", "tenant-1", 3, 0, 0)
insertNamedTestDoc(t, "doc-1", "kb-1", "doc1.txt", 0, 0)
insertNamedTestDoc(t, "doc-2", "kb-1", "doc2.txt", 0, 0)
insertNamedTestDoc(t, "doc-3", "kb-1", "doc3.txt", 0, 0)
engine := newMetadataDocEngine(map[string]map[string]interface{}{
"doc-1": {"tags": []interface{}{"old", "keep"}, "author": "alice"},
"doc-2": {"tags": []interface{}{"old"}, "author": "bob"},
}, map[string]string{"doc-1": "kb-1", "doc-2": "kb-1", "doc-3": "kb-1"})
svc := testDocumentService(t)
svc.docEngine = engine
svc.metadataSvc = &MetadataService{kbDAO: dao.NewKnowledgebaseDAO(), docEngine: engine}
resp, code, err := svc.BatchUpdateDocumentMetadatas("kb-1", &DocumentMetadataSelector{
DocumentIDs: []string{"doc-1", "doc-2", "doc-3"},
}, []DocumentMetadataUpdate{
{Key: "tags", Value: "new", Match: "old"},
{Key: "category", Value: "paper"},
}, []DocumentMetadataDelete{
{Key: "author", Value: "alice"},
})
if err != nil {
t.Fatalf("BatchUpdateDocumentMetadatas failed: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("code = %v, want success", code)
}
if resp.Updated != 3 || resp.MatchedDocs != 3 {
t.Fatalf("resp = %#v, want updated=3 matched=3", resp)
}
got1 := engine.records["doc-1"]
if fmt.Sprintf("%v", got1["category"]) != "paper" {
t.Fatalf("doc-1 category = %#v", got1["category"])
}
if _, ok := got1["author"]; ok {
t.Fatalf("doc-1 author should be deleted: %#v", got1)
}
if got := got1["tags"].([]interface{}); len(got) != 2 || got[0] != "new" || got[1] != "keep" {
t.Fatalf("doc-1 tags = %#v", got)
}
got2 := engine.records["doc-2"]
if fmt.Sprintf("%v", got2["author"]) != "bob" {
t.Fatalf("doc-2 author should be kept: %#v", got2["author"])
}
if got := got2["tags"].([]interface{}); len(got) != 1 || got[0] != "new" {
t.Fatalf("doc-2 tags = %#v", got)
}
got3 := engine.records["doc-3"]
if fmt.Sprintf("%v", got3["category"]) != "paper" {
t.Fatalf("doc-3 category = %#v", got3)
}
if _, ok := got3["tags"]; ok {
t.Fatalf("doc-3 tags should not be created by match-only update: %#v", got3)
}
}
func TestBatchUpdateDocumentMetadatasDeletesEmptyMetadataAndNoOps(t *testing.T) {
db := setupServiceTestDB(t)
pushServiceDB(t, db)
insertTestKB(t, "kb-1", "tenant-1", 2, 0, 0)
insertNamedTestDoc(t, "doc-1", "kb-1", "doc1.txt", 0, 0)
insertNamedTestDoc(t, "doc-2", "kb-1", "doc2.txt", 0, 0)
engine := newMetadataDocEngine(map[string]map[string]interface{}{
"doc-1": {"status": "draft"},
"doc-2": {"status": "done"},
}, map[string]string{"doc-1": "kb-1", "doc-2": "kb-1"})
svc := testDocumentService(t)
svc.docEngine = engine
svc.metadataSvc = &MetadataService{kbDAO: dao.NewKnowledgebaseDAO(), docEngine: engine}
resp, code, err := svc.BatchUpdateDocumentMetadatas("kb-1", &DocumentMetadataSelector{
DocumentIDs: []string{"doc-1", "doc-2"},
}, nil, []DocumentMetadataDelete{{Key: "status", Value: "draft"}})
if err != nil || code != common.CodeSuccess {
t.Fatalf("delete batch failed: code=%v err=%v", code, err)
}
if resp.Updated != 1 || resp.MatchedDocs != 2 {
t.Fatalf("resp = %#v, want updated=1 matched=2", resp)
}
if _, ok := engine.records["doc-1"]; ok {
t.Fatalf("doc-1 metadata should be fully removed: %#v", engine.records["doc-1"])
}
if fmt.Sprintf("%v", engine.records["doc-2"]["status"]) != "done" {
t.Fatalf("doc-2 metadata unexpectedly changed: %#v", engine.records["doc-2"])
}
}
func TestBatchUpdateDocumentMetadatasNormalizesNumberValues(t *testing.T) {
db := setupServiceTestDB(t)
pushServiceDB(t, db)
insertTestKB(t, "kb-1", "tenant-1", 1, 0, 0)
insertNamedTestDoc(t, "doc-1", "kb-1", "doc1.txt", 0, 0)
engine := newMetadataDocEngine(map[string]map[string]interface{}{}, map[string]string{"doc-1": "kb-1"})
svc := testDocumentService(t)
svc.docEngine = engine
svc.metadataSvc = &MetadataService{kbDAO: dao.NewKnowledgebaseDAO(), docEngine: engine}
resp, code, err := svc.BatchUpdateDocumentMetadatas("kb-1", &DocumentMetadataSelector{
DocumentIDs: []string{"doc-1"},
}, []DocumentMetadataUpdate{
{Key: "score", Value: "42", ValueType: "number"},
}, nil)
if err != nil || code != common.CodeSuccess {
t.Fatalf("number batch failed: code=%v err=%v", code, err)
}
if resp.Updated != 1 || resp.MatchedDocs != 1 {
t.Fatalf("resp = %#v, want updated=1 matched=1", resp)
}
got := engine.records["doc-1"]["score"]
switch v := got.(type) {
case int64:
if v != 42 {
t.Fatalf("score = %v, want 42", v)
}
case float64:
if v != 42 {
t.Fatalf("score = %v, want 42", v)
}
default:
t.Fatalf("score type = %T, want numeric value", got)
}
}
func TestBatchUpdateDocumentMetadatasRejectsMissingValue(t *testing.T) {
svc := testDocumentService(t)
resp, code, err := svc.BatchUpdateDocumentMetadatas("kb-1", &DocumentMetadataSelector{}, []DocumentMetadataUpdate{
{Key: "status"},
}, nil)
if err == nil {
t.Fatal("expected validation error for missing value")
}
if resp != nil {
t.Fatalf("resp = %#v, want nil", resp)
}
if code != common.CodeDataError {
t.Fatalf("code = %v, want data error", code)
}
if !strings.Contains(err.Error(), "Each update requires key and value.") {
t.Fatalf("err = %v", err)
}
}
func TestAggregateMetadataIgnoresNestedEmptyLists(t *testing.T) {
summary := aggregateMetadata([]map[string]interface{}{
{
"id": "doc-1",
"kb_id": "kb-1",
"meta_fields": map[string]interface{}{
"score": []interface{}{[]interface{}{}, 7.0},
"name": "alice",
},
},
})
scoreField, ok := summary["score"].(map[string]interface{})
if !ok {
t.Fatalf("score summary missing: %#v", summary)
}
values, ok := scoreField["values"].([][2]interface{})
if !ok {
t.Fatalf("score values type = %T", scoreField["values"])
}
if len(values) != 1 || values[0][0] != "7" || values[0][1] != 1 {
t.Fatalf("score values = %#v, want [[\"7\",1]]", values)
}
}
func TestMergeFieldValuesKeepsNumericValues(t *testing.T) {
got := mergeFieldValues(1.0, 2.0)
if len(got) != 2 || got[0] != 1.0 || got[1] != 2.0 {
t.Fatalf("mergeFieldValues = %#v, want [1 2]", got)
}
}
func TestUpdateDatasetDocumentPipelineIDTakesPrecedenceOverChunkMethod(t *testing.T) {
db := setupServiceTestDB(t)
pushServiceDB(t, db)

View File

@@ -20,8 +20,12 @@ import (
"context"
"encoding/base64"
"fmt"
"html"
"io"
"mime/multipart"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"ragflow/internal/common"
@@ -31,6 +35,7 @@ import (
"ragflow/internal/ingestion/parser"
"ragflow/internal/storage"
"ragflow/internal/utility"
"regexp"
"strings"
"time"
@@ -156,6 +161,11 @@ const DatasetFolderName = ".knowledgebase"
// FileSourceDataset represents dataset as file source
const FileSourceDataset = "knowledgebase"
var (
assertURLSafe = utility.AssertURLSafe
pinnedHTTPClient = utility.PinnedHTTPClient
)
// toFileResponse converts file model to response format
func (s *FileService) toFileResponse(file *entity.File) map[string]interface{} {
result := map[string]interface{}{
@@ -391,6 +401,57 @@ func (s *FileService) UploadFile(tenantID, parentID string, files []*multipart.F
return result, nil
}
// UploadInfos mirrors Python's upload_info file branch: store raw bytes in the
// per-user downloads bucket and return lightweight upload descriptors instead
// of creating full File rows in the file-management tree.
func (s *FileService) UploadInfos(userID string, files []*multipart.FileHeader) ([]map[string]interface{}, error) {
storageImpl := storage.GetStorageFactory().GetStorage()
if storageImpl == nil {
return nil, fmt.Errorf("storage not initialized")
}
results := make([]map[string]interface{}, 0, len(files))
for _, fileHeader := range files {
filename := fileHeader.Filename
if err := s.checkUploadInfoHealth(userID, filename); err != nil {
return nil, err
}
src, err := fileHeader.Open()
if err != nil {
return nil, fmt.Errorf("failed to open uploaded file: %w", err)
}
data, readErr := readUploadInfoData(src)
src.Close()
if readErr != nil {
return nil, fmt.Errorf("failed to read file data: %w", readErr)
}
contentType := fileHeader.Header.Get("Content-Type")
if contentType == "" {
contentType = http.DetectContentType(data)
}
filename, contentType, data = normalizeUploadInfoContent(filename, contentType, data)
resp, err := s.storeUploadInfoBlob(storageImpl, userID, filename, contentType, data)
if err != nil {
return nil, err
}
results = append(results, resp)
}
return results, nil
}
func readUploadInfoData(r io.Reader) ([]byte, error) {
limited := &io.LimitedReader{R: r, N: maxRemoteFileSize + 1}
data, err := io.ReadAll(limited)
if err != nil {
return nil, err
}
if int64(len(data)) > maxRemoteFileSize {
return nil, fmt.Errorf("file size exceeds %d bytes", maxRemoteFileSize)
}
return data, nil
}
func (s *FileService) parseFilePath(filename string) []string {
filename = strings.TrimPrefix(filename, "/")
parts := strings.Split(filename, "/")
@@ -1068,3 +1129,303 @@ func parseFileContent(filename string, data []byte) string {
}
return fp.String()
}
// toUploadInfoResponse converts a newly-uploaded file record to the shape
// Python's upload_info endpoint returns.
func (s *FileService) toUploadInfoResponse(file *entity.File, mimeType string) map[string]interface{} {
ext := ""
if idx := strings.LastIndex(file.Name, "."); idx >= 0 {
ext = strings.ToLower(file.Name[idx+1:])
}
return map[string]interface{}{
"id": file.ID,
"name": file.Name,
"size": file.Size,
"extension": ext,
"mime_type": mimeType,
"created_by": file.CreatedBy,
"created_at": float64(time.Now().UnixMilli()) / 1000.0,
"preview_url": nil,
}
}
// maxRemoteFileSize bounds the body of a ?url= upload (100 MB).
const maxRemoteFileSize = 100 << 20
// UploadFromURL fetches a remote URL, saves the content to the tenant's root
// folder, and returns the file metadata map — mirroring Python
// FileService.upload_info(tenant_id, None, url).
//
// The remote fetch is SSRF-guarded (mirrors Python's assert_url_is_safe): the
// scheme must be http/https and every address the host resolves to must be
// globally routable; the validated IP is pinned for the actual connection — and
// re-validated on each redirect hop — to defeat DNS-rebinding. The HTTP client
// carries connect and overall timeouts, and the response body is bounded with
// truncation detection so an oversized file is rejected rather than silently
// clipped.
func (s *FileService) UploadFromURL(tenantID, rawURL string) (map[string]interface{}, error) {
if rawURL == "" {
return nil, fmt.Errorf("url is required")
}
parsed, err := url.Parse(rawURL)
if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") || parsed.Hostname() == "" {
return nil, fmt.Errorf("invalid or unsafe URL")
}
data, headers, finalURL, err := fetchRemoteFileSafely(rawURL, maxRemoteFileSize)
if err != nil {
return nil, err
}
storageImpl := storage.GetStorageFactory().GetStorage()
if storageImpl == nil {
return nil, fmt.Errorf("storage not initialized")
}
contentType := headers.Get("Content-Type")
filename := normalizeRemoteUploadFilename(finalURL, contentType, data)
if err := s.checkUploadInfoHealth(tenantID, filename); err != nil {
return nil, err
}
filename, contentType, data = normalizeUploadInfoContent(filename, contentType, data)
return s.storeUploadInfoBlob(storageImpl, tenantID, filename, contentType, data)
}
// fetchRemoteFileSafely downloads rawURL with SSRF protection, connect/overall
// timeouts, and a hard size cap that rejects (rather than truncates) oversized
// bodies.
func fetchRemoteFileSafely(rawURL string, maxSize int64) ([]byte, http.Header, string, error) {
currentURL := rawURL
for redirects := 0; redirects < 10; redirects++ {
hostname, resolvedIP, err := assertURLSafe(currentURL)
if err != nil {
return nil, nil, "", err
}
client := pinnedHTTPClient(hostname, resolvedIP, 10*time.Second)
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
resp, err := client.Get(currentURL) // #nosec G107
if err != nil {
return nil, nil, "", fmt.Errorf("failed to fetch URL: %w", err)
}
if resp.StatusCode == http.StatusMovedPermanently ||
resp.StatusCode == http.StatusFound ||
resp.StatusCode == http.StatusSeeOther ||
resp.StatusCode == http.StatusTemporaryRedirect ||
resp.StatusCode == http.StatusPermanentRedirect {
location := resp.Header.Get("Location")
resp.Body.Close()
if location == "" {
return nil, nil, "", fmt.Errorf("redirect response missing Location header")
}
baseURL, parseErr := url.Parse(currentURL)
if parseErr != nil {
return nil, nil, "", parseErr
}
nextURL, resolveErr := baseURL.Parse(location)
if resolveErr != nil {
return nil, nil, "", resolveErr
}
currentURL = nextURL.String()
continue
}
if resp.StatusCode >= 400 {
resp.Body.Close()
return nil, nil, "", fmt.Errorf("remote URL returned HTTP %d", resp.StatusCode)
}
data, readErr := io.ReadAll(io.LimitReader(resp.Body, maxSize+1))
resp.Body.Close()
if readErr != nil {
return nil, nil, "", fmt.Errorf("failed to read remote content: %w", readErr)
}
if int64(len(data)) > maxSize {
return nil, nil, "", fmt.Errorf("remote file exceeds the maximum allowed size of %d bytes", maxSize)
}
return data, resp.Header.Clone(), currentURL, nil
}
return nil, nil, "", fmt.Errorf("stopped after too many redirects")
}
// isPublicIP reports whether ip is a globally routable address. It mirrors the
// allowlist intent of Python's assert_url_is_safe (which requires ip.is_global)
// by rejecting loopback, private, link-local, multicast, unspecified, and
// carrier-grade NAT ranges. IPv4-mapped IPv6 addresses are handled by the
// stdlib predicates.
func isPublicIP(ip net.IP) bool {
if ip == nil {
return false
}
if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() ||
ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() ||
ip.IsMulticast() || ip.IsInterfaceLocalMulticast() {
return false
}
// Carrier-grade NAT 100.64.0.0/10 (RFC 6598) — not covered by IsPrivate.
if ip4 := ip.To4(); ip4 != nil && ip4[0] == 100 && ip4[1]&0xc0 == 0x40 {
return false
}
return true
}
func (s *FileService) checkUploadInfoHealth(userID, filename string) error {
if filename == "" {
return fmt.Errorf("No file selected!")
}
maxFileNumPerUser := os.Getenv("MAX_FILE_NUM_PER_USER")
if maxFileNumPerUser != "" {
var maxNum int64
if _, err := fmt.Sscanf(maxFileNumPerUser, "%d", &maxNum); err == nil && maxNum > 0 {
docCount, err := s.GetDocCount(userID)
if err != nil {
return fmt.Errorf("failed to get document count: %w", err)
}
if docCount >= maxNum {
return fmt.Errorf("Exceed the maximum file number of a free user!")
}
}
}
if len([]byte(filename)) > 255 {
return fmt.Errorf("Exceed the maximum length of file name!")
}
return nil
}
func (s *FileService) storeUploadInfoBlob(storageImpl storage.Storage, userID, filename, contentType string, data []byte) (map[string]interface{}, error) {
location := common.GenerateUUID()
bucket := fmt.Sprintf("%s-downloads", userID)
if err := storageImpl.Put(bucket, location, data); err != nil {
return nil, fmt.Errorf("failed to store file: %w", err)
}
ext := ""
if idx := strings.LastIndex(filename, "."); idx >= 0 {
ext = strings.ToLower(filename[idx+1:])
}
return map[string]interface{}{
"id": location,
"name": filename,
"size": int64(len(data)),
"extension": ext,
"mime_type": contentType,
"created_by": userID,
"created_at": float64(time.Now().UnixMilli()) / 1000.0,
"preview_url": nil,
}, nil
}
func normalizeRemoteUploadFilename(rawURL, contentType string, data []byte) string {
parsed, err := url.Parse(rawURL)
filename := "download"
if err == nil {
filename = sanitizeFilename(filepath.Base(parsed.Path))
}
ct := strings.ToLower(strings.TrimSpace(strings.Split(contentType, ";")[0]))
if ct == "application/pdf" || bytesLooksLikePDF(data) {
if !strings.HasSuffix(strings.ToLower(filename), ".pdf") {
filename += ".pdf"
}
}
return filename
}
func normalizeUploadInfoContent(filename, contentType string, data []byte) (string, string, []byte) {
lowerCT := strings.ToLower(strings.TrimSpace(strings.Split(contentType, ";")[0]))
if lowerCT == "" {
lowerCT = http.DetectContentType(data)
}
if lowerCT == "application/pdf" || bytesLooksLikePDF(data) {
if !strings.HasSuffix(strings.ToLower(filename), ".pdf") {
filename += ".pdf"
}
lowerCT = "application/pdf"
}
if lowerCT == "text/html" || lowerCT == "application/xhtml+xml" || looksLikeHTML(data) {
data = htmlToReadableMarkdown(data)
if lowerCT == "" {
lowerCT = "text/html"
}
}
return filename, lowerCT, data
}
func bytesLooksLikePDF(data []byte) bool {
return len(data) >= 4 && string(data[:4]) == "%PDF"
}
func looksLikeHTML(data []byte) bool {
snippet := strings.ToLower(string(data))
return strings.Contains(snippet, "<html") || strings.Contains(snippet, "<body") || strings.Contains(snippet, "<div")
}
var (
htmlScriptStyleRE = regexp.MustCompile(`(?is)<(script|style)[^>]*>.*?</(script|style)>`)
htmlTagRE = regexp.MustCompile(`(?s)<[^>]+>`)
multiSpaceRE = regexp.MustCompile(`[ \t]+`)
multiNewlineRE = regexp.MustCompile(`\n{3,}`)
)
func htmlToReadableMarkdown(data []byte) []byte {
text := string(data)
text = htmlScriptStyleRE.ReplaceAllString(text, " ")
text = strings.ReplaceAll(text, "<br>", "\n")
text = strings.ReplaceAll(text, "<br/>", "\n")
text = strings.ReplaceAll(text, "<br />", "\n")
text = strings.ReplaceAll(text, "</p>", "\n\n")
text = strings.ReplaceAll(text, "</div>", "\n")
text = strings.ReplaceAll(text, "</li>", "\n")
text = htmlTagRE.ReplaceAllString(text, " ")
text = html.UnescapeString(text)
text = strings.ReplaceAll(text, "\r", "\n")
text = multiSpaceRE.ReplaceAllString(text, " ")
text = multiNewlineRE.ReplaceAllString(text, "\n\n")
text = strings.TrimSpace(text)
return []byte(text)
}
// reservedDeviceNames are Windows reserved filenames that must never be used.
var reservedDeviceNames = map[string]bool{
"CON": true, "PRN": true, "AUX": true, "NUL": true,
"COM1": true, "COM2": true, "COM3": true, "COM4": true, "COM5": true,
"COM6": true, "COM7": true, "COM8": true, "COM9": true,
"LPT1": true, "LPT2": true, "LPT3": true, "LPT4": true, "LPT5": true,
"LPT6": true, "LPT7": true, "LPT8": true, "LPT9": true,
}
// sanitizeFilename produces a safe, filesystem-friendly filename from an
// arbitrary URL path segment: it strips directory components, replaces unsafe /
// control characters, rejects reserved names, bounds the length, and falls back
// to "download".
func sanitizeFilename(name string) string {
name = filepath.Base(name)
name = strings.TrimSpace(name)
name = strings.Map(func(r rune) rune {
switch r {
case '/', '\\', ':', '*', '?', '"', '<', '>', '|', 0:
return '_'
}
if r < 0x20 { // control characters
return '_'
}
return r
}, name)
// Strip leading/trailing dots and spaces to avoid hidden or reserved forms.
name = strings.Trim(name, ". ")
if name == "" || name == "." || name == ".." {
return "download"
}
if stem := strings.SplitN(strings.ToUpper(name), ".", 2)[0]; reservedDeviceNames[stem] {
return "download"
}
if len(name) > 255 {
name = name[:255]
}
return name
}

View File

@@ -3,6 +3,10 @@ package service
import (
"bytes"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
@@ -15,6 +19,7 @@ type fakeStorage struct {
lastFnm string
blob []byte
err error
exists bool
}
func (f *fakeStorage) Health() bool {
@@ -22,7 +27,11 @@ func (f *fakeStorage) Health() bool {
}
func (f *fakeStorage) Put(bucket, fnm string, binary []byte, tenantID ...string) error {
panic("not implemented in fakeStorage")
f.lastBucket = bucket
f.lastFnm = fnm
f.blob = binary
f.exists = true
return f.err
}
func (f *fakeStorage) Get(bucket, fnm string, tenantID ...string) ([]byte, error) {
@@ -36,7 +45,7 @@ func (f *fakeStorage) Remove(bucket, fnm string, tenantID ...string) error {
}
func (f *fakeStorage) ObjExist(bucket, fnm string, tenantID ...string) bool {
panic("not implemented in fakeStorage")
return f.exists && f.lastBucket == bucket && f.lastFnm == fnm
}
func (f *fakeStorage) GetPresignedURL(bucket, fnm string, expires time.Duration, tenantID ...string) (string, error) {
@@ -124,3 +133,130 @@ func TestFileService_DownloadAgentFile_Error(t *testing.T) {
t.Errorf("expected nil blob, got %v", blob)
}
}
func TestFileService_UploadFromURL_PDFAddsExtensionAndStoresToDownloads(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/pdf")
_, _ = w.Write([]byte("%PDF-1.7 fake pdf"))
}))
defer server.Close()
origAssert := assertURLSafe
origPinned := pinnedHTTPClient
assertURLSafe = func(rawURL string) (string, string, error) {
return "127.0.0.1", "127.0.0.1", nil
}
pinnedHTTPClient = func(hostname, resolvedIP string, timeout time.Duration) *http.Client {
return server.Client()
}
t.Cleanup(func() {
assertURLSafe = origAssert
pinnedHTTPClient = origPinned
})
mockStorage := &fakeStorage{}
factory := storage.GetStorageFactory()
originalStorage := factory.GetStorage()
factory.SetStorage(mockStorage)
t.Cleanup(func() { factory.SetStorage(originalStorage) })
svc := NewFileService()
resp, err := svc.UploadFromURL("tenant123", server.URL+"/report")
if err != nil {
t.Fatalf("UploadFromURL failed: %v", err)
}
if mockStorage.lastBucket != "tenant123-downloads" {
t.Fatalf("bucket = %q", mockStorage.lastBucket)
}
if resp["name"] != "report.pdf" {
t.Fatalf("name = %#v, want report.pdf", resp["name"])
}
if resp["mime_type"] != "application/pdf" {
t.Fatalf("mime_type = %#v", resp["mime_type"])
}
if resp["id"] != mockStorage.lastFnm {
t.Fatalf("id = %#v, stored key = %q", resp["id"], mockStorage.lastFnm)
}
}
func TestFileService_UploadFromURL_HTMLNormalizesReadableContent(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
_, _ = w.Write([]byte(`<html><head><script>bad()</script></head><body><div>Hello</div><p>World</p></body></html>`))
}))
defer server.Close()
origAssert := assertURLSafe
origPinned := pinnedHTTPClient
assertURLSafe = func(rawURL string) (string, string, error) {
return "127.0.0.1", "127.0.0.1", nil
}
pinnedHTTPClient = func(hostname, resolvedIP string, timeout time.Duration) *http.Client {
return server.Client()
}
t.Cleanup(func() {
assertURLSafe = origAssert
pinnedHTTPClient = origPinned
})
mockStorage := &fakeStorage{}
factory := storage.GetStorageFactory()
originalStorage := factory.GetStorage()
factory.SetStorage(mockStorage)
t.Cleanup(func() { factory.SetStorage(originalStorage) })
svc := NewFileService()
resp, err := svc.UploadFromURL("tenant123", server.URL+"/page")
if err != nil {
t.Fatalf("UploadFromURL failed: %v", err)
}
stored := string(mockStorage.blob)
if strings.Contains(strings.ToLower(stored), "<html") || strings.Contains(strings.ToLower(stored), "<script") {
t.Fatalf("stored html was not normalized: %q", stored)
}
if !strings.Contains(stored, "Hello") || !strings.Contains(stored, "World") {
t.Fatalf("stored normalized text missing content: %q", stored)
}
if resp["mime_type"] != "text/html" {
t.Fatalf("mime_type = %#v", resp["mime_type"])
}
}
func TestNormalizeUploadInfoContent_PDFTakesPrecedenceOverHTML(t *testing.T) {
filename, contentType, data := normalizeUploadInfoContent(
"report",
"text/html",
[]byte("%PDF-1.7 fake pdf"),
)
if filename != "report.pdf" {
t.Fatalf("filename = %q, want report.pdf", filename)
}
if contentType != "application/pdf" {
t.Fatalf("contentType = %q, want application/pdf", contentType)
}
if !bytes.Equal(data, []byte("%PDF-1.7 fake pdf")) {
t.Fatalf("pdf bytes were unexpectedly transformed: %q", string(data))
}
}
func TestReadUploadInfoData_RejectsOversizedInput(t *testing.T) {
reader := io.LimitReader(zeroReader{}, maxRemoteFileSize+1)
_, err := readUploadInfoData(reader)
if err == nil {
t.Fatal("expected oversized input to be rejected")
}
if !strings.Contains(err.Error(), "file size exceeds") {
t.Fatalf("err = %v, want size limit message", err)
}
}
type zeroReader struct{}
func (zeroReader) Read(p []byte) (int, error) {
for i := range p {
p[i] = 0
}
return len(p), nil
}

View File

@@ -422,10 +422,16 @@ func mergeFieldValues(existing, new interface{}) []interface{} {
if val != "" {
result = append(result, val)
}
case float64, float32, int, int8, int16, int32, int64, bool:
result = append(result, val)
case []interface{}:
for _, item := range val {
addValue(item)
}
case []string:
for _, item := range val {
addValue(item)
}
}
}