mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
fix: new chat cannot be edit (#16434)
### What problem does this PR solve?
As title
main fix:
```go
if _, ok := req["meta_data_filter"]; !ok || req["meta_data_filter"] == nil {
req["meta_data_filter"] = map[string]interface{}{}
}
```
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
This commit is contained in:
@@ -18,8 +18,12 @@ package dao
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"ragflow/internal/entity"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// ChatDAO chat data access object
|
||||
@@ -185,7 +189,28 @@ func (dao *ChatDAO) Create(chat *entity.Chat) error {
|
||||
|
||||
// UpdateByID updates a chat by ID
|
||||
func (dao *ChatDAO) UpdateByID(id string, updates map[string]interface{}) error {
|
||||
return DB.Model(&entity.Chat{}).Where("id = ?", id).Updates(updates).Error
|
||||
if updates == nil {
|
||||
updates = make(map[string]interface{})
|
||||
}
|
||||
|
||||
now := time.Now().Local()
|
||||
updates["update_time"] = now.UnixMilli()
|
||||
updates["update_date"] = now.Truncate(time.Second)
|
||||
|
||||
result := DB.Session(&gorm.Session{SkipHooks: true}).Model(&entity.Chat{}).Where("id = ?", id).Updates(updates)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
var count int64
|
||||
if err := DB.Model(&entity.Chat{}).Where("id = ?", id).Count(&count).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if count == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateManyByID updates multiple chats by ID (batch update)
|
||||
|
||||
@@ -76,12 +76,26 @@ func (dao *ChatSessionDAO) Create(conv *entity.ChatSession) error {
|
||||
|
||||
// UpdateByID updates a chat session by ID
|
||||
func (dao *ChatSessionDAO) UpdateByID(id string, updates map[string]interface{}) error {
|
||||
result := DB.Model(&entity.ChatSession{}).Where("id = ?", id).Updates(updates)
|
||||
if updates == nil {
|
||||
updates = make(map[string]interface{})
|
||||
}
|
||||
|
||||
now := time.Now().Local()
|
||||
updates["update_time"] = now.UnixMilli()
|
||||
updates["update_date"] = now.Truncate(time.Second)
|
||||
|
||||
result := DB.Session(&gorm.Session{SkipHooks: true}).Model(&entity.ChatSession{}).Where("id = ?", id).Updates(updates)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
var count int64
|
||||
if err := DB.Model(&entity.ChatSession{}).Where("id = ?", id).Count(&count).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if count == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ package dao
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -37,7 +38,7 @@ func setupChatSessionDAOTestDB(t *testing.T) *gorm.DB {
|
||||
t.Fatalf("failed to open sqlite: %v", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&entity.API4Conversation{}); err != nil {
|
||||
if err := db.AutoMigrate(&entity.API4Conversation{}, &entity.ChatSession{}); err != nil {
|
||||
t.Fatalf("failed to migrate: %v", err)
|
||||
}
|
||||
|
||||
@@ -66,6 +67,72 @@ func createAgentSessionForDAOTest(t *testing.T, db *gorm.DB, id, agentID, userID
|
||||
}
|
||||
}
|
||||
|
||||
func createChatSessionForDAOTest(t *testing.T, db *gorm.DB, id, chatID, name string, updateTime int64) {
|
||||
t.Helper()
|
||||
|
||||
updateDate := time.UnixMilli(updateTime).Local()
|
||||
session := &entity.ChatSession{
|
||||
ID: id,
|
||||
DialogID: chatID,
|
||||
Name: &name,
|
||||
Message: json.RawMessage(`[{"role":"assistant","content":"hello"}]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
BaseModel: entity.BaseModel{
|
||||
CreateTime: &updateTime,
|
||||
CreateDate: &updateDate,
|
||||
UpdateTime: &updateTime,
|
||||
UpdateDate: &updateDate,
|
||||
},
|
||||
}
|
||||
if err := db.Create(session).Error; err != nil {
|
||||
t.Fatalf("failed to create chat session %s: %v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatSessionDAOUpdateByIDRefreshesTimestampsOnEmptyUpdate(t *testing.T) {
|
||||
db := setupChatSessionDAOTestDB(t)
|
||||
pushDB(t, db)
|
||||
|
||||
oldUpdateTime := int64(1000)
|
||||
createChatSessionForDAOTest(t, db, "session-1", "chat-1", "same", oldUpdateTime)
|
||||
|
||||
if err := NewChatSessionDAO().UpdateByID("session-1", map[string]interface{}{}); err != nil {
|
||||
t.Fatalf("UpdateByID failed: %v", err)
|
||||
}
|
||||
|
||||
session, err := NewChatSessionDAO().GetByID("session-1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID failed: %v", err)
|
||||
}
|
||||
if session.UpdateTime == nil || *session.UpdateTime <= oldUpdateTime {
|
||||
t.Fatalf("expected update_time to be refreshed, got %v", session.UpdateTime)
|
||||
}
|
||||
if session.UpdateDate == nil || !session.UpdateDate.After(time.UnixMilli(oldUpdateTime)) {
|
||||
t.Fatalf("expected update_date to be refreshed, got %v", session.UpdateDate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatSessionDAOUpdateByIDSameValueSucceeds(t *testing.T) {
|
||||
db := setupChatSessionDAOTestDB(t)
|
||||
pushDB(t, db)
|
||||
|
||||
createChatSessionForDAOTest(t, db, "session-1", "chat-1", "same", 1000)
|
||||
|
||||
if err := NewChatSessionDAO().UpdateByID("session-1", map[string]interface{}{"name": "same"}); err != nil {
|
||||
t.Fatalf("UpdateByID failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatSessionDAOUpdateByIDMissingSession(t *testing.T) {
|
||||
db := setupChatSessionDAOTestDB(t)
|
||||
pushDB(t, db)
|
||||
|
||||
err := NewChatSessionDAO().UpdateByID("missing", nil)
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
t.Fatalf("expected ErrRecordNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatSessionDAOListAgentSessionsOrdersByUpdateTimeDesc(t *testing.T) {
|
||||
db := setupChatSessionDAOTestDB(t)
|
||||
pushDB(t, db)
|
||||
|
||||
@@ -214,6 +214,12 @@ func (s *ChatService) Create(userID string, req map[string]interface{}) (map[str
|
||||
}
|
||||
}
|
||||
|
||||
if metaDataFilterValue, ok := req["meta_data_filter"]; ok && metaDataFilterValue != nil {
|
||||
if _, ok := mapFromValue(metaDataFilterValue); !ok {
|
||||
return nil, common.CodeDataError, errors.New("`meta_data_filter` should be an object.")
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := req["kb_ids"]; !ok {
|
||||
req["kb_ids"] = []string{}
|
||||
}
|
||||
@@ -244,6 +250,9 @@ func (s *ChatService) Create(userID string, req map[string]interface{}) (map[str
|
||||
if _, ok := req["icon"]; !ok {
|
||||
req["icon"] = ""
|
||||
}
|
||||
if _, ok := req["meta_data_filter"]; !ok || req["meta_data_filter"] == nil {
|
||||
req["meta_data_filter"] = map[string]interface{}{}
|
||||
}
|
||||
|
||||
applyCreatePromptDefaults(req)
|
||||
filterCreateChatPersistedFields(req)
|
||||
@@ -481,6 +490,9 @@ func buildCreateChatEntity(req map[string]interface{}, tenantID string) *entity.
|
||||
if metaDataFilter, ok := mapFromValue(req["meta_data_filter"]); ok {
|
||||
metaDataFilterJSON := entity.JSONMap(metaDataFilter)
|
||||
chat.MetaDataFilter = &metaDataFilterJSON
|
||||
} else {
|
||||
metaDataFilterJSON := entity.JSONMap{}
|
||||
chat.MetaDataFilter = &metaDataFilterJSON
|
||||
}
|
||||
return chat
|
||||
}
|
||||
@@ -494,6 +506,7 @@ func (s *ChatService) buildCreateChatResponse(chat *entity.Chat) (map[string]int
|
||||
data["dataset_ids"] = datasetIDs
|
||||
delete(data, "kb_ids")
|
||||
data["kb_names"] = kbNames
|
||||
data["meta_data_filter"] = normalizeMetaDataFilter(chat.MetaDataFilter)
|
||||
return data, nil
|
||||
}
|
||||
|
||||
@@ -533,6 +546,13 @@ func mapFromValue(value interface{}) (map[string]interface{}, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeMetaDataFilter(value *entity.JSONMap) entity.JSONMap {
|
||||
if value == nil || *value == nil {
|
||||
return entity.JSONMap{}
|
||||
}
|
||||
return *value
|
||||
}
|
||||
|
||||
func listFromValue(value interface{}) ([]interface{}, bool) {
|
||||
switch typed := value.(type) {
|
||||
case nil:
|
||||
@@ -1303,12 +1323,18 @@ func (s *ChatService) updateChatREST(userID, chatID string, req map[string]inter
|
||||
}
|
||||
}
|
||||
|
||||
if value, ok := req["meta_data_filter"]; ok && value != nil {
|
||||
metaDataFilter, ok := mapFromValue(value)
|
||||
if !ok {
|
||||
return nil, errors.New("`meta_data_filter` should be an object.")
|
||||
if value, ok := req["meta_data_filter"]; ok {
|
||||
if value == nil {
|
||||
req["meta_data_filter"] = entity.JSONMap{}
|
||||
} else {
|
||||
metaDataFilter, ok := mapFromValue(value)
|
||||
if !ok {
|
||||
return nil, errors.New("`meta_data_filter` should be an object.")
|
||||
}
|
||||
req["meta_data_filter"] = entity.JSONMap(metaDataFilter)
|
||||
}
|
||||
req["meta_data_filter"] = entity.JSONMap(metaDataFilter)
|
||||
} else if currentChat.MetaDataFilter == nil || *currentChat.MetaDataFilter == nil {
|
||||
req["meta_data_filter"] = entity.JSONMap{}
|
||||
}
|
||||
|
||||
updates := filterRESTChatUpdates(req)
|
||||
@@ -1492,7 +1518,7 @@ func (s *ChatService) buildRESTChatResponse(chat *entity.Chat) map[string]interf
|
||||
"llm_setting": chat.LLMSetting,
|
||||
"prompt_type": chat.PromptType,
|
||||
"prompt_config": chat.PromptConfig,
|
||||
"meta_data_filter": chat.MetaDataFilter,
|
||||
"meta_data_filter": normalizeMetaDataFilter(chat.MetaDataFilter),
|
||||
"similarity_threshold": chat.SimilarityThreshold,
|
||||
"vector_similarity_weight": chat.VectorSimilarityWeight,
|
||||
"top_n": chat.TopN,
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
@@ -66,6 +67,85 @@ func createChatRESTUpdateServiceTestChat(t *testing.T, db *gorm.DB, id, tenantID
|
||||
}
|
||||
}
|
||||
|
||||
func assertEmptyMetaDataFilter(t *testing.T, value interface{}) {
|
||||
t.Helper()
|
||||
|
||||
switch typed := value.(type) {
|
||||
case entity.JSONMap:
|
||||
if len(typed) != 0 {
|
||||
t.Fatalf("expected empty meta_data_filter, got %+v", typed)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
if len(typed) != 0 {
|
||||
t.Fatalf("expected empty meta_data_filter, got %+v", typed)
|
||||
}
|
||||
default:
|
||||
t.Fatalf("expected meta_data_filter object, got %T: %+v", value, value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatServiceCreateDefaultsMetaDataFilter(t *testing.T) {
|
||||
setupChatRESTUpdateServiceTestDB(t)
|
||||
|
||||
svc := NewChatService()
|
||||
resp, code, err := svc.Create("user-1", map[string]interface{}{
|
||||
"name": "created chat",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create failed: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("unexpected code: %v", code)
|
||||
}
|
||||
assertEmptyMetaDataFilter(t, resp["meta_data_filter"])
|
||||
|
||||
chatID, ok := resp["id"].(string)
|
||||
if !ok || chatID == "" {
|
||||
t.Fatalf("expected created chat id, got %+v", resp["id"])
|
||||
}
|
||||
chat, err := svc.chatDAO.GetByID(chatID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to fetch created chat: %v", err)
|
||||
}
|
||||
if chat.MetaDataFilter == nil {
|
||||
t.Fatal("expected persisted meta_data_filter to be non-nil")
|
||||
}
|
||||
assertEmptyMetaDataFilter(t, *chat.MetaDataFilter)
|
||||
}
|
||||
|
||||
func TestChatServiceCreateAcceptsNilMetaDataFilter(t *testing.T) {
|
||||
setupChatRESTUpdateServiceTestDB(t)
|
||||
|
||||
svc := NewChatService()
|
||||
resp, code, err := svc.Create("user-1", map[string]interface{}{
|
||||
"name": "created chat",
|
||||
"meta_data_filter": nil,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create failed: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("unexpected code: %v", code)
|
||||
}
|
||||
assertEmptyMetaDataFilter(t, resp["meta_data_filter"])
|
||||
}
|
||||
|
||||
func TestChatServiceCreateRejectsInvalidMetaDataFilter(t *testing.T) {
|
||||
setupChatRESTUpdateServiceTestDB(t)
|
||||
|
||||
svc := NewChatService()
|
||||
_, code, err := svc.Create("user-1", map[string]interface{}{
|
||||
"name": "created chat",
|
||||
"meta_data_filter": []interface{}{"invalid"},
|
||||
})
|
||||
if err == nil || err.Error() != "`meta_data_filter` should be an object." {
|
||||
t.Fatalf("expected meta_data_filter error, got %v", err)
|
||||
}
|
||||
if code != common.CodeDataError {
|
||||
t.Fatalf("unexpected code: %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatServicePatchChatMergesPromptConfigAndLLMSetting(t *testing.T) {
|
||||
db := setupChatRESTUpdateServiceTestDB(t)
|
||||
createChatRESTUpdateServiceTestChat(t, db, "chat-1", "user-1")
|
||||
@@ -162,6 +242,29 @@ func TestChatServiceUpdateChatAcceptsMetaDataFilterObject(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatServiceUpdateChatBackfillsNilMetaDataFilter(t *testing.T) {
|
||||
db := setupChatRESTUpdateServiceTestDB(t)
|
||||
createChatRESTUpdateServiceTestChat(t, db, "chat-1", "user-1")
|
||||
|
||||
svc := NewChatService()
|
||||
resp, err := svc.UpdateChat("user-1", "chat-1", map[string]interface{}{
|
||||
"name": "chat-chat-1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateChat failed: %v", err)
|
||||
}
|
||||
assertEmptyMetaDataFilter(t, resp["meta_data_filter"])
|
||||
|
||||
chat, err := svc.chatDAO.GetByID("chat-1")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to fetch chat: %v", err)
|
||||
}
|
||||
if chat.MetaDataFilter == nil {
|
||||
t.Fatal("expected meta_data_filter to be backfilled")
|
||||
}
|
||||
assertEmptyMetaDataFilter(t, *chat.MetaDataFilter)
|
||||
}
|
||||
|
||||
func TestChatServicePatchChatIgnoresTenantIDAndUpdatesName(t *testing.T) {
|
||||
db := setupChatRESTUpdateServiceTestDB(t)
|
||||
createChatRESTUpdateServiceTestChat(t, db, "chat-1", "user-1")
|
||||
|
||||
@@ -661,11 +661,14 @@ func (s *ChatSessionService) buildSessionPayload(session *entity.ChatSession, di
|
||||
}
|
||||
|
||||
func parseMessages(raw json.RawMessage) []map[string]interface{} {
|
||||
var messages []map[string]interface{}
|
||||
messages := make([]map[string]interface{}, 0)
|
||||
if len(raw) == 0 {
|
||||
return messages
|
||||
}
|
||||
if err := json.Unmarshal(raw, &messages); err == nil {
|
||||
if messages == nil {
|
||||
return make([]map[string]interface{}, 0)
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
@@ -673,19 +676,25 @@ func parseMessages(raw json.RawMessage) []map[string]interface{} {
|
||||
Messages []map[string]interface{} `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &wrapped); err != nil {
|
||||
return nil
|
||||
return make([]map[string]interface{}, 0)
|
||||
}
|
||||
if wrapped.Messages == nil {
|
||||
return make([]map[string]interface{}, 0)
|
||||
}
|
||||
return wrapped.Messages
|
||||
}
|
||||
|
||||
func parseReferenceList(raw json.RawMessage) []interface{} {
|
||||
var references []interface{}
|
||||
references := make([]interface{}, 0)
|
||||
if len(raw) == 0 {
|
||||
return references
|
||||
}
|
||||
err := json.Unmarshal(raw, &references)
|
||||
if err != nil {
|
||||
return nil
|
||||
return make([]interface{}, 0)
|
||||
}
|
||||
if references == nil {
|
||||
return make([]interface{}, 0)
|
||||
}
|
||||
return references
|
||||
}
|
||||
|
||||
@@ -905,6 +905,61 @@ func TestParseMessages_LegacyWrappedObject(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSessionPayload_EmptyCollectionsEncodeAsEmptyArrays(t *testing.T) {
|
||||
svc := &ChatSessionService{}
|
||||
payload := svc.buildSessionPayload(&entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Message: nil,
|
||||
Reference: json.RawMessage(`null`),
|
||||
}, nil, false)
|
||||
|
||||
if payload.Messages == nil {
|
||||
t.Fatal("messages is nil")
|
||||
}
|
||||
if payload.Reference == nil {
|
||||
t.Fatal("reference is nil")
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
if !strings.Contains(string(body), `"messages":[]`) {
|
||||
t.Fatalf("messages did not encode as empty array: %s", string(body))
|
||||
}
|
||||
if !strings.Contains(string(body), `"reference":[]`) {
|
||||
t.Fatalf("reference did not encode as empty array: %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCollections_ReturnEmptySlicesForMissingNullOrInvalid(t *testing.T) {
|
||||
messageInputs := []json.RawMessage{
|
||||
nil,
|
||||
json.RawMessage(`null`),
|
||||
json.RawMessage(`{"messages":null}`),
|
||||
json.RawMessage(`not-json`),
|
||||
}
|
||||
for _, input := range messageInputs {
|
||||
got := parseMessages(input)
|
||||
if got == nil || len(got) != 0 {
|
||||
t.Fatalf("parseMessages(%s)=%#v", string(input), got)
|
||||
}
|
||||
}
|
||||
|
||||
referenceInputs := []json.RawMessage{
|
||||
nil,
|
||||
json.RawMessage(`null`),
|
||||
json.RawMessage(`not-json`),
|
||||
}
|
||||
for _, input := range referenceInputs {
|
||||
got := parseReferenceList(input)
|
||||
if got == nil || len(got) != 0 {
|
||||
t.Fatalf("parseReferenceList(%s)=%#v", string(input), got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionStream_EmptyMessages(t *testing.T) {
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: &fakeSessionStore{},
|
||||
|
||||
Reference in New Issue
Block a user