From 445a13ee9aa620c54512a1bf043b33e5f41152c0 Mon Sep 17 00:00:00 2001 From: Haruko386 Date: Mon, 29 Jun 2026 19:04:59 +0800 Subject: [PATCH] fix: new chat cannot be edit (#16434) ### What problem does this PR solve? As title main fix: ```go if _, ok := req["meta_data_filter"]; !ok || req["meta_data_filter"] == nil { req["meta_data_filter"] = map[string]interface{}{} } ``` ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoring --- internal/dao/chat.go | 29 +++++- internal/dao/chat_session.go | 18 +++- internal/dao/chat_session_test.go | 69 ++++++++++++++- internal/service/chat.go | 38 ++++++-- internal/service/chat_rest_update_test.go | 103 ++++++++++++++++++++++ internal/service/chat_session.go | 17 +++- internal/service/chat_session_test.go | 55 ++++++++++++ 7 files changed, 314 insertions(+), 15 deletions(-) diff --git a/internal/dao/chat.go b/internal/dao/chat.go index e44e338f48..7100318635 100644 --- a/internal/dao/chat.go +++ b/internal/dao/chat.go @@ -18,8 +18,12 @@ package dao import ( "fmt" - "ragflow/internal/entity" "strings" + "time" + + "gorm.io/gorm" + + "ragflow/internal/entity" ) // ChatDAO chat data access object @@ -185,7 +189,28 @@ func (dao *ChatDAO) Create(chat *entity.Chat) error { // UpdateByID updates a chat by ID func (dao *ChatDAO) UpdateByID(id string, updates map[string]interface{}) error { - return DB.Model(&entity.Chat{}).Where("id = ?", id).Updates(updates).Error + if updates == nil { + updates = make(map[string]interface{}) + } + + now := time.Now().Local() + updates["update_time"] = now.UnixMilli() + updates["update_date"] = now.Truncate(time.Second) + + result := DB.Session(&gorm.Session{SkipHooks: true}).Model(&entity.Chat{}).Where("id = ?", id).Updates(updates) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + var count int64 + if err := DB.Model(&entity.Chat{}).Where("id = ?", id).Count(&count).Error; err != nil { + return err + } + if count == 0 { + return gorm.ErrRecordNotFound + } + } + return nil } // UpdateManyByID updates multiple chats by ID (batch update) diff --git a/internal/dao/chat_session.go b/internal/dao/chat_session.go index 940f0b6b1e..21d60ab60b 100644 --- a/internal/dao/chat_session.go +++ b/internal/dao/chat_session.go @@ -76,12 +76,26 @@ func (dao *ChatSessionDAO) Create(conv *entity.ChatSession) error { // UpdateByID updates a chat session by ID func (dao *ChatSessionDAO) UpdateByID(id string, updates map[string]interface{}) error { - result := DB.Model(&entity.ChatSession{}).Where("id = ?", id).Updates(updates) + if updates == nil { + updates = make(map[string]interface{}) + } + + now := time.Now().Local() + updates["update_time"] = now.UnixMilli() + updates["update_date"] = now.Truncate(time.Second) + + result := DB.Session(&gorm.Session{SkipHooks: true}).Model(&entity.ChatSession{}).Where("id = ?", id).Updates(updates) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { - return gorm.ErrRecordNotFound + var count int64 + if err := DB.Model(&entity.ChatSession{}).Where("id = ?", id).Count(&count).Error; err != nil { + return err + } + if count == 0 { + return gorm.ErrRecordNotFound + } } return nil } diff --git a/internal/dao/chat_session_test.go b/internal/dao/chat_session_test.go index e2fb6ce726..077a5d8dc0 100644 --- a/internal/dao/chat_session_test.go +++ b/internal/dao/chat_session_test.go @@ -18,6 +18,7 @@ package dao import ( "encoding/json" + "errors" "testing" "time" @@ -37,7 +38,7 @@ func setupChatSessionDAOTestDB(t *testing.T) *gorm.DB { t.Fatalf("failed to open sqlite: %v", err) } - if err := db.AutoMigrate(&entity.API4Conversation{}); err != nil { + if err := db.AutoMigrate(&entity.API4Conversation{}, &entity.ChatSession{}); err != nil { t.Fatalf("failed to migrate: %v", err) } @@ -66,6 +67,72 @@ func createAgentSessionForDAOTest(t *testing.T, db *gorm.DB, id, agentID, userID } } +func createChatSessionForDAOTest(t *testing.T, db *gorm.DB, id, chatID, name string, updateTime int64) { + t.Helper() + + updateDate := time.UnixMilli(updateTime).Local() + session := &entity.ChatSession{ + ID: id, + DialogID: chatID, + Name: &name, + Message: json.RawMessage(`[{"role":"assistant","content":"hello"}]`), + Reference: json.RawMessage(`[]`), + BaseModel: entity.BaseModel{ + CreateTime: &updateTime, + CreateDate: &updateDate, + UpdateTime: &updateTime, + UpdateDate: &updateDate, + }, + } + if err := db.Create(session).Error; err != nil { + t.Fatalf("failed to create chat session %s: %v", id, err) + } +} + +func TestChatSessionDAOUpdateByIDRefreshesTimestampsOnEmptyUpdate(t *testing.T) { + db := setupChatSessionDAOTestDB(t) + pushDB(t, db) + + oldUpdateTime := int64(1000) + createChatSessionForDAOTest(t, db, "session-1", "chat-1", "same", oldUpdateTime) + + if err := NewChatSessionDAO().UpdateByID("session-1", map[string]interface{}{}); err != nil { + t.Fatalf("UpdateByID failed: %v", err) + } + + session, err := NewChatSessionDAO().GetByID("session-1") + if err != nil { + t.Fatalf("GetByID failed: %v", err) + } + if session.UpdateTime == nil || *session.UpdateTime <= oldUpdateTime { + t.Fatalf("expected update_time to be refreshed, got %v", session.UpdateTime) + } + if session.UpdateDate == nil || !session.UpdateDate.After(time.UnixMilli(oldUpdateTime)) { + t.Fatalf("expected update_date to be refreshed, got %v", session.UpdateDate) + } +} + +func TestChatSessionDAOUpdateByIDSameValueSucceeds(t *testing.T) { + db := setupChatSessionDAOTestDB(t) + pushDB(t, db) + + createChatSessionForDAOTest(t, db, "session-1", "chat-1", "same", 1000) + + if err := NewChatSessionDAO().UpdateByID("session-1", map[string]interface{}{"name": "same"}); err != nil { + t.Fatalf("UpdateByID failed: %v", err) + } +} + +func TestChatSessionDAOUpdateByIDMissingSession(t *testing.T) { + db := setupChatSessionDAOTestDB(t) + pushDB(t, db) + + err := NewChatSessionDAO().UpdateByID("missing", nil) + if !errors.Is(err, gorm.ErrRecordNotFound) { + t.Fatalf("expected ErrRecordNotFound, got %v", err) + } +} + func TestChatSessionDAOListAgentSessionsOrdersByUpdateTimeDesc(t *testing.T) { db := setupChatSessionDAOTestDB(t) pushDB(t, db) diff --git a/internal/service/chat.go b/internal/service/chat.go index 6fa77cb084..4fa939d75c 100644 --- a/internal/service/chat.go +++ b/internal/service/chat.go @@ -214,6 +214,12 @@ func (s *ChatService) Create(userID string, req map[string]interface{}) (map[str } } + if metaDataFilterValue, ok := req["meta_data_filter"]; ok && metaDataFilterValue != nil { + if _, ok := mapFromValue(metaDataFilterValue); !ok { + return nil, common.CodeDataError, errors.New("`meta_data_filter` should be an object.") + } + } + if _, ok := req["kb_ids"]; !ok { req["kb_ids"] = []string{} } @@ -244,6 +250,9 @@ func (s *ChatService) Create(userID string, req map[string]interface{}) (map[str if _, ok := req["icon"]; !ok { req["icon"] = "" } + if _, ok := req["meta_data_filter"]; !ok || req["meta_data_filter"] == nil { + req["meta_data_filter"] = map[string]interface{}{} + } applyCreatePromptDefaults(req) filterCreateChatPersistedFields(req) @@ -481,6 +490,9 @@ func buildCreateChatEntity(req map[string]interface{}, tenantID string) *entity. if metaDataFilter, ok := mapFromValue(req["meta_data_filter"]); ok { metaDataFilterJSON := entity.JSONMap(metaDataFilter) chat.MetaDataFilter = &metaDataFilterJSON + } else { + metaDataFilterJSON := entity.JSONMap{} + chat.MetaDataFilter = &metaDataFilterJSON } return chat } @@ -494,6 +506,7 @@ func (s *ChatService) buildCreateChatResponse(chat *entity.Chat) (map[string]int data["dataset_ids"] = datasetIDs delete(data, "kb_ids") data["kb_names"] = kbNames + data["meta_data_filter"] = normalizeMetaDataFilter(chat.MetaDataFilter) return data, nil } @@ -533,6 +546,13 @@ func mapFromValue(value interface{}) (map[string]interface{}, bool) { } } +func normalizeMetaDataFilter(value *entity.JSONMap) entity.JSONMap { + if value == nil || *value == nil { + return entity.JSONMap{} + } + return *value +} + func listFromValue(value interface{}) ([]interface{}, bool) { switch typed := value.(type) { case nil: @@ -1303,12 +1323,18 @@ func (s *ChatService) updateChatREST(userID, chatID string, req map[string]inter } } - if value, ok := req["meta_data_filter"]; ok && value != nil { - metaDataFilter, ok := mapFromValue(value) - if !ok { - return nil, errors.New("`meta_data_filter` should be an object.") + if value, ok := req["meta_data_filter"]; ok { + if value == nil { + req["meta_data_filter"] = entity.JSONMap{} + } else { + metaDataFilter, ok := mapFromValue(value) + if !ok { + return nil, errors.New("`meta_data_filter` should be an object.") + } + req["meta_data_filter"] = entity.JSONMap(metaDataFilter) } - req["meta_data_filter"] = entity.JSONMap(metaDataFilter) + } else if currentChat.MetaDataFilter == nil || *currentChat.MetaDataFilter == nil { + req["meta_data_filter"] = entity.JSONMap{} } updates := filterRESTChatUpdates(req) @@ -1492,7 +1518,7 @@ func (s *ChatService) buildRESTChatResponse(chat *entity.Chat) map[string]interf "llm_setting": chat.LLMSetting, "prompt_type": chat.PromptType, "prompt_config": chat.PromptConfig, - "meta_data_filter": chat.MetaDataFilter, + "meta_data_filter": normalizeMetaDataFilter(chat.MetaDataFilter), "similarity_threshold": chat.SimilarityThreshold, "vector_similarity_weight": chat.VectorSimilarityWeight, "top_n": chat.TopN, diff --git a/internal/service/chat_rest_update_test.go b/internal/service/chat_rest_update_test.go index 586a184017..a84d762167 100644 --- a/internal/service/chat_rest_update_test.go +++ b/internal/service/chat_rest_update_test.go @@ -6,6 +6,7 @@ import ( "github.com/glebarez/sqlite" "gorm.io/gorm" + "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/entity" ) @@ -66,6 +67,85 @@ func createChatRESTUpdateServiceTestChat(t *testing.T, db *gorm.DB, id, tenantID } } +func assertEmptyMetaDataFilter(t *testing.T, value interface{}) { + t.Helper() + + switch typed := value.(type) { + case entity.JSONMap: + if len(typed) != 0 { + t.Fatalf("expected empty meta_data_filter, got %+v", typed) + } + case map[string]interface{}: + if len(typed) != 0 { + t.Fatalf("expected empty meta_data_filter, got %+v", typed) + } + default: + t.Fatalf("expected meta_data_filter object, got %T: %+v", value, value) + } +} + +func TestChatServiceCreateDefaultsMetaDataFilter(t *testing.T) { + setupChatRESTUpdateServiceTestDB(t) + + svc := NewChatService() + resp, code, err := svc.Create("user-1", map[string]interface{}{ + "name": "created chat", + }) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("unexpected code: %v", code) + } + assertEmptyMetaDataFilter(t, resp["meta_data_filter"]) + + chatID, ok := resp["id"].(string) + if !ok || chatID == "" { + t.Fatalf("expected created chat id, got %+v", resp["id"]) + } + chat, err := svc.chatDAO.GetByID(chatID) + if err != nil { + t.Fatalf("failed to fetch created chat: %v", err) + } + if chat.MetaDataFilter == nil { + t.Fatal("expected persisted meta_data_filter to be non-nil") + } + assertEmptyMetaDataFilter(t, *chat.MetaDataFilter) +} + +func TestChatServiceCreateAcceptsNilMetaDataFilter(t *testing.T) { + setupChatRESTUpdateServiceTestDB(t) + + svc := NewChatService() + resp, code, err := svc.Create("user-1", map[string]interface{}{ + "name": "created chat", + "meta_data_filter": nil, + }) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("unexpected code: %v", code) + } + assertEmptyMetaDataFilter(t, resp["meta_data_filter"]) +} + +func TestChatServiceCreateRejectsInvalidMetaDataFilter(t *testing.T) { + setupChatRESTUpdateServiceTestDB(t) + + svc := NewChatService() + _, code, err := svc.Create("user-1", map[string]interface{}{ + "name": "created chat", + "meta_data_filter": []interface{}{"invalid"}, + }) + if err == nil || err.Error() != "`meta_data_filter` should be an object." { + t.Fatalf("expected meta_data_filter error, got %v", err) + } + if code != common.CodeDataError { + t.Fatalf("unexpected code: %v", code) + } +} + func TestChatServicePatchChatMergesPromptConfigAndLLMSetting(t *testing.T) { db := setupChatRESTUpdateServiceTestDB(t) createChatRESTUpdateServiceTestChat(t, db, "chat-1", "user-1") @@ -162,6 +242,29 @@ func TestChatServiceUpdateChatAcceptsMetaDataFilterObject(t *testing.T) { } } +func TestChatServiceUpdateChatBackfillsNilMetaDataFilter(t *testing.T) { + db := setupChatRESTUpdateServiceTestDB(t) + createChatRESTUpdateServiceTestChat(t, db, "chat-1", "user-1") + + svc := NewChatService() + resp, err := svc.UpdateChat("user-1", "chat-1", map[string]interface{}{ + "name": "chat-chat-1", + }) + if err != nil { + t.Fatalf("UpdateChat failed: %v", err) + } + assertEmptyMetaDataFilter(t, resp["meta_data_filter"]) + + chat, err := svc.chatDAO.GetByID("chat-1") + if err != nil { + t.Fatalf("failed to fetch chat: %v", err) + } + if chat.MetaDataFilter == nil { + t.Fatal("expected meta_data_filter to be backfilled") + } + assertEmptyMetaDataFilter(t, *chat.MetaDataFilter) +} + func TestChatServicePatchChatIgnoresTenantIDAndUpdatesName(t *testing.T) { db := setupChatRESTUpdateServiceTestDB(t) createChatRESTUpdateServiceTestChat(t, db, "chat-1", "user-1") diff --git a/internal/service/chat_session.go b/internal/service/chat_session.go index 718a766409..eb0817cd82 100644 --- a/internal/service/chat_session.go +++ b/internal/service/chat_session.go @@ -661,11 +661,14 @@ func (s *ChatSessionService) buildSessionPayload(session *entity.ChatSession, di } func parseMessages(raw json.RawMessage) []map[string]interface{} { - var messages []map[string]interface{} + messages := make([]map[string]interface{}, 0) if len(raw) == 0 { return messages } if err := json.Unmarshal(raw, &messages); err == nil { + if messages == nil { + return make([]map[string]interface{}, 0) + } return messages } @@ -673,19 +676,25 @@ func parseMessages(raw json.RawMessage) []map[string]interface{} { Messages []map[string]interface{} `json:"messages"` } if err := json.Unmarshal(raw, &wrapped); err != nil { - return nil + return make([]map[string]interface{}, 0) + } + if wrapped.Messages == nil { + return make([]map[string]interface{}, 0) } return wrapped.Messages } func parseReferenceList(raw json.RawMessage) []interface{} { - var references []interface{} + references := make([]interface{}, 0) if len(raw) == 0 { return references } err := json.Unmarshal(raw, &references) if err != nil { - return nil + return make([]interface{}, 0) + } + if references == nil { + return make([]interface{}, 0) } return references } diff --git a/internal/service/chat_session_test.go b/internal/service/chat_session_test.go index ecf006e3fa..0501f0239f 100644 --- a/internal/service/chat_session_test.go +++ b/internal/service/chat_session_test.go @@ -905,6 +905,61 @@ func TestParseMessages_LegacyWrappedObject(t *testing.T) { } } +func TestBuildSessionPayload_EmptyCollectionsEncodeAsEmptyArrays(t *testing.T) { + svc := &ChatSessionService{} + payload := svc.buildSessionPayload(&entity.ChatSession{ + ID: "session-1", + DialogID: "chat-1", + Message: nil, + Reference: json.RawMessage(`null`), + }, nil, false) + + if payload.Messages == nil { + t.Fatal("messages is nil") + } + if payload.Reference == nil { + t.Fatal("reference is nil") + } + + body, err := json.Marshal(payload) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if !strings.Contains(string(body), `"messages":[]`) { + t.Fatalf("messages did not encode as empty array: %s", string(body)) + } + if !strings.Contains(string(body), `"reference":[]`) { + t.Fatalf("reference did not encode as empty array: %s", string(body)) + } +} + +func TestParseCollections_ReturnEmptySlicesForMissingNullOrInvalid(t *testing.T) { + messageInputs := []json.RawMessage{ + nil, + json.RawMessage(`null`), + json.RawMessage(`{"messages":null}`), + json.RawMessage(`not-json`), + } + for _, input := range messageInputs { + got := parseMessages(input) + if got == nil || len(got) != 0 { + t.Fatalf("parseMessages(%s)=%#v", string(input), got) + } + } + + referenceInputs := []json.RawMessage{ + nil, + json.RawMessage(`null`), + json.RawMessage(`not-json`), + } + for _, input := range referenceInputs { + got := parseReferenceList(input) + if got == nil || len(got) != 0 { + t.Fatalf("parseReferenceList(%s)=%#v", string(input), got) + } + } +} + func TestCompletionStream_EmptyMessages(t *testing.T) { svc := &ChatSessionService{ chatSessionDAO: &fakeSessionStore{},