mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-02 16:55:42 +08:00
feat(go-api) sessions message update (#16517)
### Summary ``` /api/v1/chats/<chat_id>/sessions/<session_id>/messages/<msg_id> DELETE /api/v1/chats/<chat_id>/sessions/<session_id>/messages/<msg_id>/feedback PUT ``` Migrates the chat session message delete and feedback APIs to the Go server, matching the Python behavior for authorization, session ownership checks, message/reference updates, and feedback validation.
This commit is contained in:
@@ -271,6 +271,93 @@ func (e *elasticsearchEngine) UpdateChunks(ctx context.Context, condition map[st
|
||||
return e.updateChunksByQuery(ctx, fullIndexName, condition, newValue)
|
||||
}
|
||||
|
||||
// AdjustChunkPagerank atomically adjusts pagerank_fea and clamps it to
|
||||
// [minWeight, maxWeight].
|
||||
func (e *elasticsearchEngine) AdjustChunkPagerank(ctx context.Context, indexName, chunkID, kbID string, delta, minWeight, maxWeight float64) error {
|
||||
if indexName == "" {
|
||||
return fmt.Errorf("index name cannot be empty")
|
||||
}
|
||||
if chunkID == "" {
|
||||
return fmt.Errorf("chunk id cannot be empty")
|
||||
}
|
||||
script := `
|
||||
if (ctx._source.kb_id == null || !ctx._source.kb_id.equals(params.kb_id)) {
|
||||
ctx.op = 'noop';
|
||||
} else {
|
||||
double current = 0.0;
|
||||
if (ctx._source.containsKey(params.field) && ctx._source[params.field] != null) {
|
||||
Object currentValue = ctx._source[params.field];
|
||||
if (currentValue instanceof Number) {
|
||||
current = ((Number)currentValue).doubleValue();
|
||||
} else {
|
||||
try {
|
||||
current = Double.parseDouble(currentValue.toString());
|
||||
} catch (Exception e) {
|
||||
current = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
double next = current + params.delta;
|
||||
if (next < params.min_weight) {
|
||||
next = params.min_weight;
|
||||
}
|
||||
if (next > params.max_weight) {
|
||||
next = params.max_weight;
|
||||
}
|
||||
if (next <= 0.0) {
|
||||
ctx._source.remove(params.field);
|
||||
} else {
|
||||
ctx._source[params.field] = next;
|
||||
}
|
||||
}
|
||||
`
|
||||
body, err := json.Marshal(map[string]interface{}{
|
||||
"script": map[string]interface{}{
|
||||
"source": script,
|
||||
"lang": "painless",
|
||||
"params": map[string]interface{}{
|
||||
"field": common.PAGERANK_FLD,
|
||||
"kb_id": kbID,
|
||||
"delta": delta,
|
||||
"min_weight": minWeight,
|
||||
"max_weight": maxWeight,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal pagerank adjust request: %w", err)
|
||||
}
|
||||
retryOnConflict := 3
|
||||
req := esapi.UpdateRequest{
|
||||
Index: indexName,
|
||||
DocumentID: chunkID,
|
||||
Body: bytes.NewReader(body),
|
||||
RetryOnConflict: &retryOnConflict,
|
||||
}
|
||||
res, err := req.Do(ctx, e.client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to adjust chunk pagerank: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.IsError() {
|
||||
if res.StatusCode == http.StatusNotFound {
|
||||
return fmt.Errorf("%w: %s", types.ErrDocumentNotFound, chunkID)
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(res.Body)
|
||||
return fmt.Errorf("elasticsearch pagerank adjust error: %s, body: %s", res.Status(), string(bodyBytes))
|
||||
}
|
||||
var updateResp struct {
|
||||
Result string `json:"result"`
|
||||
}
|
||||
if err := json.NewDecoder(res.Body).Decode(&updateResp); err != nil {
|
||||
return fmt.Errorf("failed to decode pagerank adjust response: %w", err)
|
||||
}
|
||||
if updateResp.Result == "noop" {
|
||||
return fmt.Errorf("chunk %s does not belong to dataset %s", chunkID, kbID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *elasticsearchEngine) updateSingleMemoryMessage(ctx context.Context, indexName, messageDocID string, newValue map[string]interface{}) error {
|
||||
doc := mapMemoryMessageESUpdateFields(newValue)
|
||||
delete(doc, "id")
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"ragflow/internal/common"
|
||||
@@ -30,6 +31,7 @@ import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
infinity "github.com/infiniflow/infinity-go-sdk"
|
||||
"go.uber.org/zap"
|
||||
@@ -38,6 +40,13 @@ import (
|
||||
// ChinesePunctRegex splits on comma, semicolon, Chinese punctuations, and newlines
|
||||
var ChinesePunctRegex = regexp.MustCompile(`[,,;;、\r\n]+`)
|
||||
|
||||
const (
|
||||
pagerankAdjustRetryCount = 2
|
||||
pagerankAdjustLockCount = 256
|
||||
)
|
||||
|
||||
var pagerankAdjustLocks [pagerankAdjustLockCount]sync.Mutex
|
||||
|
||||
// CreateChunkStore creates a chunk table in Infinity
|
||||
// baseName is the table name prefix (e.g., "ragflow_<tenant_id>")
|
||||
// The full table name is built as "{baseName}_{datasetID}"
|
||||
@@ -518,6 +527,97 @@ func (e *infinityEngine) UpdateChunks(ctx context.Context, condition map[string]
|
||||
return nil
|
||||
}
|
||||
|
||||
// AdjustChunkPagerank adjusts pagerank_fea and clamps it to [minWeight, maxWeight].
|
||||
func (e *infinityEngine) AdjustChunkPagerank(ctx context.Context, baseName, chunkID, datasetID string, delta, minWeight, maxWeight float64) error {
|
||||
if baseName == "" {
|
||||
return fmt.Errorf("index name cannot be empty")
|
||||
}
|
||||
if chunkID == "" {
|
||||
return fmt.Errorf("chunk id cannot be empty")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if e.client == nil || e.client.conn == nil {
|
||||
return fmt.Errorf("Infinity client not initialized")
|
||||
}
|
||||
|
||||
tableName := buildChunkTableName(baseName, datasetID)
|
||||
lock := pagerankAdjustLock(tableName + ":" + chunkID)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
db, err := e.client.conn.GetDatabase(e.client.dbName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get database: %w", err)
|
||||
}
|
||||
table, err := db.GetTable(tableName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get table %s: %w", tableName, err)
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
filter := fmt.Sprintf("id = '%s'", escapeFilterValue(chunkID))
|
||||
for attempt := 0; attempt <= pagerankAdjustRetryCount; attempt++ {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := table.Output([]string{common.PAGERANK_FLD, "row_id()"}).Filter(filter).ToResult()
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
qr, ok := result.(*infinity.QueryResult)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected query result type: %T", result)
|
||||
}
|
||||
|
||||
rowID, ok := firstQueryInt64(qr.Data, "ROW_ID", "row_id()", "row_id")
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: %s", types.ErrDocumentNotFound, chunkID)
|
||||
}
|
||||
|
||||
currentWeight := 0.0
|
||||
if currentValue, ok := firstQueryValue(qr.Data, common.PAGERANK_FLD); ok {
|
||||
if weight, ok := coerceToFloat(currentValue); ok {
|
||||
currentWeight = weight
|
||||
}
|
||||
}
|
||||
nextWeight := currentWeight + delta
|
||||
if nextWeight < minWeight {
|
||||
nextWeight = minWeight
|
||||
}
|
||||
if nextWeight > maxWeight {
|
||||
nextWeight = maxWeight
|
||||
}
|
||||
if floatsEqual(currentWeight, nextWeight) {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateFilter := fmt.Sprintf("_row_id = %d AND %s = %s", rowID, common.PAGERANK_FLD, formatFilterFloat(currentWeight))
|
||||
if _, err := table.Update(updateFilter, map[string]interface{}{common.PAGERANK_FLD: nextWeight}); err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
verifyResult, err := table.Output([]string{common.PAGERANK_FLD}).Filter(filter).ToResult()
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
verifyQR, ok := verifyResult.(*infinity.QueryResult)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected query result type: %T", verifyResult)
|
||||
}
|
||||
if currentValue, ok := firstQueryValue(verifyQR.Data, common.PAGERANK_FLD); ok {
|
||||
if actualWeight, ok := coerceToFloat(currentValue); ok && floatsEqual(actualWeight, nextWeight) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
lastErr = fmt.Errorf("pagerank update conflict")
|
||||
}
|
||||
return fmt.Errorf("failed to adjust chunk pagerank: %w", lastErr)
|
||||
}
|
||||
|
||||
// DeleteChunks deletes chunks from a dataset table
|
||||
// Table name format: {baseName}_{datasetID}
|
||||
// condition specifies which chunks to delete
|
||||
@@ -1976,6 +2076,77 @@ func escapeFilterValue(s string) string {
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
func firstQueryValue(data map[string][]interface{}, columns ...string) (interface{}, bool) {
|
||||
for _, column := range columns {
|
||||
values, ok := data[column]
|
||||
if ok && len(values) > 0 {
|
||||
return values[0], true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func firstQueryInt64(data map[string][]interface{}, columns ...string) (int64, bool) {
|
||||
value, ok := firstQueryValue(data, columns...)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return int64(v), true
|
||||
case int8:
|
||||
return int64(v), true
|
||||
case int16:
|
||||
return int64(v), true
|
||||
case int32:
|
||||
return int64(v), true
|
||||
case int64:
|
||||
return v, true
|
||||
case uint:
|
||||
if uint64(v) <= ^uint64(0)>>1 {
|
||||
return int64(v), true
|
||||
}
|
||||
case uint8:
|
||||
return int64(v), true
|
||||
case uint16:
|
||||
return int64(v), true
|
||||
case uint32:
|
||||
return int64(v), true
|
||||
case uint64:
|
||||
if v <= ^uint64(0)>>1 {
|
||||
return int64(v), true
|
||||
}
|
||||
case float32:
|
||||
return int64(v), true
|
||||
case float64:
|
||||
return int64(v), true
|
||||
case string:
|
||||
parsed, err := strconv.ParseInt(v, 10, 64)
|
||||
if err == nil {
|
||||
return parsed, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func pagerankAdjustLock(key string) *sync.Mutex {
|
||||
hash := fnv.New32a()
|
||||
_, _ = hash.Write([]byte(key))
|
||||
return &pagerankAdjustLocks[hash.Sum32()%pagerankAdjustLockCount]
|
||||
}
|
||||
|
||||
func formatFilterFloat(value float64) string {
|
||||
return strconv.FormatFloat(value, 'f', -1, 64)
|
||||
}
|
||||
|
||||
func floatsEqual(a, b float64) bool {
|
||||
diff := a - b
|
||||
if diff < 0 {
|
||||
diff = -diff
|
||||
}
|
||||
return diff < 1e-9
|
||||
}
|
||||
|
||||
// equivalentConditionToStr converts a condition map to an Infinity filter string
|
||||
func equivalentConditionToStr(condition map[string]interface{}) string {
|
||||
if len(condition) == 0 {
|
||||
|
||||
@@ -398,3 +398,60 @@ func (h *ChatSessionHandler) UpdateSession(c *gin.Context) {
|
||||
}
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
func (h *ChatSessionHandler) DeleteSessionMessage(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
userID := user.ID
|
||||
chatID, sessionID, msgID := c.Param("chat_id"), c.Param("session_id"), c.Param("msg_id")
|
||||
|
||||
result, code, err := h.chatSessionService.DeleteSessionMessage(userID, chatID, sessionID, msgID)
|
||||
if err != nil {
|
||||
if code == common.CodeAuthenticationError {
|
||||
jsonResponse(c, code, false, err.Error())
|
||||
return
|
||||
}
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
func (h *ChatSessionHandler) UpdateMessageFeedback(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
userID := user.ID
|
||||
chatID, sessionID, msgID := c.Param("chat_id"), c.Param("session_id"), c.Param("msg_id")
|
||||
|
||||
req := map[string]interface{}{}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
jsonError(c, common.CodeArgumentError, "Request body cannot be empty")
|
||||
return
|
||||
}
|
||||
jsonError(c, common.CodeArgumentError, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if len(req) == 0 {
|
||||
jsonError(c, common.CodeArgumentError, "Request body cannot be empty")
|
||||
return
|
||||
}
|
||||
|
||||
result, code, err := h.chatSessionService.UpdateMessageFeedback(c.Request.Context(), userID, chatID, sessionID, msgID, req)
|
||||
if err != nil {
|
||||
if code == common.CodeAuthenticationError {
|
||||
jsonResponse(c, code, false, err.Error())
|
||||
return
|
||||
}
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
jsonResponse(c, common.CodeSuccess, result, "success")
|
||||
}
|
||||
|
||||
@@ -72,3 +72,64 @@ func TestChatSessionHandlerUpdateSession_RejectsEmptyJSONObject(t *testing.T) {
|
||||
t.Fatalf("message=%v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatSessionHandlerUpdateMessageFeedback_RejectsEmptyBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPut, "/api/v1/chats/chat-1/sessions/session-1/messages/msg-1/feedback", nil)
|
||||
ctx.Params = gin.Params{
|
||||
{Key: "chat_id", Value: "chat-1"},
|
||||
{Key: "session_id", Value: "session-1"},
|
||||
{Key: "msg_id", Value: "msg-1"},
|
||||
}
|
||||
ctx.Set("user", &entity.User{ID: "user-1"})
|
||||
|
||||
handler := NewChatSessionHandler(service.NewChatSessionService(), nil)
|
||||
handler.UpdateMessageFeedback(ctx)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d", recorder.Code)
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("decode response body: %v", err)
|
||||
}
|
||||
if got := body["code"]; got != float64(common.CodeArgumentError) {
|
||||
t.Fatalf("code=%v", got)
|
||||
}
|
||||
if got := body["message"]; got != "Request body cannot be empty" {
|
||||
t.Fatalf("message=%v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatSessionHandlerUpdateMessageFeedback_RejectsEmptyJSONObject(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPut, "/api/v1/chats/chat-1/sessions/session-1/messages/msg-1/feedback", strings.NewReader(`{}`))
|
||||
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||
ctx.Params = gin.Params{
|
||||
{Key: "chat_id", Value: "chat-1"},
|
||||
{Key: "session_id", Value: "session-1"},
|
||||
{Key: "msg_id", Value: "msg-1"},
|
||||
}
|
||||
ctx.Set("user", &entity.User{ID: "user-1"})
|
||||
|
||||
handler := NewChatSessionHandler(service.NewChatSessionService(), nil)
|
||||
handler.UpdateMessageFeedback(ctx)
|
||||
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("decode response body: %v", err)
|
||||
}
|
||||
if got := body["code"]; got != float64(common.CodeArgumentError) {
|
||||
t.Fatalf("code=%v", got)
|
||||
}
|
||||
if got := body["message"]; got != "Request body cannot be empty" {
|
||||
t.Fatalf("message=%v", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,6 +300,8 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
chats.DELETE("/:chat_id/sessions", r.chatSessionHandler.DeleteSessions)
|
||||
chats.GET("/:chat_id/sessions/:session_id", r.chatSessionHandler.GetSession)
|
||||
chats.PATCH("/:chat_id/sessions/:session_id", r.chatSessionHandler.UpdateSession)
|
||||
chats.DELETE("/:chat_id/sessions/:session_id/messages/:msg_id", r.chatSessionHandler.DeleteSessionMessage)
|
||||
chats.PUT("/:chat_id/sessions/:session_id/messages/:msg_id/feedback", r.chatSessionHandler.UpdateMessageFeedback)
|
||||
}
|
||||
|
||||
chat := v1.Group("/chat")
|
||||
|
||||
@@ -22,8 +22,11 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/storage"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -61,14 +64,15 @@ type chatPipelineRunner interface {
|
||||
// (api/db/services/chunk_feedback_service.py) call site at
|
||||
// api/apps/restful_apis/chat_api.py — that handler records the thumb
|
||||
// vote against every chunk that produced the assistant message, in
|
||||
// addition to the session-level thumbup field. The Go stack does not
|
||||
// yet have a chunk-feedback DAO, so this interface is the seam where
|
||||
// one will plug in. Production uses *ChatSessionService itself via
|
||||
// applyChunkFeedback (which currently no-ops with a debug log) so the
|
||||
// handler can still update the session-level thumbup without crashing;
|
||||
// tests can swap in a fake by setting ChatSessionService.chunkFeedbackApplier.
|
||||
// addition to the session-level thumbup field. Production uses
|
||||
// *ChatSessionService itself via applyChunkFeedback; tests can swap
|
||||
// in a fake by setting ChatSessionService.chunkFeedbackApplier.
|
||||
type chunkFeedbackApplier interface {
|
||||
applyChunkFeedback(tenantID string, reference map[string]interface{}, isPositive bool) (map[string]interface{}, error)
|
||||
applyChunkFeedback(ctx context.Context, tenantID string, reference map[string]interface{}, isPositive bool) (map[string]interface{}, error)
|
||||
}
|
||||
|
||||
type chunkPagerankAdjuster interface {
|
||||
AdjustChunkPagerank(ctx context.Context, indexName, chunkID, kbID string, delta, minWeight, maxWeight float64) error
|
||||
}
|
||||
|
||||
// ChatSessionService chat session (conversation) service.
|
||||
@@ -78,6 +82,7 @@ type ChatSessionService struct {
|
||||
userTenantDAO userTenantStore
|
||||
pipeline chatPipelineRunner
|
||||
chunkFeedbackApplier chunkFeedbackApplier
|
||||
docEngine engine.DocEngine
|
||||
}
|
||||
|
||||
// NewChatSessionService create chat session service
|
||||
@@ -86,9 +91,131 @@ func NewChatSessionService() *ChatSessionService {
|
||||
chatSessionDAO: dao.NewChatSessionDAO(),
|
||||
userTenantDAO: dao.NewUserTenantDAO(),
|
||||
pipeline: NewChatPipelineService(),
|
||||
docEngine: engine.Get(),
|
||||
}
|
||||
}
|
||||
|
||||
// SetChatSessionRequest set chat session request.
|
||||
type SetChatSessionRequest struct {
|
||||
SessionID string `json:"conversation_id,omitempty"`
|
||||
DialogID string `json:"dialog_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
IsNew bool `json:"is_new"`
|
||||
}
|
||||
|
||||
// SetChatSessionResponse set chat session response.
|
||||
type SetChatSessionResponse struct {
|
||||
*entity.ChatSession
|
||||
}
|
||||
|
||||
// SetChatSession creates or updates a chat session.
|
||||
// Kept as a compatibility entrypoint for older chat-session callers.
|
||||
func (s *ChatSessionService) SetChatSession(userID string, req *SetChatSessionRequest) (*SetChatSessionResponse, error) {
|
||||
name := req.Name
|
||||
if name == "" {
|
||||
name = "New chat session"
|
||||
}
|
||||
if len(name) > 255 {
|
||||
name = name[:255]
|
||||
}
|
||||
|
||||
if !req.IsNew {
|
||||
updates := map[string]interface{}{
|
||||
"name": name,
|
||||
"user_id": userID,
|
||||
}
|
||||
if err := s.chatSessionDAO.UpdateByID(req.SessionID, updates); err != nil {
|
||||
return nil, errors.New("Chat session not found")
|
||||
}
|
||||
session, err := s.chatSessionDAO.GetByID(req.SessionID)
|
||||
if err != nil {
|
||||
return nil, errors.New("Fail to update a chat session")
|
||||
}
|
||||
return &SetChatSessionResponse{ChatSession: session}, nil
|
||||
}
|
||||
|
||||
dialog, err := s.chatSessionDAO.GetDialogByID(req.DialogID)
|
||||
if err != nil {
|
||||
return nil, errors.New("Dialog not found")
|
||||
}
|
||||
|
||||
prologue := "Hi! I'm your assistant. What can I do for you?"
|
||||
if dialog.PromptConfig != nil {
|
||||
if p, ok := dialog.PromptConfig["prologue"].(string); ok && p != "" {
|
||||
prologue = p
|
||||
}
|
||||
}
|
||||
messagesJSON, _ := json.Marshal([]map[string]interface{}{
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": prologue,
|
||||
},
|
||||
})
|
||||
referenceJSON, _ := json.Marshal([]interface{}{})
|
||||
|
||||
session := &entity.ChatSession{
|
||||
ID: common.GenerateUUID(),
|
||||
DialogID: req.DialogID,
|
||||
Name: &name,
|
||||
Message: messagesJSON,
|
||||
UserID: &userID,
|
||||
Reference: referenceJSON,
|
||||
}
|
||||
if err := s.chatSessionDAO.Create(session); err != nil {
|
||||
return nil, errors.New("Fail to create a chat session")
|
||||
}
|
||||
|
||||
return &SetChatSessionResponse{ChatSession: session}, nil
|
||||
}
|
||||
|
||||
// RemoveChatSessions removes chat sessions.
|
||||
// Kept as a compatibility entrypoint for older chat-session callers.
|
||||
func (s *ChatSessionService) RemoveChatSessions(userID string, chatSessions []string) error {
|
||||
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tenantIDSet := make(map[string]bool)
|
||||
for _, tid := range tenantIDs {
|
||||
tenantIDSet[tid] = true
|
||||
}
|
||||
tenantIDSet[userID] = true
|
||||
|
||||
for _, convID := range chatSessions {
|
||||
session, err := s.chatSessionDAO.GetByID(convID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Chat session not found: %s", convID)
|
||||
}
|
||||
|
||||
isOwner := false
|
||||
for tenantID := range tenantIDSet {
|
||||
exists, err := s.chatSessionDAO.CheckDialogExists(tenantID, session.DialogID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
isOwner = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isOwner {
|
||||
return errors.New("Only owner of chat session authorized for this operation")
|
||||
}
|
||||
|
||||
if err := s.chatSessionDAO.DeleteByID(convID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListChatSessionsRequest list chat sessions request.
|
||||
type ListChatSessionsRequest struct {
|
||||
DialogID string `json:"dialog_id" binding:"required"`
|
||||
}
|
||||
|
||||
// ListChatSessionsResponse list chat sessions response
|
||||
type ListChatSessionsResponse struct {
|
||||
Sessions []*entity.ChatSession
|
||||
@@ -542,7 +669,10 @@ func (s *ChatSessionService) DeleteSessionMessage(userID, chatID, sessionID, msg
|
||||
return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) UpdateMessageFeedback(userID, chatID, sessionID, msgID string, req map[string]interface{}) (*ChatSessionPayload, common.ErrorCode, error) {
|
||||
func (s *ChatSessionService) UpdateMessageFeedback(ctx context.Context, userID, chatID, sessionID, msgID string, req map[string]interface{}) (*ChatSessionPayload, common.ErrorCode, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
ownerTenantID := ""
|
||||
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
|
||||
if err != nil {
|
||||
@@ -654,7 +784,7 @@ func (s *ChatSessionService) UpdateMessageFeedback(userID, chatID, sessionID, ms
|
||||
applier = s
|
||||
}
|
||||
if priorThumbBool, ok := priorThumb.(bool); ok && priorThumbBool != thumbup {
|
||||
result, _ := applier.applyChunkFeedback(ownerTenantID, feedbackReference, !priorThumbBool)
|
||||
result, _ := applier.applyChunkFeedback(ctx, ownerTenantID, feedbackReference, !priorThumbBool)
|
||||
if result != nil {
|
||||
common.Debug("Chunk feedback undo applied",
|
||||
zap.Any("success_count", result["success_count"]),
|
||||
@@ -662,7 +792,7 @@ func (s *ChatSessionService) UpdateMessageFeedback(userID, chatID, sessionID, ms
|
||||
)
|
||||
}
|
||||
}
|
||||
result, _ := applier.applyChunkFeedback(ownerTenantID, feedbackReference, thumbup)
|
||||
result, _ := applier.applyChunkFeedback(ctx, ownerTenantID, feedbackReference, thumbup)
|
||||
if result != nil {
|
||||
common.Debug("Chunk feedback applied",
|
||||
zap.Any("success_count", result["success_count"]),
|
||||
@@ -674,32 +804,308 @@ func (s *ChatSessionService) UpdateMessageFeedback(userID, chatID, sessionID, ms
|
||||
return s.buildSessionPayload(session, nil, false), common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// applyChunkFeedback records a thumb vote against the chunks that
|
||||
// produced a session message. Mirrors Python's
|
||||
// ChunkFeedbackService.apply_feedback side effect (called from
|
||||
// api/apps/restful_apis/chat_api.py when a user toggles a thumb on
|
||||
// an assistant message). The Go persistence port for chunk feedback
|
||||
// is intentionally not yet landed — the call here is a documented
|
||||
// no-op so the session-level thumbup flow (the user-visible behavior)
|
||||
// keeps working while a future PR ports the Python DAO. The returned
|
||||
// counts let the caller log a "Chunk feedback applied: N succeeded,
|
||||
// M failed" line consistent with the Python equivalent, so log
|
||||
// scrapers don't see a regression in success/fail rates.
|
||||
//
|
||||
// Production callers should always go through the chunkFeedbackApplier
|
||||
// field; this method is the default implementation used when that
|
||||
// field is nil.
|
||||
func (s *ChatSessionService) applyChunkFeedback(tenantID string, reference map[string]interface{}, isPositive bool) (map[string]interface{}, error) {
|
||||
common.Debug("chunk feedback persistence not yet ported; dropping vote",
|
||||
zap.String("tenant_id", tenantID),
|
||||
const (
|
||||
upvoteWeightIncrement = 1
|
||||
downvoteWeightDecrement = 1
|
||||
minPagerankWeight = 0.0
|
||||
maxPagerankWeight = 100.0
|
||||
chunkFeedbackTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type feedbackChunkRow struct {
|
||||
chunkID string
|
||||
kbID string
|
||||
chunk map[string]interface{}
|
||||
}
|
||||
|
||||
// applyChunkFeedback records a thumb vote against the chunks that produced a
|
||||
// session message. It mirrors Python's ChunkFeedbackService.apply_feedback:
|
||||
// feature-flagged by CHUNK_FEEDBACK_ENABLED, split by relevance unless
|
||||
// CHUNK_FEEDBACK_WEIGHTING=uniform, and clamped through the document engine.
|
||||
func (s *ChatSessionService) applyChunkFeedback(ctx context.Context, tenantID string, reference map[string]interface{}, isPositive bool) (map[string]interface{}, error) {
|
||||
if !chunkFeedbackEnabled() {
|
||||
common.Debug("Chunk feedback feature is disabled")
|
||||
return map[string]interface{}{
|
||||
"success_count": 0,
|
||||
"fail_count": 0,
|
||||
"chunk_ids": []string{},
|
||||
"disabled": true,
|
||||
}, nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, chunkFeedbackTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
rows := feedbackRowsFromReference(reference)
|
||||
chunkIDs := make([]string, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
chunkIDs = append(chunkIDs, row.chunkID)
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
common.Debug("No chunk IDs found in reference for feedback")
|
||||
return map[string]interface{}{
|
||||
"success_count": 0,
|
||||
"fail_count": 0,
|
||||
"chunk_ids": chunkIDs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
signedBudget := float64(upvoteWeightIncrement)
|
||||
if !isPositive {
|
||||
signedBudget = -float64(downvoteWeightDecrement)
|
||||
}
|
||||
deltas := allocateFeedbackDeltas(rows, signedBudget, chunkFeedbackWeighting())
|
||||
|
||||
successCount := 0
|
||||
failCount := 0
|
||||
for _, delta := range deltas {
|
||||
if delta.delta == 0 {
|
||||
continue
|
||||
}
|
||||
if s.updateChunkWeight(ctx, tenantID, delta.chunkID, delta.kbID, delta.delta) {
|
||||
successCount++
|
||||
} else {
|
||||
failCount++
|
||||
}
|
||||
}
|
||||
|
||||
common.Info("Applied chunk feedback",
|
||||
zap.Bool("is_positive", isPositive),
|
||||
zap.String("weighting", chunkFeedbackWeighting()),
|
||||
zap.Int("success_count", successCount),
|
||||
zap.Int("chunk_count", len(chunkIDs)),
|
||||
)
|
||||
return map[string]interface{}{
|
||||
"success_count": 0,
|
||||
"fail_count": 0,
|
||||
"success_count": successCount,
|
||||
"fail_count": failCount,
|
||||
"chunk_ids": chunkIDs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type feedbackDelta struct {
|
||||
chunkID string
|
||||
kbID string
|
||||
delta float64
|
||||
}
|
||||
|
||||
func chunkFeedbackEnabled() bool {
|
||||
return strings.ToLower(os.Getenv("CHUNK_FEEDBACK_ENABLED")) == "true"
|
||||
}
|
||||
|
||||
func chunkFeedbackWeighting() string {
|
||||
weighting := strings.ToLower(strings.TrimSpace(os.Getenv("CHUNK_FEEDBACK_WEIGHTING")))
|
||||
if weighting == "uniform" || weighting == "relevance" {
|
||||
return weighting
|
||||
}
|
||||
return "relevance"
|
||||
}
|
||||
|
||||
func feedbackRowsFromReference(reference map[string]interface{}) []feedbackChunkRow {
|
||||
if len(reference) == 0 {
|
||||
return nil
|
||||
}
|
||||
rawChunks, ok := reference["chunks"].([]interface{})
|
||||
if !ok {
|
||||
if chunks, ok := reference["chunks"].([]map[string]interface{}); ok {
|
||||
rows := make([]feedbackChunkRow, 0, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
if row, ok := feedbackRowFromChunk(chunk); ok {
|
||||
rows = append(rows, row)
|
||||
}
|
||||
}
|
||||
return rows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
rows := make([]feedbackChunkRow, 0, len(rawChunks))
|
||||
for _, raw := range rawChunks {
|
||||
chunk, ok := raw.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if row, ok := feedbackRowFromChunk(chunk); ok {
|
||||
rows = append(rows, row)
|
||||
}
|
||||
}
|
||||
return rows
|
||||
}
|
||||
|
||||
func feedbackRowFromChunk(chunk map[string]interface{}) (feedbackChunkRow, bool) {
|
||||
chunkID := stringValue(chunk["id"])
|
||||
if chunkID == "" {
|
||||
chunkID = stringValue(chunk["chunk_id"])
|
||||
}
|
||||
kbID := stringValue(chunk["dataset_id"])
|
||||
if kbID == "" {
|
||||
kbID = stringValue(chunk["kb_id"])
|
||||
}
|
||||
if chunkID == "" || kbID == "" {
|
||||
return feedbackChunkRow{}, false
|
||||
}
|
||||
return feedbackChunkRow{chunkID: chunkID, kbID: kbID, chunk: chunk}, true
|
||||
}
|
||||
|
||||
func allocateFeedbackDeltas(rows []feedbackChunkRow, signedBudget float64, weighting string) []feedbackDelta {
|
||||
if len(rows) == 0 {
|
||||
return nil
|
||||
}
|
||||
if weighting == "uniform" {
|
||||
step := signedBudget / float64(len(rows))
|
||||
deltas := make([]feedbackDelta, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
deltas = append(deltas, feedbackDelta{chunkID: row.chunkID, kbID: row.kbID, delta: step})
|
||||
}
|
||||
return deltas
|
||||
}
|
||||
|
||||
magnitudes := make([]float64, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
signal := retrievalSignal(row.chunk)
|
||||
if signal <= 0 {
|
||||
signal = 1
|
||||
}
|
||||
magnitudes = append(magnitudes, signal)
|
||||
}
|
||||
parts := splitFloatBudget(magnitudes, math.Abs(signedBudget))
|
||||
sign := math.Copysign(1, signedBudget)
|
||||
deltas := make([]feedbackDelta, 0, len(rows))
|
||||
for i, row := range rows {
|
||||
deltas = append(deltas, feedbackDelta{chunkID: row.chunkID, kbID: row.kbID, delta: sign * parts[i]})
|
||||
}
|
||||
return deltas
|
||||
}
|
||||
|
||||
func retrievalSignal(chunk map[string]interface{}) float64 {
|
||||
best := 0.0
|
||||
for _, key := range []string{"similarity", "vector_similarity", "term_similarity"} {
|
||||
val, ok := floatValue(chunk[key])
|
||||
if ok && !math.IsInf(val, 0) && !math.IsNaN(val) && val > best {
|
||||
best = val
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func splitFloatBudget(magnitudes []float64, budget float64) []float64 {
|
||||
n := len(magnitudes)
|
||||
out := make([]float64, n)
|
||||
if n == 0 || budget == 0 {
|
||||
return out
|
||||
}
|
||||
total := 0.0
|
||||
for _, magnitude := range magnitudes {
|
||||
total += magnitude
|
||||
}
|
||||
if total <= 0 {
|
||||
base := budget / float64(n)
|
||||
for i := range out {
|
||||
out[i] = base
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
for i, magnitude := range magnitudes {
|
||||
out[i] = budget * magnitude / total
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) updateChunkWeight(ctx context.Context, tenantID, chunkID, kbID string, delta float64) bool {
|
||||
docEngine := s.docEngine
|
||||
if docEngine == nil {
|
||||
docEngine = engine.Get()
|
||||
}
|
||||
if docEngine == nil {
|
||||
common.Warn("Document engine is not initialized; chunk feedback skipped",
|
||||
zap.String("chunk_id", chunkID),
|
||||
zap.String("kb_id", kbID),
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
indexName := fmt.Sprintf("ragflow_%s", tenantID)
|
||||
if adjuster, ok := docEngine.(chunkPagerankAdjuster); ok {
|
||||
if err := adjuster.AdjustChunkPagerank(ctx, indexName, chunkID, kbID, delta, minPagerankWeight, maxPagerankWeight); err != nil {
|
||||
common.Warn("Failed atomic pagerank adjust for chunk",
|
||||
zap.String("chunk_id", chunkID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
rawChunk, err := docEngine.GetChunk(ctx, indexName, chunkID, []string{kbID})
|
||||
if err != nil {
|
||||
common.Warn("Chunk not found for feedback",
|
||||
zap.String("chunk_id", chunkID),
|
||||
zap.String("index", indexName),
|
||||
zap.Error(err),
|
||||
)
|
||||
return false
|
||||
}
|
||||
chunk, ok := rawChunk.(map[string]interface{})
|
||||
if !ok {
|
||||
common.Warn("Unexpected chunk shape for feedback",
|
||||
zap.String("chunk_id", chunkID),
|
||||
zap.String("kb_id", kbID),
|
||||
zap.String("index", indexName),
|
||||
zap.String("chunk_type", fmt.Sprintf("%T", rawChunk)),
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
currentWeight, _ := floatValue(chunk[common.PAGERANK_FLD])
|
||||
nextWeight := currentWeight + delta
|
||||
if nextWeight < minPagerankWeight {
|
||||
nextWeight = minPagerankWeight
|
||||
}
|
||||
if nextWeight > maxPagerankWeight {
|
||||
nextWeight = maxPagerankWeight
|
||||
}
|
||||
|
||||
newValue := map[string]interface{}{common.PAGERANK_FLD: nextWeight}
|
||||
if nextWeight <= 0 && strings.ToLower(docEngine.GetType()) == string(engine.EngineElasticsearch) {
|
||||
newValue = map[string]interface{}{"remove": common.PAGERANK_FLD}
|
||||
}
|
||||
if err := docEngine.UpdateChunks(ctx, map[string]interface{}{"id": chunkID}, newValue, indexName, kbID); err != nil {
|
||||
common.Warn("Failed to update chunk pagerank",
|
||||
zap.String("chunk_id", chunkID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func floatValue(value interface{}) (float64, bool) {
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return v, true
|
||||
case float32:
|
||||
return float64(v), true
|
||||
case int:
|
||||
return float64(v), true
|
||||
case int64:
|
||||
return float64(v), true
|
||||
case int32:
|
||||
return float64(v), true
|
||||
case json.Number:
|
||||
f, err := v.Float64()
|
||||
return f, err == nil
|
||||
case string:
|
||||
f, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
|
||||
return f, err == nil
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) ensureOwnedChat(userID, chatID string) (bool, error) {
|
||||
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
|
||||
if err != nil {
|
||||
@@ -851,6 +1257,180 @@ func isChatSessionNotFound(err error) bool {
|
||||
return errors.Is(err, gorm.ErrRecordNotFound)
|
||||
}
|
||||
|
||||
// Completion performs chat completion with full RAG support via ChatPipelineService.
|
||||
// Kept as a compatibility entrypoint for callers that still use the pre-ChatCompletions API.
|
||||
func (s *ChatSessionService) Completion(userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string) (map[string]interface{}, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, errors.New("messages cannot be empty")
|
||||
}
|
||||
lastRole, _ := messages[len(messages)-1]["role"].(string)
|
||||
if lastRole != "user" {
|
||||
return nil, errors.New("the last content of this conversation is not from user")
|
||||
}
|
||||
|
||||
session, err := s.chatSessionDAO.GetByID(conversationID)
|
||||
if err != nil {
|
||||
return nil, errors.New("Conversation not found")
|
||||
}
|
||||
|
||||
dialog, err := s.chatSessionDAO.GetDialogByID(session.DialogID)
|
||||
if err != nil {
|
||||
return nil, errors.New("Dialog not found")
|
||||
}
|
||||
|
||||
sessionMessages := s.buildSessionMessages(session, messages)
|
||||
reference := s.initializeReference(session)
|
||||
|
||||
isEmbedded := llmID != ""
|
||||
if llmID != "" {
|
||||
hasKey, err := s.checkTenantLLMAPIKey(dialog.TenantID, llmID)
|
||||
if err != nil || !hasKey {
|
||||
return nil, fmt.Errorf("Cannot use specified model %s", llmID)
|
||||
}
|
||||
dialog.LLMID = llmID
|
||||
if chatModelConfig != nil {
|
||||
dialog.LLMSetting = chatModelConfig
|
||||
}
|
||||
}
|
||||
|
||||
kwargs := chatModelConfig
|
||||
if kwargs == nil {
|
||||
kwargs = map[string]interface{}{}
|
||||
}
|
||||
resultChan, err := s.pipeline.AsyncChat(context.Background(), dialog, messages, false, kwargs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var answer strings.Builder
|
||||
var finalRef map[string]interface{}
|
||||
for result := range resultChan {
|
||||
if result.Answer != "" {
|
||||
answer.WriteString(result.Answer)
|
||||
}
|
||||
if result.Reference != nil {
|
||||
finalRef = result.Reference
|
||||
}
|
||||
}
|
||||
|
||||
ans := map[string]interface{}{
|
||||
"answer": answer.String(),
|
||||
"reference": finalRef,
|
||||
"final": true,
|
||||
}
|
||||
result := s.structureAnswerWithConv(session, ans, messageID, session.ID, reference)
|
||||
|
||||
if !isEmbedded {
|
||||
sessionMessages = append(sessionMessages, map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": answer.String(),
|
||||
"id": messageID,
|
||||
"created_at": float64(time.Now().Unix()),
|
||||
})
|
||||
s.updateSessionMessages(session, sessionMessages, reference)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CompletionStream performs streaming chat completion with full RAG support via ChatPipelineService.
|
||||
// Kept as a compatibility entrypoint for callers that still use the pre-ChatCompletions API.
|
||||
func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string, streamChan chan<- string) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
if len(messages) == 0 {
|
||||
streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "messages cannot be empty", "data": {"answer": "**ERROR**: messages cannot be empty", "reference": []}}`)
|
||||
return errors.New("messages cannot be empty")
|
||||
}
|
||||
lastRole, _ := messages[len(messages)-1]["role"].(string)
|
||||
if lastRole != "user" {
|
||||
streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "the last content of this conversation is not from user", "data": {"answer": "**ERROR**: the last content of this conversation is not from user", "reference": []}}`)
|
||||
return errors.New("the last content of this conversation is not from user")
|
||||
}
|
||||
|
||||
session, err := s.chatSessionDAO.GetByID(conversationID)
|
||||
if err != nil {
|
||||
streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "Conversation not found", "data": {"answer": "**ERROR**: Conversation not found", "reference": []}}`)
|
||||
return errors.New("Conversation not found")
|
||||
}
|
||||
|
||||
dialog, err := s.chatSessionDAO.GetDialogByID(session.DialogID)
|
||||
if err != nil {
|
||||
streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "Dialog not found", "data": {"answer": "**ERROR**: Dialog not found", "reference": []}}`)
|
||||
return errors.New("Dialog not found")
|
||||
}
|
||||
|
||||
sessionMessages := s.buildSessionMessages(session, messages)
|
||||
reference := s.initializeReference(session)
|
||||
|
||||
isEmbedded := llmID != ""
|
||||
if llmID != "" {
|
||||
hasKey, err := s.checkTenantLLMAPIKey(dialog.TenantID, llmID)
|
||||
if err != nil || !hasKey {
|
||||
errMsg := fmt.Sprintf(`{"code": 500, "message": "Cannot use specified model %s", "data": {"answer": "**ERROR**: Cannot use specified model", "reference": []}}`, llmID)
|
||||
streamChan <- fmt.Sprintf("data: %s\n\n", errMsg)
|
||||
return fmt.Errorf("Cannot use specified model %s", llmID)
|
||||
}
|
||||
dialog.LLMID = llmID
|
||||
if chatModelConfig != nil {
|
||||
dialog.LLMSetting = chatModelConfig
|
||||
}
|
||||
}
|
||||
|
||||
kwargs := chatModelConfig
|
||||
if kwargs == nil {
|
||||
kwargs = map[string]interface{}{}
|
||||
}
|
||||
resultChan, err := s.pipeline.AsyncChat(ctx, dialog, messages, true, kwargs)
|
||||
if err != nil {
|
||||
streamChan <- fmt.Sprintf("data: %s\n\n", fmt.Sprintf(`{"code": 500, "message": "%s", "data": {"answer": "**ERROR**: %s", "reference": []}}`, err.Error(), err.Error()))
|
||||
return err
|
||||
}
|
||||
|
||||
var fullAnswer strings.Builder
|
||||
for result := range resultChan {
|
||||
if result.Reference != nil && len(reference) > 0 {
|
||||
reference[len(reference)-1] = result.Reference
|
||||
}
|
||||
if result.Final {
|
||||
if result.Answer != "" {
|
||||
fullAnswer.Reset()
|
||||
fullAnswer.WriteString(result.Answer)
|
||||
}
|
||||
} else if result.Answer != "" {
|
||||
fullAnswer.WriteString(result.Answer)
|
||||
}
|
||||
ans := s.structureAnswer(session, fullAnswer.String(), messageID, session.ID, reference)
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"code": 0,
|
||||
"message": "",
|
||||
"data": ans,
|
||||
})
|
||||
streamChan <- fmt.Sprintf("data: %s\n\n", string(data))
|
||||
}
|
||||
|
||||
finalData, _ := json.Marshal(map[string]interface{}{
|
||||
"code": 0,
|
||||
"message": "",
|
||||
"data": true,
|
||||
})
|
||||
streamChan <- fmt.Sprintf("data: %s\n\n", string(finalData))
|
||||
|
||||
if !isEmbedded {
|
||||
sessionMessages = append(sessionMessages, map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": fullAnswer.String(),
|
||||
"id": messageID,
|
||||
"created_at": float64(time.Now().Unix()),
|
||||
})
|
||||
s.updateSessionMessages(session, sessionMessages, reference)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChatCompletions handles chat completion matching Python's session_completion.
|
||||
// When stream=true, returns nil result and streams SSE via streamChan.
|
||||
// When stream=false, returns the structured answer map.
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
|
||||
"gorm.io/gorm"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
@@ -171,6 +173,88 @@ func makeResultChan(results ...AsyncChatResult) <-chan AsyncChatResult {
|
||||
return ch
|
||||
}
|
||||
|
||||
type feedbackContextKey struct{}
|
||||
|
||||
type fakeFeedbackDocEngine struct {
|
||||
engine.DocEngine
|
||||
adjustCalls []struct {
|
||||
ctx context.Context
|
||||
indexName string
|
||||
chunkID string
|
||||
kbID string
|
||||
delta float64
|
||||
minWeight float64
|
||||
maxWeight float64
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeFeedbackDocEngine) GetType() string { return "elasticsearch" }
|
||||
|
||||
func (f *fakeFeedbackDocEngine) AdjustChunkPagerank(ctx context.Context, indexName, chunkID, kbID string, delta, minWeight, maxWeight float64) error {
|
||||
f.adjustCalls = append(f.adjustCalls, struct {
|
||||
ctx context.Context
|
||||
indexName string
|
||||
chunkID string
|
||||
kbID string
|
||||
delta float64
|
||||
minWeight float64
|
||||
maxWeight float64
|
||||
}{ctx: ctx, indexName: indexName, chunkID: chunkID, kbID: kbID, delta: delta, minWeight: minWeight, maxWeight: maxWeight})
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeInfinityFeedbackDocEngine struct {
|
||||
fakeFeedbackDocEngine
|
||||
getChunkCalled bool
|
||||
}
|
||||
|
||||
func (f *fakeInfinityFeedbackDocEngine) GetType() string { return string(engine.EngineInfinity) }
|
||||
|
||||
func (f *fakeInfinityFeedbackDocEngine) GetChunk(ctx context.Context, indexName, chunkID string, datasetIDs []string) (interface{}, error) {
|
||||
f.getChunkCalled = true
|
||||
return nil, errors.New("fallback should not be used")
|
||||
}
|
||||
|
||||
type fakeFallbackFeedbackDocEngine struct {
|
||||
engine.DocEngine
|
||||
chunks map[string]map[string]interface{}
|
||||
updateCalls []struct {
|
||||
ctx context.Context
|
||||
condition map[string]interface{}
|
||||
newValue map[string]interface{}
|
||||
indexName string
|
||||
kbID string
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeFallbackFeedbackDocEngine) GetType() string { return "elasticsearch" }
|
||||
|
||||
func (f *fakeFallbackFeedbackDocEngine) GetChunk(ctx context.Context, indexName, chunkID string, datasetIDs []string) (interface{}, error) {
|
||||
chunk, ok := f.chunks[chunkID]
|
||||
if !ok {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return chunk, nil
|
||||
}
|
||||
|
||||
func (f *fakeFallbackFeedbackDocEngine) UpdateChunks(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, indexName string, kbID string) error {
|
||||
f.updateCalls = append(f.updateCalls, struct {
|
||||
ctx context.Context
|
||||
condition map[string]interface{}
|
||||
newValue map[string]interface{}
|
||||
indexName string
|
||||
kbID string
|
||||
}{ctx: ctx, condition: condition, newValue: newValue, indexName: indexName, kbID: kbID})
|
||||
return nil
|
||||
}
|
||||
|
||||
func requireFloatClose(t *testing.T, got, want float64) {
|
||||
t.Helper()
|
||||
if math.Abs(got-want) > 1e-9 {
|
||||
t.Fatalf("got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// ===================================================================
|
||||
// ListChatSessions tests
|
||||
// ===================================================================
|
||||
@@ -410,6 +494,591 @@ func TestUpdateSession_NotFound(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// ===================================================================
|
||||
// Session message delete / feedback tests
|
||||
// ===================================================================
|
||||
|
||||
func TestDeleteSessionMessage_RemovesMessagePairAndReference(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Message: json.RawMessage(`[
|
||||
{"role":"assistant","content":"Welcome!"},
|
||||
{"role":"user","content":"first","id":"msg-1"},
|
||||
{"role":"assistant","content":"answer 1","id":"msg-1"},
|
||||
{"role":"user","content":"second","id":"msg-2"},
|
||||
{"role":"assistant","content":"answer 2","id":"msg-2"}
|
||||
]`),
|
||||
Reference: json.RawMessage(`[
|
||||
{"chunks":[{"id":"chunk-1","kb_id":"kb-1"}]},
|
||||
{"chunks":[{"id":"chunk-2","kb_id":"kb-2"}]}
|
||||
]`),
|
||||
}
|
||||
store.dialogExists["tenant-1|chat-1"] = true
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-1"}},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
resp, code, err := svc.DeleteSessionMessage("user-1", "chat-1", "session-1", "msg-1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
if len(resp.Messages) != 3 {
|
||||
t.Fatalf("messages=%#v", resp.Messages)
|
||||
}
|
||||
if resp.Messages[1]["id"] != "msg-2" || resp.Messages[2]["id"] != "msg-2" {
|
||||
t.Fatalf("remaining pair=%#v", resp.Messages)
|
||||
}
|
||||
if len(resp.Reference) != 1 {
|
||||
t.Fatalf("reference=%#v", resp.Reference)
|
||||
}
|
||||
ref, ok := resp.Reference[0].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("reference type=%T", resp.Reference[0])
|
||||
}
|
||||
chunks, ok := ref["chunks"].([]FormattedChunk)
|
||||
if !ok || len(chunks) != 1 || chunks[0].ID != "chunk-2" {
|
||||
t.Fatalf("remaining chunks=%#v", ref["chunks"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMessageFeedback_AppliesChunkFeedbackWithResolvedTenantAndContext(t *testing.T) {
|
||||
t.Setenv("CHUNK_FEEDBACK_ENABLED", "true")
|
||||
t.Setenv("CHUNK_FEEDBACK_WEIGHTING", "uniform")
|
||||
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Message: json.RawMessage(`[
|
||||
{"role":"assistant","content":"Welcome!"},
|
||||
{"role":"user","content":"question","id":"msg-1"},
|
||||
{"role":"assistant","content":"answer","id":"msg-1","feedback":"old"}
|
||||
]`),
|
||||
Reference: json.RawMessage(`[
|
||||
{"chunks":[{"id":"chunk-1","kb_id":"kb-1","similarity":0.9}]}
|
||||
]`),
|
||||
}
|
||||
store.dialogExists["tenant-owner|chat-1"] = true
|
||||
docEngine := &fakeFeedbackDocEngine{}
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-owner"}},
|
||||
pipeline: &fakePipeline{},
|
||||
docEngine: docEngine,
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), feedbackContextKey{}, "request-context")
|
||||
|
||||
resp, code, err := svc.UpdateMessageFeedback(ctx, "user-1", "chat-1", "session-1", "msg-1", map[string]interface{}{
|
||||
"thumbup": true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
if len(docEngine.adjustCalls) != 1 {
|
||||
t.Fatalf("adjust calls=%d", len(docEngine.adjustCalls))
|
||||
}
|
||||
call := docEngine.adjustCalls[0]
|
||||
if call.indexName != "ragflow_tenant-owner" || call.chunkID != "chunk-1" || call.kbID != "kb-1" {
|
||||
t.Fatalf("call=%#v", call)
|
||||
}
|
||||
if call.delta != 1 || call.minWeight != 0 || call.maxWeight != 100 {
|
||||
t.Fatalf("weights=%#v", call)
|
||||
}
|
||||
if got := call.ctx.Value(feedbackContextKey{}); got != "request-context" {
|
||||
t.Fatalf("context value=%v", got)
|
||||
}
|
||||
assistant := resp.Messages[2]
|
||||
if assistant["thumbup"] != true {
|
||||
t.Fatalf("assistant=%#v", assistant)
|
||||
}
|
||||
if _, ok := assistant["feedback"]; ok {
|
||||
t.Fatalf("positive feedback should remove text feedback: %#v", assistant)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMessageFeedback_ToggleUsesResolvedTenantForUndoAndApply(t *testing.T) {
|
||||
t.Setenv("CHUNK_FEEDBACK_ENABLED", "true")
|
||||
t.Setenv("CHUNK_FEEDBACK_WEIGHTING", "uniform")
|
||||
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Message: json.RawMessage(`[
|
||||
{"role":"assistant","content":"Welcome!"},
|
||||
{"role":"user","content":"question","id":"msg-1"},
|
||||
{"role":"assistant","content":"answer","id":"msg-1","thumbup":true}
|
||||
]`),
|
||||
Reference: json.RawMessage(`[
|
||||
{"chunks":[{"chunk_id":"chunk-1","dataset_id":"kb-1"}]}
|
||||
]`),
|
||||
}
|
||||
store.dialogExists["tenant-owner|chat-1"] = true
|
||||
docEngine := &fakeFeedbackDocEngine{}
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-owner"}},
|
||||
pipeline: &fakePipeline{},
|
||||
docEngine: docEngine,
|
||||
}
|
||||
|
||||
resp, code, err := svc.UpdateMessageFeedback(context.Background(), "user-1", "chat-1", "session-1", "msg-1", map[string]interface{}{
|
||||
"thumbup": false,
|
||||
"feedback": "not useful",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("code=%v", code)
|
||||
}
|
||||
if len(docEngine.adjustCalls) != 2 {
|
||||
t.Fatalf("adjust calls=%d", len(docEngine.adjustCalls))
|
||||
}
|
||||
for _, call := range docEngine.adjustCalls {
|
||||
if call.indexName != "ragflow_tenant-owner" || call.delta != -1 {
|
||||
t.Fatalf("call=%#v", call)
|
||||
}
|
||||
}
|
||||
assistant := resp.Messages[2]
|
||||
if assistant["thumbup"] != false || assistant["feedback"] != "not useful" {
|
||||
t.Fatalf("assistant=%#v", assistant)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyChunkFeedback_DisabledDoesNotTouchEngine(t *testing.T) {
|
||||
t.Setenv("CHUNK_FEEDBACK_ENABLED", "false")
|
||||
docEngine := &fakeFeedbackDocEngine{}
|
||||
svc := &ChatSessionService{docEngine: docEngine}
|
||||
|
||||
result, err := svc.applyChunkFeedback(context.Background(), "tenant-1", map[string]interface{}{
|
||||
"chunks": []interface{}{map[string]interface{}{"id": "chunk-1", "kb_id": "kb-1"}},
|
||||
}, true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result["disabled"] != true {
|
||||
t.Fatalf("result=%#v", result)
|
||||
}
|
||||
if len(docEngine.adjustCalls) != 0 {
|
||||
t.Fatalf("adjust calls=%d", len(docEngine.adjustCalls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyChunkFeedback_UniformSplitsOneVoteAcrossChunks(t *testing.T) {
|
||||
t.Setenv("CHUNK_FEEDBACK_ENABLED", "true")
|
||||
t.Setenv("CHUNK_FEEDBACK_WEIGHTING", "uniform")
|
||||
|
||||
docEngine := &fakeFeedbackDocEngine{}
|
||||
svc := &ChatSessionService{docEngine: docEngine}
|
||||
|
||||
result, err := svc.applyChunkFeedback(context.Background(), "tenant-1", map[string]interface{}{
|
||||
"chunks": []interface{}{
|
||||
map[string]interface{}{"id": "chunk-1", "kb_id": "kb-1"},
|
||||
map[string]interface{}{"id": "chunk-2", "kb_id": "kb-1"},
|
||||
},
|
||||
}, true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result["success_count"] != 2 {
|
||||
t.Fatalf("result=%#v", result)
|
||||
}
|
||||
if len(docEngine.adjustCalls) != 2 {
|
||||
t.Fatalf("adjust calls=%d", len(docEngine.adjustCalls))
|
||||
}
|
||||
requireFloatClose(t, docEngine.adjustCalls[0].delta, 0.5)
|
||||
requireFloatClose(t, docEngine.adjustCalls[1].delta, 0.5)
|
||||
requireFloatClose(t, docEngine.adjustCalls[0].delta+docEngine.adjustCalls[1].delta, 1)
|
||||
}
|
||||
|
||||
func TestApplyChunkFeedback_RelevanceDistributesOneVoteBySignals(t *testing.T) {
|
||||
t.Setenv("CHUNK_FEEDBACK_ENABLED", "true")
|
||||
t.Setenv("CHUNK_FEEDBACK_WEIGHTING", "relevance")
|
||||
|
||||
docEngine := &fakeFeedbackDocEngine{}
|
||||
svc := &ChatSessionService{docEngine: docEngine}
|
||||
|
||||
result, err := svc.applyChunkFeedback(context.Background(), "tenant-1", map[string]interface{}{
|
||||
"chunks": []interface{}{
|
||||
map[string]interface{}{"id": "chunk-1", "kb_id": "kb-1", "similarity": 2.0},
|
||||
map[string]interface{}{"id": "chunk-2", "kb_id": "kb-1", "vector_similarity": 1.0},
|
||||
map[string]interface{}{"id": "chunk-3", "kb_id": "kb-1"},
|
||||
},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result["success_count"] != 3 {
|
||||
t.Fatalf("result=%#v", result)
|
||||
}
|
||||
if len(docEngine.adjustCalls) != 3 {
|
||||
t.Fatalf("adjust calls=%d", len(docEngine.adjustCalls))
|
||||
}
|
||||
requireFloatClose(t, docEngine.adjustCalls[0].delta, -0.5)
|
||||
requireFloatClose(t, docEngine.adjustCalls[1].delta, -0.25)
|
||||
requireFloatClose(t, docEngine.adjustCalls[2].delta, -0.25)
|
||||
requireFloatClose(t, docEngine.adjustCalls[0].delta+docEngine.adjustCalls[1].delta+docEngine.adjustCalls[2].delta, -1)
|
||||
}
|
||||
|
||||
func TestUpdateChunkWeight_InfinityUsesAtomicAdjuster(t *testing.T) {
|
||||
docEngine := &fakeInfinityFeedbackDocEngine{}
|
||||
svc := &ChatSessionService{docEngine: docEngine}
|
||||
|
||||
if ok := svc.updateChunkWeight(context.Background(), "tenant-1", "chunk-1", "kb-1", 0.25); !ok {
|
||||
t.Fatal("expected updateChunkWeight to succeed")
|
||||
}
|
||||
if docEngine.getChunkCalled {
|
||||
t.Fatal("expected Infinity adjuster path, got GetChunk fallback")
|
||||
}
|
||||
if len(docEngine.adjustCalls) != 1 {
|
||||
t.Fatalf("adjust calls=%d", len(docEngine.adjustCalls))
|
||||
}
|
||||
call := docEngine.adjustCalls[0]
|
||||
if call.indexName != "ragflow_tenant-1" || call.chunkID != "chunk-1" || call.kbID != "kb-1" {
|
||||
t.Fatalf("call=%#v", call)
|
||||
}
|
||||
requireFloatClose(t, call.delta, 0.25)
|
||||
}
|
||||
|
||||
func TestApplyChunkFeedback_FallbackClampsAndRemovesPagerank(t *testing.T) {
|
||||
t.Setenv("CHUNK_FEEDBACK_ENABLED", "true")
|
||||
t.Setenv("CHUNK_FEEDBACK_WEIGHTING", "uniform")
|
||||
|
||||
docEngine := &fakeFallbackFeedbackDocEngine{
|
||||
chunks: map[string]map[string]interface{}{
|
||||
"chunk-1": {common.PAGERANK_FLD: 0},
|
||||
},
|
||||
}
|
||||
svc := &ChatSessionService{docEngine: docEngine}
|
||||
|
||||
result, err := svc.applyChunkFeedback(context.Background(), "tenant-1", map[string]interface{}{
|
||||
"chunks": []interface{}{map[string]interface{}{"id": "chunk-1", "kb_id": "kb-1"}},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result["success_count"] != 1 {
|
||||
t.Fatalf("result=%#v", result)
|
||||
}
|
||||
if len(docEngine.updateCalls) != 1 {
|
||||
t.Fatalf("update calls=%d", len(docEngine.updateCalls))
|
||||
}
|
||||
call := docEngine.updateCalls[0]
|
||||
if call.indexName != "ragflow_tenant-1" || call.kbID != "kb-1" {
|
||||
t.Fatalf("call=%#v", call)
|
||||
}
|
||||
if call.condition["id"] != "chunk-1" || call.newValue["remove"] != common.PAGERANK_FLD {
|
||||
t.Fatalf("call=%#v", call)
|
||||
}
|
||||
}
|
||||
|
||||
// ===================================================================
|
||||
// Completion tests
|
||||
// ===================================================================
|
||||
|
||||
func TestCompletion_Success(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
session := &entity.ChatSession{
|
||||
ID: "session-1", DialogID: "dialog-1",
|
||||
Message: json.RawMessage(`[{"role":"assistant","content":"Welcome!"}]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
}
|
||||
store.sessions["session-1"] = session
|
||||
store.dialogs["dialog-1"] = &entity.Chat{
|
||||
ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory",
|
||||
LLMSetting: entity.JSONMap{},
|
||||
}
|
||||
|
||||
pipeline := &fakePipeline{
|
||||
resultChan: makeResultChan(
|
||||
AsyncChatResult{Answer: "Hello", Reference: map[string]interface{}{"chunks": []interface{}{}}},
|
||||
AsyncChatResult{Answer: " world", Final: true, Reference: map[string]interface{}{"chunks": []interface{}{}}},
|
||||
),
|
||||
}
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: pipeline,
|
||||
}
|
||||
|
||||
result, err := svc.Completion("user-1", "session-1", []map[string]interface{}{
|
||||
{"role": "user", "content": "hi"},
|
||||
}, "", nil, "msg-1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
ans, _ := result["answer"].(string)
|
||||
if ans != "Hello world" {
|
||||
t.Fatalf("expected answer 'Hello world', got %q", ans)
|
||||
}
|
||||
|
||||
got := parseMessages(store.sessions["session-1"].Message)
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("stored messages=%#v", got)
|
||||
}
|
||||
if got[0]["role"] != "assistant" || got[0]["content"] != "Welcome!" {
|
||||
t.Fatalf("stored prologue=%#v", got[0])
|
||||
}
|
||||
if got[1]["role"] != "user" || got[1]["content"] != "hi" {
|
||||
t.Fatalf("stored user message=%#v", got[1])
|
||||
}
|
||||
if got[2]["role"] != "assistant" || got[2]["content"] != "Hello world" || got[2]["id"] != "msg-1" {
|
||||
t.Fatalf("stored assistant message=%#v", got[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletion_EmptyMessages(t *testing.T) {
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: &fakeSessionStore{},
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, err := svc.Completion("user-1", "session-1", nil, "", nil, "msg-1")
|
||||
if err == nil || err.Error() != "messages cannot be empty" {
|
||||
t.Fatalf("expected 'messages cannot be empty', got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletion_LastMessageNotFromUser(t *testing.T) {
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: &fakeSessionStore{},
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, err := svc.Completion("user-1", "session-1", []map[string]interface{}{
|
||||
{"role": "assistant", "content": "hello"},
|
||||
}, "", nil, "msg-1")
|
||||
if err == nil || !strings.Contains(err.Error(), "not from user") {
|
||||
t.Fatalf("expected 'not from user' error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletion_ConversationNotFound(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, err := svc.Completion("user-1", "missing", []map[string]interface{}{
|
||||
{"role": "user", "content": "hi"},
|
||||
}, "", nil, "msg-1")
|
||||
if err == nil || err.Error() != "Conversation not found" {
|
||||
t.Fatalf("expected 'Conversation not found', got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletion_DialogNotFound(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1", DialogID: "dialog-1",
|
||||
Message: json.RawMessage(`[]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
}
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{},
|
||||
}
|
||||
|
||||
_, err := svc.Completion("user-1", "session-1", []map[string]interface{}{
|
||||
{"role": "user", "content": "hi"},
|
||||
}, "", nil, "msg-1")
|
||||
if err == nil || err.Error() != "Dialog not found" {
|
||||
t.Fatalf("expected 'Dialog not found', got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletion_PipelineError(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1", DialogID: "dialog-1",
|
||||
Message: json.RawMessage(`[]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
}
|
||||
store.dialogs["dialog-1"] = &entity.Chat{
|
||||
ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory",
|
||||
LLMSetting: entity.JSONMap{},
|
||||
}
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: &fakePipeline{err: errors.New("model unavailable")},
|
||||
}
|
||||
|
||||
_, err := svc.Completion("user-1", "session-1", []map[string]interface{}{
|
||||
{"role": "user", "content": "hi"},
|
||||
}, "", nil, "msg-1")
|
||||
if err == nil || err.Error() != "model unavailable" {
|
||||
t.Fatalf("expected 'model unavailable' error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ===================================================================
|
||||
// CompletionStream tests
|
||||
// ===================================================================
|
||||
|
||||
func readStreamChan(ch <-chan string, n int) []string {
|
||||
var msgs []string
|
||||
for i := 0; i < n; i++ {
|
||||
select {
|
||||
case msg, ok := <-ch:
|
||||
if !ok {
|
||||
return msgs
|
||||
}
|
||||
msgs = append(msgs, msg)
|
||||
default:
|
||||
return msgs
|
||||
}
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
|
||||
func TestCompletionStream_Success(t *testing.T) {
|
||||
store := newFakeSessionStore()
|
||||
store.sessions["session-1"] = &entity.ChatSession{
|
||||
ID: "session-1", DialogID: "dialog-1",
|
||||
Message: json.RawMessage(`{"messages":[{"role":"assistant","content":"Welcome!"}]}`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
}
|
||||
store.dialogs["dialog-1"] = &entity.Chat{
|
||||
ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory",
|
||||
LLMSetting: entity.JSONMap{},
|
||||
}
|
||||
|
||||
pipeline := &fakePipeline{
|
||||
resultChan: makeResultChan(
|
||||
AsyncChatResult{Answer: "stream", Reference: map[string]interface{}{"chunks": []interface{}{}}},
|
||||
AsyncChatResult{Answer: " answer", Reference: map[string]interface{}{"chunks": []interface{}{}}},
|
||||
),
|
||||
}
|
||||
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: store,
|
||||
userTenantDAO: &fakeTenantStore{},
|
||||
pipeline: pipeline,
|
||||
}
|
||||
|
||||
streamChan := make(chan string, 10)
|
||||
err := svc.CompletionStream(context.Background(), "user-1", "session-1", []map[string]interface{}{
|
||||
{"role": "user", "content": "hi"},
|
||||
}, "", nil, "msg-1", streamChan)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should receive data events and final signal
|
||||
msgs := readStreamChan(streamChan, 5)
|
||||
if len(msgs) < 3 {
|
||||
t.Fatalf("expected at least 3 stream messages, got %d: %v", len(msgs), msgs)
|
||||
}
|
||||
// Check final signal
|
||||
finalFound := false
|
||||
for _, m := range msgs {
|
||||
if strings.Contains(m, `"data":true`) {
|
||||
finalFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !finalFound {
|
||||
t.Fatal("expected final=true signal in stream")
|
||||
}
|
||||
|
||||
got := parseMessages(store.sessions["session-1"].Message)
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("stored messages=%#v", got)
|
||||
}
|
||||
if got[0]["role"] != "assistant" || got[0]["content"] != "Welcome!" {
|
||||
t.Fatalf("stored prologue=%#v", got[0])
|
||||
}
|
||||
if got[1]["role"] != "user" || got[1]["content"] != "hi" {
|
||||
t.Fatalf("stored user message=%#v", got[1])
|
||||
}
|
||||
if got[2]["role"] != "assistant" || got[2]["content"] != "stream answer" || got[2]["id"] != "msg-1" {
|
||||
t.Fatalf("stored assistant message=%#v", got[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructureAnswerWithConv_ParsesArrayMessages(t *testing.T) {
|
||||
session := &entity.ChatSession{
|
||||
ID: "session-1",
|
||||
Message: json.RawMessage(`[{"role":"assistant","content":"Welcome!"}]`),
|
||||
}
|
||||
svc := &ChatSessionService{}
|
||||
|
||||
ans := svc.structureAnswerWithConv(session, map[string]interface{}{
|
||||
"answer": "Final answer",
|
||||
"reference": map[string]interface{}{"chunks": []interface{}{}},
|
||||
"final": true,
|
||||
}, "msg-1", "session-1", []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
|
||||
|
||||
if ans["id"] != "msg-1" || ans["session_id"] != "session-1" {
|
||||
t.Fatalf("ans=%#v", ans)
|
||||
}
|
||||
|
||||
got := parseMessages(session.Message)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("stored messages=%#v", got)
|
||||
}
|
||||
if got[0]["role"] != "assistant" || got[0]["content"] != "Final answer" || got[0]["id"] != "msg-1" {
|
||||
t.Fatalf("stored assistant message=%#v", got[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMessages_LegacyWrappedObject(t *testing.T) {
|
||||
got := parseMessages(json.RawMessage(`{"messages":[{"role":"assistant","content":"legacy"}]}`))
|
||||
if !reflect.DeepEqual(got, []map[string]interface{}{{"role": "assistant", "content": "legacy"}}) {
|
||||
t.Fatalf("messages=%#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSessionPayload_EmptyCollectionsEncodeAsEmptyArrays(t *testing.T) {
|
||||
svc := &ChatSessionService{}
|
||||
payload := svc.buildSessionPayload(&entity.ChatSession{
|
||||
ID: "session-1",
|
||||
DialogID: "chat-1",
|
||||
Message: nil,
|
||||
Reference: json.RawMessage(`null`),
|
||||
}, nil, false)
|
||||
|
||||
if payload.Messages == nil {
|
||||
t.Fatal("messages is nil")
|
||||
}
|
||||
if payload.Reference == nil {
|
||||
t.Fatal("reference is nil")
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
if !strings.Contains(string(body), `"messages":[]`) {
|
||||
t.Fatalf("messages did not encode as empty array: %s", string(body))
|
||||
}
|
||||
if !strings.Contains(string(body), `"reference":[]`) {
|
||||
t.Fatalf("reference did not encode as empty array: %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCollections_ReturnEmptySlicesForMissingOrNull(t *testing.T) {
|
||||
messageInputs := []json.RawMessage{
|
||||
nil,
|
||||
|
||||
Reference in New Issue
Block a user