mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
feat(go-api): add chat update endpoints (#16378)
## Summary - Added Go API route `PUT /api/v1/chats/:chat_id` to align with Python `PUT /api/v1/chats/<chat_id>` chat update behavior. - Added Go API route `PATCH /api/v1/chats/:chat_id` to align with Python `PATCH /api/v1/chats/<chat_id>` partial chat update behavior. - Added matching handler and service logic for owner checks, tenant validation, persisted-field filtering, read-only field filtering, `dataset_ids` to `kb_ids` conversion, and PATCH shallow merge semantics for `prompt_config` and `llm_setting`.
This commit is contained in:
@@ -503,3 +503,57 @@ func (h *ChatHandler) GetChat(c *gin.Context) {
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateChat updates a chat by ID using REST PUT semantics.
|
||||
func (h *ChatHandler) UpdateChat(c *gin.Context) {
|
||||
h.updateChatByMethod(c, false)
|
||||
}
|
||||
|
||||
// PatchChat updates a chat by ID using REST PATCH semantics.
|
||||
func (h *ChatHandler) PatchChat(c *gin.Context) {
|
||||
h.updateChatByMethod(c, true)
|
||||
}
|
||||
|
||||
func (h *ChatHandler) updateChatByMethod(c *gin.Context, patch bool) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
chatID := c.Param("chat_id")
|
||||
if chatID == "" {
|
||||
jsonError(c, common.CodeBadRequest, "chat_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req map[string]interface{}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
jsonError(c, common.CodeDataError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
result map[string]interface{}
|
||||
err error
|
||||
)
|
||||
if patch {
|
||||
result, err = h.chatService.PatchChat(user.ID, chatID, req)
|
||||
} else {
|
||||
result, err = h.chatService.UpdateChat(user.ID, chatID, req)
|
||||
}
|
||||
if err != nil {
|
||||
if err.Error() == "no authorization" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeAuthenticationError,
|
||||
"data": false,
|
||||
"message": "No authorization.",
|
||||
})
|
||||
return
|
||||
}
|
||||
jsonError(c, common.CodeDataError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ func setupChatHandlerTestDB(t *testing.T) *gorm.DB {
|
||||
t.Fatalf("failed to open sqlite: %v", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&entity.Chat{}); err != nil {
|
||||
if err := db.AutoMigrate(&entity.Chat{}, &entity.Tenant{}); err != nil {
|
||||
t.Fatalf("failed to migrate test schema: %v", err)
|
||||
}
|
||||
|
||||
@@ -33,6 +33,20 @@ func setupChatHandlerTestDB(t *testing.T) *gorm.DB {
|
||||
dao.DB = db
|
||||
t.Cleanup(func() { dao.DB = origDB })
|
||||
|
||||
status := string(entity.StatusValid)
|
||||
if err := db.Create(&entity.Tenant{
|
||||
ID: "user-1",
|
||||
LLMID: "model-a",
|
||||
EmbdID: "embd-a",
|
||||
ASRID: "asr-a",
|
||||
Img2TxtID: "img2txt-a",
|
||||
RerankID: "rerank-a",
|
||||
ParserIDs: "naive",
|
||||
Status: &status,
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("failed to create tenant: %v", err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
@@ -120,3 +134,68 @@ func TestBulkDeleteChatsHandlerPartialSuccess(t *testing.T) {
|
||||
t.Fatalf("unexpected message: %v", resp["message"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchChatHandlerSuccess(t *testing.T) {
|
||||
db := setupChatHandlerTestDB(t)
|
||||
createChatHandlerTestChat(t, db, "chat-1", "user-1")
|
||||
|
||||
h := NewChatHandler(service.NewChatService(), service.NewUserService())
|
||||
c, w := setupGinContextWithUser("PATCH", "/api/v1/chats/chat-1", `{"name":" updated chat "}`)
|
||||
c.Params = []gin.Param{{Key: "chat_id", Value: "chat-1"}}
|
||||
|
||||
h.PatchChat(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected success code, got %v", resp["code"])
|
||||
}
|
||||
data, ok := resp["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected object data, got %+v", resp["data"])
|
||||
}
|
||||
if data["name"] != "updated chat" {
|
||||
t.Fatalf("expected trimmed name in response, got %+v", data["name"])
|
||||
}
|
||||
if _, ok := data["kb_ids"]; ok {
|
||||
t.Fatalf("response must not expose kb_ids: %+v", data)
|
||||
}
|
||||
if _, ok := data["dataset_ids"]; !ok {
|
||||
t.Fatalf("response should expose dataset_ids: %+v", data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateChatHandlerRejectsNonOwner(t *testing.T) {
|
||||
db := setupChatHandlerTestDB(t)
|
||||
createChatHandlerTestChat(t, db, "chat-1", "tenant-2")
|
||||
|
||||
h := NewChatHandler(service.NewChatService(), service.NewUserService())
|
||||
c, w := setupGinContextWithUser("PUT", "/api/v1/chats/chat-1", `{"name":"updated"}`)
|
||||
c.Params = []gin.Param{{Key: "chat_id", Value: "chat-1"}}
|
||||
|
||||
h.UpdateChat(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeAuthenticationError) {
|
||||
t.Fatalf("expected auth error code, got %v", resp["code"])
|
||||
}
|
||||
if resp["data"] != false {
|
||||
t.Fatalf("expected data=false, got %v", resp["data"])
|
||||
}
|
||||
if resp["message"] != "No authorization." {
|
||||
t.Fatalf("unexpected message: %v", resp["message"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -256,6 +256,8 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
chats.DELETE("", r.chatHandler.BulkDeleteChats)
|
||||
chats.DELETE("/:chat_id", r.chatHandler.DeleteChat)
|
||||
chats.GET("/:chat_id", r.chatHandler.GetChat)
|
||||
chats.PUT("/:chat_id", r.chatHandler.UpdateChat)
|
||||
chats.PATCH("/:chat_id", r.chatHandler.PatchChat)
|
||||
chats.GET("/:chat_id/sessions", r.chatSessionHandler.ListChatSessions)
|
||||
chats.GET("/:chat_id/sessions/:session_id", r.chatSessionHandler.GetSession)
|
||||
chats.PATCH("/:chat_id/sessions/:session_id", r.chatSessionHandler.UpdateSession)
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"fmt"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/entity"
|
||||
"reflect"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
@@ -657,6 +658,373 @@ func (s *ChatService) getOwnedValidChat(userID, chatID string) (*entity.Chat, er
|
||||
return chat, nil
|
||||
}
|
||||
|
||||
var chatPersistedFields = map[string]struct{}{
|
||||
"name": {},
|
||||
"description": {},
|
||||
"icon": {},
|
||||
"language": {},
|
||||
"llm_id": {},
|
||||
"tenant_llm_id": {},
|
||||
"llm_setting": {},
|
||||
"prompt_type": {},
|
||||
"prompt_config": {},
|
||||
"meta_data_filter": {},
|
||||
"similarity_threshold": {},
|
||||
"vector_similarity_weight": {},
|
||||
"top_n": {},
|
||||
"top_k": {},
|
||||
"do_refer": {},
|
||||
"rerank_id": {},
|
||||
"tenant_rerank_id": {},
|
||||
"kb_ids": {},
|
||||
"status": {},
|
||||
}
|
||||
|
||||
var chatReadonlyFields = map[string]struct{}{
|
||||
"id": {},
|
||||
"tenant_id": {},
|
||||
"created_by": {},
|
||||
"create_time": {},
|
||||
"create_date": {},
|
||||
"update_time": {},
|
||||
"update_date": {},
|
||||
}
|
||||
|
||||
var defaultRerankModels = map[string]struct{}{
|
||||
"BAAI/bge-reranker-v2-m3": {},
|
||||
"maidalun1020/bce-reranker-base_v1": {},
|
||||
}
|
||||
|
||||
// UpdateChat mirrors PUT /api/v1/chats/<chat_id> in the Python REST API.
|
||||
func (s *ChatService) UpdateChat(userID, chatID string, req map[string]interface{}) (map[string]interface{}, error) {
|
||||
return s.updateChatREST(userID, chatID, req, false)
|
||||
}
|
||||
|
||||
// PatchChat mirrors PATCH /api/v1/chats/<chat_id> in the Python REST API.
|
||||
func (s *ChatService) PatchChat(userID, chatID string, req map[string]interface{}) (map[string]interface{}, error) {
|
||||
return s.updateChatREST(userID, chatID, req, true)
|
||||
}
|
||||
|
||||
func (s *ChatService) updateChatREST(userID, chatID string, req map[string]interface{}, patch bool) (map[string]interface{}, error) {
|
||||
currentChat, err := s.getOwnedValidChat(userID, chatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := s.tenantDAO.GetByID(userID); err != nil {
|
||||
return nil, errors.New("Tenant not found!")
|
||||
}
|
||||
|
||||
if !patch && isTruthy(req["tenant_id"]) {
|
||||
return nil, errors.New("`tenant_id` must not be provided.")
|
||||
}
|
||||
|
||||
if value, ok := req["name"]; ok {
|
||||
name, shouldSet, err := validateRESTChatName(value, !patch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if shouldSet {
|
||||
req["name"] = name
|
||||
} else {
|
||||
delete(req, "name")
|
||||
}
|
||||
}
|
||||
|
||||
if value, ok := req["dataset_ids"]; ok {
|
||||
kbIDs, err := s.validateRESTDatasetIDs(value, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req["kb_ids"] = kbIDs
|
||||
delete(req, "dataset_ids")
|
||||
}
|
||||
|
||||
var llmSetting map[string]interface{}
|
||||
llmSettingProvided := false
|
||||
if value, ok := req["llm_setting"]; ok {
|
||||
llmSettingProvided = true
|
||||
setting, ok := mapFromValue(value)
|
||||
if !ok {
|
||||
return nil, errors.New("`llm_setting` should be an object.")
|
||||
}
|
||||
llmSetting = setting
|
||||
}
|
||||
|
||||
if value, ok := req["llm_id"]; ok {
|
||||
llmID := fmt.Sprint(value)
|
||||
if err := s.validateRESTLLMID(llmID, userID, llmSetting); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if value, ok := req["rerank_id"]; ok {
|
||||
rerankID := fmt.Sprint(value)
|
||||
if err := s.validateRESTRerankID(rerankID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if value, ok := req["prompt_config"]; ok {
|
||||
promptConfig, ok := mapFromValue(value)
|
||||
if !ok {
|
||||
return nil, errors.New("`prompt_config` should be an object.")
|
||||
}
|
||||
if patch {
|
||||
req["prompt_config"] = mergeJSONMap(currentChat.PromptConfig, promptConfig)
|
||||
} else {
|
||||
req["prompt_config"] = entity.JSONMap(promptConfig)
|
||||
}
|
||||
}
|
||||
|
||||
if llmSettingProvided {
|
||||
if patch {
|
||||
req["llm_setting"] = mergeJSONMap(currentChat.LLMSetting, llmSetting)
|
||||
} else {
|
||||
req["llm_setting"] = entity.JSONMap(llmSetting)
|
||||
}
|
||||
}
|
||||
|
||||
if value, ok := req["meta_data_filter"]; ok && value != nil {
|
||||
metaDataFilter, ok := mapFromValue(value)
|
||||
if !ok {
|
||||
return nil, errors.New("`meta_data_filter` should be an object.")
|
||||
}
|
||||
req["meta_data_filter"] = entity.JSONMap(metaDataFilter)
|
||||
}
|
||||
|
||||
updates := filterRESTChatUpdates(req)
|
||||
if value, ok := updates["name"]; ok {
|
||||
name := value.(string)
|
||||
currentName := ""
|
||||
if currentChat.Name != nil {
|
||||
currentName = *currentChat.Name
|
||||
}
|
||||
if strings.ToLower(name) != strings.ToLower(currentName) {
|
||||
existingNames, err := s.chatDAO.GetExistingNames(userID, string(entity.StatusValid))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, existingName := range existingNames {
|
||||
if existingName == name {
|
||||
return nil, errors.New("Duplicated chat name.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(updates) > 0 {
|
||||
if err := s.chatDAO.UpdateByID(chatID, updates); err != nil {
|
||||
if patch {
|
||||
return nil, errors.New("Failed to update chat.")
|
||||
}
|
||||
return nil, errors.New("Chat not found!")
|
||||
}
|
||||
}
|
||||
|
||||
updatedChat, err := s.chatDAO.GetByID(chatID)
|
||||
if err != nil {
|
||||
return nil, errors.New("Failed to retrieve updated chat.")
|
||||
}
|
||||
return s.buildRESTChatResponse(updatedChat), nil
|
||||
}
|
||||
|
||||
func validateRESTChatName(value interface{}, required bool) (string, bool, error) {
|
||||
if value == nil {
|
||||
if required {
|
||||
return "", false, errors.New("`name` is required.")
|
||||
}
|
||||
return "", false, nil
|
||||
}
|
||||
name, ok := value.(string)
|
||||
if !ok {
|
||||
return "", false, errors.New("Chat name must be a string.")
|
||||
}
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
if required {
|
||||
return "", false, errors.New("`name` is required.")
|
||||
}
|
||||
return "", false, errors.New("`name` cannot be empty.")
|
||||
}
|
||||
if len([]byte(name)) > 255 {
|
||||
return "", false, fmt.Errorf("Chat name length is %d which is larger than 255.", len([]byte(name)))
|
||||
}
|
||||
return name, true, nil
|
||||
}
|
||||
|
||||
func (s *ChatService) validateRESTDatasetIDs(value interface{}, userID string) (entity.JSONSlice, error) {
|
||||
if value == nil {
|
||||
return entity.JSONSlice{}, nil
|
||||
}
|
||||
items, ok := value.([]interface{})
|
||||
if !ok {
|
||||
return nil, errors.New("`dataset_ids` should be a list.")
|
||||
}
|
||||
|
||||
var kbs []*entity.Knowledgebase
|
||||
kbIDs := make(entity.JSONSlice, 0, len(items))
|
||||
for _, item := range items {
|
||||
if !isTruthy(item) {
|
||||
continue
|
||||
}
|
||||
datasetID := fmt.Sprint(item)
|
||||
if !s.kbDAO.Accessible(datasetID, userID) {
|
||||
return nil, fmt.Errorf("You don't own the dataset %s", datasetID)
|
||||
}
|
||||
kb, err := s.kbDAO.GetByID(datasetID)
|
||||
if err != nil || kb == nil {
|
||||
return nil, fmt.Errorf("You don't own the dataset %s", datasetID)
|
||||
}
|
||||
if kb.ChunkNum == 0 {
|
||||
return nil, fmt.Errorf("The dataset %s doesn't own parsed file", datasetID)
|
||||
}
|
||||
kbs = append(kbs, kb)
|
||||
kbIDs = append(kbIDs, datasetID)
|
||||
}
|
||||
|
||||
embdIDs := make([]string, 0, len(kbs))
|
||||
seenEmbdIDs := make(map[string]struct{})
|
||||
for _, kb := range kbs {
|
||||
embdIDs = append(embdIDs, kb.EmbdID)
|
||||
seenEmbdIDs[s.splitModelNameAndFactory(kb.EmbdID)] = struct{}{}
|
||||
}
|
||||
if len(seenEmbdIDs) > 1 {
|
||||
return nil, fmt.Errorf("Datasets use different embedding models: %v", embdIDs)
|
||||
}
|
||||
return kbIDs, nil
|
||||
}
|
||||
|
||||
func (s *ChatService) validateRESTLLMID(llmID, tenantID string, llmSetting map[string]interface{}) error {
|
||||
if llmID == "" {
|
||||
return nil
|
||||
}
|
||||
modelType := entity.ModelTypeChat
|
||||
if rawModelType, ok := llmSetting["model_type"]; ok {
|
||||
switch typedModelType := rawModelType.(type) {
|
||||
case string:
|
||||
if typedModelType == string(entity.ModelTypeImage2Text) {
|
||||
modelType = entity.ModelTypeImage2Text
|
||||
}
|
||||
case []interface{}:
|
||||
for _, item := range typedModelType {
|
||||
if fmt.Sprint(item) == string(entity.ModelTypeImage2Text) {
|
||||
modelType = entity.ModelTypeImage2Text
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if _, _, _, _, err := NewModelProviderService().GetModelConfigFromProviderInstance(tenantID, modelType, llmID); err != nil {
|
||||
return fmt.Errorf("`llm_id` %s doesn't exist", llmID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ChatService) validateRESTRerankID(rerankID, tenantID string) error {
|
||||
if rerankID == "" {
|
||||
return nil
|
||||
}
|
||||
baseName := s.splitModelNameAndFactory(rerankID)
|
||||
if _, ok := defaultRerankModels[baseName]; ok {
|
||||
return nil
|
||||
}
|
||||
if _, _, _, _, err := NewModelProviderService().GetModelConfigFromProviderInstance(tenantID, entity.ModelTypeRerank, rerankID); err != nil {
|
||||
return fmt.Errorf("`rerank_id` %s doesn't exist", rerankID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func filterRESTChatUpdates(req map[string]interface{}) map[string]interface{} {
|
||||
updates := make(map[string]interface{})
|
||||
for field, value := range req {
|
||||
if _, ok := chatPersistedFields[field]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := chatReadonlyFields[field]; ok {
|
||||
continue
|
||||
}
|
||||
updates[field] = value
|
||||
}
|
||||
return updates
|
||||
}
|
||||
|
||||
func mapFromValue(value interface{}) (map[string]interface{}, bool) {
|
||||
if value == nil {
|
||||
return nil, false
|
||||
}
|
||||
switch typedValue := value.(type) {
|
||||
case map[string]interface{}:
|
||||
return typedValue, true
|
||||
case entity.JSONMap:
|
||||
return map[string]interface{}(typedValue), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func mergeJSONMap(base entity.JSONMap, patch map[string]interface{}) entity.JSONMap {
|
||||
merged := entity.JSONMap{}
|
||||
for key, value := range base {
|
||||
merged[key] = value
|
||||
}
|
||||
for key, value := range patch {
|
||||
merged[key] = value
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func isTruthy(value interface{}) bool {
|
||||
if value == nil {
|
||||
return false
|
||||
}
|
||||
switch typedValue := value.(type) {
|
||||
case bool:
|
||||
return typedValue
|
||||
case string:
|
||||
return typedValue != ""
|
||||
case int, int8, int16, int32, int64:
|
||||
return reflect.ValueOf(value).Int() != 0
|
||||
case uint, uint8, uint16, uint32, uint64:
|
||||
return reflect.ValueOf(value).Uint() != 0
|
||||
case float32, float64:
|
||||
return reflect.ValueOf(value).Float() != 0
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ChatService) buildRESTChatResponse(chat *entity.Chat) map[string]interface{} {
|
||||
kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs)
|
||||
return map[string]interface{}{
|
||||
"id": chat.ID,
|
||||
"tenant_id": chat.TenantID,
|
||||
"name": chat.Name,
|
||||
"description": chat.Description,
|
||||
"icon": chat.Icon,
|
||||
"language": chat.Language,
|
||||
"llm_id": chat.LLMID,
|
||||
"tenant_llm_id": chat.TenantLLMID,
|
||||
"llm_setting": chat.LLMSetting,
|
||||
"prompt_type": chat.PromptType,
|
||||
"prompt_config": chat.PromptConfig,
|
||||
"meta_data_filter": chat.MetaDataFilter,
|
||||
"similarity_threshold": chat.SimilarityThreshold,
|
||||
"vector_similarity_weight": chat.VectorSimilarityWeight,
|
||||
"top_n": chat.TopN,
|
||||
"top_k": chat.TopK,
|
||||
"do_refer": chat.DoRefer,
|
||||
"rerank_id": chat.RerankID,
|
||||
"tenant_rerank_id": chat.TenantRerankID,
|
||||
"dataset_ids": datasetIDs,
|
||||
"kb_names": kbNames,
|
||||
"status": chat.Status,
|
||||
"create_time": chat.CreateTime,
|
||||
"create_date": chat.CreateDate,
|
||||
"update_time": chat.UpdateTime,
|
||||
"update_date": chat.UpdateDate,
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteChat soft deletes a single chat owned by the current user.
|
||||
func (s *ChatService) DeleteChat(userID, chatID string) error {
|
||||
if _, err := s.getOwnedValidChat(userID, chatID); err != nil {
|
||||
|
||||
188
internal/service/chat_rest_update_test.go
Normal file
188
internal/service/chat_rest_update_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
func setupChatRESTUpdateServiceTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
TranslateError: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open sqlite: %v", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&entity.Chat{}, &entity.Tenant{}, &entity.Knowledgebase{}, &entity.UserTenant{}); err != nil {
|
||||
t.Fatalf("failed to migrate test schema: %v", err)
|
||||
}
|
||||
|
||||
origDB := dao.DB
|
||||
dao.DB = db
|
||||
t.Cleanup(func() { dao.DB = origDB })
|
||||
|
||||
status := string(entity.StatusValid)
|
||||
if err := db.Create(&entity.Tenant{
|
||||
ID: "user-1",
|
||||
LLMID: "model-a",
|
||||
EmbdID: "embd-a",
|
||||
ASRID: "asr-a",
|
||||
Img2TxtID: "img2txt-a",
|
||||
RerankID: "rerank-a",
|
||||
ParserIDs: "naive",
|
||||
Status: &status,
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("failed to create tenant: %v", err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func createChatRESTUpdateServiceTestChat(t *testing.T, db *gorm.DB, id, tenantID string) {
|
||||
t.Helper()
|
||||
|
||||
name := "chat-" + id
|
||||
status := string(entity.StatusValid)
|
||||
chat := &entity.Chat{
|
||||
ID: id,
|
||||
TenantID: tenantID,
|
||||
Name: &name,
|
||||
LLMID: "model-a",
|
||||
LLMSetting: entity.JSONMap{"temperature": float64(0.1), "top_p": float64(0.9)},
|
||||
PromptType: "simple",
|
||||
PromptConfig: entity.JSONMap{"system": "old system", "quote": true},
|
||||
KBIDs: entity.JSONSlice{},
|
||||
Status: &status,
|
||||
}
|
||||
if err := db.Create(chat).Error; err != nil {
|
||||
t.Fatalf("failed to create chat: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatServicePatchChatMergesPromptConfigAndLLMSetting(t *testing.T) {
|
||||
db := setupChatRESTUpdateServiceTestDB(t)
|
||||
createChatRESTUpdateServiceTestChat(t, db, "chat-1", "user-1")
|
||||
|
||||
svc := NewChatService()
|
||||
resp, err := svc.PatchChat("user-1", "chat-1", map[string]interface{}{
|
||||
"prompt_config": map[string]interface{}{"quote": false},
|
||||
"llm_setting": map[string]interface{}{"temperature": float64(0.2)},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PatchChat failed: %v", err)
|
||||
}
|
||||
if _, ok := resp["kb_ids"]; ok {
|
||||
t.Fatalf("response must not expose kb_ids: %+v", resp)
|
||||
}
|
||||
if _, ok := resp["dataset_ids"]; !ok {
|
||||
t.Fatalf("response should expose dataset_ids: %+v", resp)
|
||||
}
|
||||
|
||||
chat, err := svc.chatDAO.GetByID("chat-1")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to fetch updated chat: %v", err)
|
||||
}
|
||||
if chat.PromptConfig["system"] != "old system" {
|
||||
t.Fatalf("expected prompt_config.system to be preserved, got %+v", chat.PromptConfig)
|
||||
}
|
||||
if chat.PromptConfig["quote"] != false {
|
||||
t.Fatalf("expected prompt_config.quote to be patched, got %+v", chat.PromptConfig)
|
||||
}
|
||||
if chat.LLMSetting["top_p"] != float64(0.9) {
|
||||
t.Fatalf("expected llm_setting.top_p to be preserved, got %+v", chat.LLMSetting)
|
||||
}
|
||||
if chat.LLMSetting["temperature"] != float64(0.2) {
|
||||
t.Fatalf("expected llm_setting.temperature to be patched, got %+v", chat.LLMSetting)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatServiceUpdateChatRejectsTenantID(t *testing.T) {
|
||||
db := setupChatRESTUpdateServiceTestDB(t)
|
||||
createChatRESTUpdateServiceTestChat(t, db, "chat-1", "user-1")
|
||||
|
||||
svc := NewChatService()
|
||||
_, err := svc.UpdateChat("user-1", "chat-1", map[string]interface{}{
|
||||
"tenant_id": "tenant-2",
|
||||
})
|
||||
if err == nil || err.Error() != "`tenant_id` must not be provided." {
|
||||
t.Fatalf("expected tenant_id error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatServiceUpdateChatRejectsInvalidLLMSetting(t *testing.T) {
|
||||
db := setupChatRESTUpdateServiceTestDB(t)
|
||||
createChatRESTUpdateServiceTestChat(t, db, "chat-1", "user-1")
|
||||
|
||||
svc := NewChatService()
|
||||
_, err := svc.UpdateChat("user-1", "chat-1", map[string]interface{}{
|
||||
"llm_setting": "invalid",
|
||||
})
|
||||
if err == nil || err.Error() != "`llm_setting` should be an object." {
|
||||
t.Fatalf("expected llm_setting error, got %v", err)
|
||||
}
|
||||
|
||||
chat, err := svc.chatDAO.GetByID("chat-1")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to fetch chat: %v", err)
|
||||
}
|
||||
if chat.LLMSetting["temperature"] != float64(0.1) {
|
||||
t.Fatalf("expected llm_setting to remain unchanged, got %+v", chat.LLMSetting)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatServiceUpdateChatAcceptsMetaDataFilterObject(t *testing.T) {
|
||||
db := setupChatRESTUpdateServiceTestDB(t)
|
||||
createChatRESTUpdateServiceTestChat(t, db, "chat-1", "user-1")
|
||||
|
||||
svc := NewChatService()
|
||||
_, err := svc.UpdateChat("user-1", "chat-1", map[string]interface{}{
|
||||
"name": "chat-chat-1",
|
||||
"meta_data_filter": map[string]interface{}{
|
||||
"method": "disabled",
|
||||
"manual": []interface{}{},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateChat failed: %v", err)
|
||||
}
|
||||
|
||||
chat, err := svc.chatDAO.GetByID("chat-1")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to fetch chat: %v", err)
|
||||
}
|
||||
if chat.MetaDataFilter == nil || (*chat.MetaDataFilter)["method"] != "disabled" {
|
||||
t.Fatalf("expected meta_data_filter to be persisted, got %+v", chat.MetaDataFilter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatServicePatchChatIgnoresTenantIDAndUpdatesName(t *testing.T) {
|
||||
db := setupChatRESTUpdateServiceTestDB(t)
|
||||
createChatRESTUpdateServiceTestChat(t, db, "chat-1", "user-1")
|
||||
|
||||
svc := NewChatService()
|
||||
_, err := svc.PatchChat("user-1", "chat-1", map[string]interface{}{
|
||||
"tenant_id": "tenant-2",
|
||||
"name": " renamed chat ",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PatchChat failed: %v", err)
|
||||
}
|
||||
|
||||
chat, err := svc.chatDAO.GetByID("chat-1")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to fetch updated chat: %v", err)
|
||||
}
|
||||
if chat.TenantID != "user-1" {
|
||||
t.Fatalf("expected tenant_id to remain user-1, got %s", chat.TenantID)
|
||||
}
|
||||
if chat.Name == nil || *chat.Name != "renamed chat" {
|
||||
t.Fatalf("expected trimmed name, got %+v", chat.Name)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user