From 32902570141c8da9976b3adbbeb5f0781ccdf887 Mon Sep 17 00:00:00 2001 From: buua436 Date: Mon, 18 May 2026 19:21:47 +0800 Subject: [PATCH] Go: fix forgetting policy validation and fix memory update diff checks (#14976) ### What problem does this PR solve? fix forgetting policy validation and fix memory update diff checks ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- internal/service/memory.go | 143 ++++++++++++++++++++++++++++++++++++- 1 file changed, 142 insertions(+), 1 deletion(-) diff --git a/internal/service/memory.go b/internal/service/memory.go index 882282f4cc..93face56a9 100644 --- a/internal/service/memory.go +++ b/internal/service/memory.go @@ -499,7 +499,7 @@ func (s *MemoryService) UpdateMemory(tenantID string, memoryID string, req *Upda } if req.ForgettingPolicy != nil { - fp := ForgettingPolicy(strings.ToLower(*req.ForgettingPolicy)) + fp := ForgettingPolicy(strings.ToUpper(strings.TrimSpace(*req.ForgettingPolicy))) if !validForgettingPolicies[fp] { return nil, fmt.Errorf("forgetting policy '%s' is not supported", *req.ForgettingPolicy) } @@ -544,6 +544,108 @@ func (s *MemoryService) UpdateMemory(tenantID string, memoryID string, req *Upda return formatRetDataFromMemory(currentMemory), nil } + currentMemoryTypes := dao.GetMemoryTypeHuman(currentMemory.MemoryType) + normalizedCurrentMemoryTypes := normalizeMemoryTypes(currentMemoryTypes) + + filteredUpdateDict := make(map[string]interface{}, len(updateDict)) + for field, value := range updateDict { + switch field { + case "name": + currentName := strings.TrimSpace(currentMemory.Name) + requestName := strings.TrimSpace(fmt.Sprint(value)) + if currentName != requestName { + filteredUpdateDict[field] = value + } + case "permissions": + currentPermissions := strings.ToLower(strings.TrimSpace(currentMemory.Permissions)) + requestPermissions := strings.ToLower(strings.TrimSpace(fmt.Sprint(value))) + if currentPermissions != requestPermissions { + filteredUpdateDict[field] = value + } + case "llm_id": + currentLLMID := strings.TrimSpace(currentMemory.LLMID) + requestLLMID := strings.TrimSpace(fmt.Sprint(value)) + if currentLLMID != requestLLMID { + filteredUpdateDict[field] = value + } + case "embd_id": + currentEmbdID := strings.TrimSpace(currentMemory.EmbdID) + requestEmbdID := strings.TrimSpace(fmt.Sprint(value)) + if currentEmbdID != requestEmbdID { + filteredUpdateDict[field] = value + } + case "tenant_llm_id": + if currentMemory.TenantLLMID == nil || *currentMemory.TenantLLMID != value.(int64) { + filteredUpdateDict[field] = value + } + case "tenant_embd_id": + if currentMemory.TenantEmbdID == nil || *currentMemory.TenantEmbdID != value.(int64) { + filteredUpdateDict[field] = value + } + case "memory_type": + if types, ok := value.([]string); ok { + if !sameStringSet(normalizedCurrentMemoryTypes, normalizeMemoryTypes(types)) { + filteredUpdateDict[field] = value + } + } else { + filteredUpdateDict[field] = value + } + case "memory_size": + if currentMemory.MemorySize != value.(int64) { + filteredUpdateDict[field] = value + } + case "forgetting_policy": + currentForgettingPolicy := strings.ToUpper(strings.TrimSpace(currentMemory.ForgettingPolicy)) + requestForgettingPolicy := strings.ToUpper(strings.TrimSpace(fmt.Sprint(value))) + if currentForgettingPolicy != requestForgettingPolicy { + filteredUpdateDict[field] = value + } + case "temperature": + if currentMemory.Temperature != value.(float64) { + filteredUpdateDict[field] = value + } + case "avatar": + currentAvatar := "" + if currentMemory.Avatar != nil { + currentAvatar = *currentMemory.Avatar + } + if currentAvatar != fmt.Sprint(value) { + filteredUpdateDict[field] = value + } + case "description": + currentDescription := "" + if currentMemory.Description != nil { + currentDescription = *currentMemory.Description + } + if currentDescription != fmt.Sprint(value) { + filteredUpdateDict[field] = value + } + case "system_prompt": + currentSystemPrompt := "" + if currentMemory.SystemPrompt != nil { + currentSystemPrompt = *currentMemory.SystemPrompt + } + if currentSystemPrompt != fmt.Sprint(value) { + filteredUpdateDict[field] = value + } + case "user_prompt": + currentUserPrompt := "" + if currentMemory.UserPrompt != nil { + currentUserPrompt = *currentMemory.UserPrompt + } + if currentUserPrompt != fmt.Sprint(value) { + filteredUpdateDict[field] = value + } + default: + filteredUpdateDict[field] = value + } + } + updateDict = filteredUpdateDict + + if len(updateDict) == 0 { + return formatRetDataFromMemory(currentMemory), nil + } + memorySize := currentMemory.MemorySize notAllowedUpdate := []string{} for _, f := range []string{"tenant_embd_id", "embd_id", "memory_type"} { @@ -581,6 +683,45 @@ func (s *MemoryService) UpdateMemory(tenantID string, memoryID string, req *Upda return formatRetDataFromMemory(updatedMemory), nil } +func normalizeMemoryTypes(memoryTypes []string) []string { + normalized := make([]string, 0, len(memoryTypes)) + seen := make(map[string]struct{}, len(memoryTypes)) + for _, mt := range memoryTypes { + mt = strings.ToLower(strings.TrimSpace(mt)) + if mt == "" { + continue + } + if _, exists := seen[mt]; exists { + continue + } + seen[mt] = struct{}{} + normalized = append(normalized, mt) + } + return normalized +} + +func sameStringSet(a, b []string) bool { + if len(a) != len(b) { + return false + } + counts := make(map[string]int, len(a)) + for _, item := range a { + counts[item]++ + } + for _, item := range b { + counts[item]-- + if counts[item] < 0 { + return false + } + } + for _, count := range counts { + if count != 0 { + return false + } + } + return true +} + // DeleteMemory deletes a memory by ID // It also deletes associated message indexes before removing the memory record //