feat(go-api): add chat update endpoints (#16378)

## Summary

- Added Go API route `PUT /api/v1/chats/:chat_id` to align with Python
`PUT /api/v1/chats/<chat_id>` chat update behavior.
- Added Go API route `PATCH /api/v1/chats/:chat_id` to align with Python
`PATCH /api/v1/chats/<chat_id>` partial chat update behavior.
- Added matching handler and service logic for owner checks, tenant
validation, persisted-field filtering, read-only field filtering,
`dataset_ids` to `kb_ids` conversion, and PATCH shallow merge semantics
for `prompt_config` and `llm_setting`.
This commit is contained in:
Hz_
2026-06-26 19:22:57 +08:00
committed by GitHub
parent a1f1dd5007
commit e3063da390
5 changed files with 692 additions and 1 deletions

View File

@@ -503,3 +503,57 @@ func (h *ChatHandler) GetChat(c *gin.Context) {
"message": "success",
})
}
// UpdateChat updates a chat by ID using REST PUT semantics.
func (h *ChatHandler) UpdateChat(c *gin.Context) {
h.updateChatByMethod(c, false)
}
// PatchChat updates a chat by ID using REST PATCH semantics.
func (h *ChatHandler) PatchChat(c *gin.Context) {
h.updateChatByMethod(c, true)
}
func (h *ChatHandler) updateChatByMethod(c *gin.Context, patch bool) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
chatID := c.Param("chat_id")
if chatID == "" {
jsonError(c, common.CodeBadRequest, "chat_id is required")
return
}
var req map[string]interface{}
if err := c.ShouldBindJSON(&req); err != nil {
jsonError(c, common.CodeDataError, err.Error())
return
}
var (
result map[string]interface{}
err error
)
if patch {
result, err = h.chatService.PatchChat(user.ID, chatID, req)
} else {
result, err = h.chatService.UpdateChat(user.ID, chatID, req)
}
if err != nil {
if err.Error() == "no authorization" {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeAuthenticationError,
"data": false,
"message": "No authorization.",
})
return
}
jsonError(c, common.CodeDataError, err.Error())
return
}
jsonResponse(c, common.CodeSuccess, result, "success")
}

View File

@@ -25,7 +25,7 @@ func setupChatHandlerTestDB(t *testing.T) *gorm.DB {
t.Fatalf("failed to open sqlite: %v", err)
}
if err := db.AutoMigrate(&entity.Chat{}); err != nil {
if err := db.AutoMigrate(&entity.Chat{}, &entity.Tenant{}); err != nil {
t.Fatalf("failed to migrate test schema: %v", err)
}
@@ -33,6 +33,20 @@ func setupChatHandlerTestDB(t *testing.T) *gorm.DB {
dao.DB = db
t.Cleanup(func() { dao.DB = origDB })
status := string(entity.StatusValid)
if err := db.Create(&entity.Tenant{
ID: "user-1",
LLMID: "model-a",
EmbdID: "embd-a",
ASRID: "asr-a",
Img2TxtID: "img2txt-a",
RerankID: "rerank-a",
ParserIDs: "naive",
Status: &status,
}).Error; err != nil {
t.Fatalf("failed to create tenant: %v", err)
}
return db
}
@@ -120,3 +134,68 @@ func TestBulkDeleteChatsHandlerPartialSuccess(t *testing.T) {
t.Fatalf("unexpected message: %v", resp["message"])
}
}
func TestPatchChatHandlerSuccess(t *testing.T) {
db := setupChatHandlerTestDB(t)
createChatHandlerTestChat(t, db, "chat-1", "user-1")
h := NewChatHandler(service.NewChatService(), service.NewUserService())
c, w := setupGinContextWithUser("PATCH", "/api/v1/chats/chat-1", `{"name":" updated chat "}`)
c.Params = []gin.Param{{Key: "chat_id", Value: "chat-1"}}
h.PatchChat(c)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if resp["code"] != float64(common.CodeSuccess) {
t.Fatalf("expected success code, got %v", resp["code"])
}
data, ok := resp["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected object data, got %+v", resp["data"])
}
if data["name"] != "updated chat" {
t.Fatalf("expected trimmed name in response, got %+v", data["name"])
}
if _, ok := data["kb_ids"]; ok {
t.Fatalf("response must not expose kb_ids: %+v", data)
}
if _, ok := data["dataset_ids"]; !ok {
t.Fatalf("response should expose dataset_ids: %+v", data)
}
}
func TestUpdateChatHandlerRejectsNonOwner(t *testing.T) {
db := setupChatHandlerTestDB(t)
createChatHandlerTestChat(t, db, "chat-1", "tenant-2")
h := NewChatHandler(service.NewChatService(), service.NewUserService())
c, w := setupGinContextWithUser("PUT", "/api/v1/chats/chat-1", `{"name":"updated"}`)
c.Params = []gin.Param{{Key: "chat_id", Value: "chat-1"}}
h.UpdateChat(c)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if resp["code"] != float64(common.CodeAuthenticationError) {
t.Fatalf("expected auth error code, got %v", resp["code"])
}
if resp["data"] != false {
t.Fatalf("expected data=false, got %v", resp["data"])
}
if resp["message"] != "No authorization." {
t.Fatalf("unexpected message: %v", resp["message"])
}
}

View File

@@ -256,6 +256,8 @@ func (r *Router) Setup(engine *gin.Engine) {
chats.DELETE("", r.chatHandler.BulkDeleteChats)
chats.DELETE("/:chat_id", r.chatHandler.DeleteChat)
chats.GET("/:chat_id", r.chatHandler.GetChat)
chats.PUT("/:chat_id", r.chatHandler.UpdateChat)
chats.PATCH("/:chat_id", r.chatHandler.PatchChat)
chats.GET("/:chat_id/sessions", r.chatSessionHandler.ListChatSessions)
chats.GET("/:chat_id/sessions/:session_id", r.chatSessionHandler.GetSession)
chats.PATCH("/:chat_id/sessions/:session_id", r.chatSessionHandler.UpdateSession)

View File

@@ -21,6 +21,7 @@ import (
"fmt"
"ragflow/internal/common"
"ragflow/internal/entity"
"reflect"
"strings"
"unicode/utf8"
@@ -657,6 +658,373 @@ func (s *ChatService) getOwnedValidChat(userID, chatID string) (*entity.Chat, er
return chat, nil
}
var chatPersistedFields = map[string]struct{}{
"name": {},
"description": {},
"icon": {},
"language": {},
"llm_id": {},
"tenant_llm_id": {},
"llm_setting": {},
"prompt_type": {},
"prompt_config": {},
"meta_data_filter": {},
"similarity_threshold": {},
"vector_similarity_weight": {},
"top_n": {},
"top_k": {},
"do_refer": {},
"rerank_id": {},
"tenant_rerank_id": {},
"kb_ids": {},
"status": {},
}
var chatReadonlyFields = map[string]struct{}{
"id": {},
"tenant_id": {},
"created_by": {},
"create_time": {},
"create_date": {},
"update_time": {},
"update_date": {},
}
var defaultRerankModels = map[string]struct{}{
"BAAI/bge-reranker-v2-m3": {},
"maidalun1020/bce-reranker-base_v1": {},
}
// UpdateChat mirrors PUT /api/v1/chats/<chat_id> in the Python REST API.
func (s *ChatService) UpdateChat(userID, chatID string, req map[string]interface{}) (map[string]interface{}, error) {
return s.updateChatREST(userID, chatID, req, false)
}
// PatchChat mirrors PATCH /api/v1/chats/<chat_id> in the Python REST API.
func (s *ChatService) PatchChat(userID, chatID string, req map[string]interface{}) (map[string]interface{}, error) {
return s.updateChatREST(userID, chatID, req, true)
}
func (s *ChatService) updateChatREST(userID, chatID string, req map[string]interface{}, patch bool) (map[string]interface{}, error) {
currentChat, err := s.getOwnedValidChat(userID, chatID)
if err != nil {
return nil, err
}
if _, err := s.tenantDAO.GetByID(userID); err != nil {
return nil, errors.New("Tenant not found!")
}
if !patch && isTruthy(req["tenant_id"]) {
return nil, errors.New("`tenant_id` must not be provided.")
}
if value, ok := req["name"]; ok {
name, shouldSet, err := validateRESTChatName(value, !patch)
if err != nil {
return nil, err
}
if shouldSet {
req["name"] = name
} else {
delete(req, "name")
}
}
if value, ok := req["dataset_ids"]; ok {
kbIDs, err := s.validateRESTDatasetIDs(value, userID)
if err != nil {
return nil, err
}
req["kb_ids"] = kbIDs
delete(req, "dataset_ids")
}
var llmSetting map[string]interface{}
llmSettingProvided := false
if value, ok := req["llm_setting"]; ok {
llmSettingProvided = true
setting, ok := mapFromValue(value)
if !ok {
return nil, errors.New("`llm_setting` should be an object.")
}
llmSetting = setting
}
if value, ok := req["llm_id"]; ok {
llmID := fmt.Sprint(value)
if err := s.validateRESTLLMID(llmID, userID, llmSetting); err != nil {
return nil, err
}
}
if value, ok := req["rerank_id"]; ok {
rerankID := fmt.Sprint(value)
if err := s.validateRESTRerankID(rerankID, userID); err != nil {
return nil, err
}
}
if value, ok := req["prompt_config"]; ok {
promptConfig, ok := mapFromValue(value)
if !ok {
return nil, errors.New("`prompt_config` should be an object.")
}
if patch {
req["prompt_config"] = mergeJSONMap(currentChat.PromptConfig, promptConfig)
} else {
req["prompt_config"] = entity.JSONMap(promptConfig)
}
}
if llmSettingProvided {
if patch {
req["llm_setting"] = mergeJSONMap(currentChat.LLMSetting, llmSetting)
} else {
req["llm_setting"] = entity.JSONMap(llmSetting)
}
}
if value, ok := req["meta_data_filter"]; ok && value != nil {
metaDataFilter, ok := mapFromValue(value)
if !ok {
return nil, errors.New("`meta_data_filter` should be an object.")
}
req["meta_data_filter"] = entity.JSONMap(metaDataFilter)
}
updates := filterRESTChatUpdates(req)
if value, ok := updates["name"]; ok {
name := value.(string)
currentName := ""
if currentChat.Name != nil {
currentName = *currentChat.Name
}
if strings.ToLower(name) != strings.ToLower(currentName) {
existingNames, err := s.chatDAO.GetExistingNames(userID, string(entity.StatusValid))
if err != nil {
return nil, err
}
for _, existingName := range existingNames {
if existingName == name {
return nil, errors.New("Duplicated chat name.")
}
}
}
}
if len(updates) > 0 {
if err := s.chatDAO.UpdateByID(chatID, updates); err != nil {
if patch {
return nil, errors.New("Failed to update chat.")
}
return nil, errors.New("Chat not found!")
}
}
updatedChat, err := s.chatDAO.GetByID(chatID)
if err != nil {
return nil, errors.New("Failed to retrieve updated chat.")
}
return s.buildRESTChatResponse(updatedChat), nil
}
func validateRESTChatName(value interface{}, required bool) (string, bool, error) {
if value == nil {
if required {
return "", false, errors.New("`name` is required.")
}
return "", false, nil
}
name, ok := value.(string)
if !ok {
return "", false, errors.New("Chat name must be a string.")
}
name = strings.TrimSpace(name)
if name == "" {
if required {
return "", false, errors.New("`name` is required.")
}
return "", false, errors.New("`name` cannot be empty.")
}
if len([]byte(name)) > 255 {
return "", false, fmt.Errorf("Chat name length is %d which is larger than 255.", len([]byte(name)))
}
return name, true, nil
}
func (s *ChatService) validateRESTDatasetIDs(value interface{}, userID string) (entity.JSONSlice, error) {
if value == nil {
return entity.JSONSlice{}, nil
}
items, ok := value.([]interface{})
if !ok {
return nil, errors.New("`dataset_ids` should be a list.")
}
var kbs []*entity.Knowledgebase
kbIDs := make(entity.JSONSlice, 0, len(items))
for _, item := range items {
if !isTruthy(item) {
continue
}
datasetID := fmt.Sprint(item)
if !s.kbDAO.Accessible(datasetID, userID) {
return nil, fmt.Errorf("You don't own the dataset %s", datasetID)
}
kb, err := s.kbDAO.GetByID(datasetID)
if err != nil || kb == nil {
return nil, fmt.Errorf("You don't own the dataset %s", datasetID)
}
if kb.ChunkNum == 0 {
return nil, fmt.Errorf("The dataset %s doesn't own parsed file", datasetID)
}
kbs = append(kbs, kb)
kbIDs = append(kbIDs, datasetID)
}
embdIDs := make([]string, 0, len(kbs))
seenEmbdIDs := make(map[string]struct{})
for _, kb := range kbs {
embdIDs = append(embdIDs, kb.EmbdID)
seenEmbdIDs[s.splitModelNameAndFactory(kb.EmbdID)] = struct{}{}
}
if len(seenEmbdIDs) > 1 {
return nil, fmt.Errorf("Datasets use different embedding models: %v", embdIDs)
}
return kbIDs, nil
}
func (s *ChatService) validateRESTLLMID(llmID, tenantID string, llmSetting map[string]interface{}) error {
if llmID == "" {
return nil
}
modelType := entity.ModelTypeChat
if rawModelType, ok := llmSetting["model_type"]; ok {
switch typedModelType := rawModelType.(type) {
case string:
if typedModelType == string(entity.ModelTypeImage2Text) {
modelType = entity.ModelTypeImage2Text
}
case []interface{}:
for _, item := range typedModelType {
if fmt.Sprint(item) == string(entity.ModelTypeImage2Text) {
modelType = entity.ModelTypeImage2Text
break
}
}
}
}
if _, _, _, _, err := NewModelProviderService().GetModelConfigFromProviderInstance(tenantID, modelType, llmID); err != nil {
return fmt.Errorf("`llm_id` %s doesn't exist", llmID)
}
return nil
}
func (s *ChatService) validateRESTRerankID(rerankID, tenantID string) error {
if rerankID == "" {
return nil
}
baseName := s.splitModelNameAndFactory(rerankID)
if _, ok := defaultRerankModels[baseName]; ok {
return nil
}
if _, _, _, _, err := NewModelProviderService().GetModelConfigFromProviderInstance(tenantID, entity.ModelTypeRerank, rerankID); err != nil {
return fmt.Errorf("`rerank_id` %s doesn't exist", rerankID)
}
return nil
}
func filterRESTChatUpdates(req map[string]interface{}) map[string]interface{} {
updates := make(map[string]interface{})
for field, value := range req {
if _, ok := chatPersistedFields[field]; !ok {
continue
}
if _, ok := chatReadonlyFields[field]; ok {
continue
}
updates[field] = value
}
return updates
}
func mapFromValue(value interface{}) (map[string]interface{}, bool) {
if value == nil {
return nil, false
}
switch typedValue := value.(type) {
case map[string]interface{}:
return typedValue, true
case entity.JSONMap:
return map[string]interface{}(typedValue), true
default:
return nil, false
}
}
func mergeJSONMap(base entity.JSONMap, patch map[string]interface{}) entity.JSONMap {
merged := entity.JSONMap{}
for key, value := range base {
merged[key] = value
}
for key, value := range patch {
merged[key] = value
}
return merged
}
func isTruthy(value interface{}) bool {
if value == nil {
return false
}
switch typedValue := value.(type) {
case bool:
return typedValue
case string:
return typedValue != ""
case int, int8, int16, int32, int64:
return reflect.ValueOf(value).Int() != 0
case uint, uint8, uint16, uint32, uint64:
return reflect.ValueOf(value).Uint() != 0
case float32, float64:
return reflect.ValueOf(value).Float() != 0
default:
return true
}
}
func (s *ChatService) buildRESTChatResponse(chat *entity.Chat) map[string]interface{} {
kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs)
return map[string]interface{}{
"id": chat.ID,
"tenant_id": chat.TenantID,
"name": chat.Name,
"description": chat.Description,
"icon": chat.Icon,
"language": chat.Language,
"llm_id": chat.LLMID,
"tenant_llm_id": chat.TenantLLMID,
"llm_setting": chat.LLMSetting,
"prompt_type": chat.PromptType,
"prompt_config": chat.PromptConfig,
"meta_data_filter": chat.MetaDataFilter,
"similarity_threshold": chat.SimilarityThreshold,
"vector_similarity_weight": chat.VectorSimilarityWeight,
"top_n": chat.TopN,
"top_k": chat.TopK,
"do_refer": chat.DoRefer,
"rerank_id": chat.RerankID,
"tenant_rerank_id": chat.TenantRerankID,
"dataset_ids": datasetIDs,
"kb_names": kbNames,
"status": chat.Status,
"create_time": chat.CreateTime,
"create_date": chat.CreateDate,
"update_time": chat.UpdateTime,
"update_date": chat.UpdateDate,
}
}
// DeleteChat soft deletes a single chat owned by the current user.
func (s *ChatService) DeleteChat(userID, chatID string) error {
if _, err := s.getOwnedValidChat(userID, chatID); err != nil {

View File

@@ -0,0 +1,188 @@
package service
import (
"testing"
"github.com/glebarez/sqlite"
"gorm.io/gorm"
"ragflow/internal/dao"
"ragflow/internal/entity"
)
func setupChatRESTUpdateServiceTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
TranslateError: true,
})
if err != nil {
t.Fatalf("failed to open sqlite: %v", err)
}
if err := db.AutoMigrate(&entity.Chat{}, &entity.Tenant{}, &entity.Knowledgebase{}, &entity.UserTenant{}); err != nil {
t.Fatalf("failed to migrate test schema: %v", err)
}
origDB := dao.DB
dao.DB = db
t.Cleanup(func() { dao.DB = origDB })
status := string(entity.StatusValid)
if err := db.Create(&entity.Tenant{
ID: "user-1",
LLMID: "model-a",
EmbdID: "embd-a",
ASRID: "asr-a",
Img2TxtID: "img2txt-a",
RerankID: "rerank-a",
ParserIDs: "naive",
Status: &status,
}).Error; err != nil {
t.Fatalf("failed to create tenant: %v", err)
}
return db
}
func createChatRESTUpdateServiceTestChat(t *testing.T, db *gorm.DB, id, tenantID string) {
t.Helper()
name := "chat-" + id
status := string(entity.StatusValid)
chat := &entity.Chat{
ID: id,
TenantID: tenantID,
Name: &name,
LLMID: "model-a",
LLMSetting: entity.JSONMap{"temperature": float64(0.1), "top_p": float64(0.9)},
PromptType: "simple",
PromptConfig: entity.JSONMap{"system": "old system", "quote": true},
KBIDs: entity.JSONSlice{},
Status: &status,
}
if err := db.Create(chat).Error; err != nil {
t.Fatalf("failed to create chat: %v", err)
}
}
func TestChatServicePatchChatMergesPromptConfigAndLLMSetting(t *testing.T) {
db := setupChatRESTUpdateServiceTestDB(t)
createChatRESTUpdateServiceTestChat(t, db, "chat-1", "user-1")
svc := NewChatService()
resp, err := svc.PatchChat("user-1", "chat-1", map[string]interface{}{
"prompt_config": map[string]interface{}{"quote": false},
"llm_setting": map[string]interface{}{"temperature": float64(0.2)},
})
if err != nil {
t.Fatalf("PatchChat failed: %v", err)
}
if _, ok := resp["kb_ids"]; ok {
t.Fatalf("response must not expose kb_ids: %+v", resp)
}
if _, ok := resp["dataset_ids"]; !ok {
t.Fatalf("response should expose dataset_ids: %+v", resp)
}
chat, err := svc.chatDAO.GetByID("chat-1")
if err != nil {
t.Fatalf("failed to fetch updated chat: %v", err)
}
if chat.PromptConfig["system"] != "old system" {
t.Fatalf("expected prompt_config.system to be preserved, got %+v", chat.PromptConfig)
}
if chat.PromptConfig["quote"] != false {
t.Fatalf("expected prompt_config.quote to be patched, got %+v", chat.PromptConfig)
}
if chat.LLMSetting["top_p"] != float64(0.9) {
t.Fatalf("expected llm_setting.top_p to be preserved, got %+v", chat.LLMSetting)
}
if chat.LLMSetting["temperature"] != float64(0.2) {
t.Fatalf("expected llm_setting.temperature to be patched, got %+v", chat.LLMSetting)
}
}
func TestChatServiceUpdateChatRejectsTenantID(t *testing.T) {
db := setupChatRESTUpdateServiceTestDB(t)
createChatRESTUpdateServiceTestChat(t, db, "chat-1", "user-1")
svc := NewChatService()
_, err := svc.UpdateChat("user-1", "chat-1", map[string]interface{}{
"tenant_id": "tenant-2",
})
if err == nil || err.Error() != "`tenant_id` must not be provided." {
t.Fatalf("expected tenant_id error, got %v", err)
}
}
func TestChatServiceUpdateChatRejectsInvalidLLMSetting(t *testing.T) {
db := setupChatRESTUpdateServiceTestDB(t)
createChatRESTUpdateServiceTestChat(t, db, "chat-1", "user-1")
svc := NewChatService()
_, err := svc.UpdateChat("user-1", "chat-1", map[string]interface{}{
"llm_setting": "invalid",
})
if err == nil || err.Error() != "`llm_setting` should be an object." {
t.Fatalf("expected llm_setting error, got %v", err)
}
chat, err := svc.chatDAO.GetByID("chat-1")
if err != nil {
t.Fatalf("failed to fetch chat: %v", err)
}
if chat.LLMSetting["temperature"] != float64(0.1) {
t.Fatalf("expected llm_setting to remain unchanged, got %+v", chat.LLMSetting)
}
}
func TestChatServiceUpdateChatAcceptsMetaDataFilterObject(t *testing.T) {
db := setupChatRESTUpdateServiceTestDB(t)
createChatRESTUpdateServiceTestChat(t, db, "chat-1", "user-1")
svc := NewChatService()
_, err := svc.UpdateChat("user-1", "chat-1", map[string]interface{}{
"name": "chat-chat-1",
"meta_data_filter": map[string]interface{}{
"method": "disabled",
"manual": []interface{}{},
},
})
if err != nil {
t.Fatalf("UpdateChat failed: %v", err)
}
chat, err := svc.chatDAO.GetByID("chat-1")
if err != nil {
t.Fatalf("failed to fetch chat: %v", err)
}
if chat.MetaDataFilter == nil || (*chat.MetaDataFilter)["method"] != "disabled" {
t.Fatalf("expected meta_data_filter to be persisted, got %+v", chat.MetaDataFilter)
}
}
func TestChatServicePatchChatIgnoresTenantIDAndUpdatesName(t *testing.T) {
db := setupChatRESTUpdateServiceTestDB(t)
createChatRESTUpdateServiceTestChat(t, db, "chat-1", "user-1")
svc := NewChatService()
_, err := svc.PatchChat("user-1", "chat-1", map[string]interface{}{
"tenant_id": "tenant-2",
"name": " renamed chat ",
})
if err != nil {
t.Fatalf("PatchChat failed: %v", err)
}
chat, err := svc.chatDAO.GetByID("chat-1")
if err != nil {
t.Fatalf("failed to fetch updated chat: %v", err)
}
if chat.TenantID != "user-1" {
t.Fatalf("expected tenant_id to remain user-1, got %s", chat.TenantID)
}
if chat.Name == nil || *chat.Name != "renamed chat" {
t.Fatalf("expected trimmed name, got %+v", chat.Name)
}
}