From 74597b86831dc8b3bf4aa9b1a69cfe59296aea91 Mon Sep 17 00:00:00 2001 From: Haruko386 Date: Thu, 25 Jun 2026 19:07:34 +0800 Subject: [PATCH] feat[Go]: implemet api: Search/Get/Update-Messages (#16307) ### What problem does this PR solve? As title: implement: ``` /api/v1/messages/search GET /api/v1/messages GET /api/v1/messages/:/content GET /api/v1/memories//config GET /api/v1/messages/: PUT ``` ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- internal/engine/elasticsearch/chunk.go | 261 ++++++- internal/engine/elasticsearch/chunk_test.go | 2 +- internal/engine/infinity/chunk.go | 48 ++ internal/handler/memory.go | 307 +++++++-- internal/router/router.go | 3 + internal/service/memory.go | 729 +++++++++++++++++++- internal/service/memory_message_test.go | 244 +++++++ 7 files changed, 1524 insertions(+), 70 deletions(-) diff --git a/internal/engine/elasticsearch/chunk.go b/internal/engine/elasticsearch/chunk.go index 8a3559be19..299b326a1f 100644 --- a/internal/engine/elasticsearch/chunk.go +++ b/internal/engine/elasticsearch/chunk.go @@ -32,11 +32,12 @@ import ( "strconv" "strings" - "github.com/elastic/go-elasticsearch/v8/esapi" - "github.com/json-iterator/go" "ragflow/internal/common" "ragflow/internal/engine/types" + "github.com/elastic/go-elasticsearch/v8/esapi" + "github.com/json-iterator/go" + "go.uber.org/zap" ) @@ -44,6 +45,8 @@ var jsonIterator = jsoniter.Config{ SortMapKeys: false, }.Froze() +var memoryMessageVectorFieldRE = regexp.MustCompile(`^q_\d+_vec$`) + var ( elasticsearchHighlightEmTagRE = regexp.MustCompile(`[^<>]+`) elasticsearchHighlightNewlineRE = regexp.MustCompile(`[\r\n]`) @@ -233,6 +236,14 @@ func (e *elasticsearchEngine) UpdateChunks(ctx context.Context, condition map[st return fmt.Errorf("index '%s' does not exist", fullIndexName) } + if strings.HasPrefix(fullIndexName, "memory_") { + condition["memory_id"] = datasetID + if messageDocID, ok := condition["id"].(string); ok { + return e.updateSingleMemoryMessage(ctx, fullIndexName, messageDocID, newValue) + } + return e.updateChunksByQuery(ctx, fullIndexName, mapMemoryMessageESConditionFields(condition), mapMemoryMessageESUpdateFields(newValue)) + } + // Add kb_id to condition condition["kb_id"] = datasetID @@ -245,6 +256,59 @@ func (e *elasticsearchEngine) UpdateChunks(ctx context.Context, condition map[st return e.updateChunksByQuery(ctx, fullIndexName, condition, newValue) } +func (e *elasticsearchEngine) updateSingleMemoryMessage(ctx context.Context, indexName, messageDocID string, newValue map[string]interface{}) error { + doc := mapMemoryMessageESUpdateFields(newValue) + delete(doc, "id") + if len(doc) == 0 { + return nil + } + + updateBody := map[string]interface{}{"doc": doc} + body, err := json.Marshal(updateBody) + if err != nil { + return fmt.Errorf("failed to marshal memory message update request: %w", err) + } + req := esapi.UpdateRequest{ + Index: indexName, + DocumentID: messageDocID, + Body: bytes.NewReader(body), + } + res, err := req.Do(ctx, e.client) + if err != nil { + return fmt.Errorf("failed to update memory message: %w", err) + } + defer res.Body.Close() + if res.IsError() { + if res.StatusCode == http.StatusNotFound { + return fmt.Errorf("%w: %s", types.ErrDocumentNotFound, messageDocID) + } + bodyBytes, _ := io.ReadAll(res.Body) + return fmt.Errorf("elasticsearch memory message update error: %s, body: %s", res.Status(), string(bodyBytes)) + } + return nil +} + +func mapMemoryMessageESUpdateFields(newValue map[string]interface{}) map[string]interface{} { + doc := make(map[string]interface{}, len(newValue)) + for k, v := range newValue { + switch k { + case "remove", "add": + doc[k] = v + default: + doc[mapMemoryMessageESField(k, false)] = v + } + } + return doc +} + +func mapMemoryMessageESConditionFields(condition map[string]interface{}) map[string]interface{} { + mapped := make(map[string]interface{}, len(condition)) + for k, v := range condition { + mapped[mapMemoryMessageESField(k, false)] = v + } + return mapped +} + // updateSingleChunk handles single document update func (e *elasticsearchEngine) updateSingleChunk(ctx context.Context, indexName, chunkID string, newValue map[string]interface{}) error { common.Debug("ElasticsearchConnection.updateSingleChunk called", zap.String("indexName", indexName), zap.String("chunkID", chunkID)) @@ -762,15 +826,18 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque // Detect index types isSkillIndex := false + isMemoryIndex := false for _, idx := range req.IndexNames { if strings.HasPrefix(idx, "skill_") { isSkillIndex = true - break + } + if strings.HasPrefix(idx, "memory_") { + isMemoryIndex = true } } // Build bool query from condition - boolQuery := buildBoolQueryFromCondition(req.Filter, req.KbIDs, isSkillIndex) + boolQuery := buildBoolQueryFromCondition(req.Filter, req.KbIDs, isSkillIndex, isMemoryIndex) // Extract vector_similarity_weight from FusionExpr var matchText *types.MatchTextExpr @@ -816,7 +883,7 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque queryBody := make(map[string]interface{}) if matchText != nil { - textQuery := buildQueryStringQuery(matchText, vectorSimilarityWeight, isSkillIndex) + textQuery := buildQueryStringQuery(matchText, vectorSimilarityWeight, isSkillIndex, isMemoryIndex) if boolQuery != nil { if boolMap, ok := boolQuery["bool"].(map[string]interface{}); ok { if must, ok := boolMap["must"].([]interface{}); ok { @@ -874,7 +941,7 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque } // Add rank_feature queries - if req.RankFeature != nil && len(req.RankFeature) > 0 && !isSkillIndex { + if req.RankFeature != nil && len(req.RankFeature) > 0 && !isSkillIndex && !isMemoryIndex { rankFeatureQuery := buildRankFeatureQuery(req.RankFeature) if rankFeatureQuery != nil { if boolQuery, ok := queryBody["query"].(map[string]interface{}); ok { @@ -929,21 +996,25 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque // Set _source and fields for vector fields hasTextMatch := matchText != nil + selectFields := req.SelectFields + if isMemoryIndex { + selectFields = mapMemoryMessageESFields(req.SelectFields, false) + } if len(req.SelectFields) > 0 { // Use caller-specified fields, add pagerank_fld/tag_fld if needed - queryBody["_source"] = req.SelectFields + queryBody["_source"] = selectFields if hasTextMatch || hasVectorMatch { - if !isSkillIndex { - if !slices.Contains(req.SelectFields, common.PAGERANK_FLD) { + if !isSkillIndex && !isMemoryIndex { + if !slices.Contains(selectFields, common.PAGERANK_FLD) { queryBody["_source"] = append(queryBody["_source"].([]string), common.PAGERANK_FLD) } - if !slices.Contains(req.SelectFields, common.TAG_FLD) { + if !slices.Contains(selectFields, common.TAG_FLD) { queryBody["_source"] = append(queryBody["_source"].([]string), common.TAG_FLD) } } } var vectorFields []string - for _, f := range req.SelectFields { + for _, f := range selectFields { if strings.HasSuffix(f, "_vec") { vectorFields = append(vectorFields, f) } @@ -954,7 +1025,7 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque } else { // No explicit SelectFields - use match_all, but add pagerank_fld/tag_fld for scoring if needed if hasTextMatch || hasVectorMatch { - if !isSkillIndex { + if !isSkillIndex && !isMemoryIndex { queryBody["_source"] = []string{common.PAGERANK_FLD, common.TAG_FLD} } } @@ -1021,6 +1092,10 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque } } + if isMemoryIndex { + normalizeMemoryMessageChunks(allResults) + } + // Post-processing: Sort results by score if len(allResults) > 0 && (matchText != nil || hasVectorMatch) { scoreColumn := "_score" @@ -1032,6 +1107,9 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque if isSkillIndex { pagerankField = "" } + if isMemoryIndex { + pagerankField = "" + } allResults = calculateScores(allResults, scoreColumn, pagerankField) allResults = sortByScore(allResults, limit) @@ -1293,17 +1371,105 @@ func sortValuesEqual(a, b []interface{}) bool { return true } -// buildBoolQueryFromCondition builds an ES bool query from condition map -// For skill index, uses 'status' field instead of 'available_int' -func buildBoolQueryFromCondition(filter map[string]interface{}, kbIDs []string, isSkillIndex bool) map[string]interface{} { +func mapMemoryMessageESField(field string, useTokenizedContent bool) string { + name := field + boost := "" + if base, suffix, ok := strings.Cut(field, "^"); ok { + name = base + boost = "^" + suffix + } + + switch name { + case "message_type": + name = "message_type_kwd" + case "status": + name = "status_int" + case "content": + if useTokenizedContent { + name = "tokenized_content_ltks" + } else { + name = "content_ltks" + } + } + return name + boost +} + +func mapMemoryMessageESFields(fields []string, useTokenizedContent bool) []string { + if len(fields) == 0 { + return fields + } + mapped := make([]string, 0, len(fields)) + seen := make(map[string]struct{}, len(fields)) + for _, field := range fields { + mappedField := mapMemoryMessageESField(field, useTokenizedContent) + if _, ok := seen[mappedField]; ok { + continue + } + seen[mappedField] = struct{}{} + mapped = append(mapped, mappedField) + } + return mapped +} + +func normalizeMemoryMessageChunks(chunks []map[string]interface{}) { + for _, chunk := range chunks { + for key, val := range chunk { + if memoryMessageVectorFieldRE.MatchString(key) { + chunk["content_embed"] = val + delete(chunk, key) + } + } + if val, ok := chunk["message_type_kwd"]; ok { + chunk["message_type"] = val + delete(chunk, "message_type_kwd") + } + if val, ok := chunk["status_int"]; ok { + chunk["status"] = memoryMessageStatusBool(val) + delete(chunk, "status_int") + } + if val, ok := chunk["content_ltks"]; ok { + chunk["content"] = val + delete(chunk, "content_ltks") + } + } +} + +func memoryMessageStatusBool(value interface{}) bool { + switch v := value.(type) { + case bool: + return v + case int: + return v != 0 + case int64: + return v != 0 + case float64: + return v != 0 + case json.Number: + n, err := v.Int64() + return err == nil && n != 0 + case string: + return v != "" && v != "0" && !strings.EqualFold(v, "false") + default: + return false + } +} + +// buildBoolQueryFromCondition builds an ES bool query from condition map. +// Skill indexes use status, regular chunk indexes use kb_id, and memory +// message indexes use memory_id plus message-specific storage fields. +func buildBoolQueryFromCondition(filter map[string]interface{}, kbIDs []string, isSkillIndex, isMemoryIndex bool) map[string]interface{} { var mustClauses []interface{} var filterClauses []interface{} var shouldClauses []interface{} - // Add kb_id to condition + // Memory message indexes use memory_id, regular chunk indexes use kb_id. if kbIDs != nil && len(kbIDs) > 0 { + fieldName := "kb_id" + if isMemoryIndex { + fieldName = "memory_id" + } filterClauses = append(filterClauses, map[string]interface{}{ - "terms": map[string]interface{}{"kb_id": kbIDs}, + "terms": map[string]interface{}{fieldName: kbIDs}, }) } @@ -1321,6 +1487,9 @@ func buildBoolQueryFromCondition(filter map[string]interface{}, kbIDs []string, } for k, v := range filter { + if isMemoryIndex { + k = mapMemoryMessageESField(k, false) + } // For skill index, handle 'status' field instead of 'available_int' if isSkillIndex && k == "status" { if v == nil || v == "" { @@ -1389,6 +1558,18 @@ func buildBoolQueryFromCondition(filter map[string]interface{}, kbIDs []string, if v == nil || v == "" { continue } + if isMemoryIndex && k == "session_id" { + if strVal, ok := v.(string); ok && strVal != "" { + filterClauses = append(filterClauses, map[string]interface{}{ + "query_string": map[string]interface{}{ + "query": fmt.Sprintf("*%s*", strVal), + "fields": []string{"session_id"}, + "analyze_wildcard": true, + }, + }) + continue + } + } if listVal, ok := v.([]interface{}); ok { filterClauses = append(filterClauses, map[string]interface{}{ "terms": map[string]interface{}{k: listVal}, @@ -1435,7 +1616,7 @@ func buildBoolQueryFromCondition(filter map[string]interface{}, kbIDs []string, // buildQueryStringQuery builds a query_string query from MatchTextExpr // When isSkillIndex is true, uses skill-specific fields (name_tks, tags_tks, etc.) // Otherwise uses document fields (title_tks, content_ltks, etc.) -func buildQueryStringQuery(matchText *types.MatchTextExpr, vectorSimilarityWeight float64, isSkillIndex bool) map[string]interface{} { +func buildQueryStringQuery(matchText *types.MatchTextExpr, vectorSimilarityWeight float64, isSkillIndex, isMemoryIndex bool) map[string]interface{} { if matchText == nil { return nil } @@ -1451,10 +1632,15 @@ func buildQueryStringQuery(matchText *types.MatchTextExpr, vectorSimilarityWeigh if fields == nil || len(fields) == 0 { if isSkillIndex { fields = []string{"name_tks^10", "tags_tks^5", "description_tks^3", "content_tks^1"} + } else if isMemoryIndex { + fields = []string{"tokenized_content_ltks"} } else { fields = []string{"title_tks^10", "title_sm_tks^5", "important_kwd^30", "important_tks^20", "question_tks^20", "content_ltks^2", "content_sm_ltks"} } } + if isMemoryIndex { + fields = mapMemoryMessageESFields(fields, true) + } boost := 1.0 if matchText.ExtraOptions != nil { @@ -1507,6 +1693,10 @@ func buildRankFeatureQuery(rankFeature map[string]float64) []map[string]interfac // GetChunk gets a chunk by ID using ES search API func (e *elasticsearchEngine) GetChunk(ctx context.Context, baseName, chunkID string, datasetIDs []string) (interface{}, error) { + if strings.HasPrefix(baseName, "memory_") { + return e.getMemoryMessage(ctx, baseName, chunkID) + } + // Try search by doc_id field (which is stored in the document) for _, datasetID := range datasetIDs { searchReq := map[string]interface{}{ @@ -1575,6 +1765,41 @@ func (e *elasticsearchEngine) GetChunk(ctx context.Context, baseName, chunkID st return nil, nil } +func (e *elasticsearchEngine) getMemoryMessage(ctx context.Context, indexName, docID string) (interface{}, error) { + req := esapi.GetRequest{ + Index: indexName, + DocumentID: docID, + } + res, err := req.Do(ctx, e.client) + if err != nil { + return nil, fmt.Errorf("failed to get memory message: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + if res.StatusCode == http.StatusNotFound { + return nil, fmt.Errorf("%w: %s", types.ErrDocumentNotFound, docID) + } + return nil, fmt.Errorf("elasticsearch memory message get error: %s", res.Status()) + } + + var getResult struct { + Found bool `json:"found"` + Source map[string]interface{} `json:"_source"` + } + if err := json.NewDecoder(res.Body).Decode(&getResult); err != nil { + return nil, fmt.Errorf("failed to parse memory message get response: %w", err) + } + if !getResult.Found || getResult.Source == nil { + return nil, nil + } + + message := getResult.Source + message["id"] = docID + normalizeMemoryMessageChunks([]map[string]interface{}{message}) + return message, nil +} + // GetFields extracts the requested fields from ES search response chunks // // Unlike Infinity, Elasticsearch does NOT use convertSelectFields before querying. diff --git a/internal/engine/elasticsearch/chunk_test.go b/internal/engine/elasticsearch/chunk_test.go index 34313d2895..9bd6f06c58 100644 --- a/internal/engine/elasticsearch/chunk_test.go +++ b/internal/engine/elasticsearch/chunk_test.go @@ -379,7 +379,7 @@ func sortedCopy(in []int) []int { func TestBuildBoolQueryFromConditionIDFilter(t *testing.T) { check := func(name string, cond map[string]interface{}, wantFields []string) { t.Helper() - got := buildBoolQueryFromCondition(cond, nil, false) + got := buildBoolQueryFromCondition(cond, nil, false, false) outer, ok := got["bool"].(map[string]interface{}) if !ok { t.Fatalf("%s: missing bool wrapper: %v", name, got) diff --git a/internal/engine/infinity/chunk.go b/internal/engine/infinity/chunk.go index b3205abc24..19a9155959 100644 --- a/internal/engine/infinity/chunk.go +++ b/internal/engine/infinity/chunk.go @@ -1282,6 +1282,13 @@ func applyFieldMappings(chunks []map[string]interface{}) { chunk["authors_sm_tks"] = val } + if val, ok := chunk["message_type_kwd"]; ok { + chunk["message_type"] = val + } + if val, ok := chunk["status_int"]; ok { + chunk["status"] = memoryMessageStatusBool(val) + } + // position_int: convert from hex string to array format (grouped by 5) if val, ok := chunk["position_int"].(string); ok { chunk["position_int"] = utility.ConvertHexToPositionIntArray(val) @@ -1332,6 +1339,26 @@ func applyFieldMappings(chunks []map[string]interface{}) { } } +func memoryMessageStatusBool(value interface{}) bool { + switch v := value.(type) { + case bool: + return v + case int: + return v != 0 + case int64: + return v != 0 + case float64: + return v != 0 + case json.Number: + n, err := v.Int64() + return err == nil && n != 0 + case string: + return v != "" && v != "0" && !strings.EqualFold(v, "false") + default: + return false + } +} + // GetFields extracts the requested fields from Infinity search results func (e *infinityEngine) GetFields(chunks []map[string]interface{}, fields []string) map[string]map[string]interface{} { result := make(map[string]map[string]interface{}) @@ -1837,6 +1864,8 @@ func convertSelectFields(output []string, isSkillIndex ...bool) []string { "content_sm_ltks": "content", "authors_tks": "authors", "authors_sm_tks": "authors", + "message_type": "message_type_kwd", + "status": "status_int", } skillIndex := false @@ -2005,6 +2034,12 @@ func equivalentConditionToStr(condition map[string]interface{}) string { continue } + if k == "message_type" { + k = "message_type_kwd" + } else if k == "status" { + k = "status_int" + } + // Handle list values (mixed types - strings get quotes, numbers don't) if list, ok := v.([]interface{}); ok && len(list) > 0 { var strItems, numItems []string @@ -2210,6 +2245,19 @@ func transformChunkFields(chunk map[string]interface{}, embeddingCols [][2]inter d["questions"] = strings.Join(utility.ConvertToStringSlice(v), "\n") case "tag_kwd": d["tag_kwd"] = strings.Join(utility.ConvertToStringSlice(v), "###") + case "message_type": + d["message_type_kwd"] = v + case "status": + switch status := v.(type) { + case bool: + if status { + d["status_int"] = 1 + } else { + d["status_int"] = 0 + } + default: + d["status_int"] = v + } case "question_tks": if _, exists := chunk["question_kwd"]; !exists { d["questions"] = utility.ConvertToString(v) diff --git a/internal/handler/memory.go b/internal/handler/memory.go index 94594d2f38..219f5194c7 100644 --- a/internal/handler/memory.go +++ b/internal/handler/memory.go @@ -20,6 +20,7 @@ package handler import ( + "encoding/json" "errors" "net/http" "os" @@ -515,13 +516,67 @@ func (h *MemoryHandler) GetMemoryConfig(c *gin.Context) { // - message: true // - data.messages: Array of message objects // - data.storage_type: Storage type -// -// TODO: Implementation pending - depends on CanvasService and TaskService func (h *MemoryHandler) GetMemoryMessages(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + userID := strings.TrimSpace(user.ID) + if userID == "" { + jsonError(c, common.CodeAuthenticationError, "user id is required") + return + } + + memoryID := strings.TrimSpace(c.Param("memory_id")) + if memoryID == "" { + jsonError(c, common.CodeArgumentError, "memory_id is required") + return + } + + var agentIDs []string + values := c.QueryArray("agent_id") + for _, v := range values { + parts := strings.Split(v, ",") + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + agentIDs = append(agentIDs, p) + } + } + } + + keywords := strings.TrimSpace(c.DefaultQuery("keywords", "")) + page, err := strconv.Atoi(c.DefaultQuery("page", "1")) + if err != nil || page <= 0 { + jsonError(c, common.CodeArgumentError, "page must be a positive integer") + return + } + pageSize, err := strconv.Atoi(c.DefaultQuery("page_size", "50")) + if err != nil || pageSize <= 0 { + jsonError(c, common.CodeArgumentError, "page_size must be a positive integer") + return + } + if pageSize > 100 { + jsonError(c, common.CodeArgumentError, "page_size must be less than or equal to 100") + return + } + + data, err := h.memoryService.GetMemoryMessages(c.Request.Context(), userID, memoryID, agentIDs, keywords, page, pageSize) + if err != nil { + if isMemoryServiceNotFound(err) { + jsonError(c, common.CodeNotFound, err.Error()) + return + } + jsonError(c, common.CodeServerError, "Internal server error") + return + } + c.JSON(http.StatusOK, gin.H{ - "code": common.CodeServerError, - "message": "GetMemoryMessages not implemented - pending CanvasService and TaskService dependencies", - "data": nil, + "code": common.CodeSuccess, + "message": true, + "data": data, }) } @@ -541,7 +596,7 @@ func (h *MemoryHandler) GetMemoryMessages(c *gin.Context) { // - agent_response (required): Agent response // - user_id (optional): User ID // -// TODO: Implementation pending - depends on embedding engine +// TODO: Haruko386 is implementing this for now, if you implement this, delete this line plz func (h *MemoryHandler) AddMessage(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "code": common.CodeServerError, @@ -645,16 +700,123 @@ func parseMemoryMessagePath(memoryMessage string) (string, int64, error) { // // Request Parameters (JSON Body): // - status (required): Message status, boolean -// -// TODO: Implementation pending - depends on embedding engine func (h *MemoryHandler) UpdateMessage(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + userID := strings.TrimSpace(user.ID) + if userID == "" { + jsonError(c, common.CodeAuthenticationError, "user id is required") + return + } + + memoryID, messageID, err := parseMemoryMessagePath(c.Param("memory_message")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "message": err.Error(), + }) + return + } + + var req map[string]interface{} + if err = json.NewDecoder(c.Request.Body).Decode(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": common.CodeArgumentError, + "message": err.Error(), + }) + return + } + + status, ok := req["status"].(bool) + if !ok { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "message": "Status must be a boolean.", + "data": nil, + }) + return + } + + ok, err = h.memoryService.UpdateMessage(c.Request.Context(), userID, memoryID, messageID, status) + if err != nil || !ok { + if isMemoryServiceNotFound(err) { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeNotFound, + "message": err.Error(), + "data": nil, + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": "Internal server error", + "data": nil, + }) + return + } + c.JSON(http.StatusOK, gin.H{ - "code": common.CodeServerError, - "message": "UpdateMessage not implemented - pending embedding engine dependency", + "code": common.CodeSuccess, + "message": true, "data": nil, }) } +// GetMessageContent handles GET request for getting message content +// API Path: GET /api/v1/messages/:memory_id/:message_id/content +// +// Function: +// - Gets complete content of the specified message +// - doc_id format: memory_id + "_" + message_id +// +// Parameter Format: +// - memory_id: Memory ID +// - message_id: Message ID (integer) +// + +func (h *MemoryHandler) GetMessageContent(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + userID := strings.TrimSpace(user.ID) + if userID == "" { + jsonError(c, common.CodeAuthenticationError, "user id is required") + return + } + + memoryID, messageID, err := parseMemoryMessagePath(c.Param("memory_message")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "message": err.Error(), + }) + return + } + + data, err := h.memoryService.GetMessageContent(c.Request.Context(), userID, memoryID, messageID) + if err != nil { + if _, ok := err.(*service.ResourceNotFoundError); ok { + jsonError(c, common.CodeNotFound, err.Error()) + return + } + jsonError(c, common.CodeServerError, err.Error()) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": true, + "data": data, + }) +} + // SearchMessage handles GET request for searching messages // API Path: GET /api/v1/messages/search // @@ -672,13 +834,64 @@ func (h *MemoryHandler) UpdateMessage(c *gin.Context) { // - agent_id (optional): Agent ID filter // - session_id (optional): Session ID filter // - user_id (optional): User ID filter -// -// TODO: Implementation pending - depends on embedding engine func (h *MemoryHandler) SearchMessage(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + userID := strings.TrimSpace(user.ID) + if userID == "" { + jsonError(c, common.CodeAuthenticationError, "user id is required") + return + } + + var memoryIDs []string + values := c.QueryArray("memory_id") + for _, v := range values { + parts := strings.Split(v, ",") + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + memoryIDs = append(memoryIDs, p) + } + } + } + + query := c.Query("query") + + similarityThreshold, _ := strconv.ParseFloat(c.DefaultQuery("similarity_threshold", "0.2"), 64) + keywordsSimilarityWeight, _ := strconv.ParseFloat(c.DefaultQuery("keywords_similarity_weight", "0.7"), 64) + topN, _ := strconv.Atoi(c.DefaultQuery("top_n", "5")) + + agentID := c.DefaultQuery("agent_id", "") + sessionID := c.DefaultQuery("session_id", "") + + filterDict := map[string]interface{}{ + "memory_id": memoryIDs, + "agent_id": agentID, + "session_id": sessionID, + "user_id": c.DefaultQuery("user_id", ""), + } + + params := map[string]interface{}{ + "query": query, + "similarity_threshold": similarityThreshold, + "keywords_similarity_weight": keywordsSimilarityWeight, + "top_n": topN, + } + + res, code, err := h.memoryService.SearchMessage(c.Request.Context(), userID, filterDict, params) + if err != nil { + jsonError(c, code, err.Error()) + return + } + c.JSON(http.StatusOK, gin.H{ - "code": common.CodeServerError, - "message": "SearchMessage not implemented - pending embedding engine dependency", - "data": nil, + "code": common.CodeSuccess, + "message": true, + "data": res, }) } @@ -694,33 +907,49 @@ func (h *MemoryHandler) SearchMessage(c *gin.Context) { // - agent_id (optional): Agent ID filter // - session_id (optional): Session ID filter // - limit (optional): Number of results to return, default 10 -// -// TODO: Implementation pending - depends on embedding engine func (h *MemoryHandler) GetMessages(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{ - "code": common.CodeServerError, - "message": "GetMessages not implemented - pending embedding engine dependency", - "data": nil, - }) -} + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + userID := strings.TrimSpace(user.ID) + if userID == "" { + jsonError(c, common.CodeAuthenticationError, "user id is required") + return + } + + var memoryIDs []string + values := c.QueryArray("memory_id") + for _, v := range values { + parts := strings.Split(v, ",") + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + memoryIDs = append(memoryIDs, p) + } + } + } + + agentID := c.DefaultQuery("agent_id", "") + sessionID := c.DefaultQuery("session_id", "") + limit, _ := strconv.Atoi(c.DefaultQuery("limit", "10")) + if len(memoryIDs) == 0 { + jsonError(c, common.CodeArgumentError, "memory_ids is required.") + return + } + + data, code, err := h.memoryService.GetMessages(c.Request.Context(), memoryIDs, userID, agentID, sessionID, limit) + if err != nil { + jsonError(c, code, err.Error()) + return + } -// GetMessageContent handles GET request for getting message content -// API Path: GET /api/v1/messages/:memory_id/:message_id/content -// -// Function: -// - Gets complete content of the specified message -// - doc_id format: memory_id + "_" + message_id -// -// Parameter Format: -// - memory_id: Memory ID -// - message_id: Message ID (integer) -// -// TODO: Implementation pending - depends on embedding engine -func (h *MemoryHandler) GetMessageContent(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ - "code": common.CodeServerError, - "message": "GetMessageContent not implemented - pending embedding engine dependency", - "data": nil, + "code": common.CodeSuccess, + "message": true, + "data": data, }) } diff --git a/internal/router/router.go b/internal/router/router.go index ab72024798..f6a65e1d0f 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -414,8 +414,11 @@ func (r *Router) Setup(engine *gin.Engine) { // Message routes message := v1.Group("/messages") { + message.GET("", r.memoryHandler.GetMessages) message.DELETE("/:memory_message", r.memoryHandler.ForgetMessage) message.PUT("/:memory_message", r.memoryHandler.UpdateMessage) + message.GET("/:memory_message/content", r.memoryHandler.GetMessageContent) + message.GET("/search", r.memoryHandler.SearchMessage) } // Skill search routes diff --git a/internal/service/memory.go b/internal/service/memory.go index 88440c3476..87855014b9 100644 --- a/internal/service/memory.go +++ b/internal/service/memory.go @@ -20,8 +20,10 @@ import ( "context" "errors" "fmt" + "os" "ragflow/internal/common" "ragflow/internal/entity" + models "ragflow/internal/entity/models" "strconv" "strings" "time" @@ -29,6 +31,7 @@ import ( "ragflow/internal/dao" "ragflow/internal/engine" enginetypes "ragflow/internal/engine/types" + "ragflow/internal/service/nlp" ) const ( @@ -55,6 +58,12 @@ const ( TenantPermissionAll TenantPermission = "all" ) +const ( + defaultMessageTopN = 5 + defaultMessageLimit = 10 + maxMessageLimit = 100 +) + // validPermissions defines which permission values are valid var validPermissions = map[TenantPermission]bool{ TenantPermissionMe: true, @@ -793,7 +802,7 @@ func (s *MemoryService) ForgetMessage(ctx context.Context, userID string, memory condition := map[string]interface{}{ "id": messageDocID, } - indexName := fmt.Sprintf("memory_%s", memory.TenantID) + indexName := memoryIndexName(memory.TenantID) if err := s.docEngine.UpdateChunks(ctx, condition, updates, indexName, memoryID); err != nil { if isMessageDocumentNotFound(err) { @@ -807,6 +816,453 @@ func (s *MemoryService) ForgetMessage(ctx context.Context, userID string, memory return nil } +func (s *MemoryService) UpdateMessage(ctx context.Context, userID, memoryID string, messageID int64, status bool) (bool, error) { + memory, err := s.requireMemoryAccess(ctx, userID, memoryID) + if err != nil { + return false, err + } + + if s.docEngine == nil { + return false, errors.New("message store is not initialized") + } + + messageDocID := fmt.Sprintf("%s_%d", memoryID, messageID) + statusValue := 0 + if status { + statusValue = 1 + } + updates := map[string]interface{}{ + "status": statusValue, + } + condition := map[string]interface{}{ + "id": messageDocID, + } + indexName := memoryIndexName(memory.TenantID) + if err := s.docEngine.UpdateChunks(ctx, condition, updates, indexName, memoryID); err != nil { + if isMessageDocumentNotFound(err) { + return false, &ResourceNotFoundError{Resource: "Message", ID: messageDocID} + } + return false, fmt.Errorf("failed to set status for message '%d' in memory '%s': %w", messageID, memoryID, err) + } + + return true, nil +} + +func (s *MemoryService) GetMessageContent(ctx context.Context, userID, memoryID string, messageID int64) (map[string]interface{}, error) { + memory, err := s.requireMemoryAccess(ctx, userID, memoryID) + if err != nil { + return nil, err + } + if s.docEngine == nil { + return nil, errors.New("message store is not initialized") + } + + indexName := memoryIndexName(memory.TenantID) + docID := fmt.Sprintf("%s_%d", memoryID, messageID) + res, err := s.docEngine.GetChunk(ctx, indexName, docID, []string{memoryID}) + if err != nil { + if isMessageDocumentNotFound(err) { + return nil, &ResourceNotFoundError{Resource: "Message", ID: docID} + } + return nil, err + } + if res == nil { + return nil, &ResourceNotFoundError{Resource: "Message", ID: docID} + } + + message, ok := res.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("unexpected message content type %T", res) + } + return common.ConvertFloatsToPyFormat(message).(map[string]interface{}), nil +} + +func (s *MemoryService) SearchMessage(ctx context.Context, userID string, filterDict, params map[string]interface{}) ([]map[string]interface{}, common.ErrorCode, error) { + memoryIDs := splitFilterValues(filterDict["memory_id"]) + memories, err := s.filterAccessibleMemories(ctx, userID, memoryIDs) + if err != nil { + return nil, common.CodeServerError, err + } + if len(memories) == 0 { + return []map[string]interface{}{}, common.CodeSuccess, nil + } + + return s.queryMessage(ctx, memories, filterDict, params) +} + +func (s *MemoryService) queryMessage(ctx context.Context, memories []*entity.Memory, filterDict, params map[string]interface{}) ([]map[string]interface{}, common.ErrorCode, error) { + if s.docEngine == nil { + return nil, common.CodeServerError, errors.New("message store is not initialized") + } + + topN := memoryIntParam(params["top_n"], 5) + if topN <= 0 { + topN = defaultMessageTopN + } else if topN > maxMessageLimit { + topN = maxMessageLimit + } + similarityThreshold := memoryFloatParam(params["similarity_threshold"], 0.2) + keywordsSimilarityWeight := memoryFloatParam(params["keywords_similarity_weight"], 0.7) + question := strings.TrimSpace(memoryStringParam(params["query"])) + + memoryIDs := make([]string, 0, len(memories)) + conditionDict := make(map[string]interface{}) + for _, memory := range memories { + if memory == nil { + continue + } + memoryIDs = append(memoryIDs, memory.ID) + } + conditionDict["memory_id"] = memoryIDs + for _, key := range []string{"agent_id", "session_id", "user_id"} { + value := strings.TrimSpace(memoryStringParam(filterDict[key])) + if value != "" { + conditionDict[key] = value + } + } + if _, ok := conditionDict["status"]; !ok { + conditionDict["status"] = 1 + } + + matchExprs := make([]interface{}, 0, 3) + if question != "" { + matchText := memoryMessageTextExpr(question, similarityThreshold) + matchDense, err := s.memoryMessageDenseExpr(question, memories[0], topN, similarityThreshold) + if err != nil { + return nil, common.CodeServerError, err + } + fusionExpr := &enginetypes.FusionExpr{ + Method: "weighted_sum", + TopN: topN, + FusionParams: map[string]interface{}{ + "weights": fmt.Sprintf("%g,%g", 1-keywordsSimilarityWeight, keywordsSimilarityWeight), + }, + } + matchExprs = append(matchExprs, matchText, matchDense, fusionExpr) + } + + searchReq := &enginetypes.SearchRequest{ + IndexNames: memorySearchIndexNames(memories), + Offset: 0, + Limit: topN, + SelectFields: memoryMessageSelectFields(), + Filter: conditionDict, + MatchExprs: matchExprs, + OrderBy: (&enginetypes.OrderByExpr{}).Desc("valid_at"), + } + + searchResult, err := s.docEngine.Search(ctx, searchReq) + if err != nil { + return nil, common.CodeServerError, err + } + if searchResult == nil || searchResult.Total == 0 { + return []map[string]interface{}{}, common.CodeSuccess, nil + } + + messages := make([]map[string]interface{}, 0, len(searchResult.Chunks)) + for _, chunk := range searchResult.Chunks { + message := make(map[string]interface{}, len(chunk)) + for _, field := range memoryMessageSelectFields() { + if value, ok := chunk[field]; ok { + message[field] = value + } + } + messages = append(messages, message) + } + return common.ConvertFloatsToPyFormat(messages).([]map[string]interface{}), common.CodeSuccess, nil +} + +func (s *MemoryService) filterAccessibleMemories(ctx context.Context, userID string, memoryIDs []string) ([]*entity.Memory, error) { + memoryIDs = splitFilterValues(memoryIDs) + if len(memoryIDs) == 0 { + return []*entity.Memory{}, nil + } + + memories, err := s.memoryDAO.GetByIDs(memoryIDs) + if err != nil { + return nil, err + } + if len(memories) == 0 { + return []*entity.Memory{}, nil + } + + joinedTenantIDs := map[string]struct{}{userID: {}} + needsTeamLookup := false + for _, memory := range memories { + if memory != nil && memory.TenantID != userID && memory.Permissions == string(TenantPermissionTeam) { + needsTeamLookup = true + break + } + } + if needsTeamLookup { + userTenants, err := NewUserTenantService().GetUserTenantRelationByUserIDWithContext(ctx, userID) + if err != nil { + return nil, err + } + for _, tenant := range userTenants { + if tenant != nil && tenant.TenantID != "" { + joinedTenantIDs[tenant.TenantID] = struct{}{} + } + } + } + + accessible := make([]*entity.Memory, 0, len(memories)) + for _, memory := range memories { + if memory == nil { + continue + } + if memory.TenantID == userID { + accessible = append(accessible, memory) + continue + } + if memory.Permissions != string(TenantPermissionTeam) { + continue + } + if _, ok := joinedTenantIDs[memory.TenantID]; ok { + accessible = append(accessible, memory) + } + } + return accessible, nil +} + +func splitFilterValues(values interface{}) []string { + if values == nil { + return []string{} + } + + var list []string + + switch v := values.(type) { + case string: + list = []string{v} + case []string: + list = v + case []interface{}: + for _, x := range v { + if s, ok := x.(string); ok { + list = append(list, s) + } + } + default: + return []string{} + } + + res := make([]string, 0) + for _, item := range list { + if item == "" { + continue + } + parts := strings.Split(item, ",") + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + res = append(res, p) + } + } + } + return res +} + +func memoryStringParam(value interface{}) string { + switch typed := value.(type) { + case string: + return typed + case fmt.Stringer: + return typed.String() + default: + if typed == nil { + return "" + } + return fmt.Sprintf("%v", typed) + } +} + +func memoryFloatParam(value interface{}, fallback float64) float64 { + switch typed := value.(type) { + case float64: + return typed + case float32: + return float64(typed) + case int: + return float64(typed) + case int64: + return float64(typed) + case string: + if parsed, err := strconv.ParseFloat(strings.TrimSpace(typed), 64); err == nil { + return parsed + } + } + return fallback +} + +func memoryIntParam(value interface{}, fallback int) int { + switch typed := value.(type) { + case int: + return typed + case int64: + return int(typed) + case float64: + return int(typed) + case string: + if parsed, err := strconv.Atoi(strings.TrimSpace(typed)); err == nil { + return parsed + } + } + return fallback +} + +func memoryMessageSelectFields() []string { + return []string{ + "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", + "valid_at", "invalid_at", "forget_at", "status", "content", + } +} + +func memoryIndexName(tenantID string) string { + prefix := strings.TrimSpace(os.Getenv("ES_INDEX_PREFIX")) + if prefix == "" { + return fmt.Sprintf("memory_%s", tenantID) + } + return fmt.Sprintf("memory_%s_%s", prefix, tenantID) +} + +func memorySearchIndexNames(memories []*entity.Memory) []string { + seen := make(map[string]struct{}, len(memories)) + indexNames := make([]string, 0, len(memories)) + for _, memory := range memories { + if memory == nil { + continue + } + indexName := memoryIndexName(memory.TenantID) + if engine.GetEngineType() == engine.EngineInfinity { + indexName = fmt.Sprintf("%s_%s", indexName, memory.ID) + } + if _, ok := seen[indexName]; ok { + continue + } + seen[indexName] = struct{}{} + indexNames = append(indexNames, indexName) + } + return indexNames +} + +func memoryMessageTextExpr(question string, similarityThreshold float64) *enginetypes.MatchTextExpr { + matchText := &enginetypes.MatchTextExpr{ + Fields: []string{"content"}, + MatchingText: question, + TopN: 100, + ExtraOptions: map[string]interface{}{"original_query": question}, + } + + queryBuilder := nlp.GetQueryBuilder() + if queryBuilder == nil { + queryBuilder = nlp.NewQueryBuilder() + } + if built, _ := queryBuilder.Question(question, "messages", similarityThreshold); built != nil { + matchText.MatchingText = built.MatchingText + matchText.ExtraOptions = built.ExtraOptions + if matchText.ExtraOptions == nil { + matchText.ExtraOptions = map[string]interface{}{} + } + matchText.ExtraOptions["original_query"] = question + } + matchText.Fields = []string{"content"} + matchText.TopN = 100 + return matchText +} + +func (s *MemoryService) memoryMessageDenseExpr(question string, memory *entity.Memory, topN int, similarityThreshold float64) (*enginetypes.MatchDenseExpr, error) { + driver, modelName, apiConfig, maxTokens, err := NewModelProviderService().GetModelConfigFromProviderInstance(memory.TenantID, entity.ModelTypeEmbedding, memory.EmbdID) + if err != nil { + return nil, err + } + embeddingModel := models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens) + embeddings, err := embeddingModel.ModelDriver.Embed(embeddingModel.ModelName, []string{question}, embeddingModel.APIConfig, &models.EmbeddingConfig{Dimension: 0}) + if err != nil { + return nil, err + } + if len(embeddings) == 0 || len(embeddings[0].Embedding) == 0 { + return nil, errors.New("embedding response is empty") + } + + vector := embeddings[0].Embedding + return &enginetypes.MatchDenseExpr{ + VectorColumnName: fmt.Sprintf("q_%d_vec", len(vector)), + EmbeddingData: vector, + EmbeddingDataType: "float", + DistanceType: "cosine", + TopN: topN, + ExtraOptions: map[string]interface{}{"similarity": similarityThreshold}, + }, nil +} + +func (s *MemoryService) GetMessages(ctx context.Context, memoryIDs []string, userID, agentID, sessionID string, limit int) ([]map[string]interface{}, common.ErrorCode, error) { + memories, err := s.filterAccessibleMemories(ctx, userID, memoryIDs) + if err != nil { + return nil, common.CodeServerError, err + } + if len(memories) == 0 { + return []map[string]interface{}{}, common.CodeSuccess, nil + } + return s.getRecentMessage(ctx, memories, agentID, sessionID, limit) +} + +func (s *MemoryService) getRecentMessage(ctx context.Context, memories []*entity.Memory, agentID, sessionID string, limit int) ([]map[string]interface{}, common.ErrorCode, error) { + if s.docEngine == nil { + return nil, common.CodeServerError, errors.New("doc engine is nil") + } + if limit <= 0 { + limit = defaultMessageLimit + } else if limit > maxMessageLimit { + limit = maxMessageLimit + } + indexNames := memorySearchIndexNames(memories) + memoryIDs := make([]string, 0, len(memories)) + for _, memory := range memories { + if memory == nil || strings.TrimSpace(memory.ID) == "" { + continue + } + memoryIDs = append(memoryIDs, memory.ID) + } + + conditionDict := map[string]interface{}{"memory_id": memoryIDs} + if agentID = strings.TrimSpace(agentID); agentID != "" { + conditionDict["agent_id"] = agentID + } + if sessionID = strings.TrimSpace(sessionID); sessionID != "" { + conditionDict["session_id"] = sessionID + } + req := &enginetypes.SearchRequest{ + IndexNames: indexNames, + Offset: 0, + Limit: limit, + SelectFields: memoryMessageSelectFields(), + Filter: conditionDict, + MatchExprs: []interface{}{}, + OrderBy: (&enginetypes.OrderByExpr{}).Desc("valid_at"), + } + + result, err := s.docEngine.Search(ctx, req) + if err != nil { + return nil, common.CodeServerError, err + } + if result == nil || result.Total == 0 { + return []map[string]interface{}{}, common.CodeSuccess, nil + } + + messages := make([]map[string]interface{}, 0, len(result.Chunks)) + for _, chunk := range result.Chunks { + msg := make(map[string]interface{}, len(chunk)) + for _, field := range memoryMessageSelectFields() { + if val, ok := chunk[field]; ok { + msg[field] = val + } + } + messages = append(messages, msg) + } + return common.ConvertFloatsToPyFormat(messages).([]map[string]interface{}), common.CodeSuccess, nil +} + func isMessageDocumentNotFound(err error) bool { return errors.Is(err, enginetypes.ErrDocumentNotFound) } @@ -933,8 +1389,266 @@ func (s *MemoryService) GetMemoryConfig(memoryID string) (*CreateMemoryResponse, return formatRetDataFromMemoryListItem(memory), nil } -// TODO: GetMemoryMessages - Implementation pending - depends on CanvasService and TaskService -// func (s *MemoryService) GetMemoryMessages(memoryID string, agentIDs []string, keywords string, page int, pageSize int) (map[string]interface{}, error) { ... } +func (s *MemoryService) GetMemoryMessages(ctx context.Context, userID, memoryID string, agentIDs []string, keywords string, page int, pageSize int) (map[string]interface{}, error) { + memory, err := s.requireMemoryAccess(ctx, userID, memoryID) + if err != nil { + return nil, err + } + + messages, err := s.listMemoryMessages(ctx, memory, agentIDs, keywords, page, pageSize) + if err != nil { + return nil, err + } + + rawMessages, _ := messages["message_list"].([]map[string]interface{}) + agentNames := map[string]string{} + tasks := map[string]map[string]interface{}{} + if len(rawMessages) > 0 { + agentNames, err = s.memoryMessageAgentNames(rawMessages) + if err != nil { + return nil, err + } + tasks, err = s.memoryMessageTasks(memoryID) + if err != nil { + return nil, err + } + } + + for _, message := range rawMessages { + agentID, _ := message["agent_id"].(string) + message["agent_name"] = "Unknown" + if name, ok := agentNames[agentID]; ok { + message["agent_name"] = name + } + message["task"] = map[string]interface{}{} + if task, ok := tasks[memoryMessageKey(message["message_id"])]; ok { + message["task"] = task + } + if extracts, ok := message["extract"].([]map[string]interface{}); ok { + for _, extract := range extracts { + extractAgentID, _ := extract["agent_id"].(string) + extract["agent_name"] = "Unknown" + if name, ok := agentNames[extractAgentID]; ok { + extract["agent_name"] = name + } + } + } + } + + return common.ConvertFloatsToPyFormat(map[string]interface{}{ + "messages": messages, + "storage_type": memory.StorageType, + }).(map[string]interface{}), nil +} + +func (s *MemoryService) listMemoryMessages(ctx context.Context, memory *entity.Memory, agentIDs []string, keywords string, page int, pageSize int) (map[string]interface{}, error) { + if s.docEngine == nil { + return nil, errors.New("message store is not initialized") + } + if page <= 0 { + page = 1 + } + if pageSize <= 0 { + pageSize = defaultMessageLimit + } else if pageSize > maxMessageLimit { + pageSize = maxMessageLimit + } + + memoryID := memory.ID + selectFields := memoryMessageListFields() + filter := map[string]interface{}{ + "message_type": "raw", + } + if len(agentIDs) > 0 { + filter["agent_id"] = agentIDs + } + if keywords = strings.TrimSpace(keywords); keywords != "" { + filter["session_id"] = keywords + } + filter["memory_id"] = []string{memoryID} + indexNames := memorySearchIndexNames([]*entity.Memory{memory}) + + rawReq := &enginetypes.SearchRequest{ + IndexNames: indexNames, + Offset: (page - 1) * pageSize, + Limit: pageSize, + SelectFields: selectFields, + Filter: filter, + MatchExprs: []interface{}{}, + OrderBy: (&enginetypes.OrderByExpr{}).Desc("valid_at"), + } + rawResult, err := s.docEngine.Search(ctx, rawReq) + if err != nil { + return nil, err + } + + messages := map[string]interface{}{ + "message_list": []map[string]interface{}{}, + "total_count": int64(0), + } + if rawResult != nil { + messages["total_count"] = rawResult.Total + } + if rawResult == nil || rawResult.Total == 0 { + return messages, nil + } + + rawMessages := make([]map[string]interface{}, 0, len(rawResult.Chunks)) + sourceIDs := make([]interface{}, 0, len(rawResult.Chunks)) + for _, chunk := range rawResult.Chunks { + message := memoryMessageFromChunk(chunk, selectFields) + message["extract"] = []map[string]interface{}{} + if messageID, ok := message["message_id"]; ok { + sourceIDs = append(sourceIDs, messageID) + } + rawMessages = append(rawMessages, message) + } + + if len(sourceIDs) > 0 { + extractReq := &enginetypes.SearchRequest{ + IndexNames: indexNames, + Offset: 0, + Limit: 512, + SelectFields: selectFields, + Filter: map[string]interface{}{ + "memory_id": []string{memoryID}, + "source_id": sourceIDs, + }, + MatchExprs: []interface{}{}, + OrderBy: (&enginetypes.OrderByExpr{}).Desc("valid_at"), + } + extractResult, err := s.docEngine.Search(ctx, extractReq) + if err != nil { + return nil, err + } + if extractResult != nil && extractResult.Total > 0 { + groupedExtracts := make(map[string][]map[string]interface{}) + for _, chunk := range extractResult.Chunks { + message := memoryMessageFromChunk(chunk, selectFields) + sourceID := memoryMessageKey(message["source_id"]) + groupedExtracts[sourceID] = append(groupedExtracts[sourceID], message) + } + for _, message := range rawMessages { + messageID := memoryMessageKey(message["message_id"]) + if extracts, ok := groupedExtracts[messageID]; ok { + message["extract"] = extracts + } + } + } + } + + messages["message_list"] = rawMessages + return messages, nil +} + +func memoryMessageListFields() []string { + return []string{ + "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", + "valid_at", "invalid_at", "forget_at", "status", + } +} + +func memoryMessageFromChunk(chunk map[string]interface{}, fields []string) map[string]interface{} { + message := make(map[string]interface{}, len(fields)) + for _, field := range fields { + if value, ok := chunk[field]; ok { + message[field] = value + } + } + return message +} + +func (s *MemoryService) memoryMessageAgentNames(messages []map[string]interface{}) (map[string]string, error) { + agentIDSet := make(map[string]struct{}) + for _, message := range messages { + agentID, _ := message["agent_id"].(string) + if agentID != "" { + agentIDSet[agentID] = struct{}{} + } + } + if len(agentIDSet) == 0 { + return map[string]string{}, nil + } + + agentIDs := make([]string, 0, len(agentIDSet)) + for agentID := range agentIDSet { + agentIDs = append(agentIDs, agentID) + } + + var agents []struct { + ID string `gorm:"column:id"` + Title *string `gorm:"column:title"` + } + if err := dao.DB.Model(&entity.UserCanvas{}).Select("id, title").Where("id IN ?", agentIDs).Scan(&agents).Error; err != nil { + return nil, err + } + agentNames := make(map[string]string, len(agents)) + for _, agent := range agents { + if agent.Title != nil { + agentNames[agent.ID] = *agent.Title + } + } + return agentNames, nil +} + +func (s *MemoryService) memoryMessageTasks(memoryID string) (map[string]map[string]interface{}, error) { + var tasks []struct { + ID string `gorm:"column:id"` + DocID string `gorm:"column:doc_id"` + FromPage int64 `gorm:"column:from_page"` + Progress float64 `gorm:"column:progress"` + ProgressMsg *string `gorm:"column:progress_msg"` + Digest *string `gorm:"column:digest"` + ChunkIDs *string `gorm:"column:chunk_ids"` + CreateTime *int64 `gorm:"column:create_time"` + } + if err := dao.DB.Model(&entity.Task{}). + Select("id, doc_id, from_page, progress, progress_msg, digest, chunk_ids, create_time"). + Where("doc_id IN ?", []string{memoryID}). + Order("create_time ASC"). + Scan(&tasks).Error; err != nil { + return nil, err + } + + taskByMessageID := make(map[string]map[string]interface{}, len(tasks)) + for _, task := range tasks { + if task.Digest == nil { + continue + } + digest := strings.TrimSpace(*task.Digest) + if digest == "" { + continue + } + var progressMsg interface{} + if task.ProgressMsg != nil { + progressMsg = *task.ProgressMsg + } + var chunkIDs interface{} + if task.ChunkIDs != nil { + chunkIDs = *task.ChunkIDs + } + var createTime interface{} + if task.CreateTime != nil { + createTime = *task.CreateTime + } + taskMap := map[string]interface{}{ + "id": task.ID, + "doc_id": task.DocID, + "from_page": task.FromPage, + "progress": task.Progress, + "progress_msg": progressMsg, + "digest": digest, + "chunk_ids": chunkIDs, + "create_time": createTime, + } + taskByMessageID[digest] = taskMap + } + return taskByMessageID, nil +} + +func memoryMessageKey(value interface{}) string { + return strings.TrimSpace(fmt.Sprint(value)) +} // TODO: queryMessages - Implementation pending - depends on CanvasService and TaskService // func (s *MemoryService) queryMessages(tenantID string, memoryID string, filterDict map[string]interface{}, page int, pageSize int) ([]map[string]interface{}, int64, error) { ... } @@ -945,15 +1659,6 @@ func (s *MemoryService) GetMemoryConfig(memoryID string) (*CreateMemoryResponse, // TODO: UpdateMessageStatus - Implementation pending - depends on embedding engine // func (s *MemoryService) UpdateMessageStatus(memoryID string, messageID int, status bool) (bool, error) { ... } -// TODO: SearchMessage - Implementation pending - depends on embedding engine -// func (s *MemoryService) SearchMessage(filterDict map[string]interface{}, params map[string]interface{}) ([]map[string]interface{}, error) { ... } - -// TODO: GetMessages - Implementation pending - depends on embedding engine -// func (s *MemoryService) GetMessages(memoryIDs []string, agentID string, sessionID string, limit int) ([]map[string]interface{}, error) { ... } - -// TODO: GetMessageContent - Implementation pending - depends on embedding engine -// func (s *MemoryService) GetMessageContent(memoryID string, messageID int) (map[string]interface{}, error) { ... } - // isList checks if a value is a list or array type // This is a utility function for type validation // diff --git a/internal/service/memory_message_test.go b/internal/service/memory_message_test.go index 301542b9cc..40bf26c193 100644 --- a/internal/service/memory_message_test.go +++ b/internal/service/memory_message_test.go @@ -4,9 +4,16 @@ import ( "context" "errors" "fmt" + "reflect" "testing" + "github.com/glebarez/sqlite" + "gorm.io/gorm" + + "ragflow/internal/common" + "ragflow/internal/dao" enginetypes "ragflow/internal/engine/types" + "ragflow/internal/entity" ) func TestIsMessageDocumentNotFound(t *testing.T) { @@ -28,3 +35,240 @@ func TestRequireMemoryAccessReturnsCanceledContext(t *testing.T) { t.Fatalf("requireMemoryAccess error = %v, want %v", gotErr, err) } } + +type memoryMessageDocEngine struct { + fakeChatDocEngine + searchReq *enginetypes.SearchRequest + searchResp *enginetypes.SearchResult + updateCond map[string]interface{} + updateValue map[string]interface{} + updateBase string + updateID string +} + +func (e *memoryMessageDocEngine) Search(ctx context.Context, req *enginetypes.SearchRequest) (*enginetypes.SearchResult, error) { + e.searchReq = req + if e.searchResp != nil { + return e.searchResp, nil + } + return &enginetypes.SearchResult{}, nil +} + +func (e *memoryMessageDocEngine) UpdateChunks(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, baseName string, datasetID string) error { + e.updateCond = condition + e.updateValue = newValue + e.updateBase = baseName + e.updateID = datasetID + return nil +} + +func setupMemoryMessageTestDB(t *testing.T) { + t.Helper() + + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{TranslateError: true}) + if err != nil { + t.Fatalf("failed to open sqlite: %v", err) + } + if err := db.AutoMigrate(&entity.Memory{}, &entity.UserTenant{}); err != nil { + t.Fatalf("failed to migrate memory test tables: %v", err) + } + + orig := dao.DB + dao.DB = db + t.Cleanup(func() { + dao.DB = orig + }) +} + +func seedMemoryMessages(t *testing.T) { + t.Helper() + + memories := []*entity.Memory{ + { + ID: "mem-owned", + Name: "Owned", + TenantID: "user-1", + MemoryType: dao.MemoryTypeRaw, + StorageType: "table", + EmbdID: "embd-1", + LLMID: "llm-1", + Permissions: string(TenantPermissionMe), + ForgettingPolicy: string(ForgettingPolicyFIFO), + }, + { + ID: "mem-other", + Name: "Other", + TenantID: "user-2", + MemoryType: dao.MemoryTypeRaw, + StorageType: "table", + EmbdID: "embd-2", + LLMID: "llm-2", + Permissions: string(TenantPermissionMe), + ForgettingPolicy: string(ForgettingPolicyFIFO), + }, + } + for _, memory := range memories { + if err := dao.DB.Create(memory).Error; err != nil { + t.Fatalf("seed memory %s: %v", memory.ID, err) + } + } +} + +func TestGetMessagesFiltersAccessibleMemoryAndBuildsRecentSearch(t *testing.T) { + setupMemoryMessageTestDB(t) + seedMemoryMessages(t) + + docEngine := &memoryMessageDocEngine{ + searchResp: &enginetypes.SearchResult{ + Total: 1, + Chunks: []map[string]interface{}{ + { + "message_id": int64(12), + "message_type": "raw", + "memory_id": "mem-owned", + "user_id": "user-1", + "agent_id": "agent-1", + "session_id": "session-1", + "valid_at": float64(123), + "status": 1, + "content": "hello", + "extra": "should be dropped", + }, + }, + }, + } + svc := &MemoryService{memoryDAO: dao.NewMemoryDAO(), docEngine: docEngine} + + got, code, err := svc.GetMessages(context.Background(), []string{"mem-owned", "mem-other"}, "user-1", "agent-1", "session-1", 3) + if err != nil { + t.Fatalf("GetMessages error: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("code = %v, want %v", code, common.CodeSuccess) + } + if len(got) != 1 || got[0]["content"] != "hello" { + t.Fatalf("unexpected messages: %+v", got) + } + if _, ok := got[0]["extra"]; ok { + t.Fatalf("unexpected non-selected field in response: %+v", got[0]) + } + + req := docEngine.searchReq + if req == nil { + t.Fatal("expected doc engine search request") + } + if !reflect.DeepEqual(req.IndexNames, []string{"memory_user-1"}) { + t.Fatalf("IndexNames = %v, want [memory_user-1]", req.IndexNames) + } + if len(req.KbIDs) != 0 { + t.Fatalf("KbIDs = %v, want empty for memory message search", req.KbIDs) + } + if !reflect.DeepEqual(req.Filter["memory_id"], []string{"mem-owned"}) { + t.Fatalf("memory_id filter = %v, want [mem-owned]", req.Filter["memory_id"]) + } + if req.Filter["agent_id"] != "agent-1" || req.Filter["session_id"] != "session-1" { + t.Fatalf("unexpected filter: %+v", req.Filter) + } + if req.Limit != 3 { + t.Fatalf("Limit = %d, want 3", req.Limit) + } + if req.OrderBy == nil || len(req.OrderBy.Fields) != 1 || req.OrderBy.Fields[0].Field != "valid_at" || req.OrderBy.Fields[0].Type != enginetypes.SortDesc { + t.Fatalf("unexpected order by: %+v", req.OrderBy) + } +} + +func TestSearchMessageFiltersAccessibleMemoryAndDefaultsStatus(t *testing.T) { + setupMemoryMessageTestDB(t) + seedMemoryMessages(t) + + docEngine := &memoryMessageDocEngine{ + searchResp: &enginetypes.SearchResult{ + Total: 1, + Chunks: []map[string]interface{}{ + { + "message_id": int64(13), + "message_type": "raw", + "memory_id": "mem-owned", + "user_id": "user-1", + "agent_id": "agent-1", + "session_id": "session-1", + "valid_at": int64(456), + "status": 1, + "content": "matched", + }, + }, + }, + } + svc := &MemoryService{memoryDAO: dao.NewMemoryDAO(), docEngine: docEngine} + filter := map[string]interface{}{ + "memory_id": []string{"mem-owned", "mem-other"}, + "agent_id": "agent-1", + "session_id": "session-1", + "user_id": "user-1", + } + params := map[string]interface{}{ + "query": "", + "similarity_threshold": 0.2, + "keywords_similarity_weight": 0.7, + "top_n": 5, + } + + got, code, err := svc.SearchMessage(context.Background(), "user-1", filter, params) + if err != nil { + t.Fatalf("SearchMessage error: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("code = %v, want %v", code, common.CodeSuccess) + } + if len(got) != 1 || got[0]["content"] != "matched" { + t.Fatalf("unexpected search result: %+v", got) + } + + req := docEngine.searchReq + if req == nil { + t.Fatal("expected doc engine search request") + } + if !reflect.DeepEqual(req.Filter["memory_id"], []string{"mem-owned"}) { + t.Fatalf("memory_id filter = %v, want [mem-owned]", req.Filter["memory_id"]) + } + if req.Filter["status"] != 1 { + t.Fatalf("status filter = %v, want 1", req.Filter["status"]) + } + if req.Filter["agent_id"] != "agent-1" || req.Filter["session_id"] != "session-1" || req.Filter["user_id"] != "user-1" { + t.Fatalf("unexpected filter: %+v", req.Filter) + } + if len(req.MatchExprs) != 0 { + t.Fatalf("empty query should not build match expressions, got %+v", req.MatchExprs) + } + if req.Limit != 5 { + t.Fatalf("Limit = %d, want 5", req.Limit) + } +} + +func TestUpdateMessageUpdatesStatusByMessageDocID(t *testing.T) { + setupMemoryMessageTestDB(t) + seedMemoryMessages(t) + + docEngine := &memoryMessageDocEngine{} + svc := &MemoryService{memoryDAO: dao.NewMemoryDAO(), docEngine: docEngine} + + ok, err := svc.UpdateMessage(context.Background(), "user-1", "mem-owned", 42, true) + if err != nil { + t.Fatalf("UpdateMessage error: %v", err) + } + if !ok { + t.Fatal("UpdateMessage returned false") + } + if docEngine.updateBase != "memory_user-1" { + t.Fatalf("baseName = %q, want memory_user-1", docEngine.updateBase) + } + if docEngine.updateID != "mem-owned" { + t.Fatalf("datasetID = %q, want mem-owned", docEngine.updateID) + } + if docEngine.updateCond["id"] != "mem-owned_42" { + t.Fatalf("condition = %+v, want id mem-owned_42", docEngine.updateCond) + } + if docEngine.updateValue["status"] != 1 { + t.Fatalf("status update = %+v, want status 1", docEngine.updateValue) + } +}