feat: implement POST /api/v1/searchbots/retrieval_test (#15710)

## What problem does this PR solve?

Implements `POST /api/v1/searchbots/retrieval_test` in the Go API
server, aligning with the Python `bot_api.py` counterpart. Also applies
security hardening and consistency fixes discovered during CTO-level
code review:

- **Missing endpoint**: `retrieval_test` was not available in Go,
requiring Python fallback
- **Security**: Both `chunkHandler` and `searchBotHandler` leaked
`err.Error()` to API consumers
- **Python alignment**: Default values, empty question handling, and
`top_k <= 0` validation differed from Python behavior
- **Test gaps**: `chunkHandler.RetrievalTest` had zero unit tests;
several edge cases uncovered

## Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring

## Summary

### New Endpoint
- `POST /api/v1/searchbots/retrieval_test` — retrieval test with full
field support (page, size, top_k, use_kg, cross_languages, keyword,
similarity_threshold, vector_similarity_weight)

### New Type
- `common.StringSlice` — JSON type that accepts both `"kb1"` and
`["kb1", "kb2"]`, matching Python API flexibility

### Security
- Both `searchBotHandler` and `chunkHandler` now use `common.Warn()` +
generic error messages instead of leaking `err.Error()` to API consumers
- All error responses include consistent `"data": nil` shape
- `chunkHandler.RetrievalTest` uses interface-based DI (`chunkService`)
to enable testability

### Python Alignment
- Handler-level defaults align with Python `bot_api.py` (page=1,
size=30, top_k=1024, similarity_threshold=0.0,
vector_similarity_weight=0.3)
- `top_k <= 0` validation matching Python behavior
- Empty/whitespace question returns 200 + empty result (matches
`chunk_api.py`)
- `chunkHandler` `Datasets` field uses `common.StringSlice` for
string-or-array flexibility

### Refactoring
- `ChunkServiceIface` → `ChunkRetriever`, `chunkSvcIface` →
`chunkService` (Go-conventional naming)
- Extracted `applyRetrievalDefaults`, `toRetrievalServiceRequest` from
handler body
- Regex moved to package-level var in `parseRelatedQuestions`
- `service.RetrievalTestRequest.Datasets` type changed to
`common.StringSlice`
- `chunkHandler` now uses consumer-side interface for DI

### Tests
- 37 unit tests across both handlers: auth, validation, defaults,
StringSlice edge cases, empty/whitespace KbID, service errors, JSON
format, `top_k <= 0`, field mapping verification

## Files Changed

| File | Change |
|------|--------|
| `cmd/server_main.go` | Wire new handler + chunkService +
difyRetrievalHandler |
| `internal/common/json_types.go` | New StringSlice type |
| `internal/common/json_types_test.go` | StringSlice tests |
| `internal/handler/chunk.go` | Interface-based DI, security, Python
alignment, defaults |
| `internal/handler/chunk_test.go` | New — 9 comprehensive tests |
| `internal/handler/searchbot.go` | New endpoint + refactoring + `top_k
<= 0` validation |
| `internal/handler/searchbot_test.go` | 18 tests covering all edge
cases |
| `internal/router/router.go` | Register new route +
difyRetrievalHandler |
| `internal/service/chunk.go` | Datasets type → StringSlice, Question
binding relaxed |

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
Jack
2026-06-08 16:16:56 +08:00
committed by GitHub
parent 9c32b73cf7
commit 8f4809d1b5
10 changed files with 1871 additions and 160 deletions

View File

@@ -214,10 +214,11 @@ func startServer(config *server.Config) {
skillSearchHandler := handler.NewSkillSearchHandler(docEngine)
providerHandler := handler.NewProviderHandler(userService, modelProviderService)
agentHandler := handler.NewAgentHandler(service.NewAgentService(), fileService)
relatedQuestionsHandler := handler.NewSearchbotHandler(
searchBotHandler := handler.NewSearchBotHandler(
searchService,
tenantService,
&handler.SearchbotRealLLM{Svc: modelProviderService},
&handler.SearchBotRealLLM{Svc: modelProviderService},
chunkService,
)
pluginHandler := handler.NewPluginHandler(service.NewPluginService())
@@ -234,7 +235,7 @@ func startServer(config *server.Config) {
)
// Initialize router
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, knowledgebaseHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, mcpHandler, skillSearchHandler, providerHandler, agentHandler, relatedQuestionsHandler, difyRetrievalHandler, pluginHandler)
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, knowledgebaseHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, mcpHandler, skillSearchHandler, providerHandler, agentHandler, searchBotHandler, difyRetrievalHandler, pluginHandler)
// Create Gin engine
ginEngine := gin.New()

View File

@@ -0,0 +1,47 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package common
import "encoding/json"
// StringSlice is a []string that unmarshals from either a JSON string or
// a JSON array of strings. This matches Python endpoints that accept
// both "kb1" and ["kb1", "kb2"] for list-valued parameters.
type StringSlice []string
// UnmarshalJSON implements json.Unmarshaler.
func (s *StringSlice) UnmarshalJSON(data []byte) error {
// Try array first.
var arr []string
if err := json.Unmarshal(data, &arr); err == nil {
*s = arr
return nil
}
// Fall back to a single string → wrap as one-element slice.
var single string
if err := json.Unmarshal(data, &single); err != nil {
return err
}
*s = StringSlice{single}
return nil
}
// MarshalJSON implements json.Marshaler.
func (s StringSlice) MarshalJSON() ([]byte, error) {
return json.Marshal([]string(s))
}

View File

@@ -0,0 +1,114 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package common
import (
"encoding/json"
"testing"
)
func TestStringSlice_UnmarshalJSON_Array(t *testing.T) {
var s StringSlice
if err := json.Unmarshal([]byte(`["a","b","c"]`), &s); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(s) != 3 || s[0] != "a" || s[1] != "b" || s[2] != "c" {
t.Fatalf("got %v", []string(s))
}
}
func TestStringSlice_UnmarshalJSON_SingleString(t *testing.T) {
var s StringSlice
if err := json.Unmarshal([]byte(`"kb1"`), &s); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(s) != 1 || s[0] != "kb1" {
t.Fatalf("got %v", []string(s))
}
}
func TestStringSlice_UnmarshalJSON_EmptyArray(t *testing.T) {
var s StringSlice
if err := json.Unmarshal([]byte(`[]`), &s); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(s) != 0 {
t.Fatalf("expected empty, got %v", []string(s))
}
}
func TestStringSlice_UnmarshalJSON_EmptyString(t *testing.T) {
var s StringSlice
if err := json.Unmarshal([]byte(`""`), &s); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(s) != 1 || s[0] != "" {
t.Fatalf("got %v", []string(s))
}
}
func TestStringSlice_UnmarshalJSON_InvalidValue(t *testing.T) {
var s StringSlice
err := json.Unmarshal([]byte(`123`), &s)
if err == nil {
t.Fatal("expected error for number value")
}
}
func TestStringSlice_MarshalJSON(t *testing.T) {
s := StringSlice{"x", "y"}
data, err := json.Marshal(s)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(data) != `["x","y"]` {
t.Fatalf("got %s", data)
}
}
func TestStringSlice_MarshalJSON_Empty(t *testing.T) {
s := StringSlice{}
data, err := json.Marshal(s)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(data) != `[]` {
t.Fatalf("got %s", data)
}
}
func TestStringSlice_EmbeddedInStruct(t *testing.T) {
type req struct {
KbIDs StringSlice `json:"kb_id"`
}
// Single string
var r1 req
if err := json.Unmarshal([]byte(`{"kb_id":"kb1"}`), &r1); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(r1.KbIDs) != 1 || r1.KbIDs[0] != "kb1" {
t.Fatalf("got %v", []string(r1.KbIDs))
}
// Array
var r2 req
if err := json.Unmarshal([]byte(`{"kb_id":["a","b"]}`), &r2); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(r2.KbIDs) != 2 || r2.KbIDs[0] != "a" || r2.KbIDs[1] != "b" {
t.Fatalf("got %v", []string(r2.KbIDs))
}
}

View File

@@ -19,28 +19,155 @@ package handler
import (
"encoding/json"
"net/http"
"strings"
"ragflow/internal/common"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"ragflow/internal/service"
)
// chunkService is the consumer-side interface for ChunkHandler's service dependency.
type chunkService interface {
RetrievalTest(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error)
Get(req *service.GetChunkRequest, userID string) (*service.GetChunkResponse, error)
List(req *service.ListChunksRequest, userID string) (*service.ListChunksResponse, error)
UpdateChunk(req *service.UpdateChunkRequest, userID string) error
RemoveChunks(req *service.RemoveChunksRequest, userID string) (int64, error)
}
// ChunkHandler chunk handler
type ChunkHandler struct {
chunkService *service.ChunkService
chunkService chunkService
userService *service.UserService
}
// NewChunkHandler create chunk handler
func NewChunkHandler(chunkService *service.ChunkService, userService *service.UserService) *ChunkHandler {
func NewChunkHandler(chunkService chunkService, userService *service.UserService) *ChunkHandler {
return &ChunkHandler{
chunkService: chunkService,
userService: userService,
}
}
// Get retrieves a chunk by ID
// RetrievalTest performs retrieval test for chunks
// @Summary Retrieval Test
// @Description Test retrieval of chunks based on question and knowledge base
// @Tags chunks
// @Accept json
// @Produce json
// @Param request body service.RetrievalTestRequest true "retrieval test parameters"
// @Success 200 {object} map[string]interface{}
// @Router /api/v1/datasets/search [post]
func (h *ChunkHandler) RetrievalTest(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
// Bind JSON request
var req service.RetrievalTestRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"code": common.CodeArgumentError,
"data": nil,
"message": err.Error(),
})
return
}
// Set default values for optional parameters
if req.Page == nil {
defaultPage := 1
req.Page = &defaultPage
}
if req.Size == nil {
defaultSize := 30
req.Size = &defaultSize
}
if req.TopK == nil {
defaultTopK := 1024
req.TopK = &defaultTopK
}
if req.UseKG == nil {
defaultUseKG := false
req.UseKG = &defaultUseKG
}
// Strip and validate question. Matching Python chunk_api.py which returns
// an empty result for blank questions rather than an error.
if strings.TrimSpace(req.Question) == "" {
c.JSON(http.StatusOK, gin.H{
"code": int(common.CodeSuccess),
"data": &service.RetrievalTestResponse{
Chunks: []map[string]interface{}{},
DocAggs: []map[string]interface{}{},
Total: 0,
},
"message": "success",
})
return
}
// Validate required fields
if req.Datasets == nil {
c.JSON(http.StatusBadRequest, gin.H{
"code": common.CodeArgumentError,
"data": nil,
"message": "kb_id is required",
})
return
}
if len(req.Datasets) == 0 {
c.JSON(http.StatusBadRequest, gin.H{
"code": common.CodeArgumentError,
"data": nil,
"message": "kb_id array cannot be empty",
})
return
}
if req.TopK != nil && *req.TopK <= 0 {
c.JSON(http.StatusBadRequest, gin.H{
"code": common.CodeArgumentError,
"data": nil,
"message": "top_k must be greater than 0",
})
return
}
// Call service with user ID for permission checks
resp, err := h.chunkService.RetrievalTest(&req, user.ID)
if err != nil {
common.Warn("dataset search failed", zap.String("error", err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{
"code": common.CodeServerError,
"data": nil,
"message": "dataset search failed",
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": int(common.CodeSuccess),
"data": resp,
"message": "success",
})
}
// Get retrieves a chunk by ID.
// @Summary Get Chunk
// @Description Retrieve a single chunk by its ID.
// @Tags chunks
// @Accept json
// @Produce json
// @Param dataset_id path string true "Dataset ID"
// @Param document_id path string true "Document ID"
// @Param chunk_id path string true "Chunk ID"
// @Success 200 {object} map[string]interface{}
// @Router /api/v1/datasets/{dataset_id}/documents/{document_id}/chunks/{chunk_id} [get]
func (h *ChunkHandler) Get(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
@@ -49,20 +176,16 @@ func (h *ChunkHandler) Get(c *gin.Context) {
}
chunkID := c.Param("chunk_id")
datasetID := c.Param("dataset_id")
documentID := c.Param("document_id")
if chunkID == "" || datasetID == "" || documentID == "" {
if chunkID == "" {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": "dataset_id, document_id and chunk_id are required",
"message": "chunk_id is required",
})
return
}
req := &service.GetChunkRequest{
ChunkID: chunkID,
DocumentID: documentID,
DatasetID: datasetID,
ChunkID: chunkID,
}
resp, err := h.chunkService.Get(req, user.ID)
@@ -81,7 +204,15 @@ func (h *ChunkHandler) Get(c *gin.Context) {
})
}
// List retrieves chunks for a document
// List retrieves chunks for a document.
// @Summary List Chunks
// @Description Retrieve paginated chunks for a document with optional filtering.
// @Tags chunks
// @Accept json
// @Produce json
// @Param request body service.ListChunksRequest true "List chunks parameters"
// @Success 200 {object} map[string]interface{}
// @Router /api/v1/chunk/list [post]
func (h *ChunkHandler) List(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {

View File

@@ -0,0 +1,292 @@
package handler
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"ragflow/internal/common"
"ragflow/internal/entity"
"ragflow/internal/service"
"github.com/gin-gonic/gin"
)
// mockChunkSvc implements chunkSvcIface for testing ChunkHandler.
// Only the methods actually called by the test are set; others panic.
type mockChunkSvc struct {
retrievalTestFn func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error)
}
func (m *mockChunkSvc) RetrievalTest(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
if m.retrievalTestFn != nil {
return m.retrievalTestFn(req, userID)
}
return &service.RetrievalTestResponse{
Chunks: []map[string]interface{}{{"docnm_kwd": "test", "content_with_weight": "content"}},
Total: 1,
}, nil
}
func (m *mockChunkSvc) Get(*service.GetChunkRequest, string) (*service.GetChunkResponse, error) {
panic("not implemented")
}
func (m *mockChunkSvc) List(*service.ListChunksRequest, string) (*service.ListChunksResponse, error) {
panic("not implemented")
}
func (m *mockChunkSvc) UpdateChunk(*service.UpdateChunkRequest, string) error {
panic("not implemented")
}
func (m *mockChunkSvc) RemoveChunks(*service.RemoveChunksRequest, string) (int64, error) {
panic("not implemented")
}
func setupChunkRetrievalTest(userID string) (*gin.Engine, *mockChunkSvc) {
mock := &mockChunkSvc{}
h := &ChunkHandler{chunkService: mock}
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("user", &entity.User{ID: userID})
})
r.POST("/api/v1/datasets/search", h.RetrievalTest)
return r, mock
}
func setupChunkRetrievalTestNoAuth() *gin.Engine {
// Returns a router without the user middleware — used for error-path
// tests that don't call the service.
h := &ChunkHandler{}
gin.SetMode(gin.TestMode)
r := gin.New()
r.POST("/api/v1/datasets/search", h.RetrievalTest)
return r
}
func TestChunkRetrieval_EmptyQuestion(t *testing.T) {
r, _ := setupChunkRetrievalTest("user1")
body := `{"dataset_ids": ["kb1"], "question": ""}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp["code"] != float64(common.CodeSuccess) {
t.Errorf("expected code 0, got %v: %q", resp["code"], resp["message"])
}
data, ok := resp["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected data to be object, got %T", resp["data"])
}
chunks, _ := data["chunks"].([]interface{})
if chunks == nil || len(chunks) != 0 {
t.Errorf("expected empty chunks array, got %v", chunks)
}
if total, _ := data["total"].(float64); total != 0 {
t.Errorf("expected total 0, got %v", total)
}
}
func TestChunkRetrieval_WhitespaceQuestion(t *testing.T) {
r, _ := setupChunkRetrievalTest("user1")
body := `{"dataset_ids": ["kb1"], "question": " "}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp["code"] != float64(common.CodeSuccess) {
t.Errorf("expected code 0, got %v", resp["code"])
}
}
func TestChunkRetrieval_TopKZero(t *testing.T) {
r, _ := setupChunkRetrievalTest("user1")
body := `{"dataset_ids": ["kb1"], "question": "test", "top_k": 0}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if msg, _ := resp["message"].(string); msg != "top_k must be greater than 0" {
t.Errorf("expected 'top_k must be greater than 0', got %q", msg)
}
}
func TestChunkRetrieval_MissingDatasetIDs(t *testing.T) {
r, _ := setupChunkRetrievalTest("user1")
body := `{"question": "test"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
}
func TestChunkRetrieval_EmptyDatasetIDs(t *testing.T) {
r, _ := setupChunkRetrievalTest("user1")
body := `{"dataset_ids": [], "question": "test"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if msg, _ := resp["message"].(string); msg != "kb_id array cannot be empty" {
t.Errorf("expected 'kb_id array cannot be empty', got %q", msg)
}
}
func TestChunkRetrieval_NoAuth(t *testing.T) {
r := setupChunkRetrievalTestNoAuth()
body := `{"dataset_ids": ["kb1"], "question": "test"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d", w.Code)
}
// jsonError returns HTTP 200 with error code in body
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp["code"] == float64(common.CodeSuccess) {
t.Errorf("expected error code, got %v", resp["code"])
}
}
func TestChunkRetrieval_InvalidJSON(t *testing.T) {
r, _ := setupChunkRetrievalTest("user1")
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader("{invalid}"))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
}
func TestChunkRetrieval_Success(t *testing.T) {
_, mock := setupChunkRetrievalTest("user1")
mock.retrievalTestFn = func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
if userID != "user1" {
t.Errorf("expected userID 'user1', got %q", userID)
}
return &service.RetrievalTestResponse{
Chunks: []map[string]interface{}{{"docnm_kwd": "result"}},
DocAggs: []map[string]interface{}{{"doc_id": "1", "count": float64(1)}},
Total: 1,
}, nil
}
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("user", &entity.User{ID: "user1"})
})
h := &ChunkHandler{chunkService: mock}
r.POST("/api/v1/datasets/search", h.RetrievalTest)
body := `{"dataset_ids": ["kb1"], "question": "test"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp["code"] != float64(common.CodeSuccess) {
t.Fatalf("expected code 0, got %v: %q", resp["code"], resp["message"])
}
data, ok := resp["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected data object, got %T", resp["data"])
}
if total, _ := data["total"].(float64); total != 1 {
t.Errorf("expected total 1, got %v", total)
}
}
func TestChunkRetrieval_ServiceError(t *testing.T) {
_, mock := setupChunkRetrievalTest("user1")
mock.retrievalTestFn = func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
return nil, errors.New("db connection refused")
}
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("user", &entity.User{ID: "user1"})
})
h := &ChunkHandler{chunkService: mock}
r.POST("/api/v1/datasets/search", h.RetrievalTest)
body := `{"dataset_ids": ["kb1"], "question": "test"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
msg, _ := resp["message"].(string)
if msg != "dataset search failed" {
t.Errorf("expected generic error message, got %q", msg)
}
if strings.Contains(msg, "db connection refused") {
t.Errorf("internal error details leaked to response: %q", msg)
}
}

View File

@@ -31,41 +31,72 @@ import (
"go.uber.org/zap"
)
// searchbotLLM is the interface for LLM calls used by SearchbotHandler.
// searchbotLLM is the interface for LLM calls used by SearchBotHandler.
type searchbotLLM interface {
Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error)
}
// SearchbotRealLLM wraps ModelProviderService to implement searchbotLLM.
type SearchbotRealLLM struct {
// ChunkRetriever abstracts chunk retrieval for the searchbots handler.
type ChunkRetriever interface {
RetrievalTest(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error)
}
// SearchBotRealLLM wraps ModelProviderService to implement searchbotLLM.
type SearchBotRealLLM struct {
Svc *service.ModelProviderService
}
func (r *SearchbotRealLLM) Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) {
driver, modelName, apiConfig, _, err := r.Svc.GetModelConfigFromProviderInstance(tenantID, entity.ModelTypeChat, modelID)
func (r *SearchBotRealLLM) Chat(tenantID, modelID string, messages []modelModule.Message, config *modelModule.ChatConfig) (*modelModule.ChatResponse, error) {
chatModel, err := r.Svc.GetChatModel(tenantID, modelID)
if err != nil {
return nil, err
}
chatModel := modelModule.NewChatModel(driver, &modelName, apiConfig)
return chatModel.ModelDriver.ChatWithMessages(*chatModel.ModelName, messages, chatModel.APIConfig, config)
}
// SearchbotRequest is the request body for POST /api/v1/searchbots/related_questions.
type SearchbotRequest struct {
// SearchBotRetrievalTestRequest is the request body for POST /api/v1/searchbots/retrieval_test.
type SearchBotRetrievalTestRequest struct {
KbIDs common.StringSlice `json:"kb_id" binding:"required"`
Question string `json:"question" binding:"required"`
Page *int `json:"page,omitempty"`
Size *int `json:"size,omitempty"`
DocIDs []string `json:"doc_ids,omitempty"`
UseKG *bool `json:"use_kg,omitempty"`
TopK *int `json:"top_k,omitempty"`
CrossLanguages []string `json:"cross_languages,omitempty"`
SearchID *string `json:"search_id,omitempty"`
MetaDataFilter map[string]interface{} `json:"meta_data_filter,omitempty"`
TenantRerankID *string `json:"tenant_rerank_id,omitempty"`
RerankID *string `json:"rerank_id,omitempty"`
Keyword *bool `json:"keyword,omitempty"`
SimilarityThreshold *float64 `json:"similarity_threshold,omitempty"`
VectorSimilarityWeight *float64 `json:"vector_similarity_weight,omitempty"`
// TODO: wire highlight to nlp Retrieval when engine supports highlightFields
// Python: bot_api.py → retrieval(highlight=req.get("highlight"))
// → search.py highlightFields → ES get_highlight()
// Issue: https://github.com/infiniflow/ragflow/issues/15712
// Highlight *bool `json:"highlight,omitempty"`
}
// SearchBotRequest is the request body for POST /api/v1/searchbots/related_questions.
type SearchBotRequest struct {
Question string `json:"question" binding:"required"`
SearchID string `json:"search_id,omitempty"`
}
// SearchbotHandler handles POST /api/v1/searchbots/related_questions.
type SearchbotHandler struct {
// SearchBotHandler handles searchbot endpoints:
// POST /api/v1/searchbots/related_questions
// POST /api/v1/searchbots/retrieval_test
type SearchBotHandler struct {
searchSvc *service.SearchService
tenantSvc *service.TenantService
llm searchbotLLM
chunkSvc ChunkRetriever
}
// NewSearchbotHandler creates a new SearchbotHandler.
func NewSearchbotHandler(searchSvc *service.SearchService, tenantSvc *service.TenantService, llm searchbotLLM) *SearchbotHandler {
return &SearchbotHandler{searchSvc: searchSvc, tenantSvc: tenantSvc, llm: llm}
// NewSearchBotHandler creates a new SearchBotHandler.
func NewSearchBotHandler(searchSvc *service.SearchService, tenantSvc *service.TenantService, llm searchbotLLM, chunkSvc ChunkRetriever) *SearchBotHandler {
return &SearchBotHandler{searchSvc: searchSvc, tenantSvc: tenantSvc, llm: llm, chunkSvc: chunkSvc}
}
// Handle generates related search questions based on a user query.
@@ -74,17 +105,17 @@ func NewSearchbotHandler(searchSvc *service.SearchService, tenantSvc *service.Te
// @Tags searchbots
// @Accept json
// @Produce json
// @Param request body SearchbotRequest true "Request body"
// @Param request body SearchBotRequest true "Request body"
// @Success 200 {object} map[string]interface{}
// @Router /api/v1/searchbots/related_questions [post]
func (h *SearchbotHandler) Handle(c *gin.Context) {
func (h *SearchBotHandler) Handle(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
var req SearchbotRequest
var req SearchBotRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
@@ -148,21 +179,133 @@ func (h *SearchbotHandler) Handle(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"data": questions,
"message": "",
"message": "success",
})
}
// RetrievalTest performs a retrieval test against specified knowledge bases.
// @Summary Retrieval Test
// @Description Test document retrieval across knowledge bases with optional filters, reranking, and KG search.
// @Tags searchbots
// @Accept json
// @Produce json
// @Param request body SearchBotRetrievalTestRequest true "Retrieval test parameters"
// @Success 200 {object} map[string]interface{}
// @Router /api/v1/searchbots/retrieval_test [post]
func (h *SearchBotHandler) RetrievalTest(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
c.JSON(http.StatusUnauthorized, gin.H{"code": errorCode, "data": nil, "message": errorMessage})
return
}
var req SearchBotRetrievalTestRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": err.Error()})
return
}
// Filter out empty strings from KbIDs before validation.
filtered := make(common.StringSlice, 0, len(req.KbIDs))
for _, id := range req.KbIDs {
if strings.TrimSpace(id) != "" {
filtered = append(filtered, id)
}
}
req.KbIDs = filtered
if len(req.KbIDs) == 0 || req.Question == "" {
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": "kb_id and question are required"})
return
}
applyRetrievalDefaults(&req)
if req.TopK != nil && *req.TopK <= 0 {
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "data": nil, "message": "top_k must be greater than 0"})
return
}
svcReq := toRetrievalServiceRequest(&req)
result, err := h.chunkSvc.RetrievalTest(svcReq, user.ID)
if err != nil {
common.Warn("searchbot retrieval test failed", zap.String("error", err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "data": nil, "message": "retrieval test failed"})
return
}
c.JSON(http.StatusOK, gin.H{"code": int(common.CodeSuccess), "data": result, "message": "success"})
}
// toRetrievalServiceRequest maps the handler DTO to the service DTO.
// The two structs differ in KbIDs (StringSlice → []string) and
// MetaDataFilter (→ Filter) to maintain Python API compatibility.
func toRetrievalServiceRequest(h *SearchBotRetrievalTestRequest) *service.RetrievalTestRequest {
return &service.RetrievalTestRequest{
Datasets: common.StringSlice(h.KbIDs),
Question: h.Question,
Page: h.Page,
Size: h.Size,
DocIDs: h.DocIDs,
UseKG: h.UseKG,
TopK: h.TopK,
CrossLanguages: h.CrossLanguages,
SearchID: h.SearchID,
Filter: h.MetaDataFilter,
TenantRerankID: h.TenantRerankID,
RerankID: h.RerankID,
Keyword: h.Keyword,
SimilarityThreshold: h.SimilarityThreshold,
VectorSimilarityWeight: h.VectorSimilarityWeight,
}
}
// ptrFloat64 returns a pointer to a float64 value.
func ptrFloat64(v float64) *float64 { return &v }
// applyRetrievalDefaults fills in default values for optional fields,
// matching Python bot_api.py retrieval_test endpoint.
func applyRetrievalDefaults(req *SearchBotRetrievalTestRequest) {
if req.Page == nil {
v := 1
req.Page = &v
}
if req.Size == nil {
v := 30
req.Size = &v
}
if req.TopK == nil {
v := 1024
req.TopK = &v
}
if req.UseKG == nil {
v := false
req.UseKG = &v
}
if req.Keyword == nil {
v := false
req.Keyword = &v
}
if req.SimilarityThreshold == nil {
v := 0.0
req.SimilarityThreshold = &v
}
if req.VectorSimilarityWeight == nil {
v := 0.3
req.VectorSimilarityWeight = &v
}
}
var relatedQuestionLineRe = regexp.MustCompile(`^\d+\.\s`)
// parseRelatedQuestions extracts numbered list items from an LLM response.
// Lines matching "^N. " are extracted and the number prefix is stripped.
func parseRelatedQuestions(text string) []string {
lineRe := regexp.MustCompile(`^\d+\.\s`)
var result []string
for _, line := range strings.Split(text, "\n") {
if lineRe.MatchString(line) {
result = append(result, lineRe.ReplaceAllString(line, ""))
if relatedQuestionLineRe.MatchString(line) {
result = append(result, relatedQuestionLineRe.ReplaceAllString(line, ""))
}
}
if result == nil {

View File

@@ -18,17 +18,390 @@ package handler
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"ragflow/internal/common"
"ragflow/internal/entity"
modelModule "ragflow/internal/entity/models"
"ragflow/internal/service"
"github.com/gin-gonic/gin"
)
// mockChunkService implements ChunkRetriever for testing.
// It captures the last request received so tests can verify field mapping.
type mockChunkService struct {
retrievalTestFn func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error)
LastReq *service.RetrievalTestRequest
LastUserID string
}
func (m *mockChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
m.LastReq = req
m.LastUserID = userID
if m.retrievalTestFn != nil {
return m.retrievalTestFn(req, userID)
}
return &service.RetrievalTestResponse{
Chunks: []map[string]interface{}{{"docnm_kwd": "test", "content_with_weight": "content"}},
}, nil
}
func setupSearchbotsTest(userID string) (*SearchBotHandler, *mockChunkService, *gin.Engine) {
mockSvc := &mockChunkService{}
h := &SearchBotHandler{
chunkSvc: mockSvc,
}
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("user", &entity.User{ID: userID})
})
r.POST("/api/v1/searchbots/retrieval_test", h.RetrievalTest)
return h, mockSvc, r
}
func TestSearchBotsRetrieval_Basic(t *testing.T) {
_, mockSvc, r := setupSearchbotsTest("user1")
body := `{"kb_id": ["kb1"], "question": "test question"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp["code"] != float64(common.CodeSuccess) {
t.Errorf("expected code 0, got %v", resp["code"])
}
if msg, _ := resp["message"].(string); msg != "success" {
t.Errorf("expected message 'success', got %q", msg)
}
// Verify field mapping: handler → service request
if mockSvc.LastReq == nil {
t.Fatal("RetrievalTest was not called")
}
if len(mockSvc.LastReq.Datasets) != 1 || mockSvc.LastReq.Datasets[0] != "kb1" {
t.Errorf("Datasets = %v, want [\"kb1\"]", mockSvc.LastReq.Datasets)
}
if mockSvc.LastReq.Question != "test question" {
t.Errorf("Question = %q, want \"test question\"", mockSvc.LastReq.Question)
}
if mockSvc.LastUserID != "user1" {
t.Errorf("userID = %q, want \"user1\"", mockSvc.LastUserID)
}
}
func TestSearchBotsRetrieval_MissingKbID(t *testing.T) {
_, _, r := setupSearchbotsTest("user1")
body := `{"question": "test"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
msg, _ := resp["message"].(string)
if msg == "" || msg == "success" {
t.Errorf("expected validation error message, got %q", msg)
}
if !strings.Contains(msg, "KbIDs") || !strings.Contains(msg, "required") {
t.Errorf("expected message to mention 'KbIDs' and 'required', got %q", msg)
}
}
func TestSearchBotsRetrieval_MissingQuestion(t *testing.T) {
_, _, r := setupSearchbotsTest("user1")
body := `{"kb_id": ["kb1"]}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
msg, _ := resp["message"].(string)
if msg == "" || msg == "success" {
t.Errorf("expected validation error message, got %q", msg)
}
if !strings.Contains(msg, "Question") || !strings.Contains(msg, "required") {
t.Errorf("expected message to mention 'Question' and 'required', got %q", msg)
}
}
func TestSearchBotsRetrieval_NoAuth(t *testing.T) {
h := NewSearchBotHandler(nil, nil, nil, &mockChunkService{})
gin.SetMode(gin.TestMode)
r := gin.New()
r.POST("/api/v1/searchbots/retrieval_test", h.RetrievalTest)
w := httptest.NewRecorder()
body := `{"kb_id": ["kb1"], "question": "test"}`
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("expected 401, got %d", w.Code)
}
}
func TestSearchBotsRetrieval_ServiceError(t *testing.T) {
h, _, r := setupSearchbotsTest("user1")
h.chunkSvc = &mockChunkService{
retrievalTestFn: func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
return nil, errors.New("db error")
},
}
w := httptest.NewRecorder()
body := `{"kb_id": ["kb1"], "question": "test"}`
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected 500, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
code, _ := resp["code"].(float64)
if code == 0 {
t.Errorf("expected non-zero error code, got %v", code)
}
msg, _ := resp["message"].(string)
if msg == "" || msg == "success" {
t.Errorf("expected error message, got %q", msg)
}
}
func TestSearchBotsRetrieval_KbIDSingleString(t *testing.T) {
// Verify "kb1" (string) is accepted and converted to []string{"kb1"}
_, mockSvc, r := setupSearchbotsTest("user1")
body := `{"kb_id": "kb1", "question": "test"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if mockSvc.LastReq == nil {
t.Fatal("RetrievalTest was not called")
}
if len(mockSvc.LastReq.Datasets) != 1 || mockSvc.LastReq.Datasets[0] != "kb1" {
t.Errorf("Datasets = %v, want [\"kb1\"]", mockSvc.LastReq.Datasets)
}
}
func TestSearchBotsRetrieval_KbIDArray(t *testing.T) {
// Verify ["a","b"] (array) still works
_, mockSvc, r := setupSearchbotsTest("user1")
body := `{"kb_id": ["a","b"], "question": "test"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if mockSvc.LastReq == nil {
t.Fatal("RetrievalTest was not called")
}
if len(mockSvc.LastReq.Datasets) != 2 || mockSvc.LastReq.Datasets[0] != "a" || mockSvc.LastReq.Datasets[1] != "b" {
t.Errorf("Datasets = %v, want [\"a\",\"b\"]", mockSvc.LastReq.Datasets)
}
}
func TestSearchBotsRetrieval_InvalidJSON(t *testing.T) {
_, _, r := setupSearchbotsTest("user1")
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader("{invalid}"))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
}
func TestSearchBotsRetrieval_EmptyStringKbID(t *testing.T) {
_, _, r := setupSearchbotsTest("user1")
body := `{"kb_id": "", "question": "test"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if msg, _ := resp["message"].(string); msg != "kb_id and question are required" {
t.Errorf("expected message 'kb_id and question are required', got %q", msg)
}
}
func TestSearchBotsRetrieval_WhitespaceOnlyKbID(t *testing.T) {
_, _, r := setupSearchbotsTest("user1")
body := `{"kb_id": " ", "question": "test"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if msg, _ := resp["message"].(string); msg != "kb_id and question are required" {
t.Errorf("expected message 'kb_id and question are required', got %q", msg)
}
}
func TestSearchBotsRetrieval_DefaultsApplied(t *testing.T) {
// Verify that when optional fields are omitted, the handler applies
// defaults matching Python bot_api.py retrieval_test endpoint.
_, mockSvc, r := setupSearchbotsTest("user1")
body := `{"kb_id": ["kb1"], "question": "does this default?"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if mockSvc.LastReq == nil {
t.Fatal("RetrievalTest was not called")
}
svcReq := mockSvc.LastReq
if svcReq.Page == nil || *svcReq.Page != 1 {
t.Errorf("Page = %v, want 1", nullableInt(svcReq.Page))
}
if svcReq.Size == nil || *svcReq.Size != 30 {
t.Errorf("Size = %v, want 30", nullableInt(svcReq.Size))
}
if svcReq.TopK == nil || *svcReq.TopK != 1024 {
t.Errorf("TopK = %v, want 1024", nullableInt(svcReq.TopK))
}
if svcReq.UseKG == nil || *svcReq.UseKG != false {
t.Errorf("UseKG = %v, want false", nullableBool(svcReq.UseKG))
}
if svcReq.Keyword == nil || *svcReq.Keyword != false {
t.Errorf("Keyword = %v, want false", nullableBool(svcReq.Keyword))
}
if svcReq.SimilarityThreshold == nil || *svcReq.SimilarityThreshold != 0.0 {
t.Errorf("SimilarityThreshold = %v, want 0.0", nullableFloat(svcReq.SimilarityThreshold))
}
if svcReq.VectorSimilarityWeight == nil || *svcReq.VectorSimilarityWeight != 0.3 {
t.Errorf("VectorSimilarityWeight = %v, want 0.3", nullableFloat(svcReq.VectorSimilarityWeight))
}
}
func TestSearchBotsRetrieval_TopKZero(t *testing.T) {
_, _, r := setupSearchbotsTest("user1")
body := `{"kb_id": ["kb1"], "question": "test", "top_k": 0}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if msg, _ := resp["message"].(string); msg != "top_k must be greater than 0" {
t.Errorf("expected message 'top_k must be greater than 0', got %q", msg)
}
}
func TestSearchBotsRetrieval_TopKNegative(t *testing.T) {
_, _, r := setupSearchbotsTest("user1")
body := `{"kb_id": ["kb1"], "question": "test", "top_k": -1}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
if msg := jsonDecodeMessage(t, w.Body.Bytes()); msg != "top_k must be greater than 0" {
t.Errorf("expected message 'top_k must be greater than 0', got %q", msg)
}
}
func jsonDecodeMessage(t *testing.T, body []byte) string {
t.Helper()
var resp map[string]interface{}
if err := json.Unmarshal(body, &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
msg, _ := resp["message"].(string)
return msg
}
func nullableInt(p *int) string {
if p == nil { return "nil" }
return fmt.Sprintf("%d", *p)
}
func nullableBool(p *bool) string {
if p == nil { return "nil" }
return fmt.Sprintf("%v", *p)
}
func nullableFloat(p *float64) string {
if p == nil { return "nil" }
return fmt.Sprintf("%v", *p)
}
func TestSearchBotsRetrieval_EmptyQuestion(t *testing.T) {
// Send kb_id but empty question — caught by binding:"required" on the DTO.
_, _, r := setupSearchbotsTest("user1")
body := `{"kb_id": ["kb1"], "question": ""}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/searchbots/retrieval_test", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
msg := jsonDecodeMessage(t, w.Body.Bytes())
if !strings.Contains(msg, "Question") || !strings.Contains(msg, "required") {
t.Errorf("expected validation error mentioning Question and required, got %q", msg)
}
}
// fakeSearchbotLLM implements searchbotLLM for testing.
type fakeSearchbotLLM struct {
response string
@@ -42,7 +415,7 @@ func (f *fakeSearchbotLLM) Chat(tenantID, modelID string, messages []modelModule
return &modelModule.ChatResponse{Answer: &f.response}, nil
}
func setupSearchbotRequest(body string) (*gin.Context, *httptest.ResponseRecorder) {
func setupSearchBotRequest(body string) (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -54,17 +427,17 @@ func setupSearchbotRequest(body string) (*gin.Context, *httptest.ResponseRecorde
return c, w
}
// TestSearchbotHandler_Success verifies the happy path.
func TestSearchbotHandler_Success(t *testing.T) {
// TestSearchBotHandler_Success verifies the happy path.
func TestSearchBotHandler_Success(t *testing.T) {
llm := &fakeSearchbotLLM{
response: `Here are some related questions:
1. How do EV impact environment?
2. What are advantages of EV?
3. Cost of EV?`,
}
h := NewSearchbotHandler(nil, nil, llm)
h := NewSearchBotHandler(nil, nil, llm, nil)
c, w := setupSearchbotRequest(`{"question": "EV benefits"}`)
c, w := setupSearchBotRequest(`{"question": "EV benefits"}`)
h.Handle(c)
var resp map[string]interface{}
@@ -72,6 +445,9 @@ func TestSearchbotHandler_Success(t *testing.T) {
if resp["code"] != float64(common.CodeSuccess) {
t.Fatalf("expected code 0, got %v: %v", resp["code"], resp["message"])
}
if msg, _ := resp["message"].(string); msg != "success" {
t.Errorf("expected message 'success', got %q", msg)
}
questions, ok := resp["data"].([]interface{})
if !ok {
@@ -85,14 +461,14 @@ func TestSearchbotHandler_Success(t *testing.T) {
}
}
// TestSearchbotHandler_EmptyResponse verifies empty LLM response returns empty list.
func TestSearchbotHandler_EmptyResponse(t *testing.T) {
// TestSearchBotHandler_EmptyResponse verifies empty LLM response returns empty list.
func TestSearchBotHandler_EmptyResponse(t *testing.T) {
llm := &fakeSearchbotLLM{
response: "No related questions found.",
}
h := NewSearchbotHandler(nil, nil, llm)
h := NewSearchBotHandler(nil, nil, llm, nil)
c, w := setupSearchbotRequest(`{"question": "EV benefits"}`)
c, w := setupSearchBotRequest(`{"question": "EV benefits"}`)
h.Handle(c)
var resp map[string]interface{}
@@ -109,14 +485,14 @@ func TestSearchbotHandler_EmptyResponse(t *testing.T) {
}
}
// TestSearchbotHandler_LLMFailure verifies error handling on LLM failure.
func TestSearchbotHandler_LLMFailure(t *testing.T) {
// TestSearchBotHandler_LLMFailure verifies error handling on LLM failure.
func TestSearchBotHandler_LLMFailure(t *testing.T) {
llm := &fakeSearchbotLLM{
err: errFake{msg: "LLM unavailable"},
}
h := NewSearchbotHandler(nil, nil, llm)
h := NewSearchBotHandler(nil, nil, llm, nil)
c, w := setupSearchbotRequest(`{"question": "EV benefits"}`)
c, w := setupSearchBotRequest(`{"question": "EV benefits"}`)
h.Handle(c)
var resp map[string]interface{}
@@ -127,12 +503,12 @@ func TestSearchbotHandler_LLMFailure(t *testing.T) {
}
}
// TestSearchbotHandler_MissingQuestion verifies validation.
func TestSearchbotHandler_MissingQuestion(t *testing.T) {
// TestSearchBotHandler_MissingQuestion verifies validation.
func TestSearchBotHandler_MissingQuestion(t *testing.T) {
llm := &fakeSearchbotLLM{response: "dummy"}
h := NewSearchbotHandler(nil, nil, llm)
h := NewSearchBotHandler(nil, nil, llm, nil)
c, w := setupSearchbotRequest(`{}`)
c, w := setupSearchBotRequest(`{}`)
h.Handle(c)
var resp map[string]interface{}

View File

@@ -42,7 +42,7 @@ type Router struct {
skillSearchHandler *handler.SkillSearchHandler
providerHandler *handler.ProviderHandler
agentHandler *handler.AgentHandler
relatedQuestionsHandler *handler.SearchbotHandler
searchBotHandler *handler.SearchBotHandler
difyRetrievalHandler *handler.DifyRetrievalHandler
pluginHandler *handler.PluginHandler
}
@@ -68,7 +68,7 @@ func NewRouter(
skillSearchHandler *handler.SkillSearchHandler,
providerHandler *handler.ProviderHandler,
agentHandler *handler.AgentHandler,
relatedQuestionsHandler *handler.SearchbotHandler,
searchBotHandler *handler.SearchBotHandler,
difyRetrievalHandler *handler.DifyRetrievalHandler,
pluginHandler *handler.PluginHandler,
) *Router {
@@ -92,7 +92,7 @@ func NewRouter(
skillSearchHandler: skillSearchHandler,
providerHandler: providerHandler,
agentHandler: agentHandler,
relatedQuestionsHandler: relatedQuestionsHandler,
searchBotHandler: searchBotHandler,
difyRetrievalHandler: difyRetrievalHandler,
pluginHandler: pluginHandler,
}
@@ -226,7 +226,8 @@ func (r *Router) Setup(engine *gin.Engine) {
}
// Searchbot routes
v1.POST("/searchbots/related_questions", r.relatedQuestionsHandler.Handle)
v1.POST("/searchbots/related_questions", r.searchBotHandler.Handle)
v1.POST("/searchbots/retrieval_test", r.searchBotHandler.RetrievalTest)
// Dataset routes
datasets := v1.Group("/datasets")

View File

@@ -19,12 +19,19 @@ package service
import (
"context"
"fmt"
"ragflow/internal/common"
"ragflow/internal/entity"
"ragflow/internal/entity/models"
"ragflow/internal/server"
"strconv"
"strings"
"go.uber.org/zap"
"ragflow/internal/dao"
"ragflow/internal/engine"
"ragflow/internal/engine/types"
"ragflow/internal/server"
"ragflow/internal/service/nlp"
"ragflow/internal/tokenizer"
"ragflow/internal/utility"
)
@@ -54,11 +61,391 @@ func NewChunkService() *ChunkService {
}
}
// GetChunkRequest request for getting a chunk by ID.
// RetrievalTestRequest retrieval test request
type RetrievalTestRequest struct {
Datasets common.StringSlice `json:"dataset_ids" binding:"required"` // string or []string
Question string `json:"question"`
Page *int `json:"page,omitempty"`
Size *int `json:"size,omitempty"`
DocIDs []string `json:"doc_ids,omitempty"`
UseKG *bool `json:"use_kg,omitempty"`
TopK *int `json:"top_k,omitempty"`
CrossLanguages []string `json:"cross_languages,omitempty"`
SearchID *string `json:"search_id,omitempty"`
Filter map[string]interface{} `json:"meta_data_filter,omitempty"`
TenantRerankID *string `json:"tenant_rerank_id,omitempty"`
RerankID *string `json:"rerank_id,omitempty"`
Keyword *bool `json:"keyword,omitempty"`
SimilarityThreshold *float64 `json:"similarity_threshold,omitempty"`
VectorSimilarityWeight *float64 `json:"vector_similarity_weight,omitempty"`
}
// RetrievalTestResponse retrieval test response
type RetrievalTestResponse struct {
Chunks []map[string]interface{} `json:"chunks"`
DocAggs []map[string]interface{} `json:"doc_aggs"`
Labels *map[string]float64 `json:"labels"`
Total int64 `json:"total"`
}
// RetrievalTest performs retrieval test for a given question against specified knowledge bases.
//
// Flow:
// 1. Validate kbs permissions and embedding model
// 2. Apply metadata filter if specified (auto/semi_auto uses LLM, manual uses provided conditions)
// 3. Apply cross_languages transformation if requested (translate question)
// 4. Apply keyword extraction if requested (append keywords to question)
// 5. Get rank features via LabelQuestion() - tag-based weights or pagerank_fld fallback
// 6. Call RetrievalService.Retrieval() which:
// - Computes query embedding
// - Performs hybrid search (text + vector) with rank features
// - Reranks results
// - Builds doc_aggs by aggregating chunks per document
// 7. knowledge graph retrieval (not implemented)
// 8. Apply retrieval by children to group child chunks under parent chunks
func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) (*RetrievalTestResponse, error) {
common.Info("RetrievalTest started", zap.String("userID", userID), zap.Any("kbID", req.Datasets), zap.String("question", req.Question))
common.Debug(fmt.Sprintf("RetrievalTest request:\n"+
" kbID=%v\n"+
" question=%s\n"+
" page=%v, size=%v\n"+
" docIDs=%v\n"+
" useKG=%v, topK=%v\n"+
" crossLanguages=%v\n"+
" searchID=%v\n"+
" filter=%v\n"+
" tenantRerankID=%v\n"+
" rerankID=%v\n"+
" keyword=%v\n"+
" similarityThreshold=%v, vectorSimilarityWeight=%v",
req.Datasets, req.Question,
ptrString(req.Page), ptrString(req.Size), req.DocIDs,
ptrString(req.UseKG), ptrString(req.TopK), req.CrossLanguages, ptrString(req.SearchID),
req.Filter,
ptrString(req.TenantRerankID), ptrString(req.RerankID),
ptrString(req.Keyword),
ptrString(req.SimilarityThreshold), ptrString(req.VectorSimilarityWeight)))
if req.Question == "" {
return nil, fmt.Errorf("question is required")
}
ctx := context.Background()
tenantIDs, kbRecords, err := s.validateKBs(userID, req.Datasets)
if err != nil {
return nil, err
}
docIDs, err := s.resolveMetaFilter(ctx, req.SearchID, req.Filter, req.Question, req.DocIDs, req.Datasets, tenantIDs)
if err != nil {
return nil, err
}
modifiedQuestion, err := s.transformQuestion(ctx, req.Question, req.CrossLanguages, req.Keyword, tenantIDs)
if err != nil {
return nil, err
}
// Get tag-based rank features via LabelQuestion
metadataSvc := NewMetadataService()
labels := metadataSvc.LabelQuestion(modifiedQuestion, kbRecords)
common.Debug("LabelQuestion result", zap.Any("labels", labels))
embeddingModel, err := s.resolveEmbeddingModel(tenantIDs[0], kbRecords[0])
if err != nil {
return nil, err
}
rerankModel, err := s.resolveRerankModel(tenantIDs[0], req.TenantRerankID, req.RerankID)
if err != nil {
return nil, err
}
retrievalReq := &nlp.RetrievalRequest{
TenantIDs: tenantIDs,
Question: modifiedQuestion,
KbIDs: []string(req.Datasets),
DocIDs: docIDs,
Page: getPageNum(req.Page, 1),
PageSize: getPageSize(req.Size, 30),
Top: req.TopK,
SimilarityThreshold: req.SimilarityThreshold,
VectorSimilarityWeight: req.VectorSimilarityWeight,
RerankModel: rerankModel,
RankFeature: &labels,
EmbeddingModel: embeddingModel,
}
// Call RetrievalService to perform retrieval
retrievalResult, err := nlp.NewRetrievalService(s.docEngine, s.documentDAO).Retrieval(ctx, retrievalReq)
if err != nil {
return nil, fmt.Errorf("retrieval search failed: %w", err)
}
filteredChunks := retrievalResult.Chunks
// Handle knowledge graph retrieval
// TODO: KG retrieval requires GraphRAG infrastructure which is not yet implemented in Go
if req.UseKG != nil && *req.UseKG {
common.Warn("use_kg is not yet implemented in Go - skipping KG retrieval")
}
// Apply retrieval_by_children - aggregate child chunks into parent chunks
filteredChunks = nlp.RetrievalByChildren(filteredChunks, tenantIDs, s.docEngine, ctx)
// Remove vector field from each chunk
for i := range filteredChunks {
delete(filteredChunks[i], "vector")
}
common.Info("RetrievalTest completed", zap.String("userID", userID), zap.Any("kbID", req.Datasets), zap.String("question", req.Question), zap.Int64("chunkCount", int64(len(filteredChunks))))
return &RetrievalTestResponse{
Chunks: filteredChunks,
DocAggs: retrievalResult.DocAggs,
Labels: &labels,
Total: int64(len(filteredChunks)),
}, nil
}
// validateKBs resolves tenant IDs and KB records for the given dataset IDs.
func (s *ChunkService) validateKBs(userID string, datasetIDs []string) ([]string, []*entity.Knowledgebase, error) {
tenants, err := s.userTenantDAO.GetByUserID(userID)
if err != nil {
return nil, nil, fmt.Errorf("failed to get user tenants: %w", err)
}
if len(tenants) == 0 {
return nil, nil, fmt.Errorf("user has no accessible tenants")
}
common.Debug("Retrieved user tenants from database", zap.String("userID", userID), zap.Int("tenantCount", len(tenants)))
var tenantIDs []string
var kbRecords []*entity.Knowledgebase
for _, datasetID := range datasetIDs {
found := false
for _, tenant := range tenants {
kb, err := s.kbDAO.GetByIDAndTenantID(datasetID, tenant.TenantID)
if err == nil && kb != nil {
common.Debug("Found knowledge base in database",
zap.String("datasetID", datasetID),
zap.String("tenantID", tenant.TenantID),
zap.String("kbName", kb.Name),
zap.String("embdID", kb.EmbdID))
tenantIDs = append(tenantIDs, tenant.TenantID)
kbRecords = append(kbRecords, kb)
found = true
break
}
}
if !found {
return nil, nil, fmt.Errorf("only owner of dataset is authorized for this operation")
}
}
if len(kbRecords) > 1 {
firstEmbdID := kbRecords[0].EmbdID
for i := 1; i < len(kbRecords); i++ {
if kbRecords[i].EmbdID != firstEmbdID {
return nil, nil, fmt.Errorf("cannot retrieve across datasets with different embedding models")
}
}
}
return tenantIDs, kbRecords, nil
}
// resolveMetaFilter resolves a metadata filter from search_id and applies it.
func (s *ChunkService) resolveMetaFilter(ctx context.Context, searchID *string, initialFilter map[string]interface{}, question string, docIDs []string, datasetIDs []string, tenantIDs []string) ([]string, error) {
var chatID string
var chatModelForFilter *models.ChatModel
filter := initialFilter
if searchID != nil && *searchID != "" {
searchDetail, err := s.searchService.GetDetail(*searchID)
if err != nil {
common.Warn("Failed to get search detail for search_id, proceeding without it", zap.String("searchID", *searchID), zap.Error(err))
} else if searchConfig, ok := searchDetail["search_config"].(entity.JSONMap); ok && searchConfig != nil {
if searchMetaFilter, ok := searchConfig["meta_data_filter"].(map[string]interface{}); ok {
filter = searchMetaFilter
}
chatID, _ = searchConfig["chat_id"].(string)
} else {
common.Warn("No search_config found in search detail", zap.String("searchID", *searchID))
}
}
if filter != nil {
method, _ := filter["method"].(string)
if method == "auto" || method == "semi_auto" {
modelProviderSvc := NewModelProviderService()
if chatID != "" {
driver, mdlName, apiConfig, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeChat, chatID)
if getErr != nil {
common.Warn("Failed to get chat model from search_config chat_id, using tenant default", zap.String("chatID", chatID), zap.Error(getErr))
} else {
chatModelForFilter = models.NewChatModel(driver, &mdlName, apiConfig)
common.Info("Fetched chat model (from search_config) for metadata filter",
zap.String("chatID", chatID), zap.String("tenantID", tenantIDs[0]))
}
}
if chatModelForFilter == nil {
tenantSvc := NewTenantService()
modelName, err := tenantSvc.GetDefaultModelName(tenantIDs[0], entity.ModelTypeChat)
if err != nil || modelName == "" {
common.Warn("Failed to get tenant default chat model name for meta_data_filter", zap.Error(err))
} else {
driver, mdlName, apiConfig, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeChat, modelName)
if getErr != nil {
common.Warn("Failed to get chat model for meta_data_filter", zap.Error(getErr))
} else {
chatModelForFilter = models.NewChatModel(driver, &mdlName, apiConfig)
common.Info("Fetched chat model (tenant default) for metadata filter",
zap.String("tenantID", tenantIDs[0]), zap.String("modelName", modelName))
}
}
}
}
}
out := make([]string, len(docIDs))
copy(out, docIDs)
if filter != nil {
metadataSvc := NewMetadataService()
flattedMeta, err := metadataSvc.GetFlattedMetaByKBs([]string(datasetIDs))
if err != nil {
common.Warn("Failed to get flatted metadata", zap.Error(err))
} else {
common.Info("metadata filter conditions", zap.Any("filter", filter))
filteredDocIDs, _ := ApplyMetaDataFilter(ctx, filter, flattedMeta, question, chatModelForFilter, docIDs, []string(datasetIDs))
out = filteredDocIDs
common.Info("ApplyMetaDataFilter result", zap.Strings("docIDs", out))
}
}
return out, nil
}
// transformQuestion applies cross-languages translation and keyword extraction.
func (s *ChunkService) transformQuestion(ctx context.Context, question string, crossLanguages []string, keyword *bool, tenantIDs []string) (string, error) {
modifiedQuestion := question
if len(crossLanguages) == 0 && (keyword == nil || !*keyword) {
return modifiedQuestion, nil
}
tenantSvc := NewTenantService()
modelProviderSvc := NewModelProviderService()
modelName, err := tenantSvc.GetDefaultModelName(tenantIDs[0], "chat")
if err != nil || modelName == "" {
common.Warn("Failed to get default chat model name for LLM transformations", zap.Error(err))
return question, nil
}
driver, mdlName, apiConfig, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeChat, modelName)
if getErr != nil {
common.Warn("Failed to get chat model for LLM transformations", zap.Error(getErr))
return question, nil
}
chatModel := models.NewChatModel(driver, &mdlName, apiConfig)
common.Info("Fetched chat model (tenant default) for cross_languages/keyword_extraction",
zap.String("tenantID", tenantIDs[0]), zap.String("modelName", modelName))
if len(crossLanguages) > 0 {
translated, err := CrossLanguages(ctx, tenantIDs[0], modelName, question, crossLanguages)
if err != nil {
common.Warn("Failed to translate question", zap.Error(err))
} else {
modifiedQuestion = translated
}
}
if keyword != nil && *keyword {
extractedKeywords, err := KeywordExtraction(ctx, chatModel, modifiedQuestion, 3)
if err != nil {
common.Warn("Failed to extract keywords from question", zap.Error(err))
} else if extractedKeywords != "" {
modifiedQuestion = modifiedQuestion + " " + extractedKeywords
}
}
if modifiedQuestion != question {
common.Info("Modified question after transformations",
zap.String("originalQuestion", question),
zap.String("modifiedQuestion", modifiedQuestion),
zap.Strings("crossLanguages", crossLanguages),
zap.Bool("keywordExtraction", keyword != nil && *keyword))
}
return modifiedQuestion, nil
}
// resolveEmbeddingModel resolves the embedding model for a KB record.
func (s *ChunkService) resolveEmbeddingModel(tenantID string, kbRecord *entity.Knowledgebase) (*models.EmbeddingModel, error) {
var embdID string
var err error
if kbRecord.TenantEmbdID != nil && *kbRecord.TenantEmbdID > 0 {
_, embdID, err = dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *kbRecord.TenantEmbdID)
if err != nil {
return nil, fmt.Errorf("failed to get embedding model by tenant_embd_id: %w", err)
}
} else if kbRecord.EmbdID != "" {
parts := strings.Split(kbRecord.EmbdID, "@")
if len(parts) == 2 && parts[1] != "" {
_, embdID, err = dao.LookupTenantLLMByFactory(dao.NewTenantLLMDAO(), tenantID, parts[1], parts[0], entity.ModelTypeEmbedding)
} else {
_, embdID, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantID, kbRecord.EmbdID, entity.ModelTypeEmbedding)
}
if err != nil {
return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", err)
}
} else {
tenantLLM, err := dao.NewTenantLLMDAO().GetByTenantAndType(tenantID, entity.ModelTypeEmbedding)
if err != nil {
return nil, fmt.Errorf("failed to get tenant default embedding model: %w", err)
}
if tenantLLM == nil || tenantLLM.LLMName == nil || *tenantLLM.LLMName == "" {
return nil, fmt.Errorf("no default embedding model found for tenant %s", tenantID)
}
embdID = fmt.Sprintf("%s@%s", *tenantLLM.LLMName, tenantLLM.LLMFactory)
}
modelProviderSvc := NewModelProviderService()
embeddingModel, err := modelProviderSvc.GetEmbeddingModel(tenantID, embdID)
if err != nil {
return nil, fmt.Errorf("failed to get embedding model: %w", err)
}
common.Info("Fetched embedding model for retrieval",
zap.String("tenantID", tenantID), zap.String("embdID", embdID))
return embeddingModel, nil
}
// resolveRerankModel resolves the rerank model from tenant_rerank_id or rerank_id.
func (s *ChunkService) resolveRerankModel(tenantID string, tenantRerankID, rerankID *string) (*models.RerankModel, error) {
var rerankCompositeName string
var err error
if tenantRerankID != nil && *tenantRerankID != "" {
tenantRerankIDInt, parseErr := strconv.ParseInt(*tenantRerankID, 10, 64)
if parseErr != nil {
return nil, fmt.Errorf("invalid tenant_rerank_id: %w", parseErr)
}
_, rerankCompositeName, err = dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), tenantRerankIDInt)
if err != nil {
return nil, fmt.Errorf("failed to get rerank model by tenant_rerank_id: %w", err)
}
} else if rerankID != nil && *rerankID != "" {
_, rerankCompositeName, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantID, *rerankID, entity.ModelTypeRerank)
if err != nil {
return nil, fmt.Errorf("failed to get rerank model by rerank_id: %w", err)
}
}
if rerankCompositeName == "" {
return nil, nil
}
modelProviderSvc := NewModelProviderService()
driver, mdlName, apiConfig, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantID, entity.ModelTypeRerank, rerankCompositeName)
if getErr != nil {
return nil, fmt.Errorf("failed to get rerank model: %w", getErr)
}
rerankModel := models.NewRerankModel(driver, &mdlName, apiConfig)
common.Info("Fetched rerank model",
zap.String("tenantID", tenantID), zap.String("rerankCompositeName", rerankCompositeName))
return rerankModel, nil
}
// GetChunkRequest request for getting a chunk by ID
type GetChunkRequest struct {
ChunkID string `json:"chunk_id"`
DocumentID string `json:"document_id"`
DatasetID string `json:"dataset_id"`
ChunkID string `json:"chunk_id"`
}
// GetChunkResponse response for getting a chunk
@@ -75,27 +462,10 @@ func (s *ChunkService) Get(req *GetChunkRequest, userID string) (*GetChunkRespon
if req.ChunkID == "" {
return nil, fmt.Errorf("chunk_id is required")
}
if req.DatasetID == "" {
return nil, fmt.Errorf("dataset_id is required")
}
if req.DocumentID == "" {
return nil, fmt.Errorf("document_id is required")
}
ctx := context.Background()
// Verify the document exists and belongs to the dataset
docDAO := dao.NewDocumentDAO()
doc, err := docDAO.GetByID(req.DocumentID)
if err != nil || doc == nil {
return nil, fmt.Errorf("document not found")
}
if doc.KbID != req.DatasetID {
return nil, fmt.Errorf("document does not belong to this dataset")
}
// Resolve the tenant that owns this dataset. A dataset belongs to
// exactly one tenant, so the loop breaks on the first match.
// Get user's tenants
tenants, err := s.userTenantDAO.GetByUserID(userID)
if err != nil {
return nil, fmt.Errorf("failed to get user tenants: %w", err)
@@ -103,87 +473,75 @@ func (s *ChunkService) Get(req *GetChunkRequest, userID string) (*GetChunkRespon
if len(tenants) == 0 {
return nil, fmt.Errorf("user has no accessible tenants")
}
var targetTenantID string
// Try each tenant to find the chunk
var chunk map[string]interface{}
for _, tenant := range tenants {
kb, err := s.kbDAO.GetByIDAndTenantID(req.DatasetID, tenant.TenantID)
if err == nil && kb != nil {
targetTenantID = tenant.TenantID
break
}
}
if targetTenantID == "" {
return nil, fmt.Errorf("user does not have access to this dataset")
}
indexName := fmt.Sprintf("ragflow_%s", targetTenantID)
raw, err := s.docEngine.GetChunk(ctx, indexName, req.ChunkID, []string{req.DatasetID})
if err != nil {
return nil, fmt.Errorf("failed to get chunk: %w", err)
}
if raw == nil {
return nil, fmt.Errorf("chunk not found")
}
chunk, ok := raw.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid chunk format: expected map[string]interface{}, got %T", raw)
}
if actual, _ := chunk["kb_id"].(string); actual != req.DatasetID {
return nil, fmt.Errorf("chunk does not belong to dataset %q", req.DatasetID)
}
if actual, _ := chunk["doc_id"].(string); actual != req.DocumentID {
return nil, fmt.Errorf("chunk does not belong to document %q", req.DocumentID)
}
return &GetChunkResponse{Chunk: formatChunkForGet(chunk)}, nil
}
// formatChunkForGet normalizes a raw engine chunk into the public
// response shape: it strips internal fields (_vec / _tks / _ltks / _sm_
// suffixes and a handful of score / id fields), renames a few legacy
// keys, and coerces numeric fields to JSONFloat64 so the JSON encoder
// preserves precision.
func formatChunkForGet(chunk map[string]interface{}) map[string]interface{} {
result := make(map[string]interface{})
skipFields := map[string]bool{
"id": true, "authors": true, "_score": true, "SCORE": true,
}
for k, v := range chunk {
if skipFields[k] || strings.HasSuffix(k, "_vec") || strings.Contains(k, "_sm_") || strings.HasSuffix(k, "_tks") || strings.HasSuffix(k, "_ltks") {
// Get kbIDs for this tenant
kbIDs, err := s.kbDAO.GetKBIDsByTenantID(tenant.TenantID)
if err != nil {
continue
}
switch k {
case "content":
result["content_with_weight"] = v
case "docnm":
result["docnm_kwd"] = v
case "important_keywords":
utility.SetFieldArray(result, "important_kwd", v)
case "questions":
utility.SetFieldArray(result, "question_kwd", v)
case "entities_kwd", "entity_kwd", "entity_type_kwd", "from_entity_kwd",
"name_kwd", "raptor_kwd", "removed_kwd", "source_id", "tag_kwd",
"to_entity_kwd", "toc_kwd":
if utility.IsEmpty(v) {
result[k] = []interface{}{}
} else {
result[k] = v
indexName := fmt.Sprintf("ragflow_%s", tenant.TenantID)
doc, err := s.docEngine.GetChunk(ctx, indexName, req.ChunkID, kbIDs)
if err != nil {
continue
}
if doc != nil {
chunk, ok := doc.(map[string]interface{})
if ok {
result := make(map[string]interface{})
skipFields := map[string]bool{
"id": true, "authors": true, "_score": true, "SCORE": true,
}
for k, v := range chunk {
if skipFields[k] || strings.HasSuffix(k, "_vec") || strings.Contains(k, "_sm_") || strings.HasSuffix(k, "_tks") || strings.HasSuffix(k, "_ltks") {
continue
}
switch k {
case "content":
result["content_with_weight"] = v
case "docnm":
result["docnm_kwd"] = v
case "important_keywords":
utility.SetFieldArray(result, "important_kwd", v)
case "questions":
utility.SetFieldArray(result, "question_kwd", v)
case "entities_kwd", "entity_kwd", "entity_type_kwd", "from_entity_kwd",
"name_kwd", "raptor_kwd", "removed_kwd", "source_id", "tag_kwd",
"to_entity_kwd", "toc_kwd", "authors_tks", "doc_type_kwd":
if utility.IsEmpty(v) {
result[k] = []interface{}{}
} else {
result[k] = v
}
case "tag_feas":
if utility.IsEmpty(v) {
result[k] = map[string]interface{}{}
} else {
result[k] = v
}
case "create_timestamp_flt", "rank_flt", "weight_flt":
if floatVal, ok := utility.ToFloat64(v); ok {
result[k] = utility.JSONFloat64(floatVal)
}
default:
result[k] = v
}
}
return &GetChunkResponse{Chunk: result}, nil
}
case "tag_feas":
if utility.IsEmpty(v) {
result[k] = map[string]interface{}{}
} else {
result[k] = v
}
case "create_timestamp_flt", "rank_flt", "weight_flt":
if floatVal, ok := utility.ToFloat64(v); ok {
result[k] = utility.JSONFloat64(floatVal)
}
default:
result[k] = v
}
}
return result
if chunk == nil {
return nil, fmt.Errorf("chunk not found")
}
return &GetChunkResponse{Chunk: chunk}, nil
}
// ListChunksRequest request for listing chunks
@@ -316,7 +674,7 @@ func (s *ChunkService) List(req *ListChunksRequest, userID string) (*ListChunksR
utility.SetFieldArray(result, "question_kwd", v)
case "entities_kwd", "entity_kwd", "entity_type_kwd", "from_entity_kwd",
"name_kwd", "raptor_kwd", "removed_kwd",
"source_id", "tag_kwd", "to_entity_kwd", "toc_kwd":
"source_id", "tag_kwd", "to_entity_kwd", "toc_kwd", "doc_type_kwd":
if utility.IsEmpty(v) {
result[k] = []interface{}{}
} else {
@@ -447,12 +805,10 @@ func (s *ChunkService) UpdateChunk(req *UpdateChunkRequest, userID string) error
if err != nil {
return fmt.Errorf("failed to get existing chunk: %w", err)
}
if existingChunk == nil {
return fmt.Errorf("chunk %q not found in dataset %q", req.ChunkID, req.DatasetID)
}
existing, ok := existingChunk.(map[string]interface{})
if !ok {
return fmt.Errorf("invalid chunk format: expected map[string]interface{}, got %T", existingChunk)
return fmt.Errorf("invalid chunk format")
}
// Build update dict

View File

@@ -0,0 +1,250 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package service
import (
"context"
"testing"
"ragflow/internal/entity"
)
// --- Helper tests ---
func TestPtrString_Nil(t *testing.T) {
if got := ptrString[int](nil); got != "<nil>" {
t.Errorf("ptrString(nil) = %q, want <nil>", got)
}
}
func TestPtrString_Value(t *testing.T) {
val := 42
if got := ptrString(&val); got != "42" {
t.Errorf("ptrString(&42) = %q, want 42", got)
}
}
func TestPtrString_Bool(t *testing.T) {
val := true
if got := ptrString(&val); got != "true" {
t.Errorf("ptrString(&true) = %q, want true", got)
}
}
func TestGetPageNum_Nil(t *testing.T) {
if got := getPageNum(nil, 10); got != 10 {
t.Errorf("getPageNum(nil, 10) = %d, want 10", got)
}
}
func TestGetPageNum_ZeroReturnsDefault(t *testing.T) {
val := 0
if got := getPageNum(&val, 5); got != 5 {
t.Errorf("getPageNum(&0, 5) = %d, want 5", got)
}
}
func TestGetPageNum_NegativeReturnsDefault(t *testing.T) {
val := -1
if got := getPageNum(&val, 5); got != 5 {
t.Errorf("getPageNum(&-1, 5) = %d, want 5", got)
}
}
func TestGetPageNum_Valid(t *testing.T) {
val := 3
if got := getPageNum(&val, 5); got != 3 {
t.Errorf("getPageNum(&3, 5) = %d, want 3", got)
}
}
func TestGetPageSize_Nil(t *testing.T) {
if got := getPageSize(nil, 20); got != 20 {
t.Errorf("getPageSize(nil, 20) = %d, want 20", got)
}
}
func TestGetPageSize_ZeroReturnsDefault(t *testing.T) {
val := 0
if got := getPageSize(&val, 20); got != 20 {
t.Errorf("getPageSize(&0, 20) = %d, want 20", got)
}
}
func TestGetPageSize_Valid(t *testing.T) {
val := 50
if got := getPageSize(&val, 20); got != 50 {
t.Errorf("getPageSize(&50, 20) = %d, want 50", got)
}
}
// --- RetrievalTestRequest validation tests ---
func TestRetrievalTestRequest_Defaults(t *testing.T) {
req := &RetrievalTestRequest{
Datasets: []string{"kb1"},
Question: "test question",
}
// Verify pointer fields are nil by default
if req.Page != nil {
t.Error("Page should default to nil")
}
if req.Size != nil {
t.Error("Size should default to nil")
}
if req.TopK != nil {
t.Error("TopK should default to nil")
}
if req.UseKG != nil {
t.Error("UseKG should default to nil")
}
if req.SimilarityThreshold != nil {
t.Error("SimilarityThreshold should default to nil")
}
if req.VectorSimilarityWeight != nil {
t.Error("VectorSimilarityWeight should default to nil")
}
if req.Keyword != nil {
t.Error("Keyword should default to nil")
}
}
func TestRetrievalTestResponse_Fields(t *testing.T) {
resp := &RetrievalTestResponse{
Chunks: []map[string]interface{}{},
DocAggs: []map[string]interface{}{},
Total: 0,
}
if resp.Chunks == nil {
t.Error("Chunks should not be nil")
}
if resp.DocAggs == nil {
t.Error("DocAggs should not be nil")
}
if resp.Total != 0 {
t.Errorf("Total = %d, want 0", resp.Total)
}
}
// --- transformQuestion edge cases ---
func TestTransformQuestion_NoTransformNeeded(t *testing.T) {
svc := &ChunkService{}
ctx := context.Background()
result, err := svc.transformQuestion(ctx, "hello", nil, nil, []string{"t1"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != "hello" {
t.Errorf("expected unchanged question, got %q", result)
}
}
func TestTransformQuestion_EmptyCrossLanguages(t *testing.T) {
svc := &ChunkService{}
ctx := context.Background()
kw := false
result, err := svc.transformQuestion(ctx, "hello", []string{}, &kw, []string{"t1"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != "hello" {
t.Errorf("expected unchanged question, got %q", result)
}
}
func TestTransformQuestion_KeywordFalse(t *testing.T) {
// This test verifies the early-return path for transformQuestion.
// With crossLanguages non-empty it would hit the DB; this is tested
// via integration tests that have a full service setup.
svc := &ChunkService{}
ctx := context.Background()
kw := false
result, err := svc.transformQuestion(ctx, "hello", []string{}, &kw, []string{"t1"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != "hello" {
t.Errorf("expected unchanged question, got %q", result)
}
}
// --- resolveEmbeddingModel via exported Retriever ---
// These test that the retriever can handle nil inputs gracefully
func TestResolveEmbeddingModel_NilTenantEmbdID(t *testing.T) {
kb := &entity.Knowledgebase{
EmbdID: "text-embedding-ada-002@OpenAI",
}
// This will fail because it needs a real DAO, but we verify the type contract
if kb.TenantEmbdID != nil {
t.Error("TenantEmbdID should be nil for this test")
}
_ = kb // verified fields are accessible
}
func TestResolveRerankModel_BothNil(t *testing.T) {
svc := &ChunkService{}
result, err := svc.resolveRerankModel("t1", nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != nil {
t.Errorf("expected nil rerank model when both IDs are nil, got %v", result)
}
}
func TestResolveRerankModel_EmptyStrings(t *testing.T) {
svc := &ChunkService{}
empty := ""
result, err := svc.resolveRerankModel("t1", &empty, &empty)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != nil {
t.Errorf("expected nil rerank model when both IDs are empty, got %v", result)
}
}
func TestResolveRerankModel_InvalidTenantRerankID(t *testing.T) {
svc := &ChunkService{}
invalid := "not_a_number"
_, err := svc.resolveRerankModel("t1", &invalid, nil)
if err == nil {
t.Error("expected error for invalid tenant_rerank_id")
}
}
// --- validateKBs input validation ---
func TestValidateKBs_EmptyDatasets(t *testing.T) {
// validateKBs iterates over datasetIDs and queries DAOs.
// With empty input it should return empty slices.
// This test is limited since validateKBs requires DB-backed DAOs.
_ = &ChunkService{} // compiles
}
// --- Verify ChunkService struct fields ---
func TestChunkService_FieldsAccessible(t *testing.T) {
svc := &ChunkService{}
_ = svc.docEngine
_ = svc.kbDAO
_ = svc.userTenantDAO
_ = svc.searchService
// Verify embeddingCache field type
_ = svc.embeddingCache
}