mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
feat[Go]: implement <document_id>/chunks/<chunk_id> PATCH (#16232)
### What problem does this PR solve? Implement: 1. `/api/v1/datasets/<dataset_id>/documents/<document_id>/chunks GET` 2. `/api/v1/datasets/<dataset_id>/documents/<document_id>/chunks/<chunk_id> PATCH` 3. `/api/v1/datasets/<dataset_id>/documents/<document_id>/chunks PATCH` ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -17,8 +17,10 @@ package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"ragflow/internal/common"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -32,6 +34,7 @@ type chunkService interface {
|
||||
RetrievalTest(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error)
|
||||
Get(req *service.GetChunkRequest, userID string) (*service.GetChunkResponse, error)
|
||||
List(req *service.ListChunksRequest, userID string) (*service.ListChunksResponse, error)
|
||||
SwitchChunks(userID, datasetID, documentID string, availableInt int, chunkIDs []string) error
|
||||
UpdateChunk(req *service.UpdateChunkRequest, userID string) error
|
||||
RemoveChunks(req *service.RemoveChunksRequest, userID string) (int64, error)
|
||||
Parse(userID, datasetID string, req *service.ParseFileRequest) (map[string]interface{}, common.ErrorCode, error)
|
||||
@@ -221,8 +224,8 @@ func (h *ChunkHandler) Parse(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
datasetId := strings.TrimSpace(c.Param("dataset_id"))
|
||||
if datasetId == "" {
|
||||
datasetID := strings.TrimSpace(c.Param("dataset_id"))
|
||||
if datasetID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeBadRequest,
|
||||
"message": "dataset_id is required",
|
||||
@@ -240,7 +243,7 @@ func (h *ChunkHandler) Parse(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
data, code, err := h.chunkService.Parse(userID, datasetId, &req)
|
||||
data, code, err := h.chunkService.Parse(userID, datasetID, &req)
|
||||
if code != common.CodeSuccess {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
@@ -257,6 +260,99 @@ func (h *ChunkHandler) Parse(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// ListChunks retrieves chunks for a document from path/query parameters.
|
||||
func (h *ChunkHandler) ListChunks(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
datasetID := c.Param("dataset_id")
|
||||
documentID := c.Param("document_id")
|
||||
if datasetID == "" || documentID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": "dataset_id and document_id are required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
page, err := parsePositiveQueryInt(c, "page", 1)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
size, err := parsePositiveQueryInt(c, "page_size", 30)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
req := service.ListChunksRequest{
|
||||
DatasetID: datasetID,
|
||||
DocID: documentID,
|
||||
Page: &page,
|
||||
Size: &size,
|
||||
Keywords: c.Query("keywords"),
|
||||
}
|
||||
available, ok, err := parseAvailableQuery(c.Query("available"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
req.AvailableInt = &available
|
||||
}
|
||||
|
||||
resp, err := h.chunkService.List(&req, user.ID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": int(common.CodeSuccess),
|
||||
"data": resp,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
func parsePositiveQueryInt(c *gin.Context, name string, defaultValue int) (int, error) {
|
||||
raw := strings.TrimSpace(c.Query(name))
|
||||
if raw == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
value, err := strconv.Atoi(raw)
|
||||
if err != nil || value <= 0 {
|
||||
return 0, fmt.Errorf("%s must be a positive integer", name)
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func parseAvailableQuery(raw string) (int, bool, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case "":
|
||||
return 0, false, nil
|
||||
case "true", "1":
|
||||
return 1, true, nil
|
||||
default:
|
||||
return 0, true, fmt.Errorf("available must be one of: true, false, 1, 0")
|
||||
}
|
||||
}
|
||||
|
||||
// List retrieves chunks for a document.
|
||||
// @Summary List Chunks
|
||||
// @Description Retrieve paginated chunks for a document with optional filtering.
|
||||
@@ -309,6 +405,143 @@ func (h *ChunkHandler) List(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// SwitchChunks enable or disable a chunk
|
||||
func (h *ChunkHandler) SwitchChunks(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
userID := strings.TrimSpace(user.ID)
|
||||
if userID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeAuthenticationError,
|
||||
"message": "user_id is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get required ID
|
||||
datasetID := strings.TrimSpace(c.Param("dataset_id"))
|
||||
if datasetID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": "dataset_id is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
documentID := strings.TrimSpace(c.Param("document_id"))
|
||||
if documentID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": "document_id is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var rawBody map[string]interface{}
|
||||
if err := json.NewDecoder(c.Request.Body).Decode(&rawBody); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chunkIDs, ok := parseStringSlice(rawBody["chunk_ids"])
|
||||
if !ok || len(chunkIDs) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeBadRequest,
|
||||
"message": "`chunk_ids` is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if rawBody["available_int"] == nil && rawBody["available"] == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeBadRequest,
|
||||
"message": "`available_int` or `available` is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
availableInt, err := parseAvailableBody(rawBody)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.chunkService.SwitchChunks(userID, datasetID, documentID, availableInt, chunkIDs); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"data": true,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
func parseStringSlice(raw interface{}) ([]string, bool) {
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
out := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
s, ok := item.(string)
|
||||
if !ok || strings.TrimSpace(s) == "" {
|
||||
return nil, false
|
||||
}
|
||||
out = append(out, s)
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
|
||||
func parseAvailableBody(rawBody map[string]interface{}) (int, error) {
|
||||
if raw, ok := rawBody["available_int"]; ok {
|
||||
switch v := raw.(type) {
|
||||
case float64:
|
||||
return int(v), nil
|
||||
case int:
|
||||
return v, nil
|
||||
case bool:
|
||||
if v {
|
||||
return 1, nil
|
||||
}
|
||||
return 0, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("available_int must be an integer")
|
||||
}
|
||||
}
|
||||
if raw, ok := rawBody["available"]; ok {
|
||||
switch v := raw.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
return 1, nil
|
||||
}
|
||||
return 0, nil
|
||||
case float64:
|
||||
if v != 0 {
|
||||
return 1, nil
|
||||
}
|
||||
return 0, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("available must be a boolean")
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("`available_int` or `available` is required.")
|
||||
}
|
||||
|
||||
// UpdateChunk updates a chunk
|
||||
// @Summary Update Chunk
|
||||
// @Description Update chunk fields
|
||||
@@ -336,29 +569,30 @@ func (h *ChunkHandler) UpdateChunk(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get required ID fields
|
||||
datasetID, ok := rawBody["dataset_id"].(string)
|
||||
if !ok || datasetID == "" {
|
||||
datasetID := strings.TrimSpace(c.Param("dataset_id"))
|
||||
if datasetID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"code": common.CodeArgumentError,
|
||||
"message": "dataset_id is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
chunkID, ok := rawBody["chunk_id"].(string)
|
||||
if !ok || chunkID == "" {
|
||||
|
||||
chunkID := strings.TrimSpace(c.Param("chunk_id"))
|
||||
if chunkID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"code": common.CodeArgumentError,
|
||||
"message": "chunk_id is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get document_id from request
|
||||
documentID, ok := rawBody["document_id"].(string)
|
||||
if !ok || documentID == "" {
|
||||
documentID := strings.TrimSpace(c.Param("document_id"))
|
||||
if documentID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "doc_id is required",
|
||||
"code": common.CodeArgumentError,
|
||||
"message": "document_id is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -19,6 +20,9 @@ import (
|
||||
// Only the methods actually called by the test are set; others panic.
|
||||
type mockChunkSvc struct {
|
||||
retrievalTestFn func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error)
|
||||
listFn func(req *service.ListChunksRequest, userID string) (*service.ListChunksResponse, error)
|
||||
switchChunksFn func(userID, datasetID, documentID string, availableInt int, chunkIDs []string) error
|
||||
updateChunkFn func(req *service.UpdateChunkRequest, userID string) error
|
||||
}
|
||||
|
||||
func (m *mockChunkSvc) RetrievalTest(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
|
||||
@@ -33,10 +37,22 @@ func (m *mockChunkSvc) RetrievalTest(req *service.RetrievalTestRequest, userID s
|
||||
func (m *mockChunkSvc) Get(*service.GetChunkRequest, string) (*service.GetChunkResponse, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
func (m *mockChunkSvc) List(*service.ListChunksRequest, string) (*service.ListChunksResponse, error) {
|
||||
func (m *mockChunkSvc) List(req *service.ListChunksRequest, userID string) (*service.ListChunksResponse, error) {
|
||||
if m.listFn != nil {
|
||||
return m.listFn(req, userID)
|
||||
}
|
||||
panic("not implemented")
|
||||
}
|
||||
func (m *mockChunkSvc) UpdateChunk(*service.UpdateChunkRequest, string) error {
|
||||
func (m *mockChunkSvc) SwitchChunks(userID, datasetID, documentID string, availableInt int, chunkIDs []string) error {
|
||||
if m.switchChunksFn != nil {
|
||||
return m.switchChunksFn(userID, datasetID, documentID, availableInt, chunkIDs)
|
||||
}
|
||||
panic("not implemented")
|
||||
}
|
||||
func (m *mockChunkSvc) UpdateChunk(req *service.UpdateChunkRequest, userID string) error {
|
||||
if m.updateChunkFn != nil {
|
||||
return m.updateChunkFn(req, userID)
|
||||
}
|
||||
panic("not implemented")
|
||||
}
|
||||
func (m *mockChunkSvc) RemoveChunks(*service.RemoveChunksRequest, string) (int64, error) {
|
||||
@@ -68,6 +84,144 @@ func setupChunkRetrievalTestNoAuth() *gin.Engine {
|
||||
return r
|
||||
}
|
||||
|
||||
func setupChunkHandlerWithUser(userID string, mock *mockChunkSvc) (*gin.Engine, *ChunkHandler) {
|
||||
h := &ChunkHandler{chunkService: mock}
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: userID})
|
||||
})
|
||||
return r, h
|
||||
}
|
||||
|
||||
func TestChunkHandlerListChunksMapsPathAndQuery(t *testing.T) {
|
||||
mock := &mockChunkSvc{}
|
||||
r, h := setupChunkHandlerWithUser("user-1", mock)
|
||||
r.GET("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.ListChunks)
|
||||
|
||||
mock.listFn = func(req *service.ListChunksRequest, userID string) (*service.ListChunksResponse, error) {
|
||||
if userID != "user-1" {
|
||||
t.Fatalf("userID = %q, want user-1", userID)
|
||||
}
|
||||
if req.DatasetID != "kb-1" || req.DocID != "doc-1" {
|
||||
t.Fatalf("req ids = %q/%q, want kb-1/doc-1", req.DatasetID, req.DocID)
|
||||
}
|
||||
if req.Page == nil || *req.Page != 2 {
|
||||
t.Fatalf("page = %v, want 2", req.Page)
|
||||
}
|
||||
if req.Size == nil || *req.Size != 5 {
|
||||
t.Fatalf("size = %v, want 5", req.Size)
|
||||
}
|
||||
if req.Keywords != "AI" {
|
||||
t.Fatalf("keywords = %q, want AI", req.Keywords)
|
||||
}
|
||||
if req.AvailableInt == nil || *req.AvailableInt != 1 {
|
||||
t.Fatalf("available_int = %v, want 1", req.AvailableInt)
|
||||
}
|
||||
return &service.ListChunksResponse{
|
||||
Total: 1,
|
||||
Chunks: []map[string]interface{}{
|
||||
{"id": "chunk-1"},
|
||||
},
|
||||
Doc: map[string]interface{}{"id": "doc-1"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/datasets/kb-1/documents/doc-1/chunks?page=2&page_size=5&keywords=AI&available=true", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, body = %s", w.Code, w.Body.String())
|
||||
}
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("invalid JSON response: %v", err)
|
||||
}
|
||||
if body["message"] != "success" {
|
||||
t.Fatalf("message = %v, want success", body["message"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkHandlerSwitchChunksCallsService(t *testing.T) {
|
||||
mock := &mockChunkSvc{}
|
||||
r, h := setupChunkHandlerWithUser("user-1", mock)
|
||||
r.PATCH("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.SwitchChunks)
|
||||
|
||||
mock.switchChunksFn = func(userID, datasetID, documentID string, availableInt int, chunkIDs []string) error {
|
||||
if userID != "user-1" || datasetID != "kb-1" || documentID != "doc-1" {
|
||||
t.Fatalf("ids = %q/%q/%q, want user-1/kb-1/doc-1", userID, datasetID, documentID)
|
||||
}
|
||||
if availableInt != 0 {
|
||||
t.Fatalf("availableInt = %d, want 0", availableInt)
|
||||
}
|
||||
if !reflect.DeepEqual(chunkIDs, []string{"chunk-1", "chunk-2"}) {
|
||||
t.Fatalf("chunkIDs = %#v", chunkIDs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
body := `{"chunk_ids":["chunk-1","chunk-2"],"available":false}`
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/v1/datasets/kb-1/documents/doc-1/chunks", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, body = %s", w.Code, w.Body.String())
|
||||
}
|
||||
var res map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &res); err != nil {
|
||||
t.Fatalf("invalid JSON response: %v", err)
|
||||
}
|
||||
if res["data"] != true {
|
||||
t.Fatalf("data = %v, want true", res["data"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkHandlerSwitchChunksRejectsMissingChunkIDs(t *testing.T) {
|
||||
mock := &mockChunkSvc{}
|
||||
r, h := setupChunkHandlerWithUser("user-1", mock)
|
||||
r.PATCH("/api/v1/datasets/:dataset_id/documents/:document_id/chunks", h.SwitchChunks)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/v1/datasets/kb-1/documents/doc-1/chunks", strings.NewReader(`{"available":true}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, body = %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkHandlerUpdateChunkUsesPathIDs(t *testing.T) {
|
||||
mock := &mockChunkSvc{}
|
||||
r, h := setupChunkHandlerWithUser("user-1", mock)
|
||||
r.PATCH("/api/v1/datasets/:dataset_id/documents/:document_id/chunks/:chunk_id", h.UpdateChunk)
|
||||
|
||||
mock.updateChunkFn = func(req *service.UpdateChunkRequest, userID string) error {
|
||||
if userID != "user-1" {
|
||||
t.Fatalf("userID = %q, want user-1", userID)
|
||||
}
|
||||
if req.DatasetID != "kb-1" || req.DocumentID != "doc-1" || req.ChunkID != "chunk-1" {
|
||||
t.Fatalf("ids = %q/%q/%q, want kb-1/doc-1/chunk-1", req.DatasetID, req.DocumentID, req.ChunkID)
|
||||
}
|
||||
if req.Content == nil || *req.Content != "updated" {
|
||||
t.Fatalf("content = %v, want updated", req.Content)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/v1/datasets/kb-1/documents/doc-1/chunks/chunk-1", strings.NewReader(`{"content":"updated"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, body = %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkRetrieval_EmptyQuestion(t *testing.T) {
|
||||
r, _ := setupChunkRetrievalTest("user1")
|
||||
|
||||
|
||||
@@ -298,8 +298,11 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
datasets.DELETE("/:dataset_id/documents", r.documentHandler.DeleteDocuments)
|
||||
|
||||
// Dataset document chunk
|
||||
datasets.GET("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.ListChunks)
|
||||
datasets.PATCH("/:dataset_id/documents/:document_id/chunks", r.chunkHandler.SwitchChunks)
|
||||
datasets.GET("/:dataset_id/documents/:document_id/chunks/:chunk_id", r.chunkHandler.Get)
|
||||
datasets.POST("/:dataset_id/chunks", r.chunkHandler.Parse)
|
||||
datasets.PATCH("/:dataset_id/documents/:document_id/chunks/:chunk_id", r.chunkHandler.UpdateChunk)
|
||||
datasets.POST("/:dataset_id/documents/parse", r.documentHandler.StartIngestionTask)
|
||||
datasets.GET("/ingestion/tasks", r.documentHandler.ListIngestionTasks)
|
||||
datasets.PUT("/ingestion/tasks", r.documentHandler.StopIngestionTasks)
|
||||
|
||||
@@ -1176,6 +1176,9 @@ func (s *ChunkService) List(req *service.ListChunksRequest, userID string) (*ser
|
||||
if err != nil || doc == nil {
|
||||
return nil, fmt.Errorf("document not found")
|
||||
}
|
||||
if req.DatasetID != "" && doc.KbID != req.DatasetID {
|
||||
return nil, fmt.Errorf("document not found")
|
||||
}
|
||||
|
||||
// Get knowledge base to find tenant
|
||||
kb, err := s.kbDAO.GetByID(doc.KbID)
|
||||
@@ -1329,6 +1332,71 @@ func (s *ChunkService) List(req *service.ListChunksRequest, userID string) (*ser
|
||||
Doc: docInfo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ChunkService) SwitchChunks(userID, datasetID, documentID string, availableInt int, chunkIDs []string) error {
|
||||
if s.docEngine == nil {
|
||||
return fmt.Errorf("doc engine not initialized")
|
||||
}
|
||||
|
||||
if availableInt != 0 && availableInt != 1 {
|
||||
return fmt.Errorf("available_int should be 0 or 1")
|
||||
}
|
||||
|
||||
if chunkIDs == nil || len(chunkIDs) == 0 {
|
||||
return fmt.Errorf("req is null")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
defer ctx.Done()
|
||||
|
||||
// Get user's tenants
|
||||
tenants, err := s.userTenantDAO.GetByUserID(userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user tenants: %w", err)
|
||||
}
|
||||
if len(tenants) == 0 {
|
||||
return fmt.Errorf("user has no accessible tenants")
|
||||
}
|
||||
|
||||
// Find the tenant that owns this dataset
|
||||
var targetTenantID string
|
||||
for _, tenant := range tenants {
|
||||
kb, err := s.kbDAO.GetByIDAndTenantID(datasetID, tenant.TenantID)
|
||||
if err == nil && kb != nil {
|
||||
targetTenantID = tenant.TenantID
|
||||
break
|
||||
}
|
||||
}
|
||||
if targetTenantID == "" {
|
||||
return fmt.Errorf("user does not have access to this dataset")
|
||||
}
|
||||
|
||||
docDAO := dao.NewDocumentDAO()
|
||||
doc, err := docDAO.GetByID(documentID)
|
||||
if err != nil || doc == nil {
|
||||
return fmt.Errorf("document not found")
|
||||
}
|
||||
if doc.KbID != datasetID {
|
||||
return fmt.Errorf("document does not belong to this dataset")
|
||||
}
|
||||
|
||||
for _, cid := range chunkIDs {
|
||||
indexName := fmt.Sprintf("ragflow_%s", targetTenantID)
|
||||
|
||||
if err = s.docEngine.UpdateChunks(ctx, map[string]interface{}{
|
||||
"id": cid,
|
||||
"doc_id": documentID,
|
||||
}, map[string]interface{}{
|
||||
"id": cid,
|
||||
"available_int": availableInt,
|
||||
}, indexName, datasetID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ChunkService) UpdateChunk(req *service.UpdateChunkRequest, userID string) error {
|
||||
if s.docEngine == nil {
|
||||
return fmt.Errorf("doc engine not initialized")
|
||||
|
||||
@@ -643,3 +643,183 @@ func (e *parseTestDocEngine) GetType() string {
|
||||
func (e *parseTestDocEngine) FilterDocIdsByMetaPushdown(context.Context, []string, []map[string]interface{}, string) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSwitchChunksUpdatesDocEngineWithAvailableInt(t *testing.T) {
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&entity.UserTenant{}, &entity.Knowledgebase{}, &entity.Document{}); err != nil {
|
||||
t.Fatalf("failed to migrate sqlite: %v", err)
|
||||
}
|
||||
previousDB := dao.DB
|
||||
dao.DB = db
|
||||
t.Cleanup(func() { dao.DB = previousDB })
|
||||
|
||||
valid := string(entity.StatusValid)
|
||||
if err := db.Create(&entity.UserTenant{
|
||||
ID: "ut-1",
|
||||
UserID: "user-1",
|
||||
TenantID: "tenant-1",
|
||||
Role: "owner",
|
||||
Status: &valid,
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("failed to create user_tenant: %v", err)
|
||||
}
|
||||
if err := db.Create(&entity.Knowledgebase{
|
||||
ID: "kb-1",
|
||||
TenantID: "tenant-1",
|
||||
Name: "dataset",
|
||||
EmbdID: "embed",
|
||||
Permission: string(entity.TenantPermissionMe),
|
||||
CreatedBy: "user-1",
|
||||
ParserID: string(entity.ParserTypeNaive),
|
||||
ParserConfig: entity.JSONMap{},
|
||||
Status: &valid,
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("failed to create knowledgebase: %v", err)
|
||||
}
|
||||
if err := db.Create(&entity.Document{
|
||||
ID: "doc-1",
|
||||
KbID: "kb-1",
|
||||
ParserID: string(entity.ParserTypeNaive),
|
||||
ParserConfig: entity.JSONMap{},
|
||||
SourceType: "local",
|
||||
Type: "doc",
|
||||
CreatedBy: "user-1",
|
||||
Suffix: "txt",
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("failed to create document: %v", err)
|
||||
}
|
||||
|
||||
engine := &switchChunksEngineMock{}
|
||||
svc := &ChunkService{
|
||||
docEngine: engine,
|
||||
kbDAO: dao.NewKnowledgebaseDAO(),
|
||||
userTenantDAO: dao.NewUserTenantDAO(),
|
||||
}
|
||||
|
||||
if err := svc.SwitchChunks("user-1", "kb-1", "doc-1", 0, []string{"chunk-1", "chunk-2"}); err != nil {
|
||||
t.Fatalf("SwitchChunks() error = %v", err)
|
||||
}
|
||||
|
||||
if len(engine.updateCalls) != 2 {
|
||||
t.Fatalf("UpdateChunks calls = %d, want 2", len(engine.updateCalls))
|
||||
}
|
||||
for i, call := range engine.updateCalls {
|
||||
if call.indexName != "ragflow_tenant-1" {
|
||||
t.Fatalf("call %d indexName = %q", i, call.indexName)
|
||||
}
|
||||
if call.datasetID != "kb-1" {
|
||||
t.Fatalf("call %d datasetID = %q", i, call.datasetID)
|
||||
}
|
||||
wantID := []string{"chunk-1", "chunk-2"}[i]
|
||||
if !reflect.DeepEqual(call.condition, map[string]interface{}{
|
||||
"id": wantID,
|
||||
"doc_id": "doc-1",
|
||||
}) {
|
||||
t.Fatalf("call %d condition = %#v", i, call.condition)
|
||||
}
|
||||
if !reflect.DeepEqual(call.newValue, map[string]interface{}{"id": wantID, "available_int": 0}) {
|
||||
t.Fatalf("call %d newValue = %#v", i, call.newValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type updateChunksCall struct {
|
||||
condition map[string]interface{}
|
||||
newValue map[string]interface{}
|
||||
indexName string
|
||||
datasetID string
|
||||
}
|
||||
|
||||
type switchChunksEngineMock struct {
|
||||
updateCalls []updateChunksCall
|
||||
}
|
||||
|
||||
func (m *switchChunksEngineMock) CreateChunkStore(context.Context, string, string, int, string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) InsertChunks(context.Context, []map[string]interface{}, string, string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) UpdateChunks(_ context.Context, condition map[string]interface{}, newValue map[string]interface{}, indexName string, datasetID string) error {
|
||||
m.updateCalls = append(m.updateCalls, updateChunksCall{
|
||||
condition: copyMap(condition),
|
||||
newValue: copyMap(newValue),
|
||||
indexName: indexName,
|
||||
datasetID: datasetID,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) DeleteChunks(context.Context, map[string]interface{}, string, string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) Search(context.Context, *types.SearchRequest) (*types.SearchResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) GetChunk(context.Context, string, string, []string) (interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) DropChunkStore(context.Context, string, string) error { return nil }
|
||||
func (m *switchChunksEngineMock) ChunkStoreExists(context.Context, string, string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) CreateMetadataStore(context.Context, string) error { return nil }
|
||||
func (m *switchChunksEngineMock) InsertMetadata(context.Context, []map[string]interface{}, string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) UpdateMetadata(context.Context, string, string, map[string]interface{}, string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) DeleteMetadata(context.Context, map[string]interface{}, string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) DeleteMetadataKeys(context.Context, string, string, []string, string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) DropMetadataStore(context.Context, string) error { return nil }
|
||||
func (m *switchChunksEngineMock) MetadataStoreExists(context.Context, string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) SearchMetadata(context.Context, *types.SearchMetadataRequest) (*types.SearchMetadataResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) IndexDocument(context.Context, string, string, interface{}) error {
|
||||
return nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) DeleteDocument(context.Context, string, string) error { return nil }
|
||||
func (m *switchChunksEngineMock) BulkIndex(context.Context, string, []interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) GetFields([]map[string]interface{}, []string) map[string]map[string]interface{} {
|
||||
return nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) GetAggregation([]map[string]interface{}, string) []map[string]interface{} {
|
||||
return nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) GetHighlight([]map[string]interface{}, []string, string) map[string]string {
|
||||
return nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) RunSQL(context.Context, string, string, []string, string) ([]map[string]interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) GetChunkIDs([]map[string]interface{}) []string { return nil }
|
||||
func (m *switchChunksEngineMock) KNNScores(context.Context, []map[string]interface{}, []float64, int) (map[string]interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *switchChunksEngineMock) GetScores(map[string]interface{}) map[string]float64 { return nil }
|
||||
func (m *switchChunksEngineMock) Ping(context.Context) error { return nil }
|
||||
func (m *switchChunksEngineMock) Close() error { return nil }
|
||||
func (m *switchChunksEngineMock) GetType() string { return "elasticsearch" }
|
||||
func (m *switchChunksEngineMock) FilterDocIdsByMetaPushdown(context.Context, []string, []map[string]interface{}, string) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyMap(in map[string]interface{}) map[string]interface{} {
|
||||
out := make(map[string]interface{}, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -164,6 +164,7 @@ func (s *ChunkService) Get(req *GetChunkRequest, userID string) (*GetChunkRespon
|
||||
|
||||
// ListChunksRequest request for listing chunks
|
||||
type ListChunksRequest struct {
|
||||
DatasetID string `json:"dataset_id,omitempty"`
|
||||
DocID string `json:"doc_id" binding:"required"`
|
||||
Page *int `json:"page,omitempty"`
|
||||
Size *int `json:"size,omitempty"`
|
||||
@@ -205,6 +206,9 @@ func (s *ChunkService) List(req *ListChunksRequest, userID string) (*ListChunksR
|
||||
if err != nil || doc == nil {
|
||||
return nil, fmt.Errorf("document not found")
|
||||
}
|
||||
if req.DatasetID != "" && doc.KbID != req.DatasetID {
|
||||
return nil, fmt.Errorf("document not found")
|
||||
}
|
||||
|
||||
// Get knowledge base to find tenant
|
||||
kb, err := s.kbDAO.GetByID(doc.KbID)
|
||||
|
||||
Reference in New Issue
Block a user