feat[Go]: implement Add messages for Go (#16375)

### What problem does this PR solve?

As title

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Haruko386
2026-06-26 19:21:52 +08:00
committed by GitHub
parent f763044889
commit a1f1dd5007
7 changed files with 605 additions and 97 deletions

View File

@@ -67,13 +67,22 @@ func (e *elasticsearchEngine) CreateChunkStore(ctx context.Context, baseName, da
return fmt.Errorf("failed to check index existence: %w", err)
}
if exists {
if strings.HasPrefix(baseName, "memory_") {
if err := e.ensureMemoryMessageVectorMapping(ctx, baseName, vectorSize); err != nil {
return fmt.Errorf("failed to ensure memory vector mapping: %w", err)
}
common.Info("Memory index already exists, ensured vector mapping", zap.String("index_name", baseName), zap.Int("vector_size", vectorSize))
return nil
}
common.Info("Index already exists, skipping creation", zap.String("index_name", baseName))
return nil
}
// Load mapping based on index type
var mapping map[string]interface{}
if datasetID == "skill" {
if strings.HasPrefix(baseName, "memory_") {
mapping = getMemoryMessageMapping(vectorSize)
} else if datasetID == "skill" {
// Load skill-specific mapping
skillMapping, err := loadSkillMapping()
if err != nil {
@@ -149,6 +158,12 @@ func (e *elasticsearchEngine) InsertChunks(ctx context.Context, chunks []map[str
return nil, fmt.Errorf("index name cannot be empty")
}
if strings.HasPrefix(baseName, "memory_") {
if err := e.ensureMemoryMessageVectorMappingsForDocs(ctx, baseName, chunks); err != nil {
return nil, err
}
}
// Build bulk request body with index operations (upsert behavior: insert if not exists, update if exists)
var buf bytes.Buffer
for _, doc := range chunks {
@@ -901,6 +916,12 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque
hasVectorMatch := matchDense != nil && len(matchDense.EmbeddingData) > 0
if hasVectorMatch {
if isMemoryIndex {
if err := e.ensureMemoryMessageSearchVectorMappings(ctx, req.IndexNames, matchDense.VectorColumnName, len(matchDense.EmbeddingData)); err != nil {
return nil, err
}
}
k := matchDense.TopN
if k <= 0 {
k = limit
@@ -2287,6 +2308,239 @@ func loadSkillMapping() (map[string]interface{}, error) {
return mapping, nil
}
func memoryMessageVectorField(vectorSize int) string {
return fmt.Sprintf("q_%d_vec", vectorSize)
}
func memoryMessageVectorProperty(vectorSize int) map[string]interface{} {
return map[string]interface{}{
"type": "dense_vector",
"dims": vectorSize,
"index": true,
"similarity": "cosine",
}
}
func parseMemoryMessageVectorSize(field string) (int, bool) {
if !memoryMessageVectorFieldRE.MatchString(field) {
return 0, false
}
sizeText := strings.TrimSuffix(strings.TrimPrefix(field, "q_"), "_vec")
vectorSize, err := strconv.Atoi(sizeText)
if err != nil || vectorSize <= 0 {
return 0, false
}
return vectorSize, true
}
func (e *elasticsearchEngine) memoryMessageVectorMappingExists(ctx context.Context, indexName, fieldName string) (bool, error) {
req := esapi.IndicesGetMappingRequest{
Index: []string{indexName},
}
res, err := req.Do(ctx, e.client)
if err != nil {
return false, fmt.Errorf("failed to get memory vector mapping: %w", err)
}
defer res.Body.Close()
if res.StatusCode == http.StatusNotFound {
return false, nil
}
if res.IsError() {
bodyBytes, _ := io.ReadAll(res.Body)
reason := extractErrorReason(bodyBytes)
if reason != "" {
return false, fmt.Errorf("elasticsearch error getting memory vector mapping %s.%s: %s", indexName, fieldName, reason)
}
return false, fmt.Errorf("elasticsearch returned error getting memory vector mapping %s.%s: %s, body: %s", indexName, fieldName, res.Status(), string(bodyBytes))
}
var mappings map[string]interface{}
if err := json.NewDecoder(res.Body).Decode(&mappings); err != nil {
return false, fmt.Errorf("failed to decode memory vector mapping: %w", err)
}
indexMapping, ok := mappings[indexName].(map[string]interface{})
if !ok {
return false, nil
}
mapping, ok := indexMapping["mappings"].(map[string]interface{})
if !ok {
return false, nil
}
properties, ok := mapping["properties"].(map[string]interface{})
if !ok {
return false, nil
}
_, ok = properties[fieldName]
return ok, nil
}
func (e *elasticsearchEngine) ensureMemoryMessageVectorMapping(ctx context.Context, indexName string, vectorSize int) error {
if vectorSize <= 0 {
return fmt.Errorf("memory vector size must be positive, got %d", vectorSize)
}
fieldName := memoryMessageVectorField(vectorSize)
exists, err := e.memoryMessageVectorMappingExists(ctx, indexName, fieldName)
if err != nil {
return err
}
if exists {
return nil
}
body := map[string]interface{}{
"properties": map[string]interface{}{
fieldName: memoryMessageVectorProperty(vectorSize),
},
}
data, err := json.Marshal(body)
if err != nil {
return fmt.Errorf("failed to marshal memory vector mapping: %w", err)
}
req := esapi.IndicesPutMappingRequest{
Index: []string{indexName},
Body: bytes.NewReader(data),
}
res, err := req.Do(ctx, e.client)
if err != nil {
return fmt.Errorf("failed to update memory vector mapping: %w", err)
}
defer res.Body.Close()
if res.IsError() {
bodyBytes, _ := io.ReadAll(res.Body)
reason := extractErrorReason(bodyBytes)
if reason != "" {
return fmt.Errorf("elasticsearch error updating memory vector mapping %s.%s: %s", indexName, fieldName, reason)
}
return fmt.Errorf("elasticsearch returned error updating memory vector mapping %s.%s: %s, body: %s", indexName, fieldName, res.Status(), string(bodyBytes))
}
return nil
}
func (e *elasticsearchEngine) ensureMemoryMessageVectorMappingsForDocs(ctx context.Context, indexName string, chunks []map[string]interface{}) error {
seen := map[int]struct{}{}
for _, chunk := range chunks {
for field := range chunk {
vectorSize, ok := parseMemoryMessageVectorSize(field)
if !ok {
continue
}
if _, ok := seen[vectorSize]; ok {
continue
}
if err := e.ensureMemoryMessageVectorMapping(ctx, indexName, vectorSize); err != nil {
return err
}
seen[vectorSize] = struct{}{}
}
}
return nil
}
func (e *elasticsearchEngine) ensureMemoryMessageSearchVectorMappings(ctx context.Context, indexNames []string, vectorFieldName string, fallbackVectorSize int) error {
vectorSize, ok := parseMemoryMessageVectorSize(vectorFieldName)
if !ok {
vectorSize = fallbackVectorSize
}
if vectorSize <= 0 {
return fmt.Errorf("memory vector size must be positive, got %d", vectorSize)
}
for _, indexName := range indexNames {
if !strings.HasPrefix(indexName, "memory_") {
continue
}
exists, err := e.indexExists(ctx, indexName)
if err != nil {
return fmt.Errorf("failed to check memory index existence: %w", err)
}
if !exists {
continue
}
if err := e.ensureMemoryMessageVectorMapping(ctx, indexName, vectorSize); err != nil {
return err
}
}
return nil
}
func getMemoryMessageMapping(vectorSize int) map[string]interface{} {
vectorField := memoryMessageVectorField(vectorSize)
return map[string]interface{}{
"settings": map[string]interface{}{
"number_of_shards": 1,
"number_of_replicas": 0,
},
"mappings": map[string]interface{}{
"properties": map[string]interface{}{
"id": map[string]interface{}{
"type": "keyword",
},
"doc_id": map[string]interface{}{
"type": "keyword",
},
"kb_id": map[string]interface{}{
"type": "keyword",
},
"memory_id": map[string]interface{}{
"type": "keyword",
},
"user_id": map[string]interface{}{
"type": "keyword",
},
"agent_id": map[string]interface{}{
"type": "keyword",
},
"session_id": map[string]interface{}{
"type": "keyword",
},
"message_id": map[string]interface{}{
"type": "long",
},
"source_id": map[string]interface{}{
"type": "long",
},
"message_type_kwd": map[string]interface{}{
"type": "keyword",
},
"status_int": map[string]interface{}{
"type": "integer",
},
"content": map[string]interface{}{
"type": "text",
"index": false,
},
"content_ltks": map[string]interface{}{
"type": "text",
"analyzer": "whitespace",
},
"tokenized_content_ltks": map[string]interface{}{
"type": "text",
"analyzer": "whitespace",
},
"valid_at": map[string]interface{}{
"type": "date",
"format": "yyyy-MM-dd HH:mm:ss||strict_date_optional_time||epoch_millis",
},
"invalid_at": map[string]interface{}{
"type": "date",
"format": "yyyy-MM-dd HH:mm:ss||strict_date_optional_time||epoch_millis",
},
"forget_at": map[string]interface{}{
"type": "date",
"format": "yyyy-MM-dd HH:mm:ss||strict_date_optional_time||epoch_millis",
},
vectorField: memoryMessageVectorProperty(vectorSize),
},
},
}
}
// getDefaultSkillMapping returns the default skill index mapping
func getDefaultSkillMapping() map[string]interface{} {
return map[string]interface{}{

View File

@@ -52,6 +52,8 @@ func (h *AuthHandler) AuthMiddleware() gin.HandlerFunc {
return
}
authViaAPIToken := false
// Get user by access token
user, code, err := h.userService.GetUserByToken(token)
if err != nil {
@@ -64,6 +66,7 @@ func (h *AuthHandler) AuthMiddleware() gin.HandlerFunc {
c.Abort()
return
}
authViaAPIToken = true
}
if user.IsSuperuser != nil && *user.IsSuperuser {
@@ -89,6 +92,7 @@ func (h *AuthHandler) AuthMiddleware() gin.HandlerFunc {
c.Set("user", user)
c.Set("user_id", user.ID)
c.Set("email", user.Email)
c.Set("auth_via_api_token", authViaAPIToken)
c.Next()
}
}

View File

@@ -580,6 +580,34 @@ func (h *MemoryHandler) GetMemoryMessages(c *gin.Context) {
})
}
type messageMemoryIDs []string
func (ids *messageMemoryIDs) UnmarshalJSON(data []byte) error {
var single string
if err := json.Unmarshal(data, &single); err == nil {
if strings.TrimSpace(single) != "" {
*ids = []string{single}
}
return nil
}
var many []string
if err := json.Unmarshal(data, &many); err != nil {
return err
}
*ids = many
return nil
}
type AddMessageRequest struct {
MemoryIDs messageMemoryIDs `json:"memory_id" binding:"required"`
AgentID string `json:"agent_id" binding:"required"`
SessionID string `json:"session_id" binding:"required"`
UserInput string `json:"user_input" binding:"required"`
AgentResponse string `json:"agent_response" binding:"required"`
UserID string `json:"user_id"`
}
// AddMessage handles POST request for adding messages
// API Path: POST /api/v1/messages
//
@@ -595,12 +623,61 @@ func (h *MemoryHandler) GetMemoryMessages(c *gin.Context) {
// - user_input (required): User input
// - agent_response (required): Agent response
// - user_id (optional): User ID
//
// TODO: Haruko386 is implementing this for now, if you implement this, delete this line plz
func (h *MemoryHandler) AddMessage(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
currentUserID := strings.TrimSpace(user.ID)
if currentUserID == "" {
jsonError(c, common.CodeArgumentError, "user_id is required")
return
}
var reqBody AddMessageRequest
if err := c.ShouldBindJSON(&reqBody); err != nil {
jsonError(c, common.CodeArgumentError, "body arguments is required")
return
}
if len(reqBody.MemoryIDs) == 0 {
jsonError(c, common.CodeArgumentError, "memory_id is required")
return
}
effectiveUserID := currentUserID
if v, ok := c.Get("auth_via_api_token"); ok {
if authViaAPIToken, ok := v.(bool); authViaAPIToken && ok {
effectiveUserID = strings.TrimSpace(reqBody.UserID)
if effectiveUserID == "" {
jsonError(c, common.CodeArgumentError, "user_id is required")
return
}
}
}
msg := service.MemoryMessage{
UserID: effectiveUserID,
AgentID: reqBody.AgentID,
SessionID: reqBody.SessionID,
UserInput: reqBody.UserInput,
AgentResponse: reqBody.AgentResponse,
}
ok, message, err := h.memoryService.AddMessage(c.Request.Context(), currentUserID, []string(reqBody.MemoryIDs), msg)
if err != nil || !ok {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": "Some messages failed to add. Detail:" + message,
"data": nil,
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeServerError,
"message": "AddMessage not implemented - pending embedding engine dependency",
"code": common.CodeSuccess,
"message": message,
"data": nil,
})
}
@@ -741,7 +818,7 @@ func (h *MemoryHandler) UpdateMessage(c *gin.Context) {
return
}
ok, err = h.memoryService.UpdateMessage(c.Request.Context(), userID, memoryID, messageID, status)
ok, err = h.memoryService.UpdateMessageStatus(c.Request.Context(), userID, memoryID, messageID, status)
if err != nil || !ok {
if isMemoryServiceNotFound(err) {
c.JSON(http.StatusOK, gin.H{

View File

@@ -418,6 +418,7 @@ func (r *Router) Setup(engine *gin.Engine) {
message := v1.Group("/messages")
{
message.GET("", r.memoryHandler.GetMessages)
message.POST("", r.memoryHandler.AddMessage)
message.DELETE("/:memory_message", r.memoryHandler.ForgetMessage)
message.PUT("/:memory_message", r.memoryHandler.UpdateMessage)
message.GET("/:memory_message/content", r.memoryHandler.GetMessageContent)

View File

@@ -816,7 +816,93 @@ func (s *MemoryService) ForgetMessage(ctx context.Context, userID string, memory
return nil
}
func (s *MemoryService) UpdateMessage(ctx context.Context, userID, memoryID string, messageID int64, status bool) (bool, error) {
// AddMessage filters inaccessible memories and queues the raw message for
// memory extraction. The currentUserID is used only for access control; msg.UserID
// is the attribution stored on the message and may differ for API-token callers.
func (s *MemoryService) AddMessage(ctx context.Context, currentUserID string, memoryIDs []string, msg MemoryMessage) (bool, string, error) {
requestedMemoryIDs := splitFilterValues(memoryIDs)
memories, err := s.filterAccessibleMemories(ctx, currentUserID, memoryIDs)
if err != nil {
return false, err.Error(), err
}
if len(memories) == 0 {
return false, "Memory not found.", nil
}
accessibleMemoryIDs := make([]string, 0, len(memories))
for _, memory := range memories {
if memory != nil {
accessibleMemoryIDs = append(accessibleMemoryIDs, memory.ID)
}
}
missingMemoryIDs := missingRequestedMemoryIDs(requestedMemoryIDs, accessibleMemoryIDs)
res, err := NewMemoryMessageService(s).QueueSaveToMemoryTask(ctx, accessibleMemoryIDs, msg)
if err != nil {
return false, err.Error(), err
}
if len(missingMemoryIDs) > 0 {
if res == nil {
res = &QueueSaveResult{}
}
res.NotFound = append(missingMemoryIDs, res.NotFound...)
}
errorMsg := memorySaveErrorMessage(res)
if errorMsg != "" {
return false, errorMsg, nil
}
return true, "All add to task.", nil
}
func missingRequestedMemoryIDs(requestedMemoryIDs, accessibleMemoryIDs []string) []string {
if len(requestedMemoryIDs) == 0 {
return []string{}
}
accessibleSet := make(map[string]struct{}, len(accessibleMemoryIDs))
for _, memoryID := range accessibleMemoryIDs {
memoryID = strings.TrimSpace(memoryID)
if memoryID != "" {
accessibleSet[memoryID] = struct{}{}
}
}
missingIDs := make([]string, 0)
seenMissing := make(map[string]struct{})
for _, memoryID := range requestedMemoryIDs {
memoryID = strings.TrimSpace(memoryID)
if memoryID == "" {
continue
}
if _, ok := accessibleSet[memoryID]; ok {
continue
}
if _, ok := seenMissing[memoryID]; ok {
continue
}
missingIDs = append(missingIDs, memoryID)
seenMissing[memoryID] = struct{}{}
}
return missingIDs
}
func memorySaveErrorMessage(res *QueueSaveResult) string {
if res == nil {
return ""
}
var b strings.Builder
if len(res.NotFound) > 0 {
b.WriteString(fmt.Sprintf("Memory %v not found.", res.NotFound))
}
for _, failed := range res.Failed {
b.WriteString(fmt.Sprintf("Memory %s failed. Detail: %s", failed.MemoryID, failed.FailMsg))
}
return b.String()
}
func (s *MemoryService) UpdateMessageStatus(ctx context.Context, userID, memoryID string, messageID int64, status bool) (bool, error) {
memory, err := s.requireMemoryAccess(ctx, userID, memoryID)
if err != nil {
return false, err
@@ -848,6 +934,10 @@ func (s *MemoryService) UpdateMessage(ctx context.Context, userID, memoryID stri
return true, nil
}
func (s *MemoryService) UpdateMessage(ctx context.Context, userID, memoryID string, messageID int64, status bool) (bool, error) {
return s.UpdateMessageStatus(ctx, userID, memoryID, messageID, status)
}
func (s *MemoryService) GetMessageContent(ctx context.Context, userID, memoryID string, messageID int64) (map[string]interface{}, error) {
memory, err := s.requireMemoryAccess(ctx, userID, memoryID)
if err != nil {
@@ -1650,15 +1740,6 @@ func memoryMessageKey(value interface{}) string {
return strings.TrimSpace(fmt.Sprint(value))
}
// TODO: queryMessages - Implementation pending - depends on CanvasService and TaskService
// func (s *MemoryService) queryMessages(tenantID string, memoryID string, filterDict map[string]interface{}, page int, pageSize int) ([]map[string]interface{}, int64, error) { ... }
// TODO: AddMessage - Implementation pending - depends on embedding engine
// func (s *MemoryService) AddMessage(memoryIDs []string, messageDict map[string]interface{}) (bool, string, error) { ... }
// TODO: UpdateMessageStatus - Implementation pending - depends on embedding engine
// func (s *MemoryService) UpdateMessageStatus(memoryID string, messageID int, status bool) (bool, error) { ... }
// isList checks if a value is a list or array type
// This is a utility function for type validation
//

View File

@@ -14,33 +14,31 @@
// limitations under the License.
//
// memory_message_service.go — Phase 8b real MemorySaver port.
// memory_message_service.go — real MemorySaver port.
//
// Port of api.db.joint_services.memory_message_service.queue_save_to_memory_task
// from the Python runtime. The Go port is a partial implementation
// (synchronous parts land here; the embedding-model call is loud-failed
// with ErrEmbedderNotWired until a Go embedding port lands).
// from the Python runtime.
//
// Python signature (api/db/joint_services/memory_message_service.py:344):
//
// async def queue_save_to_memory_task(
// memory_ids: list[str],
// message_dict: dict,
// ) -> tuple[list[str], list[dict]]
// # (not_found_memory, failed_memory)
// async def queue_save_to_memory_task(
// memory_ids: list[str],
// message_dict: dict,
// ) -> tuple[list[str], list[dict]]
// # (not_found_memory, failed_memory)
//
// Go equivalent:
//
// type QueueSaveResult struct {
// NotFound []string
// Failed []MemoryFailure
// }
// type QueueSaveResult struct {
// NotFound []string
// Failed []MemoryFailure
// }
//
// func (s *MemoryMessageService) QueueSaveToMemoryTask(
// ctx context.Context,
// memoryIDs []string,
// msg MemoryMessage,
// ) (*QueueSaveResult, error)
// func (s *MemoryMessageService) QueueSaveToMemoryTask(
// ctx context.Context,
// memoryIDs []string,
// msg MemoryMessage,
// ) (*QueueSaveResult, error)
//
// The function is the entry point the Message component calls
// after a conversation turn when `memory_save=true` is set. It
@@ -49,14 +47,9 @@
// 1. For each memory id: look up the Memory (via MemoryService).
// 2. Generate a raw_message_id from Redis auto-increment (namespace "memory").
// 3. Build the raw_message envelope (mirrors Python:344-386).
// 4. Call embed_and_save on the memory + [raw_message]. ← DEFERRED
// 4. Call embed_and_save on the memory + [raw_message].
// 5. Insert a Task row in the task table for the async extractor.
// 6. Return not-found + failed lists.
//
// Steps 1, 2, 3, 5, 6 are implemented here. Step 4 (the
// embed_and_save call) is wrapped in a loud-fail gate that returns
// ErrEmbedderNotWired until the Go embedding layer ships.
package service
import (
@@ -64,6 +57,12 @@ import (
"errors"
"fmt"
"time"
"ragflow/internal/common"
"ragflow/internal/dao"
redisengine "ragflow/internal/engine/redis"
"ragflow/internal/entity"
models "ragflow/internal/entity/models"
)
// ErrEmbedderNotWired is returned by QueueSaveToMemoryTask when
@@ -96,9 +95,7 @@ type MemoryMessage struct {
AgentResponse string
}
// MemoryFailure describes one memory that failed to save. The
// FailMsg is the underlying error (or "embedder not wired" until
// the embedder port lands).
// MemoryFailure describes one memory that failed to save.
type MemoryFailure struct {
MemoryID string
FailMsg string
@@ -112,11 +109,10 @@ type QueueSaveResult struct {
}
// MemoryMessageService is the Go port of
// api.db.joint_services.memory_message_service. It depends on a
// MemoryService instance for the lookup; the embedder is
// hard-coded to loud-fail (see ErrEmbedderNotWired).
// api.db.joint_services.memory_message_service.
type MemoryMessageService struct {
memories *MemoryService
taskDAO *dao.TaskDAO
}
// NewMemoryMessageService constructs a service bound to the
@@ -124,15 +120,17 @@ type MemoryMessageService struct {
// the default MemorySaver in the Message component via
// `component.SetMemorySaver(...)` at boot.
func NewMemoryMessageService(memories *MemoryService) *MemoryMessageService {
return &MemoryMessageService{memories: memories}
return &MemoryMessageService{
memories: memories,
taskDAO: dao.NewTaskDAO(),
}
}
// QueueSaveToMemoryTask runs the memory-persistence flow for the
// supplied memory_ids + message. See package comment for the
// step-by-step contract. The function is synchronous — the Python
// async version awaits `embed_and_save` and a Redis call; both are
// replaced here with synchronous equivalents (and a loud-fail
// embedder).
// async version awaits `embed_and_save` and Redis calls; this Go port does the
// same work synchronously from the HTTP request path.
//
// Returned QueueSaveResult has NotFound / Failed populated for
// the per-memory outcomes. The outer error is reserved for
@@ -168,11 +166,7 @@ func (s *MemoryMessageService) QueueSaveToMemoryTask(
rawMessageID := generateRawMessageID()
rawMessage := buildRawMessage(rawMessageID, memoryID, mem, msg)
// (4) embed_and_save. Loud-fail: the embedder is the
// only step the Go runtime can't do yet. When it
// ships, replace this branch with a call into
// internal/rag/llm/embedding_model.
if err := embedAndSave(ctx, mem, rawMessage); err != nil {
if err := s.embedAndSave(ctx, mem, rawMessage); err != nil {
res.Failed = append(res.Failed, MemoryFailure{
MemoryID: memoryID,
FailMsg: err.Error(),
@@ -180,13 +174,6 @@ func (s *MemoryMessageService) QueueSaveToMemoryTask(
continue
}
// (5) Task row insertion. The Python side bulk-inserts
// a Task row with digest=str(raw_message_id); the
// extractor's task_type is "memory". The Go port
// constructs the same row shape and defers the actual
// insert to TaskDAO when the project adds one (today
// TaskDAO is in internal/dao; the API mirrors the
// Python Task entity closely).
task := buildTaskRow(rawMessageID, memoryID)
if err := s.insertTask(ctx, task); err != nil {
res.Failed = append(res.Failed, MemoryFailure{
@@ -195,21 +182,25 @@ func (s *MemoryMessageService) QueueSaveToMemoryTask(
})
continue
}
if err := queueMemoryTask(memoryID, mem.TenantID, rawMessageID, task, msg); err != nil {
res.Failed = append(res.Failed, MemoryFailure{
MemoryID: memoryID,
FailMsg: err.Error(),
})
}
}
return res, nil
}
// generateRawMessageID is a placeholder for the Redis auto-
// increment the Python side uses (`REDIS_CONN.generate_auto_increment_id
// (namespace="memory")`). The Go port generates a UUID-shaped
// integer now; replace with a Redis-backed counter when the
// project's Redis client lands.
// generateRawMessageID returns the Redis auto-increment id used by the Python
// side (`REDIS_CONN.generate_auto_increment_id(namespace="memory")`).
func generateRawMessageID() int64 {
// seconds-since-epoch is unique enough for the Go port's
// own bookkeeping. The Redis-backed counter is the source
// of truth in production; this fallback only matters for
// the tests that don't need cross-process uniqueness.
return time.Now().Unix()
if redisClient := redisengine.Get(); redisClient != nil {
if id := redisClient.GenerateAutoIncrementID("id_generator", "memory", 1, nil); id > 0 {
return id
}
}
return time.Now().UnixNano()
}
// buildRawMessage constructs the raw_message envelope that gets
@@ -224,18 +215,22 @@ func buildRawMessage(
content := fmt.Sprintf("User Input: %s\nAgent Response: %s",
msg.UserInput, msg.AgentResponse)
out := map[string]any{
"message_id": rawMessageID,
"message_type": "raw",
"source_id": 0,
"memory_id": memoryID,
"user_id": msg.UserID,
"agent_id": msg.AgentID,
"session_id": msg.SessionID,
"content": content,
"valid_at": time.Now().UTC().Format(time.RFC3339),
"invalid_at": nil,
"forget_at": nil,
"status": true,
"message_id": rawMessageID,
"message_type": "raw",
"message_type_kwd": "raw",
"source_id": 0,
"memory_id": memoryID,
"user_id": msg.UserID,
"agent_id": msg.AgentID,
"session_id": msg.SessionID,
"content": content,
"content_ltks": content,
"tokenized_content_ltks": content,
"valid_at": time.Now().UTC().Format("2006-01-02 15:04:05"),
"invalid_at": nil,
"forget_at": nil,
"status": true,
"status_int": 1,
}
if mem != nil {
// The embedder uses the memory's embd_id; keep the
@@ -253,29 +248,125 @@ func buildTaskRow(rawMessageID int64, memoryID string) map[string]any {
"doc_id": memoryID,
"task_type": "memory",
"progress": 0.0,
"begin_at": time.Now(),
"digest": fmt.Sprintf("%d", rawMessageID),
}
}
// embedAndSave is the deferred gate. Replace with a call to the
// real embedding model + memory_message table insert when those
// land.
func (s *MemoryMessageService) embedAndSave(ctx context.Context, mem *CreateMemoryResponse, rawMessage map[string]any) error {
if mem == nil {
return errors.New("memory not found")
}
if s == nil || s.memories == nil || s.memories.docEngine == nil {
return errors.New("message store is not initialized")
}
content, _ := rawMessage["content"].(string)
driver, modelName, apiConfig, maxTokens, err := NewModelProviderService().GetModelConfigFromProviderInstance(mem.TenantID, entity.ModelTypeEmbedding, mem.EmbdID)
if err != nil {
return err
}
embeddingModel := models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens)
embeddings, err := embeddingModel.ModelDriver.Embed(embeddingModel.ModelName, []string{content}, embeddingModel.APIConfig, &models.EmbeddingConfig{Dimension: 0})
if err != nil {
return err
}
if len(embeddings) == 0 || len(embeddings[0].Embedding) == 0 {
return errors.New("embedding response is empty")
}
vector := embeddings[0].Embedding
rawMessage[fmt.Sprintf("q_%d_vec", len(vector))] = vector
rawMessage["id"] = fmt.Sprintf("%s_%d", rawMessage["memory_id"], rawMessage["message_id"])
rawMessage["doc_id"] = rawMessage["memory_id"]
indexName := memoryIndexName(mem.TenantID)
exists, err := s.memories.docEngine.ChunkStoreExists(ctx, indexName, mem.ID)
if err != nil {
return fmt.Errorf("check message index: %w", err)
}
if !exists {
if err := s.memories.docEngine.CreateChunkStore(ctx, indexName, mem.ID, len(vector), ""); err != nil {
return fmt.Errorf("create message index: %w", err)
}
}
if _, err := s.memories.docEngine.InsertChunks(ctx, []map[string]interface{}{mapStringAny(rawMessage)}, indexName, mem.ID); err != nil {
return fmt.Errorf("insert message into memory: %w", err)
}
return nil
}
// embedAndSave is kept for older unit tests; production uses the method above.
func embedAndSave(_ context.Context, _ *CreateMemoryResponse, _ map[string]any) error {
return ErrEmbedderNotWired
}
// insertTask is a placeholder for the bulk_insert_into_db call
// the Python side makes. The Go side needs a TaskDAO write path
// (the Python Task entity is mirrored in internal/entity); until
// that lands this is a no-op that returns nil so the rest of
// the flow can be exercised.
func (s *MemoryMessageService) insertTask(_ context.Context, _ map[string]any) error {
return nil
func (s *MemoryMessageService) insertTask(_ context.Context, row map[string]any) error {
if s == nil {
return errors.New("nil MemoryMessageService")
}
if s.taskDAO == nil {
s.taskDAO = dao.NewTaskDAO()
}
return s.taskDAO.Create(taskFromRow(row))
}
// newUUIDString is a thin wrapper so we can swap in a real UUID
// generator later without changing call sites. Avoids an
// import-cycle with internal/uuid at the package boundary.
func newUUIDString() string {
return fmt.Sprintf("mem-%d", time.Now().UnixNano())
return common.GenerateUUID()
}
func taskFromRow(row map[string]any) *entity.Task {
digest := fmt.Sprint(row["digest"])
beginAt, _ := row["begin_at"].(time.Time)
if beginAt.IsZero() {
now := time.Now()
beginAt = now
}
return &entity.Task{
ID: fmt.Sprint(row["id"]),
DocID: fmt.Sprint(row["doc_id"]),
TaskType: fmt.Sprint(row["task_type"]),
Progress: 0,
BeginAt: &beginAt,
Digest: &digest,
}
}
func queueMemoryTask(memoryID, tenantID string, rawMessageID int64, task map[string]any, msg MemoryMessage) error {
taskID := fmt.Sprint(task["id"])
message := map[string]any{
"id": taskID,
"task_id": taskID,
"task_type": task["task_type"],
"memory_id": memoryID,
"tenant_id": tenantID,
"source_id": rawMessageID,
"message_dict": map[string]any{
"user_id": msg.UserID,
"agent_id": msg.AgentID,
"session_id": msg.SessionID,
"user_input": msg.UserInput,
"agent_response": msg.AgentResponse,
},
}
if redisClient := redisengine.Get(); redisClient == nil || !redisClient.QueueProduct(memoryTaskQueueName(0), message) {
return errors.New("Can't access Redis.")
}
return nil
}
func memoryTaskQueueName(priority int) string {
return fmt.Sprintf("te.%d.common", priority)
}
func mapStringAny(in map[string]any) map[string]interface{} {
out := make(map[string]interface{}, len(in))
for k, v := range in {
out[k] = v
}
return out
}

View File

@@ -142,8 +142,8 @@ func TestBuildTaskRow_Shape(t *testing.T) {
if row["digest"] != "99" {
t.Errorf("digest = %v, want \"99\"", row["digest"])
}
if id, _ := row["id"].(string); !strings.HasPrefix(id, "mem-") {
t.Errorf("id = %q, want mem- prefix", id)
if id, _ := row["id"].(string); len(id) != 32 {
t.Errorf("id = %q, want 32-char uuid", id)
}
}