Files
ragflow/internal/handler/chunk_test.go
Jack 8f4809d1b5 feat: implement POST /api/v1/searchbots/retrieval_test (#15710)
## What problem does this PR solve?

Implements `POST /api/v1/searchbots/retrieval_test` in the Go API
server, aligning with the Python `bot_api.py` counterpart. Also applies
security hardening and consistency fixes discovered during CTO-level
code review:

- **Missing endpoint**: `retrieval_test` was not available in Go,
requiring Python fallback
- **Security**: Both `chunkHandler` and `searchBotHandler` leaked
`err.Error()` to API consumers
- **Python alignment**: Default values, empty question handling, and
`top_k <= 0` validation differed from Python behavior
- **Test gaps**: `chunkHandler.RetrievalTest` had zero unit tests;
several edge cases uncovered

## Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring

## Summary

### New Endpoint
- `POST /api/v1/searchbots/retrieval_test` — retrieval test with full
field support (page, size, top_k, use_kg, cross_languages, keyword,
similarity_threshold, vector_similarity_weight)

### New Type
- `common.StringSlice` — JSON type that accepts both `"kb1"` and
`["kb1", "kb2"]`, matching Python API flexibility

### Security
- Both `searchBotHandler` and `chunkHandler` now use `common.Warn()` +
generic error messages instead of leaking `err.Error()` to API consumers
- All error responses include consistent `"data": nil` shape
- `chunkHandler.RetrievalTest` uses interface-based DI (`chunkService`)
to enable testability

### Python Alignment
- Handler-level defaults align with Python `bot_api.py` (page=1,
size=30, top_k=1024, similarity_threshold=0.0,
vector_similarity_weight=0.3)
- `top_k <= 0` validation matching Python behavior
- Empty/whitespace question returns 200 + empty result (matches
`chunk_api.py`)
- `chunkHandler` `Datasets` field uses `common.StringSlice` for
string-or-array flexibility

### Refactoring
- `ChunkServiceIface` → `ChunkRetriever`, `chunkSvcIface` →
`chunkService` (Go-conventional naming)
- Extracted `applyRetrievalDefaults`, `toRetrievalServiceRequest` from
handler body
- Regex moved to package-level var in `parseRelatedQuestions`
- `service.RetrievalTestRequest.Datasets` type changed to
`common.StringSlice`
- `chunkHandler` now uses consumer-side interface for DI

### Tests
- 37 unit tests across both handlers: auth, validation, defaults,
StringSlice edge cases, empty/whitespace KbID, service errors, JSON
format, `top_k <= 0`, field mapping verification

## Files Changed

| File | Change |
|------|--------|
| `cmd/server_main.go` | Wire new handler + chunkService +
difyRetrievalHandler |
| `internal/common/json_types.go` | New StringSlice type |
| `internal/common/json_types_test.go` | StringSlice tests |
| `internal/handler/chunk.go` | Interface-based DI, security, Python
alignment, defaults |
| `internal/handler/chunk_test.go` | New — 9 comprehensive tests |
| `internal/handler/searchbot.go` | New endpoint + refactoring + `top_k
<= 0` validation |
| `internal/handler/searchbot_test.go` | 18 tests covering all edge
cases |
| `internal/router/router.go` | Register new route +
difyRetrievalHandler |
| `internal/service/chunk.go` | Datasets type → StringSlice, Question
binding relaxed |

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-08 16:16:56 +08:00

293 lines
9.1 KiB
Go

package handler
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"ragflow/internal/common"
"ragflow/internal/entity"
"ragflow/internal/service"
"github.com/gin-gonic/gin"
)
// mockChunkSvc implements chunkSvcIface for testing ChunkHandler.
// Only the methods actually called by the test are set; others panic.
type mockChunkSvc struct {
retrievalTestFn func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error)
}
func (m *mockChunkSvc) RetrievalTest(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
if m.retrievalTestFn != nil {
return m.retrievalTestFn(req, userID)
}
return &service.RetrievalTestResponse{
Chunks: []map[string]interface{}{{"docnm_kwd": "test", "content_with_weight": "content"}},
Total: 1,
}, nil
}
func (m *mockChunkSvc) Get(*service.GetChunkRequest, string) (*service.GetChunkResponse, error) {
panic("not implemented")
}
func (m *mockChunkSvc) List(*service.ListChunksRequest, string) (*service.ListChunksResponse, error) {
panic("not implemented")
}
func (m *mockChunkSvc) UpdateChunk(*service.UpdateChunkRequest, string) error {
panic("not implemented")
}
func (m *mockChunkSvc) RemoveChunks(*service.RemoveChunksRequest, string) (int64, error) {
panic("not implemented")
}
func setupChunkRetrievalTest(userID string) (*gin.Engine, *mockChunkSvc) {
mock := &mockChunkSvc{}
h := &ChunkHandler{chunkService: mock}
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("user", &entity.User{ID: userID})
})
r.POST("/api/v1/datasets/search", h.RetrievalTest)
return r, mock
}
func setupChunkRetrievalTestNoAuth() *gin.Engine {
// Returns a router without the user middleware — used for error-path
// tests that don't call the service.
h := &ChunkHandler{}
gin.SetMode(gin.TestMode)
r := gin.New()
r.POST("/api/v1/datasets/search", h.RetrievalTest)
return r
}
func TestChunkRetrieval_EmptyQuestion(t *testing.T) {
r, _ := setupChunkRetrievalTest("user1")
body := `{"dataset_ids": ["kb1"], "question": ""}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp["code"] != float64(common.CodeSuccess) {
t.Errorf("expected code 0, got %v: %q", resp["code"], resp["message"])
}
data, ok := resp["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected data to be object, got %T", resp["data"])
}
chunks, _ := data["chunks"].([]interface{})
if chunks == nil || len(chunks) != 0 {
t.Errorf("expected empty chunks array, got %v", chunks)
}
if total, _ := data["total"].(float64); total != 0 {
t.Errorf("expected total 0, got %v", total)
}
}
func TestChunkRetrieval_WhitespaceQuestion(t *testing.T) {
r, _ := setupChunkRetrievalTest("user1")
body := `{"dataset_ids": ["kb1"], "question": " "}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp["code"] != float64(common.CodeSuccess) {
t.Errorf("expected code 0, got %v", resp["code"])
}
}
func TestChunkRetrieval_TopKZero(t *testing.T) {
r, _ := setupChunkRetrievalTest("user1")
body := `{"dataset_ids": ["kb1"], "question": "test", "top_k": 0}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if msg, _ := resp["message"].(string); msg != "top_k must be greater than 0" {
t.Errorf("expected 'top_k must be greater than 0', got %q", msg)
}
}
func TestChunkRetrieval_MissingDatasetIDs(t *testing.T) {
r, _ := setupChunkRetrievalTest("user1")
body := `{"question": "test"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
}
func TestChunkRetrieval_EmptyDatasetIDs(t *testing.T) {
r, _ := setupChunkRetrievalTest("user1")
body := `{"dataset_ids": [], "question": "test"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if msg, _ := resp["message"].(string); msg != "kb_id array cannot be empty" {
t.Errorf("expected 'kb_id array cannot be empty', got %q", msg)
}
}
func TestChunkRetrieval_NoAuth(t *testing.T) {
r := setupChunkRetrievalTestNoAuth()
body := `{"dataset_ids": ["kb1"], "question": "test"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d", w.Code)
}
// jsonError returns HTTP 200 with error code in body
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp["code"] == float64(common.CodeSuccess) {
t.Errorf("expected error code, got %v", resp["code"])
}
}
func TestChunkRetrieval_InvalidJSON(t *testing.T) {
r, _ := setupChunkRetrievalTest("user1")
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader("{invalid}"))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
}
func TestChunkRetrieval_Success(t *testing.T) {
_, mock := setupChunkRetrievalTest("user1")
mock.retrievalTestFn = func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
if userID != "user1" {
t.Errorf("expected userID 'user1', got %q", userID)
}
return &service.RetrievalTestResponse{
Chunks: []map[string]interface{}{{"docnm_kwd": "result"}},
DocAggs: []map[string]interface{}{{"doc_id": "1", "count": float64(1)}},
Total: 1,
}, nil
}
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("user", &entity.User{ID: "user1"})
})
h := &ChunkHandler{chunkService: mock}
r.POST("/api/v1/datasets/search", h.RetrievalTest)
body := `{"dataset_ids": ["kb1"], "question": "test"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp["code"] != float64(common.CodeSuccess) {
t.Fatalf("expected code 0, got %v: %q", resp["code"], resp["message"])
}
data, ok := resp["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected data object, got %T", resp["data"])
}
if total, _ := data["total"].(float64); total != 1 {
t.Errorf("expected total 1, got %v", total)
}
}
func TestChunkRetrieval_ServiceError(t *testing.T) {
_, mock := setupChunkRetrievalTest("user1")
mock.retrievalTestFn = func(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
return nil, errors.New("db connection refused")
}
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("user", &entity.User{ID: "user1"})
})
h := &ChunkHandler{chunkService: mock}
r.POST("/api/v1/datasets/search", h.RetrievalTest)
body := `{"dataset_ids": ["kb1"], "question": "test"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/datasets/search", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
msg, _ := resp["message"].(string)
if msg != "dataset search failed" {
t.Errorf("expected generic error message, got %q", msg)
}
if strings.Contains(msg, "db connection refused") {
t.Errorf("internal error details leaked to response: %q", msg)
}
}