mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-03 09:11:59 +08:00
### Summary Port the following PRs to GO in this PR https://github.com/infiniflow/ragflow/pull/14210 https://github.com/infiniflow/ragflow/pull/14641 https://github.com/infiniflow/ragflow/pull/15220 https://github.com/infiniflow/ragflow/pull/15228 https://github.com/infiniflow/ragflow/pull/15384 https://github.com/infiniflow/ragflow/pull/15754 https://github.com/infiniflow/ragflow/pull/16413 https://github.com/infiniflow/ragflow/pull/16483 https://github.com/infiniflow/ragflow/pull/16419 https://github.com/infiniflow/ragflow/pull/16361 https://github.com/infiniflow/ragflow/pull/16050
446 lines
14 KiB
Go
446 lines
14 KiB
Go
//
|
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
|
|
package mcp
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
)
|
|
|
|
// Connector provides the data-access operations the MCP tools need.
|
|
// It abstracts the RAGFlow backend so that tool implementations can use
|
|
// either in-process service calls or out-of-process HTTP calls.
|
|
type Connector interface {
|
|
// ListDatasets returns newline-delimited JSON lines, each containing
|
|
// at minimum {"id": "...", "description": "..."}.
|
|
ListDatasets(page, pageSize int, orderby string, desc bool) (string, error)
|
|
|
|
// ListChats returns newline-delimited JSON lines, each containing
|
|
// at minimum {"id": "...", "name": "...", "description": "..."}.
|
|
ListChats(page, pageSize int, orderby string, desc bool) (string, error)
|
|
|
|
// Retrieval executes a retrieval request and returns the result as
|
|
// a JSON string.
|
|
Retrieval(req RetrievalRequest) (string, error)
|
|
}
|
|
|
|
// RetrievalRequest carries all parameters for a retrieval query.
|
|
type RetrievalRequest struct {
|
|
DatasetIDs []string `json:"dataset_ids"`
|
|
DocumentIDs []string `json:"document_ids"`
|
|
Question string `json:"question"`
|
|
Page int `json:"page"`
|
|
PageSize int `json:"page_size"`
|
|
SimilarityThreshold float64 `json:"similarity_threshold"`
|
|
VectorSimilarityWeight float64 `json:"vector_similarity_weight"`
|
|
TopK int `json:"top_k"`
|
|
RerankID string `json:"rerank_id,omitempty"`
|
|
Keyword bool `json:"keyword"`
|
|
ForceRefresh bool `json:"force_refresh"`
|
|
}
|
|
|
|
// Server handles MCP JSON-RPC requests.
|
|
type Server struct {
|
|
connector Connector
|
|
version string
|
|
}
|
|
|
|
// NewServer creates a new MCP Server.
|
|
func NewServer(connector Connector) *Server {
|
|
return &Server{
|
|
connector: connector,
|
|
version: "1.0.0",
|
|
}
|
|
}
|
|
|
|
// HandleRequest dispatches a raw JSON-RPC request body and returns the
|
|
// serialized JSON-RPC response. Returns nil if the request is a
|
|
// notification (no id) and requires no response.
|
|
func (s *Server) HandleRequest(body []byte) ([]byte, bool, error) {
|
|
// Try to decode as a request (with an id) first.
|
|
var req JSONRPCRequest
|
|
if err := json.Unmarshal(body, &req); err != nil {
|
|
resp := NewParseError()
|
|
data, _ := json.Marshal(resp)
|
|
return data, true, nil
|
|
}
|
|
|
|
// Notifications have no id — do not send a response.
|
|
if req.ID == nil || string(req.ID) == "null" {
|
|
return nil, false, nil
|
|
}
|
|
|
|
// Validate jsonrpc field.
|
|
if req.JSONRPC != JSONRPCVersion {
|
|
resp := NewInvalidRequestError(req.ID, "jsonrpc must be \"2.0\"")
|
|
data, _ := json.Marshal(resp)
|
|
return data, true, nil
|
|
}
|
|
|
|
var resp JSONRPCResponse
|
|
|
|
switch req.Method {
|
|
case "initialize":
|
|
resp = s.handleInitialize(req.ID)
|
|
case "tools/list":
|
|
resp = s.handleListTools(req.ID)
|
|
case "tools/call":
|
|
resp = s.handleCallTool(req.ID, req.Params)
|
|
case "ping":
|
|
resp = s.handlePing(req.ID)
|
|
default:
|
|
resp = NewErrorResponse(req.ID, ErrCodeMethodNotFound,
|
|
fmt.Sprintf("Method not found: %s", req.Method))
|
|
}
|
|
|
|
data, err := json.Marshal(resp)
|
|
if err != nil {
|
|
return nil, true, fmt.Errorf("failed to marshal response: %w", err)
|
|
}
|
|
return data, true, nil
|
|
}
|
|
|
|
func (s *Server) handleInitialize(id json.RawMessage) JSONRPCResponse {
|
|
result := InitializeResult{
|
|
ProtocolVersion: MCPProtocolVersion,
|
|
Capabilities: Capabilities{
|
|
Tools: &ToolsCapability{ListChanged: false},
|
|
},
|
|
ServerInfo: ServerInfo{
|
|
Name: ServerName,
|
|
Version: s.version,
|
|
},
|
|
}
|
|
return NewSuccessResponse(id, result)
|
|
}
|
|
|
|
func (s *Server) handlePing(id json.RawMessage) JSONRPCResponse {
|
|
return NewSuccessResponse(id, struct{}{})
|
|
}
|
|
|
|
func (s *Server) handleListTools(id json.RawMessage) JSONRPCResponse {
|
|
// Fetch dataset and chat descriptions for embedding into tool descriptions,
|
|
// matching the Python MCP server behavior.
|
|
datasetDescription, err := s.connector.ListDatasets(1, 100, "create_time", true)
|
|
if err != nil {
|
|
datasetDescription = ""
|
|
}
|
|
chatDescription, err := s.connector.ListChats(1, 30, "create_time", true)
|
|
if err != nil {
|
|
chatDescription = ""
|
|
}
|
|
|
|
tools := []Tool{
|
|
{
|
|
Name: "ragflow_retrieval",
|
|
Description: "Retrieve relevant chunks from the RAGFlow retrieve interface based on the question. You can optionally specify dataset_ids to search only specific datasets, or omit dataset_ids entirely to search across ALL available datasets. You can also optionally specify document_ids to search within specific documents. When dataset_ids is not provided or is empty, the system will automatically search across all available datasets. Below is the list of all available datasets, including their descriptions and IDs:" + datasetDescription,
|
|
InputSchema: InputSchema{
|
|
Type: "object",
|
|
Properties: map[string]Property{
|
|
"dataset_ids": {
|
|
Type: "array",
|
|
Description: "Optional array of dataset IDs to search. If not provided or empty, all datasets will be searched.",
|
|
Items: &Items{Type: "string"},
|
|
},
|
|
"document_ids": {
|
|
Type: "array",
|
|
Description: "Optional array of document IDs to search within.",
|
|
Items: &Items{Type: "string"},
|
|
},
|
|
"question": {
|
|
Type: "string",
|
|
Description: "The question or query to search for.",
|
|
},
|
|
"page": {
|
|
Type: "integer",
|
|
Description: "Page number for pagination",
|
|
Default: 1,
|
|
Minimum: float64Ptr(1),
|
|
},
|
|
"page_size": {
|
|
Type: "integer",
|
|
Description: "Number of results to return per page (default: 10, max recommended: 50 to avoid token limits)",
|
|
Default: 10,
|
|
Minimum: float64Ptr(1),
|
|
Maximum: float64Ptr(100),
|
|
},
|
|
"similarity_threshold": {
|
|
Type: "number",
|
|
Description: "Minimum similarity threshold for results",
|
|
Default: 0.2,
|
|
Minimum: float64Ptr(0),
|
|
Maximum: float64Ptr(1),
|
|
},
|
|
"vector_similarity_weight": {
|
|
Type: "number",
|
|
Description: "Weight for vector similarity vs term similarity",
|
|
Default: 0.3,
|
|
Minimum: float64Ptr(0),
|
|
Maximum: float64Ptr(1),
|
|
},
|
|
"keyword": {
|
|
Type: "boolean",
|
|
Description: "Enable keyword-based search",
|
|
Default: false,
|
|
},
|
|
"top_k": {
|
|
Type: "integer",
|
|
Description: "Maximum results to consider before ranking",
|
|
Default: 1024,
|
|
Minimum: float64Ptr(1),
|
|
Maximum: float64Ptr(1024),
|
|
},
|
|
"rerank_id": {
|
|
Type: "string",
|
|
Description: "Optional reranking model identifier",
|
|
},
|
|
"force_refresh": {
|
|
Type: "boolean",
|
|
Description: "Set to true only if fresh dataset and document metadata is explicitly required. Otherwise, cached metadata is used (default: false).",
|
|
Default: false,
|
|
},
|
|
},
|
|
Required: []string{"question"},
|
|
},
|
|
},
|
|
{
|
|
Name: "ragflow_list_datasets",
|
|
Description: "List all accessible datasets (knowledge bases) in RAGFlow. Returns dataset IDs, names, and descriptions. Use this tool to discover which datasets are available before performing retrieval." + datasetDescription,
|
|
InputSchema: InputSchema{
|
|
Type: "object",
|
|
Properties: map[string]Property{
|
|
"page": {
|
|
Type: "integer",
|
|
Description: "Page number",
|
|
Default: 1,
|
|
Minimum: float64Ptr(1),
|
|
},
|
|
"page_size": {
|
|
Type: "integer",
|
|
Description: "Results per page",
|
|
Default: 100,
|
|
Minimum: float64Ptr(1),
|
|
Maximum: float64Ptr(1000),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "ragflow_list_chats",
|
|
Description: "List all accessible chat assistants in RAGFlow. Returns chat assistant IDs, names, and descriptions. Use this tool to discover available chat assistants that can be used for conversations." + chatDescription,
|
|
InputSchema: InputSchema{
|
|
Type: "object",
|
|
Properties: map[string]Property{
|
|
"page": {
|
|
Type: "integer",
|
|
Description: "Page number",
|
|
Default: 1,
|
|
Minimum: float64Ptr(1),
|
|
},
|
|
"page_size": {
|
|
Type: "integer",
|
|
Description: "Results per page",
|
|
Default: 30,
|
|
Minimum: float64Ptr(1),
|
|
Maximum: float64Ptr(100),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
return NewSuccessResponse(id, ListToolsResult{Tools: tools})
|
|
}
|
|
|
|
func (s *Server) handleCallTool(id json.RawMessage, rawParams json.RawMessage) JSONRPCResponse {
|
|
var params CallToolParams
|
|
if err := json.Unmarshal(rawParams, ¶ms); err != nil {
|
|
return NewErrorResponse(id, ErrCodeInvalidParams,
|
|
fmt.Sprintf("Invalid params: %s", err.Error()))
|
|
}
|
|
|
|
if params.Arguments == nil {
|
|
params.Arguments = make(map[string]interface{})
|
|
}
|
|
|
|
switch params.Name {
|
|
case "ragflow_retrieval":
|
|
return s.callRagflowRetrieval(id, params.Arguments)
|
|
case "ragflow_list_datasets":
|
|
return s.callListDatasets(id, params.Arguments)
|
|
case "ragflow_list_chats":
|
|
return s.callListChats(id, params.Arguments)
|
|
default:
|
|
return NewErrorResponse(id, ErrCodeMethodNotFound,
|
|
fmt.Sprintf("Tool not found: %s", params.Name))
|
|
}
|
|
}
|
|
|
|
func (s *Server) callRagflowRetrieval(id json.RawMessage, args map[string]interface{}) JSONRPCResponse {
|
|
req := RetrievalRequest{
|
|
Page: getBoundedIntArg(args, "page", 1, 1, 1_000_000),
|
|
PageSize: getBoundedIntArg(args, "page_size", 10, 1, 100),
|
|
SimilarityThreshold: getBoundedFloat64Arg(args, "similarity_threshold", 0.2, 0, 1),
|
|
VectorSimilarityWeight: getBoundedFloat64Arg(args, "vector_similarity_weight", 0.3, 0, 1),
|
|
Keyword: getBoolArg(args, "keyword", false),
|
|
TopK: getBoundedIntArg(args, "top_k", 1024, 1, 1024),
|
|
ForceRefresh: getBoolArg(args, "force_refresh", false),
|
|
Question: getStringArg(args, "question", ""),
|
|
RerankID: getStringArg(args, "rerank_id", ""),
|
|
}
|
|
|
|
if v, ok := args["dataset_ids"]; ok {
|
|
req.DatasetIDs = toStringSlice(v)
|
|
}
|
|
if v, ok := args["document_ids"]; ok {
|
|
req.DocumentIDs = toStringSlice(v)
|
|
}
|
|
|
|
if strings.TrimSpace(req.Question) == "" {
|
|
return NewSuccessResponse(id, NewErrorResult("question is required"))
|
|
}
|
|
|
|
result, err := s.connector.Retrieval(req)
|
|
if err != nil {
|
|
return NewSuccessResponse(id, NewErrorResult(err.Error()))
|
|
}
|
|
return NewSuccessResponse(id, NewTextResult(result))
|
|
}
|
|
|
|
func (s *Server) callListDatasets(id json.RawMessage, args map[string]interface{}) JSONRPCResponse {
|
|
page := getBoundedIntArg(args, "page", 1, 1, 1_000_000)
|
|
pageSize := getBoundedIntArg(args, "page_size", 100, 1, 1000)
|
|
|
|
result, err := s.connector.ListDatasets(page, pageSize, "create_time", true)
|
|
if err != nil {
|
|
return NewSuccessResponse(id, NewErrorResult(err.Error()))
|
|
}
|
|
return NewSuccessResponse(id, NewTextResult(result))
|
|
}
|
|
|
|
func (s *Server) callListChats(id json.RawMessage, args map[string]interface{}) JSONRPCResponse {
|
|
page := getBoundedIntArg(args, "page", 1, 1, 1_000_000)
|
|
pageSize := getBoundedIntArg(args, "page_size", 30, 1, 100)
|
|
|
|
result, err := s.connector.ListChats(page, pageSize, "create_time", true)
|
|
if err != nil {
|
|
return NewSuccessResponse(id, NewErrorResult(err.Error()))
|
|
}
|
|
return NewSuccessResponse(id, NewTextResult(result))
|
|
}
|
|
|
|
// --- argument extraction helpers ---
|
|
|
|
func getStringArg(args map[string]interface{}, key, defaultVal string) string {
|
|
if v, ok := args[key]; ok {
|
|
if s, ok := v.(string); ok {
|
|
return s
|
|
}
|
|
}
|
|
return defaultVal
|
|
}
|
|
|
|
func getIntArg(args map[string]interface{}, key string, defaultVal int) int {
|
|
if v, ok := args[key]; ok {
|
|
switch n := v.(type) {
|
|
case float64:
|
|
return int(n)
|
|
case int:
|
|
return n
|
|
case int64:
|
|
return int(n)
|
|
case json.Number:
|
|
if i, err := n.Int64(); err == nil {
|
|
return int(i)
|
|
}
|
|
}
|
|
}
|
|
return defaultVal
|
|
}
|
|
|
|
func getBoundedIntArg(args map[string]interface{}, key string, defaultVal, minVal, maxVal int) int {
|
|
v := getIntArg(args, key, defaultVal)
|
|
if v < minVal {
|
|
return minVal
|
|
}
|
|
if v > maxVal {
|
|
return maxVal
|
|
}
|
|
return v
|
|
}
|
|
|
|
func getFloat64Arg(args map[string]interface{}, key string, defaultVal float64) float64 {
|
|
if v, ok := args[key]; ok {
|
|
switch f := v.(type) {
|
|
case float64:
|
|
return f
|
|
case int:
|
|
return float64(f)
|
|
case int64:
|
|
return float64(f)
|
|
case json.Number:
|
|
if fl, err := f.Float64(); err == nil {
|
|
return fl
|
|
}
|
|
}
|
|
}
|
|
return defaultVal
|
|
}
|
|
|
|
func getBoundedFloat64Arg(args map[string]interface{}, key string, defaultVal, minVal, maxVal float64) float64 {
|
|
v := getFloat64Arg(args, key, defaultVal)
|
|
if v < minVal {
|
|
return minVal
|
|
}
|
|
if v > maxVal {
|
|
return maxVal
|
|
}
|
|
return v
|
|
}
|
|
|
|
func getBoolArg(args map[string]interface{}, key string, defaultVal bool) bool {
|
|
if v, ok := args[key]; ok {
|
|
if b, ok := v.(bool); ok {
|
|
return b
|
|
}
|
|
if s, ok := v.(string); ok {
|
|
switch strings.ToLower(s) {
|
|
case "true", "1", "yes":
|
|
return true
|
|
case "false", "0", "no":
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return defaultVal
|
|
}
|
|
|
|
// toStringSlice converts an interface{} value (expected to be a JSON array)
|
|
// into a []string. Values are converted via fmt.Sprintf.
|
|
func toStringSlice(v interface{}) []string {
|
|
arr, ok := v.([]interface{})
|
|
if !ok {
|
|
return nil
|
|
}
|
|
result := make([]string, 0, len(arr))
|
|
for _, item := range arr {
|
|
result = append(result, fmt.Sprintf("%v", item))
|
|
}
|
|
return result
|
|
}
|