mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
feat(go-api): add dataset tags endpoints (#16231)
## Summary - add `GET /api/v1/datasets/:dataset_id/tags` - add `PUT /api/v1/datasets/:dataset_id/tags` - implement dataset tag listing and rename flow - align rename tag validation and response shape with the Python API - add handler and service tests for dataset tags ## Routes - `GET /api/v1/datasets/:dataset_id/tags` - `PUT /api/v1/datasets/:dataset_id/tags` ## Test - Run specific tests for dataset tags: ``` go test -v ./internal/service ./internal/handler -run 'TestDatasetServiceListTags|TestDatasetServiceRenameTag|TestDatasetsHandlerListTags|TestDatasetsHandlerRenameTag' ``` - Run all tests for service and handler to verify no regressions: ``` go test ./internal/service ./internal/handler ``` - use curl cmd to test
This commit is contained in:
@@ -479,6 +479,77 @@ func (h *DatasetsHandler) GetKnowledgeGraph(c *gin.Context) {
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
// ListTags handles GET /api/v1/datasets/:dataset_id/tags.
|
||||
// @Summary List dataset tags
|
||||
// @Description List tags for a dataset
|
||||
// @Tags datasets
|
||||
// @Produce json
|
||||
// @Security ApiKeyAuth
|
||||
// @Param dataset_id path string true "Dataset ID"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/v1/datasets/{dataset_id}/tags [get]
|
||||
func (h *DatasetsHandler) ListTags(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
datasetID := strings.TrimSpace(c.Param("dataset_id"))
|
||||
result, code, err := h.datasetsService.ListTags(datasetID, user.ID)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
type renameTagRequest struct {
|
||||
FromTag string `json:"from_tag"`
|
||||
ToTag string `json:"to_tag"`
|
||||
}
|
||||
|
||||
func (h *DatasetsHandler) RenameTag(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
datasetID := strings.TrimSpace(c.Param("dataset_id"))
|
||||
|
||||
var payload map[string]interface{}
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
jsonError(c, common.CodeDataError, "Lack of from_tag or to_tag in request body")
|
||||
return
|
||||
}
|
||||
fromTagValue, hasFrom := payload["from_tag"]
|
||||
toTagValue, hasTo := payload["to_tag"]
|
||||
if !hasFrom || !hasTo {
|
||||
jsonError(c, common.CodeDataError, "Lack of from_tag or to_tag in request body")
|
||||
return
|
||||
}
|
||||
fromTag, okFrom := fromTagValue.(string)
|
||||
toTag, okTo := toTagValue.(string)
|
||||
if !okFrom || !okTo {
|
||||
jsonError(c, common.CodeArgumentError, "from_tag and to_tag must be strings")
|
||||
return
|
||||
}
|
||||
req := renameTagRequest{FromTag: fromTag, ToTag: toTag}
|
||||
if strings.TrimSpace(req.FromTag) == "" || strings.TrimSpace(req.ToTag) == "" {
|
||||
jsonError(c, common.CodeArgumentError, "from_tag and to_tag must not be empty")
|
||||
return
|
||||
}
|
||||
|
||||
result, code, err := h.datasetsService.RenameTag(datasetID, user.ID, req.FromTag, req.ToTag)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
// DeleteKnowledgeGraph handles DELETE /api/v1/datasets/:dataset_id/graph.
|
||||
func (h *DatasetsHandler) DeleteKnowledgeGraph(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
|
||||
39
internal/handler/datasets_list_tags_test.go
Normal file
39
internal/handler/datasets_list_tags_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/common"
|
||||
)
|
||||
|
||||
func TestDatasetsHandlerListTagsRequiresAuth(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &DatasetsHandler{}
|
||||
r := gin.New()
|
||||
r.GET("/api/v1/datasets/:dataset_id/tags", func(c *gin.Context) {
|
||||
h.ListTags(c)
|
||||
})
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/datasets/123e4567-e89b-12d3-a456-426614174000/tags", nil)
|
||||
r.ServeHTTP(resp, req)
|
||||
|
||||
var body struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v body=%s", err, resp.Body.String())
|
||||
}
|
||||
if body.Code != int(common.CodeUnauthorized) {
|
||||
t.Fatalf("code=%d want=%d body=%s", body.Code, common.CodeUnauthorized, resp.Body.String())
|
||||
}
|
||||
if body.Message != "User not found" {
|
||||
t.Fatalf("message=%q want=%q", body.Message, "User not found")
|
||||
}
|
||||
}
|
||||
126
internal/handler/datasets_rename_tag_test.go
Normal file
126
internal/handler/datasets_rename_tag_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
func TestDatasetsHandlerRenameTagRequiresAuth(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &DatasetsHandler{}
|
||||
r := gin.New()
|
||||
r.PUT("/api/v1/datasets/:dataset_id/tags", func(c *gin.Context) {
|
||||
h.RenameTag(c)
|
||||
})
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/datasets/123e4567-e89b-12d3-a456-426614174000/tags", bytes.NewBufferString(`{"from_tag":"a","to_tag":"b"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(resp, req)
|
||||
|
||||
var body struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v body=%s", err, resp.Body.String())
|
||||
}
|
||||
if body.Code != int(common.CodeUnauthorized) {
|
||||
t.Fatalf("code=%d want=%d body=%s", body.Code, common.CodeUnauthorized, resp.Body.String())
|
||||
}
|
||||
if body.Message != "User not found" {
|
||||
t.Fatalf("message=%q want=%q", body.Message, "User not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatasetsHandlerRenameTagRejectsMissingFields(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &DatasetsHandler{}
|
||||
r := gin.New()
|
||||
r.PUT("/api/v1/datasets/:dataset_id/tags", func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: "user-1"})
|
||||
h.RenameTag(c)
|
||||
})
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/datasets/123e4567-e89b-12d3-a456-426614174000/tags", bytes.NewBufferString(`{"from_tag":"a"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(resp, req)
|
||||
|
||||
var body struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v body=%s", err, resp.Body.String())
|
||||
}
|
||||
if body.Code != int(common.CodeDataError) {
|
||||
t.Fatalf("code=%d want=%d body=%s", body.Code, common.CodeDataError, resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatasetsHandlerRenameTagRejectsEmptyFields(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &DatasetsHandler{}
|
||||
r := gin.New()
|
||||
r.PUT("/api/v1/datasets/:dataset_id/tags", func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: "user-1"})
|
||||
h.RenameTag(c)
|
||||
})
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/datasets/123e4567-e89b-12d3-a456-426614174000/tags", bytes.NewBufferString(`{"from_tag":" ","to_tag":"x"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(resp, req)
|
||||
|
||||
var body struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v body=%s", err, resp.Body.String())
|
||||
}
|
||||
if body.Code != int(common.CodeArgumentError) {
|
||||
t.Fatalf("code=%d want=%d body=%s", body.Code, common.CodeArgumentError, resp.Body.String())
|
||||
}
|
||||
if body.Message != "from_tag and to_tag must not be empty" {
|
||||
t.Fatalf("message=%q want=%q", body.Message, "from_tag and to_tag must not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatasetsHandlerRenameTagRejectsNonStringFields(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &DatasetsHandler{}
|
||||
r := gin.New()
|
||||
r.PUT("/api/v1/datasets/:dataset_id/tags", func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: "user-1"})
|
||||
h.RenameTag(c)
|
||||
})
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/datasets/123e4567-e89b-12d3-a456-426614174000/tags", bytes.NewBufferString(`{"from_tag":1,"to_tag":"x"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(resp, req)
|
||||
|
||||
var body struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v body=%s", err, resp.Body.String())
|
||||
}
|
||||
if body.Code != int(common.CodeArgumentError) {
|
||||
t.Fatalf("code=%d want=%d body=%s", body.Code, common.CodeArgumentError, resp.Body.String())
|
||||
}
|
||||
if body.Message != "from_tag and to_tag must be strings" {
|
||||
t.Fatalf("message=%q want=%q", body.Message, "from_tag and to_tag must be strings")
|
||||
}
|
||||
}
|
||||
@@ -274,6 +274,8 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
datasets.GET("/:dataset_id", r.datasetsHandler.GetDataset)
|
||||
datasets.PUT("/:dataset_id", r.datasetsHandler.UpdateDataset)
|
||||
datasets.GET("/:dataset_id/graph", r.datasetsHandler.GetKnowledgeGraph)
|
||||
datasets.GET("/:dataset_id/tags", r.datasetsHandler.ListTags)
|
||||
datasets.PUT("/:dataset_id/tags", r.datasetsHandler.RenameTag)
|
||||
datasets.DELETE("/:dataset_id/tags", r.datasetsHandler.RemoveTags)
|
||||
datasets.POST("/:dataset_id/documents/batch-update-status", r.documentHandler.BatchUpdateDocumentStatus)
|
||||
datasets.GET("/:dataset_id/index", r.datasetsHandler.TraceIndex)
|
||||
|
||||
@@ -2031,6 +2031,113 @@ func (s *DatasetService) AggregateTags(datasetIDs []string, userID string) ([]ma
|
||||
return result, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (s *DatasetService) ListTags(datasetID, userID string) ([]map[string]interface{}, common.ErrorCode, error) {
|
||||
datasetID = strings.TrimSpace(datasetID)
|
||||
if datasetID == "" {
|
||||
return nil, common.CodeDataError, errors.New("Lack of \"Dataset ID\"")
|
||||
}
|
||||
|
||||
normalizedID, err := normalizeDatasetID(datasetID)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, err
|
||||
}
|
||||
datasetID = normalizedID
|
||||
|
||||
if !s.kbDAO.Accessible(datasetID, userID) {
|
||||
return nil, common.CodeDataError, errors.New("No authorization.")
|
||||
}
|
||||
if s.docEngine == nil {
|
||||
return nil, common.CodeServerError, errors.New("Document engine is not initialized")
|
||||
}
|
||||
|
||||
kb, err := s.kbDAO.GetByID(datasetID)
|
||||
if err != nil || kb == nil {
|
||||
return nil, common.CodeDataError, errors.New("Invalid Dataset ID")
|
||||
}
|
||||
|
||||
indexName := fmt.Sprintf("ragflow_%s", kb.TenantID)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
exists, err := s.docEngine.ChunkStoreExists(ctx, indexName, datasetID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to inspect chunk store: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return []map[string]interface{}{}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
const pageSize = 10000
|
||||
counts := make(map[string]int)
|
||||
for offset := 0; ; offset += pageSize {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("list tags timeout or canceled: %w", err)
|
||||
}
|
||||
|
||||
searchResp, err := s.docEngine.Search(ctx, &types.SearchRequest{
|
||||
IndexNames: []string{indexName},
|
||||
KbIDs: []string{datasetID},
|
||||
Offset: offset,
|
||||
Limit: pageSize,
|
||||
SelectFields: []string{"tag_kwd"},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to list tags: %w", err)
|
||||
}
|
||||
|
||||
for _, agg := range s.docEngine.GetAggregation(searchResp.Chunks, "tag_kwd") {
|
||||
tag, _ := agg["key"].(string)
|
||||
if tag == "" {
|
||||
continue
|
||||
}
|
||||
switch count := agg["count"].(type) {
|
||||
case int:
|
||||
counts[tag] += count
|
||||
case int32:
|
||||
counts[tag] += int(count)
|
||||
case int64:
|
||||
counts[tag] += int(count)
|
||||
case float64:
|
||||
counts[tag] += int(count)
|
||||
}
|
||||
}
|
||||
|
||||
chunkCount := len(searchResp.Chunks)
|
||||
if chunkCount == 0 || chunkCount < pageSize {
|
||||
break
|
||||
}
|
||||
if searchResp.Total > 0 && int64(offset+chunkCount) >= searchResp.Total {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(counts) == 0 {
|
||||
return []map[string]interface{}{}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
tags := make([]string, 0, len(counts))
|
||||
for tag := range counts {
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
sort.Slice(tags, func(i, j int) bool {
|
||||
if counts[tags[i]] != counts[tags[j]] {
|
||||
return counts[tags[i]] > counts[tags[j]]
|
||||
}
|
||||
return tags[i] < tags[j]
|
||||
})
|
||||
|
||||
result := make([]map[string]interface{}, 0, len(tags))
|
||||
for _, tag := range tags {
|
||||
result = append(result, map[string]interface{}{
|
||||
"key": tag,
|
||||
"count": counts[tag],
|
||||
})
|
||||
}
|
||||
|
||||
return result, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// GetIngestionSummary returns dataset-level ingestion counters together with
|
||||
// the aggregated document parsing status, mirroring
|
||||
// dataset_api_service.get_ingestion_summary.
|
||||
@@ -2544,3 +2651,51 @@ func limitStrings(values []string, limit int) []string {
|
||||
}
|
||||
return values[:limit]
|
||||
}
|
||||
|
||||
func (s *DatasetService) RenameTag(datasetID, userID, fromTag, toTag string) (map[string]interface{}, common.ErrorCode, error) {
|
||||
fromTag = strings.TrimSpace(fromTag)
|
||||
toTag = strings.TrimSpace(toTag)
|
||||
|
||||
datasetID, err := normalizeDatasetID(datasetID)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, err
|
||||
}
|
||||
if strings.TrimSpace(datasetID) == "" {
|
||||
return nil, common.CodeDataError, errors.New("Lack of \"Dataset ID\"")
|
||||
}
|
||||
if !s.kbDAO.Accessible(datasetID, userID) {
|
||||
return nil, common.CodeDataError, errors.New("No authorization.")
|
||||
}
|
||||
if s.docEngine == nil {
|
||||
return nil, common.CodeServerError, errors.New("Document engine is not initialized")
|
||||
}
|
||||
|
||||
kb, err := s.kbDAO.GetByID(datasetID)
|
||||
if err != nil || kb == nil {
|
||||
return nil, common.CodeDataError, errors.New("Invalid Dataset ID")
|
||||
}
|
||||
indexName := fmt.Sprintf("ragflow_%s", kb.TenantID)
|
||||
|
||||
condition := map[string]interface{}{
|
||||
"tag_kwd": fromTag,
|
||||
"kb_id": datasetID,
|
||||
}
|
||||
newValue := map[string]interface{}{
|
||||
"remove": map[string]interface{}{
|
||||
"tag_kwd": fromTag,
|
||||
},
|
||||
"add": map[string]interface{}{
|
||||
"tag_kwd": toTag,
|
||||
},
|
||||
}
|
||||
|
||||
err = s.docEngine.UpdateChunks(context.Background(), condition, newValue, indexName, datasetID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to rename tag: %w", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"from": fromTag,
|
||||
"to": toTag,
|
||||
}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
238
internal/service/dataset_list_tags_test.go
Normal file
238
internal/service/dataset_list_tags_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/engine/types"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
type listTagsMockEngine struct {
|
||||
engine.DocEngine
|
||||
searchResults map[string]*types.SearchResult
|
||||
searchErr error
|
||||
requests []*types.SearchRequest
|
||||
chunkStoreExists bool
|
||||
chunkStoreErr error
|
||||
}
|
||||
|
||||
func (m *listTagsMockEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
|
||||
if m.searchErr != nil {
|
||||
return nil, m.searchErr
|
||||
}
|
||||
cloned := &types.SearchRequest{
|
||||
IndexNames: append([]string(nil), req.IndexNames...),
|
||||
KbIDs: append([]string(nil), req.KbIDs...),
|
||||
Offset: req.Offset,
|
||||
Limit: req.Limit,
|
||||
SelectFields: append([]string(nil), req.SelectFields...),
|
||||
}
|
||||
m.requests = append(m.requests, cloned)
|
||||
if len(req.IndexNames) == 0 {
|
||||
return &types.SearchResult{}, nil
|
||||
}
|
||||
if res, ok := m.searchResults[req.IndexNames[0]]; ok {
|
||||
return res, nil
|
||||
}
|
||||
return &types.SearchResult{}, nil
|
||||
}
|
||||
|
||||
func (m *listTagsMockEngine) GetAggregation(chunks []map[string]interface{}, fieldName string) []map[string]interface{} {
|
||||
counts := make(map[string]int)
|
||||
for _, chunk := range chunks {
|
||||
raw, ok := chunk[fieldName]
|
||||
if !ok || raw == nil {
|
||||
continue
|
||||
}
|
||||
switch value := raw.(type) {
|
||||
case string:
|
||||
for _, tag := range strings.Split(value, "###") {
|
||||
tag = strings.TrimSpace(tag)
|
||||
if tag != "" {
|
||||
counts[tag]++
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
for _, tag := range value {
|
||||
tag = strings.TrimSpace(tag)
|
||||
if tag != "" {
|
||||
counts[tag]++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, 0, len(counts))
|
||||
for tag, count := range counts {
|
||||
result = append(result, map[string]interface{}{
|
||||
"key": tag,
|
||||
"count": count,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *listTagsMockEngine) ChunkStoreExists(ctx context.Context, baseName, datasetID string) (bool, error) {
|
||||
if m.chunkStoreErr != nil {
|
||||
return false, m.chunkStoreErr
|
||||
}
|
||||
return m.chunkStoreExists, nil
|
||||
}
|
||||
|
||||
func (m *listTagsMockEngine) GetType() string { return "mock" }
|
||||
|
||||
func (m *listTagsMockEngine) Ping(ctx context.Context) error { return nil }
|
||||
|
||||
func (m *listTagsMockEngine) Close() error { return nil }
|
||||
|
||||
func testDatasetServiceForListTags(t *testing.T, docEngine engine.DocEngine) *DatasetService {
|
||||
t.Helper()
|
||||
return &DatasetService{
|
||||
kbDAO: dao.NewKnowledgebaseDAO(),
|
||||
docEngine: docEngine,
|
||||
}
|
||||
}
|
||||
|
||||
func insertListTagsKB(t *testing.T, datasetID, tenantID, permission string, docNum int64) {
|
||||
t.Helper()
|
||||
kb := &entity.Knowledgebase{
|
||||
ID: datasetID,
|
||||
TenantID: tenantID,
|
||||
Name: "kb-" + datasetID[:6],
|
||||
EmbdID: "embedding@OpenAI",
|
||||
CreatedBy: tenantID,
|
||||
Permission: permission,
|
||||
ParserID: "naive",
|
||||
ParserConfig: entity.JSONMap{},
|
||||
DocNum: docNum,
|
||||
Status: sptr(string(entity.StatusValid)),
|
||||
}
|
||||
if err := dao.DB.Create(kb).Error; err != nil {
|
||||
t.Fatalf("insert test kb: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatasetServiceListTagsSuccess(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
pushServiceDB(t, db)
|
||||
|
||||
kbInput := "123e4567-e89b-12d3-a456-426614174000"
|
||||
kbID := strings.ReplaceAll(kbInput, "-", "")
|
||||
insertListTagsKB(t, kbID, "user-1", string(entity.TenantPermissionMe), 2)
|
||||
|
||||
docEngine := &listTagsMockEngine{
|
||||
chunkStoreExists: true,
|
||||
searchResults: map[string]*types.SearchResult{
|
||||
"ragflow_user-1": {
|
||||
Chunks: []map[string]interface{}{
|
||||
{"tag_kwd": "finance###urgent"},
|
||||
{"tag_kwd": "finance"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, code, err := testDatasetServiceForListTags(t, docEngine).ListTags(kbInput, "user-1")
|
||||
if err != nil {
|
||||
t.Fatalf("ListTags failed: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code=%d want=%d", code, common.CodeSuccess)
|
||||
}
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("len(result)=%d want=2 result=%v", len(result), result)
|
||||
}
|
||||
if result[0]["key"] != "finance" || result[0]["count"] != 2 {
|
||||
t.Fatalf("first row=%v want finance/2", result[0])
|
||||
}
|
||||
if result[1]["key"] != "urgent" || result[1]["count"] != 1 {
|
||||
t.Fatalf("second row=%v want urgent/1", result[1])
|
||||
}
|
||||
if len(docEngine.requests) != 1 {
|
||||
t.Fatalf("search requests=%d want=1", len(docEngine.requests))
|
||||
}
|
||||
req := docEngine.requests[0]
|
||||
if len(req.IndexNames) != 1 || req.IndexNames[0] != "ragflow_user-1" {
|
||||
t.Fatalf("IndexNames=%v want [ragflow_user-1]", req.IndexNames)
|
||||
}
|
||||
if len(req.KbIDs) != 1 || req.KbIDs[0] != kbID {
|
||||
t.Fatalf("KbIDs=%v want [%s]", req.KbIDs, kbID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatasetServiceListTagsReturnsEmptyWhenChunkStoreMissing(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
pushServiceDB(t, db)
|
||||
|
||||
kbInput := "123e4567-e89b-12d3-a456-426614174000"
|
||||
kbID := strings.ReplaceAll(kbInput, "-", "")
|
||||
insertListTagsKB(t, kbID, "user-1", string(entity.TenantPermissionMe), 1)
|
||||
|
||||
docEngine := &listTagsMockEngine{chunkStoreExists: false}
|
||||
|
||||
result, code, err := testDatasetServiceForListTags(t, docEngine).ListTags(kbInput, "user-1")
|
||||
if err != nil {
|
||||
t.Fatalf("ListTags failed: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code=%d want=%d", code, common.CodeSuccess)
|
||||
}
|
||||
if len(result) != 0 {
|
||||
t.Fatalf("len(result)=%d want=0 result=%v", len(result), result)
|
||||
}
|
||||
if len(docEngine.requests) != 0 {
|
||||
t.Fatalf("search requests=%d want=0", len(docEngine.requests))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatasetServiceListTagsRejectsUnauthorizedDataset(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
pushServiceDB(t, db)
|
||||
|
||||
kbInput := "123e4567-e89b-12d3-a456-426614174000"
|
||||
kbID := strings.ReplaceAll(kbInput, "-", "")
|
||||
insertListTagsKB(t, kbID, "tenant-9", string(entity.TenantPermissionMe), 1)
|
||||
|
||||
docEngine := &listTagsMockEngine{chunkStoreExists: true}
|
||||
|
||||
_, code, err := testDatasetServiceForListTags(t, docEngine).ListTags(kbInput, "user-1")
|
||||
if err == nil {
|
||||
t.Fatal("expected unauthorized error")
|
||||
}
|
||||
if code != common.CodeDataError {
|
||||
t.Fatalf("code=%d want=%d", code, common.CodeDataError)
|
||||
}
|
||||
if err.Error() != "No authorization." {
|
||||
t.Fatalf("error=%q want=%q", err.Error(), "No authorization.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatasetServiceListTagsReturnsChunkStoreError(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
pushServiceDB(t, db)
|
||||
|
||||
kbInput := "123e4567-e89b-12d3-a456-426614174000"
|
||||
kbID := strings.ReplaceAll(kbInput, "-", "")
|
||||
insertListTagsKB(t, kbID, "user-1", string(entity.TenantPermissionMe), 1)
|
||||
|
||||
docEngine := &listTagsMockEngine{
|
||||
chunkStoreErr: errors.New("boom"),
|
||||
}
|
||||
|
||||
_, code, err := testDatasetServiceForListTags(t, docEngine).ListTags(kbInput, "user-1")
|
||||
if err == nil {
|
||||
t.Fatal("expected chunk store error")
|
||||
}
|
||||
if code != common.CodeServerError {
|
||||
t.Fatalf("code=%d want=%d", code, common.CodeServerError)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed to inspect chunk store: boom") {
|
||||
t.Fatalf("err=%q want contains %q", err.Error(), "failed to inspect chunk store: boom")
|
||||
}
|
||||
}
|
||||
154
internal/service/dataset_rename_tag_test.go
Normal file
154
internal/service/dataset_rename_tag_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
type renameTagUpdateCall struct {
|
||||
Condition map[string]interface{}
|
||||
NewValue map[string]interface{}
|
||||
BaseName string
|
||||
DatasetID string
|
||||
}
|
||||
|
||||
type renameTagMockEngine struct {
|
||||
engine.DocEngine
|
||||
updateCalls []renameTagUpdateCall
|
||||
updateErr error
|
||||
}
|
||||
|
||||
func (m *renameTagMockEngine) UpdateChunks(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, baseName string, datasetID string) error {
|
||||
m.updateCalls = append(m.updateCalls, renameTagUpdateCall{
|
||||
Condition: condition,
|
||||
NewValue: newValue,
|
||||
BaseName: baseName,
|
||||
DatasetID: datasetID,
|
||||
})
|
||||
return m.updateErr
|
||||
}
|
||||
|
||||
func (m *renameTagMockEngine) GetType() string { return "mock" }
|
||||
func (m *renameTagMockEngine) Ping(ctx context.Context) error { return nil }
|
||||
func (m *renameTagMockEngine) Close() error { return nil }
|
||||
|
||||
func testDatasetServiceForRenameTag(t *testing.T, docEngine engine.DocEngine) *DatasetService {
|
||||
t.Helper()
|
||||
return &DatasetService{
|
||||
kbDAO: dao.NewKnowledgebaseDAO(),
|
||||
docEngine: docEngine,
|
||||
}
|
||||
}
|
||||
|
||||
func insertRenameTagKB(t *testing.T, datasetID, tenantID, permission string) {
|
||||
t.Helper()
|
||||
kb := &entity.Knowledgebase{
|
||||
ID: datasetID,
|
||||
TenantID: tenantID,
|
||||
Name: "kb-" + datasetID[:6],
|
||||
EmbdID: "embedding@OpenAI",
|
||||
CreatedBy: tenantID,
|
||||
Permission: permission,
|
||||
ParserID: "naive",
|
||||
ParserConfig: entity.JSONMap{},
|
||||
Status: sptr(string(entity.StatusValid)),
|
||||
}
|
||||
if err := dao.DB.Create(kb).Error; err != nil {
|
||||
t.Fatalf("insert test kb: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatasetServiceRenameTagSuccess(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
pushServiceDB(t, db)
|
||||
|
||||
kbInput := "123e4567-e89b-12d3-a456-426614174000"
|
||||
kbID := strings.ReplaceAll(kbInput, "-", "")
|
||||
insertRenameTagKB(t, kbID, "user-1", string(entity.TenantPermissionMe))
|
||||
|
||||
docEngine := &renameTagMockEngine{}
|
||||
result, code, err := testDatasetServiceForRenameTag(t, docEngine).RenameTag(kbInput, "user-1", "old-tag ", "new-tag")
|
||||
if err != nil {
|
||||
t.Fatalf("RenameTag failed: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code=%d want=%d", code, common.CodeSuccess)
|
||||
}
|
||||
if result["from"] != "old-tag" || result["to"] != "new-tag" {
|
||||
t.Fatalf("result=%v", result)
|
||||
}
|
||||
if len(docEngine.updateCalls) != 1 {
|
||||
t.Fatalf("update calls=%d want=1", len(docEngine.updateCalls))
|
||||
}
|
||||
|
||||
call := docEngine.updateCalls[0]
|
||||
if call.BaseName != "ragflow_user-1" {
|
||||
t.Fatalf("baseName=%q want=%q", call.BaseName, "ragflow_user-1")
|
||||
}
|
||||
if call.DatasetID != kbID {
|
||||
t.Fatalf("datasetID=%q want=%q", call.DatasetID, kbID)
|
||||
}
|
||||
if got := call.Condition["tag_kwd"]; got != "old-tag" {
|
||||
t.Fatalf("condition tag_kwd=%v want=%q", got, "old-tag")
|
||||
}
|
||||
if got := call.Condition["kb_id"]; got != kbID {
|
||||
t.Fatalf("condition kb_id=%v want=%q", got, kbID)
|
||||
}
|
||||
|
||||
remove, _ := call.NewValue["remove"].(map[string]interface{})
|
||||
add, _ := call.NewValue["add"].(map[string]interface{})
|
||||
if got := remove["tag_kwd"]; got != "old-tag" {
|
||||
t.Fatalf("remove tag_kwd=%v want=%q", got, "old-tag")
|
||||
}
|
||||
if got := add["tag_kwd"]; got != "new-tag" {
|
||||
t.Fatalf("add tag_kwd=%v want=%q", got, "new-tag")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatasetServiceRenameTagUnauthorized(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
pushServiceDB(t, db)
|
||||
|
||||
kbInput := "123e4567-e89b-12d3-a456-426614174000"
|
||||
kbID := strings.ReplaceAll(kbInput, "-", "")
|
||||
insertRenameTagKB(t, kbID, "tenant-9", string(entity.TenantPermissionMe))
|
||||
|
||||
_, code, err := testDatasetServiceForRenameTag(t, &renameTagMockEngine{}).RenameTag(kbInput, "user-1", "old", "new")
|
||||
if err == nil {
|
||||
t.Fatal("expected unauthorized error")
|
||||
}
|
||||
if code != common.CodeDataError {
|
||||
t.Fatalf("code=%d want=%d", code, common.CodeDataError)
|
||||
}
|
||||
if err.Error() != "No authorization." {
|
||||
t.Fatalf("error=%q want=%q", err.Error(), "No authorization.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatasetServiceRenameTagUpdateError(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
pushServiceDB(t, db)
|
||||
|
||||
kbInput := "123e4567-e89b-12d3-a456-426614174000"
|
||||
kbID := strings.ReplaceAll(kbInput, "-", "")
|
||||
insertRenameTagKB(t, kbID, "user-1", string(entity.TenantPermissionMe))
|
||||
|
||||
docEngine := &renameTagMockEngine{updateErr: errors.New("boom")}
|
||||
_, code, err := testDatasetServiceForRenameTag(t, docEngine).RenameTag(kbInput, "user-1", "old", "new")
|
||||
if err == nil {
|
||||
t.Fatal("expected update error")
|
||||
}
|
||||
if code != common.CodeServerError {
|
||||
t.Fatalf("code=%d want=%d", code, common.CodeServerError)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed to rename tag") {
|
||||
t.Fatalf("error=%q want contains %q", err.Error(), "failed to rename tag")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user