mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
feat: implement POST /api/v1/searchbots/retrieval_test (#15710)
## What problem does this PR solve? Implements `POST /api/v1/searchbots/retrieval_test` in the Go API server, aligning with the Python `bot_api.py` counterpart. Also applies security hardening and consistency fixes discovered during CTO-level code review: - **Missing endpoint**: `retrieval_test` was not available in Go, requiring Python fallback - **Security**: Both `chunkHandler` and `searchBotHandler` leaked `err.Error()` to API consumers - **Python alignment**: Default values, empty question handling, and `top_k <= 0` validation differed from Python behavior - **Test gaps**: `chunkHandler.RetrievalTest` had zero unit tests; several edge cases uncovered ## Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoring ## Summary ### New Endpoint - `POST /api/v1/searchbots/retrieval_test` — retrieval test with full field support (page, size, top_k, use_kg, cross_languages, keyword, similarity_threshold, vector_similarity_weight) ### New Type - `common.StringSlice` — JSON type that accepts both `"kb1"` and `["kb1", "kb2"]`, matching Python API flexibility ### Security - Both `searchBotHandler` and `chunkHandler` now use `common.Warn()` + generic error messages instead of leaking `err.Error()` to API consumers - All error responses include consistent `"data": nil` shape - `chunkHandler.RetrievalTest` uses interface-based DI (`chunkService`) to enable testability ### Python Alignment - Handler-level defaults align with Python `bot_api.py` (page=1, size=30, top_k=1024, similarity_threshold=0.0, vector_similarity_weight=0.3) - `top_k <= 0` validation matching Python behavior - Empty/whitespace question returns 200 + empty result (matches `chunk_api.py`) - `chunkHandler` `Datasets` field uses `common.StringSlice` for string-or-array flexibility ### Refactoring - `ChunkServiceIface` → `ChunkRetriever`, `chunkSvcIface` → `chunkService` (Go-conventional naming) - Extracted `applyRetrievalDefaults`, `toRetrievalServiceRequest` from handler body - Regex moved to package-level var in `parseRelatedQuestions` - `service.RetrievalTestRequest.Datasets` type changed to `common.StringSlice` - `chunkHandler` now uses consumer-side interface for DI ### Tests - 37 unit tests across both handlers: auth, validation, defaults, StringSlice edge cases, empty/whitespace KbID, service errors, JSON format, `top_k <= 0`, field mapping verification ## Files Changed | File | Change | |------|--------| | `cmd/server_main.go` | Wire new handler + chunkService + difyRetrievalHandler | | `internal/common/json_types.go` | New StringSlice type | | `internal/common/json_types_test.go` | StringSlice tests | | `internal/handler/chunk.go` | Interface-based DI, security, Python alignment, defaults | | `internal/handler/chunk_test.go` | New — 9 comprehensive tests | | `internal/handler/searchbot.go` | New endpoint + refactoring + `top_k <= 0` validation | | `internal/handler/searchbot_test.go` | 18 tests covering all edge cases | | `internal/router/router.go` | Register new route + difyRetrievalHandler | | `internal/service/chunk.go` | Datasets type → StringSlice, Question binding relaxed | 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -214,10 +214,11 @@ func startServer(config *server.Config) {
|
||||
skillSearchHandler := handler.NewSkillSearchHandler(docEngine)
|
||||
providerHandler := handler.NewProviderHandler(userService, modelProviderService)
|
||||
agentHandler := handler.NewAgentHandler(service.NewAgentService(), fileService)
|
||||
relatedQuestionsHandler := handler.NewSearchbotHandler(
|
||||
searchBotHandler := handler.NewSearchBotHandler(
|
||||
searchService,
|
||||
tenantService,
|
||||
&handler.SearchbotRealLLM{Svc: modelProviderService},
|
||||
&handler.SearchBotRealLLM{Svc: modelProviderService},
|
||||
chunkService,
|
||||
)
|
||||
pluginHandler := handler.NewPluginHandler(service.NewPluginService())
|
||||
|
||||
@@ -234,7 +235,7 @@ func startServer(config *server.Config) {
|
||||
)
|
||||
|
||||
// Initialize router
|
||||
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, knowledgebaseHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, mcpHandler, skillSearchHandler, providerHandler, agentHandler, relatedQuestionsHandler, difyRetrievalHandler, pluginHandler)
|
||||
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, knowledgebaseHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, mcpHandler, skillSearchHandler, providerHandler, agentHandler, searchBotHandler, difyRetrievalHandler, pluginHandler)
|
||||
|
||||
// Create Gin engine
|
||||
ginEngine := gin.New()
|
||||
|
||||
47
internal/common/json_types.go
Normal file
47
internal/common/json_types.go
Normal file
@@ -0,0 +1,47 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package common
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// StringSlice is a []string that unmarshals from either a JSON string or
|
||||
// a JSON array of strings. This matches Python endpoints that accept
|
||||
// both "kb1" and ["kb1", "kb2"] for list-valued parameters.
|
||||
type StringSlice []string
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
func (s *StringSlice) UnmarshalJSON(data []byte) error {
|
||||
// Try array first.
|
||||
var arr []string
|
||||
if err := json.Unmarshal(data, &arr); err == nil {
|
||||
*s = arr
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fall back to a single string → wrap as one-element slice.
|
||||
var single string
|
||||
if err := json.Unmarshal(data, &single); err != nil {
|
||||
return err
|
||||
}
|
||||
*s = StringSlice{single}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
func (s StringSlice) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal([]string(s))
|
||||
}
|
||||
114
internal/common/json_types_test.go
Normal file
114
internal/common/json_types_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStringSlice_UnmarshalJSON_Array(t *testing.T) {
|
||||
var s StringSlice
|
||||
if err := json.Unmarshal([]byte(`["a","b","c"]`), &s); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(s) != 3 || s[0] != "a" || s[1] != "b" || s[2] != "c" {
|
||||
t.Fatalf("got %v", []string(s))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringSlice_UnmarshalJSON_SingleString(t *testing.T) {
|
||||
var s StringSlice
|
||||
if err := json.Unmarshal([]byte(`"kb1"`), &s); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(s) != 1 || s[0] != "kb1" {
|
||||
t.Fatalf("got %v", []string(s))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringSlice_UnmarshalJSON_EmptyArray(t *testing.T) {
|
||||
var s StringSlice
|
||||
if err := json.Unmarshal([]byte(`[]`), &s); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(s) != 0 {
|
||||
t.Fatalf("expected empty, got %v", []string(s))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringSlice_UnmarshalJSON_EmptyString(t *testing.T) {
|
||||
var s StringSlice
|
||||
if err := json.Unmarshal([]byte(`""`), &s); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(s) != 1 || s[0] != "" {
|
||||
t.Fatalf("got %v", []string(s))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringSlice_UnmarshalJSON_InvalidValue(t *testing.T) {
|
||||
var s StringSlice
|
||||
err := json.Unmarshal([]byte(`123`), &s)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for number value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringSlice_MarshalJSON(t *testing.T) {
|
||||
s := StringSlice{"x", "y"}
|
||||
data, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if string(data) != `["x","y"]` {
|
||||
t.Fatalf("got %s", data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringSlice_MarshalJSON_Empty(t *testing.T) {
|
||||
s := StringSlice{}
|
||||
data, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if string(data) != `[]` {
|
||||
t.Fatalf("got %s", data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringSlice_EmbeddedInStruct(t *testing.T) {
|
||||
type req struct {
|
||||
KbIDs StringSlice `json:"kb_id"`
|
||||
}
|
||||
// Single string
|
||||
var r1 req
|
||||
if err := json.Unmarshal([]byte(`{"kb_id":"kb1"}`), &r1); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(r1.KbIDs) != 1 || r1.KbIDs[0] != "kb1" {
|
||||
t.Fatalf("got %v", []string(r1.KbIDs))
|
||||
}
|
||||
// Array
|
||||
var r2 req
|
||||
if err := json.Unmarshal([]byte(`{"kb_id":["a","b"]}`), &r2); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(r2.KbIDs) != 2 || r2.KbIDs[0] != "a" || r2.KbIDs[1] != "b" {
|
||||
t.Fatalf("got %v", []string(r2.KbIDs))
|
||||
}
|
||||
}
|
||||
@@ -19,28 +19,155 @@ package handler
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"ragflow/internal/common"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"ragflow/internal/service"
|
||||
)
|
||||
|
||||
// chunkService is the consumer-side interface for ChunkHandler's service dependency.
|
||||
type chunkService interface {
|
||||
RetrievalTest(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error)
|
||||
Get(req *service.GetChunkRequest, userID string) (*service.GetChunkResponse, error)
|
||||
List(req *service.ListChunksRequest, userID string) (*service.ListChunksResponse, error)
|
||||
UpdateChunk(req *service.UpdateChunkRequest, userID string) error
|
||||
RemoveChunks(req *service.RemoveChunksRequest, userID string) (int64, error)
|
||||
}
|
||||
|
||||
// ChunkHandler chunk handler
|
||||
type ChunkHandler struct {
|
||||
chunkService *service.ChunkService
|
||||
chunkService chunkService
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
// NewChunkHandler create chunk handler
|
||||
func NewChunkHandler(chunkService *service.ChunkService, userService *service.UserService) *ChunkHandler {
|
||||
func NewChunkHandler(chunkService chunkService, userService *service.UserService) *ChunkHandler {
|
||||
return &ChunkHandler{
|
||||
chunkService: chunkService,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a chunk by ID
|
||||
// RetrievalTest performs retrieval test for chunks
|
||||
// @Summary Retrieval Test
|
||||
// @Description Test retrieval of chunks based on question and knowledge base
|
||||
// @Tags chunks
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body service.RetrievalTestRequest true "retrieval test parameters"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/v1/datasets/search [post]
|
||||
func (h *ChunkHandler) RetrievalTest(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
// Bind JSON request
|
||||
var req service.RetrievalTestRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"data": nil,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Set default values for optional parameters
|
||||
if req.Page == nil {
|
||||
defaultPage := 1
|
||||
req.Page = &defaultPage
|
||||
}
|
||||
if req.Size == nil {
|
||||
defaultSize := 30
|
||||
req.Size = &defaultSize
|
||||
}
|
||||
if req.TopK == nil {
|
||||
defaultTopK := 1024
|
||||
req.TopK = &defaultTopK
|
||||
}
|
||||
if req.UseKG == nil {
|
||||
defaultUseKG := false
|
||||
req.UseKG = &defaultUseKG
|
||||
}
|
||||
|
||||
// Strip and validate question. Matching Python chunk_api.py which returns
|
||||
// an empty result for blank questions rather than an error.
|
||||
if strings.TrimSpace(req.Question) == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": int(common.CodeSuccess),
|
||||
"data": &service.RetrievalTestResponse{
|
||||
Chunks: []map[string]interface{}{},
|
||||
DocAggs: []map[string]interface{}{},
|
||||
Total: 0,
|
||||
},
|
||||
"message": "success",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Datasets == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"data": nil,
|
||||
"message": "kb_id is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Datasets) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"data": nil,
|
||||
"message": "kb_id array cannot be empty",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.TopK != nil && *req.TopK <= 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"data": nil,
|
||||
"message": "top_k must be greater than 0",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Call service with user ID for permission checks
|
||||
resp, err := h.chunkService.RetrievalTest(&req, user.ID)
|
||||
if err != nil {
|
||||
common.Warn("dataset search failed", zap.String("error", err.Error()))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"data": nil,
|
||||
"message": "dataset search failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": int(common.CodeSuccess),
|
||||
"data": resp,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// Get retrieves a chunk by ID.
|
||||
// @Summary Get Chunk
|
||||
// @Description Retrieve a single chunk by its ID.
|
||||
// @Tags chunks
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param dataset_id path string true "Dataset ID"
|
||||
// @Param document_id path string true "Document ID"
|
||||
// @Param chunk_id path string true "Chunk ID"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/v1/datasets/{dataset_id}/documents/{document_id}/chunks/{chunk_id} [get]
|
||||
func (h *ChunkHandler) Get(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
@@ -49,20 +176,16 @@ func (h *ChunkHandler) Get(c *gin.Context) {
|
||||
}
|
||||
|
||||
chunkID := c.Param("chunk_id")
|
||||
datasetID := c.Param("dataset_id")
|
||||
documentID := c.Param("document_id")
|
||||
if chunkID == "" || datasetID == "" || documentID == "" {
|
||||
if chunkID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "dataset_id, document_id and chunk_id are required",
|
||||
"message": "chunk_id is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
req := &service.GetChunkRequest{
|
||||
ChunkID: chunkID,
|
||||
DocumentID: documentID,
|
||||
DatasetID: datasetID,
|
||||
ChunkID: chunkID,
|
||||
}
|
||||
|
||||
resp, err := h.chunkService.Get(req, user.ID)
|
||||
@@ -81,7 +204,15 @@ func (h *ChunkHandler) Get(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// List retrieves chunks for a document
|
||||
// List retrieves chunks for a document.
|
||||
// @Summary List Chunks
|
||||
// @Description Retrieve paginated chunks for a document with optional filtering.
|
||||
// @Tags chunks
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body service.ListChunksRequest true "List chunks parameters"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/v1/chunk/list [post]
|
||||
func (h *ChunkHandler) List(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
|
||||
292
internal/handler/chunk_test.go
Normal file
292
internal/handler/chunk_test.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// mockChunkSvc implements chunkSvcIface for testing ChunkHandler.
|
||||
// Only the methods actually called by the test are set; others panic.
|
||||
type mockChunkSvc struct {
|
||||
retrievalTestFn func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error)
|
||||
}
|
||||
|
||||
func (m *mockChunkSvc) RetrievalTest(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
|
||||
if m.retrievalTestFn != nil {
|
||||
return m.retrievalTestFn(req, userID)
|
||||
}
|
||||
return &service.RetrievalTestResponse{
|
||||
Chunks: []map[string]interface{}{{"docnm_kwd": "test", "content_with_weight": "content"}},
|
||||
Total: 1,
|
||||
}, nil
|
||||
}
|
||||
func (m *mockChunkSvc) Get(*service.GetChunkRequest, string) (*service.GetChunkResponse, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
func (m *mockChunkSvc) List(*service.ListChunksRequest, string) (*service.ListChunksResponse, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
func (m *mockChunkSvc) UpdateChunk(*service.UpdateChunkRequest, string) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
func (m *mockChunkSvc) RemoveChunks(*service.RemoveChunksRequest, string) (int64, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func setupChunkRetrievalTest(userID string) (*gin.Engine, *mockChunkSvc) {
|
||||
mock := &mockChunkSvc{}
|
||||
h := &ChunkHandler{chunkService: mock}
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: userID})
|
||||
})
|
||||
r.POST("/api/v1/datasets/search", h.RetrievalTest)
|
||||
return r, mock
|
||||
}
|
||||
|
||||
func setupChunkRetrievalTestNoAuth() *gin.Engine {
|
||||
// Returns a router without the user middleware — used for error-path
|
||||
// tests that don't call the service.
|
||||
h := &ChunkHandler{}
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.POST("/api/v1/datasets/search", h.RetrievalTest)
|
||||
return r
|
||||
}
|
||||
|
||||
func TestChunkRetrieval_EmptyQuestion(t *testing.T) {
|
||||
r, _ := setupChunkRetrievalTest("user1")
|
||||
|
||||
body := `{"dataset_ids": ["kb1"], "question": ""}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Errorf("expected code 0, got %v: %q", resp["code"], resp["message"])
|
||||
}
|
||||
data, ok := resp["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected data to be object, got %T", resp["data"])
|
||||
}
|
||||
chunks, _ := data["chunks"].([]interface{})
|
||||
if chunks == nil || len(chunks) != 0 {
|
||||
t.Errorf("expected empty chunks array, got %v", chunks)
|
||||
}
|
||||
if total, _ := data["total"].(float64); total != 0 {
|
||||
t.Errorf("expected total 0, got %v", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkRetrieval_WhitespaceQuestion(t *testing.T) {
|
||||
r, _ := setupChunkRetrievalTest("user1")
|
||||
|
||||
body := `{"dataset_ids": ["kb1"], "question": " "}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Errorf("expected code 0, got %v", resp["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkRetrieval_TopKZero(t *testing.T) {
|
||||
r, _ := setupChunkRetrievalTest("user1")
|
||||
|
||||
body := `{"dataset_ids": ["kb1"], "question": "test", "top_k": 0}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msg, _ := resp["message"].(string); msg != "top_k must be greater than 0" {
|
||||
t.Errorf("expected 'top_k must be greater than 0', got %q", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkRetrieval_MissingDatasetIDs(t *testing.T) {
|
||||
r, _ := setupChunkRetrievalTest("user1")
|
||||
|
||||
body := `{"question": "test"}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkRetrieval_EmptyDatasetIDs(t *testing.T) {
|
||||
r, _ := setupChunkRetrievalTest("user1")
|
||||
|
||||
body := `{"dataset_ids": [], "question": "test"}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msg, _ := resp["message"].(string); msg != "kb_id array cannot be empty" {
|
||||
t.Errorf("expected 'kb_id array cannot be empty', got %q", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkRetrieval_NoAuth(t *testing.T) {
|
||||
r := setupChunkRetrievalTestNoAuth()
|
||||
|
||||
body := `{"dataset_ids": ["kb1"], "question": "test"}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d", w.Code)
|
||||
}
|
||||
// jsonError returns HTTP 200 with error code in body
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["code"] == float64(common.CodeSuccess) {
|
||||
t.Errorf("expected error code, got %v", resp["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkRetrieval_InvalidJSON(t *testing.T) {
|
||||
r, _ := setupChunkRetrievalTest("user1")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader("{invalid}"))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkRetrieval_Success(t *testing.T) {
|
||||
_, mock := setupChunkRetrievalTest("user1")
|
||||
mock.retrievalTestFn = func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
|
||||
if userID != "user1" {
|
||||
t.Errorf("expected userID 'user1', got %q", userID)
|
||||
}
|
||||
return &service.RetrievalTestResponse{
|
||||
Chunks: []map[string]interface{}{{"docnm_kwd": "result"}},
|
||||
DocAggs: []map[string]interface{}{{"doc_id": "1", "count": float64(1)}},
|
||||
Total: 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: "user1"})
|
||||
})
|
||||
h := &ChunkHandler{chunkService: mock}
|
||||
r.POST("/api/v1/datasets/search", h.RetrievalTest)
|
||||
|
||||
body := `{"dataset_ids": ["kb1"], "question": "test"}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected code 0, got %v: %q", resp["code"], resp["message"])
|
||||
}
|
||||
data, ok := resp["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected data object, got %T", resp["data"])
|
||||
}
|
||||
if total, _ := data["total"].(float64); total != 1 {
|
||||
t.Errorf("expected total 1, got %v", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkRetrieval_ServiceError(t *testing.T) {
|
||||
_, mock := setupChunkRetrievalTest("user1")
|
||||
mock.retrievalTestFn = func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
|
||||
return nil, errors.New("db connection refused")
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: "user1"})
|
||||
})
|
||||
h := &ChunkHandler{chunkService: mock}
|
||||
r.POST("/api/v1/datasets/search", h.RetrievalTest)
|
||||
|
||||
body := `{"dataset_ids": ["kb1"], "question": "test"}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
msg, _ := resp["message"].(string)
|
||||
if msg != "dataset search failed" {
|
||||
t.Errorf("expected generic error message, got %q", msg)
|
||||
}
|
||||
if strings.Contains(msg, "db connection refused") {
|
||||
t.Errorf("internal error details leaked to response: %q", msg)
|
||||
}
|
||||
}
|
||||
@@ -31,41 +31,72 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// searchbotLLM is the interface for LLM calls used by SearchbotHandler.
|
||||
// searchbotLLM is the interface for LLM calls used by SearchBotHandler.
|
||||
type searchbotLLM interface {
|
||||
Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error)
|
||||
}
|
||||
|
||||
// SearchbotRealLLM wraps ModelProviderService to implement searchbotLLM.
|
||||
type SearchbotRealLLM struct {
|
||||
// ChunkRetriever abstracts chunk retrieval for the searchbots handler.
|
||||
type ChunkRetriever interface {
|
||||
RetrievalTest(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error)
|
||||
}
|
||||
|
||||
// SearchBotRealLLM wraps ModelProviderService to implement searchbotLLM.
|
||||
type SearchBotRealLLM struct {
|
||||
Svc *service.ModelProviderService
|
||||
}
|
||||
|
||||
func (r *SearchbotRealLLM) Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) {
|
||||
driver, modelName, apiConfig, _, err := r.Svc.GetModelConfigFromProviderInstance(tenantID, entity.ModelTypeChat, modelID)
|
||||
func (r *SearchBotRealLLM) Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) {
|
||||
chatModel, err := r.Svc.GetChatModel(tenantID, modelID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chatModel := modelModule.NewChatModel(driver, &modelName, apiConfig)
|
||||
return chatModel.ModelDriver.ChatWithMessages(*chatModel.ModelName, messages, chatModel.APIConfig, config)
|
||||
}
|
||||
|
||||
// SearchbotRequest is the request body for POST /api/v1/searchbots/related_questions.
|
||||
type SearchbotRequest struct {
|
||||
// SearchBotRetrievalTestRequest is the request body for POST /api/v1/searchbots/retrieval_test.
|
||||
type SearchBotRetrievalTestRequest struct {
|
||||
KbIDs common.StringSlice `json:"kb_id" binding:"required"`
|
||||
Question string `json:"question" binding:"required"`
|
||||
Page *int `json:"page,omitempty"`
|
||||
Size *int `json:"size,omitempty"`
|
||||
DocIDs []string `json:"doc_ids,omitempty"`
|
||||
UseKG *bool `json:"use_kg,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
CrossLanguages []string `json:"cross_languages,omitempty"`
|
||||
SearchID *string `json:"search_id,omitempty"`
|
||||
MetaDataFilter map[string]interface{} `json:"meta_data_filter,omitempty"`
|
||||
TenantRerankID *string `json:"tenant_rerank_id,omitempty"`
|
||||
RerankID *string `json:"rerank_id,omitempty"`
|
||||
Keyword *bool `json:"keyword,omitempty"`
|
||||
SimilarityThreshold *float64 `json:"similarity_threshold,omitempty"`
|
||||
VectorSimilarityWeight *float64 `json:"vector_similarity_weight,omitempty"`
|
||||
// TODO: wire highlight to nlp Retrieval when engine supports highlightFields
|
||||
// Python: bot_api.py → retrieval(highlight=req.get("highlight"))
|
||||
// → search.py highlightFields → ES get_highlight()
|
||||
// Issue: https://github.com/infiniflow/ragflow/issues/15712
|
||||
// Highlight *bool `json:"highlight,omitempty"`
|
||||
}
|
||||
|
||||
// SearchBotRequest is the request body for POST /api/v1/searchbots/related_questions.
|
||||
type SearchBotRequest struct {
|
||||
Question string `json:"question" binding:"required"`
|
||||
SearchID string `json:"search_id,omitempty"`
|
||||
}
|
||||
|
||||
// SearchbotHandler handles POST /api/v1/searchbots/related_questions.
|
||||
type SearchbotHandler struct {
|
||||
// SearchBotHandler handles searchbot endpoints:
|
||||
// POST /api/v1/searchbots/related_questions
|
||||
// POST /api/v1/searchbots/retrieval_test
|
||||
type SearchBotHandler struct {
|
||||
searchSvc *service.SearchService
|
||||
tenantSvc *service.TenantService
|
||||
llm searchbotLLM
|
||||
chunkSvc ChunkRetriever
|
||||
}
|
||||
|
||||
// NewSearchbotHandler creates a new SearchbotHandler.
|
||||
func NewSearchbotHandler(searchSvc *service.SearchService, tenantSvc *service.TenantService, llm searchbotLLM) *SearchbotHandler {
|
||||
return &SearchbotHandler{searchSvc: searchSvc, tenantSvc: tenantSvc, llm: llm}
|
||||
// NewSearchBotHandler creates a new SearchBotHandler.
|
||||
func NewSearchBotHandler(searchSvc *service.SearchService, tenantSvc *service.TenantService, llm searchbotLLM, chunkSvc ChunkRetriever) *SearchBotHandler {
|
||||
return &SearchBotHandler{searchSvc: searchSvc, tenantSvc: tenantSvc, llm: llm, chunkSvc: chunkSvc}
|
||||
}
|
||||
|
||||
// Handle generates related search questions based on a user query.
|
||||
@@ -74,17 +105,17 @@ func NewSearchbotHandler(searchSvc *service.SearchService, tenantSvc *service.Te
|
||||
// @Tags searchbots
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body SearchbotRequest true "Request body"
|
||||
// @Param request body SearchBotRequest true "Request body"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/v1/searchbots/related_questions [post]
|
||||
func (h *SearchbotHandler) Handle(c *gin.Context) {
|
||||
func (h *SearchBotHandler) Handle(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
var req SearchbotRequest
|
||||
var req SearchBotRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
@@ -148,21 +179,133 @@ func (h *SearchbotHandler) Handle(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"data": questions,
|
||||
"message": "",
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// RetrievalTest performs a retrieval test against specified knowledge bases.
|
||||
// @Summary Retrieval Test
|
||||
// @Description Test document retrieval across knowledge bases with optional filters, reranking, and KG search.
|
||||
// @Tags searchbots
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body SearchBotRetrievalTestRequest true "Retrieval test parameters"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/v1/searchbots/retrieval_test [post]
|
||||
func (h *SearchBotHandler) RetrievalTest(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"code": errorCode, "data": nil, "message": errorMessage})
|
||||
return
|
||||
}
|
||||
|
||||
var req SearchBotRetrievalTestRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Filter out empty strings from KbIDs before validation.
|
||||
filtered := make(common.StringSlice, 0, len(req.KbIDs))
|
||||
for _, id := range req.KbIDs {
|
||||
if strings.TrimSpace(id) != "" {
|
||||
filtered = append(filtered, id)
|
||||
}
|
||||
}
|
||||
req.KbIDs = filtered
|
||||
|
||||
if len(req.KbIDs) == 0 || req.Question == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": "kb_id and question are required"})
|
||||
return
|
||||
}
|
||||
|
||||
applyRetrievalDefaults(&req)
|
||||
|
||||
if req.TopK != nil && *req.TopK <= 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": "top_k must be greater than 0"})
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := toRetrievalServiceRequest(&req)
|
||||
|
||||
result, err := h.chunkSvc.RetrievalTest(svcReq, user.ID)
|
||||
if err != nil {
|
||||
common.Warn("searchbot retrieval test failed", zap.String("error", err.Error()))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "data": nil, "message": "retrieval test failed"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"code": int(common.CodeSuccess), "data": result, "message": "success"})
|
||||
}
|
||||
|
||||
// toRetrievalServiceRequest maps the handler DTO to the service DTO.
|
||||
// The two structs differ in KbIDs (StringSlice → []string) and
|
||||
// MetaDataFilter (→ Filter) to maintain Python API compatibility.
|
||||
func toRetrievalServiceRequest(h *SearchBotRetrievalTestRequest) *service.RetrievalTestRequest {
|
||||
return &service.RetrievalTestRequest{
|
||||
Datasets: common.StringSlice(h.KbIDs),
|
||||
Question: h.Question,
|
||||
Page: h.Page,
|
||||
Size: h.Size,
|
||||
DocIDs: h.DocIDs,
|
||||
UseKG: h.UseKG,
|
||||
TopK: h.TopK,
|
||||
CrossLanguages: h.CrossLanguages,
|
||||
SearchID: h.SearchID,
|
||||
Filter: h.MetaDataFilter,
|
||||
TenantRerankID: h.TenantRerankID,
|
||||
RerankID: h.RerankID,
|
||||
Keyword: h.Keyword,
|
||||
SimilarityThreshold: h.SimilarityThreshold,
|
||||
VectorSimilarityWeight: h.VectorSimilarityWeight,
|
||||
}
|
||||
}
|
||||
|
||||
// ptrFloat64 returns a pointer to a float64 value.
|
||||
func ptrFloat64(v float64) *float64 { return &v }
|
||||
|
||||
// applyRetrievalDefaults fills in default values for optional fields,
|
||||
// matching Python bot_api.py retrieval_test endpoint.
|
||||
func applyRetrievalDefaults(req *SearchBotRetrievalTestRequest) {
|
||||
if req.Page == nil {
|
||||
v := 1
|
||||
req.Page = &v
|
||||
}
|
||||
if req.Size == nil {
|
||||
v := 30
|
||||
req.Size = &v
|
||||
}
|
||||
if req.TopK == nil {
|
||||
v := 1024
|
||||
req.TopK = &v
|
||||
}
|
||||
if req.UseKG == nil {
|
||||
v := false
|
||||
req.UseKG = &v
|
||||
}
|
||||
if req.Keyword == nil {
|
||||
v := false
|
||||
req.Keyword = &v
|
||||
}
|
||||
if req.SimilarityThreshold == nil {
|
||||
v := 0.0
|
||||
req.SimilarityThreshold = &v
|
||||
}
|
||||
if req.VectorSimilarityWeight == nil {
|
||||
v := 0.3
|
||||
req.VectorSimilarityWeight = &v
|
||||
}
|
||||
}
|
||||
|
||||
var relatedQuestionLineRe = regexp.MustCompile(`^\d+\.\s`)
|
||||
|
||||
// parseRelatedQuestions extracts numbered list items from an LLM response.
|
||||
// Lines matching "^N. " are extracted and the number prefix is stripped.
|
||||
func parseRelatedQuestions(text string) []string {
|
||||
lineRe := regexp.MustCompile(`^\d+\.\s`)
|
||||
var result []string
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
if lineRe.MatchString(line) {
|
||||
result = append(result, lineRe.ReplaceAllString(line, ""))
|
||||
if relatedQuestionLineRe.MatchString(line) {
|
||||
result = append(result, relatedQuestionLineRe.ReplaceAllString(line, ""))
|
||||
}
|
||||
}
|
||||
if result == nil {
|
||||
|
||||
@@ -18,17 +18,390 @@ package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/entity"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
"ragflow/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// mockChunkService implements ChunkRetriever for testing.
|
||||
// It captures the last request received so tests can verify field mapping.
|
||||
type mockChunkService struct {
|
||||
retrievalTestFn func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error)
|
||||
LastReq *service.RetrievalTestRequest
|
||||
LastUserID string
|
||||
}
|
||||
|
||||
func (m *mockChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
|
||||
m.LastReq = req
|
||||
m.LastUserID = userID
|
||||
if m.retrievalTestFn != nil {
|
||||
return m.retrievalTestFn(req, userID)
|
||||
}
|
||||
return &service.RetrievalTestResponse{
|
||||
Chunks: []map[string]interface{}{{"docnm_kwd": "test", "content_with_weight": "content"}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func setupSearchbotsTest(userID string) (*SearchBotHandler, *mockChunkService, *gin.Engine) {
|
||||
mockSvc := &mockChunkService{}
|
||||
h := &SearchBotHandler{
|
||||
chunkSvc: mockSvc,
|
||||
}
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: userID})
|
||||
})
|
||||
r.POST("/api/v1/searchbots/retrieval_test", h.RetrievalTest)
|
||||
return h, mockSvc, r
|
||||
}
|
||||
|
||||
func TestSearchBotsRetrieval_Basic(t *testing.T) {
|
||||
_, mockSvc, r := setupSearchbotsTest("user1")
|
||||
|
||||
body := `{"kb_id": ["kb1"], "question": "test question"}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Errorf("expected code 0, got %v", resp["code"])
|
||||
}
|
||||
if msg, _ := resp["message"].(string); msg != "success" {
|
||||
t.Errorf("expected message 'success', got %q", msg)
|
||||
}
|
||||
// Verify field mapping: handler → service request
|
||||
if mockSvc.LastReq == nil {
|
||||
t.Fatal("RetrievalTest was not called")
|
||||
}
|
||||
if len(mockSvc.LastReq.Datasets) != 1 || mockSvc.LastReq.Datasets[0] != "kb1" {
|
||||
t.Errorf("Datasets = %v, want [\"kb1\"]", mockSvc.LastReq.Datasets)
|
||||
}
|
||||
if mockSvc.LastReq.Question != "test question" {
|
||||
t.Errorf("Question = %q, want \"test question\"", mockSvc.LastReq.Question)
|
||||
}
|
||||
if mockSvc.LastUserID != "user1" {
|
||||
t.Errorf("userID = %q, want \"user1\"", mockSvc.LastUserID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchBotsRetrieval_MissingKbID(t *testing.T) {
|
||||
_, _, r := setupSearchbotsTest("user1")
|
||||
body := `{"question": "test"}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
msg, _ := resp["message"].(string)
|
||||
if msg == "" || msg == "success" {
|
||||
t.Errorf("expected validation error message, got %q", msg)
|
||||
}
|
||||
if !strings.Contains(msg, "KbIDs") || !strings.Contains(msg, "required") {
|
||||
t.Errorf("expected message to mention 'KbIDs' and 'required', got %q", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchBotsRetrieval_MissingQuestion(t *testing.T) {
|
||||
_, _, r := setupSearchbotsTest("user1")
|
||||
body := `{"kb_id": ["kb1"]}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
msg, _ := resp["message"].(string)
|
||||
if msg == "" || msg == "success" {
|
||||
t.Errorf("expected validation error message, got %q", msg)
|
||||
}
|
||||
if !strings.Contains(msg, "Question") || !strings.Contains(msg, "required") {
|
||||
t.Errorf("expected message to mention 'Question' and 'required', got %q", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchBotsRetrieval_NoAuth(t *testing.T) {
|
||||
h := NewSearchBotHandler(nil, nil, nil, &mockChunkService{})
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.POST("/api/v1/searchbots/retrieval_test", h.RetrievalTest)
|
||||
w := httptest.NewRecorder()
|
||||
body := `{"kb_id": ["kb1"], "question": "test"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchBotsRetrieval_ServiceError(t *testing.T) {
|
||||
h, _, r := setupSearchbotsTest("user1")
|
||||
h.chunkSvc = &mockChunkService{
|
||||
retrievalTestFn: func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
|
||||
return nil, errors.New("db error")
|
||||
},
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
body := `{"kb_id": ["kb1"], "question": "test"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500, got %d", w.Code)
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
code, _ := resp["code"].(float64)
|
||||
if code == 0 {
|
||||
t.Errorf("expected non-zero error code, got %v", code)
|
||||
}
|
||||
msg, _ := resp["message"].(string)
|
||||
if msg == "" || msg == "success" {
|
||||
t.Errorf("expected error message, got %q", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchBotsRetrieval_KbIDSingleString(t *testing.T) {
|
||||
// Verify "kb1" (string) is accepted and converted to []string{"kb1"}
|
||||
_, mockSvc, r := setupSearchbotsTest("user1")
|
||||
|
||||
body := `{"kb_id": "kb1", "question": "test"}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if mockSvc.LastReq == nil {
|
||||
t.Fatal("RetrievalTest was not called")
|
||||
}
|
||||
if len(mockSvc.LastReq.Datasets) != 1 || mockSvc.LastReq.Datasets[0] != "kb1" {
|
||||
t.Errorf("Datasets = %v, want [\"kb1\"]", mockSvc.LastReq.Datasets)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchBotsRetrieval_KbIDArray(t *testing.T) {
|
||||
// Verify ["a","b"] (array) still works
|
||||
_, mockSvc, r := setupSearchbotsTest("user1")
|
||||
|
||||
body := `{"kb_id": ["a","b"], "question": "test"}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if mockSvc.LastReq == nil {
|
||||
t.Fatal("RetrievalTest was not called")
|
||||
}
|
||||
if len(mockSvc.LastReq.Datasets) != 2 || mockSvc.LastReq.Datasets[0] != "a" || mockSvc.LastReq.Datasets[1] != "b" {
|
||||
t.Errorf("Datasets = %v, want [\"a\",\"b\"]", mockSvc.LastReq.Datasets)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchBotsRetrieval_InvalidJSON(t *testing.T) {
|
||||
_, _, r := setupSearchbotsTest("user1")
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader("{invalid}"))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchBotsRetrieval_EmptyStringKbID(t *testing.T) {
|
||||
_, _, r := setupSearchbotsTest("user1")
|
||||
body := `{"kb_id": "", "question": "test"}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if msg, _ := resp["message"].(string); msg != "kb_id and question are required" {
|
||||
t.Errorf("expected message 'kb_id and question are required', got %q", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchBotsRetrieval_WhitespaceOnlyKbID(t *testing.T) {
|
||||
_, _, r := setupSearchbotsTest("user1")
|
||||
body := `{"kb_id": " ", "question": "test"}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if msg, _ := resp["message"].(string); msg != "kb_id and question are required" {
|
||||
t.Errorf("expected message 'kb_id and question are required', got %q", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchBotsRetrieval_DefaultsApplied(t *testing.T) {
|
||||
// Verify that when optional fields are omitted, the handler applies
|
||||
// defaults matching Python bot_api.py retrieval_test endpoint.
|
||||
_, mockSvc, r := setupSearchbotsTest("user1")
|
||||
|
||||
body := `{"kb_id": ["kb1"], "question": "does this default?"}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if mockSvc.LastReq == nil {
|
||||
t.Fatal("RetrievalTest was not called")
|
||||
}
|
||||
|
||||
svcReq := mockSvc.LastReq
|
||||
if svcReq.Page == nil || *svcReq.Page != 1 {
|
||||
t.Errorf("Page = %v, want 1", nullableInt(svcReq.Page))
|
||||
}
|
||||
if svcReq.Size == nil || *svcReq.Size != 30 {
|
||||
t.Errorf("Size = %v, want 30", nullableInt(svcReq.Size))
|
||||
}
|
||||
if svcReq.TopK == nil || *svcReq.TopK != 1024 {
|
||||
t.Errorf("TopK = %v, want 1024", nullableInt(svcReq.TopK))
|
||||
}
|
||||
if svcReq.UseKG == nil || *svcReq.UseKG != false {
|
||||
t.Errorf("UseKG = %v, want false", nullableBool(svcReq.UseKG))
|
||||
}
|
||||
if svcReq.Keyword == nil || *svcReq.Keyword != false {
|
||||
t.Errorf("Keyword = %v, want false", nullableBool(svcReq.Keyword))
|
||||
}
|
||||
if svcReq.SimilarityThreshold == nil || *svcReq.SimilarityThreshold != 0.0 {
|
||||
t.Errorf("SimilarityThreshold = %v, want 0.0", nullableFloat(svcReq.SimilarityThreshold))
|
||||
}
|
||||
if svcReq.VectorSimilarityWeight == nil || *svcReq.VectorSimilarityWeight != 0.3 {
|
||||
t.Errorf("VectorSimilarityWeight = %v, want 0.3", nullableFloat(svcReq.VectorSimilarityWeight))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchBotsRetrieval_TopKZero(t *testing.T) {
|
||||
_, _, r := setupSearchbotsTest("user1")
|
||||
body := `{"kb_id": ["kb1"], "question": "test", "top_k": 0}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if msg, _ := resp["message"].(string); msg != "top_k must be greater than 0" {
|
||||
t.Errorf("expected message 'top_k must be greater than 0', got %q", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchBotsRetrieval_TopKNegative(t *testing.T) {
|
||||
_, _, r := setupSearchbotsTest("user1")
|
||||
body := `{"kb_id": ["kb1"], "question": "test", "top_k": -1}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
if msg := jsonDecodeMessage(t, w.Body.Bytes()); msg != "top_k must be greater than 0" {
|
||||
t.Errorf("expected message 'top_k must be greater than 0', got %q", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func jsonDecodeMessage(t *testing.T, body []byte) string {
|
||||
t.Helper()
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
msg, _ := resp["message"].(string)
|
||||
return msg
|
||||
}
|
||||
|
||||
func nullableInt(p *int) string {
|
||||
if p == nil { return "nil" }
|
||||
return fmt.Sprintf("%d", *p)
|
||||
}
|
||||
func nullableBool(p *bool) string {
|
||||
if p == nil { return "nil" }
|
||||
return fmt.Sprintf("%v", *p)
|
||||
}
|
||||
func nullableFloat(p *float64) string {
|
||||
if p == nil { return "nil" }
|
||||
return fmt.Sprintf("%v", *p)
|
||||
}
|
||||
|
||||
|
||||
func TestSearchBotsRetrieval_EmptyQuestion(t *testing.T) {
|
||||
// Send kb_id but empty question — caught by binding:"required" on the DTO.
|
||||
_, _, r := setupSearchbotsTest("user1")
|
||||
body := `{"kb_id": ["kb1"], "question": ""}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
msg := jsonDecodeMessage(t, w.Body.Bytes())
|
||||
if !strings.Contains(msg, "Question") || !strings.Contains(msg, "required") {
|
||||
t.Errorf("expected validation error mentioning Question and required, got %q", msg)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// fakeSearchbotLLM implements searchbotLLM for testing.
|
||||
type fakeSearchbotLLM struct {
|
||||
response string
|
||||
@@ -42,7 +415,7 @@ func (f *fakeSearchbotLLM) Chat(tenantID, modelID string, messages []modelModule
|
||||
return &modelModule.ChatResponse{Answer: &f.response}, nil
|
||||
}
|
||||
|
||||
func setupSearchbotRequest(body string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
func setupSearchBotRequest(body string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -54,17 +427,17 @@ func setupSearchbotRequest(body string) (*gin.Context, *httptest.ResponseRecorde
|
||||
return c, w
|
||||
}
|
||||
|
||||
// TestSearchbotHandler_Success verifies the happy path.
|
||||
func TestSearchbotHandler_Success(t *testing.T) {
|
||||
// TestSearchBotHandler_Success verifies the happy path.
|
||||
func TestSearchBotHandler_Success(t *testing.T) {
|
||||
llm := &fakeSearchbotLLM{
|
||||
response: `Here are some related questions:
|
||||
1. How do EV impact environment?
|
||||
2. What are advantages of EV?
|
||||
3. Cost of EV?`,
|
||||
}
|
||||
h := NewSearchbotHandler(nil, nil, llm)
|
||||
h := NewSearchBotHandler(nil, nil, llm, nil)
|
||||
|
||||
c, w := setupSearchbotRequest(`{"question": "EV benefits"}`)
|
||||
c, w := setupSearchBotRequest(`{"question": "EV benefits"}`)
|
||||
h.Handle(c)
|
||||
|
||||
var resp map[string]interface{}
|
||||
@@ -72,6 +445,9 @@ func TestSearchbotHandler_Success(t *testing.T) {
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected code 0, got %v: %v", resp["code"], resp["message"])
|
||||
}
|
||||
if msg, _ := resp["message"].(string); msg != "success" {
|
||||
t.Errorf("expected message 'success', got %q", msg)
|
||||
}
|
||||
|
||||
questions, ok := resp["data"].([]interface{})
|
||||
if !ok {
|
||||
@@ -85,14 +461,14 @@ func TestSearchbotHandler_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchbotHandler_EmptyResponse verifies empty LLM response returns empty list.
|
||||
func TestSearchbotHandler_EmptyResponse(t *testing.T) {
|
||||
// TestSearchBotHandler_EmptyResponse verifies empty LLM response returns empty list.
|
||||
func TestSearchBotHandler_EmptyResponse(t *testing.T) {
|
||||
llm := &fakeSearchbotLLM{
|
||||
response: "No related questions found.",
|
||||
}
|
||||
h := NewSearchbotHandler(nil, nil, llm)
|
||||
h := NewSearchBotHandler(nil, nil, llm, nil)
|
||||
|
||||
c, w := setupSearchbotRequest(`{"question": "EV benefits"}`)
|
||||
c, w := setupSearchBotRequest(`{"question": "EV benefits"}`)
|
||||
h.Handle(c)
|
||||
|
||||
var resp map[string]interface{}
|
||||
@@ -109,14 +485,14 @@ func TestSearchbotHandler_EmptyResponse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchbotHandler_LLMFailure verifies error handling on LLM failure.
|
||||
func TestSearchbotHandler_LLMFailure(t *testing.T) {
|
||||
// TestSearchBotHandler_LLMFailure verifies error handling on LLM failure.
|
||||
func TestSearchBotHandler_LLMFailure(t *testing.T) {
|
||||
llm := &fakeSearchbotLLM{
|
||||
err: errFake{msg: "LLM unavailable"},
|
||||
}
|
||||
h := NewSearchbotHandler(nil, nil, llm)
|
||||
h := NewSearchBotHandler(nil, nil, llm, nil)
|
||||
|
||||
c, w := setupSearchbotRequest(`{"question": "EV benefits"}`)
|
||||
c, w := setupSearchBotRequest(`{"question": "EV benefits"}`)
|
||||
h.Handle(c)
|
||||
|
||||
var resp map[string]interface{}
|
||||
@@ -127,12 +503,12 @@ func TestSearchbotHandler_LLMFailure(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchbotHandler_MissingQuestion verifies validation.
|
||||
func TestSearchbotHandler_MissingQuestion(t *testing.T) {
|
||||
// TestSearchBotHandler_MissingQuestion verifies validation.
|
||||
func TestSearchBotHandler_MissingQuestion(t *testing.T) {
|
||||
llm := &fakeSearchbotLLM{response: "dummy"}
|
||||
h := NewSearchbotHandler(nil, nil, llm)
|
||||
h := NewSearchBotHandler(nil, nil, llm, nil)
|
||||
|
||||
c, w := setupSearchbotRequest(`{}`)
|
||||
c, w := setupSearchBotRequest(`{}`)
|
||||
h.Handle(c)
|
||||
|
||||
var resp map[string]interface{}
|
||||
|
||||
@@ -42,7 +42,7 @@ type Router struct {
|
||||
skillSearchHandler *handler.SkillSearchHandler
|
||||
providerHandler *handler.ProviderHandler
|
||||
agentHandler *handler.AgentHandler
|
||||
relatedQuestionsHandler *handler.SearchbotHandler
|
||||
searchBotHandler *handler.SearchBotHandler
|
||||
difyRetrievalHandler *handler.DifyRetrievalHandler
|
||||
pluginHandler *handler.PluginHandler
|
||||
}
|
||||
@@ -68,7 +68,7 @@ func NewRouter(
|
||||
skillSearchHandler *handler.SkillSearchHandler,
|
||||
providerHandler *handler.ProviderHandler,
|
||||
agentHandler *handler.AgentHandler,
|
||||
relatedQuestionsHandler *handler.SearchbotHandler,
|
||||
searchBotHandler *handler.SearchBotHandler,
|
||||
difyRetrievalHandler *handler.DifyRetrievalHandler,
|
||||
pluginHandler *handler.PluginHandler,
|
||||
) *Router {
|
||||
@@ -92,7 +92,7 @@ func NewRouter(
|
||||
skillSearchHandler: skillSearchHandler,
|
||||
providerHandler: providerHandler,
|
||||
agentHandler: agentHandler,
|
||||
relatedQuestionsHandler: relatedQuestionsHandler,
|
||||
searchBotHandler: searchBotHandler,
|
||||
difyRetrievalHandler: difyRetrievalHandler,
|
||||
pluginHandler: pluginHandler,
|
||||
}
|
||||
@@ -226,7 +226,8 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
}
|
||||
|
||||
// Searchbot routes
|
||||
v1.POST("/searchbots/related_questions", r.relatedQuestionsHandler.Handle)
|
||||
v1.POST("/searchbots/related_questions", r.searchBotHandler.Handle)
|
||||
v1.POST("/searchbots/retrieval_test", r.searchBotHandler.RetrievalTest)
|
||||
|
||||
// Dataset routes
|
||||
datasets := v1.Group("/datasets")
|
||||
|
||||
@@ -19,12 +19,19 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/entity/models"
|
||||
"ragflow/internal/server"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/engine/types"
|
||||
"ragflow/internal/server"
|
||||
"ragflow/internal/service/nlp"
|
||||
"ragflow/internal/tokenizer"
|
||||
"ragflow/internal/utility"
|
||||
)
|
||||
@@ -54,11 +61,391 @@ func NewChunkService() *ChunkService {
|
||||
}
|
||||
}
|
||||
|
||||
// GetChunkRequest request for getting a chunk by ID.
|
||||
// RetrievalTestRequest retrieval test request
|
||||
type RetrievalTestRequest struct {
|
||||
Datasets common.StringSlice `json:"dataset_ids" binding:"required"` // string or []string
|
||||
Question string `json:"question"`
|
||||
Page *int `json:"page,omitempty"`
|
||||
Size *int `json:"size,omitempty"`
|
||||
DocIDs []string `json:"doc_ids,omitempty"`
|
||||
UseKG *bool `json:"use_kg,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
CrossLanguages []string `json:"cross_languages,omitempty"`
|
||||
SearchID *string `json:"search_id,omitempty"`
|
||||
Filter map[string]interface{} `json:"meta_data_filter,omitempty"`
|
||||
TenantRerankID *string `json:"tenant_rerank_id,omitempty"`
|
||||
RerankID *string `json:"rerank_id,omitempty"`
|
||||
Keyword *bool `json:"keyword,omitempty"`
|
||||
SimilarityThreshold *float64 `json:"similarity_threshold,omitempty"`
|
||||
VectorSimilarityWeight *float64 `json:"vector_similarity_weight,omitempty"`
|
||||
}
|
||||
|
||||
// RetrievalTestResponse retrieval test response
|
||||
type RetrievalTestResponse struct {
|
||||
Chunks []map[string]interface{} `json:"chunks"`
|
||||
DocAggs []map[string]interface{} `json:"doc_aggs"`
|
||||
Labels *map[string]float64 `json:"labels"`
|
||||
Total int64 `json:"total"`
|
||||
}
|
||||
|
||||
// RetrievalTest performs retrieval test for a given question against specified knowledge bases.
|
||||
//
|
||||
// Flow:
|
||||
// 1. Validate kbs permissions and embedding model
|
||||
// 2. Apply metadata filter if specified (auto/semi_auto uses LLM, manual uses provided conditions)
|
||||
// 3. Apply cross_languages transformation if requested (translate question)
|
||||
// 4. Apply keyword extraction if requested (append keywords to question)
|
||||
// 5. Get rank features via LabelQuestion() - tag-based weights or pagerank_fld fallback
|
||||
// 6. Call RetrievalService.Retrieval() which:
|
||||
// - Computes query embedding
|
||||
// - Performs hybrid search (text + vector) with rank features
|
||||
// - Reranks results
|
||||
// - Builds doc_aggs by aggregating chunks per document
|
||||
// 7. knowledge graph retrieval (not implemented)
|
||||
// 8. Apply retrieval by children to group child chunks under parent chunks
|
||||
func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) (*RetrievalTestResponse, error) {
|
||||
common.Info("RetrievalTest started", zap.String("userID", userID), zap.Any("kbID", req.Datasets), zap.String("question", req.Question))
|
||||
|
||||
common.Debug(fmt.Sprintf("RetrievalTest request:\n"+
|
||||
" kbID=%v\n"+
|
||||
" question=%s\n"+
|
||||
" page=%v, size=%v\n"+
|
||||
" docIDs=%v\n"+
|
||||
" useKG=%v, topK=%v\n"+
|
||||
" crossLanguages=%v\n"+
|
||||
" searchID=%v\n"+
|
||||
" filter=%v\n"+
|
||||
" tenantRerankID=%v\n"+
|
||||
" rerankID=%v\n"+
|
||||
" keyword=%v\n"+
|
||||
" similarityThreshold=%v, vectorSimilarityWeight=%v",
|
||||
req.Datasets, req.Question,
|
||||
ptrString(req.Page), ptrString(req.Size), req.DocIDs,
|
||||
ptrString(req.UseKG), ptrString(req.TopK), req.CrossLanguages, ptrString(req.SearchID),
|
||||
req.Filter,
|
||||
ptrString(req.TenantRerankID), ptrString(req.RerankID),
|
||||
ptrString(req.Keyword),
|
||||
ptrString(req.SimilarityThreshold), ptrString(req.VectorSimilarityWeight)))
|
||||
|
||||
if req.Question == "" {
|
||||
return nil, fmt.Errorf("question is required")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
tenantIDs, kbRecords, err := s.validateKBs(userID, req.Datasets)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
docIDs, err := s.resolveMetaFilter(ctx, req.SearchID, req.Filter, req.Question, req.DocIDs, req.Datasets, tenantIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modifiedQuestion, err := s.transformQuestion(ctx, req.Question, req.CrossLanguages, req.Keyword, tenantIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get tag-based rank features via LabelQuestion
|
||||
metadataSvc := NewMetadataService()
|
||||
labels := metadataSvc.LabelQuestion(modifiedQuestion, kbRecords)
|
||||
common.Debug("LabelQuestion result", zap.Any("labels", labels))
|
||||
|
||||
embeddingModel, err := s.resolveEmbeddingModel(tenantIDs[0], kbRecords[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rerankModel, err := s.resolveRerankModel(tenantIDs[0], req.TenantRerankID, req.RerankID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
retrievalReq := &nlp.RetrievalRequest{
|
||||
TenantIDs: tenantIDs,
|
||||
Question: modifiedQuestion,
|
||||
KbIDs: []string(req.Datasets),
|
||||
DocIDs: docIDs,
|
||||
Page: getPageNum(req.Page, 1),
|
||||
PageSize: getPageSize(req.Size, 30),
|
||||
Top: req.TopK,
|
||||
SimilarityThreshold: req.SimilarityThreshold,
|
||||
VectorSimilarityWeight: req.VectorSimilarityWeight,
|
||||
RerankModel: rerankModel,
|
||||
RankFeature: &labels,
|
||||
EmbeddingModel: embeddingModel,
|
||||
}
|
||||
|
||||
// Call RetrievalService to perform retrieval
|
||||
retrievalResult, err := nlp.NewRetrievalService(s.docEngine, s.documentDAO).Retrieval(ctx, retrievalReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("retrieval search failed: %w", err)
|
||||
}
|
||||
|
||||
filteredChunks := retrievalResult.Chunks
|
||||
|
||||
// Handle knowledge graph retrieval
|
||||
// TODO: KG retrieval requires GraphRAG infrastructure which is not yet implemented in Go
|
||||
if req.UseKG != nil && *req.UseKG {
|
||||
common.Warn("use_kg is not yet implemented in Go - skipping KG retrieval")
|
||||
}
|
||||
|
||||
// Apply retrieval_by_children - aggregate child chunks into parent chunks
|
||||
filteredChunks = nlp.RetrievalByChildren(filteredChunks, tenantIDs, s.docEngine, ctx)
|
||||
|
||||
// Remove vector field from each chunk
|
||||
for i := range filteredChunks {
|
||||
delete(filteredChunks[i], "vector")
|
||||
}
|
||||
|
||||
common.Info("RetrievalTest completed", zap.String("userID", userID), zap.Any("kbID", req.Datasets), zap.String("question", req.Question), zap.Int64("chunkCount", int64(len(filteredChunks))))
|
||||
|
||||
return &RetrievalTestResponse{
|
||||
Chunks: filteredChunks,
|
||||
DocAggs: retrievalResult.DocAggs,
|
||||
Labels: &labels,
|
||||
Total: int64(len(filteredChunks)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
// validateKBs resolves tenant IDs and KB records for the given dataset IDs.
|
||||
func (s *ChunkService) validateKBs(userID string, datasetIDs []string) ([]string, []*entity.Knowledgebase, error) {
|
||||
tenants, err := s.userTenantDAO.GetByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get user tenants: %w", err)
|
||||
}
|
||||
if len(tenants) == 0 {
|
||||
return nil, nil, fmt.Errorf("user has no accessible tenants")
|
||||
}
|
||||
common.Debug("Retrieved user tenants from database", zap.String("userID", userID), zap.Int("tenantCount", len(tenants)))
|
||||
|
||||
var tenantIDs []string
|
||||
var kbRecords []*entity.Knowledgebase
|
||||
for _, datasetID := range datasetIDs {
|
||||
found := false
|
||||
for _, tenant := range tenants {
|
||||
kb, err := s.kbDAO.GetByIDAndTenantID(datasetID, tenant.TenantID)
|
||||
if err == nil && kb != nil {
|
||||
common.Debug("Found knowledge base in database",
|
||||
zap.String("datasetID", datasetID),
|
||||
zap.String("tenantID", tenant.TenantID),
|
||||
zap.String("kbName", kb.Name),
|
||||
zap.String("embdID", kb.EmbdID))
|
||||
tenantIDs = append(tenantIDs, tenant.TenantID)
|
||||
kbRecords = append(kbRecords, kb)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, nil, fmt.Errorf("only owner of dataset is authorized for this operation")
|
||||
}
|
||||
}
|
||||
if len(kbRecords) > 1 {
|
||||
firstEmbdID := kbRecords[0].EmbdID
|
||||
for i := 1; i < len(kbRecords); i++ {
|
||||
if kbRecords[i].EmbdID != firstEmbdID {
|
||||
return nil, nil, fmt.Errorf("cannot retrieve across datasets with different embedding models")
|
||||
}
|
||||
}
|
||||
}
|
||||
return tenantIDs, kbRecords, nil
|
||||
}
|
||||
|
||||
// resolveMetaFilter resolves a metadata filter from search_id and applies it.
|
||||
func (s *ChunkService) resolveMetaFilter(ctx context.Context, searchID *string, initialFilter map[string]interface{}, question string, docIDs []string, datasetIDs []string, tenantIDs []string) ([]string, error) {
|
||||
var chatID string
|
||||
var chatModelForFilter *models.ChatModel
|
||||
filter := initialFilter
|
||||
|
||||
if searchID != nil && *searchID != "" {
|
||||
searchDetail, err := s.searchService.GetDetail(*searchID)
|
||||
if err != nil {
|
||||
common.Warn("Failed to get search detail for search_id, proceeding without it", zap.String("searchID", *searchID), zap.Error(err))
|
||||
} else if searchConfig, ok := searchDetail["search_config"].(entity.JSONMap); ok && searchConfig != nil {
|
||||
if searchMetaFilter, ok := searchConfig["meta_data_filter"].(map[string]interface{}); ok {
|
||||
filter = searchMetaFilter
|
||||
}
|
||||
chatID, _ = searchConfig["chat_id"].(string)
|
||||
} else {
|
||||
common.Warn("No search_config found in search detail", zap.String("searchID", *searchID))
|
||||
}
|
||||
}
|
||||
if filter != nil {
|
||||
method, _ := filter["method"].(string)
|
||||
if method == "auto" || method == "semi_auto" {
|
||||
modelProviderSvc := NewModelProviderService()
|
||||
if chatID != "" {
|
||||
driver, mdlName, apiConfig, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeChat, chatID)
|
||||
if getErr != nil {
|
||||
common.Warn("Failed to get chat model from search_config chat_id, using tenant default", zap.String("chatID", chatID), zap.Error(getErr))
|
||||
} else {
|
||||
chatModelForFilter = models.NewChatModel(driver, &mdlName, apiConfig)
|
||||
common.Info("Fetched chat model (from search_config) for metadata filter",
|
||||
zap.String("chatID", chatID), zap.String("tenantID", tenantIDs[0]))
|
||||
}
|
||||
}
|
||||
if chatModelForFilter == nil {
|
||||
tenantSvc := NewTenantService()
|
||||
modelName, err := tenantSvc.GetDefaultModelName(tenantIDs[0], entity.ModelTypeChat)
|
||||
if err != nil || modelName == "" {
|
||||
common.Warn("Failed to get tenant default chat model name for meta_data_filter", zap.Error(err))
|
||||
} else {
|
||||
driver, mdlName, apiConfig, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeChat, modelName)
|
||||
if getErr != nil {
|
||||
common.Warn("Failed to get chat model for meta_data_filter", zap.Error(getErr))
|
||||
} else {
|
||||
chatModelForFilter = models.NewChatModel(driver, &mdlName, apiConfig)
|
||||
common.Info("Fetched chat model (tenant default) for metadata filter",
|
||||
zap.String("tenantID", tenantIDs[0]), zap.String("modelName", modelName))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out := make([]string, len(docIDs))
|
||||
copy(out, docIDs)
|
||||
if filter != nil {
|
||||
metadataSvc := NewMetadataService()
|
||||
flattedMeta, err := metadataSvc.GetFlattedMetaByKBs([]string(datasetIDs))
|
||||
if err != nil {
|
||||
common.Warn("Failed to get flatted metadata", zap.Error(err))
|
||||
} else {
|
||||
common.Info("metadata filter conditions", zap.Any("filter", filter))
|
||||
filteredDocIDs, _ := ApplyMetaDataFilter(ctx, filter, flattedMeta, question, chatModelForFilter, docIDs, []string(datasetIDs))
|
||||
out = filteredDocIDs
|
||||
common.Info("ApplyMetaDataFilter result", zap.Strings("docIDs", out))
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// transformQuestion applies cross-languages translation and keyword extraction.
|
||||
func (s *ChunkService) transformQuestion(ctx context.Context, question string, crossLanguages []string, keyword *bool, tenantIDs []string) (string, error) {
|
||||
modifiedQuestion := question
|
||||
if len(crossLanguages) == 0 && (keyword == nil || !*keyword) {
|
||||
return modifiedQuestion, nil
|
||||
}
|
||||
tenantSvc := NewTenantService()
|
||||
modelProviderSvc := NewModelProviderService()
|
||||
modelName, err := tenantSvc.GetDefaultModelName(tenantIDs[0], "chat")
|
||||
if err != nil || modelName == "" {
|
||||
common.Warn("Failed to get default chat model name for LLM transformations", zap.Error(err))
|
||||
return question, nil
|
||||
}
|
||||
driver, mdlName, apiConfig, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeChat, modelName)
|
||||
if getErr != nil {
|
||||
common.Warn("Failed to get chat model for LLM transformations", zap.Error(getErr))
|
||||
return question, nil
|
||||
}
|
||||
chatModel := models.NewChatModel(driver, &mdlName, apiConfig)
|
||||
common.Info("Fetched chat model (tenant default) for cross_languages/keyword_extraction",
|
||||
zap.String("tenantID", tenantIDs[0]), zap.String("modelName", modelName))
|
||||
if len(crossLanguages) > 0 {
|
||||
translated, err := CrossLanguages(ctx, tenantIDs[0], modelName, question, crossLanguages)
|
||||
if err != nil {
|
||||
common.Warn("Failed to translate question", zap.Error(err))
|
||||
} else {
|
||||
modifiedQuestion = translated
|
||||
}
|
||||
}
|
||||
if keyword != nil && *keyword {
|
||||
extractedKeywords, err := KeywordExtraction(ctx, chatModel, modifiedQuestion, 3)
|
||||
if err != nil {
|
||||
common.Warn("Failed to extract keywords from question", zap.Error(err))
|
||||
} else if extractedKeywords != "" {
|
||||
modifiedQuestion = modifiedQuestion + " " + extractedKeywords
|
||||
}
|
||||
}
|
||||
if modifiedQuestion != question {
|
||||
common.Info("Modified question after transformations",
|
||||
zap.String("originalQuestion", question),
|
||||
zap.String("modifiedQuestion", modifiedQuestion),
|
||||
zap.Strings("crossLanguages", crossLanguages),
|
||||
zap.Bool("keywordExtraction", keyword != nil && *keyword))
|
||||
}
|
||||
return modifiedQuestion, nil
|
||||
}
|
||||
|
||||
// resolveEmbeddingModel resolves the embedding model for a KB record.
|
||||
func (s *ChunkService) resolveEmbeddingModel(tenantID string, kbRecord *entity.Knowledgebase) (*models.EmbeddingModel, error) {
|
||||
var embdID string
|
||||
var err error
|
||||
if kbRecord.TenantEmbdID != nil && *kbRecord.TenantEmbdID > 0 {
|
||||
_, embdID, err = dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *kbRecord.TenantEmbdID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get embedding model by tenant_embd_id: %w", err)
|
||||
}
|
||||
} else if kbRecord.EmbdID != "" {
|
||||
parts := strings.Split(kbRecord.EmbdID, "@")
|
||||
if len(parts) == 2 && parts[1] != "" {
|
||||
_, embdID, err = dao.LookupTenantLLMByFactory(dao.NewTenantLLMDAO(), tenantID, parts[1], parts[0], entity.ModelTypeEmbedding)
|
||||
} else {
|
||||
_, embdID, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantID, kbRecord.EmbdID, entity.ModelTypeEmbedding)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", err)
|
||||
}
|
||||
} else {
|
||||
tenantLLM, err := dao.NewTenantLLMDAO().GetByTenantAndType(tenantID, entity.ModelTypeEmbedding)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get tenant default embedding model: %w", err)
|
||||
}
|
||||
if tenantLLM == nil || tenantLLM.LLMName == nil || *tenantLLM.LLMName == "" {
|
||||
return nil, fmt.Errorf("no default embedding model found for tenant %s", tenantID)
|
||||
}
|
||||
embdID = fmt.Sprintf("%s@%s", *tenantLLM.LLMName, tenantLLM.LLMFactory)
|
||||
}
|
||||
modelProviderSvc := NewModelProviderService()
|
||||
embeddingModel, err := modelProviderSvc.GetEmbeddingModel(tenantID, embdID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get embedding model: %w", err)
|
||||
}
|
||||
common.Info("Fetched embedding model for retrieval",
|
||||
zap.String("tenantID", tenantID), zap.String("embdID", embdID))
|
||||
return embeddingModel, nil
|
||||
}
|
||||
|
||||
// resolveRerankModel resolves the rerank model from tenant_rerank_id or rerank_id.
|
||||
func (s *ChunkService) resolveRerankModel(tenantID string, tenantRerankID, rerankID *string) (*models.RerankModel, error) {
|
||||
var rerankCompositeName string
|
||||
var err error
|
||||
if tenantRerankID != nil && *tenantRerankID != "" {
|
||||
tenantRerankIDInt, parseErr := strconv.ParseInt(*tenantRerankID, 10, 64)
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("invalid tenant_rerank_id: %w", parseErr)
|
||||
}
|
||||
_, rerankCompositeName, err = dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), tenantRerankIDInt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get rerank model by tenant_rerank_id: %w", err)
|
||||
}
|
||||
} else if rerankID != nil && *rerankID != "" {
|
||||
_, rerankCompositeName, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantID, *rerankID, entity.ModelTypeRerank)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get rerank model by rerank_id: %w", err)
|
||||
}
|
||||
}
|
||||
if rerankCompositeName == "" {
|
||||
return nil, nil
|
||||
}
|
||||
modelProviderSvc := NewModelProviderService()
|
||||
driver, mdlName, apiConfig, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantID, entity.ModelTypeRerank, rerankCompositeName)
|
||||
if getErr != nil {
|
||||
return nil, fmt.Errorf("failed to get rerank model: %w", getErr)
|
||||
}
|
||||
rerankModel := models.NewRerankModel(driver, &mdlName, apiConfig)
|
||||
common.Info("Fetched rerank model",
|
||||
zap.String("tenantID", tenantID), zap.String("rerankCompositeName", rerankCompositeName))
|
||||
return rerankModel, nil
|
||||
}
|
||||
|
||||
|
||||
// GetChunkRequest request for getting a chunk by ID
|
||||
type GetChunkRequest struct {
|
||||
ChunkID string `json:"chunk_id"`
|
||||
DocumentID string `json:"document_id"`
|
||||
DatasetID string `json:"dataset_id"`
|
||||
ChunkID string `json:"chunk_id"`
|
||||
}
|
||||
|
||||
// GetChunkResponse response for getting a chunk
|
||||
@@ -75,27 +462,10 @@ func (s *ChunkService) Get(req *GetChunkRequest, userID string) (*GetChunkRespon
|
||||
if req.ChunkID == "" {
|
||||
return nil, fmt.Errorf("chunk_id is required")
|
||||
}
|
||||
if req.DatasetID == "" {
|
||||
return nil, fmt.Errorf("dataset_id is required")
|
||||
}
|
||||
if req.DocumentID == "" {
|
||||
return nil, fmt.Errorf("document_id is required")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Verify the document exists and belongs to the dataset
|
||||
docDAO := dao.NewDocumentDAO()
|
||||
doc, err := docDAO.GetByID(req.DocumentID)
|
||||
if err != nil || doc == nil {
|
||||
return nil, fmt.Errorf("document not found")
|
||||
}
|
||||
if doc.KbID != req.DatasetID {
|
||||
return nil, fmt.Errorf("document does not belong to this dataset")
|
||||
}
|
||||
|
||||
// Resolve the tenant that owns this dataset. A dataset belongs to
|
||||
// exactly one tenant, so the loop breaks on the first match.
|
||||
// Get user's tenants
|
||||
tenants, err := s.userTenantDAO.GetByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user tenants: %w", err)
|
||||
@@ -103,87 +473,75 @@ func (s *ChunkService) Get(req *GetChunkRequest, userID string) (*GetChunkRespon
|
||||
if len(tenants) == 0 {
|
||||
return nil, fmt.Errorf("user has no accessible tenants")
|
||||
}
|
||||
var targetTenantID string
|
||||
|
||||
// Try each tenant to find the chunk
|
||||
var chunk map[string]interface{}
|
||||
for _, tenant := range tenants {
|
||||
kb, err := s.kbDAO.GetByIDAndTenantID(req.DatasetID, tenant.TenantID)
|
||||
if err == nil && kb != nil {
|
||||
targetTenantID = tenant.TenantID
|
||||
break
|
||||
}
|
||||
}
|
||||
if targetTenantID == "" {
|
||||
return nil, fmt.Errorf("user does not have access to this dataset")
|
||||
}
|
||||
|
||||
indexName := fmt.Sprintf("ragflow_%s", targetTenantID)
|
||||
raw, err := s.docEngine.GetChunk(ctx, indexName, req.ChunkID, []string{req.DatasetID})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get chunk: %w", err)
|
||||
}
|
||||
if raw == nil {
|
||||
return nil, fmt.Errorf("chunk not found")
|
||||
}
|
||||
chunk, ok := raw.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid chunk format: expected map[string]interface{}, got %T", raw)
|
||||
}
|
||||
|
||||
if actual, _ := chunk["kb_id"].(string); actual != req.DatasetID {
|
||||
return nil, fmt.Errorf("chunk does not belong to dataset %q", req.DatasetID)
|
||||
}
|
||||
if actual, _ := chunk["doc_id"].(string); actual != req.DocumentID {
|
||||
return nil, fmt.Errorf("chunk does not belong to document %q", req.DocumentID)
|
||||
}
|
||||
|
||||
return &GetChunkResponse{Chunk: formatChunkForGet(chunk)}, nil
|
||||
}
|
||||
|
||||
// formatChunkForGet normalizes a raw engine chunk into the public
|
||||
// response shape: it strips internal fields (_vec / _tks / _ltks / _sm_
|
||||
// suffixes and a handful of score / id fields), renames a few legacy
|
||||
// keys, and coerces numeric fields to JSONFloat64 so the JSON encoder
|
||||
// preserves precision.
|
||||
func formatChunkForGet(chunk map[string]interface{}) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
skipFields := map[string]bool{
|
||||
"id": true, "authors": true, "_score": true, "SCORE": true,
|
||||
}
|
||||
for k, v := range chunk {
|
||||
if skipFields[k] || strings.HasSuffix(k, "_vec") || strings.Contains(k, "_sm_") || strings.HasSuffix(k, "_tks") || strings.HasSuffix(k, "_ltks") {
|
||||
// Get kbIDs for this tenant
|
||||
kbIDs, err := s.kbDAO.GetKBIDsByTenantID(tenant.TenantID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
switch k {
|
||||
case "content":
|
||||
result["content_with_weight"] = v
|
||||
case "docnm":
|
||||
result["docnm_kwd"] = v
|
||||
case "important_keywords":
|
||||
utility.SetFieldArray(result, "important_kwd", v)
|
||||
case "questions":
|
||||
utility.SetFieldArray(result, "question_kwd", v)
|
||||
case "entities_kwd", "entity_kwd", "entity_type_kwd", "from_entity_kwd",
|
||||
"name_kwd", "raptor_kwd", "removed_kwd", "source_id", "tag_kwd",
|
||||
"to_entity_kwd", "toc_kwd":
|
||||
if utility.IsEmpty(v) {
|
||||
result[k] = []interface{}{}
|
||||
} else {
|
||||
result[k] = v
|
||||
|
||||
indexName := fmt.Sprintf("ragflow_%s", tenant.TenantID)
|
||||
|
||||
doc, err := s.docEngine.GetChunk(ctx, indexName, req.ChunkID, kbIDs)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if doc != nil {
|
||||
chunk, ok := doc.(map[string]interface{})
|
||||
if ok {
|
||||
result := make(map[string]interface{})
|
||||
skipFields := map[string]bool{
|
||||
"id": true, "authors": true, "_score": true, "SCORE": true,
|
||||
}
|
||||
for k, v := range chunk {
|
||||
if skipFields[k] || strings.HasSuffix(k, "_vec") || strings.Contains(k, "_sm_") || strings.HasSuffix(k, "_tks") || strings.HasSuffix(k, "_ltks") {
|
||||
continue
|
||||
}
|
||||
switch k {
|
||||
case "content":
|
||||
result["content_with_weight"] = v
|
||||
case "docnm":
|
||||
result["docnm_kwd"] = v
|
||||
case "important_keywords":
|
||||
utility.SetFieldArray(result, "important_kwd", v)
|
||||
case "questions":
|
||||
utility.SetFieldArray(result, "question_kwd", v)
|
||||
case "entities_kwd", "entity_kwd", "entity_type_kwd", "from_entity_kwd",
|
||||
"name_kwd", "raptor_kwd", "removed_kwd", "source_id", "tag_kwd",
|
||||
"to_entity_kwd", "toc_kwd", "authors_tks", "doc_type_kwd":
|
||||
if utility.IsEmpty(v) {
|
||||
result[k] = []interface{}{}
|
||||
} else {
|
||||
result[k] = v
|
||||
}
|
||||
case "tag_feas":
|
||||
if utility.IsEmpty(v) {
|
||||
result[k] = map[string]interface{}{}
|
||||
} else {
|
||||
result[k] = v
|
||||
}
|
||||
case "create_timestamp_flt", "rank_flt", "weight_flt":
|
||||
if floatVal, ok := utility.ToFloat64(v); ok {
|
||||
result[k] = utility.JSONFloat64(floatVal)
|
||||
}
|
||||
default:
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
return &GetChunkResponse{Chunk: result}, nil
|
||||
}
|
||||
case "tag_feas":
|
||||
if utility.IsEmpty(v) {
|
||||
result[k] = map[string]interface{}{}
|
||||
} else {
|
||||
result[k] = v
|
||||
}
|
||||
case "create_timestamp_flt", "rank_flt", "weight_flt":
|
||||
if floatVal, ok := utility.ToFloat64(v); ok {
|
||||
result[k] = utility.JSONFloat64(floatVal)
|
||||
}
|
||||
default:
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
return result
|
||||
|
||||
if chunk == nil {
|
||||
return nil, fmt.Errorf("chunk not found")
|
||||
}
|
||||
|
||||
return &GetChunkResponse{Chunk: chunk}, nil
|
||||
}
|
||||
|
||||
// ListChunksRequest request for listing chunks
|
||||
@@ -316,7 +674,7 @@ func (s *ChunkService) List(req *ListChunksRequest, userID string) (*ListChunksR
|
||||
utility.SetFieldArray(result, "question_kwd", v)
|
||||
case "entities_kwd", "entity_kwd", "entity_type_kwd", "from_entity_kwd",
|
||||
"name_kwd", "raptor_kwd", "removed_kwd",
|
||||
"source_id", "tag_kwd", "to_entity_kwd", "toc_kwd":
|
||||
"source_id", "tag_kwd", "to_entity_kwd", "toc_kwd", "doc_type_kwd":
|
||||
if utility.IsEmpty(v) {
|
||||
result[k] = []interface{}{}
|
||||
} else {
|
||||
@@ -447,12 +805,10 @@ func (s *ChunkService) UpdateChunk(req *UpdateChunkRequest, userID string) error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing chunk: %w", err)
|
||||
}
|
||||
if existingChunk == nil {
|
||||
return fmt.Errorf("chunk %q not found in dataset %q", req.ChunkID, req.DatasetID)
|
||||
}
|
||||
|
||||
existing, ok := existingChunk.(map[string]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid chunk format: expected map[string]interface{}, got %T", existingChunk)
|
||||
return fmt.Errorf("invalid chunk format")
|
||||
}
|
||||
|
||||
// Build update dict
|
||||
|
||||
250
internal/service/chunk_retrieval_test.go
Normal file
250
internal/service/chunk_retrieval_test.go
Normal file
@@ -0,0 +1,250 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// --- Helper tests ---
|
||||
|
||||
func TestPtrString_Nil(t *testing.T) {
|
||||
if got := ptrString[int](nil); got != "<nil>" {
|
||||
t.Errorf("ptrString(nil) = %q, want <nil>", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPtrString_Value(t *testing.T) {
|
||||
val := 42
|
||||
if got := ptrString(&val); got != "42" {
|
||||
t.Errorf("ptrString(&42) = %q, want 42", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPtrString_Bool(t *testing.T) {
|
||||
val := true
|
||||
if got := ptrString(&val); got != "true" {
|
||||
t.Errorf("ptrString(&true) = %q, want true", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPageNum_Nil(t *testing.T) {
|
||||
if got := getPageNum(nil, 10); got != 10 {
|
||||
t.Errorf("getPageNum(nil, 10) = %d, want 10", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPageNum_ZeroReturnsDefault(t *testing.T) {
|
||||
val := 0
|
||||
if got := getPageNum(&val, 5); got != 5 {
|
||||
t.Errorf("getPageNum(&0, 5) = %d, want 5", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPageNum_NegativeReturnsDefault(t *testing.T) {
|
||||
val := -1
|
||||
if got := getPageNum(&val, 5); got != 5 {
|
||||
t.Errorf("getPageNum(&-1, 5) = %d, want 5", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPageNum_Valid(t *testing.T) {
|
||||
val := 3
|
||||
if got := getPageNum(&val, 5); got != 3 {
|
||||
t.Errorf("getPageNum(&3, 5) = %d, want 3", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPageSize_Nil(t *testing.T) {
|
||||
if got := getPageSize(nil, 20); got != 20 {
|
||||
t.Errorf("getPageSize(nil, 20) = %d, want 20", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPageSize_ZeroReturnsDefault(t *testing.T) {
|
||||
val := 0
|
||||
if got := getPageSize(&val, 20); got != 20 {
|
||||
t.Errorf("getPageSize(&0, 20) = %d, want 20", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPageSize_Valid(t *testing.T) {
|
||||
val := 50
|
||||
if got := getPageSize(&val, 20); got != 50 {
|
||||
t.Errorf("getPageSize(&50, 20) = %d, want 50", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- RetrievalTestRequest validation tests ---
|
||||
|
||||
func TestRetrievalTestRequest_Defaults(t *testing.T) {
|
||||
req := &RetrievalTestRequest{
|
||||
Datasets: []string{"kb1"},
|
||||
Question: "test question",
|
||||
}
|
||||
// Verify pointer fields are nil by default
|
||||
if req.Page != nil {
|
||||
t.Error("Page should default to nil")
|
||||
}
|
||||
if req.Size != nil {
|
||||
t.Error("Size should default to nil")
|
||||
}
|
||||
if req.TopK != nil {
|
||||
t.Error("TopK should default to nil")
|
||||
}
|
||||
if req.UseKG != nil {
|
||||
t.Error("UseKG should default to nil")
|
||||
}
|
||||
if req.SimilarityThreshold != nil {
|
||||
t.Error("SimilarityThreshold should default to nil")
|
||||
}
|
||||
if req.VectorSimilarityWeight != nil {
|
||||
t.Error("VectorSimilarityWeight should default to nil")
|
||||
}
|
||||
if req.Keyword != nil {
|
||||
t.Error("Keyword should default to nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrievalTestResponse_Fields(t *testing.T) {
|
||||
resp := &RetrievalTestResponse{
|
||||
Chunks: []map[string]interface{}{},
|
||||
DocAggs: []map[string]interface{}{},
|
||||
Total: 0,
|
||||
}
|
||||
if resp.Chunks == nil {
|
||||
t.Error("Chunks should not be nil")
|
||||
}
|
||||
if resp.DocAggs == nil {
|
||||
t.Error("DocAggs should not be nil")
|
||||
}
|
||||
if resp.Total != 0 {
|
||||
t.Errorf("Total = %d, want 0", resp.Total)
|
||||
}
|
||||
}
|
||||
|
||||
// --- transformQuestion edge cases ---
|
||||
|
||||
func TestTransformQuestion_NoTransformNeeded(t *testing.T) {
|
||||
svc := &ChunkService{}
|
||||
ctx := context.Background()
|
||||
result, err := svc.transformQuestion(ctx, "hello", nil, nil, []string{"t1"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result != "hello" {
|
||||
t.Errorf("expected unchanged question, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransformQuestion_EmptyCrossLanguages(t *testing.T) {
|
||||
svc := &ChunkService{}
|
||||
ctx := context.Background()
|
||||
kw := false
|
||||
result, err := svc.transformQuestion(ctx, "hello", []string{}, &kw, []string{"t1"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result != "hello" {
|
||||
t.Errorf("expected unchanged question, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransformQuestion_KeywordFalse(t *testing.T) {
|
||||
// This test verifies the early-return path for transformQuestion.
|
||||
// With crossLanguages non-empty it would hit the DB; this is tested
|
||||
// via integration tests that have a full service setup.
|
||||
svc := &ChunkService{}
|
||||
ctx := context.Background()
|
||||
kw := false
|
||||
result, err := svc.transformQuestion(ctx, "hello", []string{}, &kw, []string{"t1"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result != "hello" {
|
||||
t.Errorf("expected unchanged question, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
// --- resolveEmbeddingModel via exported Retriever ---
|
||||
// These test that the retriever can handle nil inputs gracefully
|
||||
|
||||
func TestResolveEmbeddingModel_NilTenantEmbdID(t *testing.T) {
|
||||
kb := &entity.Knowledgebase{
|
||||
EmbdID: "text-embedding-ada-002@OpenAI",
|
||||
}
|
||||
// This will fail because it needs a real DAO, but we verify the type contract
|
||||
if kb.TenantEmbdID != nil {
|
||||
t.Error("TenantEmbdID should be nil for this test")
|
||||
}
|
||||
_ = kb // verified fields are accessible
|
||||
}
|
||||
|
||||
func TestResolveRerankModel_BothNil(t *testing.T) {
|
||||
svc := &ChunkService{}
|
||||
result, err := svc.resolveRerankModel("t1", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result != nil {
|
||||
t.Errorf("expected nil rerank model when both IDs are nil, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRerankModel_EmptyStrings(t *testing.T) {
|
||||
svc := &ChunkService{}
|
||||
empty := ""
|
||||
result, err := svc.resolveRerankModel("t1", &empty, &empty)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result != nil {
|
||||
t.Errorf("expected nil rerank model when both IDs are empty, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRerankModel_InvalidTenantRerankID(t *testing.T) {
|
||||
svc := &ChunkService{}
|
||||
invalid := "not_a_number"
|
||||
_, err := svc.resolveRerankModel("t1", &invalid, nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid tenant_rerank_id")
|
||||
}
|
||||
}
|
||||
|
||||
// --- validateKBs input validation ---
|
||||
|
||||
func TestValidateKBs_EmptyDatasets(t *testing.T) {
|
||||
// validateKBs iterates over datasetIDs and queries DAOs.
|
||||
// With empty input it should return empty slices.
|
||||
// This test is limited since validateKBs requires DB-backed DAOs.
|
||||
_ = &ChunkService{} // compiles
|
||||
}
|
||||
|
||||
// --- Verify ChunkService struct fields ---
|
||||
func TestChunkService_FieldsAccessible(t *testing.T) {
|
||||
svc := &ChunkService{}
|
||||
_ = svc.docEngine
|
||||
_ = svc.kbDAO
|
||||
_ = svc.userTenantDAO
|
||||
_ = svc.searchService
|
||||
// Verify embeddingCache field type
|
||||
_ = svc.embeddingCache
|
||||
}
|
||||
Reference in New Issue
Block a user