From 54fb5b0fa779db73db25f26f473ee9a838d712e2 Mon Sep 17 00:00:00 2001 From: Hz_ Date: Thu, 25 Jun 2026 14:15:29 +0800 Subject: [PATCH] feat(go-api): add Go support for POST /api/v1/datasets/{dataset_id}/documents/{document_id}/chunks (#16256) ## Summary Add the Go implementation of `POST /api/v1/datasets/{dataset_id}/documents/{document_id}/chunks`. This wires the full create-chunk path in Go: - router and handler registration - request/response structs - chunk creation service logic - embedding generation - chunk insert into doc engine - chunk/token counter increment - `tag_feas` validation - `image_base64` decoding and chunk image storage/merge - unit tests for handler and service ## Testing Unit tests: - `/usr/local/go/bin/go test ./internal/handler` - `/usr/local/go/bin/go test ./internal/service/chunk` - `/usr/local/go/bin/go test ./internal/service` - `/usr/local/go/bin/go test ./...` All passed locally. Manual curl checks: - basic text chunk: Go passed - chunk with `important_keywords` / `questions` / `tag_kwd` / `tag_feas`: Go passed - blank content validation: Go matched expected `code=102` - invalid `image_base64` validation: Go matched expected `code=102` - image upload and repeated image upload / merge path: Go passed twice --- internal/handler/chunk.go | 130 ++++++ internal/handler/chunk_test.go | 198 +++++++++ .../pregel_stream_retry_integration_test.go | 9 +- internal/router/router.go | 1 + internal/service/chunk/chunk.go | 406 +++++++++++++++++ internal/service/chunk/chunk_test.go | 417 ++++++++++++++++++ internal/service/chunk_types.go | 23 + 7 files changed, 1179 insertions(+), 5 deletions(-) diff --git a/internal/handler/chunk.go b/internal/handler/chunk.go index fd667adb81..b1b53bd502 100644 --- a/internal/handler/chunk.go +++ b/internal/handler/chunk.go @@ -17,6 +17,7 @@ package handler import ( "encoding/json" + "errors" "fmt" "net/http" "ragflow/internal/common" @@ -38,6 +39,7 @@ type chunkService interface { UpdateChunk(req *service.UpdateChunkRequest, userID string) error RemoveChunks(req *service.RemoveChunksRequest, userID string) (int64, error) Parse(userID, datasetID string, req *service.ParseFileRequest) (map[string]interface{}, common.ErrorCode, error) + AddChunk(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) StopParsing(userID, datasetID string, req service.StopParsingRequest) (*service.StopParsingResponse, common.ErrorCode, error) } @@ -788,3 +790,131 @@ func (h *ChunkHandler) RemoveChunks(c *gin.Context) { "message": "success", }) } + +func addChunkStringField(rawBody map[string]json.RawMessage, field string) (string, error) { + raw, ok := rawBody[field] + if !ok { + return "", nil + } + var value string + if err := json.Unmarshal(raw, &value); err != nil { + return "", fmt.Errorf("`%s` must be a string", field) + } + return value, nil +} + +func addChunkStringPtrField(rawBody map[string]json.RawMessage, field string) (*string, error) { + raw, ok := rawBody[field] + if !ok { + return nil, nil + } + var value string + if err := json.Unmarshal(raw, &value); err != nil { + return nil, fmt.Errorf("`%s` must be a string", field) + } + return &value, nil +} + +func addChunkStringListField(rawBody map[string]json.RawMessage, field, listMessage, elementMessage string) ([]string, error) { + raw, ok := rawBody[field] + if !ok { + return nil, nil + } + var values []interface{} + if err := json.Unmarshal(raw, &values); err != nil { + return nil, errors.New(listMessage) + } + result := make([]string, len(values)) + for i, value := range values { + str, ok := value.(string) + if !ok { + return nil, errors.New(elementMessage) + } + result[i] = str + } + return result, nil +} + +func addChunkResponseMessage(code common.ErrorCode, err error) string { + if code == common.CodeServerError { + common.Warn("add chunk failed", zap.String("error", err.Error())) + return "Failed to add chunk" + } + return err.Error() +} + +func (h *ChunkHandler) AddChunk(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + userID := user.ID + datasetID, documentID := strings.TrimSpace(c.Param("dataset_id")), strings.TrimSpace(c.Param("document_id")) + + var rawBody map[string]json.RawMessage + if err := json.NewDecoder(c.Request.Body).Decode(&rawBody); err != nil { + jsonError(c, common.CodeArgumentError, err.Error()) + return + } + content, err := addChunkStringField(rawBody, "content") + if err != nil { + jsonError(c, common.CodeArgumentError, err.Error()) + return + } + importantKeywords, err := addChunkStringListField(rawBody, "important_keywords", "`important_keywords` is required to be a list", "`important_keywords` must be a list of strings") + if err != nil { + jsonError(c, common.CodeArgumentError, err.Error()) + return + } + questions, err := addChunkStringListField(rawBody, "questions", "`questions` is required to be a list", "`questions` must be a list of strings") + if err != nil { + jsonError(c, common.CodeArgumentError, err.Error()) + return + } + tagKwd, err := addChunkStringListField(rawBody, "tag_kwd", "`tag_kwd` is required to be a list", "`tag_kwd` must be a list of strings") + if err != nil { + jsonError(c, common.CodeArgumentError, err.Error()) + return + } + imageBase64, err := addChunkStringPtrField(rawBody, "image_base64") + if err != nil { + jsonError(c, common.CodeArgumentError, err.Error()) + return + } + var tagFeas interface{} + if raw, ok := rawBody["tag_feas"]; ok { + if err := json.Unmarshal(raw, &tagFeas); err != nil { + jsonError(c, common.CodeArgumentError, err.Error()) + return + } + } + + req := service.AddChunkRequest{ + DatasetID: datasetID, + DocumentID: documentID, + Content: content, + ImportantKeywords: importantKeywords, + Questions: questions, + TagKwd: tagKwd, + TagFeas: tagFeas, + ImageBase64: imageBase64, + } + + resp, err := h.chunkService.AddChunk(&req, userID) + if err != nil { + if codedErr, ok := err.(service.ErrorCoder); ok { + jsonError(c, codedErr.Code(), addChunkResponseMessage(codedErr.Code(), err)) + return + } + jsonError(c, common.CodeServerError, addChunkResponseMessage(common.CodeServerError, err)) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": resp, + "message": "success", + }) +} diff --git a/internal/handler/chunk_test.go b/internal/handler/chunk_test.go index a5ced9f270..597238386c 100644 --- a/internal/handler/chunk_test.go +++ b/internal/handler/chunk_test.go @@ -20,6 +20,7 @@ import ( // Only the methods actually called by the test are set; others panic. type mockChunkSvc struct { retrievalTestFn func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) + addChunkFn func(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) listFn func(req *service.ListChunksRequest, userID string) (*service.ListChunksResponse, error) switchChunksFn func(userID, datasetID, documentID string, availableInt int, chunkIDs []string) error updateChunkFn func(req *service.UpdateChunkRequest, userID string) error @@ -68,6 +69,12 @@ func (m *mockChunkSvc) StopParsing(userID, datasetID string, req service.StopPar func (m *mockChunkSvc) Parse(string, string, *service.ParseFileRequest) (map[string]interface{}, common.ErrorCode, error) { panic("not implemented") } +func (m *mockChunkSvc) AddChunk(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) { + if m.addChunkFn != nil { + return m.addChunkFn(req, userID) + } + return &service.AddChunkResponse{Chunk: map[string]interface{}{"id": "chunk-1"}}, nil +} func setupChunkRetrievalTest(userID string) (*gin.Engine, *mockChunkSvc) { mock := &mockChunkSvc{} @@ -563,3 +570,194 @@ func TestChunkRetrieval_ServiceError(t *testing.T) { t.Errorf("internal error details leaked to response: %q", msg) } } + +type addChunkTestError struct { + code common.ErrorCode + msg string +} + +func (e addChunkTestError) Error() string { return e.msg } +func (e addChunkTestError) Code() common.ErrorCode { return e.code } + +func TestChunkHandlerAddChunkSuccess(t *testing.T) { + mock := &mockChunkSvc{ + addChunkFn: func(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) { + if userID != "user1" { + t.Fatalf("userID = %q, want user1", userID) + } + if req.DatasetID != "kb1" || req.DocumentID != "doc1" || req.Content != "chunk body" { + t.Fatalf("unexpected request: %#v", req) + } + return &service.AddChunkResponse{Chunk: map[string]interface{}{"id": "chunk-1", "content": req.Content}}, nil + }, + } + + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("user", &entity.User{ID: "user1"}) + }) + h := &ChunkHandler{chunkService: mock} + r.POST("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.AddChunk) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/datasets/kb1/documents/doc1/chunks", strings.NewReader(`{"content":"chunk body"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if resp["code"] != float64(common.CodeSuccess) { + t.Fatalf("expected success code, got %v", resp["code"]) + } +} + +func TestChunkHandlerAddChunkPathIDsOverrideBody(t *testing.T) { + mock := &mockChunkSvc{ + addChunkFn: func(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) { + if req.DatasetID != "kb1" || req.DocumentID != "doc1" { + t.Fatalf("path IDs were not preserved: %#v", req) + } + if req.Content != "chunk body" { + t.Fatalf("unexpected content: %#v", req) + } + return &service.AddChunkResponse{Chunk: map[string]interface{}{"id": "chunk-1"}}, nil + }, + } + + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("user", &entity.User{ID: "user1"}) + }) + h := &ChunkHandler{chunkService: mock} + r.POST("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.AddChunk) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/datasets/kb1/documents/doc1/chunks", strings.NewReader(`{"dataset_id":"evil-kb","document_id":"evil-doc","content":"chunk body"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestChunkHandlerAddChunkCodedError(t *testing.T) { + mock := &mockChunkSvc{ + addChunkFn: func(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) { + return nil, addChunkTestError{code: common.CodeDataError, msg: "`content` is required"} + }, + } + + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("user", &entity.User{ID: "user1"}) + }) + h := &ChunkHandler{chunkService: mock} + r.POST("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.AddChunk) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/datasets/kb1/documents/doc1/chunks", strings.NewReader(`{"content":" "}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if resp["code"] != float64(common.CodeDataError) { + t.Fatalf("expected data error code, got %v", resp["code"]) + } +} + +func TestChunkHandlerAddChunkValidatesListFields(t *testing.T) { + tests := []struct { + name string + body string + wantMsg string + }{ + { + name: "important keywords type", + body: `{"content":"chunk body","important_keywords":{}}`, + wantMsg: "`important_keywords` is required to be a list", + }, + { + name: "tag kwd element type", + body: `{"content":"chunk body","tag_kwd":[1]}`, + wantMsg: "`tag_kwd` must be a list of strings", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock := &mockChunkSvc{ + addChunkFn: func(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) { + t.Fatal("service should not be called for invalid request") + return nil, nil + }, + } + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("user", &entity.User{ID: "user1"}) + }) + h := &ChunkHandler{chunkService: mock} + r.POST("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.AddChunk) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/datasets/kb1/documents/doc1/chunks", strings.NewReader(tt.body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if resp["message"] != tt.wantMsg { + t.Fatalf("message = %v, want %q", resp["message"], tt.wantMsg) + } + }) + } +} + +func TestChunkHandlerAddChunkHidesServerErrorDetails(t *testing.T) { + mock := &mockChunkSvc{ + addChunkFn: func(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) { + return nil, addChunkTestError{code: common.CodeServerError, msg: "encode chunk embedding: provider secret"} + }, + } + + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("user", &entity.User{ID: "user1"}) + }) + h := &ChunkHandler{chunkService: mock} + r.POST("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.AddChunk) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/datasets/kb1/documents/doc1/chunks", strings.NewReader(`{"content":"chunk body"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if resp["message"] != "Failed to add chunk" { + t.Fatalf("message = %v, want generic failure", resp["message"]) + } + if strings.Contains(w.Body.String(), "provider secret") { + t.Fatalf("server error details leaked: %s", w.Body.String()) + } +} diff --git a/internal/harness/graph/pregel/pregel_stream_retry_integration_test.go b/internal/harness/graph/pregel/pregel_stream_retry_integration_test.go index b1527a2933..98f283c3ae 100644 --- a/internal/harness/graph/pregel/pregel_stream_retry_integration_test.go +++ b/internal/harness/graph/pregel/pregel_stream_retry_integration_test.go @@ -367,9 +367,6 @@ func TestEngine_ChainOf100(t *testing.T) { if m == nil { m = map[string]any{} } - if m == nil { - m = map[string]any{} - } m["value"] = i return m, nil }) @@ -380,14 +377,16 @@ func TestEngine_ChainOf100(t *testing.T) { engine := NewEngine(sg, WithRecursionLimit(150), - WithMaxConcurrency(4), ) result, err := engine.RunSync(context.Background(), map[string]any{"value": "start"}) if err != nil { t.Fatalf("RunSync: %v", err) } - m := result.(map[string]any) + m, ok := result.(map[string]any) + if !ok { + t.Fatalf("expected map[string]any result, got %T", result) + } if v, ok := m["value"]; !ok || v.(int) != 99 { t.Fatalf("expected value=99, got %v", m["value"]) } diff --git a/internal/router/router.go b/internal/router/router.go index 975666f26a..ab72024798 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -310,6 +310,7 @@ func (r *Router) Setup(engine *gin.Engine) { datasets.GET("/:dataset_id/documents/:document_id", r.documentHandler.DownloadDocument) datasets.PATCH("/:dataset_id/documents/:document_id", r.documentHandler.UpdateDatasetDocument) datasets.DELETE("/:dataset_id/documents", r.documentHandler.DeleteDocuments) + datasets.POST("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.AddChunk) // Dataset document chunk datasets.GET("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.ListChunks) diff --git a/internal/service/chunk/chunk.go b/internal/service/chunk/chunk.go index 83e58cd917..36b05b5657 100644 --- a/internal/service/chunk/chunk.go +++ b/internal/service/chunk/chunk.go @@ -20,11 +20,17 @@ import ( "archive/zip" "bytes" "context" + "encoding/base64" "encoding/csv" "encoding/json" "encoding/xml" "fmt" + "image" + "image/color" + "image/draw" + "image/jpeg" "io" + "math" "math/rand" "path/filepath" "ragflow/internal/common" @@ -36,10 +42,15 @@ import ( "sort" "strconv" "strings" + "sync" "time" + _ "image/gif" + _ "image/png" + "github.com/cespare/xxhash/v2" "go.uber.org/zap" + "gorm.io/gorm" "ragflow/internal/dao" "ragflow/internal/engine" @@ -56,6 +67,16 @@ const ( maximumTaskPageNumber = maximumPageNumber * 1000 ) +var chunkImageMergeLocks = struct { + sync.Mutex + locks map[string]*chunkImageMergeLock +}{locks: make(map[string]*chunkImageMergeLock)} + +type chunkImageMergeLock struct { + mu sync.Mutex + refs int +} + // ChunkService chunk service type ChunkService struct { docEngine engine.DocEngine @@ -74,6 +95,12 @@ type ChunkService struct { queueParseTasksFunc func(*entity.Document, string, string, int64) error beginParseDocumentFunc func(string) error deleteTasksByDocIDsFunc func([]string) (int64, error) + getEmbeddingModelFunc func(string, string) (*models.EmbeddingModel, error) + incrementChunkStatsFunc func(string, string, int64, int64, float64) error + storeChunkImageFunc func(string, string, []byte) error + tokenizeFunc func(string) (string, error) + fineGrainedTokenizeFunc func(string) (string, error) + numTokensFunc func(string) int } // NewChunkService creates chunk service @@ -1711,3 +1738,382 @@ func (s *ChunkService) RemoveChunks(req *service.RemoveChunksRequest, userID str return deletedCount, nil } + +func (s *ChunkService) AddChunk(req *service.AddChunkRequest, userID string) (*service.AddChunkResponse, error) { + if s.docEngine == nil { + return nil, addChunkError{code: common.CodeServerError, message: "doc engine not initialized"} + } + if req == nil { + return nil, addChunkError{code: common.CodeDataError, message: "invalid request payload"} + } + if !s.accessible(req.DatasetID, userID) { + return nil, addChunkError{code: common.CodeDataError, message: fmt.Sprintf("You don't own the dataset %s.", req.DatasetID)} + } + + kb, err := s.getKnowledgebaseByID(req.DatasetID) + if err != nil || kb == nil { + return nil, addChunkError{code: common.CodeDataError, message: fmt.Sprintf("You don't own the dataset %s.", req.DatasetID)} + } + + doc, err := s.documentDAO.GetByDocumentIDAndDatasetID(req.DocumentID, req.DatasetID) + if err != nil || doc == nil { + return nil, addChunkError{code: common.CodeDataError, message: fmt.Sprintf("You don't own the document %s.", req.DocumentID)} + } + + content := strings.TrimSpace(req.Content) + if content == "" { + return nil, addChunkError{code: common.CodeDataError, message: "`content` is required"} + } + + var tagFeas map[string]float64 + if req.TagFeas != nil { + tagFeas, err = validateTagFeatures(req.TagFeas) + if err != nil { + return nil, addChunkError{code: common.CodeDataError, message: "`tag_feas` " + err.Error()} + } + } + + chunkID := strconv.FormatUint(xxhash.Sum64([]byte(req.Content+req.DocumentID)), 16) + indexName := fmt.Sprintf("ragflow_%s", kb.TenantID) + contentLtks, err := s.tokenize(req.Content) + if err != nil { + return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("tokenize content: %v", err)} + } + contentSmLtks, err := s.fineGrainedTokenize(contentLtks) + if err != nil { + return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("tokenize content fine-grained: %v", err)} + } + importantTks, err := s.tokenize(strings.Join(req.ImportantKeywords, " ")) + if err != nil { + return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("tokenize important keywords: %v", err)} + } + questionKwd := filterTrimmedStrings(req.Questions) + questionTks, err := s.tokenize(strings.Join(req.Questions, "\n")) + if err != nil { + return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("tokenize questions: %v", err)} + } + + now := time.Now() + docName := "" + if doc.Name != nil { + docName = *doc.Name + } + importantKeywords := req.ImportantKeywords + if importantKeywords == nil { + importantKeywords = []string{} + } + + chunkData := map[string]interface{}{ + "id": chunkID, + "content_with_weight": req.Content, + "content_ltks": contentLtks, + "content_sm_ltks": contentSmLtks, + "important_kwd": importantKeywords, + "important_tks": importantTks, + "question_kwd": questionKwd, + "question_tks": questionTks, + "create_time": now.Format("2006-01-02 15:04:05"), + "create_timestamp_flt": float64(now.UnixNano()) / float64(time.Second), + "kb_id": req.DatasetID, + "docnm_kwd": docName, + "doc_id": req.DocumentID, + } + if req.TagKwd != nil { + chunkData["tag_kwd"] = req.TagKwd + } + if tagFeas != nil { + chunkData["tag_feas"] = tagFeas + } + + if req.ImageBase64 != nil { + imageBinary, err := decodeChunkImageBase64(*req.ImageBase64) + if err != nil { + return nil, addChunkError{code: common.CodeDataError, message: err.Error()} + } + if err := s.storeChunkImage(req.DatasetID, chunkID, imageBinary); err != nil { + return nil, addChunkError{code: common.CodeDataError, message: "Failed to store chunk image"} + } + chunkData["img_id"] = fmt.Sprintf("%s-%s", req.DatasetID, chunkID) + chunkData["doc_type_kwd"] = "image" + } + + embeddingModel, err := s.getEmbeddingModel(kb.TenantID, kb.EmbdID) + if err != nil { + return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("get embedding model: %v", err)} + } + embeddingText := req.Content + if len(questionKwd) > 0 { + embeddingText = strings.Join(questionKwd, "\n") + } + embeddings, err := embeddingModel.ModelDriver.Embed(embeddingModel.ModelName, []string{docName, embeddingText}, embeddingModel.APIConfig, &models.EmbeddingConfig{Dimension: 0}) + if err != nil { + return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("encode chunk embedding: %v", err)} + } + if len(embeddings) != 2 { + return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("unexpected embedding count: %d", len(embeddings))} + } + mergedVec, err := mergeChunkEmbeddings(embeddings[0].Embedding, embeddings[1].Embedding) + if err != nil { + return nil, addChunkError{code: common.CodeServerError, message: err.Error()} + } + chunkData[fmt.Sprintf("q_%d_vec", len(mergedVec))] = mergedVec + + ctx, cancel := context.WithTimeout(context.Background(), 600*time.Second) + defer cancel() + if _, err := s.docEngine.InsertChunks(ctx, []map[string]interface{}{chunkData}, indexName, req.DatasetID); err != nil { + return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("insert chunk: %v", err)} + } + + tokenNum := int64(s.numTokens(req.Content)) + if err := s.incrementChunkStats(req.DocumentID, req.DatasetID, tokenNum, 1, 0); err != nil { + return nil, addChunkError{code: common.CodeServerError, message: fmt.Sprintf("increment chunk stats: %v", err)} + } + + renamedChunk := map[string]interface{}{ + "id": chunkID, + "content": req.Content, + "document_id": req.DocumentID, + "document": docName, + "important_keywords": importantKeywords, + "questions": questionKwd, + "dataset_id": req.DatasetID, + "create_timestamp": chunkData["create_timestamp_flt"], + "create_time": chunkData["create_time"], + } + if req.TagKwd != nil { + renamedChunk["tag_kwd"] = req.TagKwd + } + if imgID, ok := chunkData["img_id"]; ok { + renamedChunk["image_id"] = imgID + } + + return &service.AddChunkResponse{Chunk: renamedChunk}, nil +} + +type addChunkError struct { + code common.ErrorCode + message string +} + +func (e addChunkError) Error() string { + return e.message +} + +func (e addChunkError) Code() common.ErrorCode { + return e.code +} + +func validateTagFeatures(raw interface{}) (map[string]float64, error) { + parsed, ok := raw.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("must be an object mapping string tags to finite numeric scores") + } + cleaned := make(map[string]float64, len(parsed)) + for key, value := range parsed { + key = strings.TrimSpace(key) + if key == "" { + return nil, fmt.Errorf("keys must be non-empty strings") + } + switch typed := value.(type) { + case float64: + if math.IsNaN(typed) || math.IsInf(typed, 0) { + return nil, fmt.Errorf("values must be finite numbers") + } + cleaned[key] = typed + case float32: + if math.IsNaN(float64(typed)) || math.IsInf(float64(typed), 0) { + return nil, fmt.Errorf("values must be finite numbers") + } + cleaned[key] = float64(typed) + case int: + cleaned[key] = float64(typed) + case int8: + cleaned[key] = float64(typed) + case int16: + cleaned[key] = float64(typed) + case int32: + cleaned[key] = float64(typed) + case int64: + cleaned[key] = float64(typed) + default: + return nil, fmt.Errorf("values must be finite numbers") + } + } + return cleaned, nil +} + +func decodeChunkImageBase64(raw string) ([]byte, error) { + if strings.TrimSpace(raw) == "" { + return nil, fmt.Errorf("`image_base64` must be a non-empty string") + } + imageBinary, err := base64.StdEncoding.Strict().DecodeString(raw) + if err != nil { + return nil, fmt.Errorf("Invalid `image_base64`") + } + if len(imageBinary) == 0 { + return nil, fmt.Errorf("`image_base64` is empty") + } + return imageBinary, nil +} + +func mergeChunkEmbeddings(a, b []float64) ([]float64, error) { + if len(a) == 0 || len(b) == 0 || len(a) != len(b) { + return nil, fmt.Errorf("unexpected embedding dimensions") + } + merged := make([]float64, len(a)) + for i := range a { + merged[i] = 0.1*a[i] + 0.9*b[i] + } + return merged, nil +} + +func filterTrimmedStrings(values []string) []string { + filtered := make([]string, 0, len(values)) + for _, value := range values { + trimmed := strings.TrimSpace(value) + if trimmed != "" { + filtered = append(filtered, trimmed) + } + } + return filtered +} + +func (s *ChunkService) tokenize(text string) (string, error) { + if s.tokenizeFunc != nil { + return s.tokenizeFunc(text) + } + return tokenizer.Tokenize(text) +} + +func (s *ChunkService) fineGrainedTokenize(text string) (string, error) { + if s.fineGrainedTokenizeFunc != nil { + return s.fineGrainedTokenizeFunc(text) + } + return tokenizer.FineGrainedTokenize(text) +} + +func (s *ChunkService) numTokens(text string) int { + if s.numTokensFunc != nil { + return s.numTokensFunc(text) + } + return tokenizer.NumTokensFromString(text) +} + +func (s *ChunkService) getEmbeddingModel(tenantID, embdID string) (*models.EmbeddingModel, error) { + if s.getEmbeddingModelFunc != nil { + return s.getEmbeddingModelFunc(tenantID, embdID) + } + return service.NewModelProviderService().GetEmbeddingModel(tenantID, embdID) +} + +func (s *ChunkService) incrementChunkStats(docID, kbID string, tokenNum, chunkNum int64, duration float64) error { + if s.incrementChunkStatsFunc != nil { + return s.incrementChunkStatsFunc(docID, kbID, tokenNum, chunkNum, duration) + } + return dao.DB.Transaction(func(tx *gorm.DB) error { + result := tx.Model(&entity.Document{}). + Where("id = ? AND kb_id = ?", docID, kbID). + Updates(map[string]interface{}{ + "token_num": gorm.Expr("token_num + ?", tokenNum), + "chunk_num": gorm.Expr("chunk_num + ?", chunkNum), + "process_duration": gorm.Expr("process_duration + ?", duration), + }) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return fmt.Errorf("document not found") + } + + result = tx.Model(&entity.Knowledgebase{}). + Where("id = ?", kbID). + Updates(map[string]interface{}{ + "token_num": gorm.Expr("token_num + ?", tokenNum), + "chunk_num": gorm.Expr("chunk_num + ?", chunkNum), + }) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return fmt.Errorf("knowledgebase not found") + } + return nil + }) +} + +func (s *ChunkService) storeChunkImage(bucket, chunkID string, imageBinary []byte) error { + if s.storeChunkImageFunc != nil { + return s.storeChunkImageFunc(bucket, chunkID, imageBinary) + } + storageImpl := storage.GetStorageFactory().GetStorage() + if storageImpl == nil { + return fmt.Errorf("storage not initialized") + } + lockKey := bucket + "/" + chunkID + lock := acquireChunkImageMergeLock(lockKey) + lock.mu.Lock() + defer func() { + lock.mu.Unlock() + releaseChunkImageMergeLock(lockKey) + }() + + if !storageImpl.ObjExist(bucket, chunkID) { + return storageImpl.Put(bucket, chunkID, imageBinary) + } + + oldBinary, err := storageImpl.Get(bucket, chunkID) + if err != nil { + return err + } + oldImage, _, err := image.Decode(bytes.NewReader(oldBinary)) + if err != nil { + return err + } + newImage, _, err := image.Decode(bytes.NewReader(imageBinary)) + if err != nil { + return err + } + oldBounds, newBounds := oldImage.Bounds(), newImage.Bounds() + width := oldBounds.Dx() + if newBounds.Dx() > width { + width = newBounds.Dx() + } + height := oldBounds.Dy() + newBounds.Dy() + combined := image.NewRGBA(image.Rect(0, 0, width, height)) + draw.Draw(combined, combined.Bounds(), &image.Uniform{C: color.White}, image.Point{}, draw.Src) + draw.Draw(combined, oldBounds, oldImage, oldBounds.Min, draw.Src) + draw.Draw(combined, image.Rect(0, oldBounds.Dy(), newBounds.Dx(), oldBounds.Dy()+newBounds.Dy()), newImage, newBounds.Min, draw.Src) + + var buf bytes.Buffer + if err := jpeg.Encode(&buf, combined, nil); err != nil { + return err + } + return storageImpl.Put(bucket, chunkID, buf.Bytes()) +} + +func acquireChunkImageMergeLock(key string) *chunkImageMergeLock { + chunkImageMergeLocks.Lock() + defer chunkImageMergeLocks.Unlock() + + lock := chunkImageMergeLocks.locks[key] + if lock == nil { + lock = &chunkImageMergeLock{} + chunkImageMergeLocks.locks[key] = lock + } + lock.refs++ + return lock +} + +func releaseChunkImageMergeLock(key string) { + chunkImageMergeLocks.Lock() + defer chunkImageMergeLocks.Unlock() + + lock := chunkImageMergeLocks.locks[key] + if lock == nil { + return + } + lock.refs-- + if lock.refs == 0 { + delete(chunkImageMergeLocks.locks, key) + } +} diff --git a/internal/service/chunk/chunk_test.go b/internal/service/chunk/chunk_test.go index 88e9d113eb..04cf534567 100644 --- a/internal/service/chunk/chunk_test.go +++ b/internal/service/chunk/chunk_test.go @@ -1,16 +1,23 @@ package chunk import ( + "bytes" "context" "errors" + "image" + "image/color" + "image/png" "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/engine/types" "ragflow/internal/entity" + "ragflow/internal/entity/models" "ragflow/internal/service" + "ragflow/internal/storage" "reflect" "strings" "testing" + "time" "github.com/glebarez/sqlite" "gorm.io/gorm" @@ -412,6 +419,305 @@ func TestParseQueuesAndBeginsDocument(t *testing.T) { } } +func TestAddChunkSuccess(t *testing.T) { + db := setupChunkTestDB(t) + pushChunkTestDB(t, db) + userID, datasetID, documentID := "user-1", "kb-1", "doc-1" + insertChunkTestKB(t, datasetID, userID) + insertChunkTestDoc(t, documentID, datasetID) + + engine := &addChunkTestEngine{} + var incrementTokenNum, incrementChunkNum int64 + svc := &ChunkService{ + docEngine: engine, + kbDAO: dao.NewKnowledgebaseDAO(), + documentDAO: dao.NewDocumentDAO(), + accessibleFunc: func(datasetIDArg, userIDArg string) bool { + return datasetIDArg == datasetID && userIDArg == userID + }, + getKnowledgebaseByIDFunc: func(id string) (*entity.Knowledgebase, error) { + return &entity.Knowledgebase{ID: id, TenantID: userID, EmbdID: "embed-1"}, nil + }, + getEmbeddingModelFunc: func(string, string) (*models.EmbeddingModel, error) { + driver := &stubEmbeddingDriver{ + embeddings: []models.EmbeddingData{ + {Embedding: []float64{1, 2}}, + {Embedding: []float64{3, 4}}, + }, + } + modelName := "embed-1" + return models.NewEmbeddingModel(driver, &modelName, &models.APIConfig{}, 0), nil + }, + incrementChunkStatsFunc: func(docID, kbID string, tokenNum, chunkNum int64, duration float64) error { + if docID != documentID || kbID != datasetID || duration != 0 { + t.Fatalf("unexpected increment args doc=%s kb=%s duration=%v", docID, kbID, duration) + } + incrementTokenNum = tokenNum + incrementChunkNum = chunkNum + return nil + }, + tokenizeFunc: func(text string) (string, error) { return text, nil }, + fineGrainedTokenizeFunc: func(text string) (string, error) { return text + "_fg", nil }, + numTokensFunc: func(text string) int { return len(text) }, + } + + resp, err := svc.AddChunk(&service.AddChunkRequest{ + DatasetID: datasetID, + DocumentID: documentID, + Content: "chunk body", + ImportantKeywords: []string{"k1"}, + Questions: []string{" q1 ", ""}, + TagKwd: []string{"tag1"}, + TagFeas: map[string]interface{}{"tag1": float64(0.5)}, + }, userID) + if err != nil { + t.Fatalf("AddChunk() error = %v", err) + } + if resp == nil || resp.Chunk == nil { + t.Fatalf("expected chunk response, got %#v", resp) + } + if resp.Chunk["dataset_id"] != datasetID || resp.Chunk["document_id"] != documentID { + t.Fatalf("unexpected response chunk: %#v", resp.Chunk) + } + if resp.Chunk["content"] != "chunk body" { + t.Fatalf("content = %v, want chunk body", resp.Chunk["content"]) + } + if resp.Chunk["document"] != "doc-1.txt" { + t.Fatalf("document = %v, want doc-1.txt", resp.Chunk["document"]) + } + if incrementChunkNum != 1 { + t.Fatalf("increment chunk num = %d, want 1", incrementChunkNum) + } + if incrementTokenNum <= 0 { + t.Fatalf("increment token num = %d, want > 0", incrementTokenNum) + } + if len(engine.insertedChunks) != 1 { + t.Fatalf("inserted chunks = %d, want 1", len(engine.insertedChunks)) + } + inserted := engine.insertedChunks[0] + if inserted["doc_id"] != documentID || inserted["kb_id"] != datasetID { + t.Fatalf("unexpected inserted chunk: %#v", inserted) + } + if inserted["img_id"] != nil { + t.Fatalf("did not expect image id in inserted chunk: %#v", inserted) + } + vec, ok := inserted["q_2_vec"].([]float64) + if !ok { + t.Fatalf("expected q_2_vec []float64, got %T", inserted["q_2_vec"]) + } + if len(vec) != 2 || vec[0] < 2.7999 || vec[0] > 2.8001 || vec[1] < 3.7999 || vec[1] > 3.8001 { + t.Fatalf("vector = %v, want approximately [2.8 3.8]", vec) + } +} + +func TestAddChunkValidationErrors(t *testing.T) { + db := setupChunkTestDB(t) + pushChunkTestDB(t, db) + insertChunkTestKB(t, "kb-1", "user-1") + insertChunkTestDoc(t, "doc-1", "kb-1") + + svc := &ChunkService{ + docEngine: &addChunkTestEngine{}, + kbDAO: dao.NewKnowledgebaseDAO(), + documentDAO: dao.NewDocumentDAO(), + } + + tests := []struct { + name string + req *service.AddChunkRequest + setup func(*ChunkService) + wantMsg string + }{ + { + name: "nil request", + req: nil, + wantMsg: "invalid request payload", + }, + { + name: "empty content", + req: &service.AddChunkRequest{DatasetID: "kb-1", DocumentID: "doc-1", Content: " "}, + setup: func(svc *ChunkService) { + svc.accessibleFunc = func(string, string) bool { return true } + svc.getKnowledgebaseByIDFunc = func(string) (*entity.Knowledgebase, error) { + return &entity.Knowledgebase{ID: "kb-1", TenantID: "user-1", EmbdID: "embed-1"}, nil + } + }, + wantMsg: "`content` is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup(svc) + } + _, err := svc.AddChunk(tt.req, "user-1") + if err == nil || !strings.Contains(err.Error(), tt.wantMsg) { + t.Fatalf("error = %v, want substring %q", err, tt.wantMsg) + } + }) + } +} + +func TestAddChunkImageAndTagFeatureValidation(t *testing.T) { + db := setupChunkTestDB(t) + pushChunkTestDB(t, db) + userID, datasetID, documentID := "user-1", "kb-1", "doc-1" + insertChunkTestKB(t, datasetID, userID) + insertChunkTestDoc(t, documentID, datasetID) + + storeCalls := 0 + svc := &ChunkService{ + docEngine: &addChunkTestEngine{}, + kbDAO: dao.NewKnowledgebaseDAO(), + documentDAO: dao.NewDocumentDAO(), + accessibleFunc: func(string, string) bool { return true }, + getKnowledgebaseByIDFunc: func(id string) (*entity.Knowledgebase, error) { + return &entity.Knowledgebase{ID: id, TenantID: userID, EmbdID: "embed-1"}, nil + }, + tokenizeFunc: func(text string) (string, error) { return text, nil }, + fineGrainedTokenizeFunc: func(text string) (string, error) { return text + "_fg", nil }, + numTokensFunc: func(text string) int { return len(text) }, + getEmbeddingModelFunc: func(string, string) (*models.EmbeddingModel, error) { + driver := &stubEmbeddingDriver{ + embeddings: []models.EmbeddingData{ + {Embedding: []float64{1, 1}}, + {Embedding: []float64{1, 1}}, + }, + } + modelName := "embed-1" + return models.NewEmbeddingModel(driver, &modelName, &models.APIConfig{}, 0), nil + }, + incrementChunkStatsFunc: func(string, string, int64, int64, float64) error { return nil }, + storeChunkImageFunc: func(bucket, chunkID string, imageBinary []byte) error { + storeCalls++ + if bucket != datasetID || chunkID == "" || len(imageBinary) == 0 { + t.Fatalf("unexpected store args bucket=%s chunkID=%s len=%d", bucket, chunkID, len(imageBinary)) + } + return nil + }, + } + + _, err := svc.AddChunk(&service.AddChunkRequest{ + DatasetID: datasetID, + DocumentID: documentID, + Content: "chunk body", + ImageBase64: strPtr("not-base64"), + }, userID) + if err == nil || !strings.Contains(err.Error(), "Invalid `image_base64`") { + t.Fatalf("expected invalid image error, got %v", err) + } + + validJPEG := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO2pRZ0AAAAASUVORK5CYII=" + + resp, err := svc.AddChunk(&service.AddChunkRequest{ + DatasetID: datasetID, + DocumentID: documentID, + Content: "chunk body", + TagFeas: map[string]interface{}{"tag1": "bad"}, + }, userID) + if err == nil || !strings.Contains(err.Error(), "`tag_feas` values must be finite numbers") || resp != nil { + t.Fatalf("expected tag_feas validation error, got resp=%#v err=%v", resp, err) + } + + resp, err = svc.AddChunk(&service.AddChunkRequest{ + DatasetID: datasetID, + DocumentID: documentID, + Content: "chunk body", + TagFeas: map[string]interface{}{"tag1": float64(1)}, + ImageBase64: strPtr(validJPEG), + }, userID) + if err != nil { + t.Fatalf("AddChunk() with image error = %v", err) + } + if storeCalls != 1 { + t.Fatalf("store image calls = %d, want 1", storeCalls) + } + if _, ok := resp.Chunk["image_id"]; !ok { + t.Fatalf("expected image_id in response, got %#v", resp.Chunk) + } +} + +func TestAddChunkIncrementsStatsAfterInsert(t *testing.T) { + db := setupChunkTestDB(t) + pushChunkTestDB(t, db) + userID, datasetID, documentID := "user-1", "kb-1", "doc-1" + insertChunkTestKB(t, datasetID, userID) + insertChunkTestDoc(t, documentID, datasetID) + + var incrementCalls int + engine := &addChunkTestEngine{} + svc := &ChunkService{ + docEngine: engine, + kbDAO: dao.NewKnowledgebaseDAO(), + documentDAO: dao.NewDocumentDAO(), + accessibleFunc: func(string, string) bool { return true }, + getKnowledgebaseByIDFunc: func(id string) (*entity.Knowledgebase, error) { + return &entity.Knowledgebase{ID: id, TenantID: userID, EmbdID: "embed-1"}, nil + }, + getEmbeddingModelFunc: func(string, string) (*models.EmbeddingModel, error) { + driver := &stubEmbeddingDriver{ + embeddings: []models.EmbeddingData{ + {Embedding: []float64{1, 2}}, + {Embedding: []float64{3, 4}}, + }, + } + modelName := "embed-1" + return models.NewEmbeddingModel(driver, &modelName, &models.APIConfig{}, 0), nil + }, + incrementChunkStatsFunc: func(string, string, int64, int64, float64) error { + incrementCalls++ + return nil + }, + tokenizeFunc: func(text string) (string, error) { return text, nil }, + fineGrainedTokenizeFunc: func(text string) (string, error) { return text + "_fg", nil }, + numTokensFunc: func(text string) int { return len(text) }, + } + + _, err := svc.AddChunk(&service.AddChunkRequest{ + DatasetID: datasetID, + DocumentID: documentID, + Content: "chunk body", + }, userID) + if err != nil { + t.Fatalf("AddChunk() error = %v", err) + } + if incrementCalls != 1 { + t.Fatalf("increment calls = %d, want 1", incrementCalls) + } + if len(engine.insertedChunks) != 1 { + t.Fatalf("inserted chunks = %d, want 1", len(engine.insertedChunks)) + } + importantKwd, ok := engine.insertedChunks[0]["important_kwd"].([]string) + if !ok || len(importantKwd) != 0 { + t.Fatalf("important_kwd = %#v, want empty []string", engine.insertedChunks[0]["important_kwd"]) + } +} + +func TestStoreChunkImageMergesExistingImage(t *testing.T) { + oldImage := mustEncodePNG(t, image.Rect(0, 0, 2, 2)) + newImage := mustEncodePNG(t, image.Rect(0, 0, 1, 1)) + mockStorage := &chunkImageStorage{ + exists: true, + oldBinary: oldImage, + } + + factory := storage.GetStorageFactory() + originalStorage := factory.GetStorage() + factory.SetStorage(mockStorage) + t.Cleanup(func() { + factory.SetStorage(originalStorage) + }) + + svc := &ChunkService{} + if err := svc.storeChunkImage("kb-1", "chunk-1", newImage); err != nil { + t.Fatalf("storeChunkImage() error = %v", err) + } + if mockStorage.putCalls != 1 { + t.Fatalf("put calls = %d, want 1", mockStorage.putCalls) + } +} + func setupChunkTestDB(t *testing.T) *gorm.DB { t.Helper() @@ -620,6 +926,117 @@ func (e *parseTestDocEngine) GetChunkIDs([]map[string]interface{}) []string { return nil } +type addChunkTestEngine struct { + parseTestDocEngine + insertedChunks []map[string]interface{} + insertIndex string + insertDataset string + insertErr error +} + +func (e *addChunkTestEngine) InsertChunks(_ context.Context, chunks []map[string]interface{}, baseName string, datasetID string) ([]string, error) { + e.insertedChunks = chunks + e.insertIndex = baseName + e.insertDataset = datasetID + return nil, e.insertErr +} + +type chunkImageStorage struct { + exists bool + oldBinary []byte + putCalls int +} + +func (s *chunkImageStorage) Health() bool { return true } +func (s *chunkImageStorage) Put(bucket, fnm string, binary []byte, tenantID ...string) error { + s.putCalls++ + return nil +} +func (s *chunkImageStorage) Get(bucket, fnm string, tenantID ...string) ([]byte, error) { + return s.oldBinary, nil +} +func (s *chunkImageStorage) Remove(bucket, fnm string, tenantID ...string) error { return nil } +func (s *chunkImageStorage) ObjExist(bucket, fnm string, tenantID ...string) bool { return s.exists } +func (s *chunkImageStorage) GetPresignedURL(bucket, fnm string, expires time.Duration, tenantID ...string) (string, error) { + return "", nil +} +func (s *chunkImageStorage) BucketExists(bucket string) bool { return true } +func (s *chunkImageStorage) RemoveBucket(bucket string) error { return nil } +func (s *chunkImageStorage) Copy(srcBucket, srcPath, destBucket, destPath string) bool { return false } +func (s *chunkImageStorage) Move(srcBucket, srcPath, destBucket, destPath string) bool { return false } + +func mustEncodePNG(t *testing.T, rect image.Rectangle) []byte { + t.Helper() + + img := image.NewRGBA(rect) + for y := rect.Min.Y; y < rect.Max.Y; y++ { + for x := rect.Min.X; x < rect.Max.X; x++ { + img.Set(x, y, color.White) + } + } + + var buf bytes.Buffer + if err := png.Encode(&buf, img); err != nil { + t.Fatalf("encode png: %v", err) + } + return buf.Bytes() +} + +type stubEmbeddingDriver struct { + embeddings []models.EmbeddingData + embedErr error +} + +func (d *stubEmbeddingDriver) NewInstance(map[string]string) models.ModelDriver { return d } +func (d *stubEmbeddingDriver) Name() string { return "stub" } +func (d *stubEmbeddingDriver) ChatWithMessages(string, []models.Message, *models.APIConfig, *models.ChatConfig) (*models.ChatResponse, error) { + return nil, nil +} +func (d *stubEmbeddingDriver) ChatStreamlyWithSender(string, []models.Message, *models.APIConfig, *models.ChatConfig, func(*string, *string) error) error { + return nil +} +func (d *stubEmbeddingDriver) Embed(*string, []string, *models.APIConfig, *models.EmbeddingConfig) ([]models.EmbeddingData, error) { + return d.embeddings, d.embedErr +} +func (d *stubEmbeddingDriver) Rerank(*string, string, []string, *models.APIConfig, *models.RerankConfig) (*models.RerankResponse, error) { + return nil, nil +} +func (d *stubEmbeddingDriver) TranscribeAudio(*string, *string, *models.APIConfig, *models.ASRConfig) (*models.ASRResponse, error) { + return nil, nil +} +func (d *stubEmbeddingDriver) TranscribeAudioWithSender(*string, *string, *models.APIConfig, *models.ASRConfig, func(*string, *string) error) error { + return nil +} +func (d *stubEmbeddingDriver) AudioSpeech(*string, *string, *models.APIConfig, *models.TTSConfig) (*models.TTSResponse, error) { + return nil, nil +} +func (d *stubEmbeddingDriver) AudioSpeechWithSender(*string, *string, *models.APIConfig, *models.TTSConfig, func(*string, *string) error) error { + return nil +} +func (d *stubEmbeddingDriver) OCRFile(*string, []byte, *string, *models.APIConfig, *models.OCRConfig) (*models.OCRFileResponse, error) { + return nil, nil +} +func (d *stubEmbeddingDriver) ParseFile(*string, []byte, *string, *models.APIConfig, *models.ParseFileConfig) (*models.ParseFileResponse, error) { + return nil, nil +} +func (d *stubEmbeddingDriver) ListModels(*models.APIConfig) ([]models.ListModelResponse, error) { + return nil, nil +} +func (d *stubEmbeddingDriver) Balance(*models.APIConfig) (map[string]interface{}, error) { + return nil, nil +} +func (d *stubEmbeddingDriver) CheckConnection(*models.APIConfig) error { return nil } +func (d *stubEmbeddingDriver) ListTasks(*models.APIConfig) ([]models.ListTaskStatus, error) { + return nil, nil +} +func (d *stubEmbeddingDriver) ShowTask(string, *models.APIConfig) (*models.TaskResponse, error) { + return nil, nil +} + +func strPtr(v string) *string { + return &v +} + func (e *parseTestDocEngine) KNNScores(context.Context, []map[string]interface{}, []float64, int) (map[string]interface{}, error) { return nil, nil } diff --git a/internal/service/chunk_types.go b/internal/service/chunk_types.go index c04cfb3d4b..2e147895f2 100644 --- a/internal/service/chunk_types.go +++ b/internal/service/chunk_types.go @@ -95,6 +95,29 @@ type ParseFileRequest struct { DocumentIDs []string `json:"document_ids"` } +// AddChunkRequest request for adding a chunk +type AddChunkRequest struct { + DatasetID string `json:"dataset_id"` + DocumentID string `json:"document_id"` + Content string `json:"content"` + ImportantKeywords []string `json:"important_keywords,omitempty"` + Questions []string `json:"questions,omitempty"` + TagKwd []string `json:"tag_kwd,omitempty"` + TagFeas interface{} `json:"tag_feas,omitempty"` + ImageBase64 *string `json:"image_base64,omitempty"` +} + +// AddChunkResponse response for adding a chunk +type AddChunkResponse struct { + Chunk map[string]interface{} `json:"chunk"` +} + +// ErrorCoder exposes an application error code alongside an error string. +type ErrorCoder interface { + error + Code() common.ErrorCode +} + // Get retrieves a chunk by ID func (s *ChunkService) Get(req *GetChunkRequest, userID string) (*GetChunkResponse, error) { if s.docEngine == nil {