feat[Go]: implemet api: Search/Get/Update-Messages (#16307)

### What problem does this PR solve?

As title:
implement:
```
/api/v1/messages/search GET
/api/v1/messages GET
/api/v1/messages/<memory_id>:<message_id>/content GET
/api/v1/memories/<memory_id>/config GET
/api/v1/messages/<memory_id>:<message_id> PUT
```

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Haruko386
2026-06-25 19:07:34 +08:00
committed by GitHub
parent 49312cace3
commit 74597b8683
7 changed files with 1524 additions and 70 deletions

View File

@@ -32,11 +32,12 @@ import (
"strconv"
"strings"
"github.com/elastic/go-elasticsearch/v8/esapi"
"github.com/json-iterator/go"
"ragflow/internal/common"
"ragflow/internal/engine/types"
"github.com/elastic/go-elasticsearch/v8/esapi"
"github.com/json-iterator/go"
"go.uber.org/zap"
)
@@ -44,6 +45,8 @@ var jsonIterator = jsoniter.Config{
SortMapKeys: false,
}.Froze()
var memoryMessageVectorFieldRE = regexp.MustCompile(`^q_\d+_vec$`)
var (
elasticsearchHighlightEmTagRE = regexp.MustCompile(`<em>[^<>]+</em>`)
elasticsearchHighlightNewlineRE = regexp.MustCompile(`[\r\n]`)
@@ -233,6 +236,14 @@ func (e *elasticsearchEngine) UpdateChunks(ctx context.Context, condition map[st
return fmt.Errorf("index '%s' does not exist", fullIndexName)
}
if strings.HasPrefix(fullIndexName, "memory_") {
condition["memory_id"] = datasetID
if messageDocID, ok := condition["id"].(string); ok {
return e.updateSingleMemoryMessage(ctx, fullIndexName, messageDocID, newValue)
}
return e.updateChunksByQuery(ctx, fullIndexName, mapMemoryMessageESConditionFields(condition), mapMemoryMessageESUpdateFields(newValue))
}
// Add kb_id to condition
condition["kb_id"] = datasetID
@@ -245,6 +256,59 @@ func (e *elasticsearchEngine) UpdateChunks(ctx context.Context, condition map[st
return e.updateChunksByQuery(ctx, fullIndexName, condition, newValue)
}
func (e *elasticsearchEngine) updateSingleMemoryMessage(ctx context.Context, indexName, messageDocID string, newValue map[string]interface{}) error {
doc := mapMemoryMessageESUpdateFields(newValue)
delete(doc, "id")
if len(doc) == 0 {
return nil
}
updateBody := map[string]interface{}{"doc": doc}
body, err := json.Marshal(updateBody)
if err != nil {
return fmt.Errorf("failed to marshal memory message update request: %w", err)
}
req := esapi.UpdateRequest{
Index: indexName,
DocumentID: messageDocID,
Body: bytes.NewReader(body),
}
res, err := req.Do(ctx, e.client)
if err != nil {
return fmt.Errorf("failed to update memory message: %w", err)
}
defer res.Body.Close()
if res.IsError() {
if res.StatusCode == http.StatusNotFound {
return fmt.Errorf("%w: %s", types.ErrDocumentNotFound, messageDocID)
}
bodyBytes, _ := io.ReadAll(res.Body)
return fmt.Errorf("elasticsearch memory message update error: %s, body: %s", res.Status(), string(bodyBytes))
}
return nil
}
func mapMemoryMessageESUpdateFields(newValue map[string]interface{}) map[string]interface{} {
doc := make(map[string]interface{}, len(newValue))
for k, v := range newValue {
switch k {
case "remove", "add":
doc[k] = v
default:
doc[mapMemoryMessageESField(k, false)] = v
}
}
return doc
}
func mapMemoryMessageESConditionFields(condition map[string]interface{}) map[string]interface{} {
mapped := make(map[string]interface{}, len(condition))
for k, v := range condition {
mapped[mapMemoryMessageESField(k, false)] = v
}
return mapped
}
// updateSingleChunk handles single document update
func (e *elasticsearchEngine) updateSingleChunk(ctx context.Context, indexName, chunkID string, newValue map[string]interface{}) error {
common.Debug("ElasticsearchConnection.updateSingleChunk called", zap.String("indexName", indexName), zap.String("chunkID", chunkID))
@@ -762,15 +826,18 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque
// Detect index types
isSkillIndex := false
isMemoryIndex := false
for _, idx := range req.IndexNames {
if strings.HasPrefix(idx, "skill_") {
isSkillIndex = true
break
}
if strings.HasPrefix(idx, "memory_") {
isMemoryIndex = true
}
}
// Build bool query from condition
boolQuery := buildBoolQueryFromCondition(req.Filter, req.KbIDs, isSkillIndex)
boolQuery := buildBoolQueryFromCondition(req.Filter, req.KbIDs, isSkillIndex, isMemoryIndex)
// Extract vector_similarity_weight from FusionExpr
var matchText *types.MatchTextExpr
@@ -816,7 +883,7 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque
queryBody := make(map[string]interface{})
if matchText != nil {
textQuery := buildQueryStringQuery(matchText, vectorSimilarityWeight, isSkillIndex)
textQuery := buildQueryStringQuery(matchText, vectorSimilarityWeight, isSkillIndex, isMemoryIndex)
if boolQuery != nil {
if boolMap, ok := boolQuery["bool"].(map[string]interface{}); ok {
if must, ok := boolMap["must"].([]interface{}); ok {
@@ -874,7 +941,7 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque
}
// Add rank_feature queries
if req.RankFeature != nil && len(req.RankFeature) > 0 && !isSkillIndex {
if req.RankFeature != nil && len(req.RankFeature) > 0 && !isSkillIndex && !isMemoryIndex {
rankFeatureQuery := buildRankFeatureQuery(req.RankFeature)
if rankFeatureQuery != nil {
if boolQuery, ok := queryBody["query"].(map[string]interface{}); ok {
@@ -929,21 +996,25 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque
// Set _source and fields for vector fields
hasTextMatch := matchText != nil
selectFields := req.SelectFields
if isMemoryIndex {
selectFields = mapMemoryMessageESFields(req.SelectFields, false)
}
if len(req.SelectFields) > 0 {
// Use caller-specified fields, add pagerank_fld/tag_fld if needed
queryBody["_source"] = req.SelectFields
queryBody["_source"] = selectFields
if hasTextMatch || hasVectorMatch {
if !isSkillIndex {
if !slices.Contains(req.SelectFields, common.PAGERANK_FLD) {
if !isSkillIndex && !isMemoryIndex {
if !slices.Contains(selectFields, common.PAGERANK_FLD) {
queryBody["_source"] = append(queryBody["_source"].([]string), common.PAGERANK_FLD)
}
if !slices.Contains(req.SelectFields, common.TAG_FLD) {
if !slices.Contains(selectFields, common.TAG_FLD) {
queryBody["_source"] = append(queryBody["_source"].([]string), common.TAG_FLD)
}
}
}
var vectorFields []string
for _, f := range req.SelectFields {
for _, f := range selectFields {
if strings.HasSuffix(f, "_vec") {
vectorFields = append(vectorFields, f)
}
@@ -954,7 +1025,7 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque
} else {
// No explicit SelectFields - use match_all, but add pagerank_fld/tag_fld for scoring if needed
if hasTextMatch || hasVectorMatch {
if !isSkillIndex {
if !isSkillIndex && !isMemoryIndex {
queryBody["_source"] = []string{common.PAGERANK_FLD, common.TAG_FLD}
}
}
@@ -1021,6 +1092,10 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque
}
}
if isMemoryIndex {
normalizeMemoryMessageChunks(allResults)
}
// Post-processing: Sort results by score
if len(allResults) > 0 && (matchText != nil || hasVectorMatch) {
scoreColumn := "_score"
@@ -1032,6 +1107,9 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque
if isSkillIndex {
pagerankField = ""
}
if isMemoryIndex {
pagerankField = ""
}
allResults = calculateScores(allResults, scoreColumn, pagerankField)
allResults = sortByScore(allResults, limit)
@@ -1293,17 +1371,105 @@ func sortValuesEqual(a, b []interface{}) bool {
return true
}
// buildBoolQueryFromCondition builds an ES bool query from condition map
// For skill index, uses 'status' field instead of 'available_int'
func buildBoolQueryFromCondition(filter map[string]interface{}, kbIDs []string, isSkillIndex bool) map[string]interface{} {
func mapMemoryMessageESField(field string, useTokenizedContent bool) string {
name := field
boost := ""
if base, suffix, ok := strings.Cut(field, "^"); ok {
name = base
boost = "^" + suffix
}
switch name {
case "message_type":
name = "message_type_kwd"
case "status":
name = "status_int"
case "content":
if useTokenizedContent {
name = "tokenized_content_ltks"
} else {
name = "content_ltks"
}
}
return name + boost
}
func mapMemoryMessageESFields(fields []string, useTokenizedContent bool) []string {
if len(fields) == 0 {
return fields
}
mapped := make([]string, 0, len(fields))
seen := make(map[string]struct{}, len(fields))
for _, field := range fields {
mappedField := mapMemoryMessageESField(field, useTokenizedContent)
if _, ok := seen[mappedField]; ok {
continue
}
seen[mappedField] = struct{}{}
mapped = append(mapped, mappedField)
}
return mapped
}
func normalizeMemoryMessageChunks(chunks []map[string]interface{}) {
for _, chunk := range chunks {
for key, val := range chunk {
if memoryMessageVectorFieldRE.MatchString(key) {
chunk["content_embed"] = val
delete(chunk, key)
}
}
if val, ok := chunk["message_type_kwd"]; ok {
chunk["message_type"] = val
delete(chunk, "message_type_kwd")
}
if val, ok := chunk["status_int"]; ok {
chunk["status"] = memoryMessageStatusBool(val)
delete(chunk, "status_int")
}
if val, ok := chunk["content_ltks"]; ok {
chunk["content"] = val
delete(chunk, "content_ltks")
}
}
}
func memoryMessageStatusBool(value interface{}) bool {
switch v := value.(type) {
case bool:
return v
case int:
return v != 0
case int64:
return v != 0
case float64:
return v != 0
case json.Number:
n, err := v.Int64()
return err == nil && n != 0
case string:
return v != "" && v != "0" && !strings.EqualFold(v, "false")
default:
return false
}
}
// buildBoolQueryFromCondition builds an ES bool query from condition map.
// Skill indexes use status, regular chunk indexes use kb_id, and memory
// message indexes use memory_id plus message-specific storage fields.
func buildBoolQueryFromCondition(filter map[string]interface{}, kbIDs []string, isSkillIndex, isMemoryIndex bool) map[string]interface{} {
var mustClauses []interface{}
var filterClauses []interface{}
var shouldClauses []interface{}
// Add kb_id to condition
// Memory message indexes use memory_id, regular chunk indexes use kb_id.
if kbIDs != nil && len(kbIDs) > 0 {
fieldName := "kb_id"
if isMemoryIndex {
fieldName = "memory_id"
}
filterClauses = append(filterClauses, map[string]interface{}{
"terms": map[string]interface{}{"kb_id": kbIDs},
"terms": map[string]interface{}{fieldName: kbIDs},
})
}
@@ -1321,6 +1487,9 @@ func buildBoolQueryFromCondition(filter map[string]interface{}, kbIDs []string,
}
for k, v := range filter {
if isMemoryIndex {
k = mapMemoryMessageESField(k, false)
}
// For skill index, handle 'status' field instead of 'available_int'
if isSkillIndex && k == "status" {
if v == nil || v == "" {
@@ -1389,6 +1558,18 @@ func buildBoolQueryFromCondition(filter map[string]interface{}, kbIDs []string,
if v == nil || v == "" {
continue
}
if isMemoryIndex && k == "session_id" {
if strVal, ok := v.(string); ok && strVal != "" {
filterClauses = append(filterClauses, map[string]interface{}{
"query_string": map[string]interface{}{
"query": fmt.Sprintf("*%s*", strVal),
"fields": []string{"session_id"},
"analyze_wildcard": true,
},
})
continue
}
}
if listVal, ok := v.([]interface{}); ok {
filterClauses = append(filterClauses, map[string]interface{}{
"terms": map[string]interface{}{k: listVal},
@@ -1435,7 +1616,7 @@ func buildBoolQueryFromCondition(filter map[string]interface{}, kbIDs []string,
// buildQueryStringQuery builds a query_string query from MatchTextExpr
// When isSkillIndex is true, uses skill-specific fields (name_tks, tags_tks, etc.)
// Otherwise uses document fields (title_tks, content_ltks, etc.)
func buildQueryStringQuery(matchText *types.MatchTextExpr, vectorSimilarityWeight float64, isSkillIndex bool) map[string]interface{} {
func buildQueryStringQuery(matchText *types.MatchTextExpr, vectorSimilarityWeight float64, isSkillIndex, isMemoryIndex bool) map[string]interface{} {
if matchText == nil {
return nil
}
@@ -1451,10 +1632,15 @@ func buildQueryStringQuery(matchText *types.MatchTextExpr, vectorSimilarityWeigh
if fields == nil || len(fields) == 0 {
if isSkillIndex {
fields = []string{"name_tks^10", "tags_tks^5", "description_tks^3", "content_tks^1"}
} else if isMemoryIndex {
fields = []string{"tokenized_content_ltks"}
} else {
fields = []string{"title_tks^10", "title_sm_tks^5", "important_kwd^30", "important_tks^20", "question_tks^20", "content_ltks^2", "content_sm_ltks"}
}
}
if isMemoryIndex {
fields = mapMemoryMessageESFields(fields, true)
}
boost := 1.0
if matchText.ExtraOptions != nil {
@@ -1507,6 +1693,10 @@ func buildRankFeatureQuery(rankFeature map[string]float64) []map[string]interfac
// GetChunk gets a chunk by ID using ES search API
func (e *elasticsearchEngine) GetChunk(ctx context.Context, baseName, chunkID string, datasetIDs []string) (interface{}, error) {
if strings.HasPrefix(baseName, "memory_") {
return e.getMemoryMessage(ctx, baseName, chunkID)
}
// Try search by doc_id field (which is stored in the document)
for _, datasetID := range datasetIDs {
searchReq := map[string]interface{}{
@@ -1575,6 +1765,41 @@ func (e *elasticsearchEngine) GetChunk(ctx context.Context, baseName, chunkID st
return nil, nil
}
func (e *elasticsearchEngine) getMemoryMessage(ctx context.Context, indexName, docID string) (interface{}, error) {
req := esapi.GetRequest{
Index: indexName,
DocumentID: docID,
}
res, err := req.Do(ctx, e.client)
if err != nil {
return nil, fmt.Errorf("failed to get memory message: %w", err)
}
defer res.Body.Close()
if res.IsError() {
if res.StatusCode == http.StatusNotFound {
return nil, fmt.Errorf("%w: %s", types.ErrDocumentNotFound, docID)
}
return nil, fmt.Errorf("elasticsearch memory message get error: %s", res.Status())
}
var getResult struct {
Found bool `json:"found"`
Source map[string]interface{} `json:"_source"`
}
if err := json.NewDecoder(res.Body).Decode(&getResult); err != nil {
return nil, fmt.Errorf("failed to parse memory message get response: %w", err)
}
if !getResult.Found || getResult.Source == nil {
return nil, nil
}
message := getResult.Source
message["id"] = docID
normalizeMemoryMessageChunks([]map[string]interface{}{message})
return message, nil
}
// GetFields extracts the requested fields from ES search response chunks
//
// Unlike Infinity, Elasticsearch does NOT use convertSelectFields before querying.

View File

@@ -379,7 +379,7 @@ func sortedCopy(in []int) []int {
func TestBuildBoolQueryFromConditionIDFilter(t *testing.T) {
check := func(name string, cond map[string]interface{}, wantFields []string) {
t.Helper()
got := buildBoolQueryFromCondition(cond, nil, false)
got := buildBoolQueryFromCondition(cond, nil, false, false)
outer, ok := got["bool"].(map[string]interface{})
if !ok {
t.Fatalf("%s: missing bool wrapper: %v", name, got)

View File

@@ -1282,6 +1282,13 @@ func applyFieldMappings(chunks []map[string]interface{}) {
chunk["authors_sm_tks"] = val
}
if val, ok := chunk["message_type_kwd"]; ok {
chunk["message_type"] = val
}
if val, ok := chunk["status_int"]; ok {
chunk["status"] = memoryMessageStatusBool(val)
}
// position_int: convert from hex string to array format (grouped by 5)
if val, ok := chunk["position_int"].(string); ok {
chunk["position_int"] = utility.ConvertHexToPositionIntArray(val)
@@ -1332,6 +1339,26 @@ func applyFieldMappings(chunks []map[string]interface{}) {
}
}
func memoryMessageStatusBool(value interface{}) bool {
switch v := value.(type) {
case bool:
return v
case int:
return v != 0
case int64:
return v != 0
case float64:
return v != 0
case json.Number:
n, err := v.Int64()
return err == nil && n != 0
case string:
return v != "" && v != "0" && !strings.EqualFold(v, "false")
default:
return false
}
}
// GetFields extracts the requested fields from Infinity search results
func (e *infinityEngine) GetFields(chunks []map[string]interface{}, fields []string) map[string]map[string]interface{} {
result := make(map[string]map[string]interface{})
@@ -1837,6 +1864,8 @@ func convertSelectFields(output []string, isSkillIndex ...bool) []string {
"content_sm_ltks": "content",
"authors_tks": "authors",
"authors_sm_tks": "authors",
"message_type": "message_type_kwd",
"status": "status_int",
}
skillIndex := false
@@ -2005,6 +2034,12 @@ func equivalentConditionToStr(condition map[string]interface{}) string {
continue
}
if k == "message_type" {
k = "message_type_kwd"
} else if k == "status" {
k = "status_int"
}
// Handle list values (mixed types - strings get quotes, numbers don't)
if list, ok := v.([]interface{}); ok && len(list) > 0 {
var strItems, numItems []string
@@ -2210,6 +2245,19 @@ func transformChunkFields(chunk map[string]interface{}, embeddingCols [][2]inter
d["questions"] = strings.Join(utility.ConvertToStringSlice(v), "\n")
case "tag_kwd":
d["tag_kwd"] = strings.Join(utility.ConvertToStringSlice(v), "###")
case "message_type":
d["message_type_kwd"] = v
case "status":
switch status := v.(type) {
case bool:
if status {
d["status_int"] = 1
} else {
d["status_int"] = 0
}
default:
d["status_int"] = v
}
case "question_tks":
if _, exists := chunk["question_kwd"]; !exists {
d["questions"] = utility.ConvertToString(v)

View File

@@ -20,6 +20,7 @@
package handler
import (
"encoding/json"
"errors"
"net/http"
"os"
@@ -515,13 +516,67 @@ func (h *MemoryHandler) GetMemoryConfig(c *gin.Context) {
// - message: true
// - data.messages: Array of message objects
// - data.storage_type: Storage type
//
// TODO: Implementation pending - depends on CanvasService and TaskService
func (h *MemoryHandler) GetMemoryMessages(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := strings.TrimSpace(user.ID)
if userID == "" {
jsonError(c, common.CodeAuthenticationError, "user id is required")
return
}
memoryID := strings.TrimSpace(c.Param("memory_id"))
if memoryID == "" {
jsonError(c, common.CodeArgumentError, "memory_id is required")
return
}
var agentIDs []string
values := c.QueryArray("agent_id")
for _, v := range values {
parts := strings.Split(v, ",")
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
agentIDs = append(agentIDs, p)
}
}
}
keywords := strings.TrimSpace(c.DefaultQuery("keywords", ""))
page, err := strconv.Atoi(c.DefaultQuery("page", "1"))
if err != nil || page <= 0 {
jsonError(c, common.CodeArgumentError, "page must be a positive integer")
return
}
pageSize, err := strconv.Atoi(c.DefaultQuery("page_size", "50"))
if err != nil || pageSize <= 0 {
jsonError(c, common.CodeArgumentError, "page_size must be a positive integer")
return
}
if pageSize > 100 {
jsonError(c, common.CodeArgumentError, "page_size must be less than or equal to 100")
return
}
data, err := h.memoryService.GetMemoryMessages(c.Request.Context(), userID, memoryID, agentIDs, keywords, page, pageSize)
if err != nil {
if isMemoryServiceNotFound(err) {
jsonError(c, common.CodeNotFound, err.Error())
return
}
jsonError(c, common.CodeServerError, "Internal server error")
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": "GetMemoryMessages not implemented - pending CanvasService and TaskService dependencies",
"data": nil,
"code": common.CodeSuccess,
"message": true,
"data": data,
})
}
@@ -541,7 +596,7 @@ func (h *MemoryHandler) GetMemoryMessages(c *gin.Context) {
// - agent_response (required): Agent response
// - user_id (optional): User ID
//
// TODO: Implementation pending - depends on embedding engine
// TODO: Haruko386 is implementing this for now, if you implement this, delete this line plz
func (h *MemoryHandler) AddMessage(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
@@ -645,16 +700,123 @@ func parseMemoryMessagePath(memoryMessage string) (string, int64, error) {
//
// Request Parameters (JSON Body):
// - status (required): Message status, boolean
//
// TODO: Implementation pending - depends on embedding engine
func (h *MemoryHandler) UpdateMessage(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := strings.TrimSpace(user.ID)
if userID == "" {
jsonError(c, common.CodeAuthenticationError, "user id is required")
return
}
memoryID, messageID, err := parseMemoryMessagePath(c.Param("memory_message"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
"message": err.Error(),
})
return
}
var req map[string]interface{}
if err = json.NewDecoder(c.Request.Body).Decode(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"code": common.CodeArgumentError,
"message": err.Error(),
})
return
}
status, ok := req["status"].(bool)
if !ok {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
"message": "Status must be a boolean.",
"data": nil,
})
return
}
ok, err = h.memoryService.UpdateMessage(c.Request.Context(), userID, memoryID, messageID, status)
if err != nil || !ok {
if isMemoryServiceNotFound(err) {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeNotFound,
"message": err.Error(),
"data": nil,
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": "Internal server error",
"data": nil,
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": "UpdateMessage not implemented - pending embedding engine dependency",
"code": common.CodeSuccess,
"message": true,
"data": nil,
})
}
// GetMessageContent handles GET request for getting message content
// API Path: GET /api/v1/messages/:memory_id/:message_id/content
//
// Function:
// - Gets complete content of the specified message
// - doc_id format: memory_id + "_" + message_id
//
// Parameter Format:
// - memory_id: Memory ID
// - message_id: Message ID (integer)
//
func (h *MemoryHandler) GetMessageContent(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := strings.TrimSpace(user.ID)
if userID == "" {
jsonError(c, common.CodeAuthenticationError, "user id is required")
return
}
memoryID, messageID, err := parseMemoryMessagePath(c.Param("memory_message"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
"message": err.Error(),
})
return
}
data, err := h.memoryService.GetMessageContent(c.Request.Context(), userID, memoryID, messageID)
if err != nil {
if _, ok := err.(*service.ResourceNotFoundError); ok {
jsonError(c, common.CodeNotFound, err.Error())
return
}
jsonError(c, common.CodeServerError, err.Error())
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"message": true,
"data": data,
})
}
// SearchMessage handles GET request for searching messages
// API Path: GET /api/v1/messages/search
//
@@ -672,13 +834,64 @@ func (h *MemoryHandler) UpdateMessage(c *gin.Context) {
// - agent_id (optional): Agent ID filter
// - session_id (optional): Session ID filter
// - user_id (optional): User ID filter
//
// TODO: Implementation pending - depends on embedding engine
func (h *MemoryHandler) SearchMessage(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := strings.TrimSpace(user.ID)
if userID == "" {
jsonError(c, common.CodeAuthenticationError, "user id is required")
return
}
var memoryIDs []string
values := c.QueryArray("memory_id")
for _, v := range values {
parts := strings.Split(v, ",")
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
memoryIDs = append(memoryIDs, p)
}
}
}
query := c.Query("query")
similarityThreshold, _ := strconv.ParseFloat(c.DefaultQuery("similarity_threshold", "0.2"), 64)
keywordsSimilarityWeight, _ := strconv.ParseFloat(c.DefaultQuery("keywords_similarity_weight", "0.7"), 64)
topN, _ := strconv.Atoi(c.DefaultQuery("top_n", "5"))
agentID := c.DefaultQuery("agent_id", "")
sessionID := c.DefaultQuery("session_id", "")
filterDict := map[string]interface{}{
"memory_id": memoryIDs,
"agent_id": agentID,
"session_id": sessionID,
"user_id": c.DefaultQuery("user_id", ""),
}
params := map[string]interface{}{
"query": query,
"similarity_threshold": similarityThreshold,
"keywords_similarity_weight": keywordsSimilarityWeight,
"top_n": topN,
}
res, code, err := h.memoryService.SearchMessage(c.Request.Context(), userID, filterDict, params)
if err != nil {
jsonError(c, code, err.Error())
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": "SearchMessage not implemented - pending embedding engine dependency",
"data": nil,
"code": common.CodeSuccess,
"message": true,
"data": res,
})
}
@@ -694,33 +907,49 @@ func (h *MemoryHandler) SearchMessage(c *gin.Context) {
// - agent_id (optional): Agent ID filter
// - session_id (optional): Session ID filter
// - limit (optional): Number of results to return, default 10
//
// TODO: Implementation pending - depends on embedding engine
func (h *MemoryHandler) GetMessages(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": "GetMessages not implemented - pending embedding engine dependency",
"data": nil,
})
}
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := strings.TrimSpace(user.ID)
if userID == "" {
jsonError(c, common.CodeAuthenticationError, "user id is required")
return
}
var memoryIDs []string
values := c.QueryArray("memory_id")
for _, v := range values {
parts := strings.Split(v, ",")
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
memoryIDs = append(memoryIDs, p)
}
}
}
agentID := c.DefaultQuery("agent_id", "")
sessionID := c.DefaultQuery("session_id", "")
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
if len(memoryIDs) == 0 {
jsonError(c, common.CodeArgumentError, "memory_ids is required.")
return
}
data, code, err := h.memoryService.GetMessages(c.Request.Context(), memoryIDs, userID, agentID, sessionID, limit)
if err != nil {
jsonError(c, code, err.Error())
return
}
// GetMessageContent handles GET request for getting message content
// API Path: GET /api/v1/messages/:memory_id/:message_id/content
//
// Function:
// - Gets complete content of the specified message
// - doc_id format: memory_id + "_" + message_id
//
// Parameter Format:
// - memory_id: Memory ID
// - message_id: Message ID (integer)
//
// TODO: Implementation pending - depends on embedding engine
func (h *MemoryHandler) GetMessageContent(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": "GetMessageContent not implemented - pending embedding engine dependency",
"data": nil,
"code": common.CodeSuccess,
"message": true,
"data": data,
})
}

View File

@@ -414,8 +414,11 @@ func (r *Router) Setup(engine *gin.Engine) {
// Message routes
message := v1.Group("/messages")
{
message.GET("", r.memoryHandler.GetMessages)
message.DELETE("/:memory_message", r.memoryHandler.ForgetMessage)
message.PUT("/:memory_message", r.memoryHandler.UpdateMessage)
message.GET("/:memory_message/content", r.memoryHandler.GetMessageContent)
message.GET("/search", r.memoryHandler.SearchMessage)
}
// Skill search routes

View File

@@ -20,8 +20,10 @@ import (
"context"
"errors"
"fmt"
"os"
"ragflow/internal/common"
"ragflow/internal/entity"
models "ragflow/internal/entity/models"
"strconv"
"strings"
"time"
@@ -29,6 +31,7 @@ import (
"ragflow/internal/dao"
"ragflow/internal/engine"
enginetypes "ragflow/internal/engine/types"
"ragflow/internal/service/nlp"
)
const (
@@ -55,6 +58,12 @@ const (
TenantPermissionAll TenantPermission = "all"
)
const (
defaultMessageTopN = 5
defaultMessageLimit = 10
maxMessageLimit = 100
)
// validPermissions defines which permission values are valid
var validPermissions = map[TenantPermission]bool{
TenantPermissionMe: true,
@@ -793,7 +802,7 @@ func (s *MemoryService) ForgetMessage(ctx context.Context, userID string, memory
condition := map[string]interface{}{
"id": messageDocID,
}
indexName := fmt.Sprintf("memory_%s", memory.TenantID)
indexName := memoryIndexName(memory.TenantID)
if err := s.docEngine.UpdateChunks(ctx, condition, updates, indexName, memoryID); err != nil {
if isMessageDocumentNotFound(err) {
@@ -807,6 +816,453 @@ func (s *MemoryService) ForgetMessage(ctx context.Context, userID string, memory
return nil
}
func (s *MemoryService) UpdateMessage(ctx context.Context, userID, memoryID string, messageID int64, status bool) (bool, error) {
memory, err := s.requireMemoryAccess(ctx, userID, memoryID)
if err != nil {
return false, err
}
if s.docEngine == nil {
return false, errors.New("message store is not initialized")
}
messageDocID := fmt.Sprintf("%s_%d", memoryID, messageID)
statusValue := 0
if status {
statusValue = 1
}
updates := map[string]interface{}{
"status": statusValue,
}
condition := map[string]interface{}{
"id": messageDocID,
}
indexName := memoryIndexName(memory.TenantID)
if err := s.docEngine.UpdateChunks(ctx, condition, updates, indexName, memoryID); err != nil {
if isMessageDocumentNotFound(err) {
return false, &ResourceNotFoundError{Resource: "Message", ID: messageDocID}
}
return false, fmt.Errorf("failed to set status for message '%d' in memory '%s': %w", messageID, memoryID, err)
}
return true, nil
}
func (s *MemoryService) GetMessageContent(ctx context.Context, userID, memoryID string, messageID int64) (map[string]interface{}, error) {
memory, err := s.requireMemoryAccess(ctx, userID, memoryID)
if err != nil {
return nil, err
}
if s.docEngine == nil {
return nil, errors.New("message store is not initialized")
}
indexName := memoryIndexName(memory.TenantID)
docID := fmt.Sprintf("%s_%d", memoryID, messageID)
res, err := s.docEngine.GetChunk(ctx, indexName, docID, []string{memoryID})
if err != nil {
if isMessageDocumentNotFound(err) {
return nil, &ResourceNotFoundError{Resource: "Message", ID: docID}
}
return nil, err
}
if res == nil {
return nil, &ResourceNotFoundError{Resource: "Message", ID: docID}
}
message, ok := res.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected message content type %T", res)
}
return common.ConvertFloatsToPyFormat(message).(map[string]interface{}), nil
}
func (s *MemoryService) SearchMessage(ctx context.Context, userID string, filterDict, params map[string]interface{}) ([]map[string]interface{}, common.ErrorCode, error) {
memoryIDs := splitFilterValues(filterDict["memory_id"])
memories, err := s.filterAccessibleMemories(ctx, userID, memoryIDs)
if err != nil {
return nil, common.CodeServerError, err
}
if len(memories) == 0 {
return []map[string]interface{}{}, common.CodeSuccess, nil
}
return s.queryMessage(ctx, memories, filterDict, params)
}
func (s *MemoryService) queryMessage(ctx context.Context, memories []*entity.Memory, filterDict, params map[string]interface{}) ([]map[string]interface{}, common.ErrorCode, error) {
if s.docEngine == nil {
return nil, common.CodeServerError, errors.New("message store is not initialized")
}
topN := memoryIntParam(params["top_n"], 5)
if topN <= 0 {
topN = defaultMessageTopN
} else if topN > maxMessageLimit {
topN = maxMessageLimit
}
similarityThreshold := memoryFloatParam(params["similarity_threshold"], 0.2)
keywordsSimilarityWeight := memoryFloatParam(params["keywords_similarity_weight"], 0.7)
question := strings.TrimSpace(memoryStringParam(params["query"]))
memoryIDs := make([]string, 0, len(memories))
conditionDict := make(map[string]interface{})
for _, memory := range memories {
if memory == nil {
continue
}
memoryIDs = append(memoryIDs, memory.ID)
}
conditionDict["memory_id"] = memoryIDs
for _, key := range []string{"agent_id", "session_id", "user_id"} {
value := strings.TrimSpace(memoryStringParam(filterDict[key]))
if value != "" {
conditionDict[key] = value
}
}
if _, ok := conditionDict["status"]; !ok {
conditionDict["status"] = 1
}
matchExprs := make([]interface{}, 0, 3)
if question != "" {
matchText := memoryMessageTextExpr(question, similarityThreshold)
matchDense, err := s.memoryMessageDenseExpr(question, memories[0], topN, similarityThreshold)
if err != nil {
return nil, common.CodeServerError, err
}
fusionExpr := &enginetypes.FusionExpr{
Method: "weighted_sum",
TopN: topN,
FusionParams: map[string]interface{}{
"weights": fmt.Sprintf("%g,%g", 1-keywordsSimilarityWeight, keywordsSimilarityWeight),
},
}
matchExprs = append(matchExprs, matchText, matchDense, fusionExpr)
}
searchReq := &enginetypes.SearchRequest{
IndexNames: memorySearchIndexNames(memories),
Offset: 0,
Limit: topN,
SelectFields: memoryMessageSelectFields(),
Filter: conditionDict,
MatchExprs: matchExprs,
OrderBy: (&enginetypes.OrderByExpr{}).Desc("valid_at"),
}
searchResult, err := s.docEngine.Search(ctx, searchReq)
if err != nil {
return nil, common.CodeServerError, err
}
if searchResult == nil || searchResult.Total == 0 {
return []map[string]interface{}{}, common.CodeSuccess, nil
}
messages := make([]map[string]interface{}, 0, len(searchResult.Chunks))
for _, chunk := range searchResult.Chunks {
message := make(map[string]interface{}, len(chunk))
for _, field := range memoryMessageSelectFields() {
if value, ok := chunk[field]; ok {
message[field] = value
}
}
messages = append(messages, message)
}
return common.ConvertFloatsToPyFormat(messages).([]map[string]interface{}), common.CodeSuccess, nil
}
func (s *MemoryService) filterAccessibleMemories(ctx context.Context, userID string, memoryIDs []string) ([]*entity.Memory, error) {
memoryIDs = splitFilterValues(memoryIDs)
if len(memoryIDs) == 0 {
return []*entity.Memory{}, nil
}
memories, err := s.memoryDAO.GetByIDs(memoryIDs)
if err != nil {
return nil, err
}
if len(memories) == 0 {
return []*entity.Memory{}, nil
}
joinedTenantIDs := map[string]struct{}{userID: {}}
needsTeamLookup := false
for _, memory := range memories {
if memory != nil && memory.TenantID != userID && memory.Permissions == string(TenantPermissionTeam) {
needsTeamLookup = true
break
}
}
if needsTeamLookup {
userTenants, err := NewUserTenantService().GetUserTenantRelationByUserIDWithContext(ctx, userID)
if err != nil {
return nil, err
}
for _, tenant := range userTenants {
if tenant != nil && tenant.TenantID != "" {
joinedTenantIDs[tenant.TenantID] = struct{}{}
}
}
}
accessible := make([]*entity.Memory, 0, len(memories))
for _, memory := range memories {
if memory == nil {
continue
}
if memory.TenantID == userID {
accessible = append(accessible, memory)
continue
}
if memory.Permissions != string(TenantPermissionTeam) {
continue
}
if _, ok := joinedTenantIDs[memory.TenantID]; ok {
accessible = append(accessible, memory)
}
}
return accessible, nil
}
func splitFilterValues(values interface{}) []string {
if values == nil {
return []string{}
}
var list []string
switch v := values.(type) {
case string:
list = []string{v}
case []string:
list = v
case []interface{}:
for _, x := range v {
if s, ok := x.(string); ok {
list = append(list, s)
}
}
default:
return []string{}
}
res := make([]string, 0)
for _, item := range list {
if item == "" {
continue
}
parts := strings.Split(item, ",")
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
res = append(res, p)
}
}
}
return res
}
func memoryStringParam(value interface{}) string {
switch typed := value.(type) {
case string:
return typed
case fmt.Stringer:
return typed.String()
default:
if typed == nil {
return ""
}
return fmt.Sprintf("%v", typed)
}
}
func memoryFloatParam(value interface{}, fallback float64) float64 {
switch typed := value.(type) {
case float64:
return typed
case float32:
return float64(typed)
case int:
return float64(typed)
case int64:
return float64(typed)
case string:
if parsed, err := strconv.ParseFloat(strings.TrimSpace(typed), 64); err == nil {
return parsed
}
}
return fallback
}
func memoryIntParam(value interface{}, fallback int) int {
switch typed := value.(type) {
case int:
return typed
case int64:
return int(typed)
case float64:
return int(typed)
case string:
if parsed, err := strconv.Atoi(strings.TrimSpace(typed)); err == nil {
return parsed
}
}
return fallback
}
func memoryMessageSelectFields() []string {
return []string{
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id",
"valid_at", "invalid_at", "forget_at", "status", "content",
}
}
func memoryIndexName(tenantID string) string {
prefix := strings.TrimSpace(os.Getenv("ES_INDEX_PREFIX"))
if prefix == "" {
return fmt.Sprintf("memory_%s", tenantID)
}
return fmt.Sprintf("memory_%s_%s", prefix, tenantID)
}
func memorySearchIndexNames(memories []*entity.Memory) []string {
seen := make(map[string]struct{}, len(memories))
indexNames := make([]string, 0, len(memories))
for _, memory := range memories {
if memory == nil {
continue
}
indexName := memoryIndexName(memory.TenantID)
if engine.GetEngineType() == engine.EngineInfinity {
indexName = fmt.Sprintf("%s_%s", indexName, memory.ID)
}
if _, ok := seen[indexName]; ok {
continue
}
seen[indexName] = struct{}{}
indexNames = append(indexNames, indexName)
}
return indexNames
}
func memoryMessageTextExpr(question string, similarityThreshold float64) *enginetypes.MatchTextExpr {
matchText := &enginetypes.MatchTextExpr{
Fields: []string{"content"},
MatchingText: question,
TopN: 100,
ExtraOptions: map[string]interface{}{"original_query": question},
}
queryBuilder := nlp.GetQueryBuilder()
if queryBuilder == nil {
queryBuilder = nlp.NewQueryBuilder()
}
if built, _ := queryBuilder.Question(question, "messages", similarityThreshold); built != nil {
matchText.MatchingText = built.MatchingText
matchText.ExtraOptions = built.ExtraOptions
if matchText.ExtraOptions == nil {
matchText.ExtraOptions = map[string]interface{}{}
}
matchText.ExtraOptions["original_query"] = question
}
matchText.Fields = []string{"content"}
matchText.TopN = 100
return matchText
}
func (s *MemoryService) memoryMessageDenseExpr(question string, memory *entity.Memory, topN int, similarityThreshold float64) (*enginetypes.MatchDenseExpr, error) {
driver, modelName, apiConfig, maxTokens, err := NewModelProviderService().GetModelConfigFromProviderInstance(memory.TenantID, entity.ModelTypeEmbedding, memory.EmbdID)
if err != nil {
return nil, err
}
embeddingModel := models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens)
embeddings, err := embeddingModel.ModelDriver.Embed(embeddingModel.ModelName, []string{question}, embeddingModel.APIConfig, &models.EmbeddingConfig{Dimension: 0})
if err != nil {
return nil, err
}
if len(embeddings) == 0 || len(embeddings[0].Embedding) == 0 {
return nil, errors.New("embedding response is empty")
}
vector := embeddings[0].Embedding
return &enginetypes.MatchDenseExpr{
VectorColumnName: fmt.Sprintf("q_%d_vec", len(vector)),
EmbeddingData: vector,
EmbeddingDataType: "float",
DistanceType: "cosine",
TopN: topN,
ExtraOptions: map[string]interface{}{"similarity": similarityThreshold},
}, nil
}
func (s *MemoryService) GetMessages(ctx context.Context, memoryIDs []string, userID, agentID, sessionID string, limit int) ([]map[string]interface{}, common.ErrorCode, error) {
memories, err := s.filterAccessibleMemories(ctx, userID, memoryIDs)
if err != nil {
return nil, common.CodeServerError, err
}
if len(memories) == 0 {
return []map[string]interface{}{}, common.CodeSuccess, nil
}
return s.getRecentMessage(ctx, memories, agentID, sessionID, limit)
}
func (s *MemoryService) getRecentMessage(ctx context.Context, memories []*entity.Memory, agentID, sessionID string, limit int) ([]map[string]interface{}, common.ErrorCode, error) {
if s.docEngine == nil {
return nil, common.CodeServerError, errors.New("doc engine is nil")
}
if limit <= 0 {
limit = defaultMessageLimit
} else if limit > maxMessageLimit {
limit = maxMessageLimit
}
indexNames := memorySearchIndexNames(memories)
memoryIDs := make([]string, 0, len(memories))
for _, memory := range memories {
if memory == nil || strings.TrimSpace(memory.ID) == "" {
continue
}
memoryIDs = append(memoryIDs, memory.ID)
}
conditionDict := map[string]interface{}{"memory_id": memoryIDs}
if agentID = strings.TrimSpace(agentID); agentID != "" {
conditionDict["agent_id"] = agentID
}
if sessionID = strings.TrimSpace(sessionID); sessionID != "" {
conditionDict["session_id"] = sessionID
}
req := &enginetypes.SearchRequest{
IndexNames: indexNames,
Offset: 0,
Limit: limit,
SelectFields: memoryMessageSelectFields(),
Filter: conditionDict,
MatchExprs: []interface{}{},
OrderBy: (&enginetypes.OrderByExpr{}).Desc("valid_at"),
}
result, err := s.docEngine.Search(ctx, req)
if err != nil {
return nil, common.CodeServerError, err
}
if result == nil || result.Total == 0 {
return []map[string]interface{}{}, common.CodeSuccess, nil
}
messages := make([]map[string]interface{}, 0, len(result.Chunks))
for _, chunk := range result.Chunks {
msg := make(map[string]interface{}, len(chunk))
for _, field := range memoryMessageSelectFields() {
if val, ok := chunk[field]; ok {
msg[field] = val
}
}
messages = append(messages, msg)
}
return common.ConvertFloatsToPyFormat(messages).([]map[string]interface{}), common.CodeSuccess, nil
}
func isMessageDocumentNotFound(err error) bool {
return errors.Is(err, enginetypes.ErrDocumentNotFound)
}
@@ -933,8 +1389,266 @@ func (s *MemoryService) GetMemoryConfig(memoryID string) (*CreateMemoryResponse,
return formatRetDataFromMemoryListItem(memory), nil
}
// TODO: GetMemoryMessages - Implementation pending - depends on CanvasService and TaskService
// func (s *MemoryService) GetMemoryMessages(memoryID string, agentIDs []string, keywords string, page int, pageSize int) (map[string]interface{}, error) { ... }
func (s *MemoryService) GetMemoryMessages(ctx context.Context, userID, memoryID string, agentIDs []string, keywords string, page int, pageSize int) (map[string]interface{}, error) {
memory, err := s.requireMemoryAccess(ctx, userID, memoryID)
if err != nil {
return nil, err
}
messages, err := s.listMemoryMessages(ctx, memory, agentIDs, keywords, page, pageSize)
if err != nil {
return nil, err
}
rawMessages, _ := messages["message_list"].([]map[string]interface{})
agentNames := map[string]string{}
tasks := map[string]map[string]interface{}{}
if len(rawMessages) > 0 {
agentNames, err = s.memoryMessageAgentNames(rawMessages)
if err != nil {
return nil, err
}
tasks, err = s.memoryMessageTasks(memoryID)
if err != nil {
return nil, err
}
}
for _, message := range rawMessages {
agentID, _ := message["agent_id"].(string)
message["agent_name"] = "Unknown"
if name, ok := agentNames[agentID]; ok {
message["agent_name"] = name
}
message["task"] = map[string]interface{}{}
if task, ok := tasks[memoryMessageKey(message["message_id"])]; ok {
message["task"] = task
}
if extracts, ok := message["extract"].([]map[string]interface{}); ok {
for _, extract := range extracts {
extractAgentID, _ := extract["agent_id"].(string)
extract["agent_name"] = "Unknown"
if name, ok := agentNames[extractAgentID]; ok {
extract["agent_name"] = name
}
}
}
}
return common.ConvertFloatsToPyFormat(map[string]interface{}{
"messages": messages,
"storage_type": memory.StorageType,
}).(map[string]interface{}), nil
}
func (s *MemoryService) listMemoryMessages(ctx context.Context, memory *entity.Memory, agentIDs []string, keywords string, page int, pageSize int) (map[string]interface{}, error) {
if s.docEngine == nil {
return nil, errors.New("message store is not initialized")
}
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = defaultMessageLimit
} else if pageSize > maxMessageLimit {
pageSize = maxMessageLimit
}
memoryID := memory.ID
selectFields := memoryMessageListFields()
filter := map[string]interface{}{
"message_type": "raw",
}
if len(agentIDs) > 0 {
filter["agent_id"] = agentIDs
}
if keywords = strings.TrimSpace(keywords); keywords != "" {
filter["session_id"] = keywords
}
filter["memory_id"] = []string{memoryID}
indexNames := memorySearchIndexNames([]*entity.Memory{memory})
rawReq := &enginetypes.SearchRequest{
IndexNames: indexNames,
Offset: (page - 1) * pageSize,
Limit: pageSize,
SelectFields: selectFields,
Filter: filter,
MatchExprs: []interface{}{},
OrderBy: (&enginetypes.OrderByExpr{}).Desc("valid_at"),
}
rawResult, err := s.docEngine.Search(ctx, rawReq)
if err != nil {
return nil, err
}
messages := map[string]interface{}{
"message_list": []map[string]interface{}{},
"total_count": int64(0),
}
if rawResult != nil {
messages["total_count"] = rawResult.Total
}
if rawResult == nil || rawResult.Total == 0 {
return messages, nil
}
rawMessages := make([]map[string]interface{}, 0, len(rawResult.Chunks))
sourceIDs := make([]interface{}, 0, len(rawResult.Chunks))
for _, chunk := range rawResult.Chunks {
message := memoryMessageFromChunk(chunk, selectFields)
message["extract"] = []map[string]interface{}{}
if messageID, ok := message["message_id"]; ok {
sourceIDs = append(sourceIDs, messageID)
}
rawMessages = append(rawMessages, message)
}
if len(sourceIDs) > 0 {
extractReq := &enginetypes.SearchRequest{
IndexNames: indexNames,
Offset: 0,
Limit: 512,
SelectFields: selectFields,
Filter: map[string]interface{}{
"memory_id": []string{memoryID},
"source_id": sourceIDs,
},
MatchExprs: []interface{}{},
OrderBy: (&enginetypes.OrderByExpr{}).Desc("valid_at"),
}
extractResult, err := s.docEngine.Search(ctx, extractReq)
if err != nil {
return nil, err
}
if extractResult != nil && extractResult.Total > 0 {
groupedExtracts := make(map[string][]map[string]interface{})
for _, chunk := range extractResult.Chunks {
message := memoryMessageFromChunk(chunk, selectFields)
sourceID := memoryMessageKey(message["source_id"])
groupedExtracts[sourceID] = append(groupedExtracts[sourceID], message)
}
for _, message := range rawMessages {
messageID := memoryMessageKey(message["message_id"])
if extracts, ok := groupedExtracts[messageID]; ok {
message["extract"] = extracts
}
}
}
}
messages["message_list"] = rawMessages
return messages, nil
}
func memoryMessageListFields() []string {
return []string{
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id",
"valid_at", "invalid_at", "forget_at", "status",
}
}
func memoryMessageFromChunk(chunk map[string]interface{}, fields []string) map[string]interface{} {
message := make(map[string]interface{}, len(fields))
for _, field := range fields {
if value, ok := chunk[field]; ok {
message[field] = value
}
}
return message
}
func (s *MemoryService) memoryMessageAgentNames(messages []map[string]interface{}) (map[string]string, error) {
agentIDSet := make(map[string]struct{})
for _, message := range messages {
agentID, _ := message["agent_id"].(string)
if agentID != "" {
agentIDSet[agentID] = struct{}{}
}
}
if len(agentIDSet) == 0 {
return map[string]string{}, nil
}
agentIDs := make([]string, 0, len(agentIDSet))
for agentID := range agentIDSet {
agentIDs = append(agentIDs, agentID)
}
var agents []struct {
ID string `gorm:"column:id"`
Title *string `gorm:"column:title"`
}
if err := dao.DB.Model(&entity.UserCanvas{}).Select("id, title").Where("id IN ?", agentIDs).Scan(&agents).Error; err != nil {
return nil, err
}
agentNames := make(map[string]string, len(agents))
for _, agent := range agents {
if agent.Title != nil {
agentNames[agent.ID] = *agent.Title
}
}
return agentNames, nil
}
func (s *MemoryService) memoryMessageTasks(memoryID string) (map[string]map[string]interface{}, error) {
var tasks []struct {
ID string `gorm:"column:id"`
DocID string `gorm:"column:doc_id"`
FromPage int64 `gorm:"column:from_page"`
Progress float64 `gorm:"column:progress"`
ProgressMsg *string `gorm:"column:progress_msg"`
Digest *string `gorm:"column:digest"`
ChunkIDs *string `gorm:"column:chunk_ids"`
CreateTime *int64 `gorm:"column:create_time"`
}
if err := dao.DB.Model(&entity.Task{}).
Select("id, doc_id, from_page, progress, progress_msg, digest, chunk_ids, create_time").
Where("doc_id IN ?", []string{memoryID}).
Order("create_time ASC").
Scan(&tasks).Error; err != nil {
return nil, err
}
taskByMessageID := make(map[string]map[string]interface{}, len(tasks))
for _, task := range tasks {
if task.Digest == nil {
continue
}
digest := strings.TrimSpace(*task.Digest)
if digest == "" {
continue
}
var progressMsg interface{}
if task.ProgressMsg != nil {
progressMsg = *task.ProgressMsg
}
var chunkIDs interface{}
if task.ChunkIDs != nil {
chunkIDs = *task.ChunkIDs
}
var createTime interface{}
if task.CreateTime != nil {
createTime = *task.CreateTime
}
taskMap := map[string]interface{}{
"id": task.ID,
"doc_id": task.DocID,
"from_page": task.FromPage,
"progress": task.Progress,
"progress_msg": progressMsg,
"digest": digest,
"chunk_ids": chunkIDs,
"create_time": createTime,
}
taskByMessageID[digest] = taskMap
}
return taskByMessageID, nil
}
func memoryMessageKey(value interface{}) string {
return strings.TrimSpace(fmt.Sprint(value))
}
// TODO: queryMessages - Implementation pending - depends on CanvasService and TaskService
// func (s *MemoryService) queryMessages(tenantID string, memoryID string, filterDict map[string]interface{}, page int, pageSize int) ([]map[string]interface{}, int64, error) { ... }
@@ -945,15 +1659,6 @@ func (s *MemoryService) GetMemoryConfig(memoryID string) (*CreateMemoryResponse,
// TODO: UpdateMessageStatus - Implementation pending - depends on embedding engine
// func (s *MemoryService) UpdateMessageStatus(memoryID string, messageID int, status bool) (bool, error) { ... }
// TODO: SearchMessage - Implementation pending - depends on embedding engine
// func (s *MemoryService) SearchMessage(filterDict map[string]interface{}, params map[string]interface{}) ([]map[string]interface{}, error) { ... }
// TODO: GetMessages - Implementation pending - depends on embedding engine
// func (s *MemoryService) GetMessages(memoryIDs []string, agentID string, sessionID string, limit int) ([]map[string]interface{}, error) { ... }
// TODO: GetMessageContent - Implementation pending - depends on embedding engine
// func (s *MemoryService) GetMessageContent(memoryID string, messageID int) (map[string]interface{}, error) { ... }
// isList checks if a value is a list or array type
// This is a utility function for type validation
//

View File

@@ -4,9 +4,16 @@ import (
"context"
"errors"
"fmt"
"reflect"
"testing"
"github.com/glebarez/sqlite"
"gorm.io/gorm"
"ragflow/internal/common"
"ragflow/internal/dao"
enginetypes "ragflow/internal/engine/types"
"ragflow/internal/entity"
)
func TestIsMessageDocumentNotFound(t *testing.T) {
@@ -28,3 +35,240 @@ func TestRequireMemoryAccessReturnsCanceledContext(t *testing.T) {
t.Fatalf("requireMemoryAccess error = %v, want %v", gotErr, err)
}
}
type memoryMessageDocEngine struct {
fakeChatDocEngine
searchReq *enginetypes.SearchRequest
searchResp *enginetypes.SearchResult
updateCond map[string]interface{}
updateValue map[string]interface{}
updateBase string
updateID string
}
func (e *memoryMessageDocEngine) Search(ctx context.Context, req *enginetypes.SearchRequest) (*enginetypes.SearchResult, error) {
e.searchReq = req
if e.searchResp != nil {
return e.searchResp, nil
}
return &enginetypes.SearchResult{}, nil
}
func (e *memoryMessageDocEngine) UpdateChunks(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, baseName string, datasetID string) error {
e.updateCond = condition
e.updateValue = newValue
e.updateBase = baseName
e.updateID = datasetID
return nil
}
func setupMemoryMessageTestDB(t *testing.T) {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{TranslateError: true})
if err != nil {
t.Fatalf("failed to open sqlite: %v", err)
}
if err := db.AutoMigrate(&entity.Memory{}, &entity.UserTenant{}); err != nil {
t.Fatalf("failed to migrate memory test tables: %v", err)
}
orig := dao.DB
dao.DB = db
t.Cleanup(func() {
dao.DB = orig
})
}
func seedMemoryMessages(t *testing.T) {
t.Helper()
memories := []*entity.Memory{
{
ID: "mem-owned",
Name: "Owned",
TenantID: "user-1",
MemoryType: dao.MemoryTypeRaw,
StorageType: "table",
EmbdID: "embd-1",
LLMID: "llm-1",
Permissions: string(TenantPermissionMe),
ForgettingPolicy: string(ForgettingPolicyFIFO),
},
{
ID: "mem-other",
Name: "Other",
TenantID: "user-2",
MemoryType: dao.MemoryTypeRaw,
StorageType: "table",
EmbdID: "embd-2",
LLMID: "llm-2",
Permissions: string(TenantPermissionMe),
ForgettingPolicy: string(ForgettingPolicyFIFO),
},
}
for _, memory := range memories {
if err := dao.DB.Create(memory).Error; err != nil {
t.Fatalf("seed memory %s: %v", memory.ID, err)
}
}
}
func TestGetMessagesFiltersAccessibleMemoryAndBuildsRecentSearch(t *testing.T) {
setupMemoryMessageTestDB(t)
seedMemoryMessages(t)
docEngine := &memoryMessageDocEngine{
searchResp: &enginetypes.SearchResult{
Total: 1,
Chunks: []map[string]interface{}{
{
"message_id": int64(12),
"message_type": "raw",
"memory_id": "mem-owned",
"user_id": "user-1",
"agent_id": "agent-1",
"session_id": "session-1",
"valid_at": float64(123),
"status": 1,
"content": "hello",
"extra": "should be dropped",
},
},
},
}
svc := &MemoryService{memoryDAO: dao.NewMemoryDAO(), docEngine: docEngine}
got, code, err := svc.GetMessages(context.Background(), []string{"mem-owned", "mem-other"}, "user-1", "agent-1", "session-1", 3)
if err != nil {
t.Fatalf("GetMessages error: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("code = %v, want %v", code, common.CodeSuccess)
}
if len(got) != 1 || got[0]["content"] != "hello" {
t.Fatalf("unexpected messages: %+v", got)
}
if _, ok := got[0]["extra"]; ok {
t.Fatalf("unexpected non-selected field in response: %+v", got[0])
}
req := docEngine.searchReq
if req == nil {
t.Fatal("expected doc engine search request")
}
if !reflect.DeepEqual(req.IndexNames, []string{"memory_user-1"}) {
t.Fatalf("IndexNames = %v, want [memory_user-1]", req.IndexNames)
}
if len(req.KbIDs) != 0 {
t.Fatalf("KbIDs = %v, want empty for memory message search", req.KbIDs)
}
if !reflect.DeepEqual(req.Filter["memory_id"], []string{"mem-owned"}) {
t.Fatalf("memory_id filter = %v, want [mem-owned]", req.Filter["memory_id"])
}
if req.Filter["agent_id"] != "agent-1" || req.Filter["session_id"] != "session-1" {
t.Fatalf("unexpected filter: %+v", req.Filter)
}
if req.Limit != 3 {
t.Fatalf("Limit = %d, want 3", req.Limit)
}
if req.OrderBy == nil || len(req.OrderBy.Fields) != 1 || req.OrderBy.Fields[0].Field != "valid_at" || req.OrderBy.Fields[0].Type != enginetypes.SortDesc {
t.Fatalf("unexpected order by: %+v", req.OrderBy)
}
}
func TestSearchMessageFiltersAccessibleMemoryAndDefaultsStatus(t *testing.T) {
setupMemoryMessageTestDB(t)
seedMemoryMessages(t)
docEngine := &memoryMessageDocEngine{
searchResp: &enginetypes.SearchResult{
Total: 1,
Chunks: []map[string]interface{}{
{
"message_id": int64(13),
"message_type": "raw",
"memory_id": "mem-owned",
"user_id": "user-1",
"agent_id": "agent-1",
"session_id": "session-1",
"valid_at": int64(456),
"status": 1,
"content": "matched",
},
},
},
}
svc := &MemoryService{memoryDAO: dao.NewMemoryDAO(), docEngine: docEngine}
filter := map[string]interface{}{
"memory_id": []string{"mem-owned", "mem-other"},
"agent_id": "agent-1",
"session_id": "session-1",
"user_id": "user-1",
}
params := map[string]interface{}{
"query": "",
"similarity_threshold": 0.2,
"keywords_similarity_weight": 0.7,
"top_n": 5,
}
got, code, err := svc.SearchMessage(context.Background(), "user-1", filter, params)
if err != nil {
t.Fatalf("SearchMessage error: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("code = %v, want %v", code, common.CodeSuccess)
}
if len(got) != 1 || got[0]["content"] != "matched" {
t.Fatalf("unexpected search result: %+v", got)
}
req := docEngine.searchReq
if req == nil {
t.Fatal("expected doc engine search request")
}
if !reflect.DeepEqual(req.Filter["memory_id"], []string{"mem-owned"}) {
t.Fatalf("memory_id filter = %v, want [mem-owned]", req.Filter["memory_id"])
}
if req.Filter["status"] != 1 {
t.Fatalf("status filter = %v, want 1", req.Filter["status"])
}
if req.Filter["agent_id"] != "agent-1" || req.Filter["session_id"] != "session-1" || req.Filter["user_id"] != "user-1" {
t.Fatalf("unexpected filter: %+v", req.Filter)
}
if len(req.MatchExprs) != 0 {
t.Fatalf("empty query should not build match expressions, got %+v", req.MatchExprs)
}
if req.Limit != 5 {
t.Fatalf("Limit = %d, want 5", req.Limit)
}
}
func TestUpdateMessageUpdatesStatusByMessageDocID(t *testing.T) {
setupMemoryMessageTestDB(t)
seedMemoryMessages(t)
docEngine := &memoryMessageDocEngine{}
svc := &MemoryService{memoryDAO: dao.NewMemoryDAO(), docEngine: docEngine}
ok, err := svc.UpdateMessage(context.Background(), "user-1", "mem-owned", 42, true)
if err != nil {
t.Fatalf("UpdateMessage error: %v", err)
}
if !ok {
t.Fatal("UpdateMessage returned false")
}
if docEngine.updateBase != "memory_user-1" {
t.Fatalf("baseName = %q, want memory_user-1", docEngine.updateBase)
}
if docEngine.updateID != "mem-owned" {
t.Fatalf("datasetID = %q, want mem-owned", docEngine.updateID)
}
if docEngine.updateCond["id"] != "mem-owned_42" {
t.Fatalf("condition = %+v, want id mem-owned_42", docEngine.updateCond)
}
if docEngine.updateValue["status"] != 1 {
t.Fatalf("status update = %+v, want status 1", docEngine.updateValue)
}
}