mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
feat[Go]: implement Add messages for Go (#16375)
### What problem does this PR solve? As title ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -67,13 +67,22 @@ func (e *elasticsearchEngine) CreateChunkStore(ctx context.Context, baseName, da
|
||||
return fmt.Errorf("failed to check index existence: %w", err)
|
||||
}
|
||||
if exists {
|
||||
if strings.HasPrefix(baseName, "memory_") {
|
||||
if err := e.ensureMemoryMessageVectorMapping(ctx, baseName, vectorSize); err != nil {
|
||||
return fmt.Errorf("failed to ensure memory vector mapping: %w", err)
|
||||
}
|
||||
common.Info("Memory index already exists, ensured vector mapping", zap.String("index_name", baseName), zap.Int("vector_size", vectorSize))
|
||||
return nil
|
||||
}
|
||||
common.Info("Index already exists, skipping creation", zap.String("index_name", baseName))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load mapping based on index type
|
||||
var mapping map[string]interface{}
|
||||
if datasetID == "skill" {
|
||||
if strings.HasPrefix(baseName, "memory_") {
|
||||
mapping = getMemoryMessageMapping(vectorSize)
|
||||
} else if datasetID == "skill" {
|
||||
// Load skill-specific mapping
|
||||
skillMapping, err := loadSkillMapping()
|
||||
if err != nil {
|
||||
@@ -149,6 +158,12 @@ func (e *elasticsearchEngine) InsertChunks(ctx context.Context, chunks []map[str
|
||||
return nil, fmt.Errorf("index name cannot be empty")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(baseName, "memory_") {
|
||||
if err := e.ensureMemoryMessageVectorMappingsForDocs(ctx, baseName, chunks); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Build bulk request body with index operations (upsert behavior: insert if not exists, update if exists)
|
||||
var buf bytes.Buffer
|
||||
for _, doc := range chunks {
|
||||
@@ -901,6 +916,12 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque
|
||||
|
||||
hasVectorMatch := matchDense != nil && len(matchDense.EmbeddingData) > 0
|
||||
if hasVectorMatch {
|
||||
if isMemoryIndex {
|
||||
if err := e.ensureMemoryMessageSearchVectorMappings(ctx, req.IndexNames, matchDense.VectorColumnName, len(matchDense.EmbeddingData)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
k := matchDense.TopN
|
||||
if k <= 0 {
|
||||
k = limit
|
||||
@@ -2287,6 +2308,239 @@ func loadSkillMapping() (map[string]interface{}, error) {
|
||||
return mapping, nil
|
||||
}
|
||||
|
||||
func memoryMessageVectorField(vectorSize int) string {
|
||||
return fmt.Sprintf("q_%d_vec", vectorSize)
|
||||
}
|
||||
|
||||
func memoryMessageVectorProperty(vectorSize int) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"type": "dense_vector",
|
||||
"dims": vectorSize,
|
||||
"index": true,
|
||||
"similarity": "cosine",
|
||||
}
|
||||
}
|
||||
|
||||
func parseMemoryMessageVectorSize(field string) (int, bool) {
|
||||
if !memoryMessageVectorFieldRE.MatchString(field) {
|
||||
return 0, false
|
||||
}
|
||||
sizeText := strings.TrimSuffix(strings.TrimPrefix(field, "q_"), "_vec")
|
||||
vectorSize, err := strconv.Atoi(sizeText)
|
||||
if err != nil || vectorSize <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return vectorSize, true
|
||||
}
|
||||
|
||||
func (e *elasticsearchEngine) memoryMessageVectorMappingExists(ctx context.Context, indexName, fieldName string) (bool, error) {
|
||||
req := esapi.IndicesGetMappingRequest{
|
||||
Index: []string{indexName},
|
||||
}
|
||||
res, err := req.Do(ctx, e.client)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get memory vector mapping: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode == http.StatusNotFound {
|
||||
return false, nil
|
||||
}
|
||||
if res.IsError() {
|
||||
bodyBytes, _ := io.ReadAll(res.Body)
|
||||
reason := extractErrorReason(bodyBytes)
|
||||
if reason != "" {
|
||||
return false, fmt.Errorf("elasticsearch error getting memory vector mapping %s.%s: %s", indexName, fieldName, reason)
|
||||
}
|
||||
return false, fmt.Errorf("elasticsearch returned error getting memory vector mapping %s.%s: %s, body: %s", indexName, fieldName, res.Status(), string(bodyBytes))
|
||||
}
|
||||
|
||||
var mappings map[string]interface{}
|
||||
if err := json.NewDecoder(res.Body).Decode(&mappings); err != nil {
|
||||
return false, fmt.Errorf("failed to decode memory vector mapping: %w", err)
|
||||
}
|
||||
|
||||
indexMapping, ok := mappings[indexName].(map[string]interface{})
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
mapping, ok := indexMapping["mappings"].(map[string]interface{})
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
properties, ok := mapping["properties"].(map[string]interface{})
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
_, ok = properties[fieldName]
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (e *elasticsearchEngine) ensureMemoryMessageVectorMapping(ctx context.Context, indexName string, vectorSize int) error {
|
||||
if vectorSize <= 0 {
|
||||
return fmt.Errorf("memory vector size must be positive, got %d", vectorSize)
|
||||
}
|
||||
|
||||
fieldName := memoryMessageVectorField(vectorSize)
|
||||
exists, err := e.memoryMessageVectorMappingExists(ctx, indexName, fieldName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
body := map[string]interface{}{
|
||||
"properties": map[string]interface{}{
|
||||
fieldName: memoryMessageVectorProperty(vectorSize),
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal memory vector mapping: %w", err)
|
||||
}
|
||||
|
||||
req := esapi.IndicesPutMappingRequest{
|
||||
Index: []string{indexName},
|
||||
Body: bytes.NewReader(data),
|
||||
}
|
||||
res, err := req.Do(ctx, e.client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update memory vector mapping: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.IsError() {
|
||||
bodyBytes, _ := io.ReadAll(res.Body)
|
||||
reason := extractErrorReason(bodyBytes)
|
||||
if reason != "" {
|
||||
return fmt.Errorf("elasticsearch error updating memory vector mapping %s.%s: %s", indexName, fieldName, reason)
|
||||
}
|
||||
return fmt.Errorf("elasticsearch returned error updating memory vector mapping %s.%s: %s, body: %s", indexName, fieldName, res.Status(), string(bodyBytes))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *elasticsearchEngine) ensureMemoryMessageVectorMappingsForDocs(ctx context.Context, indexName string, chunks []map[string]interface{}) error {
|
||||
seen := map[int]struct{}{}
|
||||
for _, chunk := range chunks {
|
||||
for field := range chunk {
|
||||
vectorSize, ok := parseMemoryMessageVectorSize(field)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[vectorSize]; ok {
|
||||
continue
|
||||
}
|
||||
if err := e.ensureMemoryMessageVectorMapping(ctx, indexName, vectorSize); err != nil {
|
||||
return err
|
||||
}
|
||||
seen[vectorSize] = struct{}{}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *elasticsearchEngine) ensureMemoryMessageSearchVectorMappings(ctx context.Context, indexNames []string, vectorFieldName string, fallbackVectorSize int) error {
|
||||
vectorSize, ok := parseMemoryMessageVectorSize(vectorFieldName)
|
||||
if !ok {
|
||||
vectorSize = fallbackVectorSize
|
||||
}
|
||||
if vectorSize <= 0 {
|
||||
return fmt.Errorf("memory vector size must be positive, got %d", vectorSize)
|
||||
}
|
||||
|
||||
for _, indexName := range indexNames {
|
||||
if !strings.HasPrefix(indexName, "memory_") {
|
||||
continue
|
||||
}
|
||||
exists, err := e.indexExists(ctx, indexName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check memory index existence: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
if err := e.ensureMemoryMessageVectorMapping(ctx, indexName, vectorSize); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMemoryMessageMapping(vectorSize int) map[string]interface{} {
|
||||
vectorField := memoryMessageVectorField(vectorSize)
|
||||
return map[string]interface{}{
|
||||
"settings": map[string]interface{}{
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0,
|
||||
},
|
||||
"mappings": map[string]interface{}{
|
||||
"properties": map[string]interface{}{
|
||||
"id": map[string]interface{}{
|
||||
"type": "keyword",
|
||||
},
|
||||
"doc_id": map[string]interface{}{
|
||||
"type": "keyword",
|
||||
},
|
||||
"kb_id": map[string]interface{}{
|
||||
"type": "keyword",
|
||||
},
|
||||
"memory_id": map[string]interface{}{
|
||||
"type": "keyword",
|
||||
},
|
||||
"user_id": map[string]interface{}{
|
||||
"type": "keyword",
|
||||
},
|
||||
"agent_id": map[string]interface{}{
|
||||
"type": "keyword",
|
||||
},
|
||||
"session_id": map[string]interface{}{
|
||||
"type": "keyword",
|
||||
},
|
||||
"message_id": map[string]interface{}{
|
||||
"type": "long",
|
||||
},
|
||||
"source_id": map[string]interface{}{
|
||||
"type": "long",
|
||||
},
|
||||
"message_type_kwd": map[string]interface{}{
|
||||
"type": "keyword",
|
||||
},
|
||||
"status_int": map[string]interface{}{
|
||||
"type": "integer",
|
||||
},
|
||||
"content": map[string]interface{}{
|
||||
"type": "text",
|
||||
"index": false,
|
||||
},
|
||||
"content_ltks": map[string]interface{}{
|
||||
"type": "text",
|
||||
"analyzer": "whitespace",
|
||||
},
|
||||
"tokenized_content_ltks": map[string]interface{}{
|
||||
"type": "text",
|
||||
"analyzer": "whitespace",
|
||||
},
|
||||
"valid_at": map[string]interface{}{
|
||||
"type": "date",
|
||||
"format": "yyyy-MM-dd HH:mm:ss||strict_date_optional_time||epoch_millis",
|
||||
},
|
||||
"invalid_at": map[string]interface{}{
|
||||
"type": "date",
|
||||
"format": "yyyy-MM-dd HH:mm:ss||strict_date_optional_time||epoch_millis",
|
||||
},
|
||||
"forget_at": map[string]interface{}{
|
||||
"type": "date",
|
||||
"format": "yyyy-MM-dd HH:mm:ss||strict_date_optional_time||epoch_millis",
|
||||
},
|
||||
vectorField: memoryMessageVectorProperty(vectorSize),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// getDefaultSkillMapping returns the default skill index mapping
|
||||
func getDefaultSkillMapping() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
|
||||
@@ -52,6 +52,8 @@ func (h *AuthHandler) AuthMiddleware() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
authViaAPIToken := false
|
||||
|
||||
// Get user by access token
|
||||
user, code, err := h.userService.GetUserByToken(token)
|
||||
if err != nil {
|
||||
@@ -64,6 +66,7 @@ func (h *AuthHandler) AuthMiddleware() gin.HandlerFunc {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
authViaAPIToken = true
|
||||
}
|
||||
|
||||
if user.IsSuperuser != nil && *user.IsSuperuser {
|
||||
@@ -89,6 +92,7 @@ func (h *AuthHandler) AuthMiddleware() gin.HandlerFunc {
|
||||
c.Set("user", user)
|
||||
c.Set("user_id", user.ID)
|
||||
c.Set("email", user.Email)
|
||||
c.Set("auth_via_api_token", authViaAPIToken)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -580,6 +580,34 @@ func (h *MemoryHandler) GetMemoryMessages(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
type messageMemoryIDs []string
|
||||
|
||||
func (ids *messageMemoryIDs) UnmarshalJSON(data []byte) error {
|
||||
var single string
|
||||
if err := json.Unmarshal(data, &single); err == nil {
|
||||
if strings.TrimSpace(single) != "" {
|
||||
*ids = []string{single}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var many []string
|
||||
if err := json.Unmarshal(data, &many); err != nil {
|
||||
return err
|
||||
}
|
||||
*ids = many
|
||||
return nil
|
||||
}
|
||||
|
||||
type AddMessageRequest struct {
|
||||
MemoryIDs messageMemoryIDs `json:"memory_id" binding:"required"`
|
||||
AgentID string `json:"agent_id" binding:"required"`
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
UserInput string `json:"user_input" binding:"required"`
|
||||
AgentResponse string `json:"agent_response" binding:"required"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// AddMessage handles POST request for adding messages
|
||||
// API Path: POST /api/v1/messages
|
||||
//
|
||||
@@ -595,12 +623,61 @@ func (h *MemoryHandler) GetMemoryMessages(c *gin.Context) {
|
||||
// - user_input (required): User input
|
||||
// - agent_response (required): Agent response
|
||||
// - user_id (optional): User ID
|
||||
//
|
||||
// TODO: Haruko386 is implementing this for now, if you implement this, delete this line plz
|
||||
func (h *MemoryHandler) AddMessage(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
currentUserID := strings.TrimSpace(user.ID)
|
||||
if currentUserID == "" {
|
||||
jsonError(c, common.CodeArgumentError, "user_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody AddMessageRequest
|
||||
if err := c.ShouldBindJSON(&reqBody); err != nil {
|
||||
jsonError(c, common.CodeArgumentError, "body arguments is required")
|
||||
return
|
||||
}
|
||||
if len(reqBody.MemoryIDs) == 0 {
|
||||
jsonError(c, common.CodeArgumentError, "memory_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
effectiveUserID := currentUserID
|
||||
if v, ok := c.Get("auth_via_api_token"); ok {
|
||||
if authViaAPIToken, ok := v.(bool); authViaAPIToken && ok {
|
||||
effectiveUserID = strings.TrimSpace(reqBody.UserID)
|
||||
if effectiveUserID == "" {
|
||||
jsonError(c, common.CodeArgumentError, "user_id is required")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msg := service.MemoryMessage{
|
||||
UserID: effectiveUserID,
|
||||
AgentID: reqBody.AgentID,
|
||||
SessionID: reqBody.SessionID,
|
||||
UserInput: reqBody.UserInput,
|
||||
AgentResponse: reqBody.AgentResponse,
|
||||
}
|
||||
|
||||
ok, message, err := h.memoryService.AddMessage(c.Request.Context(), currentUserID, []string(reqBody.MemoryIDs), msg)
|
||||
if err != nil || !ok {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": "Some messages failed to add. Detail:" + message,
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": "AddMessage not implemented - pending embedding engine dependency",
|
||||
"code": common.CodeSuccess,
|
||||
"message": message,
|
||||
"data": nil,
|
||||
})
|
||||
}
|
||||
@@ -741,7 +818,7 @@ func (h *MemoryHandler) UpdateMessage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
ok, err = h.memoryService.UpdateMessage(c.Request.Context(), userID, memoryID, messageID, status)
|
||||
ok, err = h.memoryService.UpdateMessageStatus(c.Request.Context(), userID, memoryID, messageID, status)
|
||||
if err != nil || !ok {
|
||||
if isMemoryServiceNotFound(err) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -418,6 +418,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
message := v1.Group("/messages")
|
||||
{
|
||||
message.GET("", r.memoryHandler.GetMessages)
|
||||
message.POST("", r.memoryHandler.AddMessage)
|
||||
message.DELETE("/:memory_message", r.memoryHandler.ForgetMessage)
|
||||
message.PUT("/:memory_message", r.memoryHandler.UpdateMessage)
|
||||
message.GET("/:memory_message/content", r.memoryHandler.GetMessageContent)
|
||||
|
||||
@@ -816,7 +816,93 @@ func (s *MemoryService) ForgetMessage(ctx context.Context, userID string, memory
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MemoryService) UpdateMessage(ctx context.Context, userID, memoryID string, messageID int64, status bool) (bool, error) {
|
||||
// AddMessage filters inaccessible memories and queues the raw message for
|
||||
// memory extraction. The currentUserID is used only for access control; msg.UserID
|
||||
// is the attribution stored on the message and may differ for API-token callers.
|
||||
func (s *MemoryService) AddMessage(ctx context.Context, currentUserID string, memoryIDs []string, msg MemoryMessage) (bool, string, error) {
|
||||
requestedMemoryIDs := splitFilterValues(memoryIDs)
|
||||
memories, err := s.filterAccessibleMemories(ctx, currentUserID, memoryIDs)
|
||||
if err != nil {
|
||||
return false, err.Error(), err
|
||||
}
|
||||
if len(memories) == 0 {
|
||||
return false, "Memory not found.", nil
|
||||
}
|
||||
|
||||
accessibleMemoryIDs := make([]string, 0, len(memories))
|
||||
for _, memory := range memories {
|
||||
if memory != nil {
|
||||
accessibleMemoryIDs = append(accessibleMemoryIDs, memory.ID)
|
||||
}
|
||||
}
|
||||
missingMemoryIDs := missingRequestedMemoryIDs(requestedMemoryIDs, accessibleMemoryIDs)
|
||||
|
||||
res, err := NewMemoryMessageService(s).QueueSaveToMemoryTask(ctx, accessibleMemoryIDs, msg)
|
||||
if err != nil {
|
||||
return false, err.Error(), err
|
||||
}
|
||||
|
||||
if len(missingMemoryIDs) > 0 {
|
||||
if res == nil {
|
||||
res = &QueueSaveResult{}
|
||||
}
|
||||
res.NotFound = append(missingMemoryIDs, res.NotFound...)
|
||||
}
|
||||
errorMsg := memorySaveErrorMessage(res)
|
||||
if errorMsg != "" {
|
||||
return false, errorMsg, nil
|
||||
}
|
||||
return true, "All add to task.", nil
|
||||
}
|
||||
|
||||
func missingRequestedMemoryIDs(requestedMemoryIDs, accessibleMemoryIDs []string) []string {
|
||||
if len(requestedMemoryIDs) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
accessibleSet := make(map[string]struct{}, len(accessibleMemoryIDs))
|
||||
for _, memoryID := range accessibleMemoryIDs {
|
||||
memoryID = strings.TrimSpace(memoryID)
|
||||
if memoryID != "" {
|
||||
accessibleSet[memoryID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
missingIDs := make([]string, 0)
|
||||
seenMissing := make(map[string]struct{})
|
||||
for _, memoryID := range requestedMemoryIDs {
|
||||
memoryID = strings.TrimSpace(memoryID)
|
||||
if memoryID == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := accessibleSet[memoryID]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := seenMissing[memoryID]; ok {
|
||||
continue
|
||||
}
|
||||
missingIDs = append(missingIDs, memoryID)
|
||||
seenMissing[memoryID] = struct{}{}
|
||||
}
|
||||
return missingIDs
|
||||
}
|
||||
|
||||
func memorySaveErrorMessage(res *QueueSaveResult) string {
|
||||
if res == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
if len(res.NotFound) > 0 {
|
||||
b.WriteString(fmt.Sprintf("Memory %v not found.", res.NotFound))
|
||||
}
|
||||
for _, failed := range res.Failed {
|
||||
b.WriteString(fmt.Sprintf("Memory %s failed. Detail: %s", failed.MemoryID, failed.FailMsg))
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (s *MemoryService) UpdateMessageStatus(ctx context.Context, userID, memoryID string, messageID int64, status bool) (bool, error) {
|
||||
memory, err := s.requireMemoryAccess(ctx, userID, memoryID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -848,6 +934,10 @@ func (s *MemoryService) UpdateMessage(ctx context.Context, userID, memoryID stri
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *MemoryService) UpdateMessage(ctx context.Context, userID, memoryID string, messageID int64, status bool) (bool, error) {
|
||||
return s.UpdateMessageStatus(ctx, userID, memoryID, messageID, status)
|
||||
}
|
||||
|
||||
func (s *MemoryService) GetMessageContent(ctx context.Context, userID, memoryID string, messageID int64) (map[string]interface{}, error) {
|
||||
memory, err := s.requireMemoryAccess(ctx, userID, memoryID)
|
||||
if err != nil {
|
||||
@@ -1650,15 +1740,6 @@ func memoryMessageKey(value interface{}) string {
|
||||
return strings.TrimSpace(fmt.Sprint(value))
|
||||
}
|
||||
|
||||
// TODO: queryMessages - Implementation pending - depends on CanvasService and TaskService
|
||||
// func (s *MemoryService) queryMessages(tenantID string, memoryID string, filterDict map[string]interface{}, page int, pageSize int) ([]map[string]interface{}, int64, error) { ... }
|
||||
|
||||
// TODO: AddMessage - Implementation pending - depends on embedding engine
|
||||
// func (s *MemoryService) AddMessage(memoryIDs []string, messageDict map[string]interface{}) (bool, string, error) { ... }
|
||||
|
||||
// TODO: UpdateMessageStatus - Implementation pending - depends on embedding engine
|
||||
// func (s *MemoryService) UpdateMessageStatus(memoryID string, messageID int, status bool) (bool, error) { ... }
|
||||
|
||||
// isList checks if a value is a list or array type
|
||||
// This is a utility function for type validation
|
||||
//
|
||||
|
||||
@@ -14,33 +14,31 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
// memory_message_service.go — Phase 8b real MemorySaver port.
|
||||
// memory_message_service.go — real MemorySaver port.
|
||||
//
|
||||
// Port of api.db.joint_services.memory_message_service.queue_save_to_memory_task
|
||||
// from the Python runtime. The Go port is a partial implementation
|
||||
// (synchronous parts land here; the embedding-model call is loud-failed
|
||||
// with ErrEmbedderNotWired until a Go embedding port lands).
|
||||
// from the Python runtime.
|
||||
//
|
||||
// Python signature (api/db/joint_services/memory_message_service.py:344):
|
||||
//
|
||||
// async def queue_save_to_memory_task(
|
||||
// memory_ids: list[str],
|
||||
// message_dict: dict,
|
||||
// ) -> tuple[list[str], list[dict]]
|
||||
// # (not_found_memory, failed_memory)
|
||||
// async def queue_save_to_memory_task(
|
||||
// memory_ids: list[str],
|
||||
// message_dict: dict,
|
||||
// ) -> tuple[list[str], list[dict]]
|
||||
// # (not_found_memory, failed_memory)
|
||||
//
|
||||
// Go equivalent:
|
||||
//
|
||||
// type QueueSaveResult struct {
|
||||
// NotFound []string
|
||||
// Failed []MemoryFailure
|
||||
// }
|
||||
// type QueueSaveResult struct {
|
||||
// NotFound []string
|
||||
// Failed []MemoryFailure
|
||||
// }
|
||||
//
|
||||
// func (s *MemoryMessageService) QueueSaveToMemoryTask(
|
||||
// ctx context.Context,
|
||||
// memoryIDs []string,
|
||||
// msg MemoryMessage,
|
||||
// ) (*QueueSaveResult, error)
|
||||
// func (s *MemoryMessageService) QueueSaveToMemoryTask(
|
||||
// ctx context.Context,
|
||||
// memoryIDs []string,
|
||||
// msg MemoryMessage,
|
||||
// ) (*QueueSaveResult, error)
|
||||
//
|
||||
// The function is the entry point the Message component calls
|
||||
// after a conversation turn when `memory_save=true` is set. It
|
||||
@@ -49,14 +47,9 @@
|
||||
// 1. For each memory id: look up the Memory (via MemoryService).
|
||||
// 2. Generate a raw_message_id from Redis auto-increment (namespace "memory").
|
||||
// 3. Build the raw_message envelope (mirrors Python:344-386).
|
||||
// 4. Call embed_and_save on the memory + [raw_message]. ← DEFERRED
|
||||
// 4. Call embed_and_save on the memory + [raw_message].
|
||||
// 5. Insert a Task row in the task table for the async extractor.
|
||||
// 6. Return not-found + failed lists.
|
||||
//
|
||||
// Steps 1, 2, 3, 5, 6 are implemented here. Step 4 (the
|
||||
// embed_and_save call) is wrapped in a loud-fail gate that returns
|
||||
// ErrEmbedderNotWired until the Go embedding layer ships.
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
@@ -64,6 +57,12 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
redisengine "ragflow/internal/engine/redis"
|
||||
"ragflow/internal/entity"
|
||||
models "ragflow/internal/entity/models"
|
||||
)
|
||||
|
||||
// ErrEmbedderNotWired is returned by QueueSaveToMemoryTask when
|
||||
@@ -96,9 +95,7 @@ type MemoryMessage struct {
|
||||
AgentResponse string
|
||||
}
|
||||
|
||||
// MemoryFailure describes one memory that failed to save. The
|
||||
// FailMsg is the underlying error (or "embedder not wired" until
|
||||
// the embedder port lands).
|
||||
// MemoryFailure describes one memory that failed to save.
|
||||
type MemoryFailure struct {
|
||||
MemoryID string
|
||||
FailMsg string
|
||||
@@ -112,11 +109,10 @@ type QueueSaveResult struct {
|
||||
}
|
||||
|
||||
// MemoryMessageService is the Go port of
|
||||
// api.db.joint_services.memory_message_service. It depends on a
|
||||
// MemoryService instance for the lookup; the embedder is
|
||||
// hard-coded to loud-fail (see ErrEmbedderNotWired).
|
||||
// api.db.joint_services.memory_message_service.
|
||||
type MemoryMessageService struct {
|
||||
memories *MemoryService
|
||||
taskDAO *dao.TaskDAO
|
||||
}
|
||||
|
||||
// NewMemoryMessageService constructs a service bound to the
|
||||
@@ -124,15 +120,17 @@ type MemoryMessageService struct {
|
||||
// the default MemorySaver in the Message component via
|
||||
// `component.SetMemorySaver(...)` at boot.
|
||||
func NewMemoryMessageService(memories *MemoryService) *MemoryMessageService {
|
||||
return &MemoryMessageService{memories: memories}
|
||||
return &MemoryMessageService{
|
||||
memories: memories,
|
||||
taskDAO: dao.NewTaskDAO(),
|
||||
}
|
||||
}
|
||||
|
||||
// QueueSaveToMemoryTask runs the memory-persistence flow for the
|
||||
// supplied memory_ids + message. See package comment for the
|
||||
// step-by-step contract. The function is synchronous — the Python
|
||||
// async version awaits `embed_and_save` and a Redis call; both are
|
||||
// replaced here with synchronous equivalents (and a loud-fail
|
||||
// embedder).
|
||||
// async version awaits `embed_and_save` and Redis calls; this Go port does the
|
||||
// same work synchronously from the HTTP request path.
|
||||
//
|
||||
// Returned QueueSaveResult has NotFound / Failed populated for
|
||||
// the per-memory outcomes. The outer error is reserved for
|
||||
@@ -168,11 +166,7 @@ func (s *MemoryMessageService) QueueSaveToMemoryTask(
|
||||
rawMessageID := generateRawMessageID()
|
||||
rawMessage := buildRawMessage(rawMessageID, memoryID, mem, msg)
|
||||
|
||||
// (4) embed_and_save. Loud-fail: the embedder is the
|
||||
// only step the Go runtime can't do yet. When it
|
||||
// ships, replace this branch with a call into
|
||||
// internal/rag/llm/embedding_model.
|
||||
if err := embedAndSave(ctx, mem, rawMessage); err != nil {
|
||||
if err := s.embedAndSave(ctx, mem, rawMessage); err != nil {
|
||||
res.Failed = append(res.Failed, MemoryFailure{
|
||||
MemoryID: memoryID,
|
||||
FailMsg: err.Error(),
|
||||
@@ -180,13 +174,6 @@ func (s *MemoryMessageService) QueueSaveToMemoryTask(
|
||||
continue
|
||||
}
|
||||
|
||||
// (5) Task row insertion. The Python side bulk-inserts
|
||||
// a Task row with digest=str(raw_message_id); the
|
||||
// extractor's task_type is "memory". The Go port
|
||||
// constructs the same row shape and defers the actual
|
||||
// insert to TaskDAO when the project adds one (today
|
||||
// TaskDAO is in internal/dao; the API mirrors the
|
||||
// Python Task entity closely).
|
||||
task := buildTaskRow(rawMessageID, memoryID)
|
||||
if err := s.insertTask(ctx, task); err != nil {
|
||||
res.Failed = append(res.Failed, MemoryFailure{
|
||||
@@ -195,21 +182,25 @@ func (s *MemoryMessageService) QueueSaveToMemoryTask(
|
||||
})
|
||||
continue
|
||||
}
|
||||
if err := queueMemoryTask(memoryID, mem.TenantID, rawMessageID, task, msg); err != nil {
|
||||
res.Failed = append(res.Failed, MemoryFailure{
|
||||
MemoryID: memoryID,
|
||||
FailMsg: err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// generateRawMessageID is a placeholder for the Redis auto-
|
||||
// increment the Python side uses (`REDIS_CONN.generate_auto_increment_id
|
||||
// (namespace="memory")`). The Go port generates a UUID-shaped
|
||||
// integer now; replace with a Redis-backed counter when the
|
||||
// project's Redis client lands.
|
||||
// generateRawMessageID returns the Redis auto-increment id used by the Python
|
||||
// side (`REDIS_CONN.generate_auto_increment_id(namespace="memory")`).
|
||||
func generateRawMessageID() int64 {
|
||||
// seconds-since-epoch is unique enough for the Go port's
|
||||
// own bookkeeping. The Redis-backed counter is the source
|
||||
// of truth in production; this fallback only matters for
|
||||
// the tests that don't need cross-process uniqueness.
|
||||
return time.Now().Unix()
|
||||
if redisClient := redisengine.Get(); redisClient != nil {
|
||||
if id := redisClient.GenerateAutoIncrementID("id_generator", "memory", 1, nil); id > 0 {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// buildRawMessage constructs the raw_message envelope that gets
|
||||
@@ -224,18 +215,22 @@ func buildRawMessage(
|
||||
content := fmt.Sprintf("User Input: %s\nAgent Response: %s",
|
||||
msg.UserInput, msg.AgentResponse)
|
||||
out := map[string]any{
|
||||
"message_id": rawMessageID,
|
||||
"message_type": "raw",
|
||||
"source_id": 0,
|
||||
"memory_id": memoryID,
|
||||
"user_id": msg.UserID,
|
||||
"agent_id": msg.AgentID,
|
||||
"session_id": msg.SessionID,
|
||||
"content": content,
|
||||
"valid_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"invalid_at": nil,
|
||||
"forget_at": nil,
|
||||
"status": true,
|
||||
"message_id": rawMessageID,
|
||||
"message_type": "raw",
|
||||
"message_type_kwd": "raw",
|
||||
"source_id": 0,
|
||||
"memory_id": memoryID,
|
||||
"user_id": msg.UserID,
|
||||
"agent_id": msg.AgentID,
|
||||
"session_id": msg.SessionID,
|
||||
"content": content,
|
||||
"content_ltks": content,
|
||||
"tokenized_content_ltks": content,
|
||||
"valid_at": time.Now().UTC().Format("2006-01-02 15:04:05"),
|
||||
"invalid_at": nil,
|
||||
"forget_at": nil,
|
||||
"status": true,
|
||||
"status_int": 1,
|
||||
}
|
||||
if mem != nil {
|
||||
// The embedder uses the memory's embd_id; keep the
|
||||
@@ -253,29 +248,125 @@ func buildTaskRow(rawMessageID int64, memoryID string) map[string]any {
|
||||
"doc_id": memoryID,
|
||||
"task_type": "memory",
|
||||
"progress": 0.0,
|
||||
"begin_at": time.Now(),
|
||||
"digest": fmt.Sprintf("%d", rawMessageID),
|
||||
}
|
||||
}
|
||||
|
||||
// embedAndSave is the deferred gate. Replace with a call to the
|
||||
// real embedding model + memory_message table insert when those
|
||||
// land.
|
||||
func (s *MemoryMessageService) embedAndSave(ctx context.Context, mem *CreateMemoryResponse, rawMessage map[string]any) error {
|
||||
if mem == nil {
|
||||
return errors.New("memory not found")
|
||||
}
|
||||
if s == nil || s.memories == nil || s.memories.docEngine == nil {
|
||||
return errors.New("message store is not initialized")
|
||||
}
|
||||
|
||||
content, _ := rawMessage["content"].(string)
|
||||
driver, modelName, apiConfig, maxTokens, err := NewModelProviderService().GetModelConfigFromProviderInstance(mem.TenantID, entity.ModelTypeEmbedding, mem.EmbdID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
embeddingModel := models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens)
|
||||
embeddings, err := embeddingModel.ModelDriver.Embed(embeddingModel.ModelName, []string{content}, embeddingModel.APIConfig, &models.EmbeddingConfig{Dimension: 0})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(embeddings) == 0 || len(embeddings[0].Embedding) == 0 {
|
||||
return errors.New("embedding response is empty")
|
||||
}
|
||||
|
||||
vector := embeddings[0].Embedding
|
||||
rawMessage[fmt.Sprintf("q_%d_vec", len(vector))] = vector
|
||||
rawMessage["id"] = fmt.Sprintf("%s_%d", rawMessage["memory_id"], rawMessage["message_id"])
|
||||
rawMessage["doc_id"] = rawMessage["memory_id"]
|
||||
|
||||
indexName := memoryIndexName(mem.TenantID)
|
||||
exists, err := s.memories.docEngine.ChunkStoreExists(ctx, indexName, mem.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check message index: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
if err := s.memories.docEngine.CreateChunkStore(ctx, indexName, mem.ID, len(vector), ""); err != nil {
|
||||
return fmt.Errorf("create message index: %w", err)
|
||||
}
|
||||
}
|
||||
if _, err := s.memories.docEngine.InsertChunks(ctx, []map[string]interface{}{mapStringAny(rawMessage)}, indexName, mem.ID); err != nil {
|
||||
return fmt.Errorf("insert message into memory: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// embedAndSave is kept for older unit tests; production uses the method above.
|
||||
func embedAndSave(_ context.Context, _ *CreateMemoryResponse, _ map[string]any) error {
|
||||
return ErrEmbedderNotWired
|
||||
}
|
||||
|
||||
// insertTask is a placeholder for the bulk_insert_into_db call
|
||||
// the Python side makes. The Go side needs a TaskDAO write path
|
||||
// (the Python Task entity is mirrored in internal/entity); until
|
||||
// that lands this is a no-op that returns nil so the rest of
|
||||
// the flow can be exercised.
|
||||
func (s *MemoryMessageService) insertTask(_ context.Context, _ map[string]any) error {
|
||||
return nil
|
||||
func (s *MemoryMessageService) insertTask(_ context.Context, row map[string]any) error {
|
||||
if s == nil {
|
||||
return errors.New("nil MemoryMessageService")
|
||||
}
|
||||
if s.taskDAO == nil {
|
||||
s.taskDAO = dao.NewTaskDAO()
|
||||
}
|
||||
return s.taskDAO.Create(taskFromRow(row))
|
||||
}
|
||||
|
||||
// newUUIDString is a thin wrapper so we can swap in a real UUID
|
||||
// generator later without changing call sites. Avoids an
|
||||
// import-cycle with internal/uuid at the package boundary.
|
||||
func newUUIDString() string {
|
||||
return fmt.Sprintf("mem-%d", time.Now().UnixNano())
|
||||
return common.GenerateUUID()
|
||||
}
|
||||
|
||||
func taskFromRow(row map[string]any) *entity.Task {
|
||||
digest := fmt.Sprint(row["digest"])
|
||||
beginAt, _ := row["begin_at"].(time.Time)
|
||||
if beginAt.IsZero() {
|
||||
now := time.Now()
|
||||
beginAt = now
|
||||
}
|
||||
return &entity.Task{
|
||||
ID: fmt.Sprint(row["id"]),
|
||||
DocID: fmt.Sprint(row["doc_id"]),
|
||||
TaskType: fmt.Sprint(row["task_type"]),
|
||||
Progress: 0,
|
||||
BeginAt: &beginAt,
|
||||
Digest: &digest,
|
||||
}
|
||||
}
|
||||
|
||||
func queueMemoryTask(memoryID, tenantID string, rawMessageID int64, task map[string]any, msg MemoryMessage) error {
|
||||
taskID := fmt.Sprint(task["id"])
|
||||
message := map[string]any{
|
||||
"id": taskID,
|
||||
"task_id": taskID,
|
||||
"task_type": task["task_type"],
|
||||
"memory_id": memoryID,
|
||||
"tenant_id": tenantID,
|
||||
"source_id": rawMessageID,
|
||||
"message_dict": map[string]any{
|
||||
"user_id": msg.UserID,
|
||||
"agent_id": msg.AgentID,
|
||||
"session_id": msg.SessionID,
|
||||
"user_input": msg.UserInput,
|
||||
"agent_response": msg.AgentResponse,
|
||||
},
|
||||
}
|
||||
if redisClient := redisengine.Get(); redisClient == nil || !redisClient.QueueProduct(memoryTaskQueueName(0), message) {
|
||||
return errors.New("Can't access Redis.")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func memoryTaskQueueName(priority int) string {
|
||||
return fmt.Sprintf("te.%d.common", priority)
|
||||
}
|
||||
|
||||
func mapStringAny(in map[string]any) map[string]interface{} {
|
||||
out := make(map[string]interface{}, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -142,8 +142,8 @@ func TestBuildTaskRow_Shape(t *testing.T) {
|
||||
if row["digest"] != "99" {
|
||||
t.Errorf("digest = %v, want \"99\"", row["digest"])
|
||||
}
|
||||
if id, _ := row["id"].(string); !strings.HasPrefix(id, "mem-") {
|
||||
t.Errorf("id = %q, want mem- prefix", id)
|
||||
if id, _ := row["id"].(string); len(id) != 32 {
|
||||
t.Errorf("id = %q, want 32-char uuid", id)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user