mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
## Summary - Add knowledge-base retrieval support to Go chat completions. ## What changed - Routes KB-backed chat sessions through the Go retrieval service instead of falling back to solo chat. - Resolves embedding and rerank models, validates accessible knowledge bases, and preserves tenant-aware retrieval. - Rejects mixed embedding models across selected knowledge bases before retrieval to avoid incompatible vector dimensions. - Threads the HTTP request context into streaming retrieval so cancelled requests can stop downstream retrieval work. - Applies metadata filters and message-level `doc_ids` before retrieval. - Expands parent/child chunks before building references and prompt context. - Injects retrieved knowledge through a copied dialog prompt config so the caller's original dialog is not mutated. - Honors configured empty responses when no chunks are found. - Names the metadata no-match sentinel and reuses it across retrieval/handler paths. - Adds a defensive content cast while appending streamed answers. - Adds focused unit coverage for retrieval, metadata filtering, authorization, multimodal messages, references, empty-response behavior, prompt immutability, and mixed embedding models. --------- Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com> Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
1532 lines
47 KiB
Go
1532 lines
47 KiB
Go
//
|
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"ragflow/internal/common"
|
|
"ragflow/internal/engine"
|
|
"ragflow/internal/service/nlp"
|
|
"strings"
|
|
"time"
|
|
|
|
"go.uber.org/zap"
|
|
|
|
"ragflow/internal/dao"
|
|
"ragflow/internal/entity"
|
|
modelModule "ragflow/internal/entity/models"
|
|
)
|
|
|
|
type chatKnowledgebaseStore interface {
|
|
Accessible(kbID, userID string) bool
|
|
GetByIDs(ids []string) ([]*entity.Knowledgebase, error)
|
|
}
|
|
|
|
type chatModelProvider interface {
|
|
GetChatModel(tenantID, compositeModelName string) (*modelModule.ChatModel, error)
|
|
GetEmbeddingModel(tenantID, compositeModelName string) (*modelModule.EmbeddingModel, error)
|
|
GetRerankModel(tenantID, compositeModelName string) (*modelModule.RerankModel, error)
|
|
GetModelConfigFromProviderInstance(tenantID string, modelType entity.ModelType, modelName string) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error)
|
|
GetTenantDefaultModelByType(tenantID string, modelType entity.ModelType) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error)
|
|
}
|
|
|
|
type chatMetadataService interface {
|
|
LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64
|
|
GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error)
|
|
}
|
|
|
|
type chatRetrievalService interface {
|
|
Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error)
|
|
}
|
|
|
|
// ChatSessionService chat session (conversation) service
|
|
type ChatSessionService struct {
|
|
chatSessionDAO *dao.ChatSessionDAO
|
|
chatDAO *dao.ChatDAO
|
|
userTenantDAO *dao.UserTenantDAO
|
|
kbDAO chatKnowledgebaseStore
|
|
docEngine engine.DocEngine
|
|
modelProviderSvc chatModelProvider
|
|
metadataSvc chatMetadataService
|
|
retrievalSvc chatRetrievalService
|
|
}
|
|
|
|
// NewChatSessionService create chat session service
|
|
func NewChatSessionService() *ChatSessionService {
|
|
docEngine := engine.Get()
|
|
return newChatSessionServiceWithRetrieval(docEngine, nlp.NewRetrievalService(docEngine, dao.NewDocumentDAO()))
|
|
}
|
|
|
|
// NewChatSessionServiceWithRetrieval creates a chat session service with a retrieval service.
|
|
func NewChatSessionServiceWithRetrieval(retrievalSvc chatRetrievalService) *ChatSessionService {
|
|
return newChatSessionServiceWithRetrieval(engine.Get(), retrievalSvc)
|
|
}
|
|
|
|
func newChatSessionServiceWithRetrieval(docEngine engine.DocEngine, retrievalSvc chatRetrievalService) *ChatSessionService {
|
|
return &ChatSessionService{
|
|
chatSessionDAO: dao.NewChatSessionDAO(),
|
|
chatDAO: dao.NewChatDAO(),
|
|
userTenantDAO: dao.NewUserTenantDAO(),
|
|
kbDAO: dao.NewKnowledgebaseDAO(),
|
|
docEngine: docEngine,
|
|
modelProviderSvc: NewModelProviderService(),
|
|
metadataSvc: NewMetadataService(),
|
|
retrievalSvc: retrievalSvc,
|
|
}
|
|
}
|
|
|
|
// SetChatSessionRequest set chat session request
|
|
type SetChatSessionRequest struct {
|
|
SessionID string `json:"conversation_id,omitempty"`
|
|
DialogID string `json:"dialog_id,omitempty"`
|
|
Name string `json:"name,omitempty"`
|
|
IsNew bool `json:"is_new"`
|
|
}
|
|
|
|
// SetChatSessionResponse set chat session response
|
|
type SetChatSessionResponse struct {
|
|
*entity.ChatSession
|
|
}
|
|
|
|
// SetChatSession create or update a chat session
|
|
func (s *ChatSessionService) SetChatSession(userID string, req *SetChatSessionRequest) (*SetChatSessionResponse, error) {
|
|
name := req.Name
|
|
if name == "" {
|
|
name = "New chat session"
|
|
}
|
|
// Limit name length to 255 characters
|
|
if len(name) > 255 {
|
|
name = name[:255]
|
|
}
|
|
|
|
if !req.IsNew {
|
|
// Update existing chat session
|
|
updates := map[string]interface{}{
|
|
"name": name,
|
|
"user_id": userID,
|
|
}
|
|
|
|
if err := s.chatSessionDAO.UpdateByID(req.SessionID, updates); err != nil {
|
|
return nil, errors.New("Chat session not found")
|
|
}
|
|
|
|
// Get updated chat session
|
|
session, err := s.chatSessionDAO.GetByID(req.SessionID)
|
|
if err != nil {
|
|
return nil, errors.New("Fail to update a chat session")
|
|
}
|
|
|
|
return &SetChatSessionResponse{ChatSession: session}, nil
|
|
}
|
|
|
|
// Create new chat session
|
|
// Check if dialog exists
|
|
dialog, err := s.chatSessionDAO.GetDialogByID(req.DialogID)
|
|
if err != nil {
|
|
return nil, errors.New("Dialog not found")
|
|
}
|
|
|
|
// Generate UUID for new chat session
|
|
newID := common.GenerateUUID()
|
|
|
|
// Get prologue from dialog's prompt_config
|
|
prologue := "Hi! I'm your assistant. What can I do for you?"
|
|
if dialog.PromptConfig != nil {
|
|
if p, ok := dialog.PromptConfig["prologue"].(string); ok && p != "" {
|
|
prologue = p
|
|
}
|
|
}
|
|
|
|
// Create initial message - store as JSON object with messages array
|
|
messagesObj := map[string]interface{}{
|
|
"messages": []map[string]interface{}{
|
|
{
|
|
"role": "assistant",
|
|
"content": prologue,
|
|
},
|
|
},
|
|
}
|
|
messagesJSON, _ := json.Marshal(messagesObj)
|
|
|
|
// Create reference - store as JSON array
|
|
referenceJSON, _ := json.Marshal([]interface{}{})
|
|
|
|
// Create chat session
|
|
session := &entity.ChatSession{
|
|
ID: newID,
|
|
DialogID: req.DialogID,
|
|
Name: &name,
|
|
Message: messagesJSON,
|
|
UserID: &userID,
|
|
Reference: referenceJSON,
|
|
}
|
|
|
|
if err := s.chatSessionDAO.Create(session); err != nil {
|
|
return nil, errors.New("Fail to create a chat session")
|
|
}
|
|
|
|
return &SetChatSessionResponse{ChatSession: session}, nil
|
|
}
|
|
|
|
// RemoveChatSessionRequest remove chat sessions request
|
|
type RemoveChatSessionRequest struct {
|
|
ChatSessions []string `json:"conversation_ids" binding:"required"`
|
|
}
|
|
|
|
// RemoveChatSessions removes chat sessions (hard delete)
|
|
func (s *ChatSessionService) RemoveChatSessions(userID string, chatSessions []string) error {
|
|
// Get user's tenants
|
|
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Build a set of user's tenant IDs for quick lookup
|
|
tenantIDSet := make(map[string]bool)
|
|
for _, tid := range tenantIDs {
|
|
tenantIDSet[tid] = true
|
|
}
|
|
tenantIDSet[userID] = true
|
|
|
|
// Check each chat session
|
|
for _, convID := range chatSessions {
|
|
// Get the chat session
|
|
session, err := s.chatSessionDAO.GetByID(convID)
|
|
if err != nil {
|
|
return fmt.Errorf("Chat session not found: %s", convID)
|
|
}
|
|
|
|
// Check if user is the owner by checking dialog ownership
|
|
isOwner := false
|
|
for tenantID := range tenantIDSet {
|
|
exists, err := s.chatSessionDAO.CheckDialogExists(tenantID, session.DialogID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if exists {
|
|
isOwner = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !isOwner {
|
|
return errors.New("Only owner of chat session authorized for this operation")
|
|
}
|
|
|
|
// Delete the chat session
|
|
if err := s.chatSessionDAO.DeleteByID(convID); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ListChatSessionsRequest list chat sessions request
|
|
type ListChatSessionsRequest struct {
|
|
DialogID string `json:"dialog_id" binding:"required"`
|
|
}
|
|
|
|
// ListChatSessionsResponse list chat sessions response
|
|
type ListChatSessionsResponse struct {
|
|
Sessions []*entity.ChatSession
|
|
}
|
|
|
|
// ListChatSessions lists chat sessions for a dialog
|
|
func (s *ChatSessionService) ListChatSessions(userID string, chatID string) (*ListChatSessionsResponse, error) {
|
|
// Get user's tenants
|
|
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Check if user is the owner of the dialog
|
|
isOwner := false
|
|
for _, tenantID := range tenantIDs {
|
|
var exists bool
|
|
exists, err = s.chatSessionDAO.CheckDialogExists(tenantID, chatID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if exists {
|
|
isOwner = true
|
|
break
|
|
}
|
|
}
|
|
|
|
// Also check with userID as tenant
|
|
if !isOwner {
|
|
var exists bool
|
|
exists, err = s.chatSessionDAO.CheckDialogExists(userID, chatID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
isOwner = exists
|
|
}
|
|
|
|
if !isOwner {
|
|
return nil, errors.New("only owner of dialog authorized for this operation")
|
|
}
|
|
|
|
// List chat sessions
|
|
sessions, err := s.chatSessionDAO.ListByChatID(chatID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &ListChatSessionsResponse{Sessions: sessions}, nil
|
|
}
|
|
|
|
// Completion performs chat completion with full RAG support
|
|
func (s *ChatSessionService) Completion(userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string) (map[string]interface{}, error) {
|
|
// Validate the last message is from user
|
|
if len(messages) == 0 {
|
|
return nil, errors.New("messages cannot be empty")
|
|
}
|
|
lastRole, _ := messages[len(messages)-1]["role"].(string)
|
|
if lastRole != "user" {
|
|
return nil, errors.New("the last content of this conversation is not from user")
|
|
}
|
|
|
|
// Get conversation
|
|
session, err := s.chatSessionDAO.GetByID(conversationID)
|
|
if err != nil {
|
|
return nil, errors.New("Conversation not found")
|
|
}
|
|
|
|
// Get dialog
|
|
dialog, err := s.chatSessionDAO.GetDialogByID(session.DialogID)
|
|
if err != nil {
|
|
return nil, errors.New("Dialog not found")
|
|
}
|
|
|
|
// Deep copy messages to session
|
|
sessionMessages := s.buildSessionMessages(session, messages)
|
|
|
|
// Initialize reference if empty
|
|
reference := s.initializeReference(session)
|
|
|
|
// Check if custom LLM is specified and validate API key
|
|
isEmbedded := llmID != ""
|
|
if llmID != "" {
|
|
hasKey, err := s.checkTenantLLMAPIKey(dialog.TenantID, llmID)
|
|
if err != nil || !hasKey {
|
|
return nil, fmt.Errorf("Cannot use specified model %s", llmID)
|
|
}
|
|
dialog.LLMID = llmID
|
|
if chatModelConfig != nil {
|
|
dialog.LLMSetting = chatModelConfig
|
|
}
|
|
}
|
|
|
|
// Perform chat completion with RAG
|
|
result, err := s.asyncChat(userID, dialog, session, messages, chatModelConfig, messageID, reference, false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Update conversation if not embedded
|
|
if !isEmbedded {
|
|
s.updateSessionMessages(session, sessionMessages, reference)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// CompletionStream performs streaming chat completion with full RAG support
|
|
func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string, streamChan chan<- string) error {
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
|
|
// Validate the last message is from user
|
|
if len(messages) == 0 {
|
|
streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "messages cannot be empty", "data": {"answer": "**ERROR**: messages cannot be empty", "reference": []}}`)
|
|
return errors.New("messages cannot be empty")
|
|
}
|
|
lastRole, _ := messages[len(messages)-1]["role"].(string)
|
|
if lastRole != "user" {
|
|
streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "the last content of this conversation is not from user", "data": {"answer": "**ERROR**: the last content of this conversation is not from user", "reference": []}}`)
|
|
return errors.New("the last content of this conversation is not from user")
|
|
}
|
|
|
|
// Get conversation
|
|
session, err := s.chatSessionDAO.GetByID(conversationID)
|
|
if err != nil {
|
|
streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "Conversation not found", "data": {"answer": "**ERROR**: Conversation not found", "reference": []}}`)
|
|
return errors.New("Conversation not found")
|
|
}
|
|
|
|
// Get dialog
|
|
dialog, err := s.chatSessionDAO.GetDialogByID(session.DialogID)
|
|
if err != nil {
|
|
streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "Dialog not found", "data": {"answer": "**ERROR**: Dialog not found", "reference": []}}`)
|
|
return errors.New("Dialog not found")
|
|
}
|
|
|
|
// Deep copy messages to session
|
|
sessionMessages := s.buildSessionMessages(session, messages)
|
|
|
|
// Initialize reference if empty
|
|
reference := s.initializeReference(session)
|
|
|
|
// Check if custom LLM is specified and validate API key
|
|
isEmbedded := llmID != ""
|
|
if llmID != "" {
|
|
hasKey, err := s.checkTenantLLMAPIKey(dialog.TenantID, llmID)
|
|
if err != nil || !hasKey {
|
|
errMsg := fmt.Sprintf(`{"code": 500, "message": "Cannot use specified model %s", "data": {"answer": "**ERROR**: Cannot use specified model", "reference": []}}`, llmID)
|
|
streamChan <- fmt.Sprintf("data: %s\n\n", errMsg)
|
|
return fmt.Errorf("Cannot use specified model %s", llmID)
|
|
}
|
|
dialog.LLMID = llmID
|
|
if chatModelConfig != nil {
|
|
dialog.LLMSetting = chatModelConfig
|
|
}
|
|
}
|
|
|
|
// Perform streaming chat completion with RAG
|
|
resultChan, err := s.asyncChatStream(ctx, userID, dialog, session, messages, chatModelConfig, messageID, reference)
|
|
if err != nil {
|
|
streamChan <- fmt.Sprintf("data: %s\n\n", fmt.Sprintf(`{"code": 500, "message": "%s", "data": {"answer": "**ERROR**: %s", "reference": []}}`, err.Error(), err.Error()))
|
|
return err
|
|
}
|
|
|
|
// Stream results
|
|
for result := range resultChan {
|
|
data, _ := json.Marshal(map[string]interface{}{
|
|
"code": 0,
|
|
"message": "",
|
|
"data": result,
|
|
})
|
|
streamChan <- fmt.Sprintf("data: %s\n\n", string(data))
|
|
}
|
|
|
|
// Send final completion signal
|
|
finalData, _ := json.Marshal(map[string]interface{}{
|
|
"code": 0,
|
|
"message": "",
|
|
"data": true,
|
|
})
|
|
streamChan <- fmt.Sprintf("data: %s\n\n", string(finalData))
|
|
|
|
// Update conversation if not embedded
|
|
if !isEmbedded {
|
|
s.updateSessionMessages(session, sessionMessages, reference)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Helper methods
|
|
|
|
func (s *ChatSessionService) buildSessionMessages(session *entity.ChatSession, messages []map[string]interface{}) []map[string]interface{} {
|
|
// Deep copy messages to session
|
|
sessionMessages := make([]map[string]interface{}, len(messages))
|
|
for i, msg := range messages {
|
|
sessionMessages[i] = make(map[string]interface{})
|
|
for k, v := range msg {
|
|
sessionMessages[i][k] = v
|
|
}
|
|
}
|
|
return sessionMessages
|
|
}
|
|
|
|
func (s *ChatSessionService) initializeReference(session *entity.ChatSession) []interface{} {
|
|
var reference []interface{}
|
|
if len(session.Reference) > 0 {
|
|
json.Unmarshal(session.Reference, &reference)
|
|
}
|
|
// Filter out nil entries and append new reference
|
|
var filtered []interface{}
|
|
for _, r := range reference {
|
|
if r != nil {
|
|
filtered = append(filtered, r)
|
|
}
|
|
}
|
|
filtered = append(filtered, map[string]interface{}{
|
|
"chunks": []interface{}{},
|
|
"doc_aggs": []interface{}{},
|
|
})
|
|
return filtered
|
|
}
|
|
|
|
func (s *ChatSessionService) checkTenantLLMAPIKey(tenantID, modelName string) (bool, error) {
|
|
// Simplified check - in real implementation, check if tenant has API key for this model
|
|
return true, nil
|
|
}
|
|
|
|
func (s *ChatSessionService) structureAnswer(session *entity.ChatSession, answer string, messageID, conversationID string, reference []interface{}) map[string]interface{} {
|
|
return map[string]interface{}{
|
|
"answer": answer,
|
|
"reference": reference,
|
|
"conversation_id": conversationID,
|
|
"message_id": messageID,
|
|
}
|
|
}
|
|
|
|
func (s *ChatSessionService) updateSessionMessages(session *entity.ChatSession, messages []map[string]interface{}, reference []interface{}) {
|
|
// Update session with new messages and reference
|
|
messagesJSON, _ := json.Marshal(map[string]interface{}{
|
|
"messages": messages,
|
|
})
|
|
referenceJSON, _ := json.Marshal(reference)
|
|
|
|
updates := map[string]interface{}{
|
|
"message": messagesJSON,
|
|
"reference": referenceJSON,
|
|
}
|
|
s.chatSessionDAO.UpdateByID(session.ID, updates)
|
|
}
|
|
|
|
// asyncChat performs chat with RAG support (non-streaming)
|
|
func (s *ChatSessionService) asyncChat(userID string, dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) {
|
|
// Check if we need RAG (knowledge base or tavily)
|
|
hasKB := len(dialog.KBIDs) > 0
|
|
hasTavily := false
|
|
if dialog.PromptConfig != nil {
|
|
if tavilyKey, ok := dialog.PromptConfig["tavily_api_key"].(string); ok && tavilyKey != "" {
|
|
hasTavily = true
|
|
}
|
|
}
|
|
|
|
if !hasKB && !hasTavily {
|
|
// Simple chat without RAG
|
|
return s.asyncChatSolo(dialog, session, messages, config, messageID, reference, stream)
|
|
}
|
|
|
|
if hasKB {
|
|
return s.asyncChatWithRetrieval(context.Background(), userID, dialog, session, messages, config, messageID, reference, stream)
|
|
}
|
|
|
|
common.Warn("Tavily-backed chat retrieval is not implemented in Go; falling back to solo chat",
|
|
zap.String("dialog_id", dialog.ID))
|
|
return s.asyncChatSolo(dialog, session, messages, config, messageID, reference, stream)
|
|
}
|
|
|
|
// asyncChatStream performs streaming chat with RAG support
|
|
func (s *ChatSessionService) asyncChatStream(ctx context.Context, userID string, dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}) (<-chan map[string]interface{}, error) {
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
resultChan := make(chan map[string]interface{})
|
|
|
|
go func() {
|
|
defer close(resultChan)
|
|
|
|
// Check if we need RAG
|
|
hasKB := len(dialog.KBIDs) > 0
|
|
hasTavily := false
|
|
if dialog.PromptConfig != nil {
|
|
if tavilyKey, ok := dialog.PromptConfig["tavily_api_key"].(string); ok && tavilyKey != "" {
|
|
hasTavily = true
|
|
}
|
|
}
|
|
|
|
if !hasKB && !hasTavily {
|
|
// Simple chat without RAG
|
|
s.asyncChatSoloStream(dialog, session, messages, config, messageID, reference, resultChan)
|
|
return
|
|
}
|
|
|
|
if hasKB {
|
|
ragMessages, ragDialog, emptyResponse, err := s.messagesWithRetrievedKnowledge(ctx, userID, dialog, messages, reference)
|
|
if err != nil {
|
|
resultChan <- s.structureAnswer(session, "**ERROR**: "+err.Error(), messageID, session.ID, reference)
|
|
return
|
|
}
|
|
if emptyResponse != nil {
|
|
resultChan <- s.structureAnswer(session, *emptyResponse, messageID, session.ID, reference)
|
|
return
|
|
}
|
|
s.asyncChatSoloStream(ragDialog, session, ragMessages, config, messageID, reference, resultChan)
|
|
return
|
|
}
|
|
|
|
common.Warn("Tavily-backed streaming chat retrieval is not implemented in Go; falling back to solo chat",
|
|
zap.String("dialog_id", dialog.ID))
|
|
s.asyncChatSoloStream(dialog, session, messages, config, messageID, reference, resultChan)
|
|
}()
|
|
|
|
return resultChan, nil
|
|
}
|
|
|
|
func (s *ChatSessionService) asyncChatWithRetrieval(ctx context.Context, userID string, dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) {
|
|
ragMessages, ragDialog, emptyResponse, err := s.messagesWithRetrievedKnowledge(ctx, userID, dialog, messages, reference)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if emptyResponse != nil {
|
|
var lastRef interface{}
|
|
if len(reference) > 0 {
|
|
lastRef = reference[len(reference)-1]
|
|
}
|
|
ans := map[string]interface{}{
|
|
"answer": *emptyResponse,
|
|
"reference": lastRef,
|
|
"final": true,
|
|
}
|
|
return s.structureAnswerWithConv(session, ans, messageID, session.ID, reference), nil
|
|
}
|
|
return s.asyncChatSolo(ragDialog, session, ragMessages, config, messageID, reference, stream)
|
|
}
|
|
|
|
func (s *ChatSessionService) messagesWithRetrievedKnowledge(ctx context.Context, userID string, dialog *entity.Chat, messages []map[string]interface{}, reference []interface{}) ([]map[string]interface{}, *entity.Chat, *string, error) {
|
|
kbIDs := stringSliceFromJSON(dialog.KBIDs)
|
|
if len(kbIDs) == 0 {
|
|
return messages, dialog, nil, nil
|
|
}
|
|
if s.retrievalSvc == nil {
|
|
return nil, nil, nil, errors.New("retrieval service is not configured")
|
|
}
|
|
|
|
question := latestUserQuestion(messages)
|
|
if question == "" {
|
|
return messages, dialog, nil, nil
|
|
}
|
|
|
|
kbs, err := s.kbDAO.GetByIDs(kbIDs)
|
|
if err != nil {
|
|
return nil, nil, nil, fmt.Errorf("failed to load knowledge bases: %w", err)
|
|
}
|
|
kbs, err = s.knowledgebasesForDialog(userID, dialog, kbIDs, kbs)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
embeddingTenantID, embeddingModelName, err := validateKnowledgebaseEmbeddingModels(kbs, dialog.TenantID, resolveEmbeddingModelName)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
embeddingModel, err := s.modelProviderSvc.GetEmbeddingModel(embeddingTenantID, embeddingModelName)
|
|
if err != nil {
|
|
return nil, nil, nil, fmt.Errorf("failed to get embedding model: %w", err)
|
|
}
|
|
rerankModel, err := s.rerankModelForDialog(dialog)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
top := int(dialog.TopK)
|
|
pageSize := int(dialog.TopN)
|
|
if pageSize <= 0 {
|
|
pageSize = 6
|
|
}
|
|
similarityThreshold := dialog.SimilarityThreshold
|
|
vectorSimilarityWeight := dialog.VectorSimilarityWeight
|
|
var rankFeature map[string]float64
|
|
if s.metadataSvc != nil {
|
|
rankFeature = s.metadataSvc.LabelQuestion(question, kbs)
|
|
}
|
|
baseDocIDs := docIDsFromMessages(messages)
|
|
docIDs, err := s.filteredDocIDsForDialog(ctx, dialog, kbIDs, question, baseDocIDs)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
tenantIDs := tenantIDsFromKnowledgebases(kbs, dialog.TenantID)
|
|
|
|
retrievalResult, err := s.retrievalSvc.Retrieval(ctx, &nlp.RetrievalRequest{
|
|
Question: question,
|
|
TenantIDs: tenantIDs,
|
|
KbIDs: kbIDs,
|
|
DocIDs: docIDs,
|
|
Page: 1,
|
|
PageSize: pageSize,
|
|
Top: &top,
|
|
SimilarityThreshold: &similarityThreshold,
|
|
VectorSimilarityWeight: &vectorSimilarityWeight,
|
|
RankFeature: &rankFeature,
|
|
EmbeddingModel: embeddingModel,
|
|
RerankModel: rerankModel,
|
|
})
|
|
if err != nil {
|
|
return nil, nil, nil, fmt.Errorf("retrieval search failed: %w", err)
|
|
}
|
|
if retrievalResult == nil {
|
|
retrievalResult = &nlp.RetrievalResult{}
|
|
}
|
|
|
|
chunks := retrievalResult.Chunks
|
|
if s.docEngine != nil {
|
|
chunks = nlp.RetrievalByChildren(chunks, tenantIDs, s.docEngine, ctx)
|
|
}
|
|
setLatestReference(reference, chunks, retrievalResult.DocAggs)
|
|
knowledge := buildKnowledgeBlock(chunks)
|
|
if knowledge == "" {
|
|
return messages, dialog, emptyResponseForDialog(dialog), nil
|
|
}
|
|
if ragDialog, ok := dialogWithInjectedKnowledgePrompt(dialog, knowledge); ok {
|
|
return copyMessages(messages), ragDialog, nil, nil
|
|
}
|
|
|
|
return injectKnowledge(messages, knowledge), dialog, nil, nil
|
|
}
|
|
|
|
type embeddingModelNameResolver func(tenantID string, kb *entity.Knowledgebase) (string, error)
|
|
|
|
func validateKnowledgebaseEmbeddingModels(kbs []*entity.Knowledgebase, fallbackTenantID string, resolve embeddingModelNameResolver) (string, string, error) {
|
|
if len(kbs) == 0 {
|
|
return fallbackTenantID, "", nil
|
|
}
|
|
|
|
expected := ""
|
|
expectedKBID := ""
|
|
expectedTenantID := fallbackTenantID
|
|
for _, kb := range kbs {
|
|
if kb == nil {
|
|
return "", "", errors.New("knowledge base is nil")
|
|
}
|
|
tenantID := kb.TenantID
|
|
if tenantID == "" {
|
|
tenantID = fallbackTenantID
|
|
}
|
|
modelName, err := resolve(tenantID, kb)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
modelName = strings.TrimSpace(modelName)
|
|
if modelName == "" {
|
|
return "", "", fmt.Errorf("knowledge base %s has no embedding model", kb.ID)
|
|
}
|
|
if expected == "" {
|
|
expected = modelName
|
|
expectedKBID = kb.ID
|
|
expectedTenantID = tenantID
|
|
continue
|
|
}
|
|
if modelName != expected {
|
|
return "", "", fmt.Errorf("knowledge bases must use the same embedding model: %s resolves to %q, expected %q from %s", kb.ID, modelName, expected, expectedKBID)
|
|
}
|
|
}
|
|
return expectedTenantID, expected, nil
|
|
}
|
|
|
|
func (s *ChatSessionService) rerankModelForDialog(dialog *entity.Chat) (*modelModule.RerankModel, error) {
|
|
compositeName, err := resolveRerankModelName(dialog)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if compositeName == "" {
|
|
return nil, nil
|
|
}
|
|
rerankModel, err := s.modelProviderSvc.GetRerankModel(dialog.TenantID, compositeName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get rerank model: %w", err)
|
|
}
|
|
return rerankModel, nil
|
|
}
|
|
|
|
func (s *ChatSessionService) filteredDocIDsForDialog(ctx context.Context, dialog *entity.Chat, kbIDs []string, question string, baseDocIDs []string) ([]string, error) {
|
|
if dialog.MetaDataFilter == nil || len(*dialog.MetaDataFilter) == 0 {
|
|
return baseDocIDs, nil
|
|
}
|
|
if s.metadataSvc == nil {
|
|
return nil, errors.New("metadata service is not configured")
|
|
}
|
|
|
|
filter := make(map[string]interface{}, len(*dialog.MetaDataFilter))
|
|
for key, value := range *dialog.MetaDataFilter {
|
|
filter[key] = value
|
|
}
|
|
|
|
metaData, err := s.metadataSvc.GetFlattedMetaByKBs(kbIDs)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get flattened metadata for chat retrieval: %w", err)
|
|
}
|
|
|
|
var filterChatModel *modelModule.ChatModel
|
|
method, _ := filter["method"].(string)
|
|
if method == "auto" || method == "semi_auto" {
|
|
filterChatModel, err = s.modelProviderSvc.GetChatModel(dialog.TenantID, dialog.LLMID)
|
|
if err != nil {
|
|
common.Warn("Failed to get chat model for chat metadata filter", zap.Error(err))
|
|
}
|
|
}
|
|
|
|
docIDs, empty := ApplyMetaDataFilter(ctx, filter, metaData, question, filterChatModel, baseDocIDs, kbIDs)
|
|
if empty {
|
|
return []string{NoMatchDocIDSentinel}, nil
|
|
}
|
|
return docIDs, nil
|
|
}
|
|
|
|
func resolveEmbeddingModelName(tenantID string, kb *entity.Knowledgebase) (string, error) {
|
|
if kb.TenantEmbdID != nil && *kb.TenantEmbdID > 0 {
|
|
_, compositeName, err := dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *kb.TenantEmbdID)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to get embedding model by tenant_embd_id: %w", err)
|
|
}
|
|
return compositeName, nil
|
|
}
|
|
if kb.EmbdID != "" {
|
|
if strings.Contains(kb.EmbdID, "@") {
|
|
return kb.EmbdID, nil
|
|
}
|
|
_, compositeName, err := dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantID, kb.EmbdID, entity.ModelTypeEmbedding)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to get embedding model by embd_id: %w", err)
|
|
}
|
|
return compositeName, nil
|
|
}
|
|
|
|
tenantLLM, err := dao.NewTenantLLMDAO().GetByTenantAndType(tenantID, entity.ModelTypeEmbedding)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to get tenant default embedding model: %w", err)
|
|
}
|
|
if tenantLLM == nil || tenantLLM.LLMName == nil || *tenantLLM.LLMName == "" {
|
|
return "", fmt.Errorf("no default embedding model found for tenant %s", tenantID)
|
|
}
|
|
return fmt.Sprintf("%s@%s", *tenantLLM.LLMName, tenantLLM.LLMFactory), nil
|
|
}
|
|
|
|
func resolveRerankModelName(dialog *entity.Chat) (string, error) {
|
|
if dialog.TenantRerankID != nil && *dialog.TenantRerankID > 0 {
|
|
_, compositeName, err := dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *dialog.TenantRerankID)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to get rerank model by tenant_rerank_id: %w", err)
|
|
}
|
|
return compositeName, nil
|
|
}
|
|
if dialog.RerankID == "" {
|
|
return "", nil
|
|
}
|
|
if strings.Contains(dialog.RerankID, "@") {
|
|
return dialog.RerankID, nil
|
|
}
|
|
_, compositeName, err := dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), dialog.TenantID, dialog.RerankID, entity.ModelTypeRerank)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to get rerank model by rerank_id: %w", err)
|
|
}
|
|
return compositeName, nil
|
|
}
|
|
|
|
func stringSliceFromJSON(values entity.JSONSlice) []string {
|
|
result := make([]string, 0, len(values))
|
|
seen := make(map[string]struct{}, len(values))
|
|
for _, value := range values {
|
|
str, ok := value.(string)
|
|
if !ok || str == "" {
|
|
continue
|
|
}
|
|
if _, exists := seen[str]; exists {
|
|
continue
|
|
}
|
|
seen[str] = struct{}{}
|
|
result = append(result, str)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func tenantIDsFromKnowledgebases(kbs []*entity.Knowledgebase, fallback string) []string {
|
|
seen := make(map[string]struct{}, len(kbs)+1)
|
|
var tenantIDs []string
|
|
for _, kb := range kbs {
|
|
if kb == nil || kb.TenantID == "" {
|
|
continue
|
|
}
|
|
if _, exists := seen[kb.TenantID]; exists {
|
|
continue
|
|
}
|
|
seen[kb.TenantID] = struct{}{}
|
|
tenantIDs = append(tenantIDs, kb.TenantID)
|
|
}
|
|
if len(tenantIDs) == 0 && fallback != "" {
|
|
tenantIDs = append(tenantIDs, fallback)
|
|
}
|
|
return tenantIDs
|
|
}
|
|
|
|
func (s *ChatSessionService) knowledgebasesForDialog(userID string, dialog *entity.Chat, kbIDs []string, loaded []*entity.Knowledgebase) ([]*entity.Knowledgebase, error) {
|
|
byID := make(map[string]*entity.Knowledgebase, len(loaded))
|
|
for _, kb := range loaded {
|
|
if kb != nil {
|
|
byID[kb.ID] = kb
|
|
}
|
|
}
|
|
|
|
kbs := make([]*entity.Knowledgebase, 0, len(kbIDs))
|
|
for _, kbID := range kbIDs {
|
|
kb := byID[kbID]
|
|
if kb == nil {
|
|
return nil, fmt.Errorf("knowledge base %s not found", kbID)
|
|
}
|
|
if userID != "" && !s.kbDAO.Accessible(kbID, userID) {
|
|
return nil, fmt.Errorf("knowledge base %s is not authorized for user", kbID)
|
|
}
|
|
if userID == "" && kb.TenantID != dialog.TenantID {
|
|
return nil, fmt.Errorf("knowledge base %s is not authorized for dialog tenant", kbID)
|
|
}
|
|
kbs = append(kbs, kb)
|
|
}
|
|
if len(kbs) == 0 {
|
|
return nil, errors.New("no valid knowledge bases found")
|
|
}
|
|
return kbs, nil
|
|
}
|
|
|
|
func docIDsFromMessages(messages []map[string]interface{}) []string {
|
|
for i := len(messages) - 1; i >= 0; i-- {
|
|
if role, _ := messages[i]["role"].(string); role != "user" {
|
|
continue
|
|
}
|
|
return stringSliceFromValue(messages[i]["doc_ids"])
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func latestUserQuestion(messages []map[string]interface{}) string {
|
|
for i := len(messages) - 1; i >= 0; i-- {
|
|
if role, _ := messages[i]["role"].(string); role != "user" {
|
|
continue
|
|
}
|
|
return textFromMessageContent(messages[i]["content"])
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func stringSliceFromValue(value interface{}) []string {
|
|
switch typed := value.(type) {
|
|
case nil:
|
|
return nil
|
|
case []string:
|
|
return uniqueNonEmptyStrings(typed)
|
|
case []interface{}:
|
|
values := make([]string, 0, len(typed))
|
|
for _, item := range typed {
|
|
if str, ok := item.(string); ok {
|
|
values = append(values, str)
|
|
}
|
|
}
|
|
return uniqueNonEmptyStrings(values)
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func uniqueNonEmptyStrings(values []string) []string {
|
|
result := make([]string, 0, len(values))
|
|
seen := make(map[string]struct{}, len(values))
|
|
for _, value := range values {
|
|
value = strings.TrimSpace(value)
|
|
if value == "" {
|
|
continue
|
|
}
|
|
if _, exists := seen[value]; exists {
|
|
continue
|
|
}
|
|
seen[value] = struct{}{}
|
|
result = append(result, value)
|
|
}
|
|
if len(result) == 0 {
|
|
return nil
|
|
}
|
|
return result
|
|
}
|
|
|
|
func emptyResponseForDialog(dialog *entity.Chat) *string {
|
|
if dialog.PromptConfig == nil {
|
|
return nil
|
|
}
|
|
emptyResponse, ok := dialog.PromptConfig["empty_response"].(string)
|
|
if !ok || emptyResponse == "" {
|
|
return nil
|
|
}
|
|
return &emptyResponse
|
|
}
|
|
|
|
func buildKnowledgeBlock(chunks []map[string]interface{}) string {
|
|
var builder strings.Builder
|
|
for i, chunk := range chunks {
|
|
content := chunkText(chunk)
|
|
if content == "" {
|
|
continue
|
|
}
|
|
if builder.Len() > 0 {
|
|
builder.WriteString("\n\n")
|
|
}
|
|
builder.WriteString(fmt.Sprintf("[%d]", i+1))
|
|
if docName, ok := chunk["docnm_kwd"].(string); ok && docName != "" {
|
|
builder.WriteString(" ")
|
|
builder.WriteString(docName)
|
|
}
|
|
builder.WriteString("\n")
|
|
builder.WriteString(content)
|
|
}
|
|
return builder.String()
|
|
}
|
|
|
|
func chunkText(chunk map[string]interface{}) string {
|
|
for _, key := range []string{"content_with_weight", "content_ltks", "content"} {
|
|
if value, ok := chunk[key].(string); ok && strings.TrimSpace(value) != "" {
|
|
return strings.TrimSpace(value)
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func injectKnowledge(messages []map[string]interface{}, knowledge string) []map[string]interface{} {
|
|
copied := copyMessages(messages)
|
|
if len(copied) == 0 {
|
|
return copied
|
|
}
|
|
|
|
knowledgePrompt := fmt.Sprintf("Use the following knowledge snippets to answer the user's question. If the snippets do not contain the answer, say that the knowledge base does not provide enough information.\n\n%s", knowledge)
|
|
for i := len(copied) - 1; i >= 0; i-- {
|
|
if role, _ := copied[i]["role"].(string); role != "user" {
|
|
continue
|
|
}
|
|
copied[i]["content"] = injectKnowledgeIntoContent(copied[i]["content"], knowledgePrompt)
|
|
return copied
|
|
}
|
|
|
|
copied = append(copied, map[string]interface{}{
|
|
"role": "system",
|
|
"content": knowledgePrompt,
|
|
})
|
|
return copied
|
|
}
|
|
|
|
func injectKnowledgeIntoContent(content interface{}, knowledgePrompt string) interface{} {
|
|
switch typed := content.(type) {
|
|
case []interface{}:
|
|
injected := make([]interface{}, 0, len(typed)+1)
|
|
injected = append(injected, knowledgeTextBlock(knowledgePrompt))
|
|
injected = append(injected, typed...)
|
|
return injected
|
|
case []map[string]interface{}:
|
|
injected := make([]interface{}, 0, len(typed)+1)
|
|
injected = append(injected, knowledgeTextBlock(knowledgePrompt))
|
|
for _, block := range typed {
|
|
injected = append(injected, block)
|
|
}
|
|
return injected
|
|
default:
|
|
contentText := ""
|
|
if content != nil {
|
|
contentText = fmt.Sprint(content)
|
|
}
|
|
return strings.TrimSpace(knowledgePrompt + "\n\nQuestion:\n" + contentText)
|
|
}
|
|
}
|
|
|
|
func knowledgeTextBlock(knowledgePrompt string) map[string]interface{} {
|
|
return map[string]interface{}{
|
|
"type": "text",
|
|
"text": knowledgePrompt + "\n\nQuestion:",
|
|
}
|
|
}
|
|
|
|
func textFromMessageContent(content interface{}) string {
|
|
switch typed := content.(type) {
|
|
case string:
|
|
return strings.TrimSpace(typed)
|
|
case []interface{}:
|
|
return strings.TrimSpace(strings.Join(textsFromContentBlocks(typed), "\n"))
|
|
case []map[string]interface{}:
|
|
blocks := make([]interface{}, 0, len(typed))
|
|
for _, block := range typed {
|
|
blocks = append(blocks, block)
|
|
}
|
|
return strings.TrimSpace(strings.Join(textsFromContentBlocks(blocks), "\n"))
|
|
default:
|
|
if content == nil {
|
|
return ""
|
|
}
|
|
return strings.TrimSpace(fmt.Sprint(content))
|
|
}
|
|
}
|
|
|
|
func textsFromContentBlocks(blocks []interface{}) []string {
|
|
texts := make([]string, 0, len(blocks))
|
|
for _, block := range blocks {
|
|
switch typed := block.(type) {
|
|
case string:
|
|
if text := strings.TrimSpace(typed); text != "" {
|
|
texts = append(texts, text)
|
|
}
|
|
case map[string]interface{}:
|
|
if text, ok := typed["text"].(string); ok && strings.TrimSpace(text) != "" {
|
|
texts = append(texts, strings.TrimSpace(text))
|
|
}
|
|
}
|
|
}
|
|
return texts
|
|
}
|
|
|
|
func dialogWithInjectedKnowledgePrompt(dialog *entity.Chat, knowledge string) (*entity.Chat, bool) {
|
|
if dialog.PromptConfig == nil {
|
|
return dialog, false
|
|
}
|
|
systemPrompt, ok := dialog.PromptConfig["system"].(string)
|
|
if !ok || !strings.Contains(systemPrompt, "{knowledge}") {
|
|
return dialog, false
|
|
}
|
|
|
|
copied := cloneJSONMap(dialog.PromptConfig)
|
|
copied["system"] = strings.ReplaceAll(systemPrompt, "{knowledge}", knowledge)
|
|
dialogCopy := *dialog
|
|
dialogCopy.PromptConfig = copied
|
|
return &dialogCopy, true
|
|
}
|
|
|
|
func cloneJSONMap(values entity.JSONMap) entity.JSONMap {
|
|
copied := make(entity.JSONMap, len(values))
|
|
for key, value := range values {
|
|
copied[key] = value
|
|
}
|
|
return copied
|
|
}
|
|
|
|
func copyMessages(messages []map[string]interface{}) []map[string]interface{} {
|
|
copied := make([]map[string]interface{}, len(messages))
|
|
for i, msg := range messages {
|
|
copied[i] = make(map[string]interface{}, len(msg))
|
|
for key, value := range msg {
|
|
copied[i][key] = value
|
|
}
|
|
}
|
|
return copied
|
|
}
|
|
|
|
func setLatestReference(reference []interface{}, chunks []map[string]interface{}, docAggs []map[string]interface{}) {
|
|
ref := map[string]interface{}{
|
|
"chunks": chunksForReference(chunks),
|
|
"doc_aggs": mapsForReference(docAggs),
|
|
}
|
|
if len(reference) == 0 {
|
|
return
|
|
}
|
|
reference[len(reference)-1] = ref
|
|
}
|
|
|
|
func chunksForReference(chunks []map[string]interface{}) []interface{} {
|
|
result := make([]interface{}, 0, len(chunks))
|
|
for _, chunk := range chunks {
|
|
copied := make(map[string]interface{}, len(chunk))
|
|
for key, value := range chunk {
|
|
if key == "vector" {
|
|
continue
|
|
}
|
|
copied[key] = value
|
|
}
|
|
result = append(result, copied)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func mapsForReference(values []map[string]interface{}) []interface{} {
|
|
result := make([]interface{}, 0, len(values))
|
|
for _, value := range values {
|
|
result = append(result, value)
|
|
}
|
|
return result
|
|
}
|
|
|
|
// asyncChatSolo performs simple chat without RAG (non-streaming)
|
|
func (s *ChatSessionService) asyncChatSolo(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) {
|
|
common.Info("asyncChatSolo started",
|
|
zap.String("tenant_id", dialog.TenantID),
|
|
zap.String("llm_id", dialog.LLMID),
|
|
zap.String("dialog_id", dialog.ID),
|
|
zap.Int("message_count", len(messages)))
|
|
|
|
// Get system prompt
|
|
systemPrompt := s.buildSystemPrompt(dialog)
|
|
|
|
// Process messages - handle attachments and image files
|
|
processedMessages := s.processMessages(messages, dialog)
|
|
|
|
var (
|
|
driver modelModule.ModelDriver
|
|
modelName string
|
|
apiConfig *modelModule.APIConfig
|
|
err error
|
|
)
|
|
if dialog.LLMID != "" {
|
|
driver, modelName, apiConfig, _, err = s.modelProviderSvc.GetModelConfigFromProviderInstance(
|
|
dialog.TenantID, entity.ModelTypeChat, dialog.LLMID,
|
|
)
|
|
} else {
|
|
driver, modelName, apiConfig, _, err = s.modelProviderSvc.GetTenantDefaultModelByType(
|
|
dialog.TenantID, entity.ModelTypeChat,
|
|
)
|
|
}
|
|
if err != nil {
|
|
common.Error("asyncChatSolo failed to get chat model", err)
|
|
return nil, err
|
|
}
|
|
chatModel := modelModule.NewChatModel(driver, &modelName, apiConfig)
|
|
|
|
// Convert messages to Message format
|
|
var msgs []modelModule.Message
|
|
if systemPrompt != "" {
|
|
msgs = append(msgs, modelModule.Message{Role: "system", Content: systemPrompt})
|
|
}
|
|
for _, msg := range processedMessages {
|
|
role, _ := msg["role"].(string)
|
|
if role == "" || role == "system" {
|
|
continue
|
|
}
|
|
|
|
if msg["content"] != nil {
|
|
msgs = append(msgs, modelModule.Message{Role: role, Content: msg["content"]})
|
|
}
|
|
}
|
|
|
|
// Get ChatConfig directly from dialog and config
|
|
chatConfig := s.buildChatConfig(dialog, config)
|
|
|
|
// Perform chat
|
|
response, err := chatModel.ModelDriver.ChatWithMessages(*chatModel.ModelName, msgs, chatModel.APIConfig, chatConfig)
|
|
if err != nil {
|
|
common.Error("asyncChatSolo chat failed", err)
|
|
return nil, err
|
|
}
|
|
|
|
common.Info("asyncChatSolo completed",
|
|
zap.String("tenant_id", dialog.TenantID),
|
|
zap.String("llm_id", dialog.LLMID),
|
|
zap.Int("response_length", len(*response.Answer)))
|
|
|
|
// Structure the answer
|
|
ans := map[string]interface{}{
|
|
"answer": *response.Answer,
|
|
"reference": reference[len(reference)-1],
|
|
"final": true,
|
|
}
|
|
|
|
return s.structureAnswerWithConv(session, ans, messageID, session.ID, reference), nil
|
|
}
|
|
|
|
// asyncChatSoloStream performs simple streaming chat without RAG
|
|
func (s *ChatSessionService) asyncChatSoloStream(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, resultChan chan<- map[string]interface{}) {
|
|
common.Info("asyncChatSoloStream started",
|
|
zap.String("tenant_id", dialog.TenantID),
|
|
zap.String("llm_id", dialog.LLMID),
|
|
zap.String("dialog_id", dialog.ID),
|
|
zap.Int("message_count", len(messages)))
|
|
|
|
// Get system prompt
|
|
systemPrompt := s.buildSystemPrompt(dialog)
|
|
|
|
// Process messages
|
|
processedMessages := s.processMessages(messages, dialog)
|
|
|
|
var (
|
|
driver modelModule.ModelDriver
|
|
modelName string
|
|
apiConfig *modelModule.APIConfig
|
|
err error
|
|
)
|
|
if dialog.LLMID != "" {
|
|
driver, modelName, apiConfig, _, err = s.modelProviderSvc.GetModelConfigFromProviderInstance(
|
|
dialog.TenantID, entity.ModelTypeChat, dialog.LLMID,
|
|
)
|
|
} else {
|
|
driver, modelName, apiConfig, _, err = s.modelProviderSvc.GetTenantDefaultModelByType(
|
|
dialog.TenantID, entity.ModelTypeChat,
|
|
)
|
|
}
|
|
if err != nil {
|
|
common.Error("asyncChatSoloStream failed to get chat model", err)
|
|
resultChan <- s.structureAnswer(session, "**ERROR**: "+err.Error(), messageID, session.ID, reference)
|
|
return
|
|
}
|
|
chatModel := modelModule.NewChatModel(driver, &modelName, apiConfig)
|
|
|
|
// Convert messages to []modelModule.Message for ChatStreamlyWithSender
|
|
var chatMessages []modelModule.Message
|
|
if systemPrompt != "" {
|
|
chatMessages = append(chatMessages, modelModule.Message{
|
|
Role: "system",
|
|
Content: systemPrompt,
|
|
})
|
|
}
|
|
for _, msg := range processedMessages {
|
|
role, _ := msg["role"].(string)
|
|
content := msg["content"]
|
|
if role != "" && content != nil && role != "system" {
|
|
chatMessages = append(chatMessages, modelModule.Message{
|
|
Role: role,
|
|
Content: content,
|
|
})
|
|
}
|
|
}
|
|
|
|
// Get ChatConfig directly from dialog and config
|
|
chatConfig := s.buildChatConfig(dialog, config)
|
|
|
|
// Perform streaming chat using ChatStreamlyWithSender
|
|
fullAnswer := ""
|
|
err = chatModel.ModelDriver.ChatStreamlyWithSender(*chatModel.ModelName, chatMessages, chatModel.APIConfig, chatConfig, func(answer *string, reason *string) error {
|
|
if reason != nil && *reason != "" {
|
|
fullAnswer += *reason
|
|
ans := s.structureAnswer(session, fullAnswer, messageID, session.ID, reference)
|
|
resultChan <- ans
|
|
}
|
|
if answer != nil && *answer != "" {
|
|
fullAnswer += *answer
|
|
fullAnswer = s.removeReasoningContent(fullAnswer)
|
|
ans := s.structureAnswer(session, fullAnswer, messageID, session.ID, reference)
|
|
resultChan <- ans
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
resultChan <- s.structureAnswer(session, "**ERROR**: "+err.Error(), messageID, session.ID, reference)
|
|
return
|
|
}
|
|
|
|
common.Info("asyncChatSoloStream completed",
|
|
zap.String("tenant_id", dialog.TenantID),
|
|
zap.String("llm_id", dialog.LLMID),
|
|
zap.Int("response_length", len(fullAnswer)))
|
|
}
|
|
|
|
// buildSystemPrompt builds the system prompt from dialog configuration
|
|
func (s *ChatSessionService) buildSystemPrompt(dialog *entity.Chat) string {
|
|
if dialog.PromptConfig == nil {
|
|
return ""
|
|
}
|
|
|
|
system, _ := dialog.PromptConfig["system"].(string)
|
|
return system
|
|
}
|
|
|
|
// processMessages processes messages and handles attachments
|
|
func (s *ChatSessionService) processMessages(messages []map[string]interface{}, dialog *entity.Chat) []map[string]interface{} {
|
|
// Process each message
|
|
processed := make([]map[string]interface{}, len(messages))
|
|
for i, msg := range messages {
|
|
processed[i] = make(map[string]interface{})
|
|
for k, v := range msg {
|
|
processed[i][k] = v
|
|
}
|
|
|
|
// Clean content - remove file markers
|
|
if content, ok := msg["content"].(string); ok {
|
|
content = s.cleanContent(content)
|
|
processed[i]["content"] = content
|
|
}
|
|
}
|
|
|
|
return processed
|
|
}
|
|
|
|
// cleanContent removes file markers from content
|
|
func (s *ChatSessionService) cleanContent(content string) string {
|
|
// Remove ##N$$ markers
|
|
// This is a simplified version - full implementation would use regex
|
|
return content
|
|
}
|
|
|
|
// removeReasoningContent removes reasoning/thinking content from answer
|
|
func (s *ChatSessionService) removeReasoningContent(answer string) string {
|
|
// Remove </think> tags
|
|
if strings.HasSuffix(answer, "</think>") {
|
|
answer = answer[:len(answer)-len("</think>")]
|
|
}
|
|
return answer
|
|
}
|
|
|
|
// structureAnswerWithConv structures the answer with conversation update (like Python's structure_answer)
|
|
func (s *ChatSessionService) structureAnswerWithConv(session *entity.ChatSession, ans map[string]interface{}, messageID, conversationID string, reference []interface{}) map[string]interface{} {
|
|
// Extract reference from answer
|
|
ref, _ := ans["reference"].(map[string]interface{})
|
|
if ref == nil {
|
|
ref = map[string]interface{}{
|
|
"chunks": []interface{}{},
|
|
"doc_aggs": []interface{}{},
|
|
}
|
|
ans["reference"] = ref
|
|
}
|
|
|
|
// Format chunks
|
|
chunkList := s.chunksFormat(ref)
|
|
ref["chunks"] = chunkList
|
|
|
|
// Add message ID and session ID
|
|
ans["id"] = messageID
|
|
ans["session_id"] = conversationID
|
|
|
|
// Update session message
|
|
content, _ := ans["answer"].(string)
|
|
if ans["start_to_think"] != nil {
|
|
content = "<think>"
|
|
} else if ans["end_to_think"] != nil {
|
|
content = "</think>"
|
|
}
|
|
|
|
// Parse existing messages
|
|
var messagesObj map[string]interface{}
|
|
if len(session.Message) > 0 {
|
|
json.Unmarshal(session.Message, &messagesObj)
|
|
}
|
|
messages, _ := messagesObj["messages"].([]interface{})
|
|
|
|
// Update or append assistant message
|
|
if len(messages) == 0 || s.getLastRole(messages) != "assistant" {
|
|
messages = append(messages, map[string]interface{}{
|
|
"role": "assistant",
|
|
"content": content,
|
|
"created_at": float64(time.Now().Unix()),
|
|
"id": messageID,
|
|
})
|
|
} else {
|
|
lastIdx := len(messages) - 1
|
|
lastMsg, _ := messages[lastIdx].(map[string]interface{})
|
|
if lastMsg != nil {
|
|
if ans["final"] == true && ans["answer"] != nil {
|
|
lastMsg["content"] = ans["answer"]
|
|
} else {
|
|
existing, _ := lastMsg["content"].(string)
|
|
lastMsg["content"] = existing + content
|
|
}
|
|
lastMsg["created_at"] = float64(time.Now().Unix())
|
|
lastMsg["id"] = messageID
|
|
messages[lastIdx] = lastMsg
|
|
}
|
|
}
|
|
|
|
// Update reference
|
|
if len(reference) > 0 {
|
|
reference[len(reference)-1] = ref
|
|
}
|
|
|
|
return ans
|
|
}
|
|
|
|
// getLastRole gets the role of the last message
|
|
func (s *ChatSessionService) getLastRole(messages []interface{}) string {
|
|
if len(messages) == 0 {
|
|
return ""
|
|
}
|
|
lastMsg, _ := messages[len(messages)-1].(map[string]interface{})
|
|
if lastMsg != nil {
|
|
role, _ := lastMsg["role"].(string)
|
|
return role
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// chunksFormat formats chunks for reference (simplified version)
|
|
func (s *ChatSessionService) chunksFormat(reference map[string]interface{}) []interface{} {
|
|
chunks, _ := reference["chunks"].([]interface{})
|
|
if chunks == nil {
|
|
return []interface{}{}
|
|
}
|
|
|
|
// Format each chunk
|
|
formatted := make([]interface{}, len(chunks))
|
|
for i, chunk := range chunks {
|
|
formatted[i] = chunk
|
|
}
|
|
return formatted
|
|
}
|
|
|
|
// buildChatConfig builds ChatConfig directly from dialog.LLMSetting and config
|
|
func (s *ChatSessionService) buildChatConfig(dialog *entity.Chat, config map[string]interface{}) *modelModule.ChatConfig {
|
|
cfg := &modelModule.ChatConfig{}
|
|
|
|
// Start with dialog's LLM setting
|
|
if dialog.LLMSetting != nil {
|
|
if v, ok := dialog.LLMSetting["stream"].(bool); ok {
|
|
cfg.Stream = &v
|
|
}
|
|
if v, ok := dialog.LLMSetting["thinking"].(bool); ok {
|
|
cfg.Thinking = &v
|
|
}
|
|
if v, ok := dialog.LLMSetting["max_tokens"].(float64); ok {
|
|
intVal := int(v)
|
|
cfg.MaxTokens = &intVal
|
|
}
|
|
if v, ok := dialog.LLMSetting["temperature"].(float64); ok {
|
|
cfg.Temperature = &v
|
|
}
|
|
if v, ok := dialog.LLMSetting["top_p"].(float64); ok {
|
|
cfg.TopP = &v
|
|
}
|
|
if v, ok := dialog.LLMSetting["do_sample"].(bool); ok {
|
|
cfg.DoSample = &v
|
|
}
|
|
if v, ok := dialog.LLMSetting["stop"].([]interface{}); ok {
|
|
stopStrs := make([]string, 0, len(v))
|
|
for _, s := range v {
|
|
if str, ok := s.(string); ok {
|
|
stopStrs = append(stopStrs, str)
|
|
}
|
|
}
|
|
cfg.Stop = &stopStrs
|
|
}
|
|
if v, ok := dialog.LLMSetting["model_class"].(string); ok {
|
|
cfg.ModelClass = &v
|
|
}
|
|
if v, ok := dialog.LLMSetting["effort"].(string); ok {
|
|
cfg.Effort = &v
|
|
}
|
|
if v, ok := dialog.LLMSetting["verbosity"].(string); ok {
|
|
cfg.Verbosity = &v
|
|
}
|
|
}
|
|
|
|
// Override with request config
|
|
if config != nil {
|
|
if v, ok := config["stream"].(bool); ok {
|
|
cfg.Stream = &v
|
|
}
|
|
if v, ok := config["thinking"].(bool); ok {
|
|
cfg.Thinking = &v
|
|
}
|
|
if v, ok := config["max_tokens"].(float64); ok {
|
|
intVal := int(v)
|
|
cfg.MaxTokens = &intVal
|
|
}
|
|
if v, ok := config["temperature"].(float64); ok {
|
|
cfg.Temperature = &v
|
|
}
|
|
if v, ok := config["top_p"].(float64); ok {
|
|
cfg.TopP = &v
|
|
}
|
|
if v, ok := config["do_sample"].(bool); ok {
|
|
cfg.DoSample = &v
|
|
}
|
|
if v, ok := config["stop"].([]interface{}); ok {
|
|
stopStrs := make([]string, 0, len(v))
|
|
for _, s := range v {
|
|
if str, ok := s.(string); ok {
|
|
stopStrs = append(stopStrs, str)
|
|
}
|
|
}
|
|
cfg.Stop = &stopStrs
|
|
}
|
|
if v, ok := config["model_class"].(string); ok {
|
|
cfg.ModelClass = &v
|
|
}
|
|
if v, ok := config["effort"].(string); ok {
|
|
cfg.Effort = &v
|
|
}
|
|
if v, ok := config["verbosity"].(string); ok {
|
|
cfg.Verbosity = &v
|
|
}
|
|
}
|
|
|
|
return cfg
|
|
}
|