mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
### What problem does this PR solve? As title ### Type of change - [x] New Feature (non-breaking change which adds functionality)
1822 lines
55 KiB
Go
1822 lines
55 KiB
Go
//
|
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"ragflow/internal/common"
|
|
"ragflow/internal/entity"
|
|
models "ragflow/internal/entity/models"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"ragflow/internal/dao"
|
|
"ragflow/internal/engine"
|
|
enginetypes "ragflow/internal/engine/types"
|
|
"ragflow/internal/service/nlp"
|
|
)
|
|
|
|
const (
|
|
// MemoryNameLimit is the maximum length allowed for memory names
|
|
MemoryNameLimit = 128
|
|
// MemorySizeLimit is the maximum memory size in bytes (5MB)
|
|
MemorySizeLimit = 5242880
|
|
)
|
|
|
|
// Note: MemoryType, MemoryTypeRaw, MemoryTypeSemantic, MemoryTypeEpisodic,
|
|
// MemoryTypeProcedural, and CalculateMemoryType are defined in the dao package
|
|
// and imported as dao.MemoryType, dao.MemoryTypeRaw, etc.
|
|
|
|
// TenantPermission defines the access permission levels for memory resources
|
|
// Note: This type is specific to the service layer
|
|
type TenantPermission string
|
|
|
|
const (
|
|
// TenantPermissionMe restricts access to the owner only
|
|
TenantPermissionMe TenantPermission = "me"
|
|
// TenantPermissionTeam allows access within the same team
|
|
TenantPermissionTeam TenantPermission = "team"
|
|
// TenantPermissionAll allows access to all tenants
|
|
TenantPermissionAll TenantPermission = "all"
|
|
)
|
|
|
|
const (
|
|
defaultMessageTopN = 5
|
|
defaultMessageLimit = 10
|
|
maxMessageLimit = 100
|
|
)
|
|
|
|
// validPermissions defines which permission values are valid
|
|
var validPermissions = map[TenantPermission]bool{
|
|
TenantPermissionMe: true,
|
|
TenantPermissionTeam: true,
|
|
TenantPermissionAll: true,
|
|
}
|
|
|
|
// ForgettingPolicy defines the strategy for forgetting old memory entries
|
|
type ForgettingPolicy string
|
|
|
|
const (
|
|
// ForgettingPolicyFIFO uses First-In-First-Out strategy for forgetting
|
|
ForgettingPolicyFIFO ForgettingPolicy = "FIFO"
|
|
)
|
|
|
|
// validForgettingPolicies defines which forgetting policies are valid
|
|
var validForgettingPolicies = map[ForgettingPolicy]bool{
|
|
ForgettingPolicyFIFO: true,
|
|
}
|
|
|
|
// ResourceNotFoundError marks client-visible missing memory/message resources.
|
|
type ResourceNotFoundError struct {
|
|
Resource string
|
|
ID string
|
|
}
|
|
|
|
func (e *ResourceNotFoundError) Error() string {
|
|
return fmt.Sprintf("%s '%s' not found.", e.Resource, e.ID)
|
|
}
|
|
|
|
//
|
|
// Note: CalculateMemoryType and GetMemoryTypeHuman functions have been moved to dao package
|
|
// Use dao.CalculateMemoryType() and dao.GetMemoryTypeHuman() instead
|
|
|
|
// PromptAssembler handles the assembly of system prompts for memory extraction
|
|
type PromptAssembler struct{}
|
|
|
|
// SYSTEM_BASE_TEMPLATE is the base template for the system prompt used in memory extraction
|
|
// It includes placeholders for type-specific instructions, timestamp format, and max items
|
|
var SYSTEM_BASE_TEMPLATE = `**Memory Extraction Specialist**
|
|
You are an expert at analyzing conversations to extract structured memory.
|
|
|
|
{type_specific_instructions}
|
|
|
|
|
|
**OUTPUT REQUIREMENTS:**
|
|
1. Output MUST be valid JSON
|
|
2. Follow the specified output format exactly
|
|
3. Each extracted item MUST have: content, valid_at, invalid_at
|
|
4. Timestamps in {timestamp_format} format
|
|
5. Only extract memory types specified above
|
|
6. Maximum {max_items} items per type
|
|
`
|
|
|
|
// TYPE_INSTRUCTIONS contains specific instructions for each memory type extraction
|
|
var TYPE_INSTRUCTIONS = map[string]string{
|
|
"semantic": `
|
|
**EXTRACT SEMANTIC KNOWLEDGE:**
|
|
- Universal facts, definitions, concepts, relationships
|
|
- Time-invariant, generally true information
|
|
|
|
**Timestamp Rules:**
|
|
- valid_at: When the fact became true
|
|
- invalid_at: When it becomes false or empty if still true
|
|
`,
|
|
"episodic": `
|
|
**EXTRACT EPISODIC KNOWLEDGE:**
|
|
- Specific experiences, events, personal stories
|
|
- Time-bound, person-specific, contextual
|
|
|
|
**Timestamp Rules:**
|
|
- valid_at: Event start/occurrence time
|
|
- invalid_at: Event end time or empty if instantaneous
|
|
`,
|
|
"procedural": `
|
|
**EXTRACT PROCEDURAL KNOWLEDGE:**
|
|
- Processes, methods, step-by-step instructions
|
|
- Goal-oriented, actionable, often includes conditions
|
|
|
|
**Timestamp Rules:**
|
|
- valid_at: When procedure becomes valid/effective
|
|
- invalid_at: When it expires/becomes obsolete or empty if current
|
|
`,
|
|
}
|
|
|
|
// OUTPUT_TEMPLATES defines the output format for each memory type
|
|
var OUTPUT_TEMPLATES = map[string]string{
|
|
"semantic": `"semantic": [{"content": "Clear factual statement", "valid_at": "timestamp or empty", "invalid_at": "timestamp or empty"}]`,
|
|
"episodic": `"episodic": [{"content": "Narrative event description", "valid_at": "event start timestamp", "invalid_at": "event end timestamp or empty"}]`,
|
|
"procedural": `"procedural": [{"content": "Actionable instructions", "valid_at": "procedure effective timestamp", "invalid_at": "procedure expiration timestamp or empty"}]`,
|
|
}
|
|
|
|
// AssembleSystemPrompt generates a complete system prompt for memory extraction
|
|
//
|
|
// Parameters:
|
|
// - memoryTypes: Array of memory type names to extract (e.g., ["semantic", "episodic"])
|
|
//
|
|
// Returns:
|
|
// - string: Complete system prompt with type-specific instructions and output format
|
|
//
|
|
// Example:
|
|
//
|
|
// AssembleSystemPrompt([]string{"semantic", "episodic"}) returns a prompt with instructions
|
|
// for both semantic and episodic memory extraction
|
|
func (PromptAssembler) AssembleSystemPrompt(memoryTypes []string) string {
|
|
typesToExtract := getTypesToExtract(memoryTypes)
|
|
if len(typesToExtract) == 0 {
|
|
typesToExtract = []string{"raw"}
|
|
}
|
|
|
|
typeInstructions := generateTypeInstructions(typesToExtract)
|
|
outputFormat := generateOutputFormat(typesToExtract)
|
|
|
|
fullPrompt := strings.Replace(SYSTEM_BASE_TEMPLATE, "{type_specific_instructions}", typeInstructions, 1)
|
|
fullPrompt = strings.Replace(fullPrompt, "{timestamp_format}", "ISO 8601", 1)
|
|
fullPrompt = strings.Replace(fullPrompt, "{max_items}", "5", 1)
|
|
|
|
fullPrompt += fmt.Sprintf("\n**REQUIRED OUTPUT FORMAT (JSON):\n```json\n{\n%s\n}\n```\n", outputFormat)
|
|
|
|
return fullPrompt
|
|
}
|
|
|
|
// getTypesToExtract filters out "raw" type and returns valid memory types
|
|
//
|
|
// Parameters:
|
|
// - requestedTypes: Array of requested memory type names
|
|
//
|
|
// Returns:
|
|
// - []string: Filtered array of memory type names (excluding "raw")
|
|
func getTypesToExtract(requestedTypes []string) []string {
|
|
types := make(map[string]bool)
|
|
for _, rt := range requestedTypes {
|
|
lowerRT := strings.ToLower(rt)
|
|
if lowerRT != "raw" {
|
|
if _, ok := dao.MemoryTypeMap[lowerRT]; ok {
|
|
types[lowerRT] = true
|
|
}
|
|
}
|
|
}
|
|
result := make([]string, 0, len(types))
|
|
for t := range types {
|
|
result = append(result, t)
|
|
}
|
|
return result
|
|
}
|
|
|
|
// generateTypeInstructions concatenates type-specific instructions
|
|
//
|
|
// Parameters:
|
|
// - typesToExtract: Array of memory type names
|
|
//
|
|
// Returns:
|
|
// - string: Concatenated instructions for all specified types
|
|
func generateTypeInstructions(typesToExtract []string) string {
|
|
var instructions []string
|
|
for _, mt := range typesToExtract {
|
|
if instr, ok := TYPE_INSTRUCTIONS[mt]; ok {
|
|
instructions = append(instructions, instr)
|
|
}
|
|
}
|
|
return strings.Join(instructions, "\n")
|
|
}
|
|
|
|
// generateOutputFormat concatenates output format templates
|
|
//
|
|
// Parameters:
|
|
// - typesToExtract: Array of memory type names
|
|
//
|
|
// Returns:
|
|
// - string: Concatenated output format templates
|
|
func generateOutputFormat(typesToExtract []string) string {
|
|
var outputParts []string
|
|
for _, mt := range typesToExtract {
|
|
if tmpl, ok := OUTPUT_TEMPLATES[mt]; ok {
|
|
outputParts = append(outputParts, tmpl)
|
|
}
|
|
}
|
|
return strings.Join(outputParts, ",\n")
|
|
}
|
|
|
|
// MemoryService handles business logic for memory operations
|
|
// It provides methods for creating, updating, deleting, and querying memories
|
|
type MemoryService struct {
|
|
memoryDAO *dao.MemoryDAO
|
|
docEngine engine.DocEngine
|
|
}
|
|
|
|
// NewMemoryService creates a new MemoryService instance
|
|
//
|
|
// Returns:
|
|
// - *MemoryService: Initialized service instance with DAO
|
|
func NewMemoryService() *MemoryService {
|
|
return &MemoryService{
|
|
memoryDAO: dao.NewMemoryDAO(),
|
|
docEngine: engine.Get(),
|
|
}
|
|
}
|
|
|
|
// CreateMemoryRequest defines the request structure for creating a memory
|
|
type CreateMemoryRequest struct {
|
|
// Name is the memory name (required, max 128 characters)
|
|
Name string `json:"name" binding:"required"`
|
|
// MemoryType is the array of memory type names (required)
|
|
MemoryType []string `json:"memory_type" binding:"required"`
|
|
// EmbdID is the embedding model ID (required)
|
|
EmbdID string `json:"embd_id" binding:"required"`
|
|
// LLMID is the language model ID (required)
|
|
LLMID string `json:"llm_id" binding:"required"`
|
|
// TenantEmbdID is the tenant-specific embedding model ID (optional)
|
|
TenantEmbdID *string `json:"tenant_embd_id"`
|
|
// TenantLLMID is the tenant-specific language model ID (optional)
|
|
TenantLLMID *string `json:"tenant_llm_id"`
|
|
}
|
|
|
|
// UpdateMemoryRequest defines the request structure for updating a memory
|
|
// All fields are optional, only provided fields will be updated
|
|
type UpdateMemoryRequest struct {
|
|
// Name is the new memory name (optional)
|
|
Name *string `json:"name"`
|
|
// Permissions is the new permission level (optional)
|
|
Permissions *string `json:"permissions"`
|
|
// LLMID is the new language model ID (optional)
|
|
LLMID *string `json:"llm_id"`
|
|
// EmbdID is the new embedding model ID (optional)
|
|
EmbdID *string `json:"embd_id"`
|
|
// TenantLLMID is the new tenant-specific language model ID (optional)
|
|
TenantLLMID *string `json:"tenant_llm_id"`
|
|
// TenantEmbdID is the new tenant-specific embedding model ID (optional)
|
|
TenantEmbdID *string `json:"tenant_embd_id"`
|
|
// MemoryType is the new array of memory type names (optional)
|
|
MemoryType []string `json:"memory_type"`
|
|
// MemorySize is the new memory size in bytes (optional, max 5MB)
|
|
MemorySize *int64 `json:"memory_size"`
|
|
// ForgettingPolicy is the new forgetting policy (optional)
|
|
ForgettingPolicy *string `json:"forgetting_policy"`
|
|
// Temperature is the new temperature value (optional, range [0, 1])
|
|
Temperature *float64 `json:"temperature"`
|
|
// Avatar is the new avatar URL (optional)
|
|
Avatar *string `json:"avatar"`
|
|
// Description is the new description (optional)
|
|
Description *string `json:"description"`
|
|
// SystemPrompt is the new system prompt (optional)
|
|
SystemPrompt *string `json:"system_prompt"`
|
|
// UserPrompt is the new user prompt (optional)
|
|
UserPrompt *string `json:"user_prompt"`
|
|
}
|
|
|
|
// CreateMemoryResponse defines the response structure for memory operations
|
|
// Uses struct embedding to extend Memory struct with API-specific fields
|
|
type CreateMemoryResponse struct {
|
|
entity.Memory
|
|
OwnerName *string `json:"owner_name,omitempty"`
|
|
MemoryType []string `json:"memory_type"`
|
|
}
|
|
|
|
// ListMemoryResponse defines the response structure for listing memories
|
|
type ListMemoryResponse struct {
|
|
// MemoryList is the array of memory objects
|
|
MemoryList []map[string]interface{} `json:"memory_list"`
|
|
// TotalCount is the total number of memories
|
|
TotalCount int64 `json:"total_count"`
|
|
}
|
|
|
|
// CreateMemory creates a new memory with the given parameters
|
|
// It validates the request, generates a unique name if needed, and creates the memory record
|
|
//
|
|
// Parameters:
|
|
// - tenantID: The tenant ID for which to create the memory
|
|
// - req: The memory creation request containing name, memory_type, embd_id, llm_id, etc.
|
|
//
|
|
// Returns:
|
|
// - *CreateMemoryResponse: The created memory details
|
|
// - error: Error if validation fails or creation fails
|
|
//
|
|
// Example:
|
|
//
|
|
// req := &CreateMemoryRequest{Name: "MyMemory", MemoryType: []string{"semantic"}, EmbdID: "embd1", LLMID: "llm1"}
|
|
// resp, err := service.CreateMemory("tenant123", req)
|
|
func (s *MemoryService) CreateMemory(tenantID string, req *CreateMemoryRequest) (*CreateMemoryResponse, error) {
|
|
// Ensure tenant model IDs are populated for LLM and embedding model parameters
|
|
// This automatically fills tenant_llm_id and tenant_embd_id based on llm_id and embd_id
|
|
tenantLLMService := NewTenantLLMService()
|
|
params := map[string]interface{}{
|
|
"llm_id": req.LLMID,
|
|
"embd_id": req.EmbdID,
|
|
}
|
|
params = tenantLLMService.EnsureTenantModelIDForParams(tenantID, params)
|
|
|
|
// Update request with tenant model IDs from the processed params
|
|
if tenantLLMID, ok := params["tenant_llm_id"].(int64); ok {
|
|
tenantLLMIDStr := strconv.FormatInt(tenantLLMID, 10)
|
|
req.TenantLLMID = &tenantLLMIDStr
|
|
}
|
|
if tenantEmbdID, ok := params["tenant_embd_id"].(int64); ok {
|
|
tenantEmbdIDStr := strconv.FormatInt(tenantEmbdID, 10)
|
|
req.TenantEmbdID = &tenantEmbdIDStr
|
|
}
|
|
|
|
if err := common.ValidateName(req.Name); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
memoryName := req.Name
|
|
|
|
if !isList(req.MemoryType) {
|
|
return nil, errors.New("memory type must be a list")
|
|
}
|
|
|
|
memoryTypeSet := make(map[string]bool)
|
|
for _, mt := range req.MemoryType {
|
|
lowerMT := strings.ToLower(mt)
|
|
if _, ok := dao.MemoryTypeMap[lowerMT]; !ok {
|
|
return nil, fmt.Errorf("memory type '%s' is not supported", mt)
|
|
}
|
|
memoryTypeSet[lowerMT] = true
|
|
}
|
|
uniqueMemoryTypes := make([]string, 0, len(memoryTypeSet))
|
|
for mt := range memoryTypeSet {
|
|
uniqueMemoryTypes = append(uniqueMemoryTypes, mt)
|
|
}
|
|
|
|
memoryName, err := common.DuplicateName(func(name string, tid string) bool {
|
|
existing, _ := s.memoryDAO.GetByNameAndTenant(name, tid)
|
|
return len(existing) > 0
|
|
}, memoryName, tenantID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
memoryTypeInt := dao.CalculateMemoryType(uniqueMemoryTypes)
|
|
systemPrompt := PromptAssembler{}.AssembleSystemPrompt(uniqueMemoryTypes)
|
|
|
|
newID := common.GenerateUUID()
|
|
|
|
memory := &entity.Memory{
|
|
ID: newID,
|
|
Name: memoryName,
|
|
TenantID: tenantID,
|
|
MemoryType: memoryTypeInt,
|
|
StorageType: "table",
|
|
EmbdID: req.EmbdID,
|
|
LLMID: req.LLMID,
|
|
Permissions: "me",
|
|
MemorySize: MemorySizeLimit,
|
|
ForgettingPolicy: string(ForgettingPolicyFIFO),
|
|
Temperature: 0.5,
|
|
SystemPrompt: &systemPrompt,
|
|
}
|
|
|
|
// Convert tenant model IDs from string to int64 for database
|
|
if req.TenantEmbdID != nil {
|
|
if embdID, err := strconv.ParseInt(*req.TenantEmbdID, 10, 64); err == nil {
|
|
memory.TenantEmbdID = &embdID
|
|
}
|
|
}
|
|
if req.TenantLLMID != nil {
|
|
if llmID, err := strconv.ParseInt(*req.TenantLLMID, 10, 64); err == nil {
|
|
memory.TenantLLMID = &llmID
|
|
}
|
|
}
|
|
if err := s.memoryDAO.Create(memory); err != nil {
|
|
return nil, errors.New("could not create new memory")
|
|
}
|
|
|
|
createdMemory, err := s.memoryDAO.GetByID(newID)
|
|
if err != nil {
|
|
return nil, errors.New("could not create new memory")
|
|
}
|
|
|
|
return formatRetDataFromMemory(createdMemory), nil
|
|
}
|
|
|
|
// UpdateMemory updates an existing memory with the provided fields
|
|
// Only the fields specified in the request will be updated (partial update)
|
|
//
|
|
// Parameters:
|
|
// - tenantID: The tenant ID for ownership verification
|
|
// - memoryID: The ID of the memory to update
|
|
// - req: The update request with optional fields to update
|
|
//
|
|
// Returns:
|
|
// - *CreateMemoryResponse: The updated memory details
|
|
// - error: Error if validation fails or update fails
|
|
//
|
|
// Example:
|
|
//
|
|
// req := &UpdateMemoryRequest{Name: ptr("NewName"), MemorySize: ptr(int64(1000000))}
|
|
// resp, err := service.UpdateMemory("tenant123", "memory456", req)
|
|
func (s *MemoryService) UpdateMemory(tenantID string, memoryID string, req *UpdateMemoryRequest) (*CreateMemoryResponse, error) {
|
|
updateDict := make(map[string]interface{})
|
|
|
|
if req.Name != nil {
|
|
memoryName := strings.TrimSpace(*req.Name)
|
|
if err := common.ValidateName(memoryName); err != nil {
|
|
return nil, err
|
|
}
|
|
memoryName, err := common.DuplicateName(func(name string, tid string) bool {
|
|
existing, _ := s.memoryDAO.GetByNameAndTenant(name, tid)
|
|
return len(existing) > 0
|
|
}, memoryName, tenantID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
updateDict["name"] = memoryName
|
|
}
|
|
|
|
if req.Permissions != nil {
|
|
perm := TenantPermission(strings.ToLower(*req.Permissions))
|
|
if !validPermissions[perm] {
|
|
return nil, fmt.Errorf("unknown permission '%s'", *req.Permissions)
|
|
}
|
|
updateDict["permissions"] = perm
|
|
}
|
|
|
|
if req.LLMID != nil {
|
|
updateDict["llm_id"] = *req.LLMID
|
|
}
|
|
|
|
if req.EmbdID != nil {
|
|
updateDict["embd_id"] = *req.EmbdID
|
|
}
|
|
|
|
if req.TenantLLMID != nil {
|
|
if llmID, err := strconv.ParseInt(*req.TenantLLMID, 10, 64); err == nil {
|
|
updateDict["tenant_llm_id"] = llmID
|
|
}
|
|
}
|
|
|
|
if req.TenantEmbdID != nil {
|
|
if embdID, err := strconv.ParseInt(*req.TenantEmbdID, 10, 64); err == nil {
|
|
updateDict["tenant_embd_id"] = embdID
|
|
}
|
|
}
|
|
|
|
if req.MemoryType != nil && len(req.MemoryType) > 0 {
|
|
memoryTypeSet := make(map[string]bool)
|
|
for _, mt := range req.MemoryType {
|
|
lowerMT := strings.ToLower(mt)
|
|
if _, ok := dao.MemoryTypeMap[lowerMT]; !ok {
|
|
return nil, fmt.Errorf("memory type '%s' is not supported", mt)
|
|
}
|
|
memoryTypeSet[lowerMT] = true
|
|
}
|
|
uniqueMemoryTypes := make([]string, 0, len(memoryTypeSet))
|
|
for mt := range memoryTypeSet {
|
|
uniqueMemoryTypes = append(uniqueMemoryTypes, mt)
|
|
}
|
|
updateDict["memory_type"] = uniqueMemoryTypes
|
|
}
|
|
|
|
if req.MemorySize != nil {
|
|
memorySize := *req.MemorySize
|
|
if !(memorySize > 0 && memorySize <= MemorySizeLimit) {
|
|
return nil, fmt.Errorf("memory size should be in range (0, %d] Bytes", MemorySizeLimit)
|
|
}
|
|
updateDict["memory_size"] = memorySize
|
|
}
|
|
|
|
if req.ForgettingPolicy != nil {
|
|
fp := ForgettingPolicy(strings.ToUpper(strings.TrimSpace(*req.ForgettingPolicy)))
|
|
if !validForgettingPolicies[fp] {
|
|
return nil, fmt.Errorf("forgetting policy '%s' is not supported", *req.ForgettingPolicy)
|
|
}
|
|
updateDict["forgetting_policy"] = fp
|
|
}
|
|
|
|
if req.Temperature != nil {
|
|
temp := *req.Temperature
|
|
if !(temp >= 0 && temp <= 1) {
|
|
return nil, errors.New("temperature should be in range [0, 1]")
|
|
}
|
|
updateDict["temperature"] = temp
|
|
}
|
|
|
|
for _, field := range []string{"avatar", "description", "system_prompt", "user_prompt"} {
|
|
switch field {
|
|
case "avatar":
|
|
if req.Avatar != nil {
|
|
updateDict["avatar"] = *req.Avatar
|
|
}
|
|
case "description":
|
|
if req.Description != nil {
|
|
updateDict["description"] = *req.Description
|
|
}
|
|
case "system_prompt":
|
|
if req.SystemPrompt != nil {
|
|
updateDict["system_prompt"] = *req.SystemPrompt
|
|
}
|
|
case "user_prompt":
|
|
if req.UserPrompt != nil {
|
|
updateDict["user_prompt"] = *req.UserPrompt
|
|
}
|
|
}
|
|
}
|
|
|
|
currentMemory, err := s.memoryDAO.GetByID(memoryID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("memory '%s' not found", memoryID)
|
|
}
|
|
|
|
if len(updateDict) == 0 {
|
|
return formatRetDataFromMemory(currentMemory), nil
|
|
}
|
|
|
|
currentMemoryTypes := dao.GetMemoryTypeHuman(currentMemory.MemoryType)
|
|
normalizedCurrentMemoryTypes := normalizeMemoryTypes(currentMemoryTypes)
|
|
|
|
filteredUpdateDict := make(map[string]interface{}, len(updateDict))
|
|
for field, value := range updateDict {
|
|
switch field {
|
|
case "name":
|
|
currentName := strings.TrimSpace(currentMemory.Name)
|
|
requestName := strings.TrimSpace(fmt.Sprint(value))
|
|
if currentName != requestName {
|
|
filteredUpdateDict[field] = value
|
|
}
|
|
case "permissions":
|
|
currentPermissions := strings.ToLower(strings.TrimSpace(currentMemory.Permissions))
|
|
requestPermissions := strings.ToLower(strings.TrimSpace(fmt.Sprint(value)))
|
|
if currentPermissions != requestPermissions {
|
|
filteredUpdateDict[field] = value
|
|
}
|
|
case "llm_id":
|
|
currentLLMID := strings.TrimSpace(currentMemory.LLMID)
|
|
requestLLMID := strings.TrimSpace(fmt.Sprint(value))
|
|
if currentLLMID != requestLLMID {
|
|
filteredUpdateDict[field] = value
|
|
}
|
|
case "embd_id":
|
|
currentEmbdID := strings.TrimSpace(currentMemory.EmbdID)
|
|
requestEmbdID := strings.TrimSpace(fmt.Sprint(value))
|
|
if currentEmbdID != requestEmbdID {
|
|
filteredUpdateDict[field] = value
|
|
}
|
|
case "tenant_llm_id":
|
|
if currentMemory.TenantLLMID == nil || *currentMemory.TenantLLMID != value.(int64) {
|
|
filteredUpdateDict[field] = value
|
|
}
|
|
case "tenant_embd_id":
|
|
if currentMemory.TenantEmbdID == nil || *currentMemory.TenantEmbdID != value.(int64) {
|
|
filteredUpdateDict[field] = value
|
|
}
|
|
case "memory_type":
|
|
if types, ok := value.([]string); ok {
|
|
if !sameStringSet(normalizedCurrentMemoryTypes, normalizeMemoryTypes(types)) {
|
|
filteredUpdateDict[field] = value
|
|
}
|
|
} else {
|
|
filteredUpdateDict[field] = value
|
|
}
|
|
case "memory_size":
|
|
if currentMemory.MemorySize != value.(int64) {
|
|
filteredUpdateDict[field] = value
|
|
}
|
|
case "forgetting_policy":
|
|
currentForgettingPolicy := strings.ToUpper(strings.TrimSpace(currentMemory.ForgettingPolicy))
|
|
requestForgettingPolicy := strings.ToUpper(strings.TrimSpace(fmt.Sprint(value)))
|
|
if currentForgettingPolicy != requestForgettingPolicy {
|
|
filteredUpdateDict[field] = value
|
|
}
|
|
case "temperature":
|
|
if currentMemory.Temperature != value.(float64) {
|
|
filteredUpdateDict[field] = value
|
|
}
|
|
case "avatar":
|
|
currentAvatar := ""
|
|
if currentMemory.Avatar != nil {
|
|
currentAvatar = *currentMemory.Avatar
|
|
}
|
|
if currentAvatar != fmt.Sprint(value) {
|
|
filteredUpdateDict[field] = value
|
|
}
|
|
case "description":
|
|
currentDescription := ""
|
|
if currentMemory.Description != nil {
|
|
currentDescription = *currentMemory.Description
|
|
}
|
|
if currentDescription != fmt.Sprint(value) {
|
|
filteredUpdateDict[field] = value
|
|
}
|
|
case "system_prompt":
|
|
currentSystemPrompt := ""
|
|
if currentMemory.SystemPrompt != nil {
|
|
currentSystemPrompt = *currentMemory.SystemPrompt
|
|
}
|
|
if currentSystemPrompt != fmt.Sprint(value) {
|
|
filteredUpdateDict[field] = value
|
|
}
|
|
case "user_prompt":
|
|
currentUserPrompt := ""
|
|
if currentMemory.UserPrompt != nil {
|
|
currentUserPrompt = *currentMemory.UserPrompt
|
|
}
|
|
if currentUserPrompt != fmt.Sprint(value) {
|
|
filteredUpdateDict[field] = value
|
|
}
|
|
default:
|
|
filteredUpdateDict[field] = value
|
|
}
|
|
}
|
|
updateDict = filteredUpdateDict
|
|
|
|
if len(updateDict) == 0 {
|
|
return formatRetDataFromMemory(currentMemory), nil
|
|
}
|
|
|
|
memorySize := currentMemory.MemorySize
|
|
notAllowedUpdate := []string{}
|
|
for _, f := range []string{"tenant_embd_id", "embd_id", "memory_type"} {
|
|
if _, ok := updateDict[f]; ok && memorySize > 0 {
|
|
notAllowedUpdate = append(notAllowedUpdate, f)
|
|
}
|
|
}
|
|
if len(notAllowedUpdate) > 0 {
|
|
return nil, fmt.Errorf("can't update %v when memory isn't empty", notAllowedUpdate)
|
|
}
|
|
|
|
if _, ok := updateDict["memory_type"]; ok {
|
|
if _, ok := updateDict["system_prompt"]; !ok {
|
|
memoryTypes := dao.GetMemoryTypeHuman(currentMemory.MemoryType)
|
|
if len(memoryTypes) > 0 && currentMemory.SystemPrompt != nil {
|
|
defaultPrompt := PromptAssembler{}.AssembleSystemPrompt(memoryTypes)
|
|
if *currentMemory.SystemPrompt == defaultPrompt {
|
|
if types, ok := updateDict["memory_type"].([]string); ok {
|
|
updateDict["system_prompt"] = PromptAssembler{}.AssembleSystemPrompt(types)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := s.memoryDAO.UpdateByID(memoryID, updateDict); err != nil {
|
|
return nil, errors.New("failed to update memory")
|
|
}
|
|
|
|
updatedMemory, err := s.memoryDAO.GetByID(memoryID)
|
|
if err != nil {
|
|
return nil, errors.New("failed to get updated memory")
|
|
}
|
|
|
|
return formatRetDataFromMemory(updatedMemory), nil
|
|
}
|
|
|
|
func normalizeMemoryTypes(memoryTypes []string) []string {
|
|
normalized := make([]string, 0, len(memoryTypes))
|
|
seen := make(map[string]struct{}, len(memoryTypes))
|
|
for _, mt := range memoryTypes {
|
|
mt = strings.ToLower(strings.TrimSpace(mt))
|
|
if mt == "" {
|
|
continue
|
|
}
|
|
if _, exists := seen[mt]; exists {
|
|
continue
|
|
}
|
|
seen[mt] = struct{}{}
|
|
normalized = append(normalized, mt)
|
|
}
|
|
return normalized
|
|
}
|
|
|
|
func sameStringSet(a, b []string) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
counts := make(map[string]int, len(a))
|
|
for _, item := range a {
|
|
counts[item]++
|
|
}
|
|
for _, item := range b {
|
|
counts[item]--
|
|
if counts[item] < 0 {
|
|
return false
|
|
}
|
|
}
|
|
for _, count := range counts {
|
|
if count != 0 {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// DeleteMemory deletes a memory by ID
|
|
// It also deletes associated message indexes before removing the memory record
|
|
//
|
|
// Parameters:
|
|
// - memoryID: The ID of the memory to delete
|
|
//
|
|
// Returns:
|
|
// - error: Error if memory not found or deletion fails
|
|
//
|
|
// Example:
|
|
//
|
|
// err := service.DeleteMemory("memory456")
|
|
func (s *MemoryService) DeleteMemory(memoryID string) error {
|
|
_, err := s.memoryDAO.GetByID(memoryID)
|
|
if err != nil {
|
|
return fmt.Errorf("memory '%s' not found", memoryID)
|
|
}
|
|
|
|
// TODO: Delete associated message index - Implementation pending MessageService
|
|
// messageService := NewMessageService()
|
|
// hasIndex, _ := messageService.HasIndex(memory.TenantID, memoryID)
|
|
// if hasIndex {
|
|
// messageService.DeleteMessage(nil, memory.TenantID, memoryID)
|
|
// }
|
|
|
|
// Delete memory record
|
|
if err := s.memoryDAO.DeleteByID(memoryID); err != nil {
|
|
return errors.New("failed to delete memory")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ForgetMessage marks a memory message as forgotten by setting forget_at.
|
|
// This mirrors Python memory_api_service.forget_message and keeps the message
|
|
// record for retention/cleanup policies instead of deleting it immediately.
|
|
func (s *MemoryService) ForgetMessage(ctx context.Context, userID string, memoryID string, messageID int64) error {
|
|
memory, err := s.requireMemoryAccess(ctx, userID, memoryID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if s.docEngine == nil {
|
|
return errors.New("message store is not initialized")
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
forgetTime := now.Format("2006-01-02 15:04:05")
|
|
messageDocID := fmt.Sprintf("%s_%d", memoryID, messageID)
|
|
updates := map[string]interface{}{
|
|
"forget_at": forgetTime,
|
|
"forget_at_flt": now.UnixMilli(),
|
|
}
|
|
condition := map[string]interface{}{
|
|
"id": messageDocID,
|
|
}
|
|
indexName := memoryIndexName(memory.TenantID)
|
|
|
|
if err := s.docEngine.UpdateChunks(ctx, condition, updates, indexName, memoryID); err != nil {
|
|
if isMessageDocumentNotFound(err) {
|
|
// Match Python delete-by-query behavior: forgetting an already-missing
|
|
// message document is idempotent and still considered successful.
|
|
return nil
|
|
}
|
|
return fmt.Errorf("failed to forget message '%d' in memory '%s': %w", messageID, memoryID, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// AddMessage filters inaccessible memories and queues the raw message for
|
|
// memory extraction. The currentUserID is used only for access control; msg.UserID
|
|
// is the attribution stored on the message and may differ for API-token callers.
|
|
func (s *MemoryService) AddMessage(ctx context.Context, currentUserID string, memoryIDs []string, msg MemoryMessage) (bool, string, error) {
|
|
requestedMemoryIDs := splitFilterValues(memoryIDs)
|
|
memories, err := s.filterAccessibleMemories(ctx, currentUserID, memoryIDs)
|
|
if err != nil {
|
|
return false, err.Error(), err
|
|
}
|
|
if len(memories) == 0 {
|
|
return false, "Memory not found.", nil
|
|
}
|
|
|
|
accessibleMemoryIDs := make([]string, 0, len(memories))
|
|
for _, memory := range memories {
|
|
if memory != nil {
|
|
accessibleMemoryIDs = append(accessibleMemoryIDs, memory.ID)
|
|
}
|
|
}
|
|
missingMemoryIDs := missingRequestedMemoryIDs(requestedMemoryIDs, accessibleMemoryIDs)
|
|
|
|
res, err := NewMemoryMessageService(s).QueueSaveToMemoryTask(ctx, accessibleMemoryIDs, msg)
|
|
if err != nil {
|
|
return false, err.Error(), err
|
|
}
|
|
|
|
if len(missingMemoryIDs) > 0 {
|
|
if res == nil {
|
|
res = &QueueSaveResult{}
|
|
}
|
|
res.NotFound = append(missingMemoryIDs, res.NotFound...)
|
|
}
|
|
errorMsg := memorySaveErrorMessage(res)
|
|
if errorMsg != "" {
|
|
return false, errorMsg, nil
|
|
}
|
|
return true, "All add to task.", nil
|
|
}
|
|
|
|
func missingRequestedMemoryIDs(requestedMemoryIDs, accessibleMemoryIDs []string) []string {
|
|
if len(requestedMemoryIDs) == 0 {
|
|
return []string{}
|
|
}
|
|
|
|
accessibleSet := make(map[string]struct{}, len(accessibleMemoryIDs))
|
|
for _, memoryID := range accessibleMemoryIDs {
|
|
memoryID = strings.TrimSpace(memoryID)
|
|
if memoryID != "" {
|
|
accessibleSet[memoryID] = struct{}{}
|
|
}
|
|
}
|
|
|
|
missingIDs := make([]string, 0)
|
|
seenMissing := make(map[string]struct{})
|
|
for _, memoryID := range requestedMemoryIDs {
|
|
memoryID = strings.TrimSpace(memoryID)
|
|
if memoryID == "" {
|
|
continue
|
|
}
|
|
if _, ok := accessibleSet[memoryID]; ok {
|
|
continue
|
|
}
|
|
if _, ok := seenMissing[memoryID]; ok {
|
|
continue
|
|
}
|
|
missingIDs = append(missingIDs, memoryID)
|
|
seenMissing[memoryID] = struct{}{}
|
|
}
|
|
return missingIDs
|
|
}
|
|
|
|
func memorySaveErrorMessage(res *QueueSaveResult) string {
|
|
if res == nil {
|
|
return ""
|
|
}
|
|
|
|
var b strings.Builder
|
|
if len(res.NotFound) > 0 {
|
|
b.WriteString(fmt.Sprintf("Memory %v not found.", res.NotFound))
|
|
}
|
|
for _, failed := range res.Failed {
|
|
b.WriteString(fmt.Sprintf("Memory %s failed. Detail: %s", failed.MemoryID, failed.FailMsg))
|
|
}
|
|
return b.String()
|
|
}
|
|
|
|
func (s *MemoryService) UpdateMessageStatus(ctx context.Context, userID, memoryID string, messageID int64, status bool) (bool, error) {
|
|
memory, err := s.requireMemoryAccess(ctx, userID, memoryID)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
if s.docEngine == nil {
|
|
return false, errors.New("message store is not initialized")
|
|
}
|
|
|
|
messageDocID := fmt.Sprintf("%s_%d", memoryID, messageID)
|
|
statusValue := 0
|
|
if status {
|
|
statusValue = 1
|
|
}
|
|
updates := map[string]interface{}{
|
|
"status": statusValue,
|
|
}
|
|
condition := map[string]interface{}{
|
|
"id": messageDocID,
|
|
}
|
|
indexName := memoryIndexName(memory.TenantID)
|
|
if err := s.docEngine.UpdateChunks(ctx, condition, updates, indexName, memoryID); err != nil {
|
|
if isMessageDocumentNotFound(err) {
|
|
return false, &ResourceNotFoundError{Resource: "Message", ID: messageDocID}
|
|
}
|
|
return false, fmt.Errorf("failed to set status for message '%d' in memory '%s': %w", messageID, memoryID, err)
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func (s *MemoryService) UpdateMessage(ctx context.Context, userID, memoryID string, messageID int64, status bool) (bool, error) {
|
|
return s.UpdateMessageStatus(ctx, userID, memoryID, messageID, status)
|
|
}
|
|
|
|
func (s *MemoryService) GetMessageContent(ctx context.Context, userID, memoryID string, messageID int64) (map[string]interface{}, error) {
|
|
memory, err := s.requireMemoryAccess(ctx, userID, memoryID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if s.docEngine == nil {
|
|
return nil, errors.New("message store is not initialized")
|
|
}
|
|
|
|
indexName := memoryIndexName(memory.TenantID)
|
|
docID := fmt.Sprintf("%s_%d", memoryID, messageID)
|
|
res, err := s.docEngine.GetChunk(ctx, indexName, docID, []string{memoryID})
|
|
if err != nil {
|
|
if isMessageDocumentNotFound(err) {
|
|
return nil, &ResourceNotFoundError{Resource: "Message", ID: docID}
|
|
}
|
|
return nil, err
|
|
}
|
|
if res == nil {
|
|
return nil, &ResourceNotFoundError{Resource: "Message", ID: docID}
|
|
}
|
|
|
|
message, ok := res.(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("unexpected message content type %T", res)
|
|
}
|
|
return common.ConvertFloatsToPyFormat(message).(map[string]interface{}), nil
|
|
}
|
|
|
|
func (s *MemoryService) SearchMessage(ctx context.Context, userID string, filterDict, params map[string]interface{}) ([]map[string]interface{}, common.ErrorCode, error) {
|
|
memoryIDs := splitFilterValues(filterDict["memory_id"])
|
|
memories, err := s.filterAccessibleMemories(ctx, userID, memoryIDs)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, err
|
|
}
|
|
if len(memories) == 0 {
|
|
return []map[string]interface{}{}, common.CodeSuccess, nil
|
|
}
|
|
|
|
return s.queryMessage(ctx, memories, filterDict, params)
|
|
}
|
|
|
|
func (s *MemoryService) queryMessage(ctx context.Context, memories []*entity.Memory, filterDict, params map[string]interface{}) ([]map[string]interface{}, common.ErrorCode, error) {
|
|
if s.docEngine == nil {
|
|
return nil, common.CodeServerError, errors.New("message store is not initialized")
|
|
}
|
|
|
|
topN := memoryIntParam(params["top_n"], 5)
|
|
if topN <= 0 {
|
|
topN = defaultMessageTopN
|
|
} else if topN > maxMessageLimit {
|
|
topN = maxMessageLimit
|
|
}
|
|
similarityThreshold := memoryFloatParam(params["similarity_threshold"], 0.2)
|
|
keywordsSimilarityWeight := memoryFloatParam(params["keywords_similarity_weight"], 0.7)
|
|
question := strings.TrimSpace(memoryStringParam(params["query"]))
|
|
|
|
memoryIDs := make([]string, 0, len(memories))
|
|
conditionDict := make(map[string]interface{})
|
|
for _, memory := range memories {
|
|
if memory == nil {
|
|
continue
|
|
}
|
|
memoryIDs = append(memoryIDs, memory.ID)
|
|
}
|
|
conditionDict["memory_id"] = memoryIDs
|
|
for _, key := range []string{"agent_id", "session_id", "user_id"} {
|
|
value := strings.TrimSpace(memoryStringParam(filterDict[key]))
|
|
if value != "" {
|
|
conditionDict[key] = value
|
|
}
|
|
}
|
|
if _, ok := conditionDict["status"]; !ok {
|
|
conditionDict["status"] = 1
|
|
}
|
|
|
|
matchExprs := make([]interface{}, 0, 3)
|
|
if question != "" {
|
|
matchText := memoryMessageTextExpr(question, similarityThreshold)
|
|
matchDense, err := s.memoryMessageDenseExpr(question, memories[0], topN, similarityThreshold)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, err
|
|
}
|
|
fusionExpr := &enginetypes.FusionExpr{
|
|
Method: "weighted_sum",
|
|
TopN: topN,
|
|
FusionParams: map[string]interface{}{
|
|
"weights": fmt.Sprintf("%g,%g", 1-keywordsSimilarityWeight, keywordsSimilarityWeight),
|
|
},
|
|
}
|
|
matchExprs = append(matchExprs, matchText, matchDense, fusionExpr)
|
|
}
|
|
|
|
searchReq := &enginetypes.SearchRequest{
|
|
IndexNames: memorySearchIndexNames(memories),
|
|
Offset: 0,
|
|
Limit: topN,
|
|
SelectFields: memoryMessageSelectFields(),
|
|
Filter: conditionDict,
|
|
MatchExprs: matchExprs,
|
|
OrderBy: (&enginetypes.OrderByExpr{}).Desc("valid_at"),
|
|
}
|
|
|
|
searchResult, err := s.docEngine.Search(ctx, searchReq)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, err
|
|
}
|
|
if searchResult == nil || searchResult.Total == 0 {
|
|
return []map[string]interface{}{}, common.CodeSuccess, nil
|
|
}
|
|
|
|
messages := make([]map[string]interface{}, 0, len(searchResult.Chunks))
|
|
for _, chunk := range searchResult.Chunks {
|
|
message := make(map[string]interface{}, len(chunk))
|
|
for _, field := range memoryMessageSelectFields() {
|
|
if value, ok := chunk[field]; ok {
|
|
message[field] = value
|
|
}
|
|
}
|
|
messages = append(messages, message)
|
|
}
|
|
return common.ConvertFloatsToPyFormat(messages).([]map[string]interface{}), common.CodeSuccess, nil
|
|
}
|
|
|
|
func (s *MemoryService) filterAccessibleMemories(ctx context.Context, userID string, memoryIDs []string) ([]*entity.Memory, error) {
|
|
memoryIDs = splitFilterValues(memoryIDs)
|
|
if len(memoryIDs) == 0 {
|
|
return []*entity.Memory{}, nil
|
|
}
|
|
|
|
memories, err := s.memoryDAO.GetByIDs(memoryIDs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(memories) == 0 {
|
|
return []*entity.Memory{}, nil
|
|
}
|
|
|
|
joinedTenantIDs := map[string]struct{}{userID: {}}
|
|
needsTeamLookup := false
|
|
for _, memory := range memories {
|
|
if memory != nil && memory.TenantID != userID && memory.Permissions == string(TenantPermissionTeam) {
|
|
needsTeamLookup = true
|
|
break
|
|
}
|
|
}
|
|
if needsTeamLookup {
|
|
userTenants, err := NewUserTenantService().GetUserTenantRelationByUserIDWithContext(ctx, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, tenant := range userTenants {
|
|
if tenant != nil && tenant.TenantID != "" {
|
|
joinedTenantIDs[tenant.TenantID] = struct{}{}
|
|
}
|
|
}
|
|
}
|
|
|
|
accessible := make([]*entity.Memory, 0, len(memories))
|
|
for _, memory := range memories {
|
|
if memory == nil {
|
|
continue
|
|
}
|
|
if memory.TenantID == userID {
|
|
accessible = append(accessible, memory)
|
|
continue
|
|
}
|
|
if memory.Permissions != string(TenantPermissionTeam) {
|
|
continue
|
|
}
|
|
if _, ok := joinedTenantIDs[memory.TenantID]; ok {
|
|
accessible = append(accessible, memory)
|
|
}
|
|
}
|
|
return accessible, nil
|
|
}
|
|
|
|
func splitFilterValues(values interface{}) []string {
|
|
if values == nil {
|
|
return []string{}
|
|
}
|
|
|
|
var list []string
|
|
|
|
switch v := values.(type) {
|
|
case string:
|
|
list = []string{v}
|
|
case []string:
|
|
list = v
|
|
case []interface{}:
|
|
for _, x := range v {
|
|
if s, ok := x.(string); ok {
|
|
list = append(list, s)
|
|
}
|
|
}
|
|
default:
|
|
return []string{}
|
|
}
|
|
|
|
res := make([]string, 0)
|
|
for _, item := range list {
|
|
if item == "" {
|
|
continue
|
|
}
|
|
parts := strings.Split(item, ",")
|
|
for _, p := range parts {
|
|
p = strings.TrimSpace(p)
|
|
if p != "" {
|
|
res = append(res, p)
|
|
}
|
|
}
|
|
}
|
|
return res
|
|
}
|
|
|
|
func memoryStringParam(value interface{}) string {
|
|
switch typed := value.(type) {
|
|
case string:
|
|
return typed
|
|
case fmt.Stringer:
|
|
return typed.String()
|
|
default:
|
|
if typed == nil {
|
|
return ""
|
|
}
|
|
return fmt.Sprintf("%v", typed)
|
|
}
|
|
}
|
|
|
|
func memoryFloatParam(value interface{}, fallback float64) float64 {
|
|
switch typed := value.(type) {
|
|
case float64:
|
|
return typed
|
|
case float32:
|
|
return float64(typed)
|
|
case int:
|
|
return float64(typed)
|
|
case int64:
|
|
return float64(typed)
|
|
case string:
|
|
if parsed, err := strconv.ParseFloat(strings.TrimSpace(typed), 64); err == nil {
|
|
return parsed
|
|
}
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func memoryIntParam(value interface{}, fallback int) int {
|
|
switch typed := value.(type) {
|
|
case int:
|
|
return typed
|
|
case int64:
|
|
return int(typed)
|
|
case float64:
|
|
return int(typed)
|
|
case string:
|
|
if parsed, err := strconv.Atoi(strings.TrimSpace(typed)); err == nil {
|
|
return parsed
|
|
}
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func memoryMessageSelectFields() []string {
|
|
return []string{
|
|
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id",
|
|
"valid_at", "invalid_at", "forget_at", "status", "content",
|
|
}
|
|
}
|
|
|
|
func memoryIndexName(tenantID string) string {
|
|
prefix := strings.TrimSpace(os.Getenv("ES_INDEX_PREFIX"))
|
|
if prefix == "" {
|
|
return fmt.Sprintf("memory_%s", tenantID)
|
|
}
|
|
return fmt.Sprintf("memory_%s_%s", prefix, tenantID)
|
|
}
|
|
|
|
func memorySearchIndexNames(memories []*entity.Memory) []string {
|
|
seen := make(map[string]struct{}, len(memories))
|
|
indexNames := make([]string, 0, len(memories))
|
|
for _, memory := range memories {
|
|
if memory == nil {
|
|
continue
|
|
}
|
|
indexName := memoryIndexName(memory.TenantID)
|
|
if engine.GetEngineType() == engine.EngineInfinity {
|
|
indexName = fmt.Sprintf("%s_%s", indexName, memory.ID)
|
|
}
|
|
if _, ok := seen[indexName]; ok {
|
|
continue
|
|
}
|
|
seen[indexName] = struct{}{}
|
|
indexNames = append(indexNames, indexName)
|
|
}
|
|
return indexNames
|
|
}
|
|
|
|
func memoryMessageTextExpr(question string, similarityThreshold float64) *enginetypes.MatchTextExpr {
|
|
matchText := &enginetypes.MatchTextExpr{
|
|
Fields: []string{"content"},
|
|
MatchingText: question,
|
|
TopN: 100,
|
|
ExtraOptions: map[string]interface{}{"original_query": question},
|
|
}
|
|
|
|
queryBuilder := nlp.GetQueryBuilder()
|
|
if queryBuilder == nil {
|
|
queryBuilder = nlp.NewQueryBuilder()
|
|
}
|
|
if built, _ := queryBuilder.Question(question, "messages", similarityThreshold); built != nil {
|
|
matchText.MatchingText = built.MatchingText
|
|
matchText.ExtraOptions = built.ExtraOptions
|
|
if matchText.ExtraOptions == nil {
|
|
matchText.ExtraOptions = map[string]interface{}{}
|
|
}
|
|
matchText.ExtraOptions["original_query"] = question
|
|
}
|
|
matchText.Fields = []string{"content"}
|
|
matchText.TopN = 100
|
|
return matchText
|
|
}
|
|
|
|
func (s *MemoryService) memoryMessageDenseExpr(question string, memory *entity.Memory, topN int, similarityThreshold float64) (*enginetypes.MatchDenseExpr, error) {
|
|
driver, modelName, apiConfig, maxTokens, err := NewModelProviderService().GetModelConfigFromProviderInstance(memory.TenantID, entity.ModelTypeEmbedding, memory.EmbdID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
embeddingModel := models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens)
|
|
embeddings, err := embeddingModel.ModelDriver.Embed(embeddingModel.ModelName, []string{question}, embeddingModel.APIConfig, &models.EmbeddingConfig{Dimension: 0})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(embeddings) == 0 || len(embeddings[0].Embedding) == 0 {
|
|
return nil, errors.New("embedding response is empty")
|
|
}
|
|
|
|
vector := embeddings[0].Embedding
|
|
return &enginetypes.MatchDenseExpr{
|
|
VectorColumnName: fmt.Sprintf("q_%d_vec", len(vector)),
|
|
EmbeddingData: vector,
|
|
EmbeddingDataType: "float",
|
|
DistanceType: "cosine",
|
|
TopN: topN,
|
|
ExtraOptions: map[string]interface{}{"similarity": similarityThreshold},
|
|
}, nil
|
|
}
|
|
|
|
func (s *MemoryService) GetMessages(ctx context.Context, memoryIDs []string, userID, agentID, sessionID string, limit int) ([]map[string]interface{}, common.ErrorCode, error) {
|
|
memories, err := s.filterAccessibleMemories(ctx, userID, memoryIDs)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, err
|
|
}
|
|
if len(memories) == 0 {
|
|
return []map[string]interface{}{}, common.CodeSuccess, nil
|
|
}
|
|
return s.getRecentMessage(ctx, memories, agentID, sessionID, limit)
|
|
}
|
|
|
|
func (s *MemoryService) getRecentMessage(ctx context.Context, memories []*entity.Memory, agentID, sessionID string, limit int) ([]map[string]interface{}, common.ErrorCode, error) {
|
|
if s.docEngine == nil {
|
|
return nil, common.CodeServerError, errors.New("doc engine is nil")
|
|
}
|
|
if limit <= 0 {
|
|
limit = defaultMessageLimit
|
|
} else if limit > maxMessageLimit {
|
|
limit = maxMessageLimit
|
|
}
|
|
indexNames := memorySearchIndexNames(memories)
|
|
memoryIDs := make([]string, 0, len(memories))
|
|
for _, memory := range memories {
|
|
if memory == nil || strings.TrimSpace(memory.ID) == "" {
|
|
continue
|
|
}
|
|
memoryIDs = append(memoryIDs, memory.ID)
|
|
}
|
|
|
|
conditionDict := map[string]interface{}{"memory_id": memoryIDs}
|
|
if agentID = strings.TrimSpace(agentID); agentID != "" {
|
|
conditionDict["agent_id"] = agentID
|
|
}
|
|
if sessionID = strings.TrimSpace(sessionID); sessionID != "" {
|
|
conditionDict["session_id"] = sessionID
|
|
}
|
|
req := &enginetypes.SearchRequest{
|
|
IndexNames: indexNames,
|
|
Offset: 0,
|
|
Limit: limit,
|
|
SelectFields: memoryMessageSelectFields(),
|
|
Filter: conditionDict,
|
|
MatchExprs: []interface{}{},
|
|
OrderBy: (&enginetypes.OrderByExpr{}).Desc("valid_at"),
|
|
}
|
|
|
|
result, err := s.docEngine.Search(ctx, req)
|
|
if err != nil {
|
|
return nil, common.CodeServerError, err
|
|
}
|
|
if result == nil || result.Total == 0 {
|
|
return []map[string]interface{}{}, common.CodeSuccess, nil
|
|
}
|
|
|
|
messages := make([]map[string]interface{}, 0, len(result.Chunks))
|
|
for _, chunk := range result.Chunks {
|
|
msg := make(map[string]interface{}, len(chunk))
|
|
for _, field := range memoryMessageSelectFields() {
|
|
if val, ok := chunk[field]; ok {
|
|
msg[field] = val
|
|
}
|
|
}
|
|
messages = append(messages, msg)
|
|
}
|
|
return common.ConvertFloatsToPyFormat(messages).([]map[string]interface{}), common.CodeSuccess, nil
|
|
}
|
|
|
|
func isMessageDocumentNotFound(err error) bool {
|
|
return errors.Is(err, enginetypes.ErrDocumentNotFound)
|
|
}
|
|
|
|
func (s *MemoryService) requireMemoryAccess(ctx context.Context, userID string, memoryID string) (*entity.Memory, error) {
|
|
if err := ctx.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
memory, err := s.memoryDAO.GetByIDWithContext(ctx, memoryID)
|
|
if err != nil {
|
|
if dao.IsNotFoundErr(err) {
|
|
return nil, &ResourceNotFoundError{Resource: "Memory", ID: memoryID}
|
|
}
|
|
return nil, fmt.Errorf("failed to get memory '%s': %w", memoryID, err)
|
|
}
|
|
if memory.TenantID == userID {
|
|
return memory, nil
|
|
}
|
|
if memory.Permissions != string(TenantPermissionTeam) {
|
|
return nil, &ResourceNotFoundError{Resource: "Memory", ID: memoryID}
|
|
}
|
|
|
|
if err := ctx.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
userTenantService := NewUserTenantService()
|
|
userTenants, err := userTenantService.GetUserTenantRelationByUserIDWithContext(ctx, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, tenant := range userTenants {
|
|
if tenant.TenantID == memory.TenantID {
|
|
return memory, nil
|
|
}
|
|
}
|
|
|
|
return nil, &ResourceNotFoundError{Resource: "Memory", ID: memoryID}
|
|
}
|
|
|
|
// ListMemories retrieves a paginated list of memories with optional filters
|
|
// When tenantIDs is empty, it retrieves all tenants associated with the user
|
|
//
|
|
// Parameters:
|
|
// - userID: The user ID for tenant filtering when tenantIDs is empty
|
|
// - tenantIDs: Array of tenant IDs to filter by (empty means all user's tenants)
|
|
// - memoryTypes: Array of memory type names to filter by (empty means all types)
|
|
// - storageType: Storage type to filter by (empty means all types)
|
|
// - keywords: Keywords to search in memory names (empty means no keyword filter)
|
|
// - page: Page number (1-based)
|
|
// - pageSize: Number of items per page
|
|
//
|
|
// Returns:
|
|
// - *ListMemoryResponse: Contains memory list and total count
|
|
// - error: Error if query fails
|
|
//
|
|
// Example:
|
|
//
|
|
// resp, err := service.ListMemories("user123", []string{}, []string{"semantic"}, "table", "test", 1, 10)
|
|
func (s *MemoryService) ListMemories(userID string, tenantIDs []string, memoryTypes []string, storageType string, keywords string, page int, pageSize int) (*ListMemoryResponse, error) {
|
|
// If tenantIDs is empty, get all tenants associated with the user
|
|
if len(tenantIDs) == 0 {
|
|
userTenantService := NewUserTenantService()
|
|
userTenants, err := userTenantService.GetUserTenantRelationByUserID(userID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get user tenants: %w", err)
|
|
}
|
|
tenantIDs = make([]string, len(userTenants))
|
|
for i, tenant := range userTenants {
|
|
tenantIDs[i] = tenant.TenantID
|
|
}
|
|
}
|
|
|
|
memories, total, err := s.memoryDAO.GetByFilter(tenantIDs, memoryTypes, storageType, keywords, page, pageSize)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
memoryList := make([]map[string]interface{}, 0, len(memories))
|
|
for _, m := range memories {
|
|
resp := formatRetDataFromMemoryListItem(m)
|
|
var createDateStr *string
|
|
if resp.CreateTime != nil {
|
|
createDateStr = formatDateToString(*resp.CreateTime)
|
|
}
|
|
memoryMap := map[string]interface{}{
|
|
"id": resp.ID,
|
|
"name": resp.Name,
|
|
"avatar": resp.Avatar,
|
|
"tenant_id": resp.TenantID,
|
|
"owner_name": resp.OwnerName,
|
|
"memory_type": resp.MemoryType,
|
|
"storage_type": resp.StorageType,
|
|
"permissions": resp.Permissions,
|
|
"description": resp.Description,
|
|
"create_time": resp.CreateTime,
|
|
"create_date": createDateStr,
|
|
}
|
|
memoryList = append(memoryList, memoryMap)
|
|
}
|
|
|
|
return &ListMemoryResponse{
|
|
MemoryList: memoryList,
|
|
TotalCount: total,
|
|
}, nil
|
|
}
|
|
|
|
// GetMemoryConfig retrieves the full configuration of a memory by ID
|
|
//
|
|
// Parameters:
|
|
// - memoryID: The ID of the memory to retrieve
|
|
//
|
|
// Returns:
|
|
// - *CreateMemoryResponse: The memory configuration details
|
|
// - error: Error if memory not found
|
|
//
|
|
// Example:
|
|
//
|
|
// resp, err := service.GetMemoryConfig("memory456")
|
|
func (s *MemoryService) GetMemoryConfig(memoryID string) (*CreateMemoryResponse, error) {
|
|
memory, err := s.memoryDAO.GetWithOwnerNameByID(memoryID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("memory '%s' not found", memoryID)
|
|
}
|
|
return formatRetDataFromMemoryListItem(memory), nil
|
|
}
|
|
|
|
func (s *MemoryService) GetMemoryMessages(ctx context.Context, userID, memoryID string, agentIDs []string, keywords string, page int, pageSize int) (map[string]interface{}, error) {
|
|
memory, err := s.requireMemoryAccess(ctx, userID, memoryID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
messages, err := s.listMemoryMessages(ctx, memory, agentIDs, keywords, page, pageSize)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rawMessages, _ := messages["message_list"].([]map[string]interface{})
|
|
agentNames := map[string]string{}
|
|
tasks := map[string]map[string]interface{}{}
|
|
if len(rawMessages) > 0 {
|
|
agentNames, err = s.memoryMessageAgentNames(rawMessages)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
tasks, err = s.memoryMessageTasks(memoryID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
for _, message := range rawMessages {
|
|
agentID, _ := message["agent_id"].(string)
|
|
message["agent_name"] = "Unknown"
|
|
if name, ok := agentNames[agentID]; ok {
|
|
message["agent_name"] = name
|
|
}
|
|
message["task"] = map[string]interface{}{}
|
|
if task, ok := tasks[memoryMessageKey(message["message_id"])]; ok {
|
|
message["task"] = task
|
|
}
|
|
if extracts, ok := message["extract"].([]map[string]interface{}); ok {
|
|
for _, extract := range extracts {
|
|
extractAgentID, _ := extract["agent_id"].(string)
|
|
extract["agent_name"] = "Unknown"
|
|
if name, ok := agentNames[extractAgentID]; ok {
|
|
extract["agent_name"] = name
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return common.ConvertFloatsToPyFormat(map[string]interface{}{
|
|
"messages": messages,
|
|
"storage_type": memory.StorageType,
|
|
}).(map[string]interface{}), nil
|
|
}
|
|
|
|
func (s *MemoryService) listMemoryMessages(ctx context.Context, memory *entity.Memory, agentIDs []string, keywords string, page int, pageSize int) (map[string]interface{}, error) {
|
|
if s.docEngine == nil {
|
|
return nil, errors.New("message store is not initialized")
|
|
}
|
|
if page <= 0 {
|
|
page = 1
|
|
}
|
|
if pageSize <= 0 {
|
|
pageSize = defaultMessageLimit
|
|
} else if pageSize > maxMessageLimit {
|
|
pageSize = maxMessageLimit
|
|
}
|
|
|
|
memoryID := memory.ID
|
|
selectFields := memoryMessageListFields()
|
|
filter := map[string]interface{}{
|
|
"message_type": "raw",
|
|
}
|
|
if len(agentIDs) > 0 {
|
|
filter["agent_id"] = agentIDs
|
|
}
|
|
if keywords = strings.TrimSpace(keywords); keywords != "" {
|
|
filter["session_id"] = keywords
|
|
}
|
|
filter["memory_id"] = []string{memoryID}
|
|
indexNames := memorySearchIndexNames([]*entity.Memory{memory})
|
|
|
|
rawReq := &enginetypes.SearchRequest{
|
|
IndexNames: indexNames,
|
|
Offset: (page - 1) * pageSize,
|
|
Limit: pageSize,
|
|
SelectFields: selectFields,
|
|
Filter: filter,
|
|
MatchExprs: []interface{}{},
|
|
OrderBy: (&enginetypes.OrderByExpr{}).Desc("valid_at"),
|
|
}
|
|
rawResult, err := s.docEngine.Search(ctx, rawReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
messages := map[string]interface{}{
|
|
"message_list": []map[string]interface{}{},
|
|
"total_count": int64(0),
|
|
}
|
|
if rawResult != nil {
|
|
messages["total_count"] = rawResult.Total
|
|
}
|
|
if rawResult == nil || rawResult.Total == 0 {
|
|
return messages, nil
|
|
}
|
|
|
|
rawMessages := make([]map[string]interface{}, 0, len(rawResult.Chunks))
|
|
sourceIDs := make([]interface{}, 0, len(rawResult.Chunks))
|
|
for _, chunk := range rawResult.Chunks {
|
|
message := memoryMessageFromChunk(chunk, selectFields)
|
|
message["extract"] = []map[string]interface{}{}
|
|
if messageID, ok := message["message_id"]; ok {
|
|
sourceIDs = append(sourceIDs, messageID)
|
|
}
|
|
rawMessages = append(rawMessages, message)
|
|
}
|
|
|
|
if len(sourceIDs) > 0 {
|
|
extractReq := &enginetypes.SearchRequest{
|
|
IndexNames: indexNames,
|
|
Offset: 0,
|
|
Limit: 512,
|
|
SelectFields: selectFields,
|
|
Filter: map[string]interface{}{
|
|
"memory_id": []string{memoryID},
|
|
"source_id": sourceIDs,
|
|
},
|
|
MatchExprs: []interface{}{},
|
|
OrderBy: (&enginetypes.OrderByExpr{}).Desc("valid_at"),
|
|
}
|
|
extractResult, err := s.docEngine.Search(ctx, extractReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if extractResult != nil && extractResult.Total > 0 {
|
|
groupedExtracts := make(map[string][]map[string]interface{})
|
|
for _, chunk := range extractResult.Chunks {
|
|
message := memoryMessageFromChunk(chunk, selectFields)
|
|
sourceID := memoryMessageKey(message["source_id"])
|
|
groupedExtracts[sourceID] = append(groupedExtracts[sourceID], message)
|
|
}
|
|
for _, message := range rawMessages {
|
|
messageID := memoryMessageKey(message["message_id"])
|
|
if extracts, ok := groupedExtracts[messageID]; ok {
|
|
message["extract"] = extracts
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
messages["message_list"] = rawMessages
|
|
return messages, nil
|
|
}
|
|
|
|
func memoryMessageListFields() []string {
|
|
return []string{
|
|
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id",
|
|
"valid_at", "invalid_at", "forget_at", "status",
|
|
}
|
|
}
|
|
|
|
func memoryMessageFromChunk(chunk map[string]interface{}, fields []string) map[string]interface{} {
|
|
message := make(map[string]interface{}, len(fields))
|
|
for _, field := range fields {
|
|
if value, ok := chunk[field]; ok {
|
|
message[field] = value
|
|
}
|
|
}
|
|
return message
|
|
}
|
|
|
|
func (s *MemoryService) memoryMessageAgentNames(messages []map[string]interface{}) (map[string]string, error) {
|
|
agentIDSet := make(map[string]struct{})
|
|
for _, message := range messages {
|
|
agentID, _ := message["agent_id"].(string)
|
|
if agentID != "" {
|
|
agentIDSet[agentID] = struct{}{}
|
|
}
|
|
}
|
|
if len(agentIDSet) == 0 {
|
|
return map[string]string{}, nil
|
|
}
|
|
|
|
agentIDs := make([]string, 0, len(agentIDSet))
|
|
for agentID := range agentIDSet {
|
|
agentIDs = append(agentIDs, agentID)
|
|
}
|
|
|
|
var agents []struct {
|
|
ID string `gorm:"column:id"`
|
|
Title *string `gorm:"column:title"`
|
|
}
|
|
if err := dao.DB.Model(&entity.UserCanvas{}).Select("id, title").Where("id IN ?", agentIDs).Scan(&agents).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
agentNames := make(map[string]string, len(agents))
|
|
for _, agent := range agents {
|
|
if agent.Title != nil {
|
|
agentNames[agent.ID] = *agent.Title
|
|
}
|
|
}
|
|
return agentNames, nil
|
|
}
|
|
|
|
func (s *MemoryService) memoryMessageTasks(memoryID string) (map[string]map[string]interface{}, error) {
|
|
var tasks []struct {
|
|
ID string `gorm:"column:id"`
|
|
DocID string `gorm:"column:doc_id"`
|
|
FromPage int64 `gorm:"column:from_page"`
|
|
Progress float64 `gorm:"column:progress"`
|
|
ProgressMsg *string `gorm:"column:progress_msg"`
|
|
Digest *string `gorm:"column:digest"`
|
|
ChunkIDs *string `gorm:"column:chunk_ids"`
|
|
CreateTime *int64 `gorm:"column:create_time"`
|
|
}
|
|
if err := dao.DB.Model(&entity.Task{}).
|
|
Select("id, doc_id, from_page, progress, progress_msg, digest, chunk_ids, create_time").
|
|
Where("doc_id IN ?", []string{memoryID}).
|
|
Order("create_time ASC").
|
|
Scan(&tasks).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
taskByMessageID := make(map[string]map[string]interface{}, len(tasks))
|
|
for _, task := range tasks {
|
|
if task.Digest == nil {
|
|
continue
|
|
}
|
|
digest := strings.TrimSpace(*task.Digest)
|
|
if digest == "" {
|
|
continue
|
|
}
|
|
var progressMsg interface{}
|
|
if task.ProgressMsg != nil {
|
|
progressMsg = *task.ProgressMsg
|
|
}
|
|
var chunkIDs interface{}
|
|
if task.ChunkIDs != nil {
|
|
chunkIDs = *task.ChunkIDs
|
|
}
|
|
var createTime interface{}
|
|
if task.CreateTime != nil {
|
|
createTime = *task.CreateTime
|
|
}
|
|
taskMap := map[string]interface{}{
|
|
"id": task.ID,
|
|
"doc_id": task.DocID,
|
|
"from_page": task.FromPage,
|
|
"progress": task.Progress,
|
|
"progress_msg": progressMsg,
|
|
"digest": digest,
|
|
"chunk_ids": chunkIDs,
|
|
"create_time": createTime,
|
|
}
|
|
taskByMessageID[digest] = taskMap
|
|
}
|
|
return taskByMessageID, nil
|
|
}
|
|
|
|
func memoryMessageKey(value interface{}) string {
|
|
return strings.TrimSpace(fmt.Sprint(value))
|
|
}
|
|
|
|
// isList checks if a value is a list or array type
|
|
// This is a utility function for type validation
|
|
//
|
|
// Parameters:
|
|
// - v: The value to check
|
|
//
|
|
// Returns:
|
|
// - bool: true if v is []interface{} or []string, false otherwise
|
|
//
|
|
// Example:
|
|
//
|
|
// isList([]string{"a", "b"}) returns true
|
|
// isList("test") returns false
|
|
func isList(v interface{}) bool {
|
|
switch v.(type) {
|
|
case []interface{}, []string:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// formatRetDataFromMemory converts a Memory model to CreateMemoryResponse format
|
|
// This is a utility function for formatting memory data for API responses
|
|
//
|
|
// Parameters:
|
|
// - memory: The Memory model to format
|
|
//
|
|
// Returns:
|
|
// - *CreateMemoryResponse: Formatted memory response with human-readable types and dates
|
|
//
|
|
// Example:
|
|
//
|
|
// resp := formatRetDataFromMemory(memoryModel)
|
|
func formatRetDataFromMemory(memory *entity.Memory) *CreateMemoryResponse {
|
|
memoryTypes := dao.GetMemoryTypeHuman(memory.MemoryType)
|
|
|
|
resp := &CreateMemoryResponse{
|
|
Memory: *memory,
|
|
OwnerName: nil,
|
|
MemoryType: memoryTypes,
|
|
}
|
|
return resp
|
|
}
|
|
|
|
func formatDateToString(t int64) *string {
|
|
if t == 0 {
|
|
return nil
|
|
}
|
|
// Database stores timestamps in milliseconds, convert to seconds
|
|
if t > 1e10 {
|
|
t = t / 1000
|
|
}
|
|
timeObj := time.Unix(t, 0)
|
|
s := timeObj.Format("2006-01-02 15:04:05")
|
|
return &s
|
|
}
|
|
|
|
// formatRetDataFromMemoryListItem converts a MemoryListItem to CreateMemoryResponse
|
|
// This function is used for both list and detail memory responses where owner_name is from JOIN query
|
|
//
|
|
// Parameters:
|
|
// - memory: MemoryListItem pointer with owner_name from JOIN
|
|
//
|
|
// Returns:
|
|
// - *CreateMemoryResponse: Formatted response with owner_name populated
|
|
//
|
|
// Example:
|
|
//
|
|
// resp := formatRetDataFromMemoryListItem(memoryItem)
|
|
func formatRetDataFromMemoryListItem(memory *entity.MemoryListItem) *CreateMemoryResponse {
|
|
memoryTypes := dao.GetMemoryTypeHuman(memory.MemoryType)
|
|
resp := &CreateMemoryResponse{
|
|
Memory: memory.Memory,
|
|
OwnerName: memory.OwnerName,
|
|
MemoryType: memoryTypes,
|
|
}
|
|
return resp
|
|
}
|