From dc8ff63f1dd1173b33836e6be15007999f5d99cb Mon Sep 17 00:00:00 2001 From: Hz_ Date: Wed, 24 Jun 2026 17:05:58 +0800 Subject: [PATCH] feat(go-api): add dataset tags endpoints (#16231) ## Summary - add `GET /api/v1/datasets/:dataset_id/tags` - add `PUT /api/v1/datasets/:dataset_id/tags` - implement dataset tag listing and rename flow - align rename tag validation and response shape with the Python API - add handler and service tests for dataset tags ## Routes - `GET /api/v1/datasets/:dataset_id/tags` - `PUT /api/v1/datasets/:dataset_id/tags` ## Test - Run specific tests for dataset tags: ``` go test -v ./internal/service ./internal/handler -run 'TestDatasetServiceListTags|TestDatasetServiceRenameTag|TestDatasetsHandlerListTags|TestDatasetsHandlerRenameTag' ``` - Run all tests for service and handler to verify no regressions: ``` go test ./internal/service ./internal/handler ``` - use curl cmd to test --- internal/handler/datasets.go | 71 ++++++ internal/handler/datasets_list_tags_test.go | 39 +++ internal/handler/datasets_rename_tag_test.go | 126 ++++++++++ internal/router/router.go | 2 + internal/service/dataset.go | 155 ++++++++++++ internal/service/dataset_list_tags_test.go | 238 +++++++++++++++++++ internal/service/dataset_rename_tag_test.go | 154 ++++++++++++ 7 files changed, 785 insertions(+) create mode 100644 internal/handler/datasets_list_tags_test.go create mode 100644 internal/handler/datasets_rename_tag_test.go create mode 100644 internal/service/dataset_list_tags_test.go create mode 100644 internal/service/dataset_rename_tag_test.go diff --git a/internal/handler/datasets.go b/internal/handler/datasets.go index a0636e4792..3982de068a 100644 --- a/internal/handler/datasets.go +++ b/internal/handler/datasets.go @@ -479,6 +479,77 @@ func (h *DatasetsHandler) GetKnowledgeGraph(c *gin.Context) { jsonResponse(c, common.CodeSuccess, result, "success") } +// ListTags handles GET /api/v1/datasets/:dataset_id/tags. +// @Summary List dataset tags +// @Description List tags for a dataset +// @Tags datasets +// @Produce json +// @Security ApiKeyAuth +// @Param dataset_id path string true "Dataset ID" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/datasets/{dataset_id}/tags [get] +func (h *DatasetsHandler) ListTags(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + datasetID := strings.TrimSpace(c.Param("dataset_id")) + result, code, err := h.datasetsService.ListTags(datasetID, user.ID) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, result, "success") +} + +type renameTagRequest struct { + FromTag string `json:"from_tag"` + ToTag string `json:"to_tag"` +} + +func (h *DatasetsHandler) RenameTag(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + datasetID := strings.TrimSpace(c.Param("dataset_id")) + + var payload map[string]interface{} + if err := c.ShouldBindJSON(&payload); err != nil { + jsonError(c, common.CodeDataError, "Lack of from_tag or to_tag in request body") + return + } + fromTagValue, hasFrom := payload["from_tag"] + toTagValue, hasTo := payload["to_tag"] + if !hasFrom || !hasTo { + jsonError(c, common.CodeDataError, "Lack of from_tag or to_tag in request body") + return + } + fromTag, okFrom := fromTagValue.(string) + toTag, okTo := toTagValue.(string) + if !okFrom || !okTo { + jsonError(c, common.CodeArgumentError, "from_tag and to_tag must be strings") + return + } + req := renameTagRequest{FromTag: fromTag, ToTag: toTag} + if strings.TrimSpace(req.FromTag) == "" || strings.TrimSpace(req.ToTag) == "" { + jsonError(c, common.CodeArgumentError, "from_tag and to_tag must not be empty") + return + } + + result, code, err := h.datasetsService.RenameTag(datasetID, user.ID, req.FromTag, req.ToTag) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, result, "success") +} + // DeleteKnowledgeGraph handles DELETE /api/v1/datasets/:dataset_id/graph. func (h *DatasetsHandler) DeleteKnowledgeGraph(c *gin.Context) { user, errorCode, errorMessage := GetUser(c) diff --git a/internal/handler/datasets_list_tags_test.go b/internal/handler/datasets_list_tags_test.go new file mode 100644 index 0000000000..371e6264d7 --- /dev/null +++ b/internal/handler/datasets_list_tags_test.go @@ -0,0 +1,39 @@ +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + + "ragflow/internal/common" +) + +func TestDatasetsHandlerListTagsRequiresAuth(t *testing.T) { + gin.SetMode(gin.TestMode) + h := &DatasetsHandler{} + r := gin.New() + r.GET("/api/v1/datasets/:dataset_id/tags", func(c *gin.Context) { + h.ListTags(c) + }) + + resp := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/datasets/123e4567-e89b-12d3-a456-426614174000/tags", nil) + r.ServeHTTP(resp, req) + + var body struct { + Code int `json:"code"` + Message string `json:"message"` + } + if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v body=%s", err, resp.Body.String()) + } + if body.Code != int(common.CodeUnauthorized) { + t.Fatalf("code=%d want=%d body=%s", body.Code, common.CodeUnauthorized, resp.Body.String()) + } + if body.Message != "User not found" { + t.Fatalf("message=%q want=%q", body.Message, "User not found") + } +} diff --git a/internal/handler/datasets_rename_tag_test.go b/internal/handler/datasets_rename_tag_test.go new file mode 100644 index 0000000000..d2dd17a323 --- /dev/null +++ b/internal/handler/datasets_rename_tag_test.go @@ -0,0 +1,126 @@ +package handler + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + + "ragflow/internal/common" + "ragflow/internal/entity" +) + +func TestDatasetsHandlerRenameTagRequiresAuth(t *testing.T) { + gin.SetMode(gin.TestMode) + h := &DatasetsHandler{} + r := gin.New() + r.PUT("/api/v1/datasets/:dataset_id/tags", func(c *gin.Context) { + h.RenameTag(c) + }) + + resp := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/datasets/123e4567-e89b-12d3-a456-426614174000/tags", bytes.NewBufferString(`{"from_tag":"a","to_tag":"b"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(resp, req) + + var body struct { + Code int `json:"code"` + Message string `json:"message"` + } + if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v body=%s", err, resp.Body.String()) + } + if body.Code != int(common.CodeUnauthorized) { + t.Fatalf("code=%d want=%d body=%s", body.Code, common.CodeUnauthorized, resp.Body.String()) + } + if body.Message != "User not found" { + t.Fatalf("message=%q want=%q", body.Message, "User not found") + } +} + +func TestDatasetsHandlerRenameTagRejectsMissingFields(t *testing.T) { + gin.SetMode(gin.TestMode) + h := &DatasetsHandler{} + r := gin.New() + r.PUT("/api/v1/datasets/:dataset_id/tags", func(c *gin.Context) { + c.Set("user", &entity.User{ID: "user-1"}) + h.RenameTag(c) + }) + + resp := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/datasets/123e4567-e89b-12d3-a456-426614174000/tags", bytes.NewBufferString(`{"from_tag":"a"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(resp, req) + + var body struct { + Code int `json:"code"` + Message string `json:"message"` + } + if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v body=%s", err, resp.Body.String()) + } + if body.Code != int(common.CodeDataError) { + t.Fatalf("code=%d want=%d body=%s", body.Code, common.CodeDataError, resp.Body.String()) + } +} + +func TestDatasetsHandlerRenameTagRejectsEmptyFields(t *testing.T) { + gin.SetMode(gin.TestMode) + h := &DatasetsHandler{} + r := gin.New() + r.PUT("/api/v1/datasets/:dataset_id/tags", func(c *gin.Context) { + c.Set("user", &entity.User{ID: "user-1"}) + h.RenameTag(c) + }) + + resp := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/datasets/123e4567-e89b-12d3-a456-426614174000/tags", bytes.NewBufferString(`{"from_tag":" ","to_tag":"x"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(resp, req) + + var body struct { + Code int `json:"code"` + Message string `json:"message"` + } + if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v body=%s", err, resp.Body.String()) + } + if body.Code != int(common.CodeArgumentError) { + t.Fatalf("code=%d want=%d body=%s", body.Code, common.CodeArgumentError, resp.Body.String()) + } + if body.Message != "from_tag and to_tag must not be empty" { + t.Fatalf("message=%q want=%q", body.Message, "from_tag and to_tag must not be empty") + } +} + +func TestDatasetsHandlerRenameTagRejectsNonStringFields(t *testing.T) { + gin.SetMode(gin.TestMode) + h := &DatasetsHandler{} + r := gin.New() + r.PUT("/api/v1/datasets/:dataset_id/tags", func(c *gin.Context) { + c.Set("user", &entity.User{ID: "user-1"}) + h.RenameTag(c) + }) + + resp := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/datasets/123e4567-e89b-12d3-a456-426614174000/tags", bytes.NewBufferString(`{"from_tag":1,"to_tag":"x"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(resp, req) + + var body struct { + Code int `json:"code"` + Message string `json:"message"` + } + if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v body=%s", err, resp.Body.String()) + } + if body.Code != int(common.CodeArgumentError) { + t.Fatalf("code=%d want=%d body=%s", body.Code, common.CodeArgumentError, resp.Body.String()) + } + if body.Message != "from_tag and to_tag must be strings" { + t.Fatalf("message=%q want=%q", body.Message, "from_tag and to_tag must be strings") + } +} diff --git a/internal/router/router.go b/internal/router/router.go index 6973962068..0b3fff9ba0 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -274,6 +274,8 @@ func (r *Router) Setup(engine *gin.Engine) { datasets.GET("/:dataset_id", r.datasetsHandler.GetDataset) datasets.PUT("/:dataset_id", r.datasetsHandler.UpdateDataset) datasets.GET("/:dataset_id/graph", r.datasetsHandler.GetKnowledgeGraph) + datasets.GET("/:dataset_id/tags", r.datasetsHandler.ListTags) + datasets.PUT("/:dataset_id/tags", r.datasetsHandler.RenameTag) datasets.DELETE("/:dataset_id/tags", r.datasetsHandler.RemoveTags) datasets.POST("/:dataset_id/documents/batch-update-status", r.documentHandler.BatchUpdateDocumentStatus) datasets.GET("/:dataset_id/index", r.datasetsHandler.TraceIndex) diff --git a/internal/service/dataset.go b/internal/service/dataset.go index 538b0e4379..62baef3133 100644 --- a/internal/service/dataset.go +++ b/internal/service/dataset.go @@ -2031,6 +2031,113 @@ func (s *DatasetService) AggregateTags(datasetIDs []string, userID string) ([]ma return result, common.CodeSuccess, nil } +func (s *DatasetService) ListTags(datasetID, userID string) ([]map[string]interface{}, common.ErrorCode, error) { + datasetID = strings.TrimSpace(datasetID) + if datasetID == "" { + return nil, common.CodeDataError, errors.New("Lack of \"Dataset ID\"") + } + + normalizedID, err := normalizeDatasetID(datasetID) + if err != nil { + return nil, common.CodeDataError, err + } + datasetID = normalizedID + + if !s.kbDAO.Accessible(datasetID, userID) { + return nil, common.CodeDataError, errors.New("No authorization.") + } + if s.docEngine == nil { + return nil, common.CodeServerError, errors.New("Document engine is not initialized") + } + + kb, err := s.kbDAO.GetByID(datasetID) + if err != nil || kb == nil { + return nil, common.CodeDataError, errors.New("Invalid Dataset ID") + } + + indexName := fmt.Sprintf("ragflow_%s", kb.TenantID) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + exists, err := s.docEngine.ChunkStoreExists(ctx, indexName, datasetID) + if err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to inspect chunk store: %w", err) + } + if !exists { + return []map[string]interface{}{}, common.CodeSuccess, nil + } + + const pageSize = 10000 + counts := make(map[string]int) + for offset := 0; ; offset += pageSize { + if err := ctx.Err(); err != nil { + return nil, common.CodeServerError, fmt.Errorf("list tags timeout or canceled: %w", err) + } + + searchResp, err := s.docEngine.Search(ctx, &types.SearchRequest{ + IndexNames: []string{indexName}, + KbIDs: []string{datasetID}, + Offset: offset, + Limit: pageSize, + SelectFields: []string{"tag_kwd"}, + }) + if err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to list tags: %w", err) + } + + for _, agg := range s.docEngine.GetAggregation(searchResp.Chunks, "tag_kwd") { + tag, _ := agg["key"].(string) + if tag == "" { + continue + } + switch count := agg["count"].(type) { + case int: + counts[tag] += count + case int32: + counts[tag] += int(count) + case int64: + counts[tag] += int(count) + case float64: + counts[tag] += int(count) + } + } + + chunkCount := len(searchResp.Chunks) + if chunkCount == 0 || chunkCount < pageSize { + break + } + if searchResp.Total > 0 && int64(offset+chunkCount) >= searchResp.Total { + break + } + } + + if len(counts) == 0 { + return []map[string]interface{}{}, common.CodeSuccess, nil + } + + tags := make([]string, 0, len(counts)) + for tag := range counts { + tags = append(tags, tag) + } + sort.Slice(tags, func(i, j int) bool { + if counts[tags[i]] != counts[tags[j]] { + return counts[tags[i]] > counts[tags[j]] + } + return tags[i] < tags[j] + }) + + result := make([]map[string]interface{}, 0, len(tags)) + for _, tag := range tags { + result = append(result, map[string]interface{}{ + "key": tag, + "count": counts[tag], + }) + } + + return result, common.CodeSuccess, nil +} + // GetIngestionSummary returns dataset-level ingestion counters together with // the aggregated document parsing status, mirroring // dataset_api_service.get_ingestion_summary. @@ -2544,3 +2651,51 @@ func limitStrings(values []string, limit int) []string { } return values[:limit] } + +func (s *DatasetService) RenameTag(datasetID, userID, fromTag, toTag string) (map[string]interface{}, common.ErrorCode, error) { + fromTag = strings.TrimSpace(fromTag) + toTag = strings.TrimSpace(toTag) + + datasetID, err := normalizeDatasetID(datasetID) + if err != nil { + return nil, common.CodeDataError, err + } + if strings.TrimSpace(datasetID) == "" { + return nil, common.CodeDataError, errors.New("Lack of \"Dataset ID\"") + } + if !s.kbDAO.Accessible(datasetID, userID) { + return nil, common.CodeDataError, errors.New("No authorization.") + } + if s.docEngine == nil { + return nil, common.CodeServerError, errors.New("Document engine is not initialized") + } + + kb, err := s.kbDAO.GetByID(datasetID) + if err != nil || kb == nil { + return nil, common.CodeDataError, errors.New("Invalid Dataset ID") + } + indexName := fmt.Sprintf("ragflow_%s", kb.TenantID) + + condition := map[string]interface{}{ + "tag_kwd": fromTag, + "kb_id": datasetID, + } + newValue := map[string]interface{}{ + "remove": map[string]interface{}{ + "tag_kwd": fromTag, + }, + "add": map[string]interface{}{ + "tag_kwd": toTag, + }, + } + + err = s.docEngine.UpdateChunks(context.Background(), condition, newValue, indexName, datasetID) + if err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to rename tag: %w", err) + } + + return map[string]interface{}{ + "from": fromTag, + "to": toTag, + }, common.CodeSuccess, nil +} diff --git a/internal/service/dataset_list_tags_test.go b/internal/service/dataset_list_tags_test.go new file mode 100644 index 0000000000..3bac34e945 --- /dev/null +++ b/internal/service/dataset_list_tags_test.go @@ -0,0 +1,238 @@ +package service + +import ( + "context" + "errors" + "strings" + "testing" + + "ragflow/internal/common" + "ragflow/internal/dao" + "ragflow/internal/engine" + "ragflow/internal/engine/types" + "ragflow/internal/entity" +) + +type listTagsMockEngine struct { + engine.DocEngine + searchResults map[string]*types.SearchResult + searchErr error + requests []*types.SearchRequest + chunkStoreExists bool + chunkStoreErr error +} + +func (m *listTagsMockEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) { + if m.searchErr != nil { + return nil, m.searchErr + } + cloned := &types.SearchRequest{ + IndexNames: append([]string(nil), req.IndexNames...), + KbIDs: append([]string(nil), req.KbIDs...), + Offset: req.Offset, + Limit: req.Limit, + SelectFields: append([]string(nil), req.SelectFields...), + } + m.requests = append(m.requests, cloned) + if len(req.IndexNames) == 0 { + return &types.SearchResult{}, nil + } + if res, ok := m.searchResults[req.IndexNames[0]]; ok { + return res, nil + } + return &types.SearchResult{}, nil +} + +func (m *listTagsMockEngine) GetAggregation(chunks []map[string]interface{}, fieldName string) []map[string]interface{} { + counts := make(map[string]int) + for _, chunk := range chunks { + raw, ok := chunk[fieldName] + if !ok || raw == nil { + continue + } + switch value := raw.(type) { + case string: + for _, tag := range strings.Split(value, "###") { + tag = strings.TrimSpace(tag) + if tag != "" { + counts[tag]++ + } + } + case []string: + for _, tag := range value { + tag = strings.TrimSpace(tag) + if tag != "" { + counts[tag]++ + } + } + } + } + + result := make([]map[string]interface{}, 0, len(counts)) + for tag, count := range counts { + result = append(result, map[string]interface{}{ + "key": tag, + "count": count, + }) + } + return result +} + +func (m *listTagsMockEngine) ChunkStoreExists(ctx context.Context, baseName, datasetID string) (bool, error) { + if m.chunkStoreErr != nil { + return false, m.chunkStoreErr + } + return m.chunkStoreExists, nil +} + +func (m *listTagsMockEngine) GetType() string { return "mock" } + +func (m *listTagsMockEngine) Ping(ctx context.Context) error { return nil } + +func (m *listTagsMockEngine) Close() error { return nil } + +func testDatasetServiceForListTags(t *testing.T, docEngine engine.DocEngine) *DatasetService { + t.Helper() + return &DatasetService{ + kbDAO: dao.NewKnowledgebaseDAO(), + docEngine: docEngine, + } +} + +func insertListTagsKB(t *testing.T, datasetID, tenantID, permission string, docNum int64) { + t.Helper() + kb := &entity.Knowledgebase{ + ID: datasetID, + TenantID: tenantID, + Name: "kb-" + datasetID[:6], + EmbdID: "embedding@OpenAI", + CreatedBy: tenantID, + Permission: permission, + ParserID: "naive", + ParserConfig: entity.JSONMap{}, + DocNum: docNum, + Status: sptr(string(entity.StatusValid)), + } + if err := dao.DB.Create(kb).Error; err != nil { + t.Fatalf("insert test kb: %v", err) + } +} + +func TestDatasetServiceListTagsSuccess(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + + kbInput := "123e4567-e89b-12d3-a456-426614174000" + kbID := strings.ReplaceAll(kbInput, "-", "") + insertListTagsKB(t, kbID, "user-1", string(entity.TenantPermissionMe), 2) + + docEngine := &listTagsMockEngine{ + chunkStoreExists: true, + searchResults: map[string]*types.SearchResult{ + "ragflow_user-1": { + Chunks: []map[string]interface{}{ + {"tag_kwd": "finance###urgent"}, + {"tag_kwd": "finance"}, + }, + }, + }, + } + + result, code, err := testDatasetServiceForListTags(t, docEngine).ListTags(kbInput, "user-1") + if err != nil { + t.Fatalf("ListTags failed: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("code=%d want=%d", code, common.CodeSuccess) + } + if len(result) != 2 { + t.Fatalf("len(result)=%d want=2 result=%v", len(result), result) + } + if result[0]["key"] != "finance" || result[0]["count"] != 2 { + t.Fatalf("first row=%v want finance/2", result[0]) + } + if result[1]["key"] != "urgent" || result[1]["count"] != 1 { + t.Fatalf("second row=%v want urgent/1", result[1]) + } + if len(docEngine.requests) != 1 { + t.Fatalf("search requests=%d want=1", len(docEngine.requests)) + } + req := docEngine.requests[0] + if len(req.IndexNames) != 1 || req.IndexNames[0] != "ragflow_user-1" { + t.Fatalf("IndexNames=%v want [ragflow_user-1]", req.IndexNames) + } + if len(req.KbIDs) != 1 || req.KbIDs[0] != kbID { + t.Fatalf("KbIDs=%v want [%s]", req.KbIDs, kbID) + } +} + +func TestDatasetServiceListTagsReturnsEmptyWhenChunkStoreMissing(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + + kbInput := "123e4567-e89b-12d3-a456-426614174000" + kbID := strings.ReplaceAll(kbInput, "-", "") + insertListTagsKB(t, kbID, "user-1", string(entity.TenantPermissionMe), 1) + + docEngine := &listTagsMockEngine{chunkStoreExists: false} + + result, code, err := testDatasetServiceForListTags(t, docEngine).ListTags(kbInput, "user-1") + if err != nil { + t.Fatalf("ListTags failed: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("code=%d want=%d", code, common.CodeSuccess) + } + if len(result) != 0 { + t.Fatalf("len(result)=%d want=0 result=%v", len(result), result) + } + if len(docEngine.requests) != 0 { + t.Fatalf("search requests=%d want=0", len(docEngine.requests)) + } +} + +func TestDatasetServiceListTagsRejectsUnauthorizedDataset(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + + kbInput := "123e4567-e89b-12d3-a456-426614174000" + kbID := strings.ReplaceAll(kbInput, "-", "") + insertListTagsKB(t, kbID, "tenant-9", string(entity.TenantPermissionMe), 1) + + docEngine := &listTagsMockEngine{chunkStoreExists: true} + + _, code, err := testDatasetServiceForListTags(t, docEngine).ListTags(kbInput, "user-1") + if err == nil { + t.Fatal("expected unauthorized error") + } + if code != common.CodeDataError { + t.Fatalf("code=%d want=%d", code, common.CodeDataError) + } + if err.Error() != "No authorization." { + t.Fatalf("error=%q want=%q", err.Error(), "No authorization.") + } +} + +func TestDatasetServiceListTagsReturnsChunkStoreError(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + + kbInput := "123e4567-e89b-12d3-a456-426614174000" + kbID := strings.ReplaceAll(kbInput, "-", "") + insertListTagsKB(t, kbID, "user-1", string(entity.TenantPermissionMe), 1) + + docEngine := &listTagsMockEngine{ + chunkStoreErr: errors.New("boom"), + } + + _, code, err := testDatasetServiceForListTags(t, docEngine).ListTags(kbInput, "user-1") + if err == nil { + t.Fatal("expected chunk store error") + } + if code != common.CodeServerError { + t.Fatalf("code=%d want=%d", code, common.CodeServerError) + } + if !strings.Contains(err.Error(), "failed to inspect chunk store: boom") { + t.Fatalf("err=%q want contains %q", err.Error(), "failed to inspect chunk store: boom") + } +} diff --git a/internal/service/dataset_rename_tag_test.go b/internal/service/dataset_rename_tag_test.go new file mode 100644 index 0000000000..1307bb41fb --- /dev/null +++ b/internal/service/dataset_rename_tag_test.go @@ -0,0 +1,154 @@ +package service + +import ( + "context" + "errors" + "strings" + "testing" + + "ragflow/internal/common" + "ragflow/internal/dao" + "ragflow/internal/engine" + "ragflow/internal/entity" +) + +type renameTagUpdateCall struct { + Condition map[string]interface{} + NewValue map[string]interface{} + BaseName string + DatasetID string +} + +type renameTagMockEngine struct { + engine.DocEngine + updateCalls []renameTagUpdateCall + updateErr error +} + +func (m *renameTagMockEngine) UpdateChunks(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, baseName string, datasetID string) error { + m.updateCalls = append(m.updateCalls, renameTagUpdateCall{ + Condition: condition, + NewValue: newValue, + BaseName: baseName, + DatasetID: datasetID, + }) + return m.updateErr +} + +func (m *renameTagMockEngine) GetType() string { return "mock" } +func (m *renameTagMockEngine) Ping(ctx context.Context) error { return nil } +func (m *renameTagMockEngine) Close() error { return nil } + +func testDatasetServiceForRenameTag(t *testing.T, docEngine engine.DocEngine) *DatasetService { + t.Helper() + return &DatasetService{ + kbDAO: dao.NewKnowledgebaseDAO(), + docEngine: docEngine, + } +} + +func insertRenameTagKB(t *testing.T, datasetID, tenantID, permission string) { + t.Helper() + kb := &entity.Knowledgebase{ + ID: datasetID, + TenantID: tenantID, + Name: "kb-" + datasetID[:6], + EmbdID: "embedding@OpenAI", + CreatedBy: tenantID, + Permission: permission, + ParserID: "naive", + ParserConfig: entity.JSONMap{}, + Status: sptr(string(entity.StatusValid)), + } + if err := dao.DB.Create(kb).Error; err != nil { + t.Fatalf("insert test kb: %v", err) + } +} + +func TestDatasetServiceRenameTagSuccess(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + + kbInput := "123e4567-e89b-12d3-a456-426614174000" + kbID := strings.ReplaceAll(kbInput, "-", "") + insertRenameTagKB(t, kbID, "user-1", string(entity.TenantPermissionMe)) + + docEngine := &renameTagMockEngine{} + result, code, err := testDatasetServiceForRenameTag(t, docEngine).RenameTag(kbInput, "user-1", "old-tag ", "new-tag") + if err != nil { + t.Fatalf("RenameTag failed: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("code=%d want=%d", code, common.CodeSuccess) + } + if result["from"] != "old-tag" || result["to"] != "new-tag" { + t.Fatalf("result=%v", result) + } + if len(docEngine.updateCalls) != 1 { + t.Fatalf("update calls=%d want=1", len(docEngine.updateCalls)) + } + + call := docEngine.updateCalls[0] + if call.BaseName != "ragflow_user-1" { + t.Fatalf("baseName=%q want=%q", call.BaseName, "ragflow_user-1") + } + if call.DatasetID != kbID { + t.Fatalf("datasetID=%q want=%q", call.DatasetID, kbID) + } + if got := call.Condition["tag_kwd"]; got != "old-tag" { + t.Fatalf("condition tag_kwd=%v want=%q", got, "old-tag") + } + if got := call.Condition["kb_id"]; got != kbID { + t.Fatalf("condition kb_id=%v want=%q", got, kbID) + } + + remove, _ := call.NewValue["remove"].(map[string]interface{}) + add, _ := call.NewValue["add"].(map[string]interface{}) + if got := remove["tag_kwd"]; got != "old-tag" { + t.Fatalf("remove tag_kwd=%v want=%q", got, "old-tag") + } + if got := add["tag_kwd"]; got != "new-tag" { + t.Fatalf("add tag_kwd=%v want=%q", got, "new-tag") + } +} + +func TestDatasetServiceRenameTagUnauthorized(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + + kbInput := "123e4567-e89b-12d3-a456-426614174000" + kbID := strings.ReplaceAll(kbInput, "-", "") + insertRenameTagKB(t, kbID, "tenant-9", string(entity.TenantPermissionMe)) + + _, code, err := testDatasetServiceForRenameTag(t, &renameTagMockEngine{}).RenameTag(kbInput, "user-1", "old", "new") + if err == nil { + t.Fatal("expected unauthorized error") + } + if code != common.CodeDataError { + t.Fatalf("code=%d want=%d", code, common.CodeDataError) + } + if err.Error() != "No authorization." { + t.Fatalf("error=%q want=%q", err.Error(), "No authorization.") + } +} + +func TestDatasetServiceRenameTagUpdateError(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + + kbInput := "123e4567-e89b-12d3-a456-426614174000" + kbID := strings.ReplaceAll(kbInput, "-", "") + insertRenameTagKB(t, kbID, "user-1", string(entity.TenantPermissionMe)) + + docEngine := &renameTagMockEngine{updateErr: errors.New("boom")} + _, code, err := testDatasetServiceForRenameTag(t, docEngine).RenameTag(kbInput, "user-1", "old", "new") + if err == nil { + t.Fatal("expected update error") + } + if code != common.CodeServerError { + t.Fatalf("code=%d want=%d", code, common.CodeServerError) + } + if !strings.Contains(err.Error(), "failed to rename tag") { + t.Fatalf("error=%q want contains %q", err.Error(), "failed to rename tag") + } +}