diff --git a/cmd/server_main.go b/cmd/server_main.go index 8e72aecd0f..6bc08138b1 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -297,18 +297,17 @@ func startServer(config *server.Config) { // no-op echo (the audio package contract), so this is always // safe to call. configureTTSSynthesizer(modelProviderService) - modelProviderLLM := &handler.ModelProviderLLM{Svc: modelProviderService} searchBotHandler := handler.NewSearchBotHandler( searchService, tenantService, - modelProviderLLM, + modelProviderService, chunkService, ) - searchBotHandler.SetStreamLLM(modelProviderLLM) + searchBotHandler.SetStreamLLM(modelProviderService) askService := service.NewAskService(chunkService, nil, 0, 0) searchBotHandler.SetAskService(askService) - chatHandler.SetMindMapDependencies(searchService, tenantService, modelProviderLLM, chunkService) - searchHandler.SetCompletionDependencies(modelProviderLLM, askService) + chatHandler.SetMindMapDependencies(searchService, tenantService, modelProviderService, chunkService) + searchHandler.SetCompletionDependencies(modelProviderService, askService) pluginHandler := handler.NewPluginHandler(service.NewPluginService()) modelHandler := handler.NewModelHandler(service.NewModelProviderService()) fileCommitHandler := handler.NewFileCommitHandler(service.NewFileCommitService()) diff --git a/internal/handler/chat.go b/internal/handler/chat.go index 614c3281be..eba5f41266 100644 --- a/internal/handler/chat.go +++ b/internal/handler/chat.go @@ -35,8 +35,8 @@ type ChatHandler struct { userService *service.UserService searchSvc *service.SearchService tenantSvc *service.TenantService - llm chatLLM - chunkSvc ChunkRetriever + llm *service.ModelProviderService + chunkSvc service.Retriever } // NewChatHandler create chat handler @@ -48,7 +48,7 @@ func NewChatHandler(chatService *service.ChatService, userService *service.UserS } // SetMindMapDependencies sets dependencies used by POST /api/v1/chat/mindmap. -func (h *ChatHandler) SetMindMapDependencies(searchSvc *service.SearchService, tenantSvc *service.TenantService, llm chatLLM, chunkSvc ChunkRetriever) { +func (h *ChatHandler) SetMindMapDependencies(searchSvc *service.SearchService, tenantSvc *service.TenantService, llm *service.ModelProviderService, chunkSvc service.Retriever) { h.searchSvc = searchSvc h.tenantSvc = tenantSvc h.llm = llm diff --git a/internal/handler/chat_recommendation.go b/internal/handler/chat_recommendation.go new file mode 100644 index 0000000000..1650af6581 --- /dev/null +++ b/internal/handler/chat_recommendation.go @@ -0,0 +1,67 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "net/http" + "strings" + + "github.com/gin-gonic/gin" + + "ragflow/internal/common" + "ragflow/internal/service" +) + +// ChatRecommendationRequest is the request body for POST /api/v1/chat/recommendation. +type ChatRecommendationRequest struct { + Question string `json:"question" binding:"required"` + SearchID string `json:"search_id,omitempty"` +} + +// Recommendation generates related search questions for a chat query. +// @Summary Generate Chat Recommendations +// @Description Generates related questions using the chat model configured by search_config.chat_id or the tenant default. +// @Tags chat +// @Accept json +// @Produce json +// @Param request body ChatRecommendationRequest true "Recommendation parameters" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/chat/recommendation [post] +func (h *ChatHandler) Recommendation(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + var req ChatRecommendationRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{"code": common.CodeArgumentError, "data": nil, "message": "question is required"}) + return + } + if strings.TrimSpace(req.Question) == "" { + c.JSON(http.StatusOK, gin.H{"code": common.CodeArgumentError, "data": nil, "message": "question is required"}) + return + } + questions, err := service.GenerateRelatedQuestions(user.ID, req.Question, req.SearchID, h.searchSvc, h.tenantSvc, h.llm) + if err != nil { + jsonInternalError(c, err) + return + } + + jsonResponse(c, common.CodeSuccess, questions, "success") +} diff --git a/internal/handler/chat_test.go b/internal/handler/chat_test.go index 253efe0d2f..f75a5b07fb 100644 --- a/internal/handler/chat_test.go +++ b/internal/handler/chat_test.go @@ -3,7 +3,6 @@ package handler import ( "encoding/json" "net/http" - "strings" "testing" "github.com/gin-gonic/gin" @@ -72,41 +71,6 @@ func createChatHandlerTestChat(t *testing.T, db *gorm.DB, id, tenantID string) { } } -func TestChatMindMapHandlerSuccess(t *testing.T) { - llm := &fakeChatLLM{response: "# Product\n## Features\n### Search"} - chunks := &mockChunkService{retrievalTestFn: func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) { - return &service.RetrievalTestResponse{ - Chunks: []map[string]interface{}{{"content_with_weight": "Hybrid search combines vector and keyword retrieval."}}, - }, nil - }} - h := NewChatHandler(service.NewChatService(), service.NewUserService()) - h.SetMindMapDependencies(nil, nil, llm, chunks) - c, w := setupGinContextWithUser("POST", "/api/v1/chat/mindmap", `{"question":"What is search?","kb_ids":["kb-1"]}`) - - h.MindMap(c) - - if w.Code != http.StatusOK { - t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) - } - var resp map[string]interface{} - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatal(err) - } - if resp["code"] != float64(common.CodeSuccess) { - t.Fatalf("expected code 0, got %v: %v", resp["code"], resp["message"]) - } - data := resp["data"].(map[string]interface{}) - if data["id"] != "Product" { - t.Fatalf("mindmap root = %v, want Product", data["id"]) - } - if chunks.LastReq == nil || len(chunks.LastReq.Datasets) != 1 || chunks.LastReq.Datasets[0] != "kb-1" { - t.Fatalf("retrieval datasets = %+v, want [kb-1]", chunks.LastReq) - } - if llm.lastTenantID != "user-1" || len(llm.lastMessages) != 2 || !strings.Contains(llm.lastMessages[0].Content, "Hybrid search combines") { - t.Fatalf("unexpected LLM call: tenant=%q messages=%v", llm.lastTenantID, llm.lastMessages) - } -} - func TestDeleteChatHandlerSuccess(t *testing.T) { db := setupChatHandlerTestDB(t) createChatHandlerTestChat(t, db, "chat-1", "user-1") diff --git a/internal/handler/mindmap.go b/internal/handler/mindmap.go index 95d2067029..92ee3edfe4 100644 --- a/internal/handler/mindmap.go +++ b/internal/handler/mindmap.go @@ -35,8 +35,8 @@ type mindMapRunConfig struct { SearchConfig map[string]interface{} AuthUserID string ModelTenantID string - ChunkSvc ChunkRetriever - LLM chatLLM + ChunkSvc service.Retriever + LLM *service.ModelProviderService TenantSvc *service.TenantService } diff --git a/internal/handler/search.go b/internal/handler/search.go index c7db252ecf..ff58928a23 100644 --- a/internal/handler/search.go +++ b/internal/handler/search.go @@ -30,7 +30,7 @@ import ( type SearchHandler struct { searchService *service.SearchService userService *service.UserService - streamLLM streamingLLM + streamLLM *service.ModelProviderService askService *service.AskService sseWriter SSEWriter } @@ -45,7 +45,7 @@ func NewSearchHandler(searchService *service.SearchService, userService *service } // SetCompletionDependencies wires the streaming search completion runtime. -func (h *SearchHandler) SetCompletionDependencies(streamLLM streamingLLM, askService *service.AskService) { +func (h *SearchHandler) SetCompletionDependencies(streamLLM *service.ModelProviderService, askService *service.AskService) { h.streamLLM = streamLLM h.askService = askService } @@ -495,7 +495,7 @@ func (h *SearchHandler) Completion(c *gin.Context) { return } - adapter := &askStreamAdapter{llm: h.streamLLM, tenantID: plan.UserID, modelID: plan.ModelID} + adapter := &service.TenantStreamAdapter{LLM: h.streamLLM, TenantID: plan.UserID, ModelID: plan.ModelID} hadError := false for delta := range h.askService.StreamWithOptions(c.Request.Context(), adapter, plan.UserID, plan.Question, plan.KBIDs, plan.Options) { diff --git a/internal/handler/searchbot.go b/internal/handler/searchbot.go index 6b289f2704..dd386d6c6d 100644 --- a/internal/handler/searchbot.go +++ b/internal/handler/searchbot.go @@ -17,41 +17,21 @@ package handler import ( - "context" "encoding/json" "fmt" "io" "net/http" - "regexp" "strings" "github.com/gin-gonic/gin" "ragflow/internal/common" "ragflow/internal/entity" - modelModule "ragflow/internal/entity/models" "ragflow/internal/service" "go.uber.org/zap" ) -// chatLLM is the interface for LLM calls used by chat-style handlers. -type chatLLM interface { - Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) -} - -// ChunkRetriever abstracts chunk retrieval for the searchbots handler. -type ChunkRetriever interface { - RetrievalTest(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) -} - -// streamingLLM abstracts streaming chat for the Ask endpoint. -// The returned channel delivers raw text deltas from the LLM. -// Implementations should respect ctx cancellation to prevent goroutine leaks. -type streamingLLM interface { - ChatStream(ctx context.Context, tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) -} - // SearchBotAskRequest is the request body for POST /api/v1/searchbots/ask. type SearchBotAskRequest struct { Question string `json:"question" binding:"required"` @@ -66,55 +46,6 @@ type SearchBotMindMapRequest struct { SearchID string `json:"search_id,omitempty"` } -// ModelProviderLLM wraps ModelProviderService to implement chatLLM. -type ModelProviderLLM struct { - Svc *service.ModelProviderService -} - -func (r *ModelProviderLLM) Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) { - chatModel, err := r.Svc.GetChatModel(tenantID, modelID) - if err != nil { - return nil, err - } - return chatModel.ModelDriver.ChatWithMessages(*chatModel.ModelName, messages, chatModel.APIConfig, config) -} - -// ChatStream implements streamingLLM. -func (r *ModelProviderLLM) ChatStream(ctx context.Context, tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) { - chatModel, err := r.Svc.GetChatModel(tenantID, modelID) - if err != nil { - return nil, err - } - return chatStreamWithContext(ctx, chatModel, messages, config), nil -} - -// chatStreamWithContext creates a streaming LLM channel that stops sending -// when ctx is cancelled, preventing goroutine leaks on client disconnect. -func chatStreamWithContext(ctx context.Context, chatModel *modelModule.ChatModel, messages []modelModule.Message, config *modelModule.ChatConfig) <-chan string { - ch := make(chan string, 256) - go func() { - defer close(ch) - if err := chatModel.ModelDriver.ChatStreamlyWithSender(*chatModel.ModelName, messages, chatModel.APIConfig, config, - func(delta *string, _ *string) error { - if delta == nil { - return nil - } - select { - case ch <- *delta: - return nil - case <-ctx.Done(): - return ctx.Err() - } - }); err != nil { - if err == context.Canceled || err == context.DeadlineExceeded { - return - } - common.Warn("ChatStreamlyWithSender returned error", zap.Error(err)) - } - }() - return ch -} - // SearchBotRetrievalTestRequest is the request body for POST /api/v1/searchbots/retrieval_test. type SearchBotRetrievalTestRequest struct { KbIDs common.StringSlice `json:"kb_ids" binding:"required"` @@ -172,38 +103,24 @@ type SearchBotRequest struct { type SearchBotHandler struct { searchSvc *service.SearchService tenantSvc *service.TenantService - llm chatLLM - streamLLM streamingLLM - chunkSvc ChunkRetriever + llm *service.ModelProviderService + streamLLM *service.ModelProviderService + chunkSvc service.Retriever askSvc *service.AskService sseWriter SSEWriter } // NewSearchBotHandler creates a new SearchBotHandler. -func NewSearchBotHandler(searchSvc *service.SearchService, tenantSvc *service.TenantService, llm chatLLM, chunkSvc ChunkRetriever) *SearchBotHandler { +func NewSearchBotHandler(searchSvc *service.SearchService, tenantSvc *service.TenantService, llm *service.ModelProviderService, chunkSvc service.Retriever) *SearchBotHandler { return &SearchBotHandler{searchSvc: searchSvc, tenantSvc: tenantSvc, llm: llm, chunkSvc: chunkSvc, sseWriter: &ginSSEWriter{}} } // SetStreamLLM sets the streaming LLM for the Ask endpoint. -func (h *SearchBotHandler) SetStreamLLM(llm streamingLLM) { h.streamLLM = llm } +func (h *SearchBotHandler) SetStreamLLM(llm *service.ModelProviderService) { h.streamLLM = llm } // SetAskService sets the AskService used by the Ask endpoint. func (h *SearchBotHandler) SetAskService(svc *service.AskService) { h.askSvc = svc } -// askStreamAdapter adapts handler.streamingLLM to service.StreamingLLM. -type askStreamAdapter struct { - llm streamingLLM - tenantID string - modelID string -} - -func (a *askStreamAdapter) ChatStream(ctx context.Context, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) { - if a.llm == nil { - return nil, fmt.Errorf("streaming LLM not configured") - } - return a.llm.ChatStream(ctx, a.tenantID, a.modelID, messages, config) -} - // Handle generates related search questions based on a user query. // @Summary Generate Related Questions // @Description Generates 5-10 related search questions to expand the search scope. @@ -239,36 +156,9 @@ func (h *SearchBotHandler) Handle(c *gin.Context) { return } - // Resolve model ID from search config if provided - modelID := "" - if req.SearchID != "" && h.searchSvc != nil { - if detail, err := h.searchSvc.GetDetail(req.SearchID); err == nil { - if sc, ok := detail["search_config"].(map[string]interface{}); ok { - if cid, ok := sc["chat_id"].(string); ok && cid != "" { - modelID = cid - } - } - } - } - if modelID == "" && h.tenantSvc != nil { - defaultModel, err := h.tenantSvc.GetDefaultModelName(user.ID, entity.ModelTypeChat) - if err == nil && defaultModel != "" { - modelID = defaultModel - } - } - - messages := []modelModule.Message{ - {Role: "system", Content: relatedQuestionPrompt}, - {Role: "user", Content: "Keywords: " + req.Question + "\nRelated search terms:\n"}, - } - - genConf := &modelModule.ChatConfig{ - Temperature: ptrFloat64(0.9), - } - - response, err := h.llm.Chat(user.ID, modelID, messages, genConf) + questions, err := service.GenerateRelatedQuestions(user.ID, req.Question, req.SearchID, h.searchSvc, h.tenantSvc, h.llm) if err != nil { - common.Warn("searchbot LLM call failed", zap.String("error", err.Error())) + common.Warn("searchbot related questions failed", zap.String("error", err.Error())) c.JSON(http.StatusOK, gin.H{ "code": common.CodeOperatingError, "data": nil, @@ -277,10 +167,6 @@ func (h *SearchBotHandler) Handle(c *gin.Context) { return } - var questions []string - if response != nil && response.Answer != nil { - questions = parseRelatedQuestions(*response.Answer) - } c.JSON(http.StatusOK, gin.H{ "code": common.CodeSuccess, "data": questions, @@ -409,8 +295,12 @@ func (h *SearchBotHandler) Ask(c *gin.Context) { h.sseWriter.Write(c, sseError("ask service not configured")) return } + if h.streamLLM == nil { + h.sseWriter.Write(c, sseError("streaming LLM not configured")) + return + } ctx := c.Request.Context() - adapter := &askStreamAdapter{llm: h.streamLLM, tenantID: user.ID, modelID: modelID} + adapter := &service.TenantStreamAdapter{LLM: h.streamLLM, TenantID: user.ID, ModelID: modelID} for delta := range h.askSvc.Stream(ctx, adapter, user.ID, req.Question, filtered) { switch delta.Kind { case service.AskDeltaAnswer: @@ -607,7 +497,6 @@ func sseMarker(marker string) string { return fmt.Sprintf("data: %s\n\n", string(b)) } -// SSEWriter writes an SSE event to the client. type SSEWriter interface { Write(c *gin.Context, data string) } @@ -683,78 +572,3 @@ func applyRetrievalDefaults(req *SearchBotRetrievalTestRequest) { req.VectorSimilarityWeight = &v } } - -var relatedQuestionLineRe = regexp.MustCompile(`^\d+\.\s`) - -// parseRelatedQuestions extracts numbered list items from an LLM response. -// Lines matching "^N. " are extracted and the number prefix is stripped. -func parseRelatedQuestions(text string) []string { - var result []string - for _, line := range strings.Split(text, "\n") { - if relatedQuestionLineRe.MatchString(line) { - result = append(result, relatedQuestionLineRe.ReplaceAllString(line, "")) - } - } - if result == nil { - return []string{} - } - return result -} - -// relatedQuestionPrompt is the system prompt for generating related search questions. -// Matches Python rag/prompts/related_question.md -const relatedQuestionPrompt = `# Role -You are an AI language model assistant tasked with generating **5-10 related questions** based on a user's original query. -These questions should help **expand the search query scope** and **improve search relevance**. - ---- - -## Instructions - -**Input:** -You are provided with a **user's question**. - -**Output:** -Generate **5-10 alternative questions** that are **related** to the original user question. -These alternatives should help retrieve a **broader range of relevant documents** from a vector database. - -**Context:** -Focus on **rephrasing** the original question in different ways, ensuring the alternative questions are **diverse but still connected** to the topic of the original query. -Do **not** create overly obscure, irrelevant, or unrelated questions. - -**Fallback:** -If you cannot generate any relevant alternatives, do **not** return any questions. - ---- - -## Guidance - -1. Each alternative should be **unique** but still **relevant** to the original query. -2. Keep the phrasing **clear, concise, and easy to understand**. -3. Avoid overly technical jargon or specialized terms **unless directly relevant**. -4. Ensure that each question **broadens** the search angle, **not narrows** it. - ---- - -## Example - -**Original Question:** -> What are the benefits of electric vehicles? - -**Alternative Questions:** -1. How do electric vehicles impact the environment? -2. What are the advantages of owning an electric car? -3. What is the cost-effectiveness of electric vehicles? -4. How do electric vehicles compare to traditional cars in terms of fuel efficiency? -5. What are the environmental benefits of switching to electric cars? -6. How do electric vehicles help reduce carbon emissions? -7. Why are electric vehicles becoming more popular? -8. What are the long-term savings of using electric vehicles? -9. How do electric vehicles contribute to sustainability? -10. What are the key benefits of electric vehicles for consumers? - ---- - -## Reason -Rephrasing the original query into multiple alternative questions helps the user explore **different aspects** of their search topic, improving the **quality of search results**. -These questions guide the search engine to provide a **more comprehensive set** of relevant documents.` diff --git a/internal/handler/searchbot_test.go b/internal/handler/searchbot_test.go index b5e299dce0..04b246d8a5 100644 --- a/internal/handler/searchbot_test.go +++ b/internal/handler/searchbot_test.go @@ -17,29 +17,22 @@ package handler import ( - "context" "encoding/json" "errors" "fmt" - "io" "net/http" "net/http/httptest" "strings" "testing" - "time" "ragflow/internal/common" - "ragflow/internal/dao" "ragflow/internal/entity" - modelModule "ragflow/internal/entity/models" "ragflow/internal/service" "github.com/gin-gonic/gin" - "github.com/glebarez/sqlite" - "gorm.io/gorm" ) -// mockChunkService implements ChunkRetriever for testing. +// mockChunkService implements service.Retriever for testing. // It captures the last request received so tests can verify field mapping. type mockChunkService struct { retrievalTestFn func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) @@ -411,25 +404,6 @@ func TestSearchBotsRetrieval_EmptyQuestion(t *testing.T) { } } -// fakeChatLLM implements chatLLM for testing. -type fakeChatLLM struct { - response string - err error - lastTenantID string - lastModelID string - lastMessages []modelModule.Message -} - -func (f *fakeChatLLM) Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) { - f.lastTenantID = tenantID - f.lastModelID = modelID - f.lastMessages = messages - if f.err != nil { - return nil, f.err - } - return &modelModule.ChatResponse{Answer: &f.response}, nil -} - func setupSearchBotRequest(body string) (*gin.Context, *httptest.ResponseRecorder) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() @@ -442,86 +416,9 @@ func setupSearchBotRequest(body string) (*gin.Context, *httptest.ResponseRecorde return c, w } -// TestSearchBotHandler_Success verifies the happy path. -func TestSearchBotHandler_Success(t *testing.T) { - llm := &fakeChatLLM{ - response: `Here are some related questions: -1. How do EV impact environment? -2. What are advantages of EV? -3. Cost of EV?`, - } - h := NewSearchBotHandler(nil, nil, llm, nil) - - c, w := setupSearchBotRequest(`{"question": "EV benefits"}`) - h.Handle(c) - - var resp map[string]interface{} - json.Unmarshal(w.Body.Bytes(), &resp) - if resp["code"] != float64(common.CodeSuccess) { - t.Fatalf("expected code 0, got %v: %v", resp["code"], resp["message"]) - } - if msg, _ := resp["message"].(string); msg != "success" { - t.Errorf("expected message 'success', got %q", msg) - } - - questions, ok := resp["data"].([]interface{}) - if !ok { - t.Fatalf("expected data array, got %T", resp["data"]) - } - if len(questions) != 3 { - t.Fatalf("expected 3 questions, got %d", len(questions)) - } - if questions[0] != "How do EV impact environment?" { - t.Errorf("unexpected [0]: %v", questions[0]) - } -} - -// TestSearchBotHandler_EmptyResponse verifies empty LLM response returns empty list. -func TestSearchBotHandler_EmptyResponse(t *testing.T) { - llm := &fakeChatLLM{ - response: "No related questions found.", - } - h := NewSearchBotHandler(nil, nil, llm, nil) - - c, w := setupSearchBotRequest(`{"question": "EV benefits"}`) - h.Handle(c) - - var resp map[string]interface{} - json.Unmarshal(w.Body.Bytes(), &resp) - if resp["code"] != float64(common.CodeSuccess) { - t.Fatalf("expected code 0, got %v: %v", resp["code"], resp["message"]) - } - questions, ok := resp["data"].([]interface{}) - if !ok { - t.Fatalf("expected data array, got %T", resp["data"]) - } - if len(questions) != 0 { - t.Errorf("expected 0 questions, got %d", len(questions)) - } -} - -// TestSearchBotHandler_LLMFailure verifies error handling on LLM failure. -func TestSearchBotHandler_LLMFailure(t *testing.T) { - llm := &fakeChatLLM{ - err: errFake{msg: "LLM unavailable"}, - } - h := NewSearchBotHandler(nil, nil, llm, nil) - - c, w := setupSearchBotRequest(`{"question": "EV benefits"}`) - h.Handle(c) - - var resp map[string]interface{} - json.Unmarshal(w.Body.Bytes(), &resp) - code, _ := resp["code"].(float64) - if code == 0 { - t.Errorf("expected error code, got 0") - } -} - // TestSearchBotHandler_MissingQuestion verifies validation. func TestSearchBotHandler_MissingQuestion(t *testing.T) { - llm := &fakeChatLLM{response: "dummy"} - h := NewSearchBotHandler(nil, nil, llm, nil) + h := NewSearchBotHandler(nil, nil, nil, nil) c, w := setupSearchBotRequest(`{}`) h.Handle(c) @@ -534,96 +431,11 @@ func TestSearchBotHandler_MissingQuestion(t *testing.T) { } } -// errFake implements error for testing. -type errFake struct{ msg string } - -func (e errFake) Error() string { return e.msg } - -// Existing parse tests below -func TestParseRelatedQuestions_Standard(t *testing.T) { - input := `1. How do electric vehicles impact the environment? -2. What are the advantages of owning an electric car? -3. What is the cost-effectiveness?` - - got := parseRelatedQuestions(input) - if len(got) != 3 { - t.Fatalf("expected 3, got %d", len(got)) - } - if got[0] != "How do electric vehicles impact the environment?" { - t.Errorf("unexpected [0]: %q", got[0]) - } - if got[1] != "What are the advantages of owning an electric car?" { - t.Errorf("unexpected [1]: %q", got[1]) - } - if got[2] != "What is the cost-effectiveness?" { - t.Errorf("unexpected [2]: %q", got[2]) - } -} - -func TestParseRelatedQuestions_Empty(t *testing.T) { - got := parseRelatedQuestions("") - if len(got) != 0 { - t.Errorf("expected 0, got %d", len(got)) - } -} - -func TestParseRelatedQuestions_NoNumberedLines(t *testing.T) { - input := `Here are some related questions: -- First question -- Second question` - - got := parseRelatedQuestions(input) - if len(got) != 0 { - t.Errorf("expected 0, got %d", len(got)) - } -} - -func TestParseRelatedQuestions_MixedContent(t *testing.T) { - input := `Here are some related questions: -1. First related question. -Some explanation text. -2. Second related question. -More text. -3. Third related question.` - - got := parseRelatedQuestions(input) - if len(got) != 3 { - t.Fatalf("expected 3, got %d", len(got)) - } - if got[0] != "First related question." { - t.Errorf("unexpected [0]: %q", got[0]) - } - if got[1] != "Second related question." { - t.Errorf("unexpected [1]: %q", got[1]) - } - if got[2] != "Third related question." { - t.Errorf("unexpected [2]: %q", got[2]) - } -} - -func TestParseRelatedQuestions_MultiDigit(t *testing.T) { - input := `10. Tenth question. -11. Eleventh question.` - - got := parseRelatedQuestions(input) - if len(got) != 2 { - t.Fatalf("expected 2, got %d", len(got)) - } - if got[0] != "Tenth question." { - t.Errorf("unexpected [0]: %q", got[0]) - } - if got[1] != "Eleventh question." { - t.Errorf("unexpected [1]: %q", got[1]) - } -} - // ---- Ask handler tests ---- func TestAskHandler_MissingQuestion(t *testing.T) { - llm := &fakeStreamingLLM{chunks: []string{"answer"}} ret := &fakeChunkRetriever{result: &service.RetrievalTestResponse{}} h := NewSearchBotHandler(nil, nil, nil, ret) - h.SetStreamLLM(llm) c, w := cw() c.Request = httptest.NewRequest("POST", "/api/v1/searchbots/ask", strings.NewReader(`{"kb_ids": ["kb1"]}`)) @@ -636,10 +448,8 @@ func TestAskHandler_MissingQuestion(t *testing.T) { } func TestAskHandler_MissingKbIDs(t *testing.T) { - llm := &fakeStreamingLLM{chunks: []string{"answer"}} ret := &fakeChunkRetriever{result: &service.RetrievalTestResponse{}} h := NewSearchBotHandler(nil, nil, nil, ret) - h.SetStreamLLM(llm) c, w := cw() c.Request = httptest.NewRequest("POST", "/api/v1/searchbots/ask", strings.NewReader(`{"question": "test"}`)) @@ -651,46 +461,6 @@ func TestAskHandler_MissingKbIDs(t *testing.T) { } } -// fakeStreamingLLM implements streamingLLM for testing. -type fakeStreamingLLM struct { - chunks []string - err error - delay time.Duration -} - -func (f *fakeStreamingLLM) ChatStream(ctx context.Context, tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) { - if f.err != nil { - return nil, f.err - } - if f.delay > 0 { - ch := make(chan string) - go func() { - defer close(ch) - for i, chunk := range f.chunks { - if i > 0 { - select { - case <-time.After(f.delay): - case <-ctx.Done(): - return - } - } - select { - case ch <- chunk: - case <-ctx.Done(): - return - } - } - }() - return ch, nil - } - ch := make(chan string, len(f.chunks)+1) - for _, c := range f.chunks { - ch <- c - } - close(ch) - return ch, nil -} - type fakeChunkRetriever struct { result *service.RetrievalTestResponse err error @@ -721,117 +491,14 @@ func (w *bufferSSEWriter) Write(_ *gin.Context, data string) { func (w *bufferSSEWriter) String() string { return w.buf.String() } -func setupAskHandlerTenantDB(t *testing.T) { - t.Helper() - - db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{ - TranslateError: true, - }) - if err != nil { - t.Fatalf("failed to open sqlite: %v", err) - } - sqlDB, err := db.DB() - if err != nil { - t.Fatalf("failed to get sqlite db: %v", err) - } - sqlDB.SetMaxOpenConns(1) - - if err := db.AutoMigrate(&entity.Tenant{}); err != nil { - t.Fatalf("failed to migrate tenant table: %v", err) - } - - status := "1" - name := "Test Tenant" - if err := db.Create(&entity.Tenant{ - ID: "user-1", - Name: &name, - LLMID: "test-model", - EmbdID: "test-embedding", - ASRID: "test-asr", - Img2TxtID: "test-image", - RerankID: "test-rerank", - ParserIDs: "naive", - Status: &status, - }).Error; err != nil { - t.Fatalf("failed to create tenant: %v", err) - } - - orig := dao.DB - dao.DB = db - t.Cleanup(func() { - dao.DB = orig - _ = sqlDB.Close() - }) -} - -func TestAskHandler_DisablesWriteDeadlineForSSE(t *testing.T) { - setupAskHandlerTenantDB(t) - gin.SetMode(gin.TestMode) - - ret := &fakeChunkRetriever{result: &service.RetrievalTestResponse{ - Chunks: []map[string]interface{}{ - {"id": "c1", "content_with_weight": "test chunk", "docnm_kwd": "Doc", "kb_id": "kb1", "doc_id": "d1"}, - }, - DocAggs: []map[string]interface{}{{"doc_id": "d1", "count": 1}}, - }} - llm := &fakeStreamingLLM{ - chunks: []string{"first response chunk", "second response chunk"}, - delay: 120 * time.Millisecond, - } - h := NewSearchBotHandler(nil, service.NewTenantService(), nil, ret) - h.SetStreamLLM(llm) - h.SetAskService(service.NewAskService(ret, nil, 0, 1)) - - router := gin.New() - router.Use(func(c *gin.Context) { - c.Set("user", &entity.User{ID: "user-1"}) - }) - router.POST("/api/v1/searchbots/ask", h.Ask) - - server := httptest.NewUnstartedServer(router) - server.Config.WriteTimeout = 30 * time.Millisecond - server.Start() - defer server.Close() - - client := server.Client() - client.Timeout = time.Second - resp, err := client.Post(server.URL+"/api/v1/searchbots/ask", "application/json", - strings.NewReader(`{"question": "test", "kb_ids": ["kb1"]}`)) - if err != nil { - t.Fatalf("post ask stream: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected status 200, got %d", resp.StatusCode) - } - if contentType := resp.Header.Get("Content-Type"); !strings.Contains(contentType, "text/event-stream") { - t.Fatalf("expected SSE content type, got %q", contentType) - } - - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read ask stream body: %v", err) - } - - body := string(bodyBytes) - for _, want := range []string{"first response chunk", "second response chunk"} { - if !strings.Contains(body, want) { - t.Fatalf("stream body missing %q: %q", want, body) - } - } -} - // ---- Ask handler tests ---- func TestAskHandler_EmptyQuestion(t *testing.T) { - llm := &fakeStreamingLLM{chunks: []string{"answer"}} ret := &fakeChunkRetriever{result: &service.RetrievalTestResponse{ Chunks: []map[string]interface{}{{"id": "c1", "content_with_weight": "test"}}, }} h := NewSearchBotHandler(nil, nil, nil, ret) - h.SetStreamLLM(llm) c, w := cw() c.Request = httptest.NewRequest("POST", "/api/v1/searchbots/ask", strings.NewReader(`{"question": " ", "kb_ids": ["kb1"]}`)) @@ -844,10 +511,8 @@ func TestAskHandler_EmptyQuestion(t *testing.T) { } func TestAskHandler_EmptyKbIDs(t *testing.T) { - llm := &fakeStreamingLLM{chunks: []string{"answer"}} ret := &fakeChunkRetriever{result: &service.RetrievalTestResponse{}} h := NewSearchBotHandler(nil, nil, nil, ret) - h.SetStreamLLM(llm) c, w := cw() c.Request = httptest.NewRequest("POST", "/api/v1/searchbots/ask", strings.NewReader(`{"question": "test", "kb_ids": []}`)) @@ -861,13 +526,11 @@ func TestAskHandler_EmptyKbIDs(t *testing.T) { func TestAskHandler_NoChatModel(t *testing.T) { buf := &bufferSSEWriter{} - llm := &fakeStreamingLLM{chunks: []string{"answer"}} ret := &fakeChunkRetriever{result: &service.RetrievalTestResponse{ Chunks: []map[string]interface{}{{"id": "c1", "content_with_weight": "test"}}, }} h := NewSearchBotHandler(nil, nil, nil, ret) h.sseWriter = buf - h.SetStreamLLM(llm) c, _ := cw() c.Request = httptest.NewRequest("POST", "/api/v1/searchbots/ask", strings.NewReader(`{"question": "test", "kb_ids": ["kb1"]}`)) @@ -881,10 +544,8 @@ func TestAskHandler_NoChatModel(t *testing.T) { } func TestAskHandler_InvalidJSON(t *testing.T) { - llm := &fakeStreamingLLM{chunks: []string{"answer"}} ret := &fakeChunkRetriever{result: &service.RetrievalTestResponse{}} h := NewSearchBotHandler(nil, nil, nil, ret) - h.SetStreamLLM(llm) c, w := cw() c.Request = httptest.NewRequest("POST", "/api/v1/searchbots/ask", strings.NewReader(`not json`)) @@ -897,10 +558,8 @@ func TestAskHandler_InvalidJSON(t *testing.T) { } func TestAskHandler_WhitespaceKbIDFiltered(t *testing.T) { - llm := &fakeStreamingLLM{chunks: []string{"answer"}} ret := &fakeChunkRetriever{result: &service.RetrievalTestResponse{}} h := NewSearchBotHandler(nil, nil, nil, ret) - h.SetStreamLLM(llm) c, w := cw() c.Request = httptest.NewRequest("POST", "/api/v1/searchbots/ask", strings.NewReader(`{"question": "test", "kb_ids": [" ", ""]}`)) @@ -912,51 +571,6 @@ func TestAskHandler_WhitespaceKbIDFiltered(t *testing.T) { } } -func TestMindMapHandlerSuccess(t *testing.T) { - llm := &fakeChatLLM{response: "# Product\n## Features\n### Search\n#### Hybrid retrieval"} - chunks := &mockChunkService{retrievalTestFn: func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) { - return &service.RetrievalTestResponse{ - Chunks: []map[string]interface{}{{"content_with_weight": "Hybrid search combines vector and keyword retrieval."}}, - }, nil - }} - h := NewSearchBotHandler(nil, nil, llm, chunks) - gin.SetMode(gin.TestMode) - r := gin.New() - r.Use(func(c *gin.Context) { - c.Set("user", &entity.User{ID: "user-1"}) - }) - r.POST("/api/v1/searchbots/mindmap", h.MindMap) - - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/api/v1/searchbots/mindmap", strings.NewReader(`{"question":"What is search?","kb_ids":["kb-1"]}`)) - req.Header.Set("Content-Type", "application/json") - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) - } - var resp map[string]interface{} - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatal(err) - } - if resp["code"] != float64(common.CodeSuccess) { - t.Fatalf("expected code 0, got %v: %v", resp["code"], resp["message"]) - } - data := resp["data"].(map[string]interface{}) - if data["id"] != "Product" { - t.Fatalf("mindmap root = %v, want Product", data["id"]) - } - if chunks.LastReq == nil { - t.Fatal("RetrievalTest was not called") - } - if *chunks.LastReq.Page != 1 || *chunks.LastReq.Size != 12 || *chunks.LastReq.TopK != 1024 { - t.Fatalf("retrieval defaults = page %d size %d topK %d", *chunks.LastReq.Page, *chunks.LastReq.Size, *chunks.LastReq.TopK) - } - if llm.lastTenantID != "user-1" || len(llm.lastMessages) != 2 || !strings.Contains(llm.lastMessages[0].Content, "Hybrid search combines") { - t.Fatalf("unexpected LLM call: tenant=%q messages=%v", llm.lastTenantID, llm.lastMessages) - } -} - func TestParseMindMapMarkdown_ListUnderHeading(t *testing.T) { got := parseMindMapMarkdown("# Product\n- Features\n - Search") if got.ID != "Product" { diff --git a/internal/router/router.go b/internal/router/router.go index 08dd2634c9..a6700de896 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -275,6 +275,7 @@ func (r *Router) Setup(engine *gin.Engine) { // Chat routes v1.POST("/chat/mindmap", r.chatHandler.MindMap) + v1.POST("/chat/recommendation", r.chatHandler.Recommendation) chats := v1.Group("/chats") { chats.GET("", r.chatHandler.ListChats) diff --git a/internal/service/model_chat.go b/internal/service/model_chat.go new file mode 100644 index 0000000000..5e83e1c1f0 --- /dev/null +++ b/internal/service/model_chat.go @@ -0,0 +1,82 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "context" + "fmt" + + "go.uber.org/zap" + + "ragflow/internal/common" + modelModule "ragflow/internal/entity/models" +) + +func (m *ModelProviderService) Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) { + chatModel, err := m.GetChatModel(tenantID, modelID) + if err != nil { + return nil, err + } + return chatModel.ModelDriver.ChatWithMessages(*chatModel.ModelName, messages, chatModel.APIConfig, config) +} + +func (m *ModelProviderService) ChatStream(ctx context.Context, tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) { + chatModel, err := m.GetChatModel(tenantID, modelID) + if err != nil { + return nil, err + } + return chatStreamWithContext(ctx, chatModel, messages, config), nil +} + +func chatStreamWithContext(ctx context.Context, chatModel *modelModule.ChatModel, messages []modelModule.Message, config *modelModule.ChatConfig) <-chan string { + ch := make(chan string, 256) + go func() { + defer close(ch) + if err := chatModel.ModelDriver.ChatStreamlyWithSender(*chatModel.ModelName, messages, chatModel.APIConfig, config, + func(delta *string, _ *string) error { + if delta == nil { + return nil + } + select { + case ch <- *delta: + return nil + case <-ctx.Done(): + return ctx.Err() + } + }); err != nil { + if err == context.Canceled || err == context.DeadlineExceeded { + return + } + common.Warn("ChatStreamlyWithSender returned error", zap.Error(err)) + } + }() + return ch +} + +// TenantStreamAdapter adapts tenant/model-aware chat streaming to AskService. +type TenantStreamAdapter struct { + LLM *ModelProviderService + TenantID string + ModelID string +} + +func (a *TenantStreamAdapter) ChatStream(ctx context.Context, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) { + if a.LLM == nil { + return nil, fmt.Errorf("streaming LLM not configured") + } + return a.LLM.ChatStream(ctx, a.TenantID, a.ModelID, messages, config) +} diff --git a/internal/service/related_question.go b/internal/service/related_question.go new file mode 100644 index 0000000000..ddb99749cd --- /dev/null +++ b/internal/service/related_question.go @@ -0,0 +1,213 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" + + "ragflow/internal/entity" + modelModule "ragflow/internal/entity/models" +) + +// GenerateRelatedQuestions generates related search questions for chat/searchbot endpoints. +func GenerateRelatedQuestions(tenantID, question, searchID string, searchSvc *SearchService, tenantSvc *TenantService, modelProviderSvc *ModelProviderService) ([]string, error) { + if modelProviderSvc == nil { + return nil, fmt.Errorf("model provider service not configured") + } + searchConfig := relatedQuestionsSearchConfig(searchID, searchSvc) + modelID := relatedQuestionsModelID(tenantID, searchConfig, tenantSvc) + prompt, err := LoadPrompt("related_question") + if err != nil { + return nil, err + } + messages := []modelModule.Message{ + {Role: "system", Content: prompt}, + {Role: "user", Content: "\nKeywords: " + question + "\nRelated search terms:\n "}, + } + response, err := modelProviderSvc.Chat(tenantID, modelID, messages, relatedQuestionsConfig(searchConfig)) + if err != nil { + return nil, err + } + if response != nil && response.Answer != nil { + return parseRelatedQuestions(*response.Answer), nil + } + return []string{}, nil +} + +func relatedQuestionsSearchConfig(searchID string, searchSvc *SearchService) map[string]interface{} { + if searchID == "" || searchSvc == nil { + return map[string]interface{}{} + } + if detail, err := searchSvc.GetDetail(searchID); err == nil && detail != nil { + return relatedQuestionsSearchConfigFromDetail(detail) + } + return map[string]interface{}{} +} + +func relatedQuestionsSearchConfigFromDetail(detail map[string]interface{}) map[string]interface{} { + if sc, ok := detail["search_config"].(map[string]interface{}); ok && sc != nil { + return sc + } + if sc, ok := detail["search_config"].(entity.JSONMap); ok && sc != nil { + return map[string]interface{}(sc) + } + return map[string]interface{}{} +} + +func relatedQuestionsModelID(tenantID string, searchConfig map[string]interface{}, tenantSvc *TenantService) string { + modelID, _ := searchConfig["chat_id"].(string) + if modelID != "" || tenantSvc == nil { + return modelID + } + defaultModel, err := tenantSvc.GetDefaultModelName(tenantID, entity.ModelTypeChat) + if err == nil { + modelID = defaultModel + } + return modelID +} + +func relatedQuestionsConfig(searchConfig map[string]interface{}) *modelModule.ChatConfig { + var genConf map[string]interface{} + switch v := searchConfig["llm_setting"].(type) { + case map[string]interface{}: + genConf = v + case entity.JSONMap: + genConf = map[string]interface{}(v) + } + if genConf == nil { + return &modelModule.ChatConfig{Temperature: float64Ptr(0.9)} + } + cfg := &modelModule.ChatConfig{} + for key, value := range genConf { + if key == "parameter" { + continue + } + switch key { + case "stream": + if v, ok := value.(bool); ok { + cfg.Stream = &v + } + case "thinking": + if v, ok := value.(bool); ok { + cfg.Thinking = &v + } + case "max_tokens": + if v, ok := intFromRelatedQuestionConfig(value); ok { + cfg.MaxTokens = &v + } + case "temperature": + if v, ok := floatFromRelatedQuestionConfig(value); ok { + cfg.Temperature = &v + } + case "top_p": + if v, ok := floatFromRelatedQuestionConfig(value); ok && v > 0 { + cfg.TopP = &v + } + case "do_sample": + if v, ok := value.(bool); ok { + cfg.DoSample = &v + } + case "stop": + if stops := stringSliceFromRelatedQuestionConfig(value); len(stops) > 0 { + cfg.Stop = &stops + } + case "model_class": + if v, ok := value.(string); ok { + cfg.ModelClass = &v + } + case "effort": + if v, ok := value.(string); ok { + cfg.Effort = &v + } + case "verbosity": + if v, ok := value.(string); ok { + cfg.Verbosity = &v + } + } + } + return cfg +} + +func intFromRelatedQuestionConfig(value interface{}) (int, bool) { + switch v := value.(type) { + case int: + return v, true + case int64: + return int(v), true + case float64: + return int(v), true + case json.Number: + if n, err := v.Int64(); err == nil { + return int(n), true + } + } + return 0, false +} + +func floatFromRelatedQuestionConfig(value interface{}) (float64, bool) { + switch v := value.(type) { + case float64: + return v, true + case float32: + return float64(v), true + case int: + return float64(v), true + case int64: + return float64(v), true + case json.Number: + if n, err := v.Float64(); err == nil { + return n, true + } + } + return 0, false +} + +func stringSliceFromRelatedQuestionConfig(value interface{}) []string { + switch v := value.(type) { + case []string: + return v + case []interface{}: + result := make([]string, 0, len(v)) + for _, item := range v { + if s, ok := item.(string); ok { + result = append(result, s) + } + } + return result + } + return nil +} + +var relatedQuestionLineRe = regexp.MustCompile(`^\d+\.\s`) + +func parseRelatedQuestions(text string) []string { + var result []string + for _, line := range strings.Split(text, "\n") { + if relatedQuestionLineRe.MatchString(line) { + result = append(result, relatedQuestionLineRe.ReplaceAllString(line, "")) + } + } + if result == nil { + return []string{} + } + return result +} + +func float64Ptr(v float64) *float64 { return &v } diff --git a/internal/service/related_question_test.go b/internal/service/related_question_test.go new file mode 100644 index 0000000000..47f2bcbdfc --- /dev/null +++ b/internal/service/related_question_test.go @@ -0,0 +1,97 @@ +package service + +import "testing" + +func TestRelatedQuestionsConfigSkipsZeroTopP(t *testing.T) { + cfg := relatedQuestionsConfig(map[string]interface{}{ + "llm_setting": map[string]interface{}{ + "temperature": float64(0.2), + "top_p": float64(0), + "parameter": map[string]interface{}{"unused": true}, + }, + }) + + if cfg == nil || cfg.Temperature == nil || *cfg.Temperature != 0.2 { + t.Fatalf("expected temperature 0.2, got %+v", cfg) + } + if cfg.TopP != nil { + t.Fatalf("expected zero top_p to be omitted, got %v", *cfg.TopP) + } +} + +func TestParseRelatedQuestionsStandard(t *testing.T) { + input := `1. How do electric vehicles impact the environment? +2. What are the advantages of owning an electric car? +3. What is the cost-effectiveness?` + + got := parseRelatedQuestions(input) + if len(got) != 3 { + t.Fatalf("expected 3, got %d", len(got)) + } + if got[0] != "How do electric vehicles impact the environment?" { + t.Errorf("unexpected [0]: %q", got[0]) + } + if got[1] != "What are the advantages of owning an electric car?" { + t.Errorf("unexpected [1]: %q", got[1]) + } + if got[2] != "What is the cost-effectiveness?" { + t.Errorf("unexpected [2]: %q", got[2]) + } +} + +func TestParseRelatedQuestionsEmpty(t *testing.T) { + got := parseRelatedQuestions("") + if len(got) != 0 { + t.Errorf("expected 0, got %d", len(got)) + } +} + +func TestParseRelatedQuestionsNoNumberedLines(t *testing.T) { + input := `Here are some related questions: +- First question +- Second question` + + got := parseRelatedQuestions(input) + if len(got) != 0 { + t.Errorf("expected 0, got %d", len(got)) + } +} + +func TestParseRelatedQuestionsMixedContent(t *testing.T) { + input := `Here are some related questions: +1. First related question. +Some explanation text. +2. Second related question. +More text. +3. Third related question.` + + got := parseRelatedQuestions(input) + if len(got) != 3 { + t.Fatalf("expected 3, got %d", len(got)) + } + if got[0] != "First related question." { + t.Errorf("unexpected [0]: %q", got[0]) + } + if got[1] != "Second related question." { + t.Errorf("unexpected [1]: %q", got[1]) + } + if got[2] != "Third related question." { + t.Errorf("unexpected [2]: %q", got[2]) + } +} + +func TestParseRelatedQuestionsMultiDigit(t *testing.T) { + input := `10. Tenth question. +11. Eleventh question.` + + got := parseRelatedQuestions(input) + if len(got) != 2 { + t.Fatalf("expected 2, got %d", len(got)) + } + if got[0] != "Tenth question." { + t.Errorf("unexpected [0]: %q", got[0]) + } + if got[1] != "Eleventh question." { + t.Errorf("unexpected [1]: %q", got[1]) + } +}