Compare commits

...

1 Commits

Author SHA1 Message Date
Jin Hai
e2c2d8d61a Revert "feat(go-api): Add Go chat session message delete and feedback APIs (#…"
This reverts commit a553886989.
2026-06-29 21:01:22 +08:00
5 changed files with 9 additions and 1159 deletions

View File

@@ -271,77 +271,6 @@ func (e *elasticsearchEngine) UpdateChunks(ctx context.Context, condition map[st
return e.updateChunksByQuery(ctx, fullIndexName, condition, newValue)
}
// AdjustChunkPagerank atomically adjusts pagerank_fea and clamps it to [minWeight, maxWeight].
func (e *elasticsearchEngine) AdjustChunkPagerank(ctx context.Context, indexName, chunkID, kbID string, delta, minWeight, maxWeight float64) error {
if indexName == "" {
return fmt.Errorf("index name cannot be empty")
}
if chunkID == "" {
return fmt.Errorf("chunk id cannot be empty")
}
script := `
if (ctx._source.kb_id == null || !ctx._source.kb_id.equals(params.kb_id)) {
ctx.op = 'noop';
} else {
double current = 0.0;
if (ctx._source.containsKey(params.field) && ctx._source[params.field] != null) {
current = ((Number)ctx._source[params.field]).doubleValue();
}
double next = current + params.delta;
if (next < params.min_weight) {
next = params.min_weight;
}
if (next > params.max_weight) {
next = params.max_weight;
}
ctx._source[params.field] = next;
}
`
body, err := json.Marshal(map[string]interface{}{
"script": map[string]interface{}{
"source": script,
"lang": "painless",
"params": map[string]interface{}{
"field": common.PAGERANK_FLD,
"kb_id": kbID,
"delta": delta,
"min_weight": minWeight,
"max_weight": maxWeight,
},
},
})
if err != nil {
return fmt.Errorf("failed to marshal pagerank adjust request: %w", err)
}
req := esapi.UpdateRequest{
Index: indexName,
DocumentID: chunkID,
Body: bytes.NewReader(body),
}
res, err := req.Do(ctx, e.client)
if err != nil {
return fmt.Errorf("failed to adjust chunk pagerank: %w", err)
}
defer res.Body.Close()
if res.IsError() {
if res.StatusCode == http.StatusNotFound {
return fmt.Errorf("%w: %s", types.ErrDocumentNotFound, chunkID)
}
bodyBytes, _ := io.ReadAll(res.Body)
return fmt.Errorf("elasticsearch pagerank adjust error: %s, body: %s", res.Status(), string(bodyBytes))
}
var updateResp struct {
Result string `json:"result"`
}
if err := json.NewDecoder(res.Body).Decode(&updateResp); err != nil {
return fmt.Errorf("failed to decode pagerank adjust response: %w", err)
}
if updateResp.Result == "noop" {
return fmt.Errorf("chunk %s does not belong to dataset %s", chunkID, kbID)
}
return nil
}
func (e *elasticsearchEngine) updateSingleMemoryMessage(ctx context.Context, indexName, messageDocID string, newValue map[string]interface{}) error {
doc := mapMemoryMessageESUpdateFields(newValue)
delete(doc, "id")

View File

@@ -474,53 +474,3 @@ func (h *ChatSessionHandler) UpdateSession(c *gin.Context) {
}
jsonResponse(c, common.CodeSuccess, result, "success")
}
func (h *ChatSessionHandler) DeleteSessionMessage(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := user.ID
chatID, sessionID, msgID := c.Param("chat_id"), c.Param("session_id"), c.Param("msg_id")
result, code, err := h.chatSessionService.DeleteSessionMessage(userID, chatID, sessionID, msgID)
if err != nil {
if code == common.CodeAuthenticationError {
jsonResponse(c, code, false, err.Error())
return
}
jsonError(c, code, err.Error())
return
}
jsonResponse(c, common.CodeSuccess, result, "success")
}
func (h *ChatSessionHandler) UpdateMessageFeedback(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := user.ID
chatID, sessionID, msgID := c.Param("chat_id"), c.Param("session_id"), c.Param("msg_id")
req := map[string]interface{}{}
if err := c.ShouldBindJSON(&req); err != nil {
jsonError(c, common.CodeArgumentError, "Invalid request: "+err.Error())
return
}
result, code, err := h.chatSessionService.UpdateMessageFeedback(userID, chatID, sessionID, msgID, req)
if err != nil {
if code == common.CodeAuthenticationError {
jsonResponse(c, code, false, err.Error())
return
}
jsonError(c, code, err.Error())
return
}
jsonResponse(c, common.CodeSuccess, result, "success")
}

View File

@@ -284,8 +284,6 @@ func (r *Router) Setup(engine *gin.Engine) {
chats.DELETE("/:chat_id/sessions", r.chatSessionHandler.DeleteSessions)
chats.GET("/:chat_id/sessions/:session_id", r.chatSessionHandler.GetSession)
chats.PATCH("/:chat_id/sessions/:session_id", r.chatSessionHandler.UpdateSession)
chats.DELETE("/:chat_id/sessions/:session_id/messages/:msg_id", r.chatSessionHandler.DeleteSessionMessage)
chats.PUT("/:chat_id/sessions/:session_id/messages/:msg_id/feedback", r.chatSessionHandler.UpdateMessageFeedback)
}
// OpenAI-compatible chat completions route

View File

@@ -21,10 +21,7 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
"os"
"ragflow/internal/common"
"ragflow/internal/engine"
"ragflow/internal/storage"
"strings"
"time"
@@ -57,21 +54,12 @@ type chatPipelineRunner interface {
AsyncChat(ctx context.Context, chat *entity.Chat, messages []map[string]interface{}, stream bool, kwargs map[string]interface{}) (<-chan AsyncChatResult, error)
}
type chunkFeedbackApplier interface {
applyChunkFeedback(tenantID string, reference map[string]interface{}, isPositive bool) (map[string]interface{}, error)
}
type atomicChunkPagerankAdjuster interface {
AdjustChunkPagerank(ctx context.Context, indexName, chunkID, kbID string, delta, minWeight, maxWeight float64) error
}
// ChatSessionService chat session (conversation) service.
// The RAG pipeline is delegated to ChatPipelineService.
type ChatSessionService struct {
chatSessionDAO chatSessionStore
userTenantDAO userTenantStore
pipeline chatPipelineRunner
chunkFeedbackApplier chunkFeedbackApplier
chatSessionDAO chatSessionStore
userTenantDAO userTenantStore
pipeline chatPipelineRunner
}
// NewChatSessionService create chat session service
@@ -614,194 +602,6 @@ func (s *ChatSessionService) UpdateSession(userID, chatID, sessionID string, req
return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil
}
func (s *ChatSessionService) DeleteSessionMessage(userID, chatID, sessionID, msgID string) (*ChatSessionPayload, common.ErrorCode, error) {
ok, err := s.ensureOwnedChat(userID, chatID)
if err != nil {
return nil, common.CodeServerError, err
}
if !ok {
return nil, common.CodeAuthenticationError, errors.New("No authorization.")
}
session, err := s.chatSessionDAO.GetByID(sessionID)
if err != nil || session.DialogID != chatID {
if err != nil && !isChatSessionNotFound(err) {
return nil, common.CodeServerError, err
}
return nil, common.CodeDataError, errors.New("Session not found!")
}
messages := parseMessages(session.Message)
if len(session.Message) > 0 && messages == nil {
return nil, common.CodeDataError, errors.New("Invalid session messages")
}
references := parseReferenceList(session.Reference)
if len(session.Reference) > 0 && references == nil {
return nil, common.CodeDataError, errors.New("Invalid session reference")
}
for i, msg := range messages {
if msgID != stringValue(msg["id"]) {
continue
}
if i+1 >= len(messages) || stringValue(messages[i+1]["id"]) != msgID {
return nil, common.CodeServerError, errors.New("message pair assertion failed")
}
messages = append(messages[:i], messages[i+2:]...)
refIndex := (i - 1) / 2
if refIndex < 0 {
refIndex = 0
}
if refIndex < len(references) {
references = append(references[:refIndex], references[refIndex+1:]...)
}
break
}
messageRaw, err := json.Marshal(messages)
if err != nil {
return nil, common.CodeServerError, err
}
referenceRaw, err := json.Marshal(references)
if err != nil {
return nil, common.CodeServerError, err
}
if err := s.chatSessionDAO.UpdateByID(session.ID, map[string]interface{}{
"message": messageRaw,
"reference": referenceRaw,
}); err != nil {
return nil, common.CodeServerError, err
}
session.Message = messageRaw
session.Reference = referenceRaw
return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil
}
func (s *ChatSessionService) UpdateMessageFeedback(userID, chatID, sessionID, msgID string, req map[string]interface{}) (*ChatSessionPayload, common.ErrorCode, error) {
ownerTenantID := ""
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
if err != nil {
return nil, common.CodeServerError, err
}
for _, tenantID := range tenantIDs {
exists, err := s.chatSessionDAO.CheckDialogExists(tenantID, chatID)
if err != nil {
return nil, common.CodeServerError, err
}
if exists {
ownerTenantID = tenantID
break
}
}
if ownerTenantID == "" {
exists, err := s.chatSessionDAO.CheckDialogExists(userID, chatID)
if err != nil {
return nil, common.CodeServerError, err
}
if exists {
ownerTenantID = userID
}
}
ok := ownerTenantID != ""
if !ok {
return nil, common.CodeAuthenticationError, errors.New("No authorization.")
}
session, err := s.chatSessionDAO.GetByID(sessionID)
if err != nil || session.DialogID != chatID {
if err != nil && !isChatSessionNotFound(err) {
return nil, common.CodeServerError, err
}
return nil, common.CodeDataError, errors.New("Session not found!")
}
thumbRaw, ok := req["thumbup"]
if !ok {
return nil, common.CodeDataError, errors.New("thumbup must be a boolean")
}
thumbup, ok := thumbRaw.(bool)
if !ok {
return nil, common.CodeDataError, errors.New("thumbup must be a boolean")
}
messages := parseMessages(session.Message)
if len(session.Message) > 0 && messages == nil {
return nil, common.CodeDataError, errors.New("Invalid session messages")
}
messageIndex := -1
var priorThumb interface{}
applyChunkFeedback := false
var feedbackReference map[string]interface{}
for i, msg := range messages {
if msgID != stringValue(msg["id"]) || stringValue(msg["role"]) != "assistant" {
continue
}
priorThumb = msg["thumbup"]
priorThumbBool, priorThumbIsBool := priorThumb.(bool)
if thumbup {
msg["thumbup"] = true
delete(msg, "feedback")
applyChunkFeedback = !priorThumbIsBool || !priorThumbBool
} else {
msg["thumbup"] = false
if feedback, exists := req["feedback"]; exists && isTruthy(feedback) {
msg["feedback"] = feedback
}
applyChunkFeedback = !priorThumbIsBool || priorThumbBool
}
messages[i] = msg
messageIndex = i
break
}
if messageIndex != -1 && applyChunkFeedback {
references := parseReferenceList(session.Reference)
if len(session.Reference) > 0 && references == nil {
return nil, common.CodeDataError, errors.New("Invalid session reference")
}
refIndex := (messageIndex - 1) / 2
if refIndex >= 0 && refIndex < len(references) {
if reference, ok := references[refIndex].(map[string]interface{}); ok && len(reference) > 0 {
feedbackReference = reference
}
}
}
messageRaw, err := json.Marshal(messages)
if err != nil {
return nil, common.CodeServerError, err
}
if err := s.chatSessionDAO.UpdateByID(session.ID, map[string]interface{}{"message": messageRaw}); err != nil {
return nil, common.CodeServerError, err
}
session.Message = messageRaw
if feedbackReference != nil {
applier := s.chunkFeedbackApplier
if applier == nil {
applier = s
}
if priorThumbBool, ok := priorThumb.(bool); ok && priorThumbBool != thumbup {
result, _ := applier.applyChunkFeedback(ownerTenantID, feedbackReference, !priorThumbBool)
if result != nil {
common.Debug("Chunk feedback undo applied",
zap.Any("success_count", result["success_count"]),
zap.Any("fail_count", result["fail_count"]),
)
}
}
result, _ := applier.applyChunkFeedback(ownerTenantID, feedbackReference, thumbup)
if result != nil {
common.Debug("Chunk feedback applied",
zap.Any("success_count", result["success_count"]),
zap.Any("fail_count", result["fail_count"]),
)
}
}
return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil
}
func (s *ChatSessionService) ensureOwnedChat(userID, chatID string) (bool, error) {
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
if err != nil {
@@ -907,208 +707,6 @@ func parseReferenceList(raw json.RawMessage) []interface{} {
return references
}
type chunkFeedbackRow struct {
chunkID string
kbID string
chunk map[string]interface{}
}
func (s *ChatSessionService) applyChunkFeedback(tenantID string, reference map[string]interface{}, isPositive bool) (map[string]interface{}, error) {
if strings.ToLower(os.Getenv("CHUNK_FEEDBACK_ENABLED")) != "true" {
return map[string]interface{}{"success_count": 0, "fail_count": 0, "chunk_ids": []string{}, "disabled": true}, nil
}
rows := feedbackRowsFromReference(reference)
chunkIDs := make([]string, 0, len(rows))
for _, row := range rows {
chunkIDs = append(chunkIDs, row.chunkID)
}
if len(rows) == 0 {
return map[string]interface{}{"success_count": 0, "fail_count": 0, "chunk_ids": chunkIDs}, nil
}
signedBudget := 1
if !isPositive {
signedBudget = -1
}
weighting := strings.TrimSpace(strings.ToLower(os.Getenv("CHUNK_FEEDBACK_WEIGHTING")))
deltas := allocateFeedbackDeltasRelevance(rows, signedBudget)
if weighting == "uniform" {
deltas = allocateFeedbackDeltasUniform(rows, signedBudget)
}
successCount := 0
failCount := 0
for i, delta := range deltas {
if delta == 0 {
continue
}
if s.updateChunkWeight(context.Background(), tenantID, rows[i].chunkID, rows[i].kbID, delta) {
successCount++
} else {
failCount++
}
}
return map[string]interface{}{"success_count": successCount, "fail_count": failCount, "chunk_ids": chunkIDs}, nil
}
func feedbackRowsFromReference(reference map[string]interface{}) []chunkFeedbackRow {
if len(reference) == 0 {
return nil
}
rawChunks, ok := reference["chunks"].([]interface{})
if !ok {
return nil
}
rows := make([]chunkFeedbackRow, 0, len(rawChunks))
for _, raw := range rawChunks {
chunk, ok := raw.(map[string]interface{})
if !ok {
continue
}
chunkID := stringValue(chunk["id"])
if chunkID == "" {
chunkID = stringValue(chunk["chunk_id"])
}
kbID := stringValue(chunk["dataset_id"])
if kbID == "" {
kbID = stringValue(chunk["kb_id"])
}
if chunkID != "" && kbID != "" {
rows = append(rows, chunkFeedbackRow{chunkID: chunkID, kbID: kbID, chunk: chunk})
}
}
return rows
}
func allocateFeedbackDeltasUniform(rows []chunkFeedbackRow, signedBudget int) []int {
deltas := make([]int, len(rows))
for i := range rows {
deltas[i] = signedBudget
}
return deltas
}
func allocateFeedbackDeltasRelevance(rows []chunkFeedbackRow, signedBudget int) []int {
magnitudes := make([]float64, len(rows))
for i, row := range rows {
magnitudes[i] = feedbackRetrievalSignal(row.chunk)
}
total := 0.0
for _, magnitude := range magnitudes {
total += magnitude
}
if total <= 0 {
for i := range magnitudes {
magnitudes[i] = 1
}
}
sign := 1
if signedBudget < 0 {
sign = -1
}
budgetAbs := signedBudget
if budgetAbs < 0 {
budgetAbs = -budgetAbs
}
parts := splitIntegerBudget(magnitudes, budgetAbs)
for i := range parts {
parts[i] *= sign
}
return parts
}
func feedbackRetrievalSignal(chunk map[string]interface{}) float64 {
best := 0.0
for _, key := range []string{"similarity", "vector_similarity", "term_similarity"} {
val := floatFromValue(chunk[key])
if !math.IsNaN(val) && !math.IsInf(val, 0) && val > best {
best = val
}
}
return best
}
func splitIntegerBudget(magnitudes []float64, budget int) []int {
n := len(magnitudes)
parts := make([]int, n)
if n == 0 || budget == 0 {
return parts
}
total := 0.0
for _, magnitude := range magnitudes {
total += magnitude
}
if total <= 0 {
base := budget / n
remainder := budget % n
for i := range parts {
parts[i] = base
if i < remainder {
parts[i]++
}
}
return parts
}
type remainder struct {
index int
value float64
}
remainders := make([]remainder, 0, n)
assigned := 0
for i, magnitude := range magnitudes {
exact := magnitude / total * float64(budget)
base := int(math.Floor(exact))
parts[i] = base
assigned += base
remainders = append(remainders, remainder{index: i, value: exact - float64(base)})
}
for remaining := budget - assigned; remaining > 0; remaining-- {
best := 0
for i := 1; i < len(remainders); i++ {
if remainders[i].value > remainders[best].value {
best = i
}
}
parts[remainders[best].index]++
remainders[best].value = -1
}
return parts
}
func (s *ChatSessionService) updateChunkWeight(ctx context.Context, tenantID, chunkID, kbID string, delta int) bool {
docEngine := engine.Get()
if docEngine == nil {
return false
}
indexName := fmt.Sprintf("ragflow_%s", tenantID)
if adjuster, ok := docEngine.(atomicChunkPagerankAdjuster); ok {
return adjuster.AdjustChunkPagerank(ctx, indexName, chunkID, kbID, float64(delta), 0, 100) == nil
}
rawChunk, err := docEngine.GetChunk(ctx, indexName, chunkID, []string{kbID})
if err != nil || rawChunk == nil {
return false
}
chunk, ok := rawChunk.(map[string]interface{})
if !ok {
return false
}
currentWeight := floatFromValue(chunk[common.PAGERANK_FLD])
newWeight := currentWeight + float64(delta)
if newWeight < 0 {
newWeight = 0
}
if newWeight > 100 {
newWeight = 100
}
return docEngine.UpdateChunks(ctx, map[string]interface{}{"id": chunkID}, map[string]interface{}{common.PAGERANK_FLD: newWeight}, indexName, kbID) == nil
}
func formatReferenceChunks(reference map[string]interface{}) []FormattedChunk {
rawChunks, ok := reference["chunks"].([]interface{})
if !ok {

View File

@@ -170,24 +170,6 @@ func (f *fakePipeline) AsyncChat(ctx context.Context, chat *entity.Chat, message
return f.resultChan, f.err
}
type fakeChunkFeedbackApplier struct {
calls []struct {
tenantID string
reference map[string]interface{}
isPositive bool
}
err error
}
func (f *fakeChunkFeedbackApplier) applyChunkFeedback(tenantID string, reference map[string]interface{}, isPositive bool) (map[string]interface{}, error) {
f.calls = append(f.calls, struct {
tenantID string
reference map[string]interface{}
isPositive bool
}{tenantID: tenantID, reference: reference, isPositive: isPositive})
return map[string]interface{}{"success_count": 0, "fail_count": 0, "chunk_ids": []string{}}, f.err
}
func makeResultChan(results ...AsyncChatResult) <-chan AsyncChatResult {
ch := make(chan AsyncChatResult, len(results))
for _, r := range results {
@@ -207,10 +189,9 @@ func TestSetChatSession_CreateNew(t *testing.T) {
store.dialogs["dialog-1"] = dialog
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
chunkFeedbackApplier: &fakeChunkFeedbackApplier{},
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
resp, err := svc.SetChatSession("user-1", &SetChatSessionRequest{
@@ -566,12 +547,10 @@ func TestUpdateSession_Success(t *testing.T) {
}
store.dialogExists["user-1|chat-1"] = true
feedback := &fakeChunkFeedbackApplier{}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
chunkFeedbackApplier: feedback,
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
longName := " " + strings.Repeat("x", 260) + " "
@@ -659,610 +638,6 @@ func TestUpdateSession_NotFound(t *testing.T) {
}
}
func TestDeleteSessionMessage_RemovesMessagePairAndMatchingReference(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Message: json.RawMessage(`[
{"role":"assistant","content":"prologue"},
{"role":"user","content":"q1","id":"m1"},
{"role":"assistant","content":"a1","id":"m1"},
{"role":"user","content":"q2","id":"m2"},
{"role":"assistant","content":"a2","id":"m2"}
]`),
Reference: json.RawMessage(`[
{"chunks":[{"id":"c1","dataset_id":"kb1"}],"doc_aggs":[]},
{"chunks":[{"id":"c2","dataset_id":"kb2"}],"doc_aggs":[]}
]`),
}
store.dialogExists["user-1|chat-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
resp, code, err := svc.DeleteSessionMessage("user-1", "chat-1", "session-1", "m2")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("code=%v", code)
}
if len(resp.Messages) != 3 {
t.Fatalf("messages=%#v", resp.Messages)
}
if resp.Messages[1]["id"] != "m1" || resp.Messages[2]["id"] != "m1" {
t.Fatalf("unexpected remaining messages=%#v", resp.Messages)
}
var refs []map[string]interface{}
if err := json.Unmarshal(store.sessions["session-1"].Reference, &refs); err != nil {
t.Fatalf("decode stored references: %v", err)
}
if len(refs) != 1 {
t.Fatalf("reference len=%d refs=%#v", len(refs), refs)
}
chunks, _ := refs[0]["chunks"].([]interface{})
chunk, _ := chunks[0].(map[string]interface{})
if chunk["id"] != "c1" {
t.Fatalf("wrong reference remained: %#v", refs)
}
}
func TestDeleteSessionMessage_NotOwner(t *testing.T) {
svc := &ChatSessionService{
chatSessionDAO: newFakeSessionStore(),
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, code, err := svc.DeleteSessionMessage("user-1", "chat-1", "session-1", "m1")
if err == nil || err.Error() != "No authorization." {
t.Fatalf("err=%v", err)
}
if code != common.CodeAuthenticationError {
t.Fatalf("code=%v", code)
}
}
func TestDeleteSessionMessage_SessionNotFoundForWrongChat(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{ID: "session-1", DialogID: "chat-2"}
store.dialogExists["user-1|chat-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, code, err := svc.DeleteSessionMessage("user-1", "chat-1", "session-1", "m1")
if err == nil || err.Error() != "Session not found!" {
t.Fatalf("err=%v", err)
}
if code != common.CodeDataError {
t.Fatalf("code=%v", code)
}
}
func TestDeleteSessionMessage_MissingMessageIDLeavesSessionUnchanged(t *testing.T) {
store := newFakeSessionStore()
originalMessages := json.RawMessage(`[
{"role":"assistant","content":"prologue"},
{"role":"user","content":"q1","id":"m1"},
{"role":"assistant","content":"a1","id":"m1"}
]`)
originalReferences := json.RawMessage(`[{"chunks":[{"id":"c1","dataset_id":"kb1"}],"doc_aggs":[]}]`)
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Message: originalMessages,
Reference: originalReferences,
}
store.dialogExists["user-1|chat-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
resp, code, err := svc.DeleteSessionMessage("user-1", "chat-1", "session-1", "missing")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("code=%v", code)
}
if len(resp.Messages) != 3 || len(resp.Reference) != 1 {
t.Fatalf("response changed unexpectedly: messages=%#v refs=%#v", resp.Messages, resp.Reference)
}
if !reflect.DeepEqual(parseMessages(store.sessions["session-1"].Message), parseMessages(originalMessages)) {
t.Fatalf("stored messages changed: %s", store.sessions["session-1"].Message)
}
}
func TestUpdateMessageFeedback_UpdatesAssistantMessage(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Message: json.RawMessage(`[
{"role":"assistant","content":"prologue"},
{"role":"user","content":"q1","id":"m1"},
{"role":"assistant","content":"a1","id":"m1","thumbup":true}
]`),
Reference: json.RawMessage(`[{"chunks":[{"id":"c1","dataset_id":"kb1"}],"doc_aggs":[]}]`),
}
store.dialogExists["user-1|chat-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
resp, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{
"thumbup": false,
"feedback": "bad answer",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("code=%v", code)
}
msg := resp.Messages[2]
if msg["thumbup"] != false || msg["feedback"] != "bad answer" {
t.Fatalf("message=%#v", msg)
}
resp, code, err = svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{
"thumbup": true,
})
if err != nil {
t.Fatalf("unexpected second error: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("second code=%v", code)
}
msg = resp.Messages[2]
if msg["thumbup"] != true {
t.Fatalf("thumbup not set: %#v", msg)
}
if _, ok := msg["feedback"]; ok {
t.Fatalf("feedback should be removed: %#v", msg)
}
}
func TestUpdateMessageFeedback_RejectsNonBooleanThumbup(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{ID: "session-1", DialogID: "chat-1"}
store.dialogExists["user-1|chat-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": "true"})
if err == nil || err.Error() != "thumbup must be a boolean" {
t.Fatalf("err=%v", err)
}
if code != common.CodeDataError {
t.Fatalf("code=%v", code)
}
}
func TestDeleteSessionMessage_RejectsMalformedMessagesWithoutUpdate(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Message: json.RawMessage(`{bad json`),
Reference: json.RawMessage(`[]`),
}
store.dialogExists["user-1|chat-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, code, err := svc.DeleteSessionMessage("user-1", "chat-1", "session-1", "missing")
if err == nil || err.Error() != "Invalid session messages" {
t.Fatalf("err=%v", err)
}
if code != common.CodeDataError {
t.Fatalf("code=%v", code)
}
if len(store.updateCalled) != 0 {
t.Fatalf("unexpected update calls=%#v", store.updateCalled)
}
}
func TestDeleteSessionMessage_RejectsMalformedReferenceWithoutUpdate(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Message: json.RawMessage(`[
{"role":"assistant","content":"prologue"},
{"role":"user","content":"q1","id":"m1"},
{"role":"assistant","content":"a1","id":"m1"}
]`),
Reference: json.RawMessage(`{bad json`),
}
store.dialogExists["user-1|chat-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, code, err := svc.DeleteSessionMessage("user-1", "chat-1", "session-1", "m1")
if err == nil || err.Error() != "Invalid session reference" {
t.Fatalf("err=%v", err)
}
if code != common.CodeDataError {
t.Fatalf("code=%v", code)
}
if len(store.updateCalled) != 0 {
t.Fatalf("unexpected update calls=%#v", store.updateCalled)
}
}
func TestUpdateMessageFeedback_RejectsMissingThumbup(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{ID: "session-1", DialogID: "chat-1"}
store.dialogExists["user-1|chat-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{})
if err == nil || err.Error() != "thumbup must be a boolean" {
t.Fatalf("err=%v", err)
}
if code != common.CodeDataError {
t.Fatalf("code=%v", code)
}
}
func TestUpdateMessageFeedback_RejectsMalformedMessagesWithoutUpdate(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Message: json.RawMessage(`{"unexpected":[]}`),
Reference: json.RawMessage(`[]`),
}
store.dialogExists["user-1|chat-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": true})
if err == nil || err.Error() != "Invalid session messages" {
t.Fatalf("err=%v", err)
}
if code != common.CodeDataError {
t.Fatalf("code=%v", code)
}
if len(store.updateCalled) != 0 {
t.Fatalf("unexpected update calls=%#v", store.updateCalled)
}
}
func TestUpdateMessageFeedback_RejectsMalformedReferenceWithoutUpdate(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Message: json.RawMessage(`[
{"role":"assistant","content":"prologue"},
{"role":"user","content":"q1","id":"m1"},
{"role":"assistant","content":"a1","id":"m1"}
]`),
Reference: json.RawMessage(`{bad json`),
}
store.dialogExists["user-1|chat-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": true})
if err == nil || err.Error() != "Invalid session reference" {
t.Fatalf("err=%v", err)
}
if code != common.CodeDataError {
t.Fatalf("code=%v", code)
}
if len(store.updateCalled) != 0 {
t.Fatalf("unexpected update calls=%#v", store.updateCalled)
}
}
func TestUpdateMessageFeedback_NotOwner(t *testing.T) {
svc := &ChatSessionService{
chatSessionDAO: newFakeSessionStore(),
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": true})
if err == nil || err.Error() != "No authorization." {
t.Fatalf("err=%v", err)
}
if code != common.CodeAuthenticationError {
t.Fatalf("code=%v", code)
}
}
func TestUpdateMessageFeedback_SessionNotFoundForWrongChat(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{ID: "session-1", DialogID: "chat-2"}
store.dialogExists["user-1|chat-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": true})
if err == nil || err.Error() != "Session not found!" {
t.Fatalf("err=%v", err)
}
if code != common.CodeDataError {
t.Fatalf("code=%v", code)
}
}
func TestUpdateMessageFeedback_MissingMessageIDLeavesMessagesUnchanged(t *testing.T) {
store := newFakeSessionStore()
originalMessages := json.RawMessage(`[
{"role":"assistant","content":"prologue"},
{"role":"user","content":"q1","id":"m1"},
{"role":"assistant","content":"a1","id":"m1"}
]`)
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Message: originalMessages,
Reference: json.RawMessage(`[{"chunks":[],"doc_aggs":[]}]`),
}
store.dialogExists["user-1|chat-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
resp, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "missing", map[string]interface{}{"thumbup": false})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("code=%v", code)
}
if !reflect.DeepEqual(resp.Messages, parseMessages(originalMessages)) {
t.Fatalf("messages changed: %#v", resp.Messages)
}
}
func TestUpdateMessageFeedback_SkipsMatchingNonAssistantMessage(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Message: json.RawMessage(`[
{"role":"assistant","content":"prologue"},
{"role":"user","content":"q1","id":"m1"},
{"role":"assistant","content":"a1","id":"assistant-1"}
]`),
Reference: json.RawMessage(`[{"chunks":[],"doc_aggs":[]}]`),
}
store.dialogExists["user-1|chat-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
resp, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": false, "feedback": "ignored"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("code=%v", code)
}
for _, msg := range resp.Messages {
if _, ok := msg["thumbup"]; ok {
t.Fatalf("non-assistant match should be skipped, messages=%#v", resp.Messages)
}
if _, ok := msg["feedback"]; ok {
t.Fatalf("feedback should not be written, messages=%#v", resp.Messages)
}
}
}
func TestUpdateMessageFeedback_ChangedFeedbackTriggersChunkFeedback(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Message: json.RawMessage(`[
{"role":"assistant","content":"prologue"},
{"role":"user","content":"q1","id":"m1"},
{"role":"assistant","content":"a1","id":"m1","thumbup":true},
{"role":"user","content":"q2","id":"m2"},
{"role":"assistant","content":"a2","id":"m2"}
]`),
Reference: json.RawMessage(`[
{"chunks":[{"id":"c1","dataset_id":"kb1","similarity":0.9}],"doc_aggs":[]},
{"chunks":[{"id":"c2","dataset_id":"kb2","similarity":0.8}],"doc_aggs":[]}
]`),
}
store.dialogExists["user-1|chat-1"] = true
feedback := &fakeChunkFeedbackApplier{}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
chunkFeedbackApplier: feedback,
}
resp, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": false})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("code=%v", code)
}
if resp.Messages[2]["thumbup"] != false {
t.Fatalf("message not updated: %#v", resp.Messages[2])
}
if len(feedback.calls) != 2 {
t.Fatalf("feedback calls=%#v", feedback.calls)
}
if feedback.calls[0].tenantID != "user-1" || feedback.calls[1].tenantID != "user-1" {
t.Fatalf("tenant ids=%#v", feedback.calls)
}
if feedback.calls[0].isPositive || feedback.calls[1].isPositive {
t.Fatalf("expected two negative applications when changing true -> false: %#v", feedback.calls)
}
chunks, _ := feedback.calls[0].reference["chunks"].([]interface{})
chunk, _ := chunks[0].(map[string]interface{})
if chunk["id"] != "c1" {
t.Fatalf("wrong reference used for first pair: %#v", feedback.calls[0].reference)
}
}
func TestUpdateMessageFeedback_UsesOwningTenantForChunkFeedback(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Message: json.RawMessage(`[
{"role":"assistant","content":"prologue"},
{"role":"user","content":"q1","id":"m1"},
{"role":"assistant","content":"a1","id":"m1"}
]`),
Reference: json.RawMessage(`[{"chunks":[{"id":"c1","dataset_id":"kb1"}],"doc_aggs":[]}]`),
}
store.dialogExists["tenant-1|chat-1"] = true
feedback := &fakeChunkFeedbackApplier{}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-1"}},
pipeline: &fakePipeline{},
chunkFeedbackApplier: feedback,
}
_, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": true})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("code=%v", code)
}
if len(feedback.calls) != 1 {
t.Fatalf("feedback calls=%#v", feedback.calls)
}
if feedback.calls[0].tenantID != "tenant-1" {
t.Fatalf("tenantID=%q", feedback.calls[0].tenantID)
}
}
func TestUpdateMessageFeedback_DoesNotTriggerChunkFeedbackWhenUpdateFails(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Message: json.RawMessage(`[
{"role":"assistant","content":"prologue"},
{"role":"user","content":"q1","id":"m1"},
{"role":"assistant","content":"a1","id":"m1"}
]`),
Reference: json.RawMessage(`[{"chunks":[{"id":"c1","dataset_id":"kb1"}],"doc_aggs":[]}]`),
}
store.dialogExists["user-1|chat-1"] = true
store.updateByIDErr = errors.New("update failed")
feedback := &fakeChunkFeedbackApplier{}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
chunkFeedbackApplier: feedback,
}
_, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": true})
if err == nil || err.Error() != "update failed" {
t.Fatalf("err=%v", err)
}
if code != common.CodeServerError {
t.Fatalf("code=%v", code)
}
if len(feedback.calls) != 0 {
t.Fatalf("feedback should not be called: %#v", feedback.calls)
}
}
func TestUpdateMessageFeedback_UnchangedFeedbackDoesNotTriggerChunkFeedback(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Message: json.RawMessage(`[
{"role":"assistant","content":"prologue"},
{"role":"user","content":"q1","id":"m1"},
{"role":"assistant","content":"a1","id":"m1","thumbup":true}
]`),
Reference: json.RawMessage(`[{"chunks":[{"id":"c1","dataset_id":"kb1"}],"doc_aggs":[]}]`),
}
store.dialogExists["user-1|chat-1"] = true
feedback := &fakeChunkFeedbackApplier{}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
chunkFeedbackApplier: feedback,
}
_, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": true})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("code=%v", code)
}
if len(feedback.calls) != 0 {
t.Fatalf("feedback should not be called: %#v", feedback.calls)
}
}
// ===================================================================
// Completion tests
// ===================================================================