mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
feat(go-api): align chat session get/update with python behavior (#16239)
## Summary Align `/chats/:chat_id/sessions/:session_id` GET and PATCH with Python behavior.
This commit is contained in:
@@ -20,6 +20,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
@@ -56,6 +58,16 @@ func (dao *ChatSessionDAO) GetByID(id string) (*entity.ChatSession, error) {
|
||||
return &conv, nil
|
||||
}
|
||||
|
||||
// GetBySessionIDAndChatID gets a chat session by session ID and chat ID.
|
||||
func (dao *ChatSessionDAO) GetBySessionIDAndChatID(sessionID, chatID string) (*entity.ChatSession, error) {
|
||||
var conv entity.ChatSession
|
||||
err := DB.Where("id = ? AND dialog_id = ?", sessionID, chatID).First(&conv).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &conv, nil
|
||||
}
|
||||
|
||||
// Create creates a new chat session
|
||||
func (dao *ChatSessionDAO) Create(conv *entity.ChatSession) error {
|
||||
return DB.Create(conv).Error
|
||||
@@ -63,7 +75,14 @@ func (dao *ChatSessionDAO) Create(conv *entity.ChatSession) error {
|
||||
|
||||
// UpdateByID updates a chat session by ID
|
||||
func (dao *ChatSessionDAO) UpdateByID(id string, updates map[string]interface{}) error {
|
||||
return DB.Model(&entity.ChatSession{}).Where("id = ?", id).Updates(updates).Error
|
||||
result := DB.Model(&entity.ChatSession{}).Where("id = ?", id).Updates(updates)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByID deletes a chat session by ID (hard delete)
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -329,3 +330,53 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ChatSessionHandler) GetSession(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
userID := user.ID
|
||||
chatID, sessionID := c.Param("chat_id"), c.Param("session_id")
|
||||
|
||||
result, code, err := h.chatSessionService.GetSession(userID, chatID, sessionID)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
func (h *ChatSessionHandler) UpdateSession(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
userID := user.ID
|
||||
chatID, sessionID := c.Param("chat_id"), c.Param("session_id")
|
||||
|
||||
req := map[string]any{}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
jsonError(c, common.CodeArgumentError, "Request body cannot be empty")
|
||||
return
|
||||
}
|
||||
jsonError(c, common.CodeArgumentError, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if len(req) == 0 {
|
||||
jsonError(c, common.CodeArgumentError, "Request body cannot be empty")
|
||||
return
|
||||
}
|
||||
|
||||
result, code, err := h.chatSessionService.UpdateSession(userID, chatID, sessionID, req)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
74
internal/handler/chat_session_test.go
Normal file
74
internal/handler/chat_session_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestChatSessionHandlerUpdateSession_RejectsEmptyBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPatch, "/api/v1/chats/chat-1/sessions/session-1", nil)
|
||||
ctx.Params = gin.Params{
|
||||
{Key: "chat_id", Value: "chat-1"},
|
||||
{Key: "session_id", Value: "session-1"},
|
||||
}
|
||||
ctx.Set("user", &entity.User{ID: "user-1"})
|
||||
|
||||
handler := NewChatSessionHandler(service.NewChatSessionService(), nil)
|
||||
handler.UpdateSession(ctx)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d", recorder.Code)
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("decode response body: %v", err)
|
||||
}
|
||||
if got := body["code"]; got != float64(common.CodeArgumentError) {
|
||||
t.Fatalf("code=%v", got)
|
||||
}
|
||||
if got := body["message"]; got != "Request body cannot be empty" {
|
||||
t.Fatalf("message=%v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatSessionHandlerUpdateSession_RejectsEmptyJSONObject(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPatch, "/api/v1/chats/chat-1/sessions/session-1", strings.NewReader(`{}`))
|
||||
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||
ctx.Params = gin.Params{
|
||||
{Key: "chat_id", Value: "chat-1"},
|
||||
{Key: "session_id", Value: "session-1"},
|
||||
}
|
||||
ctx.Set("user", &entity.User{ID: "user-1"})
|
||||
|
||||
handler := NewChatSessionHandler(service.NewChatSessionService(), nil)
|
||||
handler.UpdateSession(ctx)
|
||||
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("decode response body: %v", err)
|
||||
}
|
||||
if got := body["code"]; got != float64(common.CodeArgumentError) {
|
||||
t.Fatalf("code=%v", got)
|
||||
}
|
||||
if got := body["message"]; got != "Request body cannot be empty" {
|
||||
t.Fatalf("message=%v", got)
|
||||
}
|
||||
}
|
||||
@@ -253,6 +253,8 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
chats.DELETE("/:chat_id", r.chatHandler.DeleteChat)
|
||||
chats.GET("/:chat_id", r.chatHandler.GetChat)
|
||||
chats.GET("/:chat_id/sessions", r.chatSessionHandler.ListChatSessions)
|
||||
chats.GET("/:chat_id/sessions/:session_id", r.chatSessionHandler.GetSession)
|
||||
chats.PATCH("/:chat_id/sessions/:session_id", r.chatSessionHandler.UpdateSession)
|
||||
}
|
||||
|
||||
// OpenAI-compatible chat completions route
|
||||
|
||||
@@ -25,6 +25,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
@@ -33,6 +35,7 @@ import (
|
||||
|
||||
type chatSessionStore interface {
|
||||
GetByID(id string) (*entity.ChatSession, error)
|
||||
GetBySessionIDAndChatID(sessionID, chatID string) (*entity.ChatSession, error)
|
||||
Create(conv *entity.ChatSession) error
|
||||
UpdateByID(id string, updates map[string]interface{}) error
|
||||
DeleteByID(id string) error
|
||||
@@ -128,16 +131,13 @@ func (s *ChatSessionService) SetChatSession(userID string, req *SetChatSessionRe
|
||||
}
|
||||
}
|
||||
|
||||
// Create initial message - store as JSON object with messages array
|
||||
messagesObj := map[string]interface{}{
|
||||
"messages": []map[string]interface{}{
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": prologue,
|
||||
},
|
||||
// Store messages in the same list shape as Python Conversation.message.
|
||||
messagesJSON, _ := json.Marshal([]map[string]interface{}{
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": prologue,
|
||||
},
|
||||
}
|
||||
messagesJSON, _ := json.Marshal(messagesObj)
|
||||
})
|
||||
|
||||
// Create reference - store as JSON array
|
||||
referenceJSON, _ := json.Marshal([]interface{}{})
|
||||
@@ -218,6 +218,20 @@ type ListChatSessionsResponse struct {
|
||||
Sessions []*entity.ChatSession
|
||||
}
|
||||
|
||||
type ChatSessionPayload struct {
|
||||
ID string `json:"id"`
|
||||
ChatID string `json:"chat_id"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
Messages []map[string]interface{} `json:"messages"`
|
||||
Reference []interface{} `json:"reference"`
|
||||
UserID *string `json:"user_id,omitempty"`
|
||||
Avatar *string `json:"avatar,omitempty"`
|
||||
CreateDate *time.Time `json:"create_date,omitempty"`
|
||||
UpdateDate *time.Time `json:"update_date,omitempty"`
|
||||
CreateTime *int64 `json:"create_time,omitempty"`
|
||||
UpdateTime *int64 `json:"update_time,omitempty"`
|
||||
}
|
||||
|
||||
// ListChatSessions lists chat sessions for a dialog
|
||||
func (s *ChatSessionService) ListChatSessions(userID string, chatID string) (*ListChatSessionsResponse, error) {
|
||||
// Get user's tenants
|
||||
@@ -263,6 +277,214 @@ func (s *ChatSessionService) ListChatSessions(userID string, chatID string) (*Li
|
||||
return &ListChatSessionsResponse{Sessions: sessions}, nil
|
||||
}
|
||||
|
||||
// GetSession returns one chat session after ownership validation.
|
||||
func (s *ChatSessionService) GetSession(userID, chatID, sessionID string) (*ChatSessionPayload, common.ErrorCode, error) {
|
||||
ok, err := s.ensureOwnedChat(userID, chatID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if !ok {
|
||||
return nil, common.CodeAuthenticationError, errors.New("No authorization.")
|
||||
}
|
||||
|
||||
session, err := s.chatSessionDAO.GetByID(sessionID)
|
||||
if err != nil {
|
||||
if isChatSessionNotFound(err) {
|
||||
return nil, common.CodeDataError, errors.New("Session not found!")
|
||||
}
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if session.DialogID != chatID {
|
||||
return nil, common.CodeDataError, errors.New("Session does not belong to this chat!")
|
||||
}
|
||||
|
||||
dialog, err := s.chatSessionDAO.GetDialogByID(chatID)
|
||||
if err != nil && !isChatSessionNotFound(err) {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
return s.buildSessionPayload(session, dialog, true), common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// UpdateSession updates one chat session after Python-style field validation.
|
||||
func (s *ChatSessionService) UpdateSession(userID, chatID, sessionID string, req map[string]interface{}) (*ChatSessionPayload, common.ErrorCode, error) {
|
||||
ok, err := s.ensureOwnedChat(userID, chatID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if !ok {
|
||||
return nil, common.CodeAuthenticationError, errors.New("No authorization.")
|
||||
}
|
||||
if len(req) == 0 {
|
||||
return nil, common.CodeArgumentError, errors.New("Request body cannot be empty")
|
||||
}
|
||||
|
||||
if _, err := s.chatSessionDAO.GetBySessionIDAndChatID(sessionID, chatID); err != nil {
|
||||
if isChatSessionNotFound(err) {
|
||||
return nil, common.CodeDataError, errors.New("Session not found!")
|
||||
}
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
if _, ok := req["message"]; ok {
|
||||
return nil, common.CodeDataError, errors.New("`messages` cannot be changed.")
|
||||
}
|
||||
if _, ok := req["messages"]; ok {
|
||||
return nil, common.CodeDataError, errors.New("`messages` cannot be changed.")
|
||||
}
|
||||
if _, ok := req["reference"]; ok {
|
||||
return nil, common.CodeDataError, errors.New("`reference` cannot be changed.")
|
||||
}
|
||||
|
||||
if name, exists := req["name"]; exists && name != nil {
|
||||
nameStr, ok := name.(string)
|
||||
if !ok || strings.TrimSpace(nameStr) == "" {
|
||||
return nil, common.CodeDataError, errors.New("`name` can not be empty.")
|
||||
}
|
||||
req["name"] = strings.TrimSpace(nameStr)
|
||||
nameRunes := []rune(req["name"].(string))
|
||||
if len(nameRunes) > 255 {
|
||||
req["name"] = string(nameRunes[:255])
|
||||
}
|
||||
}
|
||||
|
||||
updateFields := make(map[string]interface{})
|
||||
for k, v := range req {
|
||||
switch k {
|
||||
case "id", "dialog_id", "chat_id", "user_id":
|
||||
continue
|
||||
default:
|
||||
updateFields[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.chatSessionDAO.UpdateByID(sessionID, updateFields); err != nil {
|
||||
if isChatSessionNotFound(err) {
|
||||
return nil, common.CodeDataError, errors.New("Session not found!")
|
||||
}
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
session, err := s.chatSessionDAO.GetByID(sessionID)
|
||||
if err != nil {
|
||||
if isChatSessionNotFound(err) {
|
||||
return nil, common.CodeDataError, errors.New("Fail to update a session!")
|
||||
}
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) ensureOwnedChat(userID, chatID string) (bool, error) {
|
||||
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, tenantID := range tenantIDs {
|
||||
exists, err := s.chatSessionDAO.CheckDialogExists(tenantID, chatID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if exists {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
exists, err := s.chatSessionDAO.CheckDialogExists(userID, chatID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) buildSessionPayload(session *entity.ChatSession, dialog *entity.Chat, includeAvatar bool) *ChatSessionPayload {
|
||||
var avatar *string
|
||||
if includeAvatar {
|
||||
value := ""
|
||||
if dialog != nil && dialog.Icon != nil {
|
||||
value = *dialog.Icon
|
||||
}
|
||||
avatar = &value
|
||||
}
|
||||
|
||||
references := parseReferenceList(session.Reference)
|
||||
for index, ref := range references {
|
||||
refMap, ok := ref.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
refMap["chunks"] = formatReferenceChunks(refMap)
|
||||
references[index] = refMap
|
||||
}
|
||||
|
||||
return &ChatSessionPayload{
|
||||
ID: session.ID,
|
||||
ChatID: session.DialogID,
|
||||
Name: session.Name,
|
||||
Messages: parseMessages(session.Message),
|
||||
Reference: references,
|
||||
UserID: session.UserID,
|
||||
Avatar: avatar,
|
||||
CreateDate: session.CreateDate,
|
||||
UpdateDate: session.UpdateDate,
|
||||
CreateTime: session.CreateTime,
|
||||
UpdateTime: session.UpdateTime,
|
||||
}
|
||||
}
|
||||
|
||||
func parseMessages(raw json.RawMessage) []map[string]interface{} {
|
||||
var messages []map[string]interface{}
|
||||
if len(raw) == 0 {
|
||||
return messages
|
||||
}
|
||||
if err := json.Unmarshal(raw, &messages); err == nil {
|
||||
return messages
|
||||
}
|
||||
|
||||
var wrapped struct {
|
||||
Messages []map[string]interface{} `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &wrapped); err != nil {
|
||||
return nil
|
||||
}
|
||||
return wrapped.Messages
|
||||
}
|
||||
|
||||
func parseReferenceList(raw json.RawMessage) []interface{} {
|
||||
var references []interface{}
|
||||
if len(raw) == 0 {
|
||||
return references
|
||||
}
|
||||
err := json.Unmarshal(raw, &references)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return references
|
||||
}
|
||||
|
||||
func formatReferenceChunks(reference map[string]interface{}) []FormattedChunk {
|
||||
rawChunks, ok := reference["chunks"].([]interface{})
|
||||
if !ok {
|
||||
return []FormattedChunk{}
|
||||
}
|
||||
|
||||
chunks := make([]map[string]interface{}, 0, len(rawChunks))
|
||||
for _, item := range rawChunks {
|
||||
chunk, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
return formatChunks(chunks)
|
||||
}
|
||||
|
||||
func isChatSessionNotFound(err error) bool {
|
||||
return errors.Is(err, gorm.ErrRecordNotFound)
|
||||
}
|
||||
|
||||
// Completion performs chat completion with full RAG support via ChatPipelineService.
|
||||
func (s *ChatSessionService) Completion(userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string) (map[string]interface{}, error) {
|
||||
// Validate the last message is from user
|
||||
@@ -286,7 +508,7 @@ func (s *ChatSessionService) Completion(userID string, conversationID string, me
|
||||
return nil, errors.New("Dialog not found")
|
||||
}
|
||||
|
||||
// Deep copy messages to session
|
||||
// Deep copy messages to session, preserving the stored prologue that handler strips from requests.
|
||||
sessionMessages := s.buildSessionMessages(session, messages)
|
||||
|
||||
// Initialize reference if empty
|
||||
@@ -338,6 +560,12 @@ func (s *ChatSessionService) Completion(userID string, conversationID string, me
|
||||
|
||||
// Update conversation if not embedded
|
||||
if !isEmbedded {
|
||||
sessionMessages = append(sessionMessages, map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": answer.String(),
|
||||
"id": messageID,
|
||||
"created_at": float64(time.Now().Unix()),
|
||||
})
|
||||
s.updateSessionMessages(session, sessionMessages, reference)
|
||||
}
|
||||
|
||||
@@ -375,7 +603,7 @@ func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string
|
||||
return errors.New("Dialog not found")
|
||||
}
|
||||
|
||||
// Deep copy messages to session
|
||||
// Deep copy messages to session, preserving the stored prologue that handler strips from requests.
|
||||
sessionMessages := s.buildSessionMessages(session, messages)
|
||||
|
||||
// Initialize reference if empty
|
||||
@@ -440,6 +668,12 @@ func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string
|
||||
|
||||
// Update conversation if not embedded
|
||||
if !isEmbedded {
|
||||
sessionMessages = append(sessionMessages, map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": fullAnswer.String(),
|
||||
"id": messageID,
|
||||
"created_at": float64(time.Now().Unix()),
|
||||
})
|
||||
s.updateSessionMessages(session, sessionMessages, reference)
|
||||
}
|
||||
|
||||
@@ -449,14 +683,33 @@ func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string
|
||||
// Helper methods
|
||||
|
||||
func (s *ChatSessionService) buildSessionMessages(session *entity.ChatSession, messages []map[string]interface{}) []map[string]interface{} {
|
||||
// Deep copy messages to session
|
||||
sessionMessages := make([]map[string]interface{}, len(messages))
|
||||
for i, msg := range messages {
|
||||
sessionMessages[i] = make(map[string]interface{})
|
||||
for k, v := range msg {
|
||||
sessionMessages[i][k] = v
|
||||
prefix := make([]map[string]interface{}, 0, 1)
|
||||
existingMessages := parseMessages(session.Message)
|
||||
if len(existingMessages) > 0 {
|
||||
if role, _ := existingMessages[0]["role"].(string); role == "assistant" {
|
||||
firstIncomingRole := ""
|
||||
if len(messages) > 0 {
|
||||
firstIncomingRole, _ = messages[0]["role"].(string)
|
||||
}
|
||||
if firstIncomingRole != "assistant" {
|
||||
prologue := make(map[string]interface{}, len(existingMessages[0]))
|
||||
for k, v := range existingMessages[0] {
|
||||
prologue[k] = v
|
||||
}
|
||||
prefix = append(prefix, prologue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sessionMessages := make([]map[string]interface{}, 0, len(prefix)+len(messages))
|
||||
sessionMessages = append(sessionMessages, prefix...)
|
||||
for _, msg := range messages {
|
||||
cloned := make(map[string]interface{}, len(msg))
|
||||
for k, v := range msg {
|
||||
cloned[k] = v
|
||||
}
|
||||
sessionMessages = append(sessionMessages, cloned)
|
||||
}
|
||||
return sessionMessages
|
||||
}
|
||||
|
||||
@@ -495,9 +748,7 @@ func (s *ChatSessionService) structureAnswer(session *entity.ChatSession, answer
|
||||
|
||||
func (s *ChatSessionService) updateSessionMessages(session *entity.ChatSession, messages []map[string]interface{}, reference []interface{}) {
|
||||
// Update session with new messages and reference
|
||||
messagesJSON, _ := json.Marshal(map[string]interface{}{
|
||||
"messages": messages,
|
||||
})
|
||||
messagesJSON, _ := json.Marshal(messages)
|
||||
referenceJSON, _ := json.Marshal(reference)
|
||||
|
||||
updates := map[string]interface{}{
|
||||
@@ -505,6 +756,8 @@ func (s *ChatSessionService) updateSessionMessages(session *entity.ChatSession,
|
||||
"reference": referenceJSON,
|
||||
}
|
||||
s.chatSessionDAO.UpdateByID(session.ID, updates)
|
||||
session.Message = messagesJSON
|
||||
session.Reference = referenceJSON
|
||||
}
|
||||
|
||||
// structureAnswerWithConv structures the answer with conversation update (like Python's structure_answer)
|
||||
@@ -535,12 +788,8 @@ func (s *ChatSessionService) structureAnswerWithConv(session *entity.ChatSession
|
||||
content = "</think>"
|
||||
}
|
||||
|
||||
// Parse existing messages
|
||||
var messagesObj map[string]interface{}
|
||||
if len(session.Message) > 0 {
|
||||
json.Unmarshal(session.Message, &messagesObj)
|
||||
}
|
||||
messages, _ := messagesObj["messages"].([]interface{})
|
||||
// Parse existing messages. Keep backward compatibility with wrapped legacy rows.
|
||||
messages := parseMessages(session.Message)
|
||||
|
||||
// Update or append assistant message
|
||||
if len(messages) == 0 || s.getLastRole(messages) != "assistant" {
|
||||
@@ -552,20 +801,20 @@ func (s *ChatSessionService) structureAnswerWithConv(session *entity.ChatSession
|
||||
})
|
||||
} else {
|
||||
lastIdx := len(messages) - 1
|
||||
lastMsg, _ := messages[lastIdx].(map[string]interface{})
|
||||
if lastMsg != nil {
|
||||
if ans["final"] == true && ans["answer"] != nil {
|
||||
lastMsg["content"] = ans["answer"]
|
||||
} else {
|
||||
existing, _ := lastMsg["content"].(string)
|
||||
lastMsg["content"] = existing + content
|
||||
}
|
||||
lastMsg["created_at"] = float64(time.Now().Unix())
|
||||
lastMsg["id"] = messageID
|
||||
messages[lastIdx] = lastMsg
|
||||
lastMsg := messages[lastIdx]
|
||||
if ans["final"] == true && ans["answer"] != nil {
|
||||
lastMsg["content"] = ans["answer"]
|
||||
} else {
|
||||
existing, _ := lastMsg["content"].(string)
|
||||
lastMsg["content"] = existing + content
|
||||
}
|
||||
lastMsg["created_at"] = float64(time.Now().Unix())
|
||||
lastMsg["id"] = messageID
|
||||
messages[lastIdx] = lastMsg
|
||||
}
|
||||
|
||||
session.Message, _ = json.Marshal(messages)
|
||||
|
||||
// Update reference
|
||||
if len(reference) > 0 {
|
||||
reference[len(reference)-1] = ref
|
||||
@@ -575,16 +824,12 @@ func (s *ChatSessionService) structureAnswerWithConv(session *entity.ChatSession
|
||||
}
|
||||
|
||||
// getLastRole gets the role of the last message
|
||||
func (s *ChatSessionService) getLastRole(messages []interface{}) string {
|
||||
func (s *ChatSessionService) getLastRole(messages []map[string]interface{}) string {
|
||||
if len(messages) == 0 {
|
||||
return ""
|
||||
}
|
||||
lastMsg, _ := messages[len(messages)-1].(map[string]interface{})
|
||||
if lastMsg != nil {
|
||||
role, _ := lastMsg["role"].(string)
|
||||
return role
|
||||
}
|
||||
return ""
|
||||
role, _ := messages[len(messages)-1]["role"].(string)
|
||||
return role
|
||||
}
|
||||
|
||||
// chunksFormat formats chunks for reference (simplified version)
|
||||
|
||||
@@ -4,10 +4,13 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
@@ -48,7 +51,18 @@ func (f *fakeSessionStore) GetByID(id string) (*entity.ChatSession, error) {
|
||||
}
|
||||
s, ok := f.sessions[id]
|
||||
if !ok {
|
||||
return nil, errors.New("record not found")
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (f *fakeSessionStore) GetBySessionIDAndChatID(sessionID, chatID string) (*entity.ChatSession, error) {
|
||||
s, err := f.GetByID(sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.DialogID != chatID {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
@@ -70,10 +84,30 @@ func (f *fakeSessionStore) UpdateByID(id string, updates map[string]interface{})
|
||||
if f.updateByIDErr != nil {
|
||||
return f.updateByIDErr
|
||||
}
|
||||
s, ok := f.sessions[id]
|
||||
if !ok {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
f.updateCalled = append(f.updateCalled, struct {
|
||||
id string
|
||||
updates map[string]interface{}
|
||||
}{id, updates})
|
||||
for k, v := range updates {
|
||||
switch k {
|
||||
case "name":
|
||||
if str, ok := v.(string); ok {
|
||||
s.Name = &str
|
||||
}
|
||||
case "message":
|
||||
if raw, ok := v.([]byte); ok {
|
||||
s.Message = append(json.RawMessage(nil), raw...)
|
||||
}
|
||||
case "reference":
|
||||
if raw, ok := v.([]byte); ok {
|
||||
s.Reference = append(json.RawMessage(nil), raw...)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -177,16 +211,15 @@ func TestSetChatSession_CreateNew(t *testing.T) {
|
||||
t.Fatalf("expected 1 Create call, got %d", len(store.createCalled))
|
||||
}
|
||||
|
||||
// Verify prologue is in the message
|
||||
var msgObj map[string]interface{}
|
||||
if err := json.Unmarshal(store.createCalled[0].Message, &msgObj); err != nil {
|
||||
// Verify prologue is in the message list.
|
||||
var msgs []map[string]interface{}
|
||||
if err := json.Unmarshal(store.createCalled[0].Message, &msgs); err != nil {
|
||||
t.Fatalf("failed to unmarshal message: %v", err)
|
||||
}
|
||||
msgs, _ := msgObj["messages"].([]interface{})
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected 1 initial message, got %d", len(msgs))
|
||||
}
|
||||
firstMsg, _ := msgs[0].(map[string]interface{})
|
||||
firstMsg := msgs[0]
|
||||
if firstMsg["role"] != "assistant" || firstMsg["content"] != "Welcome!" {
|
||||
t.Fatalf("unexpected prologue message: %#v", firstMsg)
|
||||
}
|
||||
@@ -213,10 +246,9 @@ func TestSetChatSession_CreateNewDefaultPrologue(t *testing.T) {
|
||||
t.Fatal("expected session ID")
|
||||
}
|
||||
// Default prologue
|
||||
var msgObj map[string]interface{}
|
||||
json.Unmarshal(store.createCalled[0].Message, &msgObj)
|
||||
msgs, _ := msgObj["messages"].([]interface{})
|
||||
firstMsg, _ := msgs[0].(map[string]interface{})
|
||||
var msgs []map[string]interface{}
|
||||
json.Unmarshal(store.createCalled[0].Message, &msgs)
|
||||
firstMsg := msgs[0]
|
||||
if !strings.Contains(firstMsg["content"].(string), "Hi! I'm your assistant") {
|
||||
t.Fatalf("expected default prologue, got %q", firstMsg["content"])
|
||||
}
|
||||
@@ -407,6 +439,205 @@ func TestListChatSessions_NotOwner(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// ===================================================================
|
||||
// GetSession / UpdateSession tests
|
||||
// ===================================================================
|
||||
|
||||
func TestGetSession_Success(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Name: strPtr("session"),
|
||||
Message: json.RawMessage(`[{"role":"assistant","content":"hello"}]`),
|
||||
Reference: json.RawMessage(`[
|
||||
{"chunks":[{"chunk_id":"chunk-1","content_with_weight":"hello","doc_id":"doc-1","docnm_kwd":"Doc 1","kb_id":"kb-1"}]},
|
||||
[]
|
||||
]`),
|
||||
UserID: strPtr("user-1"),
|
||||
}
|
||||
icon := "avatar.png"
|
||||
store.dialogs["chat-1"] = &entity.Chat{ID: "chat-1", Icon: &icon}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
resp, code, err := svc.GetSession("user-1", "chat-1", "session-1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("unexpected code: %v", code)
|
||||
}
|
||||
if resp.ChatID != "chat-1" {
|
||||
t.Fatalf("chat_id=%q", resp.ChatID)
|
||||
}
|
||||
if resp.Avatar == nil || *resp.Avatar != "avatar.png" {
|
||||
t.Fatalf("avatar=%v", resp.Avatar)
|
||||
}
|
||||
if len(resp.Messages) != 1 || resp.Messages[0]["content"] != "hello" {
|
||||
t.Fatalf("messages=%#v", resp.Messages)
|
||||
}
|
||||
if len(resp.Reference) != 2 {
|
||||
t.Fatalf("reference len=%d", len(resp.Reference))
|
||||
}
|
||||
firstRef, ok := resp.Reference[0].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("reference[0] type=%T", resp.Reference[0])
|
||||
}
|
||||
chunks, ok := firstRef["chunks"].([]FormattedChunk)
|
||||
if !ok {
|
||||
t.Fatalf("chunks type=%T", firstRef["chunks"])
|
||||
}
|
||||
if len(chunks) != 1 || chunks[0].ID != "chunk-1" {
|
||||
t.Fatalf("chunks=%#v", chunks)
|
||||
}
|
||||
if _, ok := resp.Reference[1].([]interface{}); !ok {
|
||||
t.Fatalf("reference[1] changed unexpectedly: %T", resp.Reference[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSession_NotOwner(t *testing.T) {
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: newFakeSessionStore(),
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, code, err := svc.GetSession("user-1", "chat-1", "session-1")
|
||||
if err == nil || err.Error() != "No authorization." {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
if code != common.CodeAuthenticationError {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSession_WrongChat(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{ID: "session-1", DialogID: "chat-2"}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, code, err := svc.GetSession("user-1", "chat-1", "session-1")
|
||||
if err == nil || err.Error() != "Session does not belong to this chat!" {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
if code != common.CodeDataError {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSession_Success(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Name: strPtr("old"),
|
||||
Message: json.RawMessage(`[{"role":"assistant","content":"hello"}]`),
|
||||
}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
longName := " " + strings.Repeat("x", 260) + " "
|
||||
resp, code, err := svc.UpdateSession("user-1", "chat-1", "session-1", map[string]interface{}{
|
||||
"name": longName,
|
||||
"user_id": "spoof",
|
||||
"chat_id": "spoof-chat",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
if resp.Name == nil || len(*resp.Name) != 255 {
|
||||
t.Fatalf("name=%v", resp.Name)
|
||||
}
|
||||
if len(store.updateCalled) != 1 {
|
||||
t.Fatalf("update calls=%d", len(store.updateCalled))
|
||||
}
|
||||
if _, ok := store.updateCalled[0].updates["user_id"]; ok {
|
||||
t.Fatalf("unexpected user_id update: %#v", store.updateCalled[0].updates)
|
||||
}
|
||||
if _, ok := store.updateCalled[0].updates["chat_id"]; ok {
|
||||
t.Fatalf("unexpected chat_id update: %#v", store.updateCalled[0].updates)
|
||||
}
|
||||
if !reflect.DeepEqual(resp.Messages, []map[string]interface{}{{"role": "assistant", "content": "hello"}}) {
|
||||
t.Fatalf("messages=%#v", resp.Messages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSession_ValidationErrors(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{ID: "session-1", DialogID: "chat-1"}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
req map[string]interface{}
|
||||
message string
|
||||
code common.ErrorCode
|
||||
}{
|
||||
{name: "empty body", req: map[string]interface{}{}, message: "Request body cannot be empty", code: common.CodeArgumentError},
|
||||
{name: "message", req: map[string]interface{}{"message": []interface{}{}}, message: "`messages` cannot be changed.", code: common.CodeDataError},
|
||||
{name: "messages", req: map[string]interface{}{"messages": []interface{}{}}, message: "`messages` cannot be changed.", code: common.CodeDataError},
|
||||
{name: "reference", req: map[string]interface{}{"reference": []interface{}{}}, message: "`reference` cannot be changed.", code: common.CodeDataError},
|
||||
{name: "empty name", req: map[string]interface{}{"name": " "}, message: "`name` can not be empty.", code: common.CodeDataError},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, code, err := svc.UpdateSession("user-1", "chat-1", "session-1", tc.req)
|
||||
if err == nil || err.Error() != tc.message {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
if code != tc.code {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSession_NotFound(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, code, err := svc.UpdateSession("user-1", "chat-1", "missing", map[string]interface{}{"name": "renamed"})
|
||||
if err == nil || err.Error() != "Session not found!" {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
if code != common.CodeDataError {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
}
|
||||
|
||||
// ===================================================================
|
||||
// Completion tests
|
||||
// ===================================================================
|
||||
@@ -415,7 +646,7 @@ func TestCompletion_Success(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
session := &entity.ChatSession{
|
||||
ID: "session-1", DialogID: "dialog-1",
|
||||
Message: json.RawMessage(`{"messages":[]}`),
|
||||
Message: json.RawMessage(`[{"role":"assistant","content":"Welcome!"}]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
}
|
||||
store.sessions["session-1"] = session
|
||||
@@ -447,6 +678,20 @@ func TestCompletion_Success(t *testing.T) {
|
||||
if ans != "Hello world" {
|
||||
t.Fatalf("expected answer 'Hello world', got %q", ans)
|
||||
}
|
||||
|
||||
got := parseMessages(store.sessions["session-1"].Message)
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("stored messages=%#v", got)
|
||||
}
|
||||
if got[0]["role"] != "assistant" || got[0]["content"] != "Welcome!" {
|
||||
t.Fatalf("stored prologue=%#v", got[0])
|
||||
}
|
||||
if got[1]["role"] != "user" || got[1]["content"] != "hi" {
|
||||
t.Fatalf("stored user message=%#v", got[1])
|
||||
}
|
||||
if got[2]["role"] != "assistant" || got[2]["content"] != "Hello world" || got[2]["id"] != "msg-1" {
|
||||
t.Fatalf("stored assistant message=%#v", got[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletion_EmptyMessages(t *testing.T) {
|
||||
@@ -498,7 +743,7 @@ func TestCompletion_DialogNotFound(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1", DialogID: "dialog-1",
|
||||
Message: json.RawMessage(`{"messages":[]}`),
|
||||
Message: json.RawMessage(`[]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
}
|
||||
|
||||
@@ -520,7 +765,7 @@ func TestCompletion_PipelineError(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1", DialogID: "dialog-1",
|
||||
Message: json.RawMessage(`{"messages":[]}`),
|
||||
Message: json.RawMessage(`[]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
}
|
||||
store.dialogs["dialog-1"] = &entity.Chat{
|
||||
@@ -566,7 +811,7 @@ func TestCompletionStream_Success(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1", DialogID: "dialog-1",
|
||||
Message: json.RawMessage(`{"messages":[]}`),
|
||||
Message: json.RawMessage(`{"messages":[{"role":"assistant","content":"Welcome!"}]}`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
}
|
||||
store.dialogs["dialog-1"] = &entity.Chat{
|
||||
@@ -611,6 +856,53 @@ func TestCompletionStream_Success(t *testing.T) {
|
||||
if !finalFound {
|
||||
t.Fatal("expected final=true signal in stream")
|
||||
}
|
||||
|
||||
got := parseMessages(store.sessions["session-1"].Message)
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("stored messages=%#v", got)
|
||||
}
|
||||
if got[0]["role"] != "assistant" || got[0]["content"] != "Welcome!" {
|
||||
t.Fatalf("stored prologue=%#v", got[0])
|
||||
}
|
||||
if got[1]["role"] != "user" || got[1]["content"] != "hi" {
|
||||
t.Fatalf("stored user message=%#v", got[1])
|
||||
}
|
||||
if got[2]["role"] != "assistant" || got[2]["content"] != "stream answer" || got[2]["id"] != "msg-1" {
|
||||
t.Fatalf("stored assistant message=%#v", got[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructureAnswerWithConv_ParsesArrayMessages(t *testing.T) {
|
||||
session := &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
Message: json.RawMessage(`[{"role":"assistant","content":"Welcome!"}]`),
|
||||
}
|
||||
svc := &ChatSessionService{}
|
||||
|
||||
ans := svc.structureAnswerWithConv(session, map[string]interface{}{
|
||||
"answer": "Final answer",
|
||||
"reference": map[string]interface{}{"chunks": []interface{}{}},
|
||||
"final": true,
|
||||
}, "msg-1", "session-1", []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
|
||||
|
||||
if ans["id"] != "msg-1" || ans["session_id"] != "session-1" {
|
||||
t.Fatalf("ans=%#v", ans)
|
||||
}
|
||||
|
||||
got := parseMessages(session.Message)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("stored messages=%#v", got)
|
||||
}
|
||||
if got[0]["role"] != "assistant" || got[0]["content"] != "Final answer" || got[0]["id"] != "msg-1" {
|
||||
t.Fatalf("stored assistant message=%#v", got[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMessages_LegacyWrappedObject(t *testing.T) {
|
||||
got := parseMessages(json.RawMessage(`{"messages":[{"role":"assistant","content":"legacy"}]}`))
|
||||
if !reflect.DeepEqual(got, []map[string]interface{}{{"role": "assistant", "content": "legacy"}}) {
|
||||
t.Fatalf("messages=%#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionStream_EmptyMessages(t *testing.T) {
|
||||
@@ -664,7 +956,7 @@ func TestCompletionStream_DialogNotFound(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1", DialogID: "dialog-1",
|
||||
Message: json.RawMessage(`{"messages":[]}`),
|
||||
Message: json.RawMessage(`[]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
}
|
||||
|
||||
@@ -687,7 +979,7 @@ func TestCompletionStream_PipelineError(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1", DialogID: "dialog-1",
|
||||
Message: json.RawMessage(`{"messages":[]}`),
|
||||
Message: json.RawMessage(`[]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
}
|
||||
store.dialogs["dialog-1"] = &entity.Chat{
|
||||
|
||||
Reference in New Issue
Block a user