mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-03 01:01:56 +08:00
[Go] Add /api/v1/searchbots/mindmap and /api/v1/chat/mindmap (#16443)
This commit is contained in:
@@ -297,17 +297,18 @@ func startServer(config *server.Config) {
|
||||
// no-op echo (the audio package contract), so this is always
|
||||
// safe to call.
|
||||
configureTTSSynthesizer(modelProviderService)
|
||||
searchBotLLM := &handler.SearchBotRealLLM{Svc: modelProviderService}
|
||||
modelProviderLLM := &handler.ModelProviderLLM{Svc: modelProviderService}
|
||||
searchBotHandler := handler.NewSearchBotHandler(
|
||||
searchService,
|
||||
tenantService,
|
||||
searchBotLLM,
|
||||
modelProviderLLM,
|
||||
chunkService,
|
||||
)
|
||||
searchBotHandler.SetStreamLLM(searchBotLLM)
|
||||
searchBotHandler.SetStreamLLM(modelProviderLLM)
|
||||
askService := service.NewAskService(chunkService, nil, 0, 0)
|
||||
searchBotHandler.SetAskService(askService)
|
||||
searchHandler.SetCompletionDependencies(searchBotLLM, askService)
|
||||
chatHandler.SetMindMapDependencies(searchService, tenantService, modelProviderLLM, chunkService)
|
||||
searchHandler.SetCompletionDependencies(modelProviderLLM, askService)
|
||||
pluginHandler := handler.NewPluginHandler(service.NewPluginService())
|
||||
modelHandler := handler.NewModelHandler(service.NewModelProviderService())
|
||||
fileCommitHandler := handler.NewFileCommitHandler(service.NewFileCommitService())
|
||||
|
||||
@@ -18,9 +18,11 @@ package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"ragflow/internal/common"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -31,6 +33,10 @@ import (
|
||||
type ChatHandler struct {
|
||||
chatService *service.ChatService
|
||||
userService *service.UserService
|
||||
searchSvc *service.SearchService
|
||||
tenantSvc *service.TenantService
|
||||
llm chatLLM
|
||||
chunkSvc ChunkRetriever
|
||||
}
|
||||
|
||||
// NewChatHandler create chat handler
|
||||
@@ -41,6 +47,21 @@ func NewChatHandler(chatService *service.ChatService, userService *service.UserS
|
||||
}
|
||||
}
|
||||
|
||||
// SetMindMapDependencies sets dependencies used by POST /api/v1/chat/mindmap.
|
||||
func (h *ChatHandler) SetMindMapDependencies(searchSvc *service.SearchService, tenantSvc *service.TenantService, llm chatLLM, chunkSvc ChunkRetriever) {
|
||||
h.searchSvc = searchSvc
|
||||
h.tenantSvc = tenantSvc
|
||||
h.llm = llm
|
||||
h.chunkSvc = chunkSvc
|
||||
}
|
||||
|
||||
// ChatMindMapRequest is the request body for POST /api/v1/chat/mindmap.
|
||||
type ChatMindMapRequest struct {
|
||||
Question string `json:"question" binding:"required"`
|
||||
KbIDs common.StringSlice `json:"kb_ids" binding:"required"`
|
||||
SearchID string `json:"search_id,omitempty"`
|
||||
}
|
||||
|
||||
// ListChats list chats
|
||||
// @Summary List Chats
|
||||
// @Description Get list of chats (dialogs) for the current user
|
||||
@@ -138,6 +159,74 @@ func (h *ChatHandler) Create(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// MindMap generates a query mind map for chat search results.
|
||||
// @Summary Generate Chat Mind Map
|
||||
// @Description Retrieves related chunks and asks the configured chat model to summarize them into a mind map.
|
||||
// @Tags chat
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body ChatMindMapRequest true "Mind map parameters"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/v1/chat/mindmap [post]
|
||||
func (h *ChatHandler) MindMap(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
var req ChatMindMapRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Question) == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": "kb_ids and question are required"})
|
||||
return
|
||||
}
|
||||
|
||||
searchConfig := map[string]interface{}{}
|
||||
modelTenantID := user.ID
|
||||
if req.SearchID != "" {
|
||||
if h.searchSvc == nil {
|
||||
jsonInternalError(c, fmt.Errorf("search service not configured"))
|
||||
return
|
||||
}
|
||||
detail, err := h.searchSvc.GetDetail(req.SearchID)
|
||||
if err != nil {
|
||||
jsonInternalError(c, err)
|
||||
return
|
||||
}
|
||||
searchConfig = searchConfigFromDetail(detail)
|
||||
if tenantID, ok := detail["tenant_id"].(string); ok && tenantID != "" {
|
||||
modelTenantID = tenantID
|
||||
}
|
||||
}
|
||||
|
||||
kbIDs := mergeMindMapKbIDs(stringSliceFromConfig(searchConfig, "kb_ids"), req.KbIDs)
|
||||
if len(kbIDs) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": "kb_ids and question are required"})
|
||||
return
|
||||
}
|
||||
|
||||
mindMap, err := runMindMap(mindMapRunConfig{
|
||||
Question: req.Question,
|
||||
KbIDs: kbIDs,
|
||||
SearchID: req.SearchID,
|
||||
SearchConfig: searchConfig,
|
||||
AuthUserID: user.ID,
|
||||
ModelTenantID: modelTenantID,
|
||||
ChunkSvc: h.chunkSvc,
|
||||
LLM: h.llm,
|
||||
TenantSvc: h.tenantSvc,
|
||||
})
|
||||
if err != nil {
|
||||
jsonInternalError(c, err)
|
||||
return
|
||||
}
|
||||
jsonResponse(c, common.CodeSuccess, mindMap, "success")
|
||||
}
|
||||
|
||||
// ListChatsNext list chats with advanced filtering and pagination
|
||||
// @Summary List Chats Next
|
||||
// @Description Get list of chats with filtering, pagination and sorting (equivalent to list_dialogs_next)
|
||||
|
||||
@@ -3,6 +3,7 @@ package handler
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -71,6 +72,41 @@ func createChatHandlerTestChat(t *testing.T, db *gorm.DB, id, tenantID string) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatMindMapHandlerSuccess(t *testing.T) {
|
||||
llm := &fakeChatLLM{response: "# Product\n## Features\n### Search"}
|
||||
chunks := &mockChunkService{retrievalTestFn: func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
|
||||
return &service.RetrievalTestResponse{
|
||||
Chunks: []map[string]interface{}{{"content_with_weight": "Hybrid search combines vector and keyword retrieval."}},
|
||||
}, nil
|
||||
}}
|
||||
h := NewChatHandler(service.NewChatService(), service.NewUserService())
|
||||
h.SetMindMapDependencies(nil, nil, llm, chunks)
|
||||
c, w := setupGinContextWithUser("POST", "/api/v1/chat/mindmap", `{"question":"What is search?","kb_ids":["kb-1"]}`)
|
||||
|
||||
h.MindMap(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected code 0, got %v: %v", resp["code"], resp["message"])
|
||||
}
|
||||
data := resp["data"].(map[string]interface{})
|
||||
if data["id"] != "Product" {
|
||||
t.Fatalf("mindmap root = %v, want Product", data["id"])
|
||||
}
|
||||
if chunks.LastReq == nil || len(chunks.LastReq.Datasets) != 1 || chunks.LastReq.Datasets[0] != "kb-1" {
|
||||
t.Fatalf("retrieval datasets = %+v, want [kb-1]", chunks.LastReq)
|
||||
}
|
||||
if llm.lastTenantID != "user-1" || len(llm.lastMessages) != 2 || !strings.Contains(llm.lastMessages[0].Content, "Hybrid search combines") {
|
||||
t.Fatalf("unexpected LLM call: tenant=%q messages=%v", llm.lastTenantID, llm.lastMessages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteChatHandlerSuccess(t *testing.T) {
|
||||
db := setupChatHandlerTestDB(t)
|
||||
createChatHandlerTestChat(t, db, "chat-1", "user-1")
|
||||
|
||||
291
internal/handler/mindmap.go
Normal file
291
internal/handler/mindmap.go
Normal file
@@ -0,0 +1,291 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/entity"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
"ragflow/internal/service"
|
||||
)
|
||||
|
||||
type mindMapRunConfig struct {
|
||||
Question string
|
||||
KbIDs common.StringSlice
|
||||
SearchID string
|
||||
SearchConfig map[string]interface{}
|
||||
AuthUserID string
|
||||
ModelTenantID string
|
||||
ChunkSvc ChunkRetriever
|
||||
LLM chatLLM
|
||||
TenantSvc *service.TenantService
|
||||
}
|
||||
|
||||
func runMindMap(config mindMapRunConfig) (mindMapNode, error) {
|
||||
if config.ChunkSvc == nil {
|
||||
return mindMapNode{}, fmt.Errorf("chunk service not configured")
|
||||
}
|
||||
if config.LLM == nil {
|
||||
return mindMapNode{}, fmt.Errorf("LLM not configured")
|
||||
}
|
||||
modelTenantID := config.ModelTenantID
|
||||
if modelTenantID == "" {
|
||||
modelTenantID = config.AuthUserID
|
||||
}
|
||||
retrievalReq := mindMapRetrievalRequest(config.Question, config.KbIDs, config.SearchID, config.SearchConfig)
|
||||
ranks, err := config.ChunkSvc.RetrievalTest(retrievalReq, config.AuthUserID)
|
||||
if err != nil {
|
||||
return mindMapNode{}, err
|
||||
}
|
||||
sections := mindMapSections(ranks)
|
||||
if len(sections) == 0 {
|
||||
return mindMapNode{ID: "root", Children: []mindMapNode{}}, nil
|
||||
}
|
||||
modelID, _ := config.SearchConfig["chat_id"].(string)
|
||||
if modelID == "" && config.TenantSvc != nil {
|
||||
defaultModel, err := config.TenantSvc.GetDefaultModelName(modelTenantID, entity.ModelTypeChat)
|
||||
if err == nil {
|
||||
modelID = defaultModel
|
||||
}
|
||||
}
|
||||
response, err := config.LLM.Chat(modelTenantID, modelID, []modelModule.Message{{Role: "user", Content: mindMapPrompt(strings.Join(sections, "\n"))}, {Role: "user", Content: "Output:"}}, &modelModule.ChatConfig{})
|
||||
if err != nil {
|
||||
return mindMapNode{}, err
|
||||
}
|
||||
if response == nil || response.Answer == nil {
|
||||
return mindMapNode{ID: "root", Children: []mindMapNode{}}, nil
|
||||
}
|
||||
return parseMindMapMarkdown(*response.Answer), nil
|
||||
}
|
||||
|
||||
func searchConfigFromDetail(detail map[string]interface{}) map[string]interface{} {
|
||||
if sc, ok := detail["search_config"].(map[string]interface{}); ok && sc != nil {
|
||||
return sc
|
||||
}
|
||||
if sc, ok := detail["search_config"].(entity.JSONMap); ok && sc != nil {
|
||||
return map[string]interface{}(sc)
|
||||
}
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
|
||||
func mindMapRetrievalRequest(question string, kbIDs common.StringSlice, searchID string, searchConfig map[string]interface{}) *service.RetrievalTestRequest {
|
||||
page := 1
|
||||
size := 12
|
||||
topK := intFromConfig(searchConfig, "top_k", 1024)
|
||||
similarityThreshold := floatFromConfig(searchConfig, "similarity_threshold", 0.2)
|
||||
vectorSimilarityWeight := floatFromConfig(searchConfig, "vector_similarity_weight", 0.3)
|
||||
req := &service.RetrievalTestRequest{
|
||||
Datasets: kbIDs,
|
||||
Question: question,
|
||||
Page: &page,
|
||||
Size: &size,
|
||||
TopK: &topK,
|
||||
SimilarityThreshold: &similarityThreshold,
|
||||
VectorSimilarityWeight: &vectorSimilarityWeight,
|
||||
DocIDs: stringSliceFromConfig(searchConfig, "doc_ids"),
|
||||
Filter: mapFromConfig(searchConfig, "meta_data_filter"),
|
||||
}
|
||||
if searchID != "" {
|
||||
req.SearchID = &searchID
|
||||
}
|
||||
if rerankID, _ := searchConfig["rerank_id"].(string); rerankID != "" {
|
||||
req.RerankID = &rerankID
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func mindMapSections(ranks *service.RetrievalTestResponse) []string {
|
||||
if ranks == nil {
|
||||
return nil
|
||||
}
|
||||
sections := make([]string, 0, len(ranks.Chunks))
|
||||
for _, chunk := range ranks.Chunks {
|
||||
if content, ok := chunk["content_with_weight"].(string); ok && strings.TrimSpace(content) != "" {
|
||||
sections = append(sections, content)
|
||||
}
|
||||
}
|
||||
return sections
|
||||
}
|
||||
|
||||
func mergeMindMapKbIDs(saved []string, requested common.StringSlice) common.StringSlice {
|
||||
seen := map[string]bool{}
|
||||
merged := make(common.StringSlice, 0, len(saved)+len(requested))
|
||||
for _, id := range saved {
|
||||
id = strings.TrimSpace(id)
|
||||
if id != "" && !seen[id] {
|
||||
seen[id] = true
|
||||
merged = append(merged, id)
|
||||
}
|
||||
}
|
||||
for _, id := range requested {
|
||||
id = strings.TrimSpace(id)
|
||||
if id != "" && !seen[id] {
|
||||
seen[id] = true
|
||||
merged = append(merged, id)
|
||||
}
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func intFromConfig(config map[string]interface{}, key string, fallback int) int {
|
||||
switch v := config[key].(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case json.Number:
|
||||
if n, err := v.Int64(); err == nil {
|
||||
return int(n)
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func floatFromConfig(config map[string]interface{}, key string, fallback float64) float64 {
|
||||
switch v := config[key].(type) {
|
||||
case float64:
|
||||
return v
|
||||
case float32:
|
||||
return float64(v)
|
||||
case int:
|
||||
return float64(v)
|
||||
case int64:
|
||||
return float64(v)
|
||||
case json.Number:
|
||||
if n, err := v.Float64(); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func stringSliceFromConfig(config map[string]interface{}, key string) []string {
|
||||
switch v := config[key].(type) {
|
||||
case []string:
|
||||
return v
|
||||
case []interface{}:
|
||||
out := make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
if s, ok := item.(string); ok && s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mapFromConfig(config map[string]interface{}, key string) map[string]interface{} {
|
||||
if m, ok := config[key].(map[string]interface{}); ok {
|
||||
return m
|
||||
}
|
||||
if m, ok := config[key].(entity.JSONMap); ok {
|
||||
return map[string]interface{}(m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mindMapPrompt(inputText string) string {
|
||||
return `- Role: You're a talent text processor to summarize a piece of text into a mind map.
|
||||
|
||||
- Step of task:
|
||||
1. Generate a title for user's 'TEXT'.
|
||||
2. Classify the 'TEXT' into sections of a mind map.
|
||||
3. If the subject matter is really complex, split them into sub-sections and sub-subsections.
|
||||
4. Add a shot content summary of the bottom level section.
|
||||
|
||||
- Output requirement:
|
||||
- Generate at least 4 levels.
|
||||
- Always try to maximize the number of sub-sections.
|
||||
- In language of 'Text'
|
||||
- MUST IN FORMAT OF MARKDOWN
|
||||
|
||||
-TEXT-
|
||||
` + inputText + "\n"
|
||||
}
|
||||
|
||||
type mindMapNode struct {
|
||||
ID string `json:"id"`
|
||||
Children []mindMapNode `json:"children"`
|
||||
}
|
||||
|
||||
var mindMapHeadingRe = regexp.MustCompile(`^(#{1,6})\s+(.+)$`)
|
||||
var mindMapListRe = regexp.MustCompile(`^(\s*)(?:[-*+]|\d+\.)\s+(.+)$`)
|
||||
|
||||
func parseMindMapMarkdown(text string) mindMapNode {
|
||||
lines := strings.Split(strings.ReplaceAll(text, "\r\n", "\n"), "\n")
|
||||
root := mindMapNode{ID: "root", Children: []mindMapNode{}}
|
||||
stack := []*mindMapNode{&root}
|
||||
inFence := false
|
||||
listBaseLevel := 1
|
||||
lastWasList := false
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, "```") {
|
||||
inFence = !inFence
|
||||
lastWasList = false
|
||||
continue
|
||||
}
|
||||
if inFence || trimmed == "" {
|
||||
lastWasList = false
|
||||
continue
|
||||
}
|
||||
level := 0
|
||||
title := ""
|
||||
if m := mindMapHeadingRe.FindStringSubmatch(trimmed); len(m) == 3 {
|
||||
level = len(m[1])
|
||||
title = cleanMindMapText(m[2])
|
||||
lastWasList = false
|
||||
} else if m := mindMapListRe.FindStringSubmatch(line); len(m) == 3 {
|
||||
rawLevel := len(m[1])/2 + 1
|
||||
if !lastWasList {
|
||||
listBaseLevel = len(stack)
|
||||
}
|
||||
level = listBaseLevel + rawLevel - 1
|
||||
title = cleanMindMapText(m[2])
|
||||
lastWasList = true
|
||||
}
|
||||
if title == "" {
|
||||
lastWasList = false
|
||||
continue
|
||||
}
|
||||
for len(stack) > level {
|
||||
stack = stack[:len(stack)-1]
|
||||
}
|
||||
parent := stack[len(stack)-1]
|
||||
parent.Children = append(parent.Children, mindMapNode{ID: title, Children: []mindMapNode{}})
|
||||
stack = append(stack, &parent.Children[len(parent.Children)-1])
|
||||
}
|
||||
if len(root.Children) == 1 {
|
||||
return root.Children[0]
|
||||
}
|
||||
return root
|
||||
}
|
||||
|
||||
func cleanMindMapText(text string) string {
|
||||
text = strings.TrimSpace(text)
|
||||
text = strings.Trim(text, "`")
|
||||
text = strings.Trim(text, "*_ ")
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
@@ -35,8 +35,8 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// searchbotLLM is the interface for LLM calls used by SearchBotHandler.
|
||||
type searchbotLLM interface {
|
||||
// chatLLM is the interface for LLM calls used by chat-style handlers.
|
||||
type chatLLM interface {
|
||||
Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error)
|
||||
}
|
||||
|
||||
@@ -59,12 +59,19 @@ type SearchBotAskRequest struct {
|
||||
SearchID string `json:"search_id,omitempty"`
|
||||
}
|
||||
|
||||
// SearchBotRealLLM wraps ModelProviderService to implement searchbotLLM.
|
||||
type SearchBotRealLLM struct {
|
||||
// SearchBotMindMapRequest is the request body for POST /api/v1/searchbots/mindmap.
|
||||
type SearchBotMindMapRequest struct {
|
||||
Question string `json:"question" binding:"required"`
|
||||
KbIDs common.StringSlice `json:"kb_ids" binding:"required"`
|
||||
SearchID string `json:"search_id,omitempty"`
|
||||
}
|
||||
|
||||
// ModelProviderLLM wraps ModelProviderService to implement chatLLM.
|
||||
type ModelProviderLLM struct {
|
||||
Svc *service.ModelProviderService
|
||||
}
|
||||
|
||||
func (r *SearchBotRealLLM) Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) {
|
||||
func (r *ModelProviderLLM) Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) {
|
||||
chatModel, err := r.Svc.GetChatModel(tenantID, modelID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -73,7 +80,7 @@ func (r *SearchBotRealLLM) Chat(tenantID, modelID string, messages []modelModule
|
||||
}
|
||||
|
||||
// ChatStream implements streamingLLM.
|
||||
func (r *SearchBotRealLLM) ChatStream(ctx context.Context, tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) {
|
||||
func (r *ModelProviderLLM) ChatStream(ctx context.Context, tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) {
|
||||
chatModel, err := r.Svc.GetChatModel(tenantID, modelID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -161,10 +168,11 @@ type SearchBotRequest struct {
|
||||
// POST /api/v1/searchbots/related_questions
|
||||
// POST /api/v1/searchbots/retrieval_test
|
||||
// POST /api/v1/searchbots/ask
|
||||
// POST /api/v1/searchbots/mindmap
|
||||
type SearchBotHandler struct {
|
||||
searchSvc *service.SearchService
|
||||
tenantSvc *service.TenantService
|
||||
llm searchbotLLM
|
||||
llm chatLLM
|
||||
streamLLM streamingLLM
|
||||
chunkSvc ChunkRetriever
|
||||
askSvc *service.AskService
|
||||
@@ -172,7 +180,7 @@ type SearchBotHandler struct {
|
||||
}
|
||||
|
||||
// NewSearchBotHandler creates a new SearchBotHandler.
|
||||
func NewSearchBotHandler(searchSvc *service.SearchService, tenantSvc *service.TenantService, llm searchbotLLM, chunkSvc ChunkRetriever) *SearchBotHandler {
|
||||
func NewSearchBotHandler(searchSvc *service.SearchService, tenantSvc *service.TenantService, llm chatLLM, chunkSvc ChunkRetriever) *SearchBotHandler {
|
||||
return &SearchBotHandler{searchSvc: searchSvc, tenantSvc: tenantSvc, llm: llm, chunkSvc: chunkSvc, sseWriter: &ginSSEWriter{}}
|
||||
}
|
||||
|
||||
@@ -422,6 +430,80 @@ func (h *SearchBotHandler) Ask(c *gin.Context) {
|
||||
|
||||
}
|
||||
|
||||
// MindMap generates a query mind map for a shared search bot.
|
||||
// @Summary Generate Mind Map
|
||||
// @Description Retrieves related chunks and asks the configured chat model to summarize them into a mind map.
|
||||
// @Tags searchbots
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body SearchBotMindMapRequest true "Mind map parameters"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/v1/searchbots/mindmap [post]
|
||||
func (h *SearchBotHandler) MindMap(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
var req SearchBotMindMapRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
filtered := make(common.StringSlice, 0, len(req.KbIDs))
|
||||
for _, id := range req.KbIDs {
|
||||
if strings.TrimSpace(id) != "" {
|
||||
filtered = append(filtered, id)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 || strings.TrimSpace(req.Question) == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": "kb_ids and question are required"})
|
||||
return
|
||||
}
|
||||
if h.chunkSvc == nil {
|
||||
jsonInternalError(c, fmt.Errorf("chunk service not configured"))
|
||||
return
|
||||
}
|
||||
if h.llm == nil {
|
||||
jsonInternalError(c, fmt.Errorf("LLM not configured"))
|
||||
return
|
||||
}
|
||||
|
||||
searchConfig := map[string]interface{}{}
|
||||
if req.SearchID != "" {
|
||||
if h.searchSvc == nil {
|
||||
jsonInternalError(c, fmt.Errorf("search service not configured"))
|
||||
return
|
||||
}
|
||||
detail, err := h.searchSvc.GetDetail(req.SearchID)
|
||||
if err != nil {
|
||||
jsonInternalError(c, err)
|
||||
return
|
||||
}
|
||||
searchConfig = searchConfigFromDetail(detail)
|
||||
}
|
||||
|
||||
mindMap, err := runMindMap(mindMapRunConfig{
|
||||
Question: req.Question,
|
||||
KbIDs: filtered,
|
||||
SearchID: req.SearchID,
|
||||
SearchConfig: searchConfig,
|
||||
AuthUserID: user.ID,
|
||||
ModelTenantID: user.ID,
|
||||
ChunkSvc: h.chunkSvc,
|
||||
LLM: h.llm,
|
||||
TenantSvc: h.tenantSvc,
|
||||
})
|
||||
if err != nil {
|
||||
common.Warn("searchbot mindmap failed", zap.String("error", err.Error()))
|
||||
jsonInternalError(c, err)
|
||||
return
|
||||
}
|
||||
jsonResponse(c, common.CodeSuccess, mindMap, "success")
|
||||
}
|
||||
|
||||
// SearchbotDetail returns the public share-page bootstrap payload for a
|
||||
// search app. The route is mounted under apiNoAuth but still requires a beta
|
||||
// token, matching Python's AUTH_BETA flow.
|
||||
|
||||
@@ -411,13 +411,19 @@ func TestSearchBotsRetrieval_EmptyQuestion(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// fakeSearchbotLLM implements searchbotLLM for testing.
|
||||
type fakeSearchbotLLM struct {
|
||||
response string
|
||||
err error
|
||||
// fakeChatLLM implements chatLLM for testing.
|
||||
type fakeChatLLM struct {
|
||||
response string
|
||||
err error
|
||||
lastTenantID string
|
||||
lastModelID string
|
||||
lastMessages []modelModule.Message
|
||||
}
|
||||
|
||||
func (f *fakeSearchbotLLM) Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) {
|
||||
func (f *fakeChatLLM) Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) {
|
||||
f.lastTenantID = tenantID
|
||||
f.lastModelID = modelID
|
||||
f.lastMessages = messages
|
||||
if f.err != nil {
|
||||
return nil, f.err
|
||||
}
|
||||
@@ -438,7 +444,7 @@ func setupSearchBotRequest(body string) (*gin.Context, *httptest.ResponseRecorde
|
||||
|
||||
// TestSearchBotHandler_Success verifies the happy path.
|
||||
func TestSearchBotHandler_Success(t *testing.T) {
|
||||
llm := &fakeSearchbotLLM{
|
||||
llm := &fakeChatLLM{
|
||||
response: `Here are some related questions:
|
||||
1. How do EV impact environment?
|
||||
2. What are advantages of EV?
|
||||
@@ -472,7 +478,7 @@ func TestSearchBotHandler_Success(t *testing.T) {
|
||||
|
||||
// TestSearchBotHandler_EmptyResponse verifies empty LLM response returns empty list.
|
||||
func TestSearchBotHandler_EmptyResponse(t *testing.T) {
|
||||
llm := &fakeSearchbotLLM{
|
||||
llm := &fakeChatLLM{
|
||||
response: "No related questions found.",
|
||||
}
|
||||
h := NewSearchBotHandler(nil, nil, llm, nil)
|
||||
@@ -496,7 +502,7 @@ func TestSearchBotHandler_EmptyResponse(t *testing.T) {
|
||||
|
||||
// TestSearchBotHandler_LLMFailure verifies error handling on LLM failure.
|
||||
func TestSearchBotHandler_LLMFailure(t *testing.T) {
|
||||
llm := &fakeSearchbotLLM{
|
||||
llm := &fakeChatLLM{
|
||||
err: errFake{msg: "LLM unavailable"},
|
||||
}
|
||||
h := NewSearchBotHandler(nil, nil, llm, nil)
|
||||
@@ -514,7 +520,7 @@ func TestSearchBotHandler_LLMFailure(t *testing.T) {
|
||||
|
||||
// TestSearchBotHandler_MissingQuestion verifies validation.
|
||||
func TestSearchBotHandler_MissingQuestion(t *testing.T) {
|
||||
llm := &fakeSearchbotLLM{response: "dummy"}
|
||||
llm := &fakeChatLLM{response: "dummy"}
|
||||
h := NewSearchBotHandler(nil, nil, llm, nil)
|
||||
|
||||
c, w := setupSearchBotRequest(`{}`)
|
||||
@@ -906,6 +912,64 @@ func TestAskHandler_WhitespaceKbIDFiltered(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMindMapHandlerSuccess(t *testing.T) {
|
||||
llm := &fakeChatLLM{response: "# Product\n## Features\n### Search\n#### Hybrid retrieval"}
|
||||
chunks := &mockChunkService{retrievalTestFn: func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
|
||||
return &service.RetrievalTestResponse{
|
||||
Chunks: []map[string]interface{}{{"content_with_weight": "Hybrid search combines vector and keyword retrieval."}},
|
||||
}, nil
|
||||
}}
|
||||
h := NewSearchBotHandler(nil, nil, llm, chunks)
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: "user-1"})
|
||||
})
|
||||
r.POST("/api/v1/searchbots/mindmap", h.MindMap)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/searchbots/mindmap", strings.NewReader(`{"question":"What is search?","kb_ids":["kb-1"]}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected code 0, got %v: %v", resp["code"], resp["message"])
|
||||
}
|
||||
data := resp["data"].(map[string]interface{})
|
||||
if data["id"] != "Product" {
|
||||
t.Fatalf("mindmap root = %v, want Product", data["id"])
|
||||
}
|
||||
if chunks.LastReq == nil {
|
||||
t.Fatal("RetrievalTest was not called")
|
||||
}
|
||||
if *chunks.LastReq.Page != 1 || *chunks.LastReq.Size != 12 || *chunks.LastReq.TopK != 1024 {
|
||||
t.Fatalf("retrieval defaults = page %d size %d topK %d", *chunks.LastReq.Page, *chunks.LastReq.Size, *chunks.LastReq.TopK)
|
||||
}
|
||||
if llm.lastTenantID != "user-1" || len(llm.lastMessages) != 2 || !strings.Contains(llm.lastMessages[0].Content, "Hybrid search combines") {
|
||||
t.Fatalf("unexpected LLM call: tenant=%q messages=%v", llm.lastTenantID, llm.lastMessages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMindMapMarkdown_ListUnderHeading(t *testing.T) {
|
||||
got := parseMindMapMarkdown("# Product\n- Features\n - Search")
|
||||
if got.ID != "Product" {
|
||||
t.Fatalf("root = %q, want Product", got.ID)
|
||||
}
|
||||
if len(got.Children) != 1 || got.Children[0].ID != "Features" {
|
||||
t.Fatalf("children = %+v, want Features under Product", got.Children)
|
||||
}
|
||||
if len(got.Children[0].Children) != 1 || got.Children[0].Children[0].ID != "Search" {
|
||||
t.Fatalf("nested children = %+v, want Search under Features", got.Children[0].Children)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- SSE helper direct tests ----
|
||||
|
||||
func TestSseAnswer_Final(t *testing.T) {
|
||||
|
||||
@@ -194,6 +194,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
searchbotGroup.POST("/related_questions", r.searchBotHandler.Handle)
|
||||
searchbotGroup.POST("/retrieval_test", r.searchBotHandler.RetrievalTest)
|
||||
searchbotGroup.POST("/ask", r.searchBotHandler.Ask)
|
||||
searchbotGroup.POST("/mindmap", r.searchBotHandler.MindMap)
|
||||
|
||||
if r.botHandler != nil {
|
||||
chatbotGroup := apiBetaAuth.Group("/chatbots")
|
||||
@@ -271,6 +272,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
}
|
||||
|
||||
// Chat routes
|
||||
v1.POST("/chat/mindmap", r.chatHandler.MindMap)
|
||||
chats := v1.Group("/chats")
|
||||
{
|
||||
chats.GET("", r.chatHandler.ListChats)
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/handler"
|
||||
"ragflow/internal/service"
|
||||
)
|
||||
|
||||
func TestConnectorRoutesDoNotConflictWithOAuthCallbacks(t *testing.T) {
|
||||
@@ -93,3 +96,62 @@ func TestRouterSetupRegistersUpdateDatasetRoute(t *testing.T) {
|
||||
t.Fatalf("status=%d body=%s; want auth middleware to handle registered UpdateDataset route", resp.Code, resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterSetupRegistersSearchbotMindMapRoute(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
engine := gin.New()
|
||||
r := &Router{
|
||||
authHandler: handler.NewAuthHandler(),
|
||||
searchBotHandler: handler.NewSearchBotHandler(nil, nil, nil, nil),
|
||||
}
|
||||
r.Setup(engine)
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/searchbots/mindmap", nil)
|
||||
engine.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code == http.StatusNotFound {
|
||||
t.Fatalf("POST /api/v1/searchbots/mindmap returned 404; MindMap route is not registered")
|
||||
}
|
||||
var body struct {
|
||||
Code common.ErrorCode `json:"code"`
|
||||
}
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("failed to decode response body: %v", err)
|
||||
}
|
||||
if body.Code != common.CodeUnauthorized {
|
||||
t.Fatalf("status=%d body=%s; want beta auth middleware to handle registered MindMap route", resp.Code, resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterSetupRegistersChatMindMapRoute(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
engine := gin.New()
|
||||
r := &Router{
|
||||
authHandler: handler.NewAuthHandler(),
|
||||
chatHandler: handler.NewChatHandler(
|
||||
service.NewChatService(),
|
||||
service.NewUserService(),
|
||||
),
|
||||
}
|
||||
r.Setup(engine)
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/chat/mindmap", nil)
|
||||
engine.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code == http.StatusNotFound {
|
||||
t.Fatalf("POST /api/v1/chat/mindmap returned 404; Chat MindMap route is not registered")
|
||||
}
|
||||
var body struct {
|
||||
Code common.ErrorCode `json:"code"`
|
||||
}
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("failed to decode response body: %v", err)
|
||||
}
|
||||
if body.Code != common.CodeUnauthorized {
|
||||
t.Fatalf("status=%d body=%s; want auth middleware to handle registered Chat MindMap route", resp.Code, resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user