feat(go-api): align chat session get/update with python behavior (#16239)

## Summary

Align `/chats/:chat_id/sessions/:session_id` GET and PATCH with Python
behavior.
This commit is contained in:
Hz_
2026-06-24 17:34:01 +08:00
committed by GitHub
parent dc8ff63f1d
commit 9a91564194
6 changed files with 744 additions and 61 deletions

View File

@@ -20,6 +20,8 @@ import (
"strings"
"time"
"gorm.io/gorm"
"ragflow/internal/entity"
)
@@ -56,6 +58,16 @@ func (dao *ChatSessionDAO) GetByID(id string) (*entity.ChatSession, error) {
return &conv, nil
}
// GetBySessionIDAndChatID gets a chat session by session ID and chat ID.
func (dao *ChatSessionDAO) GetBySessionIDAndChatID(sessionID, chatID string) (*entity.ChatSession, error) {
var conv entity.ChatSession
err := DB.Where("id = ? AND dialog_id = ?", sessionID, chatID).First(&conv).Error
if err != nil {
return nil, err
}
return &conv, nil
}
// Create creates a new chat session
func (dao *ChatSessionDAO) Create(conv *entity.ChatSession) error {
return DB.Create(conv).Error
@@ -63,7 +75,14 @@ func (dao *ChatSessionDAO) Create(conv *entity.ChatSession) error {
// UpdateByID updates a chat session by ID
func (dao *ChatSessionDAO) UpdateByID(id string, updates map[string]interface{}) error {
return DB.Model(&entity.ChatSession{}).Where("id = ?", id).Updates(updates).Error
result := DB.Model(&entity.ChatSession{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
// DeleteByID deletes a chat session by ID (hard delete)

View File

@@ -17,6 +17,7 @@
package handler
import (
"errors"
"fmt"
"io"
"net/http"
@@ -329,3 +330,53 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) {
})
}
}
func (h *ChatSessionHandler) GetSession(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := user.ID
chatID, sessionID := c.Param("chat_id"), c.Param("session_id")
result, code, err := h.chatSessionService.GetSession(userID, chatID, sessionID)
if err != nil {
jsonError(c, code, err.Error())
return
}
jsonResponse(c, common.CodeSuccess, result, "success")
}
func (h *ChatSessionHandler) UpdateSession(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := user.ID
chatID, sessionID := c.Param("chat_id"), c.Param("session_id")
req := map[string]any{}
if err := c.ShouldBindJSON(&req); err != nil {
if errors.Is(err, io.EOF) {
jsonError(c, common.CodeArgumentError, "Request body cannot be empty")
return
}
jsonError(c, common.CodeArgumentError, "Invalid request: "+err.Error())
return
}
if len(req) == 0 {
jsonError(c, common.CodeArgumentError, "Request body cannot be empty")
return
}
result, code, err := h.chatSessionService.UpdateSession(userID, chatID, sessionID, req)
if err != nil {
jsonError(c, code, err.Error())
return
}
jsonResponse(c, common.CodeSuccess, result, "success")
}

View File

@@ -0,0 +1,74 @@
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"ragflow/internal/common"
"ragflow/internal/entity"
"ragflow/internal/service"
"github.com/gin-gonic/gin"
)
func TestChatSessionHandlerUpdateSession_RejectsEmptyBody(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodPatch, "/api/v1/chats/chat-1/sessions/session-1", nil)
ctx.Params = gin.Params{
{Key: "chat_id", Value: "chat-1"},
{Key: "session_id", Value: "session-1"},
}
ctx.Set("user", &entity.User{ID: "user-1"})
handler := NewChatSessionHandler(service.NewChatSessionService(), nil)
handler.UpdateSession(ctx)
if recorder.Code != http.StatusOK {
t.Fatalf("status=%d", recorder.Code)
}
var body map[string]interface{}
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
t.Fatalf("decode response body: %v", err)
}
if got := body["code"]; got != float64(common.CodeArgumentError) {
t.Fatalf("code=%v", got)
}
if got := body["message"]; got != "Request body cannot be empty" {
t.Fatalf("message=%v", got)
}
}
func TestChatSessionHandlerUpdateSession_RejectsEmptyJSONObject(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodPatch, "/api/v1/chats/chat-1/sessions/session-1", strings.NewReader(`{}`))
ctx.Request.Header.Set("Content-Type", "application/json")
ctx.Params = gin.Params{
{Key: "chat_id", Value: "chat-1"},
{Key: "session_id", Value: "session-1"},
}
ctx.Set("user", &entity.User{ID: "user-1"})
handler := NewChatSessionHandler(service.NewChatSessionService(), nil)
handler.UpdateSession(ctx)
var body map[string]interface{}
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
t.Fatalf("decode response body: %v", err)
}
if got := body["code"]; got != float64(common.CodeArgumentError) {
t.Fatalf("code=%v", got)
}
if got := body["message"]; got != "Request body cannot be empty" {
t.Fatalf("message=%v", got)
}
}

View File

@@ -253,6 +253,8 @@ func (r *Router) Setup(engine *gin.Engine) {
chats.DELETE("/:chat_id", r.chatHandler.DeleteChat)
chats.GET("/:chat_id", r.chatHandler.GetChat)
chats.GET("/:chat_id/sessions", r.chatSessionHandler.ListChatSessions)
chats.GET("/:chat_id/sessions/:session_id", r.chatSessionHandler.GetSession)
chats.PATCH("/:chat_id/sessions/:session_id", r.chatSessionHandler.UpdateSession)
}
// OpenAI-compatible chat completions route

View File

@@ -25,6 +25,8 @@ import (
"strings"
"time"
"gorm.io/gorm"
"ragflow/internal/dao"
"ragflow/internal/entity"
)
@@ -33,6 +35,7 @@ import (
type chatSessionStore interface {
GetByID(id string) (*entity.ChatSession, error)
GetBySessionIDAndChatID(sessionID, chatID string) (*entity.ChatSession, error)
Create(conv *entity.ChatSession) error
UpdateByID(id string, updates map[string]interface{}) error
DeleteByID(id string) error
@@ -128,16 +131,13 @@ func (s *ChatSessionService) SetChatSession(userID string, req *SetChatSessionRe
}
}
// Create initial message - store as JSON object with messages array
messagesObj := map[string]interface{}{
"messages": []map[string]interface{}{
{
"role": "assistant",
"content": prologue,
},
// Store messages in the same list shape as Python Conversation.message.
messagesJSON, _ := json.Marshal([]map[string]interface{}{
{
"role": "assistant",
"content": prologue,
},
}
messagesJSON, _ := json.Marshal(messagesObj)
})
// Create reference - store as JSON array
referenceJSON, _ := json.Marshal([]interface{}{})
@@ -218,6 +218,20 @@ type ListChatSessionsResponse struct {
Sessions []*entity.ChatSession
}
type ChatSessionPayload struct {
ID string `json:"id"`
ChatID string `json:"chat_id"`
Name *string `json:"name,omitempty"`
Messages []map[string]interface{} `json:"messages"`
Reference []interface{} `json:"reference"`
UserID *string `json:"user_id,omitempty"`
Avatar *string `json:"avatar,omitempty"`
CreateDate *time.Time `json:"create_date,omitempty"`
UpdateDate *time.Time `json:"update_date,omitempty"`
CreateTime *int64 `json:"create_time,omitempty"`
UpdateTime *int64 `json:"update_time,omitempty"`
}
// ListChatSessions lists chat sessions for a dialog
func (s *ChatSessionService) ListChatSessions(userID string, chatID string) (*ListChatSessionsResponse, error) {
// Get user's tenants
@@ -263,6 +277,214 @@ func (s *ChatSessionService) ListChatSessions(userID string, chatID string) (*Li
return &ListChatSessionsResponse{Sessions: sessions}, nil
}
// GetSession returns one chat session after ownership validation.
func (s *ChatSessionService) GetSession(userID, chatID, sessionID string) (*ChatSessionPayload, common.ErrorCode, error) {
ok, err := s.ensureOwnedChat(userID, chatID)
if err != nil {
return nil, common.CodeServerError, err
}
if !ok {
return nil, common.CodeAuthenticationError, errors.New("No authorization.")
}
session, err := s.chatSessionDAO.GetByID(sessionID)
if err != nil {
if isChatSessionNotFound(err) {
return nil, common.CodeDataError, errors.New("Session not found!")
}
return nil, common.CodeServerError, err
}
if session.DialogID != chatID {
return nil, common.CodeDataError, errors.New("Session does not belong to this chat!")
}
dialog, err := s.chatSessionDAO.GetDialogByID(chatID)
if err != nil && !isChatSessionNotFound(err) {
return nil, common.CodeServerError, err
}
return s.buildSessionPayload(session, dialog, true), common.CodeSuccess, nil
}
// UpdateSession updates one chat session after Python-style field validation.
func (s *ChatSessionService) UpdateSession(userID, chatID, sessionID string, req map[string]interface{}) (*ChatSessionPayload, common.ErrorCode, error) {
ok, err := s.ensureOwnedChat(userID, chatID)
if err != nil {
return nil, common.CodeServerError, err
}
if !ok {
return nil, common.CodeAuthenticationError, errors.New("No authorization.")
}
if len(req) == 0 {
return nil, common.CodeArgumentError, errors.New("Request body cannot be empty")
}
if _, err := s.chatSessionDAO.GetBySessionIDAndChatID(sessionID, chatID); err != nil {
if isChatSessionNotFound(err) {
return nil, common.CodeDataError, errors.New("Session not found!")
}
return nil, common.CodeServerError, err
}
if _, ok := req["message"]; ok {
return nil, common.CodeDataError, errors.New("`messages` cannot be changed.")
}
if _, ok := req["messages"]; ok {
return nil, common.CodeDataError, errors.New("`messages` cannot be changed.")
}
if _, ok := req["reference"]; ok {
return nil, common.CodeDataError, errors.New("`reference` cannot be changed.")
}
if name, exists := req["name"]; exists && name != nil {
nameStr, ok := name.(string)
if !ok || strings.TrimSpace(nameStr) == "" {
return nil, common.CodeDataError, errors.New("`name` can not be empty.")
}
req["name"] = strings.TrimSpace(nameStr)
nameRunes := []rune(req["name"].(string))
if len(nameRunes) > 255 {
req["name"] = string(nameRunes[:255])
}
}
updateFields := make(map[string]interface{})
for k, v := range req {
switch k {
case "id", "dialog_id", "chat_id", "user_id":
continue
default:
updateFields[k] = v
}
}
if err := s.chatSessionDAO.UpdateByID(sessionID, updateFields); err != nil {
if isChatSessionNotFound(err) {
return nil, common.CodeDataError, errors.New("Session not found!")
}
return nil, common.CodeServerError, err
}
session, err := s.chatSessionDAO.GetByID(sessionID)
if err != nil {
if isChatSessionNotFound(err) {
return nil, common.CodeDataError, errors.New("Fail to update a session!")
}
return nil, common.CodeServerError, err
}
return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil
}
func (s *ChatSessionService) ensureOwnedChat(userID, chatID string) (bool, error) {
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
if err != nil {
return false, err
}
for _, tenantID := range tenantIDs {
exists, err := s.chatSessionDAO.CheckDialogExists(tenantID, chatID)
if err != nil {
return false, err
}
if exists {
return true, nil
}
}
exists, err := s.chatSessionDAO.CheckDialogExists(userID, chatID)
if err != nil {
return false, err
}
return exists, nil
}
func (s *ChatSessionService) buildSessionPayload(session *entity.ChatSession, dialog *entity.Chat, includeAvatar bool) *ChatSessionPayload {
var avatar *string
if includeAvatar {
value := ""
if dialog != nil && dialog.Icon != nil {
value = *dialog.Icon
}
avatar = &value
}
references := parseReferenceList(session.Reference)
for index, ref := range references {
refMap, ok := ref.(map[string]interface{})
if !ok {
continue
}
refMap["chunks"] = formatReferenceChunks(refMap)
references[index] = refMap
}
return &ChatSessionPayload{
ID: session.ID,
ChatID: session.DialogID,
Name: session.Name,
Messages: parseMessages(session.Message),
Reference: references,
UserID: session.UserID,
Avatar: avatar,
CreateDate: session.CreateDate,
UpdateDate: session.UpdateDate,
CreateTime: session.CreateTime,
UpdateTime: session.UpdateTime,
}
}
func parseMessages(raw json.RawMessage) []map[string]interface{} {
var messages []map[string]interface{}
if len(raw) == 0 {
return messages
}
if err := json.Unmarshal(raw, &messages); err == nil {
return messages
}
var wrapped struct {
Messages []map[string]interface{} `json:"messages"`
}
if err := json.Unmarshal(raw, &wrapped); err != nil {
return nil
}
return wrapped.Messages
}
func parseReferenceList(raw json.RawMessage) []interface{} {
var references []interface{}
if len(raw) == 0 {
return references
}
err := json.Unmarshal(raw, &references)
if err != nil {
return nil
}
return references
}
func formatReferenceChunks(reference map[string]interface{}) []FormattedChunk {
rawChunks, ok := reference["chunks"].([]interface{})
if !ok {
return []FormattedChunk{}
}
chunks := make([]map[string]interface{}, 0, len(rawChunks))
for _, item := range rawChunks {
chunk, ok := item.(map[string]interface{})
if !ok {
continue
}
chunks = append(chunks, chunk)
}
return formatChunks(chunks)
}
func isChatSessionNotFound(err error) bool {
return errors.Is(err, gorm.ErrRecordNotFound)
}
// Completion performs chat completion with full RAG support via ChatPipelineService.
func (s *ChatSessionService) Completion(userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string) (map[string]interface{}, error) {
// Validate the last message is from user
@@ -286,7 +508,7 @@ func (s *ChatSessionService) Completion(userID string, conversationID string, me
return nil, errors.New("Dialog not found")
}
// Deep copy messages to session
// Deep copy messages to session, preserving the stored prologue that handler strips from requests.
sessionMessages := s.buildSessionMessages(session, messages)
// Initialize reference if empty
@@ -338,6 +560,12 @@ func (s *ChatSessionService) Completion(userID string, conversationID string, me
// Update conversation if not embedded
if !isEmbedded {
sessionMessages = append(sessionMessages, map[string]interface{}{
"role": "assistant",
"content": answer.String(),
"id": messageID,
"created_at": float64(time.Now().Unix()),
})
s.updateSessionMessages(session, sessionMessages, reference)
}
@@ -375,7 +603,7 @@ func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string
return errors.New("Dialog not found")
}
// Deep copy messages to session
// Deep copy messages to session, preserving the stored prologue that handler strips from requests.
sessionMessages := s.buildSessionMessages(session, messages)
// Initialize reference if empty
@@ -440,6 +668,12 @@ func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string
// Update conversation if not embedded
if !isEmbedded {
sessionMessages = append(sessionMessages, map[string]interface{}{
"role": "assistant",
"content": fullAnswer.String(),
"id": messageID,
"created_at": float64(time.Now().Unix()),
})
s.updateSessionMessages(session, sessionMessages, reference)
}
@@ -449,14 +683,33 @@ func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string
// Helper methods
func (s *ChatSessionService) buildSessionMessages(session *entity.ChatSession, messages []map[string]interface{}) []map[string]interface{} {
// Deep copy messages to session
sessionMessages := make([]map[string]interface{}, len(messages))
for i, msg := range messages {
sessionMessages[i] = make(map[string]interface{})
for k, v := range msg {
sessionMessages[i][k] = v
prefix := make([]map[string]interface{}, 0, 1)
existingMessages := parseMessages(session.Message)
if len(existingMessages) > 0 {
if role, _ := existingMessages[0]["role"].(string); role == "assistant" {
firstIncomingRole := ""
if len(messages) > 0 {
firstIncomingRole, _ = messages[0]["role"].(string)
}
if firstIncomingRole != "assistant" {
prologue := make(map[string]interface{}, len(existingMessages[0]))
for k, v := range existingMessages[0] {
prologue[k] = v
}
prefix = append(prefix, prologue)
}
}
}
sessionMessages := make([]map[string]interface{}, 0, len(prefix)+len(messages))
sessionMessages = append(sessionMessages, prefix...)
for _, msg := range messages {
cloned := make(map[string]interface{}, len(msg))
for k, v := range msg {
cloned[k] = v
}
sessionMessages = append(sessionMessages, cloned)
}
return sessionMessages
}
@@ -495,9 +748,7 @@ func (s *ChatSessionService) structureAnswer(session *entity.ChatSession, answer
func (s *ChatSessionService) updateSessionMessages(session *entity.ChatSession, messages []map[string]interface{}, reference []interface{}) {
// Update session with new messages and reference
messagesJSON, _ := json.Marshal(map[string]interface{}{
"messages": messages,
})
messagesJSON, _ := json.Marshal(messages)
referenceJSON, _ := json.Marshal(reference)
updates := map[string]interface{}{
@@ -505,6 +756,8 @@ func (s *ChatSessionService) updateSessionMessages(session *entity.ChatSession,
"reference": referenceJSON,
}
s.chatSessionDAO.UpdateByID(session.ID, updates)
session.Message = messagesJSON
session.Reference = referenceJSON
}
// structureAnswerWithConv structures the answer with conversation update (like Python's structure_answer)
@@ -535,12 +788,8 @@ func (s *ChatSessionService) structureAnswerWithConv(session *entity.ChatSession
content = "</think>"
}
// Parse existing messages
var messagesObj map[string]interface{}
if len(session.Message) > 0 {
json.Unmarshal(session.Message, &messagesObj)
}
messages, _ := messagesObj["messages"].([]interface{})
// Parse existing messages. Keep backward compatibility with wrapped legacy rows.
messages := parseMessages(session.Message)
// Update or append assistant message
if len(messages) == 0 || s.getLastRole(messages) != "assistant" {
@@ -552,20 +801,20 @@ func (s *ChatSessionService) structureAnswerWithConv(session *entity.ChatSession
})
} else {
lastIdx := len(messages) - 1
lastMsg, _ := messages[lastIdx].(map[string]interface{})
if lastMsg != nil {
if ans["final"] == true && ans["answer"] != nil {
lastMsg["content"] = ans["answer"]
} else {
existing, _ := lastMsg["content"].(string)
lastMsg["content"] = existing + content
}
lastMsg["created_at"] = float64(time.Now().Unix())
lastMsg["id"] = messageID
messages[lastIdx] = lastMsg
lastMsg := messages[lastIdx]
if ans["final"] == true && ans["answer"] != nil {
lastMsg["content"] = ans["answer"]
} else {
existing, _ := lastMsg["content"].(string)
lastMsg["content"] = existing + content
}
lastMsg["created_at"] = float64(time.Now().Unix())
lastMsg["id"] = messageID
messages[lastIdx] = lastMsg
}
session.Message, _ = json.Marshal(messages)
// Update reference
if len(reference) > 0 {
reference[len(reference)-1] = ref
@@ -575,16 +824,12 @@ func (s *ChatSessionService) structureAnswerWithConv(session *entity.ChatSession
}
// getLastRole gets the role of the last message
func (s *ChatSessionService) getLastRole(messages []interface{}) string {
func (s *ChatSessionService) getLastRole(messages []map[string]interface{}) string {
if len(messages) == 0 {
return ""
}
lastMsg, _ := messages[len(messages)-1].(map[string]interface{})
if lastMsg != nil {
role, _ := lastMsg["role"].(string)
return role
}
return ""
role, _ := messages[len(messages)-1]["role"].(string)
return role
}
// chunksFormat formats chunks for reference (simplified version)

View File

@@ -4,10 +4,13 @@ import (
"context"
"encoding/json"
"errors"
"reflect"
"strings"
"sync"
"testing"
"gorm.io/gorm"
"ragflow/internal/common"
"ragflow/internal/entity"
)
@@ -48,7 +51,18 @@ func (f *fakeSessionStore) GetByID(id string) (*entity.ChatSession, error) {
}
s, ok := f.sessions[id]
if !ok {
return nil, errors.New("record not found")
return nil, gorm.ErrRecordNotFound
}
return s, nil
}
func (f *fakeSessionStore) GetBySessionIDAndChatID(sessionID, chatID string) (*entity.ChatSession, error) {
s, err := f.GetByID(sessionID)
if err != nil {
return nil, err
}
if s.DialogID != chatID {
return nil, gorm.ErrRecordNotFound
}
return s, nil
}
@@ -70,10 +84,30 @@ func (f *fakeSessionStore) UpdateByID(id string, updates map[string]interface{})
if f.updateByIDErr != nil {
return f.updateByIDErr
}
s, ok := f.sessions[id]
if !ok {
return gorm.ErrRecordNotFound
}
f.updateCalled = append(f.updateCalled, struct {
id string
updates map[string]interface{}
}{id, updates})
for k, v := range updates {
switch k {
case "name":
if str, ok := v.(string); ok {
s.Name = &str
}
case "message":
if raw, ok := v.([]byte); ok {
s.Message = append(json.RawMessage(nil), raw...)
}
case "reference":
if raw, ok := v.([]byte); ok {
s.Reference = append(json.RawMessage(nil), raw...)
}
}
}
return nil
}
@@ -177,16 +211,15 @@ func TestSetChatSession_CreateNew(t *testing.T) {
t.Fatalf("expected 1 Create call, got %d", len(store.createCalled))
}
// Verify prologue is in the message
var msgObj map[string]interface{}
if err := json.Unmarshal(store.createCalled[0].Message, &msgObj); err != nil {
// Verify prologue is in the message list.
var msgs []map[string]interface{}
if err := json.Unmarshal(store.createCalled[0].Message, &msgs); err != nil {
t.Fatalf("failed to unmarshal message: %v", err)
}
msgs, _ := msgObj["messages"].([]interface{})
if len(msgs) != 1 {
t.Fatalf("expected 1 initial message, got %d", len(msgs))
}
firstMsg, _ := msgs[0].(map[string]interface{})
firstMsg := msgs[0]
if firstMsg["role"] != "assistant" || firstMsg["content"] != "Welcome!" {
t.Fatalf("unexpected prologue message: %#v", firstMsg)
}
@@ -213,10 +246,9 @@ func TestSetChatSession_CreateNewDefaultPrologue(t *testing.T) {
t.Fatal("expected session ID")
}
// Default prologue
var msgObj map[string]interface{}
json.Unmarshal(store.createCalled[0].Message, &msgObj)
msgs, _ := msgObj["messages"].([]interface{})
firstMsg, _ := msgs[0].(map[string]interface{})
var msgs []map[string]interface{}
json.Unmarshal(store.createCalled[0].Message, &msgs)
firstMsg := msgs[0]
if !strings.Contains(firstMsg["content"].(string), "Hi! I'm your assistant") {
t.Fatalf("expected default prologue, got %q", firstMsg["content"])
}
@@ -407,6 +439,205 @@ func TestListChatSessions_NotOwner(t *testing.T) {
}
}
// ===================================================================
// GetSession / UpdateSession tests
// ===================================================================
func TestGetSession_Success(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Name: strPtr("session"),
Message: json.RawMessage(`[{"role":"assistant","content":"hello"}]`),
Reference: json.RawMessage(`[
{"chunks":[{"chunk_id":"chunk-1","content_with_weight":"hello","doc_id":"doc-1","docnm_kwd":"Doc 1","kb_id":"kb-1"}]},
[]
]`),
UserID: strPtr("user-1"),
}
icon := "avatar.png"
store.dialogs["chat-1"] = &entity.Chat{ID: "chat-1", Icon: &icon}
store.dialogExists["user-1|chat-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
resp, code, err := svc.GetSession("user-1", "chat-1", "session-1")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("unexpected code: %v", code)
}
if resp.ChatID != "chat-1" {
t.Fatalf("chat_id=%q", resp.ChatID)
}
if resp.Avatar == nil || *resp.Avatar != "avatar.png" {
t.Fatalf("avatar=%v", resp.Avatar)
}
if len(resp.Messages) != 1 || resp.Messages[0]["content"] != "hello" {
t.Fatalf("messages=%#v", resp.Messages)
}
if len(resp.Reference) != 2 {
t.Fatalf("reference len=%d", len(resp.Reference))
}
firstRef, ok := resp.Reference[0].(map[string]interface{})
if !ok {
t.Fatalf("reference[0] type=%T", resp.Reference[0])
}
chunks, ok := firstRef["chunks"].([]FormattedChunk)
if !ok {
t.Fatalf("chunks type=%T", firstRef["chunks"])
}
if len(chunks) != 1 || chunks[0].ID != "chunk-1" {
t.Fatalf("chunks=%#v", chunks)
}
if _, ok := resp.Reference[1].([]interface{}); !ok {
t.Fatalf("reference[1] changed unexpectedly: %T", resp.Reference[1])
}
}
func TestGetSession_NotOwner(t *testing.T) {
svc := &ChatSessionService{
chatSessionDAO: newFakeSessionStore(),
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, code, err := svc.GetSession("user-1", "chat-1", "session-1")
if err == nil || err.Error() != "No authorization." {
t.Fatalf("err=%v", err)
}
if code != common.CodeAuthenticationError {
t.Fatalf("code=%v", code)
}
}
func TestGetSession_WrongChat(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{ID: "session-1", DialogID: "chat-2"}
store.dialogExists["user-1|chat-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, code, err := svc.GetSession("user-1", "chat-1", "session-1")
if err == nil || err.Error() != "Session does not belong to this chat!" {
t.Fatalf("err=%v", err)
}
if code != common.CodeDataError {
t.Fatalf("code=%v", code)
}
}
func TestUpdateSession_Success(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Name: strPtr("old"),
Message: json.RawMessage(`[{"role":"assistant","content":"hello"}]`),
}
store.dialogExists["user-1|chat-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
longName := " " + strings.Repeat("x", 260) + " "
resp, code, err := svc.UpdateSession("user-1", "chat-1", "session-1", map[string]interface{}{
"name": longName,
"user_id": "spoof",
"chat_id": "spoof-chat",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("code=%v", code)
}
if resp.Name == nil || len(*resp.Name) != 255 {
t.Fatalf("name=%v", resp.Name)
}
if len(store.updateCalled) != 1 {
t.Fatalf("update calls=%d", len(store.updateCalled))
}
if _, ok := store.updateCalled[0].updates["user_id"]; ok {
t.Fatalf("unexpected user_id update: %#v", store.updateCalled[0].updates)
}
if _, ok := store.updateCalled[0].updates["chat_id"]; ok {
t.Fatalf("unexpected chat_id update: %#v", store.updateCalled[0].updates)
}
if !reflect.DeepEqual(resp.Messages, []map[string]interface{}{{"role": "assistant", "content": "hello"}}) {
t.Fatalf("messages=%#v", resp.Messages)
}
}
func TestUpdateSession_ValidationErrors(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{ID: "session-1", DialogID: "chat-1"}
store.dialogExists["user-1|chat-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
cases := []struct {
name string
req map[string]interface{}
message string
code common.ErrorCode
}{
{name: "empty body", req: map[string]interface{}{}, message: "Request body cannot be empty", code: common.CodeArgumentError},
{name: "message", req: map[string]interface{}{"message": []interface{}{}}, message: "`messages` cannot be changed.", code: common.CodeDataError},
{name: "messages", req: map[string]interface{}{"messages": []interface{}{}}, message: "`messages` cannot be changed.", code: common.CodeDataError},
{name: "reference", req: map[string]interface{}{"reference": []interface{}{}}, message: "`reference` cannot be changed.", code: common.CodeDataError},
{name: "empty name", req: map[string]interface{}{"name": " "}, message: "`name` can not be empty.", code: common.CodeDataError},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
_, code, err := svc.UpdateSession("user-1", "chat-1", "session-1", tc.req)
if err == nil || err.Error() != tc.message {
t.Fatalf("err=%v", err)
}
if code != tc.code {
t.Fatalf("code=%v", code)
}
})
}
}
func TestUpdateSession_NotFound(t *testing.T) {
store := newFakeSessionStore()
store.dialogExists["user-1|chat-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, code, err := svc.UpdateSession("user-1", "chat-1", "missing", map[string]interface{}{"name": "renamed"})
if err == nil || err.Error() != "Session not found!" {
t.Fatalf("err=%v", err)
}
if code != common.CodeDataError {
t.Fatalf("code=%v", code)
}
}
// ===================================================================
// Completion tests
// ===================================================================
@@ -415,7 +646,7 @@ func TestCompletion_Success(t *testing.T) {
store := newFakeSessionStore()
session := &entity.ChatSession{
ID: "session-1", DialogID: "dialog-1",
Message: json.RawMessage(`{"messages":[]}`),
Message: json.RawMessage(`[{"role":"assistant","content":"Welcome!"}]`),
Reference: json.RawMessage(`[]`),
}
store.sessions["session-1"] = session
@@ -447,6 +678,20 @@ func TestCompletion_Success(t *testing.T) {
if ans != "Hello world" {
t.Fatalf("expected answer 'Hello world', got %q", ans)
}
got := parseMessages(store.sessions["session-1"].Message)
if len(got) != 3 {
t.Fatalf("stored messages=%#v", got)
}
if got[0]["role"] != "assistant" || got[0]["content"] != "Welcome!" {
t.Fatalf("stored prologue=%#v", got[0])
}
if got[1]["role"] != "user" || got[1]["content"] != "hi" {
t.Fatalf("stored user message=%#v", got[1])
}
if got[2]["role"] != "assistant" || got[2]["content"] != "Hello world" || got[2]["id"] != "msg-1" {
t.Fatalf("stored assistant message=%#v", got[2])
}
}
func TestCompletion_EmptyMessages(t *testing.T) {
@@ -498,7 +743,7 @@ func TestCompletion_DialogNotFound(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1", DialogID: "dialog-1",
Message: json.RawMessage(`{"messages":[]}`),
Message: json.RawMessage(`[]`),
Reference: json.RawMessage(`[]`),
}
@@ -520,7 +765,7 @@ func TestCompletion_PipelineError(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1", DialogID: "dialog-1",
Message: json.RawMessage(`{"messages":[]}`),
Message: json.RawMessage(`[]`),
Reference: json.RawMessage(`[]`),
}
store.dialogs["dialog-1"] = &entity.Chat{
@@ -566,7 +811,7 @@ func TestCompletionStream_Success(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1", DialogID: "dialog-1",
Message: json.RawMessage(`{"messages":[]}`),
Message: json.RawMessage(`{"messages":[{"role":"assistant","content":"Welcome!"}]}`),
Reference: json.RawMessage(`[]`),
}
store.dialogs["dialog-1"] = &entity.Chat{
@@ -611,6 +856,53 @@ func TestCompletionStream_Success(t *testing.T) {
if !finalFound {
t.Fatal("expected final=true signal in stream")
}
got := parseMessages(store.sessions["session-1"].Message)
if len(got) != 3 {
t.Fatalf("stored messages=%#v", got)
}
if got[0]["role"] != "assistant" || got[0]["content"] != "Welcome!" {
t.Fatalf("stored prologue=%#v", got[0])
}
if got[1]["role"] != "user" || got[1]["content"] != "hi" {
t.Fatalf("stored user message=%#v", got[1])
}
if got[2]["role"] != "assistant" || got[2]["content"] != "stream answer" || got[2]["id"] != "msg-1" {
t.Fatalf("stored assistant message=%#v", got[2])
}
}
func TestStructureAnswerWithConv_ParsesArrayMessages(t *testing.T) {
session := &entity.ChatSession{
ID: "session-1",
Message: json.RawMessage(`[{"role":"assistant","content":"Welcome!"}]`),
}
svc := &ChatSessionService{}
ans := svc.structureAnswerWithConv(session, map[string]interface{}{
"answer": "Final answer",
"reference": map[string]interface{}{"chunks": []interface{}{}},
"final": true,
}, "msg-1", "session-1", []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
if ans["id"] != "msg-1" || ans["session_id"] != "session-1" {
t.Fatalf("ans=%#v", ans)
}
got := parseMessages(session.Message)
if len(got) != 1 {
t.Fatalf("stored messages=%#v", got)
}
if got[0]["role"] != "assistant" || got[0]["content"] != "Final answer" || got[0]["id"] != "msg-1" {
t.Fatalf("stored assistant message=%#v", got[0])
}
}
func TestParseMessages_LegacyWrappedObject(t *testing.T) {
got := parseMessages(json.RawMessage(`{"messages":[{"role":"assistant","content":"legacy"}]}`))
if !reflect.DeepEqual(got, []map[string]interface{}{{"role": "assistant", "content": "legacy"}}) {
t.Fatalf("messages=%#v", got)
}
}
func TestCompletionStream_EmptyMessages(t *testing.T) {
@@ -664,7 +956,7 @@ func TestCompletionStream_DialogNotFound(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1", DialogID: "dialog-1",
Message: json.RawMessage(`{"messages":[]}`),
Message: json.RawMessage(`[]`),
Reference: json.RawMessage(`[]`),
}
@@ -687,7 +979,7 @@ func TestCompletionStream_PipelineError(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1", DialogID: "dialog-1",
Message: json.RawMessage(`{"messages":[]}`),
Message: json.RawMessage(`[]`),
Reference: json.RawMessage(`[]`),
}
store.dialogs["dialog-1"] = &entity.Chat{