mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Compare commits
1 Commits
nightly
...
revert-164
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e2c2d8d61a |
@@ -271,77 +271,6 @@ func (e *elasticsearchEngine) UpdateChunks(ctx context.Context, condition map[st
|
||||
return e.updateChunksByQuery(ctx, fullIndexName, condition, newValue)
|
||||
}
|
||||
|
||||
// AdjustChunkPagerank atomically adjusts pagerank_fea and clamps it to [minWeight, maxWeight].
|
||||
func (e *elasticsearchEngine) AdjustChunkPagerank(ctx context.Context, indexName, chunkID, kbID string, delta, minWeight, maxWeight float64) error {
|
||||
if indexName == "" {
|
||||
return fmt.Errorf("index name cannot be empty")
|
||||
}
|
||||
if chunkID == "" {
|
||||
return fmt.Errorf("chunk id cannot be empty")
|
||||
}
|
||||
script := `
|
||||
if (ctx._source.kb_id == null || !ctx._source.kb_id.equals(params.kb_id)) {
|
||||
ctx.op = 'noop';
|
||||
} else {
|
||||
double current = 0.0;
|
||||
if (ctx._source.containsKey(params.field) && ctx._source[params.field] != null) {
|
||||
current = ((Number)ctx._source[params.field]).doubleValue();
|
||||
}
|
||||
double next = current + params.delta;
|
||||
if (next < params.min_weight) {
|
||||
next = params.min_weight;
|
||||
}
|
||||
if (next > params.max_weight) {
|
||||
next = params.max_weight;
|
||||
}
|
||||
ctx._source[params.field] = next;
|
||||
}
|
||||
`
|
||||
body, err := json.Marshal(map[string]interface{}{
|
||||
"script": map[string]interface{}{
|
||||
"source": script,
|
||||
"lang": "painless",
|
||||
"params": map[string]interface{}{
|
||||
"field": common.PAGERANK_FLD,
|
||||
"kb_id": kbID,
|
||||
"delta": delta,
|
||||
"min_weight": minWeight,
|
||||
"max_weight": maxWeight,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal pagerank adjust request: %w", err)
|
||||
}
|
||||
req := esapi.UpdateRequest{
|
||||
Index: indexName,
|
||||
DocumentID: chunkID,
|
||||
Body: bytes.NewReader(body),
|
||||
}
|
||||
res, err := req.Do(ctx, e.client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to adjust chunk pagerank: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.IsError() {
|
||||
if res.StatusCode == http.StatusNotFound {
|
||||
return fmt.Errorf("%w: %s", types.ErrDocumentNotFound, chunkID)
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(res.Body)
|
||||
return fmt.Errorf("elasticsearch pagerank adjust error: %s, body: %s", res.Status(), string(bodyBytes))
|
||||
}
|
||||
var updateResp struct {
|
||||
Result string `json:"result"`
|
||||
}
|
||||
if err := json.NewDecoder(res.Body).Decode(&updateResp); err != nil {
|
||||
return fmt.Errorf("failed to decode pagerank adjust response: %w", err)
|
||||
}
|
||||
if updateResp.Result == "noop" {
|
||||
return fmt.Errorf("chunk %s does not belong to dataset %s", chunkID, kbID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *elasticsearchEngine) updateSingleMemoryMessage(ctx context.Context, indexName, messageDocID string, newValue map[string]interface{}) error {
|
||||
doc := mapMemoryMessageESUpdateFields(newValue)
|
||||
delete(doc, "id")
|
||||
|
||||
@@ -474,53 +474,3 @@ func (h *ChatSessionHandler) UpdateSession(c *gin.Context) {
|
||||
}
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
func (h *ChatSessionHandler) DeleteSessionMessage(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
userID := user.ID
|
||||
chatID, sessionID, msgID := c.Param("chat_id"), c.Param("session_id"), c.Param("msg_id")
|
||||
|
||||
result, code, err := h.chatSessionService.DeleteSessionMessage(userID, chatID, sessionID, msgID)
|
||||
if err != nil {
|
||||
if code == common.CodeAuthenticationError {
|
||||
jsonResponse(c, code, false, err.Error())
|
||||
return
|
||||
}
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
func (h *ChatSessionHandler) UpdateMessageFeedback(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
userID := user.ID
|
||||
chatID, sessionID, msgID := c.Param("chat_id"), c.Param("session_id"), c.Param("msg_id")
|
||||
|
||||
req := map[string]interface{}{}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
jsonError(c, common.CodeArgumentError, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result, code, err := h.chatSessionService.UpdateMessageFeedback(userID, chatID, sessionID, msgID, req)
|
||||
if err != nil {
|
||||
if code == common.CodeAuthenticationError {
|
||||
jsonResponse(c, code, false, err.Error())
|
||||
return
|
||||
}
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
@@ -284,8 +284,6 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
chats.DELETE("/:chat_id/sessions", r.chatSessionHandler.DeleteSessions)
|
||||
chats.GET("/:chat_id/sessions/:session_id", r.chatSessionHandler.GetSession)
|
||||
chats.PATCH("/:chat_id/sessions/:session_id", r.chatSessionHandler.UpdateSession)
|
||||
chats.DELETE("/:chat_id/sessions/:session_id/messages/:msg_id", r.chatSessionHandler.DeleteSessionMessage)
|
||||
chats.PUT("/:chat_id/sessions/:session_id/messages/:msg_id/feedback", r.chatSessionHandler.UpdateMessageFeedback)
|
||||
}
|
||||
|
||||
// OpenAI-compatible chat completions route
|
||||
|
||||
@@ -21,10 +21,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/storage"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -57,21 +54,12 @@ type chatPipelineRunner interface {
|
||||
AsyncChat(ctx context.Context, chat *entity.Chat, messages []map[string]interface{}, stream bool, kwargs map[string]interface{}) (<-chan AsyncChatResult, error)
|
||||
}
|
||||
|
||||
type chunkFeedbackApplier interface {
|
||||
applyChunkFeedback(tenantID string, reference map[string]interface{}, isPositive bool) (map[string]interface{}, error)
|
||||
}
|
||||
|
||||
type atomicChunkPagerankAdjuster interface {
|
||||
AdjustChunkPagerank(ctx context.Context, indexName, chunkID, kbID string, delta, minWeight, maxWeight float64) error
|
||||
}
|
||||
|
||||
// ChatSessionService chat session (conversation) service.
|
||||
// The RAG pipeline is delegated to ChatPipelineService.
|
||||
type ChatSessionService struct {
|
||||
chatSessionDAO chatSessionStore
|
||||
userTenantDAO userTenantStore
|
||||
pipeline chatPipelineRunner
|
||||
chunkFeedbackApplier chunkFeedbackApplier
|
||||
chatSessionDAO chatSessionStore
|
||||
userTenantDAO userTenantStore
|
||||
pipeline chatPipelineRunner
|
||||
}
|
||||
|
||||
// NewChatSessionService create chat session service
|
||||
@@ -614,194 +602,6 @@ func (s *ChatSessionService) UpdateSession(userID, chatID, sessionID string, req
|
||||
return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) DeleteSessionMessage(userID, chatID, sessionID, msgID string) (*ChatSessionPayload, common.ErrorCode, error) {
|
||||
ok, err := s.ensureOwnedChat(userID, chatID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if !ok {
|
||||
return nil, common.CodeAuthenticationError, errors.New("No authorization.")
|
||||
}
|
||||
|
||||
session, err := s.chatSessionDAO.GetByID(sessionID)
|
||||
if err != nil || session.DialogID != chatID {
|
||||
if err != nil && !isChatSessionNotFound(err) {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
return nil, common.CodeDataError, errors.New("Session not found!")
|
||||
}
|
||||
|
||||
messages := parseMessages(session.Message)
|
||||
if len(session.Message) > 0 && messages == nil {
|
||||
return nil, common.CodeDataError, errors.New("Invalid session messages")
|
||||
}
|
||||
references := parseReferenceList(session.Reference)
|
||||
if len(session.Reference) > 0 && references == nil {
|
||||
return nil, common.CodeDataError, errors.New("Invalid session reference")
|
||||
}
|
||||
for i, msg := range messages {
|
||||
if msgID != stringValue(msg["id"]) {
|
||||
continue
|
||||
}
|
||||
if i+1 >= len(messages) || stringValue(messages[i+1]["id"]) != msgID {
|
||||
return nil, common.CodeServerError, errors.New("message pair assertion failed")
|
||||
}
|
||||
messages = append(messages[:i], messages[i+2:]...)
|
||||
refIndex := (i - 1) / 2
|
||||
if refIndex < 0 {
|
||||
refIndex = 0
|
||||
}
|
||||
if refIndex < len(references) {
|
||||
references = append(references[:refIndex], references[refIndex+1:]...)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
messageRaw, err := json.Marshal(messages)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
referenceRaw, err := json.Marshal(references)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if err := s.chatSessionDAO.UpdateByID(session.ID, map[string]interface{}{
|
||||
"message": messageRaw,
|
||||
"reference": referenceRaw,
|
||||
}); err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
session.Message = messageRaw
|
||||
session.Reference = referenceRaw
|
||||
|
||||
return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) UpdateMessageFeedback(userID, chatID, sessionID, msgID string, req map[string]interface{}) (*ChatSessionPayload, common.ErrorCode, error) {
|
||||
ownerTenantID := ""
|
||||
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
for _, tenantID := range tenantIDs {
|
||||
exists, err := s.chatSessionDAO.CheckDialogExists(tenantID, chatID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if exists {
|
||||
ownerTenantID = tenantID
|
||||
break
|
||||
}
|
||||
}
|
||||
if ownerTenantID == "" {
|
||||
exists, err := s.chatSessionDAO.CheckDialogExists(userID, chatID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if exists {
|
||||
ownerTenantID = userID
|
||||
}
|
||||
}
|
||||
ok := ownerTenantID != ""
|
||||
if !ok {
|
||||
return nil, common.CodeAuthenticationError, errors.New("No authorization.")
|
||||
}
|
||||
|
||||
session, err := s.chatSessionDAO.GetByID(sessionID)
|
||||
if err != nil || session.DialogID != chatID {
|
||||
if err != nil && !isChatSessionNotFound(err) {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
return nil, common.CodeDataError, errors.New("Session not found!")
|
||||
}
|
||||
|
||||
thumbRaw, ok := req["thumbup"]
|
||||
if !ok {
|
||||
return nil, common.CodeDataError, errors.New("thumbup must be a boolean")
|
||||
}
|
||||
thumbup, ok := thumbRaw.(bool)
|
||||
if !ok {
|
||||
return nil, common.CodeDataError, errors.New("thumbup must be a boolean")
|
||||
}
|
||||
|
||||
messages := parseMessages(session.Message)
|
||||
if len(session.Message) > 0 && messages == nil {
|
||||
return nil, common.CodeDataError, errors.New("Invalid session messages")
|
||||
}
|
||||
messageIndex := -1
|
||||
var priorThumb interface{}
|
||||
applyChunkFeedback := false
|
||||
var feedbackReference map[string]interface{}
|
||||
for i, msg := range messages {
|
||||
if msgID != stringValue(msg["id"]) || stringValue(msg["role"]) != "assistant" {
|
||||
continue
|
||||
}
|
||||
priorThumb = msg["thumbup"]
|
||||
priorThumbBool, priorThumbIsBool := priorThumb.(bool)
|
||||
if thumbup {
|
||||
msg["thumbup"] = true
|
||||
delete(msg, "feedback")
|
||||
applyChunkFeedback = !priorThumbIsBool || !priorThumbBool
|
||||
} else {
|
||||
msg["thumbup"] = false
|
||||
if feedback, exists := req["feedback"]; exists && isTruthy(feedback) {
|
||||
msg["feedback"] = feedback
|
||||
}
|
||||
applyChunkFeedback = !priorThumbIsBool || priorThumbBool
|
||||
}
|
||||
messages[i] = msg
|
||||
messageIndex = i
|
||||
break
|
||||
}
|
||||
|
||||
if messageIndex != -1 && applyChunkFeedback {
|
||||
references := parseReferenceList(session.Reference)
|
||||
if len(session.Reference) > 0 && references == nil {
|
||||
return nil, common.CodeDataError, errors.New("Invalid session reference")
|
||||
}
|
||||
refIndex := (messageIndex - 1) / 2
|
||||
if refIndex >= 0 && refIndex < len(references) {
|
||||
if reference, ok := references[refIndex].(map[string]interface{}); ok && len(reference) > 0 {
|
||||
feedbackReference = reference
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messageRaw, err := json.Marshal(messages)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if err := s.chatSessionDAO.UpdateByID(session.ID, map[string]interface{}{"message": messageRaw}); err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
session.Message = messageRaw
|
||||
|
||||
if feedbackReference != nil {
|
||||
applier := s.chunkFeedbackApplier
|
||||
if applier == nil {
|
||||
applier = s
|
||||
}
|
||||
if priorThumbBool, ok := priorThumb.(bool); ok && priorThumbBool != thumbup {
|
||||
result, _ := applier.applyChunkFeedback(ownerTenantID, feedbackReference, !priorThumbBool)
|
||||
if result != nil {
|
||||
common.Debug("Chunk feedback undo applied",
|
||||
zap.Any("success_count", result["success_count"]),
|
||||
zap.Any("fail_count", result["fail_count"]),
|
||||
)
|
||||
}
|
||||
}
|
||||
result, _ := applier.applyChunkFeedback(ownerTenantID, feedbackReference, thumbup)
|
||||
if result != nil {
|
||||
common.Debug("Chunk feedback applied",
|
||||
zap.Any("success_count", result["success_count"]),
|
||||
zap.Any("fail_count", result["fail_count"]),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) ensureOwnedChat(userID, chatID string) (bool, error) {
|
||||
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
|
||||
if err != nil {
|
||||
@@ -907,208 +707,6 @@ func parseReferenceList(raw json.RawMessage) []interface{} {
|
||||
return references
|
||||
}
|
||||
|
||||
type chunkFeedbackRow struct {
|
||||
chunkID string
|
||||
kbID string
|
||||
chunk map[string]interface{}
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) applyChunkFeedback(tenantID string, reference map[string]interface{}, isPositive bool) (map[string]interface{}, error) {
|
||||
if strings.ToLower(os.Getenv("CHUNK_FEEDBACK_ENABLED")) != "true" {
|
||||
return map[string]interface{}{"success_count": 0, "fail_count": 0, "chunk_ids": []string{}, "disabled": true}, nil
|
||||
}
|
||||
|
||||
rows := feedbackRowsFromReference(reference)
|
||||
chunkIDs := make([]string, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
chunkIDs = append(chunkIDs, row.chunkID)
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return map[string]interface{}{"success_count": 0, "fail_count": 0, "chunk_ids": chunkIDs}, nil
|
||||
}
|
||||
|
||||
signedBudget := 1
|
||||
if !isPositive {
|
||||
signedBudget = -1
|
||||
}
|
||||
weighting := strings.TrimSpace(strings.ToLower(os.Getenv("CHUNK_FEEDBACK_WEIGHTING")))
|
||||
deltas := allocateFeedbackDeltasRelevance(rows, signedBudget)
|
||||
if weighting == "uniform" {
|
||||
deltas = allocateFeedbackDeltasUniform(rows, signedBudget)
|
||||
}
|
||||
|
||||
successCount := 0
|
||||
failCount := 0
|
||||
for i, delta := range deltas {
|
||||
if delta == 0 {
|
||||
continue
|
||||
}
|
||||
if s.updateChunkWeight(context.Background(), tenantID, rows[i].chunkID, rows[i].kbID, delta) {
|
||||
successCount++
|
||||
} else {
|
||||
failCount++
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{"success_count": successCount, "fail_count": failCount, "chunk_ids": chunkIDs}, nil
|
||||
}
|
||||
|
||||
func feedbackRowsFromReference(reference map[string]interface{}) []chunkFeedbackRow {
|
||||
if len(reference) == 0 {
|
||||
return nil
|
||||
}
|
||||
rawChunks, ok := reference["chunks"].([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
rows := make([]chunkFeedbackRow, 0, len(rawChunks))
|
||||
for _, raw := range rawChunks {
|
||||
chunk, ok := raw.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
chunkID := stringValue(chunk["id"])
|
||||
if chunkID == "" {
|
||||
chunkID = stringValue(chunk["chunk_id"])
|
||||
}
|
||||
kbID := stringValue(chunk["dataset_id"])
|
||||
if kbID == "" {
|
||||
kbID = stringValue(chunk["kb_id"])
|
||||
}
|
||||
if chunkID != "" && kbID != "" {
|
||||
rows = append(rows, chunkFeedbackRow{chunkID: chunkID, kbID: kbID, chunk: chunk})
|
||||
}
|
||||
}
|
||||
return rows
|
||||
}
|
||||
|
||||
func allocateFeedbackDeltasUniform(rows []chunkFeedbackRow, signedBudget int) []int {
|
||||
deltas := make([]int, len(rows))
|
||||
for i := range rows {
|
||||
deltas[i] = signedBudget
|
||||
}
|
||||
return deltas
|
||||
}
|
||||
|
||||
func allocateFeedbackDeltasRelevance(rows []chunkFeedbackRow, signedBudget int) []int {
|
||||
magnitudes := make([]float64, len(rows))
|
||||
for i, row := range rows {
|
||||
magnitudes[i] = feedbackRetrievalSignal(row.chunk)
|
||||
}
|
||||
total := 0.0
|
||||
for _, magnitude := range magnitudes {
|
||||
total += magnitude
|
||||
}
|
||||
if total <= 0 {
|
||||
for i := range magnitudes {
|
||||
magnitudes[i] = 1
|
||||
}
|
||||
}
|
||||
|
||||
sign := 1
|
||||
if signedBudget < 0 {
|
||||
sign = -1
|
||||
}
|
||||
budgetAbs := signedBudget
|
||||
if budgetAbs < 0 {
|
||||
budgetAbs = -budgetAbs
|
||||
}
|
||||
parts := splitIntegerBudget(magnitudes, budgetAbs)
|
||||
for i := range parts {
|
||||
parts[i] *= sign
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func feedbackRetrievalSignal(chunk map[string]interface{}) float64 {
|
||||
best := 0.0
|
||||
for _, key := range []string{"similarity", "vector_similarity", "term_similarity"} {
|
||||
val := floatFromValue(chunk[key])
|
||||
if !math.IsNaN(val) && !math.IsInf(val, 0) && val > best {
|
||||
best = val
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func splitIntegerBudget(magnitudes []float64, budget int) []int {
|
||||
n := len(magnitudes)
|
||||
parts := make([]int, n)
|
||||
if n == 0 || budget == 0 {
|
||||
return parts
|
||||
}
|
||||
total := 0.0
|
||||
for _, magnitude := range magnitudes {
|
||||
total += magnitude
|
||||
}
|
||||
if total <= 0 {
|
||||
base := budget / n
|
||||
remainder := budget % n
|
||||
for i := range parts {
|
||||
parts[i] = base
|
||||
if i < remainder {
|
||||
parts[i]++
|
||||
}
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
type remainder struct {
|
||||
index int
|
||||
value float64
|
||||
}
|
||||
remainders := make([]remainder, 0, n)
|
||||
assigned := 0
|
||||
for i, magnitude := range magnitudes {
|
||||
exact := magnitude / total * float64(budget)
|
||||
base := int(math.Floor(exact))
|
||||
parts[i] = base
|
||||
assigned += base
|
||||
remainders = append(remainders, remainder{index: i, value: exact - float64(base)})
|
||||
}
|
||||
for remaining := budget - assigned; remaining > 0; remaining-- {
|
||||
best := 0
|
||||
for i := 1; i < len(remainders); i++ {
|
||||
if remainders[i].value > remainders[best].value {
|
||||
best = i
|
||||
}
|
||||
}
|
||||
parts[remainders[best].index]++
|
||||
remainders[best].value = -1
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) updateChunkWeight(ctx context.Context, tenantID, chunkID, kbID string, delta int) bool {
|
||||
docEngine := engine.Get()
|
||||
if docEngine == nil {
|
||||
return false
|
||||
}
|
||||
indexName := fmt.Sprintf("ragflow_%s", tenantID)
|
||||
|
||||
if adjuster, ok := docEngine.(atomicChunkPagerankAdjuster); ok {
|
||||
return adjuster.AdjustChunkPagerank(ctx, indexName, chunkID, kbID, float64(delta), 0, 100) == nil
|
||||
}
|
||||
|
||||
rawChunk, err := docEngine.GetChunk(ctx, indexName, chunkID, []string{kbID})
|
||||
if err != nil || rawChunk == nil {
|
||||
return false
|
||||
}
|
||||
chunk, ok := rawChunk.(map[string]interface{})
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
currentWeight := floatFromValue(chunk[common.PAGERANK_FLD])
|
||||
newWeight := currentWeight + float64(delta)
|
||||
if newWeight < 0 {
|
||||
newWeight = 0
|
||||
}
|
||||
if newWeight > 100 {
|
||||
newWeight = 100
|
||||
}
|
||||
return docEngine.UpdateChunks(ctx, map[string]interface{}{"id": chunkID}, map[string]interface{}{common.PAGERANK_FLD: newWeight}, indexName, kbID) == nil
|
||||
}
|
||||
|
||||
func formatReferenceChunks(reference map[string]interface{}) []FormattedChunk {
|
||||
rawChunks, ok := reference["chunks"].([]interface{})
|
||||
if !ok {
|
||||
|
||||
@@ -170,24 +170,6 @@ func (f *fakePipeline) AsyncChat(ctx context.Context, chat *entity.Chat, message
|
||||
return f.resultChan, f.err
|
||||
}
|
||||
|
||||
type fakeChunkFeedbackApplier struct {
|
||||
calls []struct {
|
||||
tenantID string
|
||||
reference map[string]interface{}
|
||||
isPositive bool
|
||||
}
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeChunkFeedbackApplier) applyChunkFeedback(tenantID string, reference map[string]interface{}, isPositive bool) (map[string]interface{}, error) {
|
||||
f.calls = append(f.calls, struct {
|
||||
tenantID string
|
||||
reference map[string]interface{}
|
||||
isPositive bool
|
||||
}{tenantID: tenantID, reference: reference, isPositive: isPositive})
|
||||
return map[string]interface{}{"success_count": 0, "fail_count": 0, "chunk_ids": []string{}}, f.err
|
||||
}
|
||||
|
||||
func makeResultChan(results ...AsyncChatResult) <-chan AsyncChatResult {
|
||||
ch := make(chan AsyncChatResult, len(results))
|
||||
for _, r := range results {
|
||||
@@ -207,10 +189,9 @@ func TestSetChatSession_CreateNew(t *testing.T) {
|
||||
store.dialogs["dialog-1"] = dialog
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
chunkFeedbackApplier: &fakeChunkFeedbackApplier{},
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
resp, err := svc.SetChatSession("user-1", &SetChatSessionRequest{
|
||||
@@ -566,12 +547,10 @@ func TestUpdateSession_Success(t *testing.T) {
|
||||
}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
|
||||
feedback := &fakeChunkFeedbackApplier{}
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
chunkFeedbackApplier: feedback,
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
longName := " " + strings.Repeat("x", 260) + " "
|
||||
@@ -659,610 +638,6 @@ func TestUpdateSession_NotFound(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteSessionMessage_RemovesMessagePairAndMatchingReference(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Message: json.RawMessage(`[
|
||||
{"role":"assistant","content":"prologue"},
|
||||
{"role":"user","content":"q1","id":"m1"},
|
||||
{"role":"assistant","content":"a1","id":"m1"},
|
||||
{"role":"user","content":"q2","id":"m2"},
|
||||
{"role":"assistant","content":"a2","id":"m2"}
|
||||
]`),
|
||||
Reference: json.RawMessage(`[
|
||||
{"chunks":[{"id":"c1","dataset_id":"kb1"}],"doc_aggs":[]},
|
||||
{"chunks":[{"id":"c2","dataset_id":"kb2"}],"doc_aggs":[]}
|
||||
]`),
|
||||
}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
resp, code, err := svc.DeleteSessionMessage("user-1", "chat-1", "session-1", "m2")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
if len(resp.Messages) != 3 {
|
||||
t.Fatalf("messages=%#v", resp.Messages)
|
||||
}
|
||||
if resp.Messages[1]["id"] != "m1" || resp.Messages[2]["id"] != "m1" {
|
||||
t.Fatalf("unexpected remaining messages=%#v", resp.Messages)
|
||||
}
|
||||
|
||||
var refs []map[string]interface{}
|
||||
if err := json.Unmarshal(store.sessions["session-1"].Reference, &refs); err != nil {
|
||||
t.Fatalf("decode stored references: %v", err)
|
||||
}
|
||||
if len(refs) != 1 {
|
||||
t.Fatalf("reference len=%d refs=%#v", len(refs), refs)
|
||||
}
|
||||
chunks, _ := refs[0]["chunks"].([]interface{})
|
||||
chunk, _ := chunks[0].(map[string]interface{})
|
||||
if chunk["id"] != "c1" {
|
||||
t.Fatalf("wrong reference remained: %#v", refs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteSessionMessage_NotOwner(t *testing.T) {
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: newFakeSessionStore(),
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, code, err := svc.DeleteSessionMessage("user-1", "chat-1", "session-1", "m1")
|
||||
if err == nil || err.Error() != "No authorization." {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
if code != common.CodeAuthenticationError {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteSessionMessage_SessionNotFoundForWrongChat(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{ID: "session-1", DialogID: "chat-2"}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, code, err := svc.DeleteSessionMessage("user-1", "chat-1", "session-1", "m1")
|
||||
if err == nil || err.Error() != "Session not found!" {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
if code != common.CodeDataError {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteSessionMessage_MissingMessageIDLeavesSessionUnchanged(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
originalMessages := json.RawMessage(`[
|
||||
{"role":"assistant","content":"prologue"},
|
||||
{"role":"user","content":"q1","id":"m1"},
|
||||
{"role":"assistant","content":"a1","id":"m1"}
|
||||
]`)
|
||||
originalReferences := json.RawMessage(`[{"chunks":[{"id":"c1","dataset_id":"kb1"}],"doc_aggs":[]}]`)
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Message: originalMessages,
|
||||
Reference: originalReferences,
|
||||
}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
resp, code, err := svc.DeleteSessionMessage("user-1", "chat-1", "session-1", "missing")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
if len(resp.Messages) != 3 || len(resp.Reference) != 1 {
|
||||
t.Fatalf("response changed unexpectedly: messages=%#v refs=%#v", resp.Messages, resp.Reference)
|
||||
}
|
||||
if !reflect.DeepEqual(parseMessages(store.sessions["session-1"].Message), parseMessages(originalMessages)) {
|
||||
t.Fatalf("stored messages changed: %s", store.sessions["session-1"].Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMessageFeedback_UpdatesAssistantMessage(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Message: json.RawMessage(`[
|
||||
{"role":"assistant","content":"prologue"},
|
||||
{"role":"user","content":"q1","id":"m1"},
|
||||
{"role":"assistant","content":"a1","id":"m1","thumbup":true}
|
||||
]`),
|
||||
Reference: json.RawMessage(`[{"chunks":[{"id":"c1","dataset_id":"kb1"}],"doc_aggs":[]}]`),
|
||||
}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
resp, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{
|
||||
"thumbup": false,
|
||||
"feedback": "bad answer",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
msg := resp.Messages[2]
|
||||
if msg["thumbup"] != false || msg["feedback"] != "bad answer" {
|
||||
t.Fatalf("message=%#v", msg)
|
||||
}
|
||||
|
||||
resp, code, err = svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{
|
||||
"thumbup": true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected second error: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("second code=%v", code)
|
||||
}
|
||||
msg = resp.Messages[2]
|
||||
if msg["thumbup"] != true {
|
||||
t.Fatalf("thumbup not set: %#v", msg)
|
||||
}
|
||||
if _, ok := msg["feedback"]; ok {
|
||||
t.Fatalf("feedback should be removed: %#v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMessageFeedback_RejectsNonBooleanThumbup(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{ID: "session-1", DialogID: "chat-1"}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": "true"})
|
||||
if err == nil || err.Error() != "thumbup must be a boolean" {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
if code != common.CodeDataError {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteSessionMessage_RejectsMalformedMessagesWithoutUpdate(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Message: json.RawMessage(`{bad json`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, code, err := svc.DeleteSessionMessage("user-1", "chat-1", "session-1", "missing")
|
||||
if err == nil || err.Error() != "Invalid session messages" {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
if code != common.CodeDataError {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
if len(store.updateCalled) != 0 {
|
||||
t.Fatalf("unexpected update calls=%#v", store.updateCalled)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteSessionMessage_RejectsMalformedReferenceWithoutUpdate(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Message: json.RawMessage(`[
|
||||
{"role":"assistant","content":"prologue"},
|
||||
{"role":"user","content":"q1","id":"m1"},
|
||||
{"role":"assistant","content":"a1","id":"m1"}
|
||||
]`),
|
||||
Reference: json.RawMessage(`{bad json`),
|
||||
}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, code, err := svc.DeleteSessionMessage("user-1", "chat-1", "session-1", "m1")
|
||||
if err == nil || err.Error() != "Invalid session reference" {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
if code != common.CodeDataError {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
if len(store.updateCalled) != 0 {
|
||||
t.Fatalf("unexpected update calls=%#v", store.updateCalled)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMessageFeedback_RejectsMissingThumbup(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{ID: "session-1", DialogID: "chat-1"}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{})
|
||||
if err == nil || err.Error() != "thumbup must be a boolean" {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
if code != common.CodeDataError {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMessageFeedback_RejectsMalformedMessagesWithoutUpdate(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Message: json.RawMessage(`{"unexpected":[]}`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": true})
|
||||
if err == nil || err.Error() != "Invalid session messages" {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
if code != common.CodeDataError {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
if len(store.updateCalled) != 0 {
|
||||
t.Fatalf("unexpected update calls=%#v", store.updateCalled)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMessageFeedback_RejectsMalformedReferenceWithoutUpdate(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Message: json.RawMessage(`[
|
||||
{"role":"assistant","content":"prologue"},
|
||||
{"role":"user","content":"q1","id":"m1"},
|
||||
{"role":"assistant","content":"a1","id":"m1"}
|
||||
]`),
|
||||
Reference: json.RawMessage(`{bad json`),
|
||||
}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": true})
|
||||
if err == nil || err.Error() != "Invalid session reference" {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
if code != common.CodeDataError {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
if len(store.updateCalled) != 0 {
|
||||
t.Fatalf("unexpected update calls=%#v", store.updateCalled)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMessageFeedback_NotOwner(t *testing.T) {
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: newFakeSessionStore(),
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": true})
|
||||
if err == nil || err.Error() != "No authorization." {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
if code != common.CodeAuthenticationError {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMessageFeedback_SessionNotFoundForWrongChat(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{ID: "session-1", DialogID: "chat-2"}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": true})
|
||||
if err == nil || err.Error() != "Session not found!" {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
if code != common.CodeDataError {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMessageFeedback_MissingMessageIDLeavesMessagesUnchanged(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
originalMessages := json.RawMessage(`[
|
||||
{"role":"assistant","content":"prologue"},
|
||||
{"role":"user","content":"q1","id":"m1"},
|
||||
{"role":"assistant","content":"a1","id":"m1"}
|
||||
]`)
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Message: originalMessages,
|
||||
Reference: json.RawMessage(`[{"chunks":[],"doc_aggs":[]}]`),
|
||||
}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
resp, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "missing", map[string]interface{}{"thumbup": false})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
if !reflect.DeepEqual(resp.Messages, parseMessages(originalMessages)) {
|
||||
t.Fatalf("messages changed: %#v", resp.Messages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMessageFeedback_SkipsMatchingNonAssistantMessage(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Message: json.RawMessage(`[
|
||||
{"role":"assistant","content":"prologue"},
|
||||
{"role":"user","content":"q1","id":"m1"},
|
||||
{"role":"assistant","content":"a1","id":"assistant-1"}
|
||||
]`),
|
||||
Reference: json.RawMessage(`[{"chunks":[],"doc_aggs":[]}]`),
|
||||
}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
resp, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": false, "feedback": "ignored"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
for _, msg := range resp.Messages {
|
||||
if _, ok := msg["thumbup"]; ok {
|
||||
t.Fatalf("non-assistant match should be skipped, messages=%#v", resp.Messages)
|
||||
}
|
||||
if _, ok := msg["feedback"]; ok {
|
||||
t.Fatalf("feedback should not be written, messages=%#v", resp.Messages)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMessageFeedback_ChangedFeedbackTriggersChunkFeedback(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Message: json.RawMessage(`[
|
||||
{"role":"assistant","content":"prologue"},
|
||||
{"role":"user","content":"q1","id":"m1"},
|
||||
{"role":"assistant","content":"a1","id":"m1","thumbup":true},
|
||||
{"role":"user","content":"q2","id":"m2"},
|
||||
{"role":"assistant","content":"a2","id":"m2"}
|
||||
]`),
|
||||
Reference: json.RawMessage(`[
|
||||
{"chunks":[{"id":"c1","dataset_id":"kb1","similarity":0.9}],"doc_aggs":[]},
|
||||
{"chunks":[{"id":"c2","dataset_id":"kb2","similarity":0.8}],"doc_aggs":[]}
|
||||
]`),
|
||||
}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
feedback := &fakeChunkFeedbackApplier{}
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
chunkFeedbackApplier: feedback,
|
||||
}
|
||||
|
||||
resp, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": false})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
if resp.Messages[2]["thumbup"] != false {
|
||||
t.Fatalf("message not updated: %#v", resp.Messages[2])
|
||||
}
|
||||
if len(feedback.calls) != 2 {
|
||||
t.Fatalf("feedback calls=%#v", feedback.calls)
|
||||
}
|
||||
if feedback.calls[0].tenantID != "user-1" || feedback.calls[1].tenantID != "user-1" {
|
||||
t.Fatalf("tenant ids=%#v", feedback.calls)
|
||||
}
|
||||
if feedback.calls[0].isPositive || feedback.calls[1].isPositive {
|
||||
t.Fatalf("expected two negative applications when changing true -> false: %#v", feedback.calls)
|
||||
}
|
||||
chunks, _ := feedback.calls[0].reference["chunks"].([]interface{})
|
||||
chunk, _ := chunks[0].(map[string]interface{})
|
||||
if chunk["id"] != "c1" {
|
||||
t.Fatalf("wrong reference used for first pair: %#v", feedback.calls[0].reference)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMessageFeedback_UsesOwningTenantForChunkFeedback(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Message: json.RawMessage(`[
|
||||
{"role":"assistant","content":"prologue"},
|
||||
{"role":"user","content":"q1","id":"m1"},
|
||||
{"role":"assistant","content":"a1","id":"m1"}
|
||||
]`),
|
||||
Reference: json.RawMessage(`[{"chunks":[{"id":"c1","dataset_id":"kb1"}],"doc_aggs":[]}]`),
|
||||
}
|
||||
store.dialogExists["tenant-1|chat-1"] = true
|
||||
feedback := &fakeChunkFeedbackApplier{}
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-1"}},
|
||||
pipeline: &fakePipeline{},
|
||||
chunkFeedbackApplier: feedback,
|
||||
}
|
||||
|
||||
_, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": true})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
if len(feedback.calls) != 1 {
|
||||
t.Fatalf("feedback calls=%#v", feedback.calls)
|
||||
}
|
||||
if feedback.calls[0].tenantID != "tenant-1" {
|
||||
t.Fatalf("tenantID=%q", feedback.calls[0].tenantID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMessageFeedback_DoesNotTriggerChunkFeedbackWhenUpdateFails(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Message: json.RawMessage(`[
|
||||
{"role":"assistant","content":"prologue"},
|
||||
{"role":"user","content":"q1","id":"m1"},
|
||||
{"role":"assistant","content":"a1","id":"m1"}
|
||||
]`),
|
||||
Reference: json.RawMessage(`[{"chunks":[{"id":"c1","dataset_id":"kb1"}],"doc_aggs":[]}]`),
|
||||
}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
store.updateByIDErr = errors.New("update failed")
|
||||
feedback := &fakeChunkFeedbackApplier{}
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
chunkFeedbackApplier: feedback,
|
||||
}
|
||||
|
||||
_, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": true})
|
||||
if err == nil || err.Error() != "update failed" {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
if code != common.CodeServerError {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
if len(feedback.calls) != 0 {
|
||||
t.Fatalf("feedback should not be called: %#v", feedback.calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMessageFeedback_UnchangedFeedbackDoesNotTriggerChunkFeedback(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Message: json.RawMessage(`[
|
||||
{"role":"assistant","content":"prologue"},
|
||||
{"role":"user","content":"q1","id":"m1"},
|
||||
{"role":"assistant","content":"a1","id":"m1","thumbup":true}
|
||||
]`),
|
||||
Reference: json.RawMessage(`[{"chunks":[{"id":"c1","dataset_id":"kb1"}],"doc_aggs":[]}]`),
|
||||
}
|
||||
store.dialogExists["user-1|chat-1"] = true
|
||||
feedback := &fakeChunkFeedbackApplier{}
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
chunkFeedbackApplier: feedback,
|
||||
}
|
||||
|
||||
_, code, err := svc.UpdateMessageFeedback("user-1", "chat-1", "session-1", "m1", map[string]interface{}{"thumbup": true})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
if len(feedback.calls) != 0 {
|
||||
t.Fatalf("feedback should not be called: %#v", feedback.calls)
|
||||
}
|
||||
}
|
||||
|
||||
// ===================================================================
|
||||
// Completion tests
|
||||
// ===================================================================
|
||||
|
||||
Reference in New Issue
Block a user