Files
ragflow/internal/handler/dify_retrieval_handler_test.go
Jack 5a04ac0864 feat: Dify-compatible retrieval API endpoint (#15704)
## Summary

Dify-compatible retrieval API for external knowledge base integration.

## Changes

- **New handler**: DifyRetrievalHandler with POST/GET
/api/v1/dify/retrieval
- **Health check**: GET /api/v1/dify/retrieval/health
- **Full pipeline**: KB validation -> permission check -> embedding ->
metadata filter -> chunk retrieval -> child chunk aggregation ->
optional KG search -> response assembly
- **12 tests** covering all paths (success, errors, metadata filter, KG
mode)
- **Testability**: Handler dependencies defined as interfaces
(KBServiceIface, ModelServiceIface, etc.)

## Files

| File | Type |
|------|------|
| internal/handler/dify_retrieval_handler.go | New — handler +
interfaces |
| internal/handler/dify_retrieval_handler_test.go | New — 12 tests |
| internal/router/router.go | Modified — route registration |
| cmd/server_main.go | Modified — handler wiring |
| internal/service/kg/pipeline.go | Modified — SetChatModel/SetEmbModel
|
| internal/service/kg/retrieval.go | New — helper functions |
| internal/service/kg/scoring.go | Moved from service package |
| internal/service/kg/search.go | New — KG search functions |
| internal/service/kg/types.go | New — type definitions |

---------

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-05 21:16:25 +08:00

402 lines
12 KiB
Go

//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package handler
import (
"context"
"encoding/json"
"errors"
"gorm.io/gorm"
"net/http"
"net/http/httptest"
"strings"
"testing"
"ragflow/internal/common"
"ragflow/internal/engine"
"ragflow/internal/entity"
modelModule "ragflow/internal/entity/models"
"ragflow/internal/service/nlp"
"ragflow/internal/engine/types"
"github.com/gin-gonic/gin"
)
// --- Mock implementations ---
type mockKBService struct {
KBServiceIface
getByIDFn func(kbID string) (*entity.Knowledgebase, error)
accessibleFn func(kbID, userID string) bool
}
func (m *mockKBService) GetByID(kbID string) (*entity.Knowledgebase, error) {
if m.getByIDFn != nil {
return m.getByIDFn(kbID)
}
return &entity.Knowledgebase{
ID: kbID, TenantID: "tenant1", EmbdID: "text-embedding",
}, nil
}
func (m *mockKBService) Accessible(kbID, userID string) bool {
if m.accessibleFn != nil {
return m.accessibleFn(kbID, userID)
}
return true
}
type mockModelService struct {
ModelServiceIface
getEmbeddingFn func(tenantID, embdID string) (*modelModule.EmbeddingModel, error)
getChatModelFn func(tenantID, llmID string) (*modelModule.ChatModel, error)
}
func (m *mockModelService) GetEmbeddingModel(tenantID, embdID string) (*modelModule.EmbeddingModel, error) {
if m.getEmbeddingFn != nil {
return m.getEmbeddingFn(tenantID, embdID)
}
return &modelModule.EmbeddingModel{}, nil
}
func (m *mockModelService) GetChatModel(tenantID, llmID string) (*modelModule.ChatModel, error) {
if m.getChatModelFn != nil {
return m.getChatModelFn(tenantID, llmID)
}
return &modelModule.ChatModel{}, nil
}
type mockMetadataService struct {
MetadataServiceIface
getFlattedMetaFn func(kbIDs []string) (common.MetaData, error)
labelQuestionFn func(question string, kbs []*entity.Knowledgebase) map[string]float64
}
func (m *mockMetadataService) GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error) {
if m.getFlattedMetaFn != nil {
return m.getFlattedMetaFn(kbIDs)
}
return common.MetaData{}, nil
}
func (m *mockMetadataService) LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64 {
if m.labelQuestionFn != nil {
return m.labelQuestionFn(question, kbs)
}
return nil
}
type mockRetrievalService struct {
RetrievalServiceIface
retrievalFn func(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error)
}
func (m *mockRetrievalService) Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) {
if m.retrievalFn != nil {
return m.retrievalFn(ctx, req)
}
return &nlp.RetrievalResult{
Chunks: []map[string]interface{}{
{"doc_id": "doc1", "docnm_kwd": "Test Doc", "content_with_weight": "test content", "similarity": 0.85},
},
}, nil
}
type mockDocDAO struct {
DocumentDAOIface
getByIDsFn func(ids []string) ([]*entity.Document, error)
}
func (m *mockDocDAO) GetByIDs(ids []string) ([]*entity.Document, error) {
if m.getByIDsFn != nil {
return m.getByIDsFn(ids)
}
return []*entity.Document{
{ID: "doc1", Name: strPtr("Test Doc"), MetaFields: &entity.JSONMap{"author": "Zhang San"}},
}, nil
}
// mockDocEngine stubs the DocEngine interface (embed = panic on unimplemented).
type mockDocEngine struct {
engine.DocEngine
}
func (m *mockDocEngine) Close() error { return nil }
func (m *mockDocEngine) Ping(ctx context.Context) error { return nil }
func (m *mockDocEngine) GetType() string { return "mock" }
func (m *mockDocEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
return &types.SearchResult{}, nil
}
func (m *mockDocEngine) GetChunk(ctx context.Context, _, _ string, _ []string) (interface{}, error) {
return map[string]interface{}{}, nil
}
// --- Helper ---
func setupDifyTest(userID string) (*DifyRetrievalHandler, *gin.Engine) {
h := &DifyRetrievalHandler{
kbSvc: &mockKBService{},
modelSvc: &mockModelService{},
metadataSvc: &mockMetadataService{},
retrievalSvc: &mockRetrievalService{},
docDAO: &mockDocDAO{},
docEngine: &mockDocEngine{},
}
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("user", &entity.User{ID: userID})
})
r.POST("/api/v1/dify/retrieval", h.Retrieval)
r.GET("/api/v1/dify/retrieval", h.Retrieval)
r.GET("/api/v1/dify/retrieval/health", h.HealthCheck)
return h, r
}
func setupDifyTestNoAuth() (*DifyRetrievalHandler, *gin.Engine) {
h := &DifyRetrievalHandler{
kbSvc: &mockKBService{},
modelSvc: &mockModelService{},
metadataSvc: &mockMetadataService{},
retrievalSvc: &mockRetrievalService{},
docDAO: &mockDocDAO{},
docEngine: &mockDocEngine{},
}
gin.SetMode(gin.TestMode)
r := gin.New()
r.POST("/api/v1/dify/retrieval", h.Retrieval)
return h, r
}
// --- Tests ---
func TestDifyRetrieval_HealthCheck(t *testing.T) {
_, r := setupDifyTest("user1")
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/dify/retrieval/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp["data"] != true {
t.Errorf("expected data=true, got %v", resp["data"])
}
}
func TestDifyRetrieval_Basic(t *testing.T) {
_, r := setupDifyTest("user1")
body := `{"knowledge_id": "kb1", "query": "test question"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
records, ok := resp["records"].([]interface{})
if !ok || len(records) == 0 {
t.Errorf("expected non-empty records, got %v", resp["records"])
}
}
func TestDifyRetrieval_GET(t *testing.T) {
_, r := setupDifyTest("user1")
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/dify/retrieval?knowledge_id=kb1&query=test", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
}
func TestDifyRetrieval_MissingArgs(t *testing.T) {
_, r := setupDifyTest("user1")
tests := []struct {
name string
body string
}{
{"no knowledge_id", `{"query": "test"}`},
{"no query", `{"knowledge_id": "kb1"}`},
{"empty body", `{}`},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
})
}
}
func TestDifyRetrieval_KBNotFound(t *testing.T) {
h, r := setupDifyTest("user1")
h.kbSvc = &mockKBService{
getByIDFn: func(kbID string) (*entity.Knowledgebase, error) {
return nil, gorm.ErrRecordNotFound
},
}
w := httptest.NewRecorder()
body := `{"knowledge_id": "nonexistent", "query": "test"}`
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404, got %d", w.Code)
}
}
func TestDifyRetrieval_NoAuth(t *testing.T) {
_, r := setupDifyTestNoAuth()
w := httptest.NewRecorder()
body := `{"knowledge_id": "kb1", "query": "test"}`
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("expected 401, got %d", w.Code)
}
}
func TestDifyRetrieval_Unauthorized(t *testing.T) {
h, r := setupDifyTest("user1")
h.kbSvc = &mockKBService{
accessibleFn: func(kbID, userID string) bool { return false },
}
w := httptest.NewRecorder()
body := `{"knowledge_id": "kb1", "query": "test"}`
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("expected 401, got %d", w.Code)
}
}
func TestDifyRetrieval_WithMetadataFilter(t *testing.T) {
h, r := setupDifyTest("user1")
h.metadataSvc = &mockMetadataService{
getFlattedMetaFn: func(kbIDs []string) (common.MetaData, error) {
return common.MetaData{}, nil
},
}
body := `{"knowledge_id":"kb1","query":"test","metadata_condition":{"conditions":[{"name":"author","comparison_operator":"eq","value":"Zhang San"}],"logic":"and"}}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
}
func TestDifyRetrieval_InvalidJSON(t *testing.T) {
_, r := setupDifyTest("user1")
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader("{invalid json"))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
}
func TestDifyRetrieval_UseKG(t *testing.T) {
h, r := setupDifyTest("user1")
h.metadataSvc = &mockMetadataService{
labelQuestionFn: func(question string, kbs []*entity.Knowledgebase) map[string]float64 {
return map[string]float64{"tag_1": 0.8}
},
}
body := `{"knowledge_id":"kb1","query":"test","use_kg":true}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
}
func strPtr(s string) *string { return &s }
func TestDifyRetrieval_KBDBError(t *testing.T) {
h, r := setupDifyTest("user1")
h.kbSvc = &mockKBService{
getByIDFn: func(kbID string) (*entity.Knowledgebase, error) {
return nil, errors.New("connection refused")
},
}
w := httptest.NewRecorder()
body := `{"knowledge_id": "kb1", "query": "test"}`
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected 500 for DB error, got %d", w.Code)
}
}
func TestDifyRetrieval_DocLoadError(t *testing.T) {
h, r := setupDifyTest("user1")
h.docDAO = &mockDocDAO{
getByIDsFn: func(ids []string) ([]*entity.Document, error) {
return nil, errors.New("db unavailable")
},
}
w := httptest.NewRecorder()
body := `{"knowledge_id": "kb1", "query": "test"}`
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected 500 for doc load error, got %d", w.Code)
}
}
func TestDifyRetrieval_RetrievalNotFound(t *testing.T) {
h, r := setupDifyTest("user1")
h.retrievalSvc = &mockRetrievalService{
retrievalFn: func(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) {
return nil, errors.New("no chunk found: not_found")
},
}
w := httptest.NewRecorder()
body := `{"knowledge_id": "kb1", "query": "test"}`
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404 for not_found, got %d", w.Code)
}
}