mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Go: fix forgetting policy validation and fix memory update diff checks (#14976)
### What problem does this PR solve? fix forgetting policy validation and fix memory update diff checks ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@@ -499,7 +499,7 @@ func (s *MemoryService) UpdateMemory(tenantID string, memoryID string, req *Upda
|
||||
}
|
||||
|
||||
if req.ForgettingPolicy != nil {
|
||||
fp := ForgettingPolicy(strings.ToLower(*req.ForgettingPolicy))
|
||||
fp := ForgettingPolicy(strings.ToUpper(strings.TrimSpace(*req.ForgettingPolicy)))
|
||||
if !validForgettingPolicies[fp] {
|
||||
return nil, fmt.Errorf("forgetting policy '%s' is not supported", *req.ForgettingPolicy)
|
||||
}
|
||||
@@ -544,6 +544,108 @@ func (s *MemoryService) UpdateMemory(tenantID string, memoryID string, req *Upda
|
||||
return formatRetDataFromMemory(currentMemory), nil
|
||||
}
|
||||
|
||||
currentMemoryTypes := dao.GetMemoryTypeHuman(currentMemory.MemoryType)
|
||||
normalizedCurrentMemoryTypes := normalizeMemoryTypes(currentMemoryTypes)
|
||||
|
||||
filteredUpdateDict := make(map[string]interface{}, len(updateDict))
|
||||
for field, value := range updateDict {
|
||||
switch field {
|
||||
case "name":
|
||||
currentName := strings.TrimSpace(currentMemory.Name)
|
||||
requestName := strings.TrimSpace(fmt.Sprint(value))
|
||||
if currentName != requestName {
|
||||
filteredUpdateDict[field] = value
|
||||
}
|
||||
case "permissions":
|
||||
currentPermissions := strings.ToLower(strings.TrimSpace(currentMemory.Permissions))
|
||||
requestPermissions := strings.ToLower(strings.TrimSpace(fmt.Sprint(value)))
|
||||
if currentPermissions != requestPermissions {
|
||||
filteredUpdateDict[field] = value
|
||||
}
|
||||
case "llm_id":
|
||||
currentLLMID := strings.TrimSpace(currentMemory.LLMID)
|
||||
requestLLMID := strings.TrimSpace(fmt.Sprint(value))
|
||||
if currentLLMID != requestLLMID {
|
||||
filteredUpdateDict[field] = value
|
||||
}
|
||||
case "embd_id":
|
||||
currentEmbdID := strings.TrimSpace(currentMemory.EmbdID)
|
||||
requestEmbdID := strings.TrimSpace(fmt.Sprint(value))
|
||||
if currentEmbdID != requestEmbdID {
|
||||
filteredUpdateDict[field] = value
|
||||
}
|
||||
case "tenant_llm_id":
|
||||
if currentMemory.TenantLLMID == nil || *currentMemory.TenantLLMID != value.(int64) {
|
||||
filteredUpdateDict[field] = value
|
||||
}
|
||||
case "tenant_embd_id":
|
||||
if currentMemory.TenantEmbdID == nil || *currentMemory.TenantEmbdID != value.(int64) {
|
||||
filteredUpdateDict[field] = value
|
||||
}
|
||||
case "memory_type":
|
||||
if types, ok := value.([]string); ok {
|
||||
if !sameStringSet(normalizedCurrentMemoryTypes, normalizeMemoryTypes(types)) {
|
||||
filteredUpdateDict[field] = value
|
||||
}
|
||||
} else {
|
||||
filteredUpdateDict[field] = value
|
||||
}
|
||||
case "memory_size":
|
||||
if currentMemory.MemorySize != value.(int64) {
|
||||
filteredUpdateDict[field] = value
|
||||
}
|
||||
case "forgetting_policy":
|
||||
currentForgettingPolicy := strings.ToUpper(strings.TrimSpace(currentMemory.ForgettingPolicy))
|
||||
requestForgettingPolicy := strings.ToUpper(strings.TrimSpace(fmt.Sprint(value)))
|
||||
if currentForgettingPolicy != requestForgettingPolicy {
|
||||
filteredUpdateDict[field] = value
|
||||
}
|
||||
case "temperature":
|
||||
if currentMemory.Temperature != value.(float64) {
|
||||
filteredUpdateDict[field] = value
|
||||
}
|
||||
case "avatar":
|
||||
currentAvatar := ""
|
||||
if currentMemory.Avatar != nil {
|
||||
currentAvatar = *currentMemory.Avatar
|
||||
}
|
||||
if currentAvatar != fmt.Sprint(value) {
|
||||
filteredUpdateDict[field] = value
|
||||
}
|
||||
case "description":
|
||||
currentDescription := ""
|
||||
if currentMemory.Description != nil {
|
||||
currentDescription = *currentMemory.Description
|
||||
}
|
||||
if currentDescription != fmt.Sprint(value) {
|
||||
filteredUpdateDict[field] = value
|
||||
}
|
||||
case "system_prompt":
|
||||
currentSystemPrompt := ""
|
||||
if currentMemory.SystemPrompt != nil {
|
||||
currentSystemPrompt = *currentMemory.SystemPrompt
|
||||
}
|
||||
if currentSystemPrompt != fmt.Sprint(value) {
|
||||
filteredUpdateDict[field] = value
|
||||
}
|
||||
case "user_prompt":
|
||||
currentUserPrompt := ""
|
||||
if currentMemory.UserPrompt != nil {
|
||||
currentUserPrompt = *currentMemory.UserPrompt
|
||||
}
|
||||
if currentUserPrompt != fmt.Sprint(value) {
|
||||
filteredUpdateDict[field] = value
|
||||
}
|
||||
default:
|
||||
filteredUpdateDict[field] = value
|
||||
}
|
||||
}
|
||||
updateDict = filteredUpdateDict
|
||||
|
||||
if len(updateDict) == 0 {
|
||||
return formatRetDataFromMemory(currentMemory), nil
|
||||
}
|
||||
|
||||
memorySize := currentMemory.MemorySize
|
||||
notAllowedUpdate := []string{}
|
||||
for _, f := range []string{"tenant_embd_id", "embd_id", "memory_type"} {
|
||||
@@ -581,6 +683,45 @@ func (s *MemoryService) UpdateMemory(tenantID string, memoryID string, req *Upda
|
||||
return formatRetDataFromMemory(updatedMemory), nil
|
||||
}
|
||||
|
||||
func normalizeMemoryTypes(memoryTypes []string) []string {
|
||||
normalized := make([]string, 0, len(memoryTypes))
|
||||
seen := make(map[string]struct{}, len(memoryTypes))
|
||||
for _, mt := range memoryTypes {
|
||||
mt = strings.ToLower(strings.TrimSpace(mt))
|
||||
if mt == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[mt]; exists {
|
||||
continue
|
||||
}
|
||||
seen[mt] = struct{}{}
|
||||
normalized = append(normalized, mt)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func sameStringSet(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
counts := make(map[string]int, len(a))
|
||||
for _, item := range a {
|
||||
counts[item]++
|
||||
}
|
||||
for _, item := range b {
|
||||
counts[item]--
|
||||
if counts[item] < 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for _, count := range counts {
|
||||
if count != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// DeleteMemory deletes a memory by ID
|
||||
// It also deletes associated message indexes before removing the memory record
|
||||
//
|
||||
|
||||
Reference in New Issue
Block a user