mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
feat[Go]: implement Create-Chat/Session, Delete-Session (#16386)
### What problem does this PR solve?
As title:
implement:
```go
chats.POST("", r.chatHandler.Create)
chats.POST("/:chat_id/sessions", r.chatSessionHandler.CreateSession)
chats.DELETE("/:chat_id/sessions", r.chatSessionHandler.DeleteSessions)
```
bug fixed:
f80d4c7843/internal/handler/chat.go (L84)
↓
```go
result, err := h.chatService.ListChats(userID, "1", keywords, page, pageSize, orderby, desc)
```
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
This commit is contained in:
@@ -169,6 +169,15 @@ func (dao *ChatDAO) GetExistingNames(tenantID string, status string) ([]string,
|
||||
return names, err
|
||||
}
|
||||
|
||||
// ExistsByNameTenantStatus checks whether a chat with the given name exists.
|
||||
func (dao *ChatDAO) ExistsByNameTenantStatus(name, tenantID, status string) (bool, error) {
|
||||
var count int64
|
||||
err := DB.Model(&entity.Chat{}).
|
||||
Where("name = ? AND tenant_id = ? AND status = ?", name, tenantID, status).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// Create creates a new chat/dialog
|
||||
func (dao *ChatDAO) Create(chat *entity.Chat) error {
|
||||
return DB.Create(chat).Error
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"ragflow/internal/common"
|
||||
"strconv"
|
||||
@@ -81,7 +82,7 @@ func (h *ChatHandler) ListChats(c *gin.Context) {
|
||||
}
|
||||
|
||||
// List chats - default to valid status "1" (same as Python StatusEnum.VALID.value)
|
||||
result, err := h.chatService.ListChats(userID, keywords, "1", page, pageSize, orderby, desc)
|
||||
result, err := h.chatService.ListChats(userID, "1", keywords, page, pageSize, orderby, desc)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
@@ -97,6 +98,46 @@ func (h *ChatHandler) ListChats(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// Create creates a chat.
|
||||
// @Summary Create Chat
|
||||
// @Description Create a chat, aligned with Python POST /api/v1/chats.
|
||||
// @Tags chat
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body service.CreateChatRequest true "chat configuration"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/v1/chats [post]
|
||||
func (h *ChatHandler) Create(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
var req map[string]interface{}
|
||||
decoder := json.NewDecoder(c.Request.Body)
|
||||
decoder.UseNumber()
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
jsonError(c, common.CodeArgumentError, err.Error())
|
||||
return
|
||||
}
|
||||
if req == nil {
|
||||
req = map[string]interface{}{}
|
||||
}
|
||||
|
||||
result, code, err := h.chatService.Create(user.ID, req)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"data": result,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// ListChatsNext list chats with advanced filtering and pagination
|
||||
// @Summary List Chats Next
|
||||
// @Description Get list of chats with filtering, pagination and sorting (equivalent to list_dialogs_next)
|
||||
|
||||
@@ -17,11 +17,13 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"ragflow/internal/common"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -349,6 +351,98 @@ func (h *ChatSessionHandler) GetSession(c *gin.Context) {
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
// CreateSession create a session in a dialog
|
||||
func (h *ChatSessionHandler) CreateSession(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
userID := strings.TrimSpace(user.ID)
|
||||
if userID == "" {
|
||||
jsonError(c, common.CodeBadRequest, "user_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
chatID := strings.TrimSpace(c.Param("chat_id"))
|
||||
if chatID == "" {
|
||||
jsonError(c, common.CodeBadRequest, "chat_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
req := map[string]interface{}{}
|
||||
if err := json.NewDecoder(c.Request.Body).Decode(&req); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
req = map[string]interface{}{}
|
||||
} else {
|
||||
jsonError(c, common.CodeArgumentError, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
if req == nil {
|
||||
req = map[string]interface{}{}
|
||||
}
|
||||
|
||||
result, code, err := h.chatSessionService.CreateSession(userID, chatID, req)
|
||||
if err != nil {
|
||||
if code == common.CodeAuthenticationError {
|
||||
jsonResponse(c, code, false, err.Error())
|
||||
return
|
||||
}
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
// DeleteSessions delete a session in a dialog
|
||||
func (h *ChatSessionHandler) DeleteSessions(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
chatID := strings.TrimSpace(c.Param("chat_id"))
|
||||
if chatID == "" {
|
||||
jsonError(c, common.CodeBadRequest, "chat_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
userID := strings.TrimSpace(user.ID)
|
||||
if userID == "" {
|
||||
jsonError(c, common.CodeBadRequest, "user_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
req := map[string]interface{}{}
|
||||
if err := json.NewDecoder(c.Request.Body).Decode(&req); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
req = map[string]interface{}{}
|
||||
} else {
|
||||
jsonError(c, common.CodeArgumentError, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
if req == nil {
|
||||
req = map[string]interface{}{}
|
||||
}
|
||||
|
||||
result, message, code, err := h.chatSessionService.DeleteSessions(userID, chatID, req)
|
||||
if err != nil {
|
||||
if code == common.CodeAuthenticationError {
|
||||
jsonResponse(c, code, false, err.Error())
|
||||
return
|
||||
}
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, result, message)
|
||||
}
|
||||
|
||||
func (h *ChatSessionHandler) UpdateSession(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
|
||||
@@ -253,12 +253,15 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
chats := v1.Group("/chats")
|
||||
{
|
||||
chats.GET("", r.chatHandler.ListChats)
|
||||
chats.POST("", r.chatHandler.Create)
|
||||
chats.DELETE("", r.chatHandler.BulkDeleteChats)
|
||||
chats.DELETE("/:chat_id", r.chatHandler.DeleteChat)
|
||||
chats.GET("/:chat_id", r.chatHandler.GetChat)
|
||||
chats.PUT("/:chat_id", r.chatHandler.UpdateChat)
|
||||
chats.PATCH("/:chat_id", r.chatHandler.PatchChat)
|
||||
chats.GET("/:chat_id/sessions", r.chatSessionHandler.ListChatSessions)
|
||||
chats.POST("/:chat_id/sessions", r.chatSessionHandler.CreateSession)
|
||||
chats.DELETE("/:chat_id/sessions", r.chatSessionHandler.DeleteSessions)
|
||||
chats.GET("/:chat_id/sessions/:session_id", r.chatSessionHandler.GetSession)
|
||||
chats.PATCH("/:chat_id/sessions/:session_id", r.chatSessionHandler.UpdateSession)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"ragflow/internal/common"
|
||||
@@ -28,6 +29,43 @@ import (
|
||||
"ragflow/internal/dao"
|
||||
)
|
||||
|
||||
var DefaultPromptConfig = PromptConfig{
|
||||
System: strPtr(pyDefaultSystemPrompt),
|
||||
Prologue: strPtr(pyDefaultPrologue),
|
||||
Parameters: []ParameterConfig{
|
||||
{Key: "knowledge", Optional: false},
|
||||
},
|
||||
EmptyResponse: strPtr(pyDefaultEmptyResponse),
|
||||
Quote: boolPtr(true),
|
||||
TTS: boolPtr(false),
|
||||
RefineMultiturn: boolPtr(true),
|
||||
}
|
||||
|
||||
var DefaultDirectChatPromptConfig = PromptConfig{
|
||||
System: strPtr(""),
|
||||
Prologue: strPtr(""),
|
||||
Parameters: []ParameterConfig{},
|
||||
EmptyResponse: strPtr(""),
|
||||
Quote: boolPtr(false),
|
||||
TTS: boolPtr(false),
|
||||
RefineMultiturn: boolPtr(true),
|
||||
}
|
||||
|
||||
var DefaultRerankModels = map[string]struct{}{
|
||||
"BAAI/bge-reranker-v2-m3": {},
|
||||
"maidalun1020/bce-reranker-base_v1": {},
|
||||
}
|
||||
|
||||
var ReadOnlyFields = map[string]struct{}{
|
||||
"id": {},
|
||||
"tenant_id": {},
|
||||
"created_by": {},
|
||||
"create_time": {},
|
||||
"create_date": {},
|
||||
"update_time": {},
|
||||
"update_date": {},
|
||||
}
|
||||
|
||||
// ChatService chat service
|
||||
type ChatService struct {
|
||||
chatDAO *dao.ChatDAO
|
||||
@@ -114,6 +152,488 @@ func (s *ChatService) ListChats(userID, status, keywords string, page, pageSize
|
||||
}, nil
|
||||
}
|
||||
|
||||
type CreateChatRequest struct {
|
||||
Name string
|
||||
DatasetIDs []string `json:"dataset_ids"`
|
||||
KBIDs []string `json:"kb_ids"`
|
||||
LLMID *string `json:"llm_id"`
|
||||
LLMSetting map[string]interface{} `json:"llm_setting"`
|
||||
RerankID *string `json:"rerank_id"`
|
||||
PromptConfig map[string]interface{} `json:"prompt_config"`
|
||||
Description *string
|
||||
TopN *int
|
||||
TopK *int
|
||||
SimilarityThreshold *float64
|
||||
VectorSimilarityWeight *float64
|
||||
Icon *string
|
||||
TenantID *string `json:"tenant_id"`
|
||||
}
|
||||
|
||||
func (s *ChatService) Create(userID string, req map[string]interface{}) (map[string]interface{}, common.ErrorCode, error) {
|
||||
tenant, err := s.tenantDAO.GetByID(userID)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, errors.New("Tenant not found!")
|
||||
}
|
||||
|
||||
if tenantValue, ok := req["tenant_id"]; ok && isTruthy(tenantValue) {
|
||||
return nil, common.CodeDataError, errors.New("`tenant_id` must not be provided.")
|
||||
}
|
||||
|
||||
name, err := validateCreateChatName(req["name"])
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, err
|
||||
}
|
||||
req["name"] = name
|
||||
|
||||
if datasetIDsValue, ok := req["dataset_ids"]; ok {
|
||||
kbIDs, err := s.validateCreateDatasetIDs(datasetIDsValue, userID)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, err
|
||||
}
|
||||
req["kb_ids"] = kbIDs
|
||||
delete(req, "dataset_ids")
|
||||
}
|
||||
|
||||
if llmIDValue, ok := req["llm_id"]; ok {
|
||||
llmID := stringFromValue(llmIDValue)
|
||||
llmSetting, _ := mapFromValue(req["llm_setting"])
|
||||
if err = validateCreateLLMID(llmID, userID, llmSetting); err != nil {
|
||||
return nil, common.CodeDataError, err
|
||||
}
|
||||
}
|
||||
|
||||
if rerankIDValue, ok := req["rerank_id"]; ok {
|
||||
rerankID := stringFromValue(rerankIDValue)
|
||||
if err = validateCreateRerankID(rerankID, userID); err != nil {
|
||||
return nil, common.CodeDataError, err
|
||||
}
|
||||
}
|
||||
|
||||
if promptConfigValue, ok := req["prompt_config"]; ok {
|
||||
if _, ok := mapFromValue(promptConfigValue); !ok {
|
||||
return nil, common.CodeDataError, errors.New("`prompt_config` should be an object.")
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := req["kb_ids"]; !ok {
|
||||
req["kb_ids"] = []string{}
|
||||
}
|
||||
if _, ok := req["llm_id"]; !ok || req["llm_id"] == nil {
|
||||
req["llm_id"] = tenant.LLMID
|
||||
}
|
||||
if _, ok := req["llm_setting"]; !ok || req["llm_setting"] == nil {
|
||||
req["llm_setting"] = map[string]interface{}{}
|
||||
}
|
||||
if _, ok := req["description"]; !ok {
|
||||
req["description"] = "A helpful Assistant"
|
||||
}
|
||||
if _, ok := req["top_n"]; !ok {
|
||||
req["top_n"] = 6
|
||||
}
|
||||
if _, ok := req["top_k"]; !ok {
|
||||
req["top_k"] = 1024
|
||||
}
|
||||
if _, ok := req["rerank_id"]; !ok {
|
||||
req["rerank_id"] = ""
|
||||
}
|
||||
if _, ok := req["similarity_threshold"]; !ok {
|
||||
req["similarity_threshold"] = 0.1
|
||||
}
|
||||
if _, ok := req["vector_similarity_weight"]; !ok {
|
||||
req["vector_similarity_weight"] = 0.3
|
||||
}
|
||||
if _, ok := req["icon"]; !ok {
|
||||
req["icon"] = ""
|
||||
}
|
||||
|
||||
applyCreatePromptDefaults(req)
|
||||
filterCreateChatPersistedFields(req)
|
||||
|
||||
exists, err := s.chatDAO.ExistsByNameTenantStatus(name, userID, string(entity.StatusValid))
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if exists {
|
||||
return nil, common.CodeDataError, errors.New("Duplicated chat name in creating chat.")
|
||||
}
|
||||
|
||||
chat := buildCreateChatEntity(req, userID)
|
||||
if err = s.chatDAO.Create(chat); err != nil {
|
||||
return nil, common.CodeDataError, errors.New("Failed to create chat.")
|
||||
}
|
||||
|
||||
chat, err = s.chatDAO.GetByID(chat.ID)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, errors.New("Failed to retrieve created chat.")
|
||||
}
|
||||
|
||||
response, err := s.buildCreateChatResponse(chat)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
return response, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func validateCreateChatName(value interface{}) (string, error) {
|
||||
if value == nil {
|
||||
return "", errors.New("`name` is required.")
|
||||
}
|
||||
name, ok := value.(string)
|
||||
if !ok {
|
||||
return "", errors.New("Chat name must be a string.")
|
||||
}
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return "", errors.New("`name` is required.")
|
||||
}
|
||||
if len([]byte(name)) > 255 {
|
||||
return "", fmt.Errorf("Chat name length is %d which is larger than 255.", len([]byte(name)))
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func (s *ChatService) validateCreateDatasetIDs(value interface{}, tenantID string) ([]string, error) {
|
||||
if value == nil {
|
||||
return []string{}, nil
|
||||
}
|
||||
values, ok := listFromValue(value)
|
||||
if !ok {
|
||||
return nil, errors.New("`dataset_ids` should be a list.")
|
||||
}
|
||||
|
||||
normalizedIDs := make([]string, 0, len(values))
|
||||
kbs := make([]*entity.Knowledgebase, 0, len(values))
|
||||
for _, item := range values {
|
||||
if !isTruthy(item) {
|
||||
continue
|
||||
}
|
||||
datasetID := stringFromValue(item)
|
||||
normalizedIDs = append(normalizedIDs, datasetID)
|
||||
}
|
||||
|
||||
for _, datasetID := range normalizedIDs {
|
||||
if !s.kbDAO.Accessible(datasetID, tenantID) {
|
||||
return nil, fmt.Errorf("You don't own the dataset %s", datasetID)
|
||||
}
|
||||
kb, err := s.kbDAO.GetByID(datasetID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("You don't own the dataset %s", datasetID)
|
||||
}
|
||||
if kb.ChunkNum == 0 {
|
||||
return nil, fmt.Errorf("The dataset %s doesn't own parsed file", datasetID)
|
||||
}
|
||||
kbs = append(kbs, kb)
|
||||
}
|
||||
|
||||
embedIDs := make(map[string]struct{}, len(kbs))
|
||||
for _, kb := range kbs {
|
||||
embedIDs[s.splitModelNameAndFactory(kb.EmbdID)] = struct{}{}
|
||||
}
|
||||
if len(embedIDs) > 1 {
|
||||
return nil, fmt.Errorf("Datasets use different embedding models: %v", getEmbdIDs(kbs))
|
||||
}
|
||||
return normalizedIDs, nil
|
||||
}
|
||||
|
||||
func validateCreateLLMID(llmID, tenantID string, llmSetting map[string]interface{}) error {
|
||||
if llmID == "" {
|
||||
return nil
|
||||
}
|
||||
modelType := entity.ModelTypeChat
|
||||
switch confModelType := llmSetting["model_type"].(type) {
|
||||
case string:
|
||||
if confModelType == string(entity.ModelTypeImage2Text) {
|
||||
modelType = entity.ModelTypeImage2Text
|
||||
}
|
||||
case []interface{}:
|
||||
for _, item := range confModelType {
|
||||
if item == string(entity.ModelTypeImage2Text) {
|
||||
modelType = entity.ModelTypeImage2Text
|
||||
break
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
for _, item := range confModelType {
|
||||
if item == string(entity.ModelTypeImage2Text) {
|
||||
modelType = entity.ModelTypeImage2Text
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if _, _, _, _, err := NewModelProviderService().GetModelConfigFromProviderInstance(tenantID, modelType, llmID); err != nil {
|
||||
return fmt.Errorf("`llm_id` %s doesn't exist", llmID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateCreateRerankID(rerankID, tenantID string) error {
|
||||
if rerankID == "" {
|
||||
return nil
|
||||
}
|
||||
llmName := strings.Split(rerankID, "@")[0]
|
||||
if _, ok := DefaultRerankModels[llmName]; ok {
|
||||
return nil
|
||||
}
|
||||
if _, _, _, _, err := NewModelProviderService().GetModelConfigFromProviderInstance(tenantID, entity.ModelTypeRerank, rerankID); err != nil {
|
||||
return fmt.Errorf("`rerank_id` %s doesn't exist", rerankID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyCreatePromptDefaults(req map[string]interface{}) {
|
||||
promptConfig, _ := mapFromValue(req["prompt_config"])
|
||||
if promptConfig == nil {
|
||||
promptConfig = map[string]interface{}{}
|
||||
}
|
||||
if system, ok := promptConfig["system"]; !ok || !isTruthy(system) {
|
||||
promptConfig["system"] = pyDefaultSystemPrompt
|
||||
}
|
||||
if _, ok := promptConfig["prologue"]; !ok {
|
||||
promptConfig["prologue"] = pyDefaultPrologue
|
||||
}
|
||||
if _, ok := promptConfig["parameters"]; !ok {
|
||||
promptConfig["parameters"] = []interface{}{map[string]interface{}{"key": "knowledge", "optional": false}}
|
||||
}
|
||||
if _, ok := promptConfig["empty_response"]; !ok {
|
||||
promptConfig["empty_response"] = pyDefaultEmptyResponse
|
||||
}
|
||||
if _, ok := promptConfig["quote"]; !ok {
|
||||
promptConfig["quote"] = true
|
||||
}
|
||||
if _, ok := promptConfig["tts"]; !ok {
|
||||
promptConfig["tts"] = false
|
||||
}
|
||||
if _, ok := promptConfig["refine_multiturn"]; !ok {
|
||||
promptConfig["refine_multiturn"] = true
|
||||
}
|
||||
|
||||
kbIDs, _ := listFromValue(req["kb_ids"])
|
||||
system, _ := promptConfig["system"].(string)
|
||||
if len(kbIDs) > 0 && !isTruthy(promptConfig["parameters"]) && strings.Contains(system, "{knowledge}") {
|
||||
promptConfig["parameters"] = []interface{}{map[string]interface{}{"key": "knowledge", "optional": false}}
|
||||
}
|
||||
req["prompt_config"] = promptConfig
|
||||
}
|
||||
|
||||
func filterCreateChatPersistedFields(req map[string]interface{}) {
|
||||
persisted := map[string]struct{}{
|
||||
"name": {}, "description": {}, "icon": {}, "language": {}, "llm_id": {}, "tenant_llm_id": {},
|
||||
"llm_setting": {}, "prompt_type": {}, "prompt_config": {}, "meta_data_filter": {},
|
||||
"similarity_threshold": {}, "vector_similarity_weight": {}, "top_n": {}, "top_k": {},
|
||||
"do_refer": {}, "rerank_id": {}, "tenant_rerank_id": {}, "kb_ids": {}, "status": {},
|
||||
}
|
||||
for key := range req {
|
||||
if _, ok := persisted[key]; !ok {
|
||||
delete(req, key)
|
||||
}
|
||||
}
|
||||
for key := range ReadOnlyFields {
|
||||
delete(req, key)
|
||||
}
|
||||
}
|
||||
|
||||
func buildCreateChatEntity(req map[string]interface{}, tenantID string) *entity.Chat {
|
||||
name := stringFromValue(req["name"])
|
||||
description := stringFromValue(req["description"])
|
||||
icon := stringFromValue(req["icon"])
|
||||
llmID := stringFromValue(req["llm_id"])
|
||||
rerankID := stringFromValue(req["rerank_id"])
|
||||
llmSetting, _ := mapFromValue(req["llm_setting"])
|
||||
promptConfig, _ := mapFromValue(req["prompt_config"])
|
||||
kbIDs, _ := stringListFromValue(req["kb_ids"])
|
||||
kbIDsJSON := make(entity.JSONSlice, 0, len(kbIDs))
|
||||
for _, id := range kbIDs {
|
||||
kbIDsJSON = append(kbIDsJSON, id)
|
||||
}
|
||||
status, hasStatus := req["status"]
|
||||
statusValue := string(entity.StatusValid)
|
||||
if hasStatus {
|
||||
statusValue = stringFromValue(status)
|
||||
}
|
||||
|
||||
chat := &entity.Chat{
|
||||
ID: common.GenerateUUID(),
|
||||
TenantID: tenantID,
|
||||
Name: &name,
|
||||
Description: &description,
|
||||
Icon: &icon,
|
||||
LLMID: llmID,
|
||||
LLMSetting: entity.JSONMap(llmSetting),
|
||||
PromptType: stringFromValue(req["prompt_type"]),
|
||||
PromptConfig: entity.JSONMap(promptConfig),
|
||||
SimilarityThreshold: floatFromValue(req["similarity_threshold"]),
|
||||
VectorSimilarityWeight: floatFromValue(req["vector_similarity_weight"]),
|
||||
TopN: int64FromValue(req["top_n"]),
|
||||
TopK: int64FromValue(req["top_k"]),
|
||||
DoRefer: stringFromValue(req["do_refer"]),
|
||||
RerankID: rerankID,
|
||||
KBIDs: kbIDsJSON,
|
||||
Status: &statusValue,
|
||||
}
|
||||
if chat.PromptType == "" {
|
||||
chat.PromptType = "simple"
|
||||
}
|
||||
if chat.DoRefer == "" {
|
||||
chat.DoRefer = "1"
|
||||
}
|
||||
if language := stringFromValue(req["language"]); language != "" {
|
||||
chat.Language = &language
|
||||
}
|
||||
if metaDataFilter, ok := mapFromValue(req["meta_data_filter"]); ok {
|
||||
metaDataFilterJSON := entity.JSONMap(metaDataFilter)
|
||||
chat.MetaDataFilter = &metaDataFilterJSON
|
||||
}
|
||||
return chat
|
||||
}
|
||||
|
||||
func (s *ChatService) buildCreateChatResponse(chat *entity.Chat) (map[string]interface{}, error) {
|
||||
data, err := structToMap(chat)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs)
|
||||
data["dataset_ids"] = datasetIDs
|
||||
delete(data, "kb_ids")
|
||||
data["kb_names"] = kbNames
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func structToMap(value interface{}) (map[string]interface{}, error) {
|
||||
bytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := map[string]interface{}{}
|
||||
if err = json.Unmarshal(bytes, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func stringFromValue(value interface{}) string {
|
||||
switch typed := value.(type) {
|
||||
case nil:
|
||||
return ""
|
||||
case string:
|
||||
return typed
|
||||
default:
|
||||
return fmt.Sprint(typed)
|
||||
}
|
||||
}
|
||||
|
||||
func mapFromValue(value interface{}) (map[string]interface{}, bool) {
|
||||
switch typed := value.(type) {
|
||||
case nil:
|
||||
return nil, false
|
||||
case map[string]interface{}:
|
||||
return typed, true
|
||||
case entity.JSONMap:
|
||||
return map[string]interface{}(typed), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func listFromValue(value interface{}) ([]interface{}, bool) {
|
||||
switch typed := value.(type) {
|
||||
case nil:
|
||||
return nil, false
|
||||
case []interface{}:
|
||||
return typed, true
|
||||
case []string:
|
||||
result := make([]interface{}, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
result = append(result, item)
|
||||
}
|
||||
return result, true
|
||||
case entity.JSONSlice:
|
||||
return []interface{}(typed), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func stringListFromValue(value interface{}) ([]string, bool) {
|
||||
values, ok := listFromValue(value)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
result := make([]string, 0, len(values))
|
||||
for _, item := range values {
|
||||
if !isTruthy(item) {
|
||||
continue
|
||||
}
|
||||
result = append(result, stringFromValue(item))
|
||||
}
|
||||
return result, true
|
||||
}
|
||||
|
||||
func int64FromValue(value interface{}) int64 {
|
||||
switch typed := value.(type) {
|
||||
case int:
|
||||
return int64(typed)
|
||||
case int64:
|
||||
return typed
|
||||
case float64:
|
||||
return int64(typed)
|
||||
case json.Number:
|
||||
n, err := typed.Int64()
|
||||
if err == nil {
|
||||
return n
|
||||
}
|
||||
f, _ := typed.Float64()
|
||||
return int64(f)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func floatFromValue(value interface{}) float64 {
|
||||
switch typed := value.(type) {
|
||||
case float64:
|
||||
return typed
|
||||
case float32:
|
||||
return float64(typed)
|
||||
case int:
|
||||
return float64(typed)
|
||||
case int64:
|
||||
return float64(typed)
|
||||
case json.Number:
|
||||
n, _ := typed.Float64()
|
||||
return n
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func isTruthy(value interface{}) bool {
|
||||
switch typed := value.(type) {
|
||||
case nil:
|
||||
return false
|
||||
case bool:
|
||||
return typed
|
||||
case string:
|
||||
return typed != ""
|
||||
case int:
|
||||
return typed != 0
|
||||
case int64:
|
||||
return typed != 0
|
||||
case float64:
|
||||
return typed != 0
|
||||
case json.Number:
|
||||
n, err := typed.Float64()
|
||||
return err != nil || n != 0
|
||||
case []interface{}:
|
||||
return len(typed) > 0
|
||||
case []string:
|
||||
return len(typed) > 0
|
||||
case map[string]interface{}:
|
||||
return len(typed) > 0
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// ListChatsNextRequest list chats next request
|
||||
type ListChatsNextRequest struct {
|
||||
OwnerIDs []string `json:"owner_ids,omitempty"`
|
||||
@@ -1169,6 +1689,10 @@ func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func boolPtr(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
|
||||
// Helper to count UTF-8 characters (not bytes)
|
||||
func (s *ChatService) countRunes(str string) int {
|
||||
return utf8.RuneCountInString(str)
|
||||
|
||||
@@ -22,9 +22,11 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/storage"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"ragflow/internal/dao"
|
||||
@@ -306,6 +308,230 @@ func (s *ChatSessionService) GetSession(userID, chatID, sessionID string) (*Chat
|
||||
return s.buildSessionPayload(session, dialog, true), common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// CreateSession create a session in a dialog
|
||||
func (s *ChatSessionService) CreateSession(userID, chatID string, req map[string]interface{}) (*ChatSessionPayload, common.ErrorCode, error) {
|
||||
ok, err := s.ensureOwnedChat(userID, chatID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if !ok {
|
||||
return nil, common.CodeAuthenticationError, errors.New("No authorization.")
|
||||
}
|
||||
|
||||
dialog, err := s.chatSessionDAO.GetDialogByID(chatID)
|
||||
if err != nil {
|
||||
if isChatSessionNotFound(err) {
|
||||
return nil, common.CodeDataError, errors.New("Chat not found!")
|
||||
}
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
name := "New session"
|
||||
if rawName, exists := req["name"]; exists {
|
||||
nameStr, ok := rawName.(string)
|
||||
if !ok || strings.TrimSpace(nameStr) == "" {
|
||||
return nil, common.CodeDataError, errors.New("`name` can not be empty.")
|
||||
}
|
||||
name = strings.TrimSpace(nameStr)
|
||||
}
|
||||
nameRunes := []rune(name)
|
||||
if len(nameRunes) > 255 {
|
||||
name = string(nameRunes[:255])
|
||||
}
|
||||
|
||||
prologue := ""
|
||||
if dialog.PromptConfig != nil {
|
||||
if value, ok := dialog.PromptConfig["prologue"].(string); ok {
|
||||
prologue = value
|
||||
}
|
||||
}
|
||||
messagesJSON, _ := json.Marshal([]map[string]interface{}{
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": prologue,
|
||||
},
|
||||
})
|
||||
|
||||
referenceJSON, _ := json.Marshal([]interface{}{})
|
||||
|
||||
conv := &entity.ChatSession{
|
||||
ID: common.GenerateUUID(),
|
||||
DialogID: chatID,
|
||||
Name: &name,
|
||||
Message: messagesJSON,
|
||||
UserID: &userID,
|
||||
Reference: referenceJSON,
|
||||
}
|
||||
|
||||
if err := s.chatSessionDAO.Create(conv); err != nil {
|
||||
return nil, common.CodeDataError, errors.New("Fail to create a session!")
|
||||
}
|
||||
|
||||
session, err := s.chatSessionDAO.GetByID(conv.ID)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, errors.New("Fail to create a session!")
|
||||
}
|
||||
return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// DeleteSessions delete a session in a dialog
|
||||
func (s *ChatSessionService) DeleteSessions(userID, chatID string, req map[string]interface{}) (interface{}, string, common.ErrorCode, error) {
|
||||
ok, err := s.ensureOwnedChat(userID, chatID)
|
||||
if err != nil {
|
||||
return nil, "", common.CodeServerError, err
|
||||
}
|
||||
if !ok {
|
||||
return false, "No authorization.", common.CodeAuthenticationError, errors.New("No authorization.")
|
||||
}
|
||||
|
||||
if len(req) == 0 {
|
||||
return map[string]interface{}{}, "success", common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
sessionIDs, hasIDs := stringSliceFromValue(req["ids"])
|
||||
if !hasIDs || len(sessionIDs) == 0 {
|
||||
deleteAll, _ := req["delete_all"].(bool)
|
||||
if deleteAll {
|
||||
sessions, err := s.chatSessionDAO.ListByChatID(chatID)
|
||||
if err != nil {
|
||||
return nil, "", common.CodeServerError, err
|
||||
}
|
||||
for _, session := range sessions {
|
||||
sessionIDs = append(sessionIDs, session.ID)
|
||||
}
|
||||
if len(sessionIDs) == 0 {
|
||||
return map[string]interface{}{}, "success", common.CodeSuccess, nil
|
||||
}
|
||||
} else {
|
||||
return map[string]interface{}{}, "success", common.CodeSuccess, nil
|
||||
}
|
||||
}
|
||||
|
||||
uniqueIDs, duplicateMessages := checkDuplicateChatSessionIDs(sessionIDs)
|
||||
|
||||
errorsList := make([]string, 0)
|
||||
successCount := 0
|
||||
|
||||
for _, sid := range uniqueIDs {
|
||||
session, err := s.chatSessionDAO.GetBySessionIDAndChatID(sid, chatID)
|
||||
if err != nil {
|
||||
errorsList = append(errorsList, fmt.Sprintf("The chat doesn't own the session %s", sid))
|
||||
continue
|
||||
}
|
||||
|
||||
s.removeSessionUploadFiles(userID, session)
|
||||
|
||||
if err := s.chatSessionDAO.DeleteByID(sid); err != nil {
|
||||
return nil, "", common.CodeServerError, err
|
||||
}
|
||||
|
||||
successCount++
|
||||
}
|
||||
|
||||
allErrors := append(errorsList, duplicateMessages...)
|
||||
|
||||
if len(allErrors) > 0 {
|
||||
if successCount > 0 {
|
||||
return map[string]interface{}{
|
||||
"success_count": successCount,
|
||||
"errors": allErrors,
|
||||
}, fmt.Sprintf("Partially deleted %d sessions with %d errors", successCount, len(allErrors)), common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
return nil, "", common.CodeDataError, errors.New(strings.Join(allErrors, "; "))
|
||||
}
|
||||
|
||||
return true, "success", common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func stringSliceFromValue(value interface{}) ([]string, bool) {
|
||||
var raw []interface{}
|
||||
switch typed := value.(type) {
|
||||
case []interface{}:
|
||||
raw = typed
|
||||
case []string:
|
||||
raw = make([]interface{}, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
raw = append(raw, item)
|
||||
}
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
|
||||
ids := make([]string, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
id, ok := item.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(id) == "" {
|
||||
continue
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids, true
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) removeSessionUploadFiles(userID string, session *entity.ChatSession) {
|
||||
messages := parseMessages(session.Message)
|
||||
bucket := fmt.Sprintf("%s-downloads", userID)
|
||||
storageImpl := storage.GetStorageFactory().GetStorage()
|
||||
if storageImpl == nil {
|
||||
common.Warn("storage is not initialized; skip chat upload cleanup", zap.String("bucket", bucket))
|
||||
return
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
files, ok := msg["files"].([]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, item := range files {
|
||||
file, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
fileID, ok := file["id"].(string)
|
||||
if !ok || fileID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := storageImpl.Remove(bucket, fileID); err != nil {
|
||||
common.Warn("Failed to delete chat upload blob",
|
||||
zap.String("bucket", bucket),
|
||||
zap.String("file_id", fileID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkDuplicateChatSessionIDs(ids []string) ([]string, []string) {
|
||||
idCount := make(map[string]int, len(ids))
|
||||
uniqueIDs := make([]string, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
idCount[id]++
|
||||
if idCount[id] == 1 {
|
||||
uniqueIDs = append(uniqueIDs, id)
|
||||
}
|
||||
}
|
||||
|
||||
duplicateMessages := make([]string, 0)
|
||||
for id, count := range idCount {
|
||||
if count > 1 {
|
||||
duplicateMessages = append(duplicateMessages, fmt.Sprintf("Duplicate session ids: %s", id))
|
||||
}
|
||||
}
|
||||
return uniqueIDs, duplicateMessages
|
||||
}
|
||||
|
||||
// UpdateSession updates one chat session after Python-style field validation.
|
||||
func (s *ChatSessionService) UpdateSession(userID, chatID, sessionID string, req map[string]interface{}) (*ChatSessionPayload, common.ErrorCode, error) {
|
||||
ok, err := s.ensureOwnedChat(userID, chatID)
|
||||
|
||||
Reference in New Issue
Block a user