diff --git a/internal/dao/memory.go b/internal/dao/memory.go index f82b72b0a4..4925312d35 100644 --- a/internal/dao/memory.go +++ b/internal/dao/memory.go @@ -301,7 +301,7 @@ func (dao *MemoryDAO) GetWithOwnerNameByID(id string) (*entity.MemoryListItem, e // Example: // // memories, total, err := dao.GetByFilter([]string{"tenant1"}, []string{"semantic"}, "table", "test", 1, 10) -func (dao *MemoryDAO) GetByFilter(tenantIDs []string, memoryTypes []string, storageType string, keywords string, page int, pageSize int) ([]*entity.MemoryListItem, int64, error) { +func (dao *MemoryDAO) GetByFilter(userID string, tenantIDs []string, memoryTypes []string, storageType string, keywords string, page int, pageSize int) ([]*entity.MemoryListItem, int64, error) { var conditions []string var args []interface{} @@ -310,6 +310,11 @@ func (dao *MemoryDAO) GetByFilter(tenantIDs []string, memoryTypes []string, stor args = append(args, tenantIDs) } + if userID != "" { + conditions = append(conditions, "(m.tenant_id = ? OR m.permissions = ?)") + args = append(args, userID, "team") + } + if len(memoryTypes) > 0 { memoryTypeInt := CalculateMemoryType(memoryTypes) conditions = append(conditions, "m.memory_type & ? > 0") diff --git a/internal/service/memory.go b/internal/service/memory.go index 19d28ec040..039427da03 100644 --- a/internal/service/memory.go +++ b/internal/service/memory.go @@ -1419,13 +1419,14 @@ func (s *MemoryService) ListMemories(userID string, tenantIDs []string, memoryTy if err != nil { return nil, fmt.Errorf("failed to get user tenants: %w", err) } - tenantIDs = make([]string, len(userTenants)) - for i, tenant := range userTenants { - tenantIDs[i] = tenant.TenantID + tenantIDs = make([]string, 0, len(userTenants)+1) + tenantIDs = append(tenantIDs, userID) + for _, tenant := range userTenants { + tenantIDs = append(tenantIDs, tenant.TenantID) } } - memories, total, err := s.memoryDAO.GetByFilter(tenantIDs, memoryTypes, storageType, keywords, page, pageSize) + memories, total, err := s.memoryDAO.GetByFilter(userID, tenantIDs, memoryTypes, storageType, keywords, page, pageSize) if err != nil { return nil, err }