diff --git a/internal/engine/elasticsearch/chunk.go b/internal/engine/elasticsearch/chunk.go index c5efa2b49e..93a6554b54 100644 --- a/internal/engine/elasticsearch/chunk.go +++ b/internal/engine/elasticsearch/chunk.go @@ -271,6 +271,93 @@ func (e *elasticsearchEngine) UpdateChunks(ctx context.Context, condition map[st return e.updateChunksByQuery(ctx, fullIndexName, condition, newValue) } +// AdjustChunkPagerank atomically adjusts pagerank_fea and clamps it to +// [minWeight, maxWeight]. +func (e *elasticsearchEngine) AdjustChunkPagerank(ctx context.Context, indexName, chunkID, kbID string, delta, minWeight, maxWeight float64) error { + if indexName == "" { + return fmt.Errorf("index name cannot be empty") + } + if chunkID == "" { + return fmt.Errorf("chunk id cannot be empty") + } + script := ` + if (ctx._source.kb_id == null || !ctx._source.kb_id.equals(params.kb_id)) { + ctx.op = 'noop'; + } else { + double current = 0.0; + if (ctx._source.containsKey(params.field) && ctx._source[params.field] != null) { + Object currentValue = ctx._source[params.field]; + if (currentValue instanceof Number) { + current = ((Number)currentValue).doubleValue(); + } else { + try { + current = Double.parseDouble(currentValue.toString()); + } catch (Exception e) { + current = 0.0; + } + } + } + double next = current + params.delta; + if (next < params.min_weight) { + next = params.min_weight; + } + if (next > params.max_weight) { + next = params.max_weight; + } + if (next <= 0.0) { + ctx._source.remove(params.field); + } else { + ctx._source[params.field] = next; + } + } + ` + body, err := json.Marshal(map[string]interface{}{ + "script": map[string]interface{}{ + "source": script, + "lang": "painless", + "params": map[string]interface{}{ + "field": common.PAGERANK_FLD, + "kb_id": kbID, + "delta": delta, + "min_weight": minWeight, + "max_weight": maxWeight, + }, + }, + }) + if err != nil { + return fmt.Errorf("failed to marshal pagerank adjust request: %w", err) + } + retryOnConflict := 3 + req := esapi.UpdateRequest{ + Index: indexName, + DocumentID: chunkID, + Body: bytes.NewReader(body), + RetryOnConflict: &retryOnConflict, + } + res, err := req.Do(ctx, e.client) + if err != nil { + return fmt.Errorf("failed to adjust chunk pagerank: %w", err) + } + defer res.Body.Close() + if res.IsError() { + if res.StatusCode == http.StatusNotFound { + return fmt.Errorf("%w: %s", types.ErrDocumentNotFound, chunkID) + } + bodyBytes, _ := io.ReadAll(res.Body) + return fmt.Errorf("elasticsearch pagerank adjust error: %s, body: %s", res.Status(), string(bodyBytes)) + } + var updateResp struct { + Result string `json:"result"` + } + if err := json.NewDecoder(res.Body).Decode(&updateResp); err != nil { + return fmt.Errorf("failed to decode pagerank adjust response: %w", err) + } + if updateResp.Result == "noop" { + return fmt.Errorf("chunk %s does not belong to dataset %s", chunkID, kbID) + } + return nil +} + func (e *elasticsearchEngine) updateSingleMemoryMessage(ctx context.Context, indexName, messageDocID string, newValue map[string]interface{}) error { doc := mapMemoryMessageESUpdateFields(newValue) delete(doc, "id") diff --git a/internal/engine/infinity/chunk.go b/internal/engine/infinity/chunk.go index 314d2a4599..ad71d1d1da 100644 --- a/internal/engine/infinity/chunk.go +++ b/internal/engine/infinity/chunk.go @@ -20,6 +20,7 @@ import ( "context" "encoding/json" "fmt" + "hash/fnv" "os" "path/filepath" "ragflow/internal/common" @@ -30,6 +31,7 @@ import ( "sort" "strconv" "strings" + "sync" infinity "github.com/infiniflow/infinity-go-sdk" "go.uber.org/zap" @@ -38,6 +40,13 @@ import ( // ChinesePunctRegex splits on comma, semicolon, Chinese punctuations, and newlines var ChinesePunctRegex = regexp.MustCompile(`[,,;;、\r\n]+`) +const ( + pagerankAdjustRetryCount = 2 + pagerankAdjustLockCount = 256 +) + +var pagerankAdjustLocks [pagerankAdjustLockCount]sync.Mutex + // CreateChunkStore creates a chunk table in Infinity // baseName is the table name prefix (e.g., "ragflow_") // The full table name is built as "{baseName}_{datasetID}" @@ -518,6 +527,97 @@ func (e *infinityEngine) UpdateChunks(ctx context.Context, condition map[string] return nil } +// AdjustChunkPagerank adjusts pagerank_fea and clamps it to [minWeight, maxWeight]. +func (e *infinityEngine) AdjustChunkPagerank(ctx context.Context, baseName, chunkID, datasetID string, delta, minWeight, maxWeight float64) error { + if baseName == "" { + return fmt.Errorf("index name cannot be empty") + } + if chunkID == "" { + return fmt.Errorf("chunk id cannot be empty") + } + if ctx == nil { + ctx = context.Background() + } + if e.client == nil || e.client.conn == nil { + return fmt.Errorf("Infinity client not initialized") + } + + tableName := buildChunkTableName(baseName, datasetID) + lock := pagerankAdjustLock(tableName + ":" + chunkID) + lock.Lock() + defer lock.Unlock() + + db, err := e.client.conn.GetDatabase(e.client.dbName) + if err != nil { + return fmt.Errorf("failed to get database: %w", err) + } + table, err := db.GetTable(tableName) + if err != nil { + return fmt.Errorf("failed to get table %s: %w", tableName, err) + } + + var lastErr error + filter := fmt.Sprintf("id = '%s'", escapeFilterValue(chunkID)) + for attempt := 0; attempt <= pagerankAdjustRetryCount; attempt++ { + if err := ctx.Err(); err != nil { + return err + } + result, err := table.Output([]string{common.PAGERANK_FLD, "row_id()"}).Filter(filter).ToResult() + if err != nil { + lastErr = err + continue + } + qr, ok := result.(*infinity.QueryResult) + if !ok { + return fmt.Errorf("unexpected query result type: %T", result) + } + + rowID, ok := firstQueryInt64(qr.Data, "ROW_ID", "row_id()", "row_id") + if !ok { + return fmt.Errorf("%w: %s", types.ErrDocumentNotFound, chunkID) + } + + currentWeight := 0.0 + if currentValue, ok := firstQueryValue(qr.Data, common.PAGERANK_FLD); ok { + if weight, ok := coerceToFloat(currentValue); ok { + currentWeight = weight + } + } + nextWeight := currentWeight + delta + if nextWeight < minWeight { + nextWeight = minWeight + } + if nextWeight > maxWeight { + nextWeight = maxWeight + } + if floatsEqual(currentWeight, nextWeight) { + return nil + } + + updateFilter := fmt.Sprintf("_row_id = %d AND %s = %s", rowID, common.PAGERANK_FLD, formatFilterFloat(currentWeight)) + if _, err := table.Update(updateFilter, map[string]interface{}{common.PAGERANK_FLD: nextWeight}); err != nil { + lastErr = err + continue + } + verifyResult, err := table.Output([]string{common.PAGERANK_FLD}).Filter(filter).ToResult() + if err != nil { + lastErr = err + continue + } + verifyQR, ok := verifyResult.(*infinity.QueryResult) + if !ok { + return fmt.Errorf("unexpected query result type: %T", verifyResult) + } + if currentValue, ok := firstQueryValue(verifyQR.Data, common.PAGERANK_FLD); ok { + if actualWeight, ok := coerceToFloat(currentValue); ok && floatsEqual(actualWeight, nextWeight) { + return nil + } + } + lastErr = fmt.Errorf("pagerank update conflict") + } + return fmt.Errorf("failed to adjust chunk pagerank: %w", lastErr) +} + // DeleteChunks deletes chunks from a dataset table // Table name format: {baseName}_{datasetID} // condition specifies which chunks to delete @@ -1976,6 +2076,77 @@ func escapeFilterValue(s string) string { return strings.ReplaceAll(s, "'", "''") } +func firstQueryValue(data map[string][]interface{}, columns ...string) (interface{}, bool) { + for _, column := range columns { + values, ok := data[column] + if ok && len(values) > 0 { + return values[0], true + } + } + return nil, false +} + +func firstQueryInt64(data map[string][]interface{}, columns ...string) (int64, bool) { + value, ok := firstQueryValue(data, columns...) + if !ok { + return 0, false + } + switch v := value.(type) { + case int: + return int64(v), true + case int8: + return int64(v), true + case int16: + return int64(v), true + case int32: + return int64(v), true + case int64: + return v, true + case uint: + if uint64(v) <= ^uint64(0)>>1 { + return int64(v), true + } + case uint8: + return int64(v), true + case uint16: + return int64(v), true + case uint32: + return int64(v), true + case uint64: + if v <= ^uint64(0)>>1 { + return int64(v), true + } + case float32: + return int64(v), true + case float64: + return int64(v), true + case string: + parsed, err := strconv.ParseInt(v, 10, 64) + if err == nil { + return parsed, true + } + } + return 0, false +} + +func pagerankAdjustLock(key string) *sync.Mutex { + hash := fnv.New32a() + _, _ = hash.Write([]byte(key)) + return &pagerankAdjustLocks[hash.Sum32()%pagerankAdjustLockCount] +} + +func formatFilterFloat(value float64) string { + return strconv.FormatFloat(value, 'f', -1, 64) +} + +func floatsEqual(a, b float64) bool { + diff := a - b + if diff < 0 { + diff = -diff + } + return diff < 1e-9 +} + // equivalentConditionToStr converts a condition map to an Infinity filter string func equivalentConditionToStr(condition map[string]interface{}) string { if len(condition) == 0 { diff --git a/internal/handler/chat_session.go b/internal/handler/chat_session.go index ce3adf8808..26f9c493ab 100644 --- a/internal/handler/chat_session.go +++ b/internal/handler/chat_session.go @@ -398,3 +398,60 @@ func (h *ChatSessionHandler) UpdateSession(c *gin.Context) { } jsonResponse(c, common.CodeSuccess, result, "success") } + +func (h *ChatSessionHandler) DeleteSessionMessage(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + userID := user.ID + chatID, sessionID, msgID := c.Param("chat_id"), c.Param("session_id"), c.Param("msg_id") + + result, code, err := h.chatSessionService.DeleteSessionMessage(userID, chatID, sessionID, msgID) + if err != nil { + if code == common.CodeAuthenticationError { + jsonResponse(c, code, false, err.Error()) + return + } + jsonError(c, code, err.Error()) + return + } + jsonResponse(c, common.CodeSuccess, result, "success") +} + +func (h *ChatSessionHandler) UpdateMessageFeedback(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + userID := user.ID + chatID, sessionID, msgID := c.Param("chat_id"), c.Param("session_id"), c.Param("msg_id") + + req := map[string]interface{}{} + if err := c.ShouldBindJSON(&req); err != nil { + if errors.Is(err, io.EOF) { + jsonError(c, common.CodeArgumentError, "Request body cannot be empty") + return + } + jsonError(c, common.CodeArgumentError, "Invalid request: "+err.Error()) + return + } + if len(req) == 0 { + jsonError(c, common.CodeArgumentError, "Request body cannot be empty") + return + } + + result, code, err := h.chatSessionService.UpdateMessageFeedback(c.Request.Context(), userID, chatID, sessionID, msgID, req) + if err != nil { + if code == common.CodeAuthenticationError { + jsonResponse(c, code, false, err.Error()) + return + } + jsonError(c, code, err.Error()) + return + } + jsonResponse(c, common.CodeSuccess, result, "success") +} diff --git a/internal/handler/chat_session_test.go b/internal/handler/chat_session_test.go index 4aa621068e..6a2e05dd05 100644 --- a/internal/handler/chat_session_test.go +++ b/internal/handler/chat_session_test.go @@ -72,3 +72,64 @@ func TestChatSessionHandlerUpdateSession_RejectsEmptyJSONObject(t *testing.T) { t.Fatalf("message=%v", got) } } + +func TestChatSessionHandlerUpdateMessageFeedback_RejectsEmptyBody(t *testing.T) { + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodPut, "/api/v1/chats/chat-1/sessions/session-1/messages/msg-1/feedback", nil) + ctx.Params = gin.Params{ + {Key: "chat_id", Value: "chat-1"}, + {Key: "session_id", Value: "session-1"}, + {Key: "msg_id", Value: "msg-1"}, + } + ctx.Set("user", &entity.User{ID: "user-1"}) + + handler := NewChatSessionHandler(service.NewChatSessionService(), nil) + handler.UpdateMessageFeedback(ctx) + + if recorder.Code != http.StatusOK { + t.Fatalf("status=%d", recorder.Code) + } + + var body map[string]interface{} + if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil { + t.Fatalf("decode response body: %v", err) + } + if got := body["code"]; got != float64(common.CodeArgumentError) { + t.Fatalf("code=%v", got) + } + if got := body["message"]; got != "Request body cannot be empty" { + t.Fatalf("message=%v", got) + } +} + +func TestChatSessionHandlerUpdateMessageFeedback_RejectsEmptyJSONObject(t *testing.T) { + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodPut, "/api/v1/chats/chat-1/sessions/session-1/messages/msg-1/feedback", strings.NewReader(`{}`)) + ctx.Request.Header.Set("Content-Type", "application/json") + ctx.Params = gin.Params{ + {Key: "chat_id", Value: "chat-1"}, + {Key: "session_id", Value: "session-1"}, + {Key: "msg_id", Value: "msg-1"}, + } + ctx.Set("user", &entity.User{ID: "user-1"}) + + handler := NewChatSessionHandler(service.NewChatSessionService(), nil) + handler.UpdateMessageFeedback(ctx) + + var body map[string]interface{} + if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil { + t.Fatalf("decode response body: %v", err) + } + if got := body["code"]; got != float64(common.CodeArgumentError) { + t.Fatalf("code=%v", got) + } + if got := body["message"]; got != "Request body cannot be empty" { + t.Fatalf("message=%v", got) + } +} diff --git a/internal/router/router.go b/internal/router/router.go index 5b3a8967fc..88d043f06f 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -300,6 +300,8 @@ func (r *Router) Setup(engine *gin.Engine) { chats.DELETE("/:chat_id/sessions", r.chatSessionHandler.DeleteSessions) chats.GET("/:chat_id/sessions/:session_id", r.chatSessionHandler.GetSession) chats.PATCH("/:chat_id/sessions/:session_id", r.chatSessionHandler.UpdateSession) + chats.DELETE("/:chat_id/sessions/:session_id/messages/:msg_id", r.chatSessionHandler.DeleteSessionMessage) + chats.PUT("/:chat_id/sessions/:session_id/messages/:msg_id/feedback", r.chatSessionHandler.UpdateMessageFeedback) } chat := v1.Group("/chat") diff --git a/internal/service/chat_session.go b/internal/service/chat_session.go index b700658a87..49ae9a7355 100644 --- a/internal/service/chat_session.go +++ b/internal/service/chat_session.go @@ -22,8 +22,11 @@ import ( "errors" "fmt" "math" + "os" "ragflow/internal/common" + "ragflow/internal/engine" "ragflow/internal/storage" + "strconv" "strings" "time" @@ -61,14 +64,15 @@ type chatPipelineRunner interface { // (api/db/services/chunk_feedback_service.py) call site at // api/apps/restful_apis/chat_api.py — that handler records the thumb // vote against every chunk that produced the assistant message, in -// addition to the session-level thumbup field. The Go stack does not -// yet have a chunk-feedback DAO, so this interface is the seam where -// one will plug in. Production uses *ChatSessionService itself via -// applyChunkFeedback (which currently no-ops with a debug log) so the -// handler can still update the session-level thumbup without crashing; -// tests can swap in a fake by setting ChatSessionService.chunkFeedbackApplier. +// addition to the session-level thumbup field. Production uses +// *ChatSessionService itself via applyChunkFeedback; tests can swap +// in a fake by setting ChatSessionService.chunkFeedbackApplier. type chunkFeedbackApplier interface { - applyChunkFeedback(tenantID string, reference map[string]interface{}, isPositive bool) (map[string]interface{}, error) + applyChunkFeedback(ctx context.Context, tenantID string, reference map[string]interface{}, isPositive bool) (map[string]interface{}, error) +} + +type chunkPagerankAdjuster interface { + AdjustChunkPagerank(ctx context.Context, indexName, chunkID, kbID string, delta, minWeight, maxWeight float64) error } // ChatSessionService chat session (conversation) service. @@ -78,6 +82,7 @@ type ChatSessionService struct { userTenantDAO userTenantStore pipeline chatPipelineRunner chunkFeedbackApplier chunkFeedbackApplier + docEngine engine.DocEngine } // NewChatSessionService create chat session service @@ -86,9 +91,131 @@ func NewChatSessionService() *ChatSessionService { chatSessionDAO: dao.NewChatSessionDAO(), userTenantDAO: dao.NewUserTenantDAO(), pipeline: NewChatPipelineService(), + docEngine: engine.Get(), } } +// SetChatSessionRequest set chat session request. +type SetChatSessionRequest struct { + SessionID string `json:"conversation_id,omitempty"` + DialogID string `json:"dialog_id,omitempty"` + Name string `json:"name,omitempty"` + IsNew bool `json:"is_new"` +} + +// SetChatSessionResponse set chat session response. +type SetChatSessionResponse struct { + *entity.ChatSession +} + +// SetChatSession creates or updates a chat session. +// Kept as a compatibility entrypoint for older chat-session callers. +func (s *ChatSessionService) SetChatSession(userID string, req *SetChatSessionRequest) (*SetChatSessionResponse, error) { + name := req.Name + if name == "" { + name = "New chat session" + } + if len(name) > 255 { + name = name[:255] + } + + if !req.IsNew { + updates := map[string]interface{}{ + "name": name, + "user_id": userID, + } + if err := s.chatSessionDAO.UpdateByID(req.SessionID, updates); err != nil { + return nil, errors.New("Chat session not found") + } + session, err := s.chatSessionDAO.GetByID(req.SessionID) + if err != nil { + return nil, errors.New("Fail to update a chat session") + } + return &SetChatSessionResponse{ChatSession: session}, nil + } + + dialog, err := s.chatSessionDAO.GetDialogByID(req.DialogID) + if err != nil { + return nil, errors.New("Dialog not found") + } + + prologue := "Hi! I'm your assistant. What can I do for you?" + if dialog.PromptConfig != nil { + if p, ok := dialog.PromptConfig["prologue"].(string); ok && p != "" { + prologue = p + } + } + messagesJSON, _ := json.Marshal([]map[string]interface{}{ + { + "role": "assistant", + "content": prologue, + }, + }) + referenceJSON, _ := json.Marshal([]interface{}{}) + + session := &entity.ChatSession{ + ID: common.GenerateUUID(), + DialogID: req.DialogID, + Name: &name, + Message: messagesJSON, + UserID: &userID, + Reference: referenceJSON, + } + if err := s.chatSessionDAO.Create(session); err != nil { + return nil, errors.New("Fail to create a chat session") + } + + return &SetChatSessionResponse{ChatSession: session}, nil +} + +// RemoveChatSessions removes chat sessions. +// Kept as a compatibility entrypoint for older chat-session callers. +func (s *ChatSessionService) RemoveChatSessions(userID string, chatSessions []string) error { + tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) + if err != nil { + return err + } + + tenantIDSet := make(map[string]bool) + for _, tid := range tenantIDs { + tenantIDSet[tid] = true + } + tenantIDSet[userID] = true + + for _, convID := range chatSessions { + session, err := s.chatSessionDAO.GetByID(convID) + if err != nil { + return fmt.Errorf("Chat session not found: %s", convID) + } + + isOwner := false + for tenantID := range tenantIDSet { + exists, err := s.chatSessionDAO.CheckDialogExists(tenantID, session.DialogID) + if err != nil { + return err + } + if exists { + isOwner = true + break + } + } + if !isOwner { + return errors.New("Only owner of chat session authorized for this operation") + } + + if err := s.chatSessionDAO.DeleteByID(convID); err != nil { + return err + } + } + + return nil +} + +// ListChatSessionsRequest list chat sessions request. +type ListChatSessionsRequest struct { + DialogID string `json:"dialog_id" binding:"required"` +} + // ListChatSessionsResponse list chat sessions response type ListChatSessionsResponse struct { Sessions []*entity.ChatSession @@ -542,7 +669,10 @@ func (s *ChatSessionService) DeleteSessionMessage(userID, chatID, sessionID, msg return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil } -func (s *ChatSessionService) UpdateMessageFeedback(userID, chatID, sessionID, msgID string, req map[string]interface{}) (*ChatSessionPayload, common.ErrorCode, error) { +func (s *ChatSessionService) UpdateMessageFeedback(ctx context.Context, userID, chatID, sessionID, msgID string, req map[string]interface{}) (*ChatSessionPayload, common.ErrorCode, error) { + if ctx == nil { + ctx = context.Background() + } ownerTenantID := "" tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) if err != nil { @@ -654,7 +784,7 @@ func (s *ChatSessionService) UpdateMessageFeedback(userID, chatID, sessionID, ms applier = s } if priorThumbBool, ok := priorThumb.(bool); ok && priorThumbBool != thumbup { - result, _ := applier.applyChunkFeedback(ownerTenantID, feedbackReference, !priorThumbBool) + result, _ := applier.applyChunkFeedback(ctx, ownerTenantID, feedbackReference, !priorThumbBool) if result != nil { common.Debug("Chunk feedback undo applied", zap.Any("success_count", result["success_count"]), @@ -662,7 +792,7 @@ func (s *ChatSessionService) UpdateMessageFeedback(userID, chatID, sessionID, ms ) } } - result, _ := applier.applyChunkFeedback(ownerTenantID, feedbackReference, thumbup) + result, _ := applier.applyChunkFeedback(ctx, ownerTenantID, feedbackReference, thumbup) if result != nil { common.Debug("Chunk feedback applied", zap.Any("success_count", result["success_count"]), @@ -674,32 +804,308 @@ func (s *ChatSessionService) UpdateMessageFeedback(userID, chatID, sessionID, ms return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil } -// applyChunkFeedback records a thumb vote against the chunks that -// produced a session message. Mirrors Python's -// ChunkFeedbackService.apply_feedback side effect (called from -// api/apps/restful_apis/chat_api.py when a user toggles a thumb on -// an assistant message). The Go persistence port for chunk feedback -// is intentionally not yet landed — the call here is a documented -// no-op so the session-level thumbup flow (the user-visible behavior) -// keeps working while a future PR ports the Python DAO. The returned -// counts let the caller log a "Chunk feedback applied: N succeeded, -// M failed" line consistent with the Python equivalent, so log -// scrapers don't see a regression in success/fail rates. -// -// Production callers should always go through the chunkFeedbackApplier -// field; this method is the default implementation used when that -// field is nil. -func (s *ChatSessionService) applyChunkFeedback(tenantID string, reference map[string]interface{}, isPositive bool) (map[string]interface{}, error) { - common.Debug("chunk feedback persistence not yet ported; dropping vote", - zap.String("tenant_id", tenantID), +const ( + upvoteWeightIncrement = 1 + downvoteWeightDecrement = 1 + minPagerankWeight = 0.0 + maxPagerankWeight = 100.0 + chunkFeedbackTimeout = 30 * time.Second +) + +type feedbackChunkRow struct { + chunkID string + kbID string + chunk map[string]interface{} +} + +// applyChunkFeedback records a thumb vote against the chunks that produced a +// session message. It mirrors Python's ChunkFeedbackService.apply_feedback: +// feature-flagged by CHUNK_FEEDBACK_ENABLED, split by relevance unless +// CHUNK_FEEDBACK_WEIGHTING=uniform, and clamped through the document engine. +func (s *ChatSessionService) applyChunkFeedback(ctx context.Context, tenantID string, reference map[string]interface{}, isPositive bool) (map[string]interface{}, error) { + if !chunkFeedbackEnabled() { + common.Debug("Chunk feedback feature is disabled") + return map[string]interface{}{ + "success_count": 0, + "fail_count": 0, + "chunk_ids": []string{}, + "disabled": true, + }, nil + } + if ctx == nil { + ctx = context.Background() + } + if _, ok := ctx.Deadline(); !ok { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, chunkFeedbackTimeout) + defer cancel() + } + + rows := feedbackRowsFromReference(reference) + chunkIDs := make([]string, 0, len(rows)) + for _, row := range rows { + chunkIDs = append(chunkIDs, row.chunkID) + } + if len(rows) == 0 { + common.Debug("No chunk IDs found in reference for feedback") + return map[string]interface{}{ + "success_count": 0, + "fail_count": 0, + "chunk_ids": chunkIDs, + }, nil + } + + signedBudget := float64(upvoteWeightIncrement) + if !isPositive { + signedBudget = -float64(downvoteWeightDecrement) + } + deltas := allocateFeedbackDeltas(rows, signedBudget, chunkFeedbackWeighting()) + + successCount := 0 + failCount := 0 + for _, delta := range deltas { + if delta.delta == 0 { + continue + } + if s.updateChunkWeight(ctx, tenantID, delta.chunkID, delta.kbID, delta.delta) { + successCount++ + } else { + failCount++ + } + } + + common.Info("Applied chunk feedback", zap.Bool("is_positive", isPositive), + zap.String("weighting", chunkFeedbackWeighting()), + zap.Int("success_count", successCount), + zap.Int("chunk_count", len(chunkIDs)), ) return map[string]interface{}{ - "success_count": 0, - "fail_count": 0, + "success_count": successCount, + "fail_count": failCount, + "chunk_ids": chunkIDs, }, nil } +type feedbackDelta struct { + chunkID string + kbID string + delta float64 +} + +func chunkFeedbackEnabled() bool { + return strings.ToLower(os.Getenv("CHUNK_FEEDBACK_ENABLED")) == "true" +} + +func chunkFeedbackWeighting() string { + weighting := strings.ToLower(strings.TrimSpace(os.Getenv("CHUNK_FEEDBACK_WEIGHTING"))) + if weighting == "uniform" || weighting == "relevance" { + return weighting + } + return "relevance" +} + +func feedbackRowsFromReference(reference map[string]interface{}) []feedbackChunkRow { + if len(reference) == 0 { + return nil + } + rawChunks, ok := reference["chunks"].([]interface{}) + if !ok { + if chunks, ok := reference["chunks"].([]map[string]interface{}); ok { + rows := make([]feedbackChunkRow, 0, len(chunks)) + for _, chunk := range chunks { + if row, ok := feedbackRowFromChunk(chunk); ok { + rows = append(rows, row) + } + } + return rows + } + return nil + } + rows := make([]feedbackChunkRow, 0, len(rawChunks)) + for _, raw := range rawChunks { + chunk, ok := raw.(map[string]interface{}) + if !ok { + continue + } + if row, ok := feedbackRowFromChunk(chunk); ok { + rows = append(rows, row) + } + } + return rows +} + +func feedbackRowFromChunk(chunk map[string]interface{}) (feedbackChunkRow, bool) { + chunkID := stringValue(chunk["id"]) + if chunkID == "" { + chunkID = stringValue(chunk["chunk_id"]) + } + kbID := stringValue(chunk["dataset_id"]) + if kbID == "" { + kbID = stringValue(chunk["kb_id"]) + } + if chunkID == "" || kbID == "" { + return feedbackChunkRow{}, false + } + return feedbackChunkRow{chunkID: chunkID, kbID: kbID, chunk: chunk}, true +} + +func allocateFeedbackDeltas(rows []feedbackChunkRow, signedBudget float64, weighting string) []feedbackDelta { + if len(rows) == 0 { + return nil + } + if weighting == "uniform" { + step := signedBudget / float64(len(rows)) + deltas := make([]feedbackDelta, 0, len(rows)) + for _, row := range rows { + deltas = append(deltas, feedbackDelta{chunkID: row.chunkID, kbID: row.kbID, delta: step}) + } + return deltas + } + + magnitudes := make([]float64, 0, len(rows)) + for _, row := range rows { + signal := retrievalSignal(row.chunk) + if signal <= 0 { + signal = 1 + } + magnitudes = append(magnitudes, signal) + } + parts := splitFloatBudget(magnitudes, math.Abs(signedBudget)) + sign := math.Copysign(1, signedBudget) + deltas := make([]feedbackDelta, 0, len(rows)) + for i, row := range rows { + deltas = append(deltas, feedbackDelta{chunkID: row.chunkID, kbID: row.kbID, delta: sign * parts[i]}) + } + return deltas +} + +func retrievalSignal(chunk map[string]interface{}) float64 { + best := 0.0 + for _, key := range []string{"similarity", "vector_similarity", "term_similarity"} { + val, ok := floatValue(chunk[key]) + if ok && !math.IsInf(val, 0) && !math.IsNaN(val) && val > best { + best = val + } + } + return best +} + +func splitFloatBudget(magnitudes []float64, budget float64) []float64 { + n := len(magnitudes) + out := make([]float64, n) + if n == 0 || budget == 0 { + return out + } + total := 0.0 + for _, magnitude := range magnitudes { + total += magnitude + } + if total <= 0 { + base := budget / float64(n) + for i := range out { + out[i] = base + } + return out + } + + for i, magnitude := range magnitudes { + out[i] = budget * magnitude / total + } + return out +} + +func (s *ChatSessionService) updateChunkWeight(ctx context.Context, tenantID, chunkID, kbID string, delta float64) bool { + docEngine := s.docEngine + if docEngine == nil { + docEngine = engine.Get() + } + if docEngine == nil { + common.Warn("Document engine is not initialized; chunk feedback skipped", + zap.String("chunk_id", chunkID), + zap.String("kb_id", kbID), + ) + return false + } + + indexName := fmt.Sprintf("ragflow_%s", tenantID) + if adjuster, ok := docEngine.(chunkPagerankAdjuster); ok { + if err := adjuster.AdjustChunkPagerank(ctx, indexName, chunkID, kbID, delta, minPagerankWeight, maxPagerankWeight); err != nil { + common.Warn("Failed atomic pagerank adjust for chunk", + zap.String("chunk_id", chunkID), + zap.Error(err), + ) + return false + } + return true + } + + rawChunk, err := docEngine.GetChunk(ctx, indexName, chunkID, []string{kbID}) + if err != nil { + common.Warn("Chunk not found for feedback", + zap.String("chunk_id", chunkID), + zap.String("index", indexName), + zap.Error(err), + ) + return false + } + chunk, ok := rawChunk.(map[string]interface{}) + if !ok { + common.Warn("Unexpected chunk shape for feedback", + zap.String("chunk_id", chunkID), + zap.String("kb_id", kbID), + zap.String("index", indexName), + zap.String("chunk_type", fmt.Sprintf("%T", rawChunk)), + ) + return false + } + + currentWeight, _ := floatValue(chunk[common.PAGERANK_FLD]) + nextWeight := currentWeight + delta + if nextWeight < minPagerankWeight { + nextWeight = minPagerankWeight + } + if nextWeight > maxPagerankWeight { + nextWeight = maxPagerankWeight + } + + newValue := map[string]interface{}{common.PAGERANK_FLD: nextWeight} + if nextWeight <= 0 && strings.ToLower(docEngine.GetType()) == string(engine.EngineElasticsearch) { + newValue = map[string]interface{}{"remove": common.PAGERANK_FLD} + } + if err := docEngine.UpdateChunks(ctx, map[string]interface{}{"id": chunkID}, newValue, indexName, kbID); err != nil { + common.Warn("Failed to update chunk pagerank", + zap.String("chunk_id", chunkID), + zap.Error(err), + ) + return false + } + return true +} + +func floatValue(value interface{}) (float64, bool) { + switch v := value.(type) { + case float64: + return v, true + case float32: + return float64(v), true + case int: + return float64(v), true + case int64: + return float64(v), true + case int32: + return float64(v), true + case json.Number: + f, err := v.Float64() + return f, err == nil + case string: + f, err := strconv.ParseFloat(strings.TrimSpace(v), 64) + return f, err == nil + default: + return 0, false + } +} + func (s *ChatSessionService) ensureOwnedChat(userID, chatID string) (bool, error) { tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) if err != nil { @@ -851,6 +1257,180 @@ func isChatSessionNotFound(err error) bool { return errors.Is(err, gorm.ErrRecordNotFound) } +// Completion performs chat completion with full RAG support via ChatPipelineService. +// Kept as a compatibility entrypoint for callers that still use the pre-ChatCompletions API. +func (s *ChatSessionService) Completion(userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string) (map[string]interface{}, error) { + if len(messages) == 0 { + return nil, errors.New("messages cannot be empty") + } + lastRole, _ := messages[len(messages)-1]["role"].(string) + if lastRole != "user" { + return nil, errors.New("the last content of this conversation is not from user") + } + + session, err := s.chatSessionDAO.GetByID(conversationID) + if err != nil { + return nil, errors.New("Conversation not found") + } + + dialog, err := s.chatSessionDAO.GetDialogByID(session.DialogID) + if err != nil { + return nil, errors.New("Dialog not found") + } + + sessionMessages := s.buildSessionMessages(session, messages) + reference := s.initializeReference(session) + + isEmbedded := llmID != "" + if llmID != "" { + hasKey, err := s.checkTenantLLMAPIKey(dialog.TenantID, llmID) + if err != nil || !hasKey { + return nil, fmt.Errorf("Cannot use specified model %s", llmID) + } + dialog.LLMID = llmID + if chatModelConfig != nil { + dialog.LLMSetting = chatModelConfig + } + } + + kwargs := chatModelConfig + if kwargs == nil { + kwargs = map[string]interface{}{} + } + resultChan, err := s.pipeline.AsyncChat(context.Background(), dialog, messages, false, kwargs) + if err != nil { + return nil, err + } + + var answer strings.Builder + var finalRef map[string]interface{} + for result := range resultChan { + if result.Answer != "" { + answer.WriteString(result.Answer) + } + if result.Reference != nil { + finalRef = result.Reference + } + } + + ans := map[string]interface{}{ + "answer": answer.String(), + "reference": finalRef, + "final": true, + } + result := s.structureAnswerWithConv(session, ans, messageID, session.ID, reference) + + if !isEmbedded { + sessionMessages = append(sessionMessages, map[string]interface{}{ + "role": "assistant", + "content": answer.String(), + "id": messageID, + "created_at": float64(time.Now().Unix()), + }) + s.updateSessionMessages(session, sessionMessages, reference) + } + + return result, nil +} + +// CompletionStream performs streaming chat completion with full RAG support via ChatPipelineService. +// Kept as a compatibility entrypoint for callers that still use the pre-ChatCompletions API. +func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string, streamChan chan<- string) error { + if ctx == nil { + ctx = context.Background() + } + + if len(messages) == 0 { + streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "messages cannot be empty", "data": {"answer": "**ERROR**: messages cannot be empty", "reference": []}}`) + return errors.New("messages cannot be empty") + } + lastRole, _ := messages[len(messages)-1]["role"].(string) + if lastRole != "user" { + streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "the last content of this conversation is not from user", "data": {"answer": "**ERROR**: the last content of this conversation is not from user", "reference": []}}`) + return errors.New("the last content of this conversation is not from user") + } + + session, err := s.chatSessionDAO.GetByID(conversationID) + if err != nil { + streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "Conversation not found", "data": {"answer": "**ERROR**: Conversation not found", "reference": []}}`) + return errors.New("Conversation not found") + } + + dialog, err := s.chatSessionDAO.GetDialogByID(session.DialogID) + if err != nil { + streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "Dialog not found", "data": {"answer": "**ERROR**: Dialog not found", "reference": []}}`) + return errors.New("Dialog not found") + } + + sessionMessages := s.buildSessionMessages(session, messages) + reference := s.initializeReference(session) + + isEmbedded := llmID != "" + if llmID != "" { + hasKey, err := s.checkTenantLLMAPIKey(dialog.TenantID, llmID) + if err != nil || !hasKey { + errMsg := fmt.Sprintf(`{"code": 500, "message": "Cannot use specified model %s", "data": {"answer": "**ERROR**: Cannot use specified model", "reference": []}}`, llmID) + streamChan <- fmt.Sprintf("data: %s\n\n", errMsg) + return fmt.Errorf("Cannot use specified model %s", llmID) + } + dialog.LLMID = llmID + if chatModelConfig != nil { + dialog.LLMSetting = chatModelConfig + } + } + + kwargs := chatModelConfig + if kwargs == nil { + kwargs = map[string]interface{}{} + } + resultChan, err := s.pipeline.AsyncChat(ctx, dialog, messages, true, kwargs) + if err != nil { + streamChan <- fmt.Sprintf("data: %s\n\n", fmt.Sprintf(`{"code": 500, "message": "%s", "data": {"answer": "**ERROR**: %s", "reference": []}}`, err.Error(), err.Error())) + return err + } + + var fullAnswer strings.Builder + for result := range resultChan { + if result.Reference != nil && len(reference) > 0 { + reference[len(reference)-1] = result.Reference + } + if result.Final { + if result.Answer != "" { + fullAnswer.Reset() + fullAnswer.WriteString(result.Answer) + } + } else if result.Answer != "" { + fullAnswer.WriteString(result.Answer) + } + ans := s.structureAnswer(session, fullAnswer.String(), messageID, session.ID, reference) + data, _ := json.Marshal(map[string]interface{}{ + "code": 0, + "message": "", + "data": ans, + }) + streamChan <- fmt.Sprintf("data: %s\n\n", string(data)) + } + + finalData, _ := json.Marshal(map[string]interface{}{ + "code": 0, + "message": "", + "data": true, + }) + streamChan <- fmt.Sprintf("data: %s\n\n", string(finalData)) + + if !isEmbedded { + sessionMessages = append(sessionMessages, map[string]interface{}{ + "role": "assistant", + "content": fullAnswer.String(), + "id": messageID, + "created_at": float64(time.Now().Unix()), + }) + s.updateSessionMessages(session, sessionMessages, reference) + } + + return nil +} + // ChatCompletions handles chat completion matching Python's session_completion. // When stream=true, returns nil result and streams SSE via streamChan. // When stream=false, returns the structured answer map. diff --git a/internal/service/chat_session_test.go b/internal/service/chat_session_test.go index 24cf4d91bf..969a63ed1e 100644 --- a/internal/service/chat_session_test.go +++ b/internal/service/chat_session_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "math" "reflect" "strings" "sync" @@ -11,6 +12,7 @@ import ( "gorm.io/gorm" "ragflow/internal/common" + "ragflow/internal/engine" "ragflow/internal/entity" ) @@ -171,6 +173,88 @@ func makeResultChan(results ...AsyncChatResult) <-chan AsyncChatResult { return ch } +type feedbackContextKey struct{} + +type fakeFeedbackDocEngine struct { + engine.DocEngine + adjustCalls []struct { + ctx context.Context + indexName string + chunkID string + kbID string + delta float64 + minWeight float64 + maxWeight float64 + } +} + +func (f *fakeFeedbackDocEngine) GetType() string { return "elasticsearch" } + +func (f *fakeFeedbackDocEngine) AdjustChunkPagerank(ctx context.Context, indexName, chunkID, kbID string, delta, minWeight, maxWeight float64) error { + f.adjustCalls = append(f.adjustCalls, struct { + ctx context.Context + indexName string + chunkID string + kbID string + delta float64 + minWeight float64 + maxWeight float64 + }{ctx: ctx, indexName: indexName, chunkID: chunkID, kbID: kbID, delta: delta, minWeight: minWeight, maxWeight: maxWeight}) + return nil +} + +type fakeInfinityFeedbackDocEngine struct { + fakeFeedbackDocEngine + getChunkCalled bool +} + +func (f *fakeInfinityFeedbackDocEngine) GetType() string { return string(engine.EngineInfinity) } + +func (f *fakeInfinityFeedbackDocEngine) GetChunk(ctx context.Context, indexName, chunkID string, datasetIDs []string) (interface{}, error) { + f.getChunkCalled = true + return nil, errors.New("fallback should not be used") +} + +type fakeFallbackFeedbackDocEngine struct { + engine.DocEngine + chunks map[string]map[string]interface{} + updateCalls []struct { + ctx context.Context + condition map[string]interface{} + newValue map[string]interface{} + indexName string + kbID string + } +} + +func (f *fakeFallbackFeedbackDocEngine) GetType() string { return "elasticsearch" } + +func (f *fakeFallbackFeedbackDocEngine) GetChunk(ctx context.Context, indexName, chunkID string, datasetIDs []string) (interface{}, error) { + chunk, ok := f.chunks[chunkID] + if !ok { + return nil, gorm.ErrRecordNotFound + } + return chunk, nil +} + +func (f *fakeFallbackFeedbackDocEngine) UpdateChunks(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, indexName string, kbID string) error { + f.updateCalls = append(f.updateCalls, struct { + ctx context.Context + condition map[string]interface{} + newValue map[string]interface{} + indexName string + kbID string + }{ctx: ctx, condition: condition, newValue: newValue, indexName: indexName, kbID: kbID}) + return nil +} + +func requireFloatClose(t *testing.T, got, want float64) { + t.Helper() + if math.Abs(got-want) > 1e-9 { + t.Fatalf("got %v want %v", got, want) + } +} + // =================================================================== // ListChatSessions tests // =================================================================== @@ -410,6 +494,591 @@ func TestUpdateSession_NotFound(t *testing.T) { } } +// =================================================================== +// Session message delete / feedback tests +// =================================================================== + +func TestDeleteSessionMessage_RemovesMessagePairAndReference(t *testing.T) { + store := newFakeSessionStore() + store.sessions["session-1"] = &entity.ChatSession{ + ID: "session-1", + DialogID: "chat-1", + Message: json.RawMessage(`[ + {"role":"assistant","content":"Welcome!"}, + {"role":"user","content":"first","id":"msg-1"}, + {"role":"assistant","content":"answer 1","id":"msg-1"}, + {"role":"user","content":"second","id":"msg-2"}, + {"role":"assistant","content":"answer 2","id":"msg-2"} + ]`), + Reference: json.RawMessage(`[ + {"chunks":[{"id":"chunk-1","kb_id":"kb-1"}]}, + {"chunks":[{"id":"chunk-2","kb_id":"kb-2"}]} + ]`), + } + store.dialogExists["tenant-1|chat-1"] = true + + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-1"}}, + pipeline: &fakePipeline{}, + } + + resp, code, err := svc.DeleteSessionMessage("user-1", "chat-1", "session-1", "msg-1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("code=%v", code) + } + if len(resp.Messages) != 3 { + t.Fatalf("messages=%#v", resp.Messages) + } + if resp.Messages[1]["id"] != "msg-2" || resp.Messages[2]["id"] != "msg-2" { + t.Fatalf("remaining pair=%#v", resp.Messages) + } + if len(resp.Reference) != 1 { + t.Fatalf("reference=%#v", resp.Reference) + } + ref, ok := resp.Reference[0].(map[string]interface{}) + if !ok { + t.Fatalf("reference type=%T", resp.Reference[0]) + } + chunks, ok := ref["chunks"].([]FormattedChunk) + if !ok || len(chunks) != 1 || chunks[0].ID != "chunk-2" { + t.Fatalf("remaining chunks=%#v", ref["chunks"]) + } +} + +func TestUpdateMessageFeedback_AppliesChunkFeedbackWithResolvedTenantAndContext(t *testing.T) { + t.Setenv("CHUNK_FEEDBACK_ENABLED", "true") + t.Setenv("CHUNK_FEEDBACK_WEIGHTING", "uniform") + + store := newFakeSessionStore() + store.sessions["session-1"] = &entity.ChatSession{ + ID: "session-1", + DialogID: "chat-1", + Message: json.RawMessage(`[ + {"role":"assistant","content":"Welcome!"}, + {"role":"user","content":"question","id":"msg-1"}, + {"role":"assistant","content":"answer","id":"msg-1","feedback":"old"} + ]`), + Reference: json.RawMessage(`[ + {"chunks":[{"id":"chunk-1","kb_id":"kb-1","similarity":0.9}]} + ]`), + } + store.dialogExists["tenant-owner|chat-1"] = true + docEngine := &fakeFeedbackDocEngine{} + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-owner"}}, + pipeline: &fakePipeline{}, + docEngine: docEngine, + } + ctx := context.WithValue(context.Background(), feedbackContextKey{}, "request-context") + + resp, code, err := svc.UpdateMessageFeedback(ctx, "user-1", "chat-1", "session-1", "msg-1", map[string]interface{}{ + "thumbup": true, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("code=%v", code) + } + if len(docEngine.adjustCalls) != 1 { + t.Fatalf("adjust calls=%d", len(docEngine.adjustCalls)) + } + call := docEngine.adjustCalls[0] + if call.indexName != "ragflow_tenant-owner" || call.chunkID != "chunk-1" || call.kbID != "kb-1" { + t.Fatalf("call=%#v", call) + } + if call.delta != 1 || call.minWeight != 0 || call.maxWeight != 100 { + t.Fatalf("weights=%#v", call) + } + if got := call.ctx.Value(feedbackContextKey{}); got != "request-context" { + t.Fatalf("context value=%v", got) + } + assistant := resp.Messages[2] + if assistant["thumbup"] != true { + t.Fatalf("assistant=%#v", assistant) + } + if _, ok := assistant["feedback"]; ok { + t.Fatalf("positive feedback should remove text feedback: %#v", assistant) + } +} + +func TestUpdateMessageFeedback_ToggleUsesResolvedTenantForUndoAndApply(t *testing.T) { + t.Setenv("CHUNK_FEEDBACK_ENABLED", "true") + t.Setenv("CHUNK_FEEDBACK_WEIGHTING", "uniform") + + store := newFakeSessionStore() + store.sessions["session-1"] = &entity.ChatSession{ + ID: "session-1", + DialogID: "chat-1", + Message: json.RawMessage(`[ + {"role":"assistant","content":"Welcome!"}, + {"role":"user","content":"question","id":"msg-1"}, + {"role":"assistant","content":"answer","id":"msg-1","thumbup":true} + ]`), + Reference: json.RawMessage(`[ + {"chunks":[{"chunk_id":"chunk-1","dataset_id":"kb-1"}]} + ]`), + } + store.dialogExists["tenant-owner|chat-1"] = true + docEngine := &fakeFeedbackDocEngine{} + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-owner"}}, + pipeline: &fakePipeline{}, + docEngine: docEngine, + } + + resp, code, err := svc.UpdateMessageFeedback(context.Background(), "user-1", "chat-1", "session-1", "msg-1", map[string]interface{}{ + "thumbup": false, + "feedback": "not useful", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("code=%v", code) + } + if len(docEngine.adjustCalls) != 2 { + t.Fatalf("adjust calls=%d", len(docEngine.adjustCalls)) + } + for _, call := range docEngine.adjustCalls { + if call.indexName != "ragflow_tenant-owner" || call.delta != -1 { + t.Fatalf("call=%#v", call) + } + } + assistant := resp.Messages[2] + if assistant["thumbup"] != false || assistant["feedback"] != "not useful" { + t.Fatalf("assistant=%#v", assistant) + } +} + +func TestApplyChunkFeedback_DisabledDoesNotTouchEngine(t *testing.T) { + t.Setenv("CHUNK_FEEDBACK_ENABLED", "false") + docEngine := &fakeFeedbackDocEngine{} + svc := &ChatSessionService{docEngine: docEngine} + + result, err := svc.applyChunkFeedback(context.Background(), "tenant-1", map[string]interface{}{ + "chunks": []interface{}{map[string]interface{}{"id": "chunk-1", "kb_id": "kb-1"}}, + }, true) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result["disabled"] != true { + t.Fatalf("result=%#v", result) + } + if len(docEngine.adjustCalls) != 0 { + t.Fatalf("adjust calls=%d", len(docEngine.adjustCalls)) + } +} + +func TestApplyChunkFeedback_UniformSplitsOneVoteAcrossChunks(t *testing.T) { + t.Setenv("CHUNK_FEEDBACK_ENABLED", "true") + t.Setenv("CHUNK_FEEDBACK_WEIGHTING", "uniform") + + docEngine := &fakeFeedbackDocEngine{} + svc := &ChatSessionService{docEngine: docEngine} + + result, err := svc.applyChunkFeedback(context.Background(), "tenant-1", map[string]interface{}{ + "chunks": []interface{}{ + map[string]interface{}{"id": "chunk-1", "kb_id": "kb-1"}, + map[string]interface{}{"id": "chunk-2", "kb_id": "kb-1"}, + }, + }, true) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result["success_count"] != 2 { + t.Fatalf("result=%#v", result) + } + if len(docEngine.adjustCalls) != 2 { + t.Fatalf("adjust calls=%d", len(docEngine.adjustCalls)) + } + requireFloatClose(t, docEngine.adjustCalls[0].delta, 0.5) + requireFloatClose(t, docEngine.adjustCalls[1].delta, 0.5) + requireFloatClose(t, docEngine.adjustCalls[0].delta+docEngine.adjustCalls[1].delta, 1) +} + +func TestApplyChunkFeedback_RelevanceDistributesOneVoteBySignals(t *testing.T) { + t.Setenv("CHUNK_FEEDBACK_ENABLED", "true") + t.Setenv("CHUNK_FEEDBACK_WEIGHTING", "relevance") + + docEngine := &fakeFeedbackDocEngine{} + svc := &ChatSessionService{docEngine: docEngine} + + result, err := svc.applyChunkFeedback(context.Background(), "tenant-1", map[string]interface{}{ + "chunks": []interface{}{ + map[string]interface{}{"id": "chunk-1", "kb_id": "kb-1", "similarity": 2.0}, + map[string]interface{}{"id": "chunk-2", "kb_id": "kb-1", "vector_similarity": 1.0}, + map[string]interface{}{"id": "chunk-3", "kb_id": "kb-1"}, + }, + }, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result["success_count"] != 3 { + t.Fatalf("result=%#v", result) + } + if len(docEngine.adjustCalls) != 3 { + t.Fatalf("adjust calls=%d", len(docEngine.adjustCalls)) + } + requireFloatClose(t, docEngine.adjustCalls[0].delta, -0.5) + requireFloatClose(t, docEngine.adjustCalls[1].delta, -0.25) + requireFloatClose(t, docEngine.adjustCalls[2].delta, -0.25) + requireFloatClose(t, docEngine.adjustCalls[0].delta+docEngine.adjustCalls[1].delta+docEngine.adjustCalls[2].delta, -1) +} + +func TestUpdateChunkWeight_InfinityUsesAtomicAdjuster(t *testing.T) { + docEngine := &fakeInfinityFeedbackDocEngine{} + svc := &ChatSessionService{docEngine: docEngine} + + if ok := svc.updateChunkWeight(context.Background(), "tenant-1", "chunk-1", "kb-1", 0.25); !ok { + t.Fatal("expected updateChunkWeight to succeed") + } + if docEngine.getChunkCalled { + t.Fatal("expected Infinity adjuster path, got GetChunk fallback") + } + if len(docEngine.adjustCalls) != 1 { + t.Fatalf("adjust calls=%d", len(docEngine.adjustCalls)) + } + call := docEngine.adjustCalls[0] + if call.indexName != "ragflow_tenant-1" || call.chunkID != "chunk-1" || call.kbID != "kb-1" { + t.Fatalf("call=%#v", call) + } + requireFloatClose(t, call.delta, 0.25) +} + +func TestApplyChunkFeedback_FallbackClampsAndRemovesPagerank(t *testing.T) { + t.Setenv("CHUNK_FEEDBACK_ENABLED", "true") + t.Setenv("CHUNK_FEEDBACK_WEIGHTING", "uniform") + + docEngine := &fakeFallbackFeedbackDocEngine{ + chunks: map[string]map[string]interface{}{ + "chunk-1": {common.PAGERANK_FLD: 0}, + }, + } + svc := &ChatSessionService{docEngine: docEngine} + + result, err := svc.applyChunkFeedback(context.Background(), "tenant-1", map[string]interface{}{ + "chunks": []interface{}{map[string]interface{}{"id": "chunk-1", "kb_id": "kb-1"}}, + }, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result["success_count"] != 1 { + t.Fatalf("result=%#v", result) + } + if len(docEngine.updateCalls) != 1 { + t.Fatalf("update calls=%d", len(docEngine.updateCalls)) + } + call := docEngine.updateCalls[0] + if call.indexName != "ragflow_tenant-1" || call.kbID != "kb-1" { + t.Fatalf("call=%#v", call) + } + if call.condition["id"] != "chunk-1" || call.newValue["remove"] != common.PAGERANK_FLD { + t.Fatalf("call=%#v", call) + } +} + +// =================================================================== +// Completion tests +// =================================================================== + +func TestCompletion_Success(t *testing.T) { + store := newFakeSessionStore() + session := &entity.ChatSession{ + ID: "session-1", DialogID: "dialog-1", + Message: json.RawMessage(`[{"role":"assistant","content":"Welcome!"}]`), + Reference: json.RawMessage(`[]`), + } + store.sessions["session-1"] = session + store.dialogs["dialog-1"] = &entity.Chat{ + ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory", + LLMSetting: entity.JSONMap{}, + } + + pipeline := &fakePipeline{ + resultChan: makeResultChan( + AsyncChatResult{Answer: "Hello", Reference: map[string]interface{}{"chunks": []interface{}{}}}, + AsyncChatResult{Answer: " world", Final: true, Reference: map[string]interface{}{"chunks": []interface{}{}}}, + ), + } + + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: pipeline, + } + + result, err := svc.Completion("user-1", "session-1", []map[string]interface{}{ + {"role": "user", "content": "hi"}, + }, "", nil, "msg-1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + ans, _ := result["answer"].(string) + if ans != "Hello world" { + t.Fatalf("expected answer 'Hello world', got %q", ans) + } + + got := parseMessages(store.sessions["session-1"].Message) + if len(got) != 3 { + t.Fatalf("stored messages=%#v", got) + } + if got[0]["role"] != "assistant" || got[0]["content"] != "Welcome!" { + t.Fatalf("stored prologue=%#v", got[0]) + } + if got[1]["role"] != "user" || got[1]["content"] != "hi" { + t.Fatalf("stored user message=%#v", got[1]) + } + if got[2]["role"] != "assistant" || got[2]["content"] != "Hello world" || got[2]["id"] != "msg-1" { + t.Fatalf("stored assistant message=%#v", got[2]) + } +} + +func TestCompletion_EmptyMessages(t *testing.T) { + svc := &ChatSessionService{ + chatSessionDAO: &fakeSessionStore{}, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, + } + + _, err := svc.Completion("user-1", "session-1", nil, "", nil, "msg-1") + if err == nil || err.Error() != "messages cannot be empty" { + t.Fatalf("expected 'messages cannot be empty', got %v", err) + } +} + +func TestCompletion_LastMessageNotFromUser(t *testing.T) { + svc := &ChatSessionService{ + chatSessionDAO: &fakeSessionStore{}, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, + } + + _, err := svc.Completion("user-1", "session-1", []map[string]interface{}{ + {"role": "assistant", "content": "hello"}, + }, "", nil, "msg-1") + if err == nil || !strings.Contains(err.Error(), "not from user") { + t.Fatalf("expected 'not from user' error, got %v", err) + } +} + +func TestCompletion_ConversationNotFound(t *testing.T) { + store := newFakeSessionStore() + + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, + } + + _, err := svc.Completion("user-1", "missing", []map[string]interface{}{ + {"role": "user", "content": "hi"}, + }, "", nil, "msg-1") + if err == nil || err.Error() != "Conversation not found" { + t.Fatalf("expected 'Conversation not found', got %v", err) + } +} + +func TestCompletion_DialogNotFound(t *testing.T) { + store := newFakeSessionStore() + store.sessions["session-1"] = &entity.ChatSession{ + ID: "session-1", DialogID: "dialog-1", + Message: json.RawMessage(`[]`), + Reference: json.RawMessage(`[]`), + } + + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, + } + + _, err := svc.Completion("user-1", "session-1", []map[string]interface{}{ + {"role": "user", "content": "hi"}, + }, "", nil, "msg-1") + if err == nil || err.Error() != "Dialog not found" { + t.Fatalf("expected 'Dialog not found', got %v", err) + } +} + +func TestCompletion_PipelineError(t *testing.T) { + store := newFakeSessionStore() + store.sessions["session-1"] = &entity.ChatSession{ + ID: "session-1", DialogID: "dialog-1", + Message: json.RawMessage(`[]`), + Reference: json.RawMessage(`[]`), + } + store.dialogs["dialog-1"] = &entity.Chat{ + ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory", + LLMSetting: entity.JSONMap{}, + } + + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{err: errors.New("model unavailable")}, + } + + _, err := svc.Completion("user-1", "session-1", []map[string]interface{}{ + {"role": "user", "content": "hi"}, + }, "", nil, "msg-1") + if err == nil || err.Error() != "model unavailable" { + t.Fatalf("expected 'model unavailable' error, got %v", err) + } +} + +// =================================================================== +// CompletionStream tests +// =================================================================== + +func readStreamChan(ch <-chan string, n int) []string { + var msgs []string + for i := 0; i < n; i++ { + select { + case msg, ok := <-ch: + if !ok { + return msgs + } + msgs = append(msgs, msg) + default: + return msgs + } + } + return msgs +} + +func TestCompletionStream_Success(t *testing.T) { + store := newFakeSessionStore() + store.sessions["session-1"] = &entity.ChatSession{ + ID: "session-1", DialogID: "dialog-1", + Message: json.RawMessage(`{"messages":[{"role":"assistant","content":"Welcome!"}]}`), + Reference: json.RawMessage(`[]`), + } + store.dialogs["dialog-1"] = &entity.Chat{ + ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory", + LLMSetting: entity.JSONMap{}, + } + + pipeline := &fakePipeline{ + resultChan: makeResultChan( + AsyncChatResult{Answer: "stream", Reference: map[string]interface{}{"chunks": []interface{}{}}}, + AsyncChatResult{Answer: " answer", Reference: map[string]interface{}{"chunks": []interface{}{}}}, + ), + } + + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: pipeline, + } + + streamChan := make(chan string, 10) + err := svc.CompletionStream(context.Background(), "user-1", "session-1", []map[string]interface{}{ + {"role": "user", "content": "hi"}, + }, "", nil, "msg-1", streamChan) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Should receive data events and final signal + msgs := readStreamChan(streamChan, 5) + if len(msgs) < 3 { + t.Fatalf("expected at least 3 stream messages, got %d: %v", len(msgs), msgs) + } + // Check final signal + finalFound := false + for _, m := range msgs { + if strings.Contains(m, `"data":true`) { + finalFound = true + break + } + } + if !finalFound { + t.Fatal("expected final=true signal in stream") + } + + got := parseMessages(store.sessions["session-1"].Message) + if len(got) != 3 { + t.Fatalf("stored messages=%#v", got) + } + if got[0]["role"] != "assistant" || got[0]["content"] != "Welcome!" { + t.Fatalf("stored prologue=%#v", got[0]) + } + if got[1]["role"] != "user" || got[1]["content"] != "hi" { + t.Fatalf("stored user message=%#v", got[1]) + } + if got[2]["role"] != "assistant" || got[2]["content"] != "stream answer" || got[2]["id"] != "msg-1" { + t.Fatalf("stored assistant message=%#v", got[2]) + } +} + +func TestStructureAnswerWithConv_ParsesArrayMessages(t *testing.T) { + session := &entity.ChatSession{ + ID: "session-1", + Message: json.RawMessage(`[{"role":"assistant","content":"Welcome!"}]`), + } + svc := &ChatSessionService{} + + ans := svc.structureAnswerWithConv(session, map[string]interface{}{ + "answer": "Final answer", + "reference": map[string]interface{}{"chunks": []interface{}{}}, + "final": true, + }, "msg-1", "session-1", []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) + + if ans["id"] != "msg-1" || ans["session_id"] != "session-1" { + t.Fatalf("ans=%#v", ans) + } + + got := parseMessages(session.Message) + if len(got) != 1 { + t.Fatalf("stored messages=%#v", got) + } + if got[0]["role"] != "assistant" || got[0]["content"] != "Final answer" || got[0]["id"] != "msg-1" { + t.Fatalf("stored assistant message=%#v", got[0]) + } +} + +func TestParseMessages_LegacyWrappedObject(t *testing.T) { + got := parseMessages(json.RawMessage(`{"messages":[{"role":"assistant","content":"legacy"}]}`)) + if !reflect.DeepEqual(got, []map[string]interface{}{{"role": "assistant", "content": "legacy"}}) { + t.Fatalf("messages=%#v", got) + } +} + +func TestBuildSessionPayload_EmptyCollectionsEncodeAsEmptyArrays(t *testing.T) { + svc := &ChatSessionService{} + payload := svc.buildSessionPayload(&entity.ChatSession{ + ID: "session-1", + DialogID: "chat-1", + Message: nil, + Reference: json.RawMessage(`null`), + }, nil, false) + + if payload.Messages == nil { + t.Fatal("messages is nil") + } + if payload.Reference == nil { + t.Fatal("reference is nil") + } + + body, err := json.Marshal(payload) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if !strings.Contains(string(body), `"messages":[]`) { + t.Fatalf("messages did not encode as empty array: %s", string(body)) + } + if !strings.Contains(string(body), `"reference":[]`) { + t.Fatalf("reference did not encode as empty array: %s", string(body)) + } +} + func TestParseCollections_ReturnEmptySlicesForMissingOrNull(t *testing.T) { messageInputs := []json.RawMessage{ nil,