diff --git a/internal/handler/datasets.go b/internal/handler/datasets.go index 2fa78b9d4b..59b2355e9a 100644 --- a/internal/handler/datasets.go +++ b/internal/handler/datasets.go @@ -592,6 +592,44 @@ func (h *DatasetsHandler) RemoveTags(c *gin.Context) { jsonResponse(c, common.CodeSuccess, true, "success") } +// AggregateTags handles GET /api/v1/datasets/tags/aggregation. +// @Summary Aggregate dataset tags +// @Description Aggregate tags across multiple datasets +// @Tags datasets +// @Produce json +// @Security ApiKeyAuth +// @Param dataset_ids query string true "Comma-separated dataset IDs" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/datasets/tags/aggregation [get] +func (h *DatasetsHandler) AggregateTags(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + rawIDs := strings.Split(c.Query("dataset_ids"), ",") + datasetIDs := make([]string, 0, len(rawIDs)) + + for _, rawID := range rawIDs { + tempID := strings.TrimSpace(rawID) + if tempID != "" { + datasetIDs = append(datasetIDs, tempID) + } + } + if len(datasetIDs) == 0 { + jsonError(c, common.CodeDataError, "Lack of dataset_ids in query parameters") + return + } + + result, code, err := h.datasetsService.AggregateTags(datasetIDs, user.ID) + if err != nil { + jsonError(c, code, err.Error()) + return + } + jsonResponse(c, common.CodeSuccess, result, "success") +} + // RunIndex Run an indexing task (graph/raptor/mindmap) for a dataset. func (h *DatasetsHandler) RunIndex(c *gin.Context) { user, errorCode, errorMessage := GetUser(c) diff --git a/internal/handler/datasets_aggregate_tags_test.go b/internal/handler/datasets_aggregate_tags_test.go new file mode 100644 index 0000000000..1a041f6f4d --- /dev/null +++ b/internal/handler/datasets_aggregate_tags_test.go @@ -0,0 +1,78 @@ +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + + "ragflow/internal/common" + "ragflow/internal/entity" +) + +func newAggregateTagsHandlerRouter(authenticated bool) *gin.Engine { + gin.SetMode(gin.TestMode) + h := &DatasetsHandler{} + r := gin.New() + r.GET("/api/v1/datasets/tags/aggregation", func(c *gin.Context) { + if authenticated { + c.Set("user", &entity.User{ID: "user-1"}) + } + h.AggregateTags(c) + }) + return r +} + +func TestDatasetsHandlerAggregateTagsRequiresDatasetIDs(t *testing.T) { + r := newAggregateTagsHandlerRouter(true) + + resp := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/datasets/tags/aggregation", nil) + r.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status=%d body=%s", resp.Code, resp.Body.String()) + } + + var body struct { + Code int `json:"code"` + Data interface{} `json:"data"` + Message string `json:"message"` + } + if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v body=%s", err, resp.Body.String()) + } + if body.Code != int(common.CodeDataError) { + t.Fatalf("code=%d want=%d body=%s", body.Code, common.CodeDataError, resp.Body.String()) + } + if body.Message != "Lack of dataset_ids in query parameters" { + t.Fatalf("message=%q want=%q", body.Message, "Lack of dataset_ids in query parameters") + } + if body.Data != nil { + t.Fatalf("data=%v want nil", body.Data) + } +} + +func TestDatasetsHandlerAggregateTagsRequiresAuth(t *testing.T) { + r := newAggregateTagsHandlerRouter(false) + + resp := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/datasets/tags/aggregation?dataset_ids=123e4567-e89b-12d3-a456-426614174000", nil) + r.ServeHTTP(resp, req) + + var body struct { + Code int `json:"code"` + Message string `json:"message"` + } + if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v body=%s", err, resp.Body.String()) + } + if body.Code != int(common.CodeUnauthorized) { + t.Fatalf("code=%d want=%d body=%s", body.Code, common.CodeUnauthorized, resp.Body.String()) + } + if body.Message != "User not found" { + t.Fatalf("message=%q want=%q", body.Message, "User not found") + } +} diff --git a/internal/router/router.go b/internal/router/router.go index 26bf3670ae..608228756c 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -269,6 +269,7 @@ func (r *Router) Setup(engine *gin.Engine) { datasets := v1.Group("/datasets") { datasets.GET("", r.datasetsHandler.ListDatasets) + datasets.GET("/tags/aggregation", r.datasetsHandler.AggregateTags) datasets.GET("/:dataset_id", r.datasetsHandler.GetDataset) datasets.PUT("/:dataset_id", r.datasetsHandler.UpdateDataset) datasets.GET("/:dataset_id/graph", r.datasetsHandler.GetKnowledgeGraph) diff --git a/internal/service/dataset.go b/internal/service/dataset.go index 0d2b76fdfd..ad80850835 100644 --- a/internal/service/dataset.go +++ b/internal/service/dataset.go @@ -25,6 +25,7 @@ import ( "ragflow/internal/dao" "ragflow/internal/engine" redisengine "ragflow/internal/engine/redis" + "ragflow/internal/engine/types" "ragflow/internal/entity" "ragflow/internal/entity/models" "ragflow/internal/server" @@ -1831,6 +1832,90 @@ func (s *DatasetService) Accessible(kbID, userID string) bool { return s.kbDAO.Accessible(kbID, userID) } +func (s *DatasetService) AggregateTags(datasetIDs []string, userID string) ([]map[string]interface{}, common.ErrorCode, error) { + if len(datasetIDs) == 0 { + return nil, common.CodeDataError, errors.New("Lack of dataset_ids in query parameters") + } + if s.docEngine == nil { + return nil, common.CodeServerError, errors.New("Document engine is not initialized") + } + + datasetIDsByTenant := make(map[string][]string) + for _, rawID := range datasetIDs { + rawID = strings.TrimSpace(rawID) + if rawID == "" { + continue + } + datasetID, err := normalizeDatasetUUID1(rawID) + if err != nil { + return nil, common.CodeDataError, err + } + if !s.kbDAO.Accessible(datasetID, userID) { + return nil, common.CodeDataError, fmt.Errorf("No authorization for dataset '%s'", datasetID) + } + kb, err := s.kbDAO.GetByID(datasetID) + if err != nil { + if dao.IsNotFoundErr(err) { + return nil, common.CodeDataError, fmt.Errorf("Invalid Dataset ID '%s'", datasetID) + } + return nil, common.CodeServerError, errors.New("Database operation failed") + } + if kb.DocNum <= 0 { + continue + } + datasetIDsByTenant[kb.TenantID] = append(datasetIDsByTenant[kb.TenantID], datasetID) + } + + const pageSize = 10000 + merged := make(map[string]int) + for tenantID, kbIDs := range datasetIDsByTenant { + for offset := 0; ; offset += pageSize { + searchResp, err := s.docEngine.Search(context.Background(), &types.SearchRequest{ + IndexNames: []string{fmt.Sprintf("ragflow_%s", tenantID)}, + KbIDs: kbIDs, + Offset: offset, + Limit: pageSize, + SelectFields: []string{"tag_kwd"}, + }) + if err != nil { + return nil, common.CodeServerError, fmt.Errorf("failed to aggregate tags: %w", err) + } + for _, agg := range s.docEngine.GetAggregation(searchResp.Chunks, "tag_kwd") { + tag, _ := agg["key"].(string) + if tag == "" { + continue + } + switch count := agg["count"].(type) { + case int: + merged[tag] += count + case int32: + merged[tag] += int(count) + case int64: + merged[tag] += int(count) + case float64: + merged[tag] += int(count) + } + } + + chunkCount := len(searchResp.Chunks) + if chunkCount == 0 || chunkCount < pageSize { + break + } + if searchResp.Total > 0 && int64(offset+chunkCount) >= searchResp.Total { + break + } + } + } + result := make([]map[string]interface{}, 0, len(merged)) + for tag, count := range merged { + result = append(result, map[string]interface{}{ + "value": tag, + "count": count, + }) + } + return result, common.CodeSuccess, nil +} + // GetIngestionSummary returns dataset-level ingestion counters together with // the aggregated document parsing status, mirroring // dataset_api_service.get_ingestion_summary. diff --git a/internal/service/dataset_aggregate_tags_test.go b/internal/service/dataset_aggregate_tags_test.go new file mode 100644 index 0000000000..29cac97fba --- /dev/null +++ b/internal/service/dataset_aggregate_tags_test.go @@ -0,0 +1,377 @@ +package service + +import ( + "context" + "errors" + "strings" + "testing" + + "ragflow/internal/common" + "ragflow/internal/dao" + "ragflow/internal/engine" + "ragflow/internal/engine/types" + "ragflow/internal/entity" +) + +type aggregateTagsMockEngine struct { + engine.DocEngine + searchResults map[string]*types.SearchResult + pagedSearchResults map[string]map[int]*types.SearchResult + searchErr error + requests []*types.SearchRequest +} + +func (m *aggregateTagsMockEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) { + if m.searchErr != nil { + return nil, m.searchErr + } + cloned := &types.SearchRequest{ + IndexNames: append([]string(nil), req.IndexNames...), + KbIDs: append([]string(nil), req.KbIDs...), + Offset: req.Offset, + Limit: req.Limit, + SelectFields: append([]string(nil), req.SelectFields...), + } + m.requests = append(m.requests, cloned) + if len(req.IndexNames) == 0 { + return &types.SearchResult{}, nil + } + if byOffset, ok := m.pagedSearchResults[req.IndexNames[0]]; ok { + if res, ok := byOffset[req.Offset]; ok { + return res, nil + } + return &types.SearchResult{}, nil + } + if res, ok := m.searchResults[req.IndexNames[0]]; ok { + return res, nil + } + return &types.SearchResult{}, nil +} + +func (m *aggregateTagsMockEngine) GetAggregation(chunks []map[string]interface{}, fieldName string) []map[string]interface{} { + counts := make(map[string]int) + for _, chunk := range chunks { + raw, ok := chunk[fieldName] + if !ok || raw == nil { + continue + } + switch value := raw.(type) { + case string: + for _, tag := range strings.Split(value, "###") { + tag = strings.TrimSpace(tag) + if tag != "" { + counts[tag]++ + } + } + case []string: + for _, tag := range value { + tag = strings.TrimSpace(tag) + if tag != "" { + counts[tag]++ + } + } + } + } + + result := make([]map[string]interface{}, 0, len(counts)) + for tag, count := range counts { + result = append(result, map[string]interface{}{ + "key": tag, + "count": count, + }) + } + return result +} + +func (m *aggregateTagsMockEngine) GetType() string { return "mock" } + +func (m *aggregateTagsMockEngine) Ping(ctx context.Context) error { return nil } + +func (m *aggregateTagsMockEngine) Close() error { return nil } + +func testDatasetServiceForAggregateTags(t *testing.T, docEngine engine.DocEngine) *DatasetService { + t.Helper() + return &DatasetService{ + kbDAO: dao.NewKnowledgebaseDAO(), + docEngine: docEngine, + } +} + +func insertAggregateTagsKB(t *testing.T, datasetID, tenantID, permission string, docNum int64) { + t.Helper() + kb := &entity.Knowledgebase{ + ID: datasetID, + TenantID: tenantID, + Name: "kb-" + datasetID[:6], + EmbdID: "embedding@OpenAI", + CreatedBy: tenantID, + Permission: permission, + ParserID: "naive", + ParserConfig: entity.JSONMap{}, + DocNum: docNum, + Status: sptr(string(entity.StatusValid)), + } + if err := dao.DB.Create(kb).Error; err != nil { + t.Fatalf("insert test kb: %v", err) + } +} + +func insertAggregateTagsMembership(t *testing.T, tenantID, userID string) { + t.Helper() + row := &entity.UserTenant{ + ID: tenantID + "-" + userID, + UserID: userID, + TenantID: tenantID, + Role: "member", + InvitedBy: tenantID, + Status: sptr(string(entity.StatusValid)), + } + if err := dao.DB.Create(row).Error; err != nil { + t.Fatalf("insert user_tenant: %v", err) + } +} + +func aggregateTagsResultMap(rows []map[string]interface{}) map[string]int { + result := make(map[string]int, len(rows)) + for _, row := range rows { + tag, _ := row["value"].(string) + count, _ := row["count"].(int) + result[tag] = count + } + return result +} + +func TestDatasetServiceAggregateTagsMergesAcrossTenants(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + + kb1Input := "123e4567-e89b-12d3-a456-426614174000" + kb2Input := "223e4567-e89b-12d3-a456-426614174001" + kb1ID := strings.ReplaceAll(kb1Input, "-", "") + kb2ID := strings.ReplaceAll(kb2Input, "-", "") + + insertAggregateTagsKB(t, kb1ID, "user-1", string(entity.TenantPermissionMe), 2) + insertAggregateTagsKB(t, kb2ID, "tenant-2", string(entity.TenantPermissionTeam), 1) + insertAggregateTagsMembership(t, "tenant-2", "user-1") + + docEngine := &aggregateTagsMockEngine{ + searchResults: map[string]*types.SearchResult{ + "ragflow_user-1": { + Chunks: []map[string]interface{}{ + {"tag_kwd": "finance###urgent"}, + {"tag_kwd": "finance"}, + }, + }, + "ragflow_tenant-2": { + Chunks: []map[string]interface{}{ + {"tag_kwd": "urgent###internal"}, + }, + }, + }, + } + + result, code, err := testDatasetServiceForAggregateTags(t, docEngine).AggregateTags([]string{kb1Input, kb2Input}, "user-1") + if err != nil { + t.Fatalf("AggregateTags failed: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("code=%d want=%d", code, common.CodeSuccess) + } + + got := aggregateTagsResultMap(result) + want := map[string]int{ + "finance": 2, + "urgent": 2, + "internal": 1, + } + if len(got) != len(want) { + t.Fatalf("result len=%d want=%d result=%v", len(got), len(want), got) + } + for tag, wantCount := range want { + if got[tag] != wantCount { + t.Fatalf("tag %q count=%d want=%d all=%v", tag, got[tag], wantCount, got) + } + } + + if len(docEngine.requests) != 2 { + t.Fatalf("search requests=%d want=2", len(docEngine.requests)) + } + + requestByIndex := make(map[string]*types.SearchRequest, len(docEngine.requests)) + for _, req := range docEngine.requests { + if len(req.IndexNames) != 1 { + t.Fatalf("IndexNames=%v want single entry", req.IndexNames) + } + requestByIndex[req.IndexNames[0]] = req + if req.Offset != 0 { + t.Fatalf("Offset=%d want=0", req.Offset) + } + if req.Limit != 10000 { + t.Fatalf("Limit=%d want=10000", req.Limit) + } + if len(req.SelectFields) != 1 || req.SelectFields[0] != "tag_kwd" { + t.Fatalf("SelectFields=%v want [tag_kwd]", req.SelectFields) + } + } + + if req := requestByIndex["ragflow_user-1"]; req == nil || len(req.KbIDs) != 1 || req.KbIDs[0] != kb1ID { + t.Fatalf("request for ragflow_user-1 = %#v, want kbIDs=[%s]", req, kb1ID) + } + if req := requestByIndex["ragflow_tenant-2"]; req == nil || len(req.KbIDs) != 1 || req.KbIDs[0] != kb2ID { + t.Fatalf("request for ragflow_tenant-2 = %#v, want kbIDs=[%s]", req, kb2ID) + } +} + +func TestDatasetServiceAggregateTagsPagesThroughAllChunks(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + + kbInput := "723e4567-e89b-12d3-a456-426614174006" + kbID := strings.ReplaceAll(kbInput, "-", "") + insertAggregateTagsKB(t, kbID, "user-1", string(entity.TenantPermissionMe), 10002) + + firstPage := make([]map[string]interface{}, 10000) + for i := range firstPage { + firstPage[i] = map[string]interface{}{"tag_kwd": "finance"} + } + + docEngine := &aggregateTagsMockEngine{ + pagedSearchResults: map[string]map[int]*types.SearchResult{ + "ragflow_user-1": { + 0: { + Chunks: firstPage, + Total: 10002, + }, + 10000: { + Chunks: []map[string]interface{}{ + {"tag_kwd": "finance"}, + {"tag_kwd": "urgent"}, + }, + Total: 10002, + }, + }, + }, + } + + result, code, err := testDatasetServiceForAggregateTags(t, docEngine).AggregateTags([]string{kbInput}, "user-1") + if err != nil { + t.Fatalf("AggregateTags failed: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("code=%d want=%d", code, common.CodeSuccess) + } + + got := aggregateTagsResultMap(result) + if got["finance"] != 10001 { + t.Fatalf("finance count=%d want=10001 result=%v", got["finance"], got) + } + if got["urgent"] != 1 { + t.Fatalf("urgent count=%d want=1 result=%v", got["urgent"], got) + } + + if len(docEngine.requests) != 2 { + t.Fatalf("search requests=%d want=2", len(docEngine.requests)) + } + if docEngine.requests[0].Offset != 0 { + t.Fatalf("first request offset=%d want=0", docEngine.requests[0].Offset) + } + if docEngine.requests[1].Offset != 10000 { + t.Fatalf("second request offset=%d want=10000", docEngine.requests[1].Offset) + } +} + +func TestDatasetServiceAggregateTagsRejectsUnauthorizedDataset(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + + kbInput := "323e4567-e89b-12d3-a456-426614174002" + kbID := strings.ReplaceAll(kbInput, "-", "") + insertAggregateTagsKB(t, kbID, "tenant-9", string(entity.TenantPermissionMe), 1) + + _, code, err := testDatasetServiceForAggregateTags(t, &aggregateTagsMockEngine{}).AggregateTags([]string{kbInput}, "user-1") + if err == nil { + t.Fatal("expected authorization error") + } + if code != common.CodeDataError { + t.Fatalf("code=%d want=%d", code, common.CodeDataError) + } + if err.Error() != "No authorization for dataset '"+kbID+"'" { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestDatasetServiceAggregateTagsRequiresDocumentEngine(t *testing.T) { + _, code, err := testDatasetServiceForAggregateTags(t, nil).AggregateTags([]string{"123e4567-e89b-12d3-a456-426614174000"}, "user-1") + if err == nil { + t.Fatal("expected missing doc engine error") + } + if code != common.CodeServerError { + t.Fatalf("code=%d want=%d", code, common.CodeServerError) + } + if err.Error() != "Document engine is not initialized" { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestDatasetServiceAggregateTagsSkipsDatasetsWithoutDocuments(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + + emptyInput := "423e4567-e89b-12d3-a456-426614174003" + liveInput := "523e4567-e89b-12d3-a456-426614174004" + emptyID := strings.ReplaceAll(emptyInput, "-", "") + liveID := strings.ReplaceAll(liveInput, "-", "") + + insertAggregateTagsKB(t, emptyID, "user-1", string(entity.TenantPermissionMe), 0) + insertAggregateTagsKB(t, liveID, "user-1", string(entity.TenantPermissionMe), 1) + + docEngine := &aggregateTagsMockEngine{ + searchResults: map[string]*types.SearchResult{ + "ragflow_user-1": { + Chunks: []map[string]interface{}{ + {"tag_kwd": "alpha###beta"}, + }, + }, + }, + } + + result, code, err := testDatasetServiceForAggregateTags(t, docEngine).AggregateTags([]string{emptyInput, liveInput}, "user-1") + if err != nil { + t.Fatalf("AggregateTags failed: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("code=%d want=%d", code, common.CodeSuccess) + } + if len(docEngine.requests) != 1 { + t.Fatalf("search requests=%d want=1", len(docEngine.requests)) + } + if len(docEngine.requests[0].KbIDs) != 1 || docEngine.requests[0].KbIDs[0] != liveID { + t.Fatalf("KbIDs=%v want [%s]", docEngine.requests[0].KbIDs, liveID) + } + + got := aggregateTagsResultMap(result) + if got["alpha"] != 1 || got["beta"] != 1 { + t.Fatalf("unexpected result=%v", got) + } +} + +func TestDatasetServiceAggregateTagsReturnsSearchError(t *testing.T) { + db := setupServiceTestDB(t) + pushServiceDB(t, db) + + kbInput := "623e4567-e89b-12d3-a456-426614174005" + kbID := strings.ReplaceAll(kbInput, "-", "") + insertAggregateTagsKB(t, kbID, "user-1", string(entity.TenantPermissionMe), 1) + + docEngine := &aggregateTagsMockEngine{searchErr: errors.New("boom")} + _, code, err := testDatasetServiceForAggregateTags(t, docEngine).AggregateTags([]string{kbInput}, "user-1") + if err == nil { + t.Fatal("expected search error") + } + if code != common.CodeServerError { + t.Fatalf("code=%d want=%d", code, common.CodeServerError) + } + if !strings.Contains(err.Error(), "failed to aggregate tags: boom") { + t.Fatalf("unexpected error: %v", err) + } +}