feat[Go]: implement Create-Chat/Session, Delete-Session (#16386)

### What problem does this PR solve?

As title:
implement:
```go
chats.POST("", r.chatHandler.Create)
chats.POST("/:chat_id/sessions", r.chatSessionHandler.CreateSession)
chats.DELETE("/:chat_id/sessions", r.chatSessionHandler.DeleteSessions)
```

bug fixed:

f80d4c7843/internal/handler/chat.go (L84)
↓
```go
result, err := h.chatService.ListChats(userID, "1", keywords, page, pageSize, orderby, desc)
```

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
This commit is contained in:
Haruko386
2026-06-26 19:23:45 +08:00
committed by GitHub
parent e3063da390
commit a57a841a11
6 changed files with 898 additions and 1 deletions

View File

@@ -169,6 +169,15 @@ func (dao *ChatDAO) GetExistingNames(tenantID string, status string) ([]string,
return names, err
}
// ExistsByNameTenantStatus checks whether a chat with the given name exists.
func (dao *ChatDAO) ExistsByNameTenantStatus(name, tenantID, status string) (bool, error) {
var count int64
err := DB.Model(&entity.Chat{}).
Where("name = ? AND tenant_id = ? AND status = ?", name, tenantID, status).
Count(&count).Error
return count > 0, err
}
// Create creates a new chat/dialog
func (dao *ChatDAO) Create(chat *entity.Chat) error {
return DB.Create(chat).Error

View File

@@ -17,6 +17,7 @@
package handler
import (
"encoding/json"
"net/http"
"ragflow/internal/common"
"strconv"
@@ -81,7 +82,7 @@ func (h *ChatHandler) ListChats(c *gin.Context) {
}
// List chats - default to valid status "1" (same as Python StatusEnum.VALID.value)
result, err := h.chatService.ListChats(userID, keywords, "1", page, pageSize, orderby, desc)
result, err := h.chatService.ListChats(userID, "1", keywords, page, pageSize, orderby, desc)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
@@ -97,6 +98,46 @@ func (h *ChatHandler) ListChats(c *gin.Context) {
})
}
// Create creates a chat.
// @Summary Create Chat
// @Description Create a chat, aligned with Python POST /api/v1/chats.
// @Tags chat
// @Accept json
// @Produce json
// @Param request body service.CreateChatRequest true "chat configuration"
// @Success 200 {object} map[string]interface{}
// @Router /api/v1/chats [post]
func (h *ChatHandler) Create(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
var req map[string]interface{}
decoder := json.NewDecoder(c.Request.Body)
decoder.UseNumber()
if err := decoder.Decode(&req); err != nil {
jsonError(c, common.CodeArgumentError, err.Error())
return
}
if req == nil {
req = map[string]interface{}{}
}
result, code, err := h.chatService.Create(user.ID, req)
if err != nil {
jsonError(c, code, err.Error())
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"data": result,
"message": "success",
})
}
// ListChatsNext list chats with advanced filtering and pagination
// @Summary List Chats Next
// @Description Get list of chats with filtering, pagination and sorting (equivalent to list_dialogs_next)

View File

@@ -17,11 +17,13 @@
package handler
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"ragflow/internal/common"
"strings"
"github.com/gin-gonic/gin"
@@ -349,6 +351,98 @@ func (h *ChatSessionHandler) GetSession(c *gin.Context) {
jsonResponse(c, common.CodeSuccess, result, "success")
}
// CreateSession create a session in a dialog
func (h *ChatSessionHandler) CreateSession(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := strings.TrimSpace(user.ID)
if userID == "" {
jsonError(c, common.CodeBadRequest, "user_id is required")
return
}
chatID := strings.TrimSpace(c.Param("chat_id"))
if chatID == "" {
jsonError(c, common.CodeBadRequest, "chat_id is required")
return
}
req := map[string]interface{}{}
if err := json.NewDecoder(c.Request.Body).Decode(&req); err != nil {
if errors.Is(err, io.EOF) {
req = map[string]interface{}{}
} else {
jsonError(c, common.CodeArgumentError, err.Error())
return
}
}
if req == nil {
req = map[string]interface{}{}
}
result, code, err := h.chatSessionService.CreateSession(userID, chatID, req)
if err != nil {
if code == common.CodeAuthenticationError {
jsonResponse(c, code, false, err.Error())
return
}
jsonError(c, code, err.Error())
return
}
jsonResponse(c, common.CodeSuccess, result, "success")
}
// DeleteSessions delete a session in a dialog
func (h *ChatSessionHandler) DeleteSessions(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
chatID := strings.TrimSpace(c.Param("chat_id"))
if chatID == "" {
jsonError(c, common.CodeBadRequest, "chat_id is required")
return
}
userID := strings.TrimSpace(user.ID)
if userID == "" {
jsonError(c, common.CodeBadRequest, "user_id is required")
return
}
req := map[string]interface{}{}
if err := json.NewDecoder(c.Request.Body).Decode(&req); err != nil {
if errors.Is(err, io.EOF) {
req = map[string]interface{}{}
} else {
jsonError(c, common.CodeArgumentError, err.Error())
return
}
}
if req == nil {
req = map[string]interface{}{}
}
result, message, code, err := h.chatSessionService.DeleteSessions(userID, chatID, req)
if err != nil {
if code == common.CodeAuthenticationError {
jsonResponse(c, code, false, err.Error())
return
}
jsonError(c, code, err.Error())
return
}
jsonResponse(c, common.CodeSuccess, result, message)
}
func (h *ChatSessionHandler) UpdateSession(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {

View File

@@ -253,12 +253,15 @@ func (r *Router) Setup(engine *gin.Engine) {
chats := v1.Group("/chats")
{
chats.GET("", r.chatHandler.ListChats)
chats.POST("", r.chatHandler.Create)
chats.DELETE("", r.chatHandler.BulkDeleteChats)
chats.DELETE("/:chat_id", r.chatHandler.DeleteChat)
chats.GET("/:chat_id", r.chatHandler.GetChat)
chats.PUT("/:chat_id", r.chatHandler.UpdateChat)
chats.PATCH("/:chat_id", r.chatHandler.PatchChat)
chats.GET("/:chat_id/sessions", r.chatSessionHandler.ListChatSessions)
chats.POST("/:chat_id/sessions", r.chatSessionHandler.CreateSession)
chats.DELETE("/:chat_id/sessions", r.chatSessionHandler.DeleteSessions)
chats.GET("/:chat_id/sessions/:session_id", r.chatSessionHandler.GetSession)
chats.PATCH("/:chat_id/sessions/:session_id", r.chatSessionHandler.UpdateSession)
}

View File

@@ -17,6 +17,7 @@
package service
import (
"encoding/json"
"errors"
"fmt"
"ragflow/internal/common"
@@ -28,6 +29,43 @@ import (
"ragflow/internal/dao"
)
var DefaultPromptConfig = PromptConfig{
System: strPtr(pyDefaultSystemPrompt),
Prologue: strPtr(pyDefaultPrologue),
Parameters: []ParameterConfig{
{Key: "knowledge", Optional: false},
},
EmptyResponse: strPtr(pyDefaultEmptyResponse),
Quote: boolPtr(true),
TTS: boolPtr(false),
RefineMultiturn: boolPtr(true),
}
var DefaultDirectChatPromptConfig = PromptConfig{
System: strPtr(""),
Prologue: strPtr(""),
Parameters: []ParameterConfig{},
EmptyResponse: strPtr(""),
Quote: boolPtr(false),
TTS: boolPtr(false),
RefineMultiturn: boolPtr(true),
}
var DefaultRerankModels = map[string]struct{}{
"BAAI/bge-reranker-v2-m3": {},
"maidalun1020/bce-reranker-base_v1": {},
}
var ReadOnlyFields = map[string]struct{}{
"id": {},
"tenant_id": {},
"created_by": {},
"create_time": {},
"create_date": {},
"update_time": {},
"update_date": {},
}
// ChatService chat service
type ChatService struct {
chatDAO *dao.ChatDAO
@@ -114,6 +152,488 @@ func (s *ChatService) ListChats(userID, status, keywords string, page, pageSize
}, nil
}
type CreateChatRequest struct {
Name string
DatasetIDs []string `json:"dataset_ids"`
KBIDs []string `json:"kb_ids"`
LLMID *string `json:"llm_id"`
LLMSetting map[string]interface{} `json:"llm_setting"`
RerankID *string `json:"rerank_id"`
PromptConfig map[string]interface{} `json:"prompt_config"`
Description *string
TopN *int
TopK *int
SimilarityThreshold *float64
VectorSimilarityWeight *float64
Icon *string
TenantID *string `json:"tenant_id"`
}
func (s *ChatService) Create(userID string, req map[string]interface{}) (map[string]interface{}, common.ErrorCode, error) {
tenant, err := s.tenantDAO.GetByID(userID)
if err != nil {
return nil, common.CodeDataError, errors.New("Tenant not found!")
}
if tenantValue, ok := req["tenant_id"]; ok && isTruthy(tenantValue) {
return nil, common.CodeDataError, errors.New("`tenant_id` must not be provided.")
}
name, err := validateCreateChatName(req["name"])
if err != nil {
return nil, common.CodeDataError, err
}
req["name"] = name
if datasetIDsValue, ok := req["dataset_ids"]; ok {
kbIDs, err := s.validateCreateDatasetIDs(datasetIDsValue, userID)
if err != nil {
return nil, common.CodeDataError, err
}
req["kb_ids"] = kbIDs
delete(req, "dataset_ids")
}
if llmIDValue, ok := req["llm_id"]; ok {
llmID := stringFromValue(llmIDValue)
llmSetting, _ := mapFromValue(req["llm_setting"])
if err = validateCreateLLMID(llmID, userID, llmSetting); err != nil {
return nil, common.CodeDataError, err
}
}
if rerankIDValue, ok := req["rerank_id"]; ok {
rerankID := stringFromValue(rerankIDValue)
if err = validateCreateRerankID(rerankID, userID); err != nil {
return nil, common.CodeDataError, err
}
}
if promptConfigValue, ok := req["prompt_config"]; ok {
if _, ok := mapFromValue(promptConfigValue); !ok {
return nil, common.CodeDataError, errors.New("`prompt_config` should be an object.")
}
}
if _, ok := req["kb_ids"]; !ok {
req["kb_ids"] = []string{}
}
if _, ok := req["llm_id"]; !ok || req["llm_id"] == nil {
req["llm_id"] = tenant.LLMID
}
if _, ok := req["llm_setting"]; !ok || req["llm_setting"] == nil {
req["llm_setting"] = map[string]interface{}{}
}
if _, ok := req["description"]; !ok {
req["description"] = "A helpful Assistant"
}
if _, ok := req["top_n"]; !ok {
req["top_n"] = 6
}
if _, ok := req["top_k"]; !ok {
req["top_k"] = 1024
}
if _, ok := req["rerank_id"]; !ok {
req["rerank_id"] = ""
}
if _, ok := req["similarity_threshold"]; !ok {
req["similarity_threshold"] = 0.1
}
if _, ok := req["vector_similarity_weight"]; !ok {
req["vector_similarity_weight"] = 0.3
}
if _, ok := req["icon"]; !ok {
req["icon"] = ""
}
applyCreatePromptDefaults(req)
filterCreateChatPersistedFields(req)
exists, err := s.chatDAO.ExistsByNameTenantStatus(name, userID, string(entity.StatusValid))
if err != nil {
return nil, common.CodeServerError, err
}
if exists {
return nil, common.CodeDataError, errors.New("Duplicated chat name in creating chat.")
}
chat := buildCreateChatEntity(req, userID)
if err = s.chatDAO.Create(chat); err != nil {
return nil, common.CodeDataError, errors.New("Failed to create chat.")
}
chat, err = s.chatDAO.GetByID(chat.ID)
if err != nil {
return nil, common.CodeDataError, errors.New("Failed to retrieve created chat.")
}
response, err := s.buildCreateChatResponse(chat)
if err != nil {
return nil, common.CodeServerError, err
}
return response, common.CodeSuccess, nil
}
func validateCreateChatName(value interface{}) (string, error) {
if value == nil {
return "", errors.New("`name` is required.")
}
name, ok := value.(string)
if !ok {
return "", errors.New("Chat name must be a string.")
}
name = strings.TrimSpace(name)
if name == "" {
return "", errors.New("`name` is required.")
}
if len([]byte(name)) > 255 {
return "", fmt.Errorf("Chat name length is %d which is larger than 255.", len([]byte(name)))
}
return name, nil
}
func (s *ChatService) validateCreateDatasetIDs(value interface{}, tenantID string) ([]string, error) {
if value == nil {
return []string{}, nil
}
values, ok := listFromValue(value)
if !ok {
return nil, errors.New("`dataset_ids` should be a list.")
}
normalizedIDs := make([]string, 0, len(values))
kbs := make([]*entity.Knowledgebase, 0, len(values))
for _, item := range values {
if !isTruthy(item) {
continue
}
datasetID := stringFromValue(item)
normalizedIDs = append(normalizedIDs, datasetID)
}
for _, datasetID := range normalizedIDs {
if !s.kbDAO.Accessible(datasetID, tenantID) {
return nil, fmt.Errorf("You don't own the dataset %s", datasetID)
}
kb, err := s.kbDAO.GetByID(datasetID)
if err != nil {
return nil, fmt.Errorf("You don't own the dataset %s", datasetID)
}
if kb.ChunkNum == 0 {
return nil, fmt.Errorf("The dataset %s doesn't own parsed file", datasetID)
}
kbs = append(kbs, kb)
}
embedIDs := make(map[string]struct{}, len(kbs))
for _, kb := range kbs {
embedIDs[s.splitModelNameAndFactory(kb.EmbdID)] = struct{}{}
}
if len(embedIDs) > 1 {
return nil, fmt.Errorf("Datasets use different embedding models: %v", getEmbdIDs(kbs))
}
return normalizedIDs, nil
}
func validateCreateLLMID(llmID, tenantID string, llmSetting map[string]interface{}) error {
if llmID == "" {
return nil
}
modelType := entity.ModelTypeChat
switch confModelType := llmSetting["model_type"].(type) {
case string:
if confModelType == string(entity.ModelTypeImage2Text) {
modelType = entity.ModelTypeImage2Text
}
case []interface{}:
for _, item := range confModelType {
if item == string(entity.ModelTypeImage2Text) {
modelType = entity.ModelTypeImage2Text
break
}
}
case []string:
for _, item := range confModelType {
if item == string(entity.ModelTypeImage2Text) {
modelType = entity.ModelTypeImage2Text
break
}
}
}
if _, _, _, _, err := NewModelProviderService().GetModelConfigFromProviderInstance(tenantID, modelType, llmID); err != nil {
return fmt.Errorf("`llm_id` %s doesn't exist", llmID)
}
return nil
}
func validateCreateRerankID(rerankID, tenantID string) error {
if rerankID == "" {
return nil
}
llmName := strings.Split(rerankID, "@")[0]
if _, ok := DefaultRerankModels[llmName]; ok {
return nil
}
if _, _, _, _, err := NewModelProviderService().GetModelConfigFromProviderInstance(tenantID, entity.ModelTypeRerank, rerankID); err != nil {
return fmt.Errorf("`rerank_id` %s doesn't exist", rerankID)
}
return nil
}
func applyCreatePromptDefaults(req map[string]interface{}) {
promptConfig, _ := mapFromValue(req["prompt_config"])
if promptConfig == nil {
promptConfig = map[string]interface{}{}
}
if system, ok := promptConfig["system"]; !ok || !isTruthy(system) {
promptConfig["system"] = pyDefaultSystemPrompt
}
if _, ok := promptConfig["prologue"]; !ok {
promptConfig["prologue"] = pyDefaultPrologue
}
if _, ok := promptConfig["parameters"]; !ok {
promptConfig["parameters"] = []interface{}{map[string]interface{}{"key": "knowledge", "optional": false}}
}
if _, ok := promptConfig["empty_response"]; !ok {
promptConfig["empty_response"] = pyDefaultEmptyResponse
}
if _, ok := promptConfig["quote"]; !ok {
promptConfig["quote"] = true
}
if _, ok := promptConfig["tts"]; !ok {
promptConfig["tts"] = false
}
if _, ok := promptConfig["refine_multiturn"]; !ok {
promptConfig["refine_multiturn"] = true
}
kbIDs, _ := listFromValue(req["kb_ids"])
system, _ := promptConfig["system"].(string)
if len(kbIDs) > 0 && !isTruthy(promptConfig["parameters"]) && strings.Contains(system, "{knowledge}") {
promptConfig["parameters"] = []interface{}{map[string]interface{}{"key": "knowledge", "optional": false}}
}
req["prompt_config"] = promptConfig
}
func filterCreateChatPersistedFields(req map[string]interface{}) {
persisted := map[string]struct{}{
"name": {}, "description": {}, "icon": {}, "language": {}, "llm_id": {}, "tenant_llm_id": {},
"llm_setting": {}, "prompt_type": {}, "prompt_config": {}, "meta_data_filter": {},
"similarity_threshold": {}, "vector_similarity_weight": {}, "top_n": {}, "top_k": {},
"do_refer": {}, "rerank_id": {}, "tenant_rerank_id": {}, "kb_ids": {}, "status": {},
}
for key := range req {
if _, ok := persisted[key]; !ok {
delete(req, key)
}
}
for key := range ReadOnlyFields {
delete(req, key)
}
}
func buildCreateChatEntity(req map[string]interface{}, tenantID string) *entity.Chat {
name := stringFromValue(req["name"])
description := stringFromValue(req["description"])
icon := stringFromValue(req["icon"])
llmID := stringFromValue(req["llm_id"])
rerankID := stringFromValue(req["rerank_id"])
llmSetting, _ := mapFromValue(req["llm_setting"])
promptConfig, _ := mapFromValue(req["prompt_config"])
kbIDs, _ := stringListFromValue(req["kb_ids"])
kbIDsJSON := make(entity.JSONSlice, 0, len(kbIDs))
for _, id := range kbIDs {
kbIDsJSON = append(kbIDsJSON, id)
}
status, hasStatus := req["status"]
statusValue := string(entity.StatusValid)
if hasStatus {
statusValue = stringFromValue(status)
}
chat := &entity.Chat{
ID: common.GenerateUUID(),
TenantID: tenantID,
Name: &name,
Description: &description,
Icon: &icon,
LLMID: llmID,
LLMSetting: entity.JSONMap(llmSetting),
PromptType: stringFromValue(req["prompt_type"]),
PromptConfig: entity.JSONMap(promptConfig),
SimilarityThreshold: floatFromValue(req["similarity_threshold"]),
VectorSimilarityWeight: floatFromValue(req["vector_similarity_weight"]),
TopN: int64FromValue(req["top_n"]),
TopK: int64FromValue(req["top_k"]),
DoRefer: stringFromValue(req["do_refer"]),
RerankID: rerankID,
KBIDs: kbIDsJSON,
Status: &statusValue,
}
if chat.PromptType == "" {
chat.PromptType = "simple"
}
if chat.DoRefer == "" {
chat.DoRefer = "1"
}
if language := stringFromValue(req["language"]); language != "" {
chat.Language = &language
}
if metaDataFilter, ok := mapFromValue(req["meta_data_filter"]); ok {
metaDataFilterJSON := entity.JSONMap(metaDataFilter)
chat.MetaDataFilter = &metaDataFilterJSON
}
return chat
}
func (s *ChatService) buildCreateChatResponse(chat *entity.Chat) (map[string]interface{}, error) {
data, err := structToMap(chat)
if err != nil {
return nil, err
}
kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs)
data["dataset_ids"] = datasetIDs
delete(data, "kb_ids")
data["kb_names"] = kbNames
return data, nil
}
func structToMap(value interface{}) (map[string]interface{}, error) {
bytes, err := json.Marshal(value)
if err != nil {
return nil, err
}
result := map[string]interface{}{}
if err = json.Unmarshal(bytes, &result); err != nil {
return nil, err
}
return result, nil
}
func stringFromValue(value interface{}) string {
switch typed := value.(type) {
case nil:
return ""
case string:
return typed
default:
return fmt.Sprint(typed)
}
}
func mapFromValue(value interface{}) (map[string]interface{}, bool) {
switch typed := value.(type) {
case nil:
return nil, false
case map[string]interface{}:
return typed, true
case entity.JSONMap:
return map[string]interface{}(typed), true
default:
return nil, false
}
}
func listFromValue(value interface{}) ([]interface{}, bool) {
switch typed := value.(type) {
case nil:
return nil, false
case []interface{}:
return typed, true
case []string:
result := make([]interface{}, 0, len(typed))
for _, item := range typed {
result = append(result, item)
}
return result, true
case entity.JSONSlice:
return []interface{}(typed), true
default:
return nil, false
}
}
func stringListFromValue(value interface{}) ([]string, bool) {
values, ok := listFromValue(value)
if !ok {
return nil, false
}
result := make([]string, 0, len(values))
for _, item := range values {
if !isTruthy(item) {
continue
}
result = append(result, stringFromValue(item))
}
return result, true
}
func int64FromValue(value interface{}) int64 {
switch typed := value.(type) {
case int:
return int64(typed)
case int64:
return typed
case float64:
return int64(typed)
case json.Number:
n, err := typed.Int64()
if err == nil {
return n
}
f, _ := typed.Float64()
return int64(f)
default:
return 0
}
}
func floatFromValue(value interface{}) float64 {
switch typed := value.(type) {
case float64:
return typed
case float32:
return float64(typed)
case int:
return float64(typed)
case int64:
return float64(typed)
case json.Number:
n, _ := typed.Float64()
return n
default:
return 0
}
}
func isTruthy(value interface{}) bool {
switch typed := value.(type) {
case nil:
return false
case bool:
return typed
case string:
return typed != ""
case int:
return typed != 0
case int64:
return typed != 0
case float64:
return typed != 0
case json.Number:
n, err := typed.Float64()
return err != nil || n != 0
case []interface{}:
return len(typed) > 0
case []string:
return len(typed) > 0
case map[string]interface{}:
return len(typed) > 0
default:
return true
}
}
// ListChatsNextRequest list chats next request
type ListChatsNextRequest struct {
OwnerIDs []string `json:"owner_ids,omitempty"`
@@ -1169,6 +1689,10 @@ func strPtr(s string) *string {
return &s
}
func boolPtr(b bool) *bool {
return &b
}
// Helper to count UTF-8 characters (not bytes)
func (s *ChatService) countRunes(str string) int {
return utf8.RuneCountInString(str)

View File

@@ -22,9 +22,11 @@ import (
"errors"
"fmt"
"ragflow/internal/common"
"ragflow/internal/storage"
"strings"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"ragflow/internal/dao"
@@ -306,6 +308,230 @@ func (s *ChatSessionService) GetSession(userID, chatID, sessionID string) (*Chat
return s.buildSessionPayload(session, dialog, true), common.CodeSuccess, nil
}
// CreateSession create a session in a dialog
func (s *ChatSessionService) CreateSession(userID, chatID string, req map[string]interface{}) (*ChatSessionPayload, common.ErrorCode, error) {
ok, err := s.ensureOwnedChat(userID, chatID)
if err != nil {
return nil, common.CodeServerError, err
}
if !ok {
return nil, common.CodeAuthenticationError, errors.New("No authorization.")
}
dialog, err := s.chatSessionDAO.GetDialogByID(chatID)
if err != nil {
if isChatSessionNotFound(err) {
return nil, common.CodeDataError, errors.New("Chat not found!")
}
return nil, common.CodeServerError, err
}
name := "New session"
if rawName, exists := req["name"]; exists {
nameStr, ok := rawName.(string)
if !ok || strings.TrimSpace(nameStr) == "" {
return nil, common.CodeDataError, errors.New("`name` can not be empty.")
}
name = strings.TrimSpace(nameStr)
}
nameRunes := []rune(name)
if len(nameRunes) > 255 {
name = string(nameRunes[:255])
}
prologue := ""
if dialog.PromptConfig != nil {
if value, ok := dialog.PromptConfig["prologue"].(string); ok {
prologue = value
}
}
messagesJSON, _ := json.Marshal([]map[string]interface{}{
{
"role": "assistant",
"content": prologue,
},
})
referenceJSON, _ := json.Marshal([]interface{}{})
conv := &entity.ChatSession{
ID: common.GenerateUUID(),
DialogID: chatID,
Name: &name,
Message: messagesJSON,
UserID: &userID,
Reference: referenceJSON,
}
if err := s.chatSessionDAO.Create(conv); err != nil {
return nil, common.CodeDataError, errors.New("Fail to create a session!")
}
session, err := s.chatSessionDAO.GetByID(conv.ID)
if err != nil {
return nil, common.CodeDataError, errors.New("Fail to create a session!")
}
return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil
}
// DeleteSessions delete a session in a dialog
func (s *ChatSessionService) DeleteSessions(userID, chatID string, req map[string]interface{}) (interface{}, string, common.ErrorCode, error) {
ok, err := s.ensureOwnedChat(userID, chatID)
if err != nil {
return nil, "", common.CodeServerError, err
}
if !ok {
return false, "No authorization.", common.CodeAuthenticationError, errors.New("No authorization.")
}
if len(req) == 0 {
return map[string]interface{}{}, "success", common.CodeSuccess, nil
}
sessionIDs, hasIDs := stringSliceFromValue(req["ids"])
if !hasIDs || len(sessionIDs) == 0 {
deleteAll, _ := req["delete_all"].(bool)
if deleteAll {
sessions, err := s.chatSessionDAO.ListByChatID(chatID)
if err != nil {
return nil, "", common.CodeServerError, err
}
for _, session := range sessions {
sessionIDs = append(sessionIDs, session.ID)
}
if len(sessionIDs) == 0 {
return map[string]interface{}{}, "success", common.CodeSuccess, nil
}
} else {
return map[string]interface{}{}, "success", common.CodeSuccess, nil
}
}
uniqueIDs, duplicateMessages := checkDuplicateChatSessionIDs(sessionIDs)
errorsList := make([]string, 0)
successCount := 0
for _, sid := range uniqueIDs {
session, err := s.chatSessionDAO.GetBySessionIDAndChatID(sid, chatID)
if err != nil {
errorsList = append(errorsList, fmt.Sprintf("The chat doesn't own the session %s", sid))
continue
}
s.removeSessionUploadFiles(userID, session)
if err := s.chatSessionDAO.DeleteByID(sid); err != nil {
return nil, "", common.CodeServerError, err
}
successCount++
}
allErrors := append(errorsList, duplicateMessages...)
if len(allErrors) > 0 {
if successCount > 0 {
return map[string]interface{}{
"success_count": successCount,
"errors": allErrors,
}, fmt.Sprintf("Partially deleted %d sessions with %d errors", successCount, len(allErrors)), common.CodeSuccess, nil
}
return nil, "", common.CodeDataError, errors.New(strings.Join(allErrors, "; "))
}
return true, "success", common.CodeSuccess, nil
}
func stringSliceFromValue(value interface{}) ([]string, bool) {
var raw []interface{}
switch typed := value.(type) {
case []interface{}:
raw = typed
case []string:
raw = make([]interface{}, 0, len(typed))
for _, item := range typed {
raw = append(raw, item)
}
default:
return nil, false
}
ids := make([]string, 0, len(raw))
for _, item := range raw {
id, ok := item.(string)
if !ok {
continue
}
if strings.TrimSpace(id) == "" {
continue
}
ids = append(ids, id)
}
return ids, true
}
func (s *ChatSessionService) removeSessionUploadFiles(userID string, session *entity.ChatSession) {
messages := parseMessages(session.Message)
bucket := fmt.Sprintf("%s-downloads", userID)
storageImpl := storage.GetStorageFactory().GetStorage()
if storageImpl == nil {
common.Warn("storage is not initialized; skip chat upload cleanup", zap.String("bucket", bucket))
return
}
for _, msg := range messages {
files, ok := msg["files"].([]interface{})
if !ok {
continue
}
for _, item := range files {
file, ok := item.(map[string]interface{})
if !ok {
continue
}
fileID, ok := file["id"].(string)
if !ok || fileID == "" {
continue
}
if err := storageImpl.Remove(bucket, fileID); err != nil {
common.Warn("Failed to delete chat upload blob",
zap.String("bucket", bucket),
zap.String("file_id", fileID),
zap.Error(err),
)
}
}
}
}
func checkDuplicateChatSessionIDs(ids []string) ([]string, []string) {
idCount := make(map[string]int, len(ids))
uniqueIDs := make([]string, 0, len(ids))
for _, id := range ids {
id = strings.TrimSpace(id)
if id == "" {
continue
}
idCount[id]++
if idCount[id] == 1 {
uniqueIDs = append(uniqueIDs, id)
}
}
duplicateMessages := make([]string, 0)
for id, count := range idCount {
if count > 1 {
duplicateMessages = append(duplicateMessages, fmt.Sprintf("Duplicate session ids: %s", id))
}
}
return uniqueIDs, duplicateMessages
}
// UpdateSession updates one chat session after Python-style field validation.
func (s *ChatSessionService) UpdateSession(userID, chatID, sessionID string, req map[string]interface{}) (*ChatSessionPayload, common.ErrorCode, error) {
ok, err := s.ensureOwnedChat(userID, chatID)