mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
feat(go-api): Align document metadata batch APIs and upload_info with Python (#16269)
## Summary
Align the Go implementations of these APIs with the Python behavior:
- `POST /api/v1/datasets/:dataset_id/metadata/update`
- `PATCH /api/v1/datasets/:dataset_id/documents/metadatas`
- `POST /api/v1/documents/upload`
## What changed
- Added the Go routes and handlers for the 3 APIs.
- Aligned batch document metadata updates with Python semantics:
- support `match` in update items
- support list append / replace behavior
- support deleting specific list values
- remove metadata entirely when it becomes empty
- create metadata for documents that previously had none when updates
apply
- count `updated` only when a document actually changes
- Aligned `documents/upload` file uploads with Python-style
`upload_info` behavior:
- store upload-info blobs in the per-user downloads bucket
- return lightweight upload descriptors instead of normal
file-management responses
- Improved URL upload behavior:
- SSRF-guarded fetch with redirect validation
- redirect limit aligned to Python behavior
- normalize filename and MIME type
- add `.pdf` when the fetched content is PDF
- normalize HTML content into readable text instead of storing raw HTML
shells
## Validation
### Unit tests
Passed:
- `go test ./internal/service`
- `go test ./internal/handler`
Also verified targeted cases for:
- batch metadata update semantics
- upload_info URL handling
- upload_info download bucket behavior
### curl checks
Verified the new Go endpoints with `curl` and compared the response
shape and behavior with Python for:
- `POST /api/v1/datasets/{dataset_id}/metadata/update`
- `PATCH /api/v1/datasets/{dataset_id}/documents/metadatas`
- `POST /api/v1/documents/upload`
The Go responses were checked against Python for:
- argument validation
- success response shape
- metadata update results
- upload_info result structure
- file vs URL input handling
This commit is contained in:
@@ -21,6 +21,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"ragflow/internal/common"
|
||||
@@ -60,6 +61,9 @@ type documentServiceIface interface {
|
||||
GetDocumentPreview(docID string) (*service.DocumentPreview, error)
|
||||
DownloadDocument(datasetID, docID string) (*service.DownloadDocumentResp, error)
|
||||
UpdateDatasetDocument(userID, datasetID, documentID string, req *service.UpdateDatasetDocumentRequest, present map[string]bool) (*service.UpdateDatasetDocumentResponse, common.ErrorCode, error)
|
||||
BatchUpdateDocumentMetadatas(datasetID string, selector *service.DocumentMetadataSelector, updates []service.DocumentMetadataUpdate, deletes []service.DocumentMetadataDelete) (*service.BatchUpdateDocumentMetadatasResponse, common.ErrorCode, error)
|
||||
UploadDocumentInfos(userID string, files []*multipart.FileHeader) ([]map[string]interface{}, common.ErrorCode, error)
|
||||
UploadDocumentInfoByURL(userID, rawURL string) (map[string]interface{}, common.ErrorCode, error)
|
||||
ListIngestionTasks(userID string, datasetID *string, page, pageSize int) ([]*entity.IngestionTask, error)
|
||||
IngestDocuments(datasetID, userID string, docIDs []string) ([]*service.ParseDocumentResponse, error)
|
||||
StopIngestionTasks(tasks []string, userID string) ([]*entity.IngestionTask, error)
|
||||
@@ -1296,3 +1300,123 @@ func (h *DocumentHandler) UpdateDatasetDocument(c *gin.Context) {
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *DocumentHandler) UploadInfo(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
form, err := c.MultipartForm()
|
||||
if err != nil && !strings.Contains(err.Error(), "request Content-Type isn't multipart/form-data") {
|
||||
jsonError(c, common.CodeArgumentError, "Failed to parse multipart form: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var fileHeaders []*multipart.FileHeader
|
||||
if form != nil && form.File != nil {
|
||||
fileHeaders = form.File["file"]
|
||||
}
|
||||
rawURL := strings.TrimSpace(c.Query("url"))
|
||||
|
||||
if len(fileHeaders) > 0 && rawURL != "" {
|
||||
jsonError(c, common.CodeArgumentError, "Provide either multipart file(s) or ?url=..., not both.")
|
||||
return
|
||||
}
|
||||
if len(fileHeaders) == 0 && rawURL == "" {
|
||||
jsonError(c, common.CodeArgumentError, "Missing input: provide multipart file(s) or url")
|
||||
return
|
||||
}
|
||||
|
||||
if rawURL != "" {
|
||||
data, code, err := h.documentService.UploadDocumentInfoByURL(user.ID, rawURL)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"data": data,
|
||||
"message": "success",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
data, code, err := h.documentService.UploadDocumentInfos(user.ID, fileHeaders)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var payload interface{}
|
||||
if len(data) == 1 {
|
||||
payload = data[0]
|
||||
} else {
|
||||
payload = data
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"data": payload,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
type documentMetadataBatchRequest struct {
|
||||
Selector *service.DocumentMetadataSelector `json:"selector"`
|
||||
Updates []service.DocumentMetadataUpdate `json:"updates"`
|
||||
Deletes []service.DocumentMetadataDelete `json:"deletes"`
|
||||
}
|
||||
|
||||
func (h *DocumentHandler) MetadataBatchUpdate(c *gin.Context) {
|
||||
h.handleBatchUpdateDocumentMetadatas(c)
|
||||
}
|
||||
|
||||
func (h *DocumentHandler) UpdateDocumentMetadatas(c *gin.Context) {
|
||||
h.handleBatchUpdateDocumentMetadatas(c)
|
||||
}
|
||||
|
||||
func (h *DocumentHandler) handleBatchUpdateDocumentMetadatas(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
datasetID := strings.TrimSpace(c.Param("dataset_id"))
|
||||
if datasetID == "" {
|
||||
jsonError(c, common.CodeArgumentError, "dataset_id is required")
|
||||
return
|
||||
}
|
||||
if !h.datasetService.Accessible(datasetID, user.ID) {
|
||||
jsonError(c, common.CodeDataError, "You don't own the dataset "+datasetID+".")
|
||||
return
|
||||
}
|
||||
|
||||
var req documentMetadataBatchRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
jsonError(c, common.CodeDataError, err.Error())
|
||||
return
|
||||
}
|
||||
if req.Selector == nil {
|
||||
req.Selector = &service.DocumentMetadataSelector{}
|
||||
}
|
||||
if req.Updates == nil {
|
||||
req.Updates = []service.DocumentMetadataUpdate{}
|
||||
}
|
||||
if req.Deletes == nil {
|
||||
req.Deletes = []service.DocumentMetadataDelete{}
|
||||
}
|
||||
|
||||
resp, code, err := h.documentService.BatchUpdateDocumentMetadatas(datasetID, req.Selector, req.Updates, req.Deletes)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"data": resp,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ package handler
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -49,6 +50,15 @@ type fakeDocumentService struct {
|
||||
func (f *fakeDocumentService) UpdateDatasetDocument(userID, datasetID, documentID string, req *service.UpdateDatasetDocumentRequest, present map[string]bool) (*service.UpdateDatasetDocumentResponse, common.ErrorCode, error) {
|
||||
return nil, common.CodeSuccess, nil
|
||||
}
|
||||
func (f *fakeDocumentService) BatchUpdateDocumentMetadatas(datasetID string, selector *service.DocumentMetadataSelector, updates []service.DocumentMetadataUpdate, deletes []service.DocumentMetadataDelete) (*service.BatchUpdateDocumentMetadatasResponse, common.ErrorCode, error) {
|
||||
return nil, common.CodeSuccess, nil
|
||||
}
|
||||
func (f *fakeDocumentService) UploadDocumentInfos(userID string, files []*multipart.FileHeader) ([]map[string]interface{}, common.ErrorCode, error) {
|
||||
return nil, common.CodeSuccess, nil
|
||||
}
|
||||
func (f *fakeDocumentService) UploadDocumentInfoByURL(userID, rawURL string) (map[string]interface{}, common.ErrorCode, error) {
|
||||
return nil, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentService) GetDocumentArtifact(filename string) (*service.ArtifactResponse, error) {
|
||||
if filename == "error.txt" {
|
||||
|
||||
@@ -236,6 +236,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
documents := v1.Group("/documents")
|
||||
{
|
||||
documents.POST("", r.documentHandler.CreateDocument)
|
||||
documents.POST("/upload", r.documentHandler.UploadInfo)
|
||||
documents.GET("", r.documentHandler.ListDocuments)
|
||||
documents.GET("/artifact/:filename", r.documentHandler.GetDocumentArtifact)
|
||||
documents.GET("/:id/preview", r.documentHandler.GetDocumentPreview)
|
||||
@@ -315,6 +316,8 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
//datasets.POST("/:dataset_id/documents/stop", r.documentHandler.StopParseDocuments)
|
||||
datasets.DELETE("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.RemoveChunks)
|
||||
datasets.PUT("/:dataset_id/documents/:document_id/metadata/config", r.datasetsHandler.UpdateDocumentMetadataConfig)
|
||||
datasets.POST("/:dataset_id/metadata/update", r.documentHandler.MetadataBatchUpdate)
|
||||
datasets.PATCH("/:dataset_id/documents/metadatas", r.documentHandler.UpdateDocumentMetadatas)
|
||||
}
|
||||
|
||||
// Search routes
|
||||
|
||||
@@ -22,8 +22,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
@@ -41,6 +43,7 @@ import (
|
||||
"ragflow/internal/tokenizer"
|
||||
"ragflow/internal/utility"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -1571,15 +1574,10 @@ func aggregateMetadata(chunks []map[string]interface{}) map[string]interface{} {
|
||||
typeCounter[k][valueType] = typeCounter[k][valueType] + 1
|
||||
}
|
||||
|
||||
// Aggregate value counts
|
||||
values := v
|
||||
if v, ok := v.([]interface{}); ok {
|
||||
values = v
|
||||
} else {
|
||||
values = []interface{}{v}
|
||||
}
|
||||
|
||||
for _, vv := range values.([]interface{}) {
|
||||
// Aggregate value counts. Flatten nested arrays so malformed values do
|
||||
// not surface in the UI as the literal string "[]".
|
||||
values := flattenMetadataSummaryValues(v)
|
||||
for _, vv := range values {
|
||||
if vv == nil {
|
||||
continue
|
||||
}
|
||||
@@ -1676,6 +1674,27 @@ func getMetaValueType(value interface{}) string {
|
||||
return "string"
|
||||
}
|
||||
|
||||
func flattenMetadataSummaryValues(value interface{}) []interface{} {
|
||||
switch typed := value.(type) {
|
||||
case []interface{}:
|
||||
result := make([]interface{}, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
result = append(result, flattenMetadataSummaryValues(item)...)
|
||||
}
|
||||
return result
|
||||
case []string:
|
||||
result := make([]interface{}, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
result = append(result, item)
|
||||
}
|
||||
return result
|
||||
case nil:
|
||||
return nil
|
||||
default:
|
||||
return []interface{}{typed}
|
||||
}
|
||||
}
|
||||
|
||||
// isTimeString checks if a string is an ISO 8601 datetime
|
||||
func isTimeString(s string) bool {
|
||||
matched, _ := regexp.MatchString(`^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$`, s)
|
||||
@@ -2176,3 +2195,471 @@ func mapDocumentRunStatus(run *string) string {
|
||||
return "UNSTART"
|
||||
}
|
||||
}
|
||||
|
||||
// MetadataUpdate is one update item: set key to value.
|
||||
type DocumentMetadataUpdate struct {
|
||||
Key string `json:"key"`
|
||||
Value interface{} `json:"value"`
|
||||
Match interface{} `json:"match,omitempty"`
|
||||
ValueType string `json:"valueType,omitempty"`
|
||||
}
|
||||
|
||||
// MetadataDelete removes a whole key, or a specific value from a list field.
|
||||
type DocumentMetadataDelete struct {
|
||||
Key string `json:"key"`
|
||||
Value interface{} `json:"value,omitempty"`
|
||||
}
|
||||
|
||||
// MetadataSelector selects which documents to target.
|
||||
type DocumentMetadataSelector struct {
|
||||
DocumentIDs []string `json:"document_ids"`
|
||||
MetadataCondition map[string]interface{} `json:"metadata_condition"`
|
||||
}
|
||||
|
||||
// BatchUpdateDocumentMetadatasResponse summarises the operation.
|
||||
type BatchUpdateDocumentMetadatasResponse struct {
|
||||
Updated int `json:"updated"`
|
||||
MatchedDocs int `json:"matched_docs"`
|
||||
}
|
||||
|
||||
// BatchUpdateDocumentMetadatas implements the shared logic for
|
||||
// PATCH /datasets/:dataset_id/documents/metadatas and
|
||||
// POST /datasets/:dataset_id/metadata/update.
|
||||
func (s *DocumentService) BatchUpdateDocumentMetadatas(
|
||||
datasetID string,
|
||||
selector *DocumentMetadataSelector,
|
||||
updates []DocumentMetadataUpdate,
|
||||
deletes []DocumentMetadataDelete,
|
||||
) (*BatchUpdateDocumentMetadatasResponse, common.ErrorCode, error) {
|
||||
if selector == nil {
|
||||
selector = &DocumentMetadataSelector{}
|
||||
}
|
||||
if code, err := validateBatchUpdateDocumentMetadatasRequest(selector, updates, deletes); err != nil {
|
||||
return nil, code, err
|
||||
}
|
||||
|
||||
// Resolve which document IDs to target.
|
||||
targetDocIDs := make(map[string]struct{})
|
||||
|
||||
if len(selector.DocumentIDs) > 0 {
|
||||
// Validate that supplied IDs actually belong to this dataset.
|
||||
allRows, err := s.documentDAO.GetAllDocIDsByKBIDs([]string{datasetID})
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to list dataset documents: %w", err)
|
||||
}
|
||||
kbDocIDSet := make(map[string]struct{}, len(allRows))
|
||||
for _, row := range allRows {
|
||||
kbDocIDSet[row["id"]] = struct{}{}
|
||||
}
|
||||
var invalidIDs []string
|
||||
for _, id := range selector.DocumentIDs {
|
||||
if _, ok := kbDocIDSet[id]; !ok {
|
||||
invalidIDs = append(invalidIDs, id)
|
||||
}
|
||||
}
|
||||
if len(invalidIDs) > 0 {
|
||||
return nil, common.CodeDataError, fmt.Errorf("these documents do not belong to dataset %s: %s",
|
||||
datasetID, strings.Join(invalidIDs, ", "))
|
||||
}
|
||||
for _, id := range selector.DocumentIDs {
|
||||
targetDocIDs[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply metadata_condition filter.
|
||||
if len(selector.MetadataCondition) > 0 {
|
||||
flattedMeta, err := s.metadataSvc.GetFlattedMetaByKBs([]string{datasetID})
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to get flattened metadata: %w", err)
|
||||
}
|
||||
|
||||
// ParseAndConvert mirrors Python convert_conditions: conditions arrive as
|
||||
// {name, comparison_operator, value}, the operator is normalised, and the
|
||||
// (possibly non-string) value is preserved. MetaFilter then matches against
|
||||
// the common.MetaData returned by GetFlattedMetaByKBs.
|
||||
filterInput := common.ParseAndConvert(selector.MetadataCondition)
|
||||
filteredIDs := common.MetaFilter(flattedMeta, filterInput)
|
||||
|
||||
filteredSet := make(map[string]struct{}, len(filteredIDs))
|
||||
for _, id := range filteredIDs {
|
||||
filteredSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
if len(targetDocIDs) > 0 {
|
||||
// Intersect with the document_ids restriction.
|
||||
for id := range targetDocIDs {
|
||||
if _, ok := filteredSet[id]; !ok {
|
||||
delete(targetDocIDs, id)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
targetDocIDs = filteredSet
|
||||
}
|
||||
|
||||
// Early-exit when conditions given but nothing matched.
|
||||
rawConds, _ := selector.MetadataCondition["conditions"]
|
||||
if rawConds != nil && len(targetDocIDs) == 0 {
|
||||
return &BatchUpdateDocumentMetadatasResponse{Updated: 0, MatchedDocs: 0}, common.CodeSuccess, nil
|
||||
}
|
||||
}
|
||||
|
||||
ids := make([]string, 0, len(targetDocIDs))
|
||||
for id := range targetDocIDs {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
// Apply updates and deletes per document using Python's batch_update_metadata
|
||||
// semantics instead of a simple merge-then-delete.
|
||||
updated := 0
|
||||
for _, docID := range ids {
|
||||
currentMeta, err := s.GetDocumentMetadataByID(docID)
|
||||
if err != nil {
|
||||
common.Warn("BatchUpdateDocumentMetadata: get metadata failed",
|
||||
zap.String("docID", docID), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
meta := cloneDocumentMetadata(currentMeta)
|
||||
originalMeta := cloneDocumentMetadata(meta)
|
||||
|
||||
changed := applyDocumentMetadataUpdates(meta, updates)
|
||||
if applyDocumentMetadataDeletes(meta, deletes) {
|
||||
changed = true
|
||||
}
|
||||
|
||||
if !changed || reflect.DeepEqual(originalMeta, meta) {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(meta) == 0 {
|
||||
if err := s.DeleteDocumentAllMetadata(docID); err != nil {
|
||||
common.Warn("BatchUpdateDocumentMetadata: delete all metadata failed",
|
||||
zap.String("docID", docID), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
updated++
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.replaceDocumentMetadata(docID, meta); err != nil {
|
||||
common.Warn("BatchUpdateDocumentMetadata: replace metadata failed",
|
||||
zap.String("docID", docID), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
updated++
|
||||
}
|
||||
|
||||
return &BatchUpdateDocumentMetadatasResponse{Updated: updated, MatchedDocs: len(ids)}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (s *DocumentService) UploadDocumentInfos(userID string, files []*multipart.FileHeader) ([]map[string]interface{}, common.ErrorCode, error) {
|
||||
fileSvc := &FileService{
|
||||
fileDAO: s.fileDAO,
|
||||
file2DocumentDAO: s.file2DocumentDAO,
|
||||
}
|
||||
data, err := fileSvc.UploadInfos(userID, files)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, err
|
||||
}
|
||||
return data, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (s *DocumentService) UploadDocumentInfoByURL(userID, rawURL string) (map[string]interface{}, common.ErrorCode, error) {
|
||||
fileSvc := &FileService{
|
||||
fileDAO: s.fileDAO,
|
||||
file2DocumentDAO: s.file2DocumentDAO,
|
||||
}
|
||||
data, err := fileSvc.UploadFromURL(userID, rawURL)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, err
|
||||
}
|
||||
return data, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func validateBatchUpdateDocumentMetadatasRequest(
|
||||
selector *DocumentMetadataSelector,
|
||||
updates []DocumentMetadataUpdate,
|
||||
deletes []DocumentMetadataDelete,
|
||||
) (common.ErrorCode, error) {
|
||||
for _, upd := range updates {
|
||||
if strings.TrimSpace(upd.Key) == "" || upd.Value == nil {
|
||||
return common.CodeDataError, errors.New("Each update requires key and value.")
|
||||
}
|
||||
}
|
||||
for _, del := range deletes {
|
||||
if strings.TrimSpace(del.Key) == "" {
|
||||
return common.CodeDataError, errors.New("Each delete requires key.")
|
||||
}
|
||||
}
|
||||
if selector != nil && selector.MetadataCondition != nil {
|
||||
if _, ok := selector.MetadataCondition["conditions"]; !ok && len(selector.MetadataCondition) > 0 {
|
||||
return common.CodeDataError, errors.New("metadata_condition must be an object.")
|
||||
}
|
||||
}
|
||||
return common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func cloneDocumentMetadata(meta map[string]interface{}) map[string]interface{} {
|
||||
if meta == nil {
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
cloned := make(map[string]interface{}, len(meta))
|
||||
for k, v := range meta {
|
||||
cloned[k] = cloneDocumentMetadataValue(v)
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneDocumentMetadataValue(v interface{}) interface{} {
|
||||
switch typed := v.(type) {
|
||||
case []interface{}:
|
||||
cp := make([]interface{}, len(typed))
|
||||
copy(cp, typed)
|
||||
return cp
|
||||
case []string:
|
||||
cp := make([]interface{}, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
cp = append(cp, item)
|
||||
}
|
||||
return cp
|
||||
default:
|
||||
return typed
|
||||
}
|
||||
}
|
||||
|
||||
func applyDocumentMetadataUpdates(meta map[string]interface{}, updates []DocumentMetadataUpdate) bool {
|
||||
changed := false
|
||||
for _, upd := range updates {
|
||||
key := strings.TrimSpace(upd.Key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
normalizedValue := normalizeDocumentMetadataUpdateValue(upd.Value, upd.ValueType)
|
||||
matchProvided := upd.Match != nil && !(fmt.Sprintf("%v", upd.Match) == "")
|
||||
current, exists := meta[key]
|
||||
if !exists {
|
||||
if matchProvided {
|
||||
continue
|
||||
}
|
||||
if listVal, ok := toMetadataInterfaceSlice(normalizedValue); ok {
|
||||
meta[key] = dedupeDocumentMetadataList(listVal)
|
||||
} else {
|
||||
meta[key] = normalizedValue
|
||||
}
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
|
||||
if curList, ok := toMetadataInterfaceSlice(current); ok {
|
||||
if !matchProvided {
|
||||
newList := append([]interface{}{}, curList...)
|
||||
if appendList, ok := toMetadataInterfaceSlice(normalizedValue); ok {
|
||||
newList = append(newList, appendList...)
|
||||
} else {
|
||||
newList = append(newList, normalizedValue)
|
||||
}
|
||||
newList = dedupeDocumentMetadataList(newList)
|
||||
if !reflect.DeepEqual(curList, newList) {
|
||||
meta[key] = newList
|
||||
changed = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
replaced := false
|
||||
newList := make([]interface{}, 0, len(curList))
|
||||
for _, item := range curList {
|
||||
if documentMetadataValuesEqual(item, upd.Match) {
|
||||
if replacementList, ok := toMetadataInterfaceSlice(normalizedValue); ok {
|
||||
newList = append(newList, replacementList...)
|
||||
} else {
|
||||
newList = append(newList, normalizedValue)
|
||||
}
|
||||
replaced = true
|
||||
} else {
|
||||
newList = append(newList, item)
|
||||
}
|
||||
}
|
||||
newList = dedupeDocumentMetadataList(newList)
|
||||
if replaced && !reflect.DeepEqual(curList, newList) {
|
||||
meta[key] = newList
|
||||
changed = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if !matchProvided {
|
||||
if !reflect.DeepEqual(current, normalizedValue) {
|
||||
meta[key] = normalizedValue
|
||||
changed = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if documentMetadataValuesEqual(current, upd.Match) && !reflect.DeepEqual(current, normalizedValue) {
|
||||
meta[key] = normalizedValue
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
func applyDocumentMetadataDeletes(meta map[string]interface{}, deletes []DocumentMetadataDelete) bool {
|
||||
changed := false
|
||||
for _, del := range deletes {
|
||||
key := strings.TrimSpace(del.Key)
|
||||
current, exists := meta[key]
|
||||
if key == "" || !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
if curList, ok := toMetadataInterfaceSlice(current); ok {
|
||||
if del.Value == nil {
|
||||
delete(meta, key)
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
newList := make([]interface{}, 0, len(curList))
|
||||
for _, item := range curList {
|
||||
if !documentMetadataValuesEqual(item, del.Value) {
|
||||
newList = append(newList, item)
|
||||
}
|
||||
}
|
||||
if len(newList) != len(curList) {
|
||||
if len(newList) == 0 {
|
||||
delete(meta, key)
|
||||
} else {
|
||||
meta[key] = newList
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if del.Value == nil || documentMetadataValuesEqual(current, del.Value) {
|
||||
delete(meta, key)
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
func toMetadataInterfaceSlice(v interface{}) ([]interface{}, bool) {
|
||||
switch typed := v.(type) {
|
||||
case []interface{}:
|
||||
cp := make([]interface{}, len(typed))
|
||||
copy(cp, typed)
|
||||
return cp, true
|
||||
case []string:
|
||||
cp := make([]interface{}, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
cp = append(cp, item)
|
||||
}
|
||||
return cp, true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func dedupeDocumentMetadataList(items []interface{}) []interface{} {
|
||||
result := make([]interface{}, 0, len(items))
|
||||
seen := make(map[string]struct{}, len(items))
|
||||
for _, item := range items {
|
||||
key := fmt.Sprintf("%T:%v", item, item)
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
result = append(result, item)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func documentMetadataValuesEqual(a, b interface{}) bool {
|
||||
return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b)
|
||||
}
|
||||
|
||||
func normalizeDocumentMetadataUpdateValue(value interface{}, valueType string) interface{} {
|
||||
switch strings.ToLower(strings.TrimSpace(valueType)) {
|
||||
case "list":
|
||||
if list, ok := normalizeMetadataListValue(value); ok {
|
||||
return list
|
||||
}
|
||||
return []interface{}{}
|
||||
case "number":
|
||||
scalar, ok := firstScalarMetadataValue(value)
|
||||
if !ok {
|
||||
return value
|
||||
}
|
||||
switch typed := scalar.(type) {
|
||||
case float64, float32, int, int8, int16, int32, int64:
|
||||
return typed
|
||||
case json.Number:
|
||||
if i, err := typed.Int64(); err == nil {
|
||||
return i
|
||||
}
|
||||
if f, err := typed.Float64(); err == nil {
|
||||
return f
|
||||
}
|
||||
case string:
|
||||
trimmed := strings.TrimSpace(typed)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if i, err := strconv.ParseInt(trimmed, 10, 64); err == nil {
|
||||
return i
|
||||
}
|
||||
if f, err := strconv.ParseFloat(trimmed, 64); err == nil {
|
||||
return f
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
return scalar
|
||||
case "string", "time":
|
||||
if scalar, ok := firstScalarMetadataValue(value); ok {
|
||||
return fmt.Sprintf("%v", scalar)
|
||||
}
|
||||
return ""
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeMetadataListValue(value interface{}) ([]interface{}, bool) {
|
||||
switch typed := value.(type) {
|
||||
case []interface{}:
|
||||
result := make([]interface{}, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
if nested, ok := normalizeMetadataListValue(item); ok {
|
||||
result = append(result, nested...)
|
||||
continue
|
||||
}
|
||||
if item != nil {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result, true
|
||||
case []string:
|
||||
result := make([]interface{}, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
result = append(result, item)
|
||||
}
|
||||
return result, true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func firstScalarMetadataValue(value interface{}) (interface{}, bool) {
|
||||
if list, ok := normalizeMetadataListValue(value); ok {
|
||||
for _, item := range list {
|
||||
if item != nil {
|
||||
return item, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
if value == nil {
|
||||
return nil, false
|
||||
}
|
||||
return value, true
|
||||
}
|
||||
|
||||
@@ -19,7 +19,9 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
@@ -130,6 +132,107 @@ type failingDeleteMetadataEngine struct {
|
||||
updateCalled bool
|
||||
}
|
||||
|
||||
type metadataDocEngine struct {
|
||||
fakeChatDocEngine
|
||||
records map[string]map[string]interface{}
|
||||
docKBs map[string]string
|
||||
}
|
||||
|
||||
func newMetadataDocEngine(records map[string]map[string]interface{}, docKBs map[string]string) *metadataDocEngine {
|
||||
cp := make(map[string]map[string]interface{}, len(records))
|
||||
for id, meta := range records {
|
||||
dup := make(map[string]interface{}, len(meta))
|
||||
for k, v := range meta {
|
||||
dup[k] = v
|
||||
}
|
||||
cp[id] = dup
|
||||
}
|
||||
return &metadataDocEngine{records: cp, docKBs: docKBs}
|
||||
}
|
||||
|
||||
func (m *metadataDocEngine) SearchMetadata(_ context.Context, req *types.SearchMetadataRequest) (*types.SearchMetadataResult, error) {
|
||||
var ids map[string]struct{}
|
||||
if rawIDs, ok := req.Filter["id"]; ok && rawIDs != nil {
|
||||
ids = make(map[string]struct{})
|
||||
switch typed := rawIDs.(type) {
|
||||
case []string:
|
||||
for _, id := range typed {
|
||||
ids[id] = struct{}{}
|
||||
}
|
||||
case []interface{}:
|
||||
for _, id := range typed {
|
||||
if s, ok := id.(string); ok {
|
||||
ids[s] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var kbFilter map[string]struct{}
|
||||
if rawKB, ok := req.Filter["kb_id"]; ok && rawKB != nil {
|
||||
kbFilter = make(map[string]struct{})
|
||||
switch typed := rawKB.(type) {
|
||||
case string:
|
||||
kbFilter[typed] = struct{}{}
|
||||
case []string:
|
||||
for _, kb := range typed {
|
||||
kbFilter[kb] = struct{}{}
|
||||
}
|
||||
case []interface{}:
|
||||
for _, kb := range typed {
|
||||
if s, ok := kb.(string); ok {
|
||||
kbFilter[s] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := &types.SearchMetadataResult{MetadataRecords: []map[string]interface{}{}}
|
||||
for docID, meta := range m.records {
|
||||
if ids != nil {
|
||||
if _, ok := ids[docID]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
kbID := m.docKBs[docID]
|
||||
if kbFilter != nil {
|
||||
if _, ok := kbFilter[kbID]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
result.MetadataRecords = append(result.MetadataRecords, map[string]interface{}{
|
||||
"id": docID,
|
||||
"kb_id": kbID,
|
||||
"meta_fields": meta,
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *metadataDocEngine) UpdateMetadata(_ context.Context, docID string, datasetID string, metaFields map[string]interface{}, tenantID string) error {
|
||||
dup := make(map[string]interface{}, len(metaFields))
|
||||
for k, v := range metaFields {
|
||||
dup[k] = v
|
||||
}
|
||||
m.records[docID] = dup
|
||||
if _, ok := m.docKBs[docID]; !ok {
|
||||
m.docKBs[docID] = datasetID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *metadataDocEngine) DeleteMetadata(_ context.Context, condition map[string]interface{}, tenantID string) (int64, error) {
|
||||
docID, _ := condition["id"].(string)
|
||||
if docID == "" {
|
||||
return 0, nil
|
||||
}
|
||||
if _, ok := m.records[docID]; ok {
|
||||
delete(m.records, docID)
|
||||
return 1, nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (f *failingDeleteMetadataEngine) DeleteMetadata(ctx context.Context, condition map[string]interface{}, tenantID string) (int64, error) {
|
||||
return 0, f.deleteErr
|
||||
}
|
||||
@@ -1304,6 +1407,192 @@ func TestChunkImageStorageKeyFallsBackToChunkID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchUpdateDocumentMetadatasMatchesPythonSemantics(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
pushServiceDB(t, db)
|
||||
insertTestKB(t, "kb-1", "tenant-1", 3, 0, 0)
|
||||
insertNamedTestDoc(t, "doc-1", "kb-1", "doc1.txt", 0, 0)
|
||||
insertNamedTestDoc(t, "doc-2", "kb-1", "doc2.txt", 0, 0)
|
||||
insertNamedTestDoc(t, "doc-3", "kb-1", "doc3.txt", 0, 0)
|
||||
|
||||
engine := newMetadataDocEngine(map[string]map[string]interface{}{
|
||||
"doc-1": {"tags": []interface{}{"old", "keep"}, "author": "alice"},
|
||||
"doc-2": {"tags": []interface{}{"old"}, "author": "bob"},
|
||||
}, map[string]string{"doc-1": "kb-1", "doc-2": "kb-1", "doc-3": "kb-1"})
|
||||
|
||||
svc := testDocumentService(t)
|
||||
svc.docEngine = engine
|
||||
svc.metadataSvc = &MetadataService{kbDAO: dao.NewKnowledgebaseDAO(), docEngine: engine}
|
||||
|
||||
resp, code, err := svc.BatchUpdateDocumentMetadatas("kb-1", &DocumentMetadataSelector{
|
||||
DocumentIDs: []string{"doc-1", "doc-2", "doc-3"},
|
||||
}, []DocumentMetadataUpdate{
|
||||
{Key: "tags", Value: "new", Match: "old"},
|
||||
{Key: "category", Value: "paper"},
|
||||
}, []DocumentMetadataDelete{
|
||||
{Key: "author", Value: "alice"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("BatchUpdateDocumentMetadatas failed: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code = %v, want success", code)
|
||||
}
|
||||
if resp.Updated != 3 || resp.MatchedDocs != 3 {
|
||||
t.Fatalf("resp = %#v, want updated=3 matched=3", resp)
|
||||
}
|
||||
|
||||
got1 := engine.records["doc-1"]
|
||||
if fmt.Sprintf("%v", got1["category"]) != "paper" {
|
||||
t.Fatalf("doc-1 category = %#v", got1["category"])
|
||||
}
|
||||
if _, ok := got1["author"]; ok {
|
||||
t.Fatalf("doc-1 author should be deleted: %#v", got1)
|
||||
}
|
||||
if got := got1["tags"].([]interface{}); len(got) != 2 || got[0] != "new" || got[1] != "keep" {
|
||||
t.Fatalf("doc-1 tags = %#v", got)
|
||||
}
|
||||
|
||||
got2 := engine.records["doc-2"]
|
||||
if fmt.Sprintf("%v", got2["author"]) != "bob" {
|
||||
t.Fatalf("doc-2 author should be kept: %#v", got2["author"])
|
||||
}
|
||||
if got := got2["tags"].([]interface{}); len(got) != 1 || got[0] != "new" {
|
||||
t.Fatalf("doc-2 tags = %#v", got)
|
||||
}
|
||||
|
||||
got3 := engine.records["doc-3"]
|
||||
if fmt.Sprintf("%v", got3["category"]) != "paper" {
|
||||
t.Fatalf("doc-3 category = %#v", got3)
|
||||
}
|
||||
if _, ok := got3["tags"]; ok {
|
||||
t.Fatalf("doc-3 tags should not be created by match-only update: %#v", got3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchUpdateDocumentMetadatasDeletesEmptyMetadataAndNoOps(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
pushServiceDB(t, db)
|
||||
insertTestKB(t, "kb-1", "tenant-1", 2, 0, 0)
|
||||
insertNamedTestDoc(t, "doc-1", "kb-1", "doc1.txt", 0, 0)
|
||||
insertNamedTestDoc(t, "doc-2", "kb-1", "doc2.txt", 0, 0)
|
||||
|
||||
engine := newMetadataDocEngine(map[string]map[string]interface{}{
|
||||
"doc-1": {"status": "draft"},
|
||||
"doc-2": {"status": "done"},
|
||||
}, map[string]string{"doc-1": "kb-1", "doc-2": "kb-1"})
|
||||
|
||||
svc := testDocumentService(t)
|
||||
svc.docEngine = engine
|
||||
svc.metadataSvc = &MetadataService{kbDAO: dao.NewKnowledgebaseDAO(), docEngine: engine}
|
||||
|
||||
resp, code, err := svc.BatchUpdateDocumentMetadatas("kb-1", &DocumentMetadataSelector{
|
||||
DocumentIDs: []string{"doc-1", "doc-2"},
|
||||
}, nil, []DocumentMetadataDelete{{Key: "status", Value: "draft"}})
|
||||
if err != nil || code != common.CodeSuccess {
|
||||
t.Fatalf("delete batch failed: code=%v err=%v", code, err)
|
||||
}
|
||||
if resp.Updated != 1 || resp.MatchedDocs != 2 {
|
||||
t.Fatalf("resp = %#v, want updated=1 matched=2", resp)
|
||||
}
|
||||
if _, ok := engine.records["doc-1"]; ok {
|
||||
t.Fatalf("doc-1 metadata should be fully removed: %#v", engine.records["doc-1"])
|
||||
}
|
||||
if fmt.Sprintf("%v", engine.records["doc-2"]["status"]) != "done" {
|
||||
t.Fatalf("doc-2 metadata unexpectedly changed: %#v", engine.records["doc-2"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchUpdateDocumentMetadatasNormalizesNumberValues(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
pushServiceDB(t, db)
|
||||
insertTestKB(t, "kb-1", "tenant-1", 1, 0, 0)
|
||||
insertNamedTestDoc(t, "doc-1", "kb-1", "doc1.txt", 0, 0)
|
||||
|
||||
engine := newMetadataDocEngine(map[string]map[string]interface{}{}, map[string]string{"doc-1": "kb-1"})
|
||||
|
||||
svc := testDocumentService(t)
|
||||
svc.docEngine = engine
|
||||
svc.metadataSvc = &MetadataService{kbDAO: dao.NewKnowledgebaseDAO(), docEngine: engine}
|
||||
|
||||
resp, code, err := svc.BatchUpdateDocumentMetadatas("kb-1", &DocumentMetadataSelector{
|
||||
DocumentIDs: []string{"doc-1"},
|
||||
}, []DocumentMetadataUpdate{
|
||||
{Key: "score", Value: "42", ValueType: "number"},
|
||||
}, nil)
|
||||
if err != nil || code != common.CodeSuccess {
|
||||
t.Fatalf("number batch failed: code=%v err=%v", code, err)
|
||||
}
|
||||
if resp.Updated != 1 || resp.MatchedDocs != 1 {
|
||||
t.Fatalf("resp = %#v, want updated=1 matched=1", resp)
|
||||
}
|
||||
|
||||
got := engine.records["doc-1"]["score"]
|
||||
switch v := got.(type) {
|
||||
case int64:
|
||||
if v != 42 {
|
||||
t.Fatalf("score = %v, want 42", v)
|
||||
}
|
||||
case float64:
|
||||
if v != 42 {
|
||||
t.Fatalf("score = %v, want 42", v)
|
||||
}
|
||||
default:
|
||||
t.Fatalf("score type = %T, want numeric value", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchUpdateDocumentMetadatasRejectsMissingValue(t *testing.T) {
|
||||
svc := testDocumentService(t)
|
||||
resp, code, err := svc.BatchUpdateDocumentMetadatas("kb-1", &DocumentMetadataSelector{}, []DocumentMetadataUpdate{
|
||||
{Key: "status"},
|
||||
}, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected validation error for missing value")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatalf("resp = %#v, want nil", resp)
|
||||
}
|
||||
if code != common.CodeDataError {
|
||||
t.Fatalf("code = %v, want data error", code)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "Each update requires key and value.") {
|
||||
t.Fatalf("err = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAggregateMetadataIgnoresNestedEmptyLists(t *testing.T) {
|
||||
summary := aggregateMetadata([]map[string]interface{}{
|
||||
{
|
||||
"id": "doc-1",
|
||||
"kb_id": "kb-1",
|
||||
"meta_fields": map[string]interface{}{
|
||||
"score": []interface{}{[]interface{}{}, 7.0},
|
||||
"name": "alice",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
scoreField, ok := summary["score"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("score summary missing: %#v", summary)
|
||||
}
|
||||
values, ok := scoreField["values"].([][2]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("score values type = %T", scoreField["values"])
|
||||
}
|
||||
if len(values) != 1 || values[0][0] != "7" || values[0][1] != 1 {
|
||||
t.Fatalf("score values = %#v, want [[\"7\",1]]", values)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeFieldValuesKeepsNumericValues(t *testing.T) {
|
||||
got := mergeFieldValues(1.0, 2.0)
|
||||
if len(got) != 2 || got[0] != 1.0 || got[1] != 2.0 {
|
||||
t.Fatalf("mergeFieldValues = %#v, want [1 2]", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDatasetDocumentPipelineIDTakesPrecedenceOverChunkMethod(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
pushServiceDB(t, db)
|
||||
|
||||
@@ -20,8 +20,12 @@ import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"ragflow/internal/common"
|
||||
@@ -31,6 +35,7 @@ import (
|
||||
"ragflow/internal/ingestion/parser"
|
||||
"ragflow/internal/storage"
|
||||
"ragflow/internal/utility"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -156,6 +161,11 @@ const DatasetFolderName = ".knowledgebase"
|
||||
// FileSourceDataset represents dataset as file source
|
||||
const FileSourceDataset = "knowledgebase"
|
||||
|
||||
var (
|
||||
assertURLSafe = utility.AssertURLSafe
|
||||
pinnedHTTPClient = utility.PinnedHTTPClient
|
||||
)
|
||||
|
||||
// toFileResponse converts file model to response format
|
||||
func (s *FileService) toFileResponse(file *entity.File) map[string]interface{} {
|
||||
result := map[string]interface{}{
|
||||
@@ -391,6 +401,57 @@ func (s *FileService) UploadFile(tenantID, parentID string, files []*multipart.F
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// UploadInfos mirrors Python's upload_info file branch: store raw bytes in the
|
||||
// per-user downloads bucket and return lightweight upload descriptors instead
|
||||
// of creating full File rows in the file-management tree.
|
||||
func (s *FileService) UploadInfos(userID string, files []*multipart.FileHeader) ([]map[string]interface{}, error) {
|
||||
storageImpl := storage.GetStorageFactory().GetStorage()
|
||||
if storageImpl == nil {
|
||||
return nil, fmt.Errorf("storage not initialized")
|
||||
}
|
||||
|
||||
results := make([]map[string]interface{}, 0, len(files))
|
||||
for _, fileHeader := range files {
|
||||
filename := fileHeader.Filename
|
||||
if err := s.checkUploadInfoHealth(userID, filename); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
src, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open uploaded file: %w", err)
|
||||
}
|
||||
data, readErr := readUploadInfoData(src)
|
||||
src.Close()
|
||||
if readErr != nil {
|
||||
return nil, fmt.Errorf("failed to read file data: %w", readErr)
|
||||
}
|
||||
|
||||
contentType := fileHeader.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = http.DetectContentType(data)
|
||||
}
|
||||
filename, contentType, data = normalizeUploadInfoContent(filename, contentType, data)
|
||||
resp, err := s.storeUploadInfoBlob(storageImpl, userID, filename, contentType, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, resp)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func readUploadInfoData(r io.Reader) ([]byte, error) {
|
||||
limited := &io.LimitedReader{R: r, N: maxRemoteFileSize + 1}
|
||||
data, err := io.ReadAll(limited)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int64(len(data)) > maxRemoteFileSize {
|
||||
return nil, fmt.Errorf("file size exceeds %d bytes", maxRemoteFileSize)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *FileService) parseFilePath(filename string) []string {
|
||||
filename = strings.TrimPrefix(filename, "/")
|
||||
parts := strings.Split(filename, "/")
|
||||
@@ -1068,3 +1129,303 @@ func parseFileContent(filename string, data []byte) string {
|
||||
}
|
||||
return fp.String()
|
||||
}
|
||||
|
||||
// toUploadInfoResponse converts a newly-uploaded file record to the shape
|
||||
// Python's upload_info endpoint returns.
|
||||
func (s *FileService) toUploadInfoResponse(file *entity.File, mimeType string) map[string]interface{} {
|
||||
ext := ""
|
||||
if idx := strings.LastIndex(file.Name, "."); idx >= 0 {
|
||||
ext = strings.ToLower(file.Name[idx+1:])
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"id": file.ID,
|
||||
"name": file.Name,
|
||||
"size": file.Size,
|
||||
"extension": ext,
|
||||
"mime_type": mimeType,
|
||||
"created_by": file.CreatedBy,
|
||||
"created_at": float64(time.Now().UnixMilli()) / 1000.0,
|
||||
"preview_url": nil,
|
||||
}
|
||||
}
|
||||
|
||||
// maxRemoteFileSize bounds the body of a ?url= upload (100 MB).
|
||||
const maxRemoteFileSize = 100 << 20
|
||||
|
||||
// UploadFromURL fetches a remote URL, saves the content to the tenant's root
|
||||
// folder, and returns the file metadata map — mirroring Python
|
||||
// FileService.upload_info(tenant_id, None, url).
|
||||
//
|
||||
// The remote fetch is SSRF-guarded (mirrors Python's assert_url_is_safe): the
|
||||
// scheme must be http/https and every address the host resolves to must be
|
||||
// globally routable; the validated IP is pinned for the actual connection — and
|
||||
// re-validated on each redirect hop — to defeat DNS-rebinding. The HTTP client
|
||||
// carries connect and overall timeouts, and the response body is bounded with
|
||||
// truncation detection so an oversized file is rejected rather than silently
|
||||
// clipped.
|
||||
func (s *FileService) UploadFromURL(tenantID, rawURL string) (map[string]interface{}, error) {
|
||||
if rawURL == "" {
|
||||
return nil, fmt.Errorf("url is required")
|
||||
}
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") || parsed.Hostname() == "" {
|
||||
return nil, fmt.Errorf("invalid or unsafe URL")
|
||||
}
|
||||
|
||||
data, headers, finalURL, err := fetchRemoteFileSafely(rawURL, maxRemoteFileSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storageImpl := storage.GetStorageFactory().GetStorage()
|
||||
if storageImpl == nil {
|
||||
return nil, fmt.Errorf("storage not initialized")
|
||||
}
|
||||
|
||||
contentType := headers.Get("Content-Type")
|
||||
filename := normalizeRemoteUploadFilename(finalURL, contentType, data)
|
||||
if err := s.checkUploadInfoHealth(tenantID, filename); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filename, contentType, data = normalizeUploadInfoContent(filename, contentType, data)
|
||||
return s.storeUploadInfoBlob(storageImpl, tenantID, filename, contentType, data)
|
||||
}
|
||||
|
||||
// fetchRemoteFileSafely downloads rawURL with SSRF protection, connect/overall
|
||||
// timeouts, and a hard size cap that rejects (rather than truncates) oversized
|
||||
// bodies.
|
||||
func fetchRemoteFileSafely(rawURL string, maxSize int64) ([]byte, http.Header, string, error) {
|
||||
currentURL := rawURL
|
||||
for redirects := 0; redirects < 10; redirects++ {
|
||||
hostname, resolvedIP, err := assertURLSafe(currentURL)
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
client := pinnedHTTPClient(hostname, resolvedIP, 10*time.Second)
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
|
||||
resp, err := client.Get(currentURL) // #nosec G107
|
||||
if err != nil {
|
||||
return nil, nil, "", fmt.Errorf("failed to fetch URL: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusMovedPermanently ||
|
||||
resp.StatusCode == http.StatusFound ||
|
||||
resp.StatusCode == http.StatusSeeOther ||
|
||||
resp.StatusCode == http.StatusTemporaryRedirect ||
|
||||
resp.StatusCode == http.StatusPermanentRedirect {
|
||||
location := resp.Header.Get("Location")
|
||||
resp.Body.Close()
|
||||
if location == "" {
|
||||
return nil, nil, "", fmt.Errorf("redirect response missing Location header")
|
||||
}
|
||||
baseURL, parseErr := url.Parse(currentURL)
|
||||
if parseErr != nil {
|
||||
return nil, nil, "", parseErr
|
||||
}
|
||||
nextURL, resolveErr := baseURL.Parse(location)
|
||||
if resolveErr != nil {
|
||||
return nil, nil, "", resolveErr
|
||||
}
|
||||
currentURL = nextURL.String()
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
resp.Body.Close()
|
||||
return nil, nil, "", fmt.Errorf("remote URL returned HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
data, readErr := io.ReadAll(io.LimitReader(resp.Body, maxSize+1))
|
||||
resp.Body.Close()
|
||||
if readErr != nil {
|
||||
return nil, nil, "", fmt.Errorf("failed to read remote content: %w", readErr)
|
||||
}
|
||||
if int64(len(data)) > maxSize {
|
||||
return nil, nil, "", fmt.Errorf("remote file exceeds the maximum allowed size of %d bytes", maxSize)
|
||||
}
|
||||
return data, resp.Header.Clone(), currentURL, nil
|
||||
}
|
||||
return nil, nil, "", fmt.Errorf("stopped after too many redirects")
|
||||
}
|
||||
|
||||
// isPublicIP reports whether ip is a globally routable address. It mirrors the
|
||||
// allowlist intent of Python's assert_url_is_safe (which requires ip.is_global)
|
||||
// by rejecting loopback, private, link-local, multicast, unspecified, and
|
||||
// carrier-grade NAT ranges. IPv4-mapped IPv6 addresses are handled by the
|
||||
// stdlib predicates.
|
||||
func isPublicIP(ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() ||
|
||||
ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() ||
|
||||
ip.IsMulticast() || ip.IsInterfaceLocalMulticast() {
|
||||
return false
|
||||
}
|
||||
// Carrier-grade NAT 100.64.0.0/10 (RFC 6598) — not covered by IsPrivate.
|
||||
if ip4 := ip.To4(); ip4 != nil && ip4[0] == 100 && ip4[1]&0xc0 == 0x40 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *FileService) checkUploadInfoHealth(userID, filename string) error {
|
||||
if filename == "" {
|
||||
return fmt.Errorf("No file selected!")
|
||||
}
|
||||
maxFileNumPerUser := os.Getenv("MAX_FILE_NUM_PER_USER")
|
||||
if maxFileNumPerUser != "" {
|
||||
var maxNum int64
|
||||
if _, err := fmt.Sscanf(maxFileNumPerUser, "%d", &maxNum); err == nil && maxNum > 0 {
|
||||
docCount, err := s.GetDocCount(userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get document count: %w", err)
|
||||
}
|
||||
if docCount >= maxNum {
|
||||
return fmt.Errorf("Exceed the maximum file number of a free user!")
|
||||
}
|
||||
}
|
||||
}
|
||||
if len([]byte(filename)) > 255 {
|
||||
return fmt.Errorf("Exceed the maximum length of file name!")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *FileService) storeUploadInfoBlob(storageImpl storage.Storage, userID, filename, contentType string, data []byte) (map[string]interface{}, error) {
|
||||
location := common.GenerateUUID()
|
||||
bucket := fmt.Sprintf("%s-downloads", userID)
|
||||
if err := storageImpl.Put(bucket, location, data); err != nil {
|
||||
return nil, fmt.Errorf("failed to store file: %w", err)
|
||||
}
|
||||
ext := ""
|
||||
if idx := strings.LastIndex(filename, "."); idx >= 0 {
|
||||
ext = strings.ToLower(filename[idx+1:])
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"id": location,
|
||||
"name": filename,
|
||||
"size": int64(len(data)),
|
||||
"extension": ext,
|
||||
"mime_type": contentType,
|
||||
"created_by": userID,
|
||||
"created_at": float64(time.Now().UnixMilli()) / 1000.0,
|
||||
"preview_url": nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeRemoteUploadFilename(rawURL, contentType string, data []byte) string {
|
||||
parsed, err := url.Parse(rawURL)
|
||||
filename := "download"
|
||||
if err == nil {
|
||||
filename = sanitizeFilename(filepath.Base(parsed.Path))
|
||||
}
|
||||
ct := strings.ToLower(strings.TrimSpace(strings.Split(contentType, ";")[0]))
|
||||
if ct == "application/pdf" || bytesLooksLikePDF(data) {
|
||||
if !strings.HasSuffix(strings.ToLower(filename), ".pdf") {
|
||||
filename += ".pdf"
|
||||
}
|
||||
}
|
||||
return filename
|
||||
}
|
||||
|
||||
func normalizeUploadInfoContent(filename, contentType string, data []byte) (string, string, []byte) {
|
||||
lowerCT := strings.ToLower(strings.TrimSpace(strings.Split(contentType, ";")[0]))
|
||||
if lowerCT == "" {
|
||||
lowerCT = http.DetectContentType(data)
|
||||
}
|
||||
|
||||
if lowerCT == "application/pdf" || bytesLooksLikePDF(data) {
|
||||
if !strings.HasSuffix(strings.ToLower(filename), ".pdf") {
|
||||
filename += ".pdf"
|
||||
}
|
||||
lowerCT = "application/pdf"
|
||||
}
|
||||
if lowerCT == "text/html" || lowerCT == "application/xhtml+xml" || looksLikeHTML(data) {
|
||||
data = htmlToReadableMarkdown(data)
|
||||
if lowerCT == "" {
|
||||
lowerCT = "text/html"
|
||||
}
|
||||
}
|
||||
return filename, lowerCT, data
|
||||
}
|
||||
|
||||
func bytesLooksLikePDF(data []byte) bool {
|
||||
return len(data) >= 4 && string(data[:4]) == "%PDF"
|
||||
}
|
||||
|
||||
func looksLikeHTML(data []byte) bool {
|
||||
snippet := strings.ToLower(string(data))
|
||||
return strings.Contains(snippet, "<html") || strings.Contains(snippet, "<body") || strings.Contains(snippet, "<div")
|
||||
}
|
||||
|
||||
var (
|
||||
htmlScriptStyleRE = regexp.MustCompile(`(?is)<(script|style)[^>]*>.*?</(script|style)>`)
|
||||
htmlTagRE = regexp.MustCompile(`(?s)<[^>]+>`)
|
||||
multiSpaceRE = regexp.MustCompile(`[ \t]+`)
|
||||
multiNewlineRE = regexp.MustCompile(`\n{3,}`)
|
||||
)
|
||||
|
||||
func htmlToReadableMarkdown(data []byte) []byte {
|
||||
text := string(data)
|
||||
text = htmlScriptStyleRE.ReplaceAllString(text, " ")
|
||||
text = strings.ReplaceAll(text, "<br>", "\n")
|
||||
text = strings.ReplaceAll(text, "<br/>", "\n")
|
||||
text = strings.ReplaceAll(text, "<br />", "\n")
|
||||
text = strings.ReplaceAll(text, "</p>", "\n\n")
|
||||
text = strings.ReplaceAll(text, "</div>", "\n")
|
||||
text = strings.ReplaceAll(text, "</li>", "\n")
|
||||
text = htmlTagRE.ReplaceAllString(text, " ")
|
||||
text = html.UnescapeString(text)
|
||||
text = strings.ReplaceAll(text, "\r", "\n")
|
||||
text = multiSpaceRE.ReplaceAllString(text, " ")
|
||||
text = multiNewlineRE.ReplaceAllString(text, "\n\n")
|
||||
text = strings.TrimSpace(text)
|
||||
return []byte(text)
|
||||
}
|
||||
|
||||
// reservedDeviceNames are Windows reserved filenames that must never be used.
|
||||
var reservedDeviceNames = map[string]bool{
|
||||
"CON": true, "PRN": true, "AUX": true, "NUL": true,
|
||||
"COM1": true, "COM2": true, "COM3": true, "COM4": true, "COM5": true,
|
||||
"COM6": true, "COM7": true, "COM8": true, "COM9": true,
|
||||
"LPT1": true, "LPT2": true, "LPT3": true, "LPT4": true, "LPT5": true,
|
||||
"LPT6": true, "LPT7": true, "LPT8": true, "LPT9": true,
|
||||
}
|
||||
|
||||
// sanitizeFilename produces a safe, filesystem-friendly filename from an
|
||||
// arbitrary URL path segment: it strips directory components, replaces unsafe /
|
||||
// control characters, rejects reserved names, bounds the length, and falls back
|
||||
// to "download".
|
||||
func sanitizeFilename(name string) string {
|
||||
name = filepath.Base(name)
|
||||
name = strings.TrimSpace(name)
|
||||
|
||||
name = strings.Map(func(r rune) rune {
|
||||
switch r {
|
||||
case '/', '\\', ':', '*', '?', '"', '<', '>', '|', 0:
|
||||
return '_'
|
||||
}
|
||||
if r < 0x20 { // control characters
|
||||
return '_'
|
||||
}
|
||||
return r
|
||||
}, name)
|
||||
|
||||
// Strip leading/trailing dots and spaces to avoid hidden or reserved forms.
|
||||
name = strings.Trim(name, ". ")
|
||||
|
||||
if name == "" || name == "." || name == ".." {
|
||||
return "download"
|
||||
}
|
||||
if stem := strings.SplitN(strings.ToUpper(name), ".", 2)[0]; reservedDeviceNames[stem] {
|
||||
return "download"
|
||||
}
|
||||
if len(name) > 255 {
|
||||
name = name[:255]
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
@@ -3,6 +3,10 @@ package service
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -15,6 +19,7 @@ type fakeStorage struct {
|
||||
lastFnm string
|
||||
blob []byte
|
||||
err error
|
||||
exists bool
|
||||
}
|
||||
|
||||
func (f *fakeStorage) Health() bool {
|
||||
@@ -22,7 +27,11 @@ func (f *fakeStorage) Health() bool {
|
||||
}
|
||||
|
||||
func (f *fakeStorage) Put(bucket, fnm string, binary []byte, tenantID ...string) error {
|
||||
panic("not implemented in fakeStorage")
|
||||
f.lastBucket = bucket
|
||||
f.lastFnm = fnm
|
||||
f.blob = binary
|
||||
f.exists = true
|
||||
return f.err
|
||||
}
|
||||
|
||||
func (f *fakeStorage) Get(bucket, fnm string, tenantID ...string) ([]byte, error) {
|
||||
@@ -36,7 +45,7 @@ func (f *fakeStorage) Remove(bucket, fnm string, tenantID ...string) error {
|
||||
}
|
||||
|
||||
func (f *fakeStorage) ObjExist(bucket, fnm string, tenantID ...string) bool {
|
||||
panic("not implemented in fakeStorage")
|
||||
return f.exists && f.lastBucket == bucket && f.lastFnm == fnm
|
||||
}
|
||||
|
||||
func (f *fakeStorage) GetPresignedURL(bucket, fnm string, expires time.Duration, tenantID ...string) (string, error) {
|
||||
@@ -124,3 +133,130 @@ func TestFileService_DownloadAgentFile_Error(t *testing.T) {
|
||||
t.Errorf("expected nil blob, got %v", blob)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileService_UploadFromURL_PDFAddsExtensionAndStoresToDownloads(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/pdf")
|
||||
_, _ = w.Write([]byte("%PDF-1.7 fake pdf"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origAssert := assertURLSafe
|
||||
origPinned := pinnedHTTPClient
|
||||
assertURLSafe = func(rawURL string) (string, string, error) {
|
||||
return "127.0.0.1", "127.0.0.1", nil
|
||||
}
|
||||
pinnedHTTPClient = func(hostname, resolvedIP string, timeout time.Duration) *http.Client {
|
||||
return server.Client()
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
assertURLSafe = origAssert
|
||||
pinnedHTTPClient = origPinned
|
||||
})
|
||||
|
||||
mockStorage := &fakeStorage{}
|
||||
factory := storage.GetStorageFactory()
|
||||
originalStorage := factory.GetStorage()
|
||||
factory.SetStorage(mockStorage)
|
||||
t.Cleanup(func() { factory.SetStorage(originalStorage) })
|
||||
|
||||
svc := NewFileService()
|
||||
resp, err := svc.UploadFromURL("tenant123", server.URL+"/report")
|
||||
if err != nil {
|
||||
t.Fatalf("UploadFromURL failed: %v", err)
|
||||
}
|
||||
|
||||
if mockStorage.lastBucket != "tenant123-downloads" {
|
||||
t.Fatalf("bucket = %q", mockStorage.lastBucket)
|
||||
}
|
||||
if resp["name"] != "report.pdf" {
|
||||
t.Fatalf("name = %#v, want report.pdf", resp["name"])
|
||||
}
|
||||
if resp["mime_type"] != "application/pdf" {
|
||||
t.Fatalf("mime_type = %#v", resp["mime_type"])
|
||||
}
|
||||
if resp["id"] != mockStorage.lastFnm {
|
||||
t.Fatalf("id = %#v, stored key = %q", resp["id"], mockStorage.lastFnm)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileService_UploadFromURL_HTMLNormalizesReadableContent(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
_, _ = w.Write([]byte(`<html><head><script>bad()</script></head><body><div>Hello</div><p>World</p></body></html>`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origAssert := assertURLSafe
|
||||
origPinned := pinnedHTTPClient
|
||||
assertURLSafe = func(rawURL string) (string, string, error) {
|
||||
return "127.0.0.1", "127.0.0.1", nil
|
||||
}
|
||||
pinnedHTTPClient = func(hostname, resolvedIP string, timeout time.Duration) *http.Client {
|
||||
return server.Client()
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
assertURLSafe = origAssert
|
||||
pinnedHTTPClient = origPinned
|
||||
})
|
||||
|
||||
mockStorage := &fakeStorage{}
|
||||
factory := storage.GetStorageFactory()
|
||||
originalStorage := factory.GetStorage()
|
||||
factory.SetStorage(mockStorage)
|
||||
t.Cleanup(func() { factory.SetStorage(originalStorage) })
|
||||
|
||||
svc := NewFileService()
|
||||
resp, err := svc.UploadFromURL("tenant123", server.URL+"/page")
|
||||
if err != nil {
|
||||
t.Fatalf("UploadFromURL failed: %v", err)
|
||||
}
|
||||
|
||||
stored := string(mockStorage.blob)
|
||||
if strings.Contains(strings.ToLower(stored), "<html") || strings.Contains(strings.ToLower(stored), "<script") {
|
||||
t.Fatalf("stored html was not normalized: %q", stored)
|
||||
}
|
||||
if !strings.Contains(stored, "Hello") || !strings.Contains(stored, "World") {
|
||||
t.Fatalf("stored normalized text missing content: %q", stored)
|
||||
}
|
||||
if resp["mime_type"] != "text/html" {
|
||||
t.Fatalf("mime_type = %#v", resp["mime_type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeUploadInfoContent_PDFTakesPrecedenceOverHTML(t *testing.T) {
|
||||
filename, contentType, data := normalizeUploadInfoContent(
|
||||
"report",
|
||||
"text/html",
|
||||
[]byte("%PDF-1.7 fake pdf"),
|
||||
)
|
||||
if filename != "report.pdf" {
|
||||
t.Fatalf("filename = %q, want report.pdf", filename)
|
||||
}
|
||||
if contentType != "application/pdf" {
|
||||
t.Fatalf("contentType = %q, want application/pdf", contentType)
|
||||
}
|
||||
if !bytes.Equal(data, []byte("%PDF-1.7 fake pdf")) {
|
||||
t.Fatalf("pdf bytes were unexpectedly transformed: %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadUploadInfoData_RejectsOversizedInput(t *testing.T) {
|
||||
reader := io.LimitReader(zeroReader{}, maxRemoteFileSize+1)
|
||||
_, err := readUploadInfoData(reader)
|
||||
if err == nil {
|
||||
t.Fatal("expected oversized input to be rejected")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "file size exceeds") {
|
||||
t.Fatalf("err = %v, want size limit message", err)
|
||||
}
|
||||
}
|
||||
|
||||
type zeroReader struct{}
|
||||
|
||||
func (zeroReader) Read(p []byte) (int, error) {
|
||||
for i := range p {
|
||||
p[i] = 0
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
@@ -422,10 +422,16 @@ func mergeFieldValues(existing, new interface{}) []interface{} {
|
||||
if val != "" {
|
||||
result = append(result, val)
|
||||
}
|
||||
case float64, float32, int, int8, int16, int32, int64, bool:
|
||||
result = append(result, val)
|
||||
case []interface{}:
|
||||
for _, item := range val {
|
||||
addValue(item)
|
||||
}
|
||||
case []string:
|
||||
for _, item := range val {
|
||||
addValue(item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user