From a1f1dd500798cfa72b4c8037464c8e295a2d7a9e Mon Sep 17 00:00:00 2001 From: Haruko386 Date: Fri, 26 Jun 2026 19:21:52 +0800 Subject: [PATCH] feat[Go]: implement Add messages for Go (#16375) ### What problem does this PR solve? As title ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- internal/engine/elasticsearch/chunk.go | 256 +++++++++++++++++- internal/handler/auth.go | 4 + internal/handler/memory.go | 87 +++++- internal/router/router.go | 1 + internal/service/memory.go | 101 ++++++- internal/service/memory_message_service.go | 249 +++++++++++------ .../service/memory_message_service_test.go | 4 +- 7 files changed, 605 insertions(+), 97 deletions(-) diff --git a/internal/engine/elasticsearch/chunk.go b/internal/engine/elasticsearch/chunk.go index 299b326a1f..c5efa2b49e 100644 --- a/internal/engine/elasticsearch/chunk.go +++ b/internal/engine/elasticsearch/chunk.go @@ -67,13 +67,22 @@ func (e *elasticsearchEngine) CreateChunkStore(ctx context.Context, baseName, da return fmt.Errorf("failed to check index existence: %w", err) } if exists { + if strings.HasPrefix(baseName, "memory_") { + if err := e.ensureMemoryMessageVectorMapping(ctx, baseName, vectorSize); err != nil { + return fmt.Errorf("failed to ensure memory vector mapping: %w", err) + } + common.Info("Memory index already exists, ensured vector mapping", zap.String("index_name", baseName), zap.Int("vector_size", vectorSize)) + return nil + } common.Info("Index already exists, skipping creation", zap.String("index_name", baseName)) return nil } // Load mapping based on index type var mapping map[string]interface{} - if datasetID == "skill" { + if strings.HasPrefix(baseName, "memory_") { + mapping = getMemoryMessageMapping(vectorSize) + } else if datasetID == "skill" { // Load skill-specific mapping skillMapping, err := loadSkillMapping() if err != nil { @@ -149,6 +158,12 @@ func (e *elasticsearchEngine) InsertChunks(ctx context.Context, chunks []map[str return nil, fmt.Errorf("index name cannot be empty") } + if strings.HasPrefix(baseName, "memory_") { + if err := e.ensureMemoryMessageVectorMappingsForDocs(ctx, baseName, chunks); err != nil { + return nil, err + } + } + // Build bulk request body with index operations (upsert behavior: insert if not exists, update if exists) var buf bytes.Buffer for _, doc := range chunks { @@ -901,6 +916,12 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque hasVectorMatch := matchDense != nil && len(matchDense.EmbeddingData) > 0 if hasVectorMatch { + if isMemoryIndex { + if err := e.ensureMemoryMessageSearchVectorMappings(ctx, req.IndexNames, matchDense.VectorColumnName, len(matchDense.EmbeddingData)); err != nil { + return nil, err + } + } + k := matchDense.TopN if k <= 0 { k = limit @@ -2287,6 +2308,239 @@ func loadSkillMapping() (map[string]interface{}, error) { return mapping, nil } +func memoryMessageVectorField(vectorSize int) string { + return fmt.Sprintf("q_%d_vec", vectorSize) +} + +func memoryMessageVectorProperty(vectorSize int) map[string]interface{} { + return map[string]interface{}{ + "type": "dense_vector", + "dims": vectorSize, + "index": true, + "similarity": "cosine", + } +} + +func parseMemoryMessageVectorSize(field string) (int, bool) { + if !memoryMessageVectorFieldRE.MatchString(field) { + return 0, false + } + sizeText := strings.TrimSuffix(strings.TrimPrefix(field, "q_"), "_vec") + vectorSize, err := strconv.Atoi(sizeText) + if err != nil || vectorSize <= 0 { + return 0, false + } + return vectorSize, true +} + +func (e *elasticsearchEngine) memoryMessageVectorMappingExists(ctx context.Context, indexName, fieldName string) (bool, error) { + req := esapi.IndicesGetMappingRequest{ + Index: []string{indexName}, + } + res, err := req.Do(ctx, e.client) + if err != nil { + return false, fmt.Errorf("failed to get memory vector mapping: %w", err) + } + defer res.Body.Close() + + if res.StatusCode == http.StatusNotFound { + return false, nil + } + if res.IsError() { + bodyBytes, _ := io.ReadAll(res.Body) + reason := extractErrorReason(bodyBytes) + if reason != "" { + return false, fmt.Errorf("elasticsearch error getting memory vector mapping %s.%s: %s", indexName, fieldName, reason) + } + return false, fmt.Errorf("elasticsearch returned error getting memory vector mapping %s.%s: %s, body: %s", indexName, fieldName, res.Status(), string(bodyBytes)) + } + + var mappings map[string]interface{} + if err := json.NewDecoder(res.Body).Decode(&mappings); err != nil { + return false, fmt.Errorf("failed to decode memory vector mapping: %w", err) + } + + indexMapping, ok := mappings[indexName].(map[string]interface{}) + if !ok { + return false, nil + } + mapping, ok := indexMapping["mappings"].(map[string]interface{}) + if !ok { + return false, nil + } + properties, ok := mapping["properties"].(map[string]interface{}) + if !ok { + return false, nil + } + _, ok = properties[fieldName] + return ok, nil +} + +func (e *elasticsearchEngine) ensureMemoryMessageVectorMapping(ctx context.Context, indexName string, vectorSize int) error { + if vectorSize <= 0 { + return fmt.Errorf("memory vector size must be positive, got %d", vectorSize) + } + + fieldName := memoryMessageVectorField(vectorSize) + exists, err := e.memoryMessageVectorMappingExists(ctx, indexName, fieldName) + if err != nil { + return err + } + if exists { + return nil + } + + body := map[string]interface{}{ + "properties": map[string]interface{}{ + fieldName: memoryMessageVectorProperty(vectorSize), + }, + } + data, err := json.Marshal(body) + if err != nil { + return fmt.Errorf("failed to marshal memory vector mapping: %w", err) + } + + req := esapi.IndicesPutMappingRequest{ + Index: []string{indexName}, + Body: bytes.NewReader(data), + } + res, err := req.Do(ctx, e.client) + if err != nil { + return fmt.Errorf("failed to update memory vector mapping: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + bodyBytes, _ := io.ReadAll(res.Body) + reason := extractErrorReason(bodyBytes) + if reason != "" { + return fmt.Errorf("elasticsearch error updating memory vector mapping %s.%s: %s", indexName, fieldName, reason) + } + return fmt.Errorf("elasticsearch returned error updating memory vector mapping %s.%s: %s, body: %s", indexName, fieldName, res.Status(), string(bodyBytes)) + } + + return nil +} + +func (e *elasticsearchEngine) ensureMemoryMessageVectorMappingsForDocs(ctx context.Context, indexName string, chunks []map[string]interface{}) error { + seen := map[int]struct{}{} + for _, chunk := range chunks { + for field := range chunk { + vectorSize, ok := parseMemoryMessageVectorSize(field) + if !ok { + continue + } + if _, ok := seen[vectorSize]; ok { + continue + } + if err := e.ensureMemoryMessageVectorMapping(ctx, indexName, vectorSize); err != nil { + return err + } + seen[vectorSize] = struct{}{} + } + } + return nil +} + +func (e *elasticsearchEngine) ensureMemoryMessageSearchVectorMappings(ctx context.Context, indexNames []string, vectorFieldName string, fallbackVectorSize int) error { + vectorSize, ok := parseMemoryMessageVectorSize(vectorFieldName) + if !ok { + vectorSize = fallbackVectorSize + } + if vectorSize <= 0 { + return fmt.Errorf("memory vector size must be positive, got %d", vectorSize) + } + + for _, indexName := range indexNames { + if !strings.HasPrefix(indexName, "memory_") { + continue + } + exists, err := e.indexExists(ctx, indexName) + if err != nil { + return fmt.Errorf("failed to check memory index existence: %w", err) + } + if !exists { + continue + } + if err := e.ensureMemoryMessageVectorMapping(ctx, indexName, vectorSize); err != nil { + return err + } + } + return nil +} + +func getMemoryMessageMapping(vectorSize int) map[string]interface{} { + vectorField := memoryMessageVectorField(vectorSize) + return map[string]interface{}{ + "settings": map[string]interface{}{ + "number_of_shards": 1, + "number_of_replicas": 0, + }, + "mappings": map[string]interface{}{ + "properties": map[string]interface{}{ + "id": map[string]interface{}{ + "type": "keyword", + }, + "doc_id": map[string]interface{}{ + "type": "keyword", + }, + "kb_id": map[string]interface{}{ + "type": "keyword", + }, + "memory_id": map[string]interface{}{ + "type": "keyword", + }, + "user_id": map[string]interface{}{ + "type": "keyword", + }, + "agent_id": map[string]interface{}{ + "type": "keyword", + }, + "session_id": map[string]interface{}{ + "type": "keyword", + }, + "message_id": map[string]interface{}{ + "type": "long", + }, + "source_id": map[string]interface{}{ + "type": "long", + }, + "message_type_kwd": map[string]interface{}{ + "type": "keyword", + }, + "status_int": map[string]interface{}{ + "type": "integer", + }, + "content": map[string]interface{}{ + "type": "text", + "index": false, + }, + "content_ltks": map[string]interface{}{ + "type": "text", + "analyzer": "whitespace", + }, + "tokenized_content_ltks": map[string]interface{}{ + "type": "text", + "analyzer": "whitespace", + }, + "valid_at": map[string]interface{}{ + "type": "date", + "format": "yyyy-MM-dd HH:mm:ss||strict_date_optional_time||epoch_millis", + }, + "invalid_at": map[string]interface{}{ + "type": "date", + "format": "yyyy-MM-dd HH:mm:ss||strict_date_optional_time||epoch_millis", + }, + "forget_at": map[string]interface{}{ + "type": "date", + "format": "yyyy-MM-dd HH:mm:ss||strict_date_optional_time||epoch_millis", + }, + vectorField: memoryMessageVectorProperty(vectorSize), + }, + }, + } +} + // getDefaultSkillMapping returns the default skill index mapping func getDefaultSkillMapping() map[string]interface{} { return map[string]interface{}{ diff --git a/internal/handler/auth.go b/internal/handler/auth.go index 096f2470c6..02d0d49777 100644 --- a/internal/handler/auth.go +++ b/internal/handler/auth.go @@ -52,6 +52,8 @@ func (h *AuthHandler) AuthMiddleware() gin.HandlerFunc { return } + authViaAPIToken := false + // Get user by access token user, code, err := h.userService.GetUserByToken(token) if err != nil { @@ -64,6 +66,7 @@ func (h *AuthHandler) AuthMiddleware() gin.HandlerFunc { c.Abort() return } + authViaAPIToken = true } if user.IsSuperuser != nil && *user.IsSuperuser { @@ -89,6 +92,7 @@ func (h *AuthHandler) AuthMiddleware() gin.HandlerFunc { c.Set("user", user) c.Set("user_id", user.ID) c.Set("email", user.Email) + c.Set("auth_via_api_token", authViaAPIToken) c.Next() } } diff --git a/internal/handler/memory.go b/internal/handler/memory.go index 219f5194c7..3c8a4c8288 100644 --- a/internal/handler/memory.go +++ b/internal/handler/memory.go @@ -580,6 +580,34 @@ func (h *MemoryHandler) GetMemoryMessages(c *gin.Context) { }) } +type messageMemoryIDs []string + +func (ids *messageMemoryIDs) UnmarshalJSON(data []byte) error { + var single string + if err := json.Unmarshal(data, &single); err == nil { + if strings.TrimSpace(single) != "" { + *ids = []string{single} + } + return nil + } + + var many []string + if err := json.Unmarshal(data, &many); err != nil { + return err + } + *ids = many + return nil +} + +type AddMessageRequest struct { + MemoryIDs messageMemoryIDs `json:"memory_id" binding:"required"` + AgentID string `json:"agent_id" binding:"required"` + SessionID string `json:"session_id" binding:"required"` + UserInput string `json:"user_input" binding:"required"` + AgentResponse string `json:"agent_response" binding:"required"` + UserID string `json:"user_id"` +} + // AddMessage handles POST request for adding messages // API Path: POST /api/v1/messages // @@ -595,12 +623,61 @@ func (h *MemoryHandler) GetMemoryMessages(c *gin.Context) { // - user_input (required): User input // - agent_response (required): Agent response // - user_id (optional): User ID -// -// TODO: Haruko386 is implementing this for now, if you implement this, delete this line plz func (h *MemoryHandler) AddMessage(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + currentUserID := strings.TrimSpace(user.ID) + if currentUserID == "" { + jsonError(c, common.CodeArgumentError, "user_id is required") + return + } + + var reqBody AddMessageRequest + if err := c.ShouldBindJSON(&reqBody); err != nil { + jsonError(c, common.CodeArgumentError, "body arguments is required") + return + } + if len(reqBody.MemoryIDs) == 0 { + jsonError(c, common.CodeArgumentError, "memory_id is required") + return + } + + effectiveUserID := currentUserID + if v, ok := c.Get("auth_via_api_token"); ok { + if authViaAPIToken, ok := v.(bool); authViaAPIToken && ok { + effectiveUserID = strings.TrimSpace(reqBody.UserID) + if effectiveUserID == "" { + jsonError(c, common.CodeArgumentError, "user_id is required") + return + } + } + } + + msg := service.MemoryMessage{ + UserID: effectiveUserID, + AgentID: reqBody.AgentID, + SessionID: reqBody.SessionID, + UserInput: reqBody.UserInput, + AgentResponse: reqBody.AgentResponse, + } + + ok, message, err := h.memoryService.AddMessage(c.Request.Context(), currentUserID, []string(reqBody.MemoryIDs), msg) + if err != nil || !ok { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeServerError, + "message": "Some messages failed to add. Detail:" + message, + "data": nil, + }) + return + } + c.JSON(http.StatusOK, gin.H{ - "code": common.CodeServerError, - "message": "AddMessage not implemented - pending embedding engine dependency", + "code": common.CodeSuccess, + "message": message, "data": nil, }) } @@ -741,7 +818,7 @@ func (h *MemoryHandler) UpdateMessage(c *gin.Context) { return } - ok, err = h.memoryService.UpdateMessage(c.Request.Context(), userID, memoryID, messageID, status) + ok, err = h.memoryService.UpdateMessageStatus(c.Request.Context(), userID, memoryID, messageID, status) if err != nil || !ok { if isMemoryServiceNotFound(err) { c.JSON(http.StatusOK, gin.H{ diff --git a/internal/router/router.go b/internal/router/router.go index e1f17cc551..0d590c923d 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -418,6 +418,7 @@ func (r *Router) Setup(engine *gin.Engine) { message := v1.Group("/messages") { message.GET("", r.memoryHandler.GetMessages) + message.POST("", r.memoryHandler.AddMessage) message.DELETE("/:memory_message", r.memoryHandler.ForgetMessage) message.PUT("/:memory_message", r.memoryHandler.UpdateMessage) message.GET("/:memory_message/content", r.memoryHandler.GetMessageContent) diff --git a/internal/service/memory.go b/internal/service/memory.go index 87855014b9..19d28ec040 100644 --- a/internal/service/memory.go +++ b/internal/service/memory.go @@ -816,7 +816,93 @@ func (s *MemoryService) ForgetMessage(ctx context.Context, userID string, memory return nil } -func (s *MemoryService) UpdateMessage(ctx context.Context, userID, memoryID string, messageID int64, status bool) (bool, error) { +// AddMessage filters inaccessible memories and queues the raw message for +// memory extraction. The currentUserID is used only for access control; msg.UserID +// is the attribution stored on the message and may differ for API-token callers. +func (s *MemoryService) AddMessage(ctx context.Context, currentUserID string, memoryIDs []string, msg MemoryMessage) (bool, string, error) { + requestedMemoryIDs := splitFilterValues(memoryIDs) + memories, err := s.filterAccessibleMemories(ctx, currentUserID, memoryIDs) + if err != nil { + return false, err.Error(), err + } + if len(memories) == 0 { + return false, "Memory not found.", nil + } + + accessibleMemoryIDs := make([]string, 0, len(memories)) + for _, memory := range memories { + if memory != nil { + accessibleMemoryIDs = append(accessibleMemoryIDs, memory.ID) + } + } + missingMemoryIDs := missingRequestedMemoryIDs(requestedMemoryIDs, accessibleMemoryIDs) + + res, err := NewMemoryMessageService(s).QueueSaveToMemoryTask(ctx, accessibleMemoryIDs, msg) + if err != nil { + return false, err.Error(), err + } + + if len(missingMemoryIDs) > 0 { + if res == nil { + res = &QueueSaveResult{} + } + res.NotFound = append(missingMemoryIDs, res.NotFound...) + } + errorMsg := memorySaveErrorMessage(res) + if errorMsg != "" { + return false, errorMsg, nil + } + return true, "All add to task.", nil +} + +func missingRequestedMemoryIDs(requestedMemoryIDs, accessibleMemoryIDs []string) []string { + if len(requestedMemoryIDs) == 0 { + return []string{} + } + + accessibleSet := make(map[string]struct{}, len(accessibleMemoryIDs)) + for _, memoryID := range accessibleMemoryIDs { + memoryID = strings.TrimSpace(memoryID) + if memoryID != "" { + accessibleSet[memoryID] = struct{}{} + } + } + + missingIDs := make([]string, 0) + seenMissing := make(map[string]struct{}) + for _, memoryID := range requestedMemoryIDs { + memoryID = strings.TrimSpace(memoryID) + if memoryID == "" { + continue + } + if _, ok := accessibleSet[memoryID]; ok { + continue + } + if _, ok := seenMissing[memoryID]; ok { + continue + } + missingIDs = append(missingIDs, memoryID) + seenMissing[memoryID] = struct{}{} + } + return missingIDs +} + +func memorySaveErrorMessage(res *QueueSaveResult) string { + if res == nil { + return "" + } + + var b strings.Builder + if len(res.NotFound) > 0 { + b.WriteString(fmt.Sprintf("Memory %v not found.", res.NotFound)) + } + for _, failed := range res.Failed { + b.WriteString(fmt.Sprintf("Memory %s failed. Detail: %s", failed.MemoryID, failed.FailMsg)) + } + return b.String() +} + +func (s *MemoryService) UpdateMessageStatus(ctx context.Context, userID, memoryID string, messageID int64, status bool) (bool, error) { memory, err := s.requireMemoryAccess(ctx, userID, memoryID) if err != nil { return false, err @@ -848,6 +934,10 @@ func (s *MemoryService) UpdateMessage(ctx context.Context, userID, memoryID stri return true, nil } +func (s *MemoryService) UpdateMessage(ctx context.Context, userID, memoryID string, messageID int64, status bool) (bool, error) { + return s.UpdateMessageStatus(ctx, userID, memoryID, messageID, status) +} + func (s *MemoryService) GetMessageContent(ctx context.Context, userID, memoryID string, messageID int64) (map[string]interface{}, error) { memory, err := s.requireMemoryAccess(ctx, userID, memoryID) if err != nil { @@ -1650,15 +1740,6 @@ func memoryMessageKey(value interface{}) string { return strings.TrimSpace(fmt.Sprint(value)) } -// TODO: queryMessages - Implementation pending - depends on CanvasService and TaskService -// func (s *MemoryService) queryMessages(tenantID string, memoryID string, filterDict map[string]interface{}, page int, pageSize int) ([]map[string]interface{}, int64, error) { ... } - -// TODO: AddMessage - Implementation pending - depends on embedding engine -// func (s *MemoryService) AddMessage(memoryIDs []string, messageDict map[string]interface{}) (bool, string, error) { ... } - -// TODO: UpdateMessageStatus - Implementation pending - depends on embedding engine -// func (s *MemoryService) UpdateMessageStatus(memoryID string, messageID int, status bool) (bool, error) { ... } - // isList checks if a value is a list or array type // This is a utility function for type validation // diff --git a/internal/service/memory_message_service.go b/internal/service/memory_message_service.go index 427b7599e3..e7e82ed4d3 100644 --- a/internal/service/memory_message_service.go +++ b/internal/service/memory_message_service.go @@ -14,33 +14,31 @@ // limitations under the License. // -// memory_message_service.go — Phase 8b real MemorySaver port. +// memory_message_service.go — real MemorySaver port. // // Port of api.db.joint_services.memory_message_service.queue_save_to_memory_task -// from the Python runtime. The Go port is a partial implementation -// (synchronous parts land here; the embedding-model call is loud-failed -// with ErrEmbedderNotWired until a Go embedding port lands). +// from the Python runtime. // // Python signature (api/db/joint_services/memory_message_service.py:344): // -// async def queue_save_to_memory_task( -// memory_ids: list[str], -// message_dict: dict, -// ) -> tuple[list[str], list[dict]] -// # (not_found_memory, failed_memory) +// async def queue_save_to_memory_task( +// memory_ids: list[str], +// message_dict: dict, +// ) -> tuple[list[str], list[dict]] +// # (not_found_memory, failed_memory) // // Go equivalent: // -// type QueueSaveResult struct { -// NotFound []string -// Failed []MemoryFailure -// } +// type QueueSaveResult struct { +// NotFound []string +// Failed []MemoryFailure +// } // -// func (s *MemoryMessageService) QueueSaveToMemoryTask( -// ctx context.Context, -// memoryIDs []string, -// msg MemoryMessage, -// ) (*QueueSaveResult, error) +// func (s *MemoryMessageService) QueueSaveToMemoryTask( +// ctx context.Context, +// memoryIDs []string, +// msg MemoryMessage, +// ) (*QueueSaveResult, error) // // The function is the entry point the Message component calls // after a conversation turn when `memory_save=true` is set. It @@ -49,14 +47,9 @@ // 1. For each memory id: look up the Memory (via MemoryService). // 2. Generate a raw_message_id from Redis auto-increment (namespace "memory"). // 3. Build the raw_message envelope (mirrors Python:344-386). -// 4. Call embed_and_save on the memory + [raw_message]. ← DEFERRED +// 4. Call embed_and_save on the memory + [raw_message]. // 5. Insert a Task row in the task table for the async extractor. // 6. Return not-found + failed lists. -// -// Steps 1, 2, 3, 5, 6 are implemented here. Step 4 (the -// embed_and_save call) is wrapped in a loud-fail gate that returns -// ErrEmbedderNotWired until the Go embedding layer ships. - package service import ( @@ -64,6 +57,12 @@ import ( "errors" "fmt" "time" + + "ragflow/internal/common" + "ragflow/internal/dao" + redisengine "ragflow/internal/engine/redis" + "ragflow/internal/entity" + models "ragflow/internal/entity/models" ) // ErrEmbedderNotWired is returned by QueueSaveToMemoryTask when @@ -96,9 +95,7 @@ type MemoryMessage struct { AgentResponse string } -// MemoryFailure describes one memory that failed to save. The -// FailMsg is the underlying error (or "embedder not wired" until -// the embedder port lands). +// MemoryFailure describes one memory that failed to save. type MemoryFailure struct { MemoryID string FailMsg string @@ -112,11 +109,10 @@ type QueueSaveResult struct { } // MemoryMessageService is the Go port of -// api.db.joint_services.memory_message_service. It depends on a -// MemoryService instance for the lookup; the embedder is -// hard-coded to loud-fail (see ErrEmbedderNotWired). +// api.db.joint_services.memory_message_service. type MemoryMessageService struct { memories *MemoryService + taskDAO *dao.TaskDAO } // NewMemoryMessageService constructs a service bound to the @@ -124,15 +120,17 @@ type MemoryMessageService struct { // the default MemorySaver in the Message component via // `component.SetMemorySaver(...)` at boot. func NewMemoryMessageService(memories *MemoryService) *MemoryMessageService { - return &MemoryMessageService{memories: memories} + return &MemoryMessageService{ + memories: memories, + taskDAO: dao.NewTaskDAO(), + } } // QueueSaveToMemoryTask runs the memory-persistence flow for the // supplied memory_ids + message. See package comment for the // step-by-step contract. The function is synchronous — the Python -// async version awaits `embed_and_save` and a Redis call; both are -// replaced here with synchronous equivalents (and a loud-fail -// embedder). +// async version awaits `embed_and_save` and Redis calls; this Go port does the +// same work synchronously from the HTTP request path. // // Returned QueueSaveResult has NotFound / Failed populated for // the per-memory outcomes. The outer error is reserved for @@ -168,11 +166,7 @@ func (s *MemoryMessageService) QueueSaveToMemoryTask( rawMessageID := generateRawMessageID() rawMessage := buildRawMessage(rawMessageID, memoryID, mem, msg) - // (4) embed_and_save. Loud-fail: the embedder is the - // only step the Go runtime can't do yet. When it - // ships, replace this branch with a call into - // internal/rag/llm/embedding_model. - if err := embedAndSave(ctx, mem, rawMessage); err != nil { + if err := s.embedAndSave(ctx, mem, rawMessage); err != nil { res.Failed = append(res.Failed, MemoryFailure{ MemoryID: memoryID, FailMsg: err.Error(), @@ -180,13 +174,6 @@ func (s *MemoryMessageService) QueueSaveToMemoryTask( continue } - // (5) Task row insertion. The Python side bulk-inserts - // a Task row with digest=str(raw_message_id); the - // extractor's task_type is "memory". The Go port - // constructs the same row shape and defers the actual - // insert to TaskDAO when the project adds one (today - // TaskDAO is in internal/dao; the API mirrors the - // Python Task entity closely). task := buildTaskRow(rawMessageID, memoryID) if err := s.insertTask(ctx, task); err != nil { res.Failed = append(res.Failed, MemoryFailure{ @@ -195,21 +182,25 @@ func (s *MemoryMessageService) QueueSaveToMemoryTask( }) continue } + if err := queueMemoryTask(memoryID, mem.TenantID, rawMessageID, task, msg); err != nil { + res.Failed = append(res.Failed, MemoryFailure{ + MemoryID: memoryID, + FailMsg: err.Error(), + }) + } } return res, nil } -// generateRawMessageID is a placeholder for the Redis auto- -// increment the Python side uses (`REDIS_CONN.generate_auto_increment_id -// (namespace="memory")`). The Go port generates a UUID-shaped -// integer now; replace with a Redis-backed counter when the -// project's Redis client lands. +// generateRawMessageID returns the Redis auto-increment id used by the Python +// side (`REDIS_CONN.generate_auto_increment_id(namespace="memory")`). func generateRawMessageID() int64 { - // seconds-since-epoch is unique enough for the Go port's - // own bookkeeping. The Redis-backed counter is the source - // of truth in production; this fallback only matters for - // the tests that don't need cross-process uniqueness. - return time.Now().Unix() + if redisClient := redisengine.Get(); redisClient != nil { + if id := redisClient.GenerateAutoIncrementID("id_generator", "memory", 1, nil); id > 0 { + return id + } + } + return time.Now().UnixNano() } // buildRawMessage constructs the raw_message envelope that gets @@ -224,18 +215,22 @@ func buildRawMessage( content := fmt.Sprintf("User Input: %s\nAgent Response: %s", msg.UserInput, msg.AgentResponse) out := map[string]any{ - "message_id": rawMessageID, - "message_type": "raw", - "source_id": 0, - "memory_id": memoryID, - "user_id": msg.UserID, - "agent_id": msg.AgentID, - "session_id": msg.SessionID, - "content": content, - "valid_at": time.Now().UTC().Format(time.RFC3339), - "invalid_at": nil, - "forget_at": nil, - "status": true, + "message_id": rawMessageID, + "message_type": "raw", + "message_type_kwd": "raw", + "source_id": 0, + "memory_id": memoryID, + "user_id": msg.UserID, + "agent_id": msg.AgentID, + "session_id": msg.SessionID, + "content": content, + "content_ltks": content, + "tokenized_content_ltks": content, + "valid_at": time.Now().UTC().Format("2006-01-02 15:04:05"), + "invalid_at": nil, + "forget_at": nil, + "status": true, + "status_int": 1, } if mem != nil { // The embedder uses the memory's embd_id; keep the @@ -253,29 +248,125 @@ func buildTaskRow(rawMessageID int64, memoryID string) map[string]any { "doc_id": memoryID, "task_type": "memory", "progress": 0.0, + "begin_at": time.Now(), "digest": fmt.Sprintf("%d", rawMessageID), } } -// embedAndSave is the deferred gate. Replace with a call to the -// real embedding model + memory_message table insert when those -// land. +func (s *MemoryMessageService) embedAndSave(ctx context.Context, mem *CreateMemoryResponse, rawMessage map[string]any) error { + if mem == nil { + return errors.New("memory not found") + } + if s == nil || s.memories == nil || s.memories.docEngine == nil { + return errors.New("message store is not initialized") + } + + content, _ := rawMessage["content"].(string) + driver, modelName, apiConfig, maxTokens, err := NewModelProviderService().GetModelConfigFromProviderInstance(mem.TenantID, entity.ModelTypeEmbedding, mem.EmbdID) + if err != nil { + return err + } + embeddingModel := models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens) + embeddings, err := embeddingModel.ModelDriver.Embed(embeddingModel.ModelName, []string{content}, embeddingModel.APIConfig, &models.EmbeddingConfig{Dimension: 0}) + if err != nil { + return err + } + if len(embeddings) == 0 || len(embeddings[0].Embedding) == 0 { + return errors.New("embedding response is empty") + } + + vector := embeddings[0].Embedding + rawMessage[fmt.Sprintf("q_%d_vec", len(vector))] = vector + rawMessage["id"] = fmt.Sprintf("%s_%d", rawMessage["memory_id"], rawMessage["message_id"]) + rawMessage["doc_id"] = rawMessage["memory_id"] + + indexName := memoryIndexName(mem.TenantID) + exists, err := s.memories.docEngine.ChunkStoreExists(ctx, indexName, mem.ID) + if err != nil { + return fmt.Errorf("check message index: %w", err) + } + if !exists { + if err := s.memories.docEngine.CreateChunkStore(ctx, indexName, mem.ID, len(vector), ""); err != nil { + return fmt.Errorf("create message index: %w", err) + } + } + if _, err := s.memories.docEngine.InsertChunks(ctx, []map[string]interface{}{mapStringAny(rawMessage)}, indexName, mem.ID); err != nil { + return fmt.Errorf("insert message into memory: %w", err) + } + + return nil +} + +// embedAndSave is kept for older unit tests; production uses the method above. func embedAndSave(_ context.Context, _ *CreateMemoryResponse, _ map[string]any) error { return ErrEmbedderNotWired } -// insertTask is a placeholder for the bulk_insert_into_db call -// the Python side makes. The Go side needs a TaskDAO write path -// (the Python Task entity is mirrored in internal/entity); until -// that lands this is a no-op that returns nil so the rest of -// the flow can be exercised. -func (s *MemoryMessageService) insertTask(_ context.Context, _ map[string]any) error { - return nil +func (s *MemoryMessageService) insertTask(_ context.Context, row map[string]any) error { + if s == nil { + return errors.New("nil MemoryMessageService") + } + if s.taskDAO == nil { + s.taskDAO = dao.NewTaskDAO() + } + return s.taskDAO.Create(taskFromRow(row)) } // newUUIDString is a thin wrapper so we can swap in a real UUID // generator later without changing call sites. Avoids an // import-cycle with internal/uuid at the package boundary. func newUUIDString() string { - return fmt.Sprintf("mem-%d", time.Now().UnixNano()) + return common.GenerateUUID() +} + +func taskFromRow(row map[string]any) *entity.Task { + digest := fmt.Sprint(row["digest"]) + beginAt, _ := row["begin_at"].(time.Time) + if beginAt.IsZero() { + now := time.Now() + beginAt = now + } + return &entity.Task{ + ID: fmt.Sprint(row["id"]), + DocID: fmt.Sprint(row["doc_id"]), + TaskType: fmt.Sprint(row["task_type"]), + Progress: 0, + BeginAt: &beginAt, + Digest: &digest, + } +} + +func queueMemoryTask(memoryID, tenantID string, rawMessageID int64, task map[string]any, msg MemoryMessage) error { + taskID := fmt.Sprint(task["id"]) + message := map[string]any{ + "id": taskID, + "task_id": taskID, + "task_type": task["task_type"], + "memory_id": memoryID, + "tenant_id": tenantID, + "source_id": rawMessageID, + "message_dict": map[string]any{ + "user_id": msg.UserID, + "agent_id": msg.AgentID, + "session_id": msg.SessionID, + "user_input": msg.UserInput, + "agent_response": msg.AgentResponse, + }, + } + if redisClient := redisengine.Get(); redisClient == nil || !redisClient.QueueProduct(memoryTaskQueueName(0), message) { + return errors.New("Can't access Redis.") + } + return nil +} + +func memoryTaskQueueName(priority int) string { + return fmt.Sprintf("te.%d.common", priority) +} + +func mapStringAny(in map[string]any) map[string]interface{} { + out := make(map[string]interface{}, len(in)) + for k, v := range in { + out[k] = v + } + return out } diff --git a/internal/service/memory_message_service_test.go b/internal/service/memory_message_service_test.go index 493e95aa19..a99a931bb3 100644 --- a/internal/service/memory_message_service_test.go +++ b/internal/service/memory_message_service_test.go @@ -142,8 +142,8 @@ func TestBuildTaskRow_Shape(t *testing.T) { if row["digest"] != "99" { t.Errorf("digest = %v, want \"99\"", row["digest"]) } - if id, _ := row["id"].(string); !strings.HasPrefix(id, "mem-") { - t.Errorf("id = %q, want mem- prefix", id) + if id, _ := row["id"].(string); len(id) != 32 { + t.Errorf("id = %q, want 32-char uuid", id) } }