mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
feat(go-api): migrate datasets tags aggregation API to Go (#16181)
### Description Migrates the datasets tags aggregation API `GET /api/v1/datasets/tags/aggregation` from Python to Go. ### Changes - Registered the `GET /api/v1/datasets/tags/aggregation` route. - Implemented `AggregateTags` in datasets `handler` and `service`. - Added handler and service `unit tests`. ### Test Verification - Verified by comparing results between Python (9380) and Go (9384) services. - Tested scenarios: single dataset, multiple datasets, empty parameters, and unauthorized/invalid IDs. - All tests and Go `unit tests` passed.
This commit is contained in:
@@ -592,6 +592,44 @@ func (h *DatasetsHandler) RemoveTags(c *gin.Context) {
|
||||
jsonResponse(c, common.CodeSuccess, true, "success")
|
||||
}
|
||||
|
||||
// AggregateTags handles GET /api/v1/datasets/tags/aggregation.
|
||||
// @Summary Aggregate dataset tags
|
||||
// @Description Aggregate tags across multiple datasets
|
||||
// @Tags datasets
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param dataset_ids query string true "Comma-separated dataset IDs"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/v1/datasets/tags/aggregation [get]
|
||||
func (h *DatasetsHandler) AggregateTags(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
rawIDs := strings.Split(c.Query("dataset_ids"), ",")
|
||||
datasetIDs := make([]string, 0, len(rawIDs))
|
||||
|
||||
for _, rawID := range rawIDs {
|
||||
tempID := strings.TrimSpace(rawID)
|
||||
if tempID != "" {
|
||||
datasetIDs = append(datasetIDs, tempID)
|
||||
}
|
||||
}
|
||||
if len(datasetIDs) == 0 {
|
||||
jsonError(c, common.CodeDataError, "Lack of dataset_ids in query parameters")
|
||||
return
|
||||
}
|
||||
|
||||
result, code, err := h.datasetsService.AggregateTags(datasetIDs, user.ID)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
// RunIndex Run an indexing task (graph/raptor/mindmap) for a dataset.
|
||||
func (h *DatasetsHandler) RunIndex(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
|
||||
78
internal/handler/datasets_aggregate_tags_test.go
Normal file
78
internal/handler/datasets_aggregate_tags_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
func newAggregateTagsHandlerRouter(authenticated bool) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &DatasetsHandler{}
|
||||
r := gin.New()
|
||||
r.GET("/api/v1/datasets/tags/aggregation", func(c *gin.Context) {
|
||||
if authenticated {
|
||||
c.Set("user", &entity.User{ID: "user-1"})
|
||||
}
|
||||
h.AggregateTags(c)
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func TestDatasetsHandlerAggregateTagsRequiresDatasetIDs(t *testing.T) {
|
||||
r := newAggregateTagsHandlerRouter(true)
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/datasets/tags/aggregation", nil)
|
||||
r.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Code int `json:"code"`
|
||||
Data interface{} `json:"data"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v body=%s", err, resp.Body.String())
|
||||
}
|
||||
if body.Code != int(common.CodeDataError) {
|
||||
t.Fatalf("code=%d want=%d body=%s", body.Code, common.CodeDataError, resp.Body.String())
|
||||
}
|
||||
if body.Message != "Lack of dataset_ids in query parameters" {
|
||||
t.Fatalf("message=%q want=%q", body.Message, "Lack of dataset_ids in query parameters")
|
||||
}
|
||||
if body.Data != nil {
|
||||
t.Fatalf("data=%v want nil", body.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatasetsHandlerAggregateTagsRequiresAuth(t *testing.T) {
|
||||
r := newAggregateTagsHandlerRouter(false)
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/datasets/tags/aggregation?dataset_ids=123e4567-e89b-12d3-a456-426614174000", nil)
|
||||
r.ServeHTTP(resp, req)
|
||||
|
||||
var body struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v body=%s", err, resp.Body.String())
|
||||
}
|
||||
if body.Code != int(common.CodeUnauthorized) {
|
||||
t.Fatalf("code=%d want=%d body=%s", body.Code, common.CodeUnauthorized, resp.Body.String())
|
||||
}
|
||||
if body.Message != "User not found" {
|
||||
t.Fatalf("message=%q want=%q", body.Message, "User not found")
|
||||
}
|
||||
}
|
||||
@@ -269,6 +269,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
datasets := v1.Group("/datasets")
|
||||
{
|
||||
datasets.GET("", r.datasetsHandler.ListDatasets)
|
||||
datasets.GET("/tags/aggregation", r.datasetsHandler.AggregateTags)
|
||||
datasets.GET("/:dataset_id", r.datasetsHandler.GetDataset)
|
||||
datasets.PUT("/:dataset_id", r.datasetsHandler.UpdateDataset)
|
||||
datasets.GET("/:dataset_id/graph", r.datasetsHandler.GetKnowledgeGraph)
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/engine"
|
||||
redisengine "ragflow/internal/engine/redis"
|
||||
"ragflow/internal/engine/types"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/entity/models"
|
||||
"ragflow/internal/server"
|
||||
@@ -1831,6 +1832,90 @@ func (s *DatasetService) Accessible(kbID, userID string) bool {
|
||||
return s.kbDAO.Accessible(kbID, userID)
|
||||
}
|
||||
|
||||
func (s *DatasetService) AggregateTags(datasetIDs []string, userID string) ([]map[string]interface{}, common.ErrorCode, error) {
|
||||
if len(datasetIDs) == 0 {
|
||||
return nil, common.CodeDataError, errors.New("Lack of dataset_ids in query parameters")
|
||||
}
|
||||
if s.docEngine == nil {
|
||||
return nil, common.CodeServerError, errors.New("Document engine is not initialized")
|
||||
}
|
||||
|
||||
datasetIDsByTenant := make(map[string][]string)
|
||||
for _, rawID := range datasetIDs {
|
||||
rawID = strings.TrimSpace(rawID)
|
||||
if rawID == "" {
|
||||
continue
|
||||
}
|
||||
datasetID, err := normalizeDatasetUUID1(rawID)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, err
|
||||
}
|
||||
if !s.kbDAO.Accessible(datasetID, userID) {
|
||||
return nil, common.CodeDataError, fmt.Errorf("No authorization for dataset '%s'", datasetID)
|
||||
}
|
||||
kb, err := s.kbDAO.GetByID(datasetID)
|
||||
if err != nil {
|
||||
if dao.IsNotFoundErr(err) {
|
||||
return nil, common.CodeDataError, fmt.Errorf("Invalid Dataset ID '%s'", datasetID)
|
||||
}
|
||||
return nil, common.CodeServerError, errors.New("Database operation failed")
|
||||
}
|
||||
if kb.DocNum <= 0 {
|
||||
continue
|
||||
}
|
||||
datasetIDsByTenant[kb.TenantID] = append(datasetIDsByTenant[kb.TenantID], datasetID)
|
||||
}
|
||||
|
||||
const pageSize = 10000
|
||||
merged := make(map[string]int)
|
||||
for tenantID, kbIDs := range datasetIDsByTenant {
|
||||
for offset := 0; ; offset += pageSize {
|
||||
searchResp, err := s.docEngine.Search(context.Background(), &types.SearchRequest{
|
||||
IndexNames: []string{fmt.Sprintf("ragflow_%s", tenantID)},
|
||||
KbIDs: kbIDs,
|
||||
Offset: offset,
|
||||
Limit: pageSize,
|
||||
SelectFields: []string{"tag_kwd"},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to aggregate tags: %w", err)
|
||||
}
|
||||
for _, agg := range s.docEngine.GetAggregation(searchResp.Chunks, "tag_kwd") {
|
||||
tag, _ := agg["key"].(string)
|
||||
if tag == "" {
|
||||
continue
|
||||
}
|
||||
switch count := agg["count"].(type) {
|
||||
case int:
|
||||
merged[tag] += count
|
||||
case int32:
|
||||
merged[tag] += int(count)
|
||||
case int64:
|
||||
merged[tag] += int(count)
|
||||
case float64:
|
||||
merged[tag] += int(count)
|
||||
}
|
||||
}
|
||||
|
||||
chunkCount := len(searchResp.Chunks)
|
||||
if chunkCount == 0 || chunkCount < pageSize {
|
||||
break
|
||||
}
|
||||
if searchResp.Total > 0 && int64(offset+chunkCount) >= searchResp.Total {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
result := make([]map[string]interface{}, 0, len(merged))
|
||||
for tag, count := range merged {
|
||||
result = append(result, map[string]interface{}{
|
||||
"value": tag,
|
||||
"count": count,
|
||||
})
|
||||
}
|
||||
return result, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// GetIngestionSummary returns dataset-level ingestion counters together with
|
||||
// the aggregated document parsing status, mirroring
|
||||
// dataset_api_service.get_ingestion_summary.
|
||||
|
||||
377
internal/service/dataset_aggregate_tags_test.go
Normal file
377
internal/service/dataset_aggregate_tags_test.go
Normal file
@@ -0,0 +1,377 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/engine/types"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
type aggregateTagsMockEngine struct {
|
||||
engine.DocEngine
|
||||
searchResults map[string]*types.SearchResult
|
||||
pagedSearchResults map[string]map[int]*types.SearchResult
|
||||
searchErr error
|
||||
requests []*types.SearchRequest
|
||||
}
|
||||
|
||||
func (m *aggregateTagsMockEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
|
||||
if m.searchErr != nil {
|
||||
return nil, m.searchErr
|
||||
}
|
||||
cloned := &types.SearchRequest{
|
||||
IndexNames: append([]string(nil), req.IndexNames...),
|
||||
KbIDs: append([]string(nil), req.KbIDs...),
|
||||
Offset: req.Offset,
|
||||
Limit: req.Limit,
|
||||
SelectFields: append([]string(nil), req.SelectFields...),
|
||||
}
|
||||
m.requests = append(m.requests, cloned)
|
||||
if len(req.IndexNames) == 0 {
|
||||
return &types.SearchResult{}, nil
|
||||
}
|
||||
if byOffset, ok := m.pagedSearchResults[req.IndexNames[0]]; ok {
|
||||
if res, ok := byOffset[req.Offset]; ok {
|
||||
return res, nil
|
||||
}
|
||||
return &types.SearchResult{}, nil
|
||||
}
|
||||
if res, ok := m.searchResults[req.IndexNames[0]]; ok {
|
||||
return res, nil
|
||||
}
|
||||
return &types.SearchResult{}, nil
|
||||
}
|
||||
|
||||
func (m *aggregateTagsMockEngine) GetAggregation(chunks []map[string]interface{}, fieldName string) []map[string]interface{} {
|
||||
counts := make(map[string]int)
|
||||
for _, chunk := range chunks {
|
||||
raw, ok := chunk[fieldName]
|
||||
if !ok || raw == nil {
|
||||
continue
|
||||
}
|
||||
switch value := raw.(type) {
|
||||
case string:
|
||||
for _, tag := range strings.Split(value, "###") {
|
||||
tag = strings.TrimSpace(tag)
|
||||
if tag != "" {
|
||||
counts[tag]++
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
for _, tag := range value {
|
||||
tag = strings.TrimSpace(tag)
|
||||
if tag != "" {
|
||||
counts[tag]++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, 0, len(counts))
|
||||
for tag, count := range counts {
|
||||
result = append(result, map[string]interface{}{
|
||||
"key": tag,
|
||||
"count": count,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *aggregateTagsMockEngine) GetType() string { return "mock" }
|
||||
|
||||
func (m *aggregateTagsMockEngine) Ping(ctx context.Context) error { return nil }
|
||||
|
||||
func (m *aggregateTagsMockEngine) Close() error { return nil }
|
||||
|
||||
func testDatasetServiceForAggregateTags(t *testing.T, docEngine engine.DocEngine) *DatasetService {
|
||||
t.Helper()
|
||||
return &DatasetService{
|
||||
kbDAO: dao.NewKnowledgebaseDAO(),
|
||||
docEngine: docEngine,
|
||||
}
|
||||
}
|
||||
|
||||
func insertAggregateTagsKB(t *testing.T, datasetID, tenantID, permission string, docNum int64) {
|
||||
t.Helper()
|
||||
kb := &entity.Knowledgebase{
|
||||
ID: datasetID,
|
||||
TenantID: tenantID,
|
||||
Name: "kb-" + datasetID[:6],
|
||||
EmbdID: "embedding@OpenAI",
|
||||
CreatedBy: tenantID,
|
||||
Permission: permission,
|
||||
ParserID: "naive",
|
||||
ParserConfig: entity.JSONMap{},
|
||||
DocNum: docNum,
|
||||
Status: sptr(string(entity.StatusValid)),
|
||||
}
|
||||
if err := dao.DB.Create(kb).Error; err != nil {
|
||||
t.Fatalf("insert test kb: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func insertAggregateTagsMembership(t *testing.T, tenantID, userID string) {
|
||||
t.Helper()
|
||||
row := &entity.UserTenant{
|
||||
ID: tenantID + "-" + userID,
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
Role: "member",
|
||||
InvitedBy: tenantID,
|
||||
Status: sptr(string(entity.StatusValid)),
|
||||
}
|
||||
if err := dao.DB.Create(row).Error; err != nil {
|
||||
t.Fatalf("insert user_tenant: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func aggregateTagsResultMap(rows []map[string]interface{}) map[string]int {
|
||||
result := make(map[string]int, len(rows))
|
||||
for _, row := range rows {
|
||||
tag, _ := row["value"].(string)
|
||||
count, _ := row["count"].(int)
|
||||
result[tag] = count
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func TestDatasetServiceAggregateTagsMergesAcrossTenants(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
pushServiceDB(t, db)
|
||||
|
||||
kb1Input := "123e4567-e89b-12d3-a456-426614174000"
|
||||
kb2Input := "223e4567-e89b-12d3-a456-426614174001"
|
||||
kb1ID := strings.ReplaceAll(kb1Input, "-", "")
|
||||
kb2ID := strings.ReplaceAll(kb2Input, "-", "")
|
||||
|
||||
insertAggregateTagsKB(t, kb1ID, "user-1", string(entity.TenantPermissionMe), 2)
|
||||
insertAggregateTagsKB(t, kb2ID, "tenant-2", string(entity.TenantPermissionTeam), 1)
|
||||
insertAggregateTagsMembership(t, "tenant-2", "user-1")
|
||||
|
||||
docEngine := &aggregateTagsMockEngine{
|
||||
searchResults: map[string]*types.SearchResult{
|
||||
"ragflow_user-1": {
|
||||
Chunks: []map[string]interface{}{
|
||||
{"tag_kwd": "finance###urgent"},
|
||||
{"tag_kwd": "finance"},
|
||||
},
|
||||
},
|
||||
"ragflow_tenant-2": {
|
||||
Chunks: []map[string]interface{}{
|
||||
{"tag_kwd": "urgent###internal"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, code, err := testDatasetServiceForAggregateTags(t, docEngine).AggregateTags([]string{kb1Input, kb2Input}, "user-1")
|
||||
if err != nil {
|
||||
t.Fatalf("AggregateTags failed: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code=%d want=%d", code, common.CodeSuccess)
|
||||
}
|
||||
|
||||
got := aggregateTagsResultMap(result)
|
||||
want := map[string]int{
|
||||
"finance": 2,
|
||||
"urgent": 2,
|
||||
"internal": 1,
|
||||
}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("result len=%d want=%d result=%v", len(got), len(want), got)
|
||||
}
|
||||
for tag, wantCount := range want {
|
||||
if got[tag] != wantCount {
|
||||
t.Fatalf("tag %q count=%d want=%d all=%v", tag, got[tag], wantCount, got)
|
||||
}
|
||||
}
|
||||
|
||||
if len(docEngine.requests) != 2 {
|
||||
t.Fatalf("search requests=%d want=2", len(docEngine.requests))
|
||||
}
|
||||
|
||||
requestByIndex := make(map[string]*types.SearchRequest, len(docEngine.requests))
|
||||
for _, req := range docEngine.requests {
|
||||
if len(req.IndexNames) != 1 {
|
||||
t.Fatalf("IndexNames=%v want single entry", req.IndexNames)
|
||||
}
|
||||
requestByIndex[req.IndexNames[0]] = req
|
||||
if req.Offset != 0 {
|
||||
t.Fatalf("Offset=%d want=0", req.Offset)
|
||||
}
|
||||
if req.Limit != 10000 {
|
||||
t.Fatalf("Limit=%d want=10000", req.Limit)
|
||||
}
|
||||
if len(req.SelectFields) != 1 || req.SelectFields[0] != "tag_kwd" {
|
||||
t.Fatalf("SelectFields=%v want [tag_kwd]", req.SelectFields)
|
||||
}
|
||||
}
|
||||
|
||||
if req := requestByIndex["ragflow_user-1"]; req == nil || len(req.KbIDs) != 1 || req.KbIDs[0] != kb1ID {
|
||||
t.Fatalf("request for ragflow_user-1 = %#v, want kbIDs=[%s]", req, kb1ID)
|
||||
}
|
||||
if req := requestByIndex["ragflow_tenant-2"]; req == nil || len(req.KbIDs) != 1 || req.KbIDs[0] != kb2ID {
|
||||
t.Fatalf("request for ragflow_tenant-2 = %#v, want kbIDs=[%s]", req, kb2ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatasetServiceAggregateTagsPagesThroughAllChunks(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
pushServiceDB(t, db)
|
||||
|
||||
kbInput := "723e4567-e89b-12d3-a456-426614174006"
|
||||
kbID := strings.ReplaceAll(kbInput, "-", "")
|
||||
insertAggregateTagsKB(t, kbID, "user-1", string(entity.TenantPermissionMe), 10002)
|
||||
|
||||
firstPage := make([]map[string]interface{}, 10000)
|
||||
for i := range firstPage {
|
||||
firstPage[i] = map[string]interface{}{"tag_kwd": "finance"}
|
||||
}
|
||||
|
||||
docEngine := &aggregateTagsMockEngine{
|
||||
pagedSearchResults: map[string]map[int]*types.SearchResult{
|
||||
"ragflow_user-1": {
|
||||
0: {
|
||||
Chunks: firstPage,
|
||||
Total: 10002,
|
||||
},
|
||||
10000: {
|
||||
Chunks: []map[string]interface{}{
|
||||
{"tag_kwd": "finance"},
|
||||
{"tag_kwd": "urgent"},
|
||||
},
|
||||
Total: 10002,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, code, err := testDatasetServiceForAggregateTags(t, docEngine).AggregateTags([]string{kbInput}, "user-1")
|
||||
if err != nil {
|
||||
t.Fatalf("AggregateTags failed: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code=%d want=%d", code, common.CodeSuccess)
|
||||
}
|
||||
|
||||
got := aggregateTagsResultMap(result)
|
||||
if got["finance"] != 10001 {
|
||||
t.Fatalf("finance count=%d want=10001 result=%v", got["finance"], got)
|
||||
}
|
||||
if got["urgent"] != 1 {
|
||||
t.Fatalf("urgent count=%d want=1 result=%v", got["urgent"], got)
|
||||
}
|
||||
|
||||
if len(docEngine.requests) != 2 {
|
||||
t.Fatalf("search requests=%d want=2", len(docEngine.requests))
|
||||
}
|
||||
if docEngine.requests[0].Offset != 0 {
|
||||
t.Fatalf("first request offset=%d want=0", docEngine.requests[0].Offset)
|
||||
}
|
||||
if docEngine.requests[1].Offset != 10000 {
|
||||
t.Fatalf("second request offset=%d want=10000", docEngine.requests[1].Offset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatasetServiceAggregateTagsRejectsUnauthorizedDataset(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
pushServiceDB(t, db)
|
||||
|
||||
kbInput := "323e4567-e89b-12d3-a456-426614174002"
|
||||
kbID := strings.ReplaceAll(kbInput, "-", "")
|
||||
insertAggregateTagsKB(t, kbID, "tenant-9", string(entity.TenantPermissionMe), 1)
|
||||
|
||||
_, code, err := testDatasetServiceForAggregateTags(t, &aggregateTagsMockEngine{}).AggregateTags([]string{kbInput}, "user-1")
|
||||
if err == nil {
|
||||
t.Fatal("expected authorization error")
|
||||
}
|
||||
if code != common.CodeDataError {
|
||||
t.Fatalf("code=%d want=%d", code, common.CodeDataError)
|
||||
}
|
||||
if err.Error() != "No authorization for dataset '"+kbID+"'" {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatasetServiceAggregateTagsRequiresDocumentEngine(t *testing.T) {
|
||||
_, code, err := testDatasetServiceForAggregateTags(t, nil).AggregateTags([]string{"123e4567-e89b-12d3-a456-426614174000"}, "user-1")
|
||||
if err == nil {
|
||||
t.Fatal("expected missing doc engine error")
|
||||
}
|
||||
if code != common.CodeServerError {
|
||||
t.Fatalf("code=%d want=%d", code, common.CodeServerError)
|
||||
}
|
||||
if err.Error() != "Document engine is not initialized" {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatasetServiceAggregateTagsSkipsDatasetsWithoutDocuments(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
pushServiceDB(t, db)
|
||||
|
||||
emptyInput := "423e4567-e89b-12d3-a456-426614174003"
|
||||
liveInput := "523e4567-e89b-12d3-a456-426614174004"
|
||||
emptyID := strings.ReplaceAll(emptyInput, "-", "")
|
||||
liveID := strings.ReplaceAll(liveInput, "-", "")
|
||||
|
||||
insertAggregateTagsKB(t, emptyID, "user-1", string(entity.TenantPermissionMe), 0)
|
||||
insertAggregateTagsKB(t, liveID, "user-1", string(entity.TenantPermissionMe), 1)
|
||||
|
||||
docEngine := &aggregateTagsMockEngine{
|
||||
searchResults: map[string]*types.SearchResult{
|
||||
"ragflow_user-1": {
|
||||
Chunks: []map[string]interface{}{
|
||||
{"tag_kwd": "alpha###beta"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, code, err := testDatasetServiceForAggregateTags(t, docEngine).AggregateTags([]string{emptyInput, liveInput}, "user-1")
|
||||
if err != nil {
|
||||
t.Fatalf("AggregateTags failed: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code=%d want=%d", code, common.CodeSuccess)
|
||||
}
|
||||
if len(docEngine.requests) != 1 {
|
||||
t.Fatalf("search requests=%d want=1", len(docEngine.requests))
|
||||
}
|
||||
if len(docEngine.requests[0].KbIDs) != 1 || docEngine.requests[0].KbIDs[0] != liveID {
|
||||
t.Fatalf("KbIDs=%v want [%s]", docEngine.requests[0].KbIDs, liveID)
|
||||
}
|
||||
|
||||
got := aggregateTagsResultMap(result)
|
||||
if got["alpha"] != 1 || got["beta"] != 1 {
|
||||
t.Fatalf("unexpected result=%v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatasetServiceAggregateTagsReturnsSearchError(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
pushServiceDB(t, db)
|
||||
|
||||
kbInput := "623e4567-e89b-12d3-a456-426614174005"
|
||||
kbID := strings.ReplaceAll(kbInput, "-", "")
|
||||
insertAggregateTagsKB(t, kbID, "user-1", string(entity.TenantPermissionMe), 1)
|
||||
|
||||
docEngine := &aggregateTagsMockEngine{searchErr: errors.New("boom")}
|
||||
_, code, err := testDatasetServiceForAggregateTags(t, docEngine).AggregateTags([]string{kbInput}, "user-1")
|
||||
if err == nil {
|
||||
t.Fatal("expected search error")
|
||||
}
|
||||
if code != common.CodeServerError {
|
||||
t.Fatalf("code=%d want=%d", code, common.CodeServerError)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed to aggregate tags: boom") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user