diff --git a/internal/dao/chat_session.go b/internal/dao/chat_session.go index 60b83ea08b..a5f77f1809 100644 --- a/internal/dao/chat_session.go +++ b/internal/dao/chat_session.go @@ -20,6 +20,8 @@ import ( "strings" "time" + "gorm.io/gorm" + "ragflow/internal/entity" ) @@ -56,6 +58,16 @@ func (dao *ChatSessionDAO) GetByID(id string) (*entity.ChatSession, error) { return &conv, nil } +// GetBySessionIDAndChatID gets a chat session by session ID and chat ID. +func (dao *ChatSessionDAO) GetBySessionIDAndChatID(sessionID, chatID string) (*entity.ChatSession, error) { + var conv entity.ChatSession + err := DB.Where("id = ? AND dialog_id = ?", sessionID, chatID).First(&conv).Error + if err != nil { + return nil, err + } + return &conv, nil +} + // Create creates a new chat session func (dao *ChatSessionDAO) Create(conv *entity.ChatSession) error { return DB.Create(conv).Error @@ -63,7 +75,14 @@ func (dao *ChatSessionDAO) Create(conv *entity.ChatSession) error { // UpdateByID updates a chat session by ID func (dao *ChatSessionDAO) UpdateByID(id string, updates map[string]interface{}) error { - return DB.Model(&entity.ChatSession{}).Where("id = ?", id).Updates(updates).Error + result := DB.Model(&entity.ChatSession{}).Where("id = ?", id).Updates(updates) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return gorm.ErrRecordNotFound + } + return nil } // DeleteByID deletes a chat session by ID (hard delete) diff --git a/internal/handler/chat_session.go b/internal/handler/chat_session.go index 3c395e88fc..ec3dbaf449 100644 --- a/internal/handler/chat_session.go +++ b/internal/handler/chat_session.go @@ -17,6 +17,7 @@ package handler import ( + "errors" "fmt" "io" "net/http" @@ -329,3 +330,53 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) { }) } } + +func (h *ChatSessionHandler) GetSession(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + userID := user.ID + chatID, sessionID := c.Param("chat_id"), c.Param("session_id") + + result, code, err := h.chatSessionService.GetSession(userID, chatID, sessionID) + if err != nil { + jsonError(c, code, err.Error()) + return + } + jsonResponse(c, common.CodeSuccess, result, "success") +} + +func (h *ChatSessionHandler) UpdateSession(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + userID := user.ID + chatID, sessionID := c.Param("chat_id"), c.Param("session_id") + + req := map[string]any{} + if err := c.ShouldBindJSON(&req); err != nil { + if errors.Is(err, io.EOF) { + jsonError(c, common.CodeArgumentError, "Request body cannot be empty") + return + } + jsonError(c, common.CodeArgumentError, "Invalid request: "+err.Error()) + return + } + if len(req) == 0 { + jsonError(c, common.CodeArgumentError, "Request body cannot be empty") + return + } + + result, code, err := h.chatSessionService.UpdateSession(userID, chatID, sessionID, req) + if err != nil { + jsonError(c, code, err.Error()) + return + } + jsonResponse(c, common.CodeSuccess, result, "success") +} diff --git a/internal/handler/chat_session_test.go b/internal/handler/chat_session_test.go new file mode 100644 index 0000000000..4aa621068e --- /dev/null +++ b/internal/handler/chat_session_test.go @@ -0,0 +1,74 @@ +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "ragflow/internal/common" + "ragflow/internal/entity" + "ragflow/internal/service" + + "github.com/gin-gonic/gin" +) + +func TestChatSessionHandlerUpdateSession_RejectsEmptyBody(t *testing.T) { + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodPatch, "/api/v1/chats/chat-1/sessions/session-1", nil) + ctx.Params = gin.Params{ + {Key: "chat_id", Value: "chat-1"}, + {Key: "session_id", Value: "session-1"}, + } + ctx.Set("user", &entity.User{ID: "user-1"}) + + handler := NewChatSessionHandler(service.NewChatSessionService(), nil) + handler.UpdateSession(ctx) + + if recorder.Code != http.StatusOK { + t.Fatalf("status=%d", recorder.Code) + } + + var body map[string]interface{} + if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil { + t.Fatalf("decode response body: %v", err) + } + if got := body["code"]; got != float64(common.CodeArgumentError) { + t.Fatalf("code=%v", got) + } + if got := body["message"]; got != "Request body cannot be empty" { + t.Fatalf("message=%v", got) + } +} + +func TestChatSessionHandlerUpdateSession_RejectsEmptyJSONObject(t *testing.T) { + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodPatch, "/api/v1/chats/chat-1/sessions/session-1", strings.NewReader(`{}`)) + ctx.Request.Header.Set("Content-Type", "application/json") + ctx.Params = gin.Params{ + {Key: "chat_id", Value: "chat-1"}, + {Key: "session_id", Value: "session-1"}, + } + ctx.Set("user", &entity.User{ID: "user-1"}) + + handler := NewChatSessionHandler(service.NewChatSessionService(), nil) + handler.UpdateSession(ctx) + + var body map[string]interface{} + if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil { + t.Fatalf("decode response body: %v", err) + } + if got := body["code"]; got != float64(common.CodeArgumentError) { + t.Fatalf("code=%v", got) + } + if got := body["message"]; got != "Request body cannot be empty" { + t.Fatalf("message=%v", got) + } +} diff --git a/internal/router/router.go b/internal/router/router.go index 0b3fff9ba0..ed51bc090c 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -253,6 +253,8 @@ func (r *Router) Setup(engine *gin.Engine) { chats.DELETE("/:chat_id", r.chatHandler.DeleteChat) chats.GET("/:chat_id", r.chatHandler.GetChat) chats.GET("/:chat_id/sessions", r.chatSessionHandler.ListChatSessions) + chats.GET("/:chat_id/sessions/:session_id", r.chatSessionHandler.GetSession) + chats.PATCH("/:chat_id/sessions/:session_id", r.chatSessionHandler.UpdateSession) } // OpenAI-compatible chat completions route diff --git a/internal/service/chat_session.go b/internal/service/chat_session.go index 046b21eeb9..ac3b8cd21f 100644 --- a/internal/service/chat_session.go +++ b/internal/service/chat_session.go @@ -25,6 +25,8 @@ import ( "strings" "time" + "gorm.io/gorm" + "ragflow/internal/dao" "ragflow/internal/entity" ) @@ -33,6 +35,7 @@ import ( type chatSessionStore interface { GetByID(id string) (*entity.ChatSession, error) + GetBySessionIDAndChatID(sessionID, chatID string) (*entity.ChatSession, error) Create(conv *entity.ChatSession) error UpdateByID(id string, updates map[string]interface{}) error DeleteByID(id string) error @@ -128,16 +131,13 @@ func (s *ChatSessionService) SetChatSession(userID string, req *SetChatSessionRe } } - // Create initial message - store as JSON object with messages array - messagesObj := map[string]interface{}{ - "messages": []map[string]interface{}{ - { - "role": "assistant", - "content": prologue, - }, + // Store messages in the same list shape as Python Conversation.message. + messagesJSON, _ := json.Marshal([]map[string]interface{}{ + { + "role": "assistant", + "content": prologue, }, - } - messagesJSON, _ := json.Marshal(messagesObj) + }) // Create reference - store as JSON array referenceJSON, _ := json.Marshal([]interface{}{}) @@ -218,6 +218,20 @@ type ListChatSessionsResponse struct { Sessions []*entity.ChatSession } +type ChatSessionPayload struct { + ID string `json:"id"` + ChatID string `json:"chat_id"` + Name *string `json:"name,omitempty"` + Messages []map[string]interface{} `json:"messages"` + Reference []interface{} `json:"reference"` + UserID *string `json:"user_id,omitempty"` + Avatar *string `json:"avatar,omitempty"` + CreateDate *time.Time `json:"create_date,omitempty"` + UpdateDate *time.Time `json:"update_date,omitempty"` + CreateTime *int64 `json:"create_time,omitempty"` + UpdateTime *int64 `json:"update_time,omitempty"` +} + // ListChatSessions lists chat sessions for a dialog func (s *ChatSessionService) ListChatSessions(userID string, chatID string) (*ListChatSessionsResponse, error) { // Get user's tenants @@ -263,6 +277,214 @@ func (s *ChatSessionService) ListChatSessions(userID string, chatID string) (*Li return &ListChatSessionsResponse{Sessions: sessions}, nil } +// GetSession returns one chat session after ownership validation. +func (s *ChatSessionService) GetSession(userID, chatID, sessionID string) (*ChatSessionPayload, common.ErrorCode, error) { + ok, err := s.ensureOwnedChat(userID, chatID) + if err != nil { + return nil, common.CodeServerError, err + } + if !ok { + return nil, common.CodeAuthenticationError, errors.New("No authorization.") + } + + session, err := s.chatSessionDAO.GetByID(sessionID) + if err != nil { + if isChatSessionNotFound(err) { + return nil, common.CodeDataError, errors.New("Session not found!") + } + return nil, common.CodeServerError, err + } + if session.DialogID != chatID { + return nil, common.CodeDataError, errors.New("Session does not belong to this chat!") + } + + dialog, err := s.chatSessionDAO.GetDialogByID(chatID) + if err != nil && !isChatSessionNotFound(err) { + return nil, common.CodeServerError, err + } + + return s.buildSessionPayload(session, dialog, true), common.CodeSuccess, nil +} + +// UpdateSession updates one chat session after Python-style field validation. +func (s *ChatSessionService) UpdateSession(userID, chatID, sessionID string, req map[string]interface{}) (*ChatSessionPayload, common.ErrorCode, error) { + ok, err := s.ensureOwnedChat(userID, chatID) + if err != nil { + return nil, common.CodeServerError, err + } + if !ok { + return nil, common.CodeAuthenticationError, errors.New("No authorization.") + } + if len(req) == 0 { + return nil, common.CodeArgumentError, errors.New("Request body cannot be empty") + } + + if _, err := s.chatSessionDAO.GetBySessionIDAndChatID(sessionID, chatID); err != nil { + if isChatSessionNotFound(err) { + return nil, common.CodeDataError, errors.New("Session not found!") + } + return nil, common.CodeServerError, err + } + + if _, ok := req["message"]; ok { + return nil, common.CodeDataError, errors.New("`messages` cannot be changed.") + } + if _, ok := req["messages"]; ok { + return nil, common.CodeDataError, errors.New("`messages` cannot be changed.") + } + if _, ok := req["reference"]; ok { + return nil, common.CodeDataError, errors.New("`reference` cannot be changed.") + } + + if name, exists := req["name"]; exists && name != nil { + nameStr, ok := name.(string) + if !ok || strings.TrimSpace(nameStr) == "" { + return nil, common.CodeDataError, errors.New("`name` can not be empty.") + } + req["name"] = strings.TrimSpace(nameStr) + nameRunes := []rune(req["name"].(string)) + if len(nameRunes) > 255 { + req["name"] = string(nameRunes[:255]) + } + } + + updateFields := make(map[string]interface{}) + for k, v := range req { + switch k { + case "id", "dialog_id", "chat_id", "user_id": + continue + default: + updateFields[k] = v + } + } + + if err := s.chatSessionDAO.UpdateByID(sessionID, updateFields); err != nil { + if isChatSessionNotFound(err) { + return nil, common.CodeDataError, errors.New("Session not found!") + } + return nil, common.CodeServerError, err + } + + session, err := s.chatSessionDAO.GetByID(sessionID) + if err != nil { + if isChatSessionNotFound(err) { + return nil, common.CodeDataError, errors.New("Fail to update a session!") + } + return nil, common.CodeServerError, err + } + + return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil +} + +func (s *ChatSessionService) ensureOwnedChat(userID, chatID string) (bool, error) { + tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) + if err != nil { + return false, err + } + + for _, tenantID := range tenantIDs { + exists, err := s.chatSessionDAO.CheckDialogExists(tenantID, chatID) + if err != nil { + return false, err + } + if exists { + return true, nil + } + } + + exists, err := s.chatSessionDAO.CheckDialogExists(userID, chatID) + if err != nil { + return false, err + } + return exists, nil +} + +func (s *ChatSessionService) buildSessionPayload(session *entity.ChatSession, dialog *entity.Chat, includeAvatar bool) *ChatSessionPayload { + var avatar *string + if includeAvatar { + value := "" + if dialog != nil && dialog.Icon != nil { + value = *dialog.Icon + } + avatar = &value + } + + references := parseReferenceList(session.Reference) + for index, ref := range references { + refMap, ok := ref.(map[string]interface{}) + if !ok { + continue + } + refMap["chunks"] = formatReferenceChunks(refMap) + references[index] = refMap + } + + return &ChatSessionPayload{ + ID: session.ID, + ChatID: session.DialogID, + Name: session.Name, + Messages: parseMessages(session.Message), + Reference: references, + UserID: session.UserID, + Avatar: avatar, + CreateDate: session.CreateDate, + UpdateDate: session.UpdateDate, + CreateTime: session.CreateTime, + UpdateTime: session.UpdateTime, + } +} + +func parseMessages(raw json.RawMessage) []map[string]interface{} { + var messages []map[string]interface{} + if len(raw) == 0 { + return messages + } + if err := json.Unmarshal(raw, &messages); err == nil { + return messages + } + + var wrapped struct { + Messages []map[string]interface{} `json:"messages"` + } + if err := json.Unmarshal(raw, &wrapped); err != nil { + return nil + } + return wrapped.Messages +} + +func parseReferenceList(raw json.RawMessage) []interface{} { + var references []interface{} + if len(raw) == 0 { + return references + } + err := json.Unmarshal(raw, &references) + if err != nil { + return nil + } + return references +} + +func formatReferenceChunks(reference map[string]interface{}) []FormattedChunk { + rawChunks, ok := reference["chunks"].([]interface{}) + if !ok { + return []FormattedChunk{} + } + + chunks := make([]map[string]interface{}, 0, len(rawChunks)) + for _, item := range rawChunks { + chunk, ok := item.(map[string]interface{}) + if !ok { + continue + } + chunks = append(chunks, chunk) + } + return formatChunks(chunks) +} + +func isChatSessionNotFound(err error) bool { + return errors.Is(err, gorm.ErrRecordNotFound) +} + // Completion performs chat completion with full RAG support via ChatPipelineService. func (s *ChatSessionService) Completion(userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string) (map[string]interface{}, error) { // Validate the last message is from user @@ -286,7 +508,7 @@ func (s *ChatSessionService) Completion(userID string, conversationID string, me return nil, errors.New("Dialog not found") } - // Deep copy messages to session + // Deep copy messages to session, preserving the stored prologue that handler strips from requests. sessionMessages := s.buildSessionMessages(session, messages) // Initialize reference if empty @@ -338,6 +560,12 @@ func (s *ChatSessionService) Completion(userID string, conversationID string, me // Update conversation if not embedded if !isEmbedded { + sessionMessages = append(sessionMessages, map[string]interface{}{ + "role": "assistant", + "content": answer.String(), + "id": messageID, + "created_at": float64(time.Now().Unix()), + }) s.updateSessionMessages(session, sessionMessages, reference) } @@ -375,7 +603,7 @@ func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string return errors.New("Dialog not found") } - // Deep copy messages to session + // Deep copy messages to session, preserving the stored prologue that handler strips from requests. sessionMessages := s.buildSessionMessages(session, messages) // Initialize reference if empty @@ -440,6 +668,12 @@ func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string // Update conversation if not embedded if !isEmbedded { + sessionMessages = append(sessionMessages, map[string]interface{}{ + "role": "assistant", + "content": fullAnswer.String(), + "id": messageID, + "created_at": float64(time.Now().Unix()), + }) s.updateSessionMessages(session, sessionMessages, reference) } @@ -449,14 +683,33 @@ func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string // Helper methods func (s *ChatSessionService) buildSessionMessages(session *entity.ChatSession, messages []map[string]interface{}) []map[string]interface{} { - // Deep copy messages to session - sessionMessages := make([]map[string]interface{}, len(messages)) - for i, msg := range messages { - sessionMessages[i] = make(map[string]interface{}) - for k, v := range msg { - sessionMessages[i][k] = v + prefix := make([]map[string]interface{}, 0, 1) + existingMessages := parseMessages(session.Message) + if len(existingMessages) > 0 { + if role, _ := existingMessages[0]["role"].(string); role == "assistant" { + firstIncomingRole := "" + if len(messages) > 0 { + firstIncomingRole, _ = messages[0]["role"].(string) + } + if firstIncomingRole != "assistant" { + prologue := make(map[string]interface{}, len(existingMessages[0])) + for k, v := range existingMessages[0] { + prologue[k] = v + } + prefix = append(prefix, prologue) + } } } + + sessionMessages := make([]map[string]interface{}, 0, len(prefix)+len(messages)) + sessionMessages = append(sessionMessages, prefix...) + for _, msg := range messages { + cloned := make(map[string]interface{}, len(msg)) + for k, v := range msg { + cloned[k] = v + } + sessionMessages = append(sessionMessages, cloned) + } return sessionMessages } @@ -495,9 +748,7 @@ func (s *ChatSessionService) structureAnswer(session *entity.ChatSession, answer func (s *ChatSessionService) updateSessionMessages(session *entity.ChatSession, messages []map[string]interface{}, reference []interface{}) { // Update session with new messages and reference - messagesJSON, _ := json.Marshal(map[string]interface{}{ - "messages": messages, - }) + messagesJSON, _ := json.Marshal(messages) referenceJSON, _ := json.Marshal(reference) updates := map[string]interface{}{ @@ -505,6 +756,8 @@ func (s *ChatSessionService) updateSessionMessages(session *entity.ChatSession, "reference": referenceJSON, } s.chatSessionDAO.UpdateByID(session.ID, updates) + session.Message = messagesJSON + session.Reference = referenceJSON } // structureAnswerWithConv structures the answer with conversation update (like Python's structure_answer) @@ -535,12 +788,8 @@ func (s *ChatSessionService) structureAnswerWithConv(session *entity.ChatSession content = "" } - // Parse existing messages - var messagesObj map[string]interface{} - if len(session.Message) > 0 { - json.Unmarshal(session.Message, &messagesObj) - } - messages, _ := messagesObj["messages"].([]interface{}) + // Parse existing messages. Keep backward compatibility with wrapped legacy rows. + messages := parseMessages(session.Message) // Update or append assistant message if len(messages) == 0 || s.getLastRole(messages) != "assistant" { @@ -552,20 +801,20 @@ func (s *ChatSessionService) structureAnswerWithConv(session *entity.ChatSession }) } else { lastIdx := len(messages) - 1 - lastMsg, _ := messages[lastIdx].(map[string]interface{}) - if lastMsg != nil { - if ans["final"] == true && ans["answer"] != nil { - lastMsg["content"] = ans["answer"] - } else { - existing, _ := lastMsg["content"].(string) - lastMsg["content"] = existing + content - } - lastMsg["created_at"] = float64(time.Now().Unix()) - lastMsg["id"] = messageID - messages[lastIdx] = lastMsg + lastMsg := messages[lastIdx] + if ans["final"] == true && ans["answer"] != nil { + lastMsg["content"] = ans["answer"] + } else { + existing, _ := lastMsg["content"].(string) + lastMsg["content"] = existing + content } + lastMsg["created_at"] = float64(time.Now().Unix()) + lastMsg["id"] = messageID + messages[lastIdx] = lastMsg } + session.Message, _ = json.Marshal(messages) + // Update reference if len(reference) > 0 { reference[len(reference)-1] = ref @@ -575,16 +824,12 @@ func (s *ChatSessionService) structureAnswerWithConv(session *entity.ChatSession } // getLastRole gets the role of the last message -func (s *ChatSessionService) getLastRole(messages []interface{}) string { +func (s *ChatSessionService) getLastRole(messages []map[string]interface{}) string { if len(messages) == 0 { return "" } - lastMsg, _ := messages[len(messages)-1].(map[string]interface{}) - if lastMsg != nil { - role, _ := lastMsg["role"].(string) - return role - } - return "" + role, _ := messages[len(messages)-1]["role"].(string) + return role } // chunksFormat formats chunks for reference (simplified version) diff --git a/internal/service/chat_session_test.go b/internal/service/chat_session_test.go index c6e6d16cea..ecf006e3fa 100644 --- a/internal/service/chat_session_test.go +++ b/internal/service/chat_session_test.go @@ -4,10 +4,13 @@ import ( "context" "encoding/json" "errors" + "reflect" "strings" "sync" "testing" + "gorm.io/gorm" + "ragflow/internal/common" "ragflow/internal/entity" ) @@ -48,7 +51,18 @@ func (f *fakeSessionStore) GetByID(id string) (*entity.ChatSession, error) { } s, ok := f.sessions[id] if !ok { - return nil, errors.New("record not found") + return nil, gorm.ErrRecordNotFound + } + return s, nil +} + +func (f *fakeSessionStore) GetBySessionIDAndChatID(sessionID, chatID string) (*entity.ChatSession, error) { + s, err := f.GetByID(sessionID) + if err != nil { + return nil, err + } + if s.DialogID != chatID { + return nil, gorm.ErrRecordNotFound } return s, nil } @@ -70,10 +84,30 @@ func (f *fakeSessionStore) UpdateByID(id string, updates map[string]interface{}) if f.updateByIDErr != nil { return f.updateByIDErr } + s, ok := f.sessions[id] + if !ok { + return gorm.ErrRecordNotFound + } f.updateCalled = append(f.updateCalled, struct { id string updates map[string]interface{} }{id, updates}) + for k, v := range updates { + switch k { + case "name": + if str, ok := v.(string); ok { + s.Name = &str + } + case "message": + if raw, ok := v.([]byte); ok { + s.Message = append(json.RawMessage(nil), raw...) + } + case "reference": + if raw, ok := v.([]byte); ok { + s.Reference = append(json.RawMessage(nil), raw...) + } + } + } return nil } @@ -177,16 +211,15 @@ func TestSetChatSession_CreateNew(t *testing.T) { t.Fatalf("expected 1 Create call, got %d", len(store.createCalled)) } - // Verify prologue is in the message - var msgObj map[string]interface{} - if err := json.Unmarshal(store.createCalled[0].Message, &msgObj); err != nil { + // Verify prologue is in the message list. + var msgs []map[string]interface{} + if err := json.Unmarshal(store.createCalled[0].Message, &msgs); err != nil { t.Fatalf("failed to unmarshal message: %v", err) } - msgs, _ := msgObj["messages"].([]interface{}) if len(msgs) != 1 { t.Fatalf("expected 1 initial message, got %d", len(msgs)) } - firstMsg, _ := msgs[0].(map[string]interface{}) + firstMsg := msgs[0] if firstMsg["role"] != "assistant" || firstMsg["content"] != "Welcome!" { t.Fatalf("unexpected prologue message: %#v", firstMsg) } @@ -213,10 +246,9 @@ func TestSetChatSession_CreateNewDefaultPrologue(t *testing.T) { t.Fatal("expected session ID") } // Default prologue - var msgObj map[string]interface{} - json.Unmarshal(store.createCalled[0].Message, &msgObj) - msgs, _ := msgObj["messages"].([]interface{}) - firstMsg, _ := msgs[0].(map[string]interface{}) + var msgs []map[string]interface{} + json.Unmarshal(store.createCalled[0].Message, &msgs) + firstMsg := msgs[0] if !strings.Contains(firstMsg["content"].(string), "Hi! I'm your assistant") { t.Fatalf("expected default prologue, got %q", firstMsg["content"]) } @@ -407,6 +439,205 @@ func TestListChatSessions_NotOwner(t *testing.T) { } } +// =================================================================== +// GetSession / UpdateSession tests +// =================================================================== + +func TestGetSession_Success(t *testing.T) { + store := newFakeSessionStore() + store.sessions["session-1"] = &entity.ChatSession{ + ID: "session-1", + DialogID: "chat-1", + Name: strPtr("session"), + Message: json.RawMessage(`[{"role":"assistant","content":"hello"}]`), + Reference: json.RawMessage(`[ + {"chunks":[{"chunk_id":"chunk-1","content_with_weight":"hello","doc_id":"doc-1","docnm_kwd":"Doc 1","kb_id":"kb-1"}]}, + [] + ]`), + UserID: strPtr("user-1"), + } + icon := "avatar.png" + store.dialogs["chat-1"] = &entity.Chat{ID: "chat-1", Icon: &icon} + store.dialogExists["user-1|chat-1"] = true + + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, + } + + resp, code, err := svc.GetSession("user-1", "chat-1", "session-1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("unexpected code: %v", code) + } + if resp.ChatID != "chat-1" { + t.Fatalf("chat_id=%q", resp.ChatID) + } + if resp.Avatar == nil || *resp.Avatar != "avatar.png" { + t.Fatalf("avatar=%v", resp.Avatar) + } + if len(resp.Messages) != 1 || resp.Messages[0]["content"] != "hello" { + t.Fatalf("messages=%#v", resp.Messages) + } + if len(resp.Reference) != 2 { + t.Fatalf("reference len=%d", len(resp.Reference)) + } + firstRef, ok := resp.Reference[0].(map[string]interface{}) + if !ok { + t.Fatalf("reference[0] type=%T", resp.Reference[0]) + } + chunks, ok := firstRef["chunks"].([]FormattedChunk) + if !ok { + t.Fatalf("chunks type=%T", firstRef["chunks"]) + } + if len(chunks) != 1 || chunks[0].ID != "chunk-1" { + t.Fatalf("chunks=%#v", chunks) + } + if _, ok := resp.Reference[1].([]interface{}); !ok { + t.Fatalf("reference[1] changed unexpectedly: %T", resp.Reference[1]) + } +} + +func TestGetSession_NotOwner(t *testing.T) { + svc := &ChatSessionService{ + chatSessionDAO: newFakeSessionStore(), + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, + } + + _, code, err := svc.GetSession("user-1", "chat-1", "session-1") + if err == nil || err.Error() != "No authorization." { + t.Fatalf("err=%v", err) + } + if code != common.CodeAuthenticationError { + t.Fatalf("code=%v", code) + } +} + +func TestGetSession_WrongChat(t *testing.T) { + store := newFakeSessionStore() + store.sessions["session-1"] = &entity.ChatSession{ID: "session-1", DialogID: "chat-2"} + store.dialogExists["user-1|chat-1"] = true + + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, + } + + _, code, err := svc.GetSession("user-1", "chat-1", "session-1") + if err == nil || err.Error() != "Session does not belong to this chat!" { + t.Fatalf("err=%v", err) + } + if code != common.CodeDataError { + t.Fatalf("code=%v", code) + } +} + +func TestUpdateSession_Success(t *testing.T) { + store := newFakeSessionStore() + store.sessions["session-1"] = &entity.ChatSession{ + ID: "session-1", + DialogID: "chat-1", + Name: strPtr("old"), + Message: json.RawMessage(`[{"role":"assistant","content":"hello"}]`), + } + store.dialogExists["user-1|chat-1"] = true + + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, + } + + longName := " " + strings.Repeat("x", 260) + " " + resp, code, err := svc.UpdateSession("user-1", "chat-1", "session-1", map[string]interface{}{ + "name": longName, + "user_id": "spoof", + "chat_id": "spoof-chat", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("code=%v", code) + } + if resp.Name == nil || len(*resp.Name) != 255 { + t.Fatalf("name=%v", resp.Name) + } + if len(store.updateCalled) != 1 { + t.Fatalf("update calls=%d", len(store.updateCalled)) + } + if _, ok := store.updateCalled[0].updates["user_id"]; ok { + t.Fatalf("unexpected user_id update: %#v", store.updateCalled[0].updates) + } + if _, ok := store.updateCalled[0].updates["chat_id"]; ok { + t.Fatalf("unexpected chat_id update: %#v", store.updateCalled[0].updates) + } + if !reflect.DeepEqual(resp.Messages, []map[string]interface{}{{"role": "assistant", "content": "hello"}}) { + t.Fatalf("messages=%#v", resp.Messages) + } +} + +func TestUpdateSession_ValidationErrors(t *testing.T) { + store := newFakeSessionStore() + store.sessions["session-1"] = &entity.ChatSession{ID: "session-1", DialogID: "chat-1"} + store.dialogExists["user-1|chat-1"] = true + + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, + } + + cases := []struct { + name string + req map[string]interface{} + message string + code common.ErrorCode + }{ + {name: "empty body", req: map[string]interface{}{}, message: "Request body cannot be empty", code: common.CodeArgumentError}, + {name: "message", req: map[string]interface{}{"message": []interface{}{}}, message: "`messages` cannot be changed.", code: common.CodeDataError}, + {name: "messages", req: map[string]interface{}{"messages": []interface{}{}}, message: "`messages` cannot be changed.", code: common.CodeDataError}, + {name: "reference", req: map[string]interface{}{"reference": []interface{}{}}, message: "`reference` cannot be changed.", code: common.CodeDataError}, + {name: "empty name", req: map[string]interface{}{"name": " "}, message: "`name` can not be empty.", code: common.CodeDataError}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + _, code, err := svc.UpdateSession("user-1", "chat-1", "session-1", tc.req) + if err == nil || err.Error() != tc.message { + t.Fatalf("err=%v", err) + } + if code != tc.code { + t.Fatalf("code=%v", code) + } + }) + } +} + +func TestUpdateSession_NotFound(t *testing.T) { + store := newFakeSessionStore() + store.dialogExists["user-1|chat-1"] = true + + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, + } + + _, code, err := svc.UpdateSession("user-1", "chat-1", "missing", map[string]interface{}{"name": "renamed"}) + if err == nil || err.Error() != "Session not found!" { + t.Fatalf("err=%v", err) + } + if code != common.CodeDataError { + t.Fatalf("code=%v", code) + } +} + // =================================================================== // Completion tests // =================================================================== @@ -415,7 +646,7 @@ func TestCompletion_Success(t *testing.T) { store := newFakeSessionStore() session := &entity.ChatSession{ ID: "session-1", DialogID: "dialog-1", - Message: json.RawMessage(`{"messages":[]}`), + Message: json.RawMessage(`[{"role":"assistant","content":"Welcome!"}]`), Reference: json.RawMessage(`[]`), } store.sessions["session-1"] = session @@ -447,6 +678,20 @@ func TestCompletion_Success(t *testing.T) { if ans != "Hello world" { t.Fatalf("expected answer 'Hello world', got %q", ans) } + + got := parseMessages(store.sessions["session-1"].Message) + if len(got) != 3 { + t.Fatalf("stored messages=%#v", got) + } + if got[0]["role"] != "assistant" || got[0]["content"] != "Welcome!" { + t.Fatalf("stored prologue=%#v", got[0]) + } + if got[1]["role"] != "user" || got[1]["content"] != "hi" { + t.Fatalf("stored user message=%#v", got[1]) + } + if got[2]["role"] != "assistant" || got[2]["content"] != "Hello world" || got[2]["id"] != "msg-1" { + t.Fatalf("stored assistant message=%#v", got[2]) + } } func TestCompletion_EmptyMessages(t *testing.T) { @@ -498,7 +743,7 @@ func TestCompletion_DialogNotFound(t *testing.T) { store := newFakeSessionStore() store.sessions["session-1"] = &entity.ChatSession{ ID: "session-1", DialogID: "dialog-1", - Message: json.RawMessage(`{"messages":[]}`), + Message: json.RawMessage(`[]`), Reference: json.RawMessage(`[]`), } @@ -520,7 +765,7 @@ func TestCompletion_PipelineError(t *testing.T) { store := newFakeSessionStore() store.sessions["session-1"] = &entity.ChatSession{ ID: "session-1", DialogID: "dialog-1", - Message: json.RawMessage(`{"messages":[]}`), + Message: json.RawMessage(`[]`), Reference: json.RawMessage(`[]`), } store.dialogs["dialog-1"] = &entity.Chat{ @@ -566,7 +811,7 @@ func TestCompletionStream_Success(t *testing.T) { store := newFakeSessionStore() store.sessions["session-1"] = &entity.ChatSession{ ID: "session-1", DialogID: "dialog-1", - Message: json.RawMessage(`{"messages":[]}`), + Message: json.RawMessage(`{"messages":[{"role":"assistant","content":"Welcome!"}]}`), Reference: json.RawMessage(`[]`), } store.dialogs["dialog-1"] = &entity.Chat{ @@ -611,6 +856,53 @@ func TestCompletionStream_Success(t *testing.T) { if !finalFound { t.Fatal("expected final=true signal in stream") } + + got := parseMessages(store.sessions["session-1"].Message) + if len(got) != 3 { + t.Fatalf("stored messages=%#v", got) + } + if got[0]["role"] != "assistant" || got[0]["content"] != "Welcome!" { + t.Fatalf("stored prologue=%#v", got[0]) + } + if got[1]["role"] != "user" || got[1]["content"] != "hi" { + t.Fatalf("stored user message=%#v", got[1]) + } + if got[2]["role"] != "assistant" || got[2]["content"] != "stream answer" || got[2]["id"] != "msg-1" { + t.Fatalf("stored assistant message=%#v", got[2]) + } +} + +func TestStructureAnswerWithConv_ParsesArrayMessages(t *testing.T) { + session := &entity.ChatSession{ + ID: "session-1", + Message: json.RawMessage(`[{"role":"assistant","content":"Welcome!"}]`), + } + svc := &ChatSessionService{} + + ans := svc.structureAnswerWithConv(session, map[string]interface{}{ + "answer": "Final answer", + "reference": map[string]interface{}{"chunks": []interface{}{}}, + "final": true, + }, "msg-1", "session-1", []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) + + if ans["id"] != "msg-1" || ans["session_id"] != "session-1" { + t.Fatalf("ans=%#v", ans) + } + + got := parseMessages(session.Message) + if len(got) != 1 { + t.Fatalf("stored messages=%#v", got) + } + if got[0]["role"] != "assistant" || got[0]["content"] != "Final answer" || got[0]["id"] != "msg-1" { + t.Fatalf("stored assistant message=%#v", got[0]) + } +} + +func TestParseMessages_LegacyWrappedObject(t *testing.T) { + got := parseMessages(json.RawMessage(`{"messages":[{"role":"assistant","content":"legacy"}]}`)) + if !reflect.DeepEqual(got, []map[string]interface{}{{"role": "assistant", "content": "legacy"}}) { + t.Fatalf("messages=%#v", got) + } } func TestCompletionStream_EmptyMessages(t *testing.T) { @@ -664,7 +956,7 @@ func TestCompletionStream_DialogNotFound(t *testing.T) { store := newFakeSessionStore() store.sessions["session-1"] = &entity.ChatSession{ ID: "session-1", DialogID: "dialog-1", - Message: json.RawMessage(`{"messages":[]}`), + Message: json.RawMessage(`[]`), Reference: json.RawMessage(`[]`), } @@ -687,7 +979,7 @@ func TestCompletionStream_PipelineError(t *testing.T) { store := newFakeSessionStore() store.sessions["session-1"] = &entity.ChatSession{ ID: "session-1", DialogID: "dialog-1", - Message: json.RawMessage(`{"messages":[]}`), + Message: json.RawMessage(`[]`), Reference: json.RawMessage(`[]`), } store.dialogs["dialog-1"] = &entity.Chat{