feat(go-api): migrate datasets tags aggregation API to Go (#16181)

### Description

Migrates the datasets tags aggregation API `GET
/api/v1/datasets/tags/aggregation` from Python to Go.

### Changes
- Registered the `GET /api/v1/datasets/tags/aggregation` route.
- Implemented `AggregateTags` in datasets `handler` and `service`.
- Added handler and service `unit tests`.

### Test Verification
- Verified by comparing results between Python (9380) and Go (9384)
services.
- Tested scenarios: single dataset, multiple datasets, empty parameters,
and unauthorized/invalid IDs.
- All tests and Go `unit tests` passed.
This commit is contained in:
Hz_
2026-06-24 14:42:10 +08:00
committed by GitHub
parent 68d2ca0ff1
commit 368db6fa58
5 changed files with 579 additions and 0 deletions

View File

@@ -592,6 +592,44 @@ func (h *DatasetsHandler) RemoveTags(c *gin.Context) {
jsonResponse(c, common.CodeSuccess, true, "success")
}
// AggregateTags handles GET /api/v1/datasets/tags/aggregation.
// @Summary Aggregate dataset tags
// @Description Aggregate tags across multiple datasets
// @Tags datasets
// @Produce json
// @Security ApiKeyAuth
// @Param dataset_ids query string true "Comma-separated dataset IDs"
// @Success 200 {object} map[string]interface{}
// @Router /api/v1/datasets/tags/aggregation [get]
func (h *DatasetsHandler) AggregateTags(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
rawIDs := strings.Split(c.Query("dataset_ids"), ",")
datasetIDs := make([]string, 0, len(rawIDs))
for _, rawID := range rawIDs {
tempID := strings.TrimSpace(rawID)
if tempID != "" {
datasetIDs = append(datasetIDs, tempID)
}
}
if len(datasetIDs) == 0 {
jsonError(c, common.CodeDataError, "Lack of dataset_ids in query parameters")
return
}
result, code, err := h.datasetsService.AggregateTags(datasetIDs, user.ID)
if err != nil {
jsonError(c, code, err.Error())
return
}
jsonResponse(c, common.CodeSuccess, result, "success")
}
// RunIndex Run an indexing task (graph/raptor/mindmap) for a dataset.
func (h *DatasetsHandler) RunIndex(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)

View File

@@ -0,0 +1,78 @@
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"ragflow/internal/common"
"ragflow/internal/entity"
)
func newAggregateTagsHandlerRouter(authenticated bool) *gin.Engine {
gin.SetMode(gin.TestMode)
h := &DatasetsHandler{}
r := gin.New()
r.GET("/api/v1/datasets/tags/aggregation", func(c *gin.Context) {
if authenticated {
c.Set("user", &entity.User{ID: "user-1"})
}
h.AggregateTags(c)
})
return r
}
func TestDatasetsHandlerAggregateTagsRequiresDatasetIDs(t *testing.T) {
r := newAggregateTagsHandlerRouter(true)
resp := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/datasets/tags/aggregation", nil)
r.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("status=%d body=%s", resp.Code, resp.Body.String())
}
var body struct {
Code int `json:"code"`
Data interface{} `json:"data"`
Message string `json:"message"`
}
if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal response: %v body=%s", err, resp.Body.String())
}
if body.Code != int(common.CodeDataError) {
t.Fatalf("code=%d want=%d body=%s", body.Code, common.CodeDataError, resp.Body.String())
}
if body.Message != "Lack of dataset_ids in query parameters" {
t.Fatalf("message=%q want=%q", body.Message, "Lack of dataset_ids in query parameters")
}
if body.Data != nil {
t.Fatalf("data=%v want nil", body.Data)
}
}
func TestDatasetsHandlerAggregateTagsRequiresAuth(t *testing.T) {
r := newAggregateTagsHandlerRouter(false)
resp := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/datasets/tags/aggregation?dataset_ids=123e4567-e89b-12d3-a456-426614174000", nil)
r.ServeHTTP(resp, req)
var body struct {
Code int `json:"code"`
Message string `json:"message"`
}
if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal response: %v body=%s", err, resp.Body.String())
}
if body.Code != int(common.CodeUnauthorized) {
t.Fatalf("code=%d want=%d body=%s", body.Code, common.CodeUnauthorized, resp.Body.String())
}
if body.Message != "User not found" {
t.Fatalf("message=%q want=%q", body.Message, "User not found")
}
}

View File

@@ -269,6 +269,7 @@ func (r *Router) Setup(engine *gin.Engine) {
datasets := v1.Group("/datasets")
{
datasets.GET("", r.datasetsHandler.ListDatasets)
datasets.GET("/tags/aggregation", r.datasetsHandler.AggregateTags)
datasets.GET("/:dataset_id", r.datasetsHandler.GetDataset)
datasets.PUT("/:dataset_id", r.datasetsHandler.UpdateDataset)
datasets.GET("/:dataset_id/graph", r.datasetsHandler.GetKnowledgeGraph)

View File

@@ -25,6 +25,7 @@ import (
"ragflow/internal/dao"
"ragflow/internal/engine"
redisengine "ragflow/internal/engine/redis"
"ragflow/internal/engine/types"
"ragflow/internal/entity"
"ragflow/internal/entity/models"
"ragflow/internal/server"
@@ -1831,6 +1832,90 @@ func (s *DatasetService) Accessible(kbID, userID string) bool {
return s.kbDAO.Accessible(kbID, userID)
}
func (s *DatasetService) AggregateTags(datasetIDs []string, userID string) ([]map[string]interface{}, common.ErrorCode, error) {
if len(datasetIDs) == 0 {
return nil, common.CodeDataError, errors.New("Lack of dataset_ids in query parameters")
}
if s.docEngine == nil {
return nil, common.CodeServerError, errors.New("Document engine is not initialized")
}
datasetIDsByTenant := make(map[string][]string)
for _, rawID := range datasetIDs {
rawID = strings.TrimSpace(rawID)
if rawID == "" {
continue
}
datasetID, err := normalizeDatasetUUID1(rawID)
if err != nil {
return nil, common.CodeDataError, err
}
if !s.kbDAO.Accessible(datasetID, userID) {
return nil, common.CodeDataError, fmt.Errorf("No authorization for dataset '%s'", datasetID)
}
kb, err := s.kbDAO.GetByID(datasetID)
if err != nil {
if dao.IsNotFoundErr(err) {
return nil, common.CodeDataError, fmt.Errorf("Invalid Dataset ID '%s'", datasetID)
}
return nil, common.CodeServerError, errors.New("Database operation failed")
}
if kb.DocNum <= 0 {
continue
}
datasetIDsByTenant[kb.TenantID] = append(datasetIDsByTenant[kb.TenantID], datasetID)
}
const pageSize = 10000
merged := make(map[string]int)
for tenantID, kbIDs := range datasetIDsByTenant {
for offset := 0; ; offset += pageSize {
searchResp, err := s.docEngine.Search(context.Background(), &types.SearchRequest{
IndexNames: []string{fmt.Sprintf("ragflow_%s", tenantID)},
KbIDs: kbIDs,
Offset: offset,
Limit: pageSize,
SelectFields: []string{"tag_kwd"},
})
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to aggregate tags: %w", err)
}
for _, agg := range s.docEngine.GetAggregation(searchResp.Chunks, "tag_kwd") {
tag, _ := agg["key"].(string)
if tag == "" {
continue
}
switch count := agg["count"].(type) {
case int:
merged[tag] += count
case int32:
merged[tag] += int(count)
case int64:
merged[tag] += int(count)
case float64:
merged[tag] += int(count)
}
}
chunkCount := len(searchResp.Chunks)
if chunkCount == 0 || chunkCount < pageSize {
break
}
if searchResp.Total > 0 && int64(offset+chunkCount) >= searchResp.Total {
break
}
}
}
result := make([]map[string]interface{}, 0, len(merged))
for tag, count := range merged {
result = append(result, map[string]interface{}{
"value": tag,
"count": count,
})
}
return result, common.CodeSuccess, nil
}
// GetIngestionSummary returns dataset-level ingestion counters together with
// the aggregated document parsing status, mirroring
// dataset_api_service.get_ingestion_summary.

View File

@@ -0,0 +1,377 @@
package service
import (
"context"
"errors"
"strings"
"testing"
"ragflow/internal/common"
"ragflow/internal/dao"
"ragflow/internal/engine"
"ragflow/internal/engine/types"
"ragflow/internal/entity"
)
type aggregateTagsMockEngine struct {
engine.DocEngine
searchResults map[string]*types.SearchResult
pagedSearchResults map[string]map[int]*types.SearchResult
searchErr error
requests []*types.SearchRequest
}
func (m *aggregateTagsMockEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
if m.searchErr != nil {
return nil, m.searchErr
}
cloned := &types.SearchRequest{
IndexNames: append([]string(nil), req.IndexNames...),
KbIDs: append([]string(nil), req.KbIDs...),
Offset: req.Offset,
Limit: req.Limit,
SelectFields: append([]string(nil), req.SelectFields...),
}
m.requests = append(m.requests, cloned)
if len(req.IndexNames) == 0 {
return &types.SearchResult{}, nil
}
if byOffset, ok := m.pagedSearchResults[req.IndexNames[0]]; ok {
if res, ok := byOffset[req.Offset]; ok {
return res, nil
}
return &types.SearchResult{}, nil
}
if res, ok := m.searchResults[req.IndexNames[0]]; ok {
return res, nil
}
return &types.SearchResult{}, nil
}
func (m *aggregateTagsMockEngine) GetAggregation(chunks []map[string]interface{}, fieldName string) []map[string]interface{} {
counts := make(map[string]int)
for _, chunk := range chunks {
raw, ok := chunk[fieldName]
if !ok || raw == nil {
continue
}
switch value := raw.(type) {
case string:
for _, tag := range strings.Split(value, "###") {
tag = strings.TrimSpace(tag)
if tag != "" {
counts[tag]++
}
}
case []string:
for _, tag := range value {
tag = strings.TrimSpace(tag)
if tag != "" {
counts[tag]++
}
}
}
}
result := make([]map[string]interface{}, 0, len(counts))
for tag, count := range counts {
result = append(result, map[string]interface{}{
"key": tag,
"count": count,
})
}
return result
}
func (m *aggregateTagsMockEngine) GetType() string { return "mock" }
func (m *aggregateTagsMockEngine) Ping(ctx context.Context) error { return nil }
func (m *aggregateTagsMockEngine) Close() error { return nil }
func testDatasetServiceForAggregateTags(t *testing.T, docEngine engine.DocEngine) *DatasetService {
t.Helper()
return &DatasetService{
kbDAO: dao.NewKnowledgebaseDAO(),
docEngine: docEngine,
}
}
func insertAggregateTagsKB(t *testing.T, datasetID, tenantID, permission string, docNum int64) {
t.Helper()
kb := &entity.Knowledgebase{
ID: datasetID,
TenantID: tenantID,
Name: "kb-" + datasetID[:6],
EmbdID: "embedding@OpenAI",
CreatedBy: tenantID,
Permission: permission,
ParserID: "naive",
ParserConfig: entity.JSONMap{},
DocNum: docNum,
Status: sptr(string(entity.StatusValid)),
}
if err := dao.DB.Create(kb).Error; err != nil {
t.Fatalf("insert test kb: %v", err)
}
}
func insertAggregateTagsMembership(t *testing.T, tenantID, userID string) {
t.Helper()
row := &entity.UserTenant{
ID: tenantID + "-" + userID,
UserID: userID,
TenantID: tenantID,
Role: "member",
InvitedBy: tenantID,
Status: sptr(string(entity.StatusValid)),
}
if err := dao.DB.Create(row).Error; err != nil {
t.Fatalf("insert user_tenant: %v", err)
}
}
func aggregateTagsResultMap(rows []map[string]interface{}) map[string]int {
result := make(map[string]int, len(rows))
for _, row := range rows {
tag, _ := row["value"].(string)
count, _ := row["count"].(int)
result[tag] = count
}
return result
}
func TestDatasetServiceAggregateTagsMergesAcrossTenants(t *testing.T) {
db := setupServiceTestDB(t)
pushServiceDB(t, db)
kb1Input := "123e4567-e89b-12d3-a456-426614174000"
kb2Input := "223e4567-e89b-12d3-a456-426614174001"
kb1ID := strings.ReplaceAll(kb1Input, "-", "")
kb2ID := strings.ReplaceAll(kb2Input, "-", "")
insertAggregateTagsKB(t, kb1ID, "user-1", string(entity.TenantPermissionMe), 2)
insertAggregateTagsKB(t, kb2ID, "tenant-2", string(entity.TenantPermissionTeam), 1)
insertAggregateTagsMembership(t, "tenant-2", "user-1")
docEngine := &aggregateTagsMockEngine{
searchResults: map[string]*types.SearchResult{
"ragflow_user-1": {
Chunks: []map[string]interface{}{
{"tag_kwd": "finance###urgent"},
{"tag_kwd": "finance"},
},
},
"ragflow_tenant-2": {
Chunks: []map[string]interface{}{
{"tag_kwd": "urgent###internal"},
},
},
},
}
result, code, err := testDatasetServiceForAggregateTags(t, docEngine).AggregateTags([]string{kb1Input, kb2Input}, "user-1")
if err != nil {
t.Fatalf("AggregateTags failed: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("code=%d want=%d", code, common.CodeSuccess)
}
got := aggregateTagsResultMap(result)
want := map[string]int{
"finance": 2,
"urgent": 2,
"internal": 1,
}
if len(got) != len(want) {
t.Fatalf("result len=%d want=%d result=%v", len(got), len(want), got)
}
for tag, wantCount := range want {
if got[tag] != wantCount {
t.Fatalf("tag %q count=%d want=%d all=%v", tag, got[tag], wantCount, got)
}
}
if len(docEngine.requests) != 2 {
t.Fatalf("search requests=%d want=2", len(docEngine.requests))
}
requestByIndex := make(map[string]*types.SearchRequest, len(docEngine.requests))
for _, req := range docEngine.requests {
if len(req.IndexNames) != 1 {
t.Fatalf("IndexNames=%v want single entry", req.IndexNames)
}
requestByIndex[req.IndexNames[0]] = req
if req.Offset != 0 {
t.Fatalf("Offset=%d want=0", req.Offset)
}
if req.Limit != 10000 {
t.Fatalf("Limit=%d want=10000", req.Limit)
}
if len(req.SelectFields) != 1 || req.SelectFields[0] != "tag_kwd" {
t.Fatalf("SelectFields=%v want [tag_kwd]", req.SelectFields)
}
}
if req := requestByIndex["ragflow_user-1"]; req == nil || len(req.KbIDs) != 1 || req.KbIDs[0] != kb1ID {
t.Fatalf("request for ragflow_user-1 = %#v, want kbIDs=[%s]", req, kb1ID)
}
if req := requestByIndex["ragflow_tenant-2"]; req == nil || len(req.KbIDs) != 1 || req.KbIDs[0] != kb2ID {
t.Fatalf("request for ragflow_tenant-2 = %#v, want kbIDs=[%s]", req, kb2ID)
}
}
func TestDatasetServiceAggregateTagsPagesThroughAllChunks(t *testing.T) {
db := setupServiceTestDB(t)
pushServiceDB(t, db)
kbInput := "723e4567-e89b-12d3-a456-426614174006"
kbID := strings.ReplaceAll(kbInput, "-", "")
insertAggregateTagsKB(t, kbID, "user-1", string(entity.TenantPermissionMe), 10002)
firstPage := make([]map[string]interface{}, 10000)
for i := range firstPage {
firstPage[i] = map[string]interface{}{"tag_kwd": "finance"}
}
docEngine := &aggregateTagsMockEngine{
pagedSearchResults: map[string]map[int]*types.SearchResult{
"ragflow_user-1": {
0: {
Chunks: firstPage,
Total: 10002,
},
10000: {
Chunks: []map[string]interface{}{
{"tag_kwd": "finance"},
{"tag_kwd": "urgent"},
},
Total: 10002,
},
},
},
}
result, code, err := testDatasetServiceForAggregateTags(t, docEngine).AggregateTags([]string{kbInput}, "user-1")
if err != nil {
t.Fatalf("AggregateTags failed: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("code=%d want=%d", code, common.CodeSuccess)
}
got := aggregateTagsResultMap(result)
if got["finance"] != 10001 {
t.Fatalf("finance count=%d want=10001 result=%v", got["finance"], got)
}
if got["urgent"] != 1 {
t.Fatalf("urgent count=%d want=1 result=%v", got["urgent"], got)
}
if len(docEngine.requests) != 2 {
t.Fatalf("search requests=%d want=2", len(docEngine.requests))
}
if docEngine.requests[0].Offset != 0 {
t.Fatalf("first request offset=%d want=0", docEngine.requests[0].Offset)
}
if docEngine.requests[1].Offset != 10000 {
t.Fatalf("second request offset=%d want=10000", docEngine.requests[1].Offset)
}
}
func TestDatasetServiceAggregateTagsRejectsUnauthorizedDataset(t *testing.T) {
db := setupServiceTestDB(t)
pushServiceDB(t, db)
kbInput := "323e4567-e89b-12d3-a456-426614174002"
kbID := strings.ReplaceAll(kbInput, "-", "")
insertAggregateTagsKB(t, kbID, "tenant-9", string(entity.TenantPermissionMe), 1)
_, code, err := testDatasetServiceForAggregateTags(t, &aggregateTagsMockEngine{}).AggregateTags([]string{kbInput}, "user-1")
if err == nil {
t.Fatal("expected authorization error")
}
if code != common.CodeDataError {
t.Fatalf("code=%d want=%d", code, common.CodeDataError)
}
if err.Error() != "No authorization for dataset '"+kbID+"'" {
t.Fatalf("unexpected error: %v", err)
}
}
func TestDatasetServiceAggregateTagsRequiresDocumentEngine(t *testing.T) {
_, code, err := testDatasetServiceForAggregateTags(t, nil).AggregateTags([]string{"123e4567-e89b-12d3-a456-426614174000"}, "user-1")
if err == nil {
t.Fatal("expected missing doc engine error")
}
if code != common.CodeServerError {
t.Fatalf("code=%d want=%d", code, common.CodeServerError)
}
if err.Error() != "Document engine is not initialized" {
t.Fatalf("unexpected error: %v", err)
}
}
func TestDatasetServiceAggregateTagsSkipsDatasetsWithoutDocuments(t *testing.T) {
db := setupServiceTestDB(t)
pushServiceDB(t, db)
emptyInput := "423e4567-e89b-12d3-a456-426614174003"
liveInput := "523e4567-e89b-12d3-a456-426614174004"
emptyID := strings.ReplaceAll(emptyInput, "-", "")
liveID := strings.ReplaceAll(liveInput, "-", "")
insertAggregateTagsKB(t, emptyID, "user-1", string(entity.TenantPermissionMe), 0)
insertAggregateTagsKB(t, liveID, "user-1", string(entity.TenantPermissionMe), 1)
docEngine := &aggregateTagsMockEngine{
searchResults: map[string]*types.SearchResult{
"ragflow_user-1": {
Chunks: []map[string]interface{}{
{"tag_kwd": "alpha###beta"},
},
},
},
}
result, code, err := testDatasetServiceForAggregateTags(t, docEngine).AggregateTags([]string{emptyInput, liveInput}, "user-1")
if err != nil {
t.Fatalf("AggregateTags failed: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("code=%d want=%d", code, common.CodeSuccess)
}
if len(docEngine.requests) != 1 {
t.Fatalf("search requests=%d want=1", len(docEngine.requests))
}
if len(docEngine.requests[0].KbIDs) != 1 || docEngine.requests[0].KbIDs[0] != liveID {
t.Fatalf("KbIDs=%v want [%s]", docEngine.requests[0].KbIDs, liveID)
}
got := aggregateTagsResultMap(result)
if got["alpha"] != 1 || got["beta"] != 1 {
t.Fatalf("unexpected result=%v", got)
}
}
func TestDatasetServiceAggregateTagsReturnsSearchError(t *testing.T) {
db := setupServiceTestDB(t)
pushServiceDB(t, db)
kbInput := "623e4567-e89b-12d3-a456-426614174005"
kbID := strings.ReplaceAll(kbInput, "-", "")
insertAggregateTagsKB(t, kbID, "user-1", string(entity.TenantPermissionMe), 1)
docEngine := &aggregateTagsMockEngine{searchErr: errors.New("boom")}
_, code, err := testDatasetServiceForAggregateTags(t, docEngine).AggregateTags([]string{kbInput}, "user-1")
if err == nil {
t.Fatal("expected search error")
}
if code != common.CodeServerError {
t.Fatalf("code=%d want=%d", code, common.CodeServerError)
}
if !strings.Contains(err.Error(), "failed to aggregate tags: boom") {
t.Fatalf("unexpected error: %v", err)
}
}