Files
ragflow/internal/service/chat_session.go
Haruko386 a57a841a11 feat[Go]: implement Create-Chat/Session, Delete-Session (#16386)
### What problem does this PR solve?

As title:
implement:
```go
chats.POST("", r.chatHandler.Create)
chats.POST("/:chat_id/sessions", r.chatSessionHandler.CreateSession)
chats.DELETE("/:chat_id/sessions", r.chatSessionHandler.DeleteSessions)
```

bug fixed:

f80d4c7843/internal/handler/chat.go (L84)
↓
```go
result, err := h.chatService.ListChats(userID, "1", keywords, page, pageSize, orderby, desc)
```

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
2026-06-26 19:23:45 +08:00

1080 lines
31 KiB
Go

//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"ragflow/internal/common"
"ragflow/internal/storage"
"strings"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"ragflow/internal/dao"
"ragflow/internal/entity"
)
// Interfaces for testability — satisfied by the concrete DAO/pipeline types.
type chatSessionStore interface {
GetByID(id string) (*entity.ChatSession, error)
GetBySessionIDAndChatID(sessionID, chatID string) (*entity.ChatSession, error)
Create(conv *entity.ChatSession) error
UpdateByID(id string, updates map[string]interface{}) error
DeleteByID(id string) error
ListByChatID(chatID string) ([]*entity.ChatSession, error)
GetDialogByID(chatID string) (*entity.Chat, error)
CheckDialogExists(tenantID, chatID string) (bool, error)
}
type userTenantStore interface {
GetTenantIDsByUserID(userID string) ([]string, error)
}
type chatPipelineRunner interface {
AsyncChat(ctx context.Context, chat *entity.Chat, messages []map[string]interface{}, stream bool, kwargs map[string]interface{}) (<-chan AsyncChatResult, error)
}
// ChatSessionService chat session (conversation) service.
// The RAG pipeline is delegated to ChatPipelineService.
type ChatSessionService struct {
chatSessionDAO chatSessionStore
userTenantDAO userTenantStore
pipeline chatPipelineRunner
}
// NewChatSessionService create chat session service
func NewChatSessionService() *ChatSessionService {
return &ChatSessionService{
chatSessionDAO: dao.NewChatSessionDAO(),
userTenantDAO: dao.NewUserTenantDAO(),
pipeline: NewChatPipelineService(),
}
}
// SetChatSessionRequest set chat session request
type SetChatSessionRequest struct {
SessionID string `json:"conversation_id,omitempty"`
DialogID string `json:"dialog_id,omitempty"`
Name string `json:"name,omitempty"`
IsNew bool `json:"is_new"`
}
// SetChatSessionResponse set chat session response
type SetChatSessionResponse struct {
*entity.ChatSession
}
// SetChatSession create or update a chat session
func (s *ChatSessionService) SetChatSession(userID string, req *SetChatSessionRequest) (*SetChatSessionResponse, error) {
name := req.Name
if name == "" {
name = "New chat session"
}
// Limit name length to 255 characters
if len(name) > 255 {
name = name[:255]
}
if !req.IsNew {
// Update existing chat session
updates := map[string]interface{}{
"name": name,
"user_id": userID,
}
if err := s.chatSessionDAO.UpdateByID(req.SessionID, updates); err != nil {
return nil, errors.New("Chat session not found")
}
// Get updated chat session
session, err := s.chatSessionDAO.GetByID(req.SessionID)
if err != nil {
return nil, errors.New("Fail to update a chat session")
}
return &SetChatSessionResponse{ChatSession: session}, nil
}
// Create new chat session
// Check if dialog exists
dialog, err := s.chatSessionDAO.GetDialogByID(req.DialogID)
if err != nil {
return nil, errors.New("Dialog not found")
}
// Generate UUID for new chat session
newID := common.GenerateUUID()
// Get prologue from dialog's prompt_config
prologue := "Hi! I'm your assistant. What can I do for you?"
if dialog.PromptConfig != nil {
if p, ok := dialog.PromptConfig["prologue"].(string); ok && p != "" {
prologue = p
}
}
// Store messages in the same list shape as Python Conversation.message.
messagesJSON, _ := json.Marshal([]map[string]interface{}{
{
"role": "assistant",
"content": prologue,
},
})
// Create reference - store as JSON array
referenceJSON, _ := json.Marshal([]interface{}{})
// Create chat session
session := &entity.ChatSession{
ID: newID,
DialogID: req.DialogID,
Name: &name,
Message: messagesJSON,
UserID: &userID,
Reference: referenceJSON,
}
if err := s.chatSessionDAO.Create(session); err != nil {
return nil, errors.New("Fail to create a chat session")
}
return &SetChatSessionResponse{ChatSession: session}, nil
}
// RemoveChatSessions removes chat sessions (hard delete)
func (s *ChatSessionService) RemoveChatSessions(userID string, chatSessions []string) error {
// Get user's tenants
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
if err != nil {
return err
}
// Build a set of user's tenant IDs for quick lookup
tenantIDSet := make(map[string]bool)
for _, tid := range tenantIDs {
tenantIDSet[tid] = true
}
tenantIDSet[userID] = true
// Check each chat session
for _, convID := range chatSessions {
// Get the chat session
session, err := s.chatSessionDAO.GetByID(convID)
if err != nil {
return fmt.Errorf("Chat session not found: %s", convID)
}
// Check if user is the owner by checking dialog ownership
isOwner := false
for tenantID := range tenantIDSet {
exists, err := s.chatSessionDAO.CheckDialogExists(tenantID, session.DialogID)
if err != nil {
return err
}
if exists {
isOwner = true
break
}
}
if !isOwner {
return errors.New("Only owner of chat session authorized for this operation")
}
// Delete the chat session
if err := s.chatSessionDAO.DeleteByID(convID); err != nil {
return err
}
}
return nil
}
// ListChatSessionsRequest list chat sessions request
type ListChatSessionsRequest struct {
DialogID string `json:"dialog_id" binding:"required"`
}
// ListChatSessionsResponse list chat sessions response
type ListChatSessionsResponse struct {
Sessions []*entity.ChatSession
}
type ChatSessionPayload struct {
ID string `json:"id"`
ChatID string `json:"chat_id"`
Name *string `json:"name,omitempty"`
Messages []map[string]interface{} `json:"messages"`
Reference []interface{} `json:"reference"`
UserID *string `json:"user_id,omitempty"`
Avatar *string `json:"avatar,omitempty"`
CreateDate *time.Time `json:"create_date,omitempty"`
UpdateDate *time.Time `json:"update_date,omitempty"`
CreateTime *int64 `json:"create_time,omitempty"`
UpdateTime *int64 `json:"update_time,omitempty"`
}
// ListChatSessions lists chat sessions for a dialog
func (s *ChatSessionService) ListChatSessions(userID string, chatID string) (*ListChatSessionsResponse, error) {
// Get user's tenants
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
if err != nil {
return nil, err
}
// Check if user is the owner of the dialog
isOwner := false
for _, tenantID := range tenantIDs {
var exists bool
exists, err = s.chatSessionDAO.CheckDialogExists(tenantID, chatID)
if err != nil {
return nil, err
}
if exists {
isOwner = true
break
}
}
// Also check with userID as tenant
if !isOwner {
var exists bool
exists, err = s.chatSessionDAO.CheckDialogExists(userID, chatID)
if err != nil {
return nil, err
}
isOwner = exists
}
if !isOwner {
return nil, errors.New("only owner of dialog authorized for this operation")
}
// List chat sessions
sessions, err := s.chatSessionDAO.ListByChatID(chatID)
if err != nil {
return nil, err
}
return &ListChatSessionsResponse{Sessions: sessions}, nil
}
// GetSession returns one chat session after ownership validation.
func (s *ChatSessionService) GetSession(userID, chatID, sessionID string) (*ChatSessionPayload, common.ErrorCode, error) {
ok, err := s.ensureOwnedChat(userID, chatID)
if err != nil {
return nil, common.CodeServerError, err
}
if !ok {
return nil, common.CodeAuthenticationError, errors.New("No authorization.")
}
session, err := s.chatSessionDAO.GetByID(sessionID)
if err != nil {
if isChatSessionNotFound(err) {
return nil, common.CodeDataError, errors.New("Session not found!")
}
return nil, common.CodeServerError, err
}
if session.DialogID != chatID {
return nil, common.CodeDataError, errors.New("Session does not belong to this chat!")
}
dialog, err := s.chatSessionDAO.GetDialogByID(chatID)
if err != nil && !isChatSessionNotFound(err) {
return nil, common.CodeServerError, err
}
return s.buildSessionPayload(session, dialog, true), common.CodeSuccess, nil
}
// CreateSession create a session in a dialog
func (s *ChatSessionService) CreateSession(userID, chatID string, req map[string]interface{}) (*ChatSessionPayload, common.ErrorCode, error) {
ok, err := s.ensureOwnedChat(userID, chatID)
if err != nil {
return nil, common.CodeServerError, err
}
if !ok {
return nil, common.CodeAuthenticationError, errors.New("No authorization.")
}
dialog, err := s.chatSessionDAO.GetDialogByID(chatID)
if err != nil {
if isChatSessionNotFound(err) {
return nil, common.CodeDataError, errors.New("Chat not found!")
}
return nil, common.CodeServerError, err
}
name := "New session"
if rawName, exists := req["name"]; exists {
nameStr, ok := rawName.(string)
if !ok || strings.TrimSpace(nameStr) == "" {
return nil, common.CodeDataError, errors.New("`name` can not be empty.")
}
name = strings.TrimSpace(nameStr)
}
nameRunes := []rune(name)
if len(nameRunes) > 255 {
name = string(nameRunes[:255])
}
prologue := ""
if dialog.PromptConfig != nil {
if value, ok := dialog.PromptConfig["prologue"].(string); ok {
prologue = value
}
}
messagesJSON, _ := json.Marshal([]map[string]interface{}{
{
"role": "assistant",
"content": prologue,
},
})
referenceJSON, _ := json.Marshal([]interface{}{})
conv := &entity.ChatSession{
ID: common.GenerateUUID(),
DialogID: chatID,
Name: &name,
Message: messagesJSON,
UserID: &userID,
Reference: referenceJSON,
}
if err := s.chatSessionDAO.Create(conv); err != nil {
return nil, common.CodeDataError, errors.New("Fail to create a session!")
}
session, err := s.chatSessionDAO.GetByID(conv.ID)
if err != nil {
return nil, common.CodeDataError, errors.New("Fail to create a session!")
}
return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil
}
// DeleteSessions delete a session in a dialog
func (s *ChatSessionService) DeleteSessions(userID, chatID string, req map[string]interface{}) (interface{}, string, common.ErrorCode, error) {
ok, err := s.ensureOwnedChat(userID, chatID)
if err != nil {
return nil, "", common.CodeServerError, err
}
if !ok {
return false, "No authorization.", common.CodeAuthenticationError, errors.New("No authorization.")
}
if len(req) == 0 {
return map[string]interface{}{}, "success", common.CodeSuccess, nil
}
sessionIDs, hasIDs := stringSliceFromValue(req["ids"])
if !hasIDs || len(sessionIDs) == 0 {
deleteAll, _ := req["delete_all"].(bool)
if deleteAll {
sessions, err := s.chatSessionDAO.ListByChatID(chatID)
if err != nil {
return nil, "", common.CodeServerError, err
}
for _, session := range sessions {
sessionIDs = append(sessionIDs, session.ID)
}
if len(sessionIDs) == 0 {
return map[string]interface{}{}, "success", common.CodeSuccess, nil
}
} else {
return map[string]interface{}{}, "success", common.CodeSuccess, nil
}
}
uniqueIDs, duplicateMessages := checkDuplicateChatSessionIDs(sessionIDs)
errorsList := make([]string, 0)
successCount := 0
for _, sid := range uniqueIDs {
session, err := s.chatSessionDAO.GetBySessionIDAndChatID(sid, chatID)
if err != nil {
errorsList = append(errorsList, fmt.Sprintf("The chat doesn't own the session %s", sid))
continue
}
s.removeSessionUploadFiles(userID, session)
if err := s.chatSessionDAO.DeleteByID(sid); err != nil {
return nil, "", common.CodeServerError, err
}
successCount++
}
allErrors := append(errorsList, duplicateMessages...)
if len(allErrors) > 0 {
if successCount > 0 {
return map[string]interface{}{
"success_count": successCount,
"errors": allErrors,
}, fmt.Sprintf("Partially deleted %d sessions with %d errors", successCount, len(allErrors)), common.CodeSuccess, nil
}
return nil, "", common.CodeDataError, errors.New(strings.Join(allErrors, "; "))
}
return true, "success", common.CodeSuccess, nil
}
func stringSliceFromValue(value interface{}) ([]string, bool) {
var raw []interface{}
switch typed := value.(type) {
case []interface{}:
raw = typed
case []string:
raw = make([]interface{}, 0, len(typed))
for _, item := range typed {
raw = append(raw, item)
}
default:
return nil, false
}
ids := make([]string, 0, len(raw))
for _, item := range raw {
id, ok := item.(string)
if !ok {
continue
}
if strings.TrimSpace(id) == "" {
continue
}
ids = append(ids, id)
}
return ids, true
}
func (s *ChatSessionService) removeSessionUploadFiles(userID string, session *entity.ChatSession) {
messages := parseMessages(session.Message)
bucket := fmt.Sprintf("%s-downloads", userID)
storageImpl := storage.GetStorageFactory().GetStorage()
if storageImpl == nil {
common.Warn("storage is not initialized; skip chat upload cleanup", zap.String("bucket", bucket))
return
}
for _, msg := range messages {
files, ok := msg["files"].([]interface{})
if !ok {
continue
}
for _, item := range files {
file, ok := item.(map[string]interface{})
if !ok {
continue
}
fileID, ok := file["id"].(string)
if !ok || fileID == "" {
continue
}
if err := storageImpl.Remove(bucket, fileID); err != nil {
common.Warn("Failed to delete chat upload blob",
zap.String("bucket", bucket),
zap.String("file_id", fileID),
zap.Error(err),
)
}
}
}
}
func checkDuplicateChatSessionIDs(ids []string) ([]string, []string) {
idCount := make(map[string]int, len(ids))
uniqueIDs := make([]string, 0, len(ids))
for _, id := range ids {
id = strings.TrimSpace(id)
if id == "" {
continue
}
idCount[id]++
if idCount[id] == 1 {
uniqueIDs = append(uniqueIDs, id)
}
}
duplicateMessages := make([]string, 0)
for id, count := range idCount {
if count > 1 {
duplicateMessages = append(duplicateMessages, fmt.Sprintf("Duplicate session ids: %s", id))
}
}
return uniqueIDs, duplicateMessages
}
// UpdateSession updates one chat session after Python-style field validation.
func (s *ChatSessionService) UpdateSession(userID, chatID, sessionID string, req map[string]interface{}) (*ChatSessionPayload, common.ErrorCode, error) {
ok, err := s.ensureOwnedChat(userID, chatID)
if err != nil {
return nil, common.CodeServerError, err
}
if !ok {
return nil, common.CodeAuthenticationError, errors.New("No authorization.")
}
if len(req) == 0 {
return nil, common.CodeArgumentError, errors.New("Request body cannot be empty")
}
if _, err := s.chatSessionDAO.GetBySessionIDAndChatID(sessionID, chatID); err != nil {
if isChatSessionNotFound(err) {
return nil, common.CodeDataError, errors.New("Session not found!")
}
return nil, common.CodeServerError, err
}
if _, ok := req["message"]; ok {
return nil, common.CodeDataError, errors.New("`messages` cannot be changed.")
}
if _, ok := req["messages"]; ok {
return nil, common.CodeDataError, errors.New("`messages` cannot be changed.")
}
if _, ok := req["reference"]; ok {
return nil, common.CodeDataError, errors.New("`reference` cannot be changed.")
}
if name, exists := req["name"]; exists && name != nil {
nameStr, ok := name.(string)
if !ok || strings.TrimSpace(nameStr) == "" {
return nil, common.CodeDataError, errors.New("`name` can not be empty.")
}
req["name"] = strings.TrimSpace(nameStr)
nameRunes := []rune(req["name"].(string))
if len(nameRunes) > 255 {
req["name"] = string(nameRunes[:255])
}
}
updateFields := make(map[string]interface{})
for k, v := range req {
switch k {
case "id", "dialog_id", "chat_id", "user_id":
continue
default:
updateFields[k] = v
}
}
if err := s.chatSessionDAO.UpdateByID(sessionID, updateFields); err != nil {
if isChatSessionNotFound(err) {
return nil, common.CodeDataError, errors.New("Session not found!")
}
return nil, common.CodeServerError, err
}
session, err := s.chatSessionDAO.GetByID(sessionID)
if err != nil {
if isChatSessionNotFound(err) {
return nil, common.CodeDataError, errors.New("Fail to update a session!")
}
return nil, common.CodeServerError, err
}
return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil
}
func (s *ChatSessionService) ensureOwnedChat(userID, chatID string) (bool, error) {
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
if err != nil {
return false, err
}
for _, tenantID := range tenantIDs {
exists, err := s.chatSessionDAO.CheckDialogExists(tenantID, chatID)
if err != nil {
return false, err
}
if exists {
return true, nil
}
}
exists, err := s.chatSessionDAO.CheckDialogExists(userID, chatID)
if err != nil {
return false, err
}
return exists, nil
}
func (s *ChatSessionService) buildSessionPayload(session *entity.ChatSession, dialog *entity.Chat, includeAvatar bool) *ChatSessionPayload {
var avatar *string
if includeAvatar {
value := ""
if dialog != nil && dialog.Icon != nil {
value = *dialog.Icon
}
avatar = &value
}
references := parseReferenceList(session.Reference)
for index, ref := range references {
refMap, ok := ref.(map[string]interface{})
if !ok {
continue
}
refMap["chunks"] = formatReferenceChunks(refMap)
references[index] = refMap
}
return &ChatSessionPayload{
ID: session.ID,
ChatID: session.DialogID,
Name: session.Name,
Messages: parseMessages(session.Message),
Reference: references,
UserID: session.UserID,
Avatar: avatar,
CreateDate: session.CreateDate,
UpdateDate: session.UpdateDate,
CreateTime: session.CreateTime,
UpdateTime: session.UpdateTime,
}
}
func parseMessages(raw json.RawMessage) []map[string]interface{} {
var messages []map[string]interface{}
if len(raw) == 0 {
return messages
}
if err := json.Unmarshal(raw, &messages); err == nil {
return messages
}
var wrapped struct {
Messages []map[string]interface{} `json:"messages"`
}
if err := json.Unmarshal(raw, &wrapped); err != nil {
return nil
}
return wrapped.Messages
}
func parseReferenceList(raw json.RawMessage) []interface{} {
var references []interface{}
if len(raw) == 0 {
return references
}
err := json.Unmarshal(raw, &references)
if err != nil {
return nil
}
return references
}
func formatReferenceChunks(reference map[string]interface{}) []FormattedChunk {
rawChunks, ok := reference["chunks"].([]interface{})
if !ok {
return []FormattedChunk{}
}
chunks := make([]map[string]interface{}, 0, len(rawChunks))
for _, item := range rawChunks {
chunk, ok := item.(map[string]interface{})
if !ok {
continue
}
chunks = append(chunks, chunk)
}
return formatChunks(chunks)
}
func isChatSessionNotFound(err error) bool {
return errors.Is(err, gorm.ErrRecordNotFound)
}
// Completion performs chat completion with full RAG support via ChatPipelineService.
func (s *ChatSessionService) Completion(userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string) (map[string]interface{}, error) {
// Validate the last message is from user
if len(messages) == 0 {
return nil, errors.New("messages cannot be empty")
}
lastRole, _ := messages[len(messages)-1]["role"].(string)
if lastRole != "user" {
return nil, errors.New("the last content of this conversation is not from user")
}
// Get conversation
session, err := s.chatSessionDAO.GetByID(conversationID)
if err != nil {
return nil, errors.New("Conversation not found")
}
// Get dialog
dialog, err := s.chatSessionDAO.GetDialogByID(session.DialogID)
if err != nil {
return nil, errors.New("Dialog not found")
}
// Deep copy messages to session, preserving the stored prologue that handler strips from requests.
sessionMessages := s.buildSessionMessages(session, messages)
// Initialize reference if empty
reference := s.initializeReference(session)
// Check if custom LLM is specified and validate API key
isEmbedded := llmID != ""
if llmID != "" {
hasKey, err := s.checkTenantLLMAPIKey(dialog.TenantID, llmID)
if err != nil || !hasKey {
return nil, fmt.Errorf("Cannot use specified model %s", llmID)
}
dialog.LLMID = llmID
if chatModelConfig != nil {
dialog.LLMSetting = chatModelConfig
}
}
// Perform chat completion via shared RAG pipeline
ctx := context.Background()
kwargs := chatModelConfig
if kwargs == nil {
kwargs = map[string]interface{}{}
}
resultChan, err := s.pipeline.AsyncChat(ctx, dialog, messages, false, kwargs)
if err != nil {
return nil, err
}
// Collect results from the pipeline
var answer strings.Builder
var finalRef map[string]interface{}
for result := range resultChan {
if result.Answer != "" {
answer.WriteString(result.Answer)
}
if result.Reference != nil {
finalRef = result.Reference
}
}
// Structure the answer
ans := map[string]interface{}{
"answer": answer.String(),
"reference": finalRef,
"final": true,
}
result := s.structureAnswerWithConv(session, ans, messageID, session.ID, reference)
// Update conversation if not embedded
if !isEmbedded {
sessionMessages = append(sessionMessages, map[string]interface{}{
"role": "assistant",
"content": answer.String(),
"id": messageID,
"created_at": float64(time.Now().Unix()),
})
s.updateSessionMessages(session, sessionMessages, reference)
}
return result, nil
}
// CompletionStream performs streaming chat completion with full RAG support via ChatPipelineService.
func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string, streamChan chan<- string) error {
if ctx == nil {
ctx = context.Background()
}
// Validate the last message is from user
if len(messages) == 0 {
streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "messages cannot be empty", "data": {"answer": "**ERROR**: messages cannot be empty", "reference": []}}`)
return errors.New("messages cannot be empty")
}
lastRole, _ := messages[len(messages)-1]["role"].(string)
if lastRole != "user" {
streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "the last content of this conversation is not from user", "data": {"answer": "**ERROR**: the last content of this conversation is not from user", "reference": []}}`)
return errors.New("the last content of this conversation is not from user")
}
// Get conversation
session, err := s.chatSessionDAO.GetByID(conversationID)
if err != nil {
streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "Conversation not found", "data": {"answer": "**ERROR**: Conversation not found", "reference": []}}`)
return errors.New("Conversation not found")
}
// Get dialog
dialog, err := s.chatSessionDAO.GetDialogByID(session.DialogID)
if err != nil {
streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "Dialog not found", "data": {"answer": "**ERROR**: Dialog not found", "reference": []}}`)
return errors.New("Dialog not found")
}
// Deep copy messages to session, preserving the stored prologue that handler strips from requests.
sessionMessages := s.buildSessionMessages(session, messages)
// Initialize reference if empty
reference := s.initializeReference(session)
// Check if custom LLM is specified and validate API key
isEmbedded := llmID != ""
if llmID != "" {
hasKey, err := s.checkTenantLLMAPIKey(dialog.TenantID, llmID)
if err != nil || !hasKey {
errMsg := fmt.Sprintf(`{"code": 500, "message": "Cannot use specified model %s", "data": {"answer": "**ERROR**: Cannot use specified model", "reference": []}}`, llmID)
streamChan <- fmt.Sprintf("data: %s\n\n", errMsg)
return fmt.Errorf("Cannot use specified model %s", llmID)
}
dialog.LLMID = llmID
if chatModelConfig != nil {
dialog.LLMSetting = chatModelConfig
}
}
// Perform streaming chat via shared RAG pipeline
kwargs := chatModelConfig
if kwargs == nil {
kwargs = map[string]interface{}{}
}
resultChan, err := s.pipeline.AsyncChat(ctx, dialog, messages, true, kwargs)
if err != nil {
streamChan <- fmt.Sprintf("data: %s\n\n", fmt.Sprintf(`{"code": 500, "message": "%s", "data": {"answer": "**ERROR**: %s", "reference": []}}`, err.Error(), err.Error()))
return err
}
// Stream results, accumulating the answer
var fullAnswer strings.Builder
for result := range resultChan {
if result.Reference != nil && len(reference) > 0 {
reference[len(reference)-1] = result.Reference
}
if result.Final {
if result.Answer != "" {
fullAnswer.Reset()
fullAnswer.WriteString(result.Answer)
}
} else if result.Answer != "" {
fullAnswer.WriteString(result.Answer)
}
ans := s.structureAnswer(session, fullAnswer.String(), messageID, session.ID, reference)
data, _ := json.Marshal(map[string]interface{}{
"code": 0,
"message": "",
"data": ans,
})
streamChan <- fmt.Sprintf("data: %s\n\n", string(data))
}
// Send final completion signal
finalData, _ := json.Marshal(map[string]interface{}{
"code": 0,
"message": "",
"data": true,
})
streamChan <- fmt.Sprintf("data: %s\n\n", string(finalData))
// Update conversation if not embedded
if !isEmbedded {
sessionMessages = append(sessionMessages, map[string]interface{}{
"role": "assistant",
"content": fullAnswer.String(),
"id": messageID,
"created_at": float64(time.Now().Unix()),
})
s.updateSessionMessages(session, sessionMessages, reference)
}
return nil
}
// Helper methods
func (s *ChatSessionService) buildSessionMessages(session *entity.ChatSession, messages []map[string]interface{}) []map[string]interface{} {
prefix := make([]map[string]interface{}, 0, 1)
existingMessages := parseMessages(session.Message)
if len(existingMessages) > 0 {
if role, _ := existingMessages[0]["role"].(string); role == "assistant" {
firstIncomingRole := ""
if len(messages) > 0 {
firstIncomingRole, _ = messages[0]["role"].(string)
}
if firstIncomingRole != "assistant" {
prologue := make(map[string]interface{}, len(existingMessages[0]))
for k, v := range existingMessages[0] {
prologue[k] = v
}
prefix = append(prefix, prologue)
}
}
}
sessionMessages := make([]map[string]interface{}, 0, len(prefix)+len(messages))
sessionMessages = append(sessionMessages, prefix...)
for _, msg := range messages {
cloned := make(map[string]interface{}, len(msg))
for k, v := range msg {
cloned[k] = v
}
sessionMessages = append(sessionMessages, cloned)
}
return sessionMessages
}
func (s *ChatSessionService) initializeReference(session *entity.ChatSession) []interface{} {
var reference []interface{}
if len(session.Reference) > 0 {
json.Unmarshal(session.Reference, &reference)
}
// Filter out nil entries and append new reference
var filtered []interface{}
for _, r := range reference {
if r != nil {
filtered = append(filtered, r)
}
}
filtered = append(filtered, map[string]interface{}{
"chunks": []map[string]interface{}{},
"doc_aggs": []interface{}{},
})
return filtered
}
func (s *ChatSessionService) checkTenantLLMAPIKey(tenantID, modelName string) (bool, error) {
// Simplified check - in real implementation, check if tenant has API key for this model
return true, nil
}
func (s *ChatSessionService) structureAnswer(session *entity.ChatSession, answer string, messageID, conversationID string, reference []interface{}) map[string]interface{} {
return map[string]interface{}{
"answer": answer,
"reference": reference,
"conversation_id": conversationID,
"message_id": messageID,
}
}
func (s *ChatSessionService) updateSessionMessages(session *entity.ChatSession, messages []map[string]interface{}, reference []interface{}) {
// Update session with new messages and reference
messagesJSON, _ := json.Marshal(messages)
referenceJSON, _ := json.Marshal(reference)
updates := map[string]interface{}{
"message": messagesJSON,
"reference": referenceJSON,
}
s.chatSessionDAO.UpdateByID(session.ID, updates)
session.Message = messagesJSON
session.Reference = referenceJSON
}
// structureAnswerWithConv structures the answer with conversation update (like Python's structure_answer)
func (s *ChatSessionService) structureAnswerWithConv(session *entity.ChatSession, ans map[string]interface{}, messageID, conversationID string, reference []interface{}) map[string]interface{} {
// Extract reference from answer
ref, _ := ans["reference"].(map[string]interface{})
if ref == nil {
ref = map[string]interface{}{
"chunks": []map[string]interface{}{},
"doc_aggs": []interface{}{},
}
ans["reference"] = ref
}
// Format chunks
chunkList := s.chunksFormat(ref)
ref["chunks"] = chunkList
// Add message ID and session ID
ans["id"] = messageID
ans["session_id"] = conversationID
// Update session message
content, _ := ans["answer"].(string)
if ans["start_to_think"] != nil {
content = "<think>"
} else if ans["end_to_think"] != nil {
content = "</think>"
}
// Parse existing messages. Keep backward compatibility with wrapped legacy rows.
messages := parseMessages(session.Message)
// Update or append assistant message
if len(messages) == 0 || s.getLastRole(messages) != "assistant" {
messages = append(messages, map[string]interface{}{
"role": "assistant",
"content": content,
"created_at": float64(time.Now().Unix()),
"id": messageID,
})
} else {
lastIdx := len(messages) - 1
lastMsg := messages[lastIdx]
if ans["final"] == true && ans["answer"] != nil {
lastMsg["content"] = ans["answer"]
} else {
existing, _ := lastMsg["content"].(string)
lastMsg["content"] = existing + content
}
lastMsg["created_at"] = float64(time.Now().Unix())
lastMsg["id"] = messageID
messages[lastIdx] = lastMsg
}
session.Message, _ = json.Marshal(messages)
// Update reference
if len(reference) > 0 {
reference[len(reference)-1] = ref
}
return ans
}
// getLastRole gets the role of the last message
func (s *ChatSessionService) getLastRole(messages []map[string]interface{}) string {
if len(messages) == 0 {
return ""
}
role, _ := messages[len(messages)-1]["role"].(string)
return role
}
// chunksFormat formats chunks for reference (simplified version)
func (s *ChatSessionService) chunksFormat(reference map[string]interface{}) []map[string]interface{} {
switch c := reference["chunks"].(type) {
case []map[string]interface{}:
formatted := make([]map[string]interface{}, len(c))
copy(formatted, c)
return formatted
case []interface{}:
formatted := make([]map[string]interface{}, 0, len(c))
for _, item := range c {
if m, ok := item.(map[string]interface{}); ok {
formatted = append(formatted, m)
}
}
return formatted
default:
return []map[string]interface{}{}
}
}