mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
feat(go-api): add Go support for POST /api/v1/datasets/{dataset_id}/documents/{document_id}/chunks (#16256)
## Summary
Add the Go implementation of `POST
/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks`.
This wires the full create-chunk path in Go:
- router and handler registration
- request/response structs
- chunk creation service logic
- embedding generation
- chunk insert into doc engine
- chunk/token counter increment
- `tag_feas` validation
- `image_base64` decoding and chunk image storage/merge
- unit tests for handler and service
## Testing
Unit tests:
- `/usr/local/go/bin/go test ./internal/handler`
- `/usr/local/go/bin/go test ./internal/service/chunk`
- `/usr/local/go/bin/go test ./internal/service`
- `/usr/local/go/bin/go test ./...`
All passed locally.
Manual curl checks:
- basic text chunk: Go passed
- chunk with `important_keywords` / `questions` / `tag_kwd` /
`tag_feas`: Go passed
- blank content validation: Go matched expected `code=102`
- invalid `image_base64` validation: Go matched expected `code=102`
- image upload and repeated image upload / merge path: Go passed twice
This commit is contained in:
@@ -17,6 +17,7 @@ package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"ragflow/internal/common"
|
||||
@@ -38,6 +39,7 @@ type chunkService interface {
|
||||
UpdateChunk(req *service.UpdateChunkRequest, userID string) error
|
||||
RemoveChunks(req *service.RemoveChunksRequest, userID string) (int64, error)
|
||||
Parse(userID, datasetID string, req *service.ParseFileRequest) (map[string]interface{}, common.ErrorCode, error)
|
||||
AddChunk(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error)
|
||||
StopParsing(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error)
|
||||
}
|
||||
|
||||
@@ -788,3 +790,131 @@ func (h *ChunkHandler) RemoveChunks(c *gin.Context) {
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
func addChunkStringField(rawBody map[string]json.RawMessage, field string) (string, error) {
|
||||
raw, ok := rawBody[field]
|
||||
if !ok {
|
||||
return "", nil
|
||||
}
|
||||
var value string
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return "", fmt.Errorf("`%s` must be a string", field)
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func addChunkStringPtrField(rawBody map[string]json.RawMessage, field string) (*string, error) {
|
||||
raw, ok := rawBody[field]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
var value string
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return nil, fmt.Errorf("`%s` must be a string", field)
|
||||
}
|
||||
return &value, nil
|
||||
}
|
||||
|
||||
func addChunkStringListField(rawBody map[string]json.RawMessage, field, listMessage, elementMessage string) ([]string, error) {
|
||||
raw, ok := rawBody[field]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
var values []interface{}
|
||||
if err := json.Unmarshal(raw, &values); err != nil {
|
||||
return nil, errors.New(listMessage)
|
||||
}
|
||||
result := make([]string, len(values))
|
||||
for i, value := range values {
|
||||
str, ok := value.(string)
|
||||
if !ok {
|
||||
return nil, errors.New(elementMessage)
|
||||
}
|
||||
result[i] = str
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func addChunkResponseMessage(code common.ErrorCode, err error) string {
|
||||
if code == common.CodeServerError {
|
||||
common.Warn("add chunk failed", zap.String("error", err.Error()))
|
||||
return "Failed to add chunk"
|
||||
}
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
func (h *ChunkHandler) AddChunk(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
userID := user.ID
|
||||
datasetID, documentID := strings.TrimSpace(c.Param("dataset_id")), strings.TrimSpace(c.Param("document_id"))
|
||||
|
||||
var rawBody map[string]json.RawMessage
|
||||
if err := json.NewDecoder(c.Request.Body).Decode(&rawBody); err != nil {
|
||||
jsonError(c, common.CodeArgumentError, err.Error())
|
||||
return
|
||||
}
|
||||
content, err := addChunkStringField(rawBody, "content")
|
||||
if err != nil {
|
||||
jsonError(c, common.CodeArgumentError, err.Error())
|
||||
return
|
||||
}
|
||||
importantKeywords, err := addChunkStringListField(rawBody, "important_keywords", "`important_keywords` is required to be a list", "`important_keywords` must be a list of strings")
|
||||
if err != nil {
|
||||
jsonError(c, common.CodeArgumentError, err.Error())
|
||||
return
|
||||
}
|
||||
questions, err := addChunkStringListField(rawBody, "questions", "`questions` is required to be a list", "`questions` must be a list of strings")
|
||||
if err != nil {
|
||||
jsonError(c, common.CodeArgumentError, err.Error())
|
||||
return
|
||||
}
|
||||
tagKwd, err := addChunkStringListField(rawBody, "tag_kwd", "`tag_kwd` is required to be a list", "`tag_kwd` must be a list of strings")
|
||||
if err != nil {
|
||||
jsonError(c, common.CodeArgumentError, err.Error())
|
||||
return
|
||||
}
|
||||
imageBase64, err := addChunkStringPtrField(rawBody, "image_base64")
|
||||
if err != nil {
|
||||
jsonError(c, common.CodeArgumentError, err.Error())
|
||||
return
|
||||
}
|
||||
var tagFeas interface{}
|
||||
if raw, ok := rawBody["tag_feas"]; ok {
|
||||
if err := json.Unmarshal(raw, &tagFeas); err != nil {
|
||||
jsonError(c, common.CodeArgumentError, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
req := service.AddChunkRequest{
|
||||
DatasetID: datasetID,
|
||||
DocumentID: documentID,
|
||||
Content: content,
|
||||
ImportantKeywords: importantKeywords,
|
||||
Questions: questions,
|
||||
TagKwd: tagKwd,
|
||||
TagFeas: tagFeas,
|
||||
ImageBase64: imageBase64,
|
||||
}
|
||||
|
||||
resp, err := h.chunkService.AddChunk(&req, userID)
|
||||
if err != nil {
|
||||
if codedErr, ok := err.(service.ErrorCoder); ok {
|
||||
jsonError(c, codedErr.Code(), addChunkResponseMessage(codedErr.Code(), err))
|
||||
return
|
||||
}
|
||||
jsonError(c, common.CodeServerError, addChunkResponseMessage(common.CodeServerError, err))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"data": resp,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
// Only the methods actually called by the test are set; others panic.
|
||||
type mockChunkSvc struct {
|
||||
retrievalTestFn func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error)
|
||||
addChunkFn func(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error)
|
||||
listFn func(req *service.ListChunksRequest, userID string) (*service.ListChunksResponse, error)
|
||||
switchChunksFn func(userID, datasetID, documentID string, availableInt int, chunkIDs []string) error
|
||||
updateChunkFn func(req *service.UpdateChunkRequest, userID string) error
|
||||
@@ -68,6 +69,12 @@ func (m *mockChunkSvc) StopParsing(userID, datasetID string, req service.StopPar
|
||||
func (m *mockChunkSvc) Parse(string, string, *service.ParseFileRequest) (map[string]interface{}, common.ErrorCode, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
func (m *mockChunkSvc) AddChunk(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) {
|
||||
if m.addChunkFn != nil {
|
||||
return m.addChunkFn(req, userID)
|
||||
}
|
||||
return &service.AddChunkResponse{Chunk: map[string]interface{}{"id": "chunk-1"}}, nil
|
||||
}
|
||||
|
||||
func setupChunkRetrievalTest(userID string) (*gin.Engine, *mockChunkSvc) {
|
||||
mock := &mockChunkSvc{}
|
||||
@@ -563,3 +570,194 @@ func TestChunkRetrieval_ServiceError(t *testing.T) {
|
||||
t.Errorf("internal error details leaked to response: %q", msg)
|
||||
}
|
||||
}
|
||||
|
||||
type addChunkTestError struct {
|
||||
code common.ErrorCode
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e addChunkTestError) Error() string { return e.msg }
|
||||
func (e addChunkTestError) Code() common.ErrorCode { return e.code }
|
||||
|
||||
func TestChunkHandlerAddChunkSuccess(t *testing.T) {
|
||||
mock := &mockChunkSvc{
|
||||
addChunkFn: func(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) {
|
||||
if userID != "user1" {
|
||||
t.Fatalf("userID = %q, want user1", userID)
|
||||
}
|
||||
if req.DatasetID != "kb1" || req.DocumentID != "doc1" || req.Content != "chunk body" {
|
||||
t.Fatalf("unexpected request: %#v", req)
|
||||
}
|
||||
return &service.AddChunkResponse{Chunk: map[string]interface{}{"id": "chunk-1", "content": req.Content}}, nil
|
||||
},
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: "user1"})
|
||||
})
|
||||
h := &ChunkHandler{chunkService: mock}
|
||||
r.POST("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.AddChunk)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/datasets/kb1/documents/doc1/chunks", strings.NewReader(`{"content":"chunk body"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected success code, got %v", resp["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkHandlerAddChunkPathIDsOverrideBody(t *testing.T) {
|
||||
mock := &mockChunkSvc{
|
||||
addChunkFn: func(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) {
|
||||
if req.DatasetID != "kb1" || req.DocumentID != "doc1" {
|
||||
t.Fatalf("path IDs were not preserved: %#v", req)
|
||||
}
|
||||
if req.Content != "chunk body" {
|
||||
t.Fatalf("unexpected content: %#v", req)
|
||||
}
|
||||
return &service.AddChunkResponse{Chunk: map[string]interface{}{"id": "chunk-1"}}, nil
|
||||
},
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: "user1"})
|
||||
})
|
||||
h := &ChunkHandler{chunkService: mock}
|
||||
r.POST("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.AddChunk)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/datasets/kb1/documents/doc1/chunks", strings.NewReader(`{"dataset_id":"evil-kb","document_id":"evil-doc","content":"chunk body"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkHandlerAddChunkCodedError(t *testing.T) {
|
||||
mock := &mockChunkSvc{
|
||||
addChunkFn: func(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) {
|
||||
return nil, addChunkTestError{code: common.CodeDataError, msg: "`content` is required"}
|
||||
},
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: "user1"})
|
||||
})
|
||||
h := &ChunkHandler{chunkService: mock}
|
||||
r.POST("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.AddChunk)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/datasets/kb1/documents/doc1/chunks", strings.NewReader(`{"content":" "}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeDataError) {
|
||||
t.Fatalf("expected data error code, got %v", resp["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkHandlerAddChunkValidatesListFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
name: "important keywords type",
|
||||
body: `{"content":"chunk body","important_keywords":{}}`,
|
||||
wantMsg: "`important_keywords` is required to be a list",
|
||||
},
|
||||
{
|
||||
name: "tag kwd element type",
|
||||
body: `{"content":"chunk body","tag_kwd":[1]}`,
|
||||
wantMsg: "`tag_kwd` must be a list of strings",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mock := &mockChunkSvc{
|
||||
addChunkFn: func(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) {
|
||||
t.Fatal("service should not be called for invalid request")
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: "user1"})
|
||||
})
|
||||
h := &ChunkHandler{chunkService: mock}
|
||||
r.POST("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.AddChunk)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/datasets/kb1/documents/doc1/chunks", strings.NewReader(tt.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["message"] != tt.wantMsg {
|
||||
t.Fatalf("message = %v, want %q", resp["message"], tt.wantMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkHandlerAddChunkHidesServerErrorDetails(t *testing.T) {
|
||||
mock := &mockChunkSvc{
|
||||
addChunkFn: func(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) {
|
||||
return nil, addChunkTestError{code: common.CodeServerError, msg: "encode chunk embedding: provider secret"}
|
||||
},
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: "user1"})
|
||||
})
|
||||
h := &ChunkHandler{chunkService: mock}
|
||||
r.POST("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.AddChunk)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/datasets/kb1/documents/doc1/chunks", strings.NewReader(`{"content":"chunk body"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["message"] != "Failed to add chunk" {
|
||||
t.Fatalf("message = %v, want generic failure", resp["message"])
|
||||
}
|
||||
if strings.Contains(w.Body.String(), "provider secret") {
|
||||
t.Fatalf("server error details leaked: %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -367,9 +367,6 @@ func TestEngine_ChainOf100(t *testing.T) {
|
||||
if m == nil {
|
||||
m = map[string]any{}
|
||||
}
|
||||
if m == nil {
|
||||
m = map[string]any{}
|
||||
}
|
||||
m["value"] = i
|
||||
return m, nil
|
||||
})
|
||||
@@ -380,14 +377,16 @@ func TestEngine_ChainOf100(t *testing.T) {
|
||||
|
||||
engine := NewEngine(sg,
|
||||
WithRecursionLimit(150),
|
||||
WithMaxConcurrency(4),
|
||||
)
|
||||
|
||||
result, err := engine.RunSync(context.Background(), map[string]any{"value": "start"})
|
||||
if err != nil {
|
||||
t.Fatalf("RunSync: %v", err)
|
||||
}
|
||||
m := result.(map[string]any)
|
||||
m, ok := result.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected map[string]any result, got %T", result)
|
||||
}
|
||||
if v, ok := m["value"]; !ok || v.(int) != 99 {
|
||||
t.Fatalf("expected value=99, got %v", m["value"])
|
||||
}
|
||||
|
||||
@@ -310,6 +310,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
datasets.GET("/:dataset_id/documents/:document_id", r.documentHandler.DownloadDocument)
|
||||
datasets.PATCH("/:dataset_id/documents/:document_id", r.documentHandler.UpdateDatasetDocument)
|
||||
datasets.DELETE("/:dataset_id/documents", r.documentHandler.DeleteDocuments)
|
||||
datasets.POST("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.AddChunk)
|
||||
|
||||
// Dataset document chunk
|
||||
datasets.GET("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.ListChunks)
|
||||
|
||||
@@ -20,11 +20,17 @@ import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/csv"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/draw"
|
||||
"image/jpeg"
|
||||
"io"
|
||||
"math"
|
||||
"math/rand"
|
||||
"path/filepath"
|
||||
"ragflow/internal/common"
|
||||
@@ -36,10 +42,15 @@ import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
_ "image/gif"
|
||||
_ "image/png"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/engine"
|
||||
@@ -56,6 +67,16 @@ const (
|
||||
maximumTaskPageNumber = maximumPageNumber * 1000
|
||||
)
|
||||
|
||||
var chunkImageMergeLocks = struct {
|
||||
sync.Mutex
|
||||
locks map[string]*chunkImageMergeLock
|
||||
}{locks: make(map[string]*chunkImageMergeLock)}
|
||||
|
||||
type chunkImageMergeLock struct {
|
||||
mu sync.Mutex
|
||||
refs int
|
||||
}
|
||||
|
||||
// ChunkService chunk service
|
||||
type ChunkService struct {
|
||||
docEngine engine.DocEngine
|
||||
@@ -74,6 +95,12 @@ type ChunkService struct {
|
||||
queueParseTasksFunc func(*entity.Document, string, string, int64) error
|
||||
beginParseDocumentFunc func(string) error
|
||||
deleteTasksByDocIDsFunc func([]string) (int64, error)
|
||||
getEmbeddingModelFunc func(string, string) (*models.EmbeddingModel, error)
|
||||
incrementChunkStatsFunc func(string, string, int64, int64, float64) error
|
||||
storeChunkImageFunc func(string, string, []byte) error
|
||||
tokenizeFunc func(string) (string, error)
|
||||
fineGrainedTokenizeFunc func(string) (string, error)
|
||||
numTokensFunc func(string) int
|
||||
}
|
||||
|
||||
// NewChunkService creates chunk service
|
||||
@@ -1711,3 +1738,382 @@ func (s *ChunkService) RemoveChunks(req *service.RemoveChunksRequest, userID str
|
||||
|
||||
return deletedCount, nil
|
||||
}
|
||||
|
||||
func (s *ChunkService) AddChunk(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) {
|
||||
if s.docEngine == nil {
|
||||
return nil, addChunkError{code: common.CodeServerError, message: "doc engine not initialized"}
|
||||
}
|
||||
if req == nil {
|
||||
return nil, addChunkError{code: common.CodeDataError, message: "invalid request payload"}
|
||||
}
|
||||
if !s.accessible(req.DatasetID, userID) {
|
||||
return nil, addChunkError{code: common.CodeDataError, message: fmt.Sprintf("You don't own the dataset %s.", req.DatasetID)}
|
||||
}
|
||||
|
||||
kb, err := s.getKnowledgebaseByID(req.DatasetID)
|
||||
if err != nil || kb == nil {
|
||||
return nil, addChunkError{code: common.CodeDataError, message: fmt.Sprintf("You don't own the dataset %s.", req.DatasetID)}
|
||||
}
|
||||
|
||||
doc, err := s.documentDAO.GetByDocumentIDAndDatasetID(req.DocumentID, req.DatasetID)
|
||||
if err != nil || doc == nil {
|
||||
return nil, addChunkError{code: common.CodeDataError, message: fmt.Sprintf("You don't own the document %s.", req.DocumentID)}
|
||||
}
|
||||
|
||||
content := strings.TrimSpace(req.Content)
|
||||
if content == "" {
|
||||
return nil, addChunkError{code: common.CodeDataError, message: "`content` is required"}
|
||||
}
|
||||
|
||||
var tagFeas map[string]float64
|
||||
if req.TagFeas != nil {
|
||||
tagFeas, err = validateTagFeatures(req.TagFeas)
|
||||
if err != nil {
|
||||
return nil, addChunkError{code: common.CodeDataError, message: "`tag_feas` " + err.Error()}
|
||||
}
|
||||
}
|
||||
|
||||
chunkID := strconv.FormatUint(xxhash.Sum64([]byte(req.Content+req.DocumentID)), 16)
|
||||
indexName := fmt.Sprintf("ragflow_%s", kb.TenantID)
|
||||
contentLtks, err := s.tokenize(req.Content)
|
||||
if err != nil {
|
||||
return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("tokenize content: %v", err)}
|
||||
}
|
||||
contentSmLtks, err := s.fineGrainedTokenize(contentLtks)
|
||||
if err != nil {
|
||||
return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("tokenize content fine-grained: %v", err)}
|
||||
}
|
||||
importantTks, err := s.tokenize(strings.Join(req.ImportantKeywords, " "))
|
||||
if err != nil {
|
||||
return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("tokenize important keywords: %v", err)}
|
||||
}
|
||||
questionKwd := filterTrimmedStrings(req.Questions)
|
||||
questionTks, err := s.tokenize(strings.Join(req.Questions, "\n"))
|
||||
if err != nil {
|
||||
return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("tokenize questions: %v", err)}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
docName := ""
|
||||
if doc.Name != nil {
|
||||
docName = *doc.Name
|
||||
}
|
||||
importantKeywords := req.ImportantKeywords
|
||||
if importantKeywords == nil {
|
||||
importantKeywords = []string{}
|
||||
}
|
||||
|
||||
chunkData := map[string]interface{}{
|
||||
"id": chunkID,
|
||||
"content_with_weight": req.Content,
|
||||
"content_ltks": contentLtks,
|
||||
"content_sm_ltks": contentSmLtks,
|
||||
"important_kwd": importantKeywords,
|
||||
"important_tks": importantTks,
|
||||
"question_kwd": questionKwd,
|
||||
"question_tks": questionTks,
|
||||
"create_time": now.Format("2006-01-02 15:04:05"),
|
||||
"create_timestamp_flt": float64(now.UnixNano()) / float64(time.Second),
|
||||
"kb_id": req.DatasetID,
|
||||
"docnm_kwd": docName,
|
||||
"doc_id": req.DocumentID,
|
||||
}
|
||||
if req.TagKwd != nil {
|
||||
chunkData["tag_kwd"] = req.TagKwd
|
||||
}
|
||||
if tagFeas != nil {
|
||||
chunkData["tag_feas"] = tagFeas
|
||||
}
|
||||
|
||||
if req.ImageBase64 != nil {
|
||||
imageBinary, err := decodeChunkImageBase64(*req.ImageBase64)
|
||||
if err != nil {
|
||||
return nil, addChunkError{code: common.CodeDataError, message: err.Error()}
|
||||
}
|
||||
if err := s.storeChunkImage(req.DatasetID, chunkID, imageBinary); err != nil {
|
||||
return nil, addChunkError{code: common.CodeDataError, message: "Failed to store chunk image"}
|
||||
}
|
||||
chunkData["img_id"] = fmt.Sprintf("%s-%s", req.DatasetID, chunkID)
|
||||
chunkData["doc_type_kwd"] = "image"
|
||||
}
|
||||
|
||||
embeddingModel, err := s.getEmbeddingModel(kb.TenantID, kb.EmbdID)
|
||||
if err != nil {
|
||||
return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("get embedding model: %v", err)}
|
||||
}
|
||||
embeddingText := req.Content
|
||||
if len(questionKwd) > 0 {
|
||||
embeddingText = strings.Join(questionKwd, "\n")
|
||||
}
|
||||
embeddings, err := embeddingModel.ModelDriver.Embed(embeddingModel.ModelName, []string{docName, embeddingText}, embeddingModel.APIConfig, &models.EmbeddingConfig{Dimension: 0})
|
||||
if err != nil {
|
||||
return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("encode chunk embedding: %v", err)}
|
||||
}
|
||||
if len(embeddings) != 2 {
|
||||
return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("unexpected embedding count: %d", len(embeddings))}
|
||||
}
|
||||
mergedVec, err := mergeChunkEmbeddings(embeddings[0].Embedding, embeddings[1].Embedding)
|
||||
if err != nil {
|
||||
return nil, addChunkError{code: common.CodeServerError, message: err.Error()}
|
||||
}
|
||||
chunkData[fmt.Sprintf("q_%d_vec", len(mergedVec))] = mergedVec
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 600*time.Second)
|
||||
defer cancel()
|
||||
if _, err := s.docEngine.InsertChunks(ctx, []map[string]interface{}{chunkData}, indexName, req.DatasetID); err != nil {
|
||||
return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("insert chunk: %v", err)}
|
||||
}
|
||||
|
||||
tokenNum := int64(s.numTokens(req.Content))
|
||||
if err := s.incrementChunkStats(req.DocumentID, req.DatasetID, tokenNum, 1, 0); err != nil {
|
||||
return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("increment chunk stats: %v", err)}
|
||||
}
|
||||
|
||||
renamedChunk := map[string]interface{}{
|
||||
"id": chunkID,
|
||||
"content": req.Content,
|
||||
"document_id": req.DocumentID,
|
||||
"document": docName,
|
||||
"important_keywords": importantKeywords,
|
||||
"questions": questionKwd,
|
||||
"dataset_id": req.DatasetID,
|
||||
"create_timestamp": chunkData["create_timestamp_flt"],
|
||||
"create_time": chunkData["create_time"],
|
||||
}
|
||||
if req.TagKwd != nil {
|
||||
renamedChunk["tag_kwd"] = req.TagKwd
|
||||
}
|
||||
if imgID, ok := chunkData["img_id"]; ok {
|
||||
renamedChunk["image_id"] = imgID
|
||||
}
|
||||
|
||||
return &service.AddChunkResponse{Chunk: renamedChunk}, nil
|
||||
}
|
||||
|
||||
type addChunkError struct {
|
||||
code common.ErrorCode
|
||||
message string
|
||||
}
|
||||
|
||||
func (e addChunkError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
func (e addChunkError) Code() common.ErrorCode {
|
||||
return e.code
|
||||
}
|
||||
|
||||
func validateTagFeatures(raw interface{}) (map[string]float64, error) {
|
||||
parsed, ok := raw.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("must be an object mapping string tags to finite numeric scores")
|
||||
}
|
||||
cleaned := make(map[string]float64, len(parsed))
|
||||
for key, value := range parsed {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return nil, fmt.Errorf("keys must be non-empty strings")
|
||||
}
|
||||
switch typed := value.(type) {
|
||||
case float64:
|
||||
if math.IsNaN(typed) || math.IsInf(typed, 0) {
|
||||
return nil, fmt.Errorf("values must be finite numbers")
|
||||
}
|
||||
cleaned[key] = typed
|
||||
case float32:
|
||||
if math.IsNaN(float64(typed)) || math.IsInf(float64(typed), 0) {
|
||||
return nil, fmt.Errorf("values must be finite numbers")
|
||||
}
|
||||
cleaned[key] = float64(typed)
|
||||
case int:
|
||||
cleaned[key] = float64(typed)
|
||||
case int8:
|
||||
cleaned[key] = float64(typed)
|
||||
case int16:
|
||||
cleaned[key] = float64(typed)
|
||||
case int32:
|
||||
cleaned[key] = float64(typed)
|
||||
case int64:
|
||||
cleaned[key] = float64(typed)
|
||||
default:
|
||||
return nil, fmt.Errorf("values must be finite numbers")
|
||||
}
|
||||
}
|
||||
return cleaned, nil
|
||||
}
|
||||
|
||||
func decodeChunkImageBase64(raw string) ([]byte, error) {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return nil, fmt.Errorf("`image_base64` must be a non-empty string")
|
||||
}
|
||||
imageBinary, err := base64.StdEncoding.Strict().DecodeString(raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Invalid `image_base64`")
|
||||
}
|
||||
if len(imageBinary) == 0 {
|
||||
return nil, fmt.Errorf("`image_base64` is empty")
|
||||
}
|
||||
return imageBinary, nil
|
||||
}
|
||||
|
||||
func mergeChunkEmbeddings(a, b []float64) ([]float64, error) {
|
||||
if len(a) == 0 || len(b) == 0 || len(a) != len(b) {
|
||||
return nil, fmt.Errorf("unexpected embedding dimensions")
|
||||
}
|
||||
merged := make([]float64, len(a))
|
||||
for i := range a {
|
||||
merged[i] = 0.1*a[i] + 0.9*b[i]
|
||||
}
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
func filterTrimmedStrings(values []string) []string {
|
||||
filtered := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed != "" {
|
||||
filtered = append(filtered, trimmed)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func (s *ChunkService) tokenize(text string) (string, error) {
|
||||
if s.tokenizeFunc != nil {
|
||||
return s.tokenizeFunc(text)
|
||||
}
|
||||
return tokenizer.Tokenize(text)
|
||||
}
|
||||
|
||||
func (s *ChunkService) fineGrainedTokenize(text string) (string, error) {
|
||||
if s.fineGrainedTokenizeFunc != nil {
|
||||
return s.fineGrainedTokenizeFunc(text)
|
||||
}
|
||||
return tokenizer.FineGrainedTokenize(text)
|
||||
}
|
||||
|
||||
func (s *ChunkService) numTokens(text string) int {
|
||||
if s.numTokensFunc != nil {
|
||||
return s.numTokensFunc(text)
|
||||
}
|
||||
return tokenizer.NumTokensFromString(text)
|
||||
}
|
||||
|
||||
func (s *ChunkService) getEmbeddingModel(tenantID, embdID string) (*models.EmbeddingModel, error) {
|
||||
if s.getEmbeddingModelFunc != nil {
|
||||
return s.getEmbeddingModelFunc(tenantID, embdID)
|
||||
}
|
||||
return service.NewModelProviderService().GetEmbeddingModel(tenantID, embdID)
|
||||
}
|
||||
|
||||
func (s *ChunkService) incrementChunkStats(docID, kbID string, tokenNum, chunkNum int64, duration float64) error {
|
||||
if s.incrementChunkStatsFunc != nil {
|
||||
return s.incrementChunkStatsFunc(docID, kbID, tokenNum, chunkNum, duration)
|
||||
}
|
||||
return dao.DB.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.Model(&entity.Document{}).
|
||||
Where("id = ? AND kb_id = ?", docID, kbID).
|
||||
Updates(map[string]interface{}{
|
||||
"token_num": gorm.Expr("token_num + ?", tokenNum),
|
||||
"chunk_num": gorm.Expr("chunk_num + ?", chunkNum),
|
||||
"process_duration": gorm.Expr("process_duration + ?", duration),
|
||||
})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("document not found")
|
||||
}
|
||||
|
||||
result = tx.Model(&entity.Knowledgebase{}).
|
||||
Where("id = ?", kbID).
|
||||
Updates(map[string]interface{}{
|
||||
"token_num": gorm.Expr("token_num + ?", tokenNum),
|
||||
"chunk_num": gorm.Expr("chunk_num + ?", chunkNum),
|
||||
})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("knowledgebase not found")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ChunkService) storeChunkImage(bucket, chunkID string, imageBinary []byte) error {
|
||||
if s.storeChunkImageFunc != nil {
|
||||
return s.storeChunkImageFunc(bucket, chunkID, imageBinary)
|
||||
}
|
||||
storageImpl := storage.GetStorageFactory().GetStorage()
|
||||
if storageImpl == nil {
|
||||
return fmt.Errorf("storage not initialized")
|
||||
}
|
||||
lockKey := bucket + "/" + chunkID
|
||||
lock := acquireChunkImageMergeLock(lockKey)
|
||||
lock.mu.Lock()
|
||||
defer func() {
|
||||
lock.mu.Unlock()
|
||||
releaseChunkImageMergeLock(lockKey)
|
||||
}()
|
||||
|
||||
if !storageImpl.ObjExist(bucket, chunkID) {
|
||||
return storageImpl.Put(bucket, chunkID, imageBinary)
|
||||
}
|
||||
|
||||
oldBinary, err := storageImpl.Get(bucket, chunkID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
oldImage, _, err := image.Decode(bytes.NewReader(oldBinary))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newImage, _, err := image.Decode(bytes.NewReader(imageBinary))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
oldBounds, newBounds := oldImage.Bounds(), newImage.Bounds()
|
||||
width := oldBounds.Dx()
|
||||
if newBounds.Dx() > width {
|
||||
width = newBounds.Dx()
|
||||
}
|
||||
height := oldBounds.Dy() + newBounds.Dy()
|
||||
combined := image.NewRGBA(image.Rect(0, 0, width, height))
|
||||
draw.Draw(combined, combined.Bounds(), &image.Uniform{C: color.White}, image.Point{}, draw.Src)
|
||||
draw.Draw(combined, oldBounds, oldImage, oldBounds.Min, draw.Src)
|
||||
draw.Draw(combined, image.Rect(0, oldBounds.Dy(), newBounds.Dx(), oldBounds.Dy()+newBounds.Dy()), newImage, newBounds.Min, draw.Src)
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := jpeg.Encode(&buf, combined, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
return storageImpl.Put(bucket, chunkID, buf.Bytes())
|
||||
}
|
||||
|
||||
func acquireChunkImageMergeLock(key string) *chunkImageMergeLock {
|
||||
chunkImageMergeLocks.Lock()
|
||||
defer chunkImageMergeLocks.Unlock()
|
||||
|
||||
lock := chunkImageMergeLocks.locks[key]
|
||||
if lock == nil {
|
||||
lock = &chunkImageMergeLock{}
|
||||
chunkImageMergeLocks.locks[key] = lock
|
||||
}
|
||||
lock.refs++
|
||||
return lock
|
||||
}
|
||||
|
||||
func releaseChunkImageMergeLock(key string) {
|
||||
chunkImageMergeLocks.Lock()
|
||||
defer chunkImageMergeLocks.Unlock()
|
||||
|
||||
lock := chunkImageMergeLocks.locks[key]
|
||||
if lock == nil {
|
||||
return
|
||||
}
|
||||
lock.refs--
|
||||
if lock.refs == 0 {
|
||||
delete(chunkImageMergeLocks.locks, key)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
package chunk
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/png"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/engine/types"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/entity/models"
|
||||
"ragflow/internal/service"
|
||||
"ragflow/internal/storage"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
@@ -412,6 +419,305 @@ func TestParseQueuesAndBeginsDocument(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddChunkSuccess(t *testing.T) {
|
||||
db := setupChunkTestDB(t)
|
||||
pushChunkTestDB(t, db)
|
||||
userID, datasetID, documentID := "user-1", "kb-1", "doc-1"
|
||||
insertChunkTestKB(t, datasetID, userID)
|
||||
insertChunkTestDoc(t, documentID, datasetID)
|
||||
|
||||
engine := &addChunkTestEngine{}
|
||||
var incrementTokenNum, incrementChunkNum int64
|
||||
svc := &ChunkService{
|
||||
docEngine: engine,
|
||||
kbDAO: dao.NewKnowledgebaseDAO(),
|
||||
documentDAO: dao.NewDocumentDAO(),
|
||||
accessibleFunc: func(datasetIDArg, userIDArg string) bool {
|
||||
return datasetIDArg == datasetID && userIDArg == userID
|
||||
},
|
||||
getKnowledgebaseByIDFunc: func(id string) (*entity.Knowledgebase, error) {
|
||||
return &entity.Knowledgebase{ID: id, TenantID: userID, EmbdID: "embed-1"}, nil
|
||||
},
|
||||
getEmbeddingModelFunc: func(string, string) (*models.EmbeddingModel, error) {
|
||||
driver := &stubEmbeddingDriver{
|
||||
embeddings: []models.EmbeddingData{
|
||||
{Embedding: []float64{1, 2}},
|
||||
{Embedding: []float64{3, 4}},
|
||||
},
|
||||
}
|
||||
modelName := "embed-1"
|
||||
return models.NewEmbeddingModel(driver, &modelName, &models.APIConfig{}, 0), nil
|
||||
},
|
||||
incrementChunkStatsFunc: func(docID, kbID string, tokenNum, chunkNum int64, duration float64) error {
|
||||
if docID != documentID || kbID != datasetID || duration != 0 {
|
||||
t.Fatalf("unexpected increment args doc=%s kb=%s duration=%v", docID, kbID, duration)
|
||||
}
|
||||
incrementTokenNum = tokenNum
|
||||
incrementChunkNum = chunkNum
|
||||
return nil
|
||||
},
|
||||
tokenizeFunc: func(text string) (string, error) { return text, nil },
|
||||
fineGrainedTokenizeFunc: func(text string) (string, error) { return text + "_fg", nil },
|
||||
numTokensFunc: func(text string) int { return len(text) },
|
||||
}
|
||||
|
||||
resp, err := svc.AddChunk(&service.AddChunkRequest{
|
||||
DatasetID: datasetID,
|
||||
DocumentID: documentID,
|
||||
Content: "chunk body",
|
||||
ImportantKeywords: []string{"k1"},
|
||||
Questions: []string{" q1 ", ""},
|
||||
TagKwd: []string{"tag1"},
|
||||
TagFeas: map[string]interface{}{"tag1": float64(0.5)},
|
||||
}, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("AddChunk() error = %v", err)
|
||||
}
|
||||
if resp == nil || resp.Chunk == nil {
|
||||
t.Fatalf("expected chunk response, got %#v", resp)
|
||||
}
|
||||
if resp.Chunk["dataset_id"] != datasetID || resp.Chunk["document_id"] != documentID {
|
||||
t.Fatalf("unexpected response chunk: %#v", resp.Chunk)
|
||||
}
|
||||
if resp.Chunk["content"] != "chunk body" {
|
||||
t.Fatalf("content = %v, want chunk body", resp.Chunk["content"])
|
||||
}
|
||||
if resp.Chunk["document"] != "doc-1.txt" {
|
||||
t.Fatalf("document = %v, want doc-1.txt", resp.Chunk["document"])
|
||||
}
|
||||
if incrementChunkNum != 1 {
|
||||
t.Fatalf("increment chunk num = %d, want 1", incrementChunkNum)
|
||||
}
|
||||
if incrementTokenNum <= 0 {
|
||||
t.Fatalf("increment token num = %d, want > 0", incrementTokenNum)
|
||||
}
|
||||
if len(engine.insertedChunks) != 1 {
|
||||
t.Fatalf("inserted chunks = %d, want 1", len(engine.insertedChunks))
|
||||
}
|
||||
inserted := engine.insertedChunks[0]
|
||||
if inserted["doc_id"] != documentID || inserted["kb_id"] != datasetID {
|
||||
t.Fatalf("unexpected inserted chunk: %#v", inserted)
|
||||
}
|
||||
if inserted["img_id"] != nil {
|
||||
t.Fatalf("did not expect image id in inserted chunk: %#v", inserted)
|
||||
}
|
||||
vec, ok := inserted["q_2_vec"].([]float64)
|
||||
if !ok {
|
||||
t.Fatalf("expected q_2_vec []float64, got %T", inserted["q_2_vec"])
|
||||
}
|
||||
if len(vec) != 2 || vec[0] < 2.7999 || vec[0] > 2.8001 || vec[1] < 3.7999 || vec[1] > 3.8001 {
|
||||
t.Fatalf("vector = %v, want approximately [2.8 3.8]", vec)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddChunkValidationErrors(t *testing.T) {
|
||||
db := setupChunkTestDB(t)
|
||||
pushChunkTestDB(t, db)
|
||||
insertChunkTestKB(t, "kb-1", "user-1")
|
||||
insertChunkTestDoc(t, "doc-1", "kb-1")
|
||||
|
||||
svc := &ChunkService{
|
||||
docEngine: &addChunkTestEngine{},
|
||||
kbDAO: dao.NewKnowledgebaseDAO(),
|
||||
documentDAO: dao.NewDocumentDAO(),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
req *service.AddChunkRequest
|
||||
setup func(*ChunkService)
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
name: "nil request",
|
||||
req: nil,
|
||||
wantMsg: "invalid request payload",
|
||||
},
|
||||
{
|
||||
name: "empty content",
|
||||
req: &service.AddChunkRequest{DatasetID: "kb-1", DocumentID: "doc-1", Content: " "},
|
||||
setup: func(svc *ChunkService) {
|
||||
svc.accessibleFunc = func(string, string) bool { return true }
|
||||
svc.getKnowledgebaseByIDFunc = func(string) (*entity.Knowledgebase, error) {
|
||||
return &entity.Knowledgebase{ID: "kb-1", TenantID: "user-1", EmbdID: "embed-1"}, nil
|
||||
}
|
||||
},
|
||||
wantMsg: "`content` is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.setup != nil {
|
||||
tt.setup(svc)
|
||||
}
|
||||
_, err := svc.AddChunk(tt.req, "user-1")
|
||||
if err == nil || !strings.Contains(err.Error(), tt.wantMsg) {
|
||||
t.Fatalf("error = %v, want substring %q", err, tt.wantMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddChunkImageAndTagFeatureValidation(t *testing.T) {
|
||||
db := setupChunkTestDB(t)
|
||||
pushChunkTestDB(t, db)
|
||||
userID, datasetID, documentID := "user-1", "kb-1", "doc-1"
|
||||
insertChunkTestKB(t, datasetID, userID)
|
||||
insertChunkTestDoc(t, documentID, datasetID)
|
||||
|
||||
storeCalls := 0
|
||||
svc := &ChunkService{
|
||||
docEngine: &addChunkTestEngine{},
|
||||
kbDAO: dao.NewKnowledgebaseDAO(),
|
||||
documentDAO: dao.NewDocumentDAO(),
|
||||
accessibleFunc: func(string, string) bool { return true },
|
||||
getKnowledgebaseByIDFunc: func(id string) (*entity.Knowledgebase, error) {
|
||||
return &entity.Knowledgebase{ID: id, TenantID: userID, EmbdID: "embed-1"}, nil
|
||||
},
|
||||
tokenizeFunc: func(text string) (string, error) { return text, nil },
|
||||
fineGrainedTokenizeFunc: func(text string) (string, error) { return text + "_fg", nil },
|
||||
numTokensFunc: func(text string) int { return len(text) },
|
||||
getEmbeddingModelFunc: func(string, string) (*models.EmbeddingModel, error) {
|
||||
driver := &stubEmbeddingDriver{
|
||||
embeddings: []models.EmbeddingData{
|
||||
{Embedding: []float64{1, 1}},
|
||||
{Embedding: []float64{1, 1}},
|
||||
},
|
||||
}
|
||||
modelName := "embed-1"
|
||||
return models.NewEmbeddingModel(driver, &modelName, &models.APIConfig{}, 0), nil
|
||||
},
|
||||
incrementChunkStatsFunc: func(string, string, int64, int64, float64) error { return nil },
|
||||
storeChunkImageFunc: func(bucket, chunkID string, imageBinary []byte) error {
|
||||
storeCalls++
|
||||
if bucket != datasetID || chunkID == "" || len(imageBinary) == 0 {
|
||||
t.Fatalf("unexpected store args bucket=%s chunkID=%s len=%d", bucket, chunkID, len(imageBinary))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
_, err := svc.AddChunk(&service.AddChunkRequest{
|
||||
DatasetID: datasetID,
|
||||
DocumentID: documentID,
|
||||
Content: "chunk body",
|
||||
ImageBase64: strPtr("not-base64"),
|
||||
}, userID)
|
||||
if err == nil || !strings.Contains(err.Error(), "Invalid `image_base64`") {
|
||||
t.Fatalf("expected invalid image error, got %v", err)
|
||||
}
|
||||
|
||||
validJPEG := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO2pRZ0AAAAASUVORK5CYII="
|
||||
|
||||
resp, err := svc.AddChunk(&service.AddChunkRequest{
|
||||
DatasetID: datasetID,
|
||||
DocumentID: documentID,
|
||||
Content: "chunk body",
|
||||
TagFeas: map[string]interface{}{"tag1": "bad"},
|
||||
}, userID)
|
||||
if err == nil || !strings.Contains(err.Error(), "`tag_feas` values must be finite numbers") || resp != nil {
|
||||
t.Fatalf("expected tag_feas validation error, got resp=%#v err=%v", resp, err)
|
||||
}
|
||||
|
||||
resp, err = svc.AddChunk(&service.AddChunkRequest{
|
||||
DatasetID: datasetID,
|
||||
DocumentID: documentID,
|
||||
Content: "chunk body",
|
||||
TagFeas: map[string]interface{}{"tag1": float64(1)},
|
||||
ImageBase64: strPtr(validJPEG),
|
||||
}, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("AddChunk() with image error = %v", err)
|
||||
}
|
||||
if storeCalls != 1 {
|
||||
t.Fatalf("store image calls = %d, want 1", storeCalls)
|
||||
}
|
||||
if _, ok := resp.Chunk["image_id"]; !ok {
|
||||
t.Fatalf("expected image_id in response, got %#v", resp.Chunk)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddChunkIncrementsStatsAfterInsert(t *testing.T) {
|
||||
db := setupChunkTestDB(t)
|
||||
pushChunkTestDB(t, db)
|
||||
userID, datasetID, documentID := "user-1", "kb-1", "doc-1"
|
||||
insertChunkTestKB(t, datasetID, userID)
|
||||
insertChunkTestDoc(t, documentID, datasetID)
|
||||
|
||||
var incrementCalls int
|
||||
engine := &addChunkTestEngine{}
|
||||
svc := &ChunkService{
|
||||
docEngine: engine,
|
||||
kbDAO: dao.NewKnowledgebaseDAO(),
|
||||
documentDAO: dao.NewDocumentDAO(),
|
||||
accessibleFunc: func(string, string) bool { return true },
|
||||
getKnowledgebaseByIDFunc: func(id string) (*entity.Knowledgebase, error) {
|
||||
return &entity.Knowledgebase{ID: id, TenantID: userID, EmbdID: "embed-1"}, nil
|
||||
},
|
||||
getEmbeddingModelFunc: func(string, string) (*models.EmbeddingModel, error) {
|
||||
driver := &stubEmbeddingDriver{
|
||||
embeddings: []models.EmbeddingData{
|
||||
{Embedding: []float64{1, 2}},
|
||||
{Embedding: []float64{3, 4}},
|
||||
},
|
||||
}
|
||||
modelName := "embed-1"
|
||||
return models.NewEmbeddingModel(driver, &modelName, &models.APIConfig{}, 0), nil
|
||||
},
|
||||
incrementChunkStatsFunc: func(string, string, int64, int64, float64) error {
|
||||
incrementCalls++
|
||||
return nil
|
||||
},
|
||||
tokenizeFunc: func(text string) (string, error) { return text, nil },
|
||||
fineGrainedTokenizeFunc: func(text string) (string, error) { return text + "_fg", nil },
|
||||
numTokensFunc: func(text string) int { return len(text) },
|
||||
}
|
||||
|
||||
_, err := svc.AddChunk(&service.AddChunkRequest{
|
||||
DatasetID: datasetID,
|
||||
DocumentID: documentID,
|
||||
Content: "chunk body",
|
||||
}, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("AddChunk() error = %v", err)
|
||||
}
|
||||
if incrementCalls != 1 {
|
||||
t.Fatalf("increment calls = %d, want 1", incrementCalls)
|
||||
}
|
||||
if len(engine.insertedChunks) != 1 {
|
||||
t.Fatalf("inserted chunks = %d, want 1", len(engine.insertedChunks))
|
||||
}
|
||||
importantKwd, ok := engine.insertedChunks[0]["important_kwd"].([]string)
|
||||
if !ok || len(importantKwd) != 0 {
|
||||
t.Fatalf("important_kwd = %#v, want empty []string", engine.insertedChunks[0]["important_kwd"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreChunkImageMergesExistingImage(t *testing.T) {
|
||||
oldImage := mustEncodePNG(t, image.Rect(0, 0, 2, 2))
|
||||
newImage := mustEncodePNG(t, image.Rect(0, 0, 1, 1))
|
||||
mockStorage := &chunkImageStorage{
|
||||
exists: true,
|
||||
oldBinary: oldImage,
|
||||
}
|
||||
|
||||
factory := storage.GetStorageFactory()
|
||||
originalStorage := factory.GetStorage()
|
||||
factory.SetStorage(mockStorage)
|
||||
t.Cleanup(func() {
|
||||
factory.SetStorage(originalStorage)
|
||||
})
|
||||
|
||||
svc := &ChunkService{}
|
||||
if err := svc.storeChunkImage("kb-1", "chunk-1", newImage); err != nil {
|
||||
t.Fatalf("storeChunkImage() error = %v", err)
|
||||
}
|
||||
if mockStorage.putCalls != 1 {
|
||||
t.Fatalf("put calls = %d, want 1", mockStorage.putCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func setupChunkTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
@@ -620,6 +926,117 @@ func (e *parseTestDocEngine) GetChunkIDs([]map[string]interface{}) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
type addChunkTestEngine struct {
|
||||
parseTestDocEngine
|
||||
insertedChunks []map[string]interface{}
|
||||
insertIndex string
|
||||
insertDataset string
|
||||
insertErr error
|
||||
}
|
||||
|
||||
func (e *addChunkTestEngine) InsertChunks(_ context.Context, chunks []map[string]interface{}, baseName string, datasetID string) ([]string, error) {
|
||||
e.insertedChunks = chunks
|
||||
e.insertIndex = baseName
|
||||
e.insertDataset = datasetID
|
||||
return nil, e.insertErr
|
||||
}
|
||||
|
||||
type chunkImageStorage struct {
|
||||
exists bool
|
||||
oldBinary []byte
|
||||
putCalls int
|
||||
}
|
||||
|
||||
func (s *chunkImageStorage) Health() bool { return true }
|
||||
func (s *chunkImageStorage) Put(bucket, fnm string, binary []byte, tenantID ...string) error {
|
||||
s.putCalls++
|
||||
return nil
|
||||
}
|
||||
func (s *chunkImageStorage) Get(bucket, fnm string, tenantID ...string) ([]byte, error) {
|
||||
return s.oldBinary, nil
|
||||
}
|
||||
func (s *chunkImageStorage) Remove(bucket, fnm string, tenantID ...string) error { return nil }
|
||||
func (s *chunkImageStorage) ObjExist(bucket, fnm string, tenantID ...string) bool { return s.exists }
|
||||
func (s *chunkImageStorage) GetPresignedURL(bucket, fnm string, expires time.Duration, tenantID ...string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
func (s *chunkImageStorage) BucketExists(bucket string) bool { return true }
|
||||
func (s *chunkImageStorage) RemoveBucket(bucket string) error { return nil }
|
||||
func (s *chunkImageStorage) Copy(srcBucket, srcPath, destBucket, destPath string) bool { return false }
|
||||
func (s *chunkImageStorage) Move(srcBucket, srcPath, destBucket, destPath string) bool { return false }
|
||||
|
||||
func mustEncodePNG(t *testing.T, rect image.Rectangle) []byte {
|
||||
t.Helper()
|
||||
|
||||
img := image.NewRGBA(rect)
|
||||
for y := rect.Min.Y; y < rect.Max.Y; y++ {
|
||||
for x := rect.Min.X; x < rect.Max.X; x++ {
|
||||
img.Set(x, y, color.White)
|
||||
}
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := png.Encode(&buf, img); err != nil {
|
||||
t.Fatalf("encode png: %v", err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
type stubEmbeddingDriver struct {
|
||||
embeddings []models.EmbeddingData
|
||||
embedErr error
|
||||
}
|
||||
|
||||
func (d *stubEmbeddingDriver) NewInstance(map[string]string) models.ModelDriver { return d }
|
||||
func (d *stubEmbeddingDriver) Name() string { return "stub" }
|
||||
func (d *stubEmbeddingDriver) ChatWithMessages(string, []models.Message, *models.APIConfig, *models.ChatConfig) (*models.ChatResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (d *stubEmbeddingDriver) ChatStreamlyWithSender(string, []models.Message, *models.APIConfig, *models.ChatConfig, func(*string, *string) error) error {
|
||||
return nil
|
||||
}
|
||||
func (d *stubEmbeddingDriver) Embed(*string, []string, *models.APIConfig, *models.EmbeddingConfig) ([]models.EmbeddingData, error) {
|
||||
return d.embeddings, d.embedErr
|
||||
}
|
||||
func (d *stubEmbeddingDriver) Rerank(*string, string, []string, *models.APIConfig, *models.RerankConfig) (*models.RerankResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (d *stubEmbeddingDriver) TranscribeAudio(*string, *string, *models.APIConfig, *models.ASRConfig) (*models.ASRResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (d *stubEmbeddingDriver) TranscribeAudioWithSender(*string, *string, *models.APIConfig, *models.ASRConfig, func(*string, *string) error) error {
|
||||
return nil
|
||||
}
|
||||
func (d *stubEmbeddingDriver) AudioSpeech(*string, *string, *models.APIConfig, *models.TTSConfig) (*models.TTSResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (d *stubEmbeddingDriver) AudioSpeechWithSender(*string, *string, *models.APIConfig, *models.TTSConfig, func(*string, *string) error) error {
|
||||
return nil
|
||||
}
|
||||
func (d *stubEmbeddingDriver) OCRFile(*string, []byte, *string, *models.APIConfig, *models.OCRConfig) (*models.OCRFileResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (d *stubEmbeddingDriver) ParseFile(*string, []byte, *string, *models.APIConfig, *models.ParseFileConfig) (*models.ParseFileResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (d *stubEmbeddingDriver) ListModels(*models.APIConfig) ([]models.ListModelResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (d *stubEmbeddingDriver) Balance(*models.APIConfig) (map[string]interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (d *stubEmbeddingDriver) CheckConnection(*models.APIConfig) error { return nil }
|
||||
func (d *stubEmbeddingDriver) ListTasks(*models.APIConfig) ([]models.ListTaskStatus, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (d *stubEmbeddingDriver) ShowTask(string, *models.APIConfig) (*models.TaskResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func strPtr(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
func (e *parseTestDocEngine) KNNScores(context.Context, []map[string]interface{}, []float64, int) (map[string]interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -95,6 +95,29 @@ type ParseFileRequest struct {
|
||||
DocumentIDs []string `json:"document_ids"`
|
||||
}
|
||||
|
||||
// AddChunkRequest request for adding a chunk
|
||||
type AddChunkRequest struct {
|
||||
DatasetID string `json:"dataset_id"`
|
||||
DocumentID string `json:"document_id"`
|
||||
Content string `json:"content"`
|
||||
ImportantKeywords []string `json:"important_keywords,omitempty"`
|
||||
Questions []string `json:"questions,omitempty"`
|
||||
TagKwd []string `json:"tag_kwd,omitempty"`
|
||||
TagFeas interface{} `json:"tag_feas,omitempty"`
|
||||
ImageBase64 *string `json:"image_base64,omitempty"`
|
||||
}
|
||||
|
||||
// AddChunkResponse response for adding a chunk
|
||||
type AddChunkResponse struct {
|
||||
Chunk map[string]interface{} `json:"chunk"`
|
||||
}
|
||||
|
||||
// ErrorCoder exposes an application error code alongside an error string.
|
||||
type ErrorCoder interface {
|
||||
error
|
||||
Code() common.ErrorCode
|
||||
}
|
||||
|
||||
// Get retrieves a chunk by ID
|
||||
func (s *ChunkService) Get(req *GetChunkRequest, userID string) (*GetChunkResponse, error) {
|
||||
if s.docEngine == nil {
|
||||
|
||||
Reference in New Issue
Block a user