From 3bb976b383d03d91f3961ce56a3d39df5c2b295c Mon Sep 17 00:00:00 2001 From: Wang Qi Date: Tue, 30 Jun 2026 16:35:33 +0800 Subject: [PATCH] [Go] Add /api/v1/searchbots/mindmap and /api/v1/chat/mindmap (#16443) --- cmd/server_main.go | 9 +- internal/handler/chat.go | 89 +++++++++ internal/handler/chat_test.go | 36 ++++ internal/handler/mindmap.go | 291 +++++++++++++++++++++++++++++ internal/handler/searchbot.go | 98 +++++++++- internal/handler/searchbot_test.go | 82 +++++++- internal/router/router.go | 2 + internal/router/router_test.go | 62 ++++++ 8 files changed, 648 insertions(+), 21 deletions(-) create mode 100644 internal/handler/mindmap.go diff --git a/cmd/server_main.go b/cmd/server_main.go index 0ddcd5c73d..8e72aecd0f 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -297,17 +297,18 @@ func startServer(config *server.Config) { // no-op echo (the audio package contract), so this is always // safe to call. configureTTSSynthesizer(modelProviderService) - searchBotLLM := &handler.SearchBotRealLLM{Svc: modelProviderService} + modelProviderLLM := &handler.ModelProviderLLM{Svc: modelProviderService} searchBotHandler := handler.NewSearchBotHandler( searchService, tenantService, - searchBotLLM, + modelProviderLLM, chunkService, ) - searchBotHandler.SetStreamLLM(searchBotLLM) + searchBotHandler.SetStreamLLM(modelProviderLLM) askService := service.NewAskService(chunkService, nil, 0, 0) searchBotHandler.SetAskService(askService) - searchHandler.SetCompletionDependencies(searchBotLLM, askService) + chatHandler.SetMindMapDependencies(searchService, tenantService, modelProviderLLM, chunkService) + searchHandler.SetCompletionDependencies(modelProviderLLM, askService) pluginHandler := handler.NewPluginHandler(service.NewPluginService()) modelHandler := handler.NewModelHandler(service.NewModelProviderService()) fileCommitHandler := handler.NewFileCommitHandler(service.NewFileCommitService()) diff --git a/internal/handler/chat.go b/internal/handler/chat.go index e15f5f4c65..614c3281be 100644 --- a/internal/handler/chat.go +++ b/internal/handler/chat.go @@ -18,9 +18,11 @@ package handler import ( "encoding/json" + "fmt" "net/http" "ragflow/internal/common" "strconv" + "strings" "github.com/gin-gonic/gin" @@ -31,6 +33,10 @@ import ( type ChatHandler struct { chatService *service.ChatService userService *service.UserService + searchSvc *service.SearchService + tenantSvc *service.TenantService + llm chatLLM + chunkSvc ChunkRetriever } // NewChatHandler create chat handler @@ -41,6 +47,21 @@ func NewChatHandler(chatService *service.ChatService, userService *service.UserS } } +// SetMindMapDependencies sets dependencies used by POST /api/v1/chat/mindmap. +func (h *ChatHandler) SetMindMapDependencies(searchSvc *service.SearchService, tenantSvc *service.TenantService, llm chatLLM, chunkSvc ChunkRetriever) { + h.searchSvc = searchSvc + h.tenantSvc = tenantSvc + h.llm = llm + h.chunkSvc = chunkSvc +} + +// ChatMindMapRequest is the request body for POST /api/v1/chat/mindmap. +type ChatMindMapRequest struct { + Question string `json:"question" binding:"required"` + KbIDs common.StringSlice `json:"kb_ids" binding:"required"` + SearchID string `json:"search_id,omitempty"` +} + // ListChats list chats // @Summary List Chats // @Description Get list of chats (dialogs) for the current user @@ -138,6 +159,74 @@ func (h *ChatHandler) Create(c *gin.Context) { }) } +// MindMap generates a query mind map for chat search results. +// @Summary Generate Chat Mind Map +// @Description Retrieves related chunks and asks the configured chat model to summarize them into a mind map. +// @Tags chat +// @Accept json +// @Produce json +// @Param request body ChatMindMapRequest true "Mind map parameters" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/chat/mindmap [post] +func (h *ChatHandler) MindMap(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + var req ChatMindMapRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": err.Error()}) + return + } + if strings.TrimSpace(req.Question) == "" { + c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": "kb_ids and question are required"}) + return + } + + searchConfig := map[string]interface{}{} + modelTenantID := user.ID + if req.SearchID != "" { + if h.searchSvc == nil { + jsonInternalError(c, fmt.Errorf("search service not configured")) + return + } + detail, err := h.searchSvc.GetDetail(req.SearchID) + if err != nil { + jsonInternalError(c, err) + return + } + searchConfig = searchConfigFromDetail(detail) + if tenantID, ok := detail["tenant_id"].(string); ok && tenantID != "" { + modelTenantID = tenantID + } + } + + kbIDs := mergeMindMapKbIDs(stringSliceFromConfig(searchConfig, "kb_ids"), req.KbIDs) + if len(kbIDs) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": "kb_ids and question are required"}) + return + } + + mindMap, err := runMindMap(mindMapRunConfig{ + Question: req.Question, + KbIDs: kbIDs, + SearchID: req.SearchID, + SearchConfig: searchConfig, + AuthUserID: user.ID, + ModelTenantID: modelTenantID, + ChunkSvc: h.chunkSvc, + LLM: h.llm, + TenantSvc: h.tenantSvc, + }) + if err != nil { + jsonInternalError(c, err) + return + } + jsonResponse(c, common.CodeSuccess, mindMap, "success") +} + // ListChatsNext list chats with advanced filtering and pagination // @Summary List Chats Next // @Description Get list of chats with filtering, pagination and sorting (equivalent to list_dialogs_next) diff --git a/internal/handler/chat_test.go b/internal/handler/chat_test.go index f75a5b07fb..253efe0d2f 100644 --- a/internal/handler/chat_test.go +++ b/internal/handler/chat_test.go @@ -3,6 +3,7 @@ package handler import ( "encoding/json" "net/http" + "strings" "testing" "github.com/gin-gonic/gin" @@ -71,6 +72,41 @@ func createChatHandlerTestChat(t *testing.T, db *gorm.DB, id, tenantID string) { } } +func TestChatMindMapHandlerSuccess(t *testing.T) { + llm := &fakeChatLLM{response: "# Product\n## Features\n### Search"} + chunks := &mockChunkService{retrievalTestFn: func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) { + return &service.RetrievalTestResponse{ + Chunks: []map[string]interface{}{{"content_with_weight": "Hybrid search combines vector and keyword retrieval."}}, + }, nil + }} + h := NewChatHandler(service.NewChatService(), service.NewUserService()) + h.SetMindMapDependencies(nil, nil, llm, chunks) + c, w := setupGinContextWithUser("POST", "/api/v1/chat/mindmap", `{"question":"What is search?","kb_ids":["kb-1"]}`) + + h.MindMap(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if resp["code"] != float64(common.CodeSuccess) { + t.Fatalf("expected code 0, got %v: %v", resp["code"], resp["message"]) + } + data := resp["data"].(map[string]interface{}) + if data["id"] != "Product" { + t.Fatalf("mindmap root = %v, want Product", data["id"]) + } + if chunks.LastReq == nil || len(chunks.LastReq.Datasets) != 1 || chunks.LastReq.Datasets[0] != "kb-1" { + t.Fatalf("retrieval datasets = %+v, want [kb-1]", chunks.LastReq) + } + if llm.lastTenantID != "user-1" || len(llm.lastMessages) != 2 || !strings.Contains(llm.lastMessages[0].Content, "Hybrid search combines") { + t.Fatalf("unexpected LLM call: tenant=%q messages=%v", llm.lastTenantID, llm.lastMessages) + } +} + func TestDeleteChatHandlerSuccess(t *testing.T) { db := setupChatHandlerTestDB(t) createChatHandlerTestChat(t, db, "chat-1", "user-1") diff --git a/internal/handler/mindmap.go b/internal/handler/mindmap.go new file mode 100644 index 0000000000..95d2067029 --- /dev/null +++ b/internal/handler/mindmap.go @@ -0,0 +1,291 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" + + "ragflow/internal/common" + "ragflow/internal/entity" + modelModule "ragflow/internal/entity/models" + "ragflow/internal/service" +) + +type mindMapRunConfig struct { + Question string + KbIDs common.StringSlice + SearchID string + SearchConfig map[string]interface{} + AuthUserID string + ModelTenantID string + ChunkSvc ChunkRetriever + LLM chatLLM + TenantSvc *service.TenantService +} + +func runMindMap(config mindMapRunConfig) (mindMapNode, error) { + if config.ChunkSvc == nil { + return mindMapNode{}, fmt.Errorf("chunk service not configured") + } + if config.LLM == nil { + return mindMapNode{}, fmt.Errorf("LLM not configured") + } + modelTenantID := config.ModelTenantID + if modelTenantID == "" { + modelTenantID = config.AuthUserID + } + retrievalReq := mindMapRetrievalRequest(config.Question, config.KbIDs, config.SearchID, config.SearchConfig) + ranks, err := config.ChunkSvc.RetrievalTest(retrievalReq, config.AuthUserID) + if err != nil { + return mindMapNode{}, err + } + sections := mindMapSections(ranks) + if len(sections) == 0 { + return mindMapNode{ID: "root", Children: []mindMapNode{}}, nil + } + modelID, _ := config.SearchConfig["chat_id"].(string) + if modelID == "" && config.TenantSvc != nil { + defaultModel, err := config.TenantSvc.GetDefaultModelName(modelTenantID, entity.ModelTypeChat) + if err == nil { + modelID = defaultModel + } + } + response, err := config.LLM.Chat(modelTenantID, modelID, []modelModule.Message{{Role: "user", Content: mindMapPrompt(strings.Join(sections, "\n"))}, {Role: "user", Content: "Output:"}}, &modelModule.ChatConfig{}) + if err != nil { + return mindMapNode{}, err + } + if response == nil || response.Answer == nil { + return mindMapNode{ID: "root", Children: []mindMapNode{}}, nil + } + return parseMindMapMarkdown(*response.Answer), nil +} + +func searchConfigFromDetail(detail map[string]interface{}) map[string]interface{} { + if sc, ok := detail["search_config"].(map[string]interface{}); ok && sc != nil { + return sc + } + if sc, ok := detail["search_config"].(entity.JSONMap); ok && sc != nil { + return map[string]interface{}(sc) + } + return map[string]interface{}{} +} + +func mindMapRetrievalRequest(question string, kbIDs common.StringSlice, searchID string, searchConfig map[string]interface{}) *service.RetrievalTestRequest { + page := 1 + size := 12 + topK := intFromConfig(searchConfig, "top_k", 1024) + similarityThreshold := floatFromConfig(searchConfig, "similarity_threshold", 0.2) + vectorSimilarityWeight := floatFromConfig(searchConfig, "vector_similarity_weight", 0.3) + req := &service.RetrievalTestRequest{ + Datasets: kbIDs, + Question: question, + Page: &page, + Size: &size, + TopK: &topK, + SimilarityThreshold: &similarityThreshold, + VectorSimilarityWeight: &vectorSimilarityWeight, + DocIDs: stringSliceFromConfig(searchConfig, "doc_ids"), + Filter: mapFromConfig(searchConfig, "meta_data_filter"), + } + if searchID != "" { + req.SearchID = &searchID + } + if rerankID, _ := searchConfig["rerank_id"].(string); rerankID != "" { + req.RerankID = &rerankID + } + return req +} + +func mindMapSections(ranks *service.RetrievalTestResponse) []string { + if ranks == nil { + return nil + } + sections := make([]string, 0, len(ranks.Chunks)) + for _, chunk := range ranks.Chunks { + if content, ok := chunk["content_with_weight"].(string); ok && strings.TrimSpace(content) != "" { + sections = append(sections, content) + } + } + return sections +} + +func mergeMindMapKbIDs(saved []string, requested common.StringSlice) common.StringSlice { + seen := map[string]bool{} + merged := make(common.StringSlice, 0, len(saved)+len(requested)) + for _, id := range saved { + id = strings.TrimSpace(id) + if id != "" && !seen[id] { + seen[id] = true + merged = append(merged, id) + } + } + for _, id := range requested { + id = strings.TrimSpace(id) + if id != "" && !seen[id] { + seen[id] = true + merged = append(merged, id) + } + } + return merged +} + +func intFromConfig(config map[string]interface{}, key string, fallback int) int { + switch v := config[key].(type) { + case int: + return v + case int64: + return int(v) + case float64: + return int(v) + case json.Number: + if n, err := v.Int64(); err == nil { + return int(n) + } + } + return fallback +} + +func floatFromConfig(config map[string]interface{}, key string, fallback float64) float64 { + switch v := config[key].(type) { + case float64: + return v + case float32: + return float64(v) + case int: + return float64(v) + case int64: + return float64(v) + case json.Number: + if n, err := v.Float64(); err == nil { + return n + } + } + return fallback +} + +func stringSliceFromConfig(config map[string]interface{}, key string) []string { + switch v := config[key].(type) { + case []string: + return v + case []interface{}: + out := make([]string, 0, len(v)) + for _, item := range v { + if s, ok := item.(string); ok && s != "" { + out = append(out, s) + } + } + return out + } + return nil +} + +func mapFromConfig(config map[string]interface{}, key string) map[string]interface{} { + if m, ok := config[key].(map[string]interface{}); ok { + return m + } + if m, ok := config[key].(entity.JSONMap); ok { + return map[string]interface{}(m) + } + return nil +} + +func mindMapPrompt(inputText string) string { + return `- Role: You're a talent text processor to summarize a piece of text into a mind map. + +- Step of task: + 1. Generate a title for user's 'TEXT'. + 2. Classify the 'TEXT' into sections of a mind map. + 3. If the subject matter is really complex, split them into sub-sections and sub-subsections. + 4. Add a shot content summary of the bottom level section. + +- Output requirement: + - Generate at least 4 levels. + - Always try to maximize the number of sub-sections. + - In language of 'Text' + - MUST IN FORMAT OF MARKDOWN + +-TEXT- +` + inputText + "\n" +} + +type mindMapNode struct { + ID string `json:"id"` + Children []mindMapNode `json:"children"` +} + +var mindMapHeadingRe = regexp.MustCompile(`^(#{1,6})\s+(.+)$`) +var mindMapListRe = regexp.MustCompile(`^(\s*)(?:[-*+]|\d+\.)\s+(.+)$`) + +func parseMindMapMarkdown(text string) mindMapNode { + lines := strings.Split(strings.ReplaceAll(text, "\r\n", "\n"), "\n") + root := mindMapNode{ID: "root", Children: []mindMapNode{}} + stack := []*mindMapNode{&root} + inFence := false + listBaseLevel := 1 + lastWasList := false + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "```") { + inFence = !inFence + lastWasList = false + continue + } + if inFence || trimmed == "" { + lastWasList = false + continue + } + level := 0 + title := "" + if m := mindMapHeadingRe.FindStringSubmatch(trimmed); len(m) == 3 { + level = len(m[1]) + title = cleanMindMapText(m[2]) + lastWasList = false + } else if m := mindMapListRe.FindStringSubmatch(line); len(m) == 3 { + rawLevel := len(m[1])/2 + 1 + if !lastWasList { + listBaseLevel = len(stack) + } + level = listBaseLevel + rawLevel - 1 + title = cleanMindMapText(m[2]) + lastWasList = true + } + if title == "" { + lastWasList = false + continue + } + for len(stack) > level { + stack = stack[:len(stack)-1] + } + parent := stack[len(stack)-1] + parent.Children = append(parent.Children, mindMapNode{ID: title, Children: []mindMapNode{}}) + stack = append(stack, &parent.Children[len(parent.Children)-1]) + } + if len(root.Children) == 1 { + return root.Children[0] + } + return root +} + +func cleanMindMapText(text string) string { + text = strings.TrimSpace(text) + text = strings.Trim(text, "`") + text = strings.Trim(text, "*_ ") + return strings.TrimSpace(text) +} diff --git a/internal/handler/searchbot.go b/internal/handler/searchbot.go index 9a056f045a..6b289f2704 100644 --- a/internal/handler/searchbot.go +++ b/internal/handler/searchbot.go @@ -35,8 +35,8 @@ import ( "go.uber.org/zap" ) -// searchbotLLM is the interface for LLM calls used by SearchBotHandler. -type searchbotLLM interface { +// chatLLM is the interface for LLM calls used by chat-style handlers. +type chatLLM interface { Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) } @@ -59,12 +59,19 @@ type SearchBotAskRequest struct { SearchID string `json:"search_id,omitempty"` } -// SearchBotRealLLM wraps ModelProviderService to implement searchbotLLM. -type SearchBotRealLLM struct { +// SearchBotMindMapRequest is the request body for POST /api/v1/searchbots/mindmap. +type SearchBotMindMapRequest struct { + Question string `json:"question" binding:"required"` + KbIDs common.StringSlice `json:"kb_ids" binding:"required"` + SearchID string `json:"search_id,omitempty"` +} + +// ModelProviderLLM wraps ModelProviderService to implement chatLLM. +type ModelProviderLLM struct { Svc *service.ModelProviderService } -func (r *SearchBotRealLLM) Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) { +func (r *ModelProviderLLM) Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) { chatModel, err := r.Svc.GetChatModel(tenantID, modelID) if err != nil { return nil, err @@ -73,7 +80,7 @@ func (r *SearchBotRealLLM) Chat(tenantID, modelID string, messages []modelModule } // ChatStream implements streamingLLM. -func (r *SearchBotRealLLM) ChatStream(ctx context.Context, tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) { +func (r *ModelProviderLLM) ChatStream(ctx context.Context, tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) { chatModel, err := r.Svc.GetChatModel(tenantID, modelID) if err != nil { return nil, err @@ -161,10 +168,11 @@ type SearchBotRequest struct { // POST /api/v1/searchbots/related_questions // POST /api/v1/searchbots/retrieval_test // POST /api/v1/searchbots/ask +// POST /api/v1/searchbots/mindmap type SearchBotHandler struct { searchSvc *service.SearchService tenantSvc *service.TenantService - llm searchbotLLM + llm chatLLM streamLLM streamingLLM chunkSvc ChunkRetriever askSvc *service.AskService @@ -172,7 +180,7 @@ type SearchBotHandler struct { } // NewSearchBotHandler creates a new SearchBotHandler. -func NewSearchBotHandler(searchSvc *service.SearchService, tenantSvc *service.TenantService, llm searchbotLLM, chunkSvc ChunkRetriever) *SearchBotHandler { +func NewSearchBotHandler(searchSvc *service.SearchService, tenantSvc *service.TenantService, llm chatLLM, chunkSvc ChunkRetriever) *SearchBotHandler { return &SearchBotHandler{searchSvc: searchSvc, tenantSvc: tenantSvc, llm: llm, chunkSvc: chunkSvc, sseWriter: &ginSSEWriter{}} } @@ -422,6 +430,80 @@ func (h *SearchBotHandler) Ask(c *gin.Context) { } +// MindMap generates a query mind map for a shared search bot. +// @Summary Generate Mind Map +// @Description Retrieves related chunks and asks the configured chat model to summarize them into a mind map. +// @Tags searchbots +// @Accept json +// @Produce json +// @Param request body SearchBotMindMapRequest true "Mind map parameters" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/searchbots/mindmap [post] +func (h *SearchBotHandler) MindMap(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + var req SearchBotMindMapRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": err.Error()}) + return + } + + filtered := make(common.StringSlice, 0, len(req.KbIDs)) + for _, id := range req.KbIDs { + if strings.TrimSpace(id) != "" { + filtered = append(filtered, id) + } + } + if len(filtered) == 0 || strings.TrimSpace(req.Question) == "" { + c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": "kb_ids and question are required"}) + return + } + if h.chunkSvc == nil { + jsonInternalError(c, fmt.Errorf("chunk service not configured")) + return + } + if h.llm == nil { + jsonInternalError(c, fmt.Errorf("LLM not configured")) + return + } + + searchConfig := map[string]interface{}{} + if req.SearchID != "" { + if h.searchSvc == nil { + jsonInternalError(c, fmt.Errorf("search service not configured")) + return + } + detail, err := h.searchSvc.GetDetail(req.SearchID) + if err != nil { + jsonInternalError(c, err) + return + } + searchConfig = searchConfigFromDetail(detail) + } + + mindMap, err := runMindMap(mindMapRunConfig{ + Question: req.Question, + KbIDs: filtered, + SearchID: req.SearchID, + SearchConfig: searchConfig, + AuthUserID: user.ID, + ModelTenantID: user.ID, + ChunkSvc: h.chunkSvc, + LLM: h.llm, + TenantSvc: h.tenantSvc, + }) + if err != nil { + common.Warn("searchbot mindmap failed", zap.String("error", err.Error())) + jsonInternalError(c, err) + return + } + jsonResponse(c, common.CodeSuccess, mindMap, "success") +} + // SearchbotDetail returns the public share-page bootstrap payload for a // search app. The route is mounted under apiNoAuth but still requires a beta // token, matching Python's AUTH_BETA flow. diff --git a/internal/handler/searchbot_test.go b/internal/handler/searchbot_test.go index 6ad6557387..b5e299dce0 100644 --- a/internal/handler/searchbot_test.go +++ b/internal/handler/searchbot_test.go @@ -411,13 +411,19 @@ func TestSearchBotsRetrieval_EmptyQuestion(t *testing.T) { } } -// fakeSearchbotLLM implements searchbotLLM for testing. -type fakeSearchbotLLM struct { - response string - err error +// fakeChatLLM implements chatLLM for testing. +type fakeChatLLM struct { + response string + err error + lastTenantID string + lastModelID string + lastMessages []modelModule.Message } -func (f *fakeSearchbotLLM) Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) { +func (f *fakeChatLLM) Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) { + f.lastTenantID = tenantID + f.lastModelID = modelID + f.lastMessages = messages if f.err != nil { return nil, f.err } @@ -438,7 +444,7 @@ func setupSearchBotRequest(body string) (*gin.Context, *httptest.ResponseRecorde // TestSearchBotHandler_Success verifies the happy path. func TestSearchBotHandler_Success(t *testing.T) { - llm := &fakeSearchbotLLM{ + llm := &fakeChatLLM{ response: `Here are some related questions: 1. How do EV impact environment? 2. What are advantages of EV? @@ -472,7 +478,7 @@ func TestSearchBotHandler_Success(t *testing.T) { // TestSearchBotHandler_EmptyResponse verifies empty LLM response returns empty list. func TestSearchBotHandler_EmptyResponse(t *testing.T) { - llm := &fakeSearchbotLLM{ + llm := &fakeChatLLM{ response: "No related questions found.", } h := NewSearchBotHandler(nil, nil, llm, nil) @@ -496,7 +502,7 @@ func TestSearchBotHandler_EmptyResponse(t *testing.T) { // TestSearchBotHandler_LLMFailure verifies error handling on LLM failure. func TestSearchBotHandler_LLMFailure(t *testing.T) { - llm := &fakeSearchbotLLM{ + llm := &fakeChatLLM{ err: errFake{msg: "LLM unavailable"}, } h := NewSearchBotHandler(nil, nil, llm, nil) @@ -514,7 +520,7 @@ func TestSearchBotHandler_LLMFailure(t *testing.T) { // TestSearchBotHandler_MissingQuestion verifies validation. func TestSearchBotHandler_MissingQuestion(t *testing.T) { - llm := &fakeSearchbotLLM{response: "dummy"} + llm := &fakeChatLLM{response: "dummy"} h := NewSearchBotHandler(nil, nil, llm, nil) c, w := setupSearchBotRequest(`{}`) @@ -906,6 +912,64 @@ func TestAskHandler_WhitespaceKbIDFiltered(t *testing.T) { } } +func TestMindMapHandlerSuccess(t *testing.T) { + llm := &fakeChatLLM{response: "# Product\n## Features\n### Search\n#### Hybrid retrieval"} + chunks := &mockChunkService{retrievalTestFn: func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) { + return &service.RetrievalTestResponse{ + Chunks: []map[string]interface{}{{"content_with_weight": "Hybrid search combines vector and keyword retrieval."}}, + }, nil + }} + h := NewSearchBotHandler(nil, nil, llm, chunks) + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("user", &entity.User{ID: "user-1"}) + }) + r.POST("/api/v1/searchbots/mindmap", h.MindMap) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/searchbots/mindmap", strings.NewReader(`{"question":"What is search?","kb_ids":["kb-1"]}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if resp["code"] != float64(common.CodeSuccess) { + t.Fatalf("expected code 0, got %v: %v", resp["code"], resp["message"]) + } + data := resp["data"].(map[string]interface{}) + if data["id"] != "Product" { + t.Fatalf("mindmap root = %v, want Product", data["id"]) + } + if chunks.LastReq == nil { + t.Fatal("RetrievalTest was not called") + } + if *chunks.LastReq.Page != 1 || *chunks.LastReq.Size != 12 || *chunks.LastReq.TopK != 1024 { + t.Fatalf("retrieval defaults = page %d size %d topK %d", *chunks.LastReq.Page, *chunks.LastReq.Size, *chunks.LastReq.TopK) + } + if llm.lastTenantID != "user-1" || len(llm.lastMessages) != 2 || !strings.Contains(llm.lastMessages[0].Content, "Hybrid search combines") { + t.Fatalf("unexpected LLM call: tenant=%q messages=%v", llm.lastTenantID, llm.lastMessages) + } +} + +func TestParseMindMapMarkdown_ListUnderHeading(t *testing.T) { + got := parseMindMapMarkdown("# Product\n- Features\n - Search") + if got.ID != "Product" { + t.Fatalf("root = %q, want Product", got.ID) + } + if len(got.Children) != 1 || got.Children[0].ID != "Features" { + t.Fatalf("children = %+v, want Features under Product", got.Children) + } + if len(got.Children[0].Children) != 1 || got.Children[0].Children[0].ID != "Search" { + t.Fatalf("nested children = %+v, want Search under Features", got.Children[0].Children) + } +} + // ---- SSE helper direct tests ---- func TestSseAnswer_Final(t *testing.T) { diff --git a/internal/router/router.go b/internal/router/router.go index 258ac2d6ad..6968ebab23 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -194,6 +194,7 @@ func (r *Router) Setup(engine *gin.Engine) { searchbotGroup.POST("/related_questions", r.searchBotHandler.Handle) searchbotGroup.POST("/retrieval_test", r.searchBotHandler.RetrievalTest) searchbotGroup.POST("/ask", r.searchBotHandler.Ask) + searchbotGroup.POST("/mindmap", r.searchBotHandler.MindMap) if r.botHandler != nil { chatbotGroup := apiBetaAuth.Group("/chatbots") @@ -271,6 +272,7 @@ func (r *Router) Setup(engine *gin.Engine) { } // Chat routes + v1.POST("/chat/mindmap", r.chatHandler.MindMap) chats := v1.Group("/chats") { chats.GET("", r.chatHandler.ListChats) diff --git a/internal/router/router_test.go b/internal/router/router_test.go index 414ac4283a..5b2bfa854d 100644 --- a/internal/router/router_test.go +++ b/internal/router/router_test.go @@ -1,13 +1,16 @@ package router import ( + "encoding/json" "net/http" "net/http/httptest" "testing" "github.com/gin-gonic/gin" + "ragflow/internal/common" "ragflow/internal/handler" + "ragflow/internal/service" ) func TestConnectorRoutesDoNotConflictWithOAuthCallbacks(t *testing.T) { @@ -93,3 +96,62 @@ func TestRouterSetupRegistersUpdateDatasetRoute(t *testing.T) { t.Fatalf("status=%d body=%s; want auth middleware to handle registered UpdateDataset route", resp.Code, resp.Body.String()) } } + +func TestRouterSetupRegistersSearchbotMindMapRoute(t *testing.T) { + gin.SetMode(gin.TestMode) + + engine := gin.New() + r := &Router{ + authHandler: handler.NewAuthHandler(), + searchBotHandler: handler.NewSearchBotHandler(nil, nil, nil, nil), + } + r.Setup(engine) + + resp := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/searchbots/mindmap", nil) + engine.ServeHTTP(resp, req) + + if resp.Code == http.StatusNotFound { + t.Fatalf("POST /api/v1/searchbots/mindmap returned 404; MindMap route is not registered") + } + var body struct { + Code common.ErrorCode `json:"code"` + } + if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil { + t.Fatalf("failed to decode response body: %v", err) + } + if body.Code != common.CodeUnauthorized { + t.Fatalf("status=%d body=%s; want beta auth middleware to handle registered MindMap route", resp.Code, resp.Body.String()) + } +} + +func TestRouterSetupRegistersChatMindMapRoute(t *testing.T) { + gin.SetMode(gin.TestMode) + + engine := gin.New() + r := &Router{ + authHandler: handler.NewAuthHandler(), + chatHandler: handler.NewChatHandler( + service.NewChatService(), + service.NewUserService(), + ), + } + r.Setup(engine) + + resp := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/chat/mindmap", nil) + engine.ServeHTTP(resp, req) + + if resp.Code == http.StatusNotFound { + t.Fatalf("POST /api/v1/chat/mindmap returned 404; Chat MindMap route is not registered") + } + var body struct { + Code common.ErrorCode `json:"code"` + } + if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil { + t.Fatalf("failed to decode response body: %v", err) + } + if body.Code != common.CodeUnauthorized { + t.Fatalf("status=%d body=%s; want auth middleware to handle registered Chat MindMap route", resp.Code, resp.Body.String()) + } +}