diff --git a/internal/handler/chunk.go b/internal/handler/chunk.go index ff75f155bc..fd8e63d7b7 100644 --- a/internal/handler/chunk.go +++ b/internal/handler/chunk.go @@ -17,8 +17,10 @@ package handler import ( "encoding/json" + "fmt" "net/http" "ragflow/internal/common" + "strconv" "strings" "github.com/gin-gonic/gin" @@ -32,6 +34,7 @@ type chunkService interface { RetrievalTest(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) Get(req *service.GetChunkRequest, userID string) (*service.GetChunkResponse, error) List(req *service.ListChunksRequest, userID string) (*service.ListChunksResponse, error) + SwitchChunks(userID, datasetID, documentID string, availableInt int, chunkIDs []string) error UpdateChunk(req *service.UpdateChunkRequest, userID string) error RemoveChunks(req *service.RemoveChunksRequest, userID string) (int64, error) Parse(userID, datasetID string, req *service.ParseFileRequest) (map[string]interface{}, common.ErrorCode, error) @@ -221,8 +224,8 @@ func (h *ChunkHandler) Parse(c *gin.Context) { }) return } - datasetId := strings.TrimSpace(c.Param("dataset_id")) - if datasetId == "" { + datasetID := strings.TrimSpace(c.Param("dataset_id")) + if datasetID == "" { c.JSON(http.StatusBadRequest, gin.H{ "code": common.CodeBadRequest, "message": "dataset_id is required", @@ -240,7 +243,7 @@ func (h *ChunkHandler) Parse(c *gin.Context) { return } - data, code, err := h.chunkService.Parse(userID, datasetId, &req) + data, code, err := h.chunkService.Parse(userID, datasetID, &req) if code != common.CodeSuccess { c.JSON(http.StatusOK, gin.H{ "code": code, @@ -257,6 +260,99 @@ func (h *ChunkHandler) Parse(c *gin.Context) { }) } +// ListChunks retrieves chunks for a document from path/query parameters. +func (h *ChunkHandler) ListChunks(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + datasetID := c.Param("dataset_id") + documentID := c.Param("document_id") + if datasetID == "" || documentID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeArgumentError, + "message": "dataset_id and document_id are required", + }) + return + } + + page, err := parsePositiveQueryInt(c, "page", 1) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeArgumentError, + "message": err.Error(), + }) + return + } + size, err := parsePositiveQueryInt(c, "page_size", 30) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeArgumentError, + "message": err.Error(), + }) + return + } + + req := service.ListChunksRequest{ + DatasetID: datasetID, + DocID: documentID, + Page: &page, + Size: &size, + Keywords: c.Query("keywords"), + } + available, ok, err := parseAvailableQuery(c.Query("available")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeArgumentError, + "message": err.Error(), + }) + return + } + if ok { + req.AvailableInt = &available + } + + resp, err := h.chunkService.List(&req, user.ID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": common.CodeServerError, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": int(common.CodeSuccess), + "data": resp, + "message": "success", + }) +} + +func parsePositiveQueryInt(c *gin.Context, name string, defaultValue int) (int, error) { + raw := strings.TrimSpace(c.Query(name)) + if raw == "" { + return defaultValue, nil + } + value, err := strconv.Atoi(raw) + if err != nil || value <= 0 { + return 0, fmt.Errorf("%s must be a positive integer", name) + } + return value, nil +} + +func parseAvailableQuery(raw string) (int, bool, error) { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "": + return 0, false, nil + case "true", "1": + return 1, true, nil + default: + return 0, true, fmt.Errorf("available must be one of: true, false, 1, 0") + } +} + // List retrieves chunks for a document. // @Summary List Chunks // @Description Retrieve paginated chunks for a document with optional filtering. @@ -309,6 +405,143 @@ func (h *ChunkHandler) List(c *gin.Context) { }) } +// SwitchChunks enable or disable a chunk +func (h *ChunkHandler) SwitchChunks(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + userID := strings.TrimSpace(user.ID) + if userID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeAuthenticationError, + "message": "user_id is required", + }) + return + } + + // Get required ID + datasetID := strings.TrimSpace(c.Param("dataset_id")) + if datasetID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeArgumentError, + "message": "dataset_id is required", + }) + return + } + + documentID := strings.TrimSpace(c.Param("document_id")) + if documentID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeArgumentError, + "message": "document_id is required", + }) + return + } + + var rawBody map[string]interface{} + if err := json.NewDecoder(c.Request.Body).Decode(&rawBody); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeArgumentError, + "message": err.Error(), + }) + return + } + + chunkIDs, ok := parseStringSlice(rawBody["chunk_ids"]) + if !ok || len(chunkIDs) == 0 { + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeBadRequest, + "message": "`chunk_ids` is required.", + }) + return + } + + if rawBody["available_int"] == nil && rawBody["available"] == nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeBadRequest, + "message": "`available_int` or `available` is required.", + }) + return + } + + availableInt, err := parseAvailableBody(rawBody) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeArgumentError, + "message": err.Error(), + }) + return + } + + if err := h.chunkService.SwitchChunks(userID, datasetID, documentID, availableInt, chunkIDs); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": common.CodeServerError, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "data": true, + "message": "success", + }) +} + +func parseStringSlice(raw interface{}) ([]string, bool) { + items, ok := raw.([]interface{}) + if !ok { + return nil, false + } + out := make([]string, 0, len(items)) + for _, item := range items { + s, ok := item.(string) + if !ok || strings.TrimSpace(s) == "" { + return nil, false + } + out = append(out, s) + } + return out, true +} + +func parseAvailableBody(rawBody map[string]interface{}) (int, error) { + if raw, ok := rawBody["available_int"]; ok { + switch v := raw.(type) { + case float64: + return int(v), nil + case int: + return v, nil + case bool: + if v { + return 1, nil + } + return 0, nil + default: + return 0, fmt.Errorf("available_int must be an integer") + } + } + if raw, ok := rawBody["available"]; ok { + switch v := raw.(type) { + case bool: + if v { + return 1, nil + } + return 0, nil + case float64: + if v != 0 { + return 1, nil + } + return 0, nil + default: + return 0, fmt.Errorf("available must be a boolean") + } + } + return 0, fmt.Errorf("`available_int` or `available` is required.") +} + // UpdateChunk updates a chunk // @Summary Update Chunk // @Description Update chunk fields @@ -336,29 +569,30 @@ func (h *ChunkHandler) UpdateChunk(c *gin.Context) { } // Get required ID fields - datasetID, ok := rawBody["dataset_id"].(string) - if !ok || datasetID == "" { + datasetID := strings.TrimSpace(c.Param("dataset_id")) + if datasetID == "" { c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, + "code": common.CodeArgumentError, "message": "dataset_id is required", }) return } - chunkID, ok := rawBody["chunk_id"].(string) - if !ok || chunkID == "" { + + chunkID := strings.TrimSpace(c.Param("chunk_id")) + if chunkID == "" { c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, + "code": common.CodeArgumentError, "message": "chunk_id is required", }) return } // Get document_id from request - documentID, ok := rawBody["document_id"].(string) - if !ok || documentID == "" { + documentID := strings.TrimSpace(c.Param("document_id")) + if documentID == "" { c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "doc_id is required", + "code": common.CodeArgumentError, + "message": "document_id is required", }) return } diff --git a/internal/handler/chunk_test.go b/internal/handler/chunk_test.go index e82b0d63c3..355a875b20 100644 --- a/internal/handler/chunk_test.go +++ b/internal/handler/chunk_test.go @@ -5,6 +5,7 @@ import ( "errors" "net/http" "net/http/httptest" + "reflect" "strings" "testing" @@ -19,6 +20,9 @@ import ( // Only the methods actually called by the test are set; others panic. type mockChunkSvc struct { retrievalTestFn func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) + listFn func(req *service.ListChunksRequest, userID string) (*service.ListChunksResponse, error) + switchChunksFn func(userID, datasetID, documentID string, availableInt int, chunkIDs []string) error + updateChunkFn func(req *service.UpdateChunkRequest, userID string) error } func (m *mockChunkSvc) RetrievalTest(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) { @@ -33,10 +37,22 @@ func (m *mockChunkSvc) RetrievalTest(req *service.RetrievalTestRequest, userID s func (m *mockChunkSvc) Get(*service.GetChunkRequest, string) (*service.GetChunkResponse, error) { panic("not implemented") } -func (m *mockChunkSvc) List(*service.ListChunksRequest, string) (*service.ListChunksResponse, error) { +func (m *mockChunkSvc) List(req *service.ListChunksRequest, userID string) (*service.ListChunksResponse, error) { + if m.listFn != nil { + return m.listFn(req, userID) + } panic("not implemented") } -func (m *mockChunkSvc) UpdateChunk(*service.UpdateChunkRequest, string) error { +func (m *mockChunkSvc) SwitchChunks(userID, datasetID, documentID string, availableInt int, chunkIDs []string) error { + if m.switchChunksFn != nil { + return m.switchChunksFn(userID, datasetID, documentID, availableInt, chunkIDs) + } + panic("not implemented") +} +func (m *mockChunkSvc) UpdateChunk(req *service.UpdateChunkRequest, userID string) error { + if m.updateChunkFn != nil { + return m.updateChunkFn(req, userID) + } panic("not implemented") } func (m *mockChunkSvc) RemoveChunks(*service.RemoveChunksRequest, string) (int64, error) { @@ -68,6 +84,144 @@ func setupChunkRetrievalTestNoAuth() *gin.Engine { return r } +func setupChunkHandlerWithUser(userID string, mock *mockChunkSvc) (*gin.Engine, *ChunkHandler) { + h := &ChunkHandler{chunkService: mock} + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("user", &entity.User{ID: userID}) + }) + return r, h +} + +func TestChunkHandlerListChunksMapsPathAndQuery(t *testing.T) { + mock := &mockChunkSvc{} + r, h := setupChunkHandlerWithUser("user-1", mock) + r.GET("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.ListChunks) + + mock.listFn = func(req *service.ListChunksRequest, userID string) (*service.ListChunksResponse, error) { + if userID != "user-1" { + t.Fatalf("userID = %q, want user-1", userID) + } + if req.DatasetID != "kb-1" || req.DocID != "doc-1" { + t.Fatalf("req ids = %q/%q, want kb-1/doc-1", req.DatasetID, req.DocID) + } + if req.Page == nil || *req.Page != 2 { + t.Fatalf("page = %v, want 2", req.Page) + } + if req.Size == nil || *req.Size != 5 { + t.Fatalf("size = %v, want 5", req.Size) + } + if req.Keywords != "AI" { + t.Fatalf("keywords = %q, want AI", req.Keywords) + } + if req.AvailableInt == nil || *req.AvailableInt != 1 { + t.Fatalf("available_int = %v, want 1", req.AvailableInt) + } + return &service.ListChunksResponse{ + Total: 1, + Chunks: []map[string]interface{}{ + {"id": "chunk-1"}, + }, + Doc: map[string]interface{}{"id": "doc-1"}, + }, nil + } + + req := httptest.NewRequest(http.MethodGet, "/api/v1/datasets/kb-1/documents/doc-1/chunks?page=2&page_size=5&keywords=AI&available=true", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, body = %s", w.Code, w.Body.String()) + } + var body map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { + t.Fatalf("invalid JSON response: %v", err) + } + if body["message"] != "success" { + t.Fatalf("message = %v, want success", body["message"]) + } +} + +func TestChunkHandlerSwitchChunksCallsService(t *testing.T) { + mock := &mockChunkSvc{} + r, h := setupChunkHandlerWithUser("user-1", mock) + r.PATCH("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.SwitchChunks) + + mock.switchChunksFn = func(userID, datasetID, documentID string, availableInt int, chunkIDs []string) error { + if userID != "user-1" || datasetID != "kb-1" || documentID != "doc-1" { + t.Fatalf("ids = %q/%q/%q, want user-1/kb-1/doc-1", userID, datasetID, documentID) + } + if availableInt != 0 { + t.Fatalf("availableInt = %d, want 0", availableInt) + } + if !reflect.DeepEqual(chunkIDs, []string{"chunk-1", "chunk-2"}) { + t.Fatalf("chunkIDs = %#v", chunkIDs) + } + return nil + } + + body := `{"chunk_ids":["chunk-1","chunk-2"],"available":false}` + req := httptest.NewRequest(http.MethodPatch, "/api/v1/datasets/kb-1/documents/doc-1/chunks", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, body = %s", w.Code, w.Body.String()) + } + var res map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &res); err != nil { + t.Fatalf("invalid JSON response: %v", err) + } + if res["data"] != true { + t.Fatalf("data = %v, want true", res["data"]) + } +} + +func TestChunkHandlerSwitchChunksRejectsMissingChunkIDs(t *testing.T) { + mock := &mockChunkSvc{} + r, h := setupChunkHandlerWithUser("user-1", mock) + r.PATCH("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.SwitchChunks) + + req := httptest.NewRequest(http.MethodPatch, "/api/v1/datasets/kb-1/documents/doc-1/chunks", strings.NewReader(`{"available":true}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("status = %d, body = %s", w.Code, w.Body.String()) + } +} + +func TestChunkHandlerUpdateChunkUsesPathIDs(t *testing.T) { + mock := &mockChunkSvc{} + r, h := setupChunkHandlerWithUser("user-1", mock) + r.PATCH("/api/v1/datasets/:dataset_id/documents/:document_id/chunks/:chunk_id", h.UpdateChunk) + + mock.updateChunkFn = func(req *service.UpdateChunkRequest, userID string) error { + if userID != "user-1" { + t.Fatalf("userID = %q, want user-1", userID) + } + if req.DatasetID != "kb-1" || req.DocumentID != "doc-1" || req.ChunkID != "chunk-1" { + t.Fatalf("ids = %q/%q/%q, want kb-1/doc-1/chunk-1", req.DatasetID, req.DocumentID, req.ChunkID) + } + if req.Content == nil || *req.Content != "updated" { + t.Fatalf("content = %v, want updated", req.Content) + } + return nil + } + + req := httptest.NewRequest(http.MethodPatch, "/api/v1/datasets/kb-1/documents/doc-1/chunks/chunk-1", strings.NewReader(`{"content":"updated"}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, body = %s", w.Code, w.Body.String()) + } +} + func TestChunkRetrieval_EmptyQuestion(t *testing.T) { r, _ := setupChunkRetrievalTest("user1") diff --git a/internal/router/router.go b/internal/router/router.go index 8a4d3cc942..f0d5f5d103 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -298,8 +298,11 @@ func (r *Router) Setup(engine *gin.Engine) { datasets.DELETE("/:dataset_id/documents", r.documentHandler.DeleteDocuments) // Dataset document chunk + datasets.GET("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.ListChunks) + datasets.PATCH("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.SwitchChunks) datasets.GET("/:dataset_id/documents/:document_id/chunks/:chunk_id", r.chunkHandler.Get) datasets.POST("/:dataset_id/chunks", r.chunkHandler.Parse) + datasets.PATCH("/:dataset_id/documents/:document_id/chunks/:chunk_id", r.chunkHandler.UpdateChunk) datasets.POST("/:dataset_id/documents/parse", r.documentHandler.StartIngestionTask) datasets.GET("/ingestion/tasks", r.documentHandler.ListIngestionTasks) datasets.PUT("/ingestion/tasks", r.documentHandler.StopIngestionTasks) diff --git a/internal/service/chunk/chunk.go b/internal/service/chunk/chunk.go index 4b608d8e09..1193b110b1 100644 --- a/internal/service/chunk/chunk.go +++ b/internal/service/chunk/chunk.go @@ -1176,6 +1176,9 @@ func (s *ChunkService) List(req *service.ListChunksRequest, userID string) (*ser if err != nil || doc == nil { return nil, fmt.Errorf("document not found") } + if req.DatasetID != "" && doc.KbID != req.DatasetID { + return nil, fmt.Errorf("document not found") + } // Get knowledge base to find tenant kb, err := s.kbDAO.GetByID(doc.KbID) @@ -1329,6 +1332,71 @@ func (s *ChunkService) List(req *service.ListChunksRequest, userID string) (*ser Doc: docInfo, }, nil } + +func (s *ChunkService) SwitchChunks(userID, datasetID, documentID string, availableInt int, chunkIDs []string) error { + if s.docEngine == nil { + return fmt.Errorf("doc engine not initialized") + } + + if availableInt != 0 && availableInt != 1 { + return fmt.Errorf("available_int should be 0 or 1") + } + + if chunkIDs == nil || len(chunkIDs) == 0 { + return fmt.Errorf("req is null") + } + + ctx := context.Background() + defer ctx.Done() + + // Get user's tenants + tenants, err := s.userTenantDAO.GetByUserID(userID) + if err != nil { + return fmt.Errorf("failed to get user tenants: %w", err) + } + if len(tenants) == 0 { + return fmt.Errorf("user has no accessible tenants") + } + + // Find the tenant that owns this dataset + var targetTenantID string + for _, tenant := range tenants { + kb, err := s.kbDAO.GetByIDAndTenantID(datasetID, tenant.TenantID) + if err == nil && kb != nil { + targetTenantID = tenant.TenantID + break + } + } + if targetTenantID == "" { + return fmt.Errorf("user does not have access to this dataset") + } + + docDAO := dao.NewDocumentDAO() + doc, err := docDAO.GetByID(documentID) + if err != nil || doc == nil { + return fmt.Errorf("document not found") + } + if doc.KbID != datasetID { + return fmt.Errorf("document does not belong to this dataset") + } + + for _, cid := range chunkIDs { + indexName := fmt.Sprintf("ragflow_%s", targetTenantID) + + if err = s.docEngine.UpdateChunks(ctx, map[string]interface{}{ + "id": cid, + "doc_id": documentID, + }, map[string]interface{}{ + "id": cid, + "available_int": availableInt, + }, indexName, datasetID); err != nil { + return err + } + } + + return nil +} + func (s *ChunkService) UpdateChunk(req *service.UpdateChunkRequest, userID string) error { if s.docEngine == nil { return fmt.Errorf("doc engine not initialized") diff --git a/internal/service/chunk/chunk_test.go b/internal/service/chunk/chunk_test.go index bc0a5981df..88e9d113eb 100644 --- a/internal/service/chunk/chunk_test.go +++ b/internal/service/chunk/chunk_test.go @@ -643,3 +643,183 @@ func (e *parseTestDocEngine) GetType() string { func (e *parseTestDocEngine) FilterDocIdsByMetaPushdown(context.Context, []string, []map[string]interface{}, string) []string { return nil } + +func TestSwitchChunksUpdatesDocEngineWithAvailableInt(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("failed to open sqlite: %v", err) + } + if err := db.AutoMigrate(&entity.UserTenant{}, &entity.Knowledgebase{}, &entity.Document{}); err != nil { + t.Fatalf("failed to migrate sqlite: %v", err) + } + previousDB := dao.DB + dao.DB = db + t.Cleanup(func() { dao.DB = previousDB }) + + valid := string(entity.StatusValid) + if err := db.Create(&entity.UserTenant{ + ID: "ut-1", + UserID: "user-1", + TenantID: "tenant-1", + Role: "owner", + Status: &valid, + }).Error; err != nil { + t.Fatalf("failed to create user_tenant: %v", err) + } + if err := db.Create(&entity.Knowledgebase{ + ID: "kb-1", + TenantID: "tenant-1", + Name: "dataset", + EmbdID: "embed", + Permission: string(entity.TenantPermissionMe), + CreatedBy: "user-1", + ParserID: string(entity.ParserTypeNaive), + ParserConfig: entity.JSONMap{}, + Status: &valid, + }).Error; err != nil { + t.Fatalf("failed to create knowledgebase: %v", err) + } + if err := db.Create(&entity.Document{ + ID: "doc-1", + KbID: "kb-1", + ParserID: string(entity.ParserTypeNaive), + ParserConfig: entity.JSONMap{}, + SourceType: "local", + Type: "doc", + CreatedBy: "user-1", + Suffix: "txt", + }).Error; err != nil { + t.Fatalf("failed to create document: %v", err) + } + + engine := &switchChunksEngineMock{} + svc := &ChunkService{ + docEngine: engine, + kbDAO: dao.NewKnowledgebaseDAO(), + userTenantDAO: dao.NewUserTenantDAO(), + } + + if err := svc.SwitchChunks("user-1", "kb-1", "doc-1", 0, []string{"chunk-1", "chunk-2"}); err != nil { + t.Fatalf("SwitchChunks() error = %v", err) + } + + if len(engine.updateCalls) != 2 { + t.Fatalf("UpdateChunks calls = %d, want 2", len(engine.updateCalls)) + } + for i, call := range engine.updateCalls { + if call.indexName != "ragflow_tenant-1" { + t.Fatalf("call %d indexName = %q", i, call.indexName) + } + if call.datasetID != "kb-1" { + t.Fatalf("call %d datasetID = %q", i, call.datasetID) + } + wantID := []string{"chunk-1", "chunk-2"}[i] + if !reflect.DeepEqual(call.condition, map[string]interface{}{ + "id": wantID, + "doc_id": "doc-1", + }) { + t.Fatalf("call %d condition = %#v", i, call.condition) + } + if !reflect.DeepEqual(call.newValue, map[string]interface{}{"id": wantID, "available_int": 0}) { + t.Fatalf("call %d newValue = %#v", i, call.newValue) + } + } +} + +type updateChunksCall struct { + condition map[string]interface{} + newValue map[string]interface{} + indexName string + datasetID string +} + +type switchChunksEngineMock struct { + updateCalls []updateChunksCall +} + +func (m *switchChunksEngineMock) CreateChunkStore(context.Context, string, string, int, string) error { + return nil +} +func (m *switchChunksEngineMock) InsertChunks(context.Context, []map[string]interface{}, string, string) ([]string, error) { + return nil, nil +} +func (m *switchChunksEngineMock) UpdateChunks(_ context.Context, condition map[string]interface{}, newValue map[string]interface{}, indexName string, datasetID string) error { + m.updateCalls = append(m.updateCalls, updateChunksCall{ + condition: copyMap(condition), + newValue: copyMap(newValue), + indexName: indexName, + datasetID: datasetID, + }) + return nil +} +func (m *switchChunksEngineMock) DeleteChunks(context.Context, map[string]interface{}, string, string) (int64, error) { + return 0, nil +} +func (m *switchChunksEngineMock) Search(context.Context, *types.SearchRequest) (*types.SearchResult, error) { + return nil, nil +} +func (m *switchChunksEngineMock) GetChunk(context.Context, string, string, []string) (interface{}, error) { + return nil, nil +} +func (m *switchChunksEngineMock) DropChunkStore(context.Context, string, string) error { return nil } +func (m *switchChunksEngineMock) ChunkStoreExists(context.Context, string, string) (bool, error) { + return false, nil +} +func (m *switchChunksEngineMock) CreateMetadataStore(context.Context, string) error { return nil } +func (m *switchChunksEngineMock) InsertMetadata(context.Context, []map[string]interface{}, string) ([]string, error) { + return nil, nil +} +func (m *switchChunksEngineMock) UpdateMetadata(context.Context, string, string, map[string]interface{}, string) error { + return nil +} +func (m *switchChunksEngineMock) DeleteMetadata(context.Context, map[string]interface{}, string) (int64, error) { + return 0, nil +} +func (m *switchChunksEngineMock) DeleteMetadataKeys(context.Context, string, string, []string, string) error { + return nil +} +func (m *switchChunksEngineMock) DropMetadataStore(context.Context, string) error { return nil } +func (m *switchChunksEngineMock) MetadataStoreExists(context.Context, string) (bool, error) { + return false, nil +} +func (m *switchChunksEngineMock) SearchMetadata(context.Context, *types.SearchMetadataRequest) (*types.SearchMetadataResult, error) { + return nil, nil +} +func (m *switchChunksEngineMock) IndexDocument(context.Context, string, string, interface{}) error { + return nil +} +func (m *switchChunksEngineMock) DeleteDocument(context.Context, string, string) error { return nil } +func (m *switchChunksEngineMock) BulkIndex(context.Context, string, []interface{}) (interface{}, error) { + return nil, nil +} +func (m *switchChunksEngineMock) GetFields([]map[string]interface{}, []string) map[string]map[string]interface{} { + return nil +} +func (m *switchChunksEngineMock) GetAggregation([]map[string]interface{}, string) []map[string]interface{} { + return nil +} +func (m *switchChunksEngineMock) GetHighlight([]map[string]interface{}, []string, string) map[string]string { + return nil +} +func (m *switchChunksEngineMock) RunSQL(context.Context, string, string, []string, string) ([]map[string]interface{}, error) { + return nil, nil +} +func (m *switchChunksEngineMock) GetChunkIDs([]map[string]interface{}) []string { return nil } +func (m *switchChunksEngineMock) KNNScores(context.Context, []map[string]interface{}, []float64, int) (map[string]interface{}, error) { + return nil, nil +} +func (m *switchChunksEngineMock) GetScores(map[string]interface{}) map[string]float64 { return nil } +func (m *switchChunksEngineMock) Ping(context.Context) error { return nil } +func (m *switchChunksEngineMock) Close() error { return nil } +func (m *switchChunksEngineMock) GetType() string { return "elasticsearch" } +func (m *switchChunksEngineMock) FilterDocIdsByMetaPushdown(context.Context, []string, []map[string]interface{}, string) []string { + return nil +} + +func copyMap(in map[string]interface{}) map[string]interface{} { + out := make(map[string]interface{}, len(in)) + for k, v := range in { + out[k] = v + } + return out +} diff --git a/internal/service/chunk_types.go b/internal/service/chunk_types.go index 244770eb07..3628ea42c1 100644 --- a/internal/service/chunk_types.go +++ b/internal/service/chunk_types.go @@ -164,6 +164,7 @@ func (s *ChunkService) Get(req *GetChunkRequest, userID string) (*GetChunkRespon // ListChunksRequest request for listing chunks type ListChunksRequest struct { + DatasetID string `json:"dataset_id,omitempty"` DocID string `json:"doc_id" binding:"required"` Page *int `json:"page,omitempty"` Size *int `json:"size,omitempty"` @@ -205,6 +206,9 @@ func (s *ChunkService) List(req *ListChunksRequest, userID string) (*ListChunksR if err != nil || doc == nil { return nil, fmt.Errorf("document not found") } + if req.DatasetID != "" && doc.KbID != req.DatasetID { + return nil, fmt.Errorf("document not found") + } // Get knowledge base to find tenant kb, err := s.kbDAO.GetByID(doc.KbID)