Go: fix forgetting policy validation and fix memory update diff checks (#14976)

### What problem does this PR solve?

 fix forgetting policy validation and fix memory update diff checks

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
buua436
2026-05-18 19:21:47 +08:00
committed by GitHub
parent 93d3deb5e4
commit 3290257014

View File

@@ -499,7 +499,7 @@ func (s *MemoryService) UpdateMemory(tenantID string, memoryID string, req *Upda
}
if req.ForgettingPolicy != nil {
fp := ForgettingPolicy(strings.ToLower(*req.ForgettingPolicy))
fp := ForgettingPolicy(strings.ToUpper(strings.TrimSpace(*req.ForgettingPolicy)))
if !validForgettingPolicies[fp] {
return nil, fmt.Errorf("forgetting policy '%s' is not supported", *req.ForgettingPolicy)
}
@@ -544,6 +544,108 @@ func (s *MemoryService) UpdateMemory(tenantID string, memoryID string, req *Upda
return formatRetDataFromMemory(currentMemory), nil
}
currentMemoryTypes := dao.GetMemoryTypeHuman(currentMemory.MemoryType)
normalizedCurrentMemoryTypes := normalizeMemoryTypes(currentMemoryTypes)
filteredUpdateDict := make(map[string]interface{}, len(updateDict))
for field, value := range updateDict {
switch field {
case "name":
currentName := strings.TrimSpace(currentMemory.Name)
requestName := strings.TrimSpace(fmt.Sprint(value))
if currentName != requestName {
filteredUpdateDict[field] = value
}
case "permissions":
currentPermissions := strings.ToLower(strings.TrimSpace(currentMemory.Permissions))
requestPermissions := strings.ToLower(strings.TrimSpace(fmt.Sprint(value)))
if currentPermissions != requestPermissions {
filteredUpdateDict[field] = value
}
case "llm_id":
currentLLMID := strings.TrimSpace(currentMemory.LLMID)
requestLLMID := strings.TrimSpace(fmt.Sprint(value))
if currentLLMID != requestLLMID {
filteredUpdateDict[field] = value
}
case "embd_id":
currentEmbdID := strings.TrimSpace(currentMemory.EmbdID)
requestEmbdID := strings.TrimSpace(fmt.Sprint(value))
if currentEmbdID != requestEmbdID {
filteredUpdateDict[field] = value
}
case "tenant_llm_id":
if currentMemory.TenantLLMID == nil || *currentMemory.TenantLLMID != value.(int64) {
filteredUpdateDict[field] = value
}
case "tenant_embd_id":
if currentMemory.TenantEmbdID == nil || *currentMemory.TenantEmbdID != value.(int64) {
filteredUpdateDict[field] = value
}
case "memory_type":
if types, ok := value.([]string); ok {
if !sameStringSet(normalizedCurrentMemoryTypes, normalizeMemoryTypes(types)) {
filteredUpdateDict[field] = value
}
} else {
filteredUpdateDict[field] = value
}
case "memory_size":
if currentMemory.MemorySize != value.(int64) {
filteredUpdateDict[field] = value
}
case "forgetting_policy":
currentForgettingPolicy := strings.ToUpper(strings.TrimSpace(currentMemory.ForgettingPolicy))
requestForgettingPolicy := strings.ToUpper(strings.TrimSpace(fmt.Sprint(value)))
if currentForgettingPolicy != requestForgettingPolicy {
filteredUpdateDict[field] = value
}
case "temperature":
if currentMemory.Temperature != value.(float64) {
filteredUpdateDict[field] = value
}
case "avatar":
currentAvatar := ""
if currentMemory.Avatar != nil {
currentAvatar = *currentMemory.Avatar
}
if currentAvatar != fmt.Sprint(value) {
filteredUpdateDict[field] = value
}
case "description":
currentDescription := ""
if currentMemory.Description != nil {
currentDescription = *currentMemory.Description
}
if currentDescription != fmt.Sprint(value) {
filteredUpdateDict[field] = value
}
case "system_prompt":
currentSystemPrompt := ""
if currentMemory.SystemPrompt != nil {
currentSystemPrompt = *currentMemory.SystemPrompt
}
if currentSystemPrompt != fmt.Sprint(value) {
filteredUpdateDict[field] = value
}
case "user_prompt":
currentUserPrompt := ""
if currentMemory.UserPrompt != nil {
currentUserPrompt = *currentMemory.UserPrompt
}
if currentUserPrompt != fmt.Sprint(value) {
filteredUpdateDict[field] = value
}
default:
filteredUpdateDict[field] = value
}
}
updateDict = filteredUpdateDict
if len(updateDict) == 0 {
return formatRetDataFromMemory(currentMemory), nil
}
memorySize := currentMemory.MemorySize
notAllowedUpdate := []string{}
for _, f := range []string{"tenant_embd_id", "embd_id", "memory_type"} {
@@ -581,6 +683,45 @@ func (s *MemoryService) UpdateMemory(tenantID string, memoryID string, req *Upda
return formatRetDataFromMemory(updatedMemory), nil
}
func normalizeMemoryTypes(memoryTypes []string) []string {
normalized := make([]string, 0, len(memoryTypes))
seen := make(map[string]struct{}, len(memoryTypes))
for _, mt := range memoryTypes {
mt = strings.ToLower(strings.TrimSpace(mt))
if mt == "" {
continue
}
if _, exists := seen[mt]; exists {
continue
}
seen[mt] = struct{}{}
normalized = append(normalized, mt)
}
return normalized
}
func sameStringSet(a, b []string) bool {
if len(a) != len(b) {
return false
}
counts := make(map[string]int, len(a))
for _, item := range a {
counts[item]++
}
for _, item := range b {
counts[item]--
if counts[item] < 0 {
return false
}
}
for _, count := range counts {
if count != 0 {
return false
}
}
return true
}
// DeleteMemory deletes a memory by ID
// It also deletes associated message indexes before removing the memory record
//