2026-06-05 21:16:25 +08:00
|
|
|
//
|
|
|
|
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
|
|
|
//
|
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
|
//
|
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
//
|
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
//
|
|
|
|
|
|
|
|
|
|
package handler
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"context"
|
|
|
|
|
"errors"
|
|
|
|
|
"fmt"
|
|
|
|
|
"net/http"
|
|
|
|
|
"strconv"
|
|
|
|
|
"strings"
|
|
|
|
|
|
|
|
|
|
"go.uber.org/zap"
|
2026-06-09 20:07:45 -07:00
|
|
|
"gorm.io/gorm"
|
|
|
|
|
"ragflow/internal/common"
|
2026-06-05 21:16:25 +08:00
|
|
|
"ragflow/internal/engine"
|
|
|
|
|
"ragflow/internal/entity"
|
|
|
|
|
modelModule "ragflow/internal/entity/models"
|
|
|
|
|
"ragflow/internal/service"
|
|
|
|
|
"ragflow/internal/service/kg"
|
|
|
|
|
"ragflow/internal/service/nlp"
|
|
|
|
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// --- Interfaces (for testability) ---
|
|
|
|
|
|
|
|
|
|
// KBServiceIface abstracts KnowledgebaseService for the Dify handler.
|
|
|
|
|
type KBServiceIface interface {
|
|
|
|
|
GetByID(kbID string) (*entity.Knowledgebase, error)
|
|
|
|
|
Accessible(kbID, userID string) bool
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ModelServiceIface abstracts ModelProviderService for the Dify handler.
|
|
|
|
|
type ModelServiceIface interface {
|
|
|
|
|
GetEmbeddingModel(tenantID, embdID string) (*modelModule.EmbeddingModel, error)
|
|
|
|
|
GetChatModel(tenantID, compositeModelName string) (*modelModule.ChatModel, error)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// MetadataServiceIface abstracts MetadataService for the Dify handler.
|
|
|
|
|
type MetadataServiceIface interface {
|
|
|
|
|
GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error)
|
|
|
|
|
LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// RetrievalServiceIface abstracts RetrievalService for the Dify handler.
|
|
|
|
|
type RetrievalServiceIface interface {
|
|
|
|
|
Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// DocumentDAOIface abstracts DocumentDAO for the Dify handler.
|
|
|
|
|
type DocumentDAOIface interface {
|
|
|
|
|
GetByIDs(ids []string) ([]*entity.Document, error)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// --- Request / Response types ---
|
|
|
|
|
|
|
|
|
|
// difyRetrievalRequest is the JSON body / query params for the Dify retrieval endpoint.
|
|
|
|
|
type difyRetrievalRequest struct {
|
2026-06-09 20:07:45 -07:00
|
|
|
KnowledgeID string `json:"knowledge_id" form:"knowledge_id"`
|
|
|
|
|
Query string `json:"query" form:"query"`
|
|
|
|
|
UseKG bool `json:"use_kg" form:"use_kg"`
|
|
|
|
|
RetrievalSetting *difyRetrievalSetting `json:"retrieval_setting"`
|
|
|
|
|
MetadataCondition *difyMetadataCondition `json:"metadata_condition"`
|
2026-06-05 21:16:25 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type difyRetrievalSetting struct {
|
2026-06-09 20:07:45 -07:00
|
|
|
TopK *int `json:"top_k" form:"top_k"`
|
|
|
|
|
ScoreThreshold *float64 `json:"score_threshold" form:"score_threshold"`
|
2026-06-05 21:16:25 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// difyCondition is a Dify-format metadata filter condition.
|
|
|
|
|
// Dify uses "name"/"comparison_operator" instead of MetaFilterCondition's "key"/"op".
|
|
|
|
|
type difyCondition struct {
|
|
|
|
|
Name string `json:"name"`
|
|
|
|
|
ComparisonOperator string `json:"comparison_operator"`
|
|
|
|
|
Value interface{} `json:"value"`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type difyMetadataCondition struct {
|
|
|
|
|
Conditions []difyCondition `json:"conditions"`
|
|
|
|
|
Logic string `json:"logic"`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// toMetaFilterConditions converts Dify-format conditions to internal MetaFilterConditions.
|
|
|
|
|
func (c difyMetadataCondition) toMetaFilterConditions() []service.MetaFilterCondition {
|
|
|
|
|
if len(c.Conditions) == 0 {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
result := make([]service.MetaFilterCondition, len(c.Conditions))
|
|
|
|
|
for i, dc := range c.Conditions {
|
|
|
|
|
v := ""
|
|
|
|
|
if dc.Value != nil {
|
|
|
|
|
v = fmt.Sprint(dc.Value)
|
|
|
|
|
}
|
|
|
|
|
result[i] = service.MetaFilterCondition{
|
2026-06-09 20:07:45 -07:00
|
|
|
Key: dc.Name,
|
|
|
|
|
Op: dc.ComparisonOperator,
|
2026-06-05 21:16:25 +08:00
|
|
|
Value: v,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return result
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// difyRecord is one item in the response records array.
|
|
|
|
|
type difyRecord struct {
|
|
|
|
|
Content string `json:"content"`
|
|
|
|
|
Score float64 `json:"score"`
|
|
|
|
|
Title string `json:"title"`
|
|
|
|
|
Metadata map[string]interface{} `json:"metadata"`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// --- Handler ---
|
|
|
|
|
|
|
|
|
|
// DifyRetrievalHandler handles Dify-compatible retrieval requests.
|
|
|
|
|
type DifyRetrievalHandler struct {
|
|
|
|
|
kbSvc KBServiceIface
|
|
|
|
|
modelSvc ModelServiceIface
|
|
|
|
|
metadataSvc MetadataServiceIface
|
|
|
|
|
retrievalSvc RetrievalServiceIface
|
|
|
|
|
docDAO DocumentDAOIface
|
|
|
|
|
docEngine engine.DocEngine
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// NewDifyRetrievalHandler creates a new DifyRetrievalHandler.
|
|
|
|
|
// The KG pipeline is created inline when use_kg=true to avoid injecting
|
|
|
|
|
// a pipeline that depends on per-request model configuration.
|
|
|
|
|
func NewDifyRetrievalHandler(
|
|
|
|
|
kbSvc KBServiceIface,
|
|
|
|
|
modelSvc ModelServiceIface,
|
|
|
|
|
metadataSvc MetadataServiceIface,
|
|
|
|
|
retrievalSvc RetrievalServiceIface,
|
|
|
|
|
docDAO DocumentDAOIface,
|
|
|
|
|
docEngine engine.DocEngine,
|
|
|
|
|
) *DifyRetrievalHandler {
|
|
|
|
|
return &DifyRetrievalHandler{
|
|
|
|
|
kbSvc: kbSvc,
|
|
|
|
|
modelSvc: modelSvc,
|
|
|
|
|
metadataSvc: metadataSvc,
|
|
|
|
|
retrievalSvc: retrievalSvc,
|
|
|
|
|
docDAO: docDAO,
|
|
|
|
|
docEngine: docEngine,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Retrieval handles POST/GET /api/v1/dify/retrieval.
|
|
|
|
|
// Matches Python: api/apps/restful_apis/dify_retrieval_api.py::retrieval()
|
|
|
|
|
func (h *DifyRetrievalHandler) Retrieval(c *gin.Context) {
|
|
|
|
|
user, errCode, errMsg := GetUser(c)
|
|
|
|
|
if errCode != common.CodeSuccess {
|
|
|
|
|
c.JSON(http.StatusUnauthorized, gin.H{"code": errCode, "message": errMsg})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var req difyRetrievalRequest
|
|
|
|
|
if c.Request.Method == http.MethodGet {
|
|
|
|
|
if err := c.ShouldBindQuery(&req); err != nil {
|
|
|
|
|
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "message": "invalid query parameters"})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
// Manually extract top_k and score_threshold from query (flat params, not nested)
|
|
|
|
|
if v := c.Query("top_k"); v != "" {
|
|
|
|
|
if parsed, err := strconv.Atoi(v); err == nil {
|
|
|
|
|
if req.RetrievalSetting == nil {
|
|
|
|
|
req.RetrievalSetting = &difyRetrievalSetting{}
|
|
|
|
|
}
|
|
|
|
|
req.RetrievalSetting.TopK = &parsed
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if v := c.Query("score_threshold"); v != "" {
|
|
|
|
|
if parsed, err := strconv.ParseFloat(v, 64); err == nil {
|
|
|
|
|
if req.RetrievalSetting == nil {
|
|
|
|
|
req.RetrievalSetting = &difyRetrievalSetting{}
|
|
|
|
|
}
|
|
|
|
|
req.RetrievalSetting.ScoreThreshold = &parsed
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
|
|
|
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "message": "invalid request body"})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if req.KnowledgeID == "" || req.Query == "" {
|
|
|
|
|
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "message": "knowledge_id and query are required"})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kb, err := h.kbSvc.GetByID(req.KnowledgeID)
|
|
|
|
|
if err != nil {
|
|
|
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
|
|
|
c.JSON(http.StatusNotFound, gin.H{"code": common.CodeNotFound, "message": "Knowledgebase not found!"})
|
|
|
|
|
} else {
|
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "message": "failed to query knowledgebase"})
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if !h.kbSvc.Accessible(req.KnowledgeID, user.ID) {
|
|
|
|
|
c.JSON(http.StatusUnauthorized, gin.H{"code": common.CodeAuthenticationError, "message": "No authorization."})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Parse retrieval options (nil means service uses defaults)
|
|
|
|
|
var topK *int
|
|
|
|
|
if req.RetrievalSetting != nil && req.RetrievalSetting.TopK != nil {
|
|
|
|
|
topK = req.RetrievalSetting.TopK
|
|
|
|
|
}
|
|
|
|
|
var scoreThreshold *float64
|
|
|
|
|
if req.RetrievalSetting != nil && req.RetrievalSetting.ScoreThreshold != nil {
|
|
|
|
|
scoreThreshold = req.RetrievalSetting.ScoreThreshold
|
|
|
|
|
}
|
|
|
|
|
pageSize := 1024
|
|
|
|
|
if topK != nil {
|
|
|
|
|
pageSize = *topK
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get embedding model
|
|
|
|
|
embModel, err := h.modelSvc.GetEmbeddingModel(kb.TenantID, kb.EmbdID)
|
|
|
|
|
if err != nil {
|
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "message": fmt.Sprintf("failed to get embedding model: %v", err)})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Metadata filter
|
|
|
|
|
metas, metaErr := h.metadataSvc.GetFlattedMetaByKBs([]string{req.KnowledgeID})
|
|
|
|
|
docIDs := make([]string, 0)
|
|
|
|
|
if metaErr == nil && req.MetadataCondition != nil {
|
2026-06-09 20:07:45 -07:00
|
|
|
logic := req.MetadataCondition.Logic
|
|
|
|
|
if logic == "" {
|
|
|
|
|
logic = "and"
|
|
|
|
|
}
|
|
|
|
|
filteredIDs := service.ApplyMetaFilter(metas, req.MetadataCondition.toMetaFilterConditions(), logic)
|
2026-06-05 21:16:25 +08:00
|
|
|
docIDs = append(docIDs, filteredIDs...)
|
|
|
|
|
}
|
|
|
|
|
if len(docIDs) == 0 && req.MetadataCondition != nil {
|
2026-06-09 20:07:45 -07:00
|
|
|
docIDs = []string{service.NoMatchDocIDSentinel}
|
2026-06-05 21:16:25 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Label question for rank features
|
|
|
|
|
kbs := []*entity.Knowledgebase{kb}
|
|
|
|
|
rankFeature := h.metadataSvc.LabelQuestion(req.Query, kbs)
|
|
|
|
|
|
|
|
|
|
// Chunk retrieval
|
|
|
|
|
sr := &nlp.RetrievalRequest{
|
2026-06-09 20:07:45 -07:00
|
|
|
Question: req.Query,
|
|
|
|
|
TenantIDs: []string{kb.TenantID},
|
|
|
|
|
KbIDs: []string{req.KnowledgeID},
|
|
|
|
|
DocIDs: docIDs,
|
|
|
|
|
Page: 1,
|
|
|
|
|
PageSize: pageSize,
|
|
|
|
|
Top: topK,
|
|
|
|
|
SimilarityThreshold: scoreThreshold,
|
|
|
|
|
EmbeddingModel: embModel,
|
2026-06-05 21:16:25 +08:00
|
|
|
}
|
|
|
|
|
if rankFeature != nil {
|
|
|
|
|
sr.RankFeature = &rankFeature
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
result, err := h.retrievalSvc.Retrieval(c.Request.Context(), sr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
if strings.Contains(err.Error(), "not_found") {
|
|
|
|
|
c.JSON(http.StatusNotFound, gin.H{"code": common.CodeNotFound, "message": "No chunk found! Check the chunk status please!"})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "message": err.Error()})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Enrich with child chunks
|
|
|
|
|
chunks := nlp.RetrievalByChildren(result.Chunks, []string{kb.TenantID}, h.docEngine, c.Request.Context())
|
|
|
|
|
|
|
|
|
|
// KG retrieval (optional)
|
|
|
|
|
if req.UseKG {
|
|
|
|
|
chatModel, kgErr := h.modelSvc.GetChatModel(kb.TenantID, "")
|
|
|
|
|
if kgErr != nil {
|
|
|
|
|
common.Warn("KG retrieval: failed to get chat model", zap.String("kbID", req.KnowledgeID), zap.Error(kgErr))
|
|
|
|
|
} else if chatModel != nil {
|
|
|
|
|
kgPipeline := kg.NewPipeline(
|
|
|
|
|
h.docEngine,
|
|
|
|
|
[]string{req.KnowledgeID},
|
|
|
|
|
[]string{kb.TenantID},
|
|
|
|
|
req.Query,
|
|
|
|
|
)
|
|
|
|
|
kgPipeline.SetChatModel(chatModel)
|
|
|
|
|
kgPipeline.SetEmbModel(embModel)
|
|
|
|
|
if kgResult, kgErr := kgPipeline.Retrieval(c.Request.Context()); kgErr == nil {
|
|
|
|
|
if content, ok := kgResult["content_with_weight"].(string); ok && content != "" {
|
|
|
|
|
chunks = append([]map[string]interface{}{kgResult}, chunks...)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Collect doc IDs and fetch documents
|
|
|
|
|
docIDSet := make(map[string]struct{})
|
|
|
|
|
for _, ch := range chunks {
|
|
|
|
|
if docID, ok := ch["doc_id"].(string); ok && docID != "" {
|
|
|
|
|
docIDSet[docID] = struct{}{}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
allDocIDs := make([]string, 0, len(docIDSet))
|
|
|
|
|
for id := range docIDSet {
|
|
|
|
|
allDocIDs = append(allDocIDs, id)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
docMap := make(map[string]*entity.Document)
|
|
|
|
|
if len(allDocIDs) > 0 {
|
|
|
|
|
docs, err := h.docDAO.GetByIDs(allDocIDs)
|
|
|
|
|
if err != nil {
|
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "message": fmt.Sprintf("failed to load documents: %v", err)})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
for _, d := range docs {
|
|
|
|
|
docMap[d.ID] = d
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Build response
|
|
|
|
|
records := make([]difyRecord, 0, len(chunks))
|
|
|
|
|
for _, ch := range chunks {
|
|
|
|
|
docID, _ := ch["doc_id"].(string)
|
|
|
|
|
doc := docMap[docID]
|
|
|
|
|
if doc == nil {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Remove vector to reduce response size
|
|
|
|
|
delete(ch, "vector")
|
|
|
|
|
|
|
|
|
|
meta := make(map[string]interface{})
|
|
|
|
|
if doc.MetaFields != nil {
|
|
|
|
|
for k, v := range *doc.MetaFields {
|
|
|
|
|
meta[k] = v
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
meta["doc_id"] = docID
|
|
|
|
|
meta["document_id"] = docID
|
|
|
|
|
|
|
|
|
|
score, _ := ch["similarity"].(float64)
|
|
|
|
|
title, _ := ch["docnm_kwd"].(string)
|
|
|
|
|
content, _ := ch["content_with_weight"].(string)
|
|
|
|
|
|
|
|
|
|
records = append(records, difyRecord{
|
|
|
|
|
Content: content,
|
|
|
|
|
Score: score,
|
|
|
|
|
Title: title,
|
|
|
|
|
Metadata: meta,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{"records": records})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// HealthCheck returns a simple health check response.
|
|
|
|
|
func (h *DifyRetrievalHandler) HealthCheck(c *gin.Context) {
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{"code": 0, "data": true})
|
|
|
|
|
}
|