mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
### What problem does this PR solve? As title: implement: ``` /api/v1/messages/search GET /api/v1/messages GET /api/v1/messages/<memory_id>:<message_id>/content GET /api/v1/memories/<memory_id>/config GET /api/v1/messages/<memory_id>:<message_id> PUT ``` ### Type of change - [x] New Feature (non-breaking change which adds functionality)
275 lines
8.0 KiB
Go
275 lines
8.0 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"testing"
|
|
|
|
"github.com/glebarez/sqlite"
|
|
"gorm.io/gorm"
|
|
|
|
"ragflow/internal/common"
|
|
"ragflow/internal/dao"
|
|
enginetypes "ragflow/internal/engine/types"
|
|
"ragflow/internal/entity"
|
|
)
|
|
|
|
func TestIsMessageDocumentNotFound(t *testing.T) {
|
|
if !isMessageDocumentNotFound(fmt.Errorf("wrapped: %w", enginetypes.ErrDocumentNotFound)) {
|
|
t.Fatal("expected wrapped document-not-found error to be recognized")
|
|
}
|
|
|
|
if isMessageDocumentNotFound(errors.New("index does not exist")) {
|
|
t.Fatal("expected unrelated backend error to remain a server error")
|
|
}
|
|
}
|
|
|
|
func TestRequireMemoryAccessReturnsCanceledContext(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
err := ctx.Err()
|
|
if _, gotErr := NewMemoryService().requireMemoryAccess(ctx, "user-1", "memory-1"); !errors.Is(gotErr, err) {
|
|
t.Fatalf("requireMemoryAccess error = %v, want %v", gotErr, err)
|
|
}
|
|
}
|
|
|
|
type memoryMessageDocEngine struct {
|
|
fakeChatDocEngine
|
|
searchReq *enginetypes.SearchRequest
|
|
searchResp *enginetypes.SearchResult
|
|
updateCond map[string]interface{}
|
|
updateValue map[string]interface{}
|
|
updateBase string
|
|
updateID string
|
|
}
|
|
|
|
func (e *memoryMessageDocEngine) Search(ctx context.Context, req *enginetypes.SearchRequest) (*enginetypes.SearchResult, error) {
|
|
e.searchReq = req
|
|
if e.searchResp != nil {
|
|
return e.searchResp, nil
|
|
}
|
|
return &enginetypes.SearchResult{}, nil
|
|
}
|
|
|
|
func (e *memoryMessageDocEngine) UpdateChunks(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, baseName string, datasetID string) error {
|
|
e.updateCond = condition
|
|
e.updateValue = newValue
|
|
e.updateBase = baseName
|
|
e.updateID = datasetID
|
|
return nil
|
|
}
|
|
|
|
func setupMemoryMessageTestDB(t *testing.T) {
|
|
t.Helper()
|
|
|
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{TranslateError: true})
|
|
if err != nil {
|
|
t.Fatalf("failed to open sqlite: %v", err)
|
|
}
|
|
if err := db.AutoMigrate(&entity.Memory{}, &entity.UserTenant{}); err != nil {
|
|
t.Fatalf("failed to migrate memory test tables: %v", err)
|
|
}
|
|
|
|
orig := dao.DB
|
|
dao.DB = db
|
|
t.Cleanup(func() {
|
|
dao.DB = orig
|
|
})
|
|
}
|
|
|
|
func seedMemoryMessages(t *testing.T) {
|
|
t.Helper()
|
|
|
|
memories := []*entity.Memory{
|
|
{
|
|
ID: "mem-owned",
|
|
Name: "Owned",
|
|
TenantID: "user-1",
|
|
MemoryType: dao.MemoryTypeRaw,
|
|
StorageType: "table",
|
|
EmbdID: "embd-1",
|
|
LLMID: "llm-1",
|
|
Permissions: string(TenantPermissionMe),
|
|
ForgettingPolicy: string(ForgettingPolicyFIFO),
|
|
},
|
|
{
|
|
ID: "mem-other",
|
|
Name: "Other",
|
|
TenantID: "user-2",
|
|
MemoryType: dao.MemoryTypeRaw,
|
|
StorageType: "table",
|
|
EmbdID: "embd-2",
|
|
LLMID: "llm-2",
|
|
Permissions: string(TenantPermissionMe),
|
|
ForgettingPolicy: string(ForgettingPolicyFIFO),
|
|
},
|
|
}
|
|
for _, memory := range memories {
|
|
if err := dao.DB.Create(memory).Error; err != nil {
|
|
t.Fatalf("seed memory %s: %v", memory.ID, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestGetMessagesFiltersAccessibleMemoryAndBuildsRecentSearch(t *testing.T) {
|
|
setupMemoryMessageTestDB(t)
|
|
seedMemoryMessages(t)
|
|
|
|
docEngine := &memoryMessageDocEngine{
|
|
searchResp: &enginetypes.SearchResult{
|
|
Total: 1,
|
|
Chunks: []map[string]interface{}{
|
|
{
|
|
"message_id": int64(12),
|
|
"message_type": "raw",
|
|
"memory_id": "mem-owned",
|
|
"user_id": "user-1",
|
|
"agent_id": "agent-1",
|
|
"session_id": "session-1",
|
|
"valid_at": float64(123),
|
|
"status": 1,
|
|
"content": "hello",
|
|
"extra": "should be dropped",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
svc := &MemoryService{memoryDAO: dao.NewMemoryDAO(), docEngine: docEngine}
|
|
|
|
got, code, err := svc.GetMessages(context.Background(), []string{"mem-owned", "mem-other"}, "user-1", "agent-1", "session-1", 3)
|
|
if err != nil {
|
|
t.Fatalf("GetMessages error: %v", err)
|
|
}
|
|
if code != common.CodeSuccess {
|
|
t.Fatalf("code = %v, want %v", code, common.CodeSuccess)
|
|
}
|
|
if len(got) != 1 || got[0]["content"] != "hello" {
|
|
t.Fatalf("unexpected messages: %+v", got)
|
|
}
|
|
if _, ok := got[0]["extra"]; ok {
|
|
t.Fatalf("unexpected non-selected field in response: %+v", got[0])
|
|
}
|
|
|
|
req := docEngine.searchReq
|
|
if req == nil {
|
|
t.Fatal("expected doc engine search request")
|
|
}
|
|
if !reflect.DeepEqual(req.IndexNames, []string{"memory_user-1"}) {
|
|
t.Fatalf("IndexNames = %v, want [memory_user-1]", req.IndexNames)
|
|
}
|
|
if len(req.KbIDs) != 0 {
|
|
t.Fatalf("KbIDs = %v, want empty for memory message search", req.KbIDs)
|
|
}
|
|
if !reflect.DeepEqual(req.Filter["memory_id"], []string{"mem-owned"}) {
|
|
t.Fatalf("memory_id filter = %v, want [mem-owned]", req.Filter["memory_id"])
|
|
}
|
|
if req.Filter["agent_id"] != "agent-1" || req.Filter["session_id"] != "session-1" {
|
|
t.Fatalf("unexpected filter: %+v", req.Filter)
|
|
}
|
|
if req.Limit != 3 {
|
|
t.Fatalf("Limit = %d, want 3", req.Limit)
|
|
}
|
|
if req.OrderBy == nil || len(req.OrderBy.Fields) != 1 || req.OrderBy.Fields[0].Field != "valid_at" || req.OrderBy.Fields[0].Type != enginetypes.SortDesc {
|
|
t.Fatalf("unexpected order by: %+v", req.OrderBy)
|
|
}
|
|
}
|
|
|
|
func TestSearchMessageFiltersAccessibleMemoryAndDefaultsStatus(t *testing.T) {
|
|
setupMemoryMessageTestDB(t)
|
|
seedMemoryMessages(t)
|
|
|
|
docEngine := &memoryMessageDocEngine{
|
|
searchResp: &enginetypes.SearchResult{
|
|
Total: 1,
|
|
Chunks: []map[string]interface{}{
|
|
{
|
|
"message_id": int64(13),
|
|
"message_type": "raw",
|
|
"memory_id": "mem-owned",
|
|
"user_id": "user-1",
|
|
"agent_id": "agent-1",
|
|
"session_id": "session-1",
|
|
"valid_at": int64(456),
|
|
"status": 1,
|
|
"content": "matched",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
svc := &MemoryService{memoryDAO: dao.NewMemoryDAO(), docEngine: docEngine}
|
|
filter := map[string]interface{}{
|
|
"memory_id": []string{"mem-owned", "mem-other"},
|
|
"agent_id": "agent-1",
|
|
"session_id": "session-1",
|
|
"user_id": "user-1",
|
|
}
|
|
params := map[string]interface{}{
|
|
"query": "",
|
|
"similarity_threshold": 0.2,
|
|
"keywords_similarity_weight": 0.7,
|
|
"top_n": 5,
|
|
}
|
|
|
|
got, code, err := svc.SearchMessage(context.Background(), "user-1", filter, params)
|
|
if err != nil {
|
|
t.Fatalf("SearchMessage error: %v", err)
|
|
}
|
|
if code != common.CodeSuccess {
|
|
t.Fatalf("code = %v, want %v", code, common.CodeSuccess)
|
|
}
|
|
if len(got) != 1 || got[0]["content"] != "matched" {
|
|
t.Fatalf("unexpected search result: %+v", got)
|
|
}
|
|
|
|
req := docEngine.searchReq
|
|
if req == nil {
|
|
t.Fatal("expected doc engine search request")
|
|
}
|
|
if !reflect.DeepEqual(req.Filter["memory_id"], []string{"mem-owned"}) {
|
|
t.Fatalf("memory_id filter = %v, want [mem-owned]", req.Filter["memory_id"])
|
|
}
|
|
if req.Filter["status"] != 1 {
|
|
t.Fatalf("status filter = %v, want 1", req.Filter["status"])
|
|
}
|
|
if req.Filter["agent_id"] != "agent-1" || req.Filter["session_id"] != "session-1" || req.Filter["user_id"] != "user-1" {
|
|
t.Fatalf("unexpected filter: %+v", req.Filter)
|
|
}
|
|
if len(req.MatchExprs) != 0 {
|
|
t.Fatalf("empty query should not build match expressions, got %+v", req.MatchExprs)
|
|
}
|
|
if req.Limit != 5 {
|
|
t.Fatalf("Limit = %d, want 5", req.Limit)
|
|
}
|
|
}
|
|
|
|
func TestUpdateMessageUpdatesStatusByMessageDocID(t *testing.T) {
|
|
setupMemoryMessageTestDB(t)
|
|
seedMemoryMessages(t)
|
|
|
|
docEngine := &memoryMessageDocEngine{}
|
|
svc := &MemoryService{memoryDAO: dao.NewMemoryDAO(), docEngine: docEngine}
|
|
|
|
ok, err := svc.UpdateMessage(context.Background(), "user-1", "mem-owned", 42, true)
|
|
if err != nil {
|
|
t.Fatalf("UpdateMessage error: %v", err)
|
|
}
|
|
if !ok {
|
|
t.Fatal("UpdateMessage returned false")
|
|
}
|
|
if docEngine.updateBase != "memory_user-1" {
|
|
t.Fatalf("baseName = %q, want memory_user-1", docEngine.updateBase)
|
|
}
|
|
if docEngine.updateID != "mem-owned" {
|
|
t.Fatalf("datasetID = %q, want mem-owned", docEngine.updateID)
|
|
}
|
|
if docEngine.updateCond["id"] != "mem-owned_42" {
|
|
t.Fatalf("condition = %+v, want id mem-owned_42", docEngine.updateCond)
|
|
}
|
|
if docEngine.updateValue["status"] != 1 {
|
|
t.Fatalf("status update = %+v, want status 1", docEngine.updateValue)
|
|
}
|
|
}
|