// // Copyright 2026 The InfiniFlow Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // package service import ( "errors" "fmt" "ragflow/internal/common" "ragflow/internal/entity" "strings" "unicode/utf8" "ragflow/internal/dao" ) // ChatService chat service type ChatService struct { chatDAO *dao.ChatDAO kbDAO *dao.KnowledgebaseDAO userTenantDAO *dao.UserTenantDAO tenantDAO *dao.TenantDAO } // NewChatService create chat service func NewChatService() *ChatService { return &ChatService{ chatDAO: dao.NewChatDAO(), kbDAO: dao.NewKnowledgebaseDAO(), userTenantDAO: dao.NewUserTenantDAO(), tenantDAO: dao.NewTenantDAO(), } } // ChatWithKBNames chat with knowledge base names type ChatWithKBNames struct { *entity.Chat KBNames []string `json:"kb_names"` DatasetIDs []string `json:"dataset_ids"` } // ListChatsResponse list chats response type ListChatsResponse struct { Chats []*ChatWithKBNames `json:"chats"` } // ListChats list chats for a user func (s *ChatService) ListChats(userID, status, keywords string, page, pageSize int, orderby string, desc bool) (*ListChatsResponse, error) { // Get tenant IDs by user ID tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) if err != nil { return nil, err } // For now, use the first tenant ID (primary tenant) // This matches the Python implementation behavior var tenantID string if len(tenantIDs) > 0 { tenantID = tenantIDs[0] } else { tenantID = userID } // Query chats by tenant ID chats, err := s.chatDAO.ListByTenantID(tenantID, status) if err != nil { return nil, err } total := int64(len(chats)) if page > 0 && pageSize > 0 { start := (page - 1) * pageSize end := start + pageSize if start < int(total) { if end > int(total) { end = int(total) } chats = chats[start:end] } else { chats = []*entity.Chat{} } } // Enrich with knowledge base names chatsWithKBNames := make([]*ChatWithKBNames, 0, len(chats)) for _, chat := range chats { kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs) chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{ Chat: chat, KBNames: kbNames, DatasetIDs: datasetIDs, }) } return &ListChatsResponse{ Chats: chatsWithKBNames, }, nil } // ListChatsNextRequest list chats next request type ListChatsNextRequest struct { OwnerIDs []string `json:"owner_ids,omitempty"` } // ListChatsNextResponse list chats next response type ListChatsNextResponse struct { Chats []*ChatWithKBNames `json:"dialogs"` Total int64 `json:"total"` } // ListChatsNext list chats with advanced filtering (equivalent to list_dialogs_next) func (s *ChatService) ListChatsNext(userID string, keywords string, page, pageSize int, orderby string, desc bool, ownerIDs []string) (*ListChatsNextResponse, error) { var chats []*entity.Chat var total int64 var err error if len(ownerIDs) == 0 { // Get tenant IDs by user ID (joined tenants) tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) if err != nil { return nil, err } // Use database pagination chats, total, err = s.chatDAO.ListByTenantIDs(tenantIDs, userID, page, pageSize, orderby, desc, keywords) if err != nil { return nil, err } } else { // Filter by owner IDs, manual pagination chats, total, err = s.chatDAO.ListByOwnerIDs(ownerIDs, userID, orderby, desc, keywords) if err != nil { return nil, err } // Manual pagination if page > 0 && pageSize > 0 { start := (page - 1) * pageSize end := start + pageSize if start < int(total) { if end > int(total) { end = int(total) } chats = chats[start:end] } else { chats = []*entity.Chat{} } } } // Enrich with knowledge base names chatsWithKBNames := make([]*ChatWithKBNames, 0, len(chats)) for _, chat := range chats { kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs) chatsWithKBNames = append(chatsWithKBNames, &ChatWithKBNames{ Chat: chat, KBNames: kbNames, DatasetIDs: datasetIDs, }) } return &ListChatsNextResponse{ Chats: chatsWithKBNames, Total: total, }, nil } // getDatasetNamesAndIDs gets knowledge base names by IDs func (s *ChatService) getDatasetNamesAndIDs(kbIDs entity.JSONSlice) ([]string, []string) { var names = make([]string, 0, 0) var ids = make([]string, 0, 0) for _, kbID := range kbIDs { kbIDStr, ok := kbID.(string) if !ok { continue } kb, err := s.kbDAO.GetByID(kbIDStr) if err != nil || kb == nil { continue } // Only include valid KBs if kb.Status != nil && *kb.Status == "1" { names = append(names, kb.Name) ids = append(ids, kbIDStr) } } return names, ids } // ParameterConfig parameter configuration in prompt_config type ParameterConfig struct { Key string `json:"key"` Optional bool `json:"optional"` } // PromptConfig prompt configuration type PromptConfig struct { System string `json:"system"` Prologue string `json:"prologue"` Parameters []ParameterConfig `json:"parameters"` EmptyResponse string `json:"empty_response"` TavilyAPIKey string `json:"tavily_api_key,omitempty"` Keyword bool `json:"keyword,omitempty"` Quote bool `json:"quote,omitempty"` Reasoning bool `json:"reasoning,omitempty"` RefineMultiturn bool `json:"refine_multiturn,omitempty"` TocEnhance bool `json:"toc_enhance,omitempty"` TTS bool `json:"tts,omitempty"` UseKG bool `json:"use_kg,omitempty"` } // SetDialogRequest set chat request type SetDialogRequest struct { DialogID string `json:"dialog_id,omitempty"` Name string `json:"name,omitempty"` Description string `json:"description,omitempty"` Icon string `json:"icon,omitempty"` TopN int64 `json:"top_n,omitempty"` TopK int64 `json:"top_k,omitempty"` RerankID string `json:"rerank_id,omitempty"` SimilarityThreshold float64 `json:"similarity_threshold,omitempty"` VectorSimilarityWeight float64 `json:"vector_similarity_weight,omitempty"` LLMSetting map[string]interface{} `json:"llm_setting,omitempty"` MetaDataFilter map[string]interface{} `json:"meta_data_filter,omitempty"` PromptConfig *PromptConfig `json:"prompt_config" binding:"required"` KBIDs []string `json:"kb_ids,omitempty"` LLMID string `json:"llm_id,omitempty"` } // SetDialogResponse set chat response type SetDialogResponse struct { *entity.Chat KBNames []string `json:"kb_names"` } // SetDialog create or update a chat func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialogResponse, error) { // Determine if this is a create or update operation isCreate := req.DialogID == "" // Validate and process name name := req.Name if name == "" { name = "New Chat" } // Validate name type and content if strings.TrimSpace(name) == "" { return nil, errors.New("Chat name can't be empty") } // Check name length (UTF-8 byte length) if len(name) > 255 { return nil, fmt.Errorf("Chat name length is %d which is larger than 255", len(name)) } name = strings.TrimSpace(name) // Get tenant ID (use userID as default tenant) tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) if err != nil { return nil, err } var tenantID string if len(tenantIDs) > 0 { tenantID = tenantIDs[0] } else { tenantID = userID } // For create: check for duplicate names and generate unique name if isCreate { existingNames, err := s.chatDAO.GetExistingNames(tenantID, "1") if err != nil { return nil, err } // Check if name exists (case-insensitive) nameLower := strings.ToLower(name) for _, existing := range existingNames { if strings.ToLower(existing) == nameLower { // Generate unique name name = s.generateUniqueName(name, existingNames) break } } } // Set default values description := req.Description if description == "" { description = "A helpful chat" } topN := req.TopN if topN == 0 { topN = 6 } topK := req.TopK if topK == 0 { topK = 1024 } rerankID := req.RerankID similarityThreshold := req.SimilarityThreshold if similarityThreshold == 0 { similarityThreshold = 0.1 } vectorSimilarityWeight := req.VectorSimilarityWeight if vectorSimilarityWeight == 0 { vectorSimilarityWeight = 0.3 } llmSetting := req.LLMSetting if llmSetting == nil { llmSetting = make(map[string]interface{}) } metaDataFilter := req.MetaDataFilter if metaDataFilter == nil { metaDataFilter = make(map[string]interface{}) } promptConfig := req.PromptConfig // Process kb_ids kbIDs := req.KBIDs if kbIDs == nil { kbIDs = []string{} } // Set default parameters for datasets with knowledge retrieval // Check if parameters is missing or empty and kb_ids is provided if len(kbIDs) > 0 && (promptConfig.Parameters == nil || len(promptConfig.Parameters) == 0) { // Check if system prompt uses {knowledge} placeholder if strings.Contains(promptConfig.System, "{knowledge}") { // Set default parameters for any dataset with knowledge placeholder promptConfig.Parameters = []ParameterConfig{ {Key: "knowledge", Optional: false}, } } } // For update: validate that {knowledge} is not used when no KBs or Tavily if !isCreate { if len(kbIDs) == 0 && promptConfig.TavilyAPIKey == "" && strings.Contains(promptConfig.System, "{knowledge}") { return nil, errors.New("Please remove `{knowledge}` in system prompt since no dataset / Tavily used here") } } // Validate parameters for _, p := range promptConfig.Parameters { if p.Optional { continue } placeholder := fmt.Sprintf("{%s}", p.Key) if !strings.Contains(promptConfig.System, placeholder) { return nil, fmt.Errorf("Parameter '%s' is not used", p.Key) } } // Check knowledge bases and their embedding models if len(kbIDs) > 0 { kbs, err := s.kbDAO.GetByIDs(kbIDs) if err != nil { return nil, err } // Check if all KBs use the same embedding model var embdID string for i, kb := range kbs { if i == 0 { embdID = kb.EmbdID } else { // Extract base model name (remove vendor suffix) embdBase := s.splitModelNameAndFactory(embdID) kbEmbdBase := s.splitModelNameAndFactory(kb.EmbdID) if embdBase != kbEmbdBase { return nil, fmt.Errorf("Datasets use different embedding models: %v", getEmbdIDs(kbs)) } } } } // Get LLM ID (use tenant's default if not provided) llmID := req.LLMID if llmID == "" { tenant, err := s.tenantDAO.GetByID(tenantID) if err != nil { return nil, errors.New("Tenant not found") } llmID = tenant.LLMID } // Convert prompt config to JSONMap with all fields promptConfigMap := entity.JSONMap{ "system": promptConfig.System, "prologue": promptConfig.Prologue, "empty_response": promptConfig.EmptyResponse, "keyword": promptConfig.Keyword, "quote": promptConfig.Quote, "reasoning": promptConfig.Reasoning, "refine_multiturn": promptConfig.RefineMultiturn, "toc_enhance": promptConfig.TocEnhance, "tts": promptConfig.TTS, "use_kg": promptConfig.UseKG, } if promptConfig.TavilyAPIKey != "" { promptConfigMap["tavily_api_key"] = promptConfig.TavilyAPIKey } if len(promptConfig.Parameters) > 0 { params := make([]map[string]interface{}, len(promptConfig.Parameters)) for i, p := range promptConfig.Parameters { params[i] = map[string]interface{}{ "key": p.Key, "optional": p.Optional, } } promptConfigMap["parameters"] = params } // Convert kbIDs to JSONSlice kbIDsJSON := make(entity.JSONSlice, len(kbIDs)) for i, id := range kbIDs { kbIDsJSON[i] = id } if isCreate { // Generate UUID for new chat newID := common.GenerateUUID() // Set default language language := "English" // Create new chat chat := &entity.Chat{ ID: newID, TenantID: tenantID, Name: &name, Description: &description, Icon: &req.Icon, Language: &language, LLMID: llmID, LLMSetting: llmSetting, PromptConfig: promptConfigMap, MetaDataFilter: (*entity.JSONMap)(&metaDataFilter), TopN: topN, TopK: topK, RerankID: rerankID, SimilarityThreshold: similarityThreshold, VectorSimilarityWeight: vectorSimilarityWeight, KBIDs: kbIDsJSON, Status: strPtr("1"), } if err := s.chatDAO.Create(chat); err != nil { return nil, errors.New("Fail to new a chat") } // Get KB names kbNames, _ := s.getDatasetNamesAndIDs(chat.KBIDs) return &SetDialogResponse{ Chat: chat, KBNames: kbNames, }, nil } updateData := map[string]interface{}{ "name": name, "description": description, "icon": req.Icon, "llm_id": llmID, "llm_setting": llmSetting, "prompt_config": promptConfigMap, "meta_data_filter": metaDataFilter, "top_n": topN, "top_k": topK, "rerank_id": rerankID, "similarity_threshold": similarityThreshold, "vector_similarity_weight": vectorSimilarityWeight, "kb_ids": kbIDsJSON, } if err := s.chatDAO.UpdateByID(req.DialogID, updateData); err != nil { return nil, errors.New("Dialog not found") } // Get updated chat chat, err := s.chatDAO.GetByID(req.DialogID) if err != nil { return nil, errors.New("Fail to update a chat") } // Get KB names kbNames, _ := s.getDatasetNamesAndIDs(chat.KBIDs) return &SetDialogResponse{ Chat: chat, KBNames: kbNames, }, nil } // generateUniqueName generates a unique name by appending a number func (s *ChatService) generateUniqueName(name string, existingNames []string) string { baseName := name counter := 1 // Check if name already has a suffix like "(1)" if idx := strings.LastIndex(name, "("); idx > 0 { if idx2 := strings.LastIndex(name, ")"); idx2 > idx { if num, err := fmt.Sscanf(name[idx+1:idx2], "%d", &counter); err == nil && num == 1 { baseName = strings.TrimSpace(name[:idx]) counter++ } } } existingMap := make(map[string]bool) for _, n := range existingNames { existingMap[strings.ToLower(n)] = true } newName := name for { if !existingMap[strings.ToLower(newName)] { return newName } newName = fmt.Sprintf("%s(%d)", baseName, counter) counter++ } } // splitModelNameAndFactory extracts the base model name (removes vendor suffix) func (s *ChatService) splitModelNameAndFactory(embdID string) string { // Remove vendor suffix (e.g., "model@openai" -> "model") if idx := strings.LastIndex(embdID, "@"); idx > 0 { return embdID[:idx] } return embdID } // getEmbdIDs extracts embedding IDs from knowledge bases func getEmbdIDs(kbs []*entity.Knowledgebase) []string { ids := make([]string, len(kbs)) for i, kb := range kbs { ids[i] = kb.EmbdID } return ids } // RemoveChats removes dialogs by setting their status to invalid (soft delete) // Only the owner of the chat can perform this operation func (s *ChatService) RemoveChats(userID string, chatIDs []string) error { // Get user's tenants tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID) if err != nil { return err } // Build a set of user's tenant IDs for quick lookup tenantIDSet := make(map[string]bool) for _, tid := range tenantIDs { tenantIDSet[tid] = true } // Also add userID itself as a tenant (for cases where tenant_id = user_id) tenantIDSet[userID] = true // Check each chat and build update list var updates []map[string]interface{} for _, chatID := range chatIDs { // Get the chat to check ownership chat, err := s.chatDAO.GetByID(chatID) if err != nil { return fmt.Errorf("chat not found: %s", chatID) } // Check if user is the owner (chat's tenant_id must be in user's tenants) if !tenantIDSet[chat.TenantID] { return errors.New("only owner of chat authorized for this operation") } // Add to update list (soft delete by setting status to "0") updates = append(updates, map[string]interface{}{ "id": chatID, "status": "0", }) } // Batch update all dialogs if err := s.chatDAO.UpdateManyByID(updates); err != nil { return err } return nil } // strPtr returns a pointer to a string func strPtr(s string) *string { return &s } // Helper to count UTF-8 characters (not bytes) func (s *ChatService) countRunes(str string) int { return utf8.RuneCountInString(str) } // GetChatResponse get chat response with kb_names // Reference: Python _build_chat_response type GetChatResponse struct { *entity.Chat DatasetIDs []string `json:"dataset_ids"` KBNames []string `json:"kb_names"` } // GetChat gets chat detail by ID with permission check func (s *ChatService) GetChat(userID string, chatID string) (*GetChatResponse, error) { // Step 1: Get user tenants (same as Python UserTenantService.query(user_id=current_user.id)) tenants, err := s.userTenantDAO.GetByUserID(userID) if err != nil { return nil, fmt.Errorf("failed to get user tenants: %w", err) } // Step 2: Check if user has permission to access this chat // Python: for tenant in tenants: if DialogService.query(tenant_id=tenant.tenant_id, id=chat_id, status=StatusEnum.VALID.value): break hasPermission := false for _, tenant := range tenants { chats, err := s.chatDAO.QueryByTenantIDAndID(tenant.TenantID, chatID, "1") if err != nil { continue // Try next tenant } if len(chats) > 0 { hasPermission = true break } } if !hasPermission { return nil, fmt.Errorf("no authorization") } // Step 3: Get chat detail (same as Python DialogService.get_by_id(chat_id)) chat, err := s.chatDAO.GetByID(chatID) if err != nil { return nil, fmt.Errorf("chat not found") } // Step 4: Build response with kb_names (same as Python _build_chat_response) // Resolve kb_ids to kb_names kbNames, datasetIDs := s.getDatasetNamesAndIDs(chat.KBIDs) // Build dataset_ids from kb_ids (same as Python _resolve_kb_names returns ids) for _, kbID := range chat.KBIDs { datasetID, ok := kbID.(string) if !ok { continue } datasetIDs = append(datasetIDs, datasetID) } return &GetChatResponse{ Chat: chat, DatasetIDs: datasetIDs, KBNames: kbNames, }, nil }