feat(go-api) sessions message update (#16517)

### Summary
```
/api/v1/chats/<chat_id>/sessions/<session_id>/messages/<msg_id> DELETE
/api/v1/chats/<chat_id>/sessions/<session_id>/messages/<msg_id>/feedback PUT
```
Migrates the chat session message delete and feedback APIs to the Go
server, matching the Python behavior for authorization, session
ownership checks, message/reference updates, and feedback validation.
This commit is contained in:
Hz_
2026-07-02 10:33:27 +08:00
committed by GitHub
parent 5bc4753d1e
commit 0de69e5bba
7 changed files with 1657 additions and 30 deletions

View File

@@ -271,6 +271,93 @@ func (e *elasticsearchEngine) UpdateChunks(ctx context.Context, condition map[st
return e.updateChunksByQuery(ctx, fullIndexName, condition, newValue)
}
// AdjustChunkPagerank atomically adjusts pagerank_fea and clamps it to
// [minWeight, maxWeight].
func (e *elasticsearchEngine) AdjustChunkPagerank(ctx context.Context, indexName, chunkID, kbID string, delta, minWeight, maxWeight float64) error {
if indexName == "" {
return fmt.Errorf("index name cannot be empty")
}
if chunkID == "" {
return fmt.Errorf("chunk id cannot be empty")
}
script := `
if (ctx._source.kb_id == null || !ctx._source.kb_id.equals(params.kb_id)) {
ctx.op = 'noop';
} else {
double current = 0.0;
if (ctx._source.containsKey(params.field) && ctx._source[params.field] != null) {
Object currentValue = ctx._source[params.field];
if (currentValue instanceof Number) {
current = ((Number)currentValue).doubleValue();
} else {
try {
current = Double.parseDouble(currentValue.toString());
} catch (Exception e) {
current = 0.0;
}
}
}
double next = current + params.delta;
if (next < params.min_weight) {
next = params.min_weight;
}
if (next > params.max_weight) {
next = params.max_weight;
}
if (next <= 0.0) {
ctx._source.remove(params.field);
} else {
ctx._source[params.field] = next;
}
}
`
body, err := json.Marshal(map[string]interface{}{
"script": map[string]interface{}{
"source": script,
"lang": "painless",
"params": map[string]interface{}{
"field": common.PAGERANK_FLD,
"kb_id": kbID,
"delta": delta,
"min_weight": minWeight,
"max_weight": maxWeight,
},
},
})
if err != nil {
return fmt.Errorf("failed to marshal pagerank adjust request: %w", err)
}
retryOnConflict := 3
req := esapi.UpdateRequest{
Index: indexName,
DocumentID: chunkID,
Body: bytes.NewReader(body),
RetryOnConflict: &retryOnConflict,
}
res, err := req.Do(ctx, e.client)
if err != nil {
return fmt.Errorf("failed to adjust chunk pagerank: %w", err)
}
defer res.Body.Close()
if res.IsError() {
if res.StatusCode == http.StatusNotFound {
return fmt.Errorf("%w: %s", types.ErrDocumentNotFound, chunkID)
}
bodyBytes, _ := io.ReadAll(res.Body)
return fmt.Errorf("elasticsearch pagerank adjust error: %s, body: %s", res.Status(), string(bodyBytes))
}
var updateResp struct {
Result string `json:"result"`
}
if err := json.NewDecoder(res.Body).Decode(&updateResp); err != nil {
return fmt.Errorf("failed to decode pagerank adjust response: %w", err)
}
if updateResp.Result == "noop" {
return fmt.Errorf("chunk %s does not belong to dataset %s", chunkID, kbID)
}
return nil
}
func (e *elasticsearchEngine) updateSingleMemoryMessage(ctx context.Context, indexName, messageDocID string, newValue map[string]interface{}) error {
doc := mapMemoryMessageESUpdateFields(newValue)
delete(doc, "id")

View File

@@ -20,6 +20,7 @@ import (
"context"
"encoding/json"
"fmt"
"hash/fnv"
"os"
"path/filepath"
"ragflow/internal/common"
@@ -30,6 +31,7 @@ import (
"sort"
"strconv"
"strings"
"sync"
infinity "github.com/infiniflow/infinity-go-sdk"
"go.uber.org/zap"
@@ -38,6 +40,13 @@ import (
// ChinesePunctRegex splits on comma, semicolon, Chinese punctuations, and newlines
var ChinesePunctRegex = regexp.MustCompile(`[,;;、\r\n]+`)
const (
pagerankAdjustRetryCount = 2
pagerankAdjustLockCount = 256
)
var pagerankAdjustLocks [pagerankAdjustLockCount]sync.Mutex
// CreateChunkStore creates a chunk table in Infinity
// baseName is the table name prefix (e.g., "ragflow_<tenant_id>")
// The full table name is built as "{baseName}_{datasetID}"
@@ -518,6 +527,97 @@ func (e *infinityEngine) UpdateChunks(ctx context.Context, condition map[string]
return nil
}
// AdjustChunkPagerank adjusts pagerank_fea and clamps it to [minWeight, maxWeight].
func (e *infinityEngine) AdjustChunkPagerank(ctx context.Context, baseName, chunkID, datasetID string, delta, minWeight, maxWeight float64) error {
if baseName == "" {
return fmt.Errorf("index name cannot be empty")
}
if chunkID == "" {
return fmt.Errorf("chunk id cannot be empty")
}
if ctx == nil {
ctx = context.Background()
}
if e.client == nil || e.client.conn == nil {
return fmt.Errorf("Infinity client not initialized")
}
tableName := buildChunkTableName(baseName, datasetID)
lock := pagerankAdjustLock(tableName + ":" + chunkID)
lock.Lock()
defer lock.Unlock()
db, err := e.client.conn.GetDatabase(e.client.dbName)
if err != nil {
return fmt.Errorf("failed to get database: %w", err)
}
table, err := db.GetTable(tableName)
if err != nil {
return fmt.Errorf("failed to get table %s: %w", tableName, err)
}
var lastErr error
filter := fmt.Sprintf("id = '%s'", escapeFilterValue(chunkID))
for attempt := 0; attempt <= pagerankAdjustRetryCount; attempt++ {
if err := ctx.Err(); err != nil {
return err
}
result, err := table.Output([]string{common.PAGERANK_FLD, "row_id()"}).Filter(filter).ToResult()
if err != nil {
lastErr = err
continue
}
qr, ok := result.(*infinity.QueryResult)
if !ok {
return fmt.Errorf("unexpected query result type: %T", result)
}
rowID, ok := firstQueryInt64(qr.Data, "ROW_ID", "row_id()", "row_id")
if !ok {
return fmt.Errorf("%w: %s", types.ErrDocumentNotFound, chunkID)
}
currentWeight := 0.0
if currentValue, ok := firstQueryValue(qr.Data, common.PAGERANK_FLD); ok {
if weight, ok := coerceToFloat(currentValue); ok {
currentWeight = weight
}
}
nextWeight := currentWeight + delta
if nextWeight < minWeight {
nextWeight = minWeight
}
if nextWeight > maxWeight {
nextWeight = maxWeight
}
if floatsEqual(currentWeight, nextWeight) {
return nil
}
updateFilter := fmt.Sprintf("_row_id = %d AND %s = %s", rowID, common.PAGERANK_FLD, formatFilterFloat(currentWeight))
if _, err := table.Update(updateFilter, map[string]interface{}{common.PAGERANK_FLD: nextWeight}); err != nil {
lastErr = err
continue
}
verifyResult, err := table.Output([]string{common.PAGERANK_FLD}).Filter(filter).ToResult()
if err != nil {
lastErr = err
continue
}
verifyQR, ok := verifyResult.(*infinity.QueryResult)
if !ok {
return fmt.Errorf("unexpected query result type: %T", verifyResult)
}
if currentValue, ok := firstQueryValue(verifyQR.Data, common.PAGERANK_FLD); ok {
if actualWeight, ok := coerceToFloat(currentValue); ok && floatsEqual(actualWeight, nextWeight) {
return nil
}
}
lastErr = fmt.Errorf("pagerank update conflict")
}
return fmt.Errorf("failed to adjust chunk pagerank: %w", lastErr)
}
// DeleteChunks deletes chunks from a dataset table
// Table name format: {baseName}_{datasetID}
// condition specifies which chunks to delete
@@ -1976,6 +2076,77 @@ func escapeFilterValue(s string) string {
return strings.ReplaceAll(s, "'", "''")
}
func firstQueryValue(data map[string][]interface{}, columns ...string) (interface{}, bool) {
for _, column := range columns {
values, ok := data[column]
if ok && len(values) > 0 {
return values[0], true
}
}
return nil, false
}
func firstQueryInt64(data map[string][]interface{}, columns ...string) (int64, bool) {
value, ok := firstQueryValue(data, columns...)
if !ok {
return 0, false
}
switch v := value.(type) {
case int:
return int64(v), true
case int8:
return int64(v), true
case int16:
return int64(v), true
case int32:
return int64(v), true
case int64:
return v, true
case uint:
if uint64(v) <= ^uint64(0)>>1 {
return int64(v), true
}
case uint8:
return int64(v), true
case uint16:
return int64(v), true
case uint32:
return int64(v), true
case uint64:
if v <= ^uint64(0)>>1 {
return int64(v), true
}
case float32:
return int64(v), true
case float64:
return int64(v), true
case string:
parsed, err := strconv.ParseInt(v, 10, 64)
if err == nil {
return parsed, true
}
}
return 0, false
}
func pagerankAdjustLock(key string) *sync.Mutex {
hash := fnv.New32a()
_, _ = hash.Write([]byte(key))
return &pagerankAdjustLocks[hash.Sum32()%pagerankAdjustLockCount]
}
func formatFilterFloat(value float64) string {
return strconv.FormatFloat(value, 'f', -1, 64)
}
func floatsEqual(a, b float64) bool {
diff := a - b
if diff < 0 {
diff = -diff
}
return diff < 1e-9
}
// equivalentConditionToStr converts a condition map to an Infinity filter string
func equivalentConditionToStr(condition map[string]interface{}) string {
if len(condition) == 0 {

View File

@@ -398,3 +398,60 @@ func (h *ChatSessionHandler) UpdateSession(c *gin.Context) {
}
jsonResponse(c, common.CodeSuccess, result, "success")
}
func (h *ChatSessionHandler) DeleteSessionMessage(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := user.ID
chatID, sessionID, msgID := c.Param("chat_id"), c.Param("session_id"), c.Param("msg_id")
result, code, err := h.chatSessionService.DeleteSessionMessage(userID, chatID, sessionID, msgID)
if err != nil {
if code == common.CodeAuthenticationError {
jsonResponse(c, code, false, err.Error())
return
}
jsonError(c, code, err.Error())
return
}
jsonResponse(c, common.CodeSuccess, result, "success")
}
func (h *ChatSessionHandler) UpdateMessageFeedback(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
userID := user.ID
chatID, sessionID, msgID := c.Param("chat_id"), c.Param("session_id"), c.Param("msg_id")
req := map[string]interface{}{}
if err := c.ShouldBindJSON(&req); err != nil {
if errors.Is(err, io.EOF) {
jsonError(c, common.CodeArgumentError, "Request body cannot be empty")
return
}
jsonError(c, common.CodeArgumentError, "Invalid request: "+err.Error())
return
}
if len(req) == 0 {
jsonError(c, common.CodeArgumentError, "Request body cannot be empty")
return
}
result, code, err := h.chatSessionService.UpdateMessageFeedback(c.Request.Context(), userID, chatID, sessionID, msgID, req)
if err != nil {
if code == common.CodeAuthenticationError {
jsonResponse(c, code, false, err.Error())
return
}
jsonError(c, code, err.Error())
return
}
jsonResponse(c, common.CodeSuccess, result, "success")
}

View File

@@ -72,3 +72,64 @@ func TestChatSessionHandlerUpdateSession_RejectsEmptyJSONObject(t *testing.T) {
t.Fatalf("message=%v", got)
}
}
func TestChatSessionHandlerUpdateMessageFeedback_RejectsEmptyBody(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodPut, "/api/v1/chats/chat-1/sessions/session-1/messages/msg-1/feedback", nil)
ctx.Params = gin.Params{
{Key: "chat_id", Value: "chat-1"},
{Key: "session_id", Value: "session-1"},
{Key: "msg_id", Value: "msg-1"},
}
ctx.Set("user", &entity.User{ID: "user-1"})
handler := NewChatSessionHandler(service.NewChatSessionService(), nil)
handler.UpdateMessageFeedback(ctx)
if recorder.Code != http.StatusOK {
t.Fatalf("status=%d", recorder.Code)
}
var body map[string]interface{}
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
t.Fatalf("decode response body: %v", err)
}
if got := body["code"]; got != float64(common.CodeArgumentError) {
t.Fatalf("code=%v", got)
}
if got := body["message"]; got != "Request body cannot be empty" {
t.Fatalf("message=%v", got)
}
}
func TestChatSessionHandlerUpdateMessageFeedback_RejectsEmptyJSONObject(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodPut, "/api/v1/chats/chat-1/sessions/session-1/messages/msg-1/feedback", strings.NewReader(`{}`))
ctx.Request.Header.Set("Content-Type", "application/json")
ctx.Params = gin.Params{
{Key: "chat_id", Value: "chat-1"},
{Key: "session_id", Value: "session-1"},
{Key: "msg_id", Value: "msg-1"},
}
ctx.Set("user", &entity.User{ID: "user-1"})
handler := NewChatSessionHandler(service.NewChatSessionService(), nil)
handler.UpdateMessageFeedback(ctx)
var body map[string]interface{}
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
t.Fatalf("decode response body: %v", err)
}
if got := body["code"]; got != float64(common.CodeArgumentError) {
t.Fatalf("code=%v", got)
}
if got := body["message"]; got != "Request body cannot be empty" {
t.Fatalf("message=%v", got)
}
}

View File

@@ -300,6 +300,8 @@ func (r *Router) Setup(engine *gin.Engine) {
chats.DELETE("/:chat_id/sessions", r.chatSessionHandler.DeleteSessions)
chats.GET("/:chat_id/sessions/:session_id", r.chatSessionHandler.GetSession)
chats.PATCH("/:chat_id/sessions/:session_id", r.chatSessionHandler.UpdateSession)
chats.DELETE("/:chat_id/sessions/:session_id/messages/:msg_id", r.chatSessionHandler.DeleteSessionMessage)
chats.PUT("/:chat_id/sessions/:session_id/messages/:msg_id/feedback", r.chatSessionHandler.UpdateMessageFeedback)
}
chat := v1.Group("/chat")

View File

@@ -22,8 +22,11 @@ import (
"errors"
"fmt"
"math"
"os"
"ragflow/internal/common"
"ragflow/internal/engine"
"ragflow/internal/storage"
"strconv"
"strings"
"time"
@@ -61,14 +64,15 @@ type chatPipelineRunner interface {
// (api/db/services/chunk_feedback_service.py) call site at
// api/apps/restful_apis/chat_api.py — that handler records the thumb
// vote against every chunk that produced the assistant message, in
// addition to the session-level thumbup field. The Go stack does not
// yet have a chunk-feedback DAO, so this interface is the seam where
// one will plug in. Production uses *ChatSessionService itself via
// applyChunkFeedback (which currently no-ops with a debug log) so the
// handler can still update the session-level thumbup without crashing;
// tests can swap in a fake by setting ChatSessionService.chunkFeedbackApplier.
// addition to the session-level thumbup field. Production uses
// *ChatSessionService itself via applyChunkFeedback; tests can swap
// in a fake by setting ChatSessionService.chunkFeedbackApplier.
type chunkFeedbackApplier interface {
applyChunkFeedback(tenantID string, reference map[string]interface{}, isPositive bool) (map[string]interface{}, error)
applyChunkFeedback(ctx context.Context, tenantID string, reference map[string]interface{}, isPositive bool) (map[string]interface{}, error)
}
type chunkPagerankAdjuster interface {
AdjustChunkPagerank(ctx context.Context, indexName, chunkID, kbID string, delta, minWeight, maxWeight float64) error
}
// ChatSessionService chat session (conversation) service.
@@ -78,6 +82,7 @@ type ChatSessionService struct {
userTenantDAO userTenantStore
pipeline chatPipelineRunner
chunkFeedbackApplier chunkFeedbackApplier
docEngine engine.DocEngine
}
// NewChatSessionService create chat session service
@@ -86,9 +91,131 @@ func NewChatSessionService() *ChatSessionService {
chatSessionDAO: dao.NewChatSessionDAO(),
userTenantDAO: dao.NewUserTenantDAO(),
pipeline: NewChatPipelineService(),
docEngine: engine.Get(),
}
}
// SetChatSessionRequest set chat session request.
type SetChatSessionRequest struct {
SessionID string `json:"conversation_id,omitempty"`
DialogID string `json:"dialog_id,omitempty"`
Name string `json:"name,omitempty"`
IsNew bool `json:"is_new"`
}
// SetChatSessionResponse set chat session response.
type SetChatSessionResponse struct {
*entity.ChatSession
}
// SetChatSession creates or updates a chat session.
// Kept as a compatibility entrypoint for older chat-session callers.
func (s *ChatSessionService) SetChatSession(userID string, req *SetChatSessionRequest) (*SetChatSessionResponse, error) {
name := req.Name
if name == "" {
name = "New chat session"
}
if len(name) > 255 {
name = name[:255]
}
if !req.IsNew {
updates := map[string]interface{}{
"name": name,
"user_id": userID,
}
if err := s.chatSessionDAO.UpdateByID(req.SessionID, updates); err != nil {
return nil, errors.New("Chat session not found")
}
session, err := s.chatSessionDAO.GetByID(req.SessionID)
if err != nil {
return nil, errors.New("Fail to update a chat session")
}
return &SetChatSessionResponse{ChatSession: session}, nil
}
dialog, err := s.chatSessionDAO.GetDialogByID(req.DialogID)
if err != nil {
return nil, errors.New("Dialog not found")
}
prologue := "Hi! I'm your assistant. What can I do for you?"
if dialog.PromptConfig != nil {
if p, ok := dialog.PromptConfig["prologue"].(string); ok && p != "" {
prologue = p
}
}
messagesJSON, _ := json.Marshal([]map[string]interface{}{
{
"role": "assistant",
"content": prologue,
},
})
referenceJSON, _ := json.Marshal([]interface{}{})
session := &entity.ChatSession{
ID: common.GenerateUUID(),
DialogID: req.DialogID,
Name: &name,
Message: messagesJSON,
UserID: &userID,
Reference: referenceJSON,
}
if err := s.chatSessionDAO.Create(session); err != nil {
return nil, errors.New("Fail to create a chat session")
}
return &SetChatSessionResponse{ChatSession: session}, nil
}
// RemoveChatSessions removes chat sessions.
// Kept as a compatibility entrypoint for older chat-session callers.
func (s *ChatSessionService) RemoveChatSessions(userID string, chatSessions []string) error {
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
if err != nil {
return err
}
tenantIDSet := make(map[string]bool)
for _, tid := range tenantIDs {
tenantIDSet[tid] = true
}
tenantIDSet[userID] = true
for _, convID := range chatSessions {
session, err := s.chatSessionDAO.GetByID(convID)
if err != nil {
return fmt.Errorf("Chat session not found: %s", convID)
}
isOwner := false
for tenantID := range tenantIDSet {
exists, err := s.chatSessionDAO.CheckDialogExists(tenantID, session.DialogID)
if err != nil {
return err
}
if exists {
isOwner = true
break
}
}
if !isOwner {
return errors.New("Only owner of chat session authorized for this operation")
}
if err := s.chatSessionDAO.DeleteByID(convID); err != nil {
return err
}
}
return nil
}
// ListChatSessionsRequest list chat sessions request.
type ListChatSessionsRequest struct {
DialogID string `json:"dialog_id" binding:"required"`
}
// ListChatSessionsResponse list chat sessions response
type ListChatSessionsResponse struct {
Sessions []*entity.ChatSession
@@ -542,7 +669,10 @@ func (s *ChatSessionService) DeleteSessionMessage(userID, chatID, sessionID, msg
return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil
}
func (s *ChatSessionService) UpdateMessageFeedback(userID, chatID, sessionID, msgID string, req map[string]interface{}) (*ChatSessionPayload, common.ErrorCode, error) {
func (s *ChatSessionService) UpdateMessageFeedback(ctx context.Context, userID, chatID, sessionID, msgID string, req map[string]interface{}) (*ChatSessionPayload, common.ErrorCode, error) {
if ctx == nil {
ctx = context.Background()
}
ownerTenantID := ""
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
if err != nil {
@@ -654,7 +784,7 @@ func (s *ChatSessionService) UpdateMessageFeedback(userID, chatID, sessionID, ms
applier = s
}
if priorThumbBool, ok := priorThumb.(bool); ok && priorThumbBool != thumbup {
result, _ := applier.applyChunkFeedback(ownerTenantID, feedbackReference, !priorThumbBool)
result, _ := applier.applyChunkFeedback(ctx, ownerTenantID, feedbackReference, !priorThumbBool)
if result != nil {
common.Debug("Chunk feedback undo applied",
zap.Any("success_count", result["success_count"]),
@@ -662,7 +792,7 @@ func (s *ChatSessionService) UpdateMessageFeedback(userID, chatID, sessionID, ms
)
}
}
result, _ := applier.applyChunkFeedback(ownerTenantID, feedbackReference, thumbup)
result, _ := applier.applyChunkFeedback(ctx, ownerTenantID, feedbackReference, thumbup)
if result != nil {
common.Debug("Chunk feedback applied",
zap.Any("success_count", result["success_count"]),
@@ -674,32 +804,308 @@ func (s *ChatSessionService) UpdateMessageFeedback(userID, chatID, sessionID, ms
return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil
}
// applyChunkFeedback records a thumb vote against the chunks that
// produced a session message. Mirrors Python's
// ChunkFeedbackService.apply_feedback side effect (called from
// api/apps/restful_apis/chat_api.py when a user toggles a thumb on
// an assistant message). The Go persistence port for chunk feedback
// is intentionally not yet landed — the call here is a documented
// no-op so the session-level thumbup flow (the user-visible behavior)
// keeps working while a future PR ports the Python DAO. The returned
// counts let the caller log a "Chunk feedback applied: N succeeded,
// M failed" line consistent with the Python equivalent, so log
// scrapers don't see a regression in success/fail rates.
//
// Production callers should always go through the chunkFeedbackApplier
// field; this method is the default implementation used when that
// field is nil.
func (s *ChatSessionService) applyChunkFeedback(tenantID string, reference map[string]interface{}, isPositive bool) (map[string]interface{}, error) {
common.Debug("chunk feedback persistence not yet ported; dropping vote",
zap.String("tenant_id", tenantID),
const (
upvoteWeightIncrement = 1
downvoteWeightDecrement = 1
minPagerankWeight = 0.0
maxPagerankWeight = 100.0
chunkFeedbackTimeout = 30 * time.Second
)
type feedbackChunkRow struct {
chunkID string
kbID string
chunk map[string]interface{}
}
// applyChunkFeedback records a thumb vote against the chunks that produced a
// session message. It mirrors Python's ChunkFeedbackService.apply_feedback:
// feature-flagged by CHUNK_FEEDBACK_ENABLED, split by relevance unless
// CHUNK_FEEDBACK_WEIGHTING=uniform, and clamped through the document engine.
func (s *ChatSessionService) applyChunkFeedback(ctx context.Context, tenantID string, reference map[string]interface{}, isPositive bool) (map[string]interface{}, error) {
if !chunkFeedbackEnabled() {
common.Debug("Chunk feedback feature is disabled")
return map[string]interface{}{
"success_count": 0,
"fail_count": 0,
"chunk_ids": []string{},
"disabled": true,
}, nil
}
if ctx == nil {
ctx = context.Background()
}
if _, ok := ctx.Deadline(); !ok {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, chunkFeedbackTimeout)
defer cancel()
}
rows := feedbackRowsFromReference(reference)
chunkIDs := make([]string, 0, len(rows))
for _, row := range rows {
chunkIDs = append(chunkIDs, row.chunkID)
}
if len(rows) == 0 {
common.Debug("No chunk IDs found in reference for feedback")
return map[string]interface{}{
"success_count": 0,
"fail_count": 0,
"chunk_ids": chunkIDs,
}, nil
}
signedBudget := float64(upvoteWeightIncrement)
if !isPositive {
signedBudget = -float64(downvoteWeightDecrement)
}
deltas := allocateFeedbackDeltas(rows, signedBudget, chunkFeedbackWeighting())
successCount := 0
failCount := 0
for _, delta := range deltas {
if delta.delta == 0 {
continue
}
if s.updateChunkWeight(ctx, tenantID, delta.chunkID, delta.kbID, delta.delta) {
successCount++
} else {
failCount++
}
}
common.Info("Applied chunk feedback",
zap.Bool("is_positive", isPositive),
zap.String("weighting", chunkFeedbackWeighting()),
zap.Int("success_count", successCount),
zap.Int("chunk_count", len(chunkIDs)),
)
return map[string]interface{}{
"success_count": 0,
"fail_count": 0,
"success_count": successCount,
"fail_count": failCount,
"chunk_ids": chunkIDs,
}, nil
}
type feedbackDelta struct {
chunkID string
kbID string
delta float64
}
func chunkFeedbackEnabled() bool {
return strings.ToLower(os.Getenv("CHUNK_FEEDBACK_ENABLED")) == "true"
}
func chunkFeedbackWeighting() string {
weighting := strings.ToLower(strings.TrimSpace(os.Getenv("CHUNK_FEEDBACK_WEIGHTING")))
if weighting == "uniform" || weighting == "relevance" {
return weighting
}
return "relevance"
}
func feedbackRowsFromReference(reference map[string]interface{}) []feedbackChunkRow {
if len(reference) == 0 {
return nil
}
rawChunks, ok := reference["chunks"].([]interface{})
if !ok {
if chunks, ok := reference["chunks"].([]map[string]interface{}); ok {
rows := make([]feedbackChunkRow, 0, len(chunks))
for _, chunk := range chunks {
if row, ok := feedbackRowFromChunk(chunk); ok {
rows = append(rows, row)
}
}
return rows
}
return nil
}
rows := make([]feedbackChunkRow, 0, len(rawChunks))
for _, raw := range rawChunks {
chunk, ok := raw.(map[string]interface{})
if !ok {
continue
}
if row, ok := feedbackRowFromChunk(chunk); ok {
rows = append(rows, row)
}
}
return rows
}
func feedbackRowFromChunk(chunk map[string]interface{}) (feedbackChunkRow, bool) {
chunkID := stringValue(chunk["id"])
if chunkID == "" {
chunkID = stringValue(chunk["chunk_id"])
}
kbID := stringValue(chunk["dataset_id"])
if kbID == "" {
kbID = stringValue(chunk["kb_id"])
}
if chunkID == "" || kbID == "" {
return feedbackChunkRow{}, false
}
return feedbackChunkRow{chunkID: chunkID, kbID: kbID, chunk: chunk}, true
}
func allocateFeedbackDeltas(rows []feedbackChunkRow, signedBudget float64, weighting string) []feedbackDelta {
if len(rows) == 0 {
return nil
}
if weighting == "uniform" {
step := signedBudget / float64(len(rows))
deltas := make([]feedbackDelta, 0, len(rows))
for _, row := range rows {
deltas = append(deltas, feedbackDelta{chunkID: row.chunkID, kbID: row.kbID, delta: step})
}
return deltas
}
magnitudes := make([]float64, 0, len(rows))
for _, row := range rows {
signal := retrievalSignal(row.chunk)
if signal <= 0 {
signal = 1
}
magnitudes = append(magnitudes, signal)
}
parts := splitFloatBudget(magnitudes, math.Abs(signedBudget))
sign := math.Copysign(1, signedBudget)
deltas := make([]feedbackDelta, 0, len(rows))
for i, row := range rows {
deltas = append(deltas, feedbackDelta{chunkID: row.chunkID, kbID: row.kbID, delta: sign * parts[i]})
}
return deltas
}
func retrievalSignal(chunk map[string]interface{}) float64 {
best := 0.0
for _, key := range []string{"similarity", "vector_similarity", "term_similarity"} {
val, ok := floatValue(chunk[key])
if ok && !math.IsInf(val, 0) && !math.IsNaN(val) && val > best {
best = val
}
}
return best
}
func splitFloatBudget(magnitudes []float64, budget float64) []float64 {
n := len(magnitudes)
out := make([]float64, n)
if n == 0 || budget == 0 {
return out
}
total := 0.0
for _, magnitude := range magnitudes {
total += magnitude
}
if total <= 0 {
base := budget / float64(n)
for i := range out {
out[i] = base
}
return out
}
for i, magnitude := range magnitudes {
out[i] = budget * magnitude / total
}
return out
}
func (s *ChatSessionService) updateChunkWeight(ctx context.Context, tenantID, chunkID, kbID string, delta float64) bool {
docEngine := s.docEngine
if docEngine == nil {
docEngine = engine.Get()
}
if docEngine == nil {
common.Warn("Document engine is not initialized; chunk feedback skipped",
zap.String("chunk_id", chunkID),
zap.String("kb_id", kbID),
)
return false
}
indexName := fmt.Sprintf("ragflow_%s", tenantID)
if adjuster, ok := docEngine.(chunkPagerankAdjuster); ok {
if err := adjuster.AdjustChunkPagerank(ctx, indexName, chunkID, kbID, delta, minPagerankWeight, maxPagerankWeight); err != nil {
common.Warn("Failed atomic pagerank adjust for chunk",
zap.String("chunk_id", chunkID),
zap.Error(err),
)
return false
}
return true
}
rawChunk, err := docEngine.GetChunk(ctx, indexName, chunkID, []string{kbID})
if err != nil {
common.Warn("Chunk not found for feedback",
zap.String("chunk_id", chunkID),
zap.String("index", indexName),
zap.Error(err),
)
return false
}
chunk, ok := rawChunk.(map[string]interface{})
if !ok {
common.Warn("Unexpected chunk shape for feedback",
zap.String("chunk_id", chunkID),
zap.String("kb_id", kbID),
zap.String("index", indexName),
zap.String("chunk_type", fmt.Sprintf("%T", rawChunk)),
)
return false
}
currentWeight, _ := floatValue(chunk[common.PAGERANK_FLD])
nextWeight := currentWeight + delta
if nextWeight < minPagerankWeight {
nextWeight = minPagerankWeight
}
if nextWeight > maxPagerankWeight {
nextWeight = maxPagerankWeight
}
newValue := map[string]interface{}{common.PAGERANK_FLD: nextWeight}
if nextWeight <= 0 && strings.ToLower(docEngine.GetType()) == string(engine.EngineElasticsearch) {
newValue = map[string]interface{}{"remove": common.PAGERANK_FLD}
}
if err := docEngine.UpdateChunks(ctx, map[string]interface{}{"id": chunkID}, newValue, indexName, kbID); err != nil {
common.Warn("Failed to update chunk pagerank",
zap.String("chunk_id", chunkID),
zap.Error(err),
)
return false
}
return true
}
func floatValue(value interface{}) (float64, bool) {
switch v := value.(type) {
case float64:
return v, true
case float32:
return float64(v), true
case int:
return float64(v), true
case int64:
return float64(v), true
case int32:
return float64(v), true
case json.Number:
f, err := v.Float64()
return f, err == nil
case string:
f, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
return f, err == nil
default:
return 0, false
}
}
func (s *ChatSessionService) ensureOwnedChat(userID, chatID string) (bool, error) {
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
if err != nil {
@@ -851,6 +1257,180 @@ func isChatSessionNotFound(err error) bool {
return errors.Is(err, gorm.ErrRecordNotFound)
}
// Completion performs chat completion with full RAG support via ChatPipelineService.
// Kept as a compatibility entrypoint for callers that still use the pre-ChatCompletions API.
func (s *ChatSessionService) Completion(userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string) (map[string]interface{}, error) {
if len(messages) == 0 {
return nil, errors.New("messages cannot be empty")
}
lastRole, _ := messages[len(messages)-1]["role"].(string)
if lastRole != "user" {
return nil, errors.New("the last content of this conversation is not from user")
}
session, err := s.chatSessionDAO.GetByID(conversationID)
if err != nil {
return nil, errors.New("Conversation not found")
}
dialog, err := s.chatSessionDAO.GetDialogByID(session.DialogID)
if err != nil {
return nil, errors.New("Dialog not found")
}
sessionMessages := s.buildSessionMessages(session, messages)
reference := s.initializeReference(session)
isEmbedded := llmID != ""
if llmID != "" {
hasKey, err := s.checkTenantLLMAPIKey(dialog.TenantID, llmID)
if err != nil || !hasKey {
return nil, fmt.Errorf("Cannot use specified model %s", llmID)
}
dialog.LLMID = llmID
if chatModelConfig != nil {
dialog.LLMSetting = chatModelConfig
}
}
kwargs := chatModelConfig
if kwargs == nil {
kwargs = map[string]interface{}{}
}
resultChan, err := s.pipeline.AsyncChat(context.Background(), dialog, messages, false, kwargs)
if err != nil {
return nil, err
}
var answer strings.Builder
var finalRef map[string]interface{}
for result := range resultChan {
if result.Answer != "" {
answer.WriteString(result.Answer)
}
if result.Reference != nil {
finalRef = result.Reference
}
}
ans := map[string]interface{}{
"answer": answer.String(),
"reference": finalRef,
"final": true,
}
result := s.structureAnswerWithConv(session, ans, messageID, session.ID, reference)
if !isEmbedded {
sessionMessages = append(sessionMessages, map[string]interface{}{
"role": "assistant",
"content": answer.String(),
"id": messageID,
"created_at": float64(time.Now().Unix()),
})
s.updateSessionMessages(session, sessionMessages, reference)
}
return result, nil
}
// CompletionStream performs streaming chat completion with full RAG support via ChatPipelineService.
// Kept as a compatibility entrypoint for callers that still use the pre-ChatCompletions API.
func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string, streamChan chan<- string) error {
if ctx == nil {
ctx = context.Background()
}
if len(messages) == 0 {
streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "messages cannot be empty", "data": {"answer": "**ERROR**: messages cannot be empty", "reference": []}}`)
return errors.New("messages cannot be empty")
}
lastRole, _ := messages[len(messages)-1]["role"].(string)
if lastRole != "user" {
streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "the last content of this conversation is not from user", "data": {"answer": "**ERROR**: the last content of this conversation is not from user", "reference": []}}`)
return errors.New("the last content of this conversation is not from user")
}
session, err := s.chatSessionDAO.GetByID(conversationID)
if err != nil {
streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "Conversation not found", "data": {"answer": "**ERROR**: Conversation not found", "reference": []}}`)
return errors.New("Conversation not found")
}
dialog, err := s.chatSessionDAO.GetDialogByID(session.DialogID)
if err != nil {
streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "Dialog not found", "data": {"answer": "**ERROR**: Dialog not found", "reference": []}}`)
return errors.New("Dialog not found")
}
sessionMessages := s.buildSessionMessages(session, messages)
reference := s.initializeReference(session)
isEmbedded := llmID != ""
if llmID != "" {
hasKey, err := s.checkTenantLLMAPIKey(dialog.TenantID, llmID)
if err != nil || !hasKey {
errMsg := fmt.Sprintf(`{"code": 500, "message": "Cannot use specified model %s", "data": {"answer": "**ERROR**: Cannot use specified model", "reference": []}}`, llmID)
streamChan <- fmt.Sprintf("data: %s\n\n", errMsg)
return fmt.Errorf("Cannot use specified model %s", llmID)
}
dialog.LLMID = llmID
if chatModelConfig != nil {
dialog.LLMSetting = chatModelConfig
}
}
kwargs := chatModelConfig
if kwargs == nil {
kwargs = map[string]interface{}{}
}
resultChan, err := s.pipeline.AsyncChat(ctx, dialog, messages, true, kwargs)
if err != nil {
streamChan <- fmt.Sprintf("data: %s\n\n", fmt.Sprintf(`{"code": 500, "message": "%s", "data": {"answer": "**ERROR**: %s", "reference": []}}`, err.Error(), err.Error()))
return err
}
var fullAnswer strings.Builder
for result := range resultChan {
if result.Reference != nil && len(reference) > 0 {
reference[len(reference)-1] = result.Reference
}
if result.Final {
if result.Answer != "" {
fullAnswer.Reset()
fullAnswer.WriteString(result.Answer)
}
} else if result.Answer != "" {
fullAnswer.WriteString(result.Answer)
}
ans := s.structureAnswer(session, fullAnswer.String(), messageID, session.ID, reference)
data, _ := json.Marshal(map[string]interface{}{
"code": 0,
"message": "",
"data": ans,
})
streamChan <- fmt.Sprintf("data: %s\n\n", string(data))
}
finalData, _ := json.Marshal(map[string]interface{}{
"code": 0,
"message": "",
"data": true,
})
streamChan <- fmt.Sprintf("data: %s\n\n", string(finalData))
if !isEmbedded {
sessionMessages = append(sessionMessages, map[string]interface{}{
"role": "assistant",
"content": fullAnswer.String(),
"id": messageID,
"created_at": float64(time.Now().Unix()),
})
s.updateSessionMessages(session, sessionMessages, reference)
}
return nil
}
// ChatCompletions handles chat completion matching Python's session_completion.
// When stream=true, returns nil result and streams SSE via streamChan.
// When stream=false, returns the structured answer map.

View File

@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"math"
"reflect"
"strings"
"sync"
@@ -11,6 +12,7 @@ import (
"gorm.io/gorm"
"ragflow/internal/common"
"ragflow/internal/engine"
"ragflow/internal/entity"
)
@@ -171,6 +173,88 @@ func makeResultChan(results ...AsyncChatResult) <-chan AsyncChatResult {
return ch
}
type feedbackContextKey struct{}
type fakeFeedbackDocEngine struct {
engine.DocEngine
adjustCalls []struct {
ctx context.Context
indexName string
chunkID string
kbID string
delta float64
minWeight float64
maxWeight float64
}
}
func (f *fakeFeedbackDocEngine) GetType() string { return "elasticsearch" }
func (f *fakeFeedbackDocEngine) AdjustChunkPagerank(ctx context.Context, indexName, chunkID, kbID string, delta, minWeight, maxWeight float64) error {
f.adjustCalls = append(f.adjustCalls, struct {
ctx context.Context
indexName string
chunkID string
kbID string
delta float64
minWeight float64
maxWeight float64
}{ctx: ctx, indexName: indexName, chunkID: chunkID, kbID: kbID, delta: delta, minWeight: minWeight, maxWeight: maxWeight})
return nil
}
type fakeInfinityFeedbackDocEngine struct {
fakeFeedbackDocEngine
getChunkCalled bool
}
func (f *fakeInfinityFeedbackDocEngine) GetType() string { return string(engine.EngineInfinity) }
func (f *fakeInfinityFeedbackDocEngine) GetChunk(ctx context.Context, indexName, chunkID string, datasetIDs []string) (interface{}, error) {
f.getChunkCalled = true
return nil, errors.New("fallback should not be used")
}
type fakeFallbackFeedbackDocEngine struct {
engine.DocEngine
chunks map[string]map[string]interface{}
updateCalls []struct {
ctx context.Context
condition map[string]interface{}
newValue map[string]interface{}
indexName string
kbID string
}
}
func (f *fakeFallbackFeedbackDocEngine) GetType() string { return "elasticsearch" }
func (f *fakeFallbackFeedbackDocEngine) GetChunk(ctx context.Context, indexName, chunkID string, datasetIDs []string) (interface{}, error) {
chunk, ok := f.chunks[chunkID]
if !ok {
return nil, gorm.ErrRecordNotFound
}
return chunk, nil
}
func (f *fakeFallbackFeedbackDocEngine) UpdateChunks(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, indexName string, kbID string) error {
f.updateCalls = append(f.updateCalls, struct {
ctx context.Context
condition map[string]interface{}
newValue map[string]interface{}
indexName string
kbID string
}{ctx: ctx, condition: condition, newValue: newValue, indexName: indexName, kbID: kbID})
return nil
}
func requireFloatClose(t *testing.T, got, want float64) {
t.Helper()
if math.Abs(got-want) > 1e-9 {
t.Fatalf("got %v want %v", got, want)
}
}
// ===================================================================
// ListChatSessions tests
// ===================================================================
@@ -410,6 +494,591 @@ func TestUpdateSession_NotFound(t *testing.T) {
}
}
// ===================================================================
// Session message delete / feedback tests
// ===================================================================
func TestDeleteSessionMessage_RemovesMessagePairAndReference(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Message: json.RawMessage(`[
{"role":"assistant","content":"Welcome!"},
{"role":"user","content":"first","id":"msg-1"},
{"role":"assistant","content":"answer 1","id":"msg-1"},
{"role":"user","content":"second","id":"msg-2"},
{"role":"assistant","content":"answer 2","id":"msg-2"}
]`),
Reference: json.RawMessage(`[
{"chunks":[{"id":"chunk-1","kb_id":"kb-1"}]},
{"chunks":[{"id":"chunk-2","kb_id":"kb-2"}]}
]`),
}
store.dialogExists["tenant-1|chat-1"] = true
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-1"}},
pipeline: &fakePipeline{},
}
resp, code, err := svc.DeleteSessionMessage("user-1", "chat-1", "session-1", "msg-1")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("code=%v", code)
}
if len(resp.Messages) != 3 {
t.Fatalf("messages=%#v", resp.Messages)
}
if resp.Messages[1]["id"] != "msg-2" || resp.Messages[2]["id"] != "msg-2" {
t.Fatalf("remaining pair=%#v", resp.Messages)
}
if len(resp.Reference) != 1 {
t.Fatalf("reference=%#v", resp.Reference)
}
ref, ok := resp.Reference[0].(map[string]interface{})
if !ok {
t.Fatalf("reference type=%T", resp.Reference[0])
}
chunks, ok := ref["chunks"].([]FormattedChunk)
if !ok || len(chunks) != 1 || chunks[0].ID != "chunk-2" {
t.Fatalf("remaining chunks=%#v", ref["chunks"])
}
}
func TestUpdateMessageFeedback_AppliesChunkFeedbackWithResolvedTenantAndContext(t *testing.T) {
t.Setenv("CHUNK_FEEDBACK_ENABLED", "true")
t.Setenv("CHUNK_FEEDBACK_WEIGHTING", "uniform")
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Message: json.RawMessage(`[
{"role":"assistant","content":"Welcome!"},
{"role":"user","content":"question","id":"msg-1"},
{"role":"assistant","content":"answer","id":"msg-1","feedback":"old"}
]`),
Reference: json.RawMessage(`[
{"chunks":[{"id":"chunk-1","kb_id":"kb-1","similarity":0.9}]}
]`),
}
store.dialogExists["tenant-owner|chat-1"] = true
docEngine := &fakeFeedbackDocEngine{}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-owner"}},
pipeline: &fakePipeline{},
docEngine: docEngine,
}
ctx := context.WithValue(context.Background(), feedbackContextKey{}, "request-context")
resp, code, err := svc.UpdateMessageFeedback(ctx, "user-1", "chat-1", "session-1", "msg-1", map[string]interface{}{
"thumbup": true,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("code=%v", code)
}
if len(docEngine.adjustCalls) != 1 {
t.Fatalf("adjust calls=%d", len(docEngine.adjustCalls))
}
call := docEngine.adjustCalls[0]
if call.indexName != "ragflow_tenant-owner" || call.chunkID != "chunk-1" || call.kbID != "kb-1" {
t.Fatalf("call=%#v", call)
}
if call.delta != 1 || call.minWeight != 0 || call.maxWeight != 100 {
t.Fatalf("weights=%#v", call)
}
if got := call.ctx.Value(feedbackContextKey{}); got != "request-context" {
t.Fatalf("context value=%v", got)
}
assistant := resp.Messages[2]
if assistant["thumbup"] != true {
t.Fatalf("assistant=%#v", assistant)
}
if _, ok := assistant["feedback"]; ok {
t.Fatalf("positive feedback should remove text feedback: %#v", assistant)
}
}
func TestUpdateMessageFeedback_ToggleUsesResolvedTenantForUndoAndApply(t *testing.T) {
t.Setenv("CHUNK_FEEDBACK_ENABLED", "true")
t.Setenv("CHUNK_FEEDBACK_WEIGHTING", "uniform")
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Message: json.RawMessage(`[
{"role":"assistant","content":"Welcome!"},
{"role":"user","content":"question","id":"msg-1"},
{"role":"assistant","content":"answer","id":"msg-1","thumbup":true}
]`),
Reference: json.RawMessage(`[
{"chunks":[{"chunk_id":"chunk-1","dataset_id":"kb-1"}]}
]`),
}
store.dialogExists["tenant-owner|chat-1"] = true
docEngine := &fakeFeedbackDocEngine{}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-owner"}},
pipeline: &fakePipeline{},
docEngine: docEngine,
}
resp, code, err := svc.UpdateMessageFeedback(context.Background(), "user-1", "chat-1", "session-1", "msg-1", map[string]interface{}{
"thumbup": false,
"feedback": "not useful",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("code=%v", code)
}
if len(docEngine.adjustCalls) != 2 {
t.Fatalf("adjust calls=%d", len(docEngine.adjustCalls))
}
for _, call := range docEngine.adjustCalls {
if call.indexName != "ragflow_tenant-owner" || call.delta != -1 {
t.Fatalf("call=%#v", call)
}
}
assistant := resp.Messages[2]
if assistant["thumbup"] != false || assistant["feedback"] != "not useful" {
t.Fatalf("assistant=%#v", assistant)
}
}
func TestApplyChunkFeedback_DisabledDoesNotTouchEngine(t *testing.T) {
t.Setenv("CHUNK_FEEDBACK_ENABLED", "false")
docEngine := &fakeFeedbackDocEngine{}
svc := &ChatSessionService{docEngine: docEngine}
result, err := svc.applyChunkFeedback(context.Background(), "tenant-1", map[string]interface{}{
"chunks": []interface{}{map[string]interface{}{"id": "chunk-1", "kb_id": "kb-1"}},
}, true)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result["disabled"] != true {
t.Fatalf("result=%#v", result)
}
if len(docEngine.adjustCalls) != 0 {
t.Fatalf("adjust calls=%d", len(docEngine.adjustCalls))
}
}
func TestApplyChunkFeedback_UniformSplitsOneVoteAcrossChunks(t *testing.T) {
t.Setenv("CHUNK_FEEDBACK_ENABLED", "true")
t.Setenv("CHUNK_FEEDBACK_WEIGHTING", "uniform")
docEngine := &fakeFeedbackDocEngine{}
svc := &ChatSessionService{docEngine: docEngine}
result, err := svc.applyChunkFeedback(context.Background(), "tenant-1", map[string]interface{}{
"chunks": []interface{}{
map[string]interface{}{"id": "chunk-1", "kb_id": "kb-1"},
map[string]interface{}{"id": "chunk-2", "kb_id": "kb-1"},
},
}, true)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result["success_count"] != 2 {
t.Fatalf("result=%#v", result)
}
if len(docEngine.adjustCalls) != 2 {
t.Fatalf("adjust calls=%d", len(docEngine.adjustCalls))
}
requireFloatClose(t, docEngine.adjustCalls[0].delta, 0.5)
requireFloatClose(t, docEngine.adjustCalls[1].delta, 0.5)
requireFloatClose(t, docEngine.adjustCalls[0].delta+docEngine.adjustCalls[1].delta, 1)
}
func TestApplyChunkFeedback_RelevanceDistributesOneVoteBySignals(t *testing.T) {
t.Setenv("CHUNK_FEEDBACK_ENABLED", "true")
t.Setenv("CHUNK_FEEDBACK_WEIGHTING", "relevance")
docEngine := &fakeFeedbackDocEngine{}
svc := &ChatSessionService{docEngine: docEngine}
result, err := svc.applyChunkFeedback(context.Background(), "tenant-1", map[string]interface{}{
"chunks": []interface{}{
map[string]interface{}{"id": "chunk-1", "kb_id": "kb-1", "similarity": 2.0},
map[string]interface{}{"id": "chunk-2", "kb_id": "kb-1", "vector_similarity": 1.0},
map[string]interface{}{"id": "chunk-3", "kb_id": "kb-1"},
},
}, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result["success_count"] != 3 {
t.Fatalf("result=%#v", result)
}
if len(docEngine.adjustCalls) != 3 {
t.Fatalf("adjust calls=%d", len(docEngine.adjustCalls))
}
requireFloatClose(t, docEngine.adjustCalls[0].delta, -0.5)
requireFloatClose(t, docEngine.adjustCalls[1].delta, -0.25)
requireFloatClose(t, docEngine.adjustCalls[2].delta, -0.25)
requireFloatClose(t, docEngine.adjustCalls[0].delta+docEngine.adjustCalls[1].delta+docEngine.adjustCalls[2].delta, -1)
}
func TestUpdateChunkWeight_InfinityUsesAtomicAdjuster(t *testing.T) {
docEngine := &fakeInfinityFeedbackDocEngine{}
svc := &ChatSessionService{docEngine: docEngine}
if ok := svc.updateChunkWeight(context.Background(), "tenant-1", "chunk-1", "kb-1", 0.25); !ok {
t.Fatal("expected updateChunkWeight to succeed")
}
if docEngine.getChunkCalled {
t.Fatal("expected Infinity adjuster path, got GetChunk fallback")
}
if len(docEngine.adjustCalls) != 1 {
t.Fatalf("adjust calls=%d", len(docEngine.adjustCalls))
}
call := docEngine.adjustCalls[0]
if call.indexName != "ragflow_tenant-1" || call.chunkID != "chunk-1" || call.kbID != "kb-1" {
t.Fatalf("call=%#v", call)
}
requireFloatClose(t, call.delta, 0.25)
}
func TestApplyChunkFeedback_FallbackClampsAndRemovesPagerank(t *testing.T) {
t.Setenv("CHUNK_FEEDBACK_ENABLED", "true")
t.Setenv("CHUNK_FEEDBACK_WEIGHTING", "uniform")
docEngine := &fakeFallbackFeedbackDocEngine{
chunks: map[string]map[string]interface{}{
"chunk-1": {common.PAGERANK_FLD: 0},
},
}
svc := &ChatSessionService{docEngine: docEngine}
result, err := svc.applyChunkFeedback(context.Background(), "tenant-1", map[string]interface{}{
"chunks": []interface{}{map[string]interface{}{"id": "chunk-1", "kb_id": "kb-1"}},
}, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result["success_count"] != 1 {
t.Fatalf("result=%#v", result)
}
if len(docEngine.updateCalls) != 1 {
t.Fatalf("update calls=%d", len(docEngine.updateCalls))
}
call := docEngine.updateCalls[0]
if call.indexName != "ragflow_tenant-1" || call.kbID != "kb-1" {
t.Fatalf("call=%#v", call)
}
if call.condition["id"] != "chunk-1" || call.newValue["remove"] != common.PAGERANK_FLD {
t.Fatalf("call=%#v", call)
}
}
// ===================================================================
// Completion tests
// ===================================================================
func TestCompletion_Success(t *testing.T) {
store := newFakeSessionStore()
session := &entity.ChatSession{
ID: "session-1", DialogID: "dialog-1",
Message: json.RawMessage(`[{"role":"assistant","content":"Welcome!"}]`),
Reference: json.RawMessage(`[]`),
}
store.sessions["session-1"] = session
store.dialogs["dialog-1"] = &entity.Chat{
ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory",
LLMSetting: entity.JSONMap{},
}
pipeline := &fakePipeline{
resultChan: makeResultChan(
AsyncChatResult{Answer: "Hello", Reference: map[string]interface{}{"chunks": []interface{}{}}},
AsyncChatResult{Answer: " world", Final: true, Reference: map[string]interface{}{"chunks": []interface{}{}}},
),
}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: pipeline,
}
result, err := svc.Completion("user-1", "session-1", []map[string]interface{}{
{"role": "user", "content": "hi"},
}, "", nil, "msg-1")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
ans, _ := result["answer"].(string)
if ans != "Hello world" {
t.Fatalf("expected answer 'Hello world', got %q", ans)
}
got := parseMessages(store.sessions["session-1"].Message)
if len(got) != 3 {
t.Fatalf("stored messages=%#v", got)
}
if got[0]["role"] != "assistant" || got[0]["content"] != "Welcome!" {
t.Fatalf("stored prologue=%#v", got[0])
}
if got[1]["role"] != "user" || got[1]["content"] != "hi" {
t.Fatalf("stored user message=%#v", got[1])
}
if got[2]["role"] != "assistant" || got[2]["content"] != "Hello world" || got[2]["id"] != "msg-1" {
t.Fatalf("stored assistant message=%#v", got[2])
}
}
func TestCompletion_EmptyMessages(t *testing.T) {
svc := &ChatSessionService{
chatSessionDAO: &fakeSessionStore{},
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, err := svc.Completion("user-1", "session-1", nil, "", nil, "msg-1")
if err == nil || err.Error() != "messages cannot be empty" {
t.Fatalf("expected 'messages cannot be empty', got %v", err)
}
}
func TestCompletion_LastMessageNotFromUser(t *testing.T) {
svc := &ChatSessionService{
chatSessionDAO: &fakeSessionStore{},
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, err := svc.Completion("user-1", "session-1", []map[string]interface{}{
{"role": "assistant", "content": "hello"},
}, "", nil, "msg-1")
if err == nil || !strings.Contains(err.Error(), "not from user") {
t.Fatalf("expected 'not from user' error, got %v", err)
}
}
func TestCompletion_ConversationNotFound(t *testing.T) {
store := newFakeSessionStore()
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, err := svc.Completion("user-1", "missing", []map[string]interface{}{
{"role": "user", "content": "hi"},
}, "", nil, "msg-1")
if err == nil || err.Error() != "Conversation not found" {
t.Fatalf("expected 'Conversation not found', got %v", err)
}
}
func TestCompletion_DialogNotFound(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1", DialogID: "dialog-1",
Message: json.RawMessage(`[]`),
Reference: json.RawMessage(`[]`),
}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{},
}
_, err := svc.Completion("user-1", "session-1", []map[string]interface{}{
{"role": "user", "content": "hi"},
}, "", nil, "msg-1")
if err == nil || err.Error() != "Dialog not found" {
t.Fatalf("expected 'Dialog not found', got %v", err)
}
}
func TestCompletion_PipelineError(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1", DialogID: "dialog-1",
Message: json.RawMessage(`[]`),
Reference: json.RawMessage(`[]`),
}
store.dialogs["dialog-1"] = &entity.Chat{
ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory",
LLMSetting: entity.JSONMap{},
}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: &fakePipeline{err: errors.New("model unavailable")},
}
_, err := svc.Completion("user-1", "session-1", []map[string]interface{}{
{"role": "user", "content": "hi"},
}, "", nil, "msg-1")
if err == nil || err.Error() != "model unavailable" {
t.Fatalf("expected 'model unavailable' error, got %v", err)
}
}
// ===================================================================
// CompletionStream tests
// ===================================================================
func readStreamChan(ch <-chan string, n int) []string {
var msgs []string
for i := 0; i < n; i++ {
select {
case msg, ok := <-ch:
if !ok {
return msgs
}
msgs = append(msgs, msg)
default:
return msgs
}
}
return msgs
}
func TestCompletionStream_Success(t *testing.T) {
store := newFakeSessionStore()
store.sessions["session-1"] = &entity.ChatSession{
ID: "session-1", DialogID: "dialog-1",
Message: json.RawMessage(`{"messages":[{"role":"assistant","content":"Welcome!"}]}`),
Reference: json.RawMessage(`[]`),
}
store.dialogs["dialog-1"] = &entity.Chat{
ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory",
LLMSetting: entity.JSONMap{},
}
pipeline := &fakePipeline{
resultChan: makeResultChan(
AsyncChatResult{Answer: "stream", Reference: map[string]interface{}{"chunks": []interface{}{}}},
AsyncChatResult{Answer: " answer", Reference: map[string]interface{}{"chunks": []interface{}{}}},
),
}
svc := &ChatSessionService{
chatSessionDAO: store,
userTenantDAO: &fakeTenantStore{},
pipeline: pipeline,
}
streamChan := make(chan string, 10)
err := svc.CompletionStream(context.Background(), "user-1", "session-1", []map[string]interface{}{
{"role": "user", "content": "hi"},
}, "", nil, "msg-1", streamChan)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Should receive data events and final signal
msgs := readStreamChan(streamChan, 5)
if len(msgs) < 3 {
t.Fatalf("expected at least 3 stream messages, got %d: %v", len(msgs), msgs)
}
// Check final signal
finalFound := false
for _, m := range msgs {
if strings.Contains(m, `"data":true`) {
finalFound = true
break
}
}
if !finalFound {
t.Fatal("expected final=true signal in stream")
}
got := parseMessages(store.sessions["session-1"].Message)
if len(got) != 3 {
t.Fatalf("stored messages=%#v", got)
}
if got[0]["role"] != "assistant" || got[0]["content"] != "Welcome!" {
t.Fatalf("stored prologue=%#v", got[0])
}
if got[1]["role"] != "user" || got[1]["content"] != "hi" {
t.Fatalf("stored user message=%#v", got[1])
}
if got[2]["role"] != "assistant" || got[2]["content"] != "stream answer" || got[2]["id"] != "msg-1" {
t.Fatalf("stored assistant message=%#v", got[2])
}
}
func TestStructureAnswerWithConv_ParsesArrayMessages(t *testing.T) {
session := &entity.ChatSession{
ID: "session-1",
Message: json.RawMessage(`[{"role":"assistant","content":"Welcome!"}]`),
}
svc := &ChatSessionService{}
ans := svc.structureAnswerWithConv(session, map[string]interface{}{
"answer": "Final answer",
"reference": map[string]interface{}{"chunks": []interface{}{}},
"final": true,
}, "msg-1", "session-1", []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
if ans["id"] != "msg-1" || ans["session_id"] != "session-1" {
t.Fatalf("ans=%#v", ans)
}
got := parseMessages(session.Message)
if len(got) != 1 {
t.Fatalf("stored messages=%#v", got)
}
if got[0]["role"] != "assistant" || got[0]["content"] != "Final answer" || got[0]["id"] != "msg-1" {
t.Fatalf("stored assistant message=%#v", got[0])
}
}
func TestParseMessages_LegacyWrappedObject(t *testing.T) {
got := parseMessages(json.RawMessage(`{"messages":[{"role":"assistant","content":"legacy"}]}`))
if !reflect.DeepEqual(got, []map[string]interface{}{{"role": "assistant", "content": "legacy"}}) {
t.Fatalf("messages=%#v", got)
}
}
func TestBuildSessionPayload_EmptyCollectionsEncodeAsEmptyArrays(t *testing.T) {
svc := &ChatSessionService{}
payload := svc.buildSessionPayload(&entity.ChatSession{
ID: "session-1",
DialogID: "chat-1",
Message: nil,
Reference: json.RawMessage(`null`),
}, nil, false)
if payload.Messages == nil {
t.Fatal("messages is nil")
}
if payload.Reference == nil {
t.Fatal("reference is nil")
}
body, err := json.Marshal(payload)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
if !strings.Contains(string(body), `"messages":[]`) {
t.Fatalf("messages did not encode as empty array: %s", string(body))
}
if !strings.Contains(string(body), `"reference":[]`) {
t.Fatalf("reference did not encode as empty array: %s", string(body))
}
}
func TestParseCollections_ReturnEmptySlicesForMissingOrNull(t *testing.T) {
messageInputs := []json.RawMessage{
nil,