diff --git a/internal/dao/chat.go b/internal/dao/chat.go index 98d300a3f2..e44e338f48 100644 --- a/internal/dao/chat.go +++ b/internal/dao/chat.go @@ -169,6 +169,15 @@ func (dao *ChatDAO) GetExistingNames(tenantID string, status string) ([]string, return names, err } +// ExistsByNameTenantStatus checks whether a chat with the given name exists. +func (dao *ChatDAO) ExistsByNameTenantStatus(name, tenantID, status string) (bool, error) { + var count int64 + err := DB.Model(&entity.Chat{}). + Where("name = ? AND tenant_id = ? AND status = ?", name, tenantID, status). + Count(&count).Error + return count > 0, err +} + // Create creates a new chat/dialog func (dao *ChatDAO) Create(chat *entity.Chat) error { return DB.Create(chat).Error diff --git a/internal/handler/chat.go b/internal/handler/chat.go index c84ebe74ef..e15f5f4c65 100644 --- a/internal/handler/chat.go +++ b/internal/handler/chat.go @@ -17,6 +17,7 @@ package handler import ( + "encoding/json" "net/http" "ragflow/internal/common" "strconv" @@ -81,7 +82,7 @@ func (h *ChatHandler) ListChats(c *gin.Context) { } // List chats - default to valid status "1" (same as Python StatusEnum.VALID.value) - result, err := h.chatService.ListChats(userID, keywords, "1", page, pageSize, orderby, desc) + result, err := h.chatService.ListChats(userID, "1", keywords, page, pageSize, orderby, desc) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, @@ -97,6 +98,46 @@ func (h *ChatHandler) ListChats(c *gin.Context) { }) } +// Create creates a chat. +// @Summary Create Chat +// @Description Create a chat, aligned with Python POST /api/v1/chats. +// @Tags chat +// @Accept json +// @Produce json +// @Param request body service.CreateChatRequest true "chat configuration" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/chats [post] +func (h *ChatHandler) Create(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + var req map[string]interface{} + decoder := json.NewDecoder(c.Request.Body) + decoder.UseNumber() + if err := decoder.Decode(&req); err != nil { + jsonError(c, common.CodeArgumentError, err.Error()) + return + } + if req == nil { + req = map[string]interface{}{} + } + + result, code, err := h.chatService.Create(user.ID, req) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "data": result, + "message": "success", + }) +} + // ListChatsNext list chats with advanced filtering and pagination // @Summary List Chats Next // @Description Get list of chats with filtering, pagination and sorting (equivalent to list_dialogs_next) diff --git a/internal/handler/chat_session.go b/internal/handler/chat_session.go index ec3dbaf449..bbc34c95c9 100644 --- a/internal/handler/chat_session.go +++ b/internal/handler/chat_session.go @@ -17,11 +17,13 @@ package handler import ( + "encoding/json" "errors" "fmt" "io" "net/http" "ragflow/internal/common" + "strings" "github.com/gin-gonic/gin" @@ -349,6 +351,98 @@ func (h *ChatSessionHandler) GetSession(c *gin.Context) { jsonResponse(c, common.CodeSuccess, result, "success") } +// CreateSession create a session in a dialog +func (h *ChatSessionHandler) CreateSession(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + userID := strings.TrimSpace(user.ID) + if userID == "" { + jsonError(c, common.CodeBadRequest, "user_id is required") + return + } + + chatID := strings.TrimSpace(c.Param("chat_id")) + if chatID == "" { + jsonError(c, common.CodeBadRequest, "chat_id is required") + return + } + + req := map[string]interface{}{} + if err := json.NewDecoder(c.Request.Body).Decode(&req); err != nil { + if errors.Is(err, io.EOF) { + req = map[string]interface{}{} + } else { + jsonError(c, common.CodeArgumentError, err.Error()) + return + } + } + if req == nil { + req = map[string]interface{}{} + } + + result, code, err := h.chatSessionService.CreateSession(userID, chatID, req) + if err != nil { + if code == common.CodeAuthenticationError { + jsonResponse(c, code, false, err.Error()) + return + } + jsonError(c, code, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, result, "success") +} + +// DeleteSessions delete a session in a dialog +func (h *ChatSessionHandler) DeleteSessions(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + chatID := strings.TrimSpace(c.Param("chat_id")) + if chatID == "" { + jsonError(c, common.CodeBadRequest, "chat_id is required") + return + } + + userID := strings.TrimSpace(user.ID) + if userID == "" { + jsonError(c, common.CodeBadRequest, "user_id is required") + return + } + + req := map[string]interface{}{} + if err := json.NewDecoder(c.Request.Body).Decode(&req); err != nil { + if errors.Is(err, io.EOF) { + req = map[string]interface{}{} + } else { + jsonError(c, common.CodeArgumentError, err.Error()) + return + } + } + if req == nil { + req = map[string]interface{}{} + } + + result, message, code, err := h.chatSessionService.DeleteSessions(userID, chatID, req) + if err != nil { + if code == common.CodeAuthenticationError { + jsonResponse(c, code, false, err.Error()) + return + } + jsonError(c, code, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, result, message) +} + func (h *ChatSessionHandler) UpdateSession(c *gin.Context) { user, errorCode, errorMessage := GetUser(c) if errorCode != common.CodeSuccess { diff --git a/internal/router/router.go b/internal/router/router.go index e4441ef475..8816cb9145 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -253,12 +253,15 @@ func (r *Router) Setup(engine *gin.Engine) { chats := v1.Group("/chats") { chats.GET("", r.chatHandler.ListChats) + chats.POST("", r.chatHandler.Create) chats.DELETE("", r.chatHandler.BulkDeleteChats) chats.DELETE("/:chat_id", r.chatHandler.DeleteChat) chats.GET("/:chat_id", r.chatHandler.GetChat) chats.PUT("/:chat_id", r.chatHandler.UpdateChat) chats.PATCH("/:chat_id", r.chatHandler.PatchChat) chats.GET("/:chat_id/sessions", r.chatSessionHandler.ListChatSessions) + chats.POST("/:chat_id/sessions", r.chatSessionHandler.CreateSession) + chats.DELETE("/:chat_id/sessions", r.chatSessionHandler.DeleteSessions) chats.GET("/:chat_id/sessions/:session_id", r.chatSessionHandler.GetSession) chats.PATCH("/:chat_id/sessions/:session_id", r.chatSessionHandler.UpdateSession) } diff --git a/internal/service/chat.go b/internal/service/chat.go index 46e3218ea6..6c57654ad1 100644 --- a/internal/service/chat.go +++ b/internal/service/chat.go @@ -17,6 +17,7 @@ package service import ( + "encoding/json" "errors" "fmt" "ragflow/internal/common" @@ -28,6 +29,43 @@ import ( "ragflow/internal/dao" ) +var DefaultPromptConfig = PromptConfig{ + System: strPtr(pyDefaultSystemPrompt), + Prologue: strPtr(pyDefaultPrologue), + Parameters: []ParameterConfig{ + {Key: "knowledge", Optional: false}, + }, + EmptyResponse: strPtr(pyDefaultEmptyResponse), + Quote: boolPtr(true), + TTS: boolPtr(false), + RefineMultiturn: boolPtr(true), +} + +var DefaultDirectChatPromptConfig = PromptConfig{ + System: strPtr(""), + Prologue: strPtr(""), + Parameters: []ParameterConfig{}, + EmptyResponse: strPtr(""), + Quote: boolPtr(false), + TTS: boolPtr(false), + RefineMultiturn: boolPtr(true), +} + +var DefaultRerankModels = map[string]struct{}{ + "BAAI/bge-reranker-v2-m3": {}, + "maidalun1020/bce-reranker-base_v1": {}, +} + +var ReadOnlyFields = map[string]struct{}{ + "id": {}, + "tenant_id": {}, + "created_by": {}, + "create_time": {}, + "create_date": {}, + "update_time": {}, + "update_date": {}, +} + // ChatService chat service type ChatService struct { chatDAO *dao.ChatDAO @@ -114,6 +152,488 @@ func (s *ChatService) ListChats(userID, status, keywords string, page, pageSize }, nil } +type CreateChatRequest struct { + Name string + DatasetIDs []string `json:"dataset_ids"` + KBIDs []string `json:"kb_ids"` + LLMID *string `json:"llm_id"` + LLMSetting map[string]interface{} `json:"llm_setting"` + RerankID *string `json:"rerank_id"` + PromptConfig map[string]interface{} `json:"prompt_config"` + Description *string + TopN *int + TopK *int + SimilarityThreshold *float64 + VectorSimilarityWeight *float64 + Icon *string + TenantID *string `json:"tenant_id"` +} + +func (s *ChatService) Create(userID string, req map[string]interface{}) (map[string]interface{}, common.ErrorCode, error) { + tenant, err := s.tenantDAO.GetByID(userID) + if err != nil { + return nil, common.CodeDataError, errors.New("Tenant not found!") + } + + if tenantValue, ok := req["tenant_id"]; ok && isTruthy(tenantValue) { + return nil, common.CodeDataError, errors.New("`tenant_id` must not be provided.") + } + + name, err := validateCreateChatName(req["name"]) + if err != nil { + return nil, common.CodeDataError, err + } + req["name"] = name + + if datasetIDsValue, ok := req["dataset_ids"]; ok { + kbIDs, err := s.validateCreateDatasetIDs(datasetIDsValue, userID) + if err != nil { + return nil, common.CodeDataError, err + } + req["kb_ids"] = kbIDs + delete(req, "dataset_ids") + } + + if llmIDValue, ok := req["llm_id"]; ok { + llmID := stringFromValue(llmIDValue) + llmSetting, _ := mapFromValue(req["llm_setting"]) + if err = validateCreateLLMID(llmID, userID, llmSetting); err != nil { + return nil, common.CodeDataError, err + } + } + + if rerankIDValue, ok := req["rerank_id"]; ok { + rerankID := stringFromValue(rerankIDValue) + if err = validateCreateRerankID(rerankID, userID); err != nil { + return nil, common.CodeDataError, err + } + } + + if promptConfigValue, ok := req["prompt_config"]; ok { + if _, ok := mapFromValue(promptConfigValue); !ok { + return nil, common.CodeDataError, errors.New("`prompt_config` should be an object.") + } + } + + if _, ok := req["kb_ids"]; !ok { + req["kb_ids"] = []string{} + } + if _, ok := req["llm_id"]; !ok || req["llm_id"] == nil { + req["llm_id"] = tenant.LLMID + } + if _, ok := req["llm_setting"]; !ok || req["llm_setting"] == nil { + req["llm_setting"] = map[string]interface{}{} + } + if _, ok := req["description"]; !ok { + req["description"] = "A helpful Assistant" + } + if _, ok := req["top_n"]; !ok { + req["top_n"] = 6 + } + if _, ok := req["top_k"]; !ok { + req["top_k"] = 1024 + } + if _, ok := req["rerank_id"]; !ok { + req["rerank_id"] = "" + } + if _, ok := req["similarity_threshold"]; !ok { + req["similarity_threshold"] = 0.1 + } + if _, ok := req["vector_similarity_weight"]; !ok { + req["vector_similarity_weight"] = 0.3 + } + if _, ok := req["icon"]; !ok { + req["icon"] = "" + } + + applyCreatePromptDefaults(req) + filterCreateChatPersistedFields(req) + + exists, err := s.chatDAO.ExistsByNameTenantStatus(name, userID, string(entity.StatusValid)) + if err != nil { + return nil, common.CodeServerError, err + } + if exists { + return nil, common.CodeDataError, errors.New("Duplicated chat name in creating chat.") + } + + chat := buildCreateChatEntity(req, userID) + if err = s.chatDAO.Create(chat); err != nil { + return nil, common.CodeDataError, errors.New("Failed to create chat.") + } + + chat, err = s.chatDAO.GetByID(chat.ID) + if err != nil { + return nil, common.CodeDataError, errors.New("Failed to retrieve created chat.") + } + + response, err := s.buildCreateChatResponse(chat) + if err != nil { + return nil, common.CodeServerError, err + } + return response, common.CodeSuccess, nil +} + +func validateCreateChatName(value interface{}) (string, error) { + if value == nil { + return "", errors.New("`name` is required.") + } + name, ok := value.(string) + if !ok { + return "", errors.New("Chat name must be a string.") + } + name = strings.TrimSpace(name) + if name == "" { + return "", errors.New("`name` is required.") + } + if len([]byte(name)) > 255 { + return "", fmt.Errorf("Chat name length is %d which is larger than 255.", len([]byte(name))) + } + return name, nil +} + +func (s *ChatService) validateCreateDatasetIDs(value interface{}, tenantID string) ([]string, error) { + if value == nil { + return []string{}, nil + } + values, ok := listFromValue(value) + if !ok { + return nil, errors.New("`dataset_ids` should be a list.") + } + + normalizedIDs := make([]string, 0, len(values)) + kbs := make([]*entity.Knowledgebase, 0, len(values)) + for _, item := range values { + if !isTruthy(item) { + continue + } + datasetID := stringFromValue(item) + normalizedIDs = append(normalizedIDs, datasetID) + } + + for _, datasetID := range normalizedIDs { + if !s.kbDAO.Accessible(datasetID, tenantID) { + return nil, fmt.Errorf("You don't own the dataset %s", datasetID) + } + kb, err := s.kbDAO.GetByID(datasetID) + if err != nil { + return nil, fmt.Errorf("You don't own the dataset %s", datasetID) + } + if kb.ChunkNum == 0 { + return nil, fmt.Errorf("The dataset %s doesn't own parsed file", datasetID) + } + kbs = append(kbs, kb) + } + + embedIDs := make(map[string]struct{}, len(kbs)) + for _, kb := range kbs { + embedIDs[s.splitModelNameAndFactory(kb.EmbdID)] = struct{}{} + } + if len(embedIDs) > 1 { + return nil, fmt.Errorf("Datasets use different embedding models: %v", getEmbdIDs(kbs)) + } + return normalizedIDs, nil +} + +func validateCreateLLMID(llmID, tenantID string, llmSetting map[string]interface{}) error { + if llmID == "" { + return nil + } + modelType := entity.ModelTypeChat + switch confModelType := llmSetting["model_type"].(type) { + case string: + if confModelType == string(entity.ModelTypeImage2Text) { + modelType = entity.ModelTypeImage2Text + } + case []interface{}: + for _, item := range confModelType { + if item == string(entity.ModelTypeImage2Text) { + modelType = entity.ModelTypeImage2Text + break + } + } + case []string: + for _, item := range confModelType { + if item == string(entity.ModelTypeImage2Text) { + modelType = entity.ModelTypeImage2Text + break + } + } + } + if _, _, _, _, err := NewModelProviderService().GetModelConfigFromProviderInstance(tenantID, modelType, llmID); err != nil { + return fmt.Errorf("`llm_id` %s doesn't exist", llmID) + } + return nil +} + +func validateCreateRerankID(rerankID, tenantID string) error { + if rerankID == "" { + return nil + } + llmName := strings.Split(rerankID, "@")[0] + if _, ok := DefaultRerankModels[llmName]; ok { + return nil + } + if _, _, _, _, err := NewModelProviderService().GetModelConfigFromProviderInstance(tenantID, entity.ModelTypeRerank, rerankID); err != nil { + return fmt.Errorf("`rerank_id` %s doesn't exist", rerankID) + } + return nil +} + +func applyCreatePromptDefaults(req map[string]interface{}) { + promptConfig, _ := mapFromValue(req["prompt_config"]) + if promptConfig == nil { + promptConfig = map[string]interface{}{} + } + if system, ok := promptConfig["system"]; !ok || !isTruthy(system) { + promptConfig["system"] = pyDefaultSystemPrompt + } + if _, ok := promptConfig["prologue"]; !ok { + promptConfig["prologue"] = pyDefaultPrologue + } + if _, ok := promptConfig["parameters"]; !ok { + promptConfig["parameters"] = []interface{}{map[string]interface{}{"key": "knowledge", "optional": false}} + } + if _, ok := promptConfig["empty_response"]; !ok { + promptConfig["empty_response"] = pyDefaultEmptyResponse + } + if _, ok := promptConfig["quote"]; !ok { + promptConfig["quote"] = true + } + if _, ok := promptConfig["tts"]; !ok { + promptConfig["tts"] = false + } + if _, ok := promptConfig["refine_multiturn"]; !ok { + promptConfig["refine_multiturn"] = true + } + + kbIDs, _ := listFromValue(req["kb_ids"]) + system, _ := promptConfig["system"].(string) + if len(kbIDs) > 0 && !isTruthy(promptConfig["parameters"]) && strings.Contains(system, "{knowledge}") { + promptConfig["parameters"] = []interface{}{map[string]interface{}{"key": "knowledge", "optional": false}} + } + req["prompt_config"] = promptConfig +} + +func filterCreateChatPersistedFields(req map[string]interface{}) { + persisted := map[string]struct{}{ + "name": {}, "description": {}, "icon": {}, "language": {}, "llm_id": {}, "tenant_llm_id": {}, + "llm_setting": {}, "prompt_type": {}, "prompt_config": {}, "meta_data_filter": {}, + "similarity_threshold": {}, "vector_similarity_weight": {}, "top_n": {}, "top_k": {}, + "do_refer": {}, "rerank_id": {}, "tenant_rerank_id": {}, "kb_ids": {}, "status": {}, + } + for key := range req { + if _, ok := persisted[key]; !ok { + delete(req, key) + } + } + for key := range ReadOnlyFields { + delete(req, key) + } +} + +func buildCreateChatEntity(req map[string]interface{}, tenantID string) *entity.Chat { + name := stringFromValue(req["name"]) + description := stringFromValue(req["description"]) + icon := stringFromValue(req["icon"]) + llmID := stringFromValue(req["llm_id"]) + rerankID := stringFromValue(req["rerank_id"]) + llmSetting, _ := mapFromValue(req["llm_setting"]) + promptConfig, _ := mapFromValue(req["prompt_config"]) + kbIDs, _ := stringListFromValue(req["kb_ids"]) + kbIDsJSON := make(entity.JSONSlice, 0, len(kbIDs)) + for _, id := range kbIDs { + kbIDsJSON = append(kbIDsJSON, id) + } + status, hasStatus := req["status"] + statusValue := string(entity.StatusValid) + if hasStatus { + statusValue = stringFromValue(status) + } + + chat := &entity.Chat{ + ID: common.GenerateUUID(), + TenantID: tenantID, + Name: &name, + Description: &description, + Icon: &icon, + LLMID: llmID, + LLMSetting: entity.JSONMap(llmSetting), + PromptType: stringFromValue(req["prompt_type"]), + PromptConfig: entity.JSONMap(promptConfig), + SimilarityThreshold: floatFromValue(req["similarity_threshold"]), + VectorSimilarityWeight: floatFromValue(req["vector_similarity_weight"]), + TopN: int64FromValue(req["top_n"]), + TopK: int64FromValue(req["top_k"]), + DoRefer: stringFromValue(req["do_refer"]), + RerankID: rerankID, + KBIDs: kbIDsJSON, + Status: &statusValue, + } + if chat.PromptType == "" { + chat.PromptType = "simple" + } + if chat.DoRefer == "" { + chat.DoRefer = "1" + } + if language := stringFromValue(req["language"]); language != "" { + chat.Language = &language + } + if metaDataFilter, ok := mapFromValue(req["meta_data_filter"]); ok { + metaDataFilterJSON := entity.JSONMap(metaDataFilter) + chat.MetaDataFilter = &metaDataFilterJSON + } + return chat +} + +func (s *ChatService) buildCreateChatResponse(chat *entity.Chat) (map[string]interface{}, error) { + data, err := structToMap(chat) + if err != nil { + return nil, err + } + kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs) + data["dataset_ids"] = datasetIDs + delete(data, "kb_ids") + data["kb_names"] = kbNames + return data, nil +} + +func structToMap(value interface{}) (map[string]interface{}, error) { + bytes, err := json.Marshal(value) + if err != nil { + return nil, err + } + result := map[string]interface{}{} + if err = json.Unmarshal(bytes, &result); err != nil { + return nil, err + } + return result, nil +} + +func stringFromValue(value interface{}) string { + switch typed := value.(type) { + case nil: + return "" + case string: + return typed + default: + return fmt.Sprint(typed) + } +} + +func mapFromValue(value interface{}) (map[string]interface{}, bool) { + switch typed := value.(type) { + case nil: + return nil, false + case map[string]interface{}: + return typed, true + case entity.JSONMap: + return map[string]interface{}(typed), true + default: + return nil, false + } +} + +func listFromValue(value interface{}) ([]interface{}, bool) { + switch typed := value.(type) { + case nil: + return nil, false + case []interface{}: + return typed, true + case []string: + result := make([]interface{}, 0, len(typed)) + for _, item := range typed { + result = append(result, item) + } + return result, true + case entity.JSONSlice: + return []interface{}(typed), true + default: + return nil, false + } +} + +func stringListFromValue(value interface{}) ([]string, bool) { + values, ok := listFromValue(value) + if !ok { + return nil, false + } + result := make([]string, 0, len(values)) + for _, item := range values { + if !isTruthy(item) { + continue + } + result = append(result, stringFromValue(item)) + } + return result, true +} + +func int64FromValue(value interface{}) int64 { + switch typed := value.(type) { + case int: + return int64(typed) + case int64: + return typed + case float64: + return int64(typed) + case json.Number: + n, err := typed.Int64() + if err == nil { + return n + } + f, _ := typed.Float64() + return int64(f) + default: + return 0 + } +} + +func floatFromValue(value interface{}) float64 { + switch typed := value.(type) { + case float64: + return typed + case float32: + return float64(typed) + case int: + return float64(typed) + case int64: + return float64(typed) + case json.Number: + n, _ := typed.Float64() + return n + default: + return 0 + } +} + +func isTruthy(value interface{}) bool { + switch typed := value.(type) { + case nil: + return false + case bool: + return typed + case string: + return typed != "" + case int: + return typed != 0 + case int64: + return typed != 0 + case float64: + return typed != 0 + case json.Number: + n, err := typed.Float64() + return err != nil || n != 0 + case []interface{}: + return len(typed) > 0 + case []string: + return len(typed) > 0 + case map[string]interface{}: + return len(typed) > 0 + default: + return true + } +} + // ListChatsNextRequest list chats next request type ListChatsNextRequest struct { OwnerIDs []string `json:"owner_ids,omitempty"` @@ -1169,6 +1689,10 @@ func strPtr(s string) *string { return &s } +func boolPtr(b bool) *bool { + return &b +} + // Helper to count UTF-8 characters (not bytes) func (s *ChatService) countRunes(str string) int { return utf8.RuneCountInString(str) diff --git a/internal/service/chat_session.go b/internal/service/chat_session.go index ac3b8cd21f..718a766409 100644 --- a/internal/service/chat_session.go +++ b/internal/service/chat_session.go @@ -22,9 +22,11 @@ import ( "errors" "fmt" "ragflow/internal/common" + "ragflow/internal/storage" "strings" "time" + "go.uber.org/zap" "gorm.io/gorm" "ragflow/internal/dao" @@ -306,6 +308,230 @@ func (s *ChatSessionService) GetSession(userID, chatID, sessionID string) (*Chat return s.buildSessionPayload(session, dialog, true), common.CodeSuccess, nil } +// CreateSession create a session in a dialog +func (s *ChatSessionService) CreateSession(userID, chatID string, req map[string]interface{}) (*ChatSessionPayload, common.ErrorCode, error) { + ok, err := s.ensureOwnedChat(userID, chatID) + if err != nil { + return nil, common.CodeServerError, err + } + if !ok { + return nil, common.CodeAuthenticationError, errors.New("No authorization.") + } + + dialog, err := s.chatSessionDAO.GetDialogByID(chatID) + if err != nil { + if isChatSessionNotFound(err) { + return nil, common.CodeDataError, errors.New("Chat not found!") + } + return nil, common.CodeServerError, err + } + + name := "New session" + if rawName, exists := req["name"]; exists { + nameStr, ok := rawName.(string) + if !ok || strings.TrimSpace(nameStr) == "" { + return nil, common.CodeDataError, errors.New("`name` can not be empty.") + } + name = strings.TrimSpace(nameStr) + } + nameRunes := []rune(name) + if len(nameRunes) > 255 { + name = string(nameRunes[:255]) + } + + prologue := "" + if dialog.PromptConfig != nil { + if value, ok := dialog.PromptConfig["prologue"].(string); ok { + prologue = value + } + } + messagesJSON, _ := json.Marshal([]map[string]interface{}{ + { + "role": "assistant", + "content": prologue, + }, + }) + + referenceJSON, _ := json.Marshal([]interface{}{}) + + conv := &entity.ChatSession{ + ID: common.GenerateUUID(), + DialogID: chatID, + Name: &name, + Message: messagesJSON, + UserID: &userID, + Reference: referenceJSON, + } + + if err := s.chatSessionDAO.Create(conv); err != nil { + return nil, common.CodeDataError, errors.New("Fail to create a session!") + } + + session, err := s.chatSessionDAO.GetByID(conv.ID) + if err != nil { + return nil, common.CodeDataError, errors.New("Fail to create a session!") + } + return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil +} + +// DeleteSessions delete a session in a dialog +func (s *ChatSessionService) DeleteSessions(userID, chatID string, req map[string]interface{}) (interface{}, string, common.ErrorCode, error) { + ok, err := s.ensureOwnedChat(userID, chatID) + if err != nil { + return nil, "", common.CodeServerError, err + } + if !ok { + return false, "No authorization.", common.CodeAuthenticationError, errors.New("No authorization.") + } + + if len(req) == 0 { + return map[string]interface{}{}, "success", common.CodeSuccess, nil + } + + sessionIDs, hasIDs := stringSliceFromValue(req["ids"]) + if !hasIDs || len(sessionIDs) == 0 { + deleteAll, _ := req["delete_all"].(bool) + if deleteAll { + sessions, err := s.chatSessionDAO.ListByChatID(chatID) + if err != nil { + return nil, "", common.CodeServerError, err + } + for _, session := range sessions { + sessionIDs = append(sessionIDs, session.ID) + } + if len(sessionIDs) == 0 { + return map[string]interface{}{}, "success", common.CodeSuccess, nil + } + } else { + return map[string]interface{}{}, "success", common.CodeSuccess, nil + } + } + + uniqueIDs, duplicateMessages := checkDuplicateChatSessionIDs(sessionIDs) + + errorsList := make([]string, 0) + successCount := 0 + + for _, sid := range uniqueIDs { + session, err := s.chatSessionDAO.GetBySessionIDAndChatID(sid, chatID) + if err != nil { + errorsList = append(errorsList, fmt.Sprintf("The chat doesn't own the session %s", sid)) + continue + } + + s.removeSessionUploadFiles(userID, session) + + if err := s.chatSessionDAO.DeleteByID(sid); err != nil { + return nil, "", common.CodeServerError, err + } + + successCount++ + } + + allErrors := append(errorsList, duplicateMessages...) + + if len(allErrors) > 0 { + if successCount > 0 { + return map[string]interface{}{ + "success_count": successCount, + "errors": allErrors, + }, fmt.Sprintf("Partially deleted %d sessions with %d errors", successCount, len(allErrors)), common.CodeSuccess, nil + } + + return nil, "", common.CodeDataError, errors.New(strings.Join(allErrors, "; ")) + } + + return true, "success", common.CodeSuccess, nil +} + +func stringSliceFromValue(value interface{}) ([]string, bool) { + var raw []interface{} + switch typed := value.(type) { + case []interface{}: + raw = typed + case []string: + raw = make([]interface{}, 0, len(typed)) + for _, item := range typed { + raw = append(raw, item) + } + default: + return nil, false + } + + ids := make([]string, 0, len(raw)) + for _, item := range raw { + id, ok := item.(string) + if !ok { + continue + } + if strings.TrimSpace(id) == "" { + continue + } + ids = append(ids, id) + } + return ids, true +} + +func (s *ChatSessionService) removeSessionUploadFiles(userID string, session *entity.ChatSession) { + messages := parseMessages(session.Message) + bucket := fmt.Sprintf("%s-downloads", userID) + storageImpl := storage.GetStorageFactory().GetStorage() + if storageImpl == nil { + common.Warn("storage is not initialized; skip chat upload cleanup", zap.String("bucket", bucket)) + return + } + + for _, msg := range messages { + files, ok := msg["files"].([]interface{}) + if !ok { + continue + } + + for _, item := range files { + file, ok := item.(map[string]interface{}) + if !ok { + continue + } + + fileID, ok := file["id"].(string) + if !ok || fileID == "" { + continue + } + + if err := storageImpl.Remove(bucket, fileID); err != nil { + common.Warn("Failed to delete chat upload blob", + zap.String("bucket", bucket), + zap.String("file_id", fileID), + zap.Error(err), + ) + } + } + } +} + +func checkDuplicateChatSessionIDs(ids []string) ([]string, []string) { + idCount := make(map[string]int, len(ids)) + uniqueIDs := make([]string, 0, len(ids)) + for _, id := range ids { + id = strings.TrimSpace(id) + if id == "" { + continue + } + idCount[id]++ + if idCount[id] == 1 { + uniqueIDs = append(uniqueIDs, id) + } + } + + duplicateMessages := make([]string, 0) + for id, count := range idCount { + if count > 1 { + duplicateMessages = append(duplicateMessages, fmt.Sprintf("Duplicate session ids: %s", id)) + } + } + return uniqueIDs, duplicateMessages +} + // UpdateSession updates one chat session after Python-style field validation. func (s *ChatSessionService) UpdateSession(userID, chatID, sessionID string, req map[string]interface{}) (*ChatSessionPayload, common.ErrorCode, error) { ok, err := s.ensureOwnedChat(userID, chatID)