Files
ragflow/internal/handler/searchbot.go

575 lines
18 KiB
Go

//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package handler
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"ragflow/internal/common"
"ragflow/internal/entity"
"ragflow/internal/service"
"go.uber.org/zap"
)
// SearchBotAskRequest is the request body for POST /api/v1/searchbots/ask.
type SearchBotAskRequest struct {
Question string `json:"question" binding:"required"`
KbIDs common.StringSlice `json:"kb_ids" binding:"required"`
SearchID string `json:"search_id,omitempty"`
}
// SearchBotMindMapRequest is the request body for POST /api/v1/searchbots/mindmap.
type SearchBotMindMapRequest struct {
Question string `json:"question" binding:"required"`
KbIDs common.StringSlice `json:"kb_ids" binding:"required"`
SearchID string `json:"search_id,omitempty"`
}
// SearchBotRetrievalTestRequest is the request body for POST /api/v1/searchbots/retrieval_test.
type SearchBotRetrievalTestRequest struct {
KbIDs common.StringSlice `json:"kb_ids" binding:"required"`
Question string `json:"question" binding:"required"`
Page *int `json:"page,omitempty"`
Size *int `json:"size,omitempty"`
DocIDs []string `json:"doc_ids,omitempty"`
UseKG *bool `json:"use_kg,omitempty"`
TopK *int `json:"top_k,omitempty"`
CrossLanguages []string `json:"cross_languages,omitempty"`
SearchID *string `json:"search_id,omitempty"`
MetaDataFilter map[string]interface{} `json:"meta_data_filter,omitempty"`
TenantRerankID *string `json:"tenant_rerank_id,omitempty"`
RerankID *string `json:"rerank_id,omitempty"`
Keyword *bool `json:"keyword,omitempty"`
SimilarityThreshold *float64 `json:"similarity_threshold,omitempty"`
VectorSimilarityWeight *float64 `json:"vector_similarity_weight,omitempty"`
// TODO: wire highlight to nlp Retrieval when engine supports highlightFields
// Python: bot_api.py → retrieval(highlight=req.get("highlight"))
// → search.py highlightFields → ES get_highlight()
// Issue: https://github.com/infiniflow/ragflow/issues/15712
// Highlight *bool `json:"highlight,omitempty"`
}
// UnmarshalJSON accepts both kb_id (Python API) and kb_ids (Go compatibility).
func (r *SearchBotRetrievalTestRequest) UnmarshalJSON(data []byte) error {
type Alias SearchBotRetrievalTestRequest
aux := struct {
*Alias
KbID common.StringSlice `json:"kb_id"`
}{
Alias: (*Alias)(r),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
if len(r.KbIDs) == 0 && len(aux.KbID) > 0 {
r.KbIDs = aux.KbID
}
return nil
}
// SearchBotRequest is the request body for POST /api/v1/searchbots/related_questions.
type SearchBotRequest struct {
Question string `json:"question" binding:"required"`
SearchID string `json:"search_id,omitempty"`
}
// SearchBotHandler handles searchbot endpoints:
//
// POST /api/v1/searchbots/related_questions
// POST /api/v1/searchbots/retrieval_test
// POST /api/v1/searchbots/ask
// POST /api/v1/searchbots/mindmap
type SearchBotHandler struct {
searchSvc *service.SearchService
tenantSvc *service.TenantService
llm *service.ModelProviderService
streamLLM *service.ModelProviderService
chunkSvc service.Retriever
askSvc *service.AskService
sseWriter SSEWriter
}
// NewSearchBotHandler creates a new SearchBotHandler.
func NewSearchBotHandler(searchSvc *service.SearchService, tenantSvc *service.TenantService, llm *service.ModelProviderService, chunkSvc service.Retriever) *SearchBotHandler {
return &SearchBotHandler{searchSvc: searchSvc, tenantSvc: tenantSvc, llm: llm, chunkSvc: chunkSvc, sseWriter: &ginSSEWriter{}}
}
// SetStreamLLM sets the streaming LLM for the Ask endpoint.
func (h *SearchBotHandler) SetStreamLLM(llm *service.ModelProviderService) { h.streamLLM = llm }
// SetAskService sets the AskService used by the Ask endpoint.
func (h *SearchBotHandler) SetAskService(svc *service.AskService) { h.askSvc = svc }
// Handle generates related search questions based on a user query.
// @Summary Generate Related Questions
// @Description Generates 5-10 related search questions to expand the search scope.
// @Tags searchbots
// @Accept json
// @Produce json
// @Param request body SearchBotRequest true "Request body"
// @Success 200 {object} map[string]interface{}
// @Router /api/v1/searchbots/related_questions [post]
func (h *SearchBotHandler) Handle(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
var req SearchBotRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
"data": nil,
"message": "question is required",
})
return
}
if req.Question == "" {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
"data": nil,
"message": "question is required",
})
return
}
questions, err := service.GenerateRelatedQuestions(user.ID, req.Question, req.SearchID, h.searchSvc, h.tenantSvc, h.llm)
if err != nil {
common.Warn("searchbot related questions failed", zap.String("error", err.Error()))
c.JSON(http.StatusOK, gin.H{
"code": common.CodeOperatingError,
"data": nil,
"message": "LLM call failed",
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"data": questions,
"message": "success",
})
}
// RetrievalTest performs a retrieval test against specified knowledge bases.
// @Summary Retrieval Test
// @Description Test document retrieval across knowledge bases with optional filters, reranking, and KG search.
// @Tags searchbots
// @Accept json
// @Produce json
// @Param request body SearchBotRetrievalTestRequest true "Retrieval test parameters"
// @Success 200 {object} map[string]interface{}
// @Router /api/v1/searchbots/retrieval_test [post]
func (h *SearchBotHandler) RetrievalTest(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
c.JSON(http.StatusUnauthorized, gin.H{"code": errorCode, "data": nil, "message": errorMessage})
return
}
var req SearchBotRetrievalTestRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": err.Error()})
return
}
// Filter out empty strings from KbIDs before validation.
filtered := make(common.StringSlice, 0, len(req.KbIDs))
for _, id := range req.KbIDs {
if strings.TrimSpace(id) != "" {
filtered = append(filtered, id)
}
}
req.KbIDs = filtered
if len(req.KbIDs) == 0 || req.Question == "" {
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": "kb_id and question are required"})
return
}
applyRetrievalDefaults(&req)
if req.TopK != nil && *req.TopK <= 0 {
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": "top_k must be greater than 0"})
return
}
svcReq := toRetrievalServiceRequest(&req)
result, err := h.chunkSvc.RetrievalTest(svcReq, user.ID)
if err != nil {
common.Warn("searchbot retrieval test failed", zap.String("error", err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "data": nil, "message": "retrieval test failed"})
return
}
c.JSON(http.StatusOK, gin.H{"code": int(common.CodeSuccess), "data": result, "message": "success"})
}
// Ask performs a retrieval-augmented Q&A with streaming SSE response.
// @Summary Ask with Knowledge Bases
// @Description Retrieves chunks, builds prompt, and streams LLM answer with citations via SSE.
// @Tags searchbots
// @Accept json
// @Produce text/event-stream
// @Param request body SearchBotAskRequest true "Ask parameters"
// @Success 200 {object} map[string]interface{}
// @Router /api/v1/searchbots/ask [post]
func (h *SearchBotHandler) Ask(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
var req SearchBotAskRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": err.Error()})
return
}
// Filter empty kb_ids.
filtered := make(common.StringSlice, 0, len(req.KbIDs))
for _, id := range req.KbIDs {
if strings.TrimSpace(id) == "" {
continue
}
filtered = append(filtered, id)
}
if len(filtered) == 0 || strings.TrimSpace(req.Question) == "" {
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": "kb_ids and question are required"})
return
}
// Resolve chat model ID.
modelID := ""
if req.SearchID != "" && h.searchSvc != nil {
if detail, err := h.searchSvc.GetDetail(req.SearchID); err == nil {
if sc, ok := detail["search_config"].(map[string]interface{}); ok {
if cid, ok := sc["chat_id"].(string); ok && cid != "" {
modelID = cid
}
}
}
}
if modelID == "" && h.tenantSvc != nil {
defaultModel, err := h.tenantSvc.GetDefaultModelName(user.ID, entity.ModelTypeChat)
if err == nil && defaultModel != "" {
modelID = defaultModel
}
}
if modelID == "" {
h.sseWriter.Write(c, sseError("chat model not configured"))
return
}
disableWriteDeadlineForSSE(c)
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
if h.askSvc == nil {
h.sseWriter.Write(c, sseError("ask service not configured"))
return
}
if h.streamLLM == nil {
h.sseWriter.Write(c, sseError("streaming LLM not configured"))
return
}
ctx := c.Request.Context()
adapter := &service.TenantStreamAdapter{LLM: h.streamLLM, TenantID: user.ID, ModelID: modelID}
for delta := range h.askSvc.Stream(ctx, adapter, user.ID, req.Question, filtered) {
switch delta.Kind {
case service.AskDeltaAnswer:
h.sseWriter.Write(c, sseAnswer(delta.Value, nil, false))
case service.AskDeltaMarker:
h.sseWriter.Write(c, sseMarker(delta.Value))
case service.AskDeltaError:
h.sseWriter.Write(c, sseError(delta.Value))
case service.AskDeltaFinal:
h.sseWriter.Write(c, sseAnswer(delta.Value, delta.Refs, true))
}
}
c.Stream(func(w io.Writer) bool {
fmt.Fprintf(w, "data: {\"code\": 0, \"message\": \"\", \"data\": true}\n\n")
return false
})
}
// MindMap generates a query mind map for a shared search bot.
// @Summary Generate Mind Map
// @Description Retrieves related chunks and asks the configured chat model to summarize them into a mind map.
// @Tags searchbots
// @Accept json
// @Produce json
// @Param request body SearchBotMindMapRequest true "Mind map parameters"
// @Success 200 {object} map[string]interface{}
// @Router /api/v1/searchbots/mindmap [post]
func (h *SearchBotHandler) MindMap(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
var req SearchBotMindMapRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": err.Error()})
return
}
filtered := make(common.StringSlice, 0, len(req.KbIDs))
for _, id := range req.KbIDs {
if strings.TrimSpace(id) != "" {
filtered = append(filtered, id)
}
}
if len(filtered) == 0 || strings.TrimSpace(req.Question) == "" {
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": "kb_ids and question are required"})
return
}
if h.chunkSvc == nil {
jsonInternalError(c, fmt.Errorf("chunk service not configured"))
return
}
if h.llm == nil {
jsonInternalError(c, fmt.Errorf("LLM not configured"))
return
}
searchConfig := map[string]interface{}{}
if req.SearchID != "" {
if h.searchSvc == nil {
jsonInternalError(c, fmt.Errorf("search service not configured"))
return
}
detail, err := h.searchSvc.GetDetail(req.SearchID)
if err != nil {
jsonInternalError(c, err)
return
}
searchConfig = searchConfigFromDetail(detail)
}
mindMap, err := runMindMap(mindMapRunConfig{
Question: req.Question,
KbIDs: filtered,
SearchID: req.SearchID,
SearchConfig: searchConfig,
AuthUserID: user.ID,
ModelTenantID: user.ID,
ChunkSvc: h.chunkSvc,
LLM: h.llm,
TenantSvc: h.tenantSvc,
})
if err != nil {
common.Warn("searchbot mindmap failed", zap.String("error", err.Error()))
jsonInternalError(c, err)
return
}
jsonResponse(c, common.CodeSuccess, mindMap, "success")
}
// SearchbotDetail returns the public share-page bootstrap payload for a
// search app. The route is mounted under apiNoAuth but still requires a beta
// token, matching Python's AUTH_BETA flow.
func (h *SearchBotHandler) SearchbotDetail(c *gin.Context) {
searchID := strings.TrimSpace(c.Query("search_id"))
if searchID == "" {
jsonError(c, common.CodeArgumentError, "search_id is required")
return
}
userSvc := service.NewUserService()
user, code, err := userSvc.GetUserByBetaAPIToken(c.GetHeader("Authorization"))
if err != nil {
jsonError(c, code, "Authentication error: API key is invalid!")
return
}
detail, err := h.searchSvc.GetSearchShareDetail(user.ID, searchID)
if err != nil {
switch err.Error() {
case "has no permission for this operation":
jsonError(c, common.CodeOperatingError, "Has no permission for this operation.")
case "can't find this Search App!":
jsonError(c, common.CodeDataError, "Can't find this Search App!")
default:
jsonInternalError(c, err)
}
return
}
jsonResponse(c, common.CodeSuccess, detail, "success")
}
// ---- SSE helpers ----
type ssePayload struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data"`
}
// askSSEData is the inner data object for SSE events, matching Python bot_api.py.
// The Reference field is always present (non-nil) so the frontend can safely
// access .chunks or .reduce without a null guard.
type askSSEData struct {
Answer string `json:"answer"`
Reference interface{} `json:"reference"`
Final bool `json:"final"`
StartToThink bool `json:"start_to_think,omitempty"`
EndToThink bool `json:"end_to_think,omitempty"`
}
func sseAnswer(answer string, refs interface{}, final bool) string {
if refs == nil {
refs = map[string]interface{}{}
}
payload := ssePayload{
Code: 0,
Message: "",
Data: askSSEData{
Answer: answer,
Reference: refs,
Final: final,
},
}
b, _ := json.Marshal(payload)
return fmt.Sprintf("data: %s\n\n", string(b))
}
// sseError matches Python bot_api.py error format:
//
// {"code": 500, "message": "...", "data": {"answer": "**ERROR**: ...", "reference": []}}
func sseError(message string) string {
payload := ssePayload{
Code: int(common.CodeServerError),
Message: message,
Data: askSSEData{
Answer: "**ERROR**: " + message,
Reference: []map[string]interface{}{},
},
}
b, _ := json.Marshal(payload)
return fmt.Sprintf("data: %s\n\n", string(b))
}
// sseMarker matches Python dialog_service.py think-tag marker format:
//
// {"answer": "", "reference": {}, "final": false, "start_to_think": true}
func sseMarker(marker string) string {
d := askSSEData{
Answer: "",
Reference: map[string]interface{}{},
}
if marker == "<think>" {
d.StartToThink = true
} else {
d.EndToThink = true
}
payload := ssePayload{Code: 0, Message: "", Data: d}
b, _ := json.Marshal(payload)
return fmt.Sprintf("data: %s\n\n", string(b))
}
type SSEWriter interface {
Write(c *gin.Context, data string)
}
// ginSSEWriter is the production SSEWriter backed by gin.Context.Stream.
type ginSSEWriter struct{}
func (w *ginSSEWriter) Write(c *gin.Context, data string) {
c.Stream(func(w io.Writer) bool {
fmt.Fprint(w, data)
return false
})
}
// toRetrievalServiceRequest maps the handler DTO to the service DTO.
// The two structs differ in KbIDs (StringSlice → []string) and
// MetaDataFilter (→ Filter) to maintain Python API compatibility.
func toRetrievalServiceRequest(h *SearchBotRetrievalTestRequest) *service.RetrievalTestRequest {
return &service.RetrievalTestRequest{
Datasets: common.StringSlice(h.KbIDs),
Question: h.Question,
Page: h.Page,
Size: h.Size,
DocIDs: h.DocIDs,
UseKG: h.UseKG,
TopK: h.TopK,
CrossLanguages: h.CrossLanguages,
SearchID: h.SearchID,
Filter: h.MetaDataFilter,
TenantRerankID: h.TenantRerankID,
RerankID: h.RerankID,
Keyword: h.Keyword,
SimilarityThreshold: h.SimilarityThreshold,
VectorSimilarityWeight: h.VectorSimilarityWeight,
}
}
// ptrFloat64 returns a pointer to a float64 value.
func ptrFloat64(v float64) *float64 { return &v }
func intPtr(v int) *int { return &v }
func floatPtr(v float64) *float64 { return &v }
// applyRetrievalDefaults fills in default values for optional fields,
// matching Python bot_api.py retrieval_test endpoint.
func applyRetrievalDefaults(req *SearchBotRetrievalTestRequest) {
if req.Page == nil {
v := 1
req.Page = &v
}
if req.Size == nil {
v := 30
req.Size = &v
}
if req.TopK == nil {
v := 1024
req.TopK = &v
}
if req.UseKG == nil {
v := false
req.UseKG = &v
}
if req.Keyword == nil {
v := false
req.Keyword = &v
}
if req.SimilarityThreshold == nil {
v := 0.0
req.SimilarityThreshold = &v
}
if req.VectorSimilarityWeight == nil {
v := 0.3
req.VectorSimilarityWeight = &v
}
}