mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
feat[Go]: implemet api: Search/Get/Update-Messages (#16307)
### What problem does this PR solve? As title: implement: ``` /api/v1/messages/search GET /api/v1/messages GET /api/v1/messages/<memory_id>:<message_id>/content GET /api/v1/memories/<memory_id>/config GET /api/v1/messages/<memory_id>:<message_id> PUT ``` ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -32,11 +32,12 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/elastic/go-elasticsearch/v8/esapi"
|
||||
"github.com/json-iterator/go"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/engine/types"
|
||||
|
||||
"github.com/elastic/go-elasticsearch/v8/esapi"
|
||||
"github.com/json-iterator/go"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@@ -44,6 +45,8 @@ var jsonIterator = jsoniter.Config{
|
||||
SortMapKeys: false,
|
||||
}.Froze()
|
||||
|
||||
var memoryMessageVectorFieldRE = regexp.MustCompile(`^q_\d+_vec$`)
|
||||
|
||||
var (
|
||||
elasticsearchHighlightEmTagRE = regexp.MustCompile(`<em>[^<>]+</em>`)
|
||||
elasticsearchHighlightNewlineRE = regexp.MustCompile(`[\r\n]`)
|
||||
@@ -233,6 +236,14 @@ func (e *elasticsearchEngine) UpdateChunks(ctx context.Context, condition map[st
|
||||
return fmt.Errorf("index '%s' does not exist", fullIndexName)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(fullIndexName, "memory_") {
|
||||
condition["memory_id"] = datasetID
|
||||
if messageDocID, ok := condition["id"].(string); ok {
|
||||
return e.updateSingleMemoryMessage(ctx, fullIndexName, messageDocID, newValue)
|
||||
}
|
||||
return e.updateChunksByQuery(ctx, fullIndexName, mapMemoryMessageESConditionFields(condition), mapMemoryMessageESUpdateFields(newValue))
|
||||
}
|
||||
|
||||
// Add kb_id to condition
|
||||
condition["kb_id"] = datasetID
|
||||
|
||||
@@ -245,6 +256,59 @@ func (e *elasticsearchEngine) UpdateChunks(ctx context.Context, condition map[st
|
||||
return e.updateChunksByQuery(ctx, fullIndexName, condition, newValue)
|
||||
}
|
||||
|
||||
func (e *elasticsearchEngine) updateSingleMemoryMessage(ctx context.Context, indexName, messageDocID string, newValue map[string]interface{}) error {
|
||||
doc := mapMemoryMessageESUpdateFields(newValue)
|
||||
delete(doc, "id")
|
||||
if len(doc) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateBody := map[string]interface{}{"doc": doc}
|
||||
body, err := json.Marshal(updateBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal memory message update request: %w", err)
|
||||
}
|
||||
req := esapi.UpdateRequest{
|
||||
Index: indexName,
|
||||
DocumentID: messageDocID,
|
||||
Body: bytes.NewReader(body),
|
||||
}
|
||||
res, err := req.Do(ctx, e.client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update memory message: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.IsError() {
|
||||
if res.StatusCode == http.StatusNotFound {
|
||||
return fmt.Errorf("%w: %s", types.ErrDocumentNotFound, messageDocID)
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(res.Body)
|
||||
return fmt.Errorf("elasticsearch memory message update error: %s, body: %s", res.Status(), string(bodyBytes))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mapMemoryMessageESUpdateFields(newValue map[string]interface{}) map[string]interface{} {
|
||||
doc := make(map[string]interface{}, len(newValue))
|
||||
for k, v := range newValue {
|
||||
switch k {
|
||||
case "remove", "add":
|
||||
doc[k] = v
|
||||
default:
|
||||
doc[mapMemoryMessageESField(k, false)] = v
|
||||
}
|
||||
}
|
||||
return doc
|
||||
}
|
||||
|
||||
func mapMemoryMessageESConditionFields(condition map[string]interface{}) map[string]interface{} {
|
||||
mapped := make(map[string]interface{}, len(condition))
|
||||
for k, v := range condition {
|
||||
mapped[mapMemoryMessageESField(k, false)] = v
|
||||
}
|
||||
return mapped
|
||||
}
|
||||
|
||||
// updateSingleChunk handles single document update
|
||||
func (e *elasticsearchEngine) updateSingleChunk(ctx context.Context, indexName, chunkID string, newValue map[string]interface{}) error {
|
||||
common.Debug("ElasticsearchConnection.updateSingleChunk called", zap.String("indexName", indexName), zap.String("chunkID", chunkID))
|
||||
@@ -762,15 +826,18 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque
|
||||
|
||||
// Detect index types
|
||||
isSkillIndex := false
|
||||
isMemoryIndex := false
|
||||
for _, idx := range req.IndexNames {
|
||||
if strings.HasPrefix(idx, "skill_") {
|
||||
isSkillIndex = true
|
||||
break
|
||||
}
|
||||
if strings.HasPrefix(idx, "memory_") {
|
||||
isMemoryIndex = true
|
||||
}
|
||||
}
|
||||
|
||||
// Build bool query from condition
|
||||
boolQuery := buildBoolQueryFromCondition(req.Filter, req.KbIDs, isSkillIndex)
|
||||
boolQuery := buildBoolQueryFromCondition(req.Filter, req.KbIDs, isSkillIndex, isMemoryIndex)
|
||||
|
||||
// Extract vector_similarity_weight from FusionExpr
|
||||
var matchText *types.MatchTextExpr
|
||||
@@ -816,7 +883,7 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque
|
||||
queryBody := make(map[string]interface{})
|
||||
|
||||
if matchText != nil {
|
||||
textQuery := buildQueryStringQuery(matchText, vectorSimilarityWeight, isSkillIndex)
|
||||
textQuery := buildQueryStringQuery(matchText, vectorSimilarityWeight, isSkillIndex, isMemoryIndex)
|
||||
if boolQuery != nil {
|
||||
if boolMap, ok := boolQuery["bool"].(map[string]interface{}); ok {
|
||||
if must, ok := boolMap["must"].([]interface{}); ok {
|
||||
@@ -874,7 +941,7 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque
|
||||
}
|
||||
|
||||
// Add rank_feature queries
|
||||
if req.RankFeature != nil && len(req.RankFeature) > 0 && !isSkillIndex {
|
||||
if req.RankFeature != nil && len(req.RankFeature) > 0 && !isSkillIndex && !isMemoryIndex {
|
||||
rankFeatureQuery := buildRankFeatureQuery(req.RankFeature)
|
||||
if rankFeatureQuery != nil {
|
||||
if boolQuery, ok := queryBody["query"].(map[string]interface{}); ok {
|
||||
@@ -929,21 +996,25 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque
|
||||
|
||||
// Set _source and fields for vector fields
|
||||
hasTextMatch := matchText != nil
|
||||
selectFields := req.SelectFields
|
||||
if isMemoryIndex {
|
||||
selectFields = mapMemoryMessageESFields(req.SelectFields, false)
|
||||
}
|
||||
if len(req.SelectFields) > 0 {
|
||||
// Use caller-specified fields, add pagerank_fld/tag_fld if needed
|
||||
queryBody["_source"] = req.SelectFields
|
||||
queryBody["_source"] = selectFields
|
||||
if hasTextMatch || hasVectorMatch {
|
||||
if !isSkillIndex {
|
||||
if !slices.Contains(req.SelectFields, common.PAGERANK_FLD) {
|
||||
if !isSkillIndex && !isMemoryIndex {
|
||||
if !slices.Contains(selectFields, common.PAGERANK_FLD) {
|
||||
queryBody["_source"] = append(queryBody["_source"].([]string), common.PAGERANK_FLD)
|
||||
}
|
||||
if !slices.Contains(req.SelectFields, common.TAG_FLD) {
|
||||
if !slices.Contains(selectFields, common.TAG_FLD) {
|
||||
queryBody["_source"] = append(queryBody["_source"].([]string), common.TAG_FLD)
|
||||
}
|
||||
}
|
||||
}
|
||||
var vectorFields []string
|
||||
for _, f := range req.SelectFields {
|
||||
for _, f := range selectFields {
|
||||
if strings.HasSuffix(f, "_vec") {
|
||||
vectorFields = append(vectorFields, f)
|
||||
}
|
||||
@@ -954,7 +1025,7 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque
|
||||
} else {
|
||||
// No explicit SelectFields - use match_all, but add pagerank_fld/tag_fld for scoring if needed
|
||||
if hasTextMatch || hasVectorMatch {
|
||||
if !isSkillIndex {
|
||||
if !isSkillIndex && !isMemoryIndex {
|
||||
queryBody["_source"] = []string{common.PAGERANK_FLD, common.TAG_FLD}
|
||||
}
|
||||
}
|
||||
@@ -1021,6 +1092,10 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque
|
||||
}
|
||||
}
|
||||
|
||||
if isMemoryIndex {
|
||||
normalizeMemoryMessageChunks(allResults)
|
||||
}
|
||||
|
||||
// Post-processing: Sort results by score
|
||||
if len(allResults) > 0 && (matchText != nil || hasVectorMatch) {
|
||||
scoreColumn := "_score"
|
||||
@@ -1032,6 +1107,9 @@ func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchReque
|
||||
if isSkillIndex {
|
||||
pagerankField = ""
|
||||
}
|
||||
if isMemoryIndex {
|
||||
pagerankField = ""
|
||||
}
|
||||
|
||||
allResults = calculateScores(allResults, scoreColumn, pagerankField)
|
||||
allResults = sortByScore(allResults, limit)
|
||||
@@ -1293,17 +1371,105 @@ func sortValuesEqual(a, b []interface{}) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// buildBoolQueryFromCondition builds an ES bool query from condition map
|
||||
// For skill index, uses 'status' field instead of 'available_int'
|
||||
func buildBoolQueryFromCondition(filter map[string]interface{}, kbIDs []string, isSkillIndex bool) map[string]interface{} {
|
||||
func mapMemoryMessageESField(field string, useTokenizedContent bool) string {
|
||||
name := field
|
||||
boost := ""
|
||||
if base, suffix, ok := strings.Cut(field, "^"); ok {
|
||||
name = base
|
||||
boost = "^" + suffix
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "message_type":
|
||||
name = "message_type_kwd"
|
||||
case "status":
|
||||
name = "status_int"
|
||||
case "content":
|
||||
if useTokenizedContent {
|
||||
name = "tokenized_content_ltks"
|
||||
} else {
|
||||
name = "content_ltks"
|
||||
}
|
||||
}
|
||||
return name + boost
|
||||
}
|
||||
|
||||
func mapMemoryMessageESFields(fields []string, useTokenizedContent bool) []string {
|
||||
if len(fields) == 0 {
|
||||
return fields
|
||||
}
|
||||
mapped := make([]string, 0, len(fields))
|
||||
seen := make(map[string]struct{}, len(fields))
|
||||
for _, field := range fields {
|
||||
mappedField := mapMemoryMessageESField(field, useTokenizedContent)
|
||||
if _, ok := seen[mappedField]; ok {
|
||||
continue
|
||||
}
|
||||
seen[mappedField] = struct{}{}
|
||||
mapped = append(mapped, mappedField)
|
||||
}
|
||||
return mapped
|
||||
}
|
||||
|
||||
func normalizeMemoryMessageChunks(chunks []map[string]interface{}) {
|
||||
for _, chunk := range chunks {
|
||||
for key, val := range chunk {
|
||||
if memoryMessageVectorFieldRE.MatchString(key) {
|
||||
chunk["content_embed"] = val
|
||||
delete(chunk, key)
|
||||
}
|
||||
}
|
||||
if val, ok := chunk["message_type_kwd"]; ok {
|
||||
chunk["message_type"] = val
|
||||
delete(chunk, "message_type_kwd")
|
||||
}
|
||||
if val, ok := chunk["status_int"]; ok {
|
||||
chunk["status"] = memoryMessageStatusBool(val)
|
||||
delete(chunk, "status_int")
|
||||
}
|
||||
if val, ok := chunk["content_ltks"]; ok {
|
||||
chunk["content"] = val
|
||||
delete(chunk, "content_ltks")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func memoryMessageStatusBool(value interface{}) bool {
|
||||
switch v := value.(type) {
|
||||
case bool:
|
||||
return v
|
||||
case int:
|
||||
return v != 0
|
||||
case int64:
|
||||
return v != 0
|
||||
case float64:
|
||||
return v != 0
|
||||
case json.Number:
|
||||
n, err := v.Int64()
|
||||
return err == nil && n != 0
|
||||
case string:
|
||||
return v != "" && v != "0" && !strings.EqualFold(v, "false")
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// buildBoolQueryFromCondition builds an ES bool query from condition map.
|
||||
// Skill indexes use status, regular chunk indexes use kb_id, and memory
|
||||
// message indexes use memory_id plus message-specific storage fields.
|
||||
func buildBoolQueryFromCondition(filter map[string]interface{}, kbIDs []string, isSkillIndex, isMemoryIndex bool) map[string]interface{} {
|
||||
var mustClauses []interface{}
|
||||
var filterClauses []interface{}
|
||||
var shouldClauses []interface{}
|
||||
|
||||
// Add kb_id to condition
|
||||
// Memory message indexes use memory_id, regular chunk indexes use kb_id.
|
||||
if kbIDs != nil && len(kbIDs) > 0 {
|
||||
fieldName := "kb_id"
|
||||
if isMemoryIndex {
|
||||
fieldName = "memory_id"
|
||||
}
|
||||
filterClauses = append(filterClauses, map[string]interface{}{
|
||||
"terms": map[string]interface{}{"kb_id": kbIDs},
|
||||
"terms": map[string]interface{}{fieldName: kbIDs},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1321,6 +1487,9 @@ func buildBoolQueryFromCondition(filter map[string]interface{}, kbIDs []string,
|
||||
}
|
||||
|
||||
for k, v := range filter {
|
||||
if isMemoryIndex {
|
||||
k = mapMemoryMessageESField(k, false)
|
||||
}
|
||||
// For skill index, handle 'status' field instead of 'available_int'
|
||||
if isSkillIndex && k == "status" {
|
||||
if v == nil || v == "" {
|
||||
@@ -1389,6 +1558,18 @@ func buildBoolQueryFromCondition(filter map[string]interface{}, kbIDs []string,
|
||||
if v == nil || v == "" {
|
||||
continue
|
||||
}
|
||||
if isMemoryIndex && k == "session_id" {
|
||||
if strVal, ok := v.(string); ok && strVal != "" {
|
||||
filterClauses = append(filterClauses, map[string]interface{}{
|
||||
"query_string": map[string]interface{}{
|
||||
"query": fmt.Sprintf("*%s*", strVal),
|
||||
"fields": []string{"session_id"},
|
||||
"analyze_wildcard": true,
|
||||
},
|
||||
})
|
||||
continue
|
||||
}
|
||||
}
|
||||
if listVal, ok := v.([]interface{}); ok {
|
||||
filterClauses = append(filterClauses, map[string]interface{}{
|
||||
"terms": map[string]interface{}{k: listVal},
|
||||
@@ -1435,7 +1616,7 @@ func buildBoolQueryFromCondition(filter map[string]interface{}, kbIDs []string,
|
||||
// buildQueryStringQuery builds a query_string query from MatchTextExpr
|
||||
// When isSkillIndex is true, uses skill-specific fields (name_tks, tags_tks, etc.)
|
||||
// Otherwise uses document fields (title_tks, content_ltks, etc.)
|
||||
func buildQueryStringQuery(matchText *types.MatchTextExpr, vectorSimilarityWeight float64, isSkillIndex bool) map[string]interface{} {
|
||||
func buildQueryStringQuery(matchText *types.MatchTextExpr, vectorSimilarityWeight float64, isSkillIndex, isMemoryIndex bool) map[string]interface{} {
|
||||
if matchText == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -1451,10 +1632,15 @@ func buildQueryStringQuery(matchText *types.MatchTextExpr, vectorSimilarityWeigh
|
||||
if fields == nil || len(fields) == 0 {
|
||||
if isSkillIndex {
|
||||
fields = []string{"name_tks^10", "tags_tks^5", "description_tks^3", "content_tks^1"}
|
||||
} else if isMemoryIndex {
|
||||
fields = []string{"tokenized_content_ltks"}
|
||||
} else {
|
||||
fields = []string{"title_tks^10", "title_sm_tks^5", "important_kwd^30", "important_tks^20", "question_tks^20", "content_ltks^2", "content_sm_ltks"}
|
||||
}
|
||||
}
|
||||
if isMemoryIndex {
|
||||
fields = mapMemoryMessageESFields(fields, true)
|
||||
}
|
||||
|
||||
boost := 1.0
|
||||
if matchText.ExtraOptions != nil {
|
||||
@@ -1507,6 +1693,10 @@ func buildRankFeatureQuery(rankFeature map[string]float64) []map[string]interfac
|
||||
|
||||
// GetChunk gets a chunk by ID using ES search API
|
||||
func (e *elasticsearchEngine) GetChunk(ctx context.Context, baseName, chunkID string, datasetIDs []string) (interface{}, error) {
|
||||
if strings.HasPrefix(baseName, "memory_") {
|
||||
return e.getMemoryMessage(ctx, baseName, chunkID)
|
||||
}
|
||||
|
||||
// Try search by doc_id field (which is stored in the document)
|
||||
for _, datasetID := range datasetIDs {
|
||||
searchReq := map[string]interface{}{
|
||||
@@ -1575,6 +1765,41 @@ func (e *elasticsearchEngine) GetChunk(ctx context.Context, baseName, chunkID st
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (e *elasticsearchEngine) getMemoryMessage(ctx context.Context, indexName, docID string) (interface{}, error) {
|
||||
req := esapi.GetRequest{
|
||||
Index: indexName,
|
||||
DocumentID: docID,
|
||||
}
|
||||
res, err := req.Do(ctx, e.client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get memory message: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.IsError() {
|
||||
if res.StatusCode == http.StatusNotFound {
|
||||
return nil, fmt.Errorf("%w: %s", types.ErrDocumentNotFound, docID)
|
||||
}
|
||||
return nil, fmt.Errorf("elasticsearch memory message get error: %s", res.Status())
|
||||
}
|
||||
|
||||
var getResult struct {
|
||||
Found bool `json:"found"`
|
||||
Source map[string]interface{} `json:"_source"`
|
||||
}
|
||||
if err := json.NewDecoder(res.Body).Decode(&getResult); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse memory message get response: %w", err)
|
||||
}
|
||||
if !getResult.Found || getResult.Source == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
message := getResult.Source
|
||||
message["id"] = docID
|
||||
normalizeMemoryMessageChunks([]map[string]interface{}{message})
|
||||
return message, nil
|
||||
}
|
||||
|
||||
// GetFields extracts the requested fields from ES search response chunks
|
||||
//
|
||||
// Unlike Infinity, Elasticsearch does NOT use convertSelectFields before querying.
|
||||
|
||||
@@ -379,7 +379,7 @@ func sortedCopy(in []int) []int {
|
||||
func TestBuildBoolQueryFromConditionIDFilter(t *testing.T) {
|
||||
check := func(name string, cond map[string]interface{}, wantFields []string) {
|
||||
t.Helper()
|
||||
got := buildBoolQueryFromCondition(cond, nil, false)
|
||||
got := buildBoolQueryFromCondition(cond, nil, false, false)
|
||||
outer, ok := got["bool"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("%s: missing bool wrapper: %v", name, got)
|
||||
|
||||
@@ -1282,6 +1282,13 @@ func applyFieldMappings(chunks []map[string]interface{}) {
|
||||
chunk["authors_sm_tks"] = val
|
||||
}
|
||||
|
||||
if val, ok := chunk["message_type_kwd"]; ok {
|
||||
chunk["message_type"] = val
|
||||
}
|
||||
if val, ok := chunk["status_int"]; ok {
|
||||
chunk["status"] = memoryMessageStatusBool(val)
|
||||
}
|
||||
|
||||
// position_int: convert from hex string to array format (grouped by 5)
|
||||
if val, ok := chunk["position_int"].(string); ok {
|
||||
chunk["position_int"] = utility.ConvertHexToPositionIntArray(val)
|
||||
@@ -1332,6 +1339,26 @@ func applyFieldMappings(chunks []map[string]interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
func memoryMessageStatusBool(value interface{}) bool {
|
||||
switch v := value.(type) {
|
||||
case bool:
|
||||
return v
|
||||
case int:
|
||||
return v != 0
|
||||
case int64:
|
||||
return v != 0
|
||||
case float64:
|
||||
return v != 0
|
||||
case json.Number:
|
||||
n, err := v.Int64()
|
||||
return err == nil && n != 0
|
||||
case string:
|
||||
return v != "" && v != "0" && !strings.EqualFold(v, "false")
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// GetFields extracts the requested fields from Infinity search results
|
||||
func (e *infinityEngine) GetFields(chunks []map[string]interface{}, fields []string) map[string]map[string]interface{} {
|
||||
result := make(map[string]map[string]interface{})
|
||||
@@ -1837,6 +1864,8 @@ func convertSelectFields(output []string, isSkillIndex ...bool) []string {
|
||||
"content_sm_ltks": "content",
|
||||
"authors_tks": "authors",
|
||||
"authors_sm_tks": "authors",
|
||||
"message_type": "message_type_kwd",
|
||||
"status": "status_int",
|
||||
}
|
||||
|
||||
skillIndex := false
|
||||
@@ -2005,6 +2034,12 @@ func equivalentConditionToStr(condition map[string]interface{}) string {
|
||||
continue
|
||||
}
|
||||
|
||||
if k == "message_type" {
|
||||
k = "message_type_kwd"
|
||||
} else if k == "status" {
|
||||
k = "status_int"
|
||||
}
|
||||
|
||||
// Handle list values (mixed types - strings get quotes, numbers don't)
|
||||
if list, ok := v.([]interface{}); ok && len(list) > 0 {
|
||||
var strItems, numItems []string
|
||||
@@ -2210,6 +2245,19 @@ func transformChunkFields(chunk map[string]interface{}, embeddingCols [][2]inter
|
||||
d["questions"] = strings.Join(utility.ConvertToStringSlice(v), "\n")
|
||||
case "tag_kwd":
|
||||
d["tag_kwd"] = strings.Join(utility.ConvertToStringSlice(v), "###")
|
||||
case "message_type":
|
||||
d["message_type_kwd"] = v
|
||||
case "status":
|
||||
switch status := v.(type) {
|
||||
case bool:
|
||||
if status {
|
||||
d["status_int"] = 1
|
||||
} else {
|
||||
d["status_int"] = 0
|
||||
}
|
||||
default:
|
||||
d["status_int"] = v
|
||||
}
|
||||
case "question_tks":
|
||||
if _, exists := chunk["question_kwd"]; !exists {
|
||||
d["questions"] = utility.ConvertToString(v)
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -515,13 +516,67 @@ func (h *MemoryHandler) GetMemoryConfig(c *gin.Context) {
|
||||
// - message: true
|
||||
// - data.messages: Array of message objects
|
||||
// - data.storage_type: Storage type
|
||||
//
|
||||
// TODO: Implementation pending - depends on CanvasService and TaskService
|
||||
func (h *MemoryHandler) GetMemoryMessages(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
userID := strings.TrimSpace(user.ID)
|
||||
if userID == "" {
|
||||
jsonError(c, common.CodeAuthenticationError, "user id is required")
|
||||
return
|
||||
}
|
||||
|
||||
memoryID := strings.TrimSpace(c.Param("memory_id"))
|
||||
if memoryID == "" {
|
||||
jsonError(c, common.CodeArgumentError, "memory_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
var agentIDs []string
|
||||
values := c.QueryArray("agent_id")
|
||||
for _, v := range values {
|
||||
parts := strings.Split(v, ",")
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
agentIDs = append(agentIDs, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
keywords := strings.TrimSpace(c.DefaultQuery("keywords", ""))
|
||||
page, err := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
if err != nil || page <= 0 {
|
||||
jsonError(c, common.CodeArgumentError, "page must be a positive integer")
|
||||
return
|
||||
}
|
||||
pageSize, err := strconv.Atoi(c.DefaultQuery("page_size", "50"))
|
||||
if err != nil || pageSize <= 0 {
|
||||
jsonError(c, common.CodeArgumentError, "page_size must be a positive integer")
|
||||
return
|
||||
}
|
||||
if pageSize > 100 {
|
||||
jsonError(c, common.CodeArgumentError, "page_size must be less than or equal to 100")
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.memoryService.GetMemoryMessages(c.Request.Context(), userID, memoryID, agentIDs, keywords, page, pageSize)
|
||||
if err != nil {
|
||||
if isMemoryServiceNotFound(err) {
|
||||
jsonError(c, common.CodeNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
jsonError(c, common.CodeServerError, "Internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": "GetMemoryMessages not implemented - pending CanvasService and TaskService dependencies",
|
||||
"data": nil,
|
||||
"code": common.CodeSuccess,
|
||||
"message": true,
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -541,7 +596,7 @@ func (h *MemoryHandler) GetMemoryMessages(c *gin.Context) {
|
||||
// - agent_response (required): Agent response
|
||||
// - user_id (optional): User ID
|
||||
//
|
||||
// TODO: Implementation pending - depends on embedding engine
|
||||
// TODO: Haruko386 is implementing this for now, if you implement this, delete this line plz
|
||||
func (h *MemoryHandler) AddMessage(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
@@ -645,16 +700,123 @@ func parseMemoryMessagePath(memoryMessage string) (string, int64, error) {
|
||||
//
|
||||
// Request Parameters (JSON Body):
|
||||
// - status (required): Message status, boolean
|
||||
//
|
||||
// TODO: Implementation pending - depends on embedding engine
|
||||
func (h *MemoryHandler) UpdateMessage(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
userID := strings.TrimSpace(user.ID)
|
||||
if userID == "" {
|
||||
jsonError(c, common.CodeAuthenticationError, "user id is required")
|
||||
return
|
||||
}
|
||||
|
||||
memoryID, messageID, err := parseMemoryMessagePath(c.Param("memory_message"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req map[string]interface{}
|
||||
if err = json.NewDecoder(c.Request.Body).Decode(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
status, ok := req["status"].(bool)
|
||||
if !ok {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": "Status must be a boolean.",
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ok, err = h.memoryService.UpdateMessage(c.Request.Context(), userID, memoryID, messageID, status)
|
||||
if err != nil || !ok {
|
||||
if isMemoryServiceNotFound(err) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeNotFound,
|
||||
"message": err.Error(),
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": "Internal server error",
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": "UpdateMessage not implemented - pending embedding engine dependency",
|
||||
"code": common.CodeSuccess,
|
||||
"message": true,
|
||||
"data": nil,
|
||||
})
|
||||
}
|
||||
|
||||
// GetMessageContent handles GET request for getting message content
|
||||
// API Path: GET /api/v1/messages/:memory_id/:message_id/content
|
||||
//
|
||||
// Function:
|
||||
// - Gets complete content of the specified message
|
||||
// - doc_id format: memory_id + "_" + message_id
|
||||
//
|
||||
// Parameter Format:
|
||||
// - memory_id: Memory ID
|
||||
// - message_id: Message ID (integer)
|
||||
//
|
||||
|
||||
func (h *MemoryHandler) GetMessageContent(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
userID := strings.TrimSpace(user.ID)
|
||||
if userID == "" {
|
||||
jsonError(c, common.CodeAuthenticationError, "user id is required")
|
||||
return
|
||||
}
|
||||
|
||||
memoryID, messageID, err := parseMemoryMessagePath(c.Param("memory_message"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.memoryService.GetMessageContent(c.Request.Context(), userID, memoryID, messageID)
|
||||
if err != nil {
|
||||
if _, ok := err.(*service.ResourceNotFoundError); ok {
|
||||
jsonError(c, common.CodeNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
jsonError(c, common.CodeServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"message": true,
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
// SearchMessage handles GET request for searching messages
|
||||
// API Path: GET /api/v1/messages/search
|
||||
//
|
||||
@@ -672,13 +834,64 @@ func (h *MemoryHandler) UpdateMessage(c *gin.Context) {
|
||||
// - agent_id (optional): Agent ID filter
|
||||
// - session_id (optional): Session ID filter
|
||||
// - user_id (optional): User ID filter
|
||||
//
|
||||
// TODO: Implementation pending - depends on embedding engine
|
||||
func (h *MemoryHandler) SearchMessage(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
userID := strings.TrimSpace(user.ID)
|
||||
if userID == "" {
|
||||
jsonError(c, common.CodeAuthenticationError, "user id is required")
|
||||
return
|
||||
}
|
||||
|
||||
var memoryIDs []string
|
||||
values := c.QueryArray("memory_id")
|
||||
for _, v := range values {
|
||||
parts := strings.Split(v, ",")
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
memoryIDs = append(memoryIDs, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
query := c.Query("query")
|
||||
|
||||
similarityThreshold, _ := strconv.ParseFloat(c.DefaultQuery("similarity_threshold", "0.2"), 64)
|
||||
keywordsSimilarityWeight, _ := strconv.ParseFloat(c.DefaultQuery("keywords_similarity_weight", "0.7"), 64)
|
||||
topN, _ := strconv.Atoi(c.DefaultQuery("top_n", "5"))
|
||||
|
||||
agentID := c.DefaultQuery("agent_id", "")
|
||||
sessionID := c.DefaultQuery("session_id", "")
|
||||
|
||||
filterDict := map[string]interface{}{
|
||||
"memory_id": memoryIDs,
|
||||
"agent_id": agentID,
|
||||
"session_id": sessionID,
|
||||
"user_id": c.DefaultQuery("user_id", ""),
|
||||
}
|
||||
|
||||
params := map[string]interface{}{
|
||||
"query": query,
|
||||
"similarity_threshold": similarityThreshold,
|
||||
"keywords_similarity_weight": keywordsSimilarityWeight,
|
||||
"top_n": topN,
|
||||
}
|
||||
|
||||
res, code, err := h.memoryService.SearchMessage(c.Request.Context(), userID, filterDict, params)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": "SearchMessage not implemented - pending embedding engine dependency",
|
||||
"data": nil,
|
||||
"code": common.CodeSuccess,
|
||||
"message": true,
|
||||
"data": res,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -694,33 +907,49 @@ func (h *MemoryHandler) SearchMessage(c *gin.Context) {
|
||||
// - agent_id (optional): Agent ID filter
|
||||
// - session_id (optional): Session ID filter
|
||||
// - limit (optional): Number of results to return, default 10
|
||||
//
|
||||
// TODO: Implementation pending - depends on embedding engine
|
||||
func (h *MemoryHandler) GetMessages(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": "GetMessages not implemented - pending embedding engine dependency",
|
||||
"data": nil,
|
||||
})
|
||||
}
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
userID := strings.TrimSpace(user.ID)
|
||||
if userID == "" {
|
||||
jsonError(c, common.CodeAuthenticationError, "user id is required")
|
||||
return
|
||||
}
|
||||
|
||||
var memoryIDs []string
|
||||
values := c.QueryArray("memory_id")
|
||||
for _, v := range values {
|
||||
parts := strings.Split(v, ",")
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
memoryIDs = append(memoryIDs, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
agentID := c.DefaultQuery("agent_id", "")
|
||||
sessionID := c.DefaultQuery("session_id", "")
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
|
||||
if len(memoryIDs) == 0 {
|
||||
jsonError(c, common.CodeArgumentError, "memory_ids is required.")
|
||||
return
|
||||
}
|
||||
|
||||
data, code, err := h.memoryService.GetMessages(c.Request.Context(), memoryIDs, userID, agentID, sessionID, limit)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// GetMessageContent handles GET request for getting message content
|
||||
// API Path: GET /api/v1/messages/:memory_id/:message_id/content
|
||||
//
|
||||
// Function:
|
||||
// - Gets complete content of the specified message
|
||||
// - doc_id format: memory_id + "_" + message_id
|
||||
//
|
||||
// Parameter Format:
|
||||
// - memory_id: Memory ID
|
||||
// - message_id: Message ID (integer)
|
||||
//
|
||||
// TODO: Implementation pending - depends on embedding engine
|
||||
func (h *MemoryHandler) GetMessageContent(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": "GetMessageContent not implemented - pending embedding engine dependency",
|
||||
"data": nil,
|
||||
"code": common.CodeSuccess,
|
||||
"message": true,
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -414,8 +414,11 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
// Message routes
|
||||
message := v1.Group("/messages")
|
||||
{
|
||||
message.GET("", r.memoryHandler.GetMessages)
|
||||
message.DELETE("/:memory_message", r.memoryHandler.ForgetMessage)
|
||||
message.PUT("/:memory_message", r.memoryHandler.UpdateMessage)
|
||||
message.GET("/:memory_message/content", r.memoryHandler.GetMessageContent)
|
||||
message.GET("/search", r.memoryHandler.SearchMessage)
|
||||
}
|
||||
|
||||
// Skill search routes
|
||||
|
||||
@@ -20,8 +20,10 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/entity"
|
||||
models "ragflow/internal/entity/models"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -29,6 +31,7 @@ import (
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/engine"
|
||||
enginetypes "ragflow/internal/engine/types"
|
||||
"ragflow/internal/service/nlp"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -55,6 +58,12 @@ const (
|
||||
TenantPermissionAll TenantPermission = "all"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMessageTopN = 5
|
||||
defaultMessageLimit = 10
|
||||
maxMessageLimit = 100
|
||||
)
|
||||
|
||||
// validPermissions defines which permission values are valid
|
||||
var validPermissions = map[TenantPermission]bool{
|
||||
TenantPermissionMe: true,
|
||||
@@ -793,7 +802,7 @@ func (s *MemoryService) ForgetMessage(ctx context.Context, userID string, memory
|
||||
condition := map[string]interface{}{
|
||||
"id": messageDocID,
|
||||
}
|
||||
indexName := fmt.Sprintf("memory_%s", memory.TenantID)
|
||||
indexName := memoryIndexName(memory.TenantID)
|
||||
|
||||
if err := s.docEngine.UpdateChunks(ctx, condition, updates, indexName, memoryID); err != nil {
|
||||
if isMessageDocumentNotFound(err) {
|
||||
@@ -807,6 +816,453 @@ func (s *MemoryService) ForgetMessage(ctx context.Context, userID string, memory
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MemoryService) UpdateMessage(ctx context.Context, userID, memoryID string, messageID int64, status bool) (bool, error) {
|
||||
memory, err := s.requireMemoryAccess(ctx, userID, memoryID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if s.docEngine == nil {
|
||||
return false, errors.New("message store is not initialized")
|
||||
}
|
||||
|
||||
messageDocID := fmt.Sprintf("%s_%d", memoryID, messageID)
|
||||
statusValue := 0
|
||||
if status {
|
||||
statusValue = 1
|
||||
}
|
||||
updates := map[string]interface{}{
|
||||
"status": statusValue,
|
||||
}
|
||||
condition := map[string]interface{}{
|
||||
"id": messageDocID,
|
||||
}
|
||||
indexName := memoryIndexName(memory.TenantID)
|
||||
if err := s.docEngine.UpdateChunks(ctx, condition, updates, indexName, memoryID); err != nil {
|
||||
if isMessageDocumentNotFound(err) {
|
||||
return false, &ResourceNotFoundError{Resource: "Message", ID: messageDocID}
|
||||
}
|
||||
return false, fmt.Errorf("failed to set status for message '%d' in memory '%s': %w", messageID, memoryID, err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *MemoryService) GetMessageContent(ctx context.Context, userID, memoryID string, messageID int64) (map[string]interface{}, error) {
|
||||
memory, err := s.requireMemoryAccess(ctx, userID, memoryID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.docEngine == nil {
|
||||
return nil, errors.New("message store is not initialized")
|
||||
}
|
||||
|
||||
indexName := memoryIndexName(memory.TenantID)
|
||||
docID := fmt.Sprintf("%s_%d", memoryID, messageID)
|
||||
res, err := s.docEngine.GetChunk(ctx, indexName, docID, []string{memoryID})
|
||||
if err != nil {
|
||||
if isMessageDocumentNotFound(err) {
|
||||
return nil, &ResourceNotFoundError{Resource: "Message", ID: docID}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if res == nil {
|
||||
return nil, &ResourceNotFoundError{Resource: "Message", ID: docID}
|
||||
}
|
||||
|
||||
message, ok := res.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected message content type %T", res)
|
||||
}
|
||||
return common.ConvertFloatsToPyFormat(message).(map[string]interface{}), nil
|
||||
}
|
||||
|
||||
func (s *MemoryService) SearchMessage(ctx context.Context, userID string, filterDict, params map[string]interface{}) ([]map[string]interface{}, common.ErrorCode, error) {
|
||||
memoryIDs := splitFilterValues(filterDict["memory_id"])
|
||||
memories, err := s.filterAccessibleMemories(ctx, userID, memoryIDs)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if len(memories) == 0 {
|
||||
return []map[string]interface{}{}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
return s.queryMessage(ctx, memories, filterDict, params)
|
||||
}
|
||||
|
||||
func (s *MemoryService) queryMessage(ctx context.Context, memories []*entity.Memory, filterDict, params map[string]interface{}) ([]map[string]interface{}, common.ErrorCode, error) {
|
||||
if s.docEngine == nil {
|
||||
return nil, common.CodeServerError, errors.New("message store is not initialized")
|
||||
}
|
||||
|
||||
topN := memoryIntParam(params["top_n"], 5)
|
||||
if topN <= 0 {
|
||||
topN = defaultMessageTopN
|
||||
} else if topN > maxMessageLimit {
|
||||
topN = maxMessageLimit
|
||||
}
|
||||
similarityThreshold := memoryFloatParam(params["similarity_threshold"], 0.2)
|
||||
keywordsSimilarityWeight := memoryFloatParam(params["keywords_similarity_weight"], 0.7)
|
||||
question := strings.TrimSpace(memoryStringParam(params["query"]))
|
||||
|
||||
memoryIDs := make([]string, 0, len(memories))
|
||||
conditionDict := make(map[string]interface{})
|
||||
for _, memory := range memories {
|
||||
if memory == nil {
|
||||
continue
|
||||
}
|
||||
memoryIDs = append(memoryIDs, memory.ID)
|
||||
}
|
||||
conditionDict["memory_id"] = memoryIDs
|
||||
for _, key := range []string{"agent_id", "session_id", "user_id"} {
|
||||
value := strings.TrimSpace(memoryStringParam(filterDict[key]))
|
||||
if value != "" {
|
||||
conditionDict[key] = value
|
||||
}
|
||||
}
|
||||
if _, ok := conditionDict["status"]; !ok {
|
||||
conditionDict["status"] = 1
|
||||
}
|
||||
|
||||
matchExprs := make([]interface{}, 0, 3)
|
||||
if question != "" {
|
||||
matchText := memoryMessageTextExpr(question, similarityThreshold)
|
||||
matchDense, err := s.memoryMessageDenseExpr(question, memories[0], topN, similarityThreshold)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
fusionExpr := &enginetypes.FusionExpr{
|
||||
Method: "weighted_sum",
|
||||
TopN: topN,
|
||||
FusionParams: map[string]interface{}{
|
||||
"weights": fmt.Sprintf("%g,%g", 1-keywordsSimilarityWeight, keywordsSimilarityWeight),
|
||||
},
|
||||
}
|
||||
matchExprs = append(matchExprs, matchText, matchDense, fusionExpr)
|
||||
}
|
||||
|
||||
searchReq := &enginetypes.SearchRequest{
|
||||
IndexNames: memorySearchIndexNames(memories),
|
||||
Offset: 0,
|
||||
Limit: topN,
|
||||
SelectFields: memoryMessageSelectFields(),
|
||||
Filter: conditionDict,
|
||||
MatchExprs: matchExprs,
|
||||
OrderBy: (&enginetypes.OrderByExpr{}).Desc("valid_at"),
|
||||
}
|
||||
|
||||
searchResult, err := s.docEngine.Search(ctx, searchReq)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if searchResult == nil || searchResult.Total == 0 {
|
||||
return []map[string]interface{}{}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
messages := make([]map[string]interface{}, 0, len(searchResult.Chunks))
|
||||
for _, chunk := range searchResult.Chunks {
|
||||
message := make(map[string]interface{}, len(chunk))
|
||||
for _, field := range memoryMessageSelectFields() {
|
||||
if value, ok := chunk[field]; ok {
|
||||
message[field] = value
|
||||
}
|
||||
}
|
||||
messages = append(messages, message)
|
||||
}
|
||||
return common.ConvertFloatsToPyFormat(messages).([]map[string]interface{}), common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (s *MemoryService) filterAccessibleMemories(ctx context.Context, userID string, memoryIDs []string) ([]*entity.Memory, error) {
|
||||
memoryIDs = splitFilterValues(memoryIDs)
|
||||
if len(memoryIDs) == 0 {
|
||||
return []*entity.Memory{}, nil
|
||||
}
|
||||
|
||||
memories, err := s.memoryDAO.GetByIDs(memoryIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(memories) == 0 {
|
||||
return []*entity.Memory{}, nil
|
||||
}
|
||||
|
||||
joinedTenantIDs := map[string]struct{}{userID: {}}
|
||||
needsTeamLookup := false
|
||||
for _, memory := range memories {
|
||||
if memory != nil && memory.TenantID != userID && memory.Permissions == string(TenantPermissionTeam) {
|
||||
needsTeamLookup = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if needsTeamLookup {
|
||||
userTenants, err := NewUserTenantService().GetUserTenantRelationByUserIDWithContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, tenant := range userTenants {
|
||||
if tenant != nil && tenant.TenantID != "" {
|
||||
joinedTenantIDs[tenant.TenantID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessible := make([]*entity.Memory, 0, len(memories))
|
||||
for _, memory := range memories {
|
||||
if memory == nil {
|
||||
continue
|
||||
}
|
||||
if memory.TenantID == userID {
|
||||
accessible = append(accessible, memory)
|
||||
continue
|
||||
}
|
||||
if memory.Permissions != string(TenantPermissionTeam) {
|
||||
continue
|
||||
}
|
||||
if _, ok := joinedTenantIDs[memory.TenantID]; ok {
|
||||
accessible = append(accessible, memory)
|
||||
}
|
||||
}
|
||||
return accessible, nil
|
||||
}
|
||||
|
||||
func splitFilterValues(values interface{}) []string {
|
||||
if values == nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
var list []string
|
||||
|
||||
switch v := values.(type) {
|
||||
case string:
|
||||
list = []string{v}
|
||||
case []string:
|
||||
list = v
|
||||
case []interface{}:
|
||||
for _, x := range v {
|
||||
if s, ok := x.(string); ok {
|
||||
list = append(list, s)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return []string{}
|
||||
}
|
||||
|
||||
res := make([]string, 0)
|
||||
for _, item := range list {
|
||||
if item == "" {
|
||||
continue
|
||||
}
|
||||
parts := strings.Split(item, ",")
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
res = append(res, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func memoryStringParam(value interface{}) string {
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return typed
|
||||
case fmt.Stringer:
|
||||
return typed.String()
|
||||
default:
|
||||
if typed == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%v", typed)
|
||||
}
|
||||
}
|
||||
|
||||
func memoryFloatParam(value interface{}, fallback float64) float64 {
|
||||
switch typed := value.(type) {
|
||||
case float64:
|
||||
return typed
|
||||
case float32:
|
||||
return float64(typed)
|
||||
case int:
|
||||
return float64(typed)
|
||||
case int64:
|
||||
return float64(typed)
|
||||
case string:
|
||||
if parsed, err := strconv.ParseFloat(strings.TrimSpace(typed), 64); err == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func memoryIntParam(value interface{}, fallback int) int {
|
||||
switch typed := value.(type) {
|
||||
case int:
|
||||
return typed
|
||||
case int64:
|
||||
return int(typed)
|
||||
case float64:
|
||||
return int(typed)
|
||||
case string:
|
||||
if parsed, err := strconv.Atoi(strings.TrimSpace(typed)); err == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func memoryMessageSelectFields() []string {
|
||||
return []string{
|
||||
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id",
|
||||
"valid_at", "invalid_at", "forget_at", "status", "content",
|
||||
}
|
||||
}
|
||||
|
||||
func memoryIndexName(tenantID string) string {
|
||||
prefix := strings.TrimSpace(os.Getenv("ES_INDEX_PREFIX"))
|
||||
if prefix == "" {
|
||||
return fmt.Sprintf("memory_%s", tenantID)
|
||||
}
|
||||
return fmt.Sprintf("memory_%s_%s", prefix, tenantID)
|
||||
}
|
||||
|
||||
func memorySearchIndexNames(memories []*entity.Memory) []string {
|
||||
seen := make(map[string]struct{}, len(memories))
|
||||
indexNames := make([]string, 0, len(memories))
|
||||
for _, memory := range memories {
|
||||
if memory == nil {
|
||||
continue
|
||||
}
|
||||
indexName := memoryIndexName(memory.TenantID)
|
||||
if engine.GetEngineType() == engine.EngineInfinity {
|
||||
indexName = fmt.Sprintf("%s_%s", indexName, memory.ID)
|
||||
}
|
||||
if _, ok := seen[indexName]; ok {
|
||||
continue
|
||||
}
|
||||
seen[indexName] = struct{}{}
|
||||
indexNames = append(indexNames, indexName)
|
||||
}
|
||||
return indexNames
|
||||
}
|
||||
|
||||
func memoryMessageTextExpr(question string, similarityThreshold float64) *enginetypes.MatchTextExpr {
|
||||
matchText := &enginetypes.MatchTextExpr{
|
||||
Fields: []string{"content"},
|
||||
MatchingText: question,
|
||||
TopN: 100,
|
||||
ExtraOptions: map[string]interface{}{"original_query": question},
|
||||
}
|
||||
|
||||
queryBuilder := nlp.GetQueryBuilder()
|
||||
if queryBuilder == nil {
|
||||
queryBuilder = nlp.NewQueryBuilder()
|
||||
}
|
||||
if built, _ := queryBuilder.Question(question, "messages", similarityThreshold); built != nil {
|
||||
matchText.MatchingText = built.MatchingText
|
||||
matchText.ExtraOptions = built.ExtraOptions
|
||||
if matchText.ExtraOptions == nil {
|
||||
matchText.ExtraOptions = map[string]interface{}{}
|
||||
}
|
||||
matchText.ExtraOptions["original_query"] = question
|
||||
}
|
||||
matchText.Fields = []string{"content"}
|
||||
matchText.TopN = 100
|
||||
return matchText
|
||||
}
|
||||
|
||||
func (s *MemoryService) memoryMessageDenseExpr(question string, memory *entity.Memory, topN int, similarityThreshold float64) (*enginetypes.MatchDenseExpr, error) {
|
||||
driver, modelName, apiConfig, maxTokens, err := NewModelProviderService().GetModelConfigFromProviderInstance(memory.TenantID, entity.ModelTypeEmbedding, memory.EmbdID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
embeddingModel := models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens)
|
||||
embeddings, err := embeddingModel.ModelDriver.Embed(embeddingModel.ModelName, []string{question}, embeddingModel.APIConfig, &models.EmbeddingConfig{Dimension: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(embeddings) == 0 || len(embeddings[0].Embedding) == 0 {
|
||||
return nil, errors.New("embedding response is empty")
|
||||
}
|
||||
|
||||
vector := embeddings[0].Embedding
|
||||
return &enginetypes.MatchDenseExpr{
|
||||
VectorColumnName: fmt.Sprintf("q_%d_vec", len(vector)),
|
||||
EmbeddingData: vector,
|
||||
EmbeddingDataType: "float",
|
||||
DistanceType: "cosine",
|
||||
TopN: topN,
|
||||
ExtraOptions: map[string]interface{}{"similarity": similarityThreshold},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *MemoryService) GetMessages(ctx context.Context, memoryIDs []string, userID, agentID, sessionID string, limit int) ([]map[string]interface{}, common.ErrorCode, error) {
|
||||
memories, err := s.filterAccessibleMemories(ctx, userID, memoryIDs)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if len(memories) == 0 {
|
||||
return []map[string]interface{}{}, common.CodeSuccess, nil
|
||||
}
|
||||
return s.getRecentMessage(ctx, memories, agentID, sessionID, limit)
|
||||
}
|
||||
|
||||
func (s *MemoryService) getRecentMessage(ctx context.Context, memories []*entity.Memory, agentID, sessionID string, limit int) ([]map[string]interface{}, common.ErrorCode, error) {
|
||||
if s.docEngine == nil {
|
||||
return nil, common.CodeServerError, errors.New("doc engine is nil")
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = defaultMessageLimit
|
||||
} else if limit > maxMessageLimit {
|
||||
limit = maxMessageLimit
|
||||
}
|
||||
indexNames := memorySearchIndexNames(memories)
|
||||
memoryIDs := make([]string, 0, len(memories))
|
||||
for _, memory := range memories {
|
||||
if memory == nil || strings.TrimSpace(memory.ID) == "" {
|
||||
continue
|
||||
}
|
||||
memoryIDs = append(memoryIDs, memory.ID)
|
||||
}
|
||||
|
||||
conditionDict := map[string]interface{}{"memory_id": memoryIDs}
|
||||
if agentID = strings.TrimSpace(agentID); agentID != "" {
|
||||
conditionDict["agent_id"] = agentID
|
||||
}
|
||||
if sessionID = strings.TrimSpace(sessionID); sessionID != "" {
|
||||
conditionDict["session_id"] = sessionID
|
||||
}
|
||||
req := &enginetypes.SearchRequest{
|
||||
IndexNames: indexNames,
|
||||
Offset: 0,
|
||||
Limit: limit,
|
||||
SelectFields: memoryMessageSelectFields(),
|
||||
Filter: conditionDict,
|
||||
MatchExprs: []interface{}{},
|
||||
OrderBy: (&enginetypes.OrderByExpr{}).Desc("valid_at"),
|
||||
}
|
||||
|
||||
result, err := s.docEngine.Search(ctx, req)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if result == nil || result.Total == 0 {
|
||||
return []map[string]interface{}{}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
messages := make([]map[string]interface{}, 0, len(result.Chunks))
|
||||
for _, chunk := range result.Chunks {
|
||||
msg := make(map[string]interface{}, len(chunk))
|
||||
for _, field := range memoryMessageSelectFields() {
|
||||
if val, ok := chunk[field]; ok {
|
||||
msg[field] = val
|
||||
}
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
return common.ConvertFloatsToPyFormat(messages).([]map[string]interface{}), common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func isMessageDocumentNotFound(err error) bool {
|
||||
return errors.Is(err, enginetypes.ErrDocumentNotFound)
|
||||
}
|
||||
@@ -933,8 +1389,266 @@ func (s *MemoryService) GetMemoryConfig(memoryID string) (*CreateMemoryResponse,
|
||||
return formatRetDataFromMemoryListItem(memory), nil
|
||||
}
|
||||
|
||||
// TODO: GetMemoryMessages - Implementation pending - depends on CanvasService and TaskService
|
||||
// func (s *MemoryService) GetMemoryMessages(memoryID string, agentIDs []string, keywords string, page int, pageSize int) (map[string]interface{}, error) { ... }
|
||||
func (s *MemoryService) GetMemoryMessages(ctx context.Context, userID, memoryID string, agentIDs []string, keywords string, page int, pageSize int) (map[string]interface{}, error) {
|
||||
memory, err := s.requireMemoryAccess(ctx, userID, memoryID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
messages, err := s.listMemoryMessages(ctx, memory, agentIDs, keywords, page, pageSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawMessages, _ := messages["message_list"].([]map[string]interface{})
|
||||
agentNames := map[string]string{}
|
||||
tasks := map[string]map[string]interface{}{}
|
||||
if len(rawMessages) > 0 {
|
||||
agentNames, err = s.memoryMessageAgentNames(rawMessages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tasks, err = s.memoryMessageTasks(memoryID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for _, message := range rawMessages {
|
||||
agentID, _ := message["agent_id"].(string)
|
||||
message["agent_name"] = "Unknown"
|
||||
if name, ok := agentNames[agentID]; ok {
|
||||
message["agent_name"] = name
|
||||
}
|
||||
message["task"] = map[string]interface{}{}
|
||||
if task, ok := tasks[memoryMessageKey(message["message_id"])]; ok {
|
||||
message["task"] = task
|
||||
}
|
||||
if extracts, ok := message["extract"].([]map[string]interface{}); ok {
|
||||
for _, extract := range extracts {
|
||||
extractAgentID, _ := extract["agent_id"].(string)
|
||||
extract["agent_name"] = "Unknown"
|
||||
if name, ok := agentNames[extractAgentID]; ok {
|
||||
extract["agent_name"] = name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return common.ConvertFloatsToPyFormat(map[string]interface{}{
|
||||
"messages": messages,
|
||||
"storage_type": memory.StorageType,
|
||||
}).(map[string]interface{}), nil
|
||||
}
|
||||
|
||||
func (s *MemoryService) listMemoryMessages(ctx context.Context, memory *entity.Memory, agentIDs []string, keywords string, page int, pageSize int) (map[string]interface{}, error) {
|
||||
if s.docEngine == nil {
|
||||
return nil, errors.New("message store is not initialized")
|
||||
}
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = defaultMessageLimit
|
||||
} else if pageSize > maxMessageLimit {
|
||||
pageSize = maxMessageLimit
|
||||
}
|
||||
|
||||
memoryID := memory.ID
|
||||
selectFields := memoryMessageListFields()
|
||||
filter := map[string]interface{}{
|
||||
"message_type": "raw",
|
||||
}
|
||||
if len(agentIDs) > 0 {
|
||||
filter["agent_id"] = agentIDs
|
||||
}
|
||||
if keywords = strings.TrimSpace(keywords); keywords != "" {
|
||||
filter["session_id"] = keywords
|
||||
}
|
||||
filter["memory_id"] = []string{memoryID}
|
||||
indexNames := memorySearchIndexNames([]*entity.Memory{memory})
|
||||
|
||||
rawReq := &enginetypes.SearchRequest{
|
||||
IndexNames: indexNames,
|
||||
Offset: (page - 1) * pageSize,
|
||||
Limit: pageSize,
|
||||
SelectFields: selectFields,
|
||||
Filter: filter,
|
||||
MatchExprs: []interface{}{},
|
||||
OrderBy: (&enginetypes.OrderByExpr{}).Desc("valid_at"),
|
||||
}
|
||||
rawResult, err := s.docEngine.Search(ctx, rawReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
messages := map[string]interface{}{
|
||||
"message_list": []map[string]interface{}{},
|
||||
"total_count": int64(0),
|
||||
}
|
||||
if rawResult != nil {
|
||||
messages["total_count"] = rawResult.Total
|
||||
}
|
||||
if rawResult == nil || rawResult.Total == 0 {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
rawMessages := make([]map[string]interface{}, 0, len(rawResult.Chunks))
|
||||
sourceIDs := make([]interface{}, 0, len(rawResult.Chunks))
|
||||
for _, chunk := range rawResult.Chunks {
|
||||
message := memoryMessageFromChunk(chunk, selectFields)
|
||||
message["extract"] = []map[string]interface{}{}
|
||||
if messageID, ok := message["message_id"]; ok {
|
||||
sourceIDs = append(sourceIDs, messageID)
|
||||
}
|
||||
rawMessages = append(rawMessages, message)
|
||||
}
|
||||
|
||||
if len(sourceIDs) > 0 {
|
||||
extractReq := &enginetypes.SearchRequest{
|
||||
IndexNames: indexNames,
|
||||
Offset: 0,
|
||||
Limit: 512,
|
||||
SelectFields: selectFields,
|
||||
Filter: map[string]interface{}{
|
||||
"memory_id": []string{memoryID},
|
||||
"source_id": sourceIDs,
|
||||
},
|
||||
MatchExprs: []interface{}{},
|
||||
OrderBy: (&enginetypes.OrderByExpr{}).Desc("valid_at"),
|
||||
}
|
||||
extractResult, err := s.docEngine.Search(ctx, extractReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if extractResult != nil && extractResult.Total > 0 {
|
||||
groupedExtracts := make(map[string][]map[string]interface{})
|
||||
for _, chunk := range extractResult.Chunks {
|
||||
message := memoryMessageFromChunk(chunk, selectFields)
|
||||
sourceID := memoryMessageKey(message["source_id"])
|
||||
groupedExtracts[sourceID] = append(groupedExtracts[sourceID], message)
|
||||
}
|
||||
for _, message := range rawMessages {
|
||||
messageID := memoryMessageKey(message["message_id"])
|
||||
if extracts, ok := groupedExtracts[messageID]; ok {
|
||||
message["extract"] = extracts
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages["message_list"] = rawMessages
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func memoryMessageListFields() []string {
|
||||
return []string{
|
||||
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id",
|
||||
"valid_at", "invalid_at", "forget_at", "status",
|
||||
}
|
||||
}
|
||||
|
||||
func memoryMessageFromChunk(chunk map[string]interface{}, fields []string) map[string]interface{} {
|
||||
message := make(map[string]interface{}, len(fields))
|
||||
for _, field := range fields {
|
||||
if value, ok := chunk[field]; ok {
|
||||
message[field] = value
|
||||
}
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
func (s *MemoryService) memoryMessageAgentNames(messages []map[string]interface{}) (map[string]string, error) {
|
||||
agentIDSet := make(map[string]struct{})
|
||||
for _, message := range messages {
|
||||
agentID, _ := message["agent_id"].(string)
|
||||
if agentID != "" {
|
||||
agentIDSet[agentID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(agentIDSet) == 0 {
|
||||
return map[string]string{}, nil
|
||||
}
|
||||
|
||||
agentIDs := make([]string, 0, len(agentIDSet))
|
||||
for agentID := range agentIDSet {
|
||||
agentIDs = append(agentIDs, agentID)
|
||||
}
|
||||
|
||||
var agents []struct {
|
||||
ID string `gorm:"column:id"`
|
||||
Title *string `gorm:"column:title"`
|
||||
}
|
||||
if err := dao.DB.Model(&entity.UserCanvas{}).Select("id, title").Where("id IN ?", agentIDs).Scan(&agents).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
agentNames := make(map[string]string, len(agents))
|
||||
for _, agent := range agents {
|
||||
if agent.Title != nil {
|
||||
agentNames[agent.ID] = *agent.Title
|
||||
}
|
||||
}
|
||||
return agentNames, nil
|
||||
}
|
||||
|
||||
func (s *MemoryService) memoryMessageTasks(memoryID string) (map[string]map[string]interface{}, error) {
|
||||
var tasks []struct {
|
||||
ID string `gorm:"column:id"`
|
||||
DocID string `gorm:"column:doc_id"`
|
||||
FromPage int64 `gorm:"column:from_page"`
|
||||
Progress float64 `gorm:"column:progress"`
|
||||
ProgressMsg *string `gorm:"column:progress_msg"`
|
||||
Digest *string `gorm:"column:digest"`
|
||||
ChunkIDs *string `gorm:"column:chunk_ids"`
|
||||
CreateTime *int64 `gorm:"column:create_time"`
|
||||
}
|
||||
if err := dao.DB.Model(&entity.Task{}).
|
||||
Select("id, doc_id, from_page, progress, progress_msg, digest, chunk_ids, create_time").
|
||||
Where("doc_id IN ?", []string{memoryID}).
|
||||
Order("create_time ASC").
|
||||
Scan(&tasks).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
taskByMessageID := make(map[string]map[string]interface{}, len(tasks))
|
||||
for _, task := range tasks {
|
||||
if task.Digest == nil {
|
||||
continue
|
||||
}
|
||||
digest := strings.TrimSpace(*task.Digest)
|
||||
if digest == "" {
|
||||
continue
|
||||
}
|
||||
var progressMsg interface{}
|
||||
if task.ProgressMsg != nil {
|
||||
progressMsg = *task.ProgressMsg
|
||||
}
|
||||
var chunkIDs interface{}
|
||||
if task.ChunkIDs != nil {
|
||||
chunkIDs = *task.ChunkIDs
|
||||
}
|
||||
var createTime interface{}
|
||||
if task.CreateTime != nil {
|
||||
createTime = *task.CreateTime
|
||||
}
|
||||
taskMap := map[string]interface{}{
|
||||
"id": task.ID,
|
||||
"doc_id": task.DocID,
|
||||
"from_page": task.FromPage,
|
||||
"progress": task.Progress,
|
||||
"progress_msg": progressMsg,
|
||||
"digest": digest,
|
||||
"chunk_ids": chunkIDs,
|
||||
"create_time": createTime,
|
||||
}
|
||||
taskByMessageID[digest] = taskMap
|
||||
}
|
||||
return taskByMessageID, nil
|
||||
}
|
||||
|
||||
func memoryMessageKey(value interface{}) string {
|
||||
return strings.TrimSpace(fmt.Sprint(value))
|
||||
}
|
||||
|
||||
// TODO: queryMessages - Implementation pending - depends on CanvasService and TaskService
|
||||
// func (s *MemoryService) queryMessages(tenantID string, memoryID string, filterDict map[string]interface{}, page int, pageSize int) ([]map[string]interface{}, int64, error) { ... }
|
||||
@@ -945,15 +1659,6 @@ func (s *MemoryService) GetMemoryConfig(memoryID string) (*CreateMemoryResponse,
|
||||
// TODO: UpdateMessageStatus - Implementation pending - depends on embedding engine
|
||||
// func (s *MemoryService) UpdateMessageStatus(memoryID string, messageID int, status bool) (bool, error) { ... }
|
||||
|
||||
// TODO: SearchMessage - Implementation pending - depends on embedding engine
|
||||
// func (s *MemoryService) SearchMessage(filterDict map[string]interface{}, params map[string]interface{}) ([]map[string]interface{}, error) { ... }
|
||||
|
||||
// TODO: GetMessages - Implementation pending - depends on embedding engine
|
||||
// func (s *MemoryService) GetMessages(memoryIDs []string, agentID string, sessionID string, limit int) ([]map[string]interface{}, error) { ... }
|
||||
|
||||
// TODO: GetMessageContent - Implementation pending - depends on embedding engine
|
||||
// func (s *MemoryService) GetMessageContent(memoryID string, messageID int) (map[string]interface{}, error) { ... }
|
||||
|
||||
// isList checks if a value is a list or array type
|
||||
// This is a utility function for type validation
|
||||
//
|
||||
|
||||
@@ -4,9 +4,16 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
enginetypes "ragflow/internal/engine/types"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
func TestIsMessageDocumentNotFound(t *testing.T) {
|
||||
@@ -28,3 +35,240 @@ func TestRequireMemoryAccessReturnsCanceledContext(t *testing.T) {
|
||||
t.Fatalf("requireMemoryAccess error = %v, want %v", gotErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
type memoryMessageDocEngine struct {
|
||||
fakeChatDocEngine
|
||||
searchReq *enginetypes.SearchRequest
|
||||
searchResp *enginetypes.SearchResult
|
||||
updateCond map[string]interface{}
|
||||
updateValue map[string]interface{}
|
||||
updateBase string
|
||||
updateID string
|
||||
}
|
||||
|
||||
func (e *memoryMessageDocEngine) Search(ctx context.Context, req *enginetypes.SearchRequest) (*enginetypes.SearchResult, error) {
|
||||
e.searchReq = req
|
||||
if e.searchResp != nil {
|
||||
return e.searchResp, nil
|
||||
}
|
||||
return &enginetypes.SearchResult{}, nil
|
||||
}
|
||||
|
||||
func (e *memoryMessageDocEngine) UpdateChunks(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, baseName string, datasetID string) error {
|
||||
e.updateCond = condition
|
||||
e.updateValue = newValue
|
||||
e.updateBase = baseName
|
||||
e.updateID = datasetID
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupMemoryMessageTestDB(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{TranslateError: true})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&entity.Memory{}, &entity.UserTenant{}); err != nil {
|
||||
t.Fatalf("failed to migrate memory test tables: %v", err)
|
||||
}
|
||||
|
||||
orig := dao.DB
|
||||
dao.DB = db
|
||||
t.Cleanup(func() {
|
||||
dao.DB = orig
|
||||
})
|
||||
}
|
||||
|
||||
func seedMemoryMessages(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
memories := []*entity.Memory{
|
||||
{
|
||||
ID: "mem-owned",
|
||||
Name: "Owned",
|
||||
TenantID: "user-1",
|
||||
MemoryType: dao.MemoryTypeRaw,
|
||||
StorageType: "table",
|
||||
EmbdID: "embd-1",
|
||||
LLMID: "llm-1",
|
||||
Permissions: string(TenantPermissionMe),
|
||||
ForgettingPolicy: string(ForgettingPolicyFIFO),
|
||||
},
|
||||
{
|
||||
ID: "mem-other",
|
||||
Name: "Other",
|
||||
TenantID: "user-2",
|
||||
MemoryType: dao.MemoryTypeRaw,
|
||||
StorageType: "table",
|
||||
EmbdID: "embd-2",
|
||||
LLMID: "llm-2",
|
||||
Permissions: string(TenantPermissionMe),
|
||||
ForgettingPolicy: string(ForgettingPolicyFIFO),
|
||||
},
|
||||
}
|
||||
for _, memory := range memories {
|
||||
if err := dao.DB.Create(memory).Error; err != nil {
|
||||
t.Fatalf("seed memory %s: %v", memory.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMessagesFiltersAccessibleMemoryAndBuildsRecentSearch(t *testing.T) {
|
||||
setupMemoryMessageTestDB(t)
|
||||
seedMemoryMessages(t)
|
||||
|
||||
docEngine := &memoryMessageDocEngine{
|
||||
searchResp: &enginetypes.SearchResult{
|
||||
Total: 1,
|
||||
Chunks: []map[string]interface{}{
|
||||
{
|
||||
"message_id": int64(12),
|
||||
"message_type": "raw",
|
||||
"memory_id": "mem-owned",
|
||||
"user_id": "user-1",
|
||||
"agent_id": "agent-1",
|
||||
"session_id": "session-1",
|
||||
"valid_at": float64(123),
|
||||
"status": 1,
|
||||
"content": "hello",
|
||||
"extra": "should be dropped",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &MemoryService{memoryDAO: dao.NewMemoryDAO(), docEngine: docEngine}
|
||||
|
||||
got, code, err := svc.GetMessages(context.Background(), []string{"mem-owned", "mem-other"}, "user-1", "agent-1", "session-1", 3)
|
||||
if err != nil {
|
||||
t.Fatalf("GetMessages error: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code = %v, want %v", code, common.CodeSuccess)
|
||||
}
|
||||
if len(got) != 1 || got[0]["content"] != "hello" {
|
||||
t.Fatalf("unexpected messages: %+v", got)
|
||||
}
|
||||
if _, ok := got[0]["extra"]; ok {
|
||||
t.Fatalf("unexpected non-selected field in response: %+v", got[0])
|
||||
}
|
||||
|
||||
req := docEngine.searchReq
|
||||
if req == nil {
|
||||
t.Fatal("expected doc engine search request")
|
||||
}
|
||||
if !reflect.DeepEqual(req.IndexNames, []string{"memory_user-1"}) {
|
||||
t.Fatalf("IndexNames = %v, want [memory_user-1]", req.IndexNames)
|
||||
}
|
||||
if len(req.KbIDs) != 0 {
|
||||
t.Fatalf("KbIDs = %v, want empty for memory message search", req.KbIDs)
|
||||
}
|
||||
if !reflect.DeepEqual(req.Filter["memory_id"], []string{"mem-owned"}) {
|
||||
t.Fatalf("memory_id filter = %v, want [mem-owned]", req.Filter["memory_id"])
|
||||
}
|
||||
if req.Filter["agent_id"] != "agent-1" || req.Filter["session_id"] != "session-1" {
|
||||
t.Fatalf("unexpected filter: %+v", req.Filter)
|
||||
}
|
||||
if req.Limit != 3 {
|
||||
t.Fatalf("Limit = %d, want 3", req.Limit)
|
||||
}
|
||||
if req.OrderBy == nil || len(req.OrderBy.Fields) != 1 || req.OrderBy.Fields[0].Field != "valid_at" || req.OrderBy.Fields[0].Type != enginetypes.SortDesc {
|
||||
t.Fatalf("unexpected order by: %+v", req.OrderBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchMessageFiltersAccessibleMemoryAndDefaultsStatus(t *testing.T) {
|
||||
setupMemoryMessageTestDB(t)
|
||||
seedMemoryMessages(t)
|
||||
|
||||
docEngine := &memoryMessageDocEngine{
|
||||
searchResp: &enginetypes.SearchResult{
|
||||
Total: 1,
|
||||
Chunks: []map[string]interface{}{
|
||||
{
|
||||
"message_id": int64(13),
|
||||
"message_type": "raw",
|
||||
"memory_id": "mem-owned",
|
||||
"user_id": "user-1",
|
||||
"agent_id": "agent-1",
|
||||
"session_id": "session-1",
|
||||
"valid_at": int64(456),
|
||||
"status": 1,
|
||||
"content": "matched",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &MemoryService{memoryDAO: dao.NewMemoryDAO(), docEngine: docEngine}
|
||||
filter := map[string]interface{}{
|
||||
"memory_id": []string{"mem-owned", "mem-other"},
|
||||
"agent_id": "agent-1",
|
||||
"session_id": "session-1",
|
||||
"user_id": "user-1",
|
||||
}
|
||||
params := map[string]interface{}{
|
||||
"query": "",
|
||||
"similarity_threshold": 0.2,
|
||||
"keywords_similarity_weight": 0.7,
|
||||
"top_n": 5,
|
||||
}
|
||||
|
||||
got, code, err := svc.SearchMessage(context.Background(), "user-1", filter, params)
|
||||
if err != nil {
|
||||
t.Fatalf("SearchMessage error: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code = %v, want %v", code, common.CodeSuccess)
|
||||
}
|
||||
if len(got) != 1 || got[0]["content"] != "matched" {
|
||||
t.Fatalf("unexpected search result: %+v", got)
|
||||
}
|
||||
|
||||
req := docEngine.searchReq
|
||||
if req == nil {
|
||||
t.Fatal("expected doc engine search request")
|
||||
}
|
||||
if !reflect.DeepEqual(req.Filter["memory_id"], []string{"mem-owned"}) {
|
||||
t.Fatalf("memory_id filter = %v, want [mem-owned]", req.Filter["memory_id"])
|
||||
}
|
||||
if req.Filter["status"] != 1 {
|
||||
t.Fatalf("status filter = %v, want 1", req.Filter["status"])
|
||||
}
|
||||
if req.Filter["agent_id"] != "agent-1" || req.Filter["session_id"] != "session-1" || req.Filter["user_id"] != "user-1" {
|
||||
t.Fatalf("unexpected filter: %+v", req.Filter)
|
||||
}
|
||||
if len(req.MatchExprs) != 0 {
|
||||
t.Fatalf("empty query should not build match expressions, got %+v", req.MatchExprs)
|
||||
}
|
||||
if req.Limit != 5 {
|
||||
t.Fatalf("Limit = %d, want 5", req.Limit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMessageUpdatesStatusByMessageDocID(t *testing.T) {
|
||||
setupMemoryMessageTestDB(t)
|
||||
seedMemoryMessages(t)
|
||||
|
||||
docEngine := &memoryMessageDocEngine{}
|
||||
svc := &MemoryService{memoryDAO: dao.NewMemoryDAO(), docEngine: docEngine}
|
||||
|
||||
ok, err := svc.UpdateMessage(context.Background(), "user-1", "mem-owned", 42, true)
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateMessage error: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("UpdateMessage returned false")
|
||||
}
|
||||
if docEngine.updateBase != "memory_user-1" {
|
||||
t.Fatalf("baseName = %q, want memory_user-1", docEngine.updateBase)
|
||||
}
|
||||
if docEngine.updateID != "mem-owned" {
|
||||
t.Fatalf("datasetID = %q, want mem-owned", docEngine.updateID)
|
||||
}
|
||||
if docEngine.updateCond["id"] != "mem-owned_42" {
|
||||
t.Fatalf("condition = %+v, want id mem-owned_42", docEngine.updateCond)
|
||||
}
|
||||
if docEngine.updateValue["status"] != 1 {
|
||||
t.Fatalf("status update = %+v, want status 1", docEngine.updateValue)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user