fix: new chat cannot be edit (#16434)

### What problem does this PR solve?

As title
main fix:

```go
if _, ok := req["meta_data_filter"]; !ok || req["meta_data_filter"] == nil {
	req["meta_data_filter"] = map[string]interface{}{}
}
```


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
This commit is contained in:
Haruko386
2026-06-29 19:04:59 +08:00
committed by GitHub
parent 43f75fdfc7
commit 445a13ee9a
7 changed files with 314 additions and 15 deletions

View File

@@ -18,8 +18,12 @@ package dao
import (
"fmt"
"ragflow/internal/entity"
"strings"
"time"
"gorm.io/gorm"
"ragflow/internal/entity"
)
// ChatDAO chat data access object
@@ -185,7 +189,28 @@ func (dao *ChatDAO) Create(chat *entity.Chat) error {
// UpdateByID updates a chat by ID
func (dao *ChatDAO) UpdateByID(id string, updates map[string]interface{}) error {
return DB.Model(&entity.Chat{}).Where("id = ?", id).Updates(updates).Error
if updates == nil {
updates = make(map[string]interface{})
}
now := time.Now().Local()
updates["update_time"] = now.UnixMilli()
updates["update_date"] = now.Truncate(time.Second)
result := DB.Session(&gorm.Session{SkipHooks: true}).Model(&entity.Chat{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
var count int64
if err := DB.Model(&entity.Chat{}).Where("id = ?", id).Count(&count).Error; err != nil {
return err
}
if count == 0 {
return gorm.ErrRecordNotFound
}
}
return nil
}
// UpdateManyByID updates multiple chats by ID (batch update)

View File

@@ -76,12 +76,26 @@ func (dao *ChatSessionDAO) Create(conv *entity.ChatSession) error {
// UpdateByID updates a chat session by ID
func (dao *ChatSessionDAO) UpdateByID(id string, updates map[string]interface{}) error {
result := DB.Model(&entity.ChatSession{}).Where("id = ?", id).Updates(updates)
if updates == nil {
updates = make(map[string]interface{})
}
now := time.Now().Local()
updates["update_time"] = now.UnixMilli()
updates["update_date"] = now.Truncate(time.Second)
result := DB.Session(&gorm.Session{SkipHooks: true}).Model(&entity.ChatSession{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
var count int64
if err := DB.Model(&entity.ChatSession{}).Where("id = ?", id).Count(&count).Error; err != nil {
return err
}
if count == 0 {
return gorm.ErrRecordNotFound
}
}
return nil
}

View File

@@ -18,6 +18,7 @@ package dao
import (
"encoding/json"
"errors"
"testing"
"time"
@@ -37,7 +38,7 @@ func setupChatSessionDAOTestDB(t *testing.T) *gorm.DB {
t.Fatalf("failed to open sqlite: %v", err)
}
if err := db.AutoMigrate(&entity.API4Conversation{}); err != nil {
if err := db.AutoMigrate(&entity.API4Conversation{}, &entity.ChatSession{}); err != nil {
t.Fatalf("failed to migrate: %v", err)
}
@@ -66,6 +67,72 @@ func createAgentSessionForDAOTest(t *testing.T, db *gorm.DB, id, agentID, userID
}
}
func createChatSessionForDAOTest(t *testing.T, db *gorm.DB, id, chatID, name string, updateTime int64) {
t.Helper()
updateDate := time.UnixMilli(updateTime).Local()
session := &entity.ChatSession{
ID: id,
DialogID: chatID,
Name: &name,
Message: json.RawMessage(`[{"role":"assistant","content":"hello"}]`),
Reference: json.RawMessage(`[]`),
BaseModel: entity.BaseModel{
CreateTime: &updateTime,
CreateDate: &updateDate,
UpdateTime: &updateTime,
UpdateDate: &updateDate,
},
}
if err := db.Create(session).Error; err != nil {
t.Fatalf("failed to create chat session %s: %v", id, err)
}
}
func TestChatSessionDAOUpdateByIDRefreshesTimestampsOnEmptyUpdate(t *testing.T) {
db := setupChatSessionDAOTestDB(t)
pushDB(t, db)
oldUpdateTime := int64(1000)
createChatSessionForDAOTest(t, db, "session-1", "chat-1", "same", oldUpdateTime)
if err := NewChatSessionDAO().UpdateByID("session-1", map[string]interface{}{}); err != nil {
t.Fatalf("UpdateByID failed: %v", err)
}
session, err := NewChatSessionDAO().GetByID("session-1")
if err != nil {
t.Fatalf("GetByID failed: %v", err)
}
if session.UpdateTime == nil || *session.UpdateTime <= oldUpdateTime {
t.Fatalf("expected update_time to be refreshed, got %v", session.UpdateTime)
}
if session.UpdateDate == nil || !session.UpdateDate.After(time.UnixMilli(oldUpdateTime)) {
t.Fatalf("expected update_date to be refreshed, got %v", session.UpdateDate)
}
}
func TestChatSessionDAOUpdateByIDSameValueSucceeds(t *testing.T) {
db := setupChatSessionDAOTestDB(t)
pushDB(t, db)
createChatSessionForDAOTest(t, db, "session-1", "chat-1", "same", 1000)
if err := NewChatSessionDAO().UpdateByID("session-1", map[string]interface{}{"name": "same"}); err != nil {
t.Fatalf("UpdateByID failed: %v", err)
}
}
func TestChatSessionDAOUpdateByIDMissingSession(t *testing.T) {
db := setupChatSessionDAOTestDB(t)
pushDB(t, db)
err := NewChatSessionDAO().UpdateByID("missing", nil)
if !errors.Is(err, gorm.ErrRecordNotFound) {
t.Fatalf("expected ErrRecordNotFound, got %v", err)
}
}
func TestChatSessionDAOListAgentSessionsOrdersByUpdateTimeDesc(t *testing.T) {
db := setupChatSessionDAOTestDB(t)
pushDB(t, db)

View File

@@ -214,6 +214,12 @@ func (s *ChatService) Create(userID string, req map[string]interface{}) (map[str
}
}
if metaDataFilterValue, ok := req["meta_data_filter"]; ok && metaDataFilterValue != nil {
if _, ok := mapFromValue(metaDataFilterValue); !ok {
return nil, common.CodeDataError, errors.New("`meta_data_filter` should be an object.")
}
}
if _, ok := req["kb_ids"]; !ok {
req["kb_ids"] = []string{}
}
@@ -244,6 +250,9 @@ func (s *ChatService) Create(userID string, req map[string]interface{}) (map[str
if _, ok := req["icon"]; !ok {
req["icon"] = ""
}
if _, ok := req["meta_data_filter"]; !ok || req["meta_data_filter"] == nil {
req["meta_data_filter"] = map[string]interface{}{}
}
applyCreatePromptDefaults(req)
filterCreateChatPersistedFields(req)
@@ -481,6 +490,9 @@ func buildCreateChatEntity(req map[string]interface{}, tenantID string) *entity.
if metaDataFilter, ok := mapFromValue(req["meta_data_filter"]); ok {
metaDataFilterJSON := entity.JSONMap(metaDataFilter)
chat.MetaDataFilter = &metaDataFilterJSON
} else {
metaDataFilterJSON := entity.JSONMap{}
chat.MetaDataFilter = &metaDataFilterJSON
}
return chat
}
@@ -494,6 +506,7 @@ func (s *ChatService) buildCreateChatResponse(chat *entity.Chat) (map[string]int
data["dataset_ids"] = datasetIDs
delete(data, "kb_ids")
data["kb_names"] = kbNames
data["meta_data_filter"] = normalizeMetaDataFilter(chat.MetaDataFilter)
return data, nil
}
@@ -533,6 +546,13 @@ func mapFromValue(value interface{}) (map[string]interface{}, bool) {
}
}
func normalizeMetaDataFilter(value *entity.JSONMap) entity.JSONMap {
if value == nil || *value == nil {
return entity.JSONMap{}
}
return *value
}
func listFromValue(value interface{}) ([]interface{}, bool) {
switch typed := value.(type) {
case nil:
@@ -1303,12 +1323,18 @@ func (s *ChatService) updateChatREST(userID, chatID string, req map[string]inter
}
}
if value, ok := req["meta_data_filter"]; ok && value != nil {
metaDataFilter, ok := mapFromValue(value)
if !ok {
return nil, errors.New("`meta_data_filter` should be an object.")
if value, ok := req["meta_data_filter"]; ok {
if value == nil {
req["meta_data_filter"] = entity.JSONMap{}
} else {
metaDataFilter, ok := mapFromValue(value)
if !ok {
return nil, errors.New("`meta_data_filter` should be an object.")
}
req["meta_data_filter"] = entity.JSONMap(metaDataFilter)
}
req["meta_data_filter"] = entity.JSONMap(metaDataFilter)
} else if currentChat.MetaDataFilter == nil || *currentChat.MetaDataFilter == nil {
req["meta_data_filter"] = entity.JSONMap{}
}
updates := filterRESTChatUpdates(req)
@@ -1492,7 +1518,7 @@ func (s *ChatService) buildRESTChatResponse(chat *entity.Chat) map[string]interf
"llm_setting": chat.LLMSetting,
"prompt_type": chat.PromptType,
"prompt_config": chat.PromptConfig,
"meta_data_filter": chat.MetaDataFilter,
"meta_data_filter": normalizeMetaDataFilter(chat.MetaDataFilter),
"similarity_threshold": chat.SimilarityThreshold,
"vector_similarity_weight": chat.VectorSimilarityWeight,
"top_n": chat.TopN,

View File

@@ -6,6 +6,7 @@ import (
"github.com/glebarez/sqlite"
"gorm.io/gorm"
"ragflow/internal/common"
"ragflow/internal/dao"
"ragflow/internal/entity"
)
@@ -66,6 +67,85 @@ func createChatRESTUpdateServiceTestChat(t *testing.T, db *gorm.DB, id, tenantID
}
}
func assertEmptyMetaDataFilter(t *testing.T, value interface{}) {
t.Helper()
switch typed := value.(type) {
case entity.JSONMap:
if len(typed) != 0 {
t.Fatalf("expected empty meta_data_filter, got %+v", typed)
}
case map[string]interface{}:
if len(typed) != 0 {
t.Fatalf("expected empty meta_data_filter, got %+v", typed)
}
default:
t.Fatalf("expected meta_data_filter object, got %T: %+v", value, value)
}
}
func TestChatServiceCreateDefaultsMetaDataFilter(t *testing.T) {
setupChatRESTUpdateServiceTestDB(t)
svc := NewChatService()
resp, code, err := svc.Create("user-1", map[string]interface{}{
"name": "created chat",
})
if err != nil {
t.Fatalf("Create failed: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("unexpected code: %v", code)
}
assertEmptyMetaDataFilter(t, resp["meta_data_filter"])
chatID, ok := resp["id"].(string)
if !ok || chatID == "" {
t.Fatalf("expected created chat id, got %+v", resp["id"])
}
chat, err := svc.chatDAO.GetByID(chatID)
if err != nil {
t.Fatalf("failed to fetch created chat: %v", err)
}
if chat.MetaDataFilter == nil {
t.Fatal("expected persisted meta_data_filter to be non-nil")
}
assertEmptyMetaDataFilter(t, *chat.MetaDataFilter)
}
func TestChatServiceCreateAcceptsNilMetaDataFilter(t *testing.T) {
setupChatRESTUpdateServiceTestDB(t)
svc := NewChatService()
resp, code, err := svc.Create("user-1", map[string]interface{}{
"name": "created chat",
"meta_data_filter": nil,
})
if err != nil {
t.Fatalf("Create failed: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("unexpected code: %v", code)
}
assertEmptyMetaDataFilter(t, resp["meta_data_filter"])
}
func TestChatServiceCreateRejectsInvalidMetaDataFilter(t *testing.T) {
setupChatRESTUpdateServiceTestDB(t)
svc := NewChatService()
_, code, err := svc.Create("user-1", map[string]interface{}{
"name": "created chat",
"meta_data_filter": []interface{}{"invalid"},
})
if err == nil || err.Error() != "`meta_data_filter` should be an object." {
t.Fatalf("expected meta_data_filter error, got %v", err)
}
if code != common.CodeDataError {
t.Fatalf("unexpected code: %v", code)
}
}
func TestChatServicePatchChatMergesPromptConfigAndLLMSetting(t *testing.T) {
db := setupChatRESTUpdateServiceTestDB(t)
createChatRESTUpdateServiceTestChat(t, db, "chat-1", "user-1")
@@ -162,6 +242,29 @@ func TestChatServiceUpdateChatAcceptsMetaDataFilterObject(t *testing.T) {
}
}
func TestChatServiceUpdateChatBackfillsNilMetaDataFilter(t *testing.T) {
db := setupChatRESTUpdateServiceTestDB(t)
createChatRESTUpdateServiceTestChat(t, db, "chat-1", "user-1")
svc := NewChatService()
resp, err := svc.UpdateChat("user-1", "chat-1", map[string]interface{}{
"name": "chat-chat-1",
})
if err != nil {
t.Fatalf("UpdateChat failed: %v", err)
}
assertEmptyMetaDataFilter(t, resp["meta_data_filter"])
chat, err := svc.chatDAO.GetByID("chat-1")
if err != nil {
t.Fatalf("failed to fetch chat: %v", err)
}
if chat.MetaDataFilter == nil {
t.Fatal("expected meta_data_filter to be backfilled")
}
assertEmptyMetaDataFilter(t, *chat.MetaDataFilter)
}
func TestChatServicePatchChatIgnoresTenantIDAndUpdatesName(t *testing.T) {
db := setupChatRESTUpdateServiceTestDB(t)
createChatRESTUpdateServiceTestChat(t, db, "chat-1", "user-1")

View File

@@ -661,11 +661,14 @@ func (s *ChatSessionService) buildSessionPayload(session *entity.ChatSession, di
}
func parseMessages(raw json.RawMessage) []map[string]interface{} {
var messages []map[string]interface{}
messages := make([]map[string]interface{}, 0)
if len(raw) == 0 {
return messages
}
if err := json.Unmarshal(raw, &messages); err == nil {
if messages == nil {
return make([]map[string]interface{}, 0)
}
return messages
}
@@ -673,19 +676,25 @@ func parseMessages(raw json.RawMessage) []map[string]interface{} {
Messages []map[string]interface{} `json:"messages"`
}
if err := json.Unmarshal(raw, &wrapped); err != nil {
return nil
return make([]map[string]interface{}, 0)
}
if wrapped.Messages == nil {
return make([]map[string]interface{}, 0)
}
return wrapped.Messages
}
func parseReferenceList(raw json.RawMessage) []interface{} {
var references []interface{}
references := make([]interface{}, 0)
if len(raw) == 0 {
return references
}
err := json.Unmarshal(raw, &references)
if err != nil {
return nil
return make([]interface{}, 0)
}
if references == nil {
return make([]interface{}, 0)
}
return references
}

View File

@@ -905,6 +905,61 @@ func TestParseMessages_LegacyWrappedObject(t *testing.T) {
}
}
func TestBuildSessionPayload_EmptyCollectionsEncodeAsEmptyArrays(t *testing.T) {
svc := &ChatSessionService{}
payload := svc.buildSessionPayload(&entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Message: nil,
Reference: json.RawMessage(`null`),
}, nil, false)
if payload.Messages == nil {
t.Fatal("messages is nil")
}
if payload.Reference == nil {
t.Fatal("reference is nil")
}
body, err := json.Marshal(payload)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
if !strings.Contains(string(body), `"messages":[]`) {
t.Fatalf("messages did not encode as empty array: %s", string(body))
}
if !strings.Contains(string(body), `"reference":[]`) {
t.Fatalf("reference did not encode as empty array: %s", string(body))
}
}
func TestParseCollections_ReturnEmptySlicesForMissingNullOrInvalid(t *testing.T) {
messageInputs := []json.RawMessage{
nil,
json.RawMessage(`null`),
json.RawMessage(`{"messages":null}`),
json.RawMessage(`not-json`),
}
for _, input := range messageInputs {
got := parseMessages(input)
if got == nil || len(got) != 0 {
t.Fatalf("parseMessages(%s)=%#v", string(input), got)
}
}
referenceInputs := []json.RawMessage{
nil,
json.RawMessage(`null`),
json.RawMessage(`not-json`),
}
for _, input := range referenceInputs {
got := parseReferenceList(input)
if got == nil || len(got) != 0 {
t.Fatalf("parseReferenceList(%s)=%#v", string(input), got)
}
}
}
func TestCompletionStream_EmptyMessages(t *testing.T) {
svc := &ChatSessionService{
chatSessionDAO: &fakeSessionStore{},