mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-03 01:01:56 +08:00
[Go] Add API /api/v1/chat/recommendation and consolidate with /api/v1/searchbots/related_questions (#16500)
This commit is contained in:
@@ -297,18 +297,17 @@ func startServer(config *server.Config) {
|
||||
// no-op echo (the audio package contract), so this is always
|
||||
// safe to call.
|
||||
configureTTSSynthesizer(modelProviderService)
|
||||
modelProviderLLM := &handler.ModelProviderLLM{Svc: modelProviderService}
|
||||
searchBotHandler := handler.NewSearchBotHandler(
|
||||
searchService,
|
||||
tenantService,
|
||||
modelProviderLLM,
|
||||
modelProviderService,
|
||||
chunkService,
|
||||
)
|
||||
searchBotHandler.SetStreamLLM(modelProviderLLM)
|
||||
searchBotHandler.SetStreamLLM(modelProviderService)
|
||||
askService := service.NewAskService(chunkService, nil, 0, 0)
|
||||
searchBotHandler.SetAskService(askService)
|
||||
chatHandler.SetMindMapDependencies(searchService, tenantService, modelProviderLLM, chunkService)
|
||||
searchHandler.SetCompletionDependencies(modelProviderLLM, askService)
|
||||
chatHandler.SetMindMapDependencies(searchService, tenantService, modelProviderService, chunkService)
|
||||
searchHandler.SetCompletionDependencies(modelProviderService, askService)
|
||||
pluginHandler := handler.NewPluginHandler(service.NewPluginService())
|
||||
modelHandler := handler.NewModelHandler(service.NewModelProviderService())
|
||||
fileCommitHandler := handler.NewFileCommitHandler(service.NewFileCommitService())
|
||||
|
||||
@@ -35,8 +35,8 @@ type ChatHandler struct {
|
||||
userService *service.UserService
|
||||
searchSvc *service.SearchService
|
||||
tenantSvc *service.TenantService
|
||||
llm chatLLM
|
||||
chunkSvc ChunkRetriever
|
||||
llm *service.ModelProviderService
|
||||
chunkSvc service.Retriever
|
||||
}
|
||||
|
||||
// NewChatHandler create chat handler
|
||||
@@ -48,7 +48,7 @@ func NewChatHandler(chatService *service.ChatService, userService *service.UserS
|
||||
}
|
||||
|
||||
// SetMindMapDependencies sets dependencies used by POST /api/v1/chat/mindmap.
|
||||
func (h *ChatHandler) SetMindMapDependencies(searchSvc *service.SearchService, tenantSvc *service.TenantService, llm chatLLM, chunkSvc ChunkRetriever) {
|
||||
func (h *ChatHandler) SetMindMapDependencies(searchSvc *service.SearchService, tenantSvc *service.TenantService, llm *service.ModelProviderService, chunkSvc service.Retriever) {
|
||||
h.searchSvc = searchSvc
|
||||
h.tenantSvc = tenantSvc
|
||||
h.llm = llm
|
||||
|
||||
67
internal/handler/chat_recommendation.go
Normal file
67
internal/handler/chat_recommendation.go
Normal file
@@ -0,0 +1,67 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/service"
|
||||
)
|
||||
|
||||
// ChatRecommendationRequest is the request body for POST /api/v1/chat/recommendation.
|
||||
type ChatRecommendationRequest struct {
|
||||
Question string `json:"question" binding:"required"`
|
||||
SearchID string `json:"search_id,omitempty"`
|
||||
}
|
||||
|
||||
// Recommendation generates related search questions for a chat query.
|
||||
// @Summary Generate Chat Recommendations
|
||||
// @Description Generates related questions using the chat model configured by search_config.chat_id or the tenant default.
|
||||
// @Tags chat
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body ChatRecommendationRequest true "Recommendation parameters"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/v1/chat/recommendation [post]
|
||||
func (h *ChatHandler) Recommendation(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
var req ChatRecommendationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": common.CodeArgumentError, "data": nil, "message": "question is required"})
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Question) == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"code": common.CodeArgumentError, "data": nil, "message": "question is required"})
|
||||
return
|
||||
}
|
||||
questions, err := service.GenerateRelatedQuestions(user.ID, req.Question, req.SearchID, h.searchSvc, h.tenantSvc, h.llm)
|
||||
if err != nil {
|
||||
jsonInternalError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse(c, common.CodeSuccess, questions, "success")
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package handler
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -72,41 +71,6 @@ func createChatHandlerTestChat(t *testing.T, db *gorm.DB, id, tenantID string) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatMindMapHandlerSuccess(t *testing.T) {
|
||||
llm := &fakeChatLLM{response: "# Product\n## Features\n### Search"}
|
||||
chunks := &mockChunkService{retrievalTestFn: func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
|
||||
return &service.RetrievalTestResponse{
|
||||
Chunks: []map[string]interface{}{{"content_with_weight": "Hybrid search combines vector and keyword retrieval."}},
|
||||
}, nil
|
||||
}}
|
||||
h := NewChatHandler(service.NewChatService(), service.NewUserService())
|
||||
h.SetMindMapDependencies(nil, nil, llm, chunks)
|
||||
c, w := setupGinContextWithUser("POST", "/api/v1/chat/mindmap", `{"question":"What is search?","kb_ids":["kb-1"]}`)
|
||||
|
||||
h.MindMap(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected code 0, got %v: %v", resp["code"], resp["message"])
|
||||
}
|
||||
data := resp["data"].(map[string]interface{})
|
||||
if data["id"] != "Product" {
|
||||
t.Fatalf("mindmap root = %v, want Product", data["id"])
|
||||
}
|
||||
if chunks.LastReq == nil || len(chunks.LastReq.Datasets) != 1 || chunks.LastReq.Datasets[0] != "kb-1" {
|
||||
t.Fatalf("retrieval datasets = %+v, want [kb-1]", chunks.LastReq)
|
||||
}
|
||||
if llm.lastTenantID != "user-1" || len(llm.lastMessages) != 2 || !strings.Contains(llm.lastMessages[0].Content, "Hybrid search combines") {
|
||||
t.Fatalf("unexpected LLM call: tenant=%q messages=%v", llm.lastTenantID, llm.lastMessages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteChatHandlerSuccess(t *testing.T) {
|
||||
db := setupChatHandlerTestDB(t)
|
||||
createChatHandlerTestChat(t, db, "chat-1", "user-1")
|
||||
|
||||
@@ -35,8 +35,8 @@ type mindMapRunConfig struct {
|
||||
SearchConfig map[string]interface{}
|
||||
AuthUserID string
|
||||
ModelTenantID string
|
||||
ChunkSvc ChunkRetriever
|
||||
LLM chatLLM
|
||||
ChunkSvc service.Retriever
|
||||
LLM *service.ModelProviderService
|
||||
TenantSvc *service.TenantService
|
||||
}
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ import (
|
||||
type SearchHandler struct {
|
||||
searchService *service.SearchService
|
||||
userService *service.UserService
|
||||
streamLLM streamingLLM
|
||||
streamLLM *service.ModelProviderService
|
||||
askService *service.AskService
|
||||
sseWriter SSEWriter
|
||||
}
|
||||
@@ -45,7 +45,7 @@ func NewSearchHandler(searchService *service.SearchService, userService *service
|
||||
}
|
||||
|
||||
// SetCompletionDependencies wires the streaming search completion runtime.
|
||||
func (h *SearchHandler) SetCompletionDependencies(streamLLM streamingLLM, askService *service.AskService) {
|
||||
func (h *SearchHandler) SetCompletionDependencies(streamLLM *service.ModelProviderService, askService *service.AskService) {
|
||||
h.streamLLM = streamLLM
|
||||
h.askService = askService
|
||||
}
|
||||
@@ -495,7 +495,7 @@ func (h *SearchHandler) Completion(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
adapter := &askStreamAdapter{llm: h.streamLLM, tenantID: plan.UserID, modelID: plan.ModelID}
|
||||
adapter := &service.TenantStreamAdapter{LLM: h.streamLLM, TenantID: plan.UserID, ModelID: plan.ModelID}
|
||||
|
||||
hadError := false
|
||||
for delta := range h.askService.StreamWithOptions(c.Request.Context(), adapter, plan.UserID, plan.Question, plan.KBIDs, plan.Options) {
|
||||
|
||||
@@ -17,41 +17,21 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/entity"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
"ragflow/internal/service"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// chatLLM is the interface for LLM calls used by chat-style handlers.
|
||||
type chatLLM interface {
|
||||
Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error)
|
||||
}
|
||||
|
||||
// ChunkRetriever abstracts chunk retrieval for the searchbots handler.
|
||||
type ChunkRetriever interface {
|
||||
RetrievalTest(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error)
|
||||
}
|
||||
|
||||
// streamingLLM abstracts streaming chat for the Ask endpoint.
|
||||
// The returned channel delivers raw text deltas from the LLM.
|
||||
// Implementations should respect ctx cancellation to prevent goroutine leaks.
|
||||
type streamingLLM interface {
|
||||
ChatStream(ctx context.Context, tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error)
|
||||
}
|
||||
|
||||
// SearchBotAskRequest is the request body for POST /api/v1/searchbots/ask.
|
||||
type SearchBotAskRequest struct {
|
||||
Question string `json:"question" binding:"required"`
|
||||
@@ -66,55 +46,6 @@ type SearchBotMindMapRequest struct {
|
||||
SearchID string `json:"search_id,omitempty"`
|
||||
}
|
||||
|
||||
// ModelProviderLLM wraps ModelProviderService to implement chatLLM.
|
||||
type ModelProviderLLM struct {
|
||||
Svc *service.ModelProviderService
|
||||
}
|
||||
|
||||
func (r *ModelProviderLLM) Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) {
|
||||
chatModel, err := r.Svc.GetChatModel(tenantID, modelID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return chatModel.ModelDriver.ChatWithMessages(*chatModel.ModelName, messages, chatModel.APIConfig, config)
|
||||
}
|
||||
|
||||
// ChatStream implements streamingLLM.
|
||||
func (r *ModelProviderLLM) ChatStream(ctx context.Context, tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) {
|
||||
chatModel, err := r.Svc.GetChatModel(tenantID, modelID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return chatStreamWithContext(ctx, chatModel, messages, config), nil
|
||||
}
|
||||
|
||||
// chatStreamWithContext creates a streaming LLM channel that stops sending
|
||||
// when ctx is cancelled, preventing goroutine leaks on client disconnect.
|
||||
func chatStreamWithContext(ctx context.Context, chatModel *modelModule.ChatModel, messages []modelModule.Message, config *modelModule.ChatConfig) <-chan string {
|
||||
ch := make(chan string, 256)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
if err := chatModel.ModelDriver.ChatStreamlyWithSender(*chatModel.ModelName, messages, chatModel.APIConfig, config,
|
||||
func(delta *string, _ *string) error {
|
||||
if delta == nil {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case ch <- *delta:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}); err != nil {
|
||||
if err == context.Canceled || err == context.DeadlineExceeded {
|
||||
return
|
||||
}
|
||||
common.Warn("ChatStreamlyWithSender returned error", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
// SearchBotRetrievalTestRequest is the request body for POST /api/v1/searchbots/retrieval_test.
|
||||
type SearchBotRetrievalTestRequest struct {
|
||||
KbIDs common.StringSlice `json:"kb_ids" binding:"required"`
|
||||
@@ -172,38 +103,24 @@ type SearchBotRequest struct {
|
||||
type SearchBotHandler struct {
|
||||
searchSvc *service.SearchService
|
||||
tenantSvc *service.TenantService
|
||||
llm chatLLM
|
||||
streamLLM streamingLLM
|
||||
chunkSvc ChunkRetriever
|
||||
llm *service.ModelProviderService
|
||||
streamLLM *service.ModelProviderService
|
||||
chunkSvc service.Retriever
|
||||
askSvc *service.AskService
|
||||
sseWriter SSEWriter
|
||||
}
|
||||
|
||||
// NewSearchBotHandler creates a new SearchBotHandler.
|
||||
func NewSearchBotHandler(searchSvc *service.SearchService, tenantSvc *service.TenantService, llm chatLLM, chunkSvc ChunkRetriever) *SearchBotHandler {
|
||||
func NewSearchBotHandler(searchSvc *service.SearchService, tenantSvc *service.TenantService, llm *service.ModelProviderService, chunkSvc service.Retriever) *SearchBotHandler {
|
||||
return &SearchBotHandler{searchSvc: searchSvc, tenantSvc: tenantSvc, llm: llm, chunkSvc: chunkSvc, sseWriter: &ginSSEWriter{}}
|
||||
}
|
||||
|
||||
// SetStreamLLM sets the streaming LLM for the Ask endpoint.
|
||||
func (h *SearchBotHandler) SetStreamLLM(llm streamingLLM) { h.streamLLM = llm }
|
||||
func (h *SearchBotHandler) SetStreamLLM(llm *service.ModelProviderService) { h.streamLLM = llm }
|
||||
|
||||
// SetAskService sets the AskService used by the Ask endpoint.
|
||||
func (h *SearchBotHandler) SetAskService(svc *service.AskService) { h.askSvc = svc }
|
||||
|
||||
// askStreamAdapter adapts handler.streamingLLM to service.StreamingLLM.
|
||||
type askStreamAdapter struct {
|
||||
llm streamingLLM
|
||||
tenantID string
|
||||
modelID string
|
||||
}
|
||||
|
||||
func (a *askStreamAdapter) ChatStream(ctx context.Context, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) {
|
||||
if a.llm == nil {
|
||||
return nil, fmt.Errorf("streaming LLM not configured")
|
||||
}
|
||||
return a.llm.ChatStream(ctx, a.tenantID, a.modelID, messages, config)
|
||||
}
|
||||
|
||||
// Handle generates related search questions based on a user query.
|
||||
// @Summary Generate Related Questions
|
||||
// @Description Generates 5-10 related search questions to expand the search scope.
|
||||
@@ -239,36 +156,9 @@ func (h *SearchBotHandler) Handle(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve model ID from search config if provided
|
||||
modelID := ""
|
||||
if req.SearchID != "" && h.searchSvc != nil {
|
||||
if detail, err := h.searchSvc.GetDetail(req.SearchID); err == nil {
|
||||
if sc, ok := detail["search_config"].(map[string]interface{}); ok {
|
||||
if cid, ok := sc["chat_id"].(string); ok && cid != "" {
|
||||
modelID = cid
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if modelID == "" && h.tenantSvc != nil {
|
||||
defaultModel, err := h.tenantSvc.GetDefaultModelName(user.ID, entity.ModelTypeChat)
|
||||
if err == nil && defaultModel != "" {
|
||||
modelID = defaultModel
|
||||
}
|
||||
}
|
||||
|
||||
messages := []modelModule.Message{
|
||||
{Role: "system", Content: relatedQuestionPrompt},
|
||||
{Role: "user", Content: "Keywords: " + req.Question + "\nRelated search terms:\n"},
|
||||
}
|
||||
|
||||
genConf := &modelModule.ChatConfig{
|
||||
Temperature: ptrFloat64(0.9),
|
||||
}
|
||||
|
||||
response, err := h.llm.Chat(user.ID, modelID, messages, genConf)
|
||||
questions, err := service.GenerateRelatedQuestions(user.ID, req.Question, req.SearchID, h.searchSvc, h.tenantSvc, h.llm)
|
||||
if err != nil {
|
||||
common.Warn("searchbot LLM call failed", zap.String("error", err.Error()))
|
||||
common.Warn("searchbot related questions failed", zap.String("error", err.Error()))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeOperatingError,
|
||||
"data": nil,
|
||||
@@ -277,10 +167,6 @@ func (h *SearchBotHandler) Handle(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var questions []string
|
||||
if response != nil && response.Answer != nil {
|
||||
questions = parseRelatedQuestions(*response.Answer)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"data": questions,
|
||||
@@ -409,8 +295,12 @@ func (h *SearchBotHandler) Ask(c *gin.Context) {
|
||||
h.sseWriter.Write(c, sseError("ask service not configured"))
|
||||
return
|
||||
}
|
||||
if h.streamLLM == nil {
|
||||
h.sseWriter.Write(c, sseError("streaming LLM not configured"))
|
||||
return
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
adapter := &askStreamAdapter{llm: h.streamLLM, tenantID: user.ID, modelID: modelID}
|
||||
adapter := &service.TenantStreamAdapter{LLM: h.streamLLM, TenantID: user.ID, ModelID: modelID}
|
||||
for delta := range h.askSvc.Stream(ctx, adapter, user.ID, req.Question, filtered) {
|
||||
switch delta.Kind {
|
||||
case service.AskDeltaAnswer:
|
||||
@@ -607,7 +497,6 @@ func sseMarker(marker string) string {
|
||||
return fmt.Sprintf("data: %s\n\n", string(b))
|
||||
}
|
||||
|
||||
// SSEWriter writes an SSE event to the client.
|
||||
type SSEWriter interface {
|
||||
Write(c *gin.Context, data string)
|
||||
}
|
||||
@@ -683,78 +572,3 @@ func applyRetrievalDefaults(req *SearchBotRetrievalTestRequest) {
|
||||
req.VectorSimilarityWeight = &v
|
||||
}
|
||||
}
|
||||
|
||||
var relatedQuestionLineRe = regexp.MustCompile(`^\d+\.\s`)
|
||||
|
||||
// parseRelatedQuestions extracts numbered list items from an LLM response.
|
||||
// Lines matching "^N. " are extracted and the number prefix is stripped.
|
||||
func parseRelatedQuestions(text string) []string {
|
||||
var result []string
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
if relatedQuestionLineRe.MatchString(line) {
|
||||
result = append(result, relatedQuestionLineRe.ReplaceAllString(line, ""))
|
||||
}
|
||||
}
|
||||
if result == nil {
|
||||
return []string{}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// relatedQuestionPrompt is the system prompt for generating related search questions.
|
||||
// Matches Python rag/prompts/related_question.md
|
||||
const relatedQuestionPrompt = `# Role
|
||||
You are an AI language model assistant tasked with generating **5-10 related questions** based on a user's original query.
|
||||
These questions should help **expand the search query scope** and **improve search relevance**.
|
||||
|
||||
---
|
||||
|
||||
## Instructions
|
||||
|
||||
**Input:**
|
||||
You are provided with a **user's question**.
|
||||
|
||||
**Output:**
|
||||
Generate **5-10 alternative questions** that are **related** to the original user question.
|
||||
These alternatives should help retrieve a **broader range of relevant documents** from a vector database.
|
||||
|
||||
**Context:**
|
||||
Focus on **rephrasing** the original question in different ways, ensuring the alternative questions are **diverse but still connected** to the topic of the original query.
|
||||
Do **not** create overly obscure, irrelevant, or unrelated questions.
|
||||
|
||||
**Fallback:**
|
||||
If you cannot generate any relevant alternatives, do **not** return any questions.
|
||||
|
||||
---
|
||||
|
||||
## Guidance
|
||||
|
||||
1. Each alternative should be **unique** but still **relevant** to the original query.
|
||||
2. Keep the phrasing **clear, concise, and easy to understand**.
|
||||
3. Avoid overly technical jargon or specialized terms **unless directly relevant**.
|
||||
4. Ensure that each question **broadens** the search angle, **not narrows** it.
|
||||
|
||||
---
|
||||
|
||||
## Example
|
||||
|
||||
**Original Question:**
|
||||
> What are the benefits of electric vehicles?
|
||||
|
||||
**Alternative Questions:**
|
||||
1. How do electric vehicles impact the environment?
|
||||
2. What are the advantages of owning an electric car?
|
||||
3. What is the cost-effectiveness of electric vehicles?
|
||||
4. How do electric vehicles compare to traditional cars in terms of fuel efficiency?
|
||||
5. What are the environmental benefits of switching to electric cars?
|
||||
6. How do electric vehicles help reduce carbon emissions?
|
||||
7. Why are electric vehicles becoming more popular?
|
||||
8. What are the long-term savings of using electric vehicles?
|
||||
9. How do electric vehicles contribute to sustainability?
|
||||
10. What are the key benefits of electric vehicles for consumers?
|
||||
|
||||
---
|
||||
|
||||
## Reason
|
||||
Rephrasing the original query into multiple alternative questions helps the user explore **different aspects** of their search topic, improving the **quality of search results**.
|
||||
These questions guide the search engine to provide a **more comprehensive set** of relevant documents.`
|
||||
|
||||
@@ -17,29 +17,22 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
"ragflow/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// mockChunkService implements ChunkRetriever for testing.
|
||||
// mockChunkService implements service.Retriever for testing.
|
||||
// It captures the last request received so tests can verify field mapping.
|
||||
type mockChunkService struct {
|
||||
retrievalTestFn func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error)
|
||||
@@ -411,25 +404,6 @@ func TestSearchBotsRetrieval_EmptyQuestion(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// fakeChatLLM implements chatLLM for testing.
|
||||
type fakeChatLLM struct {
|
||||
response string
|
||||
err error
|
||||
lastTenantID string
|
||||
lastModelID string
|
||||
lastMessages []modelModule.Message
|
||||
}
|
||||
|
||||
func (f *fakeChatLLM) Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) {
|
||||
f.lastTenantID = tenantID
|
||||
f.lastModelID = modelID
|
||||
f.lastMessages = messages
|
||||
if f.err != nil {
|
||||
return nil, f.err
|
||||
}
|
||||
return &modelModule.ChatResponse{Answer: &f.response}, nil
|
||||
}
|
||||
|
||||
func setupSearchBotRequest(body string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
@@ -442,86 +416,9 @@ func setupSearchBotRequest(body string) (*gin.Context, *httptest.ResponseRecorde
|
||||
return c, w
|
||||
}
|
||||
|
||||
// TestSearchBotHandler_Success verifies the happy path.
|
||||
func TestSearchBotHandler_Success(t *testing.T) {
|
||||
llm := &fakeChatLLM{
|
||||
response: `Here are some related questions:
|
||||
1. How do EV impact environment?
|
||||
2. What are advantages of EV?
|
||||
3. Cost of EV?`,
|
||||
}
|
||||
h := NewSearchBotHandler(nil, nil, llm, nil)
|
||||
|
||||
c, w := setupSearchBotRequest(`{"question": "EV benefits"}`)
|
||||
h.Handle(c)
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected code 0, got %v: %v", resp["code"], resp["message"])
|
||||
}
|
||||
if msg, _ := resp["message"].(string); msg != "success" {
|
||||
t.Errorf("expected message 'success', got %q", msg)
|
||||
}
|
||||
|
||||
questions, ok := resp["data"].([]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected data array, got %T", resp["data"])
|
||||
}
|
||||
if len(questions) != 3 {
|
||||
t.Fatalf("expected 3 questions, got %d", len(questions))
|
||||
}
|
||||
if questions[0] != "How do EV impact environment?" {
|
||||
t.Errorf("unexpected [0]: %v", questions[0])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchBotHandler_EmptyResponse verifies empty LLM response returns empty list.
|
||||
func TestSearchBotHandler_EmptyResponse(t *testing.T) {
|
||||
llm := &fakeChatLLM{
|
||||
response: "No related questions found.",
|
||||
}
|
||||
h := NewSearchBotHandler(nil, nil, llm, nil)
|
||||
|
||||
c, w := setupSearchBotRequest(`{"question": "EV benefits"}`)
|
||||
h.Handle(c)
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected code 0, got %v: %v", resp["code"], resp["message"])
|
||||
}
|
||||
questions, ok := resp["data"].([]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected data array, got %T", resp["data"])
|
||||
}
|
||||
if len(questions) != 0 {
|
||||
t.Errorf("expected 0 questions, got %d", len(questions))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchBotHandler_LLMFailure verifies error handling on LLM failure.
|
||||
func TestSearchBotHandler_LLMFailure(t *testing.T) {
|
||||
llm := &fakeChatLLM{
|
||||
err: errFake{msg: "LLM unavailable"},
|
||||
}
|
||||
h := NewSearchBotHandler(nil, nil, llm, nil)
|
||||
|
||||
c, w := setupSearchBotRequest(`{"question": "EV benefits"}`)
|
||||
h.Handle(c)
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
code, _ := resp["code"].(float64)
|
||||
if code == 0 {
|
||||
t.Errorf("expected error code, got 0")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchBotHandler_MissingQuestion verifies validation.
|
||||
func TestSearchBotHandler_MissingQuestion(t *testing.T) {
|
||||
llm := &fakeChatLLM{response: "dummy"}
|
||||
h := NewSearchBotHandler(nil, nil, llm, nil)
|
||||
h := NewSearchBotHandler(nil, nil, nil, nil)
|
||||
|
||||
c, w := setupSearchBotRequest(`{}`)
|
||||
h.Handle(c)
|
||||
@@ -534,96 +431,11 @@ func TestSearchBotHandler_MissingQuestion(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// errFake implements error for testing.
|
||||
type errFake struct{ msg string }
|
||||
|
||||
func (e errFake) Error() string { return e.msg }
|
||||
|
||||
// Existing parse tests below
|
||||
func TestParseRelatedQuestions_Standard(t *testing.T) {
|
||||
input := `1. How do electric vehicles impact the environment?
|
||||
2. What are the advantages of owning an electric car?
|
||||
3. What is the cost-effectiveness?`
|
||||
|
||||
got := parseRelatedQuestions(input)
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("expected 3, got %d", len(got))
|
||||
}
|
||||
if got[0] != "How do electric vehicles impact the environment?" {
|
||||
t.Errorf("unexpected [0]: %q", got[0])
|
||||
}
|
||||
if got[1] != "What are the advantages of owning an electric car?" {
|
||||
t.Errorf("unexpected [1]: %q", got[1])
|
||||
}
|
||||
if got[2] != "What is the cost-effectiveness?" {
|
||||
t.Errorf("unexpected [2]: %q", got[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRelatedQuestions_Empty(t *testing.T) {
|
||||
got := parseRelatedQuestions("")
|
||||
if len(got) != 0 {
|
||||
t.Errorf("expected 0, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRelatedQuestions_NoNumberedLines(t *testing.T) {
|
||||
input := `Here are some related questions:
|
||||
- First question
|
||||
- Second question`
|
||||
|
||||
got := parseRelatedQuestions(input)
|
||||
if len(got) != 0 {
|
||||
t.Errorf("expected 0, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRelatedQuestions_MixedContent(t *testing.T) {
|
||||
input := `Here are some related questions:
|
||||
1. First related question.
|
||||
Some explanation text.
|
||||
2. Second related question.
|
||||
More text.
|
||||
3. Third related question.`
|
||||
|
||||
got := parseRelatedQuestions(input)
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("expected 3, got %d", len(got))
|
||||
}
|
||||
if got[0] != "First related question." {
|
||||
t.Errorf("unexpected [0]: %q", got[0])
|
||||
}
|
||||
if got[1] != "Second related question." {
|
||||
t.Errorf("unexpected [1]: %q", got[1])
|
||||
}
|
||||
if got[2] != "Third related question." {
|
||||
t.Errorf("unexpected [2]: %q", got[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRelatedQuestions_MultiDigit(t *testing.T) {
|
||||
input := `10. Tenth question.
|
||||
11. Eleventh question.`
|
||||
|
||||
got := parseRelatedQuestions(input)
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2, got %d", len(got))
|
||||
}
|
||||
if got[0] != "Tenth question." {
|
||||
t.Errorf("unexpected [0]: %q", got[0])
|
||||
}
|
||||
if got[1] != "Eleventh question." {
|
||||
t.Errorf("unexpected [1]: %q", got[1])
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Ask handler tests ----
|
||||
|
||||
func TestAskHandler_MissingQuestion(t *testing.T) {
|
||||
llm := &fakeStreamingLLM{chunks: []string{"answer"}}
|
||||
ret := &fakeChunkRetriever{result: &service.RetrievalTestResponse{}}
|
||||
h := NewSearchBotHandler(nil, nil, nil, ret)
|
||||
h.SetStreamLLM(llm)
|
||||
c, w := cw()
|
||||
c.Request = httptest.NewRequest("POST", "/api/v1/searchbots/ask",
|
||||
strings.NewReader(`{"kb_ids": ["kb1"]}`))
|
||||
@@ -636,10 +448,8 @@ func TestAskHandler_MissingQuestion(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAskHandler_MissingKbIDs(t *testing.T) {
|
||||
llm := &fakeStreamingLLM{chunks: []string{"answer"}}
|
||||
ret := &fakeChunkRetriever{result: &service.RetrievalTestResponse{}}
|
||||
h := NewSearchBotHandler(nil, nil, nil, ret)
|
||||
h.SetStreamLLM(llm)
|
||||
c, w := cw()
|
||||
c.Request = httptest.NewRequest("POST", "/api/v1/searchbots/ask",
|
||||
strings.NewReader(`{"question": "test"}`))
|
||||
@@ -651,46 +461,6 @@ func TestAskHandler_MissingKbIDs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// fakeStreamingLLM implements streamingLLM for testing.
|
||||
type fakeStreamingLLM struct {
|
||||
chunks []string
|
||||
err error
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
func (f *fakeStreamingLLM) ChatStream(ctx context.Context, tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) {
|
||||
if f.err != nil {
|
||||
return nil, f.err
|
||||
}
|
||||
if f.delay > 0 {
|
||||
ch := make(chan string)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
for i, chunk := range f.chunks {
|
||||
if i > 0 {
|
||||
select {
|
||||
case <-time.After(f.delay):
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
select {
|
||||
case ch <- chunk:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return ch, nil
|
||||
}
|
||||
ch := make(chan string, len(f.chunks)+1)
|
||||
for _, c := range f.chunks {
|
||||
ch <- c
|
||||
}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
type fakeChunkRetriever struct {
|
||||
result *service.RetrievalTestResponse
|
||||
err error
|
||||
@@ -721,117 +491,14 @@ func (w *bufferSSEWriter) Write(_ *gin.Context, data string) {
|
||||
|
||||
func (w *bufferSSEWriter) String() string { return w.buf.String() }
|
||||
|
||||
func setupAskHandlerTenantDB(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{
|
||||
TranslateError: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open sqlite: %v", err)
|
||||
}
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get sqlite db: %v", err)
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
|
||||
if err := db.AutoMigrate(&entity.Tenant{}); err != nil {
|
||||
t.Fatalf("failed to migrate tenant table: %v", err)
|
||||
}
|
||||
|
||||
status := "1"
|
||||
name := "Test Tenant"
|
||||
if err := db.Create(&entity.Tenant{
|
||||
ID: "user-1",
|
||||
Name: &name,
|
||||
LLMID: "test-model",
|
||||
EmbdID: "test-embedding",
|
||||
ASRID: "test-asr",
|
||||
Img2TxtID: "test-image",
|
||||
RerankID: "test-rerank",
|
||||
ParserIDs: "naive",
|
||||
Status: &status,
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("failed to create tenant: %v", err)
|
||||
}
|
||||
|
||||
orig := dao.DB
|
||||
dao.DB = db
|
||||
t.Cleanup(func() {
|
||||
dao.DB = orig
|
||||
_ = sqlDB.Close()
|
||||
})
|
||||
}
|
||||
|
||||
func TestAskHandler_DisablesWriteDeadlineForSSE(t *testing.T) {
|
||||
setupAskHandlerTenantDB(t)
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
ret := &fakeChunkRetriever{result: &service.RetrievalTestResponse{
|
||||
Chunks: []map[string]interface{}{
|
||||
{"id": "c1", "content_with_weight": "test chunk", "docnm_kwd": "Doc", "kb_id": "kb1", "doc_id": "d1"},
|
||||
},
|
||||
DocAggs: []map[string]interface{}{{"doc_id": "d1", "count": 1}},
|
||||
}}
|
||||
llm := &fakeStreamingLLM{
|
||||
chunks: []string{"first response chunk", "second response chunk"},
|
||||
delay: 120 * time.Millisecond,
|
||||
}
|
||||
h := NewSearchBotHandler(nil, service.NewTenantService(), nil, ret)
|
||||
h.SetStreamLLM(llm)
|
||||
h.SetAskService(service.NewAskService(ret, nil, 0, 1))
|
||||
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: "user-1"})
|
||||
})
|
||||
router.POST("/api/v1/searchbots/ask", h.Ask)
|
||||
|
||||
server := httptest.NewUnstartedServer(router)
|
||||
server.Config.WriteTimeout = 30 * time.Millisecond
|
||||
server.Start()
|
||||
defer server.Close()
|
||||
|
||||
client := server.Client()
|
||||
client.Timeout = time.Second
|
||||
resp, err := client.Post(server.URL+"/api/v1/searchbots/ask", "application/json",
|
||||
strings.NewReader(`{"question": "test", "kb_ids": ["kb1"]}`))
|
||||
if err != nil {
|
||||
t.Fatalf("post ask stream: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
if contentType := resp.Header.Get("Content-Type"); !strings.Contains(contentType, "text/event-stream") {
|
||||
t.Fatalf("expected SSE content type, got %q", contentType)
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read ask stream body: %v", err)
|
||||
}
|
||||
|
||||
body := string(bodyBytes)
|
||||
for _, want := range []string{"first response chunk", "second response chunk"} {
|
||||
if !strings.Contains(body, want) {
|
||||
t.Fatalf("stream body missing %q: %q", want, body)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Ask handler tests ----
|
||||
|
||||
func TestAskHandler_EmptyQuestion(t *testing.T) {
|
||||
|
||||
llm := &fakeStreamingLLM{chunks: []string{"answer"}}
|
||||
ret := &fakeChunkRetriever{result: &service.RetrievalTestResponse{
|
||||
Chunks: []map[string]interface{}{{"id": "c1", "content_with_weight": "test"}},
|
||||
}}
|
||||
h := NewSearchBotHandler(nil, nil, nil, ret)
|
||||
h.SetStreamLLM(llm)
|
||||
c, w := cw()
|
||||
c.Request = httptest.NewRequest("POST", "/api/v1/searchbots/ask",
|
||||
strings.NewReader(`{"question": " ", "kb_ids": ["kb1"]}`))
|
||||
@@ -844,10 +511,8 @@ func TestAskHandler_EmptyQuestion(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAskHandler_EmptyKbIDs(t *testing.T) {
|
||||
llm := &fakeStreamingLLM{chunks: []string{"answer"}}
|
||||
ret := &fakeChunkRetriever{result: &service.RetrievalTestResponse{}}
|
||||
h := NewSearchBotHandler(nil, nil, nil, ret)
|
||||
h.SetStreamLLM(llm)
|
||||
c, w := cw()
|
||||
c.Request = httptest.NewRequest("POST", "/api/v1/searchbots/ask",
|
||||
strings.NewReader(`{"question": "test", "kb_ids": []}`))
|
||||
@@ -861,13 +526,11 @@ func TestAskHandler_EmptyKbIDs(t *testing.T) {
|
||||
|
||||
func TestAskHandler_NoChatModel(t *testing.T) {
|
||||
buf := &bufferSSEWriter{}
|
||||
llm := &fakeStreamingLLM{chunks: []string{"answer"}}
|
||||
ret := &fakeChunkRetriever{result: &service.RetrievalTestResponse{
|
||||
Chunks: []map[string]interface{}{{"id": "c1", "content_with_weight": "test"}},
|
||||
}}
|
||||
h := NewSearchBotHandler(nil, nil, nil, ret)
|
||||
h.sseWriter = buf
|
||||
h.SetStreamLLM(llm)
|
||||
c, _ := cw()
|
||||
c.Request = httptest.NewRequest("POST", "/api/v1/searchbots/ask",
|
||||
strings.NewReader(`{"question": "test", "kb_ids": ["kb1"]}`))
|
||||
@@ -881,10 +544,8 @@ func TestAskHandler_NoChatModel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAskHandler_InvalidJSON(t *testing.T) {
|
||||
llm := &fakeStreamingLLM{chunks: []string{"answer"}}
|
||||
ret := &fakeChunkRetriever{result: &service.RetrievalTestResponse{}}
|
||||
h := NewSearchBotHandler(nil, nil, nil, ret)
|
||||
h.SetStreamLLM(llm)
|
||||
c, w := cw()
|
||||
c.Request = httptest.NewRequest("POST", "/api/v1/searchbots/ask",
|
||||
strings.NewReader(`not json`))
|
||||
@@ -897,10 +558,8 @@ func TestAskHandler_InvalidJSON(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAskHandler_WhitespaceKbIDFiltered(t *testing.T) {
|
||||
llm := &fakeStreamingLLM{chunks: []string{"answer"}}
|
||||
ret := &fakeChunkRetriever{result: &service.RetrievalTestResponse{}}
|
||||
h := NewSearchBotHandler(nil, nil, nil, ret)
|
||||
h.SetStreamLLM(llm)
|
||||
c, w := cw()
|
||||
c.Request = httptest.NewRequest("POST", "/api/v1/searchbots/ask",
|
||||
strings.NewReader(`{"question": "test", "kb_ids": [" ", ""]}`))
|
||||
@@ -912,51 +571,6 @@ func TestAskHandler_WhitespaceKbIDFiltered(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMindMapHandlerSuccess(t *testing.T) {
|
||||
llm := &fakeChatLLM{response: "# Product\n## Features\n### Search\n#### Hybrid retrieval"}
|
||||
chunks := &mockChunkService{retrievalTestFn: func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
|
||||
return &service.RetrievalTestResponse{
|
||||
Chunks: []map[string]interface{}{{"content_with_weight": "Hybrid search combines vector and keyword retrieval."}},
|
||||
}, nil
|
||||
}}
|
||||
h := NewSearchBotHandler(nil, nil, llm, chunks)
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: "user-1"})
|
||||
})
|
||||
r.POST("/api/v1/searchbots/mindmap", h.MindMap)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/searchbots/mindmap", strings.NewReader(`{"question":"What is search?","kb_ids":["kb-1"]}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected code 0, got %v: %v", resp["code"], resp["message"])
|
||||
}
|
||||
data := resp["data"].(map[string]interface{})
|
||||
if data["id"] != "Product" {
|
||||
t.Fatalf("mindmap root = %v, want Product", data["id"])
|
||||
}
|
||||
if chunks.LastReq == nil {
|
||||
t.Fatal("RetrievalTest was not called")
|
||||
}
|
||||
if *chunks.LastReq.Page != 1 || *chunks.LastReq.Size != 12 || *chunks.LastReq.TopK != 1024 {
|
||||
t.Fatalf("retrieval defaults = page %d size %d topK %d", *chunks.LastReq.Page, *chunks.LastReq.Size, *chunks.LastReq.TopK)
|
||||
}
|
||||
if llm.lastTenantID != "user-1" || len(llm.lastMessages) != 2 || !strings.Contains(llm.lastMessages[0].Content, "Hybrid search combines") {
|
||||
t.Fatalf("unexpected LLM call: tenant=%q messages=%v", llm.lastTenantID, llm.lastMessages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMindMapMarkdown_ListUnderHeading(t *testing.T) {
|
||||
got := parseMindMapMarkdown("# Product\n- Features\n - Search")
|
||||
if got.ID != "Product" {
|
||||
|
||||
@@ -275,6 +275,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
|
||||
// Chat routes
|
||||
v1.POST("/chat/mindmap", r.chatHandler.MindMap)
|
||||
v1.POST("/chat/recommendation", r.chatHandler.Recommendation)
|
||||
chats := v1.Group("/chats")
|
||||
{
|
||||
chats.GET("", r.chatHandler.ListChats)
|
||||
|
||||
82
internal/service/model_chat.go
Normal file
82
internal/service/model_chat.go
Normal file
@@ -0,0 +1,82 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"ragflow/internal/common"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
)
|
||||
|
||||
func (m *ModelProviderService) Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) {
|
||||
chatModel, err := m.GetChatModel(tenantID, modelID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return chatModel.ModelDriver.ChatWithMessages(*chatModel.ModelName, messages, chatModel.APIConfig, config)
|
||||
}
|
||||
|
||||
func (m *ModelProviderService) ChatStream(ctx context.Context, tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) {
|
||||
chatModel, err := m.GetChatModel(tenantID, modelID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return chatStreamWithContext(ctx, chatModel, messages, config), nil
|
||||
}
|
||||
|
||||
func chatStreamWithContext(ctx context.Context, chatModel *modelModule.ChatModel, messages []modelModule.Message, config *modelModule.ChatConfig) <-chan string {
|
||||
ch := make(chan string, 256)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
if err := chatModel.ModelDriver.ChatStreamlyWithSender(*chatModel.ModelName, messages, chatModel.APIConfig, config,
|
||||
func(delta *string, _ *string) error {
|
||||
if delta == nil {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case ch <- *delta:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}); err != nil {
|
||||
if err == context.Canceled || err == context.DeadlineExceeded {
|
||||
return
|
||||
}
|
||||
common.Warn("ChatStreamlyWithSender returned error", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
// TenantStreamAdapter adapts tenant/model-aware chat streaming to AskService.
|
||||
type TenantStreamAdapter struct {
|
||||
LLM *ModelProviderService
|
||||
TenantID string
|
||||
ModelID string
|
||||
}
|
||||
|
||||
func (a *TenantStreamAdapter) ChatStream(ctx context.Context, messages []modelModule.Message, config *modelModule.ChatConfig) (<-chan string, error) {
|
||||
if a.LLM == nil {
|
||||
return nil, fmt.Errorf("streaming LLM not configured")
|
||||
}
|
||||
return a.LLM.ChatStream(ctx, a.TenantID, a.ModelID, messages, config)
|
||||
}
|
||||
213
internal/service/related_question.go
Normal file
213
internal/service/related_question.go
Normal file
@@ -0,0 +1,213 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"ragflow/internal/entity"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
)
|
||||
|
||||
// GenerateRelatedQuestions generates related search questions for chat/searchbot endpoints.
|
||||
func GenerateRelatedQuestions(tenantID, question, searchID string, searchSvc *SearchService, tenantSvc *TenantService, modelProviderSvc *ModelProviderService) ([]string, error) {
|
||||
if modelProviderSvc == nil {
|
||||
return nil, fmt.Errorf("model provider service not configured")
|
||||
}
|
||||
searchConfig := relatedQuestionsSearchConfig(searchID, searchSvc)
|
||||
modelID := relatedQuestionsModelID(tenantID, searchConfig, tenantSvc)
|
||||
prompt, err := LoadPrompt("related_question")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages := []modelModule.Message{
|
||||
{Role: "system", Content: prompt},
|
||||
{Role: "user", Content: "\nKeywords: " + question + "\nRelated search terms:\n "},
|
||||
}
|
||||
response, err := modelProviderSvc.Chat(tenantID, modelID, messages, relatedQuestionsConfig(searchConfig))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if response != nil && response.Answer != nil {
|
||||
return parseRelatedQuestions(*response.Answer), nil
|
||||
}
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
func relatedQuestionsSearchConfig(searchID string, searchSvc *SearchService) map[string]interface{} {
|
||||
if searchID == "" || searchSvc == nil {
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
if detail, err := searchSvc.GetDetail(searchID); err == nil && detail != nil {
|
||||
return relatedQuestionsSearchConfigFromDetail(detail)
|
||||
}
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
|
||||
func relatedQuestionsSearchConfigFromDetail(detail map[string]interface{}) map[string]interface{} {
|
||||
if sc, ok := detail["search_config"].(map[string]interface{}); ok && sc != nil {
|
||||
return sc
|
||||
}
|
||||
if sc, ok := detail["search_config"].(entity.JSONMap); ok && sc != nil {
|
||||
return map[string]interface{}(sc)
|
||||
}
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
|
||||
func relatedQuestionsModelID(tenantID string, searchConfig map[string]interface{}, tenantSvc *TenantService) string {
|
||||
modelID, _ := searchConfig["chat_id"].(string)
|
||||
if modelID != "" || tenantSvc == nil {
|
||||
return modelID
|
||||
}
|
||||
defaultModel, err := tenantSvc.GetDefaultModelName(tenantID, entity.ModelTypeChat)
|
||||
if err == nil {
|
||||
modelID = defaultModel
|
||||
}
|
||||
return modelID
|
||||
}
|
||||
|
||||
func relatedQuestionsConfig(searchConfig map[string]interface{}) *modelModule.ChatConfig {
|
||||
var genConf map[string]interface{}
|
||||
switch v := searchConfig["llm_setting"].(type) {
|
||||
case map[string]interface{}:
|
||||
genConf = v
|
||||
case entity.JSONMap:
|
||||
genConf = map[string]interface{}(v)
|
||||
}
|
||||
if genConf == nil {
|
||||
return &modelModule.ChatConfig{Temperature: float64Ptr(0.9)}
|
||||
}
|
||||
cfg := &modelModule.ChatConfig{}
|
||||
for key, value := range genConf {
|
||||
if key == "parameter" {
|
||||
continue
|
||||
}
|
||||
switch key {
|
||||
case "stream":
|
||||
if v, ok := value.(bool); ok {
|
||||
cfg.Stream = &v
|
||||
}
|
||||
case "thinking":
|
||||
if v, ok := value.(bool); ok {
|
||||
cfg.Thinking = &v
|
||||
}
|
||||
case "max_tokens":
|
||||
if v, ok := intFromRelatedQuestionConfig(value); ok {
|
||||
cfg.MaxTokens = &v
|
||||
}
|
||||
case "temperature":
|
||||
if v, ok := floatFromRelatedQuestionConfig(value); ok {
|
||||
cfg.Temperature = &v
|
||||
}
|
||||
case "top_p":
|
||||
if v, ok := floatFromRelatedQuestionConfig(value); ok && v > 0 {
|
||||
cfg.TopP = &v
|
||||
}
|
||||
case "do_sample":
|
||||
if v, ok := value.(bool); ok {
|
||||
cfg.DoSample = &v
|
||||
}
|
||||
case "stop":
|
||||
if stops := stringSliceFromRelatedQuestionConfig(value); len(stops) > 0 {
|
||||
cfg.Stop = &stops
|
||||
}
|
||||
case "model_class":
|
||||
if v, ok := value.(string); ok {
|
||||
cfg.ModelClass = &v
|
||||
}
|
||||
case "effort":
|
||||
if v, ok := value.(string); ok {
|
||||
cfg.Effort = &v
|
||||
}
|
||||
case "verbosity":
|
||||
if v, ok := value.(string); ok {
|
||||
cfg.Verbosity = &v
|
||||
}
|
||||
}
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func intFromRelatedQuestionConfig(value interface{}) (int, bool) {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v, true
|
||||
case int64:
|
||||
return int(v), true
|
||||
case float64:
|
||||
return int(v), true
|
||||
case json.Number:
|
||||
if n, err := v.Int64(); err == nil {
|
||||
return int(n), true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func floatFromRelatedQuestionConfig(value interface{}) (float64, bool) {
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return v, true
|
||||
case float32:
|
||||
return float64(v), true
|
||||
case int:
|
||||
return float64(v), true
|
||||
case int64:
|
||||
return float64(v), true
|
||||
case json.Number:
|
||||
if n, err := v.Float64(); err == nil {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func stringSliceFromRelatedQuestionConfig(value interface{}) []string {
|
||||
switch v := value.(type) {
|
||||
case []string:
|
||||
return v
|
||||
case []interface{}:
|
||||
result := make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
if s, ok := item.(string); ok {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var relatedQuestionLineRe = regexp.MustCompile(`^\d+\.\s`)
|
||||
|
||||
func parseRelatedQuestions(text string) []string {
|
||||
var result []string
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
if relatedQuestionLineRe.MatchString(line) {
|
||||
result = append(result, relatedQuestionLineRe.ReplaceAllString(line, ""))
|
||||
}
|
||||
}
|
||||
if result == nil {
|
||||
return []string{}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func float64Ptr(v float64) *float64 { return &v }
|
||||
97
internal/service/related_question_test.go
Normal file
97
internal/service/related_question_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestRelatedQuestionsConfigSkipsZeroTopP(t *testing.T) {
|
||||
cfg := relatedQuestionsConfig(map[string]interface{}{
|
||||
"llm_setting": map[string]interface{}{
|
||||
"temperature": float64(0.2),
|
||||
"top_p": float64(0),
|
||||
"parameter": map[string]interface{}{"unused": true},
|
||||
},
|
||||
})
|
||||
|
||||
if cfg == nil || cfg.Temperature == nil || *cfg.Temperature != 0.2 {
|
||||
t.Fatalf("expected temperature 0.2, got %+v", cfg)
|
||||
}
|
||||
if cfg.TopP != nil {
|
||||
t.Fatalf("expected zero top_p to be omitted, got %v", *cfg.TopP)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRelatedQuestionsStandard(t *testing.T) {
|
||||
input := `1. How do electric vehicles impact the environment?
|
||||
2. What are the advantages of owning an electric car?
|
||||
3. What is the cost-effectiveness?`
|
||||
|
||||
got := parseRelatedQuestions(input)
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("expected 3, got %d", len(got))
|
||||
}
|
||||
if got[0] != "How do electric vehicles impact the environment?" {
|
||||
t.Errorf("unexpected [0]: %q", got[0])
|
||||
}
|
||||
if got[1] != "What are the advantages of owning an electric car?" {
|
||||
t.Errorf("unexpected [1]: %q", got[1])
|
||||
}
|
||||
if got[2] != "What is the cost-effectiveness?" {
|
||||
t.Errorf("unexpected [2]: %q", got[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRelatedQuestionsEmpty(t *testing.T) {
|
||||
got := parseRelatedQuestions("")
|
||||
if len(got) != 0 {
|
||||
t.Errorf("expected 0, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRelatedQuestionsNoNumberedLines(t *testing.T) {
|
||||
input := `Here are some related questions:
|
||||
- First question
|
||||
- Second question`
|
||||
|
||||
got := parseRelatedQuestions(input)
|
||||
if len(got) != 0 {
|
||||
t.Errorf("expected 0, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRelatedQuestionsMixedContent(t *testing.T) {
|
||||
input := `Here are some related questions:
|
||||
1. First related question.
|
||||
Some explanation text.
|
||||
2. Second related question.
|
||||
More text.
|
||||
3. Third related question.`
|
||||
|
||||
got := parseRelatedQuestions(input)
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("expected 3, got %d", len(got))
|
||||
}
|
||||
if got[0] != "First related question." {
|
||||
t.Errorf("unexpected [0]: %q", got[0])
|
||||
}
|
||||
if got[1] != "Second related question." {
|
||||
t.Errorf("unexpected [1]: %q", got[1])
|
||||
}
|
||||
if got[2] != "Third related question." {
|
||||
t.Errorf("unexpected [2]: %q", got[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRelatedQuestionsMultiDigit(t *testing.T) {
|
||||
input := `10. Tenth question.
|
||||
11. Eleventh question.`
|
||||
|
||||
got := parseRelatedQuestions(input)
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2, got %d", len(got))
|
||||
}
|
||||
if got[0] != "Tenth question." {
|
||||
t.Errorf("unexpected [0]: %q", got[0])
|
||||
}
|
||||
if got[1] != "Eleventh question." {
|
||||
t.Errorf("unexpected [1]: %q", got[1])
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user